/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.substmodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.branchmodel.EpochBranchModel;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.siteratemodel.GammaSiteRateModel;
import dr.evomodel.substmodel.ComplexSubstitutionModel;
import dr.evomodel.substmodel.GlmSubstitutionModel;
import dr.evomodel.substmodel.MarkovJumpsSubstitutionModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.tree.TreeStatistic;
import dr.inference.markovjumps.MarkovJumpsCore;
import dr.inference.markovjumps.MarkovJumpsType;
import dr.inference.model.Parameter;
import dr.xml.Reportable;
import java.util.List;

public class SubstitutionModelRandomEffectClassifier
extends TreeStatistic
implements Reportable {
    private final int dim;
    private final int nStates;
    private final int nSites;
    private final double threshold;
    private Tree tree;
    private GlmSubstitutionModel glmSubstitutionModel;
    private EpochBranchModel epochBranchModel;
    private GammaSiteRateModel siteModel;
    private BranchRateModel branchRates;
    private Parameter proxyRates;
    private ComplexSubstitutionModel proxy;
    private MarkovJumpsSubstitutionModel markovJumps;
    private final boolean usingRateVariation;
    private final boolean usingEpochs;
    private boolean[] epochUsesTargetModel;
    private boolean nullIsZero;
    private int[] fromState;
    private int[] toState;

    public SubstitutionModelRandomEffectClassifier(String string, Tree tree, GlmSubstitutionModel glmSubstitutionModel, EpochBranchModel epochBranchModel, BranchRateModel branchRateModel, GammaSiteRateModel gammaSiteRateModel, int n, double d, boolean bl) {
        super(string);
        int n2;
        int n3;
        this.tree = tree;
        this.glmSubstitutionModel = glmSubstitutionModel;
        this.epochBranchModel = epochBranchModel;
        this.siteModel = gammaSiteRateModel;
        this.branchRates = branchRateModel;
        this.nullIsZero = bl;
        this.usingRateVariation = gammaSiteRateModel != null;
        boolean bl2 = this.usingEpochs = epochBranchModel != null;
        if (!(glmSubstitutionModel instanceof GlmSubstitutionModel)) {
            throw new RuntimeException("SubstitutionModelRandomEffectClassifier only works for GLM substitution models.");
        }
        if (this.usingEpochs) {
            List<SubstitutionModel> list = epochBranchModel.getSubstitutionModels();
            this.epochUsesTargetModel = new boolean[list.size()];
            n3 = 0;
            for (n2 = 0; n2 < list.size(); ++n2) {
                if (list.get(n2) != glmSubstitutionModel) continue;
                this.epochUsesTargetModel[n2] = true;
                ++n3;
            }
            if (n3 == 0) {
                throw new RuntimeException("Cannot find specified GLM substitution model (id: " + glmSubstitutionModel.getId() + ") in specified epoch model (id: " + epochBranchModel.getId() + ")");
            }
        } else {
            this.epochUsesTargetModel = new boolean[0];
        }
        this.nSites = n;
        this.nStates = glmSubstitutionModel.getFrequencyModel().getDataType().getStateCount();
        this.dim = this.nStates * (this.nStates - 1);
        this.threshold = d;
        this.fromState = new int[this.dim];
        this.toState = new int[this.dim];
        int n4 = 0;
        n3 = this.dim / 2;
        for (n2 = 0; n2 < this.nStates - 1; ++n2) {
            int n5 = n2 + 1;
            while (n5 < this.nStates) {
                this.fromState[n4] = n2;
                this.toState[n4] = n5;
                this.fromState[n4 + n3] = n5++;
                this.toState[n4 + n3] = n2;
                ++n4;
            }
        }
        this.proxyRates = new Parameter.Default(this.dim);
        this.proxy = new ComplexSubstitutionModel("internalGlmProxyForSubstitutionModelRandomEffectClassifier", glmSubstitutionModel.getDataType(), glmSubstitutionModel.getFrequencyModel(), this.proxyRates);
        this.markovJumps = new MarkovJumpsSubstitutionModel(this.proxy, MarkovJumpsType.COUNTS);
    }

    @Override
    public void setTree(Tree tree) {
    }

    @Override
    public Tree getTree() {
        return null;
    }

    @Override
    public int getDimension() {
        return this.dim;
    }

    private void makeProxyModel(int n, boolean bl) {
        double[] dArray = new double[this.dim];
        this.glmSubstitutionModel.setupRelativeRates(dArray);
        if (!bl) {
            double[] dArray2 = this.glmSubstitutionModel.getGeneralizedLinearModel().getRandomEffect(0).getParameterValues();
            if (this.nullIsZero) {
                int n2 = n;
                dArray[n2] = dArray[n2] / Math.exp(dArray2[n]);
            } else {
                dArray[n] = 0.0;
            }
        }
        for (int i = 0; i < this.dim; ++i) {
            this.proxyRates.setParameterValue(i, dArray[i]);
        }
    }

    private double getEpochContribution(NodeRef nodeRef) {
        BranchModel.Mapping mapping = this.epochBranchModel.getBranchModelMapping(nodeRef);
        int[] nArray = mapping.getOrder();
        double[] dArray = mapping.getWeights();
        double d = 0.0;
        for (int i = 0; i < nArray.length; ++i) {
            if (!this.epochUsesTargetModel[nArray[i]]) continue;
            d += dArray[i];
        }
        return d;
    }

    private double getTreeLengthInSubstitutions() {
        double d = 0.0;
        NodeRef nodeRef = this.tree.getRoot();
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef2 = this.tree.getNode(i);
            if (nodeRef2 == nodeRef) continue;
            if (this.usingEpochs) {
                d += this.getEpochContribution(nodeRef2);
                continue;
            }
            double d2 = this.tree.getNodeHeight(this.tree.getParent(nodeRef2)) - this.tree.getNodeHeight(nodeRef2);
            d += d2 * this.branchRates.getBranchRate(this.tree, nodeRef2);
        }
        return d;
    }

    private double getDoubleResult(double d) {
        if (this.threshold <= 0.0) {
            return d;
        }
        if (Math.abs(d) > this.threshold) {
            return 1.0;
        }
        return 0.0;
    }

    private double getCountForRateCategory(int n, double d, boolean bl, boolean bl2) {
        this.makeProxyModel(n, bl);
        double[] dArray = new double[this.nStates * this.nStates];
        double[] dArray2 = new double[this.nStates * this.nStates];
        double[] dArray3 = new double[this.nStates * this.nStates];
        double[] dArray4 = new double[this.nStates * this.nStates];
        if (!bl2) {
            int n2 = this.fromState[n];
            int n3 = this.toState[n];
            MarkovJumpsCore.fillRegistrationMatrix(dArray, n2, n3, this.nStates, 1.0);
        } else {
            MarkovJumpsCore.fillRegistrationMatrix(dArray, this.nStates);
        }
        this.markovJumps.setRegistration(dArray);
        this.markovJumps.computeJointStatMarkovJumps(d, dArray2);
        this.markovJumps.computeCondStatMarkovJumps(d, dArray3);
        this.proxy.getTransitionProbabilities(d, dArray4);
        double d2 = 0.0;
        double[] dArray5 = this.glmSubstitutionModel.getFrequencyModel().getFrequencies();
        for (int i = 0; i < this.nStates; ++i) {
            for (int j = 0; j < this.nStates; ++j) {
                d2 += dArray2[i * this.nStates + j] * dArray5[i];
            }
        }
        return d2;
    }

    private double getCount(int n, boolean bl, boolean bl2) {
        double d = this.getTreeLengthInSubstitutions();
        double d2 = 0.0;
        if (this.usingRateVariation) {
            for (int i = 0; i < this.siteModel.getCategoryCount(); ++i) {
                double d3 = this.siteModel.getRateForCategory(i);
                d2 += this.getCountForRateCategory(n, d * d3, bl, bl2) * this.siteModel.getProportionForCategory(i);
            }
        } else {
            d2 += this.getCountForRateCategory(n, d, bl, bl2);
        }
        return d2 * (double)this.nSites;
    }

    private double getCountDifferences(int n) {
        return this.getCount(n, true, false) - this.getCount(n, false, false);
    }

    @Override
    public double getStatisticValue(int n) {
        return this.getDoubleResult(this.getCountDifferences(n));
    }
}

