# Copyright 2022 DARWIN EU (C)
#
# This file is part of DrugUtilisation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#' This function is used to summarise the large scale characteristics of a
#' cohort table
#'
#' @param cohort The cohort to characterise.
#' @param strata Stratification list.
#' @param window Temporal windows that we want to characterize.
#' @param eventInWindow Tables to characterise the events in the window.
#' @param episodeInWindow Tables to characterise the episodes in the window.
#' @param includeSource Whether to include source concepts.
#' @param minCellCount All counts lower than minCellCount will be obscured.
#' @param minimumFrequency Minimum frequency covariates to report.
#' @param cdm A cdm reference.
#'
#' @return The output of this function is a `ResultSummary` containing the
#' relevant information.
#'
#' @export
#'
summariseLargeScaleCharacteristics <- function(cohort,
                                               strata = list(),
                                               window = list(
                                                 c(-Inf, -366), c(-365, -31),
                                                 c(-30, -1), c(0, 0), c(1, 30),
                                                 c(31, 365), c(366, Inf)
                                               ),
                                               eventInWindow = NULL,
                                               episodeInWindow = NULL,
                                               includeSource = FALSE,
                                               minCellCount = 5,
                                               minimumFrequency = 0.005,
                                               cdm = attr(cohort, "cdm_reference")) {
  if (!is.list(window)) {
    window <- list(window)
  }

  # initial checks
  checkX(cohort)
  checkStrata(strata, cohort)
  checkWindow(window)
  tables <- c(
    namesTable$table_name, paste("ATC", c("1st", "2nd", "3rd", "4th", "5th"))
  )
  checkmate::assertSubset(eventInWindow, tables)
  checkmate::assertSubset(episodeInWindow, tables)
  if (is.null(eventInWindow) & is.null(episodeInWindow)) {
    cli::cli_abort("'eventInWindow' or 'episodeInWindow' must be provided")
  }
  checkmate::assertLogical(includeSource, any.missing = FALSE, len = 1)
  checkmate::assertIntegerish(
    minCellCount, lower = 0, any.missing = FALSE, len = 1
  )
  checkmate::assertNumber(minimumFrequency, lower = 0, upper = 1)
  checkCdm(cdm)
  assertWriteSchema(cdm)

  # add names to windows
  names(window) <- gsub("_", " ", gsub("m", "-", getWindowNames(window)))

  # random tablePrefix
  tablePrefix <- c(sample(letters, 5, TRUE), "_") %>% paste0(collapse = "")

  # initial table
  x <- getInitialTable(cohort, tablePrefix)

  # get analysis table
  analyses <- getAnalyses(eventInWindow, episodeInWindow)

  minWindow <- min(unlist(window))
  maxWindow <- max(unlist(window))

  # perform lsc
  lsc <- NULL
  for (tab in unique(analyses$table)) {
    analysesTable <- analyses %>% dplyr::filter(.data$table == .env$tab)
    table <- getTable(tab, x, includeSource, minWindow, maxWindow, tablePrefix)
    for (k in seq_len(nrow(analysesTable))) {
      type <- analysesTable$type[k]
      analysis <- analysesTable$analysis[k]
      tableAnalysis <- getTableAnalysis(table, type, analysis, tablePrefix)
      for (win in seq_along(window)) {
        tableWindow <- getTableWindow(tableAnalysis, window[[win]], tablePrefix)
        lsc <- lsc %>%
          dplyr::bind_rows(
            summariseConcept(cohort, tableWindow, strata, tablePrefix) %>%
              dplyr::mutate(
                "window_name" = names(window)[win],
                "table_name" = .env$tab,
                "analysis" = .env$analysis,
                "type" = .env$type
              )
          )
      }
      if (includeSource & analysis == "standard" & !is.na(getSourceConceptName(tab))) {
        tableAnalysis <- getTableAnalysis(table, type, "source", tablePrefix)
        for (win in seq_along(window)) {
          tableWindow <- getTableWindow(tableAnalysis, window[[win]], tablePrefix)
          lsc <- lsc %>%
            dplyr::bind_rows(
              summariseConcept(cohort, tableWindow, strata, tablePrefix) %>%
                dplyr::mutate(
                  "window_name" = names(window)[win],
                  "table_name" = .env$tab,
                  "analysis" = "source",
                  "type" = .env$type
                )
            )
        }
      }
    }
  }

  # calculate denominators
  den <- denominatorCounts(cohort, x, strata, window, tablePrefix)

  # format results
  results <- lsc %>%
    formatLscResult(den, cdm, minimumFrequency, minCellCount)

  # eliminate permanent tables
  CDMConnector::dropTable(cdm = cdm, name = dplyr::starts_with(tablePrefix))

  # return
  return(results)
}

#' This function is used to add columns with the large scale characteristics of
#' a cohort table.
#'
#' @param cohort The cohort to characterise.
#' @param window Temporal windows that we want to characterize.
#' @param eventInWindow Tables to characterise the events in the window.
#' @param episodeInWindow Tables to characterise the episodes in the window.
#' @param includeSource Whether to include source concepts.
#' @param minimumFrequency Minimum frequency covariates to report.
#'
#' @return The output of this function is the cohort with the new created
#' columns
#'
#' @export
#'
addLargeScaleCharacteristics <- function(cohort,
                                         window = list(c(0, Inf)),
                                         eventInWindow = NULL,
                                         episodeInWindow = NULL,
                                         includeSource = FALSE,
                                         minimumFrequency = 0.005) {
  if (!is.list(window)) {
    window <- list(window)
  }

  # initial checks
  checkX(cohort)
  checkWindow(window)
  tables <- c(
    namesTable$table_name, paste("ATC", c("1st", "2nd", "3rd", "4th", "5th"))
  )
  checkmate::assertSubset(eventInWindow, tables)
  checkmate::assertSubset(episodeInWindow, tables)
  if (is.null(eventInWindow) & is.null(episodeInWindow)) {
    cli::cli_abort("'eventInWindow' or 'episodeInWindow' must be provided")
  }
  checkmate::assertLogical(includeSource, any.missing = FALSE, len = 1)
  checkmate::assertNumber(minimumFrequency, lower = 0, upper = 1)
  cdm <- attr(cohort, "cdm_reference")
  checkCdm(cdm)
  assertWriteSchema(cdm)

  # add names to windows
  winNams <- unlist(getWindowNames(window))
  nams <- uniqueVariableName(length(window))
  dic <- dplyr::tibble(window_name = nams, window_nam = paste0("lsc_", winNams))
  names(window) <- nams

  # random tablePrefix
  tablePrefix <- c(sample(letters, 5, TRUE), "_") %>% paste0(collapse = "")

  # initial table
  x <- getInitialTable(cohort, tablePrefix)

  # minimum count
  numEntries <- x |>
    dplyr::ungroup() |>
    dplyr::tally() |>
    dplyr::pull()
  minimumCount <- numEntries * minimumFrequency

  # get analysis table
  analyses <- getAnalyses(eventInWindow, episodeInWindow)

  minWindow <- min(unlist(window))
  maxWindow <- max(unlist(window))

  lsc <- NULL
  for (tab in unique(analyses$table)) {
    analysesTable <- analyses %>% dplyr::filter(.data$table == .env$tab)
    table <- getTable(tab, x, includeSource, minWindow, maxWindow, tablePrefix)
    for (k in seq_len(nrow(analysesTable))) {
      type <- analysesTable$type[k]
      analysis <- analysesTable$analysis[k]
      tableAnalysis <- getTableAnalysis(table, type, analysis, tablePrefix)
      for (win in seq_along(window)) {
        tableWindow <- getTableWindow(tableAnalysis, window[[win]], tablePrefix)
        lsc <- lsc %>%
          trimCounts(tableWindow, minimumCount, tablePrefix, names(window)[win])
      }
      if (includeSource & analysis == "standard" & !is.na(getSourceConceptName(tab))) {
        tableAnalysis <- getTableAnalysis(table, type, "source", tablePrefix)
        for (win in seq_along(window)) {
          tableWindow <- getTableWindow(tableAnalysis, window[[win]], tablePrefix)
          lsc <- lsc %>%
            trimCounts(tableWindow, minimumCount, tablePrefix, names(window)[win])
        }
      }
    }
  }

  # add new columns
  originalCols <- colnames(cohort)
  cohort <- cohort %>%
    dplyr::left_join(
      lsc %>%
        dplyr::select(
          "subject_id", "cohort_start_date", "concept", "window_name"
        ) %>%
        dplyr::inner_join(dic, by = "window_name", copy = TRUE) %>%
        dplyr::mutate(
          value = 1,
          concept = as.character(as.integer(.data$concept)),
          name = paste0(.data$window_nam, "_", .data$concept)
        ) %>%
        dplyr::select("subject_id", "cohort_start_date", "name", "value") %>%
        tidyr::pivot_wider(
          names_from = "name", values_from = "value"
        ),
      by = c("subject_id", "cohort_start_date")
    ) %>%
    dplyr::mutate(dplyr::across(
      !dplyr::all_of(originalCols), ~ dplyr::if_else(is.na(.x), 0, 1)
    )) %>%
    CDMConnector::computeQuery()

  # eliminate permanent tables
  CDMConnector::dropTable(cdm = cdm, name = dplyr::starts_with(tablePrefix))

  # return
  return(cohort)

}

getAnalyses <- function(eventInWindow, episodeInWindow) {
  atc <- c("ATC 1st", "ATC 2nd", "ATC 3rd", "ATC 4th", "ATC 5th")
  icd10 <- c("icd10 chapter", "icd10 subchapter")
  list(
    dplyr::tibble(
      table = eventInWindow[!(eventInWindow %in% c(atc, icd10))],
      type = "event", analysis = "standard"
    ),
    dplyr::tibble(
      table = episodeInWindow[!(episodeInWindow %in% c(atc, icd10))],
      type = "episode", analysis = "standard"
    ),
    dplyr::tibble(
      table = "drug_exposure", type = "event",
      analysis = eventInWindow[eventInWindow %in% atc],
    ),
    dplyr::tibble(
      table = "drug_exposure", type = "episode",
      analysis = episodeInWindow[episodeInWindow %in% atc],
    ),
    dplyr::tibble(
      table = "condition_occurrence", type = "event",
      analysis = eventInWindow[eventInWindow %in% icd10],
    ),
    dplyr::tibble(
      table = "condition_occurrence", type = "episode",
      analysis = episodeInWindow[episodeInWindow %in% icd10],
    )
  ) %>%
    dplyr::bind_rows() %>%
    tidyr::drop_na()
}
getInitialTable <- function(cohort, tablePrefix) {
  cohort %>%
    addDemographics(
      age = FALSE, sex = FALSE, priorObservationName = "start_obs",
      futureObservationName = "end_obs"
    ) %>%
    dplyr::mutate(start_obs = -.data$start_obs) %>%
    dplyr::select("subject_id", "cohort_start_date", "start_obs", "end_obs") %>%
    dplyr::distinct() %>%
    dbplyr::window_order(.data$subject_id, .data$cohort_start_date) %>%
    dplyr::mutate(obs_id = dplyr::row_number()) %>%
    dbplyr::window_order() %>%
    CDMConnector::computeQuery(
      name = paste0(tablePrefix, "individuals"), temporary = FALSE,
      schema = writeSchema(cohort), overwrite = TRUE
    )
}
getTable <- function(tab, x, includeSource, minWindow, maxWindow, tablePrefix) {
  cdm <- attr(x, "cdm_reference")
  toSelect <- c(
    "subject_id" = "person_id",
    "start_diff" = getStartName(tab),
    "end_diff" = ifelse(is.na(getEndName(tab)), getStartName(tab), getEndName(tab)),
    "standard" = getConceptName(tab),
    "source" = getSourceConceptName(tab)
  )
  if (includeSource == FALSE) {
    toSelect <- toSelect["source" != names(toSelect)]
  }
  table <- cdm[[tab]] %>%
    dplyr::select(dplyr::all_of(toSelect)) %>%
    dplyr::inner_join(x, by = "subject_id") %>%
    dplyr::mutate(end_diff = dplyr::if_else(
      is.na(.data$end_diff), .data$start_diff, .data$end_diff
    )) %>%
    dplyr::mutate(start_diff = !!CDMConnector::datediff(
      "cohort_start_date", "start_diff"
    )) %>%
    dplyr::mutate(end_diff = !!CDMConnector::datediff(
      "cohort_start_date", "end_diff"
    )) %>%
    dplyr::filter(
      .data$end_diff >= .data$start_obs & .data$start_diff <= .data$end_obs
    )
  if (!is.infinite(minWindow)) {
    table <- table %>%
      dplyr::filter(.data$end_diff >= .env$minWindow)
  }
  if (!is.infinite(maxWindow)) {
    table <- table %>%
      dplyr::filter(.data$start_diff <= .env$maxWindow)
  }
  table <- table %>%
    dplyr::select(-"start_obs", -"end_obs") %>%
    CDMConnector::computeQuery(
      name = paste0(tablePrefix, "table"), temporary = FALSE,
      schema = writeSchema(x), overwrite = TRUE
    )
}
writeSchema <- function(x) {
  attr(attr(x, "cdm_reference"), "write_schema")
}
summariseConcept <- function(cohort, tableWindow, strata, tablePrefix) {
  result <- NULL
  cohortNames <- CDMConnector::cohortSet(cohort)$cohort_name
  for (cohortName in cohortNames) {
    cdi <- CDMConnector::cohortSet(cohort) %>%
      dplyr::filter(.data$cohort_name == .env$cohortName) %>%
      dplyr::pull("cohort_definition_id")
    tableWindowCohort <- tableWindow %>%
      dplyr::inner_join(
        cohort %>%
          dplyr::filter(.data$cohort_definition_id == .env$cdi),
        by = c("subject_id", "cohort_start_date")
      ) %>%
      dplyr::select(
        "obs_id", "concept", dplyr::all_of(unique(unlist(strata)))
      ) %>%
      CDMConnector::computeQuery(
        name = paste0(tablePrefix, "table_window_cohort"), temporary = FALSE,
        schema = writeSchema(cohort), overwrite = TRUE
      )
    result <- result %>%
      dplyr::bind_rows(
        tableWindowCohort %>%
          dplyr::group_by(.data$concept) %>%
          dplyr::summarise(count = as.numeric(dplyr::n()), .groups = "drop") %>%
          dplyr::collect() %>%
          dplyr::mutate(strata_name = "Overall", strata_level = "Overall") %>%
          dplyr::bind_rows(summariseStrataCounts(tableWindowCohort, strata)) %>%
          dplyr::mutate(group_name = "Cohort name", group_level = cohortName)
      )
  }
  return(result)
}
summariseStrataCounts <- function(tableWindowCohort, strata) {
  result <- NULL
  for (k in seq_along(strata)) {
    result <- result %>%
      dplyr::union_all(
        tableWindowCohort %>%
          dplyr::group_by(dplyr::pick(c("concept", strata[[k]]))) %>%
          dplyr::summarise(count = as.numeric(dplyr::n()), .groups = "drop") %>%
          dplyr::collect() %>%
          tidyr::unite(col = "strata_level", dplyr::all_of(strata[[k]]), sep = " and ") %>%
          dplyr::mutate(strata_name = paste0(strata[[k]], collapse = " and "))
      )
  }
  return(result)
}
denominatorCounts <- function(cohort, x, strata, window, tablePrefix) {
  table <- x %>%
    dplyr::rename("start_diff" = "start_obs", "end_diff" = "end_obs") %>%
    dplyr::mutate(concept = "denominator")
  den <- NULL
  for (win in seq_along(window)) {
    tableWindow <- getTableWindow(table, window[[win]], tablePrefix)
    den <- den %>%
      dplyr::bind_rows(
        summariseConcept(cohort, tableWindow, strata, tablePrefix) %>%
          dplyr::mutate(window_name = names(window)[win])
      )
  }
  return(den)
}
formatLscResult <- function(lsc, den, cdm, minimumFrequency, minCellCount) {
  lsc %>%
    dplyr::inner_join(
      den %>%
        dplyr::rename("denominator" = "count") %>%
        dplyr::filter(.data$denominator >= .env$minCellCount) %>%
        dplyr::select(-"concept"),
      by = c(
        "strata_name", "strata_level", "group_name", "group_level",
        "window_name"
      )
    ) %>%
    dplyr::mutate(percentage = 100 * .data$count / .data$denominator) %>%
    dplyr::select(-"denominator") %>%
    dplyr::filter(.data$count >= .env$minCellCount) %>%
    dplyr::filter(.data$percentage >= 100 * .env$minimumFrequency) %>%
    tidyr::pivot_longer(
      cols = c("count", "percentage"), names_to = "estimate_type",
      values_to = "estimate"
    ) %>%
    addCdmName(cdm = cdm) %>%
    dplyr::mutate(
      estimate = as.character(.data$estimate),
      result_type = "Summarised Large Scale Characteristics"
    ) %>%
    dplyr::inner_join(addConceptName(lsc, cdm), by = c("concept", "analysis")) %>%
    dplyr::select(
      "result_type", "cdm_name", "group_name", "group_level", "strata_name",
      "strata_level", "table_name", "type", "analysis", "concept",
      "variable" = "concept_name", "variable_level" = "window_name",
      "estimate_type", "estimate"
    )
}
addConceptName <- function(lsc, cdm) {
  concepts <- lsc %>%
    dplyr::select("concept", "analysis") %>%
    dplyr::distinct()
  conceptNames <- cdm[["concept"]] %>%
    dplyr::select("concept" = "concept_id", "concept_name") %>%
    dplyr::inner_join(
      concepts %>%
        dplyr::mutate(concept = as.numeric(.data$concept)),
      by = "concept",
      copy = TRUE
    ) %>%
    dplyr::collect()
  return(conceptNames)
}
getTableAnalysis <- function(table, type, analysis, tablePrefix) {
  if (type == "event") {
    table <- table %>%
      dplyr::mutate("end_diff" = .data$start_diff)
  }
  if (analysis %in% c("standard", "source")) {
    table <- table %>%
      dplyr::rename("concept" = dplyr::all_of(analysis)) %>%
      dplyr::select(-dplyr::any_of(c("standard", "source")))
  } else {
    table <- table %>%
      dplyr::rename("concept" = "standard") %>%
      dplyr::select(-dplyr::any_of("source"))
    table <- getCodesGroup(table, analysis, tablePrefix)
  }
  return(table)
}
getCodesGroup <- function(table, analysis, tablePrefix) {
  cdm <- attr(table, "cdm_reference")
  if (analysis %in% c("ATC 1st", "ATC 2nd", "ATC 3rd", "ATC 4th", "ATC 5h")) {
    codes <- cdm[["concept"]] %>%
      dplyr::filter(.data$vocabulary_id == "ATC") %>%
      dplyr::filter(.data$concept_class_id == .env$analysis) %>%
      dplyr::select("concept_new" = "concept_id") %>%
      dplyr::inner_join(
        cdm[["concept_ancestor"]] %>%
          dplyr::select(
            "concept_new" = "ancestor_concept_id",
            "concept" = "descendant_concept_id"
          ),
        by = "concept_new"
      )
  } else {
    codes <- cdm[["concept"]] %>%
      dplyr::filter(.data$vocabulary_id == "ICD10") %>%
      dplyr::filter(.data$concept_class_id == .env$analysis) %>%
      dplyr::select("concept_new" = "concept_id")
    # TODO
  }
  table <- table %>%
    dplyr::inner_join(codes, by = "concept") %>%
    dplyr::select(-"concept") %>%
    dplyr::rename("concept" = "concept_new") %>%
    CDMConnector::computeQuery(
      name = paste0(tablePrefix, "table_group"), temporary = FALSE,
      schema = writeSchema(table), overwrite = TRUE
    )
  return(table)
}
getTableWindow <- function(table, window, tablePrefix) {
  startWindow <- window[1]
  endWindow <- window[2]
  if (is.infinite(startWindow)) {
    if (is.infinite(endWindow)) {
      tableWindow <- table
    } else {
      tableWindow <- table %>%
        dplyr::filter(.data$start_diff <= .env$endWindow)
    }
  } else {
    if (is.infinite(endWindow)) {
      tableWindow <- table %>%
        dplyr::filter(.data$end_diff >= .env$startWindow)
    } else {
      tableWindow <- table %>%
        dplyr::filter(
          .data$end_diff >= .env$startWindow &
            .data$start_diff <= .env$endWindow
        )
    }
  }
  tableWindow <- tableWindow %>%
    dplyr::select("subject_id", "cohort_start_date", "obs_id", "concept") %>%
    dplyr::distinct() %>%
    CDMConnector::computeQuery(
      name = paste0(tablePrefix, "table_window"), temporary = FALSE,
      schema = writeSchema(table), overwrite = TRUE
    )
  return(tableWindow)
}
trimCounts <- function(lsc, tableWindow, minimumCount, tablePrefix, winName) {
  x <- tableWindow %>%
    dplyr::inner_join(
      tableWindow %>%
        dplyr::group_by(.data$concept) %>%
        dplyr::summarise(count = dplyr::n(), .groups = "drop") %>%
        dplyr::filter(.data$count >= .env$minimumCount) %>%
        dplyr::select("concept"),
      by = "concept"
    ) %>%
    dplyr::mutate("window_name" = .env$winName)
  if (is.null(lsc)) {
    lsc <- x %>%
      CDMConnector::computeQuery(
        name = paste0(tablePrefix, "lsc"), temporary = FALSE,
        schema = writeSchema(tableWindow), overwrite = TRUE
      )
  } else {
    lsc <- x %>%
      CDMConnector::appendPermanent(
        name = paste0(tablePrefix, "lsc"), schema = writeSchema(tableWindow)
      )
  }
  return(lsc)
}
