/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evomodel.treedatalikelihood.discrete.NodeHeightGradientForDiscreteTrait;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightTransform;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.util.Transform;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.Reportable;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;

public class NodeHeightTransformTest
implements Reportable {
    private final NodeHeightTransform nodeHeightTransform;
    private final NodeHeightGradientForDiscreteTrait nodeHeightGradient;
    private final Parameter ratios;
    private final Transform.ComposeMultivariable realLineTransform;
    protected MultivariateFunction numericUnweighted = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            NodeHeightTransformTest.this.nodeHeightTransform.inverse(dArray);
            return NodeHeightTransformTest.this.nodeHeightGradient.getLikelihood().getLogLikelihood();
        }

        @Override
        public int getNumArguments() {
            return NodeHeightTransformTest.this.nodeHeightTransform.getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            return 1.0;
        }
    };
    protected MultivariateFunction numericWeighted = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            NodeHeightTransformTest.this.nodeHeightTransform.inverse(dArray);
            return NodeHeightTransformTest.this.nodeHeightGradient.getLikelihood().getLogLikelihood() - NodeHeightTransformTest.this.nodeHeightTransform.getLogJacobian(dArray);
        }

        @Override
        public int getNumArguments() {
            return NodeHeightTransformTest.this.nodeHeightTransform.getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            return 1.0;
        }
    };
    protected MultivariateFunction numericMultipleWeighted = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            double[] dArray2 = NodeHeightTransformTest.this.realLineTransform.inverse(dArray, 0, dArray.length);
            Parameter parameter = NodeHeightTransformTest.this.nodeHeightTransform.getNodeHeights();
            for (int i = 0; i < dArray2.length; ++i) {
                parameter.setParameterValueQuietly(i, dArray2[i]);
            }
            NodeHeightTransformTest.this.nodeHeightTransform.getNodeHeights().fireParameterChangedEvent();
            NodeHeightTransformTest.this.nodeHeightGradient.getLikelihood().makeDirty();
            double d = NodeHeightTransformTest.this.nodeHeightGradient.getLikelihood().getLogLikelihood() - NodeHeightTransformTest.this.realLineTransform.logJacobian(dArray2, 0, dArray.length);
            return d;
        }

        @Override
        public int getNumArguments() {
            return NodeHeightTransformTest.this.nodeHeightTransform.getNodeHeights().getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return Double.NEGATIVE_INFINITY;
        }

        @Override
        public double getUpperBound(int n) {
            return Double.POSITIVE_INFINITY;
        }
    };
    private static final String NODE_HEIGHT_TRANSFORM_TEST = "nodeHeightTransformTest";
    public static AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser(){

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            NodeHeightTransform nodeHeightTransform = (NodeHeightTransform)xMLObject.getChild(NodeHeightTransform.class);
            NodeHeightGradientForDiscreteTrait nodeHeightGradientForDiscreteTrait = (NodeHeightGradientForDiscreteTrait)xMLObject.getChild(NodeHeightGradientForDiscreteTrait.class);
            Parameter parameter = (Parameter)xMLObject.getChild(Parameter.class);
            return new NodeHeightTransformTest(nodeHeightTransform, nodeHeightGradientForDiscreteTrait, parameter);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return new XMLSyntaxRule[]{new ElementRule(NodeHeightTransform.class), new ElementRule(NodeHeightGradientForDiscreteTrait.class), new ElementRule(Parameter.class)};
        }

        @Override
        public String getParserDescription() {
            return null;
        }

        @Override
        public Class getReturnType() {
            return NodeHeightTransformTest.class;
        }

        @Override
        public String getParserName() {
            return NodeHeightTransformTest.NODE_HEIGHT_TRANSFORM_TEST;
        }
    };

    public NodeHeightTransformTest(NodeHeightTransform nodeHeightTransform, NodeHeightGradientForDiscreteTrait nodeHeightGradientForDiscreteTrait, Parameter parameter) {
        this.nodeHeightTransform = nodeHeightTransform;
        this.nodeHeightGradient = nodeHeightGradientForDiscreteTrait;
        this.ratios = parameter;
        ArrayList<Transform> arrayList = new ArrayList<Transform>();
        if (nodeHeightTransform.getParameter().getDimension() != parameter.getDimension()) {
            arrayList.add(new Transform.LogTransform());
        }
        for (int i = 0; i < parameter.getDimension(); ++i) {
            arrayList.add(new Transform.LogitTransform());
        }
        this.realLineTransform = new Transform.ComposeMultivariable(new Transform.Array(arrayList, nodeHeightTransform.getParameter()), nodeHeightTransform);
    }

    @Override
    public String getReport() {
        String string = this.nodeHeightGradient.getReport();
        double[] dArray = this.nodeHeightGradient.getGradientLogDensity();
        double[] dArray2 = this.nodeHeightTransform.updateGradientUnWeightedLogDensity(dArray, this.nodeHeightTransform.getNodeHeights().getParameterValues(), 0, dArray.length);
        double[] dArray3 = NumericalDerivative.gradient(this.numericUnweighted, this.nodeHeightTransform.transform(this.nodeHeightTransform.getNodeHeights().getParameterValues()));
        double[] dArray4 = this.nodeHeightTransform.updateGradientLogDensity(dArray, this.nodeHeightTransform.getNodeHeights().getParameterValues(), 0, dArray.length);
        double[] dArray5 = NumericalDerivative.gradient(this.numericWeighted, this.nodeHeightTransform.transform(this.nodeHeightTransform.getNodeHeights().getParameterValues()));
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("\nGradient wrt Unweighted LogLikelihood:");
        stringBuilder.append("\nPeeling: ").append(new Vector(dArray2));
        stringBuilder.append("\nNumeric: ").append(new Vector(dArray3));
        stringBuilder.append("\nGradient wrt Weighted LogLikelihood:");
        stringBuilder.append("\nPeeling: ").append(new Vector(dArray4));
        stringBuilder.append("\nNumeric: ").append(new Vector(dArray5));
        double[] dArray6 = this.realLineTransform.updateGradientLogDensity(dArray, this.nodeHeightTransform.getNodeHeights().getParameterValues(), 0, this.nodeHeightTransform.getNodeHeights().getDimension());
        double[] dArray7 = NumericalDerivative.gradient(this.numericMultipleWeighted, this.realLineTransform.transform(this.nodeHeightTransform.getNodeHeights().getParameterValues(), 0, this.nodeHeightTransform.getNodeHeights().getDimension()));
        stringBuilder.append("\nGradient wrt Multiple Weighted LogLikelihood:");
        stringBuilder.append("\nPeeling: ").append(new Vector(dArray6));
        stringBuilder.append("\nNumeric: ").append(new Vector(dArray7));
        return string + stringBuilder.toString();
    }
}

