/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PrecisionColumnProvider;
import dr.inference.hmc.PrecisionMatrixVectorProductProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.AbstractParticleOperator;
import dr.inference.operators.hmc.MassPreconditionScheduler;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.inference.operators.hmc.MinimumTravelInformation;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.TaskPool;

abstract class AbstractZigZagOperator
extends AbstractParticleOperator
implements Loggable {
    final TaskPool taskPool;
    protected static final boolean DEBUG = false;
    private static final boolean printEventLocations = false;

    AbstractZigZagOperator(GradientWrtParameterProvider gradientWrtParameterProvider, PrecisionMatrixVectorProductProvider precisionMatrixVectorProductProvider, PrecisionColumnProvider precisionColumnProvider, double d, AbstractParticleOperator.Options options, AbstractParticleOperator.NativeCodeOptions nativeCodeOptions, boolean bl, Parameter parameter, Parameter parameter2, int n, MassPreconditioner massPreconditioner, MassPreconditionScheduler.Type type) {
        super(gradientWrtParameterProvider, precisionMatrixVectorProductProvider, precisionColumnProvider, d, options, nativeCodeOptions, bl, parameter, parameter2, massPreconditioner, type);
        this.taskPool = n > 1 ? new TaskPool(gradientWrtParameterProvider.getDimension(), n) : null;
    }

    @Override
    final double integrateTrajectory(WrappedVector wrappedVector, WrappedVector wrappedVector2) {
        this.timer.startTimer("warmUp");
        WrappedVector wrappedVector3 = this.drawInitialVelocity(wrappedVector2);
        WrappedVector wrappedVector4 = this.getInitialGradient();
        WrappedVector wrappedVector5 = this.getPrecisionProduct(wrappedVector3);
        AbstractParticleOperator.BounceState bounceState = new AbstractParticleOperator.BounceState(this.drawTotalTravelTime());
        this.initializeNumEvent();
        this.timer.stopTimer("warmUp");
        this.timer.startTimer("integrateTrajectory");
        while (bounceState.isTimeRemaining()) {
            MinimumTravelInformation minimumTravelInformation = this.getNextBounce(wrappedVector, wrappedVector3, wrappedVector5, wrappedVector4, wrappedVector2);
            bounceState = this.doBounce(bounceState, minimumTravelInformation, wrappedVector, wrappedVector3, wrappedVector5, wrappedVector4, wrappedVector2);
        }
        this.timer.stopTimer("integrateTrajectory");
        this.storeVelocity(wrappedVector3);
        return 0.0;
    }

    private AbstractParticleOperator.BounceState doBounce(AbstractParticleOperator.BounceState bounceState, MinimumTravelInformation minimumTravelInformation, WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        AbstractParticleOperator.BounceState bounceState2;
        this.timer.startTimer("doBounce");
        double d = bounceState.remainingTime;
        double d2 = minimumTravelInformation.time;
        if (d < d2) {
            this.updatePositionAndMomentum(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5, d);
            bounceState2 = new AbstractParticleOperator.BounceState(AbstractParticleOperator.Type.NONE, -1, 0.0);
        } else {
            AbstractParticleOperator.Type type = minimumTravelInformation.type;
            int n = minimumTravelInformation.index[0];
            WrappedVector wrappedVector6 = this.getPrecisionColumn(n);
            this.updateDynamics(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5, wrappedVector6, d2, minimumTravelInformation.index, type);
            AbstractZigZagOperator.reflectVelocity(wrappedVector2, minimumTravelInformation.index);
            bounceState2 = new AbstractParticleOperator.BounceState(type, n, d - d2);
            this.recordEvents(type);
        }
        this.timer.stopTimer("doBounce");
        return bounceState2;
    }

    abstract WrappedVector drawInitialVelocity(WrappedVector var1);

    abstract MinimumTravelInformation getNextBounce(WrappedVector var1, WrappedVector var2, WrappedVector var3, WrappedVector var4, WrappedVector var5);

    abstract void updatePositionAndMomentum(WrappedVector var1, WrappedVector var2, WrappedVector var3, WrappedVector var4, WrappedVector var5, double var6);

    abstract void updateDynamics(WrappedVector var1, WrappedVector var2, WrappedVector var3, WrappedVector var4, WrappedVector var5, WrappedVector var6, double var7, int[] var9, AbstractParticleOperator.Type var10);

    static double findGradientRoot(double d, double d2, double d3) {
        return AbstractZigZagOperator.minimumPositiveRoot(-0.5 * d, d2, d3);
    }

    double findBinaryBoundaryTime(int n, double d, double d2) {
        double d3 = Double.POSITIVE_INFINITY;
        if (this.headingTowardsBinaryBoundary(d2, n)) {
            d3 = Math.abs(d / d2);
        }
        return d3;
    }

    MinimumTravelInformation findCategoricalBoundaryTime(double[] dArray, double[] dArray2) {
        double d = Double.POSITIVE_INFINITY;
        int[] nArray = new int[]{-1, -1};
        if (this.categoryClasses == null) {
            return new MinimumTravelInformation(d, nArray);
        }
        int n = 0;
        while (n < dArray.length) {
            if (this.categoryClasses[n] > 0) {
                int n2 = this.categoryClasses[n] - 1;
                double[] dArray3 = new double[n2];
                double[] dArray4 = new double[n2];
                double[] dArray5 = new double[n2];
                System.arraycopy(dArray, n, dArray3, 0, n2);
                System.arraycopy(dArray2, n, dArray4, 0, n2);
                System.arraycopy(this.observedDataMask, n, dArray5, 0, n2);
                MinimumTravelInformation minimumTravelInformation = this.findCategoricalBoundaryTimeOneTrait(dArray3, dArray4, dArray5);
                if (minimumTravelInformation.time < d) {
                    d = minimumTravelInformation.time;
                    nArray[0] = n + minimumTravelInformation.index[0];
                    nArray[1] = n + minimumTravelInformation.index[1];
                }
                n += n2;
                continue;
            }
            ++n;
        }
        return new MinimumTravelInformation(d, nArray);
    }

    private MinimumTravelInformation findCategoricalBoundaryTimeOneTrait(double[] dArray, double[] dArray2, double[] dArray3) {
        int[] nArray = new int[]{-1, -1};
        double d = Double.POSITIVE_INFINITY;
        if (!this.isReferenceClass(dArray3)) {
            int n;
            int n2 = -1;
            for (n = 0; n < dArray3.length; ++n) {
                if (!(dArray3[n] > 0.0)) continue;
                n2 = n;
                break;
            }
            nArray[0] = n2;
            for (n = 0; n < dArray.length; ++n) {
                double d2;
                double d3 = dArray2[n2] - dArray2[n];
                double d4 = dArray[n2] - dArray[n];
                if (!(d3 < 0.0) || !((d2 = -d4 / d3) < d)) continue;
                d = d2;
                nArray[1] = n;
            }
        }
        return new MinimumTravelInformation(d, nArray);
    }

    private boolean isReferenceClass(double[] dArray) {
        for (int i = 0; i < dArray.length; ++i) {
            if (dArray[i] != 0.0) continue;
            return false;
        }
        return true;
    }

    private static double minimumPositiveRoot(double d, double d2, double d3) {
        double d4;
        double d5 = AbstractZigZagOperator.sign(d);
        if ((d4 = (d2 *= d5) * d2 - 4.0 * (d *= d5) * (d3 *= d5)) < 0.0) {
            return Double.POSITIVE_INFINITY;
        }
        double d6 = Math.sqrt(d4);
        double d7 = (-d2 - d6) / (2.0 * d);
        if (d7 <= 0.0) {
            d7 = (-d2 + d6) / (2.0 * d);
        }
        if (d7 <= 0.0) {
            d7 = Double.POSITIVE_INFINITY;
        }
        return d7;
    }

    static void reflectMomentum(WrappedVector wrappedVector, int n) {
        wrappedVector.set(n, -wrappedVector.get(n));
    }

    static void setZeroPosition(WrappedVector wrappedVector, int n) {
        wrappedVector.set(n, 0.0);
    }

    static void setEqualPosition(WrappedVector wrappedVector, int n, int n2) {
        wrappedVector.set(n2, wrappedVector.get(n));
    }

    static void setZeroMomentum(WrappedVector wrappedVector, int n) {
        wrappedVector.set(n, 0.0);
    }

    private static void reflectVelocity(WrappedVector wrappedVector, int[] nArray) {
        for (int i = 0; i < nArray.length; ++i) {
            if (nArray[i] < 0) continue;
            wrappedVector.set(nArray[i], -wrappedVector.get(nArray[i]));
        }
    }

    protected boolean close(double[] dArray, double[] dArray2) {
        for (int i = 0; i < dArray.length; ++i) {
            if (!(Math.abs((dArray[i] - dArray2[i]) / (dArray[i] + dArray2[i])) > 1.0E-5)) continue;
            return false;
        }
        return true;
    }

    static int sign(double d) {
        int n = 0;
        if (d > 0.0) {
            n = 1;
        } else if (d < 0.0) {
            n = -1;
        }
        return n;
    }

    @Override
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArray = new LogColumn[]{new NumberColumn("total events"){

            @Override
            public double getDoubleValue() {
                return AbstractZigZagOperator.this.numEvents;
            }
        }, new NumberColumn("gradient events"){

            @Override
            public double getDoubleValue() {
                return AbstractZigZagOperator.this.numGradientEvents;
            }
        }, new NumberColumn("boundary events"){

            @Override
            public double getDoubleValue() {
                return AbstractZigZagOperator.this.numBoundaryEvents;
            }
        }};
        return logColumnArray;
    }
}

