#' @title VCG Sampler for Energy Distance Balancing
#' @description This function performs energy distance based balancing and selects a subset from pool based on energy distance to approximate a randomized control trial. Optionally, it visualizes the balancing results.
#' @param formula A formula specifying the treated indicator and covariates, e.g., `treated ~ cov1 + cov2 | stratum`. The treated variable must be binary (0=pool, 1=treated)
#' @param data A data frame containing the variables specified in the formula.
#' @param n Integer. Number of observations to sample from the pool, or a vector of n for each stratum
#' @param c_w Optional: Vector of positive weights for covariates, reflecting the relative importance of the covariates for balancing.
#' @param random Logical. If `TRUE`, the distance is used as the probability for selecting the observation; otherwise, the nearest observations are used (deterministic). Default: FALSE
#' @param plot Logical. If `TRUE`, returns a visualization of the balancing effect.
#' @return If `plot = TRUE`, returns a list with:
#' \itemize{
#'   \item A data frame with added columns:
#'     \itemize{
#'       \item `VCG`: Indicator for selected pool units. VCG==1 indicates the VCG selected.
#'       \item `e_weights`: Energy weights used for selection
#'       \item `<treated>_balanced`: A factor indicating balanced treated assignment.
#'     }
#'   \item A ggplot2 object showing the median and MAD differences before and after balancing,
#'         with a 95% permutation ellipse as an approximation for typical random deviations.
#' }
#' If `plot = FALSE`, returns only the modified data frame.
#' @details
#'
#' If random is set to FALSE, the function selects the top `n` units from the pool with the lowest energy distance and assigns them to the VCG group.
#' If random is set to TRUE, the function samples  `n` units from pool with sampling probability inversely proportional to energy distance.
#' The quality of covariate balancing is visualized using differences in medians and median absolute deviations (MADs).
#' Permutation ellipses are generated by randomly permuting the pool and treated groups to estimate usual (random) variability.
#' Only the X and Y axes are computed directly; the ellipse is interpolated between the axes.
#' This method is intended as a visual approximation rather than a precise statistical test.
#' @examples
#'
#' dat   <- data.frame(
#'   cov1  = rnorm(50, 10, 1),
#'   cov2  = rnorm(50, 7,  1),
#'   cov3  = rnorm(50, 5,  1),
#'   treated = rep(c(0, 1), c(35, 15))
#' )
#'   VCG_sampler(treated ~ cov1 + cov2 + cov3, data=dat, n=5)
#'
#' @rdname VCG_sampler
#' @export
#' @importFrom osqp solve_osqp osqpSettings
#' @importFrom ggforce geom_ellipse
#' @importFrom ggplot2 ggplot geom_point geom_segment geom_text geom_vline geom_hline labs theme_minimal xlab ylab aes arrow unit xlim ylim expansion scale_y_continuous
#' @importFrom stats median mad quantile sd as.formula density dist
#' @importFrom grDevices rgb






#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%#
# Energy sampling
VCG_sampler <- function(formula, data, n, c_w=NULL, random=FALSE, plot=TRUE){
  x0 <- y0 <- a <- b <- angle <- x <- y <- xend <- yend <- label <- NULL

  if(!is.null(c_w)){
    c_w <- c_w[which(!is.na(c_w))]
    c_w <- c_w[which(is.finite(c_w))]
    if(length(c_w)<2) c_w <- NULL
  }


  energy_sampler_w <- function(covs, treat, c_w = NULL, w = NULL, scale=TRUE) {
    n <- length(treat)

    if(n!=Filter(Negate(is.null), c(nrow(covs), length(covs)))[1]) stop('Length of treat weight must be the same as number of rows in covs')


    if (is.null(w)) w <- rep(1, n)

    # Scale covariates
    if(scale) covs <- scale(covs)

    if(!is.null(c_w)){
      if(length(c_w)!=Filter(Negate(is.null), c(ncol(covs), 1))[1]) stop('Length of covariate weight (c_w) must be the same as number of cols in covs')
      if(any(c_w<=0)) stop('Covariate weights of 0 or negative are not meaningful here. In this case, the variable should be removed from the analysis entirely.')

      covs <- sweep(covs, 2, c_w, `*`)

    }


    # Compute Euclidean distance matrix
    d <- as.matrix(dist(covs, method = "euclidean"))

    treat <- as.factor(treat)
    t0 <- which(treat == levels(treat)[1])
    t1 <- which(treat == levels(treat)[2])
    n0 <- length(t0)
    n1 <- length(t1)

    # Normalize weights
    w[t0] <- w[t0]/mean(w[t0])
    w[t1] <- w[t1]/mean(w[t1])

    w0 <- w[t0] / n0
    w1 <- w[t1] / n1

    # Quadratic program setup
    P <- -d[t0, t0] * tcrossprod(w0)
    q <- 2 * (w1 %*% d[t1, t0]) * w0
    Amat <- cbind(diag(n0), w0)
    lvec <- c(rep(1e-8, n0), 1)
    uvec <- c(rep(Inf,  n0), 1)
    diag(P) <- diag(P) + 1e-4 * w0^2 / 2

    # Solve with OSQP
    settings <- osqp::osqpSettings(eps_abs = 1e-8, eps_rel = 1e-6, max_iter = 5e4, polish = TRUE, verbose = FALSE)
    result   <- osqp::solve_osqp(P = 2 * P, q = q, A = t(Amat), l = lvec, u = uvec, pars = settings)

    wout <- rep(1, n)
    wout[t0]    <- result$x
    return(wout)
  }

  if(any(grepl('*', as.character(formula), fixed = TRUE))) stop('Only + for covariates and | for stratum variables are allowed.')
  if(any(grepl(':', as.character(formula), fixed = TRUE))) stop('Only + for covariates and | for stratum variables are allowed.')
  fnames   <- all.vars(as.formula(formula))
  fparts   <- all.names(as.formula(formula))
  leftpart <- fnames[1]   # The response variable
  vars     <- fnames[-1]  # Exclude the response variable

  if(any(fparts=='|')){
    stratum <- strsplit(as.character(as.formula(formula))[3], "\\|")[[1]]
    stratum <- trimws(strsplit(stratum[2], "\\+")[[1]])
    if(length(stratum)>3) stop('max 3 stratum variables are allowed!')
    vars    <- vars[which(!vars%in%stratum)]
    data$in_stratum <- interaction(data[stratum])
    if(nlevels(data$in_stratum)>8) stop('Too many strata, a maximum of 8 are allowed!')
    if(length(n)!=nlevels(data$in_stratum)) n <- n[1]
    if(length(n)==1) n <- rep(n, nlevels(data$in_stratum))
    new_formula <- as.formula(paste(leftpart, '~', paste(vars, collapse='+')))
    data_out <-  NULL

    for(k in 1:nlevels(data$in_stratum)){
      data_k <- data[which(data$in_stratum==levels(data$in_stratum)[k]), ]
      out    <- VCG_sampler(new_formula, data=data_k, n=n[k], c_w=c_w, random=random, plot=FALSE)
      data_out <- rbind(data_out, out)
    }
    data_save    <- data_out

    if(plot){
      s_colors <- c("#00A091", "#FE7500", "#7A00E6", "#F5C142","#694A3E","#1F70A6","#EBA0AB","#6CADD1")

      median_diff <- function(data, lvls, leftpart, vars) {
        meds <- sapply(vars, function(var) {
          stats::median(data[data[[leftpart]]==lvls[1], var], na.rm = TRUE)-stats::median(data[data[[leftpart]]==lvls[2], var], na.rm=TRUE)
        })
        names(meds) <- vars
        return(meds)
      }

      mad_diff <- function(data, lvls, leftpart, vars) {
        mad <- sapply(vars, function(var) {
          stats::mad(data[data[[leftpart]]==lvls[1], var], na.rm = TRUE)-stats::mad(data[data[[leftpart]]==lvls[2], var], na.rm=TRUE)
        })
        names(mad) <- vars
        return(mad)
      }


      data_out[, vars] <- robust_scale(data_out[, vars], group=data_out[, leftpart])
      data_0   <- data_out
      data_out[, leftpart] <- as.factor(data_out[, leftpart])
      lvls <-  levels(data_out[, leftpart])
      if(length(lvls) != 2) stop('The left part of the formula must be a binary variable')

      p <- ggplot() +
        geom_vline(xintercept = 0,  color = "#7A00E6", linewidth = 1, linetype = 3) +
        geom_hline(yintercept = 0,  color = "#7A00E6", linewidth = 1, linetype = 3) +
        theme_minimal() +
        scale_y_continuous(expand = expansion(mult = c(0, 0.05)))  +
        scale_x_continuous(expand = expansion(mult = c(0.1, 0.1)))  +
        xlab('VCG - TG (median diff.)') +
        ylab('VCG - TG (MAD diff.)') +
        labs(title = "Energy balancing results", subtitle = "95% ellipse of permutations")

      for(k in 1:nlevels(data$in_stratum)){
        R <- 1000
        Mm <- Msd <- matrix(NA, nrow = R, ncol = length(vars))
        for(i in 1:R){
          data_rand <- data_out[which(data_out$in_stratum==levels(data_out$in_stratum)[k]), ]
          data_rand$per_group  <- sample(data_rand[, leftpart], replace = FALSE)
          Mm[i,]  <- median_diff(data_rand, lvls, 'per_group', vars)
          Msd[i,] <- mad_diff(data_rand,    lvls, 'per_group', vars)
        }
        mrange <-  quantile(as.vector(Mm),  probs = c(0.025, 0.975), na.rm = TRUE)
        sdrange<-  quantile(as.vector(Msd), probs = c(0.025, 0.975), na.rm = TRUE)
        ellipse_data <- data.frame(x0=mean(mrange), y0=mean(sdrange), a=mean(abs(mrange-mean(mrange))), b=mean(abs(sdrange-mean(sdrange))), angle = 0)
        p <- p+ggforce::geom_ellipse(data = ellipse_data, aes(x0 = x0, y0 = y0, a = a, b = b, angle = angle),
                                     fill = s_colors[k], color = s_colors[k], alpha = 0.15)

      }

      for(k in 1:nlevels(data$in_stratum)){
        arrows <- data.frame(
          label = paste0(substr(vars, 1, 10), ' (', levels(data_out$in_stratum)[k], ')'),
          x =    median_diff(data_out[which(data_out$in_stratum==levels(data_out$in_stratum)[k]), ], lvls, leftpart, vars),
          y =    mad_diff(data_out[which(data_out$in_stratum==levels(data_out$in_stratum)[k]), ],   lvls,  leftpart, vars),
          xend = median_diff(data_out[which(data_out$in_stratum==levels(data_out$in_stratum)[k]), ], lvls, paste0(leftpart, '_balanced'), vars),
          yend = mad_diff(data_out[which(data_out$in_stratum==levels(data_out$in_stratum)[k]), ],    lvls, paste0(leftpart, '_balanced'), vars)
        )

        if(!is.null(c_w)) arrows$label <- paste0(arrows$label, '*', c_w)

        p <- p+geom_point(data = arrows, aes(x = x, y = y), color = s_colors[k], size = 3, shape =16) +
          geom_segment(data = arrows, aes(x = x, y = y, xend = xend, yend = yend),
                       arrow = arrow(length = unit(0.1, "inches")), color = s_colors[k], linewidth = 1.25) +
          geom_text(data = arrows, aes(x = x, y = y, label = label), vjust = -0.5, color = "black", size = 5)

      }
      p

      return(list(data_save, p))
    }else{
      return(data_save)
    }

  }else{
    #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%#

    data$insideID   <- 1:nrow(data)
    data_save       <- data
    data            <- na.omit(data[, c(leftpart, vars, 'insideID')])
    data[, vars]    <- robust_scale(data[, vars], group=data[, leftpart])
    data_0   <- data
    data$VCG <- NA
    data[, leftpart] <- as.factor(data[, leftpart])
    lvls <-  levels(data[, leftpart])
    if(length(lvls) != 2) stop('The left part of the formula must be a binary variable')
    data$VCG[which(data[, leftpart]==lvls[1])] <- 0

    weights <- energy_sampler_w(data[,vars], data[,leftpart], c_w=c_w, scale = TRUE)
    data$e_weights <- weights


    if(!random){
      data$e_weights[which(data[, leftpart]==lvls[2])] <- NA
      cut <- min(sort(data$e_weights[which(data[, leftpart]==lvls[1])], decreasing = T)[1:n], na.rm=TRUE)
      data$VCG[which(data$e_weights>=cut)[1:n]] <- 1
    }else{
      data$ID_inside <- 1:nrow(data)
      data$e_weights <- data$e_weights/sum(data$e_weights, na.rm=T)
      data$e_weights[which(data$e_weights<0)] <- 0
      ID_selected <- sample(data$ID_inside[which(data[, leftpart]==lvls[1])], size = n, replace=F, prob = data$e_weights[which(data[, leftpart]==lvls[1])])
      data$VCG[which(data$ID_inside %in% ID_selected)] <- 1
      data <- data[, which(colnames(data)!='ID_inside')]
    }

    data$e_weights[which(data[, leftpart]==lvls[2])] <- 1
    newrep <- ifelse(is.na(data$VCG), 1, 0)
    newrep[which(data$VCG==0)] <- NA
    newrep <- factor(newrep, labels=lvls)
    eval(parse(text=paste0('data$', leftpart, '_balanced <- newrep')))

    # recover the initial data
    data_out <- merge(data_save, data[, c('VCG', 'e_weights', 'insideID', paste0(leftpart, '_balanced'))], by='insideID', all.x = TRUE)
    data_out$insideID <- NULL

    if(plot){
      median_diff <- function(data, lvls, leftpart, vars) {
        meds <- sapply(vars, function(var) {
          stats::median(data[data[[leftpart]]==lvls[1], var], na.rm = TRUE)-stats::median(data[data[[leftpart]]==lvls[2], var], na.rm=TRUE)
        })
        names(meds) <- vars
        return(meds)
      }

      mad_diff <- function(data, lvls, leftpart, vars) {
        mad <- sapply(vars, function(var) {
          stats::mad(data[data[[leftpart]]==lvls[1], var], na.rm = TRUE)-stats::mad(data[data[[leftpart]]==lvls[2], var], na.rm=TRUE)
        })
        names(mad) <- vars
        return(mad)
      }

      # ellipse simulation
      R <- 1000
      Mm <- Msd <- matrix(NA, nrow = R, ncol = length(vars))
      for(i in 1:R){
        data_0$per_group  <- sample(data_0[, leftpart], replace = FALSE)
        Mm[i,]  <- median_diff(data_0, lvls, 'per_group', vars)
        Msd[i,] <- mad_diff(data_0,  lvls,  'per_group', vars)
      }
      #mrange <- quantile(abs(as.vector(Mm)),  probs = c(0.95), na.rm = TRUE)
      #sdrange<- quantile(abs(as.vector(Msd)), probs = c(0.95), na.rm = TRUE)
      mrange <-  quantile(as.vector(Mm),  probs = c(0.025, 0.975), na.rm = TRUE)
      sdrange<-  quantile(as.vector(Msd), probs = c(0.025, 0.975), na.rm = TRUE)



      # Define points
      arrows <- data.frame(
        label = substr(vars, 1, 10),
        x = median_diff(data_0, lvls, leftpart, vars),
        y = mad_diff(data_0, lvls,  leftpart, vars),
        xend = median_diff(data, lvls, paste0(leftpart, '_balanced'), vars),
        yend = mad_diff(data, lvls, paste0(leftpart, '_balanced'), vars)
      )
      if(!is.null(c_w)) arrows$label <- paste0(arrows$label, '(*', c_w, ')')

      #ellipse_data <- data.frame(x0=0, y0=0, a=mrange, b=sdrange, angle = 0)
      ellipse_data <- data.frame(x0=mean(mrange), y0=mean(sdrange), a=mean(abs(mrange-mean(mrange))), b=mean(abs(sdrange-mean(sdrange))), angle = 0)

      p <- ggplot() +
        ggforce::geom_ellipse(data = ellipse_data, aes(x0 = x0, y0 = y0, a = a, b = b, angle = angle),
                              fill = "#7A00E6", color = NA, alpha = 0.25) +
        geom_vline(xintercept = 0,  color = "#7A00E6", linewidth = 1, linetype = 3) +
        geom_hline(yintercept = 0,  color = "#7A00E6", linewidth = 1, linetype = 3) +
        geom_point(data = arrows, aes(x = x, y = y), color = "#00A091", size = 3, shape =16) +
        geom_segment(data = arrows, aes(x = x, y = y, xend = xend, yend = yend),
                     arrow = arrow(length = unit(0.1, "inches")), color = "#00A091", linewidth = 1.25) +
        geom_text(data = arrows, aes(x = x, y = y, label = label), vjust = -0.5, color = "black", size = 5) +
        theme_minimal() + #xlim(range(c(-mrange, mrange, arrows$x, arrows$xend))) + ylim(range(c(-sdrange, sdrange, arrows$y, arrows$yend))) +
        scale_y_continuous(expand = expansion(mult = c(0, 0.05)))  +
        scale_x_continuous(expand = expansion(mult = c(0.1, 0.1)))  +
        xlab('VCG - TG (median diff.)') +
        ylab('VCG - TG (MAD diff.)') +
        labs(title = "Energy balancing results", subtitle = "95% ellipse of permutations")
      return(list(data_out, p))
    }else{
      return(data_out)
    }
  }
}
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%#


