/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import java.math.BigInteger;
import java.util.Arrays;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.classifiers.trees.LimAttHoeffdingTree;
import moa.core.Measurement;
import moa.options.ClassOption;
import moa.options.FlagOption;
import moa.options.FloatOption;
import moa.options.IntOption;
import weka.core.Instance;
import weka.core.Utils;

public class LimAttClassifier
extends AbstractClassifier {
    private static final long serialVersionUID = 1L;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.LimAttHoeffdingTree");
    public IntOption numAttributesOption = new IntOption("numAttributes", 'n', "The number of attributes to use per model.", 1, 1, Integer.MAX_VALUE);
    public FloatOption weightShrinkOption = new FloatOption("weightShrink", 'w', "The number to multiply the weight misclassified counts.", 0.5, 0.0, 3.4028234663852886E38);
    public FloatOption deltaAdwinOption = new FloatOption("deltaAdwin", 'a', "Delta of Adwin change detection", 0.002, 0.0, 1.0);
    public FloatOption oddsOffsetOption = new FloatOption("oddsOffset", 'o', "Offset for odds to avoid probabilities that are zero.", 0.001, 0.0, 3.4028234663852886E38);
    public FlagOption pruneOption = new FlagOption("prune", 'x', "Enable pruning.");
    public FlagOption bigTreesOption = new FlagOption("bigTrees", 'b', "Use m-n attributes on the trees.");
    public IntOption numEnsemblePruningOption = new IntOption("numEnsemblePruning", 'm', "The pruned number of classifiers to use to predict.", 10, 1, Integer.MAX_VALUE);
    public FlagOption adwinReplaceWorstClassifierOption = new FlagOption("adwinReplaceWorstClassifier", 'z', "When one Adwin detects change, replace worst classifier.");
    protected Classifier[] ensemble;
    protected ADWIN[] ADError;
    protected int numberOfChangesDetected;
    protected int[][] matrixCodes;
    protected boolean initMatrixCodes = false;
    protected boolean initClassifiers = false;
    protected int numberAttributes = 1;
    protected int numInstances = 0;
    public FloatOption learningRatioOption = new FloatOption("learningRatio", 'e', "Learning ratio", 1.0);
    public FloatOption penaltyFactorOption = new FloatOption("lambda", 'p', "Lambda", 0.0);
    public IntOption initialNumInstancesOption = new IntOption("initialNumInstances", 'i', "initialNumInstances", 10);
    protected double[][] weightAttribute;
    protected boolean reset;

    public String getPurposeString() {
        return "Ensemble Combining Restricted Hoeffding Trees using Stacking";
    }

    public void resetLearningImpl() {
        this.initClassifiers = true;
        this.reset = true;
    }

    public void trainOnInstanceImpl(Instance inst) {
        int i;
        int numClasses = inst.numClasses();
        if (this.initClassifiers) {
            this.numberAttributes = this.numAttributesOption.getValue();
            if (this.bigTreesOption.isSet()) {
                this.numberAttributes = inst.numAttributes() - 1 - this.numAttributesOption.getValue();
            }
            CombinationGenerator x = new CombinationGenerator(inst.numAttributes() - 1, this.numberAttributes);
            int numberClassifiers = x.getTotal().intValue();
            this.ensemble = new Classifier[numberClassifiers];
            Classifier baseLearner = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
            baseLearner.resetLearning();
            for (i = 0; i < this.ensemble.length; ++i) {
                this.ensemble[i] = baseLearner.copy();
            }
            this.ADError = new ADWIN[this.ensemble.length];
            for (i = 0; i < this.ensemble.length; ++i) {
                this.ADError[i] = new ADWIN(this.deltaAdwinOption.getValue());
            }
            this.numberOfChangesDetected = 0;
            i = 0;
            if (baseLearner instanceof LimAttHoeffdingTree) {
                while (x.hasMore()) {
                    ((LimAttHoeffdingTree)this.ensemble[i]).setlistAttributes(x.getNext());
                    ++i;
                }
            }
            this.initClassifiers = false;
        }
        boolean Change = false;
        Instance weightedInst = (Instance)inst.copy();
        double[][] votes = new double[this.ensemble.length + 1][numClasses];
        for (i = 0; i < this.ensemble.length; ++i) {
            int j;
            double[] v = new double[numClasses];
            for (int j2 = 0; j2 < v.length; ++j2) {
                v[j2] = this.oddsOffsetOption.getValue();
            }
            double[] vt = this.ensemble[i].getVotesForInstance(inst);
            double sum = Utils.sum(vt);
            if (!Double.isNaN(sum) && sum > 0.0) {
                j = 0;
                while (j < vt.length) {
                    int n = j++;
                    vt[n] = vt[n] / sum;
                }
            } else {
                for (int k = 0; k < vt.length; ++k) {
                    vt[k] = 0.0;
                }
            }
            sum = (double)numClasses * this.oddsOffsetOption.getValue();
            for (j = 0; j < vt.length; ++j) {
                int n = j;
                v[n] = v[n] + vt[j];
                sum += vt[j];
            }
            for (j = 0; j < vt.length; ++j) {
                votes[i][j] = Math.log(v[j] / (sum - v[j]));
            }
        }
        if (!this.adwinReplaceWorstClassifierOption.isSet()) {
            for (i = 0; i < this.ensemble.length; ++i) {
                boolean correctlyClassifies = this.ensemble[i].correctlyClassifies(weightedInst);
                double ErrEstim = this.ADError[i].getEstimation();
                if (!this.ADError[i].setInput(correctlyClassifies ? 0.0 : 1.0)) continue;
                this.numInstances = this.initialNumInstancesOption.getValue();
                if (!(this.ADError[i].getEstimation() > ErrEstim)) continue;
                Change = true;
                ++this.numberOfChangesDetected;
                this.ensemble[i].resetLearning();
                this.ADError[i] = new ADWIN(this.deltaAdwinOption.getValue());
                for (int ii = 0; ii < inst.numClasses(); ++ii) {
                    this.weightAttribute[ii][i] = 0.0;
                }
            }
        } else {
            for (i = 0; i < this.ensemble.length; ++i) {
                boolean correctlyClassifies = this.ensemble[i].correctlyClassifies(weightedInst);
                double ErrEstim = this.ADError[i].getEstimation();
                if (!this.ADError[i].setInput(correctlyClassifies ? 0.0 : 1.0) || !(this.ADError[i].getEstimation() > ErrEstim)) continue;
                Change = true;
            }
            if (Change) {
                ++this.numberOfChangesDetected;
                double max = 0.0;
                int imax = -1;
                for (int i2 = 0; i2 < this.ensemble.length; ++i2) {
                    if (!(max < this.ADError[i2].getEstimation())) continue;
                    max = this.ADError[i2].getEstimation();
                    imax = i2;
                }
                if (imax != -1) {
                    this.ensemble[imax].resetLearning();
                    this.ADError[imax] = new ADWIN(this.deltaAdwinOption.getValue());
                    for (int ii = 0; ii < inst.numClasses(); ++ii) {
                        this.weightAttribute[ii][imax] = 0.0;
                    }
                }
            }
        }
        this.trainOnInstanceImplPerceptron(inst.numClasses(), (int)inst.classValue(), votes);
        for (int i3 = 0; i3 < this.ensemble.length; ++i3) {
            this.ensemble[i3].trainOnInstance(inst);
        }
    }

    public double[] getVotesForInstance(Instance inst) {
        if (this.initClassifiers) {
            return new double[0];
        }
        int numClasses = inst.numClasses();
        int sizeEnsemble = this.ensemble.length;
        if (this.pruneOption.isSet()) {
            sizeEnsemble = this.numEnsemblePruningOption.getValue();
        }
        double[][] votes = new double[sizeEnsemble + 1][numClasses];
        int[] bestClassifiers = new int[sizeEnsemble];
        if (this.pruneOption.isSet()) {
            double[] weight = new double[this.ensemble.length];
            for (int i = 0; i < numClasses; ++i) {
                for (int j = 0; j < this.ensemble.length; ++j) {
                    int n = j;
                    weight[n] = weight[n] + this.weightAttribute[i][j];
                }
            }
            Arrays.sort(weight);
            double cutValue = weight[this.ensemble.length - sizeEnsemble];
            int ii = 0;
            for (int j = 0; j < this.ensemble.length; ++j) {
                if (!(weight[j] >= cutValue) || ii >= sizeEnsemble) continue;
                bestClassifiers[ii] = j;
                ++ii;
            }
        } else {
            for (int ii = 0; ii < sizeEnsemble; ++ii) {
                bestClassifiers[ii] = ii;
            }
        }
        for (int ii = 0; ii < sizeEnsemble; ++ii) {
            int j;
            int i = bestClassifiers[ii];
            double[] v = new double[numClasses];
            for (int j2 = 0; j2 < v.length; ++j2) {
                v[j2] = this.oddsOffsetOption.getValue();
            }
            double[] vt = this.ensemble[i].getVotesForInstance(inst);
            double sum = Utils.sum(vt);
            if (!Double.isNaN(sum) && sum > 0.0) {
                j = 0;
                while (j < vt.length) {
                    int n = j++;
                    vt[n] = vt[n] / sum;
                }
            } else {
                for (int k = 0; k < vt.length; ++k) {
                    vt[k] = 0.0;
                }
            }
            sum = (double)numClasses * this.oddsOffsetOption.getValue();
            for (j = 0; j < vt.length; ++j) {
                int n = j;
                v[n] = v[n] + vt[j];
                sum += vt[j];
            }
            for (j = 0; j < vt.length; ++j) {
                votes[ii][j] = Math.log(v[j] / (sum - v[j]));
            }
        }
        return this.getVotesForInstancePerceptron(votes, bestClassifiers, inst.numClasses());
    }

    public boolean isRandomizable() {
        return true;
    }

    public void getModelDescription(StringBuilder out, int indent) {
    }

    protected Measurement[] getModelMeasurementsImpl() {
        return new Measurement[]{new Measurement("ensemble size", this.ensemble != null ? (double)this.ensemble.length : 0.0), new Measurement("change detections", this.numberOfChangesDetected)};
    }

    public Classifier[] getSubClassifiers() {
        return (Classifier[])this.ensemble.clone();
    }

    public void trainOnInstanceImplPerceptron(int numClasses, int actualClass, double[][] votes) {
        int i;
        if (this.reset) {
            this.reset = false;
            this.weightAttribute = new double[numClasses][votes.length];
            for (int i2 = 0; i2 < numClasses; ++i2) {
                for (int j = 0; j < votes.length - 1; ++j) {
                    this.weightAttribute[i2][j] = 1.0 / ((double)votes.length - 1.0);
                }
            }
            this.numInstances = this.initialNumInstancesOption.getValue();
        }
        double learningRatio = this.learningRatioOption.getValue() * 2.0 / ((double)(this.numInstances + (votes.length - 1)) + 2.0);
        double lambda = this.penaltyFactorOption.getValue();
        ++this.numInstances;
        double[] preds = new double[numClasses];
        for (i = 0; i < numClasses; ++i) {
            preds[i] = this.prediction(votes, i);
        }
        for (i = 0; i < numClasses; ++i) {
            double actual = i == actualClass ? 1.0 : 0.0;
            double delta = (actual - preds[i]) * preds[i] * (1.0 - preds[i]);
            for (int j = 0; j < this.ensemble.length; ++j) {
                double[] dArray = this.weightAttribute[i];
                int n = j;
                dArray[n] = dArray[n] + learningRatio * (delta * votes[j][i] - lambda * this.weightAttribute[i][j]);
            }
            double[] dArray = this.weightAttribute[i];
            int n = this.ensemble.length;
            dArray[n] = dArray[n] + learningRatio * delta;
        }
    }

    public double predictionPruning(double[][] votes, int[] bestClassifiers, int classVal) {
        double sum = 0.0;
        for (int i = 0; i < votes.length - 1; ++i) {
            sum += this.weightAttribute[classVal][bestClassifiers[i]] * votes[i][classVal];
        }
        return 1.0 / (1.0 + Math.exp(-(sum += this.weightAttribute[classVal][votes.length - 1])));
    }

    public double prediction(double[][] votes, int classVal) {
        double sum = 0.0;
        for (int i = 0; i < votes.length - 1; ++i) {
            sum += this.weightAttribute[classVal][i] * votes[i][classVal];
        }
        return 1.0 / (1.0 + Math.exp(-(sum += this.weightAttribute[classVal][votes.length - 1])));
    }

    public double[] getVotesForInstancePerceptron(double[][] votesEnsemble, int[] bestClassifiers, int numClasses) {
        double[] votes = new double[numClasses];
        if (!this.reset) {
            for (int i = 0; i < votes.length; ++i) {
                votes[i] = this.predictionPruning(votesEnsemble, bestClassifiers, i);
            }
        }
        return votes;
    }

    public class CombinationGenerator {
        private int[] a;
        private int n;
        private int r;
        private BigInteger numLeft;
        private BigInteger total;

        public CombinationGenerator(int n, int r) {
            if (r > n) {
                throw new IllegalArgumentException();
            }
            if (n < 1) {
                throw new IllegalArgumentException();
            }
            this.n = n;
            this.r = r;
            this.a = new int[r];
            BigInteger nFact = this.getFactorial(n);
            BigInteger rFact = this.getFactorial(r);
            BigInteger nminusrFact = this.getFactorial(n - r);
            this.total = nFact.divide(rFact.multiply(nminusrFact));
            this.reset();
        }

        public void reset() {
            for (int i = 0; i < this.a.length; ++i) {
                this.a[i] = i;
            }
            this.numLeft = new BigInteger(this.total.toString());
        }

        public BigInteger getNumLeft() {
            return this.numLeft;
        }

        public boolean hasMore() {
            return this.numLeft.compareTo(BigInteger.ZERO) == 1;
        }

        public BigInteger getTotal() {
            return this.total;
        }

        private BigInteger getFactorial(int n) {
            BigInteger fact = BigInteger.ONE;
            for (int i = n; i > 1; --i) {
                fact = fact.multiply(new BigInteger(Integer.toString(i)));
            }
            return fact;
        }

        public int[] getNext() {
            if (this.numLeft.equals(this.total)) {
                this.numLeft = this.numLeft.subtract(BigInteger.ONE);
                int[] b = new int[this.a.length];
                for (int k = 0; k < this.a.length; ++k) {
                    b[k] = this.a[k];
                }
                return b;
            }
            int i = this.r - 1;
            while (this.a[i] == this.n - this.r + i) {
                --i;
            }
            this.a[i] = this.a[i] + 1;
            for (int j = i + 1; j < this.r; ++j) {
                this.a[j] = this.a[i] + j - i;
            }
            this.numLeft = this.numLeft.subtract(BigInteger.ONE);
            int[] b = new int[this.a.length];
            for (int k = 0; k < this.a.length; ++k) {
                b[k] = this.a[k];
            }
            return b;
        }
    }
}

