/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import dr.inference.distribution.RandomField;
import dr.inference.model.GradientProvider;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.distributions.RandomFieldDistribution;
import dr.math.matrixAlgebra.RobustEigenDecomposition;
import java.util.Arrays;

public class GaussianMarkovRandomField
extends RandomFieldDistribution {
    public static final String TYPE = "GaussianMarkovRandomField";
    protected final int dim;
    private final Parameter meanParameter;
    private final Parameter precisionParameter;
    private final Parameter lambdaParameter;
    private final RandomField.WeightProvider weightProvider;
    private final double[] mean;
    final SymmetricTriDiagonalMatrix Q;
    private final SymmetricTriDiagonalMatrix savedQ;
    private boolean meanKnown;
    boolean qKnown;
    private boolean savedQKnown;
    private final double logMatchTerm;
    private static final boolean CHECK_DETERMINANT = false;
    private static final double HALF_LOG_TWO_PI = Math.log(Math.PI * 2) / 2.0;

    public GaussianMarkovRandomField(String string, int n, Parameter parameter, Parameter parameter2, Parameter parameter3, RandomField.WeightProvider weightProvider, boolean bl) {
        super(string);
        this.dim = n;
        this.meanParameter = parameter2;
        this.precisionParameter = parameter;
        this.lambdaParameter = parameter3;
        this.weightProvider = weightProvider;
        this.addVariable(this.meanParameter);
        this.addVariable(this.precisionParameter);
        if (parameter3 != null) {
            this.addVariable(this.lambdaParameter);
        }
        if (weightProvider != null) {
            this.addModel(weightProvider);
        }
        this.mean = new double[n];
        this.Q = new SymmetricTriDiagonalMatrix(n);
        this.savedQ = new SymmetricTriDiagonalMatrix(n);
        this.logMatchTerm = bl ? this.matchPseudoDeterminantTerm(n) : 0.0;
        this.meanKnown = false;
        this.qKnown = false;
    }

    @Override
    public double[] getMean() {
        if (!this.meanKnown) {
            if (this.meanParameter == null) {
                Arrays.fill(this.mean, 0.0);
            } else if (this.meanParameter.getDimension() == 1) {
                Arrays.fill(this.mean, this.meanParameter.getParameterValue(0));
            } else {
                for (int i = 0; i < this.mean.length; ++i) {
                    this.mean[i] = this.meanParameter.getParameterValue(i);
                }
            }
            this.meanKnown = true;
        }
        return this.mean;
    }

    protected SymmetricTriDiagonalMatrix getQ() {
        if (!this.qKnown) {
            int n;
            double d = this.precisionParameter.getParameterValue(0);
            if (this.weightProvider == null) {
                this.Q.diagonal[0] = d;
                for (n = 1; n < this.dim - 1; ++n) {
                    this.Q.diagonal[n] = 2.0 * d;
                }
                this.Q.diagonal[this.dim - 1] = d;
                for (n = 0; n < this.dim - 1; ++n) {
                    this.Q.offDiagonal[n] = -d;
                }
            } else {
                this.Q.diagonal[0] = d * this.weightProvider.weight(0, 1);
                for (n = 1; n < this.dim - 1; ++n) {
                    this.Q.diagonal[n] = d * (this.weightProvider.weight(n - 1, n) + this.weightProvider.weight(n, n + 1));
                }
                this.Q.diagonal[this.dim - 1] = d * this.weightProvider.weight(this.dim - 2, this.dim - 1);
                for (n = 0; n < this.dim - 1; ++n) {
                    this.Q.offDiagonal[n] = -d * this.weightProvider.weight(n, n + 1);
                }
            }
            if (this.lambdaParameter != null) {
                double d2 = this.lambdaParameter.getParameterValue(0);
                for (int i = 0; i < this.dim - 1; ++i) {
                    this.Q.offDiagonal[i] = this.Q.offDiagonal[i] * d2;
                }
            }
            this.qKnown = true;
        }
        return this.Q;
    }

    private static double[][] makePrecisionMatrix(SymmetricTriDiagonalMatrix symmetricTriDiagonalMatrix) {
        int n;
        int n2 = symmetricTriDiagonalMatrix.diagonal.length;
        double[][] dArray = new double[n2][n2];
        for (n = 0; n < n2; ++n) {
            dArray[n][n] = symmetricTriDiagonalMatrix.diagonal[n];
        }
        for (n = 0; n < n2 - 1; ++n) {
            dArray[n][n + 1] = symmetricTriDiagonalMatrix.offDiagonal[n];
            dArray[n + 1][n] = symmetricTriDiagonalMatrix.offDiagonal[n];
        }
        return dArray;
    }

    private boolean isImproper() {
        return this.lambdaParameter == null || this.lambdaParameter.getParameterValue(0) == 1.0;
    }

    @Override
    public GradientProvider getGradientWrt(Parameter parameter) {
        if (parameter == this.precisionParameter) {
            return new GradientProvider(){

                @Override
                public int getDimension() {
                    return 1;
                }

                @Override
                public double[] getGradientLogDensity(Object object) {
                    double d = GaussianMarkovRandomField.gradLogPdfWrtPrecision((double[])object, GaussianMarkovRandomField.this.getMean(), GaussianMarkovRandomField.this.getQ(), GaussianMarkovRandomField.this.precisionParameter.getParameterValue(0), GaussianMarkovRandomField.this.isImproper());
                    return new double[]{d};
                }
            };
        }
        if (parameter == this.meanParameter) {
            return new GradientProvider(){

                @Override
                public int getDimension() {
                    return GaussianMarkovRandomField.this.meanParameter.getDimension();
                }

                @Override
                public double[] getGradientLogDensity(Object object) {
                    double[] dArray = GaussianMarkovRandomField.gradLogPdf((double[])object, GaussianMarkovRandomField.this.getMean(), GaussianMarkovRandomField.this.getQ());
                    if (GaussianMarkovRandomField.this.meanParameter.getDimension() == GaussianMarkovRandomField.this.dim) {
                        int n = 0;
                        while (n < GaussianMarkovRandomField.this.dim) {
                            int n2 = n++;
                            dArray[n2] = dArray[n2] * -1.0;
                        }
                        return dArray;
                    }
                    if (GaussianMarkovRandomField.this.meanParameter.getDimension() == 1) {
                        double d = 0.0;
                        for (int i = 0; i < GaussianMarkovRandomField.this.dim; ++i) {
                            d -= dArray[i];
                        }
                        return new double[]{d};
                    }
                    throw new IllegalArgumentException("Unknown mean parameter structure");
                }
            };
        }
        if (parameter == this.lambdaParameter) {
            throw new RuntimeException("Not yet implemented");
        }
        throw new RuntimeException("Unknown parameter");
    }

    @Override
    public String getType() {
        return TYPE;
    }

    private double matchPseudoDeterminantTerm(int n) {
        double d = 0.0;
        if (this.isImproper() && this.weightProvider == null) {
            for (int i = 1; i < n; ++i) {
                double d2 = 2.0 - 2.0 * Math.cos((double)i * Math.PI / (double)n);
                d += Math.log(d2);
            }
        }
        return d;
    }

    private double getLogDeterminant() {
        int n = this.isImproper() ? this.dim - 1 : this.dim;
        double d = (double)n * Math.log(this.precisionParameter.getParameterValue(0)) + this.logMatchTerm;
        if (!this.isImproper() || this.weightProvider != null) {
            double[][] dArray = GaussianMarkovRandomField.makePrecisionMatrix(this.Q);
            RobustEigenDecomposition robustEigenDecomposition = new RobustEigenDecomposition(new DenseDoubleMatrix2D(dArray));
            DoubleMatrix1D doubleMatrix1D = robustEigenDecomposition.getRealEigenvalues();
            for (int i = 0; i < doubleMatrix1D.size(); ++i) {
                double d2 = doubleMatrix1D.get(i);
                if (!(Math.abs(d2) > 1.0E-6)) continue;
                d += Math.log(d2);
            }
            d -= (double)n * Math.log(this.precisionParameter.getParameterValue(0));
        }
        return d;
    }

    @Override
    public double[][] getScaleMatrix() {
        return GaussianMarkovRandomField.makePrecisionMatrix(this.getQ());
    }

    @Override
    public Variable<Double> getLocationVariable() {
        return this.meanParameter;
    }

    @Override
    public double logPdf(double[] dArray) {
        return GaussianMarkovRandomField.logPdf(dArray, this.getMean(), this.getQ(), this.isImproper(), this.getLogDeterminant());
    }

    public static double gradLogPdfWrtPrecision(double[] dArray, double[] dArray2, SymmetricTriDiagonalMatrix symmetricTriDiagonalMatrix, double d, boolean bl) {
        int n = bl ? dArray.length - 1 : dArray.length;
        return 0.5 * ((double)n - GaussianMarkovRandomField.getSSE(dArray, dArray2, symmetricTriDiagonalMatrix)) / d;
    }

    public static double[] gradLogPdf(double[] dArray, double[] dArray2, SymmetricTriDiagonalMatrix symmetricTriDiagonalMatrix) {
        int n;
        int n2 = dArray.length;
        double[] dArray3 = new double[n2];
        double[] dArray4 = new double[n2];
        for (n = 0; n < n2; ++n) {
            dArray4[n] = dArray2[n] - dArray[n];
        }
        dArray3[0] = symmetricTriDiagonalMatrix.diagonal[0] * dArray4[0] + symmetricTriDiagonalMatrix.offDiagonal[0] * dArray4[1];
        for (n = 1; n < n2 - 1; ++n) {
            dArray3[n] = symmetricTriDiagonalMatrix.offDiagonal[n - 1] * dArray4[n - 1] + symmetricTriDiagonalMatrix.diagonal[n] * dArray4[n] + symmetricTriDiagonalMatrix.offDiagonal[n] * dArray4[n + 1];
        }
        dArray3[n2 - 1] = symmetricTriDiagonalMatrix.offDiagonal[n2 - 2] * dArray4[n2 - 2] + symmetricTriDiagonalMatrix.diagonal[n2 - 1] * dArray4[n2 - 1];
        return dArray3;
    }

    public static double[][] hessianLogPdf(double[] dArray, SymmetricTriDiagonalMatrix symmetricTriDiagonalMatrix) {
        int n = dArray.length;
        double[][] dArray2 = new double[n][n];
        dArray2[0][0] = -symmetricTriDiagonalMatrix.diagonal[0];
        dArray2[0][1] = -symmetricTriDiagonalMatrix.offDiagonal[0];
        for (int i = 1; i < n - 1; ++i) {
            dArray2[i][i - 1] = -symmetricTriDiagonalMatrix.offDiagonal[i - 1];
            dArray2[i][i] = -symmetricTriDiagonalMatrix.diagonal[i];
            dArray2[i][i + 1] = -symmetricTriDiagonalMatrix.offDiagonal[i];
        }
        dArray2[n - 1][n - 2] = -symmetricTriDiagonalMatrix.offDiagonal[n - 2];
        dArray2[n - 1][n - 1] = -symmetricTriDiagonalMatrix.diagonal[n - 1];
        return dArray2;
    }

    public static double[] diagonalHessianLogPdf(double[] dArray, SymmetricTriDiagonalMatrix symmetricTriDiagonalMatrix) {
        int n = dArray.length;
        double[] dArray2 = new double[n];
        System.arraycopy(symmetricTriDiagonalMatrix.diagonal, 0, dArray2, 0, n);
        return dArray2;
    }

    private static double logPdf(double[] dArray, double[] dArray2, SymmetricTriDiagonalMatrix symmetricTriDiagonalMatrix, boolean bl, double d) {
        return GaussianMarkovRandomField.getLogNormalization(dArray.length, bl, d) - 0.5 * GaussianMarkovRandomField.getSSE(dArray, dArray2, symmetricTriDiagonalMatrix);
    }

    private static double getSSE(double[] dArray, double[] dArray2, SymmetricTriDiagonalMatrix symmetricTriDiagonalMatrix) {
        int n = dArray.length;
        double[] dArray3 = new double[n];
        for (int i = 0; i < n; ++i) {
            dArray3[i] = dArray[i] - dArray2[i];
        }
        double d = 0.0;
        for (int i = 0; i < n - 1; ++i) {
            d += symmetricTriDiagonalMatrix.diagonal[i] * dArray3[i] * dArray3[i] + 2.0 * symmetricTriDiagonalMatrix.offDiagonal[i] * dArray3[i] * dArray3[i + 1];
        }
        return d += symmetricTriDiagonalMatrix.diagonal[n - 1] * dArray3[n - 1] * dArray3[n - 1];
    }

    private static double getLogNormalization(int n, boolean bl, double d) {
        int n2 = bl ? n - 1 : n;
        return (double)(-n2) * HALF_LOG_TWO_PI + 0.5 * d;
    }

    @Override
    public int getDimension() {
        return this.dim;
    }

    @Override
    public double[] getGradientLogDensity(Object object) {
        return GaussianMarkovRandomField.gradLogPdf((double[])object, this.getMean(), this.getQ());
    }

    @Override
    public double[] getDiagonalHessianLogDensity(Object object) {
        return GaussianMarkovRandomField.diagonalHessianLogPdf((double[])object, this.getQ());
    }

    @Override
    public double[][] getHessianLogDensity(Object object) {
        return GaussianMarkovRandomField.hessianLogPdf((double[])object, this.getQ());
    }

    @Override
    public double[] nextRandom() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        throw new IllegalArgumentException("Unknown model");
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable == this.meanParameter) {
            this.meanKnown = false;
        } else if (variable == this.precisionParameter || variable == this.lambdaParameter) {
            this.qKnown = false;
        } else {
            throw new IllegalArgumentException("Unknown variable");
        }
    }

    @Override
    protected void storeState() {
        if (this.qKnown) {
            this.Q.copyTo(this.savedQ);
        }
        this.savedQKnown = this.qKnown;
    }

    @Override
    protected void restoreState() {
        this.meanKnown = false;
        this.qKnown = this.savedQKnown;
        if (this.qKnown) {
            this.savedQ.swap(this.Q);
        }
    }

    @Override
    protected void acceptState() {
    }

    static class SymmetricTriDiagonalMatrix {
        double[] diagonal;
        double[] offDiagonal;

        SymmetricTriDiagonalMatrix(int n) {
            this(new double[n], new double[n - 1]);
        }

        SymmetricTriDiagonalMatrix(double[] dArray, double[] dArray2) {
            this.diagonal = dArray;
            this.offDiagonal = dArray2;
        }

        void copyTo(SymmetricTriDiagonalMatrix symmetricTriDiagonalMatrix) {
            System.arraycopy(this.diagonal, 0, symmetricTriDiagonalMatrix.diagonal, 0, this.diagonal.length);
            System.arraycopy(this.offDiagonal, 0, symmetricTriDiagonalMatrix.offDiagonal, 0, this.offDiagonal.length);
        }

        void swap(SymmetricTriDiagonalMatrix symmetricTriDiagonalMatrix) {
            double[] dArray = this.diagonal;
            this.diagonal = symmetricTriDiagonalMatrix.diagonal;
            symmetricTriDiagonalMatrix.diagonal = dArray;
            double[] dArray2 = this.offDiagonal;
            this.offDiagonal = symmetricTriDiagonalMatrix.offDiagonal;
            symmetricTriDiagonalMatrix.offDiagonal = dArray2;
        }
    }
}

