## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment  = "#>",
  fig.width = 8,
  fig.height = 8,
  message = FALSE,
  warning = FALSE
)
library(ggplot2); library(dplyr); library(tidyr)
data.table::setDTthreads(1L)
options(dplyr.summarise.inform = FALSE, scipen = 999, digits = 5)
theme_set(theme_bw(base_size = 12))

## ----dgp, class.source = "fold-show"------------------------------------------
# Parameters
baseline    <- 500
ratio_bias  <- 3    # late vs early slope multiplier
ratio_apos  <- 3    # men vs women slope multiplier
ages        <- 20:30
d_early     <- 25
d_late      <- 30

# Counterfactual earning slopes
s_f_25 <- baseline
s_f_30 <- baseline * ratio_bias
s_m_25 <- baseline * ratio_apos
s_m_30 <- baseline * ratio_apos * ratio_bias

# Treatment effects (ATT)
att_f  <- -6000
att_m  <- -2000

# Build each group
make_group <- function(sex, D, slope, att_level) {
  tibble(
    age    = ages,
    female = sex,
    D      = D,
    y_0    = slope * age,           # counterfactual (no treatment)
    y_1    = y_0 + ifelse(age >= D, att_level, 0),
    y      = y_1                     # observed
  )
}

data <- bind_rows(
  make_group(1, d_early, s_f_25, att_f),
  make_group(1, d_late,  s_f_30, att_f),
  make_group(0, d_early, s_m_25, att_m),
  make_group(0, d_late,  s_m_30, att_m)
)

## ----verify_pt_levels---------------------------------------------------------
# Calculate trend from pre-treatment (age D-1) to each subsequent age
trends <- function(g, d, a) {
  pre <- data |> filter(D == d, female == g, age == d-1) |> pull(y_0)
  post <- data |> filter(D == d, female == g, age == a) |> pull(y_0)
  post - pre
}

# Difference in trends between early and late cohorts
diff_trends <- function(g, a) trends(g, 25, a) - trends(g, 30, a)

expand_grid(
  female = c(0, 1),
  age    = ages
) |>
  mutate(
    diff_trend = purrr::map2_dbl(female, age, diff_trends),
    gender = factor(if_else(female == 1, "Women", "Men"), levels = c("Women", "Men"))
  ) |> 
  ggplot(aes(x = age, y = diff_trend, color = gender)) + 
  geom_line(linewidth = 1.1) + 
  labs(
    title = "Parallel Trends Violation in Levels by Gender",
    x = "Age", 
    y = "Difference in counterfactual trends", 
    color = "Gender"
  ) +
  theme(legend.position = "bottom")

## ----verify_normalized_pt-----------------------------------------------------
# Get counterfactual earnings for early cohort (APO)
get_apo <- function(g, a) {
  data |>
    dplyr::filter(D == 25, female == g, age == a) |>
    dplyr::pull(y_0)
}

expand_grid(
  female = c(0, 1),
  age    = ages
) |>
  mutate(
    diff_trend = purrr::map2_dbl(female, age, diff_trends),
    y_0 = purrr::map2_dbl(female, age, get_apo),
    norm_diff_trend = diff_trend / y_0,
    gender = factor(if_else(female == 1, "Women", "Men"), levels = c("Women", "Men"))
  ) |> 
  ggplot(aes(x = age, y = norm_diff_trend, color = gender)) + 
  geom_line(linewidth = 1.1) + 
  labs(
    title = "Normalized Parallel Trends Violation",
    x = "Age", 
    y = "Normalized difference in counterfactual trends", 
    color = "Gender"
  ) +
  scale_y_continuous(labels = scales::percent_format()) +
  theme(legend.position = "bottom")

## ----fig1_setup, fig.height=4-------------------------------------------------
plot_df <- bind_rows(
  # Early cohort: observed and counterfactual
  data %>%
    filter(D == d_early) %>%
    pivot_longer(cols = c(y_0, y), names_to = "series", values_to = "value"),
  # Late cohort: counterfactual only (not yet treated at these ages for early cohort)
  data %>%
    filter(D == d_late) %>%
    transmute(age, female, D, series = "y_0", value = y_0)
) %>%
  mutate(
    gender = factor(if_else(female == 1, "Women", "Men"), levels = c("Women", "Men")),
    cohort = factor(if_else(D == d_early, "Early (D=25)", "Late (D=30)"),
                    levels = c("Early (D=25)", "Late (D=30)")),
    series_label = case_when(
      cohort == "Early (D=25)" & series == "y"   ~ "Early — Observed",
      cohort == "Early (D=25)" & series == "y_0" ~ "Early — Counterfactual",
      cohort == "Late (D=30)"  & series == "y_0" ~ "Late — Counterfactual"
    ),
    linetype = if_else(series_label == "Early — Observed", "solid", "dashed"),
    color = case_when(
      series_label == "Early — Observed"       ~ "Obs.",
      series_label == "Early — Counterfactual" ~ "CF Early",
      TRUE                                     ~ "CF Late"
    )
  )

ggplot(plot_df, aes(x = age, y = value, group = series_label)) +
  geom_line(aes(linetype = linetype, color = color), linewidth = 1.1) +
  facet_wrap(~ gender, nrow = 1) +
  scale_linetype_identity() +
  scale_color_manual(
    values = c("Obs." = "#1f77b4", "CF Early" = "#9ecae1", "CF Late" = "#fdae6b")
  ) +
  labs(
    title = "Figure 1: Observed vs. Counterfactual Earnings",
    x = "Age", y = "Earnings", color = NULL
  ) +
  theme(legend.position = "bottom")

## ----fig2_did-----------------------------------------------------------------
# Calculate DID-imputed counterfactual
did_cf <- bind_rows(lapply(c(1, 0), function(g) {
  # Early cohort's pre-treatment level
  early_pre <- data %>%
    filter(female == g, D == d_early, age == d_early - 1) %>%
    pull(y)
  
  # Late cohort trend
  late_cf <- data %>% filter(female == g, D == d_late) %>% select(age, y_0)
  late_pre <- late_cf %>% filter(age == d_early - 1) %>% pull(y_0)
  
  # DID imputation: shift late trend to match early pre-treatment level
  tibble(
    age    = late_cf$age,
    female = g,
    D      = d_early,
    series = "y_cf_did",
    value  = early_pre + (late_cf$y_0 - late_pre)
  )
})) %>%
  mutate(
    gender = factor(if_else(female == 1, "Women", "Men"), levels = c("Women", "Men")),
    cohort = factor("Early (D=25)", levels = c("Early (D=25)", "Late (D=30)")),
    series_label = "Early — DID-imputed CF",
    linetype = "dotdash",
    color = "CF DID"
  )

plot_df_plus <- bind_rows(plot_df, did_cf)

ggplot(plot_df_plus, aes(x = age, y = value, group = series_label)) +
  geom_line(aes(linetype = linetype, color = color), linewidth = 1.1) +
  facet_wrap(~ gender, nrow = 1) +
  scale_linetype_identity() +
  scale_color_manual(
    values = c(
      "Obs." = "#1f77b4", "CF Early" = "#9ecae1", 
      "CF Late" = "#fdae6b", "CF DID" = "#2ca02c"
    )
  ) +
  labs(
    title = "Figure 2: DID-Imputed Counterfactual",
    subtitle = "Green line = what DID thinks would have happened without treatment",
    x = "Age", y = "Earnings", color = NULL
  ) +
  theme(legend.position = "bottom")

## ----fig3_decomposition-------------------------------------------------------
# Calculate true APO and ATE
early_true_age <- data %>%
  filter(D == d_early) %>%
  group_by(female, age) %>%
  summarise(APO_true = y_0, APO_obs = y, .groups = "drop") |> 
  mutate(ATE_true = APO_obs - APO_true)

# Calculate DID APO
early_did_age <- did_cf %>%
  group_by(female, age) %>%
  summarise(APO_did = value, .groups = "drop")

# Combine and calculate bias
summary_age <- early_true_age %>%
  left_join(early_did_age, by = c("female", "age")) %>%
  mutate(
    PT_bias = APO_did - APO_true,  # parallel trends bias
    ATE_did = APO_obs - APO_did,  # DID ATE = truth - PT bias
    gender  = factor(if_else(female == 1, "Women", "Men"), levels = c("Women", "Men"))
  )

# Prepare stacked bar chart data
apo_stack_age <- summary_age %>%
  transmute(
    gender, age = factor(age),
    measure = factor("APO", levels = c("APO", "ATE")),
    total = APO_did, comp_true = APO_true, comp_bias = PT_bias
  ) %>%
  pivot_longer(c(comp_true, comp_bias), names_to = "component", values_to = "value") %>%
  mutate(component = if_else(component == "comp_true", "Truth", "PT bias"))

ate_stack_age <- summary_age %>%
  transmute(
    gender, age = factor(age),
    measure = factor("ATE", levels = c("APO", "ATE")),
    total = ATE_did, comp_true = ATE_true, comp_bias = -PT_bias
  ) %>%
  pivot_longer(c(comp_true, comp_bias), names_to = "component", values_to = "value") %>%
  mutate(component = if_else(component == "comp_true", "Truth", "PT bias"))

bar_age_df <- bind_rows(apo_stack_age, ate_stack_age)

ggplot(bar_age_df, aes(x = age, y = value, fill = component)) +
  geom_col(width = 0.7) +
  facet_grid(rows = vars(measure), cols = vars(gender), scales = "free_y") +
  scale_fill_manual(
    values = c("Truth" = "#1f77b4", "PT bias" = "#fdae6b"),
    breaks = c("Truth", "PT bias")
  ) +
  labs(
    title = "Decomposing DID into effect and bias",
    subtitle = "Top row: DID APO | Bottom row: DID ATE",
    x = "Age", y = NULL, fill = NULL
  ) +
  theme(legend.position = "bottom")

## ----fig4_pt_ratio------------------------------------------------------------
pt_ratio_df <- summary_age %>%
  transmute(
    gender, age = factor(age),
    measure = factor("Ratios", levels = c("APO","ATE","Ratios")),
    component = "PT / DID APO",
    value = -PT_bias / APO_did
  )

bar_age_three <- bind_rows(
  apo_stack_age %>% mutate(measure = factor("APO", levels = c("APO","ATE","Ratios"))),
  ate_stack_age %>% mutate(measure = factor("ATE", levels = c("APO","ATE","Ratios"))),
  pt_ratio_df
)

ggplot(bar_age_three, aes(x = age, y = value, fill = component)) +
  geom_col(width = 0.7) +
  facet_grid(rows = vars(measure), cols = vars(gender), scales = "free_y") +
  scale_fill_manual(
    values = c(
      "Truth" = "#1f77b4", "PT bias" = "#fdae6b", "PT / DID APO" = "#99d8c9"
    )
  ) +
  labs(
    title = "Figure 4: Normalizing the PT Bias",
    subtitle = "Bottom row shows PT bias / DID APO",
    x = "Age", y = NULL, fill = NULL
  ) +
  theme(legend.position = "bottom")

## ----fig5_full_ratio----------------------------------------------------------
ratios_df <- summary_age %>%
  transmute(
    gender, age = factor(age),
    measure = factor("Ratios", levels = c("APO","ATE","Ratios")),
    `PT ratio`  = -PT_bias / APO_did,
    `ATE ratio` = ATE_true / APO_did
  ) %>%
  pivot_longer(
    cols = c(`PT ratio`, `ATE ratio`), 
    names_to = "component", 
    values_to = "value"
  )

bar_age_three2 <- bind_rows(
  apo_stack_age %>% mutate(measure = factor("APO", levels = c("APO","ATE","Ratios"))),
  ate_stack_age %>% mutate(measure = factor("ATE", levels = c("APO","ATE","Ratios"))),
  ratios_df
)

ggplot(bar_age_three2, aes(x = age, y = value, fill = component)) +
  geom_col(width = 0.7) +
  facet_grid(rows = vars(measure), cols = vars(gender), scales = "free_y") +
  scale_fill_manual(
    values = c(
      "Truth" = "#1f77b4", "PT bias" = "#fdae6b",
      "PT ratio" = "#99d8c9", "ATE ratio" = "#2ca25f"
    )
  ) +
  labs(
    title = "Figure 5: Complete Decomposition of the Normalized Ratio",
    subtitle = "Bottom row: Light teal (drops out when differencing) + Dark green (what remains)",
    x = "Age", y = NULL, fill = NULL
  ) +
  theme(legend.position = "bottom")

## ----fig6_multiplicative------------------------------------------------------

## sanity check
summary_age %>%
  transmute(
    gender, age = factor(age),
    measure = factor("Ratios", levels = c("APO","ATE","Ratios")),
    `ATE ratio` = ATE_true / APO_did,
    `ATE ratio check` = (ATE_true / APO_true) - (ATE_true / APO_true) * (PT_bias / APO_did)
  )

theta_df <- summary_age %>%
  transmute(
    gender, age = factor(age),
    APO_true, ATE_true, APO_did, PT_bias,
    theta_true = ATE_true / APO_true,
    bias_A = -PT_bias / APO_did,
    bias_B = - theta_true * (PT_bias / APO_did)
  )

theta_stack <- bind_rows(
  theta_df %>% transmute(
    gender, age, measure = factor("Ratios", levels = c("APO","ATE","Ratios")),
    component = "Truth", value = theta_true
  ),
  theta_df %>% transmute(
    gender, age, measure = factor("Ratios", levels = c("APO","ATE","Ratios")),
    component = "Bias A", value = bias_A
  ),
  theta_df %>% transmute(
    gender, age, measure = factor("Ratios", levels = c("APO","ATE","Ratios")),
    component = "Bias B", value = bias_B
  )
)

bar_age_three3 <- bind_rows(
  apo_stack_age %>% mutate(measure = factor("APO", levels = c("APO","ATE","Ratios"))),
  ate_stack_age %>% mutate(measure = factor("ATE", levels = c("APO","ATE","Ratios"))),
  theta_stack
)

ggplot(bar_age_three3, aes(x = age, y = value, fill = component)) +
  geom_col(width = 0.7) +
  facet_grid(rows = vars(measure), cols = vars(gender), scales = "free_y") +
  scale_fill_manual(
    values = c(
      "Truth" = "#1f77b4", "PT bias" = "#fdae6b",
      "Bias A" = "#99d8c9", "Bias B" = "#2ca25f"
    ),
    breaks = c("Truth", "PT bias", "Bias A", "Bias B"),
    labels = c("Truth", "PT bias", "PT / DID APO", "Truth × Bias")
  ) +
  labs(
    title = "Figure 6: The Multiplicative Bias Structure",
    subtitle = "Bottom row: Blue (true θ) + Teal (drops out) + Green (multiplicative bias)",
    x = "Age", y = NULL, fill = NULL
  ) +
  theme(legend.position = "bottom")

