/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent;

import dr.evomodel.coalescent.OldAbstractCoalescentLikelihood;
import dr.evomodel.coalescent.OldGMRFSkyrideLikelihood;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightTransform;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;

public class GMRFSkyrideGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable {
    private final OldGMRFSkyrideLikelihood skyrideLikelihood;
    private final WrtParameter wrtParameter;
    private final Parameter parameter;
    private final OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping;
    private final NodeHeightTransform nodeHeightTransform;
    private MultivariateFunction numeric1 = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            if (GMRFSkyrideGradient.this.nodeHeightTransform != null) {
                GMRFSkyrideGradient.this.wrtParameter.update(GMRFSkyrideGradient.this.nodeHeightTransform, dArray);
            } else {
                for (int i = 0; i < GMRFSkyrideGradient.this.parameter.getDimension(); ++i) {
                    GMRFSkyrideGradient.this.parameter.setParameterValueQuietly(i, dArray[i]);
                }
                GMRFSkyrideGradient.this.parameter.fireParameterChangedEvent();
            }
            GMRFSkyrideGradient.this.skyrideLikelihood.makeDirty();
            return GMRFSkyrideGradient.this.skyrideLikelihood.getLogLikelihood();
        }

        @Override
        public int getNumArguments() {
            return GMRFSkyrideGradient.this.getParameter().getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            return Double.POSITIVE_INFINITY;
        }
    };

    public GMRFSkyrideGradient(OldGMRFSkyrideLikelihood oldGMRFSkyrideLikelihood, WrtParameter wrtParameter, TreeModel treeModel, NodeHeightTransform nodeHeightTransform) {
        this.skyrideLikelihood = oldGMRFSkyrideLikelihood;
        this.intervalNodeMapping = this.skyrideLikelihood.getIntervalNodeMapping();
        this.wrtParameter = wrtParameter;
        this.nodeHeightTransform = nodeHeightTransform;
        this.parameter = nodeHeightTransform == null ? new NodeHeightProxyParameter("internalNodeHeights", treeModel, true) : nodeHeightTransform.getParameter();
    }

    @Override
    public Likelihood getLikelihood() {
        return this.skyrideLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.getParameter().getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        return this.wrtParameter.getGradientLogDensity(this.skyrideLikelihood, this.intervalNodeMapping);
    }

    @Override
    public String getReport() {
        double[] dArray = this.getParameter().getParameterValues();
        double[] dArray2 = NumericalDerivative.gradient(this.numeric1, this.getParameter().getParameterValues());
        for (int i = 0; i < dArray.length; ++i) {
            this.getParameter().setParameterValue(i, dArray[i]);
        }
        double[] dArray3 = NumericalDerivative.diagonalHessian(this.numeric1, this.getParameter().getParameterValues());
        for (int i = 0; i < dArray.length; ++i) {
            this.getParameter().setParameterValue(i, dArray[i]);
        }
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("analytic: ").append(new Vector(this.getGradientLogDensity()));
        stringBuilder.append("\n");
        stringBuilder.append("numeric: ").append(new Vector(dArray2));
        stringBuilder.append("\n");
        stringBuilder.append("analytic diagonal Hessian: ").append(new Vector(this.getDiagonalHessianLogDensity()));
        stringBuilder.append("\n");
        stringBuilder.append("numeric diagonal Hessian: ").append(new Vector(dArray3));
        stringBuilder.append("\n");
        return stringBuilder.toString();
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        return new double[this.getDimension()];
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented!");
    }

    public static enum WrtParameter {
        COALESCENT_INTERVAL{

            @Override
            double[] getGradientLogDensity(OldGMRFSkyrideLikelihood oldGMRFSkyrideLikelihood, OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping) {
                double[] dArray = super.getGradientLogDensityWrtUnsortedNodeHeight(oldGMRFSkyrideLikelihood);
                double[] dArray2 = new double[dArray.length];
                double d = 0.0;
                for (int i = dArray.length - 1; i > -1; --i) {
                    dArray2[i] = d += dArray[i];
                }
                return dArray2;
            }

            @Override
            void update(NodeHeightTransform nodeHeightTransform, double[] dArray) {
                nodeHeightTransform.inverse(dArray, 0, dArray.length);
            }
        }
        ,
        NODE_HEIGHTS{

            @Override
            double[] getGradientLogDensity(OldGMRFSkyrideLikelihood oldGMRFSkyrideLikelihood, OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping) {
                double[] dArray = this.getGradientLogDensityWrtUnsortedNodeHeight(oldGMRFSkyrideLikelihood);
                return intervalNodeMapping.sortByNodeNumbers(dArray);
            }

            @Override
            void update(NodeHeightTransform nodeHeightTransform, double[] dArray) {
                nodeHeightTransform.transform(dArray, 0, dArray.length);
            }
        };


        abstract double[] getGradientLogDensity(OldGMRFSkyrideLikelihood var1, OldAbstractCoalescentLikelihood.IntervalNodeMapping var2);

        abstract void update(NodeHeightTransform var1, double[] var2);

        double[] getGradientLogDensityWrtUnsortedNodeHeight(OldGMRFSkyrideLikelihood oldGMRFSkyrideLikelihood) {
            double[] dArray = new double[oldGMRFSkyrideLikelihood.getCoalescentIntervalDimension()];
            double[] dArray2 = oldGMRFSkyrideLikelihood.getPopSizeParameter().getParameterValues();
            int n = 0;
            for (int i = 0; i < oldGMRFSkyrideLikelihood.getIntervalCount(); ++i) {
                if (oldGMRFSkyrideLikelihood.getIntervalType(i) != OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT) continue;
                double d = -Math.exp(-dArray2[n]) * (double)oldGMRFSkyrideLikelihood.getLineageCount(i) * (double)(oldGMRFSkyrideLikelihood.getLineageCount(i) - 1);
                if (n < oldGMRFSkyrideLikelihood.getCoalescentIntervalDimension() - 1 && i < oldGMRFSkyrideLikelihood.getIntervalCount() - 1) {
                    d -= -Math.exp(-dArray2[n + 1]) * (double)oldGMRFSkyrideLikelihood.getLineageCount(i + 1) * (double)(oldGMRFSkyrideLikelihood.getLineageCount(i + 1) - 1);
                }
                dArray[n] = d / 2.0;
                ++n;
            }
            return dArray;
        }
    }
}

