#!/usr/bin/perl

use Modern::Perl '2018';
use Fatal qw(open);

die "Must be run from top level" if !-e './tools/genStan';

my $dest = 'src/stan_files';

my $itemModel = 'functions {
#include /functions/cmp_prob2.stan
}
';

my $commonData =
'  // dimensions
  int<lower=1> NPA;             // number of players or objects or things
  int<lower=1> NCMP;            // number of unique comparisons
  int<lower=1> N;               // number of observations';

my $responseData =
'  // response data
  int<lower=1, upper=NPA> pa1[NCMP];        // PA1 for observation N
  int<lower=1, upper=NPA> pa2[NCMP];        // PA2 for observation N
  int weight[NCMP];
  int pick[NCMP];
  int refresh[NCMP];';

my $unidim_ll =
'  vector[NTHRESH*2 + 1] prob;
  vector[N] log_lik;
  int cur = 1;

  for (cmp in 1:NCMP) {
    real ll;
    if (refresh[cmp]) {
      prob = cmp_probs(scale, alpha, theta[pa1[cmp]], theta[pa2[cmp]], cumTh);
    }
    ll = categorical_lpmf(rcat[cmp] | prob);
    for (wx in 1:weight[cmp]) {
      log_lik[cur] = ll;
      cur = cur + 1;
    }
  }
';

my $threshold_prior = 'threshold ~ normal(0, 2.0);';
my $multivariateThresholdPrior = "$threshold_prior";
my $alpha_prior = 'alpha ~ exponential(0.1);';

sub mkUnidim {
    my ($adapt, $ll) = @_;
    my $data = $adapt? "varCorrection" : "scale";
    my $par = $adapt? "sigma" : "alpha";
    my $scaleDef = $adapt? "
  real scale = (sd(theta) ^ varCorrection)/1.749;" : "";
    my $theta_sd = $adapt? "sigma" : "1.0";
    my $prior = $adapt? 'sigma ~ lognormal(1, 1);' : $alpha_prior;
    my $cmp_alpha = $adapt? '' : ', alpha';
    my $llBody = $ll? $unidim_ll : '';

    my $unidim =
qq[data {
$commonData
  int<lower=1> NTHRESH;         // number of thresholds
  real $data;
$responseData
}
transformed data {
  int rcat[NCMP];

  for (cmp in 1:NCMP) {
    rcat[cmp] = pick[cmp] + NTHRESH + 1;
  }
}
parameters {
  vector[NPA] theta;
  vector[NTHRESH] threshold;
  real<lower=0> $par;
}
transformed parameters {
  vector[NTHRESH] cumTh = cumulative_sum(threshold);$scaleDef
}
model {
  vector[NTHRESH*2 + 1] prob;
  $prior
  theta ~ normal(0, $theta_sd);
  $threshold_prior
  for (cmp in 1:NCMP) {
    if (refresh[cmp]) {
      prob = cmp_probs(scale$cmp_alpha, theta[pa1[cmp]], theta[pa2[cmp]], cumTh);
    }
    if (weight[cmp] == 1) {
      target += categorical_lpmf(rcat[cmp] | prob);
    } else {
      target += weight[cmp] * categorical_lpmf(rcat[cmp] | prob);
    }
  }
}
generated quantities {
  real thetaVar = variance(theta);
$llBody}
];
}

sub mvtDataCommon {
    my ($extra) = @_;
    if (!defined $extra) {
	$extra = '';
    } else {
	$extra = "
  $extra";
    }
qq[data {
$commonData
  int<lower=1> NITEMS;
  int<lower=1> NTHRESH[NITEMS];         // number of thresholds
  int<lower=1> TOFFSET[NITEMS];
  vector[NITEMS] scale;$extra
$responseData
  int item[NCMP];
}
transformed data {
  int totalThresholds = sum(NTHRESH);
  int rcat[NCMP];
  for (cmp in 1:NCMP) {
    rcat[cmp] = pick[cmp] + NTHRESH[item[cmp]] + 1;
  }
}];
}

my $multivariateThresholdParam =
  'vector[totalThresholds] threshold;';

my $multivariateThresholdTParamDecl =
  'vector[totalThresholds] cumTh;';

my $multivariateThresholdTParam =
'for (ix in 1:NITEMS) {
    int from = TOFFSET[ix];
    int to = TOFFSET[ix] + NTHRESH[ix] - 1;
    cumTh[from:to] = cumulative_sum(threshold[from:to]);
  }';

my $multivariateQuickLikelihoodDecl =
  'vector[max(NTHRESH)*2 + 1] prob;
  int probSize;';

my $multivariateQuickLikelihood =
'for (cmp in 1:NCMP) {
    if (refresh[cmp]) {
      int ix = item[cmp];
      int from = TOFFSET[ix];
      int to = TOFFSET[ix] + NTHRESH[ix] - 1;
      probSize = (2*NTHRESH[ix]+1);
      prob[:probSize] = cmp_probs(scale[ix], alpha[ix],
               theta[pa1[cmp], ix],
               theta[pa2[cmp], ix], cumTh[from:to]);
    }
    if (weight[cmp] == 1) {
      target += categorical_lpmf(rcat[cmp] | prob[:probSize]);
    } else {
      target += weight[cmp] * categorical_lpmf(rcat[cmp] | prob[:probSize]);
    }
  }';

my $multivariateLikelihoodDecl =
  '
  vector[max(NTHRESH)*2 + 1] prob;
  int probSize;
  vector[N] log_lik;
  int cur = 1;
';

my $multivariateLikelihood = '

  for (cmp in 1:NCMP) {
    real ll;
    if (refresh[cmp]) {
      int ix = item[cmp];
      int from = TOFFSET[ix];
      int to = TOFFSET[ix] + NTHRESH[ix] - 1;
      probSize = (2*NTHRESH[ix]+1);
      prob[:probSize] = cmp_probs(scale[ix], alpha[ix],
               theta[pa1[cmp], ix],
               theta[pa2[cmp], ix], cumTh[from:to]);
    }
    ll = categorical_lpmf(rcat[cmp] | prob[:probSize]);
    for (wx in 1:weight[cmp]) {
      log_lik[cur] = ll;
      cur = cur + 1;
    }
  }';

sub mkCorModel {
    my ($ll) = @_;
    my $llDecl = $ll? $multivariateLikelihoodDecl : '';
    my $llBody = $ll? $multivariateLikelihood : '';
    mvtDataCommon().
qq[
parameters {
  $multivariateThresholdParam
  vector<lower=0>[NITEMS] alpha;
  matrix[NPA,NITEMS]      rawTheta;
  cholesky_factor_corr[NITEMS] rawThetaCorChol;
}
transformed parameters {
  $multivariateThresholdTParamDecl
  matrix[NPA,NITEMS]      theta;

  // non-centered parameterization due to thin data
  for (pa in 1:NPA) {
    theta[pa,] = (rawThetaCorChol * rawTheta[pa,]')';
  }
  $multivariateThresholdTParam
}
model {
  $multivariateQuickLikelihoodDecl

  rawThetaCorChol ~ lkj_corr_cholesky(2);
  for (pa in 1:NPA) {
    rawTheta[pa,] ~ normal(0,1);
  }
  $multivariateThresholdPrior
  $alpha_prior
  $multivariateQuickLikelihood
}
generated quantities {$llDecl
  corr_matrix[NITEMS] thetaCor;
  thetaCor = multiply_lower_tri_self_transpose(rawThetaCorChol);$llBody
}
];
}

sub mkFactorModel {
    my ($ll) = @_;
    my $llDecl = $ll? $multivariateLikelihoodDecl : '';
    my $llBody = $ll? $multivariateLikelihood : '';

    mvtDataCommon('vector[NITEMS] alpha;').
    qq[
parameters {
  $multivariateThresholdParam
  matrix[NPA,NITEMS] rawUniqueTheta; // do not interpret, see uniqueTheta
  vector[NPA] rawFactor;      // do not interpret, see factor
  vector[NITEMS] rawLoadings; // do not interpret, see factorLoadings
  vector[NITEMS] rawUnique;      // do not interpret, see unique
}
transformed parameters {
  $multivariateThresholdTParamDecl
  matrix[NPA,NITEMS] theta;
  for (pa in 1:NPA) {
    // theta must be normal(0,1) distributed
    theta[pa,] = ((rawFactor[pa] * rawLoadings)' +
                  rawUniqueTheta[pa,] .* rawUnique');
  }
  $multivariateThresholdTParam
}
model {
  $multivariateQuickLikelihoodDecl

  $multivariateThresholdPrior
  target += NITEMS * normal_lpdf(rawFactor | 0, 1.0);
  for (ix in 1:NITEMS) {
    rawLoadings[ix] ~ normal(0, 1);
    rawUnique[ix] ~ normal(1, 1);
  }
  for (pa in 1:NPA) {
    rawUniqueTheta[pa,] ~ normal(0, 1.0);
  }
  $multivariateQuickLikelihood
}
generated quantities {$llDecl
  vector[NITEMS] sigma;
  vector[NITEMS] factorLoadings = rawLoadings;
  vector[NITEMS] factorProp;
  vector[NPA] factor = rawFactor;
  row_vector[NITEMS] unique = rawUnique';
  matrix[NPA,NITEMS] uniqueTheta = rawUniqueTheta;

  for (fx in 1:NITEMS) {
    sigma[fx] = sd(theta[,fx]);
    if (unique[fx] < 0) {
      unique[fx] = -unique[fx];
      uniqueTheta[,fx] = -uniqueTheta[,fx];
    }
  }
  if (factorLoadings[1] < 0) {
    factorLoadings = -factorLoadings;
    factor = -factor;
  }
  for (fx in 1:NITEMS) {
    // https://www.tandfonline.com/doi/full/10.1080/00031305.2018.1549100
    real resid = variance(rawUniqueTheta[,fx] * rawUnique[fx]);
    real pred = variance(rawFactor * rawLoadings[fx]);
    factorProp[fx] = pred / (pred + resid);
    if (factorLoadings[fx] < 0) factorProp[fx] = -factorProp[fx];
  }$llBody
}
];
}

{
    open my $fh, ">$dest/unidim.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkUnidim(0,0);
}
{
    open my $fh, ">$dest/unidim_adapt.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh 'functions {
#include /functions/cmp_prob1.stan
}
';
    print $fh mkUnidim(1,0);
}
{
    open my $fh, ">$dest/unidim_ll.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkUnidim(0,1);
}
{
    open my $fh, ">$dest/correlation.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkCorModel(0);
}
{
    open my $fh, ">$dest/correlation_ll.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkCorModel(1);
}
{
    open my $fh, ">$dest/factor.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkFactorModel(0);
}
{
    open my $fh, ">$dest/factor_ll.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkFactorModel(1);
}
