/*
 * Decompiled with CFR 0.152.
 */
package edu.brandeis.glycodenovo.core;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Scanner;

public class MyClassifier {
    ArrayList<Double> trainedWeights;
    ArrayList<ClassificationTree> cForest;

    MyClassifier(String rootpath, char pos, char neg) {
        if (rootpath.charAt(rootpath.length() - 1) != '\\') {
            rootpath = String.valueOf(rootpath) + '\\';
        }
        Scanner sc = null;
        this.trainedWeights = new ArrayList();
        this.cForest = new ArrayList();
        String path = String.valueOf(rootpath) + "ionclassifier_" + pos + "_v_" + neg + ".txt";
        File specFile = new File(path);
        try {
            sc = new Scanner(specFile);
        }
        catch (FileNotFoundException fileNotFoundException) {
            throw new IllegalArgumentException("ionclassifier_" + pos + "_v_" + neg + ".txt doesn't exist");
        }
        int modelid = 0;
        int numNodes = 0;
        ClassificationTree currTree = null;
        while (sc.hasNextLine()) {
            String currentLine = sc.nextLine().trim();
            Scanner lineSc = new Scanner(currentLine);
            if (!lineSc.hasNext()) continue;
            String reader = lineSc.next();
            if (reader.equals("Model")) {
                modelid = lineSc.nextInt() - 1;
                continue;
            }
            if (reader.equals("NumNodes")) {
                numNodes = lineSc.nextInt();
                this.cForest.add(new ClassificationTree(numNodes));
                currTree = this.cForest.get(modelid);
                continue;
            }
            if (reader.equals("Children")) {
                currentLine = sc.nextLine().trim();
                lineSc = new Scanner(currentLine);
                int currNode = 0;
                while (!currentLine.equals("EndChildren")) {
                    int left = lineSc.nextInt();
                    int right = lineSc.nextInt();
                    currTree.setChildren(currNode, left, right);
                    ++currNode;
                    currentLine = sc.nextLine().trim();
                    lineSc = new Scanner(currentLine);
                }
                continue;
            }
            if (reader.equals("ClassProb")) {
                currentLine = sc.nextLine().trim();
                lineSc = new Scanner(currentLine);
                int currNode = 0;
                while (!currentLine.equals("EndClassProb")) {
                    double left = lineSc.nextDouble();
                    double right = lineSc.nextDouble();
                    currTree.setProb(currNode, left, right);
                    ++currNode;
                    currentLine = sc.nextLine().trim();
                    lineSc = new Scanner(currentLine);
                }
                continue;
            }
            if (reader.equals("CutPoint")) {
                currentLine = sc.nextLine().trim();
                lineSc = new Scanner(currentLine);
                int currNode = 0;
                while (lineSc.hasNext()) {
                    String read = lineSc.next();
                    currTree.getNode((int)currNode).cutpoint = this.getDouble(read);
                    ++currNode;
                }
                continue;
            }
            if (reader.equals("CutVar")) {
                currentLine = sc.nextLine().trim();
                lineSc = new Scanner(currentLine);
                int currNode = 0;
                while (lineSc.hasNext()) {
                    int read = lineSc.nextInt();
                    currTree.getNode((int)currNode).cutvar = read - 1;
                    ++currNode;
                }
                continue;
            }
            if (!reader.equals("TrainedWeights")) continue;
            currentLine = sc.nextLine().trim();
            lineSc = new Scanner(currentLine);
            while (lineSc.hasNext()) {
                double read = lineSc.nextDouble();
                this.trainedWeights.add(read);
            }
        }
    }

    void printClassifier() {
        int i = 0;
        while (i < this.cForest.size()) {
            ClassificationTree currTree = this.cForest.get(i);
            System.out.println("\nModel " + i);
            System.out.println("\tNumNodes" + currTree.numNodes);
            System.out.println("\tChildren");
            for (CNode currNode : currTree.cTree) {
                System.out.println("\t\t" + currTree.toInt(currNode.left) + " " + currTree.toInt(currNode.right));
            }
            System.out.println("\tClassProb");
            for (CNode currNode : currTree.cTree) {
                System.out.println("\t\t" + currNode.lprob + " " + currNode.rprob);
            }
            System.out.print("\tCutPoint\n\t\t");
            for (CNode currNode : currTree.cTree) {
                System.out.print(String.valueOf(currNode.cutpoint) + " ");
            }
            System.out.print("\n\tCutVar\n\t\t");
            for (CNode currNode : currTree.cTree) {
                System.out.print(String.valueOf(currNode.cutvar) + " ");
            }
            ++i;
        }
        System.out.println("\n\nTrainedWeights");
        for (Double weight : this.trainedWeights) {
            System.out.print(weight + " ");
        }
    }

    private double getDouble(String s) {
        Double a;
        try {
            a = Double.parseDouble(s);
        }
        catch (Exception exception) {
            return -1.0;
        }
        return a;
    }

    double[] getScore(ArrayList<Double> massFeatures) {
        double[] sum = new double[2];
        sum[1] = 0.0;
        sum[0] = 0.0;
        int i = 0;
        while (i < this.trainedWeights.size()) {
            ClassificationTree cTree = this.cForest.get(i);
            double weight = this.trainedWeights.get(i);
            double[] score = cTree.getScore(massFeatures);
            sum[0] = sum[0] + score[0] * weight;
            sum[1] = sum[1] + score[1] * weight;
            ++i;
        }
        sum[0] = sum[0] / (double)this.trainedWeights.size();
        sum[1] = sum[1] / (double)this.trainedWeights.size();
        return sum;
    }

    ClassificationTree getTree(int index) {
        return this.cForest.get(index);
    }

    private class CNode {
        CNode left;
        CNode right;
        CNode parent;
        double lprob = 0.0;
        double rprob = 0.0;
        double cutpoint = -1.0;
        int cutvar = -1;

        CNode() {
        }
    }

    class ClassificationTree {
        int numNodes;
        ArrayList<CNode> cTree;

        ClassificationTree(int num) {
            this.numNodes = num;
            this.cTree = new ArrayList();
            int i = 0;
            while (i < num) {
                this.cTree.add(new CNode());
                ++i;
            }
        }

        double rootCutPoint() {
            return this.cTree.get((int)0).cutpoint;
        }

        int rootCutVar() {
            return this.cTree.get((int)0).cutvar;
        }

        CNode getNode(int nodeID) {
            return this.cTree.get(nodeID);
        }

        void setChildren(int nodeID, int left, int right) {
            CNode node = this.getNode(nodeID);
            if (left > 0) {
                CNode lNode;
                node.left = lNode = this.getNode(left - 1);
                lNode.parent = node;
            }
            if (right > 0) {
                CNode rNode;
                node.right = rNode = this.getNode(right - 1);
                rNode.parent = node;
            }
        }

        void setProb(int nodeID, double left, double right) {
            CNode node = this.getNode(nodeID);
            node.lprob = left;
            node.rprob = right;
        }

        int toInt(CNode node) {
            if (node == null) {
                return -1;
            }
            return this.cTree.indexOf(node);
        }

        CNode swim(ArrayList<Double> massFeatures) {
            CNode currNode = this.cTree.get(0);
            while (currNode.left != null && currNode.right != null) {
                double value = massFeatures.get(currNode.cutvar);
                currNode = value < currNode.cutpoint ? currNode.left : currNode.right;
            }
            return currNode;
        }

        double[] getScore(ArrayList<Double> massFeatures) {
            CNode leaf = this.swim(massFeatures);
            double[] ans = new double[]{leaf.lprob >= 0.5 ? 1 : -1, leaf.rprob >= 0.5 ? 1 : -1};
            return ans;
        }

        void printTree(FileWriter filewriter) throws IOException {
            int i = 0;
            while (i < this.cTree.size()) {
                CNode currNode = this.cTree.get(i);
                filewriter.write("cutvar: " + currNode.cutvar + " cutpoing: " + currNode.cutpoint + "\n");
                filewriter.write("lprob: " + currNode.lprob + " rprob: " + currNode.rprob + "\n");
                ++i;
            }
        }
    }
}

