/*
 * Decompiled with CFR 0.152.
 */
package org.extratrees;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import org.extratrees.AbstractTrees;
import org.extratrees.BinaryTree;
import org.extratrees.Matrix;
import org.extratrees.ShuffledIterator;

public class ExtraTrees
extends AbstractTrees<BinaryTree> {
    Matrix input;
    double[] output;
    double[] outputSq;
    static double zero = 1.0E-6;
    ArrayList<Integer> cols;
    int numRandomCuts = 1;
    boolean evenCuts = false;

    public ExtraTrees(Matrix matrix, double[] dArray) {
        int n;
        if (matrix.nrows != dArray.length) {
            throw new IllegalArgumentException("Input and output do not have same length.");
        }
        this.input = matrix;
        this.output = dArray;
        this.outputSq = new double[dArray.length];
        for (n = 0; n < dArray.length; ++n) {
            this.outputSq[n] = this.output[n] * this.output[n];
        }
        this.cols = new ArrayList(matrix.ncols);
        for (n = 0; n < matrix.ncols; ++n) {
            this.cols.add(n);
        }
    }

    public boolean isEvenCuts() {
        return this.evenCuts;
    }

    public void setEvenCuts(boolean bl) {
        this.evenCuts = bl;
    }

    public int getNumRandomCuts() {
        return this.numRandomCuts;
    }

    public void setNumRandomCuts(int n) {
        this.numRandomCuts = n;
    }

    public ArrayList<BinaryTree> buildTrees(int n, int n2, int n3, int[] nArray) {
        ArrayList<BinaryTree> arrayList = new ArrayList<BinaryTree>(n3);
        ShuffledIterator<Integer> shuffledIterator = new ShuffledIterator<Integer>(this.cols);
        for (int i = 0; i < n3; ++i) {
            arrayList.add(this.buildTree(n, n2, nArray, shuffledIterator));
        }
        return arrayList;
    }

    public static double getValue(ArrayList<BinaryTree> arrayList, double[] dArray) {
        double d = 0.0;
        for (BinaryTree binaryTree : arrayList) {
            d += binaryTree.getValue(dArray);
        }
        return d / (double)arrayList.size();
    }

    public static double getValue(ArrayList<BinaryTree> arrayList, double[] dArray, int n) {
        double d = 0.0;
        for (BinaryTree binaryTree : arrayList) {
            d += binaryTree.getValue(dArray, n);
        }
        return d / (double)arrayList.size();
    }

    public double[] getValues(Matrix matrix) {
        return ExtraTrees.getValues(this.trees, matrix);
    }

    public static double[] getValues(ArrayList<BinaryTree> arrayList, Matrix matrix) {
        double[] dArray = new double[matrix.nrows];
        double[] dArray2 = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            for (int j = 0; j < matrix.ncols; ++j) {
                dArray2[j] = matrix.get(i, j);
            }
            dArray[i] = ExtraTrees.getValue(arrayList, dArray2);
        }
        return dArray;
    }

    @Override
    public BinaryTree buildTree(int n, int n2) {
        int[] nArray = new int[this.output.length];
        for (int i = 0; i < nArray.length; ++i) {
            nArray[i] = i;
        }
        ShuffledIterator<Integer> shuffledIterator = new ShuffledIterator<Integer>(this.cols);
        return this.buildTree(n, n2, nArray, shuffledIterator);
    }

    public BinaryTree buildTree(int n, int n2, int[] nArray, ShuffledIterator<Integer> shuffledIterator) {
        if (nArray.length < n) {
            return this.makeLeaf(nArray);
        }
        shuffledIterator.reset();
        int n3 = 0;
        int n4 = -1;
        double d = Double.NEGATIVE_INFINITY;
        boolean bl = false;
        boolean bl2 = false;
        int n5 = 0;
        int n6 = 0;
        double d2 = Double.NaN;
        while (shuffledIterator.hasNext()) {
            int n7 = shuffledIterator.next();
            double d3 = Double.POSITIVE_INFINITY;
            double d4 = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < nArray.length; ++i) {
                double d5 = this.input.get(nArray[i], n7);
                if (d5 < d3) {
                    d3 = d5;
                }
                if (!(d5 > d4)) continue;
                d4 = d5;
            }
            if (d4 - d3 < zero) continue;
            double d6 = d4 - d3;
            for (int i = 0; i < this.numRandomCuts; ++i) {
                double d7;
                double d8;
                if (this.evenCuts) {
                    double d9 = d3 + (double)i * d6 / (double)this.numRandomCuts;
                    d8 = d3 + (double)(i + 1) * d6 / (double)this.numRandomCuts;
                    d7 = Math.random() * (d8 - d9) + d9;
                } else {
                    d7 = Math.random() * d6 + d3;
                }
                int n8 = 0;
                int n9 = 0;
                d8 = 0.0;
                double d10 = 0.0;
                double d11 = 0.0;
                double d12 = 0.0;
                for (int j = 0; j < nArray.length; ++j) {
                    if (this.input.get(nArray[j], n7) < d7) {
                        ++n8;
                        d8 += this.output[nArray[j]];
                        d11 += this.outputSq[nArray[j]];
                        continue;
                    }
                    ++n9;
                    d10 += this.output[nArray[j]];
                    d12 += this.outputSq[nArray[j]];
                }
                double d13 = d11 / (double)n8 - d8 / (double)n8 * (d8 / (double)n8);
                double d14 = d12 / (double)n9 - d10 / (double)n9 * (d10 / (double)n9);
                double d15 = (d11 + d12) / (double)nArray.length - Math.pow((d8 + d10) / (double)nArray.length, 2.0);
                double d16 = 1.0 - ((double)n8 * d13 + (double)n9 * d14) / (double)nArray.length / d15;
                if (d15 < zero * zero) {
                    return this.makeLeaf(nArray);
                }
                if (!(d16 > d)) continue;
                d = d16;
                n4 = n7;
                d2 = d7;
                bl = d13 < zero * zero;
                bl2 = d14 < zero * zero;
                n5 = n8;
                n6 = n9;
            }
            if (++n3 < n2) continue;
            break;
        }
        if (n4 < 0) {
            return this.makeLeaf(nArray);
        }
        int[] nArray2 = new int[n5];
        int[] nArray3 = new int[n6];
        int n10 = 0;
        int n11 = 0;
        for (int i = 0; i < nArray.length; ++i) {
            if (this.input.get(nArray[i], n4) < d2) {
                nArray2[n10] = nArray[i];
                ++n10;
                continue;
            }
            nArray3[n11] = nArray[i];
            ++n11;
        }
        BinaryTree binaryTree = new BinaryTree();
        binaryTree.column = n4;
        binaryTree.threshold = d2;
        binaryTree.nSuccessors = nArray.length;
        binaryTree.left = bl ? this.makeLeaf(nArray2) : this.buildTree(n, n2, nArray2, shuffledIterator);
        binaryTree.right = bl2 ? this.makeLeaf(nArray3) : this.buildTree(n, n2, nArray3, shuffledIterator);
        binaryTree.value = binaryTree.left.value * (double)binaryTree.left.nSuccessors + binaryTree.right.value * (double)binaryTree.right.nSuccessors;
        binaryTree.value /= (double)binaryTree.nSuccessors;
        return binaryTree;
    }

    public BinaryTree makeLeaf(int[] nArray) {
        BinaryTree binaryTree = new BinaryTree();
        binaryTree.value = 0.0;
        binaryTree.nSuccessors = nArray.length;
        for (int i = 0; i < nArray.length; ++i) {
            binaryTree.value += this.output[nArray[i]];
        }
        binaryTree.value /= (double)nArray.length;
        return binaryTree;
    }

    public static ExtraTrees getSampleData(int n, int n2) {
        double[] dArray = new double[n];
        double[] dArray2 = new double[n * n2];
        for (int i = 0; i < dArray2.length; ++i) {
            dArray2[i] = Math.random();
        }
        Matrix matrix = new Matrix(dArray2, n, n2);
        for (int i = 0; i < dArray.length; ++i) {
            matrix.set(i, 2, 0.5);
            dArray[i] = matrix.get(i, 1) + 0.2 * matrix.get(i, 3);
        }
        ExtraTrees extraTrees = new ExtraTrees(matrix, dArray);
        return extraTrees;
    }

    public static double getMeanSqError(ArrayList<BinaryTree> arrayList, Matrix matrix, double[] dArray) {
        double d = 0.0;
        double[] dArray2 = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            for (int j = 0; j < matrix.ncols; ++j) {
                dArray2[j] = matrix.get(i, j);
            }
            d += Math.pow(ExtraTrees.getValue(arrayList, dArray2) - dArray[i], 2.0);
        }
        return d / (double)dArray.length;
    }

    public static double getMeanSqError(ArrayList<BinaryTree> arrayList, Matrix matrix, double[] dArray, int n, int[] nArray) {
        double d = 0.0;
        double[] dArray2 = new double[matrix.ncols];
        for (int i = 0; i < nArray.length; ++i) {
            int n2 = nArray[i];
            for (int j = 0; j < matrix.ncols; ++j) {
                dArray2[j] = matrix.get(n2, j);
            }
            d += Math.pow(ExtraTrees.getValue(arrayList, dArray2, n) - dArray[n2], 2.0);
        }
        return d / (double)nArray.length;
    }

    public static double getMeanAbsError(ArrayList<BinaryTree> arrayList, Matrix matrix, double[] dArray) {
        double d = 0.0;
        double[] dArray2 = ExtraTrees.getValues(arrayList, matrix);
        for (int i = 0; i < dArray.length; ++i) {
            d += Math.abs(dArray2[i] - dArray[i]);
        }
        return d / (double)dArray.length;
    }

    public ArrayList<BinaryTree> buildTreeCV(int n, int n2) {
        int n3;
        int[] nArray = new int[]{2, 3, 5, 9, 14};
        int n4 = (int)(0.6666666666666666 * (double)this.output.length);
        Integer[] integerArray = new Integer[this.output.length];
        for (int i = 0; i < integerArray.length; ++i) {
            integerArray[i] = i;
        }
        Collections.shuffle(Arrays.asList(integerArray));
        int[] nArray2 = new int[n4];
        int[] nArray3 = new int[this.output.length - n4];
        for (n3 = 0; n3 < nArray2.length; ++n3) {
            nArray2[n3] = integerArray[n3];
        }
        for (n3 = 0; n3 < nArray3.length; ++n3) {
            nArray3[n3] = integerArray[n3 + nArray2.length];
        }
        ArrayList<BinaryTree> arrayList = this.buildTrees(2, n, n2, nArray2);
        double[] dArray = new double[nArray.length];
        double d = Double.POSITIVE_INFINITY;
        int n5 = nArray[0];
        for (int i = 0; i < nArray.length; ++i) {
            dArray[i] = ExtraTrees.getMeanSqError(arrayList, this.input, this.output, nArray[i], nArray3);
            if (!(dArray[i] < d)) continue;
            n5 = nArray[i];
            d = dArray[i];
        }
        ArrayList<BinaryTree> arrayList2 = this.buildTrees(n5, n, n2);
        return arrayList2;
    }

    public static void main(String[] stringArray) {
        int n = 10000;
        int n2 = 7;
        int n3 = 15;
        ExtraTrees extraTrees = ExtraTrees.getSampleData(n, n2);
        Date date = new Date();
        Date date2 = new Date();
        System.out.println("Took: " + (double)(date2.getTime() - date.getTime()) / 1000.0 + "s");
        Date date3 = new Date();
        extraTrees.learnTrees(2, 6, n3);
        ArrayList arrayList = extraTrees.trees;
        Date date4 = new Date();
        ExtraTrees extraTrees2 = ExtraTrees.getSampleData(1000, n2);
        double[] dArray = ExtraTrees.getValues(arrayList, extraTrees2.input);
        for (int i = 0; i < extraTrees2.output.length; ++i) {
            System.out.print(String.format("%d\t%1.3f %1.3f", i, extraTrees2.output[i], dArray[i]));
            System.out.println();
        }
        System.out.println("Took: " + (double)(date4.getTime() - date3.getTime()) / 1000.0 + "s");
        int[] nArray = new int[extraTrees2.output.length];
        for (int i = 0; i < nArray.length; ++i) {
            nArray[i] = i;
        }
        double d = ExtraTrees.getMeanSqError(arrayList, extraTrees2.input, extraTrees2.output);
        double d2 = ExtraTrees.getMeanSqError(arrayList, extraTrees2.input, extraTrees2.output, 5, nArray);
        System.out.println("Error: " + d);
        System.out.println("Error: " + d2);
    }
}

