/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.ProcessSimulation;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.discrete.DiscreteTraitBranchRateDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MachineAccuracy;
import dr.math.MultivariateFunction;
import dr.xml.Reportable;

public class DiscreteTraitBranchRateGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable,
Loggable {
    protected final TreeDataLikelihood treeDataLikelihood;
    protected final TreeTrait treeTraitProvider;
    protected final Tree tree;
    protected final boolean useHessian;
    protected final Parameter rateParameter;
    protected final ArbitraryBranchRates branchRateModel;
    protected MultivariateFunction numeric1 = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            for (int i = 0; i < dArray.length; ++i) {
                DiscreteTraitBranchRateGradient.this.rateParameter.setParameterValue(i, dArray[i]);
            }
            return DiscreteTraitBranchRateGradient.this.treeDataLikelihood.getLogLikelihood();
        }

        @Override
        public int getNumArguments() {
            return DiscreteTraitBranchRateGradient.this.rateParameter.getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            return Double.POSITIVE_INFINITY;
        }
    };
    private static final boolean DEBUG = true;
    protected static final boolean COUNT_TOTAL_OPERATIONS = true;
    protected long getGradientLogDensityCount = 0L;

    public DiscreteTraitBranchRateGradient(String string, TreeDataLikelihood treeDataLikelihood, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, Parameter parameter, boolean bl) {
        assert (treeDataLikelihood != null);
        this.treeDataLikelihood = treeDataLikelihood;
        this.tree = treeDataLikelihood.getTree();
        this.rateParameter = parameter;
        this.useHessian = bl;
        BranchRateModel branchRateModel = treeDataLikelihood.getBranchRateModel();
        this.branchRateModel = branchRateModel instanceof ArbitraryBranchRates ? (ArbitraryBranchRates)branchRateModel : null;
        String string2 = DiscreteTraitBranchRateDelegate.getName(string);
        TreeTrait treeTrait = treeDataLikelihood.getTreeTrait(string2);
        if (treeTrait == null) {
            DiscreteTraitBranchRateDelegate discreteTraitBranchRateDelegate = new DiscreteTraitBranchRateDelegate(string, treeDataLikelihood.getTree(), beagleDataLikelihoodDelegate);
            ProcessSimulation processSimulation = new ProcessSimulation(treeDataLikelihood, discreteTraitBranchRateDelegate);
            treeDataLikelihood.addTraits(processSimulation.getTreeTraits());
        }
        this.treeTraitProvider = treeDataLikelihood.getTreeTrait(string2);
        assert (this.treeTraitProvider != null);
        int n = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount();
        if (n != 1) {
            throw new RuntimeException("Not yet implemented for >1 traits");
        }
    }

    @Override
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.rateParameter;
    }

    @Override
    public int getDimension() {
        return this.getParameter().getDimension();
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        double[] dArray = new double[this.tree.getNodeCount() - 1];
        double[] dArray2 = (double[])this.treeDataLikelihood.getTreeTrait(DiscreteTraitBranchRateDelegate.HESSIAN_TRAIT_NAME).getTrait(this.tree, null);
        double[] dArray3 = (double[])this.treeTraitProvider.getTrait(this.tree, null);
        int n = 0;
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            int n2 = this.getParameterIndexFromNode(nodeRef);
            double d = this.getChainGradient(this.tree, nodeRef);
            double d2 = this.getChainSecondDerivative(this.tree, nodeRef);
            dArray[n2] = dArray2[n] * d * d + dArray3[n] * d2;
            ++n;
        }
        return dArray;
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = new double[this.tree.getNodeCount() - 1];
        double[] dArray2 = (double[])this.treeTraitProvider.getTrait(this.tree, null);
        int n = 0;
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            double d;
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            int n2 = this.getParameterIndexFromNode(nodeRef);
            dArray[n2] = d = dArray2[n] * this.getChainGradient(this.tree, nodeRef);
            ++n;
        }
        ++this.getGradientLogDensityCount;
        return dArray;
    }

    protected double getChainGradient(Tree tree, NodeRef nodeRef) {
        return tree.getBranchLength(nodeRef);
    }

    protected double getChainSecondDerivative(Tree tree, NodeRef nodeRef) {
        return 0.0;
    }

    protected int getParameterIndexFromNode(NodeRef nodeRef) {
        return this.branchRateModel == null ? nodeRef.getNumber() : this.branchRateModel.getParameterIndexFromNode(nodeRef);
    }

    protected boolean valuesAreSufficientlyLarge(double[] dArray) {
        for (double d : dArray) {
            if (!(Math.abs(d) < MachineAccuracy.SQRT_EPSILON * 1.2)) continue;
            return false;
        }
        return true;
    }

    @Override
    public String getReport() {
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("\n\tgetGradientLogDensityCount = ").append(this.getGradientLogDensityCount).append("\n");
        stringBuilder.append(this.treeTraitProvider.toString()).append("\n");
        stringBuilder.append(this.treeDataLikelihood.getReport());
        String string = GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, null);
        if (this.useHessian) {
            string = string + "Hessian\n";
            string = string + HessianWrtParameterProvider.getReportAndCheckForError(this, null);
        }
        string = string + stringBuilder.toString();
        return string;
    }

    @Override
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArray = new LogColumn[]{new LogColumn.Default("gradient report", new Object(){

            public String toString() {
                return "\n" + DiscreteTraitBranchRateGradient.this.getReport();
            }
        })};
        return logColumnArray;
    }
}

