policy_learn

library("data.table")
library("polle")

This vignette is a guide to policy_learn() and some of the associated S3 methods. The purpose of policy_learn is to specify a policy learning algorithm and estimate an optimal policy. For details on the methodology, see the associated paper (Nordland and Holst 2023).

We consider a fixed two-stage problem as a general setup and simulate data using sim_two_stage() and create a policy_data object using policy_data():

d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
                  action = c("A_1", "A_2"),
                  baseline = c("B", "BB"),
                  covariates = list(L = c("L_1", "L_2"),
                                    C = c("C_1", "C_2")),
                  utility = c("U_1", "U_2", "U_3"))
pd
#> Policy data with n = 2000 observations and maximal K = 2 stages.
#> 
#>      action
#> stage    0    1    n
#>     1 1017  983 2000
#>     2  819 1181 2000
#> 
#> Baseline covariates: B, BB
#> State covariates: L, C
#> Average utility: 0.84

Specifying and applying a policy learner

policy_learn() specify a policy learning algorithm via the type argument: Q-learning (ql), doubly robust Q-learning (drql), doubly robust blip learning (blip), policy tree learning (ptl), and outcome weighted learning (owl).

Because each policy learning type has varying control arguments, these are passed as a list using the control argument. To help the user set the required control arguments and to provide documentation, each type has a helper function control_type() which sets the default control arguments and overwrite values if supplied by the user.

As an example we specify a doubly robust blip learner:

pl_blip <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  )
)

For details on the implementation, see Algorithm 3 in (Nordland and Holst 2023). The only required control argument for blip learning is a model input. The blip_models argument expects a q_model. In this case we input a simple linear model as implemented in q_glm.

The output of policy_learn() is again a function:

pl_blip
#> Policy learner with arguments:
#> policy_data, g_models=NULL, g_functions=NULL,
#> g_full_history=FALSE, q_models, q_full_history=FALSE

In order to apply the policy learner we need to input a policy_data object and nuisance models g_models and q_models for computing the doubly robust score.


(po_blip <- pl_blip(
  pd,
  g_models = list(g_glm(), g_glm()),
  q_models = list(q_glm(), q_glm())
 ))
#> Policy object with list elements:
#> blip_functions, q_functions, action_set, stage_action_sets,
#> alpha, threshold, K
#> Use 'get_policy' to get the associated policy.

Cross-fitting the doubly robust score

Like policy_eval() is it possible to cross-fit the doubly robust score used as input to the policy model. The number of folds for the cross-fitting procedure is provided via the L argument. As default, the cross-fitted nuisance models are not saved. The cross-fitted nuisance models can be saved via the save_cross_fit_models argument:

pl_blip_cross <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  ),
  L = 2,
  save_cross_fit_models = TRUE
)
po_blip_cross <- pl_blip_cross(
   pd,
   g_models = list(g_glm(), g_glm()),
   q_models = list(q_glm(), q_glm())
 )

From a user perspective, nothing has changed. However, the policy object now contains each of the cross-fitted nuisance models:

po_blip_cross$g_functions_cf
#> $`1`
#> $stage_1
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>    -0.18321      0.15191      0.90737     -0.03865      0.18927      0.15088  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1384 
#> Residual Deviance: 1086  AIC: 1098
#> 
#> 
#> $stage_2
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>     0.24410      0.13150      0.99426     -0.02289     -0.41777     -0.17383  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1349 
#> Residual Deviance: 1082  AIC: 1094
#> 
#> 
#> attr(,"full_history")
#> [1] FALSE
#> 
#> $`2`
#> $stage_1
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>    0.113952    -0.240397     1.142507    -0.094362    -0.009235    -0.101783  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1386 
#> Residual Deviance: 1065  AIC: 1077
#> 
#> 
#> $stage_2
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>     0.15426      0.01307      0.96485     -0.08554     -0.33532     -0.12597  
#> 
#> Degrees of Freedom: 999 Total (i.e. Null);  994 Residual
#> Null Deviance:       1357 
#> Residual Deviance: 1102  AIC: 1114
#> 
#> 
#> attr(,"full_history")
#> [1] FALSE

Realistic policy learning

Realistic policy learning is implemented for types ql, drql, blip and ptl (for a binary action set). The alpha argument sets the probability threshold for defining the realistic action set. For implementation details, see Algorithm 5 in (Nordland and Holst 2023). Here we set a 5% restriction:

pl_blip_alpha <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  ),
  alpha = 0.05,
  L = 2
)
po_blip_alpha <- pl_blip_alpha(
   pd,
   g_models = list(g_glm(), g_glm()),
   q_models = list(q_glm(), q_glm())
 )

The policy object now lists the alpha level as well as the g-model used to define the realistic action set:

po_blip_alpha$alpha
#> [1] 0.05
po_blip_alpha$g_functions
#> $stage_1
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>    -0.03295     -0.05107      1.02271     -0.06478      0.09582      0.02370  
#> 
#> Degrees of Freedom: 1999 Total (i.e. Null);  1994 Residual
#> Null Deviance:       2772 
#> Residual Deviance: 2161  AIC: 2173
#> 
#> 
#> $stage_2
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)            L            C            B     BBgroup2     BBgroup3  
#>     0.19814      0.07355      0.97991     -0.05280     -0.37163     -0.14598  
#> 
#> Degrees of Freedom: 1999 Total (i.e. Null);  1994 Residual
#> Null Deviance:       2707 
#> Residual Deviance: 2186  AIC: 2198
#> 
#> 
#> attr(,"full_history")
#> [1] FALSE

Implementation/Simulation and get_policy_functions()

A policy function is great for evaluating a given policy or even implementing or simulating from a single-stage policy. However, the function is not useful for implementing or simulating from a learned multi-stage policy. To access the policy function for each stage we use get_policy_functions(). In this case we get the second stage policy function:

pf_blip <- get_policy_functions(po_blip, stage = 2)

The stage specific policy requires a data.table with named columns as input and returns a character vector with the recommended actions:

pf_blip(
  H = data.table(BB = c("group2", "group1"),
                 L = c(1, 0),
                 C = c(1, 2))
)
#> [1] "1" "1"

Policy objects and get_policy()

Applying the policy learner returns a policy_object containing all of the components needed to specify the learned policy. In this the only component of the policy is a model for the blip function:

po_blip$blip_functions$stage_1$blip_model
#> $model
#> 
#> Call:  NULL
#> 
#> Coefficients:
#> (Intercept)     BBgroup2     BBgroup3            L            C  
#>      0.4076       0.2585       0.2231       0.1765       0.8624  
#> 
#> Degrees of Freedom: 1999 Total (i.e. Null);  1995 Residual
#> Null Deviance:       56820 
#> Residual Deviance: 53220     AIC: 12250
#> 
#> attr(,"class")
#> [1] "q_glm"

To access and apply the policy itself use get_policy(), which behaves as a policy meaning that we can apply to any (suitable) policy_data object to get the policy actions:

get_policy(po_blip)(pd) |> head(4)
#> Key: <id, stage>
#>       id stage      d
#>    <int> <int> <char>
#> 1:     1     1      1
#> 2:     1     2      1
#> 3:     2     1      0
#> 4:     2     2      0

SessionInfo

sessionInfo()
#> R version 4.4.1 (2024-06-14)
#> Platform: aarch64-apple-darwin23.5.0
#> Running under: macOS Sonoma 14.6.1
#> 
#> Matrix products: default
#> BLAS:   /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRblas.dylib 
#> LAPACK: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRlapack.dylib;  LAPACK version 3.12.0
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> time zone: Europe/Copenhagen
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] splines   stats     graphics  grDevices utils     datasets  methods  
#> [8] base     
#> 
#> other attached packages:
#> [1] ggplot2_3.5.1       data.table_1.15.4   polle_1.5          
#> [4] SuperLearner_2.0-29 gam_1.22-4          foreach_1.5.2      
#> [7] nnls_1.5           
#> 
#> loaded via a namespace (and not attached):
#>  [1] sass_0.4.9          utf8_1.2.4          future_1.33.2      
#>  [4] lattice_0.22-6      listenv_0.9.1       digest_0.6.36      
#>  [7] magrittr_2.0.3      evaluate_0.24.0     grid_4.4.1         
#> [10] iterators_1.0.14    mvtnorm_1.2-5       policytree_1.2.3   
#> [13] fastmap_1.2.0       jsonlite_1.8.8      Matrix_1.7-0       
#> [16] survival_3.6-4      fansi_1.0.6         scales_1.3.0       
#> [19] numDeriv_2016.8-1.1 codetools_0.2-20    jquerylib_0.1.4    
#> [22] lava_1.8.0          cli_3.6.3           rlang_1.1.4        
#> [25] mets_1.3.4          parallelly_1.37.1   future.apply_1.11.2
#> [28] munsell_0.5.1       withr_3.0.0         cachem_1.1.0       
#> [31] yaml_2.3.8          tools_4.4.1         parallel_4.4.1     
#> [34] colorspace_2.1-0    ranger_0.16.0       globals_0.16.3     
#> [37] vctrs_0.6.5         R6_2.5.1            lifecycle_1.0.4    
#> [40] pkgconfig_2.0.3     timereg_2.0.5       progressr_0.14.0   
#> [43] bslib_0.7.0         pillar_1.9.0        gtable_0.3.5       
#> [46] Rcpp_1.0.13         glue_1.7.0          xfun_0.45          
#> [49] tibble_3.2.1        highr_0.11          knitr_1.47         
#> [52] farver_2.1.2        htmltools_0.5.8.1   rmarkdown_2.27     
#> [55] labeling_0.4.3      compiler_4.4.1

References

Nordland, Andreas, and Klaus K. Holst. 2023. “Policy Learning with the Polle Package.” https://doi.org/10.48550/arXiv.2212.02335.