PyTorch用map函数进行tokenize之后,为什么需要remove_columns?

news/2024/9/30 2:10:33/

遇到的问题

下面以GLUE数据集中的tokenize为例,下面是典型的处理方法,经过一个map函数,使用lambda表达式对数据进行tokenize。后面会跟上remove_columns这句话,这是为什么呢?

python">if(task_name=="sst2"):tokenized_datasets = raw_dataset.map(lambda sample: sst2_tokenization(sample, tokenizer))tokenized_datasets = tokenized_datasets.remove_columns(['sentence', "label", "idx"])elif(task_name=="cola"):tokenized_datasets = raw_dataset.map(lambda sample: cola_tokenization(sample, tokenizer))tokenized_datasets = tokenized_datasets.remove_columns(['sentence', "label", "idx"])elif(task_name=="mnli"):tokenized_datasets = raw_dataset.map(lambda sample: mnli_tokenization(sample, tokenizer))tokenized_datasets = tokenized_datasets.remove_columns(["hypothesis", 'premise', "label", "idx"])elif(task_name=="mrpc"):tokenized_datasets = raw_dataset.map(lambda sample: mrpc_tokenization(sample, tokenizer))tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", 'sentence2', "label", "idx"])

上面代码使用的tokenize函数,比如下面两个sst2_tokenization和cola_tokenization

def sst2_tokenization(example, tokenizer):prefix = "Classification task:sst2. Choose a label from list:[0, 1] for this context: "input_str = example["sentence"]input_str = prefix + input_strinput_str += " Among them, 0 represents negative, 1 represents positive. Answer:"label_dict = {"-1":"-1", "0":"0", "1":"1"}label = example["label"]target_str = label_dict[str(label)]tokenized_data = tokenizer(input_str, return_tensors="pt", truncation=True, max_length=256, padding = "max_length")input_ids = tokenized_data.input_ids.squeeze(0)attention_mask = tokenized_data.attention_mask.squeeze(0)target_ids = tokenizer(target_str, return_tensors="pt", truncation=True, max_length=3, padding = "max_length").input_idstarget_ids = target_ids.squeeze(0)output_batch = {"input_ids": input_ids,"labels": target_ids,"attention_mask": attention_mask}return output_batchdef cola_tokenization(example, tokenizer):prefix = "Classification task:cola. Choose a label from list:[0, 1] for this context: "input_str = example["sentence"]input_str = prefix + input_strinput_str += " Among them, 0 represents unacceptable, 1 represents acceptable. Answer:"label_dict = {"-1":"-1", "0":"0", "1":"1"}label = example["label"]target_str = label_dict[str(label)]tokenized_data = tokenizer(input_str, return_tensors="pt", truncation=True, max_length=256, padding = "max_length")input_ids = tokenized_data.input_ids.squeeze(0)attention_mask = tokenized_data.attention_mask.squeeze(0)target_ids = tokenizer(target_str, return_tensors="pt", truncation=True, max_length=3, padding = "max_length").input_idstarget_ids = target_ids.squeeze(0)output_batch = {"input_ids": input_ids,"labels": target_ids,"attention_mask": attention_mask}return output_batch

参考的代码来自:https://github.com/mlwu22/TPT/blob/main/data_process.py

解答

这是我问GPT的问题:

经过cola_tokenization之后,不是结果只有 “input_ids”: input_ids, # 经过编码的输入句子
“labels”: target_ids, # 经过编码的目标标签
“attention_mask”: attention_mask # 注意力掩码吗,为什么还有’sentence’, “label”, "idx"这几个字段?

下面是回答:

你理解得没错,经过 cola_tokenization 之后,处理过的数据应该只包含 input_idslabelsattention_mask 这些字段。那么为什么在 tokenized_datasets 中还会存在 sentencelabelidx 这几个字段呢?这是因为 map 函数默认会保留原始数据集中的所有列,除非你显式地删除它们。

让我们一步一步来看:

1. 原始数据集的结构

raw_dataset 原始数据集包含以下字段:

  • sentence: 句子的文本内容
  • label: 标签,表示句子是否符合语法
  • idx: 每个样本的唯一索引,标识数据集中的位置

2. map 函数的默认行为

当你使用 map 函数时,它会应用你定义的函数(这里是 cola_tokenization)来处理每个样本。cola_tokenization 返回的是一个新的字典,包含 input_idslabelsattention_mask但是map 函数的默认行为是将这些新生成的字段添加到原始数据集的每个样本上,同时保留原有的字段,除非你告诉它删除这些字段。

因此,经过 cola_tokenization 处理后,数据集中的每个样本可能包含以下字段:

  • 原始字段:
    • sentence: 原始的句子文本
    • label: 标签
    • idx: 样本索引
  • 新生成的字段:
    • input_ids: 经过 tokenizer 处理的输入
    • attention_mask: 输入的注意力掩码
    • labels: 模型的目标标签

3. remove_columns 的作用

你要做的是删除这些原始字段,因为它们对模型训练没有用了。通过 remove_columns(['sentence', "label", "idx"]),你明确地删除了 sentencelabelidx 这几个原始字段,这样数据集中就只剩下模型需要的 input_idsattention_masklabels

总结

map 函数的默认行为是保留原始数据集的字段,所以即使你已经添加了 input_idsattention_masklabels,原来的 sentencelabelidx 依然存在。通过 remove_columns,你才能将这些原始字段删除,避免它们在训练中产生冗余信息。

后记

2024年9月26日写于上海,基于GPT4o大模型生成。


http://www.ppmy.cn/news/1532193.html

相关文章

TypeScript是基于LLM上层研发的受益者

TypeScript优在哪里 TypeScript是一种由微软开发的开源编程语言,它是JavaScript的一个超集,添加了类型系统和一些其他特性。TypeScript的优势在于: 静态类型检查:TypeScript的最大卖点是它的静态类型系统。这允许开发者在编写代码…

Python 异步编程:如何从入门到精通

在现代编程实践中,异步编程已成为提高程序性能和效率的关键技术之一。Python,作为一种动态、解释型的高级编程语言,提供了强大的异步编程能力。本文将带你从异步编程的基础知识入手,逐步深入到实际应用,探索Python异步…

Groupby_SQL和pandas等效例子

在SQL中使用PARTITION BY与在Pandas中使用groupby().agg()有一定的相似性,因为它们都是用于分组数据并对每组应用某种聚合函数。 例一 下面将给出一个SQL的例子和等效的Pandas代码。 SQL 示例 假设我们有一个名为employees的表,包含以下列&#xff1a…

javascript手写实现instanceof函数 介绍判断数组的方法

1 手写实现instanceof函数 instanceof 是 JavaScript 中的一个运算符,用于检测一个对象是否是某个构造函数的实例。它的语法是 object instanceof Constructor,如果 object 是 Constructor.prototype 的一个实例,返回 true,否则返…

LSTM预测未来30天销售额

加入深度实战社区:www.zzgcz.com,免费学习所有深度学习实战项目。 1. 项目简介 本项目旨在利用深度学习中的长短期记忆网络(LSTM)来预测未来30天的销售额。LSTM模型能够处理时序数据中的长期依赖问题,因此在销售额预测这类涉及时…

k8s_资源管理介绍

资源管理介绍 在k8s中,所有内容都抽象成资源,用户需要通过操作资源来管理k8s k8s本身就是一个集群系统,用户可以在集群中部署服务,在k8s集群中运行一个个的容器,将指定的程序部署到容器中 k8s最小的管理单元是pod&…

【.NET 8 实战--孢子记账--从单体到微服务】--用户(删除/修改/查询)

本文将继续讲解和用户相关的接口,包括用户的删除、修改和查询 一、需求 我们先来看一下这篇文章的目标也就是需求表: 编号需求标题需求内容1删除对某一个用户进行删除,注意:只能删除用户,和用户相关联的数据不能删除…

【C语言内存管理】第七章 内存管理中的常见问题

第七章 内存管理中的常见问题 内存管理是C语言编程中的一个关键环节,错误的内存操作容易导致严重的程序漏洞和系统崩溃问题。以下详细讨论内存管理中的常见问题,并提供相应的代码示例和预防方法。 1. 内存泄漏 内存泄漏是指程序在堆上分配内存后&…