/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.operators;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evomodel.continuous.FullyConjugateMultivariateTraitLikelihood;
import dr.evomodel.continuous.LatentTruncation;
import dr.inference.model.CompoundParameter;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Parameter;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.NormalDistribution;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.logging.Logger;

public class LatentLiabilityGibbs
extends SimpleMCMCOperator {
    public static final String LATENT_LIABILITY_GIBBS_OPERATOR = "latentLiabilityGibbsOperator";
    public static final String TREE_MODEL = "treeModel";
    private final LatentTruncation latentLiability;
    private final FullyConjugateMultivariateTraitLikelihood traitModel;
    private final CompoundParameter tipTraitParameter;
    protected double[] rootPriorMean;
    protected double rootPriorSampleSize;
    private final MatrixParameter precisionParam;
    private final MutableTreeModel treeModel;
    private final int dim;
    public double[][] postMeans;
    public double[][] preMeans;
    public double[] preP;
    public double[] postP;
    private Parameter mask;
    private boolean hasMask = false;
    private int numFixed = 0;
    private int numUpdate = 0;
    private int[] doUpdate;
    private int[] dontUpdate;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        public static final String MASK = "mask";
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("weight"), new ElementRule(FullyConjugateMultivariateTraitLikelihood.class, "The model for the latent random variables"), new ElementRule(LatentTruncation.class, "The model that links latent and observed variables"), new ElementRule("mask", Parameter.class, "Mask: 1 for latent variables that should be sampled", true), new ElementRule(CompoundParameter.class, "The parameter of tip locations from the tree")};

        @Override
        public String getParserName() {
            return LatentLiabilityGibbs.LATENT_LIABILITY_GIBBS_OPERATOR;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            if (xMLObject.getChildCount() < 3) {
                throw new XMLParseException("Element with id = '" + xMLObject.getName() + "' should contain:\n\t 1 conjugate multivariateTraitLikelihood, 1 latentLiabilityLikelihood and one parameter \n");
            }
            double d = xMLObject.getDoubleAttribute("weight");
            FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood = (FullyConjugateMultivariateTraitLikelihood)xMLObject.getChild(FullyConjugateMultivariateTraitLikelihood.class);
            LatentTruncation latentTruncation = (LatentTruncation)xMLObject.getChild(LatentTruncation.class);
            CompoundParameter compoundParameter = (CompoundParameter)xMLObject.getChild(CompoundParameter.class);
            Parameter parameter = null;
            if (xMLObject.hasChildNamed(MASK)) {
                parameter = (Parameter)xMLObject.getElementFirstChild(MASK);
            }
            return new LatentLiabilityGibbs(fullyConjugateMultivariateTraitLikelihood, latentTruncation, compoundParameter, parameter, d);
        }

        @Override
        public String getParserDescription() {
            return "This element returns a gibbs sampler on tip latent trais for latent liability model.";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public LatentLiabilityGibbs(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, LatentTruncation latentTruncation, CompoundParameter compoundParameter, Parameter parameter, double d) {
        this.latentLiability = latentTruncation;
        this.traitModel = fullyConjugateMultivariateTraitLikelihood;
        this.tipTraitParameter = compoundParameter;
        this.rootPriorMean = fullyConjugateMultivariateTraitLikelihood.getPriorMean();
        this.rootPriorSampleSize = fullyConjugateMultivariateTraitLikelihood.getPriorSampleSize();
        this.precisionParam = (MatrixParameter)fullyConjugateMultivariateTraitLikelihood.getDiffusionModel().getPrecisionParameter();
        this.treeModel = fullyConjugateMultivariateTraitLikelihood.getTreeModel();
        this.dim = this.precisionParam.getRowDimension();
        this.mask = parameter;
        if (parameter != null) {
            this.hasMask = true;
        }
        this.postMeans = new double[this.treeModel.getNodeCount()][this.dim];
        this.preMeans = new double[this.treeModel.getNodeCount()][this.dim];
        this.preP = new double[this.treeModel.getNodeCount()];
        this.postP = new double[this.treeModel.getNodeCount()];
        this.dontUpdate = new int[this.dim];
        this.doUpdate = new int[this.dim];
        if (this.hasMask) {
            for (int i = 0; i < this.dim; ++i) {
                if (parameter.getParameterValue(i) == 0.0) {
                    this.dontUpdate[this.numFixed] = i;
                    ++this.numFixed;
                    continue;
                }
                this.doUpdate[this.numUpdate] = i;
                ++this.numUpdate;
            }
        }
        this.setWeight(d);
    }

    public int getStepCount() {
        return 1;
    }

    private void printInformation(MatrixParameter matrixParameter) {
        StringBuffer stringBuffer = new StringBuffer("\n \n parameter \n");
        for (int i = 0; i < this.dim; ++i) {
            stringBuffer.append(matrixParameter.getParameterValue(0, i));
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    private void printInformation(double[] dArray) {
        StringBuffer stringBuffer = new StringBuffer("\n \n double vector \n");
        for (int i = 0; i < this.treeModel.getNodeCount(); ++i) {
            stringBuffer.append(dArray[i]);
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    private void printInformation(double[][] dArray) {
        StringBuffer stringBuffer = new StringBuffer("\n \n double matrix \n");
        for (int i = 0; i < 1; ++i) {
            for (int j = 0; j < this.treeModel.getNodeCount(); ++j) {
                stringBuffer.append(dArray[j][i]);
            }
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    private void printInformation(double d) {
        StringBuffer stringBuffer = new StringBuffer("\n \n double \n");
        stringBuffer.append(d);
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    private void printInformation(double d, String string) {
        StringBuffer stringBuffer = new StringBuffer("\n");
        stringBuffer.append(string);
        stringBuffer.append("\t\t");
        stringBuffer.append(d);
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    @Override
    public double doOperation() {
        int n = MathUtils.nextInt(this.treeModel.getExternalNodeCount());
        NodeRef nodeRef = this.treeModel.getExternalNode(n);
        double d = this.sampleNode2(nodeRef);
        this.tipTraitParameter.fireParameterChangedEvent();
        return d;
    }

    public void doPostOrderTraversal(NodeRef nodeRef) {
        int n = nodeRef.getNumber();
        if (this.treeModel.isExternal(nodeRef)) {
            double[] dArray = this.getNodeTrait(nodeRef);
            for (int i = 0; i < this.dim; ++i) {
                this.postMeans[n][i] = dArray[i];
            }
            this.postP[n] = 1.0 / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef);
            return;
        }
        NodeRef nodeRef2 = this.treeModel.getChild(nodeRef, 0);
        NodeRef nodeRef3 = this.treeModel.getChild(nodeRef, 1);
        this.doPostOrderTraversal(nodeRef2);
        this.doPostOrderTraversal(nodeRef3);
        if (!this.treeModel.isRoot(nodeRef)) {
            int n2 = nodeRef2.getNumber();
            int n3 = nodeRef3.getNumber();
            double d = this.postP[n2];
            double d2 = this.postP[n3];
            double d3 = 1.0 / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef);
            double d4 = d + d2;
            this.postP[n] = d4 * d3 / (d4 + d3);
            for (int i = 0; i < this.dim; ++i) {
                this.postMeans[n][i] = (d * this.postMeans[n2][i] + d2 * this.postMeans[n3][i]) / (d + d2);
            }
        }
    }

    public double[] getNodeTrait(NodeRef nodeRef) {
        int n = nodeRef.getNumber();
        double[] dArray = this.tipTraitParameter.getParameter(n).getParameterValues();
        return dArray;
    }

    public double getNodeTrait(NodeRef nodeRef, int n) {
        int n2 = nodeRef.getNumber();
        double d = this.tipTraitParameter.getParameter(n2).getParameterValue(n);
        return d;
    }

    public void setNodeTrait(NodeRef nodeRef, double[] dArray) {
        int n = nodeRef.getNumber();
        for (int i = 0; i < this.dim; ++i) {
            this.tipTraitParameter.getParameter(n).setParameterValue(i, dArray[i]);
        }
        this.traitModel.getTraitParameter().getParameter(n).fireParameterChangedEvent();
    }

    public void setNodeTrait(NodeRef nodeRef, int n, double d) {
        int n2 = nodeRef.getNumber();
        this.tipTraitParameter.getParameter(n2).setParameterValue(n, d);
    }

    public void doPreOrderTraversal(NodeRef nodeRef) {
        int n = nodeRef.getNumber();
        if (this.treeModel.isRoot(nodeRef)) {
            this.preP[n] = this.rootPriorSampleSize;
            for (int i = 0; i < this.dim; ++i) {
                this.preMeans[n][i] = this.rootPriorMean[i];
            }
        } else {
            NodeRef nodeRef2 = this.treeModel.getParent(nodeRef);
            NodeRef nodeRef3 = this.getSisterNode(nodeRef);
            int n2 = nodeRef2.getNumber();
            int n3 = nodeRef3.getNumber();
            double d = this.preP[n2];
            double d2 = this.postP[n3];
            double d3 = 1.0 / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef);
            double d4 = d + d2;
            this.preP[n] = d4 * d3 / (d4 + d3);
            for (int i = 0; i < this.dim; ++i) {
                this.preMeans[n][i] = (d * this.preMeans[n2][i] + d2 * this.postMeans[n3][i]) / (d + d2);
            }
        }
        if (this.treeModel.isExternal(nodeRef)) {
            return;
        }
        this.doPreOrderTraversal(this.treeModel.getChild(nodeRef, 0));
        this.doPreOrderTraversal(this.treeModel.getChild(nodeRef, 1));
    }

    public NodeRef getSisterNode(NodeRef nodeRef) {
        NodeRef nodeRef2 = this.treeModel.getChild(this.treeModel.getParent(nodeRef), 0);
        NodeRef nodeRef3 = this.treeModel.getChild(this.treeModel.getParent(nodeRef), 1);
        if (nodeRef2 == nodeRef) {
            return nodeRef3;
        }
        return nodeRef2;
    }

    public double sampleNode(NodeRef nodeRef) {
        int n;
        int n2 = nodeRef.getNumber();
        double[] dArray = this.getNodeTrait(nodeRef);
        double[] dArray2 = new double[this.dim];
        for (int i = 0; i < this.dim; ++i) {
            dArray2[i] = this.preMeans[n2][i];
        }
        double d = this.preP[n2];
        double[][] dArray3 = new double[this.dim][this.dim];
        for (n = 0; n < this.dim; ++n) {
            for (int i = 0; i < this.dim; ++i) {
                dArray3[n][i] = d * this.precisionParam.getParameterValue(n, i);
            }
        }
        n = MathUtils.nextInt(this.dim);
        double d2 = this.getConditionalMean(n, dArray3, dArray, dArray2);
        double d3 = Math.sqrt(1.0 / dArray3[n][n]);
        double d4 = this.getNodeTrait(nodeRef, n);
        double d5 = MathUtils.nextGaussian();
        d5 *= d3;
        NormalDistribution normalDistribution = new NormalDistribution(d2, d3);
        double d6 = normalDistribution.logPdf(d4);
        double d7 = normalDistribution.logPdf(d5 += d2);
        this.setNodeTrait(nodeRef, n, d5);
        double d8 = d6 - d7;
        this.traitModel.getTraitParameter().getParameter(n2).fireParameterChangedEvent();
        return d8;
    }

    public double sampleNode2(NodeRef nodeRef) {
        double[] dArray;
        int n = nodeRef.getNumber();
        double[] dArray2 = this.traitModel.getConditionalMean(n);
        double[][] dArray3 = this.traitModel.getConditionalPrecision(n);
        double[] dArray4 = dArray = this.getNodeTrait(nodeRef);
        int n2 = 0;
        boolean bl = false;
        if (this.hasMask) {
            double[] dArray5 = new double[this.numUpdate];
            double[] dArray6 = new double[this.numUpdate];
            double[][] dArray7 = new double[this.numUpdate][this.numUpdate];
            for (int i = 0; i < this.numUpdate; ++i) {
                dArray5[i] = dArray4[this.doUpdate[i]];
                for (int j = 0; j < this.numUpdate; ++j) {
                    dArray7[i][j] = dArray3[this.doUpdate[i]][this.doUpdate[j]];
                }
            }
            double[] dArray8 = this.getComponentConditionalMean(dArray3, dArray, dArray2, dArray7);
            MultivariateNormalDistribution multivariateNormalDistribution = new MultivariateNormalDistribution(dArray8, dArray7);
            while (!bl & n2 < 10000) {
                dArray6 = multivariateNormalDistribution.nextMultivariateNormal();
                for (int i = 0; i < this.numUpdate; ++i) {
                    dArray4[this.doUpdate[i]] = dArray6[i];
                }
                this.setNodeTrait(nodeRef, dArray4);
                if (this.latentLiability.validTraitForTip(n)) {
                    bl = true;
                }
                ++n2;
            }
            double d = multivariateNormalDistribution.logPdf(dArray5);
            double d2 = multivariateNormalDistribution.logPdf(dArray6);
            double d3 = d - d2;
            this.traitModel.getTraitParameter().getParameter(n).fireParameterChangedEvent();
            return d3;
        }
        MultivariateNormalDistribution multivariateNormalDistribution = new MultivariateNormalDistribution(dArray2, dArray3);
        while (!bl & n2 < 10000) {
            dArray4 = multivariateNormalDistribution.nextMultivariateNormal();
            this.setNodeTrait(nodeRef, dArray4);
            if (this.latentLiability.validTraitForTip(n)) {
                bl = true;
            }
            ++n2;
        }
        double d = multivariateNormalDistribution.logPdf(dArray);
        double d4 = multivariateNormalDistribution.logPdf(dArray4);
        double d5 = d - d4;
        this.traitModel.getTraitParameter().getParameter(n).fireParameterChangedEvent();
        return d5;
    }

    private double[] getComponentConditionalMean(double[][] dArray, double[] dArray2, double[] dArray3, double[][] dArray4) {
        int n;
        double[] dArray5 = new double[this.numUpdate];
        double[][] dArray6 = new double[this.numUpdate][this.numFixed];
        Matrix matrix = new Matrix(this.numUpdate, this.numFixed);
        Vector vector = new Vector(this.numUpdate);
        double[] dArray7 = new double[this.numFixed];
        for (n = 0; n < this.numUpdate; ++n) {
            for (int i = 0; i < this.numFixed; ++i) {
                dArray6[n][i] = dArray[this.doUpdate[n]][this.dontUpdate[i]];
            }
        }
        for (n = 0; n < this.numFixed; ++n) {
            dArray7[n] = dArray2[this.dontUpdate[n]] - dArray3[this.dontUpdate[n]];
        }
        SymmetricMatrix symmetricMatrix = new SymmetricMatrix(dArray4).inverse();
        Matrix matrix2 = new Matrix(dArray6);
        try {
            matrix = symmetricMatrix.product(matrix2);
            vector = matrix.product(new Vector(dArray7));
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        for (int i = 0; i < this.numUpdate; ++i) {
            dArray5[i] = dArray3[this.doUpdate[i]] - vector.component(i);
        }
        return dArray5;
    }

    private double getConditionalMean(int n, double[][] dArray, double[] dArray2, double[] dArray3) {
        double d = 0.0;
        for (int i = 0; i < this.dim; ++i) {
            if (i == n) continue;
            d += dArray[n][i] * (dArray2[i] - dArray3[i]);
        }
        double d2 = dArray3[n] - d / dArray[n][n];
        return d2;
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override
    public String getOperatorName() {
        return LATENT_LIABILITY_GIBBS_OPERATOR;
    }
}

