/*
 * Decompiled with CFR 0.152.
 */
package weka.estimators;

import java.io.Serializable;
import no.uib.cipr.matrix.DenseCholesky;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.UpperSPDDenseMatrix;
import no.uib.cipr.matrix.Vector;
import weka.core.Utils;
import weka.estimators.MultivariateEstimator;

public class MultivariateGaussianEstimator
implements MultivariateEstimator,
Serializable {
    protected DenseVector mean;
    protected UpperSPDDenseMatrix covarianceInverse;
    protected double lnconstant;
    protected double m_Ridge = 1.0E-6;
    public static final double Log2PI = Math.log(Math.PI * 2);

    public String toString() {
        StringBuffer sb = new StringBuffer();
        sb.append("Natural logarithm of normalizing factor: " + this.lnconstant + "\n\n");
        sb.append("Mean vector:\n\n" + this.mean + "\n");
        sb.append("Inverse of covariance matrix:\n\n" + this.covarianceInverse + "\n");
        return sb.toString();
    }

    public double[] getMean() {
        return this.mean.getData();
    }

    @Override
    public double logDensity(double[] valuePassed) {
        DenseVector subtractedMean = new DenseVector(this.mean.size());
        for (int i = 0; i < valuePassed.length; ++i) {
            subtractedMean.set(i, valuePassed[i] - this.mean.get(i));
        }
        return this.lnconstant - 0.5 * subtractedMean.dot(this.covarianceInverse.mult(subtractedMean, new DenseVector(subtractedMean.size())));
    }

    @Override
    public void estimate(double[][] observations, double[] weights) {
        if (weights == null) {
            weights = new double[observations.length];
            for (int i = 0; i < weights.length; ++i) {
                weights[i] = 1.0;
            }
        }
        this.mean = this.weightedMean(observations, weights);
        UpperSPDDenseMatrix cov = this.weightedCovariance(observations, weights, this.mean);
        DenseCholesky chol = new DenseCholesky(observations[0].length, true).factor(cov);
        this.covarianceInverse = new UpperSPDDenseMatrix(chol.solve(Matrices.identity(observations[0].length)));
        double logDeterminant = 0.0;
        for (int i = 0; i < observations[0].length; ++i) {
            logDeterminant += Math.log(chol.getU().get(i, i));
        }
        this.lnconstant = -(Log2PI * (double)observations[0].length + (logDeterminant *= 2.0)) * 0.5;
    }

    public double[][] estimatePooled(double[][][] observations, double[][] weights) {
        int m = -1;
        int c = observations.length;
        for (int i = 0; i < observations.length; ++i) {
            if (observations[i].length <= 0) continue;
            m = observations[i][0].length;
        }
        if (m == -1) {
            throw new IllegalArgumentException("Cannot compute pooled estimates with no data.");
        }
        Matrix[] groupCovariance = new Matrix[c];
        DenseVector[] groupMean = new DenseVector[c];
        double[] groupWeights = new double[c];
        for (int i = 0; i < groupCovariance.length; ++i) {
            if (observations[i].length <= 0) continue;
            groupMean[i] = this.weightedMean(observations[i], weights[i]);
            groupCovariance[i] = this.weightedCovariance(observations[i], weights[i], groupMean[i]);
            groupWeights[i] = Utils.sum(weights[i]);
        }
        Utils.normalize(groupWeights);
        double[][] means = new double[c][];
        Matrix cov = new UpperSPDDenseMatrix(m);
        this.mean = new DenseVector(groupMean[0].size());
        for (int i = 0; i < c; ++i) {
            if (observations[i].length <= 0) continue;
            cov = cov.add(groupWeights[i], groupCovariance[i]);
            this.mean = (DenseVector)this.mean.add(groupWeights[i], groupMean[i]);
            means[i] = groupMean[i].getData();
        }
        DenseCholesky chol = new DenseCholesky(m, true).factor((UpperSPDDenseMatrix)cov);
        this.covarianceInverse = new UpperSPDDenseMatrix(chol.solve(Matrices.identity(m)));
        double logDeterminant = 0.0;
        for (int i = 0; i < m; ++i) {
            logDeterminant += Math.log(chol.getU().get(i, i));
        }
        this.lnconstant = -(Log2PI * (double)m + (logDeterminant *= 2.0)) * 0.5;
        return means;
    }

    private DenseVector weightedMean(double[][] matrix, double[] weights) {
        int rows = matrix.length;
        int cols = matrix[0].length;
        DenseVector mean = new DenseVector(cols);
        double sumOfWeights = 0.0;
        for (int i = 0; i < rows; ++i) {
            double[] row = matrix[i];
            double w = weights[i];
            for (int j = 0; j < cols; ++j) {
                mean.add(j, row[j] * w);
            }
            sumOfWeights += w;
        }
        mean.scale(1.0 / sumOfWeights);
        return mean;
    }

    private UpperSPDDenseMatrix weightedCovariance(double[][] matrix, double[] weights, Vector mean) {
        int rows = matrix.length;
        int cols = matrix[0].length;
        if (mean.size() != cols) {
            throw new IllegalArgumentException("Length of the mean vector must match matrix.");
        }
        DenseMatrix covT = new DenseMatrix(cols, cols);
        for (int i = 0; i < cols; ++i) {
            for (int j = i; j < cols; ++j) {
                double s = 0.0;
                double sumOfWeights = 0.0;
                for (int k = 0; k < rows; ++k) {
                    s += weights[k] * (matrix[k][j] - mean.get(j)) * (matrix[k][i] - mean.get(i));
                    sumOfWeights += weights[k];
                }
                covT.set(i, j, s /= sumOfWeights);
                covT.set(j, i, s);
                if (i != j) continue;
                covT.add(i, j, this.m_Ridge);
            }
        }
        return new UpperSPDDenseMatrix(covT);
    }

    public String ridgeTipText() {
        return "The value of the ridge parameter.";
    }

    public double getRidge() {
        return this.m_Ridge;
    }

    public void setRidge(double newRidge) {
        this.m_Ridge = newRidge;
    }

    public static void main(String[] args) {
        double[][] dataset1 = new double[4][1];
        dataset1[0][0] = 0.49;
        dataset1[1][0] = 0.46;
        dataset1[2][0] = 0.51;
        dataset1[3][0] = 0.55;
        MultivariateGaussianEstimator mv1 = new MultivariateGaussianEstimator();
        mv1.estimate(dataset1, new double[]{0.7, 0.2, 0.05, 0.05});
        System.err.println(mv1);
        double integral1 = 0.0;
        int numVals = 1000;
        for (int i = 0; i < numVals; ++i) {
            double[] point = new double[]{((double)i + 0.5) * (1.0 / (double)numVals)};
            double logdens = mv1.logDensity(point);
            if (Double.isNaN(logdens)) continue;
            integral1 += Math.exp(logdens) * (1.0 / (double)numVals);
        }
        System.err.println("Approximate integral: " + integral1);
        double[][] dataset = new double[4][3];
        dataset[0][0] = 0.49;
        dataset[0][1] = 0.51;
        dataset[0][2] = 0.53;
        dataset[1][0] = 0.46;
        dataset[1][1] = 0.47;
        dataset[1][2] = 0.52;
        dataset[2][0] = 0.51;
        dataset[2][1] = 0.49;
        dataset[2][2] = 0.47;
        dataset[3][0] = 0.55;
        dataset[3][1] = 0.52;
        dataset[3][2] = 0.54;
        MultivariateGaussianEstimator mv = new MultivariateGaussianEstimator();
        mv.estimate(dataset, new double[]{2.0, 0.2, 0.05, 0.05});
        System.err.println(mv);
        double integral = 0.0;
        int numVals2 = 200;
        for (int i = 0; i < numVals2; ++i) {
            for (int j = 0; j < numVals2; ++j) {
                for (int k = 0; k < numVals2; ++k) {
                    double[] point = new double[]{((double)i + 0.5) * (1.0 / (double)numVals2), ((double)j + 0.5) * (1.0 / (double)numVals2), ((double)k + 0.5) * (1.0 / (double)numVals2)};
                    double logdens = mv.logDensity(point);
                    if (Double.isNaN(logdens)) continue;
                    integral += Math.exp(logdens) / (double)(numVals2 * numVals2 * numVals2);
                }
            }
        }
        System.err.println("Approximate integral: " + integral);
        double[][] dataset3 = new double[5][3];
        dataset3[0][0] = 0.49;
        dataset3[0][1] = 0.51;
        dataset3[0][2] = 0.53;
        dataset3[4][0] = 0.49;
        dataset3[4][1] = 0.51;
        dataset3[4][2] = 0.53;
        dataset3[1][0] = 0.46;
        dataset3[1][1] = 0.47;
        dataset3[1][2] = 0.52;
        dataset3[2][0] = 0.51;
        dataset3[2][1] = 0.49;
        dataset3[2][2] = 0.47;
        dataset3[3][0] = 0.55;
        dataset3[3][1] = 0.52;
        dataset3[3][2] = 0.54;
        MultivariateGaussianEstimator mv3 = new MultivariateGaussianEstimator();
        mv3.estimate(dataset3, new double[]{1.0, 0.2, 0.05, 0.05, 1.0});
        System.err.println(mv3);
        double integral3 = 0.0;
        int numVals3 = 200;
        for (int i = 0; i < numVals3; ++i) {
            for (int j = 0; j < numVals3; ++j) {
                for (int k = 0; k < numVals3; ++k) {
                    double[] point = new double[]{((double)i + 0.5) * (1.0 / (double)numVals3), ((double)j + 0.5) * (1.0 / (double)numVals3), ((double)k + 0.5) * (1.0 / (double)numVals3)};
                    double logdens = mv.logDensity(point);
                    if (Double.isNaN(logdens)) continue;
                    integral3 += Math.exp(logdens) / (double)(numVals3 * numVals3 * numVals3);
                }
            }
        }
        System.err.println("Approximate integral: " + integral3);
        double[][][] dataset4 = new double[][][]{new double[2][3], new double[3][3]};
        dataset4[0][0][0] = 0.49;
        dataset4[0][0][1] = 0.51;
        dataset4[0][0][2] = 0.53;
        dataset4[0][1][0] = 0.49;
        dataset4[0][1][1] = 0.51;
        dataset4[0][1][2] = 0.53;
        dataset4[1][0][0] = 0.46;
        dataset4[1][0][1] = 0.47;
        dataset4[1][0][2] = 0.52;
        dataset4[1][1][0] = 0.51;
        dataset4[1][1][1] = 0.49;
        dataset4[1][1][2] = 0.47;
        dataset4[1][2][0] = 0.55;
        dataset4[1][2][1] = 0.52;
        dataset4[1][2][2] = 0.54;
        double[][] weights = new double[][]{{1.0, 3.0}, {2.0, 1.0, 1.0}};
        MultivariateGaussianEstimator mv4 = new MultivariateGaussianEstimator();
        mv4.estimatePooled(dataset4, weights);
        System.err.println(mv4);
        double integral4 = 0.0;
        int numVals4 = 200;
        for (int i = 0; i < numVals4; ++i) {
            for (int j = 0; j < numVals4; ++j) {
                for (int k = 0; k < numVals4; ++k) {
                    double[] point = new double[]{((double)i + 0.5) * (1.0 / (double)numVals4), ((double)j + 0.5) * (1.0 / (double)numVals4), ((double)k + 0.5) * (1.0 / (double)numVals4)};
                    double logdens = mv.logDensity(point);
                    if (Double.isNaN(logdens)) continue;
                    integral4 += Math.exp(logdens) / (double)(numVals4 * numVals4 * numVals4);
                }
            }
        }
        System.err.println("Approximate integral: " + integral4);
        double[][][] dataset5 = new double[][][]{new double[4][3], new double[4][3]};
        dataset5[0][0][0] = 0.49;
        dataset5[0][0][1] = 0.51;
        dataset5[0][0][2] = 0.53;
        dataset5[0][1][0] = 0.49;
        dataset5[0][1][1] = 0.51;
        dataset5[0][1][2] = 0.53;
        dataset5[0][2][0] = 0.49;
        dataset5[0][2][1] = 0.51;
        dataset5[0][2][2] = 0.53;
        dataset5[0][3][0] = 0.49;
        dataset5[0][3][1] = 0.51;
        dataset5[0][3][2] = 0.53;
        dataset5[1][0][0] = 0.46;
        dataset5[1][0][1] = 0.47;
        dataset5[1][0][2] = 0.52;
        dataset5[1][1][0] = 0.46;
        dataset5[1][1][1] = 0.47;
        dataset5[1][1][2] = 0.52;
        dataset5[1][2][0] = 0.51;
        dataset5[1][2][1] = 0.49;
        dataset5[1][2][2] = 0.47;
        dataset5[1][3][0] = 0.55;
        dataset5[1][3][1] = 0.52;
        dataset5[1][3][2] = 0.54;
        double[][] weights2 = new double[][]{{1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0}};
        MultivariateGaussianEstimator mv5 = new MultivariateGaussianEstimator();
        mv5.estimatePooled(dataset5, weights2);
        System.err.println(mv5);
        double integral5 = 0.0;
        int numVals5 = 200;
        for (int i = 0; i < numVals5; ++i) {
            for (int j = 0; j < numVals5; ++j) {
                for (int k = 0; k < numVals5; ++k) {
                    double[] point = new double[]{((double)i + 0.5) * (1.0 / (double)numVals5), ((double)j + 0.5) * (1.0 / (double)numVals5), ((double)k + 0.5) * (1.0 / (double)numVals5)};
                    double logdens = mv.logDensity(point);
                    if (Double.isNaN(logdens)) continue;
                    integral5 += Math.exp(logdens) / (double)(numVals5 * numVals5 * numVals5);
                }
            }
        }
        System.err.println("Approximate integral: " + integral5);
    }
}

