/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.substmodel;

import dr.evolution.datatype.DataType;
import dr.evolution.datatype.GeneralDataType;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.substmodel.BaseSubstitutionModel;
import dr.evomodel.substmodel.EigenDecomposition;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.ProductChainFrequencyModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.substmodel.SubstitutionProcess;
import dr.inference.model.Model;
import dr.math.KroneckerOperation;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class ProductChainSubstitutionModel
extends BaseSubstitutionModel
implements Citable {
    protected final int numBaseModel;
    protected final List<SubstitutionModel> baseModels;
    protected final List<SiteRateModel> rateModels;
    protected final int[] stateSizes;
    protected final ProductChainFrequencyModel pcFreqModel;
    protected double[] rateMatrix = null;
    private final boolean forceAverageModel;
    private SubstitutionProcess averageModel = null;

    public ProductChainSubstitutionModel(String string, List<SubstitutionModel> list) {
        this(string, list, null);
    }

    public ProductChainSubstitutionModel(String string, List<SubstitutionModel> list, List<SiteRateModel> list2) {
        this(string, list, list2, false);
    }

    public ProductChainSubstitutionModel(String string, List<SubstitutionModel> list, List<SiteRateModel> list2, boolean bl) {
        super(string);
        this.baseModels = list;
        this.rateModels = list2;
        this.forceAverageModel = bl;
        this.numBaseModel = list.size();
        if (this.numBaseModel == 0) {
            throw new RuntimeException("May not construct ProductChainSubstitutionModel with 0 base models");
        }
        if (list2 != null) {
            for (SiteRateModel siteRateModel : list2) {
                if (siteRateModel.getCategoryCount() <= 1) continue;
                throw new RuntimeException("ProductChainSubstitutionModels with multiple categories not yet implemented");
            }
        }
        ArrayList arrayList = new ArrayList();
        this.stateSizes = new int[this.numBaseModel];
        this.stateCount = 1;
        for (int i = 0; i < this.numBaseModel; ++i) {
            arrayList.add(list.get(i).getFrequencyModel());
            DataType dataType = list.get(i).getDataType();
            this.stateSizes[i] = dataType.getStateCount();
            this.stateCount *= dataType.getStateCount();
            this.addModel(list.get(i));
            this.addModel(list2.get(i));
        }
        this.pcFreqModel = new ProductChainFrequencyModel("pc", arrayList);
        this.addModel(this.pcFreqModel);
        String[] stringArray = this.getCharacterStrings();
        this.dataType = new GeneralDataType(stringArray);
        this.updateMatrix = true;
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.SUBSTITUTION_MODELS;
    }

    @Override
    public String getDescription() {
        return "Product chain substitution model";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(CommonCitations.OBRIEN_2009_LEARNING);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public EigenDecomposition getEigenDecomposition() {
        ProductChainSubstitutionModel productChainSubstitutionModel = this;
        synchronized (productChainSubstitutionModel) {
            if (this.updateMatrix) {
                this.computeKroneckerSumsAndProducts();
            }
        }
        return this.eigenDecomposition;
    }

    private String[] getCharacterStrings() {
        String[] stringArray = null;
        for (int i = this.numBaseModel - 1; i >= 0; --i) {
            stringArray = this.recursivelyAppendCharacterStates(this.baseModels.get(i).getDataType(), stringArray);
        }
        return stringArray;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        super.handleModelChangedEvent(model, object, n);
        this.fireModelChanged(model);
        this.averageModel = null;
    }

    private String[] recursivelyAppendCharacterStates(DataType dataType, String[] stringArray) {
        String[] stringArray2 = stringArray;
        if (stringArray2 == null) {
            stringArray2 = new String[]{""};
        }
        int n = stringArray2.length;
        int n2 = dataType.getStateCount();
        String[] stringArray3 = new String[n * n2];
        for (int i = 0; i < n2; ++i) {
            String string = dataType.getCode(i);
            for (int j = 0; j < n; ++j) {
                stringArray3[i * n + j] = string + stringArray2[j];
            }
        }
        return stringArray3;
    }

    @Override
    public void getInfinitesimalMatrix(double[] dArray) {
        this.getEigenDecomposition();
        System.arraycopy(this.rateMatrix, 0, dArray, 0, this.stateCount * this.stateCount);
    }

    double getRateForModel(int n) {
        if (!this.forceAverageModel) {
            return this.rateModels.get(n).getRateForCategory(0);
        }
        double d = 0.0;
        for (int i = 0; i < this.rateModels.size(); ++i) {
            d += this.rateModels.get(i).getRateForCategory(0);
        }
        return d / (double)this.rateModels.size();
    }

    protected double[] scaleForProductChain(double[] dArray, int n) {
        if (this.rateModels == null) {
            return dArray;
        }
        double d = this.getRateForModel(n);
        if (d == 1.0) {
            return dArray;
        }
        int n2 = dArray.length;
        double[] dArray2 = new double[n2];
        for (int i = 0; i < n2; ++i) {
            dArray2[i] = d * dArray[i];
        }
        return dArray2;
    }

    private SubstitutionProcess computeAverageModel() {
        return new SubstitutionProcess(){
            private double[] averageMatrix = null;
            private EigenDecomposition eigenDecomposition = null;

            @Override
            public void getTransitionProbabilities(double d, double[] dArray) {
                throw new RuntimeException("Should not be called");
            }

            @Override
            public EigenDecomposition getEigenDecomposition() {
                if (this.eigenDecomposition == null) {
                    double[][] dArray = new double[ProductChainSubstitutionModel.this.stateSizes[0]][ProductChainSubstitutionModel.this.stateSizes[0]];
                    double[] dArray2 = new double[ProductChainSubstitutionModel.this.stateSizes[0] * ProductChainSubstitutionModel.this.stateSizes[0]];
                    this.getInfinitesimalMatrix(dArray2);
                    for (int i = 0; i < ProductChainSubstitutionModel.this.stateSizes[0]; ++i) {
                        System.arraycopy(dArray2, i * ProductChainSubstitutionModel.this.stateSizes[0], dArray[i], 0, ProductChainSubstitutionModel.this.stateSizes[0]);
                    }
                    this.eigenDecomposition = ProductChainSubstitutionModel.this.getDefaultEigenSystem(ProductChainSubstitutionModel.this.stateSizes[0]).decomposeMatrix(dArray);
                }
                return this.eigenDecomposition;
            }

            @Override
            public FrequencyModel getFrequencyModel() {
                throw new RuntimeException("Should not be called");
            }

            @Override
            public void getInfinitesimalMatrix(double[] dArray) {
                if (this.averageMatrix == null) {
                    int n;
                    int n2 = dArray.length;
                    this.averageMatrix = new double[n2];
                    double[][] dArray2 = new double[ProductChainSubstitutionModel.this.baseModels.size()][n2];
                    for (n = 0; n < ProductChainSubstitutionModel.this.baseModels.size(); ++n) {
                        ProductChainSubstitutionModel.this.baseModels.get(n).getInfinitesimalMatrix(dArray2[n]);
                    }
                    for (n = 0; n < n2; ++n) {
                        double d = 0.0;
                        for (int i = 0; i < ProductChainSubstitutionModel.this.baseModels.size(); ++i) {
                            d += dArray2[i][n];
                        }
                        this.averageMatrix[n] = d /= (double)ProductChainSubstitutionModel.this.baseModels.size();
                    }
                }
                System.arraycopy(this.averageMatrix, 0, dArray, 0, this.averageMatrix.length);
            }

            @Override
            public DataType getDataType() {
                throw new RuntimeException("Should not be called");
            }

            @Override
            public boolean canReturnComplexDiagonalization() {
                throw new RuntimeException("Should not be called");
            }
        };
    }

    private SubstitutionProcess getBaseModel(int n) {
        if (!this.forceAverageModel) {
            return this.baseModels.get(n);
        }
        if (this.averageModel == null) {
            this.averageModel = this.computeAverageModel();
        }
        return this.averageModel;
    }

    private void computeKroneckerSumsAndProducts() {
        int n = this.stateSizes[0];
        double[] dArray = new double[n * n];
        this.getBaseModel(0).getInfinitesimalMatrix(dArray);
        dArray = this.scaleForProductChain(dArray, 0);
        EigenDecomposition eigenDecomposition = this.getBaseModel(0).getEigenDecomposition();
        double[] dArray2 = this.scaleForProductChain(eigenDecomposition.getEigenValues(), 0);
        double[] dArray3 = eigenDecomposition.getEigenVectors();
        double[] dArray4 = ProductChainSubstitutionModel.transpose(eigenDecomposition.getInverseEigenVectors(), n);
        for (int i = 1; i < this.numBaseModel; ++i) {
            SubstitutionProcess substitutionProcess = this.getBaseModel(i);
            int n2 = this.stateSizes[i];
            double[] dArray5 = new double[n2 * n2];
            substitutionProcess.getInfinitesimalMatrix(dArray5);
            dArray5 = this.scaleForProductChain(dArray5, i);
            dArray = KroneckerOperation.sum(dArray, n, dArray5, n2);
            EigenDecomposition eigenDecomposition2 = substitutionProcess.getEigenDecomposition();
            double[] dArray6 = this.scaleForProductChain(eigenDecomposition2.getEigenValues(), i);
            double[] dArray7 = eigenDecomposition2.getEigenVectors();
            double[] dArray8 = ProductChainSubstitutionModel.transpose(eigenDecomposition2.getInverseEigenVectors(), n2);
            dArray2 = KroneckerOperation.sum(dArray2, dArray6);
            dArray3 = KroneckerOperation.product(dArray3, n, n, dArray7, n2, n2);
            dArray4 = KroneckerOperation.product(dArray4, n, n, dArray8, n2, n2);
            n *= n2;
        }
        this.rateMatrix = dArray;
        this.eigenDecomposition = new EigenDecomposition(dArray3, ProductChainSubstitutionModel.transpose(dArray4, n), dArray2);
        this.updateMatrix = false;
    }

    private static double[] transpose(double[] dArray, int n) {
        double[] dArray2 = new double[n * n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                dArray2[j * n + i] = dArray[i * n + j];
            }
        }
        return dArray2;
    }

    @Override
    public FrequencyModel getFrequencyModel() {
        return this.pcFreqModel;
    }

    @Override
    protected void frequenciesChanged() {
    }

    @Override
    protected void ratesChanged() {
    }

    @Override
    protected void setupRelativeRates(double[] dArray) {
    }
}

