/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.naivebayes;

import java.io.IOException;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.naivebayes.NaiveBayesModel;
import opennlp.tools.util.TrainingParameters;

public class NaiveBayesTrainer
extends AbstractEventTrainer {
    public static final String NAIVE_BAYES_VALUE = "NAIVEBAYES";
    private int numUniqueEvents;
    private int numEvents;
    private int numPreds;
    private int numOutcomes;
    private int[][] contexts;
    private float[][] values;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private String[] outcomeLabels;
    private String[] predLabels;

    public NaiveBayesTrainer() {
    }

    public NaiveBayesTrainer(TrainingParameters parameters) {
        super(parameters);
    }

    @Override
    public boolean isSortAndMerge() {
        return false;
    }

    @Override
    public AbstractModel doTrain(DataIndexer indexer) throws IOException {
        return this.trainModel(indexer);
    }

    public AbstractModel trainModel(DataIndexer di) {
        this.display("Incorporating indexed data for training...  \n");
        this.contexts = di.getContexts();
        this.values = di.getValues();
        this.numTimesEventsSeen = di.getNumTimesEventsSeen();
        this.numEvents = di.getNumEvents();
        this.numUniqueEvents = this.contexts.length;
        this.outcomeLabels = di.getOutcomeLabels();
        this.outcomeList = di.getOutcomeList();
        this.predLabels = di.getPredLabels();
        this.numPreds = this.predLabels.length;
        this.numOutcomes = this.outcomeLabels.length;
        this.display("done.\n");
        this.display("\tNumber of Event Tokens: " + this.numUniqueEvents + "\n");
        this.display("\t    Number of Outcomes: " + this.numOutcomes + "\n");
        this.display("\t  Number of Predicates: " + this.numPreds + "\n");
        this.display("Computing model parameters...\n");
        Context[] finalParameters = this.findParameters();
        this.display("...done.\n");
        return new NaiveBayesModel(finalParameters, this.predLabels, this.outcomeLabels);
    }

    private MutableContext[] findParameters() {
        int[] allOutcomesPattern = new int[this.numOutcomes];
        for (int oi = 0; oi < this.numOutcomes; ++oi) {
            allOutcomesPattern[oi] = oi;
        }
        Context[] params = new MutableContext[this.numPreds];
        for (int pi = 0; pi < this.numPreds; ++pi) {
            params[pi] = new MutableContext(allOutcomesPattern, new double[this.numOutcomes]);
            for (int aoi = 0; aoi < this.numOutcomes; ++aoi) {
                ((MutableContext)params[pi]).setParameter(aoi, 0.0);
            }
        }
        EvalParameters evalParams = new EvalParameters(params, this.numOutcomes);
        double stepSize = 1.0;
        for (int ei = 0; ei < this.numUniqueEvents; ++ei) {
            int targetOutcome = this.outcomeList[ei];
            for (int ni = 0; ni < this.numTimesEventsSeen[ei]; ++ni) {
                for (int ci = 0; ci < this.contexts[ei].length; ++ci) {
                    int pi = this.contexts[ei][ci];
                    if (this.values == null) {
                        ((MutableContext)params[pi]).updateParameter(targetOutcome, stepSize);
                        continue;
                    }
                    ((MutableContext)params[pi]).updateParameter(targetOutcome, stepSize * (double)this.values[ei][ci]);
                }
            }
        }
        this.trainingStats(evalParams);
        return params;
    }

    private double trainingStats(EvalParameters evalParams) {
        int numCorrect = 0;
        for (int ei = 0; ei < this.numUniqueEvents; ++ei) {
            for (int ni = 0; ni < this.numTimesEventsSeen[ei]; ++ni) {
                double[] modelDistribution = new double[this.numOutcomes];
                if (this.values != null) {
                    NaiveBayesModel.eval(this.contexts[ei], this.values[ei], modelDistribution, evalParams, false);
                } else {
                    NaiveBayesModel.eval(this.contexts[ei], null, modelDistribution, evalParams, false);
                }
                int max = this.maxIndex(modelDistribution);
                if (max != this.outcomeList[ei]) continue;
                ++numCorrect;
            }
        }
        double trainingAccuracy = (double)numCorrect / (double)this.numEvents;
        this.display("Stats: (" + numCorrect + "/" + this.numEvents + ") " + trainingAccuracy + "\n");
        return trainingAccuracy;
    }

    private int maxIndex(double[] values) {
        int max = 0;
        for (int i = 1; i < values.length; ++i) {
            if (!(values[i] > values[max])) continue;
            max = i;
        }
        return max;
    }
}

