## Base class for nimble function list quadrature grids.
QUAD_CACHE_BASE <- nimbleFunctionVirtual(
    run = function() {},
    methods = list(
        cacheQuadGrid = function(levels = double(), nodes = double(2), wgts = double(1), 
                                 modeI = integer(), prune = double(0, default = 0)) {
        },
        nodes = function(idx = integer(0, default = 0)) {
            returnType(double(2))
        },
        weights = function(idx = integer(0, default = 0)) {
            returnType(double(1))
        },
        modeIndex = function() {
            returnType(integer())
        },
        gridSize = function() {
            returnType(integer())
        },
        checkGrid = function(levels = double(0, default = -1), d = integer(0, default = 1), prune = double(0, default = 0)) {
            returnType(logical())
        }
    )
)

#' Log sum exponential.
#'
#'  Compute the sum of two log values on the real scale and return on log scale.
#'
#' @name logSumExp
#'
#' @param log1 scalar of value 1 to exponentiate and sum.
#' @param log2 scalar of value 2 to exponentiate and sum with value 1.
#'
#' @details Adds two values from the log scale at the exponential scale, and then logs it. When values are really negative,
#' this function is numerically stable and reduces the chance of underflow. It is a two value version of a log-sum-exponential.

#' @return \code{logSumExp} returns \code{log(exp(log1) + exp(log2))}
#'
#' @export
logSumExp = nimbleFunction(run = function(log1 = double(), log2 = double()) {
    if (log1 > log2) {
        ans <- log(1 + exp(log2 - log1)) + log1
    } else ans <- log(1 + exp(log1 - log2)) + log2
    returnType(double())
    return(ans)
}, buildDerivs = list(run = list()))

#' Caching system for building multiple quadrature grids.
#'
#' Save the quadrature grid generated by a chosen rule and return it upon request.
#' 
#'
#' @details
#' This function is intended to be used within another NIMBLE function to conveninetly cache multiple quadrature grids. 
#' It cannot be compiled without being included within a virtual nimble list "QUAD_CACHE_BASE".
#'
#' @author Paul van Dam-Bates
#'
#' @export
quadGridCache <- nimbleFunction(
    contains = QUAD_CACHE_BASE,
    setup = function() {
        nodes_cached <- matrix(0, nrow = 1, ncol = 1)
        weights_cached <- c(0, 0)
        modeIndex_cached <- -1
        nGrid_cached <- 0L
        gridBuilt <- FALSE
        prune_ <- 0
        levels_ <- 0
        constructionRule_ <- "notset" 
        d_ <- 1
    },
    run = function() {
    },
    methods = list(
     cacheQuadGrid = function(levels = double(), nodes = double(2), wgts = double(1),
                                 modeI = integer(), prune = double(0, default = 0)) {
        nodes_cached <<- nodes
        weights_cached <<- wgts
        modeIndex_cached <<- modeI
        nGrid_cached <<- dim(nodes_cached)[1]
        gridBuilt <<- TRUE
        levels_ <<- levels
        prune_ <<- prune
        d_ <<- dim(nodes_cached)[2]
    },
    checkGrid = function(levels = double(0, default = -1), d = integer(0, default = 1), prune = double(0, default = -1)) {
        returnType(logical())
        if (!gridBuilt | (levels != levels_) | (prune_ != prune) | (d_ != d))
            return(FALSE) else return(TRUE)
    },
    nodes = function(idx = integer(0, default = 0)) {
        returnType(double(2))
        if (idx > 0) {
            if(idx > nGrid_cached) stop("Trying to access more quadrature points than available.")
            return(matrix(nodes_cached[idx, ], nrow = 1))
        }
        if (idx == -1 & modeIndex_cached > 0)
            return(matrix(nodes_cached[modeIndex_cached, ], nrow = 1))
        if(idx > nGrid_cached)
          stop("Trying to access more quadrature points than available.")
        return(nodes_cached)
    },
    weights = function(idx = integer(0, default = 0)) {
        returnType(double(1))
        if (idx > 0){
            if(idx > nGrid_cached) stop("Trying to access more quadrature points than available.")
            return(numeric(value = weights_cached[idx], length = 1))
        }
        if (idx == -1 & modeIndex_cached > 0)
            return(numeric(value = weights_cached[modeIndex_cached], length = 1))
        if(idx > nGrid_cached)
          stop("Trying to access more quadrature points than available.")
        return(weights_cached)
    },
    modeIndex = function() {
        returnType(integer())
        return(modeIndex_cached)
    },
    gridSize = function() {
        returnType(integer())
        return(nGrid_cached)
    }
  )
)

#' Configure Quadrature Grids
#'
#' Takes requested quadrature rules, builds the associated quadrature grids, and caches them. Updates 
#' all the features of the grids on call and returns nodes and weights.
#' 
#' @name configureQuadGrid
#'
#' @param d Number of dimensions.
#' @param levels Number of quadrature points to generate in each dimension (or the level of accuracy of a sparse grid).
#' @param quadRule Default quadRule to be used. Options are "AGHQ" or "CCD". May also be a user supplied quadrature rule as a nimbleFunction.
#' @param control list to control how the quadrature rules are built. See details for options. 
#'
#' @details
#'  Different options for building quadrature rules can be specified by `control` list.  These include
#' 
#' \itemize{
#'    \item \code{quadRules} which includes all the different rules that are being requested.
#'    \item \code{constructionRule} How each quadrature rule should be combined into multiple dimension. Currently possible to choose "product" which repeats
#'      creates a grid of repeated nodes in each dimension. Alternatively, can use "sparse" to apply standard sparse construction.
#'    \item \code{CCD_f0} multiplier for the CCD grid for how much past the radius sqrt(d) to extend the nodes. Unless an advanced user, keep at default of 1.1.
#'    \item \code{prune} the proportion of quadrature points (when generated by the product rule) to keep based on the weights for integration over a multivariate normal.
#'    \item \code{userConstruction} choose method to construct multivariate grid. If "MULTI" then user provided a multivariate construction in the provided function. If "PRODUCT"
#'    then a product rule construction is used to generate quadrature points in each dimension. If "SPARSE", then a Smolyak rule is applied.
#' }
#'
#' Quadrature grids are generally based on adaptive Gauss-Hermite (GH) quadrature which is expanded via a product or sparse rule into multiple dimensions. Sparse grids are
#' built following the Smolyak rule (Heiss and Winschel, 2008) and demonstrated in the package \pkg{mvQuad} (Weiser, 2023). Pruning is also implemented as described in 
#' \enc{Jäckel}{Jaeckel} (2005), where weights are adjusted by the value of a standard multivariate normal at that node, and nodes are removed until some threshold is met.
#' 
#' The available methods that can called by this function once it is setup are:
#'
#' \itemize{
#'   \item \code{buildGrid} method for creating the quadrature grid. Inputs are \code{method} includes ("AGHQ", "CCD", "USER") to choose the active quadrature rule. 
#'   \code{nQuad} number of quadrature points (nQuad) per dimension, \code{prune} proportion of points in the product construction to use, and \code{constructionRule} 
#'   includes ("product", "sparse"). Default behaviour for all input is to use values that were last requested.
#'
#'   \item \code{setDim} Allows the user to change the dimension of the quadrature grids which will reset all grids in use.
#'
#'   \item \code{nodes}, \code{weights}, \code{gridSize}, and \code{modeIndex} give access to the user for the details of the quadrature grid in use. This is either based
#'   on the last call to \code{buildGrid}, or by choosing a different grid with \code{setMethod}. \code{nodes} and \code{weights} return all nodes and weights if no values
#'   are passed, or if an index is passed, the node and weight associated with that index. Passing -1 indicates that the mode should be returned which in this case is all zeros.
#' }
#'
#'
#' @author Paul van Dam-Bates
#'
#' @references
#' Heiss, F. and Winschel V. (2008). Likelihood approximation by numerical integration on sparse grids. Journal of Econometrics 144 (1), 62–80.
#'
#' Weiser, C. (2023). _mvQuad: Methods for Multivariate Quadrature._. (R package version 1.0-8), <https://CRAN.R-project.org/package=mvQuad>.
#'
#' \enc{Jäckel}{Jaeckel}, P. (2005). A note on multivariate gauss-hermite quadrature. London: ABN-Amro. Re.
#'
#' @examples
#'
#' library(mvQuad)
#' RmvQuad <- function(levels, d) {
#'    out <- mvQuad::createNIGrid(dim=d, type = "GHe", level=levels, ndConstruction = "sparse")
#'    cbind(out$weights, out$nodes)
#' }
#' nimMVQuad <- nimbleRcall(function(levels = integer(), d = integer()){},
#'                          Rfun = "RmvQuad", returnType = double(2))
#' myQuadRule <- nimbleFunction(
#'      contains = QUAD_RULE_BASE,
#'      name = "quadRule_USER",
#'      setup = function() {},
#'      run = function() {},
#'      methods = list(
#'          buildGrid = function(levels = integer(0, default = 0), d = integer(0, default = 1)) {
#'              output <- nimMVQuad(levels, d)
#'              returnType(double(2))
#'              return(output)
#'          }
#'      )
#'  )
#'
#' quadGrid_user <- configureQuadGrid(d=2, levels=3, quadRule = myQuadRule,
#'                    control = list(quadRules = c("AGHQ", "CCD", "AGHQSPARSE"),
#'                              userConstruction = "MULTI"))
#'
#' @export
configureQuadGrid <- nimbleFunction(
    name = "quadGridClass",
    setup = function(d = 1, levels = 3, quadRule = "AGHQ", control = list()) {
        ## Can list all possible quad rules here and set it.
        possibleRules <- c("AGHQ", "CCD", "AGHQSPARSE")
        
        quadRules <- extractControlElement(control, "quadRules", NULL)
        userType <- extractControlElement(control, "userConstruction", "MULTI")
        if(!userType %in% c("MULTI", "SPARSE", "PRODUCT"))
          stop("Error:  The types of user grid contstructions are either MULTI (default) or constructed based on univariate rules by SPARSE or PRODUCT combinations.")
          
        ccd_f0 <- extractControlElement(control, "CCD_f0", 1.1)

        prune_ <- extractControlElement(control, "prune", 0)
        if (prune_ > 1 | prune_ < 0)
            stop("Can only prune a proportion of quadrature points.")
        
        ## Default quad rule will be 1.
        if(is.function(quadRule)){
            defaultRule <- "USER"
            if(!is.nfGenerator(quadRule))
                stop("User-provided quadrature rule must be a nimbleFunction")
            if(!identical(QUAD_RULE_BASE,environment(quadRule)$contains))
                stop("User-provided quadrature rule must set `contains = QUAD_RULE_BASE`")
            if(!"buildGrid" %in% names(environment(quadRule)$methods) ||
               !all(c('levels','d') %in% names(formals(environment(quadRule)$methods$buildGrid))))
               stop("User-provided quadrature rule must provide `buildGrid` method with arguments `levels` and `d`")
        } else{
          defaultRule <- quadRule
        }
        if (!all(quadRules %in% possibleRules))
            stop("Error:  Only AGHQ, CCD, and AGHQSPARSE rules are currently implemented. User rules must be supplied as a function to `quadRule`.")
            
        if (!any(defaultRule == quadRules))
            quadRules <- c(defaultRule, quadRules)

        numError <- 1e-16

        quadGridCache_nfl <- nimbleFunctionList(QUAD_CACHE_BASE)
        quadRule_nfl <- nimbleFunctionList(QUAD_RULE_BASE)

        I_AGHQ <- I_CCD <- I_USER <- I_AGHQSPARSE <- as.numeric(-1)
        I_RULE <- which(quadRules == defaultRule)[1]

        nRules <- length(quadRules)
        for (i in 1:nRules) {
            inum <- as.numeric(i)
            if (quadRules[i] == "AGHQ") {
                I_AGHQ <- inum
                quadRule_nfl[[i]] <- quadRule_GH(type="GHe")
            } else if (quadRules[i] == "CCD") {
                I_CCD <- inum
                quadRule_nfl[[i]] <- quadRule_CCD(f0 = ccd_f0)
            } else if (quadRules[i] == "AGHQSPARSE") {
                I_AGHQSPARSE <- inum
                quadRule_nfl[[i]] <- quadRule_GH(type="GHe")
            } else if (quadRules[i] == "USER") {
                I_USER <- inum
                quadRule_nfl[[i]] <- quadRule()
            } else {
              stop("An unrecognized quadRule was detected.")
            }
            quadGridCache_nfl[[i]] <- quadGridCache()            
        }
        gridBuilt <- rep(FALSE, nRules)
        nGrid <- numeric(nRules)
        modeI <- -1L + integer(nRules)
        if(nRules == 1){
          gridBuilt <- c(gridBuilt, FALSE)
          nGrid <- c(nGrid, 0)
          quadRules <- c(quadRules, "NULL")
          modeI <- c(modeI, -1)
        }
    },
    run = function() {
    },
    methods = list(
        ## NULL means keep it as is, and nQuad = -1.
        buildGrid = function(method = character(0, default = "NULL"),
                             nQuad = integer(0, default = -1), 
                             prune = double(0, default = -1)) {
            if ( method != "NULL" ) setRule(method)
            if ( nQuad != -1 ) levels <<- nQuad
            if(prune != -1) prune_ <<- prune
            if ( !quadGridCache_nfl[[I_RULE]]$checkGrid(levels = levels, d = d, prune = prune_) ) {
              newGrid()
            }
            modeI[I_RULE] <<- quadGridCache_nfl[[I_RULE]]$modeIndex()
            nGrid[I_RULE] <<- quadGridCache_nfl[[I_RULE]]$gridSize()
        },
        newGrid = function(){
          ## No sparse or product rule for d=1. User can provide a multivariate grid similar to CCD as well.
          if ( I_RULE == I_CCD | (I_RULE == I_USER & userType == "MULTI") | d == 1){
            nodes_wgts <- quadRule_nfl[[I_RULE]]$buildGrid(levels = levels, d = d)
          } else {
            if ( I_RULE == I_AGHQSPARSE | (I_RULE == I_USER & userType == "SPARSE") ){
              nodes_wgts <- sparse_construction()
            } else {
              levels_vec <- numeric(value = levels, length = d) ## In theory we can choose a different number of nodes per dimension.
              nodes_wgts <- product_construction(levels_vec = levels_vec)
            }
          }
          if(dim(nodes_wgts)[2] != d+1)
            stop("Quadrature Grid is not the correct number of dimensions.")

          findModeIndex(nodes_wgts)
          
          if(prune_ > 0)
            nodes_wgts <- pruneGrid(nodes_wgts)
          
          nGrid[I_RULE] <<- dim(nodes_wgts)[1]
          quadGridCache_nfl[[I_RULE]]$cacheQuadGrid(levels = levels, nodes = matrix(nodes_wgts[,2:(d+1)], ncol = d, nrow = nGrid[I_RULE]),
                                                      wgts = nodes_wgts[,1], modeI = modeI[I_RULE], prune = prune_)
            
          gridBuilt[I_RULE] <<- TRUE
        },
        findModeIndex = function(nodes_wgts = double(2)){
          modei <- -1L
          if(I_RULE == I_CCD) {
            modei <- 1
          } else {
            nQ <- dim(nodes_wgts)[1]
            if(nQ %% 2 == 0 & ( I_RULE == I_AGHQ )){
              modei <- -1L
            } else {
              modei <- ceiling(nQ/2)
              if (sum(abs(nodes_wgts[modei, 2:(d + 1)])) > 1e-15) {
                for (i in 1:nQ) {
                    if (sum(abs(nodes_wgts[i, 2:(d + 1)])) < 1e-15) modei <- i
                }
              }
            }
          }
          modeI[I_RULE] <<- as.integer(modei)
        },
        ## Prune grid and then cache it again.
        pruneGrid = function( nodes_wgts = double(2) ) {
            if ( prune_ > 1 | prune_ < 0 ) stop("Can only prune a proportion of quadrature points.")
            nQ <- dim(nodes_wgts)[1]
            if (I_RULE != I_CCD | prune_ == 0) {
                ntrim <- 0
                ## Adjust weights:
                log_weights_adj <- numeric(value = 0, length = nQ)
                for (i in 1:nQ) {
                    log_weights_adj[i] <- sum(dnorm(nodes_wgts[i,2:(d+1)], mean = 0, sd = 1, log = TRUE)) +
                        log(nodes_wgts[i,1])
                }
                ## Bubble sort as we don't have a quantile function.
                order <- 1:nQ
                for( k in 1:(nQ-1) ){
                  i <- 1
                  while( i < nQ-k+1 ){
                    if(log_weights_adj[order[i]] > log_weights_adj[order[i+1]]) {
                      tmp <- order[i+1]
                      order[i+1] <- order[i]
                      order[i] <- tmp
                    }
                  i <- i+1
                  }
                }
                ntrim <- ceiling(nQ*prune_)
                q <- log_weights_adj[order][ntrim]
                keep <- which(log_weights_adj > q + numError) ## Matching weights can be numerically different and want to trim symmetrically.
                if(dim(keep)[1] >= 3) {
                  nodes_wgts <- nodes_wgts[keep,]
                  if (modeI[I_RULE] > 0) {
                      modei <- which(keep == modeI[I_RULE] )
                      if (dim(modei)[1] > 0) modeI[I_RULE]  <<- modei[1] else modeI[I_RULE] <<- -1
                  }
                } else {
                  stop("Will not prune to less than 3 quadrature points. Choose another pruning proportion or switch to Laplace, one quadrature node.")
                }
            }
            returnType(double(2))
            return(nodes_wgts)
        },
        product_construction = function(levels_vec = double(1)) {
            nQ <- prod(levels_vec)
            nodes_wgts <- matrix(1, nrow = nQ, ncol = d + 1)

            ## Get quad grid
            nodes_mat <- matrix(0, nrow = d, ncol = max(levels_vec))
            weights_vec <- matrix(0, nrow = d, ncol = max(levels_vec))
            for (i in 1:d) {
                nodesi <- quadRule_nfl[[I_RULE]]$buildGrid(levels_vec[i], d = 1)
                nodes_mat[i, 1:levels_vec[i]] <- nodesi[, 2]
                weights_vec[i, 1:levels_vec[i]] <- nodesi[, 1]
            }

            swp <- nimNumeric(value = 0, length = d)
            swp[1] <- 1
            for (ii in 2:d) swp[ii] <- prod(levels_vec[1:(ii - 1)])

            ## Do Product Rule: Repeat x for each dimension swp times.
            for (j in 1:d) {
                idx <- 1
                for (ii in 1:nQ) {
                    nodes_wgts[ii, j + 1] <- nodes_mat[j, idx]
                    nodes_wgts[ii, 1] <- nodes_wgts[ii, 1] * weights_vec[j, idx]
                    k <- ii %% swp[j]
                    if (k == 0)
                        idx <- idx + 1
                    if (idx > levels_vec[j])
                        idx <- 1
                }
            }
            returnType(double(2))
            return(nodes_wgts)
        },
        sparse_construction = function() {
            if( I_RULE == I_AGHQSPARSE | (I_RULE == I_USER & userType == "SPARSE") ) {
              minq <- max(0, levels - d)
              maxq <- levels - 1

              noSubGrids <- sum(factorial(minq:maxq + d - 1) / (factorial(minq:maxq) * factorial(d - 1)))

              gridCombos <- matrix(0, nrow = noSubGrids, ncol = d)
              start <- 1
              for (q in minq:maxq) {
                  tmpCombos <- drop_algorithm(d, d + q)
                  nq <- dim(tmpCombos)[1]
                  gridCombos[start:(start + nq - 1), ] <- tmpCombos
                  start <- start + nq
              }

              totalPts <- 0
              for (i in 1:noSubGrids) {
                  totalPts <- totalPts + prod(gridCombos[i, ])
              }
              firstZero <- 0
              zeroIndex <- 0
              extraZeros <- 0
              nodes_wgts <- matrix(0, nrow = totalPts, ncol = d + 1)
              cnt <- 1
              for (i in 1:noSubGrids) {
                  zeroi <- FALSE              
                  q <- sum(gridCombos[i, ]) - d
                  wgtadj <- (-1)^(maxq - q) * factorial(d - 1) / (factorial(d + q - levels) *
                      factorial(levels - q - 1))
                  nodes_prod <- product_construction(levels_vec = gridCombos[i, ])
                  ni <- dim(nodes_prod)[1]
                  ## Adjust weights:
                  nodes_prod[,1] <- nodes_prod[, 1] * wgtadj
                  ## Sum repeated generations of the zeros (modes)
                  if(all(ceiling(gridCombos[i,] / 2) != gridCombos[i,] / 2)){
                    zi <- ceiling(ni/2)
                    if(firstZero == 0){
                      firstZero <- i
                      zeroIndex <- cnt + zi - 1
                    }else{
                      zeroi <- TRUE
                    }
                  }
                  if(zeroi){
                    nodes_wgts[cnt:(cnt + zi - 2), ] <- nodes_prod[1:(zi-1), ]
                    nodes_wgts[zeroIndex, 1] <- nodes_wgts[zeroIndex, 1] + nodes_prod[zi, 1] ## Aggregate the Mode.
                    nodes_wgts[(cnt+zi-1):(cnt + ni - 2), ] <- nodes_prod[(zi+1):ni, ]
                    cnt <- cnt + dim(nodes_prod)[1] - 1
                  }else{
                    nodes_wgts[cnt:(cnt + ni - 1), ] <- nodes_prod
                    cnt <- cnt + dim(nodes_prod)[1]
                  }
              }
              nodes_wgts <- nodes_wgts[1:(cnt-1),]
            } else {
              stop("Trying to apply sparse construction to an invalid quadrature rule.")
            }
            returnType(double(2))
            return(nodes_wgts)            
        },
        ## Surely there is a better way to do this...
        setRule = function(method = character(0, default = "AGHQ")) {
            success <- FALSE
            i <- 1
            while(i <= nRules & !success){
              if(quadRules[i] == method){
                I_RULE <<- i
                success <- TRUE
              }
              i <- i+1
            }
            if( !success )
              stop("Quadrature Rule being requested was either not created or is invalid. Choose a valid quadrature rule.")
        },
        setDim = function(ndim = integer(0, default = 1)) {
            if (ndim <= 0) stop("Can't input negative dimensions") else d <<- ndim
            ## Make sure the next grid gets built.
            for( i in 1:nRules ) gridBuilt[i] <<- FALSE
        },
        weights = function(idx = integer(0, default = 0)) {
            if (!gridBuilt[I_RULE]) buildGrid()
            if (idx == -1 & modeI[I_RULE] > 0) idx <- modeI[I_RULE] 
            returnType(double(1))
            return(quadGridCache_nfl[[I_RULE]]$weights(idx = idx))
        },
        nodes = function(idx = integer(0, default = 0)) {
            if (!gridBuilt[I_RULE]) buildGrid()
            if (idx == -1 & modeI[I_RULE]  > 0) idx <- modeI[I_RULE]
            returnType(double(2))
            return(quadGridCache_nfl[[I_RULE]]$nodes(idx = idx))
        },
        gridSize = function() {
            if (!gridBuilt[I_RULE]) buildGrid()
            returnType(integer())
            return(nGrid[I_RULE])
        },
        modeIndex = function() {
            if (!gridBuilt[I_RULE]) buildGrid()
            returnType(double())
            return(modeI[I_RULE] )
        }
    )
)  ## End of configureQuadGrid

## Create a caching random effects system for simulating the posterior random
## effect distribution according to Stringer: This requires the inner mode, the
## and inner cholesky of the negative Hessian.
## Not exported
INNER_CACHE_BASE <- nimbleFunctionVirtual(
    run = function() {
    },
    methods = list(
        buildCache = function(nGridUpdate = integer(), nLatents = integer()) {},
        cache_weights = function(weight = double(), idx = integer()) {},
        cache_inner_mode = function(mode = double(1), idx = integer()) {},
        cache_inner_negHessChol = function(negHessChol = double(2), idx = integer()) {},
        gridSize = function(){returnType(integer())},
        weights = function() {
            returnType(double(1))
        },
        simulate = function(n = integer()) {
            returnType(double(2))
        }
    )
)

## For doing simulation of random-effects save outer mode and negHessian.  
## Need wgt*density and inner mode and inner cholesky for each quad node
# Not exported
inner_cache_methods = nimbleFunction(
    contains = INNER_CACHE_BASE,
    setup = function(nre = 0, nGrid = 0, condIndepSets = NULL, nCondIndepSets = 1) {
        verbose <- isTRUE(nimble::getNimbleOption('verbose'))
        innerMode <- matrix(0, nrow = 1, ncol = 1)
        innerNegHessChol <- array(0, c(1, 1, 1))
        wgtsDens <- c(1, -1)
        cacheBuilt <- FALSE
        if (is.null(condIndepSets)) {
            condIndepSets <- nre  ## Assuming all one set.
            nCondIndepSets <- 1  ## If NULL then this is not relevant.
        }
        if (length(condIndepSets) == 1) {
            condIndepSets <- c(condIndepSets, -1)  ##  Make sure it's a vector.
        }
    },
    run = function() {
    },
    methods = list(
        buildCache = function(nGridUpdate = integer(0, default = -1), nLatents = integer()) {
            nre <<- nLatents
            ## If the cond independent sets don't match up, don't use.
            if (nre != sum(condIndepSets[1:nCondIndepSets])) {
                if(verbose)
                  print("  Warning: Not able to simulate latent effects from conditionally independent sets.")
                condIndepSets <<- numeric(value = nre, length = 1)
                nCondIndepSets <<- 1
            }

            if (nGridUpdate > 0 & nGridUpdate != nGrid) {
                nGrid <<- nGridUpdate
                cacheBuilt <<- FALSE
            }
            if (!cacheBuilt) {
                nGrid <<- nGridUpdate
                wgtsDens <<- numeric(value = 0, length = nGrid)
                innerMode <<- matrix(0, nrow = nGrid, ncol = nre)
                innerNegHessChol <<- array(0, c(nGrid, nre, nre))
                cacheBuilt <<- TRUE
            }
        },
        ## Note to self, this wgt will be density*wgt, strictly for simulating.
        cache_weights = function(weight = double(), idx = integer()) {
            if(idx > nGrid | idx <= 0)
              stop("Trying to cache weights larger than we expect the quadrature grid to be.")
            wgtsDens[idx] <<- weight
        },
        cache_inner_mode = function(mode = double(1), idx = integer()) {
            if(idx > nGrid | idx <= 0)
              stop("Trying to cache latent mode larger than we expect the quadrature grid to be.")
            innerMode[idx, ] <<- mode
        },
        ## Note potentially storing a lot of zeros here. Could break it into a list of cond indpt sets.
        cache_inner_negHessChol = function(negHessChol = double(2), idx = integer()) {
            if(idx > nGrid | idx <= 0)
              stop("Trying to cache latent hessian larger than we expect the quadrature grid to be.")
        
            innerNegHessChol[idx, , ] <<- negHessChol
        },
        weights = function() {
            returnType(double(1))
            return(wgtsDens)
        },
        gridSize = function(){
          returnType(integer())
          return(nGrid)
        },
        ## Adding first column to be index for theta.
        simulate = function(n = integer()) {
            val <- matrix(0, nrow = n, ncol = nre + 1)
            simwgt <- wgtsDens/sum(wgtsDens)  ## Did log sum exp when doing input.

            ## Simulate theta points first.  Seems efficient to separate to not
            ## initiate too many index vectors for cond indpt sets.
            for (i in 1:n) {
                k <- rcat(1, prob = simwgt)
                val[i, 1] <- k
                jStart <- 1
                for (j in 1:nCondIndepSets) {
                    val[i, (jStart + 1):(jStart + condIndepSets[j])] <-
                        rmnorm_chol(n = 1, mean = innerMode[k, jStart:(jStart + condIndepSets[j] - 1)],
                                    cholesky = innerNegHessChol[k, jStart:(jStart + condIndepSets[j] - 1),
                                                                jStart:(jStart + condIndepSets[j] - 1)],
                                    prec_param = TRUE)
                    jStart <- jStart + condIndepSets[j]
                }
            }
            returnType(double(2))
            return(val)
        }
    )
)
