【机器学习300问】48、如何绘制ROC曲线?

您所在的位置:网站首页 如何看roc曲线有没造假过 【机器学习300问】48、如何绘制ROC曲线?

【机器学习300问】48、如何绘制ROC曲线?

2024-07-16 23:43| 来源: 网络整理| 查看: 265

        ROC曲线(受试者工作特征曲线)是一种用于可视化评估二分类模型性能的指标。特别是在不同阈值情况下模型对正类和负类的区分能力。那么“阈值”到底是个什么呢?ROC曲线中的每一个点到底是什么意思?

一、ROC曲线的绘制【理论】

        二分类器(模型)输出的是预测样本的正类概率,模型在预测完所有样本的概率后会对其进行降序排序。假设一个样本被二分类器预测输出的概率是0.6,那么到底这个样本是正类还是负类呢?如果我们认为超过0.5的概率就是正类,那么显然该样本的预测标签为“正”。但如果我们认为超过0.6才算正类,那么样本的标签就成“负”的了。所以“阈值”就是人们判定预测结果到底正还是负的一个依据。

        阈值,预测概率大于该阈值样本判定为正,预测概率小于该阈值样本判定为负。ROC曲线绘制的过程,就是逐渐调整阈值,计算每次调整的阈值对应的(FPR,TPR),并在表格上绘制出该点的位置,最后把所有点连起来就得到了ROC曲线。

二、ROC曲线的绘制【实践】 (1)来点数据 序号真实标签模型输出概率(降序排列)110.95210.9310.85410.8510.75600.7700.65810.6900.551000.51110.451200.41300.351400.31510.251600.21700.151800.11900.052010.0

        假设测试集中有20个样本,如上表所示按照概率降序排列。 分别列出了样本序号,样本真实的分类,模型预测输出的概率。

(2)文字演示

        当阈值为正无穷的时候,也就是说哪怕样本的概率是1,也没有一个样本被模型认为是正类,分类器认为全部都是负的,此时的FP=TP=0,显然FPR=TPR=0,在曲线上的坐标就是(0,0)

        当阈值设定为0.9的时候,上表中样本1和2都被预测为正。此时的P=9,TP=2得到TPR=2/9=0.22。此时没有预测错的样本FP=0算出FPR=0/11=0。最终的在曲线上的坐标就是(0,0.22)

        依次按照文字描述的过程,就可以计算得到所有阈值(这里我们将预测值的分度值设定成0.1,从1.0逐渐下降至0.0)坐标。将点连城线就得到了ROC曲线。

(3)代码演示 ① 导入必要的库 import numpy as np import pandas as pd from sklearn.metrics import roc_curve, auc import matplotlib.pyplot as plt ② 构造测试集 # 假设我们有如上表格所示的数据存储在一个DataFrame中 sample_data = pd.DataFrame({ '真实标签': [1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1], '模型输出概率': [0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15, 0.1, 0.05, 0.0] }) # 可以输出查看一下sample_data

 ③ 绘制ROC曲线 # 将'真实标签'转化为二进制形式(通常真实标签会被编码为0和1) true_labels = sample_data['真实标签'].astype(int) # 获取'模型输出概率' predicted_probs = sample_data['模型输出概率'] # 计算ROC曲线所需的各项指标 fpr, tpr, _ = roc_curve(true_labels, predicted_probs, pos_label=1) # 计算曲线下面积(AUC) roc_auc = auc(fpr, tpr) # 绘制ROC曲线 plt.figure() plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc) plt.plot([0, 1], [0, 1], 'k--') # 平行于坐标轴的直线,代表随机猜测的结果 plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver Operating Characteristic Curve') plt.legend(loc="lower right") plt.show()

        如果我们在图中把0.1,0.2一直到1这十个阈值标出来的话,就是下面这个图:

        在我们文字演示时,设定当阈值=0.9的时候,对应的坐标(0.0.22)在图中很清晰的现实出来了。 上图的代码阈值刻度是sklearn.metrics.roc_curve 函数依据模型输出的概率得分y_score,以排序后从最小到最大的顺序依次作为阈值,计算出每个阈值下的真阳性率(TPR)和假阳性率(FPR),从而生成一系列坐标点绘制成ROC曲线。



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3