/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous;

import dr.evomodel.continuous.TreeTraitNormalDistributionModel;
import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.inference.distribution.ParametricMultivariateDistributionModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.RandomGenerator;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.List;

public class AcrossTreeTraitNormalDistributionModel
extends AbstractModel
implements ParametricMultivariateDistributionModel,
RandomGenerator {
    private final ContinuousDataLikelihoodDelegate delegate1;
    private final ContinuousDataLikelihoodDelegate delegate2;
    private final Parameter rho;
    private final int dim;
    private double[][] variance;
    private MultivariateNormalDistribution distribution;
    private MultivariateNormalDistribution storedDistribution;
    private boolean distributionKnown;
    private boolean storedDistributionKnown;
    public static XMLObjectParser ACROSS_TREE_TRAIT_MODEL = new AbstractXMLObjectParser(){
        private static final String ACROSS_TREE_TRAIT_NORMAL = "acrossTreeTraitNormalDistribution";
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(TreeDataLikelihood.class, 2, 2), new ElementRule(Parameter.class)};

        @Override
        public String getParserName() {
            return ACROSS_TREE_TRAIT_NORMAL;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            List<TreeDataLikelihood> list = xMLObject.getAllChildren(TreeDataLikelihood.class);
            ArrayList<ContinuousDataLikelihoodDelegate> arrayList = new ArrayList<ContinuousDataLikelihoodDelegate>();
            for (TreeDataLikelihood treeDataLikelihood : list) {
                DataLikelihoodDelegate dataLikelihoodDelegate = treeDataLikelihood.getDataLikelihoodDelegate();
                if (!(dataLikelihoodDelegate instanceof ContinuousDataLikelihoodDelegate)) continue;
                arrayList.add((ContinuousDataLikelihoodDelegate)dataLikelihoodDelegate);
            }
            Parameter parameter = (Parameter)xMLObject.getChild(Parameter.class);
            return new AcrossTreeTraitNormalDistributionModel((ContinuousDataLikelihoodDelegate)arrayList.get(0), (ContinuousDataLikelihoodDelegate)arrayList.get(1), parameter);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public String getParserDescription() {
            return "Parses TreeTraitNormalDistributionModel";
        }

        @Override
        public Class getReturnType() {
            return TreeTraitNormalDistributionModel.class;
        }
    };

    public AcrossTreeTraitNormalDistributionModel(ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate2, Parameter parameter) {
        super("multivariateNormalDistributionModel");
        this.delegate1 = continuousDataLikelihoodDelegate;
        this.delegate2 = continuousDataLikelihoodDelegate2;
        this.rho = parameter;
        this.dim = continuousDataLikelihoodDelegate.getTreeTraitPrecision().length;
        if (continuousDataLikelihoodDelegate.getTraitDim() != continuousDataLikelihoodDelegate2.getTraitDim()) {
            throw new IllegalArgumentException("Unequal trait dimensions");
        }
        if (continuousDataLikelihoodDelegate.getTraitCount() != 1 || continuousDataLikelihoodDelegate2.getTraitCount() != 1) {
            throw new IllegalArgumentException("Only implemented for single traits");
        }
        this.addModel(continuousDataLikelihoodDelegate);
        this.addModel(continuousDataLikelihoodDelegate2);
        this.addVariable(parameter);
        this.distributionKnown = false;
    }

    @Override
    public double logPdf(double[] dArray) {
        this.checkDistribution();
        return this.distribution.logPdf(dArray);
    }

    @Override
    public double[][] getScaleMatrix() {
        this.checkDistribution();
        return this.distribution.getScaleMatrix();
    }

    @Override
    public double[] getMean() {
        this.checkDistribution();
        return this.distribution.getMean();
    }

    @Override
    public String getType() {
        return "TreeTraitMVN";
    }

    @Override
    public void handleModelChangedEvent(Model model, Object object, int n) {
        if (model != this.delegate1 && model != this.delegate2) {
            throw new IllegalArgumentException("Unknown model");
        }
        this.distributionKnown = false;
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable != this.rho) {
            throw new IllegalArgumentException("Unknown correlation");
        }
        this.distributionKnown = false;
    }

    @Override
    protected void storeState() {
        this.storedDistribution = this.distribution;
        this.storedDistributionKnown = this.distributionKnown;
    }

    @Override
    protected void restoreState() {
        this.distributionKnown = this.storedDistributionKnown;
        this.distribution = this.storedDistribution;
    }

    @Override
    protected void acceptState() {
    }

    private void checkDistribution() {
        if (!this.distributionKnown) {
            this.distribution = this.createNewDistribution();
            this.distributionKnown = true;
        }
    }

    private MultivariateNormalDistribution createNewDistribution() {
        return new MultivariateNormalDistribution(this.computeMean(), this.computePrecision());
    }

    private double[] computeMean() {
        return new double[2 * this.dim];
    }

    private double[][] computePrecision() {
        double[][][] dArrayArray = new double[][][]{this.delegate1.getTreeTraitVariance(), this.delegate2.getTreeTraitVariance()};
        if (this.variance == null) {
            this.variance = new double[2 * this.dim][2 * this.dim];
        }
        for (int i = 0; i < 2; ++i) {
            double[][] dArray = dArrayArray[i];
            for (int j = 0; j < this.dim; ++j) {
                for (int k = 0; k < this.dim; ++k) {
                    this.variance[i * this.dim + j][i * this.dim + k] = dArray[j][k];
                }
            }
            double d = this.rho.getParameterValue(0);
            for (int j = i + 1; j < 2; ++j) {
                for (int k = 0; k < this.dim; ++k) {
                    double d2 = Math.sqrt(this.variance[i * this.dim + k][i * this.dim + k]);
                    double d3 = Math.sqrt(this.variance[j * this.dim + k][j * this.dim + k]);
                    this.variance[i * this.dim + k][j * this.dim + k] = d * d2 * d3;
                    this.variance[j * this.dim + k][i * this.dim + k] = d * d2 + d3;
                }
            }
        }
        return new SymmetricMatrix(this.variance).inverse().toComponents();
    }

    @Override
    public double[] nextRandom() {
        this.checkDistribution();
        return this.distribution.nextMultivariateNormal();
    }

    @Override
    public double logPdf(Object object) {
        this.checkDistribution();
        return this.distribution.logPdf(object);
    }

    @Override
    public Variable<Double> getLocationVariable() {
        throw new UnsupportedOperationException("Not implemented");
    }
}

