/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.functions;

import com.github.javacliparser.FloatOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.functions.SGD;
import moa.core.DoubleVector;

public class AdaGrad
extends SGD {
    private static final long serialVersionUID = -3732968666673530291L;
    protected double m_epsilon = 1.0E-8;
    public FloatOption epsilonOption = new FloatOption("epsilon", 'p', "epsilon parameter.", 1.0E-8);
    protected DoubleVector m_velocity;
    protected double m_biasVelocity;

    @Override
    public String getPurposeString() {
        return "An online optimiser for learning various linear models (binary class SVM, binary class logistic regression and linear regression).";
    }

    public void setEpsilon(double eps) {
        this.m_epsilon = eps;
    }

    public double getEpsilon() {
        return this.m_epsilon;
    }

    public AdaGrad() {
        this.lambdaRegularizationOption = new FloatOption(this.lambdaRegularizationOption.getName(), this.lambdaRegularizationOption.getCLIChar(), this.lambdaRegularizationOption.getPurpose(), 0.0);
        this.learningRateOption = new FloatOption(this.learningRateOption.getName(), this.learningRateOption.getCLIChar(), this.learningRateOption.getPurpose(), 0.01);
    }

    @Override
    public void resetLearningImpl() {
        this.reset();
        this.setLambda(this.lambdaRegularizationOption.getValue());
        this.setLearningRate(this.learningRateOption.getValue());
        this.setEpsilon(this.epsilonOption.getValue());
        this.setLossFunction(this.lossFunctionOption.getChosenIndex());
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        double dldz;
        double y;
        if (this.m_weights == null) {
            this.m_weights = new DoubleVector();
            this.m_velocity = new DoubleVector();
            this.m_bias = 0.0;
            this.m_weights.setValue(instance.numAttributes(), 0.0);
            this.m_velocity.setValue(instance.numAttributes(), 0.0);
        }
        if (instance.classIsMissing()) {
            return;
        }
        double z = AdaGrad.dotProd(instance, this.m_weights, instance.classIndex()) + this.m_bias;
        if (instance.classAttribute().isNominal()) {
            double d = y = instance.classValue() == 0.0 ? 0.0 : 1.0;
            if (this.m_loss == 1) {
                double yhat = 1.0 / (1.0 + Math.exp(-z));
                dldz = yhat - y;
            } else {
                dldz = (y = y * 2.0 - 1.0) * z < 1.0 ? -y : 0.0;
            }
        } else {
            y = instance.classValue();
            dldz = z - y;
        }
        int n = instance.numValues();
        DoubleVector gradients = new DoubleVector();
        gradients.setValue(instance.numAttributes(), 0.0);
        for (int i = 0; i < n; ++i) {
            int idx = instance.index(i);
            gradients.setValue(idx, instance.valueSparse(i) * dldz + this.m_lambda / (this.m_t + this.m_epsilon) * this.m_weights.getValue(idx));
        }
        double biasGradient = dldz;
        this.m_biasVelocity += biasGradient * biasGradient;
        this.m_bias -= this.m_learningRate / (Math.sqrt(this.m_biasVelocity) + this.m_epsilon) * biasGradient;
        for (int i = 0; i < this.m_weights.numValues(); ++i) {
            double g = gradients.getValue(i);
            this.m_velocity.addToValue(i, g * g);
            this.m_weights.addToValue(i, -(this.m_learningRate / (Math.sqrt(this.m_velocity.getValue(i)) + this.m_epsilon)) * g);
        }
        this.m_t += 1.0;
    }

    @Override
    protected String getModelName() {
        return "AdaGrad";
    }
}

