/*
 * Decompiled with CFR 0.152.
 */
package moa.evaluation.preview;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import moa.evaluation.preview.LearningCurve;
import moa.evaluation.preview.Preview;
import moa.evaluation.preview.PreviewCollection;
import moa.evaluation.preview.PreviewCollectionLearningCurveWrapper;

public class MeanPreviewCollection {
    PreviewCollection<PreviewCollection<Preview>> origMultiRunPreviews;
    PreviewCollection<Preview> meanPreviews;
    PreviewCollection<Preview> stdPreviews;

    public MeanPreviewCollection(PreviewCollection<PreviewCollection<Preview>> multiRunPreviews) {
        this.origMultiRunPreviews = multiRunPreviews;
        this.meanPreviews = new PreviewCollection("mean preview entry id", "parameter value id", multiRunPreviews.taskClass, multiRunPreviews.variedParamName, multiRunPreviews.variedParamValues);
        this.stdPreviews = new PreviewCollection("mean preview entry id", "parameter value id", multiRunPreviews.taskClass, multiRunPreviews.variedParamName, multiRunPreviews.variedParamValues);
        int numFolds = multiRunPreviews.subPreviews.size();
        int numParamValues = multiRunPreviews.variedParamValues.length;
        int numEntriesPerPreview = multiRunPreviews.numEntries() / numFolds / numParamValues;
        for (int paramValue = 0; paramValue < numParamValues; ++paramValue) {
            this.constructMeanStdPreviewsForParam(numEntriesPerPreview, numParamValues, paramValue);
        }
    }

    public PreviewCollection<Preview> getMeanPreviews() {
        return this.meanPreviews;
    }

    public PreviewCollection<Preview> getStdPreviews() {
        return this.stdPreviews;
    }

    private void constructMeanStdPreviewsForParam(int numEntriesPerPreview, int numParamValues, int paramValue) {
        List<double[]> meanParamMeasurements = this.calculateMeanMeasurementsForParam(numEntriesPerPreview, numParamValues, paramValue);
        List<double[]> stdParamMeasurements = this.calculateStdMeasurementsForParam(numEntriesPerPreview, numParamValues, paramValue, meanParamMeasurements);
        String[] meanMeasurementNames = this.origMultiRunPreviews.getMeasurementNames();
        meanMeasurementNames = Arrays.copyOfRange(meanMeasurementNames, 4, meanMeasurementNames.length);
        String[] stdMeasurementNames = new String[meanMeasurementNames.length];
        stdMeasurementNames[0] = meanMeasurementNames[0];
        for (int m = 1; m < meanMeasurementNames.length; ++m) {
            stdMeasurementNames[m] = "[std] " + meanMeasurementNames[m];
        }
        LearningCurve meanLearningCurve = new LearningCurve(meanMeasurementNames[0]);
        meanLearningCurve.setData(Arrays.asList(meanMeasurementNames), meanParamMeasurements);
        LearningCurve stdLearningCurve = new LearningCurve(stdMeasurementNames[0]);
        stdLearningCurve.setData(Arrays.asList(stdMeasurementNames), stdParamMeasurements);
        PreviewCollectionLearningCurveWrapper meanParamValuePreview = new PreviewCollectionLearningCurveWrapper(meanLearningCurve, this.origMultiRunPreviews.taskClass);
        PreviewCollectionLearningCurveWrapper stdParamValuePreview = new PreviewCollectionLearningCurveWrapper(stdLearningCurve, this.origMultiRunPreviews.taskClass);
        this.meanPreviews.setPreview(paramValue, meanParamValuePreview);
        this.stdPreviews.setPreview(paramValue, stdParamValuePreview);
    }

    private List<double[]> calculateMeanMeasurementsForParam(int numEntriesPerPreview, int numParamValues, int paramValue) {
        ArrayList<double[]> paramMeasurementsSum = new ArrayList<double[]>(numEntriesPerPreview);
        ArrayList<double[]> meanParamMeasurements = new ArrayList<double[]>(numEntriesPerPreview);
        int numCompleteFolds = 0;
        for (PreviewCollection foldPreview : this.origMultiRunPreviews.subPreviews) {
            if (foldPreview.getPreviews().size() != numParamValues) continue;
            ++numCompleteFolds;
            Preview foldParamPreview = (Preview)foldPreview.getPreviews().get(paramValue);
            this.addPreviewMeasurementsToSum(paramMeasurementsSum, foldParamPreview, numEntriesPerPreview);
        }
        for (int entryIdx = 0; entryIdx < numEntriesPerPreview; ++entryIdx) {
            double[] sumEntry = (double[])paramMeasurementsSum.get(entryIdx);
            double[] meanEntry = new double[sumEntry.length];
            meanEntry[0] = sumEntry[0];
            for (int m = 1; m < sumEntry.length; ++m) {
                meanEntry[m] = sumEntry[m] / (double)numCompleteFolds;
            }
            meanParamMeasurements.add(meanEntry);
        }
        return meanParamMeasurements;
    }

    private List<double[]> calculateStdMeasurementsForParam(int numEntriesPerPreview, int numParamValues, int paramValue, List<double[]> meanParamMeasurements) {
        ArrayList<double[]> paramMeasurementsSquaredDiffSum = new ArrayList<double[]>(numEntriesPerPreview);
        ArrayList<double[]> paramMeasurementsStd = new ArrayList<double[]>(numEntriesPerPreview);
        int numCompleteFolds = 0;
        for (PreviewCollection foldPreview : this.origMultiRunPreviews.subPreviews) {
            if (foldPreview.getPreviews().size() != numParamValues) continue;
            ++numCompleteFolds;
            Preview foldParamPreview = (Preview)foldPreview.getPreviews().get(paramValue);
            this.addPreviewMeasurementSquaredDiffsToSum(meanParamMeasurements, paramMeasurementsSquaredDiffSum, foldParamPreview, numEntriesPerPreview);
        }
        for (int entryIdx = 0; entryIdx < numEntriesPerPreview; ++entryIdx) {
            double[] sumEntry = (double[])paramMeasurementsSquaredDiffSum.get(entryIdx);
            double[] stdEntry = new double[sumEntry.length];
            stdEntry[0] = sumEntry[0];
            for (int m = 1; m < sumEntry.length; ++m) {
                stdEntry[m] = numCompleteFolds > 1 ? Math.sqrt(sumEntry[m] / (double)(numCompleteFolds - 1)) : Math.sqrt(sumEntry[m]);
            }
            paramMeasurementsStd.add(stdEntry);
        }
        return paramMeasurementsStd;
    }

    private void addPreviewMeasurementsToSum(List<double[]> measurementsSum, Preview preview, int numEntriesPerPreview) {
        List<double[]> previewMeasurements = preview.getData();
        for (int entryIdx = 0; entryIdx < numEntriesPerPreview; ++entryIdx) {
            double[] sumEntry;
            double[] previewEntry = previewMeasurements.get(entryIdx);
            if (measurementsSum.size() > entryIdx) {
                sumEntry = measurementsSum.get(entryIdx);
            } else {
                sumEntry = new double[previewEntry.length];
                measurementsSum.add(sumEntry);
            }
            if (sumEntry[0] == 0.0) {
                sumEntry[0] = previewEntry[0];
            }
            for (int measure = 1; measure < sumEntry.length; ++measure) {
                int n = measure;
                sumEntry[n] = sumEntry[n] + previewEntry[measure];
            }
        }
    }

    private void addPreviewMeasurementSquaredDiffsToSum(List<double[]> meanMeasurements, List<double[]> measurementsSquaredDiffSum, Preview preview, int numEntriesPerPreview) {
        List<double[]> previewMeasurements = preview.getData();
        for (int entryIdx = 0; entryIdx < numEntriesPerPreview; ++entryIdx) {
            double[] squaredDiffSumEntry;
            double[] meanEntry = meanMeasurements.get(entryIdx);
            double[] previewEntry = previewMeasurements.get(entryIdx);
            if (measurementsSquaredDiffSum.size() > entryIdx) {
                squaredDiffSumEntry = measurementsSquaredDiffSum.get(entryIdx);
            } else {
                squaredDiffSumEntry = new double[previewEntry.length];
                measurementsSquaredDiffSum.add(squaredDiffSumEntry);
            }
            if (squaredDiffSumEntry[0] == 0.0) {
                squaredDiffSumEntry[0] = previewEntry[0];
            }
            int m = 1;
            while (m < previewEntry.length) {
                double diff = meanEntry[m] - previewEntry[m];
                double squaredDiff = diff * diff;
                int n = m++;
                squaredDiffSumEntry[n] = squaredDiffSumEntry[n] + squaredDiff;
            }
        }
    }
}

