Compatibility with scikit-learn

New to lifelines in version 0.21.3 is a wrapper that allows you to use lifeline’s regression models with scikit-learn’s APIs.


the API and functionality is still experimental. Please report any bugs or features on our Github issue list.

from lifelines.utils.sklearn_adapter import sklearn_adapter

from lifelines import CoxPHFitter
from lifelines.datasets import load_rossi

X = load_rossi().drop('week', axis=1) # keep as a dataframe
Y = load_rossi().pop('week')

CoxRegression = sklearn_adapter(CoxPHFitter, event_col='arrest')
# CoxRegression is a class like the `LinearRegression` class or `SVC` class in scikit-learn

sk_cph = CoxRegression(penalizer=1e-5), Y)

SkLearnCoxPHFitter(alpha=0.05, penalizer=1e-5, strata=None, tie_method='Efron')

sk_cph.score(X, Y)


The X variable still needs to be a DataFrame, and should contain the event-occurred column (event_col) if it exists.

If needed, the original lifeline’s instance is available as the lifelines_model attribute.


The wrapped classes can even be used in more complex scikit-learn functions (ex: cross_val_score) and classes (ex: GridSearchCV):

import numpy as np
from lifelines import WeibullAFTFitter
from sklearn.model_selection import cross_val_score

base_class = sklearn_adapter(WeibullAFTFitter, event_col='arrest')
wf = base_class()

scores = cross_val_score(wf, X, Y, cv=5)

[0.59037328 0.503427   0.55454545 0.59689534 0.62311068]

from sklearn.model_selection import GridSearchCV

clf = GridSearchCV(wf, {
   "penalizer": 10.0 ** np.arange(-2, 3),
   "l1_ratio": [0, 1/3, 2/3],
   "model_ancillary": [True, False],
}, cv=4), Y)


SkLearnWeibullAFTFitter(alpha=0.05, fit_intercept=True,
                        l1_ratio=0.66666, model_ancillary=True,



The lifelines.utils.sklearn_adapter() is currently only designed to work with right-censored data.


A note on saving these models: saving can be done with any serialization library, but to load them in a different script / program, you may need to recreate the class (this is a consequence of the implementation). Ex:

# needed to reload
from lifelines.utils.sklearn_adapter import sklearn_adapter
from lifelines import CoxPHFitter
sklearn_adapter(CoxPHFitter, event_col='arrest')

from joblib import load

model = load(...)