#' color_palette
#'
#' A pre-loaded color palette
#'
#' A pre-loaded color palette that can be used in charting functions
#'
#' @export
#' @return character vector of colors in hexadecimal notation
color_palette = function(){
  return(
    c(
      "#636363",
      "#143fba",
      "#006a53",
      "#209162",
      "#33eea0",
      "#8f9d35",
      "#fff200",
      "#ff5c00",
      "#6a000d",
      "#ff006e",
      "#b400da",
      "#280030"
    )
  )
}

#' decomping
#'
#' Variable decomposition of linear regression
#'
#' Calculates the decomposition of the independent variables based on an input model object.
#' This can be expanded by leveraging id variables (e.g. date) and categories (i.e. groups of variables).
#'
#' @export
#' @param model Model object
#' @param de_normalise A boolean to specify whether to apply the normalisation
#' @param categories \code{data.frame} mapping variables to groups
#' @param tail_window for time series, length of tail
#' @param verbose A boolean to specify whether to print warnings
#' @import tidyverse
#' @importFrom stats complete.cases
#' @importFrom methods is
#' @importFrom tidyr pivot_longer
#' @importFrom magrittr '%>%'
#' @importFrom reshape2 melt acast
#' @return a \code{list} of 3 \code{data.frame}'s representing the variable and category decomposition, and the fitted values.
#' @examples
#' run_model(data = mtcars,dv = 'mpg',ivs = c('wt','cyl','disp'),decompose=FALSE) %>% decomping()
decomping = function(model = NULL,
                     de_normalise = TRUE,
                     categories = NULL,
                     tail_window = NULL,
                     verbose = FALSE){
  # test    ####
  
  # raw_data = read_xcsv(verbose = FALSE,
  #                         file = "https://raw.githubusercontent.com/paladinic/data/main/pooled%20data.csv")
  # dv = "amazon"
  # ivs = c("rakhi", "christmas", "diwali")
  # id_var = "Week"
  # pool_var = 'country'
  # 
  # model_table = build_model_table(c(ivs, "", ""))
  # model = run_model(
  #   id_var = id_var,
  #   verbose = FALSE,
  #   data = raw_data,
  #   dv = dv,
  #   pool_var = pool_var,
  #   model_table = model_table,
  #   normalise_by_pool = TRUE
  # )
  # 
  # rm(raw_data)
  # rm(model_table)
  # rm(pool_var)
  # rm(ivs)
  # rm(id_var)
  # rm(dv)
  # 
  # de_normalise = FALSE
  # de_normalise = TRUE
  # categories = NULL
  # verbose = FALSE
  # tail_window = 10

  # checks  ####
  
  # check verbose
  if(!is.logical(verbose)){
    message("Warning: verbose provided mus be logical (TRUE or FALSE). Setting to False.")
    verbose = FALSE
  }
  
  
  # check if model is provided
  if(is.null(model)){
    if(verbose){
      message("Error: No model provided. Returning NULL.")
    }
    return(NULL)
  }
  if(!is(model,'lm')){
    if(verbose){
      message("Error: model must be of type 'lm'. Returning NULL.")
    }
    return(NULL)
  }
  
  
  # get the coefficients from the model object
  coef = model$coefficients
  ivs = model$model_table %>% 
    filter(variable_t != '') %>% 
    pull(variable_t)
  
  
  # TODO: function for checking model object to use in all functions
  if(!("model_table" %in% names(model))){
    model_table = build_model_table(ivs = ivs,trans_df = model$trans_df) %>% 
      get_variable_t()
  }
  
  # extract dependent variable name
  dv = model$dv
  
  # get the modeled data from the model object
  data = model$model
  raw_data = model$raw_data

  # if raw_data is found, check for and drop NAs
  if(any(complete.cases(raw_data))) {
    if (verbose) {
      message("Warning: NA's found in raw data will be dropped.")
    }
    raw_data = raw_data[complete.cases(raw_data),]
  }
  
  # get the dependent variable from the data object
  actual = data %>% pull(!!sym(dv))
  raw_actual = raw_data %>% pull(!!sym(dv))
  
  if(!is.logical(de_normalise)){
    if(verbose){
      message("Warning: de_normalise provided must be of type logical. Setting de_normalise to FALSE.")
    }
    de_normalise = FALSE
  }
  
  
  # get the intercept value from the coef object
  intercept = coef[1]
  # keep the other coefficients
  coef = coef[-1]
  
  
  # offset
  if(any(model$model_table$fixed!='')){
    
    fixed_vars = model$model_table %>% 
      filter(fixed != '') %>% 
      pull(variable_t)
    
    fixed_coefs = rep(1,length(fixed_vars))
    
    names(fixed_coefs) = fixed_vars
    
    coef = c(coef,fixed_coefs)
  }
  
  
  # get pool var
  pool_var = model$pool_var
  pool_var_values = raw_data %>% pull(!!sym(pool_var))
  
  # get id_var and values
  id_var = model$id_var
  id_var_values = raw_data %>% pull(!!sym(id_var))
  
  # process ####
  
  # generate the fitted values dataframe
  fitted_values = tibble(
    actual = c(actual),
    residual = model$residuals,
    predicted = model$fitted.values,
    id = id_var_values, #%>% factor(),
    pool = pool_var_values # %>% factor()
  ) 
  fitted_values = pivot_longer(
    data = fitted_values,
    cols = c('actual','residual','predicted'),
    values_to = 'value',
    names_to = "variable") %>% 
    arrange(variable,id)
  
  
  if(!is.null(tail_window) & !is.null(raw_data)){
    # extend id
    
    ## get id
    unique_id_var_values  = id_var_values %>%
      unique() %>%
      sort()
    
    ## get interval (mode of diff)
    interval = diff(unique_id_var_values)
    uniqv = unique(interval)
    interval = uniqv[which.max(tabulate(match(interval, uniqv)))]
    
    ## generate id extension
    start = max(unique_id_var_values) + interval
    end = max(unique_id_var_values) + (interval * tail_window)
    id_ext = seq(start, end, interval)
    
    #TODO:
    # WARNING # what if the id is not consistent across pools (i.e. different ends)?
    
    ## get pool
    unique_pool = pool_var_values %>%
      unique()
    
    # blank df with extended index and pool
    
    ## combos dates, pools
    df_ext = expand.grid(id_ext, unique_pool) %>%
      data.frame()
    colnames(df_ext) = c(id_var, pool_var)
    
    ## variables' columns
    
    raw_ivs = model$model_table$variable %>%
      unique() %>%
      {
        .[. != '']
      }
    
    df_ext[raw_ivs] = 0
    
    raw_data[id_var] = id_var_values
    
    # append df
    independendent_variables = apply_transformation(
      raw_data = raw_data %>%
        bind_rows(df_ext) %>%
        select(id_var, pool_var, raw_ivs),
      model_table = model$model_table,
      trans_df = model$trans_df,
      pool_var = model$pool_var,
      verbose = verbose
    ) %>%
      rename(id = !!sym(id_var),
             pool = !!sym(pool_var)) #%>%
    
    id_ext = independendent_variables %>% 
      pull(id)
    pool_ext = independendent_variables %>% 
      pull(pool)
    independendent_variables = independendent_variables %>% 
      select(names(coef))
    
    
    if (length(coef) == 1) {
      # multiply independent variable by coefficient
      variable_decomp = data.frame(independendent_variables * coef)
      colnames(variable_decomp) = names(coef)
    } else{
      # multiply independent variables data frame by coefficient vector
      variable_decomp = data.frame(mapply(
        FUN = `*`,
        independendent_variables,
        coef,
        SIMPLIFY = FALSE
      ))
    }
    
    # rename variable decomp using coef names
    colnames(variable_decomp) = names(coef)
    
    variable_decomp = tibble(
      "(Intercept)" = intercept,
      variable_decomp,
      id = id_ext,
      pool = pool_ext
    )
    
  }else{
    
    # check raw_data & tail_window
    if(!is.null(tail_window) & is.null(raw_data)){
      
      if(verbose)message('Warning: No raw_data supplied for tail_window. Ignoring tail window.')
      
    }
    
    # get the independent variables decomp
    independendent_variables =  data[, 2:ncol(data)]
    if (length(coef) == 1) {
      # multiply independent variable by coefficient
      variable_decomp = data.frame(independendent_variables * coef)
      colnames(variable_decomp) = names(coef)
    } else{
      # multiply independent variables data frame by coefficient vector
      variable_decomp = data.frame(mapply(
        FUN = `*`,
        independendent_variables,
        coef,
        SIMPLIFY = FALSE
      ))
    }
    
    # rename variable decomp using coef names
    colnames(variable_decomp) = names(coef)
    
    variable_decomp = tibble(
      "(Intercept)" = intercept,
      variable_decomp,
      id = id_var_values,
      pool = pool_var_values
    )
  }
  
  
  # generate tibble df using the variable decomp, intercept and id variable
  variable_decomp = variable_decomp %>% 
    pivot_longer(cols = c('(Intercept)',names(coef)),
                 names_to = 'variable',
                 values_to = 'value') %>% 
    arrange(variable,id) %>% 
    rename(contrib = value)
  
  # if an id variable name is provided use it
  if (id_var != "id") {
    # fitted values
    col_names = colnames(fitted_values)
    col_names[col_names == "id"] = id_var
    colnames(fitted_values) = col_names
    
    # decomp
    col_names = colnames(variable_decomp)
    col_names[col_names == "id"] = id_var
    colnames(variable_decomp) = col_names
    
  }
  
  # if a raw actual is provided and de-normalise is TRUE, check if dv is STAed
  if(de_normalise){
    
    pool_mean = tibble(raw_actual = raw_actual,
                       pool = pool_var_values) %>%
      group_by(pool) %>%
      summarise(pool_mean = mean(raw_actual))
    
    
    variable_decomp = variable_decomp %>%
      left_join(pool_mean, by = "pool") %>%
      mutate(contrib = contrib * pool_mean) %>%
      select(-pool_mean)
    
    
    fitted_values = fitted_values %>%
      left_join(pool_mean, by = "pool") %>%
      mutate(value = value * pool_mean) %>%
      select(-pool_mean)
  }
  
  # IMPROVE CATEGORIES CHECK?
  if (is.null(categories)) {
    if(all(model$model_table$category == "")){
      if(verbose){
        message("Warning: No categories table provided and no categories found in model_table. Setting category_decomp = variable_decomp.")
      }
      category_decomp = variable_decomp
    }else{
      # create category from model table categories
      categories = model_table %>%
        select(variable,category) %>%
        mutate(category = if_else(category == '','Other',category)) %>%
        mutate(calc = 'none')
    }
  } else if(!is.data.frame(categories)){
    if(all(model_table$category == "")){
      if(verbose){
        message("Warning: categories table provided is not a data.frame and no categories found in model_table. Setting category_decomp = variable_decomp.")
      }
      category_decomp = variable_decomp
    }else{
      
      if(verbose){
        message("Warning: categories provided must be of type data.frame. Using model_table categories.")
      }
      # create category from model table categories
      categories = model_table %>%
        select(variable,category) %>%
        mutate(category = if_else(category == '','Other',category)) %>%
        mutate(calc = 'none')
    }
  }else if(!(all(c('variable','category') %in% colnames(categories)))){
    if(verbose){
      message("Warning: categories provided must contain atleast columns 'variable' and 'category'.")
    }
    if(all(model_table$category == "")){
      if(verbose){
        message("Warning: categories table provided is not a data.frame and no categories found in model_table. Setting category_decomp = variable_decomp.")
      }
      category_decomp = variable_decomp
    }else{
      
      if(verbose){
        message("Warning: categories provided must be of type data.frame. Using model_table categories.")
      }
      # create category from model table categories
      categories = model_table %>%
        select(variable,category) %>%
        mutate(category = if_else(category == '','Other',category)) %>%
        mutate(calc = 'none')
    }
  }
  
  if(!exists('category_decomp')){
    
    if(!('calc' %in% colnames(categories))){
      if(verbose){
        message("Warning: categories type data.frame provided does not include a 'calc' column. Setting all categories to 'none'.")
      }
      categories$calc = 'none'
    }
    
    # generate category decomp using categories df input
    category_decomp = variable_decomp %>%
      # add a categories and calc column to the variables decomp table
      left_join(categories, by = "variable") %>%
      # if no category is found for a variable assign it "Other" as a category
      mutate(category  = if_else(is.na(category),
                                 "Other",
                                 category)) %>%
      # assign the "Base" category to the intercept variable
      mutate(category  = if_else(variable == "(Intercept)",
                                 "Base",
                                 category)) %>%
      # if no calc is found for a variable/category assign it "none" as a calc
      mutate(calc  = if_else(is.na(calc),
                             "none",
                             calc)) %>%
      # group and sum the table by id and category
      group_by(!!sym(id_var), category,pool) %>%
      summarise(contrib = sum(contrib)) %>%
      rename(variable = category)
    
    # extract minned variables
    minned_vars = categories %>%
      filter(calc == "min") %>%
      pull(category) %>%
      unique()
    
    # extract maxed variables
    maxed_vars = categories %>%
      filter(calc == "max") %>%
      pull(category) %>%
      unique()
    
    # extract the initial (pre calc) base value
    based_value = category_decomp[category_decomp$variable == "Base", "contrib"]
    
    # for each minned variable
    for (cat in minned_vars) {
      # get the category values
      cat_val = category_decomp %>%
        filter(variable == cat) %>%
        pull(contrib)
      
      # get the minimum of each category
      min_val = cat_val %>%
        min()
      
      # replace the category values with the minned variable
      category_decomp[category_decomp$variable == cat, "contrib"] = cat_val - min_val
      
      # replace the base value with the base plus min value
      based_value = based_value + min_val
    }
    # for each maxed variable
    for (cat in maxed_vars) {
      # get the category values
      cat_val = category_decomp %>%
        filter(variable == cat) %>%
        pull(contrib)
      
      # get maximum of each category
      max_val = cat_val %>%
        max()
      
      # replace the category values with the mixed variable
      category_decomp[category_decomp$variable == cat, "contrib"] = cat_val - max_val
      
      # replace the base value with the base plus max value
      based_value = based_value + max_val
    }
    
    # replace the base value with the minned and maxed base value
    category_decomp[category_decomp$variable == "Base", "contrib"] = based_value
    
  }
  
  # return a list of category, variable tables, and fitted values
  l = list(
    category_decomp = category_decomp,
    variable_decomp = variable_decomp,
    fitted_values = fitted_values
  )
  
  return(l)
}

#' decomp_chart
#'
#' Variable Decomposition Bar Chart
#'
#' Plot the variable, or category, decomposition as stacked bars over the id variable which can be supplied to the \code{decomping} function.
#'
#' @param model Model object
#' @param decomp_list list object generated by the \code{decomping} function.
#' @param pool string specifying a group within the pool column to be filtered
#' @param colors character vector of colors in hexadecimal notation
#' @param variable_decomp boolean specifying whether the chart should be based on the variable_decomp or the category_decomp from the \code{decomping} function.
#' @param verbose A boolean to specify whether to print warnings
#' @import plotly
#' @import tidyverse
#' @export
#' @return a \code{plotly} bar chart of the model's decomposition
decomp_chart = function(model = NULL,
                        decomp_list = NULL,
                        pool = NULL,
                        colors = color_palette(),
                        variable_decomp = FALSE,
                        verbose = FALSE) {
  
  # checks    ####
  
  # Check verbose
  if(!is.logical(verbose)){
    message("Warning: verbose must be logical (TRUE or FALSE). Setting to False.")
    verbose = FALSE
  }

  # Check decomp_list , model
  if(is.null(model)){
    if(is.null(decomp_list)){
      if(verbose){
        message("Error: No decomp_list provided. Returning NULL. ")
      }
      return(NULL)
    }
  }else{
    decomp_list = model$decomp_list
  }


  # get decomp
  if (variable_decomp) {
    # get variable decomp table
    decomp = decomp_list$variable_decomp
    title = 'Variable Decomposition'
  } else{
    # get category decomp table
    decomp = decomp_list$category_decomp
    title = 'Category Decomposition'
  }

  # check decomp table
  if(!is.data.frame(decomp)){
  }
  if(!all(c("pool","variable","contrib") %in% colnames(decomp))){
    if(verbose){
      message("Error: decomp table must include 3 columns called 'pool', 'variable' and 'value'. Returning NULL. ")
    }
    return(NULL)
  }

  # get actual dependent variable table
  fitted_values = decomp_list$fitted_values

  # check fitted_values
  if(!is.data.frame(fitted_values)){
    ##
  }
  if(!all(c("pool","variable","value") %in% colnames(fitted_values))){
    if(verbose){
      message("Error: fitted_values table must include 3 columns called 'pool', 'variable' and 'value'. Returning NULL.")
    }
    return(NULL)
  }

  fitted_values = fitted_values[fitted_values$variable %in% c("actual","predicted", "residual"), ]

  # the id variable name is the first column name
  id_var = colnames(decomp)[1]


  # filter by pool if provided
  if (!is.null(pool)) {

    if(!any(decomp$pool == pool)){
      if(verbose){
        message("Warning: POOL ",pool," not found. No POOL filtering applied.")
      }
      decomp = decomp %>%
        rename(value = contrib)
    }
    else{
      decomp = decomp[decomp$pool == pool, ]%>%
        rename(value = contrib)
      fitted_values = fitted_values[fitted_values$pool == pool, ]
    }
  }

  if(is.null(pool)){
    if(verbose){
      message("Warning: No pool provided. Aggregating by id_var.")
    }

    fitted_values = fitted_values %>%
      group_by(variable,!!sym(id_var)) %>%
      summarise(value = sum(value)) %>%
      mutate(pool = "total_pool")

    decomp = decomp %>%
      group_by(variable,!!sym(id_var)) %>%
      summarise(contrib = sum(contrib)) %>%
      mutate(pool = "total_pool") %>%
      rename(value = contrib)

    pool = "total_pool"
  }

  # plot      ####
  
  # plot
  plot_ly(data = decomp,
          x = ~ get(id_var)) %>%
    add_trace(type = "bar",
              y = ~ value,
              color = ~ variable,
              name = ~ variable,
              colors = colors) %>%
    add_lines(
      data = fitted_values %>%
        filter(variable == "actual"),
      x = ~ get(id_var),
      y = ~ value,
      line = list(color = c("black")),
      name = "actual"
    ) %>%
    add_lines(
      data = fitted_values %>%
        filter(variable == "residual"),
      x = ~ get(id_var),
      y = ~ value,
      line = list(color = c("red")),
      name = ~ variable
    ) %>%
    layout(barmode = "relative",
           plot_bgcolor  = "rgba(0, 0, 0, 0)",
           paper_bgcolor = "rgba(0, 0, 0, 0)",
           title = title,
           font = list(color = '#1c0022'),
           xaxis = list(title = id_var,
                        showgrid = FALSE,
                        zerolinecolor = "#1c0022"))

}


#' fit_chart
#'
#' Dependent Variable, Predictions and Residuals Line Chart
#'
#' Plot the dependent variable, predictions and Residuals as a line chart over the id variable which can be supplied to the \code{decomping} function.
#'
#' @param model Model object
#' @param decomp_list list object generated by the \code{decomping} function.
#' @param pool string specifying a group within the pool column to be filtered
#' @param colors character vector of colors in hexadecimal notation
#' @param verbose A boolean to specify whether to print warnings
#' @import plotly
#' @import tidyverse
#' @export
#' @return a \code{plotly} line chart of the model's prediction and actual
#' @examples
#' run_model(data = mtcars,dv = 'mpg',ivs = 'cyl') %>% fit_chart()
fit_chart = function(model = NULL,
                     decomp_list = NULL,
                     pool = NULL,
                     verbose = FALSE,
                     colors = NULL) {
  # test    ####
  
  # data = read_xcsv(
  #   verbose = FALSE,
  #   file = "https://raw.githubusercontent.com/paladinic/data/main/ecomm_data.csv")
  # dv = "ecommerce"
  # ivs = c("black.friday", "christmas", "covid")
  # id_var = "date"
  # model = run_model(
  #   verbose = FALSE,
  #   data = data,
  #   dv = dv,
  #   ivs = ivs,
  #   normalise_by_pool = FALSE,
  #   id_var = id_var 
  # )
  # model %>% fit_chart()
  # verbose = TRUE
  # colors = NULL
  # pool = NULL
  
  # checks  #####
  
  # Check verbose
  if(!is.logical(verbose)){
    message("Warning: verbose must be logical (TRUE or FALSE). Setting to False.")
    verbose = FALSE
  }


  # Check decomp_list , model
  if(is.null(model)){
    if(is.null(decomp_list)){
      if(verbose){
        message("Error: No decomp_list provided. Returning NULL.")
      }
      return(NULL)
    }
  }else{
    decomp_list = model$decomp_list
  }
  # get actual dependent variable table
  fitted_values = decomp_list$fitted_values

  # the id variable name is the first column name
  id_var = colnames(fitted_values)[1]

  # filter by pool if provided
  if (!is.null(pool)) {
    if(!any(fitted_values$pool == pool)){
      if(verbose){
        message("Warning: POOL ",pool," not found. No POOL filtering applied.")
      }
    }
    else{
      fitted_values = fitted_values[fitted_values$pool == pool, ]
    }
  }

  if(is.null(pool)){
    if(verbose){
      message("Warning: No pool provided. Aggregating by id_var.")
    }
    fitted_values = fitted_values %>%
      group_by(variable,!!sym(id_var)) %>%
      summarise(value = sum(value)) %>%
      mutate(pool = "total_pool")

    pool = "total_pool"
  }

  # colors
  if(is.null(colors)){
    c1 = "#c90f3a"
    c2 = "#1c0022"
    c3 = "#00cf74"
  }else{
    c1 = colors[1]
    c2 = colors[2]
    c3 = colors[3]
  }

  # plot    ####

  plot_ly(fitted_values) %>%
    add_lines(
      data = filter(fitted_values, variable == "residual"),
      x = ~get(id_var),
      y = ~ value,
      line = list(color = c1),
      color =  ~ variable
    ) %>%
    add_lines(
      data = filter(fitted_values, variable == "actual"),
      x = ~ get(id_var),
      y = ~ value,
      line = list(color = c2),
      color =  ~ variable
    ) %>%
    add_lines(
      data = filter(fitted_values, variable == "predicted"),
      x = ~ get(id_var),
      y = ~ value,
      line = list(color = c3),
      color =  ~ variable
    ) %>%
    layout(
      plot_bgcolor  = "rgba(0, 0, 0, 0)",
      paper_bgcolor = "rgba(0, 0, 0, 0)",
      title = 'Fit Chart',
      font = list(color = '#1c0022'),
      xaxis = list(
        showgrid = FALSE,
        zerolinecolor = "#1c0022",
        title = id_var
      )
    )

}

#' add_total_pool
#'
#' Add an aggregated decomposition.
#'
#' When running a pooled model, it might be desirable to view the output of the `decomping()` function as an aggregate of all pools.
#'
#' @param model Model object
#' @param decomp_list list object generated by the \code{decomping} function.
#' @param verbose A boolean to specify whether to print warnings
#' @import dplyr
#' @export
#' @return a \code{list} of 3 \code{data.frame}'s representing the variable and category decomposition, and the fitted values.
add_total_pool = function(
    model = NULL,
    decomp_list = NULL,
    verbose = FALSE){
  # checks ------------------
  
  # check verbose
  if(!is.logical(verbose)){
    message("Warning: verbose must be logical (TRUE or FALSE). Setting to False.")
    verbose = FALSE
  }
  
  # check inputs
  if(is.null(model) & is.null(decomp_list)){
    message('Error: Neither "model" nor "decomp_list" were provided. Returning NULL')
    return(NULL)
  }
  if(!is.null(model)){
    if(!is.null(decomp_list)){
      if(verbose)message('Warning: Both "model" and "decomp_list" have been provided. "model" will be used and "decomp_list" will be ignored.')
    }
    
    if(is.null(model$decomp_list)){
      if(!is.null(decomp_list)){
        if(verbose)message('Warning: "model" does not contain "decomp_list". Using "decomp_list" provided.')
      }
    }else{
      decomp_list = model$decomp_list
    }
  }

  # process -----------------------------------------------------------------
  
  variable_decomp = decomp_list$variable_decomp
  category_decomp = decomp_list$category_decomp
  fitted_values = decomp_list$fitted_values
  
  id = colnames(variable_decomp)[1]
  
  variable_decomp = variable_decomp %>%
    bind_rows(
      variable_decomp %>%
        group_by(!!sym(id),variable) %>%
        summarise(contrib = sum(contrib)) %>%
        mutate(pool = 'Total')
    )
  
  category_decomp = category_decomp %>%
    bind_rows(
      category_decomp %>%
        group_by(!!sym(id),variable) %>%
        summarise(contrib = sum(contrib)) %>%
        mutate(pool = 'Total')
    )
  
  fitted_values = fitted_values %>%
    bind_rows(
      fitted_values %>%
        group_by(!!sym(id),variable) %>%
        summarise(value = sum(value)) %>%
        mutate(pool = 'Total')
    )
  
  decomp_list$variable_decomp = variable_decomp
  decomp_list$category_decomp = category_decomp
  decomp_list$fitted_values = fitted_values
  
  return(decomp_list)
  
}

#' add_total_pool_to_data
#'
#' Add an aggregated set of observations to a \code{data.frame}
#'
#' Add an aggregated set of observations to a \code{data.frame} based on a "pool" variable provided.
#'
#' @param data A \code{data.frame} containing the pool and id variables provided.
#' @param pool_var A string representing the variable name of the pool variable.
#' @param id_var A string representing the variable name of the id variable.
#' @import dplyr
#' @export
#' @return a \code{data.frame} with additional observations.
add_total_pool_to_data = function(data,pool_var,id_var) {
  
  # calculate the aggregated pools
  totals = data %>%
    group_by(!!sym(id_var)) %>%
    select_if(is.numeric) %>%
    summarise_all(.funs = sum) %>%
    mutate(pool_var = 'Total')
  
  # rename the columns of the aggregated df to match the original
  cols = colnames(totals)
  cols[cols == 'pool_var'] = pool_var
  colnames(totals) = cols
  
  # append the aggregated rows to the original data
  data = data %>%
    bind_rows(totals)
  
  return(data)
  
}




#' filter_decomp_pool
#'
#' Filter a model's decomposition based on a given pool.
#'
#' Filter all \code{data.frame}'s within a model's decomposition based on a given pool.
#'
#' @param decomp A \code{list} of \code{data.frame} from a model object or generated using \code{decomping}.
#' @param pool A string representing the variable name of the pool variable.
#' @param verbose A boolean to specify whether to print warnings
#' @import dplyr
#' @export
#' @return a \code{list} of 3 \code{data.frame}'s representing the variable and category decomposition, and the fitted values.
filter_decomp_pool = function(decomp,pool,verbose = TRUE){
  if (!(pool %in% colnames(decomp$variable_decomp))) {
    if(verbose){
      message("Error: pool string provided does not match decomp columns.")
    }
    return(decomp)
  }
  
  variable_decomp = decomp$variable_decomp %>%
    filter(pool == !!pool)
  category_decomp = decomp$category_decomp%>%
    filter(pool == !!pool)
  fitted_values = decomp$fitted_values%>%
    filter(pool == !!pool)
  
  decomp$variable_decomp = variable_decomp
  decomp$category_decomp = category_decomp
  decomp$fitted_values = fitted_values
  
  return(decomp)
  
  
}



#' resid_hist_chart
#'
#' Histogram of Model Residuals
#'
#' Plot a histogram to visualise the distribution of residuals.
#' This is meant to assess the residual distribution's normality.
#'
#' @param model Model object
#' @param decomp_list list object generated by the \code{decomping} function.
#' @param pool string specifying a group within the pool column to be filtered
#' @param color string specifying bar color
#' @param verbose A boolean to specify whether to print warnings
#' @import plotly
#' @import tidyverse
#' @export
#' @return a \code{plotly} histogram of the model's residuals
resid_hist_chart = function(model = NULL,
                            decomp_list = NULL,
                            pool = NULL,
                            color = "black",
                            verbose = FALSE){
  # Check verbose
  if(!is.logical(verbose)){
    message("Warning: verbose must be logical (TRUE or FALSE). Setting to False.")
    verbose = FALSE
  }


  # Check decomp_list , model
  if(is.null(model)){
    if(is.null(decomp_list)){
      if(verbose){
        message("Error: No decomp_list provided. Returning NULL.")
      }
      return(NULL)
    }
  }else{
    decomp_list = model$decomp_list
  }

  df = decomp_list$fitted_values %>%
    filter(variable == "residual") %>%
    tibble() %>%
    select(value) %>%
    rename(residual = 1)

  if(!is.null(pool)){
    pool_var = decomp_list$fitted_values %>%
      filter(variable == "residual") %>%
      pull(pool)

    if(pool %in% pool_var){
      df = df[pool_var==pool,]
    }
  }

  plot_ly(data = df) %>%
    add_histogram(histnorm = "probability",
                  x = ~residual,
                  marker = list(color = color,
                                line = list(color = "white",
                                            width = .5))) %>%
    layout(font = list(color = '#1c0022'),
           title = 'Residual Distribution',
           plot_bgcolor  = "rgba(0, 0, 0, 0)",
           paper_bgcolor = "rgba(0, 0, 0, 0)")

}

#' heteroscedasticity_chart
#'
#' Scatter of Residuals over dependent Variable
#'
#' Plot a scatter chart of residuals over the dependent variable.
#' This is meant to assess the consistency of the residuals' variance across the dependent variable.
#'
#' @param model Model object
#' @param decomp_list list object generated by the \code{decomping} function.
#' @param pool string specifying a group within the pool column to be filtered
#' @param color string specifying bar color
#' @param verbose A boolean to specify whether to print warnings
#' @import plotly
#' @import tidyverse
#' @importFrom tidyr pivot_wider
#' @export
#' @return a \code{plotly} scatter chart of the model's dependent variable over residuals
heteroskedasticity_chart = function(model = NULL,
                                    decomp_list = NULL,
                                    pool = NULL,
                                    color = "black",
                                    verbose = FALSE){

  # Check verbose
  if (!is.logical(verbose)) {
    message("Warning: verbose must be logical (TRUE or FALSE). Setting to False.")
    verbose = FALSE
  }


  # Check decomp_list , model
  if (is.null(model)) {
    if (is.null(decomp_list)) {
      if (verbose) {
        message("Error: No decomp_list provided. Returning NULL.")
      }
      return(NULL)
    }
  } else{
    decomp_list = model$decomp_list
  }


  df = decomp_list$fitted_values %>%
    filter(variable != 'prediction') %>%
    pivot_wider(names_from  = variable)

  if (!is.null(pool)) {
    pool_var = df %>%
      pull(pool)

    if (pool %in% pool_var) {
      df = df[pool_var == pool, ]
    } else{
      print('Warning: pool not found in data. Using full data.')
    }
  }

  plot_ly(data = df) %>%
    add_trace(
      x = ~ residual,
      y = ~ actual,
      type = 'scatter',
      mode = 'markers',
      marker = list(color = color,
                    line = list(color = "white",
                                width = .5))
    ) %>%
    layout(
      font = list(color = '#1c0022'),
      title = 'Heteroskedasticity',
      plot_bgcolor  = "rgba(0, 0, 0, 0)",
      paper_bgcolor = "rgba(0, 0, 0, 0)"
    )

}


#' acf_chart
#'
#' Bar chart of autocorrelation function
#'
#' A bar chart meant to assess the correlation of the residuals with lagged versions of themselves.
#'
#' @param model Model object
#' @param decomp_list list object generated by the \code{decomping} function.
#' @param pool string specifying a group within the pool column to be filtered
#' @param color string specifying bar color
#' @param verbose A boolean to specify whether to print warnings
#' @import plotly
#' @import tidyverse
#' @export
#' @return a \code{plotly} bar chart of the model's ACF
acf_chart = function(model = NULL,
                     decomp_list,
                     pool = NULL,
                     color = "black",
                     verbose = FALSE){

  # Check verbose
  if (!is.logical(verbose)) {
    message("Warning: verbose must be logical (TRUE or FALSE). Setting to False.")
    verbose = FALSE
  }


  # Check decomp_list , model
  if (is.null(model)) {
    if (is.null(decomp_list)) {
      if (verbose) {
        message("Error: No decomp_list provided. Returning NULL.")
      }
      return(NULL)
    }
  } else{
    decomp_list = model$decomp_list
  }



  df = decomp_list$fitted_values %>%
    filter(variable == "residual") %>%
    tibble() %>%
    select(value) %>%
    rename(residual = 1)

  if(!is.null(pool)){
    pool_var = decomp_list$fitted_values %>%
      filter(variable == "residual") %>%
      pull(pool)

    if(pool %in% pool_var){
      df = df[pool_var==pool,]
    }
  }

  x = acf(df$residual, plot = FALSE)
  x = data.frame(x$acf) %>%
    rownames_to_column("x") %>%
    mutate(x = as.numeric(x))


  plot_ly() %>%
    add_trace(y = x$x.acf, x = x$x,marker = list(color = color,
                                                 line = list(color = "white",
                                                             width = .5)),type="bar") %>%
    add_trace(
      showlegend = FALSE,
      y = rep(0.2, length(x$x)),
      x = x$x,
      type = 'scatter',
      mode = 'lines',
      hoverinfo = 'skip',
      line = list(color = "rgba(0, 0, 0, 0.5)",dash = 'dot')
    ) %>%
    add_trace(
      hoverinfo = 'skip',
      showlegend = FALSE,
      y = rep(-0.2, length(x$x)),
      x = x$x,
      type = 'scatter',
      mode = 'lines',
      line = list(color = "rgba(0, 0, 0, 0.5)",dash = 'dot')
    ) %>%
    add_trace(
      hoverinfo = 'skip',
      showlegend = FALSE,
      y = rep(0.4, length(x$x)),
      x = x$x,
      type = 'scatter',
      mode = 'lines',
      line = list(color = "rgba(0, 0, 0, 0.75)",dash = 'dot')
    ) %>%
    add_trace(
      hoverinfo = 'skip',
      showlegend = FALSE,
      y = rep(-0.4, length(x$x)),
      x = x$x,
      type = 'scatter',
      mode = 'lines',
      line = list(color = "rgba(0, 0, 0, 0.75)",dash = 'dot')
    ) %>%
    layout(
      font = list(color = '#1c0022'),
      title = 'Autocorrelation Function',
      plot_bgcolor  = "rgba(0, 0, 0, 0)",
      paper_bgcolor = "rgba(0, 0, 0, 0)",
      xaxis = list(
        showgrid = FALSE,
        zerolinecolor = "#1c0022")
    )

}

#' response_curves
#'
#' Line chart of variable response curves
#'
#' Line chart of variable response curves visualising the relationship of each independent variable with the dependent variable
#'
#' @param model Model object
#' @param x_min number specifying horizontal axis min
#' @param x_max number specifying horizontal axis max
#' @param y_min number specifying vertical axis min
#' @param y_max number specifying vertical axis max
#' @param interval number specifying interval between points of the curve
#' @param trans_only a boolean specifying whether to display non-linear only \code{y = b*dim_rest(x)}
#' @param colors character vector of colors in hexadecimal notation
#' @param verbose A boolean to specify whether to print warnings
#' @param table A boolean to specify whether to return a \code{data.frame} of the response curves
#' @param points A boolean to specify whether to include the points from the data on the curve
#' @param plotly A boolean to specify whether to include use ggplot over plotly
#' @param add_intercept A boolean to specify whether to include the intercept whne calculating the curves
#' @importFrom ggplot2 ggplot geom_line scale_color_manual theme ggtitle ylab geom_vline geom_hline element_rect aes
#' @import plotly
#' @import tidyverse
#' @import tibble
#' @importFrom RColorBrewer brewer.pal
#' @importFrom stats na.omit
#' @export
#' @return a \code{plotly} line chart of the model's response curves
#' @examples
#' model = run_model(data = mtcars,dv = 'mpg',ivs = c('disp'))
#' model %>%
#'    response_curves()
#' model = run_model(data = mtcars,dv = 'mpg',ivs = c('wt','cyl','disp')) 
#' 
#' model %>%
#'    response_curves()
#'    
#' run_model(data = scale(mtcars) %>% 
#'               data.frame(),
#'           dv = 'mpg',
#'           ivs = c('wt','cyl','disp')) %>%
#'    response_curves()
response_curves = function(
    model,
    x_min = NULL,
    x_max = NULL,
    y_min = NULL,
    y_max = NULL,
    interval = NULL,
    trans_only = FALSE,
    colors = color_palette(),
    plotly = TRUE,
    verbose = FALSE,
    table = FALSE,
    add_intercept = FALSE,
    points = FALSE){
  # checks  ####
  # model = run_model(data = mtcars,dv = 'mpg',ivs = c('disp','wt'))
  # x_min = NULL
  # x_max = NULL
  # y_min = NULL
  # y_max = NULL
  # interval = NULL
  # trans_only = FALSE
  # colors = color_palette()
  # plotly = TRUE
  # verbose = FALSE
  # table = FALSE
  # add_intercept = FALSE
  # points = FALSE
  
  if (is.null(x_max)) x_max = 1e+05
  if (is.null(x_min)) x_min = -1e+05
  if (is.null(interval)) interval = (x_max-x_min)/100
  if (is.null(y_max)) y_max = x_max
  if (is.null(y_min)) y_min = x_min

  # process ####
  optim_table = model$output_model_table


  trans_df = model$trans_df %>%
    filter(ts == FALSE)

  if (trans_only) {
    optim_table = optim_table[!(((optim_table[trans_df$name] ==
                                    "") %>% data.frame() %>% rowSums()) == nrow(trans_df)),
    ]
  }
  optim_table = optim_table %>% filter(variable != "(Intercept)")
  optim_table = optim_table[c("variable", "variable_t",
                              trans_df$name, "coef")] %>% na.omit()
  if (nrow(optim_table) == 0) {
    message("Error: Check model and/or model_table.")
    return(NULL)
  }
  curves_df = list()
  x_raw = seq(x_min, x_max, interval)
  for (i in 1:nrow(optim_table)) {
    var = optim_table$variable_t[i]
    coef = optim_table$coef[i]
    x = x_raw
    for (j in 1:nrow(trans_df)) {
      t_name = trans_df$name[j]
      t_func = trans_df$func[j]
      param_vals = model$output_model_table %>%
        filter(variable_t == var) %>%
        pull(!!sym(t_name)) %>%
        strsplit(split = ",")

      param_vals = param_vals[[1]] %>% as.numeric()
      if (length(param_vals) == 0) {
        next
      }
      param_names = letters[1:length(param_vals)]
      e <- new.env()
      for (k in 1:length(param_vals)) {
        p_name = param_names[k]
        p_val = param_vals[k]
        assign(p_name, p_val, envir = e)
      }
      x = t_func %>% run_text(env = e)
    }
    
    df = data.frame(value = x * coef,
                    variable = var,
                    x = x_raw) %>%
      mutate(value = as.numeric(value)) %>%
      mutate(variable = as.character(variable)) %>%
      mutate(x = as.numeric(x))
    
    curves_df = append(curves_df, list(df))
  }
  curves_df = curves_df %>%
    Reduce(f = rbind)

  if(add_intercept){
    curves_df = curves_df %>%
      mutate(value = value + model$coefficients[1])
  }

  curves_df = curves_df %>%
    filter(value >= y_min) %>%
    filter(value <= y_max)


  if (table) {
    return(curves_df)
  }

  # plotly  ####
  if(plotly){

    p = plot_ly()
    p = p %>% add_trace(data = curves_df, x = ~x, y = ~value,
                        color = ~variable, mode = "lines", type = "scatter",
                        colors = colors) %>%
      layout(plot_bgcolor = "rgba(0, 0, 0, 0)",
             paper_bgcolor = "rgba(0, 0, 0, 0)",
             font = list(color = "#1c0022"),
             xaxis = list(showgrid = FALSE),
             yaxis = list(title = model$dv),
             title = "Response Curves")

    if(points) {

      raw_data = model$model

      # calculate predicted points through functions
      curves_df = list()
      for (i in 1:nrow(optim_table)) {
        var = optim_table$variable_t[i]
        coef = optim_table$coef[i]
        x_raw = raw_data %>% pull(!!sym(var))
        x = x_raw
        for (j in 1:nrow(trans_df)) {
          t_name = trans_df$name[j]
          t_func = trans_df$func[j]
          param_vals = model$output_model_table %>% filter(variable_t ==
                                                             var) %>% pull(!!sym(t_name)) %>% strsplit(split = ",")
          param_vals = param_vals[[1]] %>% as.numeric()
          if (length(param_vals) == 0) {
            next
          }
          param_names = letters[1:length(param_vals)]
          e <- new.env()
          for (k in 1:length(param_vals)) {
            p_name = param_names[k]
            p_val = param_vals[k]
            assign(p_name, p_val, envir = e)
          }
          x = t_func %>% run_text(env = e)
        }
        curves_df =  append(curves_df, data.frame(value = x *
                                                        coef, variable = var, x = x_raw) %>% mutate(value = as.numeric(value)) %>%
                                  mutate(variable = as.character(variable)) %>% mutate(x = as.numeric(x)))
      }
      curves_df = curves_df %>%
        data.frame() #%>%
      # filter(value >= y_min) %>% filter(value <= y_max)


      # add points to plotly item
      p = p %>%  add_trace(data = curves_df, x = ~x, y = ~value,
                           color = ~variable, mode = "markers", type = "scatter",
                           colors = colors)
    }

  }
  # ggplot  ####
  if(!plotly){
    
    p = ggplot(data=curves_df, aes(x=x, y=value, col=variable)) +
      geom_line() +
      scale_color_manual(values = color_palette()) +
      theme(
        panel.background = element_rect(fill = "white",
                                        colour = "white")) +
      ggtitle("Response Curves") +
      ylab(model$dv) +
      geom_vline(xintercept = 0) +
      geom_hline(yintercept = 0)
    
  }

  return(p)
}
