不平衡数据集的建模的技巧和策略

news/2024/11/22 10:06:32/

不平衡数据集是指一个类中的示例数量与另一类中的示例数量显著不同的情况。 例如在一个二元分类问题中,一个类只占总样本的一小部分,这被称为不平衡数据集。类不平衡会在构建机器学习模型时导致很多问题。

不平衡数据集的主要问题之一是模型可能会偏向多数类,从而导致预测少数类的性能不佳。 这是因为模型经过训练以最小化错误率,并且当多数类被过度代表时,模型倾向于更频繁地预测多数类。 这会导致更高的准确率得分,但少数类别得分较低。

另一个问题是,当模型暴露于新的、看不见的数据时,它可能无法很好地泛化。 这是因为该模型是在倾斜的数据集上训练的,可能无法处理测试数据中的不平衡。

在本文中,我们将讨论处理不平衡数据集和提高机器学习模型性能的各种技巧和策略。 将涵盖的一些技术包括重采样技术、代价敏感学习、使用适当的性能指标、集成方法和其他策略。 通过这些技巧,可以为不平衡的数据集构建有效的模型。

处理不平衡数据集的技巧

重采样技术是处理不平衡数据集的最流行方法之一。 这些技术涉及减少多数类中的示例数量或增加少数类中的示例数量。

欠采样可以从多数类中随机删除示例以减小其大小并平衡数据集。 这种技术简单易行,但会导致信息丢失,因为它会丢弃一些多数类示例。

过采样与欠采样相反,过采样随机复制少数类中的示例以增加其大小。 这种技术可能会导致过度拟合,因为模型是在少数类的重复示例上训练的。

SMOTE是一种更高级的技术,它创建少数类的合成示例,而不是复制现有示例。 这种技术有助于在不引入重复项的情况下平衡数据集。

代价敏感学习(Cost-sensitive learning)是另一种可用于处理不平衡数据集的技术。 在这种方法中,不同的错误分类成本被分配给不同的类别。 这意味着与错误分类多数类示例相比,模型因错误分类少数类示例而受到更严重的惩罚。

在处理不平衡的数据集时,使用适当的性能指标也很重要。 准确性并不总是最好的指标,因为在处理不平衡的数据集时它可能会产生误导。 相反,使用 AUC-ROC等指标可以更好地指示模型性能。

集成方法,例如 bagging 和 boosting,也可以有效地对不平衡数据集进行建模。 这些方法结合了多个模型的预测以提高整体性能。 Bagging 涉及独立训练多个模型并对它们的预测进行平均,而 boosting 涉及按顺序训练多个模型,其中每个模型都试图纠正前一个模型的错误。

重采样技术、成本敏感学习、使用适当的性能指标和集成方法是一些技巧和策略,可以帮助处理不平衡的数据集并提高机器学习模型的性能。

在不平衡数据集上提高模型性能的策略

收集更多数据是在不平衡数据集上提高模型性能的最直接策略之一。 通过增加少数类中的示例数量,模型将有更多信息可供学习,并且不太可能偏向多数类。 当少数类中的示例数量非常少时,此策略特别有用。

生成合成样本是另一种可用于提高模型性能的策略。 合成样本是人工创建的样本,与少数类中的真实样本相似。 这些样本可以使用 SMOTE等技术生成,该技术通过在现有示例之间进行插值来创建合成示例。 生成合成样本有助于平衡数据集并为模型提供更多示例以供学习。

使用领域知识来关注重要样本也是一种可行的策略,通过识别数据集中信息量最大的示例来提高模型性能。 例如,如果我们正在处理医学数据集,可能知道某些症状或实验室结果更能表明某种疾病。 通过关注这些例子可以提高模型准确预测少数类的能力。

最后可以使用异常检测等高级技术来识别和关注少数类示例。 这些技术可用于识别与多数类不同且可能是少数类示例的示例。 这可以通过识别数据集中信息量最大的示例来帮助提高模型性能。

在收集更多数据、生成合成样本、使用领域知识专注于重要样本以及使用异常检测等先进技术是一些可用于提高模型在不平衡数据集上的性能的策略。 这些策略可以帮助平衡数据集,为模型提供更多示例以供学习,并识别数据集中信息量最大的示例。

不平衡数据集的练习

这里我们使用信用卡欺诈分类的数据集演示处理不平衡数据的方法

 importpandasaspdimportnumpyasnpfromsklearn.preprocessingimportRobustScalerfromsklearn.linear_modelimportLogisticRegressionfromsklearn.model_selectionimporttrain_test_splitfromsklearn.metricsimportaccuracy_scorefromsklearn.metricsimportconfusion_matrix, classification_report,f1_score,recall_score,roc_auc_score, roc_curveimportmatplotlib.pyplotaspltimportseabornassnsfrommatplotlibimportrc,rcParamsimportitertoolsimportwarningswarnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning)

读取数据

 df=pd.read_csv("creditcard.csv")df.head()print("Number of observations : " ,len(df))print("Number of variables : ", len(df.columns))#Number of observations :  284807#Number of variables :  31

查看数据集信息

 df.info()<class'pandas.core.frame.DataFrame'>RangeIndex: 284807entries, 0to284806Datacolumns (total31columns):#   Column  Non-Null Count   Dtype  ---  ------  --------------   -----  0   Time    284807non-null  float641   V1      284807non-null  float642   V2      284807non-null  float643   V3      284807non-null  float644   V4      284807non-null  float645   V5      284807non-null  float646   V6      284807non-null  float647   V7      284807non-null  float648   V8      284807non-null  float649   V9      284807non-null  float6410  V10     284807non-null  float6411  V11     284807non-null  float6412  V12     284807non-null  float6413  V13     284807non-null  float6414  V14     284807non-null  float6415  V15     284807non-null  float6416  V16     284807non-null  float6417  V17     284807non-null  float6418  V18     284807non-null  float6419  V19     284807non-null  float6420  V20     284807non-null  float6421  V21     284807non-null  float6422  V22     284807non-null  float6423  V23     284807non-null  float6424  V24     284807non-null  float6425  V25     284807non-null  float6426  V26     284807non-null  float6427  V27     284807non-null  float6428  V28     284807non-null  float6429  Amount  284807non-null  float6430  Class   284807non-null  int64  dtypes: float64(30), int64(1)memoryusage: 67.4MB

查看分类类别:

 f,ax=plt.subplots(1,2,figsize=(18,8))df['Class'].value_counts().plot.pie(explode=[0,0.1],autopct='%1.1f%%',ax=ax[0],shadow=True)ax[0].set_title('dağılım')ax[0].set_ylabel('')sns.countplot('Class',data=df,ax=ax[1])ax[1].set_title('Class')plt.show()

 rob_scaler=RobustScaler()df['Amount'] =rob_scaler.fit_transform(df['Amount'].values.reshape(-1,1))df['Time'] =rob_scaler.fit_transform(df['Time'].values.reshape(-1,1))df.head()

创建基类模型

 X=df.drop("Class", axis=1)y=df["Class"]X_train, X_test, y_train, y_test=train_test_split(X, y, test_size=0.20, random_state=123456)model=LogisticRegression(random_state=123456)model.fit(X_train, y_train)y_pred=model.predict(X_test)accuracy=accuracy_score(y_test, y_pred)print("Accuracy: %.3f"%(accuracy))

我们创建的模型的准确率评分为0.999。我们可以说我们的模型很完美吗?

混淆矩阵是一个用来描述分类模型的真实值在测试数据上的性能的表。它包含4种不同的估计值和实际值的组合。

 defplot_confusion_matrix(cm, classes,title='Confusion matrix',cmap=plt.cm.Blues):plt.rcParams.update({'font.size': 19})plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title,fontdict={'size':'16'})plt.colorbar()tick_marks=np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=45,fontsize=12,color="blue")plt.yticks(tick_marks, classes,fontsize=12,color="blue")rc('font', weight='bold')fmt='.1f'thresh=cm.max()fori, jinitertools.product(range(cm.shape[0]), range(cm.shape[1])):plt.text(j, i, format(cm[i, j], fmt),horizontalalignment="center",color="red")plt.ylabel('True label',fontdict={'size':'16'})plt.xlabel('Predicted label',fontdict={'size':'16'})plt.tight_layout()plot_confusion_matrix(confusion_matrix(y_test, y_pred=y_pred), classes=['Non Fraud','Fraud'],title='Confusion matrix')

•非欺诈类共进行了56875次预测,其中56870次(TP)正确,5次(FP)错误。

•欺诈类共进行了87次预测,其中31次(FN)错误,56次(TN)正确。

该模型可以预测欺诈状态,准确率为0.99。但当检查混淆矩阵时,欺诈类的错误预测率相当高。也就是说该模型正确地预测了非欺诈类的概率为0.99。但是非欺诈类的观测值的数量高于欺诈类的观测值的数量,这拉搞了我们对准确率的计算,并且我们更加关注的是欺诈类的准确率,所以我们需要一个指标来衡量它的性能。

选择正确的指标

在处理不平衡数据集时,选择正确的指标来评估模型的性能非常重要。 传统指标,如准确性、精确度和召回率,可能不适用于不平衡的数据集,因为它们没有考虑数据中类别的分布。

经常用于不平衡数据集的一个指标是 F1 分数。 F1 分数是精确率和召回率的调和平均值,它提供了两个指标之间的平衡。 计算如下:

F1 = 2 * (precision * recall) / (precision + recall)

另一个经常用于不平衡数据集的指标是 AUC-ROC。 AUC-ROC 衡量模型区分正类和负类的能力。 它是通过绘制不同分类阈值下的TPR与FPR来计算的。 AUC-ROC 值的范围从 0.5(随机猜测)到 1.0(完美分类)。

 print(classification_report(y_test, y_pred))precision   recall   f1-score   support0       1.00      1.00      1.00     568751       0.92      0.64      0.76        87accuracy                           1.00     56962macroavg       0.96      0.82      0.88     56962weightedavg       1.00      1.00      1.00     56962

返回对0(非欺诈)类的预测有多少是正确的。查看混淆矩阵,56870 + 31 = 56901个非欺诈类预测,其中56870个预测正确。0类的精度值接近1 (56870 / 56901)

返回对1 (欺诈)类的预测有多少是正确的。查看混淆矩阵,5 + 56 = 61个欺诈类别预测,其中56个被正确估计。0类的精度为0.92 (56 / 61),可以看到差别还是很大的

过采样

通过复制少数类样本来稳定数据集。

随机过采样:通过添加从少数群体中随机选择的样本来平衡数据集。如果数据集很小,可以使用这种技术。可能会导致过拟合。randomoverampler方法接受sampling_strategy参数,当sampling_strategy = ’ minority '被调用时,它会增加minority类的数量,使其与majority类的数量相等。

我们可以在这个参数中输入一个浮点值。例如,假设我们的少数群体人数为1000人,多数群体人数为100人。如果我们说sampling_strategy = 0.5,少数类将被添加到500。

 y_train.value_counts()0    2274401       405Name: Class, dtype: int64fromimblearn.over_samplingimportRandomOverSampleroversample=RandomOverSampler(sampling_strategy='minority')X_randomover, y_randomover=oversample.fit_resample(X_train, y_train)

采样后训练

 model.fit(X_randomover, y_randomover)y_pred=model.predict(X_test)plot_confusion_matrix(confusion_matrix(y_test, y_pred=y_pred), classes=['Non Fraud','Fraud'],title='Confusion matrix')

应用随机过采样后,训练模型的精度值为0.97,出现了下降。但是从混淆矩阵来看,模型的欺诈类的正确估计率有所提高。

SMOTE 过采样:从少数群体中随机选取一个样本。然后,为这个样本找到k个最近的邻居。从k个最近的邻居中随机选取一个,将其与从少数类中随机选取的样本组合在特征空间中形成线段,形成合成样本。

 from imblearn.over_sampling import SMOTEoversample = SMOTE()X_smote, y_smote = oversample.fit_resample(X_train, y_train)

使用SMOTE后的数据训练

 model.fit(X_smote, y_smote)y_pred = model.predict(X_test)accuracy = accuracy_score(y_test, y_pred)plot_confusion_matrix(confusion_matrix(y_test, y_pred=y_pred), classes=['Non Fraud','Fraud'],title='Confusion matrix')

可以看到与基线模型相比,欺诈的准确率有所提高,但是比随机过采样有所下降,这可能是数据集的原因,因为SMOTE采样会生成心的数据,所以并不适合所有的数据集。

总结

在这篇文章中,我们讨论了处理不平衡数据集和提高机器学习模型性能的各种技巧和策略。不平衡的数据集可能是机器学习中的一个常见问题,并可能导致在预测少数类时表现不佳。

本文介绍了一些可用于平衡数据集的重采样技术,如欠采样、过采样和SMOTE。还讨论了成本敏感学习和使用适当的性能指标,如AUC-ROC,这可以提供更好的模型性能指示。

处理不平衡的数据集是具有挑战性的,但通过遵循本文讨论的技巧和策略,可以建立有效的模型准确预测少数群体。重要的是要记住最佳方法将取决于特定的数据集和问题,为了获得最佳结果,可能需要结合各种技术。因此,试验不同的技术并使用适当的指标评估它们的性能是很重要的。

https://avoid.overfit.cn/post/774ca6891f26470093970c074afceede

作者:Emine Bozkuş


http://www.ppmy.cn/news/20839.html

相关文章

Mac 打开JD-GUI报错:ERROR launching ‘JD-GUI‘

目录一、JD-GUI下载二、JD-GUI报错信息三、解决方案1、查找JD-GUI包内容2、修改universalJavaApplicationStub.sh文件一、JD-GUI下载 JD-GUI下载地址&#xff1a;https://github.com/java-decompiler/jd-gui/releases 二、JD-GUI报错信息 Mac系统版本&#xff1a;11.3 JD-GUI…

【数据结构初阶】第三篇——单链表

链表的概念及其结构 初始化链表 打印单链表 增加结点 头插 尾插 在给定位置之前插入 在给定位置之后插入 删除结点 头删 尾删 删除给定位置的结点 查找数据 修改数据 链表的概念及其结构 基本概念 链表是一种物理存储结构上非连续&#xff0c;非顺序的存储结构&a…

解决问题的方法论

概述 解决问题的能力是职场中最重要的能力之一&#xff0c;如何逻辑清晰、效率满满的解决问题&#xff0c;可参考以下4个步骤。 一、准确的界定问题 找出真正的问题。 准确的界定问题&#xff0c;避免被表面现象所迷惑。 《麦肯锡工具》中&#xff0c;给出一个标准的步骤&am…

电子技术——基本MOS放大器配置

电子技术——基本MOS放大器配置 上一节我们探究了一种MOS管的放大器实现&#xff0c;其实MOS放大器还有许多变种配置&#xff0c;在本节我们学习最基本的三大MOS放大器配置&#xff0c;分别是共栅极&#xff08;CG&#xff09;、共漏极&#xff08;CD&#xff09;、共源极&…

易控智驾:用最“接地气”的自动驾驶,写一本“矿区修炼手册”

CES2023刚刚在拉斯维加斯闭幕&#xff0c;作为行业风向标&#xff0c;本届展会上元宇宙、汽车技术等重要科技依然是大亮点。宝马、英特尔等厂商&#xff0c;依然带来了有趣的消费级产品&#xff0c;但也有更多的工业与制造业产品、方案&#xff0c;带着更多的科技智能属性脱颖而…

【C++算法图解专栏】一篇文章带你入门二分算法

✍个人博客&#xff1a;https://blog.csdn.net/Newin2020?spm1011.2415.3001.5343 &#x1f4e3;专栏定位&#xff1a;为 0 基础刚入门数据结构与算法的小伙伴提供详细的讲解&#xff0c;也欢迎大佬们一起交流~ &#x1f4da;专栏地址&#xff1a;https://blog.csdn.net/Newin…

element 日期组件实现只能选择小时或者只能选择小时、分钟

前言 在使用 element 框架时&#xff0c;总是会有一些满足不了现有项目需求的问题&#xff0c;这个时候就需要我们对 element 的组件进行改造&#xff0c;最近有一个需求就是要求日期组件只能选择年月日时&#xff0c;不要分钟和秒&#xff0c;找了一圈&#xff0c;发现 elemen…

linux内核之netlink通信

Linux内核(04)之netlink通信 Author&#xff1a;Onceday Date&#xff1a;2023年1月3日 漫漫长路&#xff0c;才刚刚开始… 参考文档&#xff1a; netlink 机制 binarydady 阿里云开发者社区linux中通用Netlink详解及使用剖析 binarydady 阿里云开发者社区RFC 3549 Linux N…