

#' Nested cross-validation for caret
#'
#' This function applies nested cross-validation (CV) to training of models
#' using the `caret` package. The function also allows the option of embedded
#' filtering of predictors for feature selection nested within the outer loop of
#' CV. Predictions on the outer test folds are brought back together and error
#' estimation/ accuracy determined. The default is 10x10 nested CV.
#'
#' @param y Response vector. For classification this should be a factor.
#' @param x Matrix or dataframe of predictors
#' @param filterFUN Filter function, e.g. [ttest_filter] or [relieff_filter].
#'   Any function can be provided and is passed `y` and `x`. Must return a
#'   character vector with names of filtered predictors.
#' @param filter_options List of additional arguments passed to the filter
#'   function specified by `filterFUN`.
#' @param outer_method String of either `"cv"` or `"LOOCV"` specifying whether
#'   to do k-fold CV or leave one out CV (LOOCV) for the outer folds
#' @param n_outer_folds Number of outer CV folds
#' @param outer_folds Optional list containing indices of test folds for outer
#'   CV. If supplied, `n_outer_folds` is ignored.
#' @param metric A string that specifies what summary metric will be used to
#'   select the optimal model. By default, "logLoss" is used for classification
#'   and "RMSE" is used for regression. Note this differs from the default
#'   setting in caret which uses "Accuracy" for classification. See details.
#' @param trControl A list of values generated by the `caret` function
#'   [trainControl]. This defines how inner CV training through `caret` is
#'   performed. Default for the inner loop is 10-fold CV. See
#'   http://topepo.github.io/caret/using-your-own-model-in-train.html.
#' @param tuneGrid Data frame of tuning values, see [caret::train].
#' @param savePredictions Indicates whether hold-out predictions for each inner
#'   CV fold should be saved for ROC curves, accuracy etc see
#'   [caret::trainControl]. Default is `"final"` to capture predictions for
#'   inner CV ROC.
#' @param cv.cores Number of cores for parallel processing of the outer loops.
#'   NOTE: this uses `parallel::mclapply` on unix/mac and `parallel::parLapply`
#'   on windows.
#' @param na.option Character value specifying how `NA`s are dealt with.
#'   `"omit"` is equivalent to `na.action = na.omit`. `"omitcol"` removes cases
#'   if there are `NA` in 'y', but columns (predictors) containing `NA` are
#'   removed from 'x' to preserve cases. Any other value means that `NA` are
#'   ignored (a message is given).
#' @param ... Arguments passed to [caret::train]
#' @return An object with S3 class "nestcv.train"
#'   \item{call}{the matched call}
#'   \item{output}{Predictions on the left-out outer folds}
#'   \item{outer_result}{List object of results from each outer fold containing
#'   predictions on left-out outer folds, caret result and number of filtered
#'   predictors at each fold.}
#'   \item{dimx}{dimensions of `x`}
#'   \item{outer_folds}{List of indices of outer test folds}
#'   \item{final_fit}{Final fitted caret model using best tune parameters}
#'   \item{final_vars}{Column names of filtered predictors entering final model}
#'   \item{roc}{ROC AUC for binary classification where available.}
#'   \item{trControl}{`caret::trainControl` object used for inner CV}
#'   \item{bestTunes}{best tuned parameters from each outer fold}
#'   \item{finalTune}{final parameters used for final model}
#'   \item{summary}{Overall performance summary. Accuracy and balanced accuracy
#'   for classification. ROC AUC for binary classification. RMSE for
#'   regression.}
#' @details Parallelisation is performed on the outer folds using `mclapply`.
#'   For classification `metric` defaults to using 'logLoss' with the
#'   `trControl` arguments `classProbs = TRUE, summaryFunction = mnLogLoss`,
#'   rather than 'Accuracy' which is the default classification metric in
#'   `caret`. See [trainControl]. LogLoss is arguably more consistent than
#'   Accuracy for tuning parameters in datasets with small sample size.
#'
#'   Models can be fitted with a single set of fixed parameters, in which case
#'   `trControl` defaults to `trainControl(method = "none")` which disables
#'   inner CV as it is unnecessary. See
#'   https://topepo.github.io/caret/model-training-and-tuning.html#fitting-models-without-parameter-tuning
#'
#' @author Myles Lewis
#' @examples
#' \donttest{
#' ## sigmoid function
#' sigmoid <- function(x) {1 / (1 + exp(-x))}
#' 
#' ## load iris dataset and simulate a binary outcome
#' data(iris)
#' x <- iris[, 1:4]
#' colnames(x) <- c("marker1", "marker2", "marker3", "marker4")
#' x <- as.data.frame(apply(x, 2, scale))
#' y2 <- sigmoid(0.5 * x$marker1 + 2 * x$marker2) > runif(nrow(x))
#' y2 <- factor(y2, labels = c("class1", "class2"))
#' 
#' ## Example using random forest with caret
#' cvrf <- nestcv.train(y2, x, method = "rf",
#'                      n_outer_folds = 3,
#'                      cv.cores = 2)
#' summary(cvrf)
#' 
#' ## Example of glmnet tuned using caret
#' ## set up small tuning grid for quick execution
#' ## length.out of 20-100 is usually recommended for lambda
#' ## and more alpha values ranging from 0-1
#' tg <- expand.grid(lambda = exp(seq(log(2e-3), log(1e0), length.out = 5)),
#'                   alpha = 1)
#' 
#' ncv <- nestcv.train(y = y2, x = x,
#'                     method = "glmnet",
#'                     n_outer_folds = 3,
#'                     tuneGrid = tg, cv.cores = 2)
#' summary(ncv)
#' 
#' ## plot tuning for outer fold #1
#' plot(ncv$outer_result[[1]]$fit, xTrans = log)
#' 
#' ## plot final ROC curve
#' plot(ncv$roc)
#' 
#' ## plot ROC for left-out inner folds
#' inroc <- innercv_roc(ncv)
#' plot(inroc)
#' }
#' @importFrom caret createFolds train trainControl mnLogLoss confusionMatrix
#'   defaultSummary
#' @importFrom data.table rbindlist
#' @importFrom parallel mclapply
#' @importFrom pROC roc
#' @importFrom stats predict setNames
#' @export
#' 
nestcv.train <- function(y, x,
                         filterFUN = NULL,
                         filter_options = NULL,
                         outer_method = c("cv", "LOOCV"),
                         n_outer_folds = 10,
                         outer_folds = NULL,
                         cv.cores = 1,
                         metric = ifelse(is.factor(y), "logLoss", "RMSE"),
                         trControl = NULL,
                         tuneGrid = NULL,
                         savePredictions = "final",
                         na.option = "pass",
                         ...) {
  nestcv.call <- match.call(expand.dots = TRUE)
  outer_method <- match.arg(outer_method)
  ok <- checkxy(y, x, na.option)
  y <- y[ok$r]
  x <- x[ok$r, ok$c]
  if (is.null(trControl)) {
    trControl <- if (is.factor(y)) {
      trainControl(method = "cv", 
                   number = 10,
                   classProbs = TRUE,
                   savePredictions = savePredictions,
                   summaryFunction = mnLogLoss)
    } else trainControl(method = "cv", 
                        number = 10,
                        savePredictions = savePredictions)
  }
  # switch off inner CV if tuneGrid is single row
  if (!is.null(tuneGrid)) {
    if (nrow(tuneGrid) == 1) trControl <- trainControl(method = "none", classProbs = TRUE)
  }
  if (is.null(outer_folds)) {
    outer_folds <- switch(outer_method,
                          cv = createFolds(y, k = n_outer_folds),
                          LOOCV = 1:length(y))
  }
  
  if (Sys.info()["sysname"] == "Windows" & cv.cores >= 2) {
    cl <- makeCluster(cv.cores)
    clusterExport(cl, varlist = c("outer_folds", "y", "x", 
                                  "filterFUN", "filter_options",
                                  "metric", "trControl", "tuneGrid",
                                  "nestcv.trainCore", ...),
                  envir = environment())
    outer_res <- parLapply(cl = cl, outer_folds, function(test) {
      nestcv.trainCore(test, y, x,
                       filterFUN, filter_options,
                       metric, trControl, tuneGrid, ...)
    })
    stopCluster(cl)
  } else {
    outer_res <- mclapply(outer_folds, function(test) {
      nestcv.trainCore(test, y, x,
                       filterFUN, filter_options,
                       metric, trControl, tuneGrid, ...)
    }, mc.cores = cv.cores)
  }
  
  predslist <- lapply(outer_res, '[[', 'preds')
  output <- data.table::rbindlist(predslist)
  output <- as.data.frame(output)
  if (!is.null(rownames(x))) {
    rownames(output) <- unlist(lapply(predslist, rownames))}
  summary <- predSummary(output)
  caret.roc <- NULL
  if (is.factor(y) & nlevels(y) == 2) {
    caret.roc <- pROC::roc(output$testy, output$predyp, direction = "<", 
                           quiet = TRUE)
  }
  bestTunes <- lapply(outer_res, function(i) i$fit$bestTune)
  bestTunes <- as.data.frame(data.table::rbindlist(bestTunes))
  rownames(bestTunes) <- paste('Fold', seq_len(n_outer_folds))
  finalTune <- colMeans(bestTunes)
  finalTune <- data.frame(as.list(finalTune))
  filtx <- if (is.null(filterFUN)) x else {
    args <- list(y = y, x = x)
    args <- append(args, filter_options)
    fset <- do.call(filterFUN, args)
    x[, fset]
  }
  fitControl <- trainControl(method = "none", classProbs = TRUE)
  final_fit <- caret::train(x = filtx, y = y, 
                            trControl = fitControl,
                            tuneGrid = finalTune, ...)
  
  out <- list(call = nestcv.call,
              output = output,
              outer_result = outer_res,
              outer_method = outer_method,
              outer_folds = outer_folds,
              dimx = dim(x),
              final_fit = final_fit,
              final_vars = colnames(filtx),
              roc = caret.roc,
              trControl = trControl,
              bestTunes = bestTunes,
              finalTune = finalTune,
              summary = summary)
  class(out) <- "nestcv.train"
  out
}


nestcv.trainCore <- function(test, y, x,
                             filterFUN, filter_options,
                             metric, trControl, tuneGrid, ...) {
  filtx <- if (is.null(filterFUN)) x else {
    args <- list(y = y[-test], x = x[-test, ])
    args <- append(args, filter_options)
    fset <- do.call(filterFUN, args)
    x[, fset]
  }
  fit <- caret::train(x = filtx[-test, ], y = y[-test],
                      metric = metric,
                      trControl = trControl,
                      tuneGrid = tuneGrid, ...)
  predy <- predict(fit, newdata = filtx[test, ])
  preds <- data.frame(predy=predy, testy=y[test])
  if (is.factor(y)) {
    predyp <- predict(fit, newdata = filtx[test, ], type = "prob")
    # note predyp has 2 columns
    preds$predyp <- predyp[,2]
  }
  rownames(preds) <- rownames(x)[test]
  ret <- list(preds = preds,
              fit = fit,
              nfilter = ncol(filtx))
  ret
}


#' @export
summary.nestcv.train <- function(object, 
                                 digits = max(3L, getOption("digits") - 3L), 
                                 ...) {
  cat("Nested cross-validation with caret\n")
  if (!is.null(object$call$filterFUN)) 
    cat("Filter: ", object$call$filterFUN, "\n") else cat("No filter\n")
  cat("Outer loop: ", switch(object$outer_method,
                             cv = paste0(length(object$outer_folds), "-fold cv\n"),
                             LOOCV = "leave-one-out CV\n"))
  cat("Inner loop: ", paste0(object$trControl$number, "-fold ",
                             object$trControl$method, "\n"))
  cat(object$dimx[1], "observations,", object$dimx[2], "predictors\n\n")
  nfilter <- unlist(lapply(object$outer_result, '[[', 'nfilter'))
  foldres <- object$bestTunes
  foldres$n.filter <- nfilter
  print(foldres, digits = digits, print.gap = 2L)
  cat("\nFinal parameters:\n")
  print(object$finalTune, digits = digits, print.gap = 2L, row.names = FALSE)
  cat("\nResult:\n")
  print(object$summary, digits = digits, print.gap = 2L)
  out <- list(dimx = object$dimx, folds = foldres,
              final_param = object$finalTune, result = object$summary)
  invisible(out)
}


#' @method predict nestcv.train
#' @export
predict.nestcv.train <- function(object, newdata, ...) {
  if (any(!object$final_vars %in% colnames(newdata))) 
    stop("newdata is missing some predictors", call. = FALSE)
  predict(object$final_fit, newdata = newdata[, object$final_vars], ...)
}
