# printing
#' @param x An object of class `DTRreg`.
#' @rdname DTRreg
#' @importFrom utils tail
#' @export
print.DTRreg <- function(x,...) {

  K <- x$K

  cat(ifelse(x$setup$type == "DTR", 
             paste("DTR estimation over ", K, " stages:\n\n"), 
             ""),
      "Blip parameter estimates",
      ifelse(x$setup$var.estim == "none", 
             " (standard errors not requested)\n", 
             "\n"))
  
  for (k in 1L:x$K) {
    if (x$setup$var.estim == "none" || {is.list(x$covmat) && any(x$covmat[[k]] == "aborted")}) {
      cat("Stage ", k, " (n = ", x$analysis$n[[k]], ")\n", sep = "")
      result <- data.frame("Estimate" = sprintf(x$psi[[k]], fmt = '%#.4f'))
      rownames(result) <- names(x$psi[[k]])
      print(result)
      cat("\n")
    } else {
      cat("Stage ", k, " (n = ", x$analysis$n[[k]], ")\n", sep = "")
      
      se <- utils::tail(sqrt(diag(x$covmat[[k]])), length(x$psi[[k]]))
      ci <- cbind(x$psi[[k]] - 1.96 * se, x$psi[[k]] + 1.96 * se)
      result <- data.frame("Estimate" = sprintf(x$psi[[k]], fmt = '%#.4f'),
                           "Std_Error" = sprintf(se, fmt = '%#.4f'),
                           "95%_CI_Lower" = sprintf(pmin(ci[, 1L], ci[, 2L]), fmt = '%#.4f'), 
                           "95%_CI_Upper" = sprintf(pmax(ci[, 1L], ci[, 2L]), fmt = '%#.4f'))
      rownames(result) <- names(x$psi[[k]])
      print(result)
    }
  }

  # print warnings if non-regularity concerns
  if (x$setup$type == "DTR") {
    if (any(!is.na(x$nonreg))) {
      
      tst <- unlist(x$analysis$cts) == "bin" & x$nonreg > 0.05
      if (any(tst)) {
        use_plural <- sum(tst) > 1L
        cat("\nWarning: possible non-regularity at ",
            ifelse(use_plural, "stages ", "stage "),
            paste(which(tst), collapse = ", "),
            ifelse(use_plural, " (probs = ", "(prob = "),
            paste(format(x$nonreg[tst], digits = 3), collapse = ", "), 
            ifelse(use_plural, ", respectively)\n", ")\n"), sep = "")
      }
    }
    
    # print DTRregs explicitly if bin or cts.l
    if (x$analysis$cts[[1L]] == "bin") cat("Recommended dynamic treatment regimen:\n")
    
    for (j in 1L:K) {
      
      # format parameter estimate to be 3 digits w/ scientific notation
      est <- sprintf(abs(x$psi[[j]]), fmt = '%#.4f')
      
      # identify the appropriate +/- connector
      sg <- c("", rep("+", length(est) - 1L))
      sg[x$psi[[j]] < 0.0] <- "-"
      
      # names of parameters
      nms <- names(x$psi[[j]])
      nms[1L] <- ""
      
      if (is.null(x$optimization)) x$optimization <- "max"
      
      # combine these to form the rule  
      rule <- paste0(paste(sg, est, nms, collapse = " "), " > 0")
      
      if (x$analysis$cts[[j]] == "bin") {
        cat("Stage ", j, ": treat if", rule, "\n", sep = "")
      } else if (x$analysis$cts[[j]] == "cts.l") {
        cat("Stage ", j, ": maximum treatment if",
            rule, ", otherwise minimum treatment.\n", sep = "")
      }
    }
  }
}


#' @param object An object of class `DTRreg`.
#' @param ... Ignored.
#' @rdname DTRreg
#' @export
summary.DTRreg <- function(object,...) {
  print.DTRreg(object)
}

#' Optimal Outcome Prediction for DTRs
#' 
#' Predicted outcome assuming optimal treatment (according to analysis via 
#'   G-estimation or dWOLS) was followed.  Assumes blip and treatment-free 
#'   models correctly specified.
#'  
#' This function may be used in a similar fashion to more traditional modeling 
#'   commands (such as lm).  Users are referred to the primary `DTRreg()` 
#'   and `DTRSurv()` help command
#'   (and associated literature) for information concerning model specification.
#'   In particular, we note that the predict function assumes that the 
#'   treatment-free model has been correctly specified, as the treatment-free 
#'   parameters are used in the prediction process.
#'   
#' @param object A model object generated by the function `DTRreg()` or `DWSurv()`.
#' @param newdata A dataset (usually the data analyzed by DTRreg for which 
#'   predicted outcomes are desired.  If a new dataset is provided, variable 
#'   names should correspond to those presented to `DTRreg()` or `DWSurv()`.
#' @param treat.range If treatment is continuous (rather than binary), a 
#'   vectors of the form c(min,max) which specify the minimum and maximum value 
#'   the treatment may take at stage 1.  
#'   If unspecified, this will be inferred from the 
#'   treat.range provided with use of the original DTRreg command.  As such, if 
#'   no treatment range was specified there either, treat.range will be the 
#'   minimum and maximum observed first stage treatment.
#' @param ... Space for additional arguments (not currently used)
#' 
#' @return An \eqn{n}{n} x 1 matrix of predicted outcome values.
#' 
#' @references
#' Chakraborty, B., Moodie, E. E. M. (2013) \emph{Statistical Methods for 
#'   Dynamic Treatment Regimes}. New York: Springer.
#'   
#' Robins, J. M. (2004) \emph{Optimal structural nested models for optimal 
#'   sequential decisions}. In Proceedings of the Second Seattle Symposium on 
#'   Biostatistics, D. Y. Lin and P. J. Heagerty (eds), 189-326. 
#'   New York: Springer.
#'   
#' Wallace, M. P., Moodie, E. M. (2015) Doubly-Robust Dynamic Treatment Regimen 
#'   Estimation Via Weighted Least Squares. \emph{Biometrics} \bold{71}(3), 
#'   636-644 (doi:10.1111/biom.12306.)
#'   
#' @author Michael Wallace
#' 
#' @template data_setup
#' @template model_definitions
#' @template g_est
#' @examples
#'   
#' # predicted Y for optimal treatment
#' dat <- data.frame(X1, X2, A1, A2)
#' predict(mod1, newdata = dat)
#' @concept dynamic treatment regimens
#' @concept adaptive treatment strategies
#' @concept personalized medicine
#' @concept g-estimation
#' @concept dynamic weighted ordinary least squares
#' @export
predict.DTRreg <- function(object, newdata, treat.range = NULL, ...) {
  
  K <- object$K
  
  stopifnot("`treat.range` must be a list of length `stage`" = is.null(treat.range) ||
              {is.list(treat.range) && all(lengths(treat.range) == 2L)})

  if (!missing(newdata)) {
    if (!all(complete.cases(newdata))) {
      stop("if provided, `newdata` must be complete", call. = FALSE)
    }
    
    if (object$analysis$cts[[1L]] == "bin") {
      cts_obj <- Binary$new(tf.model = object$setup$models[[1L]]$tf, 
                            blip.model = object$setup$models[[1L]]$blip, 
                            tx.var = object$setup$tx.vars[1L])
    } else if (object$analysis$cts[[1L]] == "multinom") {
      cts_obj <- MultiNom$new(tf.model = object$setup$models[[1L]]$tf, 
                              blip.model = object$setup$models[[1L]]$blip, 
                              tx.var = object$setup$tx.vars[1L], 
                              tx.levels = levels(object$training_data$A[[1L]]))
    } else if (object$analysis$cts[[1L]] == "cts.q") {
      if (!is.list(treat.range)) {
        tx.range <- object$setup$tx.range[[1L]]
      } else {
        tx.range <- treat.range[[1L]]
      }
      cts_obj <- ContQuadraticBlip$new(tf.model = object$setup$models[[1L]]$tf, 
                                       blip.model = object$setup$models[[1L]]$blip, 
                                       tx.var = object$setup$tx.vars[1L],
                                       treat.range = tx.range)
    }
    newdata[[object$setup$tx.vars[1L]]] <- 
      cts_obj$opt(outcome.fit = object$analysis$outcome.fit[[1L]], 
                  data = newdata)
  } else {
    newdata <- object$training_data$data
    newdata[[object$setup$tx.vars[1L]]] <- object$analysis$opt.treat[[1L]]
  }
  if (object$analysis$cts[[1L]] == "cts.q") {
    newdata[["l__txvar2__l"]] <- newdata[[object$setup$tx.vars[1L]]]^2
  }
  pred <- predict(object$analysis$outcome.fit[[1L]], newdata = newdata)
  if ("DWSurv" %in% class(object)) pred <- exp(pred)
  pred
}

#' @rdname DTRreg
#' @export
coef.DTRreg <- function(object,...) {
  K <- object$K
  psi <- object$psi
  names(psi) <- paste0("stage_", 1L:K)
  psi
}

#' Confidence Interval Calculations for DTRs
#' 
#' Confidence intervals for parameters, with the option of 
#'   constructing the confidence intervals using the percentile method 
#'   when bootstrap is used.
#'   
#' @param object A model object generated by the function DTRreg.
#' @param parm Not available for DTRreg objects.
#' @param level The confidence level required.
#' @param type Typical Wald-type confidence interval "se" (default) or 
#'   confidence intervals derived with the percentile method "percentile" 
#'   (bootstrap variance estimates only).
#' @param ... Space for additional arguments (not currently used).
#' 
#' @return A list with columns giving lower and upper confidence limits for each
#'   parameter. These will be labelled as (1-level)/2 and 1 - (1-level)/2 in 
#'   percentage (by default 2.5\% and 97.5\%).
#'   
#' @importFrom stats confint qnorm quantile
#' @export
confint.DTRreg <- function(object, parm = NULL, level = 0.95, 
                           type = c("se", "percentile"), ...) {
  
  type <- match.arg(type)
  
  if (object$setup$var.estim == "none") {
    cat("confint() only available when DTRreg called with a variance",
        "estimation option\n")
    return(invisible(NULL))
  }
  
  if (any(object$covmat[[1]] == "aborted")) {
    cat("variance estimation aborted; unable to obtain confidence intervals\n")
    return(invisible(NULL))
  }
    
  K <- object$K
  psi <- object$psi
  names(psi) <- paste0("stage_", 1L:K)
  
  for (j in 1L:K) {
    psi[[j]] <- matrix(NA, nrow = length(object$psi[[j]]), ncol = 2L)
    rownames(psi[[j]]) <- names(object$psi[[j]])
    colnames(psi[[j]]) <- c(paste(100.0 * {{1.0 - level} / 2.0},"%"), 
                            paste(100.0 * {{1.0 + level} / 2.0},"%"))
      
    if (type == "se") {
      psi[[j]][, 1L] <- object$psi[[j]] + 
        stats::qnorm({1.0 - level} / 2.0) * sqrt(diag(object$covmat[[j]]))
        
      psi[[j]][, 2L] <- object$psi[[j]] + 
        stats::qnorm({1.0 + level} / 2.0) * sqrt(diag(object$covmat[[j]]))
        
    } else if (type == "percentile") {
      if (object$setup$var.estim != "bootstrap") {
        stop("Percentile confidence intervals only available with bootstrap.",
             call. = FALSE)
      }
      psi[[j]][, 1L] <- apply(object$psi.boot[[j]], 2L, stats::quantile, 
                              probs = {1.0 - level} / 2.0)
      psi[[j]][, 2L] <- apply(object$psi.boot[[j]], 2L, stats::quantile, 
                              probs = {1.0 + level} / 2.0)
    }
  }
  psi
}

#' Diagnostic Plots for DTR Estimation
#' 
#' Diagnostic plots for assessment of treatment, treatment-free, and blip models 
#'   following DTR estimation using DTRreg or DWSurv.
#'   
#' DTR estimation using G-estimation and dWOLS requires the specification of 
#'   three models: the treatment, treatment-free, and blip.  The treatment model 
#'   may be assessed via standard diagnostics, whereas the treatment-free and 
#'   blip models may be simultaneously assessed using diagnostic plots 
#'   introduced by Rich et al.  The plot() function first presents diagnostic 
#'   plots that assess the latter, plotting fitted values against residuals and 
#'   covariates following DTR estimation.  If there is any evidence of a 
#'   relationship between the variables in these plots, this is evidence that at 
#'   least one of the blip or treatment-free models is mis-specified.
#'   
#' Following these plots, the plot() function will present standard diagnostic 
#'   plots for the treatment model.  These are produced directly by the standard 
#'   plot() command applied to the models that were fit.  For example, if 
#'   treatment is binary, the resulting plots are the same as those that are 
#'   generated by the plot() command applied to a glm object for logistic 
#'   regression.
#' 
#' @param x A model object generated by the functions DTRreg and DWSurv.
#' @param ... Space for additional arguments (not currently used)
#' @references
#' Chakraborty, B., Moodie, E. E. M. (2013) \emph{Statistical Methods for 
#'   Dynamic Treatment Regimes}. New York: Springer.
#'   
#' Rich B., Moodie E. E. M., Stephens D. A., Platt R. W. (2010) Model 
#'   Checking with Residuals for G-estimation of Optimal Dynamic Treatment 
#'   Regimes. \emph{International Journal of Biostatistics} \bold{6}(2), 
#'   Article 12.
#'   
#' Robins, J. M. (2004) \emph{Optimal structural nested models for optimal 
#'   sequential decisions}. In Proceedings of the Second Seattle Symposium on 
#'   Biostatistics, D. Y. Lin and P. J. Heagerty (eds), 189-326. 
#'   New York: Springer.
#'   
#' Wallace, M. P., Moodie, E. M. (2015) Doubly-Robust Dynamic Treatment 
#'   Regimen Estimation Via Weighted Least Squares. \emph{Biometrics} 
#'   \bold{71}(3), 636-644 (doi:10.1111/biom.12306.)
#' @author Michael Wallace
#' @template data_setup
#' @template model_definitions
#' @template g_est
#' @examples
#'  
#' # model diagnostics: note treatment-free model is mis-specified
#' plot(mod1)
#' 
#' @import graphics
#' @importFrom stats lowess
#' @concept dynamic treatment regimens
#' @concept adaptive treatment strategies
#' @concept personalized medicine
#' @concept g-estimation
#' @concept dynamic weighted ordinary least squares
#' @export
plot.DTRreg <- function(x, ...) {
  
  for (j in 1L:x$K) {
    cases <- x$analysis$last.stage >= j
      
    advance <- readline("Hit <Return> to see next plot:")
    cat("Stage", j, "outcome model:\n")
    
    obsK <- x$analysis$Y[[j]]
    fitK <- x$analysis$outcome.fit[[j]]$fitted.value
      
    plot(fitK, obsK - fitK,
         xlab = "Fitted values",
         ylab = "Residuals",
         main = paste("Stage", j, "Residuals vs Fitted"))
    abline(h = 0.0, lty = 2L, col = 4L)
    points(stats::lowess(fitK, obsK - fitK), type = "l", col = 2L, lwd = 2L)
    
    blip <- x$setup$models[[j]]$blip
    blip_mm <- stats::model.matrix(blip, x$training_data$data[cases, ])
    if (attr(stats::terms(blip), "intercept") == 1L) {
      blip_mm <- blip_mm[, -1L, drop = FALSE]
    }
    
    if (ncol(blip_mm) > 0L) {
      for (i in 1L:ncol(blip_mm)) {
        
        advance <- readline("Hit <Return> to see next plot:")
        cat("Stage", j, "blip function variable", colnames(blip_mm)[i], "\n")
        
        plot(blip_mm[, i], 
             obsK - fitK,
             xlab = colnames(blip_mm)[i],
             ylab = "Residuals",
             main = paste("Stage", j, "Residuals vs", colnames(blip_mm)[i]))
        abline(h = 0.0, lty = 2L, col = 4L)
        points(stats::lowess(blip_mm[, i], obsK - fitK), 
               type = "l", col = 2L, lwd = 2L)
      }
    }
  }

  for (j in 1:x$K) {
    if (!is.logical(x$analysis$tx.mod.fitted[[j]])) {
      cat("Stage", j, "treatment model:\n")
      plot(x$analysis$tx.mod.fitted[[j]], 
           sub.caption = paste("Stage", j, "treatment"))
      advance <- readline("Hit <Return> to see next plot:")
    }
  }
  
  if ("DWSurv" %in% class(x)) {

    for (j in 1L:x$K) {
      if (!is.logical(x$analysis$cens.mod.fitted[[j]])) {
        cat("Stage", j, "censoring model:\n")
        
        plot(x$analysis$cens.mod.fitted[[j]],
             sub.caption = paste("Stage", j, "censoring"))
        advance <- readline("Hit <Return> to see next plot:")
      }
    }
  } else {
    for (j in 1L:x$K) {
      if (!is.logical(x$analysis$cc.mod.fitted[[j]])) {
        cat("Stage", j, "complete cases model\n")
      
        plot(x$analysis$cc.mod.fitted[[j]],
             sub.caption = paste("Stage", j, "complete cases"))
        advance <- readline("Hit <Return> to see next plot:")
      }
    }
  }
}