#' Evaluate Strategy
#' 
#' Given an unevaluated strategy, an initial number of 
#' individual and a number of cycle to compute, returns the 
#' evaluated version of the objects and the count of 
#' individual per state per model cycle.
#' 
#' `init` need not be integer. E.g. `c(A = 1, B = 0.5, C =
#' 0.1, ...)`.
#' 
#' @param strategy An `uneval_strategy` object.
#' @param parameters Optional. An object generated by 
#'   [define_parameters()].
#' @param cycles positive integer. Number of Markov Cycles 
#'   to compute.
#' @param init numeric vector, same length as number of 
#'   model states. Number of individuals in each model state
#'   at the beginning.
#' @param method Counting method.
#' @param expand_limit A named vector of state expansion 
#'   limits.
#' @param inflow Numeric vector, similar to `init`. Number
#'   of new individuals in each state per cycle.
#' @param strategy_name Name of the strategy.
#'   
#' @return An `eval_strategy` object (actually a list of 
#'   evaluated parameters, matrix, states and cycles 
#'   counts).
#'   
#' @example inst/examples/example_eval_strategy.R
#'   
#' @keywords internal
eval_strategy <- function(strategy, parameters, cycles, 
                          init, method, expand_limit,
                          inflow, strategy_name) {
  
  stopifnot(
    cycles > 0,
    length(cycles) == 1
  )
  
  uneval_transition <- get_transition(strategy)
  uneval_states <- get_states(strategy)
  
  i_parameters <- interp_heemod(parameters)
  
  i_uneval_transition <- interp_heemod(
    uneval_transition,
    more = as_expr_list(i_parameters)
  )
  
  i_uneval_states <- interp_heemod(
    uneval_states,
    more = as_expr_list(i_parameters)
  )
  
  
  td_tm <- has_state_time(i_uneval_transition)
  
  td_st <- has_state_time(i_uneval_states)
  
  # no expansion if 
  expand <- any(c(td_tm, td_st))
  
  
  if (expand) {
    
    if (inherits(uneval_transition, "part_surv")) {
      stop("Cannot use 'state_time' with partitionned survival.")
    }
    
    uneval_transition <- i_uneval_transition
    uneval_states <- i_uneval_states
    
    # parameters not needed anymore because of interp
    parameters <- define_parameters()
    
    # from cells to cols
    td_tm <- td_tm %>% 
      matrix(
        nrow = get_matrix_order(uneval_transition), 
        byrow = TRUE
      ) %>% 
      apply(1, any)
    
    to_expand <- sort(unique(c(
      get_state_names(uneval_transition)[td_tm],
      get_state_names(uneval_states)[td_st]
    )))
    
    message(sprintf(
      "%s: detected use of 'state_time', expanding state%s: %s.",
      strategy_name,
      plur(length(to_expand)),
      paste(to_expand, collapse = ", ")
    ))
    
    for (st in to_expand) {
      init <- expand_state(
        init, state_name = st, cycles = expand_limit[st]
      )
      
      inflow <- expand_state(
        inflow, state_name = st, cycles = expand_limit[st]
      )
    }
    
    for (st in to_expand) {
      uneval_transition <- expand_state(
        x = uneval_transition,
        state_pos = which(get_state_names(uneval_transition) == st),
        state_name = st,
        cycles = expand_limit[st]
      )
      
      uneval_states <- expand_state(
        x = uneval_states,
        state_name = st,
        cycles = expand_limit[st]
      )
    }
  }
  
  parameters <- eval_parameters(parameters,
                                cycles = cycles,
                                strategy_name = strategy_name)
  
  e_init <- unlist(eval_init(x = init, parameters[1, ]))
  e_inflow <- eval_inflow(x = inflow, parameters)
  
  if (any(is.na(e_init)) || any(is.na(e_inflow))) {
    stop("Missing values not allowed in 'init' or 'inflow'.")
  }
  
  if (! any(e_init > 0)) {
    stop("At least one init count must be > 0.")
  }
  
  states <- eval_state_list(uneval_states, parameters)
  
  transition <- eval_transition(uneval_transition,
                                parameters)
  
  count_table <- compute_counts(
    x = transition,
    init = e_init,
    inflow = e_inflow
  ) %>% 
    correct_counts(method = method)
  
  values <- compute_values(states, count_table)
  
  if (expand) {
    for (st in to_expand) {
      exp_cols <- sprintf(".%s_%i", st, seq_len(expand_limit[st] + 1))
      
      count_table[[st]] <- rowSums(count_table[exp_cols])
      count_table <- count_table[-which(names(count_table) %in% exp_cols)]
    }
  }
  
  structure(
    list(
      parameters = parameters,
      transition = transition,
      states = states,
      counts = count_table,
      values = values,
      e_init = e_init,
      e_inflow = e_inflow,
      n_indiv = sum(e_init, unlist(e_inflow)),
      cycles = cycles,
      expand_limit = expand_limit
    ),
    class = c("eval_strategy")
  )
}

get_eval_init <- function(x) {
  UseMethod("get_eval_init")
}

get_eval_init.eval_strategy <- function(x) {
  x$e_init
}

get_eval_inflow <- function(x) {
  UseMethod("get_eval_inflow")
}

get_eval_inflow.eval_strategy <- function(x) {
  x$e_inflow
}

get_n_indiv <- function(x) {
  UseMethod("get_n_indiv")
}

get_n_indiv.eval_strategy <- function(x) {
  x$n_indiv
}

#' Compute Count of Individual in Each State per Cycle
#' 
#' Given an initial number of individual and an evaluated 
#' transition matrix, returns the number of individual per 
#' state per cycle.
#' 
#' Use the `method` argument to specify if transitions 
#' are supposed to happen at the beginning or the end of 
#' each cycle. Alternatively linear interpolation between 
#' cycles can be performed.
#' 
#' @param x An `eval_matrix` or
#'   `eval_part_surv` object.
#' @param init numeric vector, same length as number of 
#'   model states. Number of individuals in each model state
#'   at the beginning.
#' @param inflow numeric vector, similar to `init`.
#'   Number of new individuals in each state per cycle.
#'   
#' @return A `cycle_counts` object.
#'   
#' @keywords internal
compute_counts <- function(x, ...) {
  UseMethod("compute_counts")
}

#' @export
compute_counts.eval_matrix <- function(x, init,
                                       inflow,
                                       ...) {
  
  if (! length(init) == get_matrix_order(x)) {
    stop(sprintf(
      "Length of 'init' vector (%i) differs from the number of states (%i).",
      length(init),
      get_matrix_order(x)
    ))
  }
  
  if (! length(inflow) == get_matrix_order(x)) {
    stop(sprintf(
      "Length of 'inflow' vector (%i) differs from the number of states (%i).",
      length(inflow),
      get_matrix_order(x)
    ))
  }
  
  i <- 0
  add_and_mult <- function(x, y) {
    i <<- i + 1
    (x + unlist(inflow[i, ])) %*% y
  }
  
  list_counts <- Reduce(
    add_and_mult,
    x,
    init,
    accumulate = TRUE
  )
  
  res <- dplyr::as.tbl(
    as.data.frame(
      matrix(
        unlist(list_counts),
        byrow = TRUE,
        ncol = get_matrix_order(x)
      )
    )
  )
  
  colnames(res) <- get_state_names(x)
  
  structure(res, class = c("cycle_counts", class(res)))
}

#' Compute State Values per Cycle
#' 
#' Given states and counts, computes the total state values 
#' per cycle.
#' 
#' @param states An object of class `eval_state_list`.
#' @param counts An object of class `cycle_counts`.
#'   
#' @return A data.frame of state values, one column per 
#'   state value and one row per cycle.
#'   
#' @keywords internal
compute_values <- function(states, counts) {
  states_names <- get_state_names(states)
  state_values_names <- get_state_value_names(states)
  num_cycles <- nrow(counts)

  ## combine the list of states into a single large array
  dims_array_1 <- c(
    num_cycles,
    length(state_values_names),
    length(states_names))
  
  dims_array_2 <- dims_array_1 + c(0, 1, 0)
  
  state_val_array <- array(unlist(states), dim = dims_array_2)

  ## get rid of markov_cycle
  mc_col <- match("markov_cycle", names(states[[1]]))
  state_val_array <- state_val_array[, -mc_col, , drop = FALSE]

  ## put counts into a similar large array
  counts_mat <- array(unlist(counts[, states_names]),
                      dim = dims_array_1[c(1, 3, 2)])
  counts_mat <- aperm(counts_mat, c(1, 3, 2))

  # multiply, sum, and add markov_cycle back in
  vals_x_counts <- state_val_array * counts_mat
  wtd_sums <- rowSums(vals_x_counts, dims = 2)
  res <- data.frame(markov_cycle = states[[1]]$markov_cycle, wtd_sums)
  names(res)[-1] <- state_values_names

  res
}
