## ----setup, include = FALSE, warning = FALSE----------------------------------
knitr::opts_chunk$set(comment = FALSE, 
                      message = FALSE,
                      fig.align = "center")

## ----load-packages, message=FALSE---------------------------------------------
library(cellmig)
library(ggplot2)
library(ggforce)
library(patchwork)
library(rstan)

# Set a clean theme for all plots
ggplot2::theme_set(new = theme_bw(base_size = 10))

## ----sim-partial--------------------------------------------------------------
# Set seed for reproducibility
set.seed(1253)

# Generate synthetic data
y_p <- gen_partial(control = list(
  N_biorep = 8,                  # Number of plates
  N_techrep = 5,                 # Wells per treatment per plate
  N_cell = 50,                   # Cells per well
  delta = c(-0.4, -0.2, -0.1, 0, 0.1, 0.2, 0.4), # True treatment effects
  sigma_bio = 0.1,               # Biological variability
  sigma_tech = 0.05,             # Technical variability
  offset = 4,                    # Index of the control treatment
  prior_alpha_p_M = 0,           # Mean plate baseline
  prior_alpha_p_SD = 0.1,        # SD of plate baseline
  prior_kappa_mu_M = 1.5,        # Mean log(shape parameter)
  prior_kappa_mu_SD = 0.1,       # SD of log(shape parameter)
  prior_kappa_sigma_M = 0,       # Fixed sigma for shape
  prior_kappa_sigma_SD = 0.1     # SD of sigma for shape
))

## ----view-sim-data------------------------------------------------------------
str(y_p$y)

## ----plot-sim-raw, fig.width=7, fig.height=3----------------------------------
ggplot(data = y_p$y) +
  geom_sina(aes(x = paste0("t=", group), col = paste0("p=", plate), 
                y = v, group = well), size = 0.5) +
  xlab("Treatment Group") +
  ylab("Migration Speed (µm/min)") +
  scale_color_grey(name = "Plate") +
  scale_y_log10() +
  theme(legend.position = "top")

## ----plot-sim-well, fig.width=7, fig.height=3---------------------------------
well_means <- aggregate(v ~ well + group + plate, data = y_p$y, FUN = mean)

ggplot(data = well_means) +
  geom_sina(aes(x = paste0("t=", group), col = paste0("p=", plate), 
                y = v, group = well), size = 1) +
  xlab("Treatment Group") +
  ylab("Mean Well Speed") +
  scale_color_grey(name = "Plate") +
  scale_y_log10() +
  theme(legend.position = "top")

## ----format-sim-data----------------------------------------------------------
sim_data <- y_p$y

# Format columns
sim_data$well <- as.character(sim_data$well)
sim_data$compound <- as.character(sim_data$compound)
sim_data$plate <- as.character(sim_data$plate)

# Define offset (Control = Group 4)
sim_data$offset <- 0
sim_data$offset[sim_data$group == 4] <- 1

## ----fit-sim-data, fig.width=7, fig.height=3.5--------------------------------
osd <- cellmig(x = sim_data,
               control = list(mcmc_warmup = 250,
                              mcmc_steps = 800,
                              mcmc_chains = 2,
                              mcmc_cores = 2,
                              mcmc_algorithm = "NUTS",
                              adapt_delta = 0.8,
                              max_treedepth = 10))

## ----check-alpha, fig.width=5, fig.height=3-----------------------------------
ggplot(data = osd$posteriors$alpha_p) +
  geom_point(aes(y = plate_id, x = exp(mean))) +
  geom_errorbarh(aes(y = plate_id, x = exp(mean), xmin = exp(X2.5.), 
                     xmax = exp(X97.5.)), height = 0.1) +
  # Overlay true values in red
  geom_point(data = data.frame(v = exp(y_p$par$alpha_p),
                               plate = osd$posteriors$alpha_p$plate_id),
             aes(y = plate, x = v), col = "red", size = 2) +
  xlab(label = expression("Plate Baseline ("*alpha[p]*"')")) +
  ylab("Plate ID") +
  scale_y_continuous(breaks = osd$posteriors$alpha_p$plate_id)

## ----check-delta, fig.width=5, fig.height=3-----------------------------------
# Note: Group 4 is the control (offset), so it is not estimated
ggplot(data = osd$posteriors$delta_t) +
  geom_point(aes(y = group_id, x = mean)) +
  geom_errorbarh(aes(y = group_id, x = mean, xmin = X2.5., xmax = X97.5.), 
                 height = 0.1) +
  # Overlay true values (excluding control)
  geom_point(data = data.frame(delta = c(-0.4, -0.2, -0.1, 0.1, 0.2, 0.4), 
                               group_id = 1:6),
             aes(y = group_id, x = delta), col = "red", size = 2) +
  xlab(label = expression("Treatment Effect ("*delta[t]*")")) +
  ylab("Treatment Group") +
  scale_y_continuous(breaks = 1:8)

## ----check-sigma, fig.width=5, fig.height=3-----------------------------------
plot(osd$fit, par = c("sigma_bio", "sigma_tech", "sigma_delta"))

## ----sim-full-control---------------------------------------------------------
control_full <- list(
  N_biorep = 4, 
  N_techrep = 3, 
  N_cell = 30, 
  N_group = 8,
  # Priors for plate baselines
  prior_alpha_p_M = -0.5,
  prior_alpha_p_SD = 1.0,
  # Priors for Gamma shape
  prior_kappa_mu_M = 1.5,
  prior_kappa_mu_SD = 1.0,
  prior_kappa_sigma_M = 0,
  prior_kappa_sigma_SD = 1.0,
  # Priors for variability
  prior_sigma_bio_M = 0.0,
  prior_sigma_bio_SD = 1.0,
  prior_sigma_tech_M = 0.0,
  prior_sigma_tech_SD = 1.0,
  prior_sigma_delta_M = 0.0,
  prior_sigma_delta_SD = 1.0
)

## ----sim-full-----------------------------------------------------------------
y_f <- gen_full(control = control_full)
str(y_f)

## ----compare-sim-modes, fig.width=5, fig.height=5-----------------------------
# Compare a subset of velocities
w <- data.frame(v_f = y_f$y$v[1:2000], v_p = y_p$y$v[1:2000])

ggplot(data = w) +
  geom_point(aes(x = v_f, y = v_p), size = 0.5, alpha = 0.5) +
  geom_density_2d(aes(x = v_f, y = v_p), col = "orange") +
  scale_x_log10(name = "Fully Generative Speed", limits = c(0.01, 10^4)) +
  scale_y_log10(name = "Partially Generative Speed", limits = c(0.01, 10^4)) +
  annotation_logticks(base = 10, sides = "lb") +
  theme_bw()

## ----power-analysis, eval = FALSE, message=FALSE, warning=FALSE---------------
# # --- Configuration ---
# N_bioreps <- c(3, 6, 9)      # Replicate scenarios to test
# N_sim <- 10                  # Number of simulations per scenario
# true_deltas <- c(-0.3, -0.15, 0, 0.2, 0.4) # Effects to test
# offset <- 3                  # Control group index
# 
# # Store results
# deltas <- vector(mode = "list", length = length(N_bioreps) * N_sim)
# i <- 1
# 
# for(N_biorep in N_bioreps) {
#   for(b in 1:N_sim) {
# 
#     # 1. Simulate data
#     y_p <- gen_partial(control = list(
#       N_biorep = N_biorep,
#       N_techrep = 3,
#       N_cell = 40,
#       delta = true_deltas,
#       sigma_bio = 0.1,
#       sigma_tech = 0.05,
#       offset = offset,
#       prior_alpha_p_M = -0.5,
#       prior_alpha_p_SD = 0.1,
#       prior_kappa_mu_M = 1.5,
#       prior_kappa_mu_SD = 0.1,
#       prior_kappa_sigma_M = 0,
#       prior_kappa_sigma_SD = 0.1
#     ))
# 
#     # 2. Format data
#     sim_data <- y_p$y
#     sim_data$well <- as.character(sim_data$well)
#     sim_data$compound <- as.character(sim_data$compound)
#     sim_data$plate <- as.character(sim_data$plate)
#     sim_data$offset <- 0
#     sim_data$offset[sim_data$group == offset] <- 1
# 
#     # 3. Fit model
#     o <- cellmig(x = sim_data,
#                  control = list(mcmc_warmup = 300,
#                                 mcmc_steps = 1000,
#                                 mcmc_chains = 1,
#                                 mcmc_cores = 1, # Increase for speed
#                                 mcmc_algorithm = "NUTS",
#                                 adapt_delta = 0.8,
#                                 max_treedepth = 10))
# 
#     # 4. Evaluate Performance
#     delta <- o$posteriors$delta_t
#     delta$b <- b
#     delta$N_biorep <- N_biorep
#     delta$true_deltas <- true_deltas[-offset]
# 
#     # True Positive: HDI excludes 0 AND includes true value
#     delta$TP <- (delta$X2.5. <= delta$true_deltas &
#                    delta$X97.5. >= delta$true_deltas) &
#       !(delta$X2.5. <= 0 & delta$X97.5. >= 0)
# 
#     deltas[[i]] <- delta
#     i <- i + 1
#   }
# }
# 
# # Combine results
# deltas <- do.call(rbind, deltas)

## ----plot-power, eval = FALSE, fig.height=4, fig.width = 7--------------------
# ggplot(data = aggregate(TP ~ N_biorep + true_deltas, data = deltas, FUN = sum)) +
#   geom_point(aes(x = N_biorep, y = TP, col = abs(true_deltas),
#                  group = as.factor(true_deltas)), size = 2, alpha = 0.5) +
#   geom_line(aes(x = N_biorep, y = TP, col = abs(true_deltas),
#                 group = as.factor(true_deltas)), alpha = 0.5) +
#   ylab("Number of True Positives (TPs)") +
#   xlab("Number of Biological Replicates") +
#   scale_color_distiller(name = expression("|"*delta[t]*"|"), palette = "Spectral") +
#   theme_bw()

## -----------------------------------------------------------------------------
sessionInfo()

