# title: packages and functions for the molecular atlas of human granulopoiesis
# Last update: 28.10.24


# Packages ----------------------------------------------------------------

packages_general <- c(
  'tidyverse', # always tidy
  'purrr',
  'magrittr', # put it in a pipe
  'broom',
  'digest', # create hash for data identity
  'janitor', # format coherent names
  'clipr', # copy/paste
  'ggalt', # arrange ggplots
  'ggpubr',  # annotate plots
  'colorspace', 
  'readxl', # impoet exel data
  'data.table', # formatting that tidy cant do
  "ggrepel", # non overlapping labels in plots
  'Biobase', # Bioconductor
  'cowplot',
  'stevemisc', # helpful functions
  'remedy',
  'tinytex',
  'scales', # different scales in ggplot
  'ggh4x',
  'plyr',
  'ggpp',
  'SOfun',  # data shaping package
  'makeunique', 
  'knitr',
  NULL
)
# install.packages(packages_general) 

packages_tables <- c(
  'DT',
  'rmarkdown',
  'formattable',
  'gt', # make tables
  "gtExtras", # gt table options
  'webshot2' # save tables as files
)

# if (!require("BiocManager", quietly = TRUE)) install.packages("BiocManager")

packages_missingValues <- c(
  "naniar", 
  'visdat', 
  NULL
)


packages_comparisons <- c(
  'UpSetR', # compare plots
  'ComplexUpset',
  'qdapTools',
  NULL
)

packages_mRNA <- c(
  'DESeq2', # Bioconductor
  'apeglm',
  "vsn"
)

packages_normalisation <- c(
  'vsn', # @Bioconductor
  'hexbin',
  'preprocessCore', #@Bioconductor
  'limma', # @Bioconductor
  'corrplot',
  'compositions',
  'DEGreport',
  'EDASeq'
)

packages_annotation <- c(
  'AnnotationDbi', # anno
  'org.Hs.eg.db', # anno
  'clusterProfiler') # anno KEGG


packages_enrichment <- c(
  'ReactomePA', # for reactome anno from cluster profiler
  'GOSemSim',
  'AnnotationDbi', # anno
  'org.Hs.eg.db', # anno
  'clusterProfiler',
  'DOSE',
  "enrichplot",
  'msigdbr',
  "GSEABase",
  'viper',
  'goseq')

# install.packages(packages_enrichment)

packages_pca <- c(
  'PCAtools', # PCAs easy
  'factoextra',
  'Rtsne' ) #for t-SNE) # PCAs diverse

packages_cluster <- c(
  'factoextra',
  'cluster',
  'dendextend', #dendrograms
  'ape' # phylogenetic trees
)

packages_heatmap <- c(
  # 'pheatmap', 
  'dendextend',
  'RColorBrewer',
  'viridis',
  'dendsort',
  'ComplexHeatmap', # completHeatmaps
  'circlize', # colors for completHeatmaps
  'cluster',# cluster cH
  'magick',
  'seriation', # clustering of complex heatmap based on seriation
  NULL
)

packages_dex <- c(
  'limma',
  'statmod',
  'ffmanova',
  NULL
)

packages_dexViz <- c(
  'ggridges',
  "ggpointdensity",
  'vidger',
  NULL
)

packages_WGCNA <- c(
  'WGCNA',
  'CorLevelPlot',
  'gridExtra',
  NULL
)

packages_timecourse <- c(
  'maSigPro'
)

packages_mixomics <- c(
  'mixOmics',
  'caret'
)

packages_modeling <- c(
  'caret',
  'splines',
  'mgcv',
  'tidyclust',
  'tidymodels',
  NULL
)


# overwrite masked function to reinstate dplyr ----
select <- dplyr::select
rename <- dplyr::rename
group_by <- dplyr::group_by
summarise <- dplyr::summarise
mutate <- dplyr::mutate



# Colors ----

## stage ----
colors_stage <- rev(c(
  "#c8a600",
  "#ffac40",
  "#ff7763",
  "#ca0068",
  "#a80092",
  "#a043ec",
  "#0000e3"))

names(colors_stage) <- c( "MB", "PM", "MC", "MM", "B", "S", "PMN")


colors_stage.c <- rev(c(
  "#c8a600",
  "#a80092",
  "#a043ec",
  "#0000e3"))

names(colors_stage.c) <- c( "MB", "PM", "MC", "Mature")


colors_progression.arrow <- rev(c(
  "#c8a600",
  "#ffac40",
  "#ff7763",
  "#ca0068",
  "#a80092",
  "#a043ec",
  "#0000e3"))

names(colors_progression.arrow) <- c( "MB->PM", "PM->MC", "MC->MM", "MM->B", "B->S", "S->PMN")


colors_progression.to <- rev(c(
  "#c8a600",
  "#ffac40",
  "#ff7763",
  "#ca0068",
  "#a80092",
  "#a043ec",
  "#0000e3"))

names(colors_progression.to) <- c( "MBtoPM", "PMtoMC", "MCtoMM", "MMtoB", "BtoS", "StoPMN")


colors_stage_uniFin <- colors_stage[-c(4:6)]
# pal(colors_stage_uniFin)
names(colors_stage_uniFin) <- c(
  "MB", "PM", "MC", "fin (MM/B/PMN)")


colors_progression.arrow.mic <- rev(c(
  "#c8a600",
  "#ff7763",
  "#ca0068",
  "#a80092",
  "#a043ec",
  "#0000e3"))

names(colors_progression.arrow.mic) <- c( "MB->MC", "MC->MM", "MM->B", "B->S", "S->PMN")


colors_UpDown_n <- (c(
  '#f89540',
  '#7e03a8'
))

names(colors_UpDown_n) <- c('n_up', 'n_down')

color_up = c('#f89540')
color_down = c('#7e03a8')



# graphics ----


library(ggplot2)

theme_nada = theme(axis.title.x=element_blank(),
                  axis.title.y=element_blank())
theme_noY = theme(axis.title.y=element_blank())
theme_noX = theme(axis.title.x=element_blank())

# shapes (mrn & pro)
shps <- c(1, 16)





# Groups ----

levels_mrn <- c('MB', 'PM', 'MC', 'MM', 'B', 'S', 'PMN')
levels_pro <- c('MB', 'PM', 'MC', 'MM', 'B', 'PMN')
levels_mic <- c('MB', 'MC', 'MM', 'B', 'S', 'PMN')

levels_mrn_prog_to <- c('MBtoPM', 'PMtoMC', 'MCtoMM', 'MMtoB', 'BtoS', 'StoPMN')

comp_progression_mrn_to <- c('MBtoPM', 'PMtoMC', 'MCtoMM', 'MMtoB', 'BtoS', 'StoPMN')
comp_progression_mrn_arrow <- c('MB -> PM', 'PM -> MC', 'MC -> MM', 'MM -> B', 'B -> S', 'S -> PMN')
comp_vsPMN_mrn <- c('MBvsPMN', 'PMvPMN', 'MCvsPMN', 'MMvsPMN', 'BvsPMN', 'SvsPMN')

comp_progression_mrn_to_likePro <- c('MBtoPM', 'PMtoMC', 'MCtoMM', 'MMtoB', 'BtoS', 'BtoPMN')

comp_progression_pro_to <- c('MBtoPM', 'PMtoMC', 'MCtoMM', 'MMtoB', 'BtoPMN')
comp_progression_pro_arrow <- c('MB -> PM', 'PM -> MC', 'MC -> MM', 'MM -> B', 'B -> PMN')
comp_vsPMN_pro <- c('MBvsPMN', 'PMvPMN', 'MCvsPMN', 'MMvsPMN', 'BvsPMN')


comp_progression_mic_to <- c('MBtoMC', 'MCtoMM', 'MMtoB', 'BtoS', 'StoPMN')
comp_vsPMN_mic <- c('MBvsPMN','MCvsPMN', 'MMvsPMN', 'BvsPMN', 'SvsPMN')



# Own functions ----

## general ----

# convert df to matrix
# df2m <- function(X) {
#   if (!methods::is(X, "matrix")) {
#     m <- as.matrix(X[, which(vapply(X, is.numeric, logical(1)))])
#   }
#   else {
#     m <- X
#   }
#   m
# }

add.fan2ENSG <- function(df){
  df$SYMBOL <- fan$SYMBOL[match(df$ENSG, fan$ENSG)]
  df$UNIPROT <- fan$UNIPROT[match(df$ENSG, fan$ENSG)]
  df$ENTREZ <- fan$ENTREZ[match(df$ENSG, fan$ENSG)]
  return(df)
}

add.fan2UNIPROT<- function(df){
  df$SYMBOL <- fan$SYMBOL[match(df$UNIPROT, fan$UNIPROT)]
  df$ENSG <- fan$ENSG[match(df$UNIPROT, fan$UNIPROT)]
  df$ENTREZ <- fan$ENTREZ[match(df$UNIPROT, fan$UNIPROT)]
  return(df)
}

add.fan2SYMBOL <- function(df){
  df$ENSG <- fan$ENSG[match(df$SYMBOL, fan$SYMBOL)]
  df$UNIPROT <- fan$UNIPROT[match(df$SYMBOL, fan$SYMBOL)]
  #df$mrn.present <- ifelse(df$ENSG %in% exp_mrnWc$ENSG, T, F)
  #df$pro.present <- ifelse(df$ENSG %in% exp_proWc$ENSG, T, F)
  return(df)
}


# quickcheck expression values
plotL.SYMBOL <- function(expL, SYM){
  expL %>% filter(SYMBOL == SYM) %>%
    ggplot(aes(x = stage, y = value, fill = stage))+
    geom_boxplot() +
    geom_jitter(color="black", size=0.4, alpha=0.9) +
    scale_fill_manual(values = colors_stage)
}




# turn fid into rownames
FIDtoRow <- function(df){
  df <- as.data.frame(df)
  rownames(df) <- df$fid
  df$fid <- NULL
  return(df)
}

IDtoRow <- function(df){
  df <- as.data.frame(df)
  rownames(df) <- df$id
  df$id <- NULL
  return(df)
}


makeFIDcols_SYMBOL <- function(exp, fan, fid){
  if(fid == 'ENSG'){
    colnames(exp) <- fan$SYMBOL[match(colnames(exp), fan$ENSG)]}
  else if(fid == 'UNIPROT') {
    colnames(exp) <- fan$SYMBOL[match(colnames(exp), fan$UNIPROT)]}
  else{
    paste0("Please define fid cols as ENSG or UNIPROT")}
  return(exp)
}

# add stage
add_stage <- function(df, san, levels){
  df$stage <- factor(san$stage[match(df$id, san$id)], levels = levels)
  return(df)
}

# create list with pattern
df_list <- mget(ls(pattern= "^df"))



# form medians
turn_exp_into_long_medians <- function(exp, san){
  exp$id <- rownames(exp)
  expL <- exp %>% pivot_longer(!id, names_to = "fid", values_to = "value")
  expL$stage <- san$stage[match(expL$id, san$id)]
  medians <- expL %>% group_by(fid, stage) %>%
    dplyr::summarise(median = median(value, na.rm = T))
  return(medians)
}


turn.ENSGintoENTREZ <- function(vector.ENSG){
  vector.ENTREZ <- fan$ENTREZ[match(vector.ENSG, fan$ENSG)]
}

# factorize san
san_factorisator <- function(san){
  as_factors <- names(san)
  not_as_factors <- c("id", "id_primary_sample")
  as_factors <- as_factors[! as_factors %in% not_as_factors]
  san[as_factors] <- lapply(san[as_factors], as.factor)
  return(san)
}

UniToRow <- function(df){
  df <- as.data.frame(df)
  rownames(df) <- df$UNIPROT
  df$UNIPROT <- NULL
  return(df)
}

ENSG2Row <- function(df){
  df <- as.data.frame(df)
  rownames(df) <- df$ENSG
  df$ENSG <- NULL
  return(df)
}

SYM2Row <- function(df){
  df <- as.data.frame(df)
  rownames(df) <- df$SYMBOL
  df$SYMBOL <- NULL
  return(df)
}

Id2Row <- function(df){
  df <- as.data.frame(df)
  rownames(df) <- df$id
  df$id <- NULL
  return(df)
}

transpose_UNIPROTcols <- function(df){
  UNIPROT <- df$UNIPROT
  df$UNIPROT <- NULL
  df_t <- as.data.frame(t(df))
  colnames(df_t) <- UNIPROT
  return(df_t)
}

transpose_FIDcols <- function(df){
  fid <- df$fid
  df$fid <- NULL
  df_t <- as.data.frame(t(df))
  colnames(df_t) <- fid
  return(df_t)
}

transpose_2SYMBOLcols <- function(df){
  SYMBOL <- df$SYMBOL
  df$SYMBOL <- NULL
  df_t <- as.data.frame(t(df))
  colnames(df_t) <- SYMBOL
  return(df_t)
}


#extract significant results
get_sigs <- function(df, adj.p, FC){
  df <- df[df$adj.P.Val < adj.p & abs(df$logFC) >= FC,]
  return(df)
}

# normalize by healthy expression
create_hdFCs <- function(exp_df, ids_healthy) {
  exp_hd <- exp_df[,colnames(exp_df) %in% ids_healthy]
  median_hd = apply(exp_hd, 1, median, na.rm = T)
  hdFCs <- exp_df - median_hd
  return(hdFCs)
}


# Uniques
isUnique <- function(vector){
  return(!any(duplicated(vector)))
}

# NA in features per group
NAInGroupsPerFeature <- function(exp_raw, san_raw){
  # transpose
  #exp <- as.data.frame(t(exp_mox))
  # exp$fid <- rownames(exp)
  # long & group
  exp_long <- exp %>% pivot_longer(!fid, names_to = 'id_moxI', values_to = "value")
  exp_long$pop_gen <- san_mox$pop_gen[match(exp_long$id_moxI, san_mox$id_moxI)]
  # not NA/group
  notNAperGeneANDgroup <- exp_long %>% 
    group_by(fid, pop_gen) %>% 
    summarise(sum_notNA = sum(!is.na(value)))
  # create filter
  fid_min2_in2Groups <- notNAperGeneANDgroup %>% group_by(fid) %>%
    filter(sum(sum_notNA>1)>1, sum_notNA > 1) %>%
    summarise(tot_valid = n(), valid_groups = str_c(pop_gen, collapse = ';')) %>%
    pull(fid) %>% unique()
  # apply filter
  exp_mox_filtered <- exp_mox[, colnames(exp_mox) %in% fid_min2_in2Groups]
  # export
  list <- list(
    exp_long = exp_long,
    notNAperGeneANDgroup = notNAperGeneANDgroup,
    fid_min2_in2Groups = fid_min2_in2Groups, 
    exp_mox_filtered = exp_mox_filtered
  )
  return(list)
}

# remove all NA cols 
Cols_AllMissing <- function(df){ # helper function
  as.vector(which(colSums(is.na(df)) == nrow(df)))
}

# format name prot to upper
firstup <- function(x) {
  substr(x, 1, 1) <- toupper(substr(x, 1, 1))
  x
}



## specific for packages ----

### PCAtools ----

# function includes reduce var = 10%
# CAVE: exp must be all numeric except for gene ID

prep_pca <- function(group_ids, exp, san) {
  # prep exp
  exp_pca <- exp[, colnames(exp) %in%
                   c(group_ids, "fid")]
  exp_pca <- FIDtoRow(exp_pca)
  exp_pca <- exp_pca[complete.cases(exp_pca),]
  exp_pca <- as.data.frame(exp_pca)
  # prep san
  ids <- as.factor(colnames(exp_pca))
  san <- as.data.frame(san)
  # san_pca <- san[san['id'] %in% ids,]
  san_pca <- san[san$id %in% ids,]
  names_exp <- names(exp_pca)
  san_pca <- as_tibble(san_pca)
  san_pca <-  san_pca %>% dplyr::arrange(factor(san_pca$id, levels = names_exp))
  #san_pca <- san_pca[match(names, san_pca[,'id']),]
  san_pca <- as.data.frame(san_pca)
  rownames(san_pca) <- san_pca$id
  identical(colnames(exp_pca), rownames(san_pca))
  exp_pca <- as.data.frame(exp_pca)
  san_pca <- as.data.frame(san_pca)
  # prep pca
  pca_dat <- PCAtools::pca(exp_pca, metadata = san_pca, removeVar = 0.1)
  return(pca_dat)
}

## LIMMA ----

#limma fct (gives out fit1, 2, 3 and the summary)
# uses trend = T to contain lin mods of all groups and robust = T to adjust for heteroscedasticity

limma_getRes <- function(df, design, contrast) {
  fit1 <- lmFit(df, design)
  fit2 <- contrasts.fit(fit1, contrasts = contrast)
  fit3 <- eBayes(fit2, trend = T, robust = T)
  summary <- decideTests(fit3, p.value= 0.05, lfc=1, method = "separate", adjust.method = "BH")
  results <- list(
    fit1 = fit1,
    fit2 = fit2,
    fit3 = fit3,
    summary = summary)
  return(results)
}



limma_getResn <- function(df, design, contrast) {
  fit0<-normalizeBetweenArrays(df,method = "quantile")
  fit1 <- lmFit(fit0, design)
  fit2 <- contrasts.fit(fit1, contrasts = contrast)
  fit3 <- eBayes(fit2, trend = T, robust = T)
  summary <- decideTests(fit3, p.value= 0.05, lfc=1, method = "separate", adjust.method = "BH")
  results <- list(
    fit1 = fit1,
    fit2 = fit2,
    fit3 = fit3,
    summary = summary)
  return(results)
}


FC_level_mutater <- function (df) {
  df %<>% mutate(
    FC_level = case_when(
      FDR <= 5 & logFC < -0.5 & logFC > -1 ~ "down_0.5",
      FDR <= 5 & logFC < -1 & logFC > -2 ~ "down_1",
      FDR <= 5 & logFC < -2 & logFC > -3 ~ "down_2",
      FDR <= 5 & logFC < -3 & logFC > -4 ~ "down_3",
      FDR <= 5 & logFC < -4 ~ "down_4",
      FDR <= 5 & logFC > 0.5 & logFC < 1 ~ "up_0.5",
      FDR <= 5 & logFC > 1 & logFC < 2 ~ "up_1",
      FDR <= 5 & logFC > 2 & logFC < 3 ~ "up_2",
      FDR <= 5 & logFC > 3 & logFC < 4 ~ "up_3",
      FDR <= 5 & logFC > 4 ~ "up_4",
      FDR <= 5 & logFC > -0.5 & logFC < 0.5 ~ "NS",
      FDR > 5 ~ "NS"))}


## cluster ----

# cluster size
asses_custerSize <- function(exp) {
  # WSS
  plot_wss <- fviz_nbclust(exp, FUN = hcut, method = "wss") +
    labs(title= "Sums of squares") 
  # average silhouette width
  plot_sil <- fviz_nbclust(exp, FUN = hcut, method = "silhouette") +
    labs(title= "Silhouette width")
  # gap statistics
  gap_stat <- clusGap(exp, FUN = hcut, nstart = 25, K.max = 10, B = 50)
  plot_gap <- fviz_gap_stat(gap_stat) + labs(title= "Gap statistic")
  # distance matrix
  dist <- dist(exp, method = "euclidean")
  plot_DistMat <- fviz_dist(dist, 
                            lab_size = 5,
                            gradient = list(low = "#00AFBB", mid = "white", high = "#FC4E07"))
  plot_DistMat <- plot_DistMat + ggtitle("Distance matrix")
  # collect plots
  plot_list <- list(
    wss = plot_wss,
    sil = plot_sil,
    gap = plot_gap,
    DistMat = plot_DistMat)
  return(plot_list)
}

# kMeans
# calc kmeans
calc_kmeans <- function(df, centers_nr, nstart_nr){
  set.seed(123)
  k <- kmeans(df, centers = centers_nr, nstart = nstart_nr)
  return(k)
}

# plot kmeans

plot_kmeans <- function(exp, k, title) {
  set.seed(123)
  # calculate
  kmeans <- calc_kmeans(exp, k, 50)
  # plot
  plot <- fviz_cluster(kmeans, data = exp, repel = T ) +
    ggtitle(title)
  return(plot)
}

## others ----

# list2env(exp_list_prim, envir=.GlobalEnv)

## remove NA cols fct
delete.na <- function(DF, n=0) {
  DF[rowSums(is.na(DF)) <= n,]
}

## find rows with all NA
# DF %>% filter_at(vars(!starts_with("U")), all_vars(is.na(.)))