1. Описание данных
2. Исследование зависимостей и формулирование гипотез
2.1. Анализ бинарных показателей
2.2. Анализ непрерывных показателей
2.3. Анализ корреляционной матрицы
3. Построение моделей для прогнозирования оттока
3.1. Открытая модель
3.2. Закрытая модель
3.3. Voting
3.4. Stacking
4. Сравнение качества моделей
5. Пример использования модели для прогнозирования оттока
5.1. Обучение
5.2. Выбор порога принятия решения
5.3. Демонстрация прогнозирования
5.4. Предсказание с помощью модели стэкинга
5.5. Возможные ошибки
Ссылка на исходный датасет (kaggle): Telco Customer Churn
Описание работы:
Любой бизнес хочет максимизировать количество клиентов. Для достижения этой цели важно не только пытаться привлечь новых, но и удерживать уже существующих. Удержать клиента обойдется компании дешевле, чем привлечь нового. Кроме того, новый клиент может оказаться слабо заинтересованным в услугах бизнеса и с ним будет сложно работать, тогда как о старых клиентах уже есть необходимые данные по взаимодействию с сервисом.
Соответственно, прогнозируя отток, мы можем вовремя среагировать и попытаться удержать клиента, который хочет уйти. Опираясь на данные об услугах, которыми пользуется клиент, мы можем сделать ему специальное предложение, пытаясь изменить его решение об уходе от оператора. Благодаря этому задача удержания будет легче в реализации, чем задача привлечения новых пользователей, о которых мы еще ничего не знаем.
Вам предоставлен набор данных от телекоммуникационной компании. В данных содержится информация о почти шести тысячах пользователей, их демографических характеристиках, услугах, которыми они пользуются, длительности пользования услугами оператора, методе оплаты, размере оплаты.
Cтоит задача проанализировать данные и спрогнозировать отток пользователей (выявить людей, которые продлят контракт и которые не продлят). Работа должна включать в себя следующие обязательные пункты:
Во втором разделе обязательно должно присутствовать обоснование гипотез, подробное описание выявленных взаимосвязей, а также их визуализация.
В четвертом дополнительно должны быть сформулированы общие выводы работы.
Codebook
1. Описание данных
2. Исследование зависимостей и формулирование гипотез
3. Построение моделей для прогнозирования оттока
4. Сравнение качества моделей
telecom_users.csv
содержит следующие значения:
customerID
– id клиента
gender
– пол клиента (male/female)
SeniorCitizen
– яляется ли клиент пенсионером (1, 0)
Partner
– состоит ли клиент в браке (Yes, No)
Dependents
– есть ли у клиента иждивенцы (Yes, No)
tenure
– сколько месяцев человек являлся клиентом компании
PhoneService
– подключена ли услуга телефонной связи (Yes, No)
MultipleLines
– подключены ли несколько телефонных линий (Yes, No, No phone service)
InternetService
– интернет-провайдер клиента (DSL, Fiber optic, No)
OnlineSecurity
– подключена ли услуга онлайн-безопасности (Yes, No, No internet service)
OnlineBackup
– подключена ли услуга online backup (Yes, No, No internet service)
DeviceProtection
– есть ли у клиента страховка оборудования (Yes, No, No internet service)
TechSupport
– подключена ли услуга технической поддержки (Yes, No, No internet service)
StreamingTV
– подключена ли услуга стримингового телевидения (Yes, No, No internet service)
StreamingMovies
– подключена ли услуга стримингового кинотеатра (Yes, No, No internet service)
Contract
– тип контракта клиента (Month-to-month, One year, Two year)
PaperlessBilling
– пользуется ли клиент безбумажным биллингом (Yes, No)
PaymentMethod
– метод оплаты (Electronic check, Mailed check, Bank transfer (automatic), Credit card (automatic))
MonthlyCharges
– месячный размер оплаты на настоящий момент
TotalCharges
– общая сумма, которую клиент заплатил за услуги за все время
Churn
– произошел ли отток (Yes or No)
# Configuration
# Cache directory used for sklearn's pipeline caching. Set to `None` to disable caching.
# sklearn.pipeline.Pipeline(memory=CACHE_DIR)
CACHE_DIR = '_cache-ml-telecom-users' # `None` to disable
REMOVE_CACHE = False # remove cache directory after completion
# Utility functions
import re
def split_camel_case(cc_string):
"""
>>> split_camel_case('HTTP2Service')
['HTTP2', 'Service']
>>> split_camel_case('CellRangeA1Z99')
['Cell', 'Range', 'A1', 'Z99']
>>> split_camel_case('customerID')
['customer', 'ID']
"""
return re.split(r'(?<=\d)(?=\D)|(?<=[^A-Z\d])(?=[A-Z\d])|(?<!^)(?=[A-Z][a-z])', cc_string)
def camel_to_snake_case(cc_string):
""" camel_to_snake_case('customerID') -> 'customer_id' """
return '_'.join(split_camel_case(cc_string)).lower()
import itertools
import shutil
import warnings
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
df_raw = pd.read_csv('data/telecom_users.csv', index_col=0, na_values=[' '])
df_raw.dtypes.to_frame('Dtype').join(df_raw.head().T)
# Raw data imported:
Dtype | 1869 | 4528 | 6344 | 6739 | 432 | |
---|---|---|---|---|---|---|
customerID | object | 7010-BRBUU | 9688-YGXVR | 9286-DOJGF | 6994-KERXL | 2181-UAESM |
gender | object | Male | Female | Female | Male | Male |
SeniorCitizen | int64 | 0 | 0 | 1 | 0 | 0 |
Partner | object | Yes | No | Yes | No | No |
Dependents | object | Yes | No | No | No | No |
tenure | int64 | 72 | 44 | 38 | 4 | 2 |
PhoneService | object | Yes | Yes | Yes | Yes | Yes |
MultipleLines | object | Yes | No | Yes | No | No |
InternetService | object | No | Fiber optic | Fiber optic | DSL | DSL |
OnlineSecurity | object | No internet service | No | No | No | Yes |
OnlineBackup | object | No internet service | Yes | No | No | No |
DeviceProtection | object | No internet service | Yes | No | No | Yes |
TechSupport | object | No internet service | No | No | No | No |
StreamingTV | object | No internet service | Yes | No | No | No |
StreamingMovies | object | No internet service | No | No | Yes | No |
Contract | object | Two year | Month-to-month | Month-to-month | Month-to-month | Month-to-month |
PaperlessBilling | object | No | Yes | Yes | Yes | No |
PaymentMethod | object | Credit card (automatic) | Credit card (automatic) | Bank transfer (automatic) | Electronic check | Electronic check |
MonthlyCharges | float64 | 24.1 | 88.15 | 74.95 | 55.9 | 53.45 |
TotalCharges | float64 | 1734.65 | 3973.2 | 2869.85 | 238.5 | 119.5 |
Churn | object | No | No | Yes | No | No |
df_raw.info()
<class 'pandas.core.frame.DataFrame'> Int64Index: 5986 entries, 1869 to 860 Data columns (total 21 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 customerID 5986 non-null object 1 gender 5986 non-null object 2 SeniorCitizen 5986 non-null int64 3 Partner 5986 non-null object 4 Dependents 5986 non-null object 5 tenure 5986 non-null int64 6 PhoneService 5986 non-null object 7 MultipleLines 5986 non-null object 8 InternetService 5986 non-null object 9 OnlineSecurity 5986 non-null object 10 OnlineBackup 5986 non-null object 11 DeviceProtection 5986 non-null object 12 TechSupport 5986 non-null object 13 StreamingTV 5986 non-null object 14 StreamingMovies 5986 non-null object 15 Contract 5986 non-null object 16 PaperlessBilling 5986 non-null object 17 PaymentMethod 5986 non-null object 18 MonthlyCharges 5986 non-null float64 19 TotalCharges 5976 non-null float64 20 Churn 5986 non-null object dtypes: float64(2), int64(2), object(17) memory usage: 1.0+ MB
# Пропущенные значения содержатся только в строках, где tenure = 0
df_raw[df_raw.tenure==0].TotalCharges.isna().sum()
10
Результаты первичного осмотра:
План подготовки данных:
customerID
.pd.get_dummies()
).# Подготовка данных
df = df_raw.copy()
df = df.drop(columns='customerID') # (1)
df = df.fillna(0) # (2)
df = df.sort_index() # (3)
yes_no_columns = [
'Partner',
'Dependents',
'PhoneService',
'MultipleLines',
'OnlineSecurity',
'OnlineBackup',
'DeviceProtection',
'TechSupport',
'StreamingTV',
'StreamingMovies',
'PaperlessBilling',
'Churn',
]
df[yes_no_columns] = df[yes_no_columns] == 'Yes' # (4)
df.SeniorCitizen = df.SeniorCitizen == 1
df = pd.get_dummies(df, dtype=bool) # (5)
df = df.drop(columns=['gender_Female'])
df['InternetService_No'] = ~df['InternetService_No']
column_names_mapping = {
**dict(zip(df_raw.columns, df_raw.columns.map(camel_to_snake_case))),
'gender_Male': 'is_male',
'InternetService_DSL': 'internet_dsl',
'InternetService_Fiber optic': 'internet_fiber',
'InternetService_No': 'internet_service',
'Contract_Month-to-month': 'contract_one_month',
'Contract_One year': 'contract_one_year',
'Contract_Two year': 'contract_two_year',
'PaymentMethod_Bank transfer (automatic)': 'pay_auto_transfer',
'PaymentMethod_Credit card (automatic)': 'pay_auto_credit',
'PaymentMethod_Electronic check': 'pay_check_email',
'PaymentMethod_Mailed check': 'pay_check_mail',
}
df = df.rename(columns=column_names_mapping) # (6)
middle_columns = df.columns.drop(['churn', 'tenure', 'monthly_charges', 'total_charges'])
df = df[['churn', *middle_columns, 'tenure', 'monthly_charges', 'total_charges']] # reorder columns
df.head(10).T
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 9 | 10 | |
---|---|---|---|---|---|---|---|---|---|---|
churn | False | False | True | False | True | True | False | False | False | False |
senior_citizen | False | False | False | False | False | False | False | False | False | False |
partner | True | False | False | False | False | False | False | False | False | True |
dependents | False | False | False | False | False | False | True | False | True | True |
phone_service | False | True | True | False | True | True | True | False | True | True |
multiple_lines | False | False | False | False | False | True | True | False | False | False |
online_security | False | True | True | True | False | False | False | True | True | True |
online_backup | True | False | True | False | False | False | True | False | True | False |
device_protection | False | True | False | True | False | True | False | False | False | False |
tech_support | False | False | False | True | False | False | False | False | False | False |
streaming_tv | False | False | False | False | False | True | True | False | False | False |
streaming_movies | False | False | False | False | False | True | False | False | False | False |
paperless_billing | True | False | True | False | True | True | True | False | False | True |
is_male | False | True | True | True | False | False | True | False | True | True |
internet_dsl | True | True | True | True | False | False | False | True | True | True |
internet_fiber | False | False | False | False | True | True | True | False | False | False |
internet_service | True | True | True | True | True | True | True | True | True | True |
contract_one_month | True | False | True | False | True | True | True | True | False | True |
contract_one_year | False | True | False | True | False | False | False | False | True | False |
contract_two_year | False | False | False | False | False | False | False | False | False | False |
pay_auto_transfer | False | False | False | True | False | False | False | False | True | False |
pay_auto_credit | False | False | False | False | False | False | True | False | False | False |
pay_check_email | True | False | False | False | True | True | False | False | False | False |
pay_check_mail | False | True | True | False | False | False | False | True | False | True |
tenure | 1 | 34 | 2 | 45 | 2 | 8 | 22 | 10 | 62 | 13 |
monthly_charges | 29.85 | 56.95 | 53.85 | 42.3 | 70.7 | 99.65 | 89.1 | 29.75 | 56.15 | 49.95 |
total_charges | 29.85 | 1889.5 | 108.15 | 1840.75 | 151.65 | 820.5 | 1949.4 | 301.9 | 3487.95 | 587.45 |
# Статистика по бинарным показателям (type=bool)
df_bool = df.loc[:, df.dtypes=='bool']
df_bool_true_rate = df_bool.sum() / df_bool.count()
df_bool_stats = pd.DataFrame({
'true_rate': df_bool_true_rate,
'false_rate': 1 - df_bool_true_rate})
(df_bool_stats * 100).plot.bar(figsize=(14, 5), rot=-30, stacked=True)
plt.axhline(df_bool_true_rate.churn * 100, c='k', lw='1', ls='--', label='Target Level')
plt.xticks(ha='left')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.box()
plt.ylabel('%')
plt.title('Binary Features Stats')
pass
# Статистика по числовым (continuous) показателям
df.describe(exclude=bool).T
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
tenure | 5986.0 | 32.468760 | 24.516391 | 0.00 | 9.0000 | 29.000 | 56.0 | 72.00 |
monthly_charges | 5986.0 | 64.802213 | 30.114702 | 18.25 | 35.6500 | 70.400 | 89.9 | 118.75 |
total_charges | 5986.0 | 2294.221559 | 2274.164124 | 0.00 | 401.5875 | 1408.575 | 3841.5 | 8684.80 |
# Каждый бинарный показатель образует две группы: позитивную (=True) и негативную (=False).
# Например: contract_two_year и non_contract_two_year, is_male и non_is_male.
subdf_pos = df_bool.drop(columns='churn')
subdf_neg = ~subdf_pos.add_prefix('non_')
# Процент оттока (churn rate) считается по каждой группе, как позитивной так и негативной.
churn_pos = subdf_pos[df.churn].sum() / subdf_pos.sum()
churn_neg = subdf_neg[df.churn].sum() / subdf_neg.sum()
# Добавляем показатель общего оттока, в качестве "отправной точки" для оценки остальных уровней.
churn_rate_total = (df.sum() / df.count())['churn':'churn'].rename({'churn': 'churn_overall'})
churn_rates = pd.concat([churn_rate_total.rename(''), churn_pos, churn_neg])
churn_rates = churn_rates.sort_values()
overall_churn_index = churn_rates.index.tolist().index('churn_overall')
ax = (churn_rates * 100).to_frame().plot.bar(figsize=(15, 5), rot=-45, legend=None)
ax.axhline(churn_rates[overall_churn_index] * 100, c='k', lw=1, ls='--')
ax.patches[overall_churn_index].set_color('r')
plt.xticks(ha='left')
plt.title('Churn Per Group VS Overall Churn')
plt.xlabel('Group')
plt.ylabel('%')
plt.box()
pass
contract_two_year
, non_contract_one_month
, contract_one_year
)non_internet_service
)is_male
, non_is_male
)phone_service
, non_phone_service
)pay_check_email
)contract_one_month
)senior_citizen
)internet_fiber
)Признаки первой и третьей группы являются хорошими индикаторами оттока, а признаки второй группы наоборот – наименее показательны.
total_charges
¶# KDE plots
(df.tenure * df.monthly_charges + 500).plot.kde(color='green', alpha=.5, label=r'tenure * monthly_charges')
df.total_charges.plot.kde(color='red', alpha=.5, label='total_charges')
plt.box()
plt.legend();
# Correlation matrix
np.corrcoef(df.total_charges, df.tenure * df.monthly_charges)
array([[1. , 0.99956435], [0.99956435, 1. ]])
Показатель total_charges
близок к линейной комбинации двух показателей: tenure * monthly_charges
(корреляция $\approx 1$), однако как показывает практика, удаление этого показателя ухудшает качество прогнозирования.
# Коэффициенты асимметрии и эксцесса для числовых показателей
df_cont = df[['tenure', 'monthly_charges', 'total_charges']]
df_cont.agg(['skew', 'kurt'])
tenure | monthly_charges | total_charges | |
---|---|---|---|
skew | 0.233822 | -0.218678 | 0.951272 |
kurt | -1.386832 | -1.259671 | -0.260961 |
Распределение переменных не соответствует нормальному, о чем свидетельствуют коэффициенты асимметрии (skew) и избыточного эксцесса (kurt - 3
) отличные от нуля, а также форма гистограмм на следующем изображении.
# Scatter matrix
pd.plotting.scatter_matrix(df_cont, c=df.churn, figsize=(12, 12),
hist_kwds={'color': 'gray', 'bins': 25}, diagonal='hist', alpha=.3)
pass
# purple: churn=0, yellow: churn=1
# KDE plot
cmap = plt.rcParams['axes.prop_cycle'].by_key()['color']
axes = df_cont[df.churn].plot.kde(subplots=True, sharex=False, figsize=(8, 10), color=cmap[1])
df_cont[~df.churn].plot.kde(subplots=True, sharex=False, figsize=(8, 10), color=cmap[0], ax=axes)
for ax, title in zip(axes, df_cont.columns):
ax.legend(['churn=True', 'churn=False'])
ax.set_title(title)
axes[0].set_xlabel('Months')
axes[1].set_xlabel('Charge Sum')
axes[2].set_xlabel('Charge Sum')
plt.tight_layout()
tenure < 20
: показатель лояльности менее 20 месяцев.70 < monthly_charges < 110
: размер ежемесячной абонентской платы в диапазоне 70-110.25
: недорогие тарифные планы50
: тарифные планы среднего ценового диапазона75, 100
: тарифные планы высокого ценового диапазона# Tenure index frequency
(df.tenure.value_counts().head().rename_axis('tenure').to_frame('freq') / df.index.size).style.format('{:.0%}')
freq | |
---|---|
tenure | |
1 | 9% |
72 | 5% |
2 | 3% |
3 | 3% |
4 | 3% |
# Correlation matrix
df_corr = df.corr(method='spearman')
Коэффициент корреляции Пирсона (Pearson) $r$, используемый по-умолчанию, подразумевает (assumptions):
Поскольку для исследуемых данных эти условия не выполняются, используется коэффициент корреляции Спирмена (Spearman) $\rho$ ("ro"), т.к. он не подразумевает конкретного распределения (distribution-free).
# Plot correlation matrix as a heatmap
fig, ax = plt.subplots(figsize=(20, 20))
mask = np.triu(np.ones_like(df_corr, dtype=bool))
cmap = sns.diverging_palette(230, 20, as_cmap=True)
sns.heatmap(df_corr, cmap=cmap, mask=mask, square=True, cbar_kws={"shrink": .5}, annot=True, fmt='.2f',
center=0, lw=1)
plt.xticks(rotation=30, ha='right')
plt.yticks()
plt.tick_params(left=False)
plt.title('Correlation Matrix', size=20)
pass
Примечательные корреляции:
monthly_charges
↑↑ internet_fiber
(0.80) – высокая абонентская плата у пользователей услуги Интернет через оптоволокно (высокоскоростное подключение по выделенной линии).monthly_charges
↑↓ internet_dsl
(-0.22) – абонентская плата у пользователей услуги Интернет через DSL-модем (обычное подключение через телефонную линию) незначительно ниже средней.phone_service
↑↓ internet_dsl
(-0.45) – значительное число абонентов, пользующихся услугой Интернет через DSL-модем, не пользуются сотовой связью.churn
~ phone_service
(0.01) – отток не имеет корреляции с услугами сотовой связи.churn
~ is_male
(-0.01) – отток не имеет корреляции с полом абонента.churn
↑↑ internet_fiber
(0.30) – повышенный отток среди пользователей услуги Интернет через оптоволокно.churn
↑↑ contract_one_month
(0.40) – повышенный отток среди пользователей с краткосрочным контрактом. pay_check_email
↑↑ contract_one_month
(0.33) – пользователи с краткосрочным контрактом чаще предпочитают получать счет на email.internet_fiber
↑↑ multiple_lines
(0.37) – пользователи высокоскоросного подключения чаще подключают услугу конференц-связи. Услуга конференц-связи может свидетельствовать об отношении абонента к корпоративному сегменту, либо имеют место дорогие all-include тарифы.Гипотеза 1. К оттоку склонны новые клиенты.
Гипотеза 2. К оттоку склонны клиенты, с краткосрочным контрактом (1 месяц).
Гипотеза 3. К оттоку склонны пользователи дорогих тарифных планов (высокая абонентская плата).
Гипотеза 4. К оттоку склонны пользователи услуги Интернет через оптоволокно.
import time
import sklearn
import sklearn.neural_network
import sklearn.discriminant_analysis
import sklearn.ensemble
import sklearn.gaussian_process
import sklearn.linear_model
import sklearn.naive_bayes
import sklearn.neural_network
import sklearn.neighbors
import sklearn.svm
import sklearn.tree
import xgboost
from sklearn.preprocessing import RobustScaler
from sklearn.compose import ColumnTransformer, make_column_selector
from sklearn.pipeline import Pipeline
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import StackingClassifier, VotingClassifier
from sklearn.tree import DecisionTreeClassifier, plot_tree
# Prevent matplotlib annotation overlapping using force-based text position auto-ajustment
from adjustText import adjust_text
White box – открытая модель с прозрачной логикой прогнозирования.
SEED = 42
# Подготавливаем матрицу объектов-признаков (X) и целевую переменную (y) для обучения моделей
df_shaffled = df.sample(frac=1, random_state=SEED)
X = df_shaffled.drop(columns='churn')
y = df_shaffled.churn
dtc = DecisionTreeClassifier(max_depth=4, class_weight='balanced', min_samples_leaf=0.01, random_state=SEED)
dtc.fit(X, y)
DecisionTreeClassifier(class_weight='balanced', max_depth=4, min_samples_leaf=0.01, random_state=42)
# Decision tree diagram
plt.figure(figsize=(60, 10))
plot_tree(dtc, filled=True, fontsize=14, feature_names=X.columns, class_names=['not churn', 'churn']);
(Jupyter Notebook: double click to zoom)
На диаграмме: левая ветвь ← True, правая ветвь → False.
# Feature importances
pd.Series(dtc.feature_importances_, index=X.columns).sort_values().plot.barh(figsize=(15, 10), width=.8, color='g')
plt.title('Feature Importances', size=20)
plt.box()
plt.xticks([])
pass
Простая модель дерева решений наглядно показывает наиболее значимые для прогнозирования оттока признаки:
contract_one_month
(подтверждение гипотезы 2)contract_one_year
, tenure
(подтверждение гипотезы 1 и 2)monthly_charges
, internet_fiber
(подтверждение гипотезы 3 и 4)cross_val_score(dtc, X, y, n_jobs=-1, cv=4, scoring='roc_auc').mean()
# ROC AUC Score:
0.8282613761909807
Модель на базе дерева решений имеет наглядную интерпретацию, но для достижения максимальной точности необходимо использовать более сложные закрытые модели и их комбинации (ансамбли).
Black box – закрытая модель, нацеленная на прогнозирование с максимальной точностью.
Оптимизация гиперпараметров выполнена с помощью визуального поиска по сетке.
classifiers = {
"QDA" : sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis(
reg_param=.0001
),
"AdaBoost" : sklearn.ensemble.AdaBoostClassifier(
learning_rate=.5,
),
"ExtraTrees" : sklearn.ensemble.ExtraTreesClassifier(
min_samples_split=.01,
max_features=10,
bootstrap=True,
class_weight='balanced_subsample',
max_samples=.5,
),
"GradientBoosting" : sklearn.ensemble.GradientBoostingClassifier(
min_samples_leaf=.1,
),
"RandomForest" : sklearn.ensemble.RandomForestClassifier(
max_depth=6,
n_estimators=5000,
max_samples=.1,
min_samples_leaf=2,
min_samples_split=.0001,
),
"LogisticRegression" : sklearn.linear_model.LogisticRegression(
n_jobs=-1,
class_weight='balanced',
),
"SGD" : sklearn.linear_model.SGDClassifier(
loss='log',
penalty='elasticnet',
learning_rate='adaptive',
class_weight='balanced',
l1_ratio=1.0,
alpha=.001,
eta0=0.1,
max_iter=1000,
),
"GaussianNB" : sklearn.naive_bayes.GaussianNB(),
"MLP" : sklearn.neural_network.MLPClassifier(
early_stopping=True,
batch_size=100,
activation='tanh',
learning_rate_init=.01,
),
"KNeighbors" : sklearn.neighbors.KNeighborsClassifier(
n_jobs=-1,
n_neighbors=100,
),
"SVC" : sklearn.svm.SVC(
probability=True,
kernel='linear',
gamma='auto',
class_weight='balanced',
),
"NuSVC" : sklearn.svm.NuSVC(
probability=True,
kernel='linear',
gamma='auto',
class_weight='balanced',
),
"DecisionTree" : sklearn.tree.DecisionTreeClassifier(
min_samples_split=.1,
class_weight='balanced',
ccp_alpha=.0001,
),
"XGB" : xgboost.XGBClassifier(
use_label_encoder=False,
eval_metric='logloss',
n_jobs=-1,
verbosity=0,
max_depth=2,
reg_alpha=2,
learning_rate=0.1,
),
}
Утечка информации (data leakage) из тренировочных в тестовые данные происходит в результате предварительной трансформации всего датасета (например нормализация, до CV/Train/Test split), что приводит к завышенной оценке производительности модели. Чтобы избежать этого необходимо обрабатывать тренировочные и тестовые данные раздельно.
Наладить предобработку входных данных можно с помощью объекта Pipeline
.
# Для каждого классификатора создается конвейер, нормализующий числовые колонки во входных данных.
numeric_columns_selector = make_column_selector(dtype_include=np.number) # boolean dtype is not included
numeric_columns = numeric_columns_selector(X)
numeric_columns_idx = np.flatnonzero(X.columns.isin(numeric_columns))
preprocessor = ColumnTransformer(
remainder='passthrough',
transformers=[('scaler', RobustScaler(), numeric_columns_idx)]
)
# Classifier -> Pipeline
for k, clf in classifiers.items():
if 'random_state' in clf.get_params():
clf.set_params(random_state=SEED)
classifiers[k] = Pipeline([('pre', preprocessor), ('clf', clf)], memory=CACHE_DIR)
def _time_score_table_cv(estimators):
""" Train and evaluate each model. Return results as dataframe.
"""
columns = 'fit_time_ms', 'auc_score', 'clf_name'
table = []
# Warm up cache for fair results
if CACHE_DIR is not None:
list(estimators.values())[0].fit(X, y)
for i, [clf_name, clf] in enumerate(estimators.items()):
print('%d.' % i, clf_name, end=' ')
t1 = time.time_ns()
auc_score = cross_val_score(clf, X, y, n_jobs=-1, cv=4, scoring='roc_auc').mean()
t2 = time.time_ns()
fit_time_ms = (t2 - t1) / 1e6
print('(%.3f s)' % (fit_time_ms/1000))
table.append([fit_time_ms, auc_score, clf_name])
return pd.DataFrame(table, columns=columns)
def _plot_tpt(time_score_table, figsize=(12, 6)):
""" Create plot from dataframe with time-score data.
"""
display(time_score_table.nlargest(1, 'auc_score').rename_axis(columns='BEST').iloc[:, -2:])
fig, ax = plt.subplots(figsize=figsize)
color_cycle = itertools.cycle(mpl.cm.tab20.colors)
for row in time_score_table.itertuples():
ax.scatter(row.fit_time_ms, row.auc_score, 80, label=row.clf_name, color=next(color_cycle), zorder=10)
texts = time_score_table.apply(lambda row: ax.text(*row, zorder=11), axis=1)
plt.xscale('log')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.xlabel('Fit Time, ms')
plt.ylabel('ROC AUC Score')
plt.title('Time-Performance Trade-Off', size=20)
plt.grid(zorder=1)
plt.box()
plt.plot([0, 1], [.4, .7], transform=ax.transAxes, c='k', lw=1, ls='--', zorder=2)
adjust_text(texts, time_score_table.fit_time_ms.values, time_score_table.auc_score.values,
expand_points=(1.05, 1.5), expand_align=(1.05, 1.5))
bbox_props = dict(boxstyle="larrow", facecolor='white', lw=2)
ax.text(.2, .3, " Better ", ha="center", va="center", rotation=-80, size=15,
bbox=bbox_props, transform=ax.transAxes, zorder=3)
# Каждая модель обучается и оценивается на одном и том же наборе данных. Для оценки используется кросс-валидация
# с 4-мя сбалансированными (stratified) фолдами. Время обучения и усредненная оценка записывается
# в таблицу (pandas.DataFrame).
df_ts = _time_score_table_cv(classifiers)
H:\Portable\anaconda3\lib\site-packages\sklearn\discriminant_analysis.py:808: UserWarning: Variables are collinear warnings.warn("Variables are collinear")
0. QDA (0.170 s) 1. AdaBoost (1.348 s) 2. ExtraTrees (1.406 s) 3. GradientBoosting (1.768 s) 4. RandomForest (34.969 s) 5. LogisticRegression (0.360 s) 6. SGD (0.314 s) 7. GaussianNB (0.104 s) 8. MLP (3.550 s) 9. KNeighbors (1.783 s) 10. SVC (15.678 s) 11. NuSVC (11.460 s) 12. DecisionTree (0.085 s) 13. XGB (2.323 s)
_plot_tpt(df_ts)
BEST | auc_score | clf_name |
---|---|---|
13 | 0.847193 | XGB |
# Общий зачет классификаторов
scores = df_ts.set_index('clf_name')['auc_score']
# Для голосования и стэкинга отбираем все классификаторы кроме SVC и NuSVC
# т.к. они долго обучаются и имеют низкую точность.
classifiers_selected = classifiers.copy()
del classifiers_selected['SVC']
del classifiers_selected['NuSVC']
clf_vote = VotingClassifier(list(classifiers_selected.items()), voting='soft', n_jobs=-1)
(clf_vote_scores := cross_val_score(clf_vote, X, y, n_jobs=-1, cv=4, scoring='roc_auc'))
array([0.83778338, 0.82786467, 0.87135202, 0.84641185])
# Добавить среднюю оценку голосования в общий зачет
scores['Voting'] = clf_vote_scores.mean()
meta_alg = sklearn.ensemble.RandomForestClassifier(max_depth=6, max_samples=0.1, min_samples_leaf=2,
min_samples_split=0.0001, n_estimators=5000, random_state=SEED)
stack = StackingClassifier(list(classifiers_selected.items()), meta_alg, cv=3, n_jobs=-1)
(stacking_score := cross_val_score(stack, X, y, n_jobs=-1, cv=4, scoring='roc_auc'))
array([0.83752233, 0.83027135, 0.87074464, 0.85008494])
# Добавить среднюю оценку стэкинга в общий зачет
scores['Stacking'] = stacking_score.mean()
# Отсортировать и вывести таблицу с оценками
scores = scores.sort_values()
scores[::-1].to_frame()
auc_score | |
---|---|
clf_name | |
XGB | 0.847193 |
Stacking | 0.847156 |
GradientBoosting | 0.847148 |
Voting | 0.845853 |
RandomForest | 0.845676 |
AdaBoost | 0.844780 |
ExtraTrees | 0.843539 |
LogisticRegression | 0.842712 |
MLP | 0.842373 |
SGD | 0.841930 |
QDA | 0.834124 |
KNeighbors | 0.833526 |
NuSVC | 0.831547 |
GaussianNB | 0.828962 |
DecisionTree | 0.828499 |
SVC | 0.827724 |
# Вывести оценки в виде точечной диаграммы
plt.figure(figsize=(13, 5))
plt.scatter(scores, scores.index, s=200, c=mpl.cm.get_cmap('tab20', len(scores)).colors, zorder=2)
plt.box()
plt.grid(zorder=1)
plt.xlabel('ROC AUC Score')
plt.title('Качество моделей', size=20)
plt.tick_params(axis='y', which='both', left=False, right=False, labelright=True)
pass
Результаты:
XGBoost
классификатора.0.847193
.from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# Первые 100 сэмплов датасета для демонстрации прогнозирования
X_train, X_test, y_train, y_test = X[100:], X[:100], y[100:], y[:100]
%%time
classifiers_fitted = classifiers.copy()
for clf_name, clf in classifiers_fitted.items():
print('Fitting: %s' % clf_name)
clf.fit(X_train, y_train)
print('Fitting: Voting')
classifiers_fitted['Voting'] = clf_vote.fit(X_train, y_train)
print('Fitting: Stacking')
classifiers_fitted['Stacking'] = stack.fit(X_train, y_train)
Fitting: QDA Fitting: AdaBoost
H:\Portable\anaconda3\lib\site-packages\sklearn\discriminant_analysis.py:808: UserWarning: Variables are collinear warnings.warn("Variables are collinear")
Fitting: ExtraTrees Fitting: GradientBoosting Fitting: RandomForest Fitting: LogisticRegression Fitting: SGD Fitting: GaussianNB Fitting: MLP Fitting: KNeighbors Fitting: SVC Fitting: NuSVC Fitting: DecisionTree Fitting: XGB Fitting: Voting Fitting: Stacking Wall time: 2min 18s
Для оценки производительности моделей была использована метрика ROC AUC, не зависящая от выбора порога принятия решения (threshold), далее "порог".
threshold_default = 0.5
stack_y_pred = classifiers_fitted['Stacking'].predict(X_test)
stack_y_prob = classifiers_fitted['Stacking'].predict_proba(X_test)[:, 1]
stack_y_prob_thresholded = stack_y_prob > threshold_default
sum(stack_y_pred == stack_y_prob_thresholded)
100
Метод predict()
всегда использует пороговое значение 0.5
, что может быть не лучшим выбором в условиях данной задачи.
target = y_test.to_frame()
target['stack_prob']=classifiers_fitted['Stacking'].predict_proba(X_test)[:, 1]
target = target.sort_values('stack_prob')
target['churn_missed'] = target.churn.cumsum()
df_prob = target.join(X_test)
df_prob.style.format('{:.2%}', subset=['stack_prob'])
churn | stack_prob | churn_missed | senior_citizen | partner | dependents | phone_service | multiple_lines | online_security | online_backup | device_protection | tech_support | streaming_tv | streaming_movies | paperless_billing | is_male | internet_dsl | internet_fiber | internet_service | contract_one_month | contract_one_year | contract_two_year | pay_auto_transfer | pay_auto_credit | pay_check_email | pay_check_mail | tenure | monthly_charges | total_charges | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2680 | False | 0.62% | 0 | False | True | True | True | False | True | True | True | True | True | True | False | True | True | False | True | False | False | True | True | False | False | False | 71 | 85.450000 | 6029.900000 |
1764 | False | 0.74% | 0 | False | True | True | True | True | True | False | True | True | False | True | False | True | True | False | True | False | False | True | False | False | False | True | 63 | 75.700000 | 4676.700000 |
2247 | False | 0.83% | 0 | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | True | False | False | 40 | 19.650000 | 830.250000 |
3951 | False | 0.85% | 0 | False | True | False | True | True | True | True | True | True | True | False | False | True | True | False | True | False | False | True | False | False | False | True | 70 | 79.150000 | 5536.500000 |
5690 | False | 0.92% | 0 | False | False | False | False | False | True | False | True | True | True | True | True | True | True | False | True | False | False | True | False | True | False | False | 72 | 61.200000 | 4390.250000 |
4996 | False | 0.97% | 0 | False | False | False | True | True | True | True | False | True | False | True | False | False | True | False | True | False | False | True | False | False | False | True | 47 | 74.050000 | 3496.300000 |
4303 | False | 1.04% | 0 | False | True | False | True | True | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | True | 45 | 25.500000 | 1121.050000 |
6769 | False | 1.11% | 0 | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False | False | False | False | True | False | True | False | False | 48 | 19.850000 | 916.000000 |
3576 | False | 1.17% | 0 | False | True | True | True | True | False | False | False | False | False | False | True | False | False | False | False | False | True | False | True | False | False | False | 72 | 23.300000 | 1623.150000 |
1610 | False | 1.23% | 0 | False | True | True | True | False | True | True | False | True | False | False | False | True | True | False | True | False | False | True | False | True | False | False | 51 | 60.500000 | 3121.450000 |
268 | True | 1.24% | 1 | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False | False | False | False | True | False | False | False | True | 59 | 19.350000 | 1099.600000 |
5849 | False | 1.28% | 1 | False | True | False | True | False | False | False | False | False | False | False | True | True | False | False | False | False | False | True | True | False | False | False | 63 | 19.950000 | 1234.800000 |
6376 | False | 1.42% | 1 | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False | False | False | False | True | True | False | False | False | 66 | 19.350000 | 1240.800000 |
640 | False | 1.47% | 1 | False | True | True | True | True | False | False | False | False | False | False | False | True | False | False | False | False | False | True | True | False | False | False | 68 | 25.400000 | 1620.200000 |
5487 | False | 1.60% | 1 | False | False | True | True | False | False | False | False | False | False | False | True | True | False | False | False | False | True | False | False | False | False | True | 33 | 20.150000 | 682.150000 |
4108 | False | 1.68% | 1 | False | True | True | True | False | False | False | False | False | False | False | True | True | False | False | False | False | True | False | True | False | False | False | 33 | 19.450000 | 600.250000 |
3234 | False | 1.73% | 1 | False | True | True | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | True | 24 | 19.700000 | 452.550000 |
7012 | False | 1.89% | 1 | False | True | False | True | True | True | True | False | True | True | True | True | False | True | False | True | False | False | True | False | False | True | False | 62 | 84.950000 | 5150.550000 |
1933 | False | 3.03% | 1 | False | False | False | True | False | False | False | False | False | False | False | True | True | False | False | False | False | True | False | False | False | False | True | 20 | 19.700000 | 415.900000 |
1909 | False | 3.22% | 1 | False | True | True | True | False | False | False | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | True | 7 | 20.650000 | 150.000000 |
4385 | False | 3.26% | 1 | False | False | False | True | False | False | False | False | False | False | False | True | True | False | False | False | False | True | False | False | False | False | True | 21 | 20.350000 | 422.700000 |
6092 | False | 3.36% | 1 | False | True | False | True | True | True | True | True | True | False | False | True | False | False | True | True | False | False | True | True | False | False | False | 72 | 94.250000 | 6849.750000 |
1906 | False | 3.44% | 1 | False | False | False | False | False | False | True | False | True | False | False | False | True | True | False | True | False | True | False | False | True | False | False | 40 | 36.000000 | 1382.900000 |
2982 | False | 3.57% | 1 | False | True | True | True | False | False | False | False | False | False | False | True | True | False | False | False | False | True | False | False | True | False | False | 52 | 20.850000 | 1071.600000 |
1936 | False | 4.29% | 1 | False | True | False | True | False | True | True | False | True | False | False | True | False | True | False | True | False | True | False | True | False | False | False | 51 | 60.500000 | 3145.150000 |
107 | False | 4.32% | 1 | False | False | False | False | False | True | False | False | False | False | False | False | False | True | False | True | False | True | False | False | False | False | True | 32 | 30.150000 | 927.650000 |
4822 | False | 4.45% | 1 | False | True | True | True | False | True | False | True | True | True | True | True | True | False | True | True | False | False | True | False | True | False | False | 72 | 104.900000 | 7559.550000 |
3966 | False | 4.78% | 1 | False | True | True | True | True | True | True | False | True | True | True | True | True | False | True | True | False | False | True | False | True | False | False | 72 | 107.700000 | 7919.800000 |
213 | False | 5.10% | 1 | False | True | True | True | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | True | False | 29 | 20.000000 | 540.050000 |
2538 | False | 5.84% | 1 | False | True | True | False | False | False | False | False | True | True | False | True | True | True | False | True | False | False | True | False | False | True | False | 34 | 40.550000 | 1325.850000 |
1368 | False | 7.26% | 1 | False | False | False | True | False | True | False | True | True | False | False | False | True | False | True | True | False | True | False | False | False | True | False | 54 | 84.400000 | 4484.050000 |
5065 | False | 7.34% | 1 | True | False | False | True | True | False | False | False | False | False | False | True | True | False | False | False | False | True | False | True | False | False | False | 33 | 24.900000 | 847.800000 |
5520 | False | 7.55% | 1 | False | False | False | True | True | True | False | True | False | False | False | True | False | False | True | True | False | True | False | False | True | False | False | 55 | 84.250000 | 4589.850000 |
877 | False | 7.93% | 1 | False | False | True | False | False | True | False | True | True | True | True | False | False | True | False | True | False | True | False | True | False | False | False | 51 | 60.150000 | 3077.000000 |
659 | False | 8.26% | 1 | False | True | False | True | False | False | False | False | False | False | False | False | True | False | False | False | True | False | False | True | False | False | False | 37 | 20.350000 | 697.650000 |
3691 | False | 8.46% | 1 | False | True | False | True | True | False | True | False | False | True | False | False | False | True | False | True | False | True | False | True | False | False | False | 40 | 63.900000 | 2635.000000 |
3094 | False | 8.47% | 1 | False | False | False | True | False | False | True | False | False | True | False | True | False | True | False | True | False | True | False | False | True | False | False | 39 | 58.600000 | 2224.500000 |
5264 | False | 8.47% | 1 | True | True | False | False | False | False | False | True | False | False | False | True | True | True | False | True | False | True | False | False | True | False | False | 69 | 29.800000 | 2134.300000 |
5482 | True | 9.10% | 2 | False | True | True | True | False | False | False | True | True | True | True | True | True | True | False | True | False | True | False | False | False | False | True | 33 | 73.900000 | 2405.050000 |
939 | False | 9.49% | 2 | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | True | 15 | 19.900000 | 320.450000 |
189 | False | 9.52% | 2 | False | True | False | True | True | False | True | False | False | False | False | False | False | True | False | True | True | False | False | False | True | False | False | 40 | 56.600000 | 2379.100000 |
2664 | False | 9.72% | 2 | False | True | True | True | True | True | True | False | False | True | False | True | True | False | True | True | False | True | False | False | True | False | False | 62 | 94.950000 | 5791.850000 |
6834 | False | 10.13% | 2 | False | False | False | False | False | False | True | False | False | False | True | True | False | True | False | True | False | False | True | False | False | False | True | 15 | 38.800000 | 603.000000 |
1231 | False | 11.50% | 2 | False | False | False | True | False | True | False | True | False | False | True | False | True | True | False | True | True | False | False | False | True | False | False | 20 | 64.400000 | 1398.600000 |
3361 | False | 11.85% | 2 | True | False | False | True | True | False | True | True | True | False | True | False | False | True | False | True | True | False | False | False | True | False | False | 64 | 74.650000 | 4869.350000 |
481 | True | 14.13% | 3 | False | True | False | False | False | True | False | True | False | False | True | True | False | True | False | True | True | False | False | True | False | False | False | 48 | 45.300000 | 2145.000000 |
102 | False | 15.08% | 3 | False | False | False | True | True | False | False | True | True | True | False | False | True | False | True | True | False | True | False | True | False | False | False | 38 | 95.000000 | 3605.600000 |
569 | False | 15.19% | 3 | False | False | False | False | False | False | True | True | True | False | False | True | False | True | False | True | False | True | False | False | True | False | False | 11 | 40.400000 | 422.600000 |
4046 | False | 15.25% | 3 | False | False | False | True | True | False | False | False | False | False | False | True | False | True | False | True | True | False | False | False | False | False | True | 44 | 50.150000 | 2139.100000 |
6537 | True | 15.31% | 4 | False | True | False | True | True | True | True | True | True | True | True | False | True | False | True | True | False | True | False | False | True | False | False | 70 | 115.650000 | 7968.850000 |
3417 | False | 16.91% | 4 | False | True | True | True | False | False | True | False | True | False | False | False | False | True | False | True | True | False | False | True | False | False | False | 8 | 56.300000 | 401.500000 |
467 | False | 16.95% | 4 | False | True | True | True | True | False | True | False | False | True | False | True | True | False | True | True | False | True | False | False | False | True | False | 72 | 89.700000 | 6588.950000 |
1235 | False | 17.73% | 4 | True | True | False | True | True | False | True | True | True | True | False | True | True | False | True | True | False | True | False | False | False | False | True | 61 | 98.300000 | 6066.550000 |
4352 | False | 19.26% | 4 | False | True | True | True | True | True | True | True | False | False | False | True | True | False | True | True | True | False | False | False | False | True | False | 64 | 91.800000 | 5960.500000 |
280 | False | 19.64% | 4 | False | False | False | True | False | True | True | False | False | False | False | False | False | True | False | True | True | False | False | False | False | False | True | 6 | 55.150000 | 322.900000 |
2437 | False | 19.88% | 4 | False | False | False | True | False | True | True | True | False | False | False | False | False | True | False | True | True | False | False | False | False | True | False | 7 | 61.400000 | 438.900000 |
5243 | False | 19.97% | 4 | False | False | False | True | False | False | False | False | True | False | True | True | True | False | True | True | False | True | False | False | True | False | False | 19 | 87.700000 | 1725.950000 |
4658 | False | 21.30% | 4 | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | True | True | False | False | False | False | False | True | 8 | 44.450000 | 369.300000 |
2810 | False | 21.30% | 4 | True | True | True | True | True | False | False | True | True | True | True | True | False | False | True | True | False | True | False | True | False | False | False | 65 | 103.900000 | 6767.100000 |
1417 | False | 22.80% | 4 | True | False | False | True | False | False | True | False | False | False | False | True | False | True | False | True | True | False | False | False | True | False | False | 18 | 49.850000 | 865.750000 |
4998 | False | 23.13% | 4 | True | False | False | True | True | False | True | True | True | True | True | True | True | False | True | True | False | True | False | True | False | False | False | 62 | 110.750000 | 7053.350000 |
4803 | False | 24.36% | 4 | False | True | False | True | True | True | False | False | False | True | False | False | True | False | True | True | True | False | False | True | False | False | False | 38 | 91.700000 | 3479.050000 |
6824 | True | 24.88% | 5 | True | False | False | True | True | False | True | False | False | False | False | True | False | False | True | True | True | False | False | False | True | False | False | 66 | 80.450000 | 5224.350000 |
1689 | False | 25.18% | 5 | False | True | True | False | False | True | True | True | False | True | True | False | True | True | False | True | True | False | False | False | False | False | True | 7 | 58.850000 | 465.700000 |
5150 | False | 26.37% | 5 | False | True | False | True | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | True | False | False | False | 2 | 20.100000 | 43.150000 |
2730 | False | 28.88% | 5 | False | False | False | True | True | False | True | False | True | True | True | True | True | False | True | True | False | True | False | True | False | False | False | 49 | 106.650000 | 5168.100000 |
4528 | False | 29.04% | 5 | False | False | False | True | False | False | True | True | False | True | False | True | False | False | True | True | True | False | False | False | True | False | False | 44 | 88.150000 | 3973.200000 |
2285 | True | 29.34% | 6 | False | True | False | True | True | True | False | True | False | False | True | False | True | False | True | True | True | False | False | False | False | True | False | 53 | 93.900000 | 5029.200000 |
4626 | False | 30.09% | 6 | True | False | False | True | True | False | False | True | False | False | True | True | False | False | True | True | True | False | False | True | False | False | False | 63 | 89.600000 | 5538.800000 |
6360 | True | 30.78% | 7 | False | False | False | True | False | False | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | False | True | 1 | 20.300000 | 20.300000 |
4549 | False | 34.01% | 7 | False | False | True | True | True | False | True | True | True | False | False | True | False | False | True | True | True | False | False | False | True | False | False | 17 | 92.700000 | 1556.850000 |
2847 | False | 34.18% | 7 | False | True | False | True | True | False | True | False | False | False | False | False | False | False | True | True | True | False | False | True | False | False | False | 26 | 79.300000 | 2015.800000 |
2682 | True | 34.36% | 8 | False | False | False | True | True | False | False | False | False | False | False | False | True | True | False | True | True | False | False | False | False | False | True | 4 | 50.400000 | 206.600000 |
388 | False | 35.24% | 8 | False | False | False | True | True | False | False | True | False | False | True | True | True | False | True | True | True | False | False | False | False | False | True | 44 | 90.400000 | 4063.000000 |
2922 | True | 35.51% | 9 | False | True | False | True | False | True | False | False | False | True | True | True | False | False | True | True | True | False | False | False | False | False | True | 28 | 92.350000 | 2602.900000 |
2735 | False | 35.94% | 9 | True | False | False | True | True | False | False | False | False | False | False | True | True | False | True | True | True | False | False | True | False | False | False | 52 | 72.950000 | 3829.750000 |
3868 | True | 36.19% | 10 | False | True | True | True | False | False | True | False | False | False | False | True | False | False | True | True | True | False | False | True | False | False | False | 21 | 74.050000 | 1565.700000 |
2991 | False | 38.83% | 10 | True | True | False | True | False | False | False | True | True | True | False | False | False | False | True | True | True | False | False | False | False | True | False | 37 | 90.000000 | 3371.750000 |
4338 | False | 40.57% | 10 | True | False | False | False | False | False | False | True | True | True | False | False | True | True | False | True | True | False | False | False | True | False | False | 5 | 45.400000 | 214.750000 |
6081 | False | 41.29% | 10 | False | False | False | True | False | True | False | False | False | False | False | True | True | True | False | True | True | False | False | False | True | False | False | 1 | 49.800000 | 49.800000 |
1838 | False | 41.48% | 10 | True | False | False | True | True | False | True | False | False | False | True | True | True | False | True | True | True | False | False | False | False | True | False | 54 | 90.050000 | 4931.800000 |
947 | False | 41.92% | 10 | True | True | True | True | True | False | False | True | False | True | False | True | False | False | True | True | True | False | False | False | False | False | True | 32 | 91.350000 | 2896.550000 |
6542 | True | 45.33% | 11 | False | False | False | True | True | False | False | False | False | False | False | True | True | False | True | True | True | False | False | False | False | False | True | 15 | 76.000000 | 1130.850000 |
951 | False | 48.31% | 11 | False | False | False | True | False | False | True | False | False | False | False | False | False | False | True | True | True | False | False | False | False | True | False | 8 | 75.600000 | 535.550000 |
1675 | True | 50.03% | 12 | False | False | False | True | False | False | False | False | False | True | False | True | False | False | True | True | True | False | False | False | False | True | False | 29 | 78.900000 | 2384.150000 |
506 | False | 53.39% | 12 | False | False | True | True | True | False | False | True | False | False | False | True | True | False | True | True | True | False | False | False | False | False | True | 11 | 78.000000 | 851.800000 |
389 | False | 53.63% | 12 | False | False | True | True | False | False | False | False | False | False | False | False | True | True | False | True | True | False | False | False | False | True | False | 1 | 44.000000 | 44.000000 |
898 | True | 54.50% | 13 | False | False | False | True | False | True | False | False | True | True | True | True | False | False | True | True | True | False | False | True | False | False | False | 12 | 98.900000 | 1120.950000 |
5746 | True | 59.34% | 14 | False | False | False | True | True | False | True | False | False | False | False | True | False | False | True | True | True | False | False | False | True | False | False | 10 | 81.000000 | 818.050000 |
1630 | False | 59.83% | 14 | True | True | False | True | True | False | True | True | False | True | True | True | True | False | True | True | True | False | False | False | False | True | False | 38 | 102.600000 | 4009.200000 |
6533 | False | 60.19% | 14 | True | True | False | True | True | False | True | False | True | True | True | True | True | False | True | True | True | False | False | False | False | True | False | 28 | 105.800000 | 2998.000000 |
4109 | True | 64.71% | 15 | False | False | False | False | False | False | False | False | False | False | False | True | True | True | False | True | True | False | False | False | False | True | False | 2 | 25.050000 | 56.350000 |
6445 | False | 65.11% | 15 | False | False | False | True | True | True | False | False | False | True | False | False | True | False | True | True | True | False | False | False | False | True | False | 4 | 90.650000 | 367.950000 |
1401 | True | 71.18% | 16 | True | False | False | True | False | False | True | True | False | True | True | True | False | False | True | True | True | False | False | False | True | False | False | 4 | 99.800000 | 442.850000 |
3919 | True | 72.52% | 17 | False | False | False | True | True | False | False | False | False | True | True | True | False | False | True | True | True | False | False | True | False | False | False | 6 | 93.550000 | 536.400000 |
639 | True | 74.67% | 18 | False | True | False | True | True | False | False | False | False | False | True | True | True | False | True | True | True | False | False | False | False | True | False | 10 | 85.250000 | 855.300000 |
4077 | False | 78.43% | 18 | True | True | True | True | True | False | False | False | False | True | True | True | False | False | True | True | True | False | False | False | False | True | False | 14 | 95.800000 | 1346.300000 |
6240 | True | 85.49% | 19 | False | False | True | True | True | False | False | False | False | True | True | True | True | False | True | True | True | False | False | False | False | True | False | 1 | 93.300000 | 93.300000 |
2397 | True | 86.24% | 20 | False | False | False | True | False | False | False | False | False | True | True | True | True | False | True | True | True | False | False | False | False | True | False | 1 | 88.350000 | 88.350000 |
6096 | False | 88.12% | 20 | True | False | False | True | False | False | False | False | False | False | False | True | False | False | True | True | True | False | False | False | False | True | False | 1 | 70.200000 | 70.200000 |
def plot_cmatrix(y_true, y_pred, ax=None, cmap='gray', title=None):
"""
Plot proper oriented confusion matrix with actual values in rows and predicted values
in columns. Used as a replacement for `sklearn.metrics.plot_confusion_matrix()` function.
"""
if ax is None:
ax = plt.gca()
disp = ConfusionMatrixDisplay(confusion_matrix(y_true, y_pred))
disp.plot(cmap=cmap, ax=ax)
ax.images[0].colorbar.remove()
if title is not None:
ax.set_title(title)
return ax
fig, [ax1, ax2] = plt.subplots(ncols=2, figsize=(10, 5))
plot_cmatrix(target.churn, target.stack_prob > .5, ax=ax1, title='threshold=50.0%')
plot_cmatrix(target.churn, target.stack_prob > .245, ax=ax2, title='threshold=24.5%')
pass
def get_conclusion(row):
if row.churn and not row.stack_pred:
return 'miss'
if not row.churn and row.stack_pred:
return 'false alarm'
return 'correct'
Прогнозирование с выбранным порогом: 24.5%
.245
и менее..245
.threshold = .245
df_pred = target.churn.to_frame()
df_pred['stack_pred'] = target.stack_prob > threshold
df_pred['conclusion'] = df_pred.apply(get_conclusion, axis='columns')
df_pred.join(X_test).style
churn | stack_pred | conclusion | senior_citizen | partner | dependents | phone_service | multiple_lines | online_security | online_backup | device_protection | tech_support | streaming_tv | streaming_movies | paperless_billing | is_male | internet_dsl | internet_fiber | internet_service | contract_one_month | contract_one_year | contract_two_year | pay_auto_transfer | pay_auto_credit | pay_check_email | pay_check_mail | tenure | monthly_charges | total_charges | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2680 | False | False | correct | False | True | True | True | False | True | True | True | True | True | True | False | True | True | False | True | False | False | True | True | False | False | False | 71 | 85.450000 | 6029.900000 |
1764 | False | False | correct | False | True | True | True | True | True | False | True | True | False | True | False | True | True | False | True | False | False | True | False | False | False | True | 63 | 75.700000 | 4676.700000 |
2247 | False | False | correct | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | True | False | False | 40 | 19.650000 | 830.250000 |
3951 | False | False | correct | False | True | False | True | True | True | True | True | True | True | False | False | True | True | False | True | False | False | True | False | False | False | True | 70 | 79.150000 | 5536.500000 |
5690 | False | False | correct | False | False | False | False | False | True | False | True | True | True | True | True | True | True | False | True | False | False | True | False | True | False | False | 72 | 61.200000 | 4390.250000 |
4996 | False | False | correct | False | False | False | True | True | True | True | False | True | False | True | False | False | True | False | True | False | False | True | False | False | False | True | 47 | 74.050000 | 3496.300000 |
4303 | False | False | correct | False | True | False | True | True | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | True | 45 | 25.500000 | 1121.050000 |
6769 | False | False | correct | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False | False | False | False | True | False | True | False | False | 48 | 19.850000 | 916.000000 |
3576 | False | False | correct | False | True | True | True | True | False | False | False | False | False | False | True | False | False | False | False | False | True | False | True | False | False | False | 72 | 23.300000 | 1623.150000 |
1610 | False | False | correct | False | True | True | True | False | True | True | False | True | False | False | False | True | True | False | True | False | False | True | False | True | False | False | 51 | 60.500000 | 3121.450000 |
268 | True | False | miss | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False | False | False | False | True | False | False | False | True | 59 | 19.350000 | 1099.600000 |
5849 | False | False | correct | False | True | False | True | False | False | False | False | False | False | False | True | True | False | False | False | False | False | True | True | False | False | False | 63 | 19.950000 | 1234.800000 |
6376 | False | False | correct | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False | False | False | False | True | True | False | False | False | 66 | 19.350000 | 1240.800000 |
640 | False | False | correct | False | True | True | True | True | False | False | False | False | False | False | False | True | False | False | False | False | False | True | True | False | False | False | 68 | 25.400000 | 1620.200000 |
5487 | False | False | correct | False | False | True | True | False | False | False | False | False | False | False | True | True | False | False | False | False | True | False | False | False | False | True | 33 | 20.150000 | 682.150000 |
4108 | False | False | correct | False | True | True | True | False | False | False | False | False | False | False | True | True | False | False | False | False | True | False | True | False | False | False | 33 | 19.450000 | 600.250000 |
3234 | False | False | correct | False | True | True | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | True | 24 | 19.700000 | 452.550000 |
7012 | False | False | correct | False | True | False | True | True | True | True | False | True | True | True | True | False | True | False | True | False | False | True | False | False | True | False | 62 | 84.950000 | 5150.550000 |
1933 | False | False | correct | False | False | False | True | False | False | False | False | False | False | False | True | True | False | False | False | False | True | False | False | False | False | True | 20 | 19.700000 | 415.900000 |
1909 | False | False | correct | False | True | True | True | False | False | False | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | True | 7 | 20.650000 | 150.000000 |
4385 | False | False | correct | False | False | False | True | False | False | False | False | False | False | False | True | True | False | False | False | False | True | False | False | False | False | True | 21 | 20.350000 | 422.700000 |
6092 | False | False | correct | False | True | False | True | True | True | True | True | True | False | False | True | False | False | True | True | False | False | True | True | False | False | False | 72 | 94.250000 | 6849.750000 |
1906 | False | False | correct | False | False | False | False | False | False | True | False | True | False | False | False | True | True | False | True | False | True | False | False | True | False | False | 40 | 36.000000 | 1382.900000 |
2982 | False | False | correct | False | True | True | True | False | False | False | False | False | False | False | True | True | False | False | False | False | True | False | False | True | False | False | 52 | 20.850000 | 1071.600000 |
1936 | False | False | correct | False | True | False | True | False | True | True | False | True | False | False | True | False | True | False | True | False | True | False | True | False | False | False | 51 | 60.500000 | 3145.150000 |
107 | False | False | correct | False | False | False | False | False | True | False | False | False | False | False | False | False | True | False | True | False | True | False | False | False | False | True | 32 | 30.150000 | 927.650000 |
4822 | False | False | correct | False | True | True | True | False | True | False | True | True | True | True | True | True | False | True | True | False | False | True | False | True | False | False | 72 | 104.900000 | 7559.550000 |
3966 | False | False | correct | False | True | True | True | True | True | True | False | True | True | True | True | True | False | True | True | False | False | True | False | True | False | False | 72 | 107.700000 | 7919.800000 |
213 | False | False | correct | False | True | True | True | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | True | False | 29 | 20.000000 | 540.050000 |
2538 | False | False | correct | False | True | True | False | False | False | False | False | True | True | False | True | True | True | False | True | False | False | True | False | False | True | False | 34 | 40.550000 | 1325.850000 |
1368 | False | False | correct | False | False | False | True | False | True | False | True | True | False | False | False | True | False | True | True | False | True | False | False | False | True | False | 54 | 84.400000 | 4484.050000 |
5065 | False | False | correct | True | False | False | True | True | False | False | False | False | False | False | True | True | False | False | False | False | True | False | True | False | False | False | 33 | 24.900000 | 847.800000 |
5520 | False | False | correct | False | False | False | True | True | True | False | True | False | False | False | True | False | False | True | True | False | True | False | False | True | False | False | 55 | 84.250000 | 4589.850000 |
877 | False | False | correct | False | False | True | False | False | True | False | True | True | True | True | False | False | True | False | True | False | True | False | True | False | False | False | 51 | 60.150000 | 3077.000000 |
659 | False | False | correct | False | True | False | True | False | False | False | False | False | False | False | False | True | False | False | False | True | False | False | True | False | False | False | 37 | 20.350000 | 697.650000 |
3691 | False | False | correct | False | True | False | True | True | False | True | False | False | True | False | False | False | True | False | True | False | True | False | True | False | False | False | 40 | 63.900000 | 2635.000000 |
3094 | False | False | correct | False | False | False | True | False | False | True | False | False | True | False | True | False | True | False | True | False | True | False | False | True | False | False | 39 | 58.600000 | 2224.500000 |
5264 | False | False | correct | True | True | False | False | False | False | False | True | False | False | False | True | True | True | False | True | False | True | False | False | True | False | False | 69 | 29.800000 | 2134.300000 |
5482 | True | False | miss | False | True | True | True | False | False | False | True | True | True | True | True | True | True | False | True | False | True | False | False | False | False | True | 33 | 73.900000 | 2405.050000 |
939 | False | False | correct | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | True | 15 | 19.900000 | 320.450000 |
189 | False | False | correct | False | True | False | True | True | False | True | False | False | False | False | False | False | True | False | True | True | False | False | False | True | False | False | 40 | 56.600000 | 2379.100000 |
2664 | False | False | correct | False | True | True | True | True | True | True | False | False | True | False | True | True | False | True | True | False | True | False | False | True | False | False | 62 | 94.950000 | 5791.850000 |
6834 | False | False | correct | False | False | False | False | False | False | True | False | False | False | True | True | False | True | False | True | False | False | True | False | False | False | True | 15 | 38.800000 | 603.000000 |
1231 | False | False | correct | False | False | False | True | False | True | False | True | False | False | True | False | True | True | False | True | True | False | False | False | True | False | False | 20 | 64.400000 | 1398.600000 |
3361 | False | False | correct | True | False | False | True | True | False | True | True | True | False | True | False | False | True | False | True | True | False | False | False | True | False | False | 64 | 74.650000 | 4869.350000 |
481 | True | False | miss | False | True | False | False | False | True | False | True | False | False | True | True | False | True | False | True | True | False | False | True | False | False | False | 48 | 45.300000 | 2145.000000 |
102 | False | False | correct | False | False | False | True | True | False | False | True | True | True | False | False | True | False | True | True | False | True | False | True | False | False | False | 38 | 95.000000 | 3605.600000 |
569 | False | False | correct | False | False | False | False | False | False | True | True | True | False | False | True | False | True | False | True | False | True | False | False | True | False | False | 11 | 40.400000 | 422.600000 |
4046 | False | False | correct | False | False | False | True | True | False | False | False | False | False | False | True | False | True | False | True | True | False | False | False | False | False | True | 44 | 50.150000 | 2139.100000 |
6537 | True | False | miss | False | True | False | True | True | True | True | True | True | True | True | False | True | False | True | True | False | True | False | False | True | False | False | 70 | 115.650000 | 7968.850000 |
3417 | False | False | correct | False | True | True | True | False | False | True | False | True | False | False | False | False | True | False | True | True | False | False | True | False | False | False | 8 | 56.300000 | 401.500000 |
467 | False | False | correct | False | True | True | True | True | False | True | False | False | True | False | True | True | False | True | True | False | True | False | False | False | True | False | 72 | 89.700000 | 6588.950000 |
1235 | False | False | correct | True | True | False | True | True | False | True | True | True | True | False | True | True | False | True | True | False | True | False | False | False | False | True | 61 | 98.300000 | 6066.550000 |
4352 | False | False | correct | False | True | True | True | True | True | True | True | False | False | False | True | True | False | True | True | True | False | False | False | False | True | False | 64 | 91.800000 | 5960.500000 |
280 | False | False | correct | False | False | False | True | False | True | True | False | False | False | False | False | False | True | False | True | True | False | False | False | False | False | True | 6 | 55.150000 | 322.900000 |
2437 | False | False | correct | False | False | False | True | False | True | True | True | False | False | False | False | False | True | False | True | True | False | False | False | False | True | False | 7 | 61.400000 | 438.900000 |
5243 | False | False | correct | False | False | False | True | False | False | False | False | True | False | True | True | True | False | True | True | False | True | False | False | True | False | False | 19 | 87.700000 | 1725.950000 |
4658 | False | False | correct | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | True | True | False | False | False | False | False | True | 8 | 44.450000 | 369.300000 |
2810 | False | False | correct | True | True | True | True | True | False | False | True | True | True | True | True | False | False | True | True | False | True | False | True | False | False | False | 65 | 103.900000 | 6767.100000 |
1417 | False | False | correct | True | False | False | True | False | False | True | False | False | False | False | True | False | True | False | True | True | False | False | False | True | False | False | 18 | 49.850000 | 865.750000 |
4998 | False | False | correct | True | False | False | True | True | False | True | True | True | True | True | True | True | False | True | True | False | True | False | True | False | False | False | 62 | 110.750000 | 7053.350000 |
4803 | False | False | correct | False | True | False | True | True | True | False | False | False | True | False | False | True | False | True | True | True | False | False | True | False | False | False | 38 | 91.700000 | 3479.050000 |
6824 | True | True | correct | True | False | False | True | True | False | True | False | False | False | False | True | False | False | True | True | True | False | False | False | True | False | False | 66 | 80.450000 | 5224.350000 |
1689 | False | True | false alarm | False | True | True | False | False | True | True | True | False | True | True | False | True | True | False | True | True | False | False | False | False | False | True | 7 | 58.850000 | 465.700000 |
5150 | False | True | false alarm | False | True | False | True | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | True | False | False | False | 2 | 20.100000 | 43.150000 |
2730 | False | True | false alarm | False | False | False | True | True | False | True | False | True | True | True | True | True | False | True | True | False | True | False | True | False | False | False | 49 | 106.650000 | 5168.100000 |
4528 | False | True | false alarm | False | False | False | True | False | False | True | True | False | True | False | True | False | False | True | True | True | False | False | False | True | False | False | 44 | 88.150000 | 3973.200000 |
2285 | True | True | correct | False | True | False | True | True | True | False | True | False | False | True | False | True | False | True | True | True | False | False | False | False | True | False | 53 | 93.900000 | 5029.200000 |
4626 | False | True | false alarm | True | False | False | True | True | False | False | True | False | False | True | True | False | False | True | True | True | False | False | True | False | False | False | 63 | 89.600000 | 5538.800000 |
6360 | True | True | correct | False | False | False | True | False | False | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | False | True | 1 | 20.300000 | 20.300000 |
4549 | False | True | false alarm | False | False | True | True | True | False | True | True | True | False | False | True | False | False | True | True | True | False | False | False | True | False | False | 17 | 92.700000 | 1556.850000 |
2847 | False | True | false alarm | False | True | False | True | True | False | True | False | False | False | False | False | False | False | True | True | True | False | False | True | False | False | False | 26 | 79.300000 | 2015.800000 |
2682 | True | True | correct | False | False | False | True | True | False | False | False | False | False | False | False | True | True | False | True | True | False | False | False | False | False | True | 4 | 50.400000 | 206.600000 |
388 | False | True | false alarm | False | False | False | True | True | False | False | True | False | False | True | True | True | False | True | True | True | False | False | False | False | False | True | 44 | 90.400000 | 4063.000000 |
2922 | True | True | correct | False | True | False | True | False | True | False | False | False | True | True | True | False | False | True | True | True | False | False | False | False | False | True | 28 | 92.350000 | 2602.900000 |
2735 | False | True | false alarm | True | False | False | True | True | False | False | False | False | False | False | True | True | False | True | True | True | False | False | True | False | False | False | 52 | 72.950000 | 3829.750000 |
3868 | True | True | correct | False | True | True | True | False | False | True | False | False | False | False | True | False | False | True | True | True | False | False | True | False | False | False | 21 | 74.050000 | 1565.700000 |
2991 | False | True | false alarm | True | True | False | True | False | False | False | True | True | True | False | False | False | False | True | True | True | False | False | False | False | True | False | 37 | 90.000000 | 3371.750000 |
4338 | False | True | false alarm | True | False | False | False | False | False | False | True | True | True | False | False | True | True | False | True | True | False | False | False | True | False | False | 5 | 45.400000 | 214.750000 |
6081 | False | True | false alarm | False | False | False | True | False | True | False | False | False | False | False | True | True | True | False | True | True | False | False | False | True | False | False | 1 | 49.800000 | 49.800000 |
1838 | False | True | false alarm | True | False | False | True | True | False | True | False | False | False | True | True | True | False | True | True | True | False | False | False | False | True | False | 54 | 90.050000 | 4931.800000 |
947 | False | True | false alarm | True | True | True | True | True | False | False | True | False | True | False | True | False | False | True | True | True | False | False | False | False | False | True | 32 | 91.350000 | 2896.550000 |
6542 | True | True | correct | False | False | False | True | True | False | False | False | False | False | False | True | True | False | True | True | True | False | False | False | False | False | True | 15 | 76.000000 | 1130.850000 |
951 | False | True | false alarm | False | False | False | True | False | False | True | False | False | False | False | False | False | False | True | True | True | False | False | False | False | True | False | 8 | 75.600000 | 535.550000 |
1675 | True | True | correct | False | False | False | True | False | False | False | False | False | True | False | True | False | False | True | True | True | False | False | False | False | True | False | 29 | 78.900000 | 2384.150000 |
506 | False | True | false alarm | False | False | True | True | True | False | False | True | False | False | False | True | True | False | True | True | True | False | False | False | False | False | True | 11 | 78.000000 | 851.800000 |
389 | False | True | false alarm | False | False | True | True | False | False | False | False | False | False | False | False | True | True | False | True | True | False | False | False | False | True | False | 1 | 44.000000 | 44.000000 |
898 | True | True | correct | False | False | False | True | False | True | False | False | True | True | True | True | False | False | True | True | True | False | False | True | False | False | False | 12 | 98.900000 | 1120.950000 |
5746 | True | True | correct | False | False | False | True | True | False | True | False | False | False | False | True | False | False | True | True | True | False | False | False | True | False | False | 10 | 81.000000 | 818.050000 |
1630 | False | True | false alarm | True | True | False | True | True | False | True | True | False | True | True | True | True | False | True | True | True | False | False | False | False | True | False | 38 | 102.600000 | 4009.200000 |
6533 | False | True | false alarm | True | True | False | True | True | False | True | False | True | True | True | True | True | False | True | True | True | False | False | False | False | True | False | 28 | 105.800000 | 2998.000000 |
4109 | True | True | correct | False | False | False | False | False | False | False | False | False | False | False | True | True | True | False | True | True | False | False | False | False | True | False | 2 | 25.050000 | 56.350000 |
6445 | False | True | false alarm | False | False | False | True | True | True | False | False | False | True | False | False | True | False | True | True | True | False | False | False | False | True | False | 4 | 90.650000 | 367.950000 |
1401 | True | True | correct | True | False | False | True | False | False | True | True | False | True | True | True | False | False | True | True | True | False | False | False | True | False | False | 4 | 99.800000 | 442.850000 |
3919 | True | True | correct | False | False | False | True | True | False | False | False | False | True | True | True | False | False | True | True | True | False | False | True | False | False | False | 6 | 93.550000 | 536.400000 |
639 | True | True | correct | False | True | False | True | True | False | False | False | False | False | True | True | True | False | True | True | True | False | False | False | False | True | False | 10 | 85.250000 | 855.300000 |
4077 | False | True | false alarm | True | True | True | True | True | False | False | False | False | True | True | True | False | False | True | True | True | False | False | False | False | True | False | 14 | 95.800000 | 1346.300000 |
6240 | True | True | correct | False | False | True | True | True | False | False | False | False | True | True | True | True | False | True | True | True | False | False | False | False | True | False | 1 | 93.300000 | 93.300000 |
2397 | True | True | correct | False | False | False | True | False | False | False | False | False | True | True | True | True | False | True | True | True | False | False | False | False | True | False | 1 | 88.350000 | 88.350000 |
6096 | False | True | false alarm | True | False | False | True | False | False | False | False | False | False | False | True | False | False | True | True | True | False | False | False | False | True | False | 1 | 70.200000 | 70.200000 |
Возможные исходы классификации (колонка conclusion):
correct
– верный прогнозfalse alarm
– ложное срабатывание (type I error)miss
– пропуск (type II error)clf = classifiers_fitted['Stacking']
sample = [1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0,
0, 66, 80.45, 5224.35]
clf.predict([sample])
array([False])
clf.predict_proba([sample])[:, 1]
array([0.2487913])
Ошибка 1
NotFittedError: This AdaBoostClassifier instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.
Перед использованием методов predict()
(или predict_proba()
) необходимо обучить классификатор с помощью метода fit()
.
Ошибка 2
ValueError: Expected 2D array, got 1D array instead:
Единственный сэмпл должен быть обернут в массив из 1-го элемента.
Неправильно:
clf.predict([1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0,
0, 66, 80.45, 5224.35])
Правильно:
clf.predict([[1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0,
0, 66, 80.45, 5224.35]])
Ошибка 3
ValueError: y should be a 1d array, got an array of shape (100, 2) instead.
Метод predict_proba()
в отличии от predict()
возвращает две колонки. Для бинарной классификации: первая колонка – вероятность class=False, вторая колонка – вероятность class=True.
Неправильно:
y_score = clf.predict_proba(X_test)
roc_auc_score(y_test, y_score)
Правильно:
y_score = clf.predict_proba(X_test)[:, 1]
roc_auc_score(y_test, y_score)
Ошибка 4
ValueError: Classification metrics can't handle a mix of binary and continuous targets
Метрика recall_score
работает с предсказаниями в виде бинарных значений (False/True, 0/1).
Метрика roc_auc_score
работает с предсказаниями в виде вероятностей ($x \in [0.0, 1.0]$).
Неправильно:
y_pred = clf.predict_proba(X_test)[:, 1]
recall_score(y_test, y_pred)
Правильно:
y_pred = clf.predict(X_test)
recall_score(y_test, y_pred)
# Автоматически удалить папку с кэшем (если REMOVE_CACHE=True)
if REMOVE_CACHE:
shutil.rmtree(CACHE_DIR)