This vignette is a guide to policy_learn()
and some of
the associated S3 methods. The purpose of policy_learn
is
to specify a policy learning algorithm and estimate an optimal policy.
For details on the methodology, see the associated paper (Nordland and Holst 2023).
We consider a fixed two-stage problem as a general setup and simulate
data using sim_two_stage()
and create a
policy_data
object using policy_data()
:
d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
action = c("A_1", "A_2"),
baseline = c("B", "BB"),
covariates = list(L = c("L_1", "L_2"),
C = c("C_1", "C_2")),
utility = c("U_1", "U_2", "U_3"))
pd
#> Policy data with n = 2000 observations and maximal K = 2 stages.
#>
#> action
#> stage 0 1 n
#> 1 1017 983 2000
#> 2 819 1181 2000
#>
#> Baseline covariates: B, BB
#> State covariates: L, C
#> Average utility: 0.84
policy_learn()
specify a policy learning algorithm via
the type
argument: Q-learning (ql
), doubly
robust Q-learning (drql
), doubly robust blip learning
(blip
), policy tree learning (ptl
), and
outcome weighted learning (owl
).
Because each policy learning type has varying control arguments,
these are passed as a list using the control
argument. To
help the user set the required control arguments and to provide
documentation, each type has a helper function
control_type()
which sets the default control arguments and
overwrite values if supplied by the user.
As an example we specify a doubly robust blip learner:
pl_blip <- policy_learn(
type = "blip",
control = control_blip(
blip_models = q_glm(formula = ~ BB + L + C)
)
)
For details on the implementation, see Algorithm 3 in (Nordland and Holst 2023). The only required
control argument for blip learning is a model input. The
blip_models
argument expects a q_model
. In
this case we input a simple linear model as implemented in
q_glm
.
The output of policy_learn()
is again a function:
pl_blip
#> Policy learner with arguments:
#> policy_data, g_models=NULL, g_functions=NULL,
#> g_full_history=FALSE, q_models, q_full_history=FALSE
In order to apply the policy learner we need to input a
policy_data
object and nuisance models
g_models
and q_models
for computing the doubly
robust score.
Like policy_eval()
is it possible to cross-fit the
doubly robust score used as input to the policy model. The number of
folds for the cross-fitting procedure is provided via the L
argument. As default, the cross-fitted nuisance models are not saved.
The cross-fitted nuisance models can be saved via the
save_cross_fit_models
argument:
pl_blip_cross <- policy_learn(
type = "blip",
control = control_blip(
blip_models = q_glm(formula = ~ BB + L + C)
),
L = 2,
save_cross_fit_models = TRUE
)
po_blip_cross <- pl_blip_cross(
pd,
g_models = list(g_glm(), g_glm()),
q_models = list(q_glm(), q_glm())
)
From a user perspective, nothing has changed. However, the policy object now contains each of the cross-fitted nuisance models:
po_blip_cross$g_functions_cf
#> $`1`
#> $stage_1
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> -0.18321 0.15191 0.90737 -0.03865 0.18927 0.15088
#>
#> Degrees of Freedom: 999 Total (i.e. Null); 994 Residual
#> Null Deviance: 1384
#> Residual Deviance: 1086 AIC: 1098
#>
#>
#> $stage_2
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> 0.24410 0.13150 0.99426 -0.02289 -0.41777 -0.17383
#>
#> Degrees of Freedom: 999 Total (i.e. Null); 994 Residual
#> Null Deviance: 1349
#> Residual Deviance: 1082 AIC: 1094
#>
#>
#> attr(,"full_history")
#> [1] FALSE
#>
#> $`2`
#> $stage_1
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> 0.113952 -0.240397 1.142507 -0.094362 -0.009235 -0.101783
#>
#> Degrees of Freedom: 999 Total (i.e. Null); 994 Residual
#> Null Deviance: 1386
#> Residual Deviance: 1065 AIC: 1077
#>
#>
#> $stage_2
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> 0.15426 0.01307 0.96485 -0.08554 -0.33532 -0.12597
#>
#> Degrees of Freedom: 999 Total (i.e. Null); 994 Residual
#> Null Deviance: 1357
#> Residual Deviance: 1102 AIC: 1114
#>
#>
#> attr(,"full_history")
#> [1] FALSE
Realistic policy learning is implemented for types ql
,
drql
, blip
and ptl
(for a binary
action set). The alpha
argument sets the probability
threshold for defining the realistic action set. For implementation
details, see Algorithm 5 in (Nordland and Holst
2023). Here we set a 5% restriction:
pl_blip_alpha <- policy_learn(
type = "blip",
control = control_blip(
blip_models = q_glm(formula = ~ BB + L + C)
),
alpha = 0.05,
L = 2
)
po_blip_alpha <- pl_blip_alpha(
pd,
g_models = list(g_glm(), g_glm()),
q_models = list(q_glm(), q_glm())
)
The policy object now lists the alpha
level as well as
the g-model used to define the realistic action set:
po_blip_alpha$g_functions
#> $stage_1
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> -0.03295 -0.05107 1.02271 -0.06478 0.09582 0.02370
#>
#> Degrees of Freedom: 1999 Total (i.e. Null); 1994 Residual
#> Null Deviance: 2772
#> Residual Deviance: 2161 AIC: 2173
#>
#>
#> $stage_2
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) L C B BBgroup2 BBgroup3
#> 0.19814 0.07355 0.97991 -0.05280 -0.37163 -0.14598
#>
#> Degrees of Freedom: 1999 Total (i.e. Null); 1994 Residual
#> Null Deviance: 2707
#> Residual Deviance: 2186 AIC: 2198
#>
#>
#> attr(,"full_history")
#> [1] FALSE
get_policy_functions()
A policy
function is great for evaluating a given policy
or even implementing or simulating from a single-stage policy. However,
the function is not useful for implementing or simulating from a learned
multi-stage policy. To access the policy function for each stage we use
get_policy_functions()
. In this case we get the second
stage policy function:
The stage specific policy requires a data.table
with
named columns as input and returns a character vector with the
recommended actions:
get_policy()
Applying the policy learner returns a policy_object
containing all of the components needed to specify the learned policy.
In this the only component of the policy is a model for the blip
function:
po_blip$blip_functions$stage_1$blip_model
#> $model
#>
#> Call: NULL
#>
#> Coefficients:
#> (Intercept) BBgroup2 BBgroup3 L C
#> 0.4076 0.2585 0.2231 0.1765 0.8624
#>
#> Degrees of Freedom: 1999 Total (i.e. Null); 1995 Residual
#> Null Deviance: 56820
#> Residual Deviance: 53220 AIC: 12250
#>
#> attr(,"class")
#> [1] "q_glm"
To access and apply the policy itself use get_policy()
,
which behaves as a policy
meaning that we can apply to any
(suitable) policy_data
object to get the policy
actions:
sessionInfo()
#> R version 4.4.1 (2024-06-14)
#> Platform: aarch64-apple-darwin23.5.0
#> Running under: macOS Sonoma 14.6.1
#>
#> Matrix products: default
#> BLAS: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRblas.dylib
#> LAPACK: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRlapack.dylib; LAPACK version 3.12.0
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> time zone: Europe/Copenhagen
#> tzcode source: internal
#>
#> attached base packages:
#> [1] splines stats graphics grDevices utils datasets methods
#> [8] base
#>
#> other attached packages:
#> [1] ggplot2_3.5.1 data.table_1.15.4 polle_1.5
#> [4] SuperLearner_2.0-29 gam_1.22-4 foreach_1.5.2
#> [7] nnls_1.5
#>
#> loaded via a namespace (and not attached):
#> [1] sass_0.4.9 utf8_1.2.4 future_1.33.2
#> [4] lattice_0.22-6 listenv_0.9.1 digest_0.6.36
#> [7] magrittr_2.0.3 evaluate_0.24.0 grid_4.4.1
#> [10] iterators_1.0.14 mvtnorm_1.2-5 policytree_1.2.3
#> [13] fastmap_1.2.0 jsonlite_1.8.8 Matrix_1.7-0
#> [16] survival_3.6-4 fansi_1.0.6 scales_1.3.0
#> [19] numDeriv_2016.8-1.1 codetools_0.2-20 jquerylib_0.1.4
#> [22] lava_1.8.0 cli_3.6.3 rlang_1.1.4
#> [25] mets_1.3.4 parallelly_1.37.1 future.apply_1.11.2
#> [28] munsell_0.5.1 withr_3.0.0 cachem_1.1.0
#> [31] yaml_2.3.8 tools_4.4.1 parallel_4.4.1
#> [34] colorspace_2.1-0 ranger_0.16.0 globals_0.16.3
#> [37] vctrs_0.6.5 R6_2.5.1 lifecycle_1.0.4
#> [40] pkgconfig_2.0.3 timereg_2.0.5 progressr_0.14.0
#> [43] bslib_0.7.0 pillar_1.9.0 gtable_0.3.5
#> [46] Rcpp_1.0.13 glue_1.7.0 xfun_0.45
#> [49] tibble_3.2.1 highr_0.11 knitr_1.47
#> [52] farver_2.1.2 htmltools_0.5.8.1 rmarkdown_2.27
#> [55] labeling_0.4.3 compiler_4.4.1