/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTraitProvider;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.branchratemodel.AbstractBranchRateModel;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.IntegratedSquaredGPApproximation;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class NonParametricBranchRateModel
extends AbstractBranchRateModel
implements DifferentiableBranchRates,
Citable {
    private final Tree tree;
    private final Parameter coefficients;
    private final double degree;
    private final double boundary;
    private final Parameter mean;
    private final Parameter marginalVariance;
    private final Parameter lengthScale;
    private boolean nodeRatesKnown;
    private boolean storedNodeRatesKnown;
    private double[] nodeRates;
    private double[] storedNodeRates;
    private double[] temp;
    private final double rootHeight;

    public NonParametricBranchRateModel(String string, Tree tree, Parameter parameter, double d, double d2, Parameter parameter2, Parameter parameter3, Parameter parameter4) {
        super(string);
        this.tree = tree;
        this.coefficients = parameter;
        this.degree = d;
        this.boundary = d2;
        this.mean = parameter2;
        this.marginalVariance = parameter3;
        this.lengthScale = parameter4;
        if (tree instanceof TreeModel) {
            this.addModel((TreeModel)tree);
        }
        this.addVariable(parameter);
        this.addVariable(parameter2);
        this.addVariable(parameter3);
        this.addVariable(parameter4);
        this.nodeRatesKnown = false;
        this.nodeRates = new double[tree.getNodeCount() - 1];
        this.temp = new double[tree.getNodeCount() - 1];
        this.rootHeight = tree.getNodeHeight(tree.getRoot());
    }

    @Override
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        assert (tree == this.tree);
        if (!this.nodeRatesKnown) {
            TreeTraversal.Rate rate = new TreeTraversal.Rate(this.nodeRates);
            this.calculateNodeGeneric(rate);
            this.nodeRatesKnown = true;
        }
        return this.nodeRates[this.getParameterIndexFromNode(nodeRef)] * (this.rootHeight / 2.0);
    }

    @Override
    public double[] updateGradientLogDensity(double[] dArray, double[] dArray2, int n, int n2) {
        assert (n == 0);
        assert (n2 == this.coefficients.getDimension() - 1);
        double[] dArray3 = new double[this.coefficients.getDimension()];
        Arrays.fill(dArray3, 0.0);
        TreeTraversal.Gradient gradient = new TreeTraversal.Gradient(dArray3, dArray, this.temp);
        this.calculateNodeGeneric(gradient);
        return dArray3;
    }

    private void calculateNodeGeneric(TreeTraversal treeTraversal) {
        NodeRef nodeRef = this.tree.getRoot();
        double d = this.tree.getNodeHeight(nodeRef);
        double d2 = d / 2.0;
        this.traverseTreeByBranchGeneric(d, this.tree.getChild(nodeRef, 0), treeTraversal, d2);
        this.traverseTreeByBranchGeneric(d, this.tree.getChild(nodeRef, 1), treeTraversal, d2);
    }

    private void traverseTreeByBranchGeneric(double d, NodeRef nodeRef, TreeTraversal treeTraversal, double d2) {
        double d3 = this.tree.getNodeHeight(nodeRef);
        int n = this.getParameterIndexFromNode(nodeRef);
        double[] dArray = new double[this.coefficients.getDimension()];
        double d4 = (this.rootHeight / 2.0 - d) / d2;
        double d5 = (this.rootHeight / 2.0 - d3) / d2;
        for (int i = 0; i < this.coefficients.getDimension(); ++i) {
            dArray[i] = this.coefficients.getParameterValue(i);
        }
        IntegratedSquaredGPApproximation integratedSquaredGPApproximation = new IntegratedSquaredGPApproximation(dArray, this.degree, this.boundary, this.mean.getParameterUntransformedValue(0), this.marginalVariance.getParameterValue(0), this.lengthScale.getParameterValue(0));
        if (d5 > d4) {
            treeTraversal.calculate(n, d4, d5, integratedSquaredGPApproximation);
        }
        if (!this.tree.isExternal(nodeRef)) {
            this.traverseTreeByBranchGeneric(d3, this.tree.getChild(nodeRef, 0), treeTraversal, d2);
            this.traverseTreeByBranchGeneric(d3, this.tree.getChild(nodeRef, 1), treeTraversal, d2);
        }
    }

    @Override
    public double getBranchRateDifferential(Tree tree, NodeRef nodeRef) {
        return 1.0;
    }

    @Override
    public double getBranchRateSecondDifferential(Tree tree, NodeRef nodeRef) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public void handleModelChangedEvent(Model model, Object object, int n) {
        this.nodeRatesKnown = false;
        this.fireModelChanged();
        if (model != this.tree) {
            throw new IllegalArgumentException("How did we get here?");
        }
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.nodeRatesKnown = false;
        this.fireModelChanged();
    }

    @Override
    protected void storeState() {
        if (this.storedNodeRates == null) {
            this.storedNodeRates = new double[this.nodeRates.length];
        }
        System.arraycopy(this.nodeRates, 0, this.storedNodeRates, 0, this.nodeRates.length);
        this.storedNodeRatesKnown = this.nodeRatesKnown;
    }

    @Override
    protected void restoreState() {
        double[] dArray = this.nodeRates;
        this.nodeRates = this.storedNodeRates;
        this.storedNodeRates = dArray;
        this.nodeRatesKnown = this.storedNodeRatesKnown;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public Parameter getRateParameter() {
        return this.coefficients;
    }

    @Override
    public int getParameterIndexFromNode(NodeRef nodeRef) {
        int n = nodeRef.getNumber();
        if (n > this.tree.getRoot().getNumber()) {
            --n;
        }
        return n;
    }

    @Override
    public ArbitraryBranchRates.BranchRateTransform getTransform() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[] updateDiagonalHessianLogDensity(double[] dArray, double[] dArray2, double[] dArray3, int n, int n2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.MOLECULAR_CLOCK;
    }

    @Override
    public String getDescription() {
        return "Time-varying branch rate model";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(new Citation(new Author[]{new Author("P", "Datta"), new Author("P", "Lemey"), new Author("MA", "Suchard")}, Citation.Status.IN_PREPARATION));
    }

    @Override
    public String toString() {
        TreeTraitProvider[] treeTraitProviderArray = new TreeTraitProvider[]{this};
        return TreeUtils.newick(this.tree, treeTraitProviderArray);
    }

    static interface TreeTraversal {
        public void calculate(int var1, double var2, double var4, IntegratedSquaredGPApproximation var6);

        public static class Rate
        implements TreeTraversal {
            private final double[] nodeRates;

            Rate(double[] dArray) {
                this.nodeRates = dArray;
            }

            @Override
            public void calculate(int n, double d, double d2, IntegratedSquaredGPApproximation integratedSquaredGPApproximation) {
                double d3 = d2 - d;
                this.nodeRates[n] = integratedSquaredGPApproximation.getIntegral(d, d2) / d3;
            }
        }

        public static class Gradient
        implements TreeTraversal {
            private final double[] gradientCoefficients;
            private final double[] gradientNodes;
            private final double[] temp;

            Gradient(double[] dArray, double[] dArray2, double[] dArray3) {
                this.gradientCoefficients = dArray;
                this.gradientNodes = dArray2;
                this.temp = dArray3;
            }

            @Override
            public void calculate(int n, double d, double d2, IntegratedSquaredGPApproximation integratedSquaredGPApproximation) {
                double d3 = d2 - d;
                for (int i = 0; i < this.gradientCoefficients.length; ++i) {
                    this.temp[i] = integratedSquaredGPApproximation.getGradientWrtCoefficient(d, d2, i) / d3;
                    int n2 = i;
                    this.gradientCoefficients[n2] = this.gradientCoefficients[n2] + this.temp[i] * this.gradientNodes[n];
                }
            }
        }
    }
}

