
Compatibility with scikit-learn¶
New to lifelines in version 0.21.3 is a wrapper that allows you to use lifeline’s regression models with scikit-learn’s APIs.
Note
the API and functionality is still experimental. Please report any bugs or features on our Github issue list.
from lifelines.utils.sklearn_adapter import sklearn_adapter
from lifelines import CoxPHFitter
from lifelines.datasets import load_rossi
X = load_rossi().drop('week', axis=1) # keep as a dataframe
Y = load_rossi().pop('week')
CoxRegression = sklearn_adapter(CoxPHFitter, event_col='arrest')
# CoxRegression is a class like the `LinearRegression` class or `SVC` class in scikit-learn
sk_cph = CoxRegression(penalizer=1e-5)
sk_cph.fit(X, Y)
print(sk_cph)
"""
SkLearnCoxPHFitter(alpha=0.05, penalizer=1e-5, strata=None, tie_method='Efron')
"""
sk_cph.predict(X)
sk_cph.score(X, Y)
Note
The X variable still needs to be a DataFrame, and should contain the event-occurred column (event_col
) if it exists.
If needed, the original lifeline’s instance is available as the lifelines_model
attribute.
sk_cph.lifelines_model.print_summary()
The wrapped classes can even be used in more complex scikit-learn functions (ex: cross_val_score
) and classes (ex: GridSearchCV
):
import numpy as np
from lifelines import WeibullAFTFitter
from sklearn.model_selection import cross_val_score
base_class = sklearn_adapter(WeibullAFTFitter, event_col='arrest')
wf = base_class()
scores = cross_val_score(wf, X, Y, cv=5)
print(scores)
"""
[0.59037328 0.503427 0.55454545 0.59689534 0.62311068]
"""
from sklearn.model_selection import GridSearchCV
clf = GridSearchCV(wf, {
"penalizer": 10.0 ** np.arange(-2, 3),
"l1_ratio": [0, 1/3, 2/3],
"model_ancillary": [True, False],
}, cv=4)
clf.fit(X, Y)
print(clf.best_estimator_)
"""
SkLearnWeibullAFTFitter(alpha=0.05, fit_intercept=True,
l1_ratio=0.66666, model_ancillary=True,
penalizer=0.01)
"""
Note
The lifelines.utils.sklearn_adapter()
is currently only designed to work with right-censored data.
Serialization¶
A note on saving these models: saving can be done with any serialization library, but to load them in a different script / program, you may need to recreate the class (this is a consequence of the implementation). Ex:
# needed to reload
from lifelines.utils.sklearn_adapter import sklearn_adapter
from lifelines import CoxPHFitter
sklearn_adapter(CoxPHFitter, event_col='arrest')
from joblib import load
model = load(...)