RMST and survival forests

· 5 minutes read

This is a follow-up from my previous post on integrating survival curves and computing restricted mean survival times (RMST), which concluded with:

[…] it can be replicated with every model that yields predictions for the survival function \(S(t)\).

This made me think: can we actually compute the difference in RMST after fitting a random survival forest? Well, it turns out that yes, we can!

Let’s use the same dataset from the German breast cancer study:

1
2
library(haven)
brcancer <- read_dta("https://www.stata-press.com/data/r16/brcancer.dta")

We will be using the randomForestSRC package to fit a random survival forest. There are other options (such as the cforest function from the party package), but we’ll stick with randomForestSRC for simplicity.

1
library(randomForestSRC)

The model we fit is the same model as before, with just treatment as a binary covariate:

1
2
3
set.seed(295682735)

fit <- rfsrc(Surv(rectime, censrec) ~ hormon, data = brcancer)

Again, for simplicity, we’ll use the default arguments of rfsrc (e.g. the number of trees to grow for the algorithm); we also set a seed for reproducibility. Let’s print the model fit:

1
fit
##                          Sample size: 686
##                     Number of deaths: 299
##                      Number of trees: 1000
##            Forest terminal node size: 15
##        Average no. of terminal nodes: 2
## No. of variables tried at each split: 1
##               Total no. of variables: 1
##        Resampling used to grow trees: swor
##     Resample size used to grow trees: 434
##                             Analysis: RSF
##                               Family: surv
##                       Splitting rule: logrank *random*
##        Number of random split points: 10
##                           Error rate: 67.66%

We don’t really care about the model accuracy here, as all of this is just a proof of concept.

randomForestSRC provides a predict method, which we can use to obtain prediction on a test dataset of two individuals (one treated, one untreated):

1
2
3
4
5
6
data_grid <- expand.grid(
  rectime = 1,
  censrec = 0,
  hormon = unique(brcancer$hormon)
)
fit_prediction <- predict(fit, newdata = data_grid)

We need to provide a value for rectime and censrec, despite it not being used by predict; by default, predictions are computed at each observed event time.

Interestingly, a variety of predictions are returned (e.g. survival probability, cumulative hazard, etc.). We can easily extract the survival predictions, which is a matrix with as many rows as individuals in the test dataset and as many columns as the number of distinct observed event times:

1
2
3
survival_prediction <- fit_prediction$survival

class(survival_prediction)
## [1] "matrix" "array"
1
dim(survival_prediction)
## [1]   2 270

We process the prediction to be a tidy dataset:

1
2
3
4
# First reshape...
tidy_survival_prediction <- data.frame(fit_prediction$time.interest, t(survival_prediction))
names(tidy_survival_prediction) <- c("rectime", data_grid$hormon)
head(tidy_survival_prediction)
##   rectime         0         1
## 1      72 0.9975730 1.0000000
## 2      98 0.9952686 1.0000000
## 3     113 0.9929122 1.0000000
## 4     120 0.9905855 1.0000000
## 5     160 0.9881725 1.0000000
## 6     169 0.9881725 0.9958175
1
2
3
4
5
6
7
8
9
# Second reshape...
library(tidyr)
tidy_survival_prediction <- pivot_longer(
  tidy_survival_prediction,
  cols = 2:3,
  names_to = "hormon",
  values_to = "S_hat_RF"
)
head(tidy_survival_prediction)
## # A tibble: 6 x 3
##   rectime hormon S_hat_RF
##     <dbl> <chr>     <dbl>
## 1      72 0         0.998
## 2      72 1         1    
## 3      98 0         0.995
## 4      98 1         1    
## 5     113 0         0.993
## 6     113 1         1
1
2
3
# Adding a new factor for pretty plotting:
tidy_survival_prediction$hormon <- as.numeric(tidy_survival_prediction$hormon)
tidy_survival_prediction$hormon2 <- ifelse(tidy_survival_prediction$hormon == 0, "Control arm", "Treatment arm")

We can finally plot the fitted survival curves from the random survival forest model:

1
2
3
4
5
6
7
library(ggplot2)
ggplot(tidy_survival_prediction, aes(x = rectime, y = S_hat_RF, linetype = hormon2)) +
  geom_line() +
  coord_cartesian(ylim = c(0, 1)) +
  scale_x_continuous(breaks = 365 * seq(6)) +
  theme(legend.position = c(0, 0), legend.justification = c(0, 0)) +
  labs(x = "Follow-up time (days)", y = "Fitted survival", linetype = "")

Not too bad!

Let’s now compute the fitted survival curves from the flexible parametric model, as a comparison:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
# Fit...
library(rstpm2)
fit_rstpm2 <- stpm2(Surv(rectime, censrec) ~ hormon, data = brcancer, df = 5)

# Predict...
tidy_survival_prediction$S_hat_FPM <- predict(
  fit_rstpm2,
  type = "surv",
  newdata = tidy_survival_prediction
)

# ...and plot!
ggplot(tidy_survival_prediction, aes(x = rectime, linetype = hormon2)) +
  geom_line(aes(y = S_hat_RF, color = "Random forest")) +
  geom_line(aes(y = S_hat_FPM, color = "FPM")) +
  coord_cartesian(ylim = c(0, 1)) +
  scale_x_continuous(breaks = 365 * seq(6)) +
  scale_color_manual(values = c("red", "black")) +
  theme(legend.position = c(0, 0), legend.justification = c(0, 0)) +
  labs(x = "Follow-up time (days)", y = "Fitted survival", linetype = "", color = "")

The two models seem to agree fairly well.

Let’s finally calculate the difference in RMST using the predictions from the random forest model:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
idx_0 <- which(tidy_survival_prediction$hormon == 0 & tidy_survival_prediction$rectime <= 365 * 5)
df_0 <- tidy_survival_prediction[idx_0, ]
int_spline_0 <- splinefun(
  x = df_0$rectime,
  y = df_0$S_hat_RF,
  method = "natural"
)

idx_1 <- which(tidy_survival_prediction$hormon == 1 & tidy_survival_prediction$rectime <= 365 * 5)
df_1 <- tidy_survival_prediction[idx_1, ]
int_spline_1 <- splinefun(
  x = df_1$rectime,
  y = df_1$S_hat_RF,
  method = "natural"
)

Here we use the whole fitted survival curve to fit the spline interpolation. The RMST can finally be calculated as:

1
2
RMST_0 <- integrate(f = int_spline_0, lower = 0, upper = 365 * 5)$value
RMST_0
## [1] 1260.773
1
2
RMST_1 <- integrate(f = int_spline_1, lower = 0, upper = 365 * 5)$value
RMST_1
## [1] 1413.315

The difference in RMST will therefore be RMST_1 - RMST_0 = 153 days, or equivalently, 0.418 years. Remember: using a flexible parametric model, the estimated difference in RMST was 140 days (0.382 years).

In conclusion, yes, it is possible to calculate RMST after fitting a random survival forest model. The example that was described above is kind of silly (I mean, we included a single binary covariate in a random forest model), but still kind of useful to illustrate how to do that in R. I guess I satisfied my curiosity — for now…