#' Plot Association Surface (3D or Perspective)
#'
#' @description
#' Produces a 3D perspective plot of the estimated semi-parametric association
#' surface \eqn{\hat{f}_{rs}(\eta_1, \eta_2)} for a specified transition.
#'
#' @param object A \code{"jmSurface"} object.
#' @param transition Character string specifying which transition to plot.
#' @param n_grid Integer grid resolution. Default \code{40}.
#' @param theta,phi Viewing angles for \code{persp}. Defaults \code{-30, 25}.
#' @param col Color palette. Default \code{hcl.colors(50, "viridis")}.
#' @param main Title. If \code{NULL}, auto-generated.
#' @param ... Additional arguments passed to \code{persp}.
#'
#' @return Invisibly returns the prediction grid with fitted values.
#'
#' @export
plot_surface <- function(object, transition, n_grid = 40,
                         theta = -30, phi = 25,
                         col = NULL, main = NULL, ...) {

  if (!inherits(object, "jmSurface"))
    stop("object must be of class 'jmSurface'")
  if (!(transition %in% names(object$gam_fits)))
    stop("Transition '", transition, "' not found. Available: ",
         paste(names(object$gam_fits), collapse = ", "))

  gf <- object$gam_fits[[transition]]
  ed <- object$eta_data[[transition]]
  eta_cols <- grep("^eta_", names(ed), value = TRUE)

  if (length(eta_cols) < 2)
    stop("Need at least 2 biomarker summaries for a surface plot")

  if (is.null(col)) col <- grDevices::hcl.colors(50, "viridis")
  if (is.null(main)) main <- paste("Association Surface:", transition)

  ## Build prediction grid
  e1r <- quantile(ed[[eta_cols[1]]], c(0.05, 0.95), na.rm = TRUE)
  e2r <- quantile(ed[[eta_cols[2]]], c(0.05, 0.95), na.rm = TRUE)
  e1s <- seq(e1r[1], e1r[2], length.out = n_grid)
  e2s <- seq(e2r[1], e2r[2], length.out = n_grid)

  g <- expand.grid(x1 = e1s, x2 = e2s)
  names(g) <- eta_cols[1:2]

  ## Fill covariates with medians
  for (cv in object$covariates) {
    if (cv %in% names(ed)) {
      g[[cv]] <- median(ed[[cv]], na.rm = TRUE)
    }
  }
  if (length(eta_cols) >= 3) {
    g[[eta_cols[3]]] <- median(ed[[eta_cols[3]]], na.rm = TRUE)
  }
  g$time_in_state <- median(ed$time_in_state, na.rm = TRUE)
  g$status <- 1

  g$z <- tryCatch(
    predict(gf, newdata = g, type = "link"),
    error = function(e) predict(gf, newdata = g, type = "terms")[, 1]
  )

  z_mat <- matrix(g$z, nrow = length(e1s))

  ## Color by z value
  nrz <- nrow(z_mat); ncz <- ncol(z_mat)
  zfacet <- (z_mat[-1, -1] + z_mat[-1, -ncz] +
             z_mat[-nrz, -1] + z_mat[-nrz, -ncz]) / 4
  nbcol <- length(col)
  facetcol <- col[cut(zfacet, breaks = nbcol)]

  persp(e1s, e2s, z_mat,
        xlab = gsub("eta_", "", eta_cols[1]),
        ylab = gsub("eta_", "", eta_cols[2]),
        zlab = "log-Hazard",
        theta = theta, phi = phi,
        col = facetcol, shade = 0.3,
        main = main, ...)

  invisible(g)
}


#' Contour Heatmap of Association Surface
#'
#' @description
#' Produces a filled contour (heatmap) of the estimated association surface,
#' identifying "danger zones" of elevated transition risk.
#'
#' @param object A \code{"jmSurface"} object.
#' @param transition Character string specifying which transition.
#' @param n_grid Integer grid resolution. Default \code{50}.
#' @param col Color palette function. Default \code{hcl.colors(30, "YlOrRd", rev=TRUE)}.
#' @param main Title. If \code{NULL}, auto-generated.
#' @param ... Additional arguments passed to \code{filled.contour}.
#'
#' @return Invisibly returns the prediction grid.
#'
#' @export
contour_heatmap <- function(object, transition, n_grid = 50,
                            col = NULL, main = NULL, ...) {

  if (!inherits(object, "jmSurface"))
    stop("object must be of class 'jmSurface'")
  if (!(transition %in% names(object$gam_fits)))
    stop("Transition '", transition, "' not found.")

  gf <- object$gam_fits[[transition]]
  ed <- object$eta_data[[transition]]
  eta_cols <- grep("^eta_", names(ed), value = TRUE)

  if (length(eta_cols) < 2)
    stop("Need at least 2 biomarker summaries for a heatmap")

  if (is.null(col)) col <- grDevices::hcl.colors(30, "YlOrRd", rev = TRUE)
  if (is.null(main)) main <- paste("Heatmap:", transition)

  e1r <- quantile(ed[[eta_cols[1]]], c(0.05, 0.95), na.rm = TRUE)
  e2r <- quantile(ed[[eta_cols[2]]], c(0.05, 0.95), na.rm = TRUE)
  e1s <- seq(e1r[1], e1r[2], length.out = n_grid)
  e2s <- seq(e2r[1], e2r[2], length.out = n_grid)

  g <- expand.grid(x1 = e1s, x2 = e2s)
  names(g) <- eta_cols[1:2]

  for (cv in object$covariates) {
    if (cv %in% names(ed)) g[[cv]] <- median(ed[[cv]], na.rm = TRUE)
  }
  if (length(eta_cols) >= 3)
    g[[eta_cols[3]]] <- median(ed[[eta_cols[3]]], na.rm = TRUE)
  g$time_in_state <- median(ed$time_in_state, na.rm = TRUE)
  g$status <- 1

  g$z <- tryCatch(
    predict(gf, newdata = g, type = "link"),
    error = function(e) predict(gf, newdata = g, type = "terms")[, 1]
  )

  z_mat <- matrix(g$z, nrow = length(e1s))

  filled.contour(e1s, e2s, z_mat,
                 color.palette = function(n) grDevices::hcl.colors(n, "YlOrRd", rev = TRUE),
                 xlab = gsub("eta_", "", eta_cols[1]),
                 ylab = gsub("eta_", "", eta_cols[2]),
                 main = main,
                 key.title = title(main = "log-HR", cex.main = 0.9),
                 ...)

  invisible(g)
}


#' Marginal Effect Slices
#'
#' @description
#' Produces marginal effect slice plots showing the effect of one biomarker
#' on the log-hazard at fixed quantiles (Q25, Q50, Q75) of the other.
#' Diverging slices indicate interaction; parallel slices indicate additivity.
#'
#' @param object A \code{"jmSurface"} object.
#' @param transition Character string specifying which transition.
#' @param n_points Integer number of evaluation points. Default \code{60}.
#' @param quantiles Numeric vector of quantile probabilities for slicing.
#'   Default \code{c(0.25, 0.50, 0.75)}.
#' @param colors Character vector of colors for each slice. Default blue/orange/red.
#' @param main Title. If \code{NULL}, auto-generated.
#' @param ... Additional arguments passed to \code{plot}.
#'
#' @return Invisibly returns the data frame of slice values.
#'
#' @export
marginal_slices <- function(object, transition, n_points = 60,
                            quantiles = c(0.25, 0.50, 0.75),
                            colors = c("#264653", "#e76f51", "#2a9d8f"),
                            main = NULL, ...) {

  if (!inherits(object, "jmSurface"))
    stop("object must be of class 'jmSurface'")
  if (!(transition %in% names(object$gam_fits)))
    stop("Transition '", transition, "' not found.")

  gf <- object$gam_fits[[transition]]
  ed <- object$eta_data[[transition]]
  eta_cols <- grep("^eta_", names(ed), value = TRUE)

  if (length(eta_cols) < 2)
    stop("Need at least 2 biomarker summaries for marginal slices")

  if (is.null(main)) main <- paste("Marginal Effect Slices:", transition)

  e1r <- quantile(ed[[eta_cols[1]]], c(0.05, 0.95), na.rm = TRUE)
  e1s <- seq(e1r[1], e1r[2], length.out = n_points)
  q2 <- quantile(ed[[eta_cols[2]]], quantiles, na.rm = TRUE)

  all_slices <- list()
  for (i in seq_along(q2)) {
    s <- data.frame(x = e1s)
    names(s) <- eta_cols[1]
    s[[eta_cols[2]]] <- q2[i]

    for (cv in object$covariates) {
      if (cv %in% names(ed)) s[[cv]] <- median(ed[[cv]], na.rm = TRUE)
    }
    if (length(eta_cols) >= 3)
      s[[eta_cols[3]]] <- median(ed[[eta_cols[3]]], na.rm = TRUE)
    s$time_in_state <- median(ed$time_in_state, na.rm = TRUE)
    s$status <- 1

    s$f_hat <- tryCatch(
      predict(gf, newdata = s, type = "link"),
      error = function(e) predict(gf, newdata = s, type = "terms")[, 1]
    )

    s$quantile <- paste0("Q", round(quantiles[i] * 100))
    s$q_val <- q2[i]
    all_slices[[i]] <- s
  }

  slices_df <- do.call(rbind, all_slices)

  ## Plot
  ylim <- range(slices_df$f_hat, na.rm = TRUE) * c(1.1, 1.1)
  plot(NA, xlim = range(e1s), ylim = ylim,
       xlab = gsub("eta_", "", eta_cols[1]),
       ylab = "Partial log-HR",
       main = main, ...)

  for (i in seq_along(q2)) {
    s <- all_slices[[i]]
    lty <- c(1, 2, 4)[min(i, 3)]
    lines(s[[eta_cols[1]]], s$f_hat, col = colors[min(i, length(colors))],
          lwd = 2.5, lty = lty)
  }

  mk2_name <- gsub("eta_", "", eta_cols[2])
  legend("topright",
         legend = paste0(mk2_name, " at Q", round(quantiles * 100),
                         " (", round(q2, 1), ")"),
         col = colors[1:length(q2)],
         lwd = 2.5, lty = c(1, 2, 4)[1:length(q2)],
         bty = "n", cex = 0.85)

  invisible(slices_df)
}


#' Plot method for jmSurface objects
#'
#' @description Default plot dispatches to \code{plot_surface} for the
#'   first transition.
#' @param x A \code{jmSurface} object.
#' @param transition Which transition to plot. Default: first.
#' @param type One of \code{"surface"}, \code{"heatmap"}, \code{"slices"}.
#' @param ... Additional arguments.
#' @return Invisibly returns the prediction grid (a data frame) produced by
#'   the dispatched plotting function (\code{plot_surface},
#'   \code{contour_heatmap}, or \code{marginal_slices}).
#' @export
plot.jmSurface <- function(x, transition = NULL,
                           type = c("surface", "heatmap", "slices"), ...) {
  type <- match.arg(type)
  if (is.null(transition)) transition <- x$transitions[1]

  switch(type,
    surface = plot_surface(x, transition, ...),
    heatmap = contour_heatmap(x, transition, ...),
    slices  = marginal_slices(x, transition, ...)
  )
}
