Skip to contents

Compute ensemble model outputs as a linear pool, otherwise known as a distributional mixture, of component model outputs for each combination of model task, output type, and output type id. Supported output types include mean, quantile, cdf, and pmf.

Usage

linear_pool(
  model_outputs,
  weights = NULL,
  weights_col_name = "weight",
  model_id = "hub-ensemble",
  task_id_cols = NULL,
  n_samples = 10000,
  ...
)

Arguments

model_outputs

an object of class model_out_tbl with component model outputs (e.g., predictions).

weights

an optional data.frame with component model weights. If provided, it should have a column named model_id and a column containing model weights. Optionally, it may contain additional columns corresponding to task id variables, output_type, or output_type_id, if weights are specific to values of those variables. The default is NULL, in which case an equally-weighted ensemble is calculated. Should be prevalidated.

weights_col_name

character string naming the column in weights with model weights. Defaults to "weight"

model_id

character string with the identifier to use for the ensemble model.

task_id_cols

character vector with names of columns in model_outputs that specify modeling tasks. Defaults to NULL, in which case all columns in model_outputs other than "model_id", "output_type", "output_type_id", and "value" are used as task ids.

n_samples

numeric that specifies the number of samples to use when calculating quantiles from an estimated quantile function. Defaults to 1e4.

...

parameters that are passed to distfromq::make_q_fn, specifying details of how to estimate a quantile function from provided quantile levels and quantile values for output_type "quantile".

Value

a model_out_tbl object of ensemble predictions. Note that any additional columns in the input model_outputs are dropped.

Details

The underlying mechanism for the computations varies for different output_types. When the output_type is cdf, pmf, or mean, this function simply calls simple_ensemble to calculate a (weighted) mean of the component model outputs. This is the definitional calculation for the CDF or PMF of a linear pool. For the mean output type, this is justified by the fact that the (weighted) mean of the linear pool is the (weighted) mean of the means of the component distributions.

When the output_type is quantile, we obtain the quantiles of a linear pool in three steps:

  1. Interpolate and extrapolate from the provided quantiles for each component model to obtain an estimate of the CDF of that distribution.

  2. Draw samples from the distribution for each component model. To reduce Monte Carlo variability, we use quasi-random samples corresponding to quantiles of the estimated distribution.

  3. Collect the samples from all component models and extract the desired quantiles.

Steps 1 and 2 in this process are performed by distfromq::make_q_fn.

Examples

# We illustrate the calculation of a linear pool when we have quantiles from the
# component models. We take the components to be normal distributions with
# means -3, 0, and 3, all standard deviations 1, and weights 0.25, 0.5, and 0.25.
library(purrr)
component_ids <- letters[1:3]
component_weights <- c(0.25, 0.5, 0.25)
component_means <- c(-3, 0, 3)

lp_qs <- seq(from = -5, to = 5, by = 0.25) # linear pool quantiles, expected outputs
ps <- rep(0, length(lp_qs))
for (m in seq_len(3)) {
  ps <- ps + component_weights[m] * pnorm(lp_qs, mean = component_means[m])
}

component_qs <- purrr::map(component_means, ~ qnorm(ps, mean=.x)) |> unlist()
component_outputs <- data.frame(
  stringsAsFactors = FALSE,
  model_id = rep(component_ids, each = length(lp_qs)),
  target = "inc death",
  output_type = "quantile",
  output_type_id = ps,
  value = component_qs)

lp_from_component_qs <- linear_pool(
  component_outputs,
  weights = data.frame(model_id = component_ids, weight = component_weights))

head(lp_from_component_qs)
#> # A tibble: 6 × 5
#>   model_id     target    output_type output_type_id value
#>   <chr>        <chr>     <chr>                <dbl> <dbl>
#> 1 hub-ensemble inc death quantile           0.00569 -5.00
#> 2 hub-ensemble inc death quantile           0.0100  -4.75
#> 3 hub-ensemble inc death quantile           0.0167  -4.50
#> 4 hub-ensemble inc death quantile           0.0264  -4.25
#> 5 hub-ensemble inc death quantile           0.0397  -4.00
#> 6 hub-ensemble inc death quantile           0.0567  -3.75
all.equal(lp_from_component_qs$value, lp_qs, tolerance = 1e-3,
          check.attributes=FALSE)
#> [1] TRUE