/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.hmc;

import dr.inference.model.DerivativeOrder;
import dr.inference.model.DerivativeProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.List;
import java.util.logging.Logger;

public interface DerivativeWrtParameterProvider {
    public static final Double TOLERANCE = 0.1;

    public Likelihood getLikelihood();

    public Parameter getParameter();

    public int getDimension(DerivativeOrder var1);

    public double[] getDerivativeLogDensity(DerivativeOrder var1);

    public DerivativeOrder getHighestOrder();

    public static DerivativeOrder getHighestOrder(List<DerivativeWrtParameterProvider> list) {
        if (list.size() == 0) {
            return DerivativeOrder.ZEROTH;
        }
        DerivativeOrder derivativeOrder = DerivativeOrder.FULL_HESSIAN;
        for (DerivativeWrtParameterProvider derivativeWrtParameterProvider : list) {
            if (derivativeWrtParameterProvider.getHighestOrder().getValue() >= derivativeOrder.getValue()) continue;
            derivativeOrder = derivativeWrtParameterProvider.getHighestOrder();
        }
        return derivativeOrder;
    }

    public static String makeReport(String string, double[] dArray, double[] dArray2, boolean bl, double d) throws MismatchException {
        StringBuilder stringBuilder = new StringBuilder(string);
        stringBuilder.append("analytic: ").append(new Vector(dArray));
        stringBuilder.append("\n");
        stringBuilder.append("numeric : ").append(new Vector(dArray2));
        if (bl) {
            for (int i = 0; i < dArray.length; ++i) {
                double d2 = 2.0 * (dArray[i] - dArray2[i]) / (dArray[i] + dArray2[i]);
                if (!(Math.abs(d2) > d)) continue;
                stringBuilder.append("\nDifference @ ").append(i + 1).append(": ").append(dArray[i]).append(" ").append(dArray2[i]).append(" ").append(d2).append("\n");
                Logger.getLogger("dr.inference.hmc").info(stringBuilder.toString());
                throw new MismatchException();
            }
        }
        return stringBuilder.toString();
    }

    public static String getReportAndCheckForError(DerivativeWrtParameterProvider derivativeWrtParameterProvider, DerivativeOrder derivativeOrder, double d, double d2, Double d3) {
        String string;
        try {
            string = new CheckDerivativeNumerically(derivativeWrtParameterProvider, derivativeOrder, d, d2, d3).getReport();
        }
        catch (MismatchException mismatchException) {
            String string2 = mismatchException.getMessage();
            if (string2 == null) {
                string2 = derivativeWrtParameterProvider.getParameter().getParameterName();
            }
            if (string2 == null) {
                string2 = "Gradient check failure";
            }
            throw new RuntimeException(string2);
        }
        return string;
    }

    public static class CheckDerivativeNumerically {
        private final DerivativeWrtParameterProvider provider;
        private final DerivativeOrder type;
        private final Parameter parameter;
        private final double lowerBound;
        private final double upperBound;
        private final boolean checkValues;
        private final double tolerance;
        private MultivariateFunction numeric = new MultivariateFunction(){

            @Override
            public double evaluate(double[] dArray) {
                this.setParameter(dArray);
                if (type == DerivativeOrder.GRADIENT) {
                    return provider.getLikelihood().getLogLikelihood();
                }
                throw new RuntimeException("Not yet implemented");
            }

            @Override
            public int getNumArguments() {
                return parameter.getDimension();
            }

            @Override
            public double getLowerBound(int n) {
                return lowerBound;
            }

            @Override
            public double getUpperBound(int n) {
                return upperBound;
            }
        };

        CheckDerivativeNumerically(DerivativeWrtParameterProvider derivativeWrtParameterProvider, DerivativeOrder derivativeOrder, double d, double d2, Double d3) {
            this.provider = derivativeWrtParameterProvider;
            this.type = derivativeOrder;
            this.parameter = derivativeWrtParameterProvider.getParameter();
            this.lowerBound = d;
            this.upperBound = d2;
            this.checkValues = d3 != null;
            this.tolerance = this.checkValues ? d3 : 0.0;
        }

        private void setParameter(double[] dArray) {
            for (int i = 0; i < dArray.length; ++i) {
                this.parameter.setParameterValueQuietly(i, dArray[i]);
            }
            this.parameter.fireParameterChangedEvent();
        }

        private double[] getNumericalGradient() {
            double[] dArray = this.parameter.getParameterValues();
            double[] dArray2 = NumericalDerivative.gradient(this.numeric, this.parameter.getParameterValues());
            this.setParameter(dArray);
            return dArray2;
        }

        public String getReport() throws MismatchException {
            double[] dArray = this.provider.getDerivativeLogDensity(this.type);
            double[] dArray2 = this.getNumericalGradient();
            return DerivativeWrtParameterProvider.makeReport("Gradient\n", dArray, dArray2, this.checkValues, this.tolerance);
        }
    }

    public static class MismatchException
    extends Exception {
    }

    public static class ParameterWrapper
    implements DerivativeWrtParameterProvider,
    Reportable {
        final DerivativeProvider provider;
        final Parameter parameter;
        final Likelihood likelihood;

        public ParameterWrapper(DerivativeProvider derivativeProvider, Parameter parameter, Likelihood likelihood) {
            this.provider = derivativeProvider;
            this.parameter = parameter;
            this.likelihood = likelihood;
        }

        @Override
        public Likelihood getLikelihood() {
            return this.likelihood;
        }

        @Override
        public Parameter getParameter() {
            return this.parameter;
        }

        @Override
        public int getDimension(DerivativeOrder derivativeOrder) {
            return derivativeOrder.getDerivativeDimension(this.parameter.getDimension());
        }

        @Override
        public double[] getDerivativeLogDensity(DerivativeOrder derivativeOrder) {
            assert (this.provider.getHighestOrder().getValue() >= derivativeOrder.getValue());
            return this.provider.getDerivativeLogDensity(this.parameter.getParameterValues(), derivativeOrder);
        }

        @Override
        public DerivativeOrder getHighestOrder() {
            return this.provider.getHighestOrder();
        }

        @Override
        public String getReport() {
            return null;
        }
    }
}

