/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Boostable;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.DecisionTree;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import java.util.logging.Logger;

public class DecisionTreeTrainer
extends ClassifierTrainer<DecisionTree>
implements Boostable {
    private static Logger logger = MalletLogger.getLogger(DecisionTreeTrainer.class.getName());
    public static final int DEFAULT_MAX_DEPTH = 5;
    public static final double DEFAULT_MIN_INFO_GAIN_SPLIT = 0.001;
    int maxDepth = 5;
    double minInfoGainSplit = 0.001;
    boolean finished = false;
    DecisionTree classifier = null;

    public DecisionTreeTrainer(int maxDepth) {
        this.maxDepth = maxDepth;
    }

    public DecisionTreeTrainer() {
        this(4);
    }

    public DecisionTreeTrainer setMaxDepth(int maxDepth) {
        this.maxDepth = maxDepth;
        return this;
    }

    public DecisionTreeTrainer setMinInfoGainSplit(double m) {
        this.minInfoGainSplit = m;
        return this;
    }

    @Override
    public boolean isFinishedTraining() {
        return this.finished;
    }

    @Override
    public DecisionTree getClassifier() {
        return this.classifier;
    }

    @Override
    public DecisionTree train(InstanceList trainingList) {
        FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
        DecisionTree.Node root = new DecisionTree.Node(trainingList, null, selectedFeatures);
        this.splitTree(root, selectedFeatures, 0);
        root.stopGrowth();
        this.finished = true;
        System.out.println("DecisionTree learned:");
        root.print();
        this.classifier = new DecisionTree(trainingList.getPipe(), root);
        return this.classifier;
    }

    protected void splitTree(DecisionTree.Node node, FeatureSelection selectedFeatures, int depth) {
        if (depth == this.maxDepth || node.getSplitInfoGain() < this.minInfoGainSplit) {
            return;
        }
        logger.info("Splitting feature \"" + node.getSplitFeature() + "\" infogain=" + node.getSplitInfoGain());
        node.split(selectedFeatures);
        this.splitTree(node.getFeaturePresentChild(), selectedFeatures, depth + 1);
        this.splitTree(node.getFeatureAbsentChild(), selectedFeatures, depth + 1);
    }

    public static abstract class Factory
    extends ClassifierTrainer.Factory<DecisionTreeTrainer> {
        protected static int maxDepth = 5;
        protected static double minInfoGainSplit = 0.001;

        @Override
        public DecisionTreeTrainer newClassifierTrainer(Classifier initialClassifier) {
            DecisionTreeTrainer t = new DecisionTreeTrainer();
            t.maxDepth = maxDepth;
            t.minInfoGainSplit = minInfoGainSplit;
            return t;
        }
    }
}

