#!/usr/bin/env python3

import argparse
import functools
import pathlib
import re
import statistics
import sys
import tempfile

import numpy
import pandas
import plotly.express
import tabulate

def parse_lnt(lines, aggregate=statistics.median):
    """
    Parse lines in LNT format and return a list of dictionnaries of the form:

        [
            {
                'benchmark': <benchmark1>,
                <metric1>: float,
                <metric2>: float,
                ...
            },
            {
                'benchmark': <benchmark2>,
                <metric1>: float,
                <metric2>: float,
                ...
            },
            ...
        ]

    If a metric has multiple values associated to it, they are aggregated into a single
    value using the provided aggregation function.
    """
    results = {}
    for line in lines:
        line = line.strip()
        if not line:
            continue

        (identifier, value) = line.split(' ')
        (benchmark, metric) = identifier.split('.')
        if benchmark not in results:
            results[benchmark] = {'benchmark': benchmark}

        entry = results[benchmark]
        if metric not in entry:
            entry[metric] = []
        entry[metric].append(float(value))

    for (bm, entry) in results.items():
        for metric in entry:
            if isinstance(entry[metric], list):
                entry[metric] = aggregate(entry[metric])

    return list(results.values())

def plain_text_comparison(data, metric, baseline_name=None, candidate_name=None):
    """
    Create a tabulated comparison of the baseline and the candidate for the given metric.
    """
    data = data.replace(numpy.nan, None) # avoid NaNs in tabulate output
    headers = ['Benchmark', baseline_name, candidate_name, 'Difference', '% Difference']
    fmt = (None, '.2f', '.2f', '.2f', '.2%')
    table = data[['benchmark', f'{metric}_0', f'{metric}_1', 'difference', 'percent']]

    # Compute the geomean and report on their difference
    geomean_0 = statistics.geometric_mean(data[f'{metric}_0'].dropna())
    geomean_1 = statistics.geometric_mean(data[f'{metric}_1'].dropna())
    geomean_row = ['Geomean', geomean_0, geomean_1, (geomean_1 - geomean_0), (geomean_1 - geomean_0) / geomean_0]
    table.loc[table.index.max() + 1] = geomean_row

    return tabulate.tabulate(table.set_index('benchmark'), headers=headers, floatfmt=fmt, numalign='right')

def create_chart(data, metric, subtitle=None, series_names=None):
    """
    Create a bar chart comparing the given metric across the provided series.
    """
    data = data.rename(columns={f'{metric}_{i}': series_names[i] for i in range(len(series_names))})
    title = ' vs '.join(series_names)
    figure = plotly.express.bar(data, title=title, subtitle=subtitle, x='benchmark', y=series_names, barmode='group')
    figure.update_layout(xaxis_title='', yaxis_title='', legend_title='')
    return figure

def produce_kpis(data, noise, extrema, series, series_names, meta_candidate, title):
    addendum = f"{noise:.0%} noise threshold, based on {len(data)} benchmarks"
    top_addendum = f"by >= {extrema:.0%}, {noise:.0%} noise threshold, based on {len(data)} benchmarks"
    headers = [title if title else '']
    columns = [[
        f'Benchmarks where {meta_candidate} is faster than {series_names[0]} ({addendum})',
        f'Neutral benchmarks ({addendum})',
        f'Benchmarks where {meta_candidate} is slower than {series_names[0]} ({addendum})',
        f'Worst performers ({top_addendum})',
        f'Best performers ({top_addendum})',
    ]]
    fmt = [None]

    def compute_kpis(base, cand):
        diff = data[cand] - data[base]
        pct = diff / data[base]
        faster = data[(data[base] > data[cand]) & (pct.abs() > noise)]
        neutral = data[pct.abs() <= noise]
        slower = data[(data[base] < data[cand]) & (pct.abs() > noise)]
        worst = data[(data[base] < data[cand]) & (pct.abs() >= extrema)]
        best = data[(data[base] > data[cand]) & (pct.abs() >= extrema)]
        return list(map(lambda k: len(k) / len(data), [faster, neutral, slower, worst, best]))

    baseline = series[0]
    for (i, candidate) in enumerate(series[1:], start=1):
        kpis = compute_kpis(baseline, candidate)
        headers.append(series_names[i])
        columns.append(kpis)
        fmt.append('.2%')

    rows = list(zip(*columns))
    print(tabulate.tabulate(rows, headers=headers, floatfmt=fmt))

def main(argv):
    parser = argparse.ArgumentParser(
        prog='compare-benchmarks',
        description='Compare the results of multiple sets of benchmarks in LNT format.',
        epilog='This script depends on the modules listed in `libcxx/utils/requirements.txt`.')
    parser.add_argument('files', type=argparse.FileType('r'), nargs='+',
        help='Path to LNT format files containing the benchmark results to compare. In the text format, '
             'exactly two files must be compared.')
    parser.add_argument('--output', '-o', type=pathlib.Path, required=False,
        help='Path of a file where to output the resulting comparison. If the output format is `text`, '
             'default to stdout. If the output format is `chart`, default to a temporary file which is '
             'opened automatically once generated, but not removed after creation.')
    parser.add_argument('--metric', type=str, default='execution_time',
        help='The metric to compare. LNT data may contain multiple metrics (e.g. code size, execution time, etc) -- '
             'this option allows selecting which metric is being analyzed. The default is `execution_time`.')
    parser.add_argument('--filter', type=str, required=False,
        help='An optional regular expression used to filter the benchmarks included in the comparison. '
             'Only benchmarks whose names match the regular expression will be included.')
    parser.add_argument('--ignore-under', type=float, required=False,
        help='Ignore benchmarks whose value (in absolute terms) is less than the provided float for all '
             'the data sets being compared. This allows ignoring benchmarks that are likely to contain '
             'a significant amount of noise.')
    parser.add_argument('--sort', type=str, required=False, default='benchmark',
                        choices=['benchmark', 'baseline', 'candidate', 'percent_diff'],
        help='Optional sorting criteria for displaying results. By default, results are displayed in '
             'alphabetical order of the benchmark. Supported sorting criteria are: '
             '`benchmark` (sort using the alphabetical name of the benchmark), '
             '`baseline` (sort using the absolute number of the baseline run), '
             '`candidate` (sort using the absolute number of the candidate run), '
             'and `percent_diff` (sort using the percent difference between the baseline and the candidate). '
             'Note that when more than two input files are compared, the only valid sorting order is `benchmark`.')
    parser.add_argument('--format', type=str, choices=['text', 'chart', 'kpi'], default='text',
        help='Select the output format. `text` generates a plain-text comparison in tabular form, `chart` '
             'generates a self-contained HTML graph that can be opened in a browser, and `kpi` generates a '
             'summary report based on a few KPIs. The default is `text`.')
    parser.add_argument('--open', action='store_true',
        help='Whether to automatically open the generated HTML file when finished. This option only makes sense '
             'when the output format is `chart`.')
    parser.add_argument('--series-names', type=str, required=False,
        help='Optional comma-delimited list of names to use for the various series. By default, we use '
             'Baseline and Candidate for two input files, and CandidateN for subsequent inputs.')
    parser.add_argument('--subtitle', type=str, required=False,
        help='Optional subtitle to use for the chart. This can be used to help identify the contents of the chart. '
             'This option cannot be used with the plain text output.')
    parser.add_argument('--noise-threshold', type=float, required=False,
        help='Noise threshold used by KPIs to determine noise. This is a floating point number between '
             '0 and 1 that represents the percentage of difference required between two results in order '
             'for them not to be considered "within the noise" of each other.')
    parser.add_argument('--top-performer-threshold', type=float, required=False, default=0.5,
        help='Threshold percent used by KPIs to determine top (and worst) performers. This is a floating '
             'point number between 0 and 1 that represents the percentage of difference required to consider '
             'a benchmark to be a top/worst performer. For example, if this number is 0.5, we consider top/worst '
             'performers in the data to be benchmarks that have at least 50%% of difference between the baseline '
             'and the candidate.')
    parser.add_argument('--meta-candidate', type=str, required=False,
        help='The name to use for the candidate when producing a KPI report. Required for --format=kpi.')
    parser.add_argument('--discard-benchmarks-introduced-after', type=str, required=False,
        help='Discard benchmarks introduced after the given candidate. This is useful to stabilize reports '
             'when new benchmarks are introduced as time goes on, which would change the total number of '
             'benchmarks and hence appear to retroactively change the report for previous candidates. '
             'If used, the name used here must correspond to the name of a series (as passed to or defaulted '
             'via `--series-names`.')
    args = parser.parse_args(argv)

    # Validate arguments (the values admissible for various arguments depend on other
    # arguments, the number of inputs, etc)
    if args.format == 'text':
        if len(args.files) != 2:
            parser.error('--format=text requires exactly two input files to compare')
        if args.subtitle is not None:
            parser.error('Passing --subtitle makes no sense with --format=text')
        if args.open:
            parser.error('Passing --open makes no sense with --format=text')

    if args.format == 'kpi':
        if args.open:
            parser.error('Passing --open makes no sense with --format=kpi')
        if args.noise_threshold is None:
            raise parser.error('--format=kpi requires passing a --noise-threshold')
        if args.meta_candidate is None:
            raise parser.error('--format=kpi requires passing a --meta-candidate')

    if len(args.files) != 2 and args.sort != 'benchmark':
        parser.error('Using any sort order other than `benchmark` requires exactly two input files.')

    if args.series_names is None:
        args.series_names = ['Baseline']
        if len(args.files) == 2:
            args.series_names += ['Candidate']
        elif len(args.files) > 2:
            args.series_names.extend(f'Candidate{n}' for n in range(1, len(args.files)))
    else:
        args.series_names = args.series_names.split(',')
        if len(args.series_names) != len(args.files):
            parser.error(f'Passed incorrect number of series names: got {len(args.series_names)} series names but {len(args.files)} inputs to compare')

    # Parse the raw LNT data and store each input in a dataframe
    lnt_inputs = [parse_lnt(file.readlines()) for file in args.files]
    series = [f'{args.metric}_{i}' for (i, _) in enumerate(lnt_inputs)]
    inputs = [pandas.DataFrame(lnt).rename(columns={args.metric: s}) for (s, lnt) in zip(series, lnt_inputs)]

    # Join the inputs into a single dataframe
    data = functools.reduce(lambda a, b: a.merge(b, how='outer', on='benchmark'), inputs)

    # If we have exactly two data sets, compute additional info in new columns
    if len(lnt_inputs) == 2:
        data['difference'] = data[f'{args.metric}_1'] - data[f'{args.metric}_0']
        data['percent'] = data['difference'] / data[f'{args.metric}_0']

    if args.filter is not None:
        keeplist = [b for b in data['benchmark'] if re.search(args.filter, b) is not None]
        data = data[data['benchmark'].isin(keeplist)]

    if args.ignore_under is not None:
        data = data[~(data[series] < args.ignore_under).all(axis=1)]

    # Sort the data by the appropriate criteria
    if args.sort == 'benchmark':
        data = data.sort_values(by='benchmark')
    elif args.sort == 'baseline':
        data = data.sort_values(by=f'{args.metric}_0')
    elif args.sort == 'candidate':
        data = data.sort_values(by=f'{args.metric}_1')
    elif args.sort == 'percent_diff':
        data = data.sort_values(by=f'percent')

    if args.format == 'chart':
        figure = create_chart(data, args.metric, subtitle=args.subtitle, series_names=args.series_names)
        do_open = args.output is None or args.open
        output = args.output or tempfile.NamedTemporaryFile(suffix='.html').name
        plotly.io.write_html(figure, file=output, auto_open=do_open)
    elif args.format == 'kpi':
        if args.discard_benchmarks_introduced_after is not None:
            index = args.series_names.index(args.discard_benchmarks_introduced_after)
            for candidate in series[index+1:]:
                first_candidate = f'{args.metric}_1'
                data = data[~(data[first_candidate].isna() & data[candidate].notna())]
        produce_kpis(data, noise=args.noise_threshold,
                           extrema=args.top_performer_threshold,
                           series=series,
                           series_names=args.series_names,
                           meta_candidate=args.meta_candidate,
                           title=args.subtitle)
    else:
        diff = plain_text_comparison(data, args.metric, baseline_name=args.series_names[0],
                                                        candidate_name=args.series_names[1])
        diff += '\n'
        if args.output is not None:
            with open(args.output, 'w') as out:
                out.write(diff)
        else:
            sys.stdout.write(diff)

if __name__ == '__main__':
    main(sys.argv[1:])
