sklearn 中的决策树
from sklearn.tree import DecisionTreeClassifierclf = DecisionTreeClassifier(criterion='entropy')
其中的 criterion 参数,就是决策树算法,可以选择 entropy,就是基于信息熵; 而 gini,就是基于基尼系数。
https://github.com/zhouwei713/DataAnalyse/tree/master/Decision_tree
— 数据预处理 —
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as matplot
import seaborn as sns
df = pd.read_csv( 'HR.csv', index_col= None)
# 检测是否有缺失数据
df.isnull().any()
>>>
satisfaction_level False
last_evaluation False
number_project False
average_montly_hours False
time_spend_company False
Work_accident False
left False
promotion_last_5years False
sales False
salary False
dtype: bool
数据很完整,没有缺失值
df.head()
satisfaction_level:员工对公司满意度last_evaluation:上一次公司对员工的评价number_project:该员工同时负责多少项目average_montly_hours:每个月工作的时长time_spend_company:在公司工作多久Work_accident:是否有个工作事故left:是否离开公司(1表示离职)promotion_last_5years:过去5年是否又被升职sales:部门salary:工资水平
— 数据分析 —
print(df.shape)>>>(14999, 10)
共有14999条数据,每条数据10个特征。 不过需要把 left 作为标签,实际为9个特征。
turnover_rate = df.left.value_counts() / len(df)print(turnover_rate)>>>0 0.7619171 0.238083Name: left, dtype: float64
离职率为0.24
turnover_Summary = df.groupby('left')print(turnover_Summary.mean())
— 相关性分析 —
corr = df.corr()sns.heatmap(corr, xticklabels=corr.columns.values, yticklabels=corr.columns.values)
两个特征的交叉处,颜色越浅,说明相关性越大。
— 数据集字符串转换成数字 —
print(df.dtypes)>>>satisfaction_level float64last_evaluation float64number_project int64average_montly_hours int64time_spend_company int64Work_accident int64left int64promotion_last_5years int64sales objectsalary objectdtype: object
sales 和 salary 两个特征需要转换
import pandas as pdpd.Series._accessors>>>{'cat', 'dt', 'sparse', 'str'}
可以看到,对于 series 类型数据,有如下四个属性可用。
cat: 用于分类数据(Categorical data)
str: 用于字符数据(String Object data)
dt: 用于时间数据(datetime-like data)
sparse: 用于系数矩阵
test = pd.Series([ 'str1', 'str2', 'str3'])
print(test.str.upper())
print(dir(test.str))
>>>
0 STR1
1 STR2
2 STR3
dtype: object
[... '_get_series_list', '_inferred_dtype', '_is_categorical', '_make_accessor', '_orig', '_parent', '_validate', '_wrap_result', 'capitalize', 'casefold', 'cat', 'center', 'contains', 'count', 'decode', 'encode', 'endswith', 'extract', 'extractall', 'find', 'findall', 'get', 'get_dummies', 'index'...]
df["sales"] = df["sales"].astype('category')df["salary"] = df["salary"].astype('category')然后再使用 cat.codes 来实现对整数的映射df["sales"] = df["sales"].cat.codesdf["salary"] = df["salary"].cat.codes
— 模型训练 —
target_name = 'left'X = df.drop('left', axis=1)y = df[target_name]划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.15, random_state=123, stratify=y)
stratify 参数的作用是在训练集和测试集中,不同标签(离职与非离职)所占比例相同,即原数据集中比例是多少,训练集和测试集中比例也为多少。
from sklearn.metrics import roc_auc_score
clf = DecisionTreeClassifier(
criterion='entropy',
min_weight_fraction_leaf=0.01
)
clf = clf.fit(X_train,y_train)
clf_roc_auc = roc_auc_score(y_test, clf.predict(X_test))
print ( "决策树 AUC = %2.2f" % clf_roc_auc)
>>>
---决策树---
决策树 AUC = 0.93
这里使用了 ROC 和 AUC 来检查分类器的准确率,我们来看看它们的含义。
— ROC —
ROC 曲线有个很好的特性: 当测试集中的正负样本的分布变化的时候,ROC 曲线能够保持不变。 在实际的数据集中经常会出现类不平衡(class imbalance)现象,即负样本比正样本多很多(或者相反),而且测试数据中的正负样本的分布也可能随着时间变化。
from sklearn.metrics import roc_curve
clf_fpr, clf_tpr, clf_thresholds = roc_curve(y_test, clf.predict_proba(X_test)[:, 1])
plt.figure()
# 决策树 ROC
plt.plot(clf_fpr, clf_tpr, label= 'Decision Tree (area = %0.2f)' % clf_roc_auc)
plt.xlim([ 0.0, 1.0])
plt.ylim([ 0.0, 1.05])
plt.xlabel( 'False Positive Rate')
plt.ylabel( 'True Positive Rate')
plt.title( 'ROC Graph')
plt.legend(loc= "lower right")
plt.show()
— 决策树应用 —
importances = clf.feature_importances_
feat_names = df.drop([ 'left'],axis= 1).columns
indices = np.argsort(importances)[:: -1]
plt.figure(figsize=( 12, 6))
plt.title( "Feature importances by Decision Tree")
plt.bar( range( len(indices)), importances[indices], color= 'lightblue', align= "center")
plt.step( range( len(indices)), np.cumsum(importances[indices]), where= 'mid', label= 'Cumulative')
plt.xticks( range( len(indices)), feat_names[indices], rotation= 'vertical',fontsize= 14)
plt.xlim([ -1, len(indices)])
plt.show()