NLP-transformer学习:(7)evaluate实践

embedded/2025/1/17 1:01:54/

transformer7evaluate__1">NLP-transformer学习:(7)evaluate 使用方法

在这里插入图片描述

打好基础,为了后面学习走得更远。
本章节是单独的 NLP-transformer学习 章节,主要实践了evaluate。同时,最近将学习代码传到:https://github.com/MexWayne/mexwayne_transformers-code,作者的代码版本有些细节我发现到目前不能完全行的通,为了尊重原作者,我这里保持了大部分的内容,并表明了来源,欢迎大家一起学习


提示:以下是本篇文章正文内容,下面案例可供参考

1 evaluate是什么?

evaluate库就是一个非常简单的机器学习评估库函数,封装了很多我们平差给你评估模型的函数。
地址:https://huggingface.co/evaluate-metric
我们可以在这个链接下看到有很多指标,比如 bleu sari precision
在这里插入图片描述

2 evaluate基本用法?

首先需要安装:
pip install evaluate

2.1 evaluate 调用与输入参数

代码:

import evaluateif __name__ == "__main__":# see the evalution function that evaluate supportprint(evaluate.list_evaluation_modules())#print(evaluate.list_evaluation_modules(include_community=False, with_details=True))# load the accuracy class accuracy = evaluate.load("accuracy")# introduce the accuracy functionsprint(accuracy.description)

结果:
在这里插入图片描述
但是看到这 介绍内容太少,我么不清楚输入怎么做

import evaluateif __name__ == "__main__":# load the accuracy class accuracy = evaluate.load("accuracy")# introduce the accuracy functions#print(accuracy.description)# the inputs help guide print(accuracy.inputs_description)

结果:
在这里插入图片描述
这样就会有很多详细,比如
在这里插入图片描述
这里 ,需要是 list of int 的 预测值(predictions)和真值(references),因为总共有6个,reference 和 predictions对的上的就3个,所以accuracy 是 0.5。
在这里插入图片描述
如果权重不同的还可以增加权证欧冠你,这样 accurarcy
在这里插入图片描述

2.2 evaluate 几种计算方式

(1)全局计算

import evaluateif __name__ == "__main__":######################################################## inputs and call # global accuracyaccuracy = evaluate.load("accuracy")results = accuracy.compute(references=[0,1,2,0,1,2], predictions=[0,1,1,2,1,0])print(results)# iterate accurarcyaccuracy = evaluate.load("accuracy")# for refs, preds in zip([[0,1],[0,1]],[2,8],[4,1], [[1,0],[0,1],[3,8],[4,1]]):for refs, preds in zip([0,1,2,3,4],[0,1,2,3,3]):accuracy.add(references=refs, precitions=preds)print(accuracy.compute())

第一个打印结果:
在这里插入图片描述
(2)全局计算

import evaluateif __name__ == "__main__":######################################################## inputs and call # global accuracyaccuracy = evaluate.load("accuracy")results = accuracy.compute(references=[0,1,2,0,1,2], predictions=[0,1,1,2,1,0])print(results)

在这里插入图片描述

(2)迭代计算:

import evaluateif __name__ == "__main__":# iterate accurarcyaccuracy = evaluate.load("accuracy")for refs, preds in zip([0,1,2,3,4],[0,1,2,3,3]):accuracy.add(references=refs, predictions=preds)print("iterate way:")print(accuracy.compute())

结果
在这里插入图片描述
(3)batch 迭代计算

import evaluateif __name__ == "__main__":# batch iterate accurarcyaccuracy = evaluate.load("accuracy")for refs, preds in zip([[0,1,2,3,4], [0,1,2,3,3], [8,8,8,8,8]],  # refs batches[[0,1,2,3,4], [0,1,2,3,5], [8,8,8,8,8]]):  # preds batchesaccuracy.add_batch(references=refs, predictions=preds)print("batch iterate way:")print(accuracy.compute())

结果:
在这里插入图片描述

(4)多指标
这里 我们选择 accuracy,precsion, f1, recall
在这里插入图片描述在这里插入图片描述
在这里插入图片描述

import evaluateif __name__ == "__main__":# multiple labelsclf_metrics = evaluate.combine(["accuracy", "f1", "recall", "precision", "XNLI", "SARI"])print(clf_metrics.compute(predictions=[0, 1, 1, 1, 1], references=[0, 1, 0, 1, 1]))

运行结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3 评估可视化

有多个模型,然后都去做预测,我们需要可视化比较

import evaluateif __name__ == "__main__":##################################################### visualfrom evaluate.visualization import radar_plotdata = [{"accuracy": 0.99, "precision": 0.80, "f1": 0.95, "latency_in_seconds": 33.6 , "recall":0.5},{"accuracy": 0.98, "precision": 0.87, "f1": 0.91, "latency_in_seconds": 11.2 , "recall":0.5},{"accuracy": 0.98, "precision": 0.78, "f1": 0.88, "latency_in_seconds": 87.6 , "recall":0.6}, {"accuracy": 0.88, "precision": 0.78, "f1": 0.81, "latency_in_seconds": 101.6, "recall":0.7},{"accuracy": 0.78, "precision": 0.78, "f1": 0.81, "latency_in_seconds": 100.0, "recall":0.9}]model_names = ["Model 1", "Model 2", "Model 3", "Model 4", "Model 5"]plot = radar_plot(data=data, model_names=model_names)print(type(plot))plot.show()plot.savefig('radar.png')

结果
在这里插入图片描述


http://www.ppmy.cn/embedded/115528.html

相关文章

速盾:高防 CDN 怎么屏蔽恶意访问?

在当今网络环境中,恶意访问是网站和应用面临的一个严重问题。高防 CDN(Content Delivery Network,内容分发网络)作为一种强大的防护工具,可以有效地屏蔽恶意访问,保护网站和应用的安全。那么,高…

无人机集群路径规划:雾凇优化算法( rime optimization algorithm,RIME)求解无人机集群路径规划,提供MATLAB代码

一、单个无人机路径规划模型介绍 无人机三维路径规划是指在三维空间中为无人机规划一条合理的飞行路径,使其能够安全、高效地完成任务。路径规划是无人机自主飞行的关键技术之一,它可以通过算法和模型来确定无人机的航迹,以避开障碍物、优化…

【Android】BottomSheet基本用法总结(BottomSheetDialog,BottomSheetDialogFragment)

BottomSheet BottomSheet 是一种位于屏幕底部的面板,用于显示附加内容或选项。提供了从屏幕底部向上滑动显示内容的交互方式。这种设计模式在 Material Design 中被广泛推荐,因为它可以提供一种优雅且不干扰主屏幕内容的方式来展示额外信息或操作。 具体…

深入理解API和前后端网络请求流程

在现代web应用开发中,理解API和网络请求流程的细节至关重要。本文将深入探讨从用户操作到后端处理,再到前端展示的整个过程,包括每个环节的作用、原理和潜在的优化点。 一、API的本质与类型 1. API的定义与作用 API(应用程序编…

Study Plan For Algorithms - Part35

1. 滑动窗口的最大值 给定一个数组 nums 和滑动窗口的大小 k,请找出所有滑动窗口里的最大值。 方法一: def maxSlidingWindow(nums, k):queue []res []if not nums:return queuefor r in range(len(nums)):l r - k 1if l > 0 and queue[0] nums…

阿里云 Quick BI使用介绍

Quick BI使用介绍 文章目录 阿里云 Quick BI使用介绍1. 创建自己的quick bi服务器2. 新建数据源3. 上传文件和 使用4. 开始分析 -选仪表盘5. 提供的图表6. 一个图表的设置使用小结 阿里云 Quick BI使用介绍 Quick BI是一款全场景数据消费式的BI平台,秉承全场景消费…

keepalived高可用

keepalived是什么keepalived是集群管理中保证集群高可用的一个服务软件,用来防止单点故障。keepalived工作原理keepalived是以VRRP协议为实现基础的,VRRP全称Virtual Router Redundancy Protocol,即虚拟路由冗余协议。虚拟路由冗余协议&#…

安装MySQL驱动程序笔记一

安装MySQL驱动程序到计算机上,主要依赖于你使用的编程语言和开发环境。以下是一个通用的安装步骤,特别是针对Java环境,因为MySQL Connector/J是最常用的MySQL JDBC驱动程序。 1. 下载MySQL驱动程序 首先,你需要从MySQL官网&…