/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.Sourcable;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Randomizable;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class AdaBoostM1
extends RandomizableIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
Sourcable,
TechnicalInformationHandler {
    static final long serialVersionUID = -7378107808933117974L;
    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    protected double[] m_Betas;
    protected int m_NumIterationsPerformed;
    protected int m_WeightThreshold = 100;
    protected boolean m_UseResampling;
    protected int m_NumClasses;
    protected Classifier m_ZeroR;

    public AdaBoostM1() {
        this.m_Classifier = new DecisionStump();
    }

    public String globalInfo() {
        return "Class for boosting a nominal class classifier using the Adaboost M1 method. Only nominal class problems can be tackled. Often dramatically improves performance, but sometimes overfits.\n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Yoav Freund and Robert E. Schapire");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Experiments with a new boosting algorithm");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Thirteenth International Conference on Machine Learning");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1996");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "148-156");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "Morgan Kaufmann");
        technicalInformation.setValue(TechnicalInformation.Field.ADDRESS, "San Francisco");
        return technicalInformation;
    }

    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    protected Instances selectWeightQuantile(Instances instances, double d) {
        int n = instances.numInstances();
        Instances instances2 = new Instances(instances, n);
        double[] dArray = new double[n];
        double d2 = 0.0;
        for (int i = 0; i < n; ++i) {
            dArray[i] = instances.instance(i).weight();
            d2 += dArray[i];
        }
        double d3 = d2 * d;
        int[] nArray = Utils.sort(dArray);
        d2 = 0.0;
        for (int i = n - 1; i >= 0; --i) {
            Instance instance = (Instance)instances.instance(nArray[i]).copy();
            instances2.add(instance);
            if ((d2 += dArray[nArray[i]]) > d3 && i > 0 && dArray[nArray[i]] != dArray[nArray[i - 1]]) break;
        }
        if (this.m_Debug) {
            System.err.println("Selected " + instances2.numInstances() + " out of " + n);
        }
        return instances2;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>();
        vector.addElement(new Option("\tPercentage of weight mass to base training on.\n\t(default 100, reduce to around 90 speed up)", "P", 1, "-P <num>"));
        vector.addElement(new Option("\tUse resampling for boosting.", "Q", 0, "-Q"));
        Enumeration enumeration = super.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement((Option)enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string = Utils.getOption('P', stringArray);
        if (string.length() != 0) {
            this.setWeightThreshold(Integer.parseInt(string));
        } else {
            this.setWeightThreshold(100);
        }
        this.setUseResampling(Utils.getFlag('Q', stringArray));
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        Vector<String> vector = new Vector<String>();
        if (this.getUseResampling()) {
            vector.add("-Q");
        }
        vector.add("-P");
        vector.add("" + this.getWeightThreshold());
        String[] stringArray = super.getOptions();
        for (int i = 0; i < stringArray.length; ++i) {
            vector.add(stringArray[i]);
        }
        return vector.toArray(new String[vector.size()]);
    }

    public String weightThresholdTipText() {
        return "Weight threshold for weight pruning.";
    }

    public void setWeightThreshold(int n) {
        this.m_WeightThreshold = n;
    }

    public int getWeightThreshold() {
        return this.m_WeightThreshold;
    }

    public String useResamplingTipText() {
        return "Whether resampling is used instead of reweighting.";
    }

    public void setUseResampling(boolean bl) {
        this.m_UseResampling = bl;
    }

    public boolean getUseResampling() {
        return this.m_UseResampling;
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        if (super.getCapabilities().handles(Capabilities.Capability.NOMINAL_CLASS)) {
            capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        }
        if (super.getCapabilities().handles(Capabilities.Capability.BINARY_CLASS)) {
            capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        }
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        super.buildClassifier(instances);
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        if (instances.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(instances);
            return;
        }
        this.m_ZeroR = null;
        this.m_NumClasses = instances.numClasses();
        if (!this.m_UseResampling && this.m_Classifier instanceof WeightedInstancesHandler) {
            this.buildClassifierWithWeights(instances);
        } else {
            this.buildClassifierUsingResampling(instances);
        }
    }

    protected void buildClassifierUsingResampling(Instances instances) throws Exception {
        int n = instances.numInstances();
        Random random = new Random(this.m_Seed);
        int n2 = 0;
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterationsPerformed = 0;
        Instances instances2 = new Instances(instances, 0, n);
        double d = instances2.sumOfWeights();
        for (int i = 0; i < instances2.numInstances(); ++i) {
            instances2.instance(i).setWeight(instances2.instance(i).weight() / d);
        }
        this.m_NumIterationsPerformed = 0;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            Evaluation evaluation;
            double d2;
            if (this.m_Debug) {
                System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
            }
            Instances instances3 = this.m_WeightThreshold < 100 ? this.selectWeightQuantile(instances2, (double)this.m_WeightThreshold / 100.0) : new Instances(instances2);
            n2 = 0;
            double[] dArray = new double[instances3.numInstances()];
            for (int i = 0; i < dArray.length; ++i) {
                dArray[i] = instances3.instance(i).weight();
            }
            do {
                Instances instances4 = instances3.resampleWithWeights(random, dArray);
                this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(instances4);
                evaluation = new Evaluation(instances);
                evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterationsPerformed], instances2);
            } while (Utils.eq(d2 = evaluation.errorRate(), 0.0) && ++n2 < MAX_NUM_RESAMPLING_ITERATIONS);
            if (Utils.grOrEq(d2, 0.5) || Utils.eq(d2, 0.0)) {
                if (this.m_NumIterationsPerformed != 0) break;
                this.m_NumIterationsPerformed = 1;
                break;
            }
            this.m_Betas[this.m_NumIterationsPerformed] = Math.log((1.0 - d2) / d2);
            double d3 = (1.0 - d2) / d2;
            if (this.m_Debug) {
                System.err.println("\terror rate = " + d2 + "  beta = " + this.m_Betas[this.m_NumIterationsPerformed]);
            }
            this.setWeights(instances2, d3);
            ++this.m_NumIterationsPerformed;
        }
    }

    protected void setWeights(Instances instances, double d) throws Exception {
        Instance instance;
        double d2 = instances.sumOfWeights();
        Enumeration enumeration = instances.enumerateInstances();
        while (enumeration.hasMoreElements()) {
            instance = (Instance)enumeration.nextElement();
            if (Utils.eq(this.m_Classifiers[this.m_NumIterationsPerformed].classifyInstance(instance), instance.classValue())) continue;
            instance.setWeight(instance.weight() * d);
        }
        double d3 = instances.sumOfWeights();
        enumeration = instances.enumerateInstances();
        while (enumeration.hasMoreElements()) {
            instance = (Instance)enumeration.nextElement();
            instance.setWeight(instance.weight() * d2 / d3);
        }
    }

    protected void buildClassifierWithWeights(Instances instances) throws Exception {
        int n = instances.numInstances();
        Random random = new Random(this.m_Seed);
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterationsPerformed = 0;
        Instances instances2 = new Instances(instances, 0, n);
        this.m_NumIterationsPerformed = 0;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
            }
            Instances instances3 = this.m_WeightThreshold < 100 ? this.selectWeightQuantile(instances2, (double)this.m_WeightThreshold / 100.0) : new Instances(instances2, 0, n);
            if (this.m_Classifiers[this.m_NumIterationsPerformed] instanceof Randomizable) {
                ((Randomizable)((Object)this.m_Classifiers[this.m_NumIterationsPerformed])).setSeed(random.nextInt());
            }
            this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(instances3);
            Evaluation evaluation = new Evaluation(instances);
            evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterationsPerformed], instances2);
            double d = evaluation.errorRate();
            if (Utils.grOrEq(d, 0.5) || Utils.eq(d, 0.0)) {
                if (this.m_NumIterationsPerformed != 0) break;
                this.m_NumIterationsPerformed = 1;
                break;
            }
            this.m_Betas[this.m_NumIterationsPerformed] = Math.log((1.0 - d) / d);
            double d2 = (1.0 - d) / d;
            if (this.m_Debug) {
                System.err.println("\terror rate = " + d + "  beta = " + this.m_Betas[this.m_NumIterationsPerformed]);
            }
            this.setWeights(instances2, d2);
            ++this.m_NumIterationsPerformed;
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_ZeroR != null) {
            return this.m_ZeroR.distributionForInstance(instance);
        }
        if (this.m_NumIterationsPerformed == 0) {
            throw new Exception("No model built");
        }
        double[] dArray = new double[instance.numClasses()];
        if (this.m_NumIterationsPerformed == 1) {
            return this.m_Classifiers[0].distributionForInstance(instance);
        }
        for (int i = 0; i < this.m_NumIterationsPerformed; ++i) {
            int n = (int)this.m_Classifiers[i].classifyInstance(instance);
            dArray[n] = dArray[n] + this.m_Betas[i];
        }
        return Utils.logs2probs(dArray);
    }

    public String toSource(String string) throws Exception {
        int n;
        if (this.m_NumIterationsPerformed == 0) {
            throw new Exception("No model built yet");
        }
        if (!(this.m_Classifiers[0] instanceof Sourcable)) {
            throw new Exception("Base learner " + this.m_Classifier.getClass().getName() + " is not Sourcable");
        }
        StringBuffer stringBuffer = new StringBuffer("class ");
        stringBuffer.append(string).append(" {\n\n");
        stringBuffer.append("  public static double classify(Object [] i) {\n");
        if (this.m_NumIterationsPerformed == 1) {
            stringBuffer.append("    return " + string + "_0.classify(i);\n");
        } else {
            stringBuffer.append("    double [] sums = new double [" + this.m_NumClasses + "];\n");
            for (n = 0; n < this.m_NumIterationsPerformed; ++n) {
                stringBuffer.append("    sums[(int) " + string + '_' + n + ".classify(i)] += " + this.m_Betas[n] + ";\n");
            }
            stringBuffer.append("    double maxV = sums[0];\n    int maxI = 0;\n    for (int j = 1; j < " + this.m_NumClasses + "; j++) {\n" + "      if (sums[j] > maxV) { maxV = sums[j]; maxI = j; }\n" + "    }\n    return (double) maxI;\n");
        }
        stringBuffer.append("  }\n}\n");
        for (n = 0; n < this.m_Classifiers.length; ++n) {
            stringBuffer.append(((Sourcable)((Object)this.m_Classifiers[n])).toSource(string + '_' + n));
        }
        return stringBuffer.toString();
    }

    public String toString() {
        if (this.m_ZeroR != null) {
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
            stringBuffer.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
            stringBuffer.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            stringBuffer.append(this.m_ZeroR.toString());
            return stringBuffer.toString();
        }
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_NumIterationsPerformed == 0) {
            stringBuffer.append("AdaBoostM1: No model built yet.\n");
        } else if (this.m_NumIterationsPerformed == 1) {
            stringBuffer.append("AdaBoostM1: No boosting possible, one classifier used!\n");
            stringBuffer.append(this.m_Classifiers[0].toString() + "\n");
        } else {
            stringBuffer.append("AdaBoostM1: Base classifiers and their weights: \n\n");
            for (int i = 0; i < this.m_NumIterationsPerformed; ++i) {
                stringBuffer.append(this.m_Classifiers[i].toString() + "\n\n");
                stringBuffer.append("Weight: " + Utils.roundDouble(this.m_Betas[i], 2) + "\n\n");
            }
            stringBuffer.append("Number of performed Iterations: " + this.m_NumIterationsPerformed + "\n");
        }
        return stringBuffer.toString();
    }

    public static void main(String[] stringArray) {
        AdaBoostM1.runClassifier(new AdaBoostM1(), stringArray);
    }
}

