## ----setup, include = FALSE, warning = FALSE----------------------------------
knitr::opts_chunk$set(comment = FALSE, 
                      warning = FALSE, 
                      message = FALSE)

## -----------------------------------------------------------------------------
library(cellmig)
library(ggplot2)
library(ggforce)
library(rstan)
ggplot2::theme_set(new = theme_bw(base_size = 10))

## ----eval=FALSE---------------------------------------------------------------
# if (!require("BiocManager", quietly = TRUE))
#     install.packages("BiocManager")
# 
# BiocManager::install("cellmig")

## ----load-data----------------------------------------------------------------
data("d", package = "cellmig")
str(d)
head(d)

## ----fig.width=7, fig.height=6------------------------------------------------
ggplot(data = d)+
  facet_wrap(facets = ~paste0("compound=", compound), 
             scales = "free_y", ncol = 2)+
  geom_sina(aes(x = as.factor(dose), col = plate, y = v, group = well), 
            size = 0.5)+
  theme_bw()+
  theme(legend.position = "top",
        strip.text.x = element_text(margin = margin(0.03,0,0.03,0, "cm")))+
  ylab(label = "migration velocity")+
  xlab(label = '')+
  scale_color_grey()+
  guides(color = guide_legend(override.aes = list(size = 3)))+
  guides(shape = guide_legend(override.aes = list(size = 3)))+
  scale_y_log10()+
  annotation_logticks(base = 10, sides = "l")

## ----fig.width=7, fig.height=6------------------------------------------------
dm <- aggregate(v~well+plate+compound+dose, data = d, FUN = mean)
ggplot(data = dm)+
  facet_wrap(facets = ~paste0("compound=", compound), 
             scales = "free_y", ncol = 2)+
  geom_sina(aes(x = as.factor(dose), col = plate, y = v, group = well), 
            size = 1.5, alpha = 0.7)+
  theme_bw()+
  theme(legend.position = "top",
        strip.text.x = element_text(margin = margin(0.03,0,0.03,0, "cm")))+
  ylab(label = "migration velocity")+
  xlab(label = '')+
  scale_color_grey()+
  guides(color = guide_legend(override.aes = list(size = 3)))+
  guides(shape = guide_legend(override.aes = list(size = 3)))+
  scale_y_log10()+
  annotation_logticks(base = 10, sides = "l")

## ----fit-model, fig.width=7, fig.height=3.5-----------------------------------
o <- cellmig(x = d,
             control = list(mcmc_warmup = 300,  # Warmup iterations
                            mcmc_steps = 1000,  # Sampling iterations
                            mcmc_chains = 2,    # Number of chains
                            mcmc_cores = 2))    # Parallel cores

## ----view-delta---------------------------------------------------------------
knitr::kable(o$posteriors$delta_t, digits = 2)

## ----plot-delta, fig.width=6, fig.height=4------------------------------------
ggplot(data = o$posteriors$delta_t) +
  geom_line(aes(x = dose, y = mean, col = compound, group = compound)) +
  geom_point(aes(x = dose, y = mean, col = compound)) +
  geom_errorbar(aes(x = dose, y = mean, ymin = X2.5., ymax = X97.5., 
                    col = compound), width = 0.1) +
  ylab(label = expression("Log-Fold Change ("*delta*")")) +
  xlab("Dose") +
  theme(legend.position = "top")

## ----plot-fold-change, fig.width=6, fig.height=4------------------------------
ggplot(data = o$posteriors$delta_t) +
  geom_line(aes(x = dose, y = exp(mean), col = compound, group = compound)) +
  geom_point(aes(x = dose, y = exp(mean), col = compound)) +
  geom_errorbar(aes(x = dose, y = exp(mean), ymin = exp(X2.5.), 
                    ymax = exp(X97.5.), col = compound), width = 0.1) +
  ylab(label = expression("Fold Change ("*delta*"')")) +
  xlab("Dose") +
  theme(legend.position = "top")

## -----------------------------------------------------------------------------
# Get pairwise comparisons (log-scale)
u <- get_pairs(x = o, exponentiate = FALSE)

## -----------------------------------------------------------------------------
# vislualize matrix of rhos
u$plot_rho

## -----------------------------------------------------------------------------
# visualize matrix of pis
u$plot_pi

## -----------------------------------------------------------------------------
# visualize volcano plot
u$plot_volcano

## ----get-groups---------------------------------------------------------------
# View available group labels
groups <- get_groups(x = o)
head(groups)

## ----plot-violins, fig.width=7, fig.height=3----------------------------------
# Compare all groups against Compound 2, Dose 1
u_violin <- get_violins(x = o, 
                        from_groups = groups$group,
                        to_group = "C2|D1",
                        exponentiate = FALSE)
u_violin$plot

## ----ppc-cell, fig.width=6, fig.height=9--------------------------------------
g <- get_ppc_violins(x = o, wrap = TRUE, ncol = 3)
g + scale_y_log10()

## ----ppc-well, fig.width=5, fig.height=5--------------------------------------
g <- get_ppc_means(x = o)
g

## -----------------------------------------------------------------------------
# run loo
loo_o <- rstan::loo(x = o$fit)

# Print diagnostic table
print(loo_o)

# Plot diagnostic estimates for each cell
plot(loo_o)

# Which cells have k>0.7? Inspect these cells
which(loo_o$pointwise[,"influence_pareto_k"]>0.7)

## ----rhat, fig.width=3, fig.height=3------------------------------------------
rstan::stan_rhat(o$fit)

## ----mcmc-summary-------------------------------------------------------------
summary_stats <- summary(o$fit)$summary
hist(summary_stats[,"n_eff"])

## -----------------------------------------------------------------------------
rstan::check_hmc_diagnostics(object = o$fit)

## ----plot-variance, fig.height=3, fig.width=7---------------------------------
# Plate-specific baseline effects
g_alpha_p <- ggplot(data = o$posteriors$alpha_p) +
  geom_errorbarh(aes(y = plate, x = mean, xmin = X2.5., xmax = X97.5.),
                 height = 0.2) +
  geom_point(aes(y = plate, x = mean)) +
  xlab("Plate Effect (log-scale)")

# Variance parameters (Biological vs Technical)
g_sigma <- ggplot() +
  geom_errorbarh(data = o$posteriors$sigma_bio,
                 aes(y = "Biological (Plate)",
                     x = mean, xmin = X2.5., xmax = X97.5.), height = 0.2) +
  geom_errorbarh(data = o$posteriors$sigma_tech,
                 aes(y = "Technical (Well)",
                     x = mean, xmin = X2.5., xmax = X97.5.), height = 0.2) +
  geom_errorbarh(data = o$posteriors$sigma_delta,
                 aes(y = "Treatment Variation",
                     x = mean, xmin = X2.5., xmax = X97.5.), height = 0.2) +
  geom_point(data = o$posteriors$sigma_bio,
             aes(y = "Biological (Plate)", x = mean)) +
  geom_point(data = o$posteriors$sigma_tech,
             aes(y = "Technical (Well)", x = mean)) +
  geom_point(data = o$posteriors$sigma_delta,
             aes(y = "Treatment Variation", x = mean)) +
  xlab("Standard Deviation")

g_alpha_p | g_sigma

## ----dose-response, fig.width=8, fig.height=5---------------------------------
get_dose_response_profile(x = o, exponentiate = TRUE) +
  patchwork::plot_layout(widths = c(.7, 1, 2))

## -----------------------------------------------------------------------------
sessionInfo()

