#' Print Generic For `mab`
#' @description
#' Custom Print Display for objects of `mab` class returned by [single_mab_simulation()].
#' Prevents the large list from being printed directly, and provides
#' useful information about the settings of each trial.
#' @param x A `mab` class object created by [single_mab_simulation()].
#' @param ... Further arguments passed to or from other methods.
#' @method print mab
#' @name print.mab
#' @returns Text summary of settings used for the Multi-Arm Bandit trial.
#' @details
#' The items used to create the text summary can be found in the settings
#' element of the output object.
#'
#' `...` is provided to be compatible with `print()`, but no other arguments
#' change the output.
#' @export
#' @examples
#' # Running a Trial
#' x <- single_mab_simulation(
#'   data = tanf,
#'   algorithm = "thompson",
#'   assignment_method = "batch",
#'   period_length = 1750,
#'   prior_periods = "All",
#'   blocking = FALSE,
#'   whole_experiment = TRUE,
#'   perfect_assignment = TRUE,
#'   data_cols = c(
#'     id_col = "ic_case_id",
#'     success_col = "success",
#'     condition_col = "condition"
#'   )
#' )
#' print(x)
print.mab <- function(x, ...) {
  print_mab(x)
  base::cat("----------------------------------------------------- \n")
}
#-------------------------------------------------------------------------------
#' Print Helper for `mab` and `multiple.mab`
#' @description Common items for the print generics for `mab` and `multiple.mab` classes
#' @name print_mab
#' @param mab A `mab` or `multiple.mab` object.
#' @returns Text summary of settings used for the Multi-Arm Bandit trial.
#' @keywords internal
print_mab <- function(mab) {
  settings <- mab$settings

  base::cat(
    "Summary for MAB Procedure: \n ----------------------------------------------------- \n"
  )

  base::cat("Bandit Algorithm:     ", settings$algorithm, "\n")
  base::cat("Control Augmentation: ", settings$control_augment, "\n")
  base::cat("Bandit Assignment:    ", 1 - settings$random_assign_prop, "\n")
  base::cat("Randomized Assignment:", settings$random_assign_prop, "\n")
  base::cat("Perfect Assignment:   ", settings$perfect_assignment, "\n")
  base::cat("Whole Experiment:     ", settings$whole_experiment, "\n")
  if (settings$blocking) {
    base::cat("Blocking Variables:   ", settings$block_cols, "\n")
  }
  base::cat("Assignment Method:    ", settings$assignment_method, "\n")

  if (settings$assignment_method %in% c("batch", "date")) {
    base::cat("Period Length:        ", settings$period_length)
  }
  if (settings$assignment_method == "batch") {
    base::cat(" People\n")
  }
  if (settings$assignment_method == "date") {
    base::cat("", settings$time_unit)
    if (settings$period_length > 1) {
      base::cat("s\n")
    } else {
      base::cat("\n")
    }
  }

  base::cat(
    "Total Periods:        ",
    max(mab$bandits$period_number),
    "periods\n"
  )
  base::cat("Prior Periods:        ", settings$prior_periods, "periods\n")
  base::cat("Number of Treatments: ", length(settings$conditions), "\n")
  if (settings$control_augment > 0) {
    base::cat("Control Group:        ", settings$control, "\n")
  }
}

#------------------------------------------------------------------------------
##' Summary Generic For "mab" Class
#' @description
#' Summarizes the Results of a Single Multi-Arm Bandit Trial. Provides
#' confidence intervals around the AIPW estimates, final calculations
#' of the Thompson sampling probabilities or UCB1 values, and the number of observations assigned for each arm.
#' @param object A `mab` class object created by [single_mab_simulation()].
#' @param level Numeric value of length 1; indicates confidence interval Width (i.e 0.90, 0.95, 0.99).
#' Defaults to 0.95.
#' @param ... Additional arguments.
#' @method summary mab
#' @export
#' @details
#' The confidence intervals applied follow a standard normal distribution
#' because it is assumed the AIPW estimators are asymptotically normal as shown
#' in \href{https://www.pnas.org/doi/full/10.1073/pnas.2014602118}{Hadad et al. (2021)}.
#'
#' `...` is provided to be compatible with `summary()`, the function
#' does not have any additional arguments.
#'
#' All of the data provided to create a table like this is present in the object
#' created by [single_mab_simulation()] but
#' this provides a simple shortcut, which is useful when testing many
#' different simulations.
#
#' @returns A tibble containing summary information from the trial with the columns:
#' \itemize{
#' \item `Treatment_Arm`: Contains the treatment condition.
#' \item `Probability_Of_Best_Arm`/`UCB1_Value`: Final Thompson sampling probabilities or UCB1 values for each treatment.
#' \item `estimated_probability_of_success`: The AIPW estimates for the probability of success for each treatment.
#' \item `SE`: The standard error for the AIPW estimates.
#' \item `lower_bound`: The lower bound on the normal confidence interval for the `estimated_probability_of_success`. Default is 95%.
#' \item `upper_bound`: The upper bound on the normal confidence interval for the `estimated_probability_of_success`. Default is 95%.
#' \item `num_assigned`: The number of observations assigned to each treatment under the simulated trial.
#' \item `level`: The confidence level for the confidence interval, default is 95%.
#' \item `periods`: The total number of periods of the simulation.
#' }
#' @references
#' Hadad, Vitor, David A. Hirshberg, Ruohan Zhan, Stefan Wager, and Susan Athey. 2021.
#' "Confidence Intervals for Policy Evaluation in Adaptive Experiments." \emph{Proceedings of the National Academy of Sciences of the United States of America} 118
#' (15): e2014602118. \doi{10.1073/pnas.2014602118}.
#'
#' @example inst/examples/summary.mab_example.R
summary.mab <- function(object, level = 0.95, ...) {
  check_level(level)
  periods <- base::max(object$bandits$period_number)
  col2 <- switch(
    object$settings$algorithm,
    "ucb1" = "UCB1_Value",
    "thompson" = "Probability_Of_Best_Arm"
  )
  estimates <- object$estimates |>
    dplyr::filter(estimator == "AIPW") |>
    dplyr::mutate(mab_condition = as.character(mab_condition))

  quantities <- get_assignment_quantities(object, object$settings$conditions)
  quantities <- tibble::as_tibble(quantities) |>
    dplyr::mutate(mab_condition = names(quantities))

  normalq <- base::abs(stats::qnorm((1 - level) / 2))

  object$bandits[periods, ] |>
    tidyr::pivot_longer(
      cols = -period_number,
      names_to = "Treatment_Arm",
      values_to = col2
    ) |>
    dplyr::select(-period_number) |>
    dplyr::left_join(estimates, by = c("Treatment_Arm" = "mab_condition")) |>
    dplyr::mutate(
      SE = sqrt(variance),
      lower_bound = mean - normalq * sqrt(variance),
      upper_bound = mean + normalq * sqrt(variance)
    ) |>
    dplyr::select(-variance, -estimator) |>
    dplyr::left_join(quantities, by = c("Treatment_Arm" = "mab_condition")) |>
    dplyr::rename(
      "estimated_probability_of_success" = "mean",
      "num_assigned" = "value"
    ) |>
    dplyr::mutate(
      level = level,
      periods = periods
    )
}
#------------------------------------------------------------------------------
#' Plot Generic for `mab` objects
#' @description Uses [ggplot2::ggplot()] to plot the results of a single
#' Multi-Arm-Bandit trial. Provides options to select the type of plot,
#' and to change how the plot looks. Objects created can be added to
#' with `+` like any other ggplot plot, but arguments to change
#' the underlying geom must be passed to the function initially.
#'
#' @method plot mab
#' @param x A `mab` class object created by [single_mab_simulation()]
#' @param type String; Type of plot requested; valid types are:
#' \itemize{
#' \item `arm`: Shows Thompson sampling probabilities or UCB1 values over the trial period.
#' \item `assign`: Shows cumulative assignment proportions over the trial period.
#' \item `estimate`: Shows AIPW estimates for success probability with
#' user specified normal confidence intervals based on their estimated variance.
#' }
#' @param save Logical; Whether or not to save the plot to disk; FALSE by default.
#' @param path String; File directory to save file if necessary.
#' @inheritParams summary.mab
#' @param ... Arguments to pass to `ggplot2::geom_*` function (e.g. `color`, `linewidth`, `alpha`, `bins` etc.).
#' @details
#' This function provides minimalist plots to quickly view the results of any
#' Multi-Arm-Bandit trial, and has the ability to be customized through the `...`
#' inside the call and `+` afterwards. However, all the data necessary is
#' provided in the output of [single_mab_simulation()] for extreme
#' customization or professional plots, it is highly recommended
#' to start completely from scratch and not use the generic.
#'
#' The confidence intervals applied follow a standard normal distribution
#' because it is assumed the AIPW estimators are asymptotically normal as shown
#' in \href{https://www.pnas.org/doi/full/10.1073/pnas.2014602118}{Hadad et al. (2021)}
#'
#' @references
#' Hadad, Vitor, David A. Hirshberg, Ruohan Zhan, Stefan Wager, and Susan Athey. 2021.
#' "Confidence Intervals for Policy Evaluation in Adaptive Experiments." \emph{Proceedings of the National Academy of Sciences of the United States of America} 118
#' (15): e2014602118. \doi{10.1073/pnas.2014602118}.
#' @export
#' @example inst/examples/plot.mab_example.R
#' @returns Minimal ggplot object, that can be customized and added to with `+` (to change `scales`, `labels`, `legend`, `theme`, etc.).

plot.mab <- function(x, type, level = .95, save = FALSE, path = NULL, ...) {
  rlang::check_installed("ggplot2")
  plot <- switch(
    type,
    "arm" = plot_arms(x = x, ...),
    "assign" = plot_assign(x = x, ...),
    "estimate" = plot_estimates(x = x, level = level, ...),
    rlang::abort("Invalid Type: Specify `arm`, `assign`, or `estimate`")
  )
  if (save) {
    ggplot2::ggsave(plot, filename = path)
  }
  return(plot)
}

#-------------------------------------------------------------------------------
#' @name plot_arms
#' @title Plot Treatment Arms Over Time
#' @description
#' Helper to [plot.mab()]; plots treatment arms over time.
#' @returns ggplot object
#' @param x A `mab` object passed from [plot.mab()]
#' @inheritParams plot.mab
#' @returns Minimal ggplot object, that can be customized and added to with `+` (to change `scales`, `labels`, `legend`, `theme`, etc.).
#' @keywords internal

plot_arms <- function(x, ...) {
  rlang::check_installed("ggplot2")
  data <- x$bandits
  periods <- base::max(data$period_number)

  if (x$settings$algorithm == "ucb1") {
    ylab <- "UCB1 Values"
    title <- "UCB1 Sampling Over Time"
  }
  if (x$settings$algorithm == "thompson") {
    ylab <- "Posterior Probability of Being Best Arm"
    title <- "Thompson Sampling Over Time"
  }

  data |>
    tidyr::pivot_longer(
      cols = -period_number,
      names_to = "condition",
      values_to = "probs"
    ) |>
    ggplot2::ggplot(ggplot2::aes(
      x = period_number,
      y = probs,
      color = condition
    )) +
    ggplot2::geom_line(...) +
    ggplot2::scale_y_continuous(
      breaks = base::seq(0, 1, 0.1),
      limits = base::range(0, 1)
    ) +
    ggplot2::labs(
      x = "Assignment Period",
      y = ylab,
      title = title,
      color = "Treatment Arm"
    ) +
    ggplot2::theme_minimal()
}

#' @name plot_assign
#' @title Plot Cumulative Assignment Probability Over Time
#' @returns ggplot object
#' @param x A `mab` object passed from [plot.mab()]
#' @inheritParams plot.mab
#' @returns Minimal ggplot object, that can be customized and added to with `+` (to change `scales`, `labels`, `legend`, `theme`, etc.).
#' @keywords internal
plot_assign <- function(x, ...) {
  data <- x$final_data
  cumulative_data <- data |>
    dplyr::select(mab_condition, period_number) |>
    dplyr::arrange(period_number) |>
    dplyr::group_by(mab_condition, period_number) |>
    dplyr::count() |>
    dplyr::ungroup() |>
    dplyr::mutate(n = n / nrow(data)) |>
    dplyr::group_by(mab_condition) |>
    dplyr::mutate(cum_n = cumsum(n))

  ggplot2::ggplot(
    cumulative_data,
    ggplot2::aes(x = period_number, y = cum_n, color = mab_condition)
  ) +
    ggplot2::geom_line(...) +
    ggplot2::labs(
      x = "Assignment Period",
      y = "Proportion of Data",
      title = "Cumulative Assignments Across Trial",
      color = "Treatment Arm"
    ) +
    ggplot2::scale_y_continuous(
      breaks = base::seq(0, 1, 0.1),
      limits = base::range(0, 1)
    ) +
    ggplot2::theme_minimal()
}

#' @name plot_estimates
#' @title Plot AIPW Estimates
#' @inheritParams plot.mab
#' @description
#' Plot summary of AIPW estimates and variances for each treatment arm.
#' @returns Minimal ggplot object, that can be customized and added to with `+` (to change `scales`, `labels`, `legend`, `theme`, etc.).
#' @keywords internal
plot_estimates <- function(x, level = 0.95, ...) {
  rlang::check_installed("ggplot2")
  check_level(level)
  normalq <- base::abs(stats::qnorm((1 - level) / 2))

  x$estimates |>
    dplyr::filter(estimator == "AIPW") |>
    ggplot2::ggplot(ggplot2::aes(x = mean, y = mab_condition)) +
    ggplot2::geom_errorbarh(
      ggplot2::aes(
        xmin = mean - normalq * sqrt(variance),
        xmax = mean + normalq * sqrt(variance)
      ),
      ...
    ) +
    ggplot2::labs(
      x = "Probability of Success (AIPW)",
      y = "Treatment Condition",
      title = "AIPW Estimated Success Probabilities"
    ) +
    ggplot2::theme_minimal()
}
#-------------------------------------------------------------------------------
#' Check Level
#' @description
#' Checking if the `level` argument in the S3 generic methods
#' is valid for a confidence interval.
#' @name check_level
#' @inheritParams plot.mab
#' @returns Throws an error if `level` is invalid, else does nothing.
#' @keywords internal
check_level <- function(level) {
  if (!is.numeric(level) || (level < 0 || level > 1)) {
    rlang::abort(c(
      "`level` must be a number between 0 and 1",
      "x" = paste0("You passed: ", level)
    ))
  }
}
