/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.MaxEnt;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import cc.mallet.util.Maths;
import com.google.errorprone.annotations.Var;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;

public class MaxEntOptimizableByLabelLikelihood
implements Optimizable.ByGradientValue {
    private static Logger logger = MalletLogger.getLogger(MaxEntOptimizableByLabelLikelihood.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntOptimizableByLabelLikelihood.class.getName() + "-pl");
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0;
    static final Class DEFAULT_MAXIMIZER_CLASS = LimitedMemoryBFGS.class;
    boolean usingHyperbolicPrior = false;
    boolean usingGaussianPrior = true;
    double gaussianPriorVariance = 1.0;
    double hyperbolicPriorSlope = 0.2;
    double hyperbolicPriorSharpness = 10.0;
    Class maximizerClass = DEFAULT_MAXIMIZER_CLASS;
    double[] parameters;
    double[] constraints;
    double[] cachedGradient;
    MaxEnt theClassifier;
    InstanceList trainingList;
    double cachedValue;
    boolean cachedValueStale;
    boolean cachedGradientStale;
    int numLabels;
    int numFeatures;
    int defaultFeatureIndex;
    FeatureSelection featureSelection;
    FeatureSelection[] perLabelFeatureSelection;
    int numGetValueCalls = 0;
    int numGetValueGradientCalls = 0;

    public MaxEntOptimizableByLabelLikelihood() {
    }

    public MaxEntOptimizableByLabelLikelihood(InstanceList trainingSet, MaxEnt initialClassifier) {
        this.trainingList = trainingSet;
        Alphabet fd = trainingSet.getDataAlphabet();
        LabelAlphabet ld = (LabelAlphabet)trainingSet.getTargetAlphabet();
        ld.stopGrowth();
        this.numLabels = ld.size();
        this.numFeatures = fd.size() + 1;
        this.defaultFeatureIndex = this.numFeatures - 1;
        this.parameters = new double[this.numLabels * this.numFeatures];
        this.constraints = new double[this.numLabels * this.numFeatures];
        this.cachedGradient = new double[this.numLabels * this.numFeatures];
        Arrays.fill(this.parameters, 0.0);
        Arrays.fill(this.constraints, 0.0);
        Arrays.fill(this.cachedGradient, 0.0);
        this.featureSelection = trainingSet.getFeatureSelection();
        this.perLabelFeatureSelection = trainingSet.getPerLabelFeatureSelection();
        if (this.featureSelection != null) {
            this.featureSelection.add(this.defaultFeatureIndex);
        }
        if (this.perLabelFeatureSelection != null) {
            for (int i = 0; i < this.perLabelFeatureSelection.length; ++i) {
                this.perLabelFeatureSelection[i].add(this.defaultFeatureIndex);
            }
        }
        assert (this.featureSelection == null || this.perLabelFeatureSelection == null);
        if (initialClassifier != null) {
            this.theClassifier = initialClassifier;
            this.parameters = this.theClassifier.parameters;
            this.featureSelection = this.theClassifier.featureSelection;
            this.perLabelFeatureSelection = this.theClassifier.perClassFeatureSelection;
            this.defaultFeatureIndex = this.theClassifier.defaultFeatureIndex;
            assert (initialClassifier.getInstancePipe() == trainingSet.getPipe());
        } else if (this.theClassifier == null) {
            this.theClassifier = new MaxEnt(trainingSet.getPipe(), this.parameters, this.featureSelection, this.perLabelFeatureSelection);
        }
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        logger.fine("Number of instances in training list = " + this.trainingList.size());
        for (Instance inst : this.trainingList) {
            double instanceWeight = this.trainingList.getInstanceWeight(inst);
            Labeling labeling = inst.getLabeling();
            if (labeling == null) continue;
            FeatureVector fv = (FeatureVector)inst.getData();
            Alphabet fdict = fv.getAlphabet();
            assert (fv.getAlphabet() == fd);
            int li = labeling.getBestIndex();
            MatrixOps.rowPlusEquals(this.constraints, this.numFeatures, li, fv, instanceWeight);
            assert (!Double.isNaN(instanceWeight)) : "instanceWeight is NaN";
            assert (!Double.isNaN(li)) : "bestIndex is NaN";
            boolean hasNaN = false;
            for (int i = 0; i < fv.numLocations(); ++i) {
                if (!Double.isNaN(fv.valueAtLocation(i))) continue;
                logger.info("NaN for feature " + fdict.lookupObject(fv.indexAtLocation(i)).toString());
                hasNaN = true;
            }
            if (hasNaN) {
                logger.info("NaN in instance: " + inst.getName());
            }
            int n = li * this.numFeatures + this.defaultFeatureIndex;
            this.constraints[n] = this.constraints[n] + 1.0 * instanceWeight;
        }
    }

    public MaxEnt getClassifier() {
        return this.theClassifier;
    }

    @Override
    public double getParameter(int index) {
        return this.parameters[index];
    }

    @Override
    public void setParameter(int index, double v) {
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        this.parameters[index] = v;
    }

    @Override
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override
    public void getParameters(@Var double[] buff) {
        if (buff == null || buff.length != this.parameters.length) {
            buff = new double[this.parameters.length];
        }
        System.arraycopy(this.parameters, 0, buff, 0, this.parameters.length);
    }

    @Override
    public void setParameters(double[] buff) {
        assert (buff != null);
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        if (buff.length != this.parameters.length) {
            this.parameters = new double[buff.length];
        }
        System.arraycopy(buff, 0, this.parameters, 0, buff.length);
    }

    @Override
    public double getValue() {
        if (this.cachedValueStale) {
            int li;
            ++this.numGetValueCalls;
            this.cachedValue = 0.0;
            this.cachedGradientStale = true;
            MatrixOps.setAll(this.cachedGradient, 0.0);
            double[] scores = new double[this.trainingList.getTargetAlphabet().size()];
            double value = 0.0;
            Iterator iter = this.trainingList.iterator();
            int ii = 0;
            while (iter.hasNext()) {
                ++ii;
                Instance instance = (Instance)iter.next();
                double instanceWeight = this.trainingList.getInstanceWeight(instance);
                Labeling labeling = instance.getLabeling();
                if (labeling == null) continue;
                this.theClassifier.getClassificationScores(instance, scores);
                FeatureVector fv = (FeatureVector)instance.getData();
                int li2 = labeling.getBestIndex();
                value = -(instanceWeight * Math.log(scores[li2]));
                if (Double.isNaN(value)) {
                    logger.fine("MaxEntTrainer: Instance " + instance.getName() + "has NaN value. log(scores)= " + Math.log(scores[li2]) + " scores = " + scores[li2] + " has instance weight = " + instanceWeight);
                }
                if (Double.isInfinite(value)) {
                    logger.warning("Instance " + instance.getSource() + " has infinite value; skipping value and gradient");
                    this.cachedValue -= value;
                    this.cachedValueStale = false;
                    return -value;
                }
                this.cachedValue += value;
                for (int si = 0; si < scores.length; ++si) {
                    if (scores[si] == 0.0) continue;
                    assert (!Double.isInfinite(scores[si]));
                    MatrixOps.rowPlusEquals(this.cachedGradient, this.numFeatures, si, fv, -instanceWeight * scores[si]);
                    int n = this.numFeatures * si + this.defaultFeatureIndex;
                    this.cachedGradient[n] = this.cachedGradient[n] + -instanceWeight * scores[si];
                }
            }
            double prior = 0.0;
            if (this.usingHyperbolicPrior) {
                for (li = 0; li < this.numLabels; ++li) {
                    for (int fi = 0; fi < this.numFeatures; ++fi) {
                        prior += this.hyperbolicPriorSlope / this.hyperbolicPriorSharpness * Math.log(Maths.cosh(this.hyperbolicPriorSharpness * this.parameters[li * this.numFeatures + fi]));
                    }
                }
            } else if (this.usingGaussianPrior) {
                for (li = 0; li < this.numLabels; ++li) {
                    for (int fi = 0; fi < this.numFeatures; ++fi) {
                        double param = this.parameters[li * this.numFeatures + fi];
                        prior += param * param / (2.0 * this.gaussianPriorVariance);
                    }
                }
            }
            double oValue = this.cachedValue;
            this.cachedValue += prior;
            this.cachedValue *= -1.0;
            this.cachedValueStale = false;
            progressLogger.info("Value (labelProb=" + oValue + " prior=" + prior + ") loglikelihood = " + this.cachedValue);
        }
        return this.cachedValue;
    }

    @Override
    public void getValueGradient(double[] buffer) {
        if (this.cachedGradientStale) {
            ++this.numGetValueGradientCalls;
            if (this.cachedValueStale) {
                this.getValue();
            }
            MatrixOps.plusEquals(this.cachedGradient, this.constraints);
            if (this.usingHyperbolicPrior) {
                throw new UnsupportedOperationException("Hyperbolic prior not yet implemented.");
            }
            if (this.usingGaussianPrior) {
                MatrixOps.plusEquals(this.cachedGradient, this.parameters, -1.0 / this.gaussianPriorVariance);
            }
            MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, 0.0);
            if (this.perLabelFeatureSelection == null) {
                for (int labelIndex = 0; labelIndex < this.numLabels; ++labelIndex) {
                    MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, labelIndex, 0.0, this.featureSelection, false);
                }
            } else {
                for (int labelIndex = 0; labelIndex < this.numLabels; ++labelIndex) {
                    MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, labelIndex, 0.0, this.perLabelFeatureSelection[labelIndex], false);
                }
            }
            this.cachedGradientStale = false;
        }
        assert (buffer != null && buffer.length == this.parameters.length);
        System.arraycopy(this.cachedGradient, 0, buffer, 0, this.cachedGradient.length);
    }

    public int getValueGradientCalls() {
        return this.numGetValueGradientCalls;
    }

    public int getValueCalls() {
        return this.numGetValueCalls;
    }

    public MaxEntOptimizableByLabelLikelihood useGaussianPrior() {
        this.usingGaussianPrior = true;
        this.usingHyperbolicPrior = false;
        return this;
    }

    public MaxEntOptimizableByLabelLikelihood useHyperbolicPrior() {
        this.usingGaussianPrior = false;
        this.usingHyperbolicPrior = true;
        return this;
    }

    public MaxEntOptimizableByLabelLikelihood useNoPrior() {
        this.usingGaussianPrior = false;
        this.usingHyperbolicPrior = false;
        return this;
    }

    public MaxEntOptimizableByLabelLikelihood setGaussianPriorVariance(double gaussianPriorVariance) {
        this.usingGaussianPrior = true;
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = gaussianPriorVariance;
        return this;
    }

    public MaxEntOptimizableByLabelLikelihood setHyperbolicPriorSlope(double hyperbolicPriorSlope) {
        this.usingGaussianPrior = false;
        this.usingHyperbolicPrior = true;
        this.hyperbolicPriorSlope = hyperbolicPriorSlope;
        return this;
    }

    public MaxEntOptimizableByLabelLikelihood setHyperbolicPriorSharpness(double hyperbolicPriorSharpness) {
        this.usingGaussianPrior = false;
        this.usingHyperbolicPrior = true;
        this.hyperbolicPriorSharpness = hyperbolicPriorSharpness;
        return this;
    }
}

