/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.learning.parametric.bayesian;

import eu.amidst.core.datastream.DataInstance;
import eu.amidst.core.datastream.DataOnMemory;
import eu.amidst.core.datastream.DataStream;
import eu.amidst.core.distribution.UnivariateDistribution;
import eu.amidst.core.inference.messagepassing.VMP;
import eu.amidst.core.learning.parametric.bayesian.BayesianParameterLearningAlgorithm;
import eu.amidst.core.learning.parametric.bayesian.SVB;
import eu.amidst.core.learning.parametric.bayesian.utils.DataPosterior;
import eu.amidst.core.learning.parametric.bayesian.utils.PlateuStructure;
import eu.amidst.core.learning.parametric.bayesian.utils.TransitionMethod;
import eu.amidst.core.learning.parametric.bayesian.utils.VMPLocalUpdates;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.models.DAG;
import eu.amidst.core.utils.CompoundVector;
import eu.amidst.core.utils.Serialization;
import eu.amidst.core.utils.Vector;
import eu.amidst.core.variables.Variable;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;

public class StochasticVI
implements BayesianParameterLearningAlgorithm,
Serializable {
    private static final long serialVersionUID = 4107783324901370839L;
    public static String SVB = "SVB";
    public static String PRIOR = "PRIOR";
    protected DataStream<DataInstance> dataStream;
    protected DAG dag;
    protected SVB svb = new SVB();
    protected int batchSize = 100;
    protected int maximumLocalIterations = 100;
    protected double localThreshold = 0.1;
    protected long dataSetSize;
    private long timiLimit;
    private double learningFactor = 0.75;
    private CompoundVector prior;
    private CompoundVector initialPosterior;
    private CompoundVector currentParam;
    private int iteration;
    private boolean firstBatch = true;

    public int getBatchSize() {
        return this.batchSize;
    }

    public int getSeed() {
        return this.svb.getSeed();
    }

    public void setLearningFactor(double learningFactor) {
        this.learningFactor = learningFactor;
    }

    public void setTimiLimit(long seconds) {
        this.timiLimit = seconds;
    }

    public void setDataSetSize(long dataSetSize) {
        this.dataSetSize = dataSetSize;
    }

    public StochasticVI() {
        this.svb.setNonSequentialModel(true);
    }

    public void setVMPOnFirstBatch(boolean firstBatch) {
        this.firstBatch = firstBatch;
    }

    @Override
    public void setPlateuStructure(PlateuStructure plateuStructure) {
        this.svb.setPlateuStructure(plateuStructure);
    }

    public void setTransitionMethod(TransitionMethod transitionMethod) {
        this.svb.setTransitionMethod(transitionMethod);
    }

    public void setLocalThreshold(double localThreshold) {
        this.localThreshold = localThreshold;
    }

    public void setMaximumLocalIterations(int maximumLocalIterations) {
        this.maximumLocalIterations = maximumLocalIterations;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    public SVB getSVB() {
        return this.svb;
    }

    @Override
    public void initLearning() {
        VMPLocalUpdates vmpLocalUpdates = new VMPLocalUpdates(this.svb.getPlateuStructure());
        this.svb.getPlateuStructure().setVmp(vmpLocalUpdates);
        this.svb.getPlateuStructure().getVMP().setMaxIter(this.maximumLocalIterations);
        this.svb.getPlateuStructure().getVMP().setThreshold(this.localThreshold);
        this.svb.setDAG(this.dag);
        this.svb.setWindowsSize(this.batchSize);
        this.svb.initLearning();
        this.prior = this.svb.getNaturalParameterPrior();
        this.initialPosterior = Serialization.deepCopy(this.svb.getPlateuStructure().getPlateauNaturalParameterPosterior());
        this.initialPosterior.sum(this.prior);
        this.svb.updateNaturalParameterPosteriors(this.initialPosterior);
        this.currentParam = this.svb.getNaturalParameterPrior();
        this.iteration = 0;
    }

    private void updateFirstBatch(DataOnMemory<DataInstance> firstBatch) {
        this.svb.getPlateuStructure().setVmp(new VMP());
        this.svb.getPlateuStructure().getVMP().setMaxIter(this.maximumLocalIterations);
        this.svb.getPlateuStructure().getVMP().setThreshold(this.localThreshold);
        this.svb.setDAG(this.dag);
        this.svb.setWindowsSize(this.batchSize);
        this.svb.initLearning();
        this.initialPosterior = Serialization.deepCopy(this.svb.getPlateuStructure().getPlateauNaturalParameterPosterior());
        this.initialPosterior.sum(this.prior);
        this.svb.updateNaturalParameterPosteriors(this.initialPosterior);
        this.svb.updateModel(firstBatch);
        this.currentParam = Serialization.deepCopy(this.svb.getPlateuStructure().getPlateauNaturalParameterPosterior());
        VMPLocalUpdates vmpLocalUpdates = new VMPLocalUpdates(this.svb.getPlateuStructure());
        this.svb.getPlateuStructure().setVmp(vmpLocalUpdates);
        this.svb.getPlateuStructure().getVMP().setMaxIter(this.maximumLocalIterations);
        this.svb.getPlateuStructure().getVMP().setThreshold(this.localThreshold);
        this.svb.setDAG(this.dag);
        this.svb.setWindowsSize(this.batchSize);
        this.svb.initLearning();
        this.prior = this.svb.getNaturalParameterPrior();
        this.svb.updateNaturalParameterPosteriors(this.currentParam);
        this.iteration = 0;
    }

    @Override
    public double updateModel(DataOnMemory<DataInstance> batch) {
        if (this.firstBatch) {
            this.updateFirstBatch(batch);
            this.firstBatch = false;
        }
        CompoundVector newParam = this.svb.updateModelOnBatchParallel(batch).getVector();
        newParam.multiplyBy((double)this.dataSetSize / (double)this.batchSize);
        newParam.sum((Vector)this.prior);
        double stepSize = Math.pow(1 + this.iteration, -this.learningFactor);
        newParam.multiplyBy(stepSize);
        this.currentParam.multiplyBy(1.0 - stepSize);
        this.currentParam.sum((Vector)newParam);
        this.svb.updateNaturalParameterPosteriors(this.currentParam);
        ++this.iteration;
        return Double.NaN;
    }

    @Override
    public int getWindowsSize() {
        throw new UnsupportedOperationException("Use method getBatchSise() instead.");
    }

    @Override
    public void setWindowsSize(int windowsSize) {
        throw new UnsupportedOperationException("Use method setBatchSise() instead.");
    }

    @Override
    public void setDataStream(DataStream<DataInstance> data) {
        this.dataStream = data;
    }

    @Override
    public double getLogMarginalProbability() {
        return Double.NaN;
    }

    @Override
    public void runLearning() {
        this.initLearning();
        boolean convergence = false;
        double totalTimeElbo = 0.0;
        double totalTime = 0.0;
        Iterator<DataOnMemory<DataInstance>> iterator = this.dataStream.iterableOverBatches(this.batchSize).iterator();
        while (!convergence) {
            long startBatch = System.nanoTime();
            DataOnMemory<DataInstance> batch = iterator.next();
            if (!iterator.hasNext()) {
                iterator = this.dataStream.iterableOverBatches(this.batchSize).iterator();
            }
            CompoundVector newParam = this.svb.updateModelOnBatchParallel(batch).getVector();
            newParam.multiplyBy((double)this.dataSetSize / (double)this.batchSize);
            newParam.sum((Vector)this.prior);
            double stepSize = Math.pow(1 + this.iteration, -this.learningFactor);
            newParam.multiplyBy(stepSize);
            this.currentParam.multiplyBy(1.0 - stepSize);
            this.currentParam.sum((Vector)newParam);
            this.svb.updateNaturalParameterPosteriors(this.currentParam);
            long startBatchELBO = System.nanoTime();
            long endBatch = System.nanoTime();
            System.out.println("TIME ELBO:" + (totalTimeElbo += (double)(endBatch - startBatchELBO)) / 1.0E9);
            System.out.println("SVI ELBO: " + this.iteration + ", " + stepSize + ", " + (totalTime += (double)(endBatch - startBatch)) / 1.0E9 + " seconds " + totalTimeElbo / 1.0E9 + " seconds" + (totalTime - totalTimeElbo) / 1.0E9 + " seconds");
            if ((totalTime - totalTimeElbo) / 1.0E9 > (double)this.timiLimit || this.iteration > this.maximumLocalIterations) {
                convergence = true;
            }
            ++this.iteration;
        }
    }

    @Override
    public void setDAG(DAG dag_) {
        this.dag = dag_;
    }

    @Override
    public void setSeed(int seed) {
        this.svb.setSeed(seed);
    }

    @Override
    public BayesianNetwork getLearntBayesianNetwork() {
        return this.svb.getLearntBayesianNetwork();
    }

    @Override
    public void setParallelMode(boolean parallelMode) {
    }

    @Override
    public void setOutput(boolean activateOutput) {
        this.svb.setOutput(activateOutput);
    }

    @Override
    public List<DataPosterior> computePosterior(DataOnMemory<DataInstance> batch) {
        return null;
    }

    @Override
    public List<DataPosterior> computePosterior(DataOnMemory<DataInstance> batch, List<Variable> latentVariables) {
        return null;
    }

    @Override
    public <E extends UnivariateDistribution> E getParameterPosterior(Variable parameter) {
        return this.svb.getParameterPosterior(parameter);
    }

    @Override
    public double predictedLogLikelihood(DataOnMemory<DataInstance> batch) {
        return this.svb.predictedLogLikelihood(batch);
    }
}

