/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PrecisionColumnProvider;
import dr.inference.hmc.PrecisionMatrixVectorProductProvider;
import dr.inference.hmc.ReversibleHMCProvider;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.AbstractParticleOperator;
import dr.inference.operators.hmc.AbstractZigZagOperator;
import dr.inference.operators.hmc.MinimumTravelInformation;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.TaskPool;
import dr.xml.Reportable;
import java.util.function.BinaryOperator;

public class ReversibleZigZagOperator
extends AbstractZigZagOperator
implements Reportable,
ReversibleHMCProvider {
    public ReversibleZigZagOperator(GradientWrtParameterProvider gradientWrtParameterProvider, PrecisionMatrixVectorProductProvider precisionMatrixVectorProductProvider, PrecisionColumnProvider precisionColumnProvider, double d, AbstractParticleOperator.Options options, Parameter parameter, int n) {
        super(gradientWrtParameterProvider, precisionMatrixVectorProductProvider, precisionColumnProvider, d, options, parameter, n);
    }

    @Override
    public String getOperatorName() {
        return "Zig-zag particle operator";
    }

    @Override
    MinimumTravelInformation getNextBounce(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        this.timer.startTimer("getNext");
        MinimumTravelInformation minimumTravelInformation = this.taskPool != null ? this.getNextBounceParallel(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5) : this.getNextBounceSerial(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5);
        this.timer.stopTimer("getNext");
        return minimumTravelInformation;
    }

    private MinimumTravelInformation getNextBounceSerial(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        return this.getNextBounceImpl(0, wrappedVector.getDim(), wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer());
    }

    private MinimumTravelInformation getNextBounceParallel(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        double[] dArray = wrappedVector.getBuffer();
        double[] dArray2 = wrappedVector2.getBuffer();
        double[] dArray3 = wrappedVector3.getBuffer();
        double[] dArray4 = wrappedVector4.getBuffer();
        double[] dArray5 = wrappedVector5.getBuffer();
        TaskPool.RangeCallable<MinimumTravelInformation> rangeCallable = (n, n2, n3) -> this.getNextBounceImpl(n, n2, dArray, dArray2, dArray3, dArray4, dArray5);
        BinaryOperator binaryOperator = (minimumTravelInformation, minimumTravelInformation2) -> minimumTravelInformation.time < minimumTravelInformation2.time ? minimumTravelInformation : minimumTravelInformation2;
        return this.taskPool.mapReduce(rangeCallable, binaryOperator);
    }

    private MinimumTravelInformation getNextBounceImpl(int n, int n2, double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4, double[] dArray5) {
        double d = Double.POSITIVE_INFINITY;
        int n3 = -1;
        AbstractParticleOperator.Type type = AbstractParticleOperator.Type.NONE;
        for (int i = n; i < n2; ++i) {
            double d2;
            double d3 = this.findBoundaryTime(i, dArray[i], dArray2[i]);
            if (d3 < d) {
                d = d3;
                n3 = i;
                type = AbstractParticleOperator.Type.BOUNDARY;
            }
            if (!((d2 = ReversibleZigZagOperator.findGradientRoot(dArray3[i], dArray4[i], dArray5[i])) < d)) continue;
            d = d2;
            n3 = i;
            type = AbstractParticleOperator.Type.GRADIENT;
        }
        return new MinimumTravelInformation(d, n3, type);
    }

    @Override
    final WrappedVector drawInitialMomentum() {
        WrappedVector wrappedVector = this.preconditioning.mass;
        double[] dArray = new double[wrappedVector.getDim()];
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            int n2 = MathUtils.nextDouble() > 0.5 ? 1 : -1;
            dArray[i] = (double)n2 * MathUtils.nextExponential(1.0) * Math.sqrt(wrappedVector.get(i));
        }
        if (this.mask != null) {
            this.applyMask(dArray);
        }
        return new WrappedVector.Raw(dArray);
    }

    @Override
    final WrappedVector drawInitialVelocity(WrappedVector wrappedVector) {
        WrappedVector wrappedVector2 = this.preconditioning.mass;
        double[] dArray = new double[wrappedVector.getDim()];
        int n = wrappedVector.getDim();
        for (int i = 0; i < n; ++i) {
            dArray[i] = (double)ReversibleZigZagOperator.sign(wrappedVector.get(i)) / Math.sqrt(wrappedVector2.get(i));
        }
        return new WrappedVector.Raw(dArray);
    }

    private void testNative(MinimumTravelInformation minimumTravelInformation, WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        this.timer.startTimer("getNextC++");
        MinimumTravelInformation minimumTravelInformation2 = this.nativeZigZag.getNextReversibleEvent(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer());
        this.timer.stopTimer("getNextC++");
        if (!minimumTravelInformation.equals(minimumTravelInformation2)) {
            System.err.println(minimumTravelInformation2 + " ?= " + minimumTravelInformation + "\n");
            System.exit(-1);
        }
    }

    private void updateDynamics(double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4, double[] dArray5, double[] dArray6, double d, int n) {
        double d2 = d * d / 2.0;
        double d3 = 2.0 * dArray2[n];
        int n2 = dArray.length;
        for (int i = 0; i < n2; ++i) {
            double d4 = dArray4[i];
            double d5 = dArray3[i];
            dArray[i] = dArray[i] + d * dArray2[i];
            dArray5[i] = dArray5[i] + d * d4 - d2 * d5;
            dArray4[i] = d4 - d * d5;
            dArray3[i] = d5 - d3 * dArray6[i];
        }
    }

    @Override
    void updateDynamics(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5, WrappedVector wrappedVector6, double d, int n, AbstractParticleOperator.Type type) {
        this.updateDynamics(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer(), wrappedVector6.getBuffer(), d, n);
        if (type == AbstractParticleOperator.Type.BOUNDARY) {
            ReversibleZigZagOperator.reflectMomentum(wrappedVector5, wrappedVector, n);
        } else {
            ReversibleZigZagOperator.setZeroMomentum(wrappedVector5, n);
        }
    }

    @Override
    void updatePositionAndMomentum(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5, double d) {
        ReversibleZigZagOperator.updatePosition(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), d);
        ReversibleZigZagOperator.updateMomentum(wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer(), d);
    }

    @Override
    public void reversiblePositionMomentumUpdate(WrappedVector wrappedVector, WrappedVector wrappedVector2, int n, double d) {
        this.preconditioning.totalTravelTime = d;
        if (n == -1) {
            this.negateVector(wrappedVector2);
        }
        this.integrateTrajectory(wrappedVector, wrappedVector2);
        if (n == -1) {
            this.negateVector(wrappedVector2);
        }
        ReadableVector.Utils.setParameter((ReadableVector)wrappedVector, this.parameter);
    }

    @Override
    public double[] getInitialPosition() {
        return this.parameter.getParameterValues();
    }

    @Override
    public double getParameterLogJacobian() {
        return 0.0;
    }

    @Override
    public void setParameter(double[] dArray) {
        ReadableVector.Utils.setParameter(dArray, this.parameter);
    }

    @Override
    public WrappedVector drawMomentum() {
        return this.drawInitialMomentum();
    }

    @Override
    public double getJointProbability(WrappedVector wrappedVector) {
        return this.gradientProvider.getLikelihood().getLogLikelihood() - this.getKineticEnergy(wrappedVector) - this.getParameterLogJacobian();
    }

    @Override
    public double getLogLikelihood() {
        return this.gradientProvider.getLikelihood().getLogLikelihood();
    }

    @Override
    public double getKineticEnergy(ReadableVector readableVector) {
        int n = readableVector.getDim();
        double d = 0.0;
        for (int i = 0; i < n; ++i) {
            d += Math.abs(readableVector.get(i));
        }
        return d;
    }

    @Override
    public double getStepSize() {
        return this.preconditioning.totalTravelTime;
    }

    private void negateVector(WrappedVector wrappedVector) {
        for (int i = 0; i < wrappedVector.getDim(); ++i) {
            wrappedVector.set(i, -wrappedVector.get(i));
        }
    }
}

