"""Scikit-learn wrapper for ranger survival."""
import typing as t
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted
from skranger import ranger
from skranger.tree.base import BaseRangerTree
if t.TYPE_CHECKING: # pragma: no cover
from skranger.ensemble.survival import RangerForestSurvival
[docs]class RangerTreeSurvival(BaseRangerTree, BaseEstimator):
r"""Ranger Survival implementation for sci-kit survival.
Provides a sksurv interface to the Ranger C++ library using Cython.
:param bool verbose: Enable ranger's verbose logging
:param int/callable mtry: The number of features to split on each node. When a
callable is passed, the function must accept a single parameter which is the
number of features passed, and return some value between 1 and the number of
features.
:param str importance: One of one of ``none``, ``impurity``, ``impurity_corrected``,
``permutation``.
:param int min_node_size: The minimal node size.
:param int max_depth: The maximal tree depth; 0 means unlimited.
:param bool replace: Sample with replacement.
:param float sample_fraction: The fraction of observations to sample. The default
is 1 when sampling with replacement, and 0.632 otherwise.
:param bool keep_inbag: If true, save how often observations are in-bag in each
tree. These will be stored in the ``ranger_forest_`` attribute under the key
``"inbag_counts"``.
:param list inbag: A list of size ``n_estimators``, containing inbag counts for each
observation. Can be used for stratified sampling.
:param str split_rule: One of ``logrank``, ``extratrees``, ``C``, or ``maxstat``,
default ``logrank``.
:param int num_random_splits: The number of random splits to consider for the
``extratrees`` splitrule.
:param float alpha: Significance threshold to allow splitting for the ``maxstat``
split rule.
:param float minprop: Lower quantile of covariate distribution to be considered for
splitting for ``maxstat`` split rule.
:param str respect_categorical_features: One of ``ignore``, ``order``, ``partition``.
The default is ``partition`` for the ``extratrees`` splitrule, otherwise the
default is ``ignore``.
:param bool scale_permutation_importance: For ``permutation`` importance,
scale permutation importance by standard error as in (Breiman 2001).
:param bool local_importance: For ``permutation`` importance, calculate and
return local importance values as (Breiman 2001).
:param list regularization_factor: A vector of regularization factors for the
features.
:param bool regularization_usedepth: Whether to consider depth in regularization.
:param bool holdout: Hold-out all samples with case weight 0 and use these for
feature importance and prediction error.
:param bool oob_error: Whether to calculate out-of-bag prediction error.
:param int seed: Random seed value.
:ivar int n_features_in\_: The number of features (columns) from the fit input ``X``.
:ivar list feature_names\_: Names for the features of the fit input ``X``.
:ivar dict ranger_forest\_: The returned result object from calling C++ ranger.
:ivar int mtry\_: The mtry value as determined if ``mtry`` is callable, otherwise
it is the same as ``mtry``.
:ivar float sample_fraction\_: The sample fraction determined by input validation.
:ivar list regularization_factor\_: The regularization factors determined by input
validation.
:ivar list unordered_feature_names\_: The unordered feature names determined by
input validation.
:ivar int split_rule\_: The split rule integer corresponding to ranger enum
``SplitRule``.
:ivar bool use_regularization_factor\_: Input validation determined bool for using
regularization factor input parameter.
:ivar str respect_categorical_features\_: Input validation determined string
respecting categorical features.
:ivar int importance_mode\_: The importance mode integer corresponding to ranger
enum ``ImportanceMode``.
:ivar ndarray feature_importances\_: The variable importances from ranger.
"""
def __init__(
self,
*,
verbose=False,
mtry=0,
importance="none",
min_node_size=0,
max_depth=0,
replace=True,
sample_fraction=None,
keep_inbag=False,
inbag=None,
split_rule="logrank",
num_random_splits=1,
alpha=0.5,
minprop=0.1,
respect_categorical_features=None,
scale_permutation_importance=False,
local_importance=False,
regularization_factor=None,
regularization_usedepth=False,
holdout=False,
oob_error=False,
seed=42,
):
self.verbose = verbose
self.mtry = mtry
self.importance = importance
self.min_node_size = min_node_size
self.max_depth = max_depth
self.replace = replace
self.sample_fraction = sample_fraction
self.keep_inbag = keep_inbag
self.inbag = inbag
self.split_rule = split_rule
self.num_random_splits = num_random_splits
self.alpha = alpha
self.minprop = minprop
self.respect_categorical_features = respect_categorical_features
self.scale_permutation_importance = scale_permutation_importance
self.local_importance = local_importance
self.regularization_factor = regularization_factor
self.regularization_usedepth = regularization_usedepth
self.holdout = holdout
self.oob_error = oob_error
self.seed = seed
[docs] @classmethod
def from_forest(cls, forest: "RangerForestSurvival", idx: int):
"""Extract a tree from a forest.
:param RangerForestClassifier forest: A trained RangerForestClassifier instance
:param int idx: The tree index from the forest to extract.
"""
# Even though we have a tree object, we keep the exact same dictionary structure
# that exists in the forests, so that we can reuse the Cython entrypoints.
# We also copy over some instance attributes from the trained forest.
# params
instance = cls(
verbose=forest.verbose,
mtry=forest.mtry,
importance=forest.importance,
min_node_size=forest.min_node_size,
max_depth=forest.max_depth,
replace=forest.replace,
sample_fraction=forest.sample_fraction,
keep_inbag=forest.keep_inbag,
inbag=forest.inbag,
split_rule=forest.split_rule,
num_random_splits=forest.num_random_splits,
alpha=forest.alpha,
minprop=forest.minprop,
respect_categorical_features=forest.respect_categorical_features,
scale_permutation_importance=forest.scale_permutation_importance,
local_importance=forest.local_importance,
regularization_factor=forest.regularization_factor,
regularization_usedepth=forest.regularization_usedepth,
holdout=forest.holdout,
oob_error=forest.oob_error,
seed=forest.seed,
)
# forest
ranger_forest = {}
for k, v in forest.ranger_forest_.items():
if k == "forest":
ranger_forest[k] = {}
for fk, fv in v.items():
if isinstance(fv, list) and len(fv) > 0 and isinstance(fv[0], list):
ranger_forest[k][fk] = [fv[idx]]
else:
ranger_forest[k][fk] = fv
else:
ranger_forest[k] = v
ranger_forest["num_trees"] = 1
instance.ranger_forest_ = ranger_forest
# vars
instance.n_features_in_ = forest.n_features_in_
instance.feature_names_ = forest.feature_names_
instance.sample_fraction_ = forest.sample_fraction_
instance.mtry_ = forest.mtry_
instance.regularization_factor_ = forest.regularization_factor_
instance.split_rule_ = forest.split_rule_
instance.use_regularization_factor_ = forest.use_regularization_factor_
instance.respect_categorical_features_ = forest.respect_categorical_features_
instance.importance_mode_ = forest.importance_mode_
instance.tree_type_ = forest.tree_type_
instance.event_times_ = forest.event_times_
return instance
[docs] def fit(
self,
X,
y,
sample_weight=None,
split_select_weights=None,
always_split_features=None,
categorical_features=None,
):
"""Fit the ranger random forest using training data.
:param array2d X: training input features
:param array2d y: training input targets, rows of (bool, float)
representing (survival, time)
:param array1d sample_weight: optional weights for input samples
:param list split_select_weights: Vector of weights between 0 and 1 of
probabilities to select features for splitting. Can be a single vector or a
vector of vectors with one vector per tree.
:param list always_split_features: Features which should always be selected for
splitting. A list of column index values.
:param list categorical_features: A list of column index values which should be
considered categorical, or unordered.
"""
self.tree_type_ = 5 # tree_type, TREE_SURVIVAL
# Check input
X = check_array(X)
# convert 1d array of 2tuples to 2d array
# ranger expects the time first, and status second
# since we follow the scikit-survival convention, we fliplr
yr = np.fliplr(np.array(y.tolist()))
# Check the init parameters
self._validate_parameters(X, y, sample_weight)
# Set X info
self.feature_names_ = [str(c).encode() for c in range(X.shape[1])]
self._check_n_features(X, reset=True)
# Check weights
sample_weight, use_sample_weight = self._check_sample_weight(sample_weight, X)
(
always_split_features,
use_always_split_features,
) = self._check_always_split_features(always_split_features)
(
categorical_features,
use_categorical_features,
) = self._check_categorical_features(categorical_features)
(
split_select_weights,
use_split_select_weights,
) = self._check_split_select_weights(split_select_weights)
# Fit the forest
self.ranger_forest_ = ranger.ranger(
self.tree_type_,
np.asfortranarray(X.astype("float64")),
np.asfortranarray(yr.astype("float64")),
self.feature_names_, # variable_names
self.mtry_,
1, # num_trees
self.verbose,
self.seed,
1, # num_threads
True, # write_forest
self.importance_mode_,
self.min_node_size,
split_select_weights,
use_split_select_weights,
always_split_features, # always_split_variable_names
use_always_split_features, # use_always_split_variable_names
False, # prediction_mode
{}, # loaded_forest
self.replace, # sample_with_replacement
False, # probability
categorical_features, # unordered_feature_names
use_categorical_features, # use_unordered_features
False, # save_memory
self.split_rule_,
sample_weight, # case_weights
use_sample_weight, # use_case_weights
{}, # class_weights
False, # predict_all
self.keep_inbag,
self.sample_fraction_,
self.alpha,
self.minprop,
self.holdout,
1, # prediction_type
self.num_random_splits,
self.oob_error,
self.max_depth,
self.inbag or [],
bool(self.inbag), # use_inbag
self.regularization_factor_,
False, # use_regularization_factor
self.regularization_usedepth,
)
self.event_times_ = np.array(
self.ranger_forest_["forest"]["unique_death_times"]
)
# dtype to suppress warning about ragged nested sequences
self.cumulative_hazard_function_ = np.array(
self.ranger_forest_["forest"]["cumulative_hazard_function"], dtype=object
)
sample_weight = sample_weight if sample_weight != [] else np.ones(len(X))
terminal_node_forest = self._get_terminal_node_forest(X)
terminal_nodes = np.atleast_2d(terminal_node_forest["predictions"]).astype(int)
self._set_leaf_samples(terminal_nodes)
self._set_node_values(np.array(y.tolist()), sample_weight)
self._set_n_classes()
return self
def _predict(self, X):
check_is_fitted(self)
X = check_array(X)
self._check_n_features(X, reset=False)
result = ranger.ranger(
self.tree_type_,
np.asfortranarray(X.astype("float64")),
np.asfortranarray([[]]),
self.feature_names_, # variable_names
self.mtry_,
1, # num_trees
self.verbose,
self.seed,
1, # num_threads
False, # write_forest
self.importance_mode_,
self.min_node_size,
[],
False, # use_split_select_weights
[], # always_split_variable_names
False, # use_always_split_variable_names
True, # prediction_mode
self.ranger_forest_["forest"], # loaded_forest
self.replace, # sample_with_replacement
False, # probability
[], # unordered_feature_names
False, # use_unordered_features
False, # save_memory
self.split_rule_,
[], # case_weights
False, # use_case_weights
{}, # class_weights
False, # predict_all
self.keep_inbag,
[1], # sample_fraction
self.alpha,
self.minprop,
self.holdout,
1, # prediction_type
self.num_random_splits,
self.oob_error,
self.max_depth,
self.inbag or [],
bool(self.inbag), # use_inbag
self.regularization_factor_,
self.use_regularization_factor_,
self.regularization_usedepth,
)
return result
[docs] def predict_cumulative_hazard_function(self, X):
"""Predict cumulative hazard function.
:param array2d X: prediction input features
"""
result = self._predict(X)
return np.atleast_2d(result["predictions"])
[docs] def predict_survival_function(self, X):
"""Predict survival function.
:param array2d X: prediction input features
"""
chf = self.predict_cumulative_hazard_function(X)
return np.exp(-chf)
[docs] def predict(self, X):
"""Predict risk score.
:param array2d X: prediction input features
"""
chf = self.predict_cumulative_hazard_function(X)
return chf.sum(1)
def _more_tags(self):
return {
"requires_y": True,
"_xfail_checks": {
"check_sample_weights_invariance": "zero sample_weight is not equivalent to removing samples",
},
}