#' Simulate a single trajectory from an interval-censored multi-state model
#' with Weibull transition intensities
#' 
#' @description Simulate a single trajectory from a multi-state model quantified
#' by a transition matrix, with interval-censored transitions and Weibull 
#' distributed transition intensities. Allows for Weibull censoring in each of 
#' the states.
#' 
#' 
#' @param obstimes A numeric vector of times at which the subject is observed.
#' @param tmat A transition matrix as created by \code{\link[mstate:transMat]{transMat}}, 
#' with H rows and H columns indicating the states. The total number of possible 
#' transitions will be indicated by M.
#' @param tmat2 Summary of transition matrix.
#' @param startstate The numeric starting state of the subject, can be chosen from 
#' any integer from 1 to H. By default, \code{startstate = 1}.
#' @param exact A numeric vector indicating which states are exactly observed. 
#' The transition time to exact states will be observed at exact times, regardless 
#' of the times in \code{obstimes}. No exact states if missing.
#' @param shape A numeric vector of length M indicating the shape of the Weibull 
#' transition intensity for the corresponding transition in \code{tmat}. See 
#' \code{help(dweibull)}.
#' @param scale A numeric vector of length M indicating the scale of the Weibull 
#' transition intensity for the corresponding transition in \code{tmat}. See 
#' \code{help(dweibull)}.
#' @param censshape A numeric vector of length H indicating the Weibull 
#' censoring shape in each of the states. If left missing, 
#' no censoring is applied.
#' @param censscale A numeric vector of length H indicating the Weibull censoring 
#' scale in each of the states. If left missing, no censoring is applied.
#' @param true_trajec Should the true (right-censored) trajectory be returned for
#' the subject as well? Default = \code{FALSE}.
#' 
#' 
#' @importFrom igraph is_dag
#' @importFrom mstate to.trans2 msfit probtrans
#' 
#' 
#' @details Suppose a subject arrives in state g at time s. If we wish to generate 
#' a survival time from that state according to a Weibull intensity in a clock forward 
#' model, we can use the inverse transform of the conditional Weibull intensity.
#' More specifically, letting \eqn{a}{a} denote the shape and \eqn{\sigma}{\sigma} denote the scale, 
#' the conditional survival function for \eqn{t > s}{t > s} is given by
#' \deqn{S(t|s) = \mathbf{P}(T \geq t | T \geq s) = \exp(\left( \frac{s}{\sigma} \right)^a - \left( \frac{t}{\sigma} \right)^a)}{S(t|s) = P(T >= t | T >= s) = exp((s/\sigma)^a - (t/\sigma)^a)}
#' The corresponding cumulative intensity is then given by:
#' \deqn{A(t|s) = -\log(S(t|s)) = \left( \frac{t}{\sigma} \right)^a - \left( \frac{s}{\sigma} \right)^a}{A(t|s) = - log(S(t|s)) = (t/\sigma)^a - (s/\sigma)^a}
#' And the inverse cumulative intensity is then:
#' \deqn{A^{-1}(t|s) = \sigma \sqrt[a]{t + \left( \frac{s}{\sigma} \right)^a}}{A^(-1)(t|s) = \sigma (t + (s/\sigma)^a)^(1/a)}
#' A conditional survival time is then generated by:
#' \deqn{T|s = A^{-1}(-\log(U)|s)}{T|s = A^(-1)(-log(U)|s)}
#' with \eqn{U}{U} a sample from the standard uniform distribution.
#' If we additionally have covariates (or frailties), the \eqn{-\log(U)}{-log(U)}
#' above should be replaced by \eqn{\frac{-\log(U)}{\exp(\beta X)}}{(-log(U))/(exp(beta X))}
#' with \eqn{\beta}{beta} and \eqn{X}{X} the coefficients and covariates respectively.
#' Taking \code{censshape} to be 1 for all transitions, we obtain exponential 
#' censoring with rate 1/\code{censscale}.
#' 
#' @returns A matrix with 3 columns \code{time, state} and \code{cens}, indicating 
#' the observation time, the corresponding state and whether the subject has 
#' been censored in the state.
#' 
#' @keywords internal
#' @noRd
#' 
#' @examples 
#' require(mstate)
#' gd <- sim_weibmsm(obstimes = seq(0, 20, 2), tmat = trans.illdeath(),
#'                   shape = c(1, 1, 1), scale = c(2, 10, 1))
#' 




sim1_weibmsm <- function(obstimes, tmat, tmat2, startstate = 1, exact, shape, scale, 
                        censshape, censscale, true_trajec = FALSE){
  
  
  # Argument Checks ---------------------------------------------------------

  # 
  # arg_checks <- makeAssertCollection()
  # 
  # assertMatrix(tmat, min.rows = 2, min.cols = 2, all.missing = FALSE, add = arg_checks)
  # assert(nrow(tmat) == ncol(tmat), add = arg_checks)
  # 
  H <- nrow(tmat)
  M <- nrow(tmat2)
  # 
  # assertNumeric(obstimes, lower = 0, upper = Inf, any.missing = FALSE, 
  #               min.len = 1, add = arg_checks)
  # 
  # assertIntegerish(startstate, lower = 1, upper = H, any.missing = FALSE,
  #               len = 1, add = arg_checks)
  # 
  # if(!missing(exact)){
  #   assertNumeric(exact, any.missing = FALSE, lower = 1, upper = H, min.len = 1,
  #                 max.len = H, unique = TRUE, add = arg_checks)
  # }
  # 
  # assertNumeric(shape, lower = 0, upper = Inf, any.missing = FALSE, len = M, add = arg_checks)
  # assertNumeric(scale, lower = 0, upper = Inf, any.missing = FALSE, len = M, add = arg_checks)
  # 
  cens_idx <- FALSE
  if(!missing(censshape) & !missing(censscale)){
    cens_idx <- TRUE
  }
  # 
  # if(!missing(censshape)){
  #   assertNumeric(censshape, lower = 0, upper = Inf, any.missing = TRUE, len = H, add = arg_checks)
  # }
  # if(!missing(censscale)){
  #   assertNumeric(censscale, lower = 0, upper = Inf, any.missing = TRUE, len = H, add = arg_checks)
  # }
  # 
  # if((missing(censshape) & !missing(censscale)) | (!missing(censshape) & missing(censscale)) ){
  #   stop("Please define both censshape and censscale.")
  # }
  # 
  # if (!arg_checks$isEmpty()) checkmate::reportAssertions(arg_checks)
  

  
  
  
  
  
  # Post processing of arguments ---------------------------------------------------------
  
  obstimes <- sort(unique(obstimes))
  min_obstime <- min(obstimes)
  max_obstime <- max(obstimes)
  
  #Extract state names
  state_names <- colnames(tmat)
  if(is.null(state_names)){
    state_names <- 1:H
  }
  
  #Determine absorbing states
  absorbing_states <- which(apply(is.na(tmat), 1, all))
  
  
  
  # Main function ---------------------------------------------------------------
  
  #We keep track of the true trajectory of the subject in a list 
  #This is faster than using matrix, as we need to append multiple times.
  
  current_time <- min_obstime
  current_state <- startstate
  current_cens <- 0
  true_trajectory <- list(state = current_state,
                          time = current_time,
                          cens = current_cens)
  
  time_idx <- 2
  #We keep going until we have reached the maximum observation time or surpassed it.
  while(current_time < max_obstime){
    #Check if we are in an absorbing state
    currently_absorbed <- current_state %in% absorbing_states
    if(currently_absorbed){ #If we are in an absorbing state, no point in continuing
      break
    }
    if(current_cens == 0){ #If the subject is uncensored, we keep generating the trajectory
      #We want to calculate a transition time for each possible transition out of 
      #current state. First determine which transitions are even possible.
      possible_transitions <- which(tmat2[, "from"] == current_state)
      n_possible_transitions <- length(possible_transitions)
      #We draw a uniform outcome for each possible transition for generation of survival time
      unif_outc <- runif(n_possible_transitions)
      
      #Generate a "possible" transition time for each transition, using inverse transform method
      possible_transition_times <- sapply(1:n_possible_transitions,
                                          function(x) inverse_weib_haz(t = -log(unif_outc[x]), s = current_time, shape = shape[possible_transitions[x]], scale = scale[possible_transitions[x]]))
      which_min_transition_time <- which.min(possible_transition_times)
      min_transition_time <- possible_transition_times[which_min_transition_time]
      min_transition_number <- possible_transitions[which_min_transition_time]
      


      
      #Set current time to smallest possible transition time + appropriate state
      #This way, if we don't have right-censoring, we don't have to change anything.
      previous_time <- current_time
      previous_state <- current_state
      current_time <- min_transition_time
      current_state <- tmat2[min_transition_number, "to"]
      current_cens <- 0
      
      #If we specify censoring, also determine censoring time, but only if there is censoring in the corresponding state
      if(cens_idx){
        if(!is.na(censshape[previous_state]) & !is.na(censscale[previous_state])){
          cens_time <- rweibull(1, shape = censshape[previous_state], scale = censscale[previous_state])
          #possible_transition_times are on chronological time-scale, while censoring time is generated from time 0.
          if(cens_time < min_transition_time - previous_time){
            current_time <- previous_time + cens_time
            current_state <- previous_state
            current_cens <- 1
          }
        }
      }
      
      #Check if transition is observed exactly and add extra observation time if so.
      if(!missing(exact) & !current_cens){
        trans_exactly_observed <- tmat2[min_transition_number, "to"] %in% exact
        if(trans_exactly_observed & min_transition_time <= max_obstime){
          obstimes <- sort(unique(c(obstimes, min_transition_time)))
        }
      }
      
      
      #Add observation to true trajectory
      true_trajectory[["time"]][time_idx] <- current_time
      true_trajectory[["state"]][time_idx] <- current_state
      true_trajectory[["cens"]][time_idx] <- current_cens
      time_idx <- time_idx + 1
    } else{ #If the subject is censored, we stop.
      break
    }
  }
  
  #Convert to matrix for faster processing.
  true_trajectory <- matrix(unlist(true_trajectory), ncol = length(true_trajectory),
                            dimnames = list(c(), names(true_trajectory)))
  
  #Get observed state from true trajectory
  #Returns a vector with the state and censoring indicator.
  get_state <- function(obs_time, true_trajectory, cens_idx) {
    # Find the index of the closest time that is <= obs_time
    index <- findInterval(obs_time, true_trajectory[, "time"])
    res <- true_trajectory[index, c("state", "cens")]
    # Return the corresponding state
    return(res)
  }
  #Apply function to get observed states
  observed_states <- sapply(obstimes, get_state, true_trajectory = true_trajectory, cens_idx = cens_idx)
  #Transform into matrix to get data for output
  #Columns: observation_times, observation_states, censoring indicator
  if(cens_idx){
    observed_trajectory <- cbind(obstimes, t(observed_states))
    colnames(observed_trajectory)[1] <- "time"
    first_cens_obs <- match(1, observed_trajectory[, "cens"])
    if(!is.na(first_cens_obs)){
      observed_trajectory <- observed_trajectory[1:(first_cens_obs-1), , drop = FALSE]
    }
  } else{
    observed_trajectory <- cbind(obstimes, t(observed_states))
    colnames(observed_trajectory)[c(1, 3)] <- c("time", "cens")
  }
  
  #Make return object
  if(true_trajec){
    out <- list(observed_trajectory = observed_trajectory,
                true_trajectory = true_trajectory)
  } else{
    out <- observed_trajectory
  }
  out
}



