#include "scalablebayesm.h"
#include <RcppArmadillo.h>

using namespace Rcpp;
using namespace arma;

//[[Rcpp::export]]
List rhierLinearDPParallel_rcpp_loop(List const& regdata,
                                     arma::mat const& Z,
                                     arma::vec const& deltabar,
                                     arma::mat const& Ad,
                                     List const& Prioralphalist,
                                     List const& lambda_hyper,
                                     arma::mat const& mubar,
                                     arma::mat const& Amu,
                                     int const& nu,
                                     arma::mat const& V,
                                     int nu_e,
                                     int maxuniq,
                                     int gridsize,
                                     arma::vec const& ssq,
                                     int R,
                                     int keep,
                                     int nprint,
                                     arma::mat olddelta,
                                     arma::vec const& a,
                                     arma::vec tau,
                                     bool drawdelta,
                                     double BayesmConstantA,
                                     int BayesmConstantnuInc,
                                     double BayesmConstantDPalpha, bool verbose){
  
  // // Initialize variable placeholders
   int mkeep, Istar;
   vec betabar, q0v;
   DPOut mgout_struct;
   //murooti thetaNp10_struct thetaStar0_struct;
   murooti thetaStarReg_struct;
   mat rootpi, Abeta, Abetabar;
   
   int nz = Z.n_cols;
   int nvar = V.n_cols;
   int nreg = regdata.size();
   
   unireg runiregout_struct; 
   
   // convert List to std::vector of struct
   Rcpp::List regdatai, nmix;
   std::vector<moments> regdata_vector;
   moments regdatai_struct;
   for (int reg = 0; reg < nreg; reg++){
     regdatai = regdata[reg];
     regdatai_struct.y = as<vec>(regdatai["y"]);
     regdatai_struct.X = as<mat>(regdatai["X"]);
     regdatai_struct.XpX = as<mat>(regdatai["XpX"]);//check for XpX
     regdatai_struct.Xpy = as<vec>(regdatai["Xpy"]);//check for Xpy
     regdata_vector.push_back(regdatai_struct); 
   }
   
   ivec indic = ones<ivec>(nreg);

   std::vector<murooti> thetaStar_vector(1); //declares a std::vector named thetaStar_vector that stores objects of type murooti.
   murooti thetaNp10_struct, thetaStar0_struct;//Two additional objects of type murooti are declared: thetaNp10_struct and thetaStar0_struct.
   thetaStar0_struct.mu = zeros<vec>(nvar); //set mu as a vector
   thetaStar0_struct.rooti = eye(nvar,nvar); //sets the member variable rooti to an identity matrix of size nvar by nvar.
   thetaStar_vector[0] = thetaStar0_struct; //it assigns the thetaStar0_struct object to the first element of the thetaStar_vector vector
   
   
   double alpha = BayesmConstantDPalpha;
   
   // fix oldprob (only one comp)
   double oldprob = 1.0;
   
   // convert Prioralpha from List to struct
   priorAlpha priorAlpha_struct; 
   priorAlpha_struct.power = Prioralphalist["power"]; 
   priorAlpha_struct.alphamin = Prioralphalist["alphamin"];
   priorAlpha_struct.alphamax = Prioralphalist["alphamax"];
   priorAlpha_struct.n = Prioralphalist["n"];
   
   // initialize lambda
   lambda lambda_struct;
   lambda_struct.mubar = zeros<vec>(nvar);
   lambda_struct.Amu = BayesmConstantA;
  lambda_struct.nu = nvar + BayesmConstantnuInc;
  lambda_struct.V = lambda_struct.nu * eye(nvar, nvar);

  // allocate space for draws
  mat oldbetas = zeros<mat>(nreg,nvar);
  mat taudraw(floor(R/keep), nreg); //check
  vec probdraw = zeros<vec>(R/keep);
  arma::mat Deltadraw(1,1); if(drawdelta) Deltadraw.zeros(floor(R/keep), nz*nvar);//Same as MNL.enlarge Deltadraw only if the space is required
  Rcpp::List compdraw(floor(R/keep));//same as MNL
  vec Istardraw = zeros<vec>(floor((R)/keep));

  if ((nprint>0) && verbose) startMcmcTimer();
   for (int rep = 0; rep<R; rep++){
  //   //Rcpp::List mgout;
     if(drawdelta){
       olddelta.reshape(nvar,nz);
       
       mgout_struct = rDPGibbs1(oldbetas-Z*trans(olddelta), //y
                                lambda_struct,
                                thetaStar_vector,
                                maxuniq,
                                indic,
                                q0v,
                                alpha,
                                priorAlpha_struct,
                                gridsize,
                                lambda_hyper);
     } else {
       mgout_struct = rDPGibbs1(oldbetas, //y
                                lambda_struct,
                                thetaStar_vector,
                                maxuniq,
                                indic,
                                q0v,
                                alpha,
                                priorAlpha_struct,
                                gridsize, 
                                lambda_hyper);
     }
   
indic = mgout_struct.indic;
lambda_struct = mgout_struct.lambda_struct; //get lambda
alpha = mgout_struct.alpha;//get alphadraw
thetaStar_vector = mgout_struct.thetaStar_vector; //thetaStar includes indic, lambda, and alpha
Istar = thetaStar_vector.size();

//if(drawdelta) olddelta = drawDelta(Z,oldbetas,ind,oldcomp,deltabar,Ad);
if(drawdelta) {olddelta = drawDeltaDP(Z,oldbetas,indic,thetaStar_vector,deltabar,Ad);}

// loop over all reg equations drawing beta_i | ind[i], z[i,], mu[ind[i]], rooti[ind[i]]
for (int reg = 0; reg < nreg; reg++){
  thetaStarReg_struct = thetaStar_vector[indic[reg] - 1];
  rootpi = thetaStarReg_struct.rooti;
  // note: beta_i = Delta*z_i + u_i  Delta is nvar x nz
  if(drawdelta){
    olddelta.reshape(nvar, nz);
    betabar = thetaStarReg_struct.mu + olddelta * trans(Z(reg, span::all));
  } else {
    betabar = thetaStarReg_struct.mu;
  }

  Abeta = trans(rootpi)*rootpi;
  Abetabar = Abeta*betabar;
  //check:
  runiregout_struct = runiregG1(regdata_vector[reg].y, regdata_vector[reg].X,
                               regdata_vector[reg].XpX, regdata_vector[reg].Xpy,
                               tau[reg], Abeta, Abetabar, nu_e, ssq[reg]);
  // Purpose:
  //  perform one Gibbs iteration for Univ Regression Model
  //  only does one iteration so can be used in rhierLinearModel
  // return beta and sigmasq
  oldbetas(reg,span::all) = trans(runiregout_struct.beta); //Beta not necessary return here
  tau[reg] = runiregout_struct.sigmasq; //taui is the variance of ei
}
   }

if ((nprint > 0) && verbose) startMcmcTimer();
//start main iteration loop
 for(int rep = 0; rep < R; rep++) {
//first draw comps, indic, p | {beta_i}, delta
   if(drawdelta){
     olddelta.reshape(nvar, nz);
     mgout_struct = rDPGibbs1(oldbetas - Z * trans(olddelta), lambda_struct, thetaStar_vector, maxuniq, indic, q0v, alpha, priorAlpha_struct, gridsize, lambda_hyper);
   } else {
     mgout_struct = rDPGibbs1(oldbetas, lambda_struct, thetaStar_vector, maxuniq, indic, q0v, alpha, priorAlpha_struct, gridsize, lambda_hyper);
   }
   indic = mgout_struct.indic; // For each iteration, randomly pick 1 mu and 1 rooti from 3 components of compdraw
   lambda_struct = mgout_struct.lambda_struct; // get lambda
   alpha = mgout_struct.alpha; // get alphadraw
   thetaStar_vector = mgout_struct.thetaStar_vector; // thetaStar includes indic, lambda, and alpha
   Istar = thetaStar_vector.size();
//now draw delta | {beta_i}, ind, comps
   if(drawdelta) {
     olddelta = drawDeltaDP(Z, oldbetas, indic, thetaStar_vector, deltabar, Ad);
   }
//print time to completion and draw # every nprint'th draw
   if ((nprint > 0) && verbose) if((rep + 1) % nprint == 0) infoMcmcTimer(rep, R);
   if(((rep+1)>0) & ((rep+1)%keep==0)){
     mkeep = (rep+1)/keep;
     probdraw(mkeep-1, arma::span::all) = oldprob; //vector of 1s
     Istardraw[mkeep - 1] = Istar;
     if(drawdelta) Deltadraw(mkeep-1, span::all) = trans(vectorise(olddelta));
//compdraw[mkeep-1] = oldcomp;
     thetaNp10_struct = mgout_struct.thetaNp1_vector[0];
     compdraw[mkeep - 1] = List::create(List::create(Named("mu") = NumericVector(thetaNp10_struct.mu.begin(), thetaNp10_struct.mu.end()), Named("rooti") = thetaNp10_struct.rooti));
   }
 }

  if ((nprint > 0) && verbose) endMcmcTimer();
//}  
  if (drawdelta){
    return(Rcpp::List::create(Rcpp::Named("compdraw") = compdraw,
                              Rcpp::Named("probdraw") = probdraw,
                              Rcpp::Named("Deltadraw") = Deltadraw));

  } else {
    return(Rcpp::List::create(Rcpp::Named("compdraw") = compdraw,
                              Rcpp::Named("probdraw") = probdraw));

  }
}  

