/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;

public class MILR
extends Classifier
implements OptionHandler,
MultiInstanceCapabilitiesHandler {
    static final long serialVersionUID = 1996101190172373826L;
    protected double[] m_Par;
    protected int m_NumClasses;
    protected double m_Ridge = 1.0E-6;
    protected int[] m_Classes;
    protected double[][][] m_Data;
    protected Instances m_Attributes;
    protected double[] xMean = null;
    protected double[] xSD = null;
    protected int m_AlgorithmType = 0;
    public static final int ALGORITHMTYPE_DEFAULT = 0;
    public static final int ALGORITHMTYPE_ARITHMETIC = 1;
    public static final int ALGORITHMTYPE_GEOMETRIC = 2;
    public static final Tag[] TAGS_ALGORITHMTYPE = new Tag[]{new Tag(0, "standard MI assumption"), new Tag(1, "collective MI assumption, arithmetic mean for posteriors"), new Tag(2, "collective MI assumption, geometric mean for posteriors")};

    public String globalInfo() {
        return "Uses either standard or collective multi-instance assumption, but within linear regression. For the collective assumption, it offers arithmetic or geometric mean for the posteriors.";
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>();
        vector.addElement(new Option("\tTurn on debugging output.", "D", 0, "-D"));
        vector.addElement(new Option("\tSet the ridge in the log-likelihood.", "R", 1, "-R <ridge>"));
        vector.addElement(new Option("\tDefines the type of algorithm:\n\t 0. standard MI assumption\n\t 1. collective MI assumption, arithmetic mean for posteriors\n\t 2. collective MI assumption, geometric mean for posteriors", "A", 1, "-A [0|1|2]"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        this.setDebug(Utils.getFlag('D', stringArray));
        String string = Utils.getOption('R', stringArray);
        if (string.length() != 0) {
            this.setRidge(Double.parseDouble(string));
        } else {
            this.setRidge(1.0E-6);
        }
        string = Utils.getOption('A', stringArray);
        if (string.length() != 0) {
            this.setAlgorithmType(new SelectedTag(Integer.parseInt(string), TAGS_ALGORITHMTYPE));
        } else {
            this.setAlgorithmType(new SelectedTag(0, TAGS_ALGORITHMTYPE));
        }
    }

    public String[] getOptions() {
        Vector<String> vector = new Vector<String>();
        if (this.getDebug()) {
            vector.add("-D");
        }
        vector.add("-R");
        vector.add("" + this.getRidge());
        vector.add("-A");
        vector.add("" + this.m_AlgorithmType);
        return vector.toArray(new String[vector.size()]);
    }

    public String ridgeTipText() {
        return "The ridge in the log-likelihood.";
    }

    public void setRidge(double d) {
        this.m_Ridge = d;
    }

    public double getRidge() {
        return this.m_Ridge;
    }

    public String algorithmTypeTipText() {
        return "The mean type for the posteriors.";
    }

    public SelectedTag getAlgorithmType() {
        return new SelectedTag(this.m_AlgorithmType, TAGS_ALGORITHMTYPE);
    }

    public void setAlgorithmType(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_ALGORITHMTYPE) {
            this.m_AlgorithmType = selectedTag.getSelectedTag().getID();
        }
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return capabilities;
    }

    public Capabilities getMultiInstanceCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.disableAllClasses();
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        double d;
        int n;
        int n2;
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        this.m_NumClasses = instances.numClasses();
        int n3 = instances.attribute(1).relation().numAttributes();
        int n4 = instances.numInstances();
        this.m_Data = new double[n4][n3][];
        this.m_Classes = new int[n4];
        this.m_Attributes = instances.attribute(1).relation();
        this.xMean = new double[n3];
        this.xSD = new double[n3];
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        int[] nArray = new int[n3];
        if (this.m_Debug) {
            System.out.println("Extracting data...");
        }
        for (n2 = 0; n2 < this.m_Data.length; ++n2) {
            Instance instance = instances.instance(n2);
            this.m_Classes[n2] = (int)instance.classValue();
            Instances instances2 = instance.relationalValue(1);
            int n5 = instances2.numInstances();
            d4 += (double)n5;
            for (n = 0; n < n3; ++n) {
                this.m_Data[n2][n] = new double[n5];
                d = 0.0;
                double d5 = 0.0;
                double d6 = 0.0;
                for (int i = 0; i < n5; ++i) {
                    if (!instances2.instance(i).isMissing(n)) {
                        this.m_Data[n2][n][i] = instances2.instance(i).value(n);
                        d += this.m_Data[n2][n][i];
                        d5 += this.m_Data[n2][n][i] * this.m_Data[n2][n][i];
                        d6 += 1.0;
                        continue;
                    }
                    this.m_Data[n2][n][i] = Double.NaN;
                }
                if (d6 > 0.0) {
                    int n6 = n;
                    this.xMean[n6] = this.xMean[n6] + d / d6;
                    int n7 = n;
                    this.xSD[n7] = this.xSD[n7] + d5 / d6;
                    continue;
                }
                int n8 = n;
                nArray[n8] = nArray[n8] + 1;
            }
            if (this.m_Classes[n2] == 1) {
                d2 += 1.0;
                continue;
            }
            d3 += 1.0;
        }
        for (n2 = 0; n2 < n3; ++n2) {
            this.xMean[n2] = this.xMean[n2] / (double)(n4 - nArray[n2]);
            this.xSD[n2] = Math.sqrt(Math.abs(this.xSD[n2] / ((double)(n4 - nArray[n2]) - 1.0) - this.xMean[n2] * this.xMean[n2] * (double)(n4 - nArray[n2]) / ((double)(n4 - nArray[n2]) - 1.0)));
        }
        if (this.m_Debug) {
            System.out.println("Descriptives...");
            System.out.println(d3 + " bags have class 0 and " + d2 + " bags have class 1");
            System.out.println("\n Variable     Avg       SD    ");
            for (n2 = 0; n2 < n3; ++n2) {
                System.out.println(Utils.doubleToString(n2, 8, 4) + Utils.doubleToString(this.xMean[n2], 10, 4) + Utils.doubleToString(this.xSD[n2], 10, 4));
            }
        }
        for (n2 = 0; n2 < n4; ++n2) {
            for (int i = 0; i < n3; ++i) {
                for (int j = 0; j < this.m_Data[n2][i].length; ++j) {
                    if (this.xSD[i] == 0.0) continue;
                    this.m_Data[n2][i][j] = !Double.isNaN(this.m_Data[n2][i][j]) ? (this.m_Data[n2][i][j] - this.xMean[i]) / this.xSD[i] : 0.0;
                }
            }
        }
        if (this.m_Debug) {
            System.out.println("\nIteration History...");
        }
        double[] dArray = new double[n3 + 1];
        dArray[0] = Math.log((d2 + 1.0) / (d3 + 1.0));
        double[][] dArray2 = new double[2][dArray.length];
        dArray2[0][0] = Double.NaN;
        dArray2[1][0] = Double.NaN;
        for (int i = 1; i < dArray.length; ++i) {
            dArray[i] = 0.0;
            dArray2[0][i] = Double.NaN;
            dArray2[1][i] = Double.NaN;
        }
        OptEng optEng = new OptEng(this.m_AlgorithmType);
        optEng.setDebug(this.m_Debug);
        this.m_Par = optEng.findArgmin(dArray, dArray2);
        while (this.m_Par == null) {
            this.m_Par = optEng.getVarbValues();
            if (this.m_Debug) {
                System.out.println("200 iterations finished, not enough!");
            }
            this.m_Par = optEng.findArgmin(this.m_Par, dArray2);
        }
        if (this.m_Debug) {
            System.out.println(" -------------<Converged>--------------");
        }
        if (this.m_AlgorithmType == 1) {
            double[] dArray3 = new double[n3];
            for (n = 1; n < n3 + 1; ++n) {
                dArray3[n - 1] = Math.abs(this.m_Par[n]);
            }
            int[] nArray2 = Utils.sort(dArray3);
            d = dArray3[nArray2[nArray2.length - 1]];
            for (int i = nArray2.length - 1; i >= 0; --i) {
                System.out.println(this.m_Attributes.attribute(nArray2[i]).name() + "\t" + dArray3[nArray2[i]] * 100.0 / d);
            }
        }
        for (int i = 1; i < n3 + 1; ++i) {
            if (this.xSD[i - 1] == 0.0) continue;
            int n9 = i;
            this.m_Par[n9] = this.m_Par[n9] / this.xSD[i - 1];
            this.m_Par[0] = this.m_Par[0] - this.m_Par[i] * this.xMean[i - 1];
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        int n;
        Instances instances = instance.relationalValue(1);
        int n2 = instances.numInstances();
        int n3 = instances.numAttributes();
        double[][] dArray = new double[n2][n3 + 1];
        for (int i = 0; i < n2; ++i) {
            dArray[i][0] = 1.0;
            n = 1;
            for (int j = 0; j < n3; ++j) {
                dArray[i][n] = !instances.instance(i).isMissing(j) ? instances.instance(i).value(j) : this.xMean[n - 1];
                ++n;
            }
        }
        double[] dArray2 = new double[2];
        switch (this.m_AlgorithmType) {
            case 0: {
                dArray2[0] = 0.0;
                for (n = 0; n < n2; ++n) {
                    double d = 0.0;
                    for (int i = 0; i < this.m_Par.length; ++i) {
                        d += this.m_Par[i] * dArray[n][i];
                    }
                    d = Math.exp(d);
                    dArray2[0] = dArray2[0] - Math.log(1.0 + d);
                }
                dArray2[0] = Math.exp(dArray2[0]);
                dArray2[1] = 1.0 - dArray2[0];
                break;
            }
            case 1: {
                dArray2[0] = 0.0;
                for (n = 0; n < n2; ++n) {
                    double d = 0.0;
                    for (int i = 0; i < this.m_Par.length; ++i) {
                        d += this.m_Par[i] * dArray[n][i];
                    }
                    d = Math.exp(d);
                    dArray2[0] = dArray2[0] + 1.0 / (1.0 + d);
                }
                dArray2[0] = dArray2[0] / (double)n2;
                dArray2[1] = 1.0 - dArray2[0];
                break;
            }
            case 2: {
                for (n = 0; n < n2; ++n) {
                    double d = 0.0;
                    for (int i = 0; i < this.m_Par.length; ++i) {
                        d += this.m_Par[i] * dArray[n][i];
                    }
                    dArray2[1] = dArray2[1] + d / (double)n2;
                }
                dArray2[1] = 1.0 / (1.0 + Math.exp(-dArray2[1]));
                dArray2[0] = 1.0 - dArray2[1];
            }
        }
        return dArray2;
    }

    public String toString() {
        String string = "Modified Logistic Regression";
        if (this.m_Par == null) {
            return string + ": No model built yet.";
        }
        string = string + "\nMean type: " + this.getAlgorithmType().getSelectedTag().getReadable() + "\n";
        string = string + "\nCoefficients...\nVariable      Coeff.\n";
        int n = 1;
        int n2 = 0;
        while (n < this.m_Par.length) {
            string = string + this.m_Attributes.attribute(n2).name();
            string = string + " " + Utils.doubleToString(this.m_Par[n], 12, 4);
            string = string + "\n";
            ++n;
            ++n2;
        }
        string = string + "Intercept:";
        string = string + " " + Utils.doubleToString(this.m_Par[0], 10, 4);
        string = string + "\n";
        string = string + "\nOdds Ratios...\nVariable         O.R.\n";
        n = 1;
        n2 = 0;
        while (n < this.m_Par.length) {
            string = string + " " + this.m_Attributes.attribute(n2).name();
            double d = Math.exp(this.m_Par[n]);
            string = string + " " + (d > 1.0E10 ? "" + d : Utils.doubleToString(d, 12, 4));
            ++n;
            ++n2;
        }
        string = string + "\n";
        return string;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5527 $");
    }

    public static void main(String[] stringArray) {
        MILR.runClassifier(new MILR(), stringArray);
    }

    private class OptEng
    extends Optimization {
        private int m_Type;

        public OptEng(int n) {
            this.m_Type = n;
        }

        protected double objectiveFunction(double[] dArray) {
            int n;
            double d = 0.0;
            switch (this.m_Type) {
                case 0: {
                    double d2;
                    int n2;
                    for (n = 0; n < MILR.this.m_Classes.length; ++n) {
                        n2 = MILR.this.m_Data[n][0].length;
                        d2 = 0.0;
                        double d3 = 0.0;
                        for (int i = 0; i < n2; ++i) {
                            double d4 = 0.0;
                            for (int j = MILR.this.m_Data[n].length - 1; j >= 0; --j) {
                                d4 += MILR.this.m_Data[n][j][i] * dArray[j + 1];
                            }
                            d4 += dArray[0];
                            d4 = Math.exp(d4);
                            if (MILR.this.m_Classes[n] == 1) {
                                d3 -= Math.log(1.0 + d4);
                                continue;
                            }
                            d2 += Math.log(1.0 + d4);
                        }
                        if (MILR.this.m_Classes[n] == 1) {
                            d2 = -Math.log(1.0 - Math.exp(d3));
                        }
                        d += d2;
                    }
                    break;
                }
                case 1: {
                    double d5;
                    double d2;
                    int n2;
                    for (n = 0; n < MILR.this.m_Classes.length; ++n) {
                        n2 = MILR.this.m_Data[n][0].length;
                        d2 = 0.0;
                        for (int i = 0; i < n2; ++i) {
                            d5 = 0.0;
                            for (int j = MILR.this.m_Data[n].length - 1; j >= 0; --j) {
                                d5 += MILR.this.m_Data[n][j][i] * dArray[j + 1];
                            }
                            d5 += dArray[0];
                            d5 = Math.exp(d5);
                            if (MILR.this.m_Classes[n] == 1) {
                                d2 += 1.0 - 1.0 / (1.0 + d5);
                                continue;
                            }
                            d2 += 1.0 / (1.0 + d5);
                        }
                        d -= Math.log(d2 /= (double)n2);
                    }
                    break;
                }
                case 2: {
                    double d5;
                    double d2;
                    int n2;
                    for (n = 0; n < MILR.this.m_Classes.length; ++n) {
                        n2 = MILR.this.m_Data[n][0].length;
                        d2 = 0.0;
                        for (int i = 0; i < n2; ++i) {
                            d5 = 0.0;
                            for (int j = MILR.this.m_Data[n].length - 1; j >= 0; --j) {
                                d5 += MILR.this.m_Data[n][j][i] * dArray[j + 1];
                            }
                            d5 += dArray[0];
                            if (MILR.this.m_Classes[n] == 1) {
                                d2 -= d5 / (double)n2;
                                continue;
                            }
                            d2 += d5 / (double)n2;
                        }
                        d += Math.log(1.0 + Math.exp(d2));
                    }
                    break;
                }
            }
            for (n = 1; n < dArray.length; ++n) {
                d += MILR.this.m_Ridge * dArray[n] * dArray[n];
            }
            return d;
        }

        protected double[] evaluateGradient(double[] dArray) {
            int n;
            double[] dArray2 = new double[dArray.length];
            switch (this.m_Type) {
                case 0: {
                    double d;
                    int n2;
                    double d2;
                    int n3;
                    double[] dArray3;
                    double d3;
                    int n4;
                    for (n = 0; n < MILR.this.m_Classes.length; ++n) {
                        n4 = MILR.this.m_Data[n][0].length;
                        d3 = 0.0;
                        dArray3 = new double[dArray2.length];
                        for (n3 = 0; n3 < n4; ++n3) {
                            d2 = 0.0;
                            for (n2 = MILR.this.m_Data[n].length - 1; n2 >= 0; --n2) {
                                d2 += MILR.this.m_Data[n][n2][n3] * dArray[n2 + 1];
                            }
                            d2 += dArray[0];
                            d2 = Math.exp(d2) / (1.0 + Math.exp(d2));
                            if (MILR.this.m_Classes[n] == 1) {
                                d3 -= Math.log(1.0 - d2);
                            }
                            n2 = 0;
                            while (n2 < dArray.length) {
                                d = 1.0;
                                if (n2 > 0) {
                                    d = MILR.this.m_Data[n][n2 - 1][n3];
                                }
                                int n5 = n2++;
                                dArray3[n5] = dArray3[n5] + d * d2;
                            }
                        }
                        d3 = Math.exp(d3);
                        for (n3 = 0; n3 < dArray2.length; ++n3) {
                            if (MILR.this.m_Classes[n] == 1) {
                                int n6 = n3;
                                dArray2[n6] = dArray2[n6] - dArray3[n3] / (d3 - 1.0);
                                continue;
                            }
                            int n7 = n3;
                            dArray2[n7] = dArray2[n7] + dArray3[n3];
                        }
                    }
                    break;
                }
                case 1: {
                    double d;
                    int n2;
                    double d2;
                    int n3;
                    double[] dArray3;
                    double d3;
                    int n4;
                    for (n = 0; n < MILR.this.m_Classes.length; ++n) {
                        n4 = MILR.this.m_Data[n][0].length;
                        d3 = 0.0;
                        dArray3 = new double[dArray.length];
                        for (n3 = 0; n3 < n4; ++n3) {
                            d2 = 0.0;
                            for (n2 = MILR.this.m_Data[n].length - 1; n2 >= 0; --n2) {
                                d2 += MILR.this.m_Data[n][n2][n3] * dArray[n2 + 1];
                            }
                            d2 += dArray[0];
                            d2 = Math.exp(d2);
                            d3 = MILR.this.m_Classes[n] == 1 ? (d3 += d2 / (1.0 + d2)) : (d3 += 1.0 / (1.0 + d2));
                            n2 = 0;
                            while (n2 < dArray.length) {
                                d = 1.0;
                                if (n2 > 0) {
                                    d = MILR.this.m_Data[n][n2 - 1][n3];
                                }
                                int n8 = n2++;
                                dArray3[n8] = dArray3[n8] + d * d2 / ((1.0 + d2) * (1.0 + d2));
                            }
                        }
                        for (n3 = 0; n3 < dArray2.length; ++n3) {
                            if (MILR.this.m_Classes[n] == 1) {
                                int n9 = n3;
                                dArray2[n9] = dArray2[n9] - dArray3[n3] / d3;
                                continue;
                            }
                            int n10 = n3;
                            dArray2[n10] = dArray2[n10] + dArray3[n3] / d3;
                        }
                    }
                    break;
                }
                case 2: {
                    double d;
                    int n2;
                    double d2;
                    int n3;
                    double[] dArray3;
                    double d3;
                    int n4;
                    for (n = 0; n < MILR.this.m_Classes.length; ++n) {
                        n4 = MILR.this.m_Data[n][0].length;
                        d3 = 0.0;
                        dArray3 = new double[dArray.length];
                        for (n3 = 0; n3 < n4; ++n3) {
                            d2 = 0.0;
                            for (n2 = MILR.this.m_Data[n].length - 1; n2 >= 0; --n2) {
                                d2 += MILR.this.m_Data[n][n2][n3] * dArray[n2 + 1];
                            }
                            d2 += dArray[0];
                            if (MILR.this.m_Classes[n] == 1) {
                                d3 -= d2 / (double)n4;
                                n2 = 0;
                                while (n2 < dArray2.length) {
                                    d = 1.0;
                                    if (n2 > 0) {
                                        d = MILR.this.m_Data[n][n2 - 1][n3];
                                    }
                                    int n11 = n2++;
                                    dArray3[n11] = dArray3[n11] - d / (double)n4;
                                }
                                continue;
                            }
                            d3 += d2 / (double)n4;
                            n2 = 0;
                            while (n2 < dArray2.length) {
                                d = 1.0;
                                if (n2 > 0) {
                                    d = MILR.this.m_Data[n][n2 - 1][n3];
                                }
                                int n12 = n2++;
                                dArray3[n12] = dArray3[n12] + d / (double)n4;
                            }
                        }
                        for (n3 = 0; n3 < dArray.length; ++n3) {
                            int n13 = n3;
                            dArray2[n13] = dArray2[n13] + Math.exp(d3) * dArray3[n3] / (1.0 + Math.exp(d3));
                        }
                    }
                    break;
                }
            }
            for (n = 1; n < dArray.length; ++n) {
                int n14 = n;
                dArray2[n14] = dArray2[n14] + 2.0 * MILR.this.m_Ridge * dArray[n];
            }
            return dArray2;
        }

        public String getRevision() {
            return RevisionUtils.extract("$Revision: 5527 $");
        }
    }
}

