/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous.hmc;

import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.continuous.hmc.LinearOrderTreePrecisionTraitProductProvider;
import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.WrappedTipFullConditionalDistributionDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.model.VariableListener;
import dr.math.matrixAlgebra.ReadableMatrix;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.WritableVector;
import dr.math.matrixAlgebra.missingData.MissingOps;
import dr.util.StopWatch;
import dr.util.TaskPool;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.Reportable;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.List;
import org.ejml.data.DenseMatrix64F;

public class IntegratedLoadingsGradient
implements GradientWrtParameterProvider,
VariableListener,
Reportable {
    private final TreeTrait<List<WrappedNormalSufficientStatistics>> fullConditionalDensity;
    private final IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood;
    private final int dimTrait;
    private final int dimFactors;
    private final Tree tree;
    private final Likelihood likelihood;
    private final double[] data;
    private final boolean[] missing;
    private final ThreadUseProvider threadUseProvider;
    private final RemainderCompProvider remainderCompProvider;
    private final TaskPool taskPool;
    private StopWatch[] stopWatches;
    private static final boolean TIMING = false;
    private static final boolean DEBUG = false;
    private static final String PARSER_NAME = "integratedFactorAnalysisLoadingsGradient";
    private static final String THREAD_TYPE = "threadType";
    private static final String PARALLEL = "parallel";
    private static final String SERIAL = "serial";
    private static final String REMAINDER_COMPUTATION = "remainderComputation";
    private static final String FULL = "full";
    private static final String SKIP = "skip";
    public static AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(IntegratedFactorAnalysisLikelihood.class), new ElementRule(TreeDataLikelihood.class), new ElementRule(TaskPool.class, true), AttributeRule.newStringRule("threadType", true), AttributeRule.newStringRule("remainderComputation", true)};

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            RemainderCompProvider remainderCompProvider;
            ThreadUseProvider threadUseProvider;
            TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood)xMLObject.getChild(TreeDataLikelihood.class);
            IntegratedFactorAnalysisLikelihood integratedFactorAnalysisLikelihood = (IntegratedFactorAnalysisLikelihood)xMLObject.getChild(IntegratedFactorAnalysisLikelihood.class);
            DataLikelihoodDelegate dataLikelihoodDelegate = treeDataLikelihood.getDataLikelihoodDelegate();
            if (!(dataLikelihoodDelegate instanceof ContinuousDataLikelihoodDelegate)) {
                throw new XMLParseException("TODO");
            }
            ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate = (ContinuousDataLikelihoodDelegate)dataLikelihoodDelegate;
            TaskPool taskPool = (TaskPool)xMLObject.getChild(TaskPool.class);
            String string = xMLObject.getAttribute(IntegratedLoadingsGradient.THREAD_TYPE, IntegratedLoadingsGradient.PARALLEL);
            if (string.equalsIgnoreCase(IntegratedLoadingsGradient.PARALLEL)) {
                threadUseProvider = ThreadUseProvider.PARALLEL;
            } else if (string.equalsIgnoreCase(IntegratedLoadingsGradient.SERIAL)) {
                threadUseProvider = ThreadUseProvider.SERIAL;
            } else {
                throw new XMLParseException("The attribute threadType must have values \"parallel\" or \"serial\".");
            }
            String string2 = xMLObject.getAttribute(IntegratedLoadingsGradient.REMAINDER_COMPUTATION, IntegratedLoadingsGradient.SKIP);
            if (string2.equalsIgnoreCase(IntegratedLoadingsGradient.SKIP)) {
                remainderCompProvider = RemainderCompProvider.SKIP;
            } else if (string2.equalsIgnoreCase(IntegratedLoadingsGradient.FULL)) {
                remainderCompProvider = RemainderCompProvider.FULL;
            } else {
                throw new XMLParseException("The attribute remainderComputation must have values \"skip\" or \"full\".");
            }
            if (taskPool != null && string != IntegratedLoadingsGradient.PARALLEL) {
                throw new XMLParseException("Cannot simultaneously provide taskPool and threadType=\"" + string + "\". Please either change to " + IntegratedLoadingsGradient.THREAD_TYPE + "=\"" + IntegratedLoadingsGradient.PARALLEL + "\" or remove the " + "taskPool" + " element.");
            }
            return new IntegratedLoadingsGradient(treeDataLikelihood, continuousDataLikelihoodDelegate, integratedFactorAnalysisLikelihood, taskPool, threadUseProvider, remainderCompProvider);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public String getParserDescription() {
            return "Generates a gradient provider for the loadings matrix when factors are integrated out";
        }

        @Override
        public Class getReturnType() {
            return IntegratedLoadingsGradient.class;
        }

        @Override
        public String getParserName() {
            return IntegratedLoadingsGradient.PARSER_NAME;
        }
    };

    private IntegratedLoadingsGradient(TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, IntegratedFactorAnalysisLikelihood integratedFactorAnalysisLikelihood, TaskPool taskPool, ThreadUseProvider threadUseProvider, RemainderCompProvider remainderCompProvider) {
        this.factorAnalysisLikelihood = integratedFactorAnalysisLikelihood;
        String string = integratedFactorAnalysisLikelihood.getModelName();
        String string2 = WrappedTipFullConditionalDistributionDelegate.getName(string);
        if (treeDataLikelihood.getTreeTrait(string2) == null) {
            continuousDataLikelihoodDelegate.addWrappedFullConditionalDensityTrait(string);
        }
        this.fullConditionalDensity = LinearOrderTreePrecisionTraitProductProvider.castTreeTrait(treeDataLikelihood.getTreeTrait(string2));
        this.tree = treeDataLikelihood.getTree();
        this.dimTrait = integratedFactorAnalysisLikelihood.getDataDimension();
        this.dimFactors = integratedFactorAnalysisLikelihood.getNumberOfFactors();
        CompoundParameter compoundParameter = integratedFactorAnalysisLikelihood.getParameter();
        this.data = compoundParameter.getParameterValues();
        compoundParameter.addVariableListener(this);
        this.missing = this.getMissing(integratedFactorAnalysisLikelihood.getMissingDataIndices(), compoundParameter.getDimension());
        ArrayList<Likelihood> arrayList = new ArrayList<Likelihood>();
        arrayList.add(treeDataLikelihood);
        arrayList.add(integratedFactorAnalysisLikelihood);
        this.likelihood = new CompoundLikelihood(arrayList);
        TaskPool taskPool2 = this.taskPool = taskPool != null ? taskPool : new TaskPool(this.tree.getExternalNodeCount(), 1);
        if (this.taskPool.getNumTaxon() != this.tree.getExternalNodeCount()) {
            throw new IllegalArgumentException("Incorrectly specified TaskPool");
        }
        this.threadUseProvider = threadUseProvider;
        this.remainderCompProvider = remainderCompProvider;
    }

    private boolean[] getMissing(List<Integer> list, int n) {
        boolean[] blArray = new boolean[n];
        for (int n2 : list) {
            blArray[n2] = true;
        }
        return blArray;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.factorAnalysisLikelihood.getLoadings();
    }

    @Override
    public int getDimension() {
        return this.dimFactors * this.dimTrait;
    }

    private ReadableMatrix shiftToSecondMoment(WrappedMatrix wrappedMatrix, ReadableVector readableVector) {
        assert (wrappedMatrix.getMajorDim() == wrappedMatrix.getMinorDim());
        assert (wrappedMatrix.getMajorDim() == readableVector.getDim());
        int n = wrappedMatrix.getMajorDim();
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                wrappedMatrix.set(i, j, wrappedMatrix.get(i, j) + readableVector.get(i) * readableVector.get(j));
            }
        }
        return wrappedMatrix;
    }

    private WrappedNormalSufficientStatistics getWeightedAverage(ReadableVector readableVector, ReadableMatrix readableMatrix, ReadableVector readableVector2, ReadableMatrix readableMatrix2) {
        assert (readableVector.getDim() == readableVector2.getDim());
        assert (readableMatrix.getDim() == readableMatrix2.getDim());
        assert (readableVector.getDim() == readableMatrix.getMinorDim());
        assert (readableVector.getDim() == readableMatrix.getMajorDim());
        WrappedVector.Raw raw = new WrappedVector.Raw(new double[readableVector.getDim()], 0, this.dimFactors);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.dimFactors, this.dimFactors);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.dimFactors, this.dimFactors);
        WrappedMatrix.WrappedDenseMatrix wrappedDenseMatrix = new WrappedMatrix.WrappedDenseMatrix(denseMatrix64F);
        WrappedMatrix.WrappedDenseMatrix wrappedDenseMatrix2 = new WrappedMatrix.WrappedDenseMatrix(denseMatrix64F2);
        MissingOps.add(readableMatrix, readableMatrix2, wrappedDenseMatrix);
        MissingOps.safeInvert2(denseMatrix64F, denseMatrix64F2, false);
        MissingOps.weightedAverage(readableVector, readableMatrix, readableVector2, readableMatrix2, (WritableVector)raw, wrappedDenseMatrix2, this.dimFactors);
        return new WrappedNormalSufficientStatistics(raw, wrappedDenseMatrix, wrappedDenseMatrix2);
    }

    @Override
    public double[] getGradientLogDensity() {
        double[][] dArray = new double[this.taskPool.getNumThreads()][this.getDimension()];
        WrappedVector.Parameter parameter = new WrappedVector.Parameter(this.factorAnalysisLikelihood.getPrecision());
        ReadableMatrix readableMatrix = ReadableMatrix.Utils.transposeProxy(new WrappedMatrix.MatrixParameter(this.factorAnalysisLikelihood.getLoadings()));
        double[] dArray2 = this.factorAnalysisLikelihood.getPrecision().getParameterValues();
        double[] dArray3 = ReadableMatrix.Utils.toArray(new WrappedMatrix.MatrixParameter(this.factorAnalysisLikelihood.getLoadings()));
        assert (parameter.getDim() == this.dimTrait);
        assert (readableMatrix.getMajorDim() == this.dimFactors);
        assert (readableMatrix.getMinorDim() == this.dimTrait);
        if (this.remainderCompProvider.computeRemainder()) {
            this.likelihood.getLogLikelihood();
        }
        List<WrappedNormalSufficientStatistics> list = this.fullConditionalDensity.getTrait(this.tree, null);
        assert (list.size() == this.tree.getExternalNodeCount());
        if (!this.threadUseProvider.usePool()) {
            int n3 = this.tree.getExternalNodeCount();
            for (int i = 0; i < n3; ++i) {
                this.computeGradientForOneTaxon(0, i, readableMatrix, dArray3, parameter, dArray2, list.get(i), dArray);
            }
        } else {
            this.taskPool.fork((n, n2) -> this.computeGradientForOneTaxon(n2, n, readableMatrix, dArray3, parameter, dArray2, (WrappedNormalSufficientStatistics)list.get(n), dArray));
        }
        return IntegratedLoadingsGradient.join(dArray);
    }

    private void computeGradientForOneTaxon(int n, int n2, ReadableMatrix readableMatrix, double[] dArray, ReadableVector readableVector, double[] dArray2, WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics, double[][] dArray3) {
        WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics2 = this.getTipKernel(n2);
        WrappedVector wrappedVector = wrappedNormalSufficientStatistics2.getMean();
        WrappedMatrix wrappedMatrix = wrappedNormalSufficientStatistics2.getPrecision();
        WrappedVector wrappedVector2 = wrappedNormalSufficientStatistics.getMean();
        WrappedMatrix wrappedMatrix2 = wrappedNormalSufficientStatistics.getPrecision();
        WrappedMatrix wrappedMatrix3 = wrappedNormalSufficientStatistics.getVariance();
        WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics3 = this.getWeightedAverage(wrappedVector2, wrappedMatrix2, wrappedVector, wrappedMatrix);
        WrappedVector wrappedVector3 = wrappedNormalSufficientStatistics3.getMean();
        WrappedMatrix wrappedMatrix4 = wrappedNormalSufficientStatistics3.getVariance();
        ReadableMatrix readableMatrix2 = this.shiftToSecondMoment(wrappedMatrix4, wrappedVector3);
        double[] dArray4 = ReadableMatrix.Utils.toArray(readableMatrix2);
        for (int i = 0; i < this.dimFactors; ++i) {
            double d = wrappedVector3.get(i);
            for (int j = 0; j < this.dimTrait; ++j) {
                if (this.missing[n2 * this.dimTrait + j]) continue;
                double d2 = 0.0;
                for (int k = 0; k < this.dimFactors; ++k) {
                    d2 += dArray4[i * this.dimFactors + k] * dArray[j * this.dimFactors + k];
                }
                double[] dArray5 = dArray3[n];
                int n3 = i * this.dimTrait + j;
                dArray5[n3] = dArray5[n3] + (d * this.data[n2 * this.dimTrait + j] - d2) * dArray2[j];
            }
        }
    }

    private static double[] join(double[][] dArray) {
        int n = dArray.length;
        int n2 = dArray[0].length;
        double[] dArray2 = dArray[0];
        for (int i = 1; i < n; ++i) {
            double[] dArray3 = dArray[i];
            for (int j = 0; j < n2; ++j) {
                int n3 = j;
                dArray2[n3] = dArray2[n3] + dArray3[j];
            }
        }
        return dArray2;
    }

    private WrappedNormalSufficientStatistics getTipKernel(int n) {
        double[] dArray = this.factorAnalysisLikelihood.getTipPartial(n, false);
        return new WrappedNormalSufficientStatistics(dArray, 0, this.dimFactors, null, PrecisionType.FULL);
    }

    @Override
    public String getReport() {
        String string = "";
        string = string + GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, null);
        return string;
    }

    private String timingInfo() {
        StringBuilder stringBuilder = new StringBuilder("\nTiming in IntegratedLoadingsGradient\n");
        for (StopWatch stopWatch : this.stopWatches) {
            stringBuilder.append("\t").append(stopWatch.toString()).append("\n");
            stopWatch.reset();
        }
        return stringBuilder.toString();
    }

    @Override
    public void variableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        throw new RuntimeException("Trait data is not cached");
    }

    private static enum RemainderCompProvider {
        FULL{

            @Override
            boolean computeRemainder() {
                return true;
            }
        }
        ,
        SKIP{

            @Override
            boolean computeRemainder() {
                return false;
            }
        };


        abstract boolean computeRemainder();
    }

    private static enum ThreadUseProvider {
        PARALLEL{

            @Override
            boolean usePool() {
                return true;
            }
        }
        ,
        SERIAL{

            @Override
            boolean usePool() {
                return false;
            }
        };


        abstract boolean usePool();
    }
}

