8000 Better interface to explore trees · Issue #75 · ropensci/aorsf · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Better interface to explore trees #75
Open
@bcjaeger

Description

@bcjaeger

Based on a conversation that started in #33 with @emilyriederer. The exported data from aorsf is enough to reproduce some elements of growing trees, but is not intuitive enough to do exploratory hypothesis testing. The example below shows how things can work with the limited tools aorsf currently has, but also shows how awkward it is to reproduce something inside of an aorsf object.

Emily, would you be willing to try the code below as a base to test your ideas about feature generation? I would be grateful for all your thoughts on sources of friction. We could plan a few exported functions to address the largest pain points. I've already run into an issue with replicating results from forests that sample with replacement and realized I don't export bootstrap weights from C++, which is why I had to set sample_with_replacement=FALSE in my example. There's plenty more that could be improved on. =]

# replicating some of aorsf's internal logic in R
# this works but there are several little things that I would
# not anticipate a user knowing about unless they were very
# familiar with the cpp back-end.

library(aorsf)
library(survival)
library(magrittr)

data_example <- pbc_orsf

# ordered factors are a nuisance for model.matrix() so convert to int
data_example$stage <- as.integer(data_example$stage)
# drop unused variables here instead of dropping them in the formula.
# otherwise, the column indices from the orsf fit will be off.
data_example$id <- NULL

set.seed(123)

# Need to set sample_with_replacement to false to replicate results.
# if it isn't false, then bootstrap weights are generated inside of
# cpp to simulate boostrapping, and those weights aren't exported to R,
# so we won't be able to replicate results from R.

fit <- orsf(data_example, time + status ~ ., 
            sample_with_replacement = FALSE, 
            leaf_min_obs = 60)

# you can tell the first tree in this forest only has two leaves b/c 
# if you print out the list of predictions from the first tree:
fit$forest$leaf_pred_prob[[1]]
#> [[1]]
#> numeric(0)
#> 
#> [[2]]
#>  [1] 0.9898990 0.9797980 0.9695918 0.9591660 0.9486257 0.9366178 0.9241296
#>  [8] 0.9096901 0.8942716 0.8780121 0.8604519 0.8408961 0.8154144 0.7862925
#> [15] 0.7548408 0.7045181 0.6340663
#> 
#> [[3]]
#>  [1] 0.98666667 0.97333333 0.96000000 0.94666667 0.93333333 0.92000000
#>  [7] 0.90666667 0.89333333 0.88000000 0.86666667 0.85333333 0.84000000
#> [13] 0.82666667 0.81333333 0.80000000 0.78666667 0.77333333 0.76000000
#> [19] 0.74666667 0.73333333 0.72000000 0.70666667 0.69333333 0.68000000
#> [25] 0.66666667 0.65333333 0.64000000 0.62666667 0.61333333 0.60000000
#> [31] 0.57333333 0.55968254 0.54495405 0.53022556 0.51549708 0.50076859
#> [37] 0.48604010 0.47085135 0.45566259 0.43938893 0.42108106 0.40194101
#> [43] 0.38280096 0.36028326 0.33776555 0.31178359 0.28343963 0.25509566
#> [49] 0.22675170 0.19840774 0.17006378 0.14171981 0.10628986 0.07085991
#> [55] 0.03542995
# it shows no predictions in the root node (first item in the list) and 
# then gives predictions for nodes 1 and 2 (second and third items). A
# non-empty value in leaf predictions indicates the given node is a leaf.

# generate the data sample for the first tree:
# first, for survival forests, the data needs to be sorted by time and status
data_example_sort <- 
 data_example[order(data_example$time, -data_example$status), ]

# second, pull the out-of-bag rows for first tree 
# (add 1 to the rows b/c C++ indexing starts at 0 and R starts at 1)
rows_oob <- fit$forest$rows_oobag[[1]] + 1

# this is the data that the first tree in the forest uses
tree_sample <- data_example_sort[-rows_oob, ]

# the data are scaled before growing the tree, so do that here too.
means <- fit$get_means()
sds <- fit$get_stdev()

for(i in names(means)){
 tree_sample[[i]] <- tree_sample[[i]] %>% 
  subtract(means[i]) %>% 
  divide_by(sds[i])
}

# linear combination values used for the first node of the first tree
node_betas <- fit$forest$coef_values[[1]][[1]]

# column index of the variables (add 1 b/c C++ indexing starts at 0)
coef_indices <- fit$forest$coef_indices[[1]][[1]] + 1

# use column index to get the names of the variables
node_variables <- fit$get_names_x(ref_coded = TRUE)[coef_indices]

# remove _'s b/c model.matrix doesnt use them
node_variables <- stringr::str_remove(node_variables, '_')

# A bit of a hack here. aorsf has its own way of getting an x matrix,
# but it's not exported. However, it's similar to using model.matrix.
xmat <- model.matrix(~ ., data = tree_sample)[, node_variables]

# get the linear combination of predictors
linear_pred <- xmat %*% node_betas

# the cut-point used to split data for first node of first tree
node_cp <- fit$forest$cutpoint[[1]][1]

# count how many go left
sum(linear_pred <= node_cp)
#> [1] 99

# this means 99 observations go to node 1, the left child of 
# the current node, which is node 0:
fit$forest$child_left[[1]][1]
#> [1] 1

# and this is how many go to node 2, the right child of node 0
sum(linear_pred > node_cp)
#> [1] 75
# the right child is always the left child + 1, so the right child is 2.

# find the rows of the right traveling sample
went_right <- which(linear_pred > node_cp)

# get predicted survival probabilites for observations in this leaf node
surv <- survfit(Surv(time, status) ~ 1, data = tree_sample[went_right, ])

# assert equality between our leaf predictions from R 
# versus the leaf predictions stored in the forests
testthat::expect_equal(
 fit$forest$leaf_pred_prob[[1]][[3]],
 unique(surv$surv)
)

Created on 2025-03-11 with reprex v2.1.1

Metadata

Metadata

Assignees

No one assigned
< 3524 ul class="Box-sc-g0xbh4-0 eYazgg prc-ActionList-ActionList-X4RiC" data-dividers="false" data-variant="full">

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0