/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.lazy;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Evaluation;
import weka.classifiers.SingleClassifierEnhancer;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.trees.DecisionStump;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class LWL
extends SingleClassifierEnhancer
implements UpdateableClassifier,
WeightedInstancesHandler {
    protected Instances m_Train;
    protected double[] m_Min;
    protected double[] m_Max;
    protected boolean m_NoAttribNorm = false;
    protected int m_kNN = -1;
    protected int m_WeightKernel = 0;
    protected boolean m_UseAllK = true;
    protected static final int LINEAR = 0;
    protected static final int EPANECHNIKOV = 1;
    protected static final int TRICUBE = 2;
    protected static final int INVERSE = 3;
    protected static final int GAUSS = 4;
    protected static final int CONSTANT = 5;

    public String globalInfo() {
        return "Class for performing locally weighted learning. Can do classification (e.g. using naive Bayes) or regression (e.g. using linear regression). The base learner needs to implement WeightedInstancesHandler. For more info, see\n\nEibe Frank, Mark Hall, and Bernhard Pfahringer (2003). \"Locally Weighted Naive Bayes\". Conference on Uncertainty in AI.\n\nAtkeson, C., A. Moore, and S. Schaal (1996) \"Locally weighted learning\" AI Reviews.";
    }

    public LWL() {
        this.m_Classifier = new DecisionStump();
    }

    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(3);
        vector.addElement(new Option("\tDo not normalize numeric attributes' values in distance calculation.\n\t(default DO normalization)", "N", 0, "-N"));
        vector.addElement(new Option("\tSet the number of neighbours used to set the kernel bandwidth.\n\t(default all)", "K", 1, "-K <number of neighbours>"));
        vector.addElement(new Option("\tSet the weighting kernel shape to use. 0=Linear, 1=Epanechnikov,\n\t2=Tricube, 3=Inverse, 4=Gaussian.\n\t(default 0 = Linear)", "U", 1, "-U <number of weighting method>"));
        Enumeration enumeration = super.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement((Option)enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string = Utils.getOption('K', stringArray);
        if (string.length() != 0) {
            this.setKNN(Integer.parseInt(string));
        } else {
            this.setKNN(0);
        }
        String string2 = Utils.getOption('U', stringArray);
        if (string2.length() != 0) {
            this.setWeightingKernel(Integer.parseInt(string2));
        } else {
            this.setWeightingKernel(0);
        }
        this.setDontNormalize(Utils.getFlag('N', stringArray));
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray = super.getOptions();
        String[] stringArray2 = new String[stringArray.length + 5];
        int n = 0;
        stringArray2[n++] = "-U";
        stringArray2[n++] = "" + this.getWeightingKernel();
        stringArray2[n++] = "-K";
        stringArray2[n++] = "" + this.getKNN();
        stringArray2[n++] = this.getDontNormalize() ? "-N" : "";
        System.arraycopy(stringArray, 0, stringArray2, n, stringArray.length);
        return stringArray2;
    }

    public String KNNTipText() {
        return "How many neighbours are used to determine the width of the weighting function (<= 0 means all neighbours).";
    }

    public void setKNN(int n) {
        this.m_kNN = n;
        if (n <= 0) {
            this.m_kNN = 0;
            this.m_UseAllK = true;
        } else {
            this.m_UseAllK = false;
        }
    }

    public int getKNN() {
        return this.m_kNN;
    }

    public String weightingKernelTipText() {
        return "Determines weighting function. [0 = Linear, 1 = Epnechnikov,2 = Tricube, 3 = Inverse, 4 = Gaussian and 5 = Constant. (default 0 = Linear)].";
    }

    public void setWeightingKernel(int n) {
        if (n != 0 && n != 1 && n != 2 && n != 3 && n != 4 && n != 5) {
            return;
        }
        this.m_WeightKernel = n;
    }

    public int getWeightingKernel() {
        return this.m_WeightKernel;
    }

    public String dontNormalizeTipText() {
        return "Turns off normalization for attribute values in distance calculation.";
    }

    public boolean getDontNormalize() {
        return this.m_NoAttribNorm;
    }

    public void setDontNormalize(boolean bl) {
        this.m_NoAttribNorm = bl;
    }

    protected double getAttributeMin(int n) {
        return this.m_Min[n];
    }

    protected double getAttributeMax(int n) {
        return this.m_Max[n];
    }

    public void buildClassifier(Instances instances) throws Exception {
        int n;
        if (!(this.m_Classifier instanceof WeightedInstancesHandler)) {
            throw new IllegalArgumentException("Classifier must be a WeightedInstancesHandler!");
        }
        if (instances.classIndex() < 0) {
            throw new Exception("No class attribute assigned to instances");
        }
        if (instances.checkForStringAttributes()) {
            throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
        }
        this.m_Train = new Instances(instances, 0, instances.numInstances());
        this.m_Train.deleteWithMissingClass();
        this.m_Min = new double[this.m_Train.numAttributes()];
        this.m_Max = new double[this.m_Train.numAttributes()];
        for (n = 0; n < this.m_Train.numAttributes(); ++n) {
            this.m_Max[n] = Double.NaN;
            this.m_Min[n] = Double.NaN;
        }
        for (n = 0; n < this.m_Train.numInstances(); ++n) {
            this.updateMinMax(this.m_Train.instance(n));
        }
    }

    public void updateClassifier(Instance instance) throws Exception {
        if (!this.m_Train.equalHeaders(instance.dataset())) {
            throw new Exception("Incompatible instance types");
        }
        if (!instance.classIsMissing()) {
            this.updateMinMax(instance);
            this.m_Train.add(instance);
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        double d;
        int n;
        int n2;
        double d2;
        MyHeap myHeap;
        if (this.m_Train.numInstances() == 0) {
            throw new Exception("No training instances!");
        }
        this.updateMinMax(instance);
        double[] dArray = new double[this.m_Train.numInstances()];
        int n3 = dArray.length - 1;
        if (!this.m_UseAllK && this.m_kNN < n3) {
            n3 = this.m_kNN;
            myHeap = new MyHeap(n3);
        } else {
            myHeap = new MyHeap(dArray.length);
        }
        int n4 = 0;
        for (int i = 0; i < this.m_Train.numInstances(); ++i) {
            switch (this.m_WeightKernel) {
                case 0: 
                case 1: 
                case 2: {
                    if (n4 < n3) {
                        dArray[i] = this.distance(instance, this.m_Train.instance(i));
                        myHeap.put(i, dArray[i]);
                        break;
                    }
                    MyHeapElement myHeapElement = myHeap.peek();
                    dArray[i] = this.distance(instance, this.m_Train.instance(i), myHeapElement.distance);
                    if (!(dArray[i] < myHeapElement.distance)) break;
                    myHeap.get();
                    myHeap.put(i, dArray[i]);
                    break;
                }
                default: {
                    dArray[i] = this.distance(instance, this.m_Train.instance(i));
                }
            }
            ++n4;
        }
        int[] nArray = Utils.sort(dArray);
        if (this.m_Debug) {
            System.out.println("Instance Distances");
            for (n4 = 0; n4 < nArray.length; ++n4) {
                System.out.println("" + dArray[nArray[n4]]);
            }
        }
        if ((d2 = dArray[nArray[n3 - 1]]) <= 0.0) {
            for (n2 = n3; n2 < nArray.length; ++n2) {
                if (!(dArray[nArray[n2]] > d2)) continue;
                d2 = dArray[nArray[n2]];
                break;
            }
            if (d2 <= 0.0) {
                throw new Exception("All training instances coincide with test instance!");
            }
        }
        for (n2 = 0; n2 < dArray.length; ++n2) {
            dArray[n2] = dArray[n2] / d2;
        }
        block15: for (n2 = 0; n2 < dArray.length; ++n2) {
            switch (this.m_WeightKernel) {
                case 0: {
                    dArray[n2] = Math.max(1.0001 - dArray[n2], 0.0);
                    continue block15;
                }
                case 1: {
                    if (dArray[n2] <= 1.0) {
                        dArray[n2] = 0.75 * (1.0001 - dArray[n2] * dArray[n2]);
                        continue block15;
                    }
                    dArray[n2] = 0.0;
                    continue block15;
                }
                case 2: {
                    if (dArray[n2] <= 1.0) {
                        dArray[n2] = Math.pow(1.0001 - Math.pow(dArray[n2], 3.0), 3.0);
                        continue block15;
                    }
                    dArray[n2] = 0.0;
                    continue block15;
                }
                case 5: {
                    if (dArray[n2] <= 1.0) {
                        dArray[n2] = 1.0;
                        continue block15;
                    }
                    dArray[n2] = 0.0;
                    continue block15;
                }
                case 3: {
                    dArray[n2] = 1.0 / (1.0 + dArray[n2]);
                    continue block15;
                }
                case 4: {
                    dArray[n2] = Math.exp(-dArray[n2] * dArray[n2]);
                }
            }
        }
        if (this.m_Debug) {
            System.out.println("Instance Weights");
            for (n2 = 0; n2 < nArray.length; ++n2) {
                System.out.println("" + dArray[nArray[n2]]);
            }
        }
        Instances instances = new Instances(this.m_Train, 0);
        double d3 = 0.0;
        double d4 = 0.0;
        for (n = 0; n < nArray.length && !((d = dArray[nArray[n]]) < 1.0E-20); ++n) {
            Instance instance2 = (Instance)this.m_Train.instance(nArray[n]).copy();
            d3 += instance2.weight();
            d4 += instance2.weight() * d;
            instance2.setWeight(instance2.weight() * d);
            instances.add(instance2);
        }
        if (this.m_Debug) {
            System.out.println("Kept " + instances.numInstances() + " out of " + this.m_Train.numInstances() + " instances");
        }
        for (n = 0; n < instances.numInstances(); ++n) {
            Instance instance3 = instances.instance(n);
            instance3.setWeight(instance3.weight() * d3 / d4);
        }
        this.m_Classifier.buildClassifier(instances);
        if (this.m_Debug) {
            System.out.println("Classifying test instance: " + instance);
            System.out.println("Built base classifier:\n" + this.m_Classifier.toString());
        }
        return this.m_Classifier.distributionForInstance(instance);
    }

    public String toString() {
        if (this.m_Train == null) {
            return "Locally weighted learning: No model built yet.";
        }
        String string = "Locally weighted learning\n===========================\n";
        string = string + "Using classifier: " + this.m_Classifier.getClass().getName() + "\n";
        switch (this.m_WeightKernel) {
            case 0: {
                string = string + "Using linear weighting kernels\n";
                break;
            }
            case 1: {
                string = string + "Using epanechnikov weighting kernels\n";
                break;
            }
            case 2: {
                string = string + "Using tricube weighting kernels\n";
                break;
            }
            case 3: {
                string = string + "Using inverse-distance weighting kernels\n";
                break;
            }
            case 4: {
                string = string + "Using gaussian weighting kernels\n";
                break;
            }
            case 5: {
                string = string + "Using constant weighting kernels\n";
            }
        }
        string = string + "Using " + (this.m_UseAllK ? "all" : "" + this.m_kNN) + " neighbours";
        return string;
    }

    private double distance(Instance instance, Instance instance2) throws Exception {
        return this.distance(instance, instance2, Math.sqrt(Double.MAX_VALUE));
    }

    private double distance(Instance instance, Instance instance2, double d) throws Exception {
        return this.euclideanDistance(instance, instance2, d);
    }

    private double euclideanDistance(Instance instance, Instance instance2, double d) {
        double d2 = 0.0;
        d *= d;
        int n = 0;
        int n2 = 0;
        while (n < instance.numValues() || n2 < instance2.numValues()) {
            double d3;
            int n3 = n >= instance.numValues() ? this.m_Train.numAttributes() : instance.index(n);
            int n4 = n2 >= instance2.numValues() ? this.m_Train.numAttributes() : instance2.index(n2);
            if (n3 == this.m_Train.classIndex()) {
                ++n;
                continue;
            }
            if (n4 == this.m_Train.classIndex()) {
                ++n2;
                continue;
            }
            if (n3 == n4) {
                d3 = this.difference(n3, instance.valueSparse(n), instance2.valueSparse(n2));
                ++n;
                ++n2;
            } else if (n3 > n4) {
                d3 = this.difference(n4, 0.0, instance2.valueSparse(n2));
                ++n2;
            } else {
                d3 = this.difference(n3, instance.valueSparse(n), 0.0);
                ++n;
            }
            if (!((d2 += d3 * d3) > d)) continue;
            return Double.MAX_VALUE;
        }
        d2 = Math.sqrt(d2);
        return d2;
    }

    private double difference(int n, double d, double d2) {
        switch (this.m_Train.attribute(n).type()) {
            case 1: {
                if (Instance.isMissingValue(d) || Instance.isMissingValue(d2) || (int)d != (int)d2) {
                    return 1.0;
                }
                return 0.0;
            }
            case 0: {
                if (Instance.isMissingValue(d) || Instance.isMissingValue(d2)) {
                    double d3;
                    if (Instance.isMissingValue(d) && Instance.isMissingValue(d2)) {
                        if (!this.m_NoAttribNorm) {
                            return 1.0;
                        }
                        return this.m_Max[n] - this.m_Min[n];
                    }
                    if (Instance.isMissingValue(d2)) {
                        d3 = !this.m_NoAttribNorm ? this.norm(d, n) : d;
                    } else {
                        double d4 = d3 = !this.m_NoAttribNorm ? this.norm(d2, n) : d2;
                    }
                    if (!this.m_NoAttribNorm && d3 < 0.5) {
                        d3 = 1.0 - d3;
                    } else if (this.m_NoAttribNorm) {
                        if (this.m_Max[n] - d3 > d3 - this.m_Min[n]) {
                            return this.m_Max[n] - d3;
                        }
                        return d3 - this.m_Min[n];
                    }
                    return d3;
                }
                return !this.m_NoAttribNorm ? this.norm(d, n) - this.norm(d2, n) : d - d2;
            }
        }
        return 0.0;
    }

    private double norm(double d, int n) {
        if (Double.isNaN(this.m_Min[n]) || Utils.eq(this.m_Max[n], this.m_Min[n])) {
            return 0.0;
        }
        return (d - this.m_Min[n]) / (this.m_Max[n] - this.m_Min[n]);
    }

    private void updateMinMax(Instance instance) {
        for (int i = 0; i < this.m_Train.numAttributes(); ++i) {
            if (instance.isMissing(i)) continue;
            if (Double.isNaN(this.m_Min[i])) {
                this.m_Min[i] = instance.value(i);
                this.m_Max[i] = instance.value(i);
                continue;
            }
            if (instance.value(i) < this.m_Min[i]) {
                this.m_Min[i] = instance.value(i);
                continue;
            }
            if (!(instance.value(i) > this.m_Max[i])) continue;
            this.m_Max[i] = instance.value(i);
        }
    }

    public static void main(String[] stringArray) {
        try {
            System.out.println(Evaluation.evaluateModel(new LWL(), stringArray));
        }
        catch (Exception exception) {
            exception.printStackTrace();
            System.err.println(exception.getMessage());
        }
    }

    private class MyHeapElement {
        int index;
        double distance;

        public MyHeapElement(int n, double d) {
            this.distance = d;
            this.index = n;
        }
    }

    private class MyHeap {
        MyHeapElement[] m_heap = null;

        public MyHeap(int n) {
            if (n % 2 == 0) {
                ++n;
            }
            this.m_heap = new MyHeapElement[n + 1];
            this.m_heap[0] = new MyHeapElement(0, 0.0);
        }

        public int size() {
            return this.m_heap[0].index;
        }

        public MyHeapElement peek() {
            return this.m_heap[1];
        }

        public MyHeapElement get() throws Exception {
            if (this.m_heap[0].index == 0) {
                throw new Exception("No elements present in the heap");
            }
            MyHeapElement myHeapElement = this.m_heap[1];
            this.m_heap[1] = this.m_heap[this.m_heap[0].index];
            --this.m_heap[0].index;
            this.downheap();
            return myHeapElement;
        }

        public void put(int n, double d) throws Exception {
            if (this.m_heap[0].index + 1 > this.m_heap.length - 1) {
                throw new Exception("the number of elements cannot exceed the initially set maximum limit");
            }
            ++this.m_heap[0].index;
            this.m_heap[this.m_heap[0].index] = new MyHeapElement(n, d);
            this.upheap();
        }

        private void upheap() {
            int n = this.m_heap[0].index;
            while (n > 1 && this.m_heap[n].distance > this.m_heap[n / 2].distance) {
                MyHeapElement myHeapElement = this.m_heap[n];
                this.m_heap[n] = this.m_heap[n / 2];
                this.m_heap[n /= 2] = myHeapElement;
            }
        }

        private void downheap() {
            int n = 1;
            while (2 * n <= this.m_heap[0].index && (this.m_heap[n].distance < this.m_heap[2 * n].distance || this.m_heap[n].distance < this.m_heap[2 * n + 1].distance)) {
                MyHeapElement myHeapElement;
                if (2 * n + 1 <= this.m_heap[0].index) {
                    if (this.m_heap[2 * n].distance > this.m_heap[2 * n + 1].distance) {
                        myHeapElement = this.m_heap[n];
                        this.m_heap[n] = this.m_heap[2 * n];
                        n = 2 * n;
                        this.m_heap[n] = myHeapElement;
                        continue;
                    }
                    myHeapElement = this.m_heap[n];
                    this.m_heap[n] = this.m_heap[2 * n + 1];
                    n = 2 * n + 1;
                    this.m_heap[n] = myHeapElement;
                    continue;
                }
                myHeapElement = this.m_heap[n];
                this.m_heap[n] = this.m_heap[2 * n];
                n = 2 * n;
                this.m_heap[n] = myHeapElement;
            }
        }
    }
}

