持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第3天, 点击查看活动详情
1 机器学习的“Hello World”
在机器学习内有一个被广泛使用和学习数据集是- MNIST数据集 ,它是一组由美国高中生和人口调查局员工手写的70000个数字的图片集合。也被称为机器学习领域的“Hello World”。该数据集也被常常用于各种分类算法的检测之中。
Scikit-Learn有一个sklearn.datasets模块,其中有各种工具类帮助我们下载数据和获取一些比较流行的数据集。我们可以用datasets提供的fetch_openml方法去下载openml的数据:
from sklearn.datasets import fetch_openml
# 这边需要注意如果不设置as_frame=false,则fetch_openml()会返回DataFrame格式的数据集
mnist = fetch_openml("mnist_784",version=1,as_frame=False)
mnist.keys()
# 我们可以查看到数据集中包含的内容
dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])
fetch_openml的详细内容可以查看sklearn.datasets.fetch_openml的官网文档。我们能够很容易知道fetch_openml()会返回一个sklearn.utils.Bunch类型的data(如果我们参数这是了return_X_y=True,则直接返回data,target的元组),它类似于字典结构,我们可以用keys()输出其包含的内容:
DESCR键
描述数据集
data键
包含一个数组,每个实例为一行,每个特征为一列
target键
包含一个带标记的数组
我们只需要data和target
X, y =mnist['data'],mnist['target']
X.shape # 输出(70000, 784)
y.shape # 输出(70000,)
从输出中可以看出,一共有70000张图片,包含了784个特征。同时根据介绍openml官网对该集合的介绍如下:
从介绍上中我们可以在得知两个有用的信息,一是该数据集已经分好了训练集和测试集、二是图片是28*28像素大小。我们可以通过Matplotlib的imshow()拿其中一张图片看一下:
import matplotlib.pyplot as plt
import matplotlib as mpl
some_digit = X[0]
# 因为原数据集是每行放了784个特征,我们需要重新形成一个28*28的数组
some_digit_image = some_digit.reshape(28,28)
plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()
图1-2 X[0]的图像
我们看到图片显示的是黑白图像,是因为我用了cmap='binary',我们还能够去matplotlib官网的Example页面搜索Colormap reference。里面详细讲了Matplotlib包含的colormaps。如下图1-3为我截取的一部分cmap,里面可以看到binary就是一个只有从白到黑的连续颜色集合。读者也可以试着改改cmap选项,让图片能够更加多彩。
图1-3 Sequential (2) colormaps
2 训练二元分类器
在我们开始使用数据集之前,应该要先创建一个测试集,同时最好还能够将测试集混淆。因为有些算法对于训练实例的顺序敏感,例如连续输入许多类似的实例,可能执行的性能不佳(当然我们也要注意例如股市价格或天气情况这类的时间序列数据,可能混洗之后反而不好)。
之前我们已经知道mnist数据集已经分好了训练集和测试集。那我们可以用训练集先来完成一个能够分辨5的"数字5检测器"(当然之后我们也会明白,这可以是1-9之间任意的数字检测)。
from sklearn.linear_model import SGDClassifier
# 由于mnist数据集给的标签是字符类型,需要使用np.astype将y集转换成数字集
y = y.astype(np.uint8)
# 分割训练集和测试集
X_train,X_test,y_train,y_test = X[:60000],X[:60000],y[:60000],y[60000:]
# 设置标签
# 使用该方法后,将y_train转化成只有bool值的标签
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
# 引入SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)
print( sgd_clf.predict([some_digit]) )
[True]
这次我选择了随机梯度下降(SGD)分类器,因为它能够很有效的处理非常大型的数据。还因SGD独立处理训练实例,一次一个也适合在线学习。这次的结果还是很不错的,我们的分类器猜对了。
3 性能测量
我们的分类器选好了,我们还需要测量其准确率如何。依旧是选择K-折交叉验证取评估SGDClassifier模型,但是这次我们的返回不是neg_mean_squared_error,而是accuracy。
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf,X_train,y_train_5,cv=3,scoring="accuracy")
array([0.95035, 0.96035, 0.9604 ])
当我们为scoring选择了“accuracy”,这时的cross_val_score()函数,会调用sklearn.metrics.accuracy_score方法去比较我们的预测值和标签,最后返回每次的准确率。三个准确率似乎还挺高,但真的有这么高吗?
这里我们再看一个傻瓜分类器,他将所有的图都分成“非5”:
from sklearn.base import BaseEstimator
class Never5Classifier(BaseEstimator):
def fit(self, X, y=None):
def predict(self, X):
# 预测值都是False,表示非5
return np.zeros((len(X), 1), dtype=bool)
never_5_clf = Never5Classifier()
cross_val_score(never_5_clf, X_train, y_train_5, cv=3,scoring="accuracy")
array([0.91125, 0.90855, 0.90915])
看结果是不是很神奇!一个全猜非5的傻瓜分类器,怎么也有90%多准确率呢?其实也很好理解,5出现的概率为10%,不是5的概率不就有90%吗。因此我们对于分类器的评估方法还是有问题,也说明了准确率通常不能作为分类器的首要性能指标,尤其是处理一些偏数据集(某些数据类比其他类更多时)。
3.1 混淆矩阵
这时候我们需要引入一个更好的方法-混淆矩阵。其总体思路是统计A类别实例被分成B类的次数。
图3-1 混淆矩阵
我们通过图3-1可以了解混淆矩阵的基本概念,其中行表示实际的类别,列表示预测的类别:
第一行第一列是真正类(TP)
,表示真实为A类且预测A类
第一行第二列是假负类(FN)
,表示真实为A类但预测为B类
第二行第一列是假正类(FP)
,表示真实为B类但预测为A类
第二行第二列是真负类(TN)
,表示真实为B类且预测为B类
我在引入一个案例,例如我们将患有某种病的人和为患病的人进行统计后形成了如下的混淆矩阵:
图3-2 患病和未患的混淆矩阵
我们能够更加清楚了解混淆矩阵的含义:
每一行表示实际的样本数(本例实际患者15+16个,未患者22+35个)
每一列表示预测的结果数(本例预测患者15+22个,未患者16+35个)。
有了对混淆矩阵的基础认识,让我们继续回到MNIST数据集上。我们知道实际值,就是我们的targe标签,但是我们没有预测值,这里我们需要使用cross_val_predict()函数。从名字可以看出与cross_val_score()函数的区别在于,cross_val_predict()返回的只有预测值,他不会去和targe比较,返回评估分数。
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
# 通过3-折交叉验证获得一个干净的预测
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
# 使用confusion_matrix获取混淆矩阵
confusion_matrix(y_train_5, y_train_pred)
array([[53892, 687],
[ 1891, 3530]], dtype=int64)
返回的真正如下表:
表3-1 分类器的混淆矩阵
非5(False) 5(True) 非5(False) 53892 687 5(True) 1891 3530
看表格我们可以很清楚知道:53892个图片被正确分类成了“非5”,3530张图片被正确分成了“5”。
3.2 精度和召回率
我们需要引入新的指标了,既然我们能够知道正确和错误的分类数,那我们是不是就可以根据这些分类数获得分类器的精度:
公式3-1:精度(Precision)
我们可以知道TP+FP那就是实际的正类总数,TP是我们的预测正确的正类数,因此这个精度可以理解为我们预测正确的精度。当然最完美的精度就是
单独一个精度,只能够关注于正类而忽略了其他的信息,这显然不好。因此我们需要继续了解另一个指标-召回率,也称灵敏度(sensitivity) 或者 真正类率(rue positive rate,TPR),它是分类器正确检测到的正类实例的比率:
公式3-2:召回率(Recall)
TP+FN就是我们的预测的总数,那么其实召回率就是我们预测正确数在预测总数的比例。
接下来我将混淆矩阵、精度和召回率放在一起加深理解:
图3-3 混淆矩阵、精度和召回率
接下来我们继续使用scikit-learn计算精度和召回率:
from sklearn.metrics import precision_score, recall_score
precision_score(y_train_5, y_train_pred) # 输出 0.8370879772350012
recall_score(y_train_5, y_train_pred) # 输出 0.6511713705958311
显示再看我们的5-检测器似乎准确度就不是那么高了,当他说这是一个5时,只有83%的准确率,而且也只能检测出全部5中的65%。
我们可以将精度和召回率进行组合,形成新的指标,称为
公式3-3:
根据公式3-3,我们可知只有当精度和召回率都很高,
from sklearn.metrics import f1_score
f1_score(y_train_5, y_train_pred) #输出 0.7325171197343846
在
例如:在训练一个可供儿童观看的视频分类器时,我们希望的是能够精度高一些,尽可能排除一些不良视频,哪怕会误杀一部分好视频。但是如果我们在训练一些检测可疑人物的监控分类器时,又希望能够提高报警正确率也就是召回率高些,不然时不时给你一个错误警报,那我们也受不了啊。
事实上,我们现在基本都是精度和召回率之间呈现反比之趋势,要么精度高了召回率下降,反之亦然。
3.3 精度/召回率权衡
对于上述的问题,我们就需要进行精度/召回率权衡。OK,我们回到一开始使用过的随机梯度下降分类器(SGDClassifier)中,我们看一下他是怎么工作的。
对于每一个实例,他会基于决策函数先计算出一个分值,同时为了判断还有一个阈值,当分值大于阈值时,则将该实例判成正类,相反则是负类。
图3-4 阈值、精度和召回率
从图3-4 可以看到,假设我们的阈值在蓝色箭头处,我们可以找到3个真的5以及一个假的5,精度为75%,在总的4个5中,找到了3个,召回率也是75%。但是当我们移动阈值到橙色箭头处,就会发现精度增加了,但是召回也在下降。而将阈值移动到绿色箭头处,则是相反。
我们需要注意,召回率在阈值提高时一定是下降的,但是精度却不一定上升,例如蓝色箭头处,当我们将阈值提高一位数字,我们会发现,精度反而下降到了2/3=67%。
因此我们需要调整阈值,来调整分类器的精度和召回率。但是Scikit-Learn不允许我们直接修改阈值。那我们就需要自己根据他的决策分数进行调整,通过decision_function()可以访问决策分数:
y_score = sgd_clf.decision_function([some_digit])
y_score # 输出 array([2164.22030239])
# 先将阈值设置为0
threshold = 0
y_some_digit_pref = (y_score > threshold)
y_some_digit_pref # 输出 array([True])
# 调整阈值为5000
threshold = 5000
y_some_digit_pref = (y_score > threshold)
y_some_digit_pref # 输出 array([False])
很明显的当阈值为0时,分类器预测正确了该图片为5,但是阈值提高到5000,分类器就错估了这张图。因此如何决定阈值的高低需要继续计算:
# mpl在windows上绘图时,因为字体的原因,导致无法显示负号,这段代码用于解决负号显示问题
plt.rcParams['axes.unicode_minus'] = False
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="精度", linewidth=2)
plt.plot(thresholds, recalls[:-1], "g-", label="召回率", linewidth=2)
plt.legend(loc="center right", fontsize=16)
plt.xlabel("阈值", fontsize=16)
plt.grid(True)
plt.axis([-50000, 50000, 0, 1])
# 获得精度大于90%时recall和thresholds坐标
recall_90_precision = recalls[np.argmax(precisions >= 0.90)]
# threshold_90_precision 输出 3370.0194991439594
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]
plt.figure(figsize=(8, 4)) # 画出精度和召回率
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
# 画出精度在0.9时,精度和召回率的与X轴的连接线
plt.plot([threshold_90_precision, threshold_90_precision], [0., 0.9], "r:")
# 画出精度在0.9时,精度和召回率的与y轴的连接线
plt.plot([-50000, threshold_90_precision], [0.9, 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [recall_90_precision, recall_90_precision], "r:")
# 画出精度在0.9时,精度和召回率的点
plt.plot([threshold_90_precision], [0.9], "bo")
plt.plot([threshold_90_precision], [recall_90_precision], "ro")
plt.show()
图3-5 精度和召回率与决策阈值
上图中我们使用`np.argmax(precisions >= 0.90)`这个函数计算出precisions中大于0.9的集合中最大值的第一个索引,然后获得`recall_90_precision`和`threshold_90_precision`两个点,然后在函数图中显示出。
我们还可以绘制更方便查看精度/召回率的图,那就是直接绘制精度和召回率的关系:
def plot_precision_vs_recall(precisions, recalls):
plt.plot(recalls, precisions, "b-", linewidth=2)
plt.xlabel("召回率", fontsize=16)
plt.ylabel("精度", fontsize=16)
plt.axis([0, 1, 0, 1])
plt.grid(True)
plt.figure(figsize=(8, 6))
plot_precision_vs_recall(precisions, recalls)
plt.plot([recall_90_precision, recall_90_precision], [0., 0.9], "r:")
plt.plot([0.0, recall_90_precision], [0.9, 0.9], "r:")
plt.plot([recall_90_precision], [0.9], "ro")
plt.show()
图3-5 精度和召回率
怎么样是不是这样更加方便我们观察精度和召回率的关系,但是我们可要注意了,不管是高精度还是高召回率时,相对的召回率和精度都会变得很低!如:当召回率高于80%时,精度就开始大幅度下降。
还记得精度和阈值的关系吗,最后我们可以按照之前的方法获得一个精度为90%的分类器:
# 获得精度为90%阈值
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]
# 根据阈值与决策分数,得出我们的对于图片的预测
y_train_pred_90 = (y_scores > threshold_90_precision)
# 看看我们的精确度
precision_score(y_train_5, y_train_pred_90)
# 输出 0.9
完美,我们成功获得精度为90%的分类器。