llama factory 是如何加载数据集 通过对数据集加载的代码的理解编写自定义数据集训练代码

news/2024/11/28 17:55:18/

第一层从训练代码追踪到以下代码

def get_dataset(tokenizer: "PreTrainedTokenizer",model_args: "ModelArguments",data_args: "DataArguments",training_args: "Seq2SeqTrainingArguments",stage: Literal["pt", "sft", "rm", "ppo"],# split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]:template = get_template_and_fix_tokenizer(tokenizer, data_args.template)if data_args.train_on_prompt and template.efficient_eos:raise ValueError("Current template does not support `train_on_prompt`.")# Load from cacheif data_args.cache_path is not None:if os.path.exists(data_args.cache_path):logger.warning("Loading dataset from disk will ignore other data arguments.")dataset = load_from_disk(data_args.cache_path)if data_args.streaming:dataset = dataset.to_iterable_dataset()return datasetwith training_args.main_process_first(desc="load dataset"):all_datasets = []for dataset_attr in get_dataset_list(data_args):all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))dataset = merge_dataset(all_datasets, data_args, training_args)with training_args.main_process_first(desc="pre-process dataset"):preprocess_func, print_function = get_preprocess_and_print_func(tokenizer, template, data_args, training_args, stage)column_names = list(next(iter(dataset)).keys())kwargs = {}if not data_args.streaming:kwargs = dict(num_proc=data_args.preprocessing_num_workers,load_from_cache_file=(not data_args.overwrite_cache),desc="Running tokenizer on dataset",)dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):if training_args.should_save:dataset.save_to_disk(data_args.cache_path)logger.info("Dataset cache saved at {}.".format(data_args.cache_path))if training_args.should_log:try:print_function(next(iter(dataset)))except StopIteration:raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")return dataset

这段Python代码定义了一个名为get_dataset的函数,其目的是根据给定的参数加载和预处理一个数据集。下面是该函数的逐步解读:

  1. 函数参数
    • tokenizer: 一个预训练的tokenizer对象,用于处理文本数据。
    • model_args, data_args, training_args: 分别包含模型、数据和训练的参数。
    • stage: 指定当前的训练阶段,如"pt"(预训练)、“sft”(监督微调)、“rm”(奖励模型训练)或"ppo"(PPO训练)。
    • split: 指定数据集的分割,默认为"train"。
  2. 函数逻辑
    • 首先,获取模板并修复tokenizer(get_template_and_fix_tokenizer函数未在代码中给出)。
    • 检查是否支持train_on_prompt功能,如果不支持则抛出错误。
    • 尝试从磁盘加载数据集。如果设置了cache_path且该路径下数据集存在,则直接从磁盘加载,忽略其他数据参数。如果需要流式传输,则将数据集转换为可迭代的。
    • 如果数据集不存在或需要重新生成,则使用get_dataset_list函数获取所有数据集属性,并使用load_single_dataset函数为每个属性加载数据集。然后,使用merge_dataset函数合并所有数据集。
    • 对数据集进行预处理。预处理函数preprocess_func和打印函数print_functionget_preprocess_and_print_func函数返回。预处理包括将数据集的每一行映射到tokenizer。如果不在流式传输模式下,还会使用多进程进行预处理。
    • 如果设置了cache_path,并且尚未创建,则将数据集保存到磁盘。
    • 如果需要日志记录,则打印数据集的一个样本。
  3. 函数返回
    返回一个数据集对象,可以是普通的Dataset或可迭代的IterableDataset
    这个函数的主要目的是提供一个统一的接口来加载、合并和预处理数据集,同时支持缓存和流式传输,适用于不同的训练阶段。

第二层 阅读加载单个数据的代码

def load_single_dataset(dataset_attr: "DatasetAttr",model_args: "ModelArguments",data_args: "DataArguments",
):logger.info("Loading dataset {}...".format(dataset_attr))data_path, data_name, data_dir, data_files = None, None, None, Noneif dataset_attr.load_from in ["hf_hub", "ms_hub"]:data_path = dataset_attr.dataset_namedata_name = dataset_attr.subsetdata_dir = dataset_attr.folderelif dataset_attr.load_from == "script":data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)data_name = dataset_attr.subsetdata_dir = dataset_attr.folderelif dataset_attr.load_from == "file":data_files = []local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)if os.path.isdir(local_path):  # is directoryfor file_name in os.listdir(local_path):data_files.append(os.path.join(local_path, file_name))if data_path is None:data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):raise ValueError("File types should be identical.")elif os.path.isfile(local_path):  # is filedata_files.append(local_path)data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)else:raise ValueError("File not found.")if data_path is None:raise ValueError("File extension must be txt, csv, json or jsonl.")checksum(data_files, dataset_attr.file_sha1)else:raise NotImplementedErrorif dataset_attr.load_from == "ms_hub":try:from modelscope import MsDatasetfrom modelscope.utils.config_ds import MS_DATASETS_CACHEcache_dir = model_args.cache_dir or MS_DATASETS_CACHEdataset = MsDataset.load(dataset_name=data_path,subset_name=data_name,data_dir=data_dir,data_files=data_files,split=data_args.split,cache_dir=cache_dir,token=model_args.ms_hub_token,use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),).to_hf_dataset()except ImportError:raise ImportError("Please install modelscope via `pip install modelscope -U`")else:if "trust_remote_code" in inspect.signature(load_dataset).parameters:  # for datasets==2.16.0kwargs = {"trust_remote_code": True}else:kwargs = {}dataset = load_dataset(path=data_path,name=data_name,data_dir=data_dir,data_files=data_files,split=data_args.split,cache_dir=model_args.cache_dir,token=model_args.hf_hub_token,streaming=(data_args.streaming and (dataset_attr.load_from != "file")),**kwargs,)if data_args.streaming and (dataset_attr.load_from == "file"):  # faster than specifying streaming=Truedataset = dataset.to_iterable_dataset()  # TODO: add num shards parameterif data_args.max_samples is not None:  # truncate datasetnum_samples = min(data_args.max_samples, len(dataset))dataset = dataset.select(range(num_samples))return align_dataset(dataset, dataset_attr, data_args)

是一个独立文件读取的Python函数,用于根据提供的参数加载数据集。下面是该函数的中文解释:

  1. 日志记录:记录开始加载数据集的信息。
  2. 确定数据路径和名称:根据数据集的来源(“hf_hub”、“ms_hub”、“script”或“file”),计算数据集文件的正确路径。
  3. 校验和验证:如果数据集是从本地文件加载的,函数会根据dataset_attr中提供的预期值校验文件的有效SHA1校验和。
  4. 数据集加载:使用datasets库中的load_dataset函数加载数据集。加载数据集的参数根据来源和提供的额外参数确定。
  5. 流调整:如果设置了data_args.streaming且数据集是从文件加载的,则将数据集转换为可迭代的,更适合流式传输的数据集。
  6. 数据集截断:如果设置了data_args.max_samples,则截断数据集到指定的样本数。
  7. 对齐数据集:调用align_dataset函数将数据集与dataset_attrdata_args对齐。这个函数在提供的代码中没有定义,所以它的确切行为是未知的。
  8. 返回数据集:返回已加载和处理的数据集。
    请注意,该函数假设存在某些变量和函数,如loggerosinspectload_dataset,这些都是Python代码中的典型内容。此外,align_dataset在提供的代码中被引用,但没有定义,这表明可能还有其他代码定义了这个函数及其行为。

http://www.ppmy.cn/news/1379723.html

相关文章

SpringCloudGateway理论与实践

文章目录 网关介绍为什么需要网关Gateway 使用gateway pom依赖yml 配置重启测试总结 断言过滤器工厂路由过滤器的种类请求头过滤器默认过滤器全局过滤器总结 Gateway解决跨域 网关介绍 Spring Cloud Gateway 是一个基于Spring Framework 5,由Spring Cloud团队开发的…

代码随想录 贪心算法-中等题目-序列问题

376.摆动序列 376. 摆动序列 中等 如果连续数字之间的差严格地在正数和负数之间交替,则数字序列称为 摆动序列 。第一个差(如果存在的话)可能是正数或负数。仅有一个元素或者含两个不等元素的序列也视作摆动序列。 例如, [1, 7…

skynet cluster集群笔记

skynet cluster集群笔记 前言cluster相关方法说明集群设计方案:集群中常遇到的问题:注意事项: 前言 skynet 是一个基于事件驱动的分布式游戏服务器框架,支持构建高性能、高并发的网络程序。在 skynet中,集群是指将多个…

2024年春招程序员个人简历范本(精选5篇|附模板)

HR浏览一份简历也就25秒左右,如果你连「好简历」都没有,怎么能找到好工作呢? 如果你不懂得如何在简历上展示自己,或者觉得怎么改简历都不出彩,那请你一定仔细读完。 Java开发工程师简历范本> 性别 男 年龄 24 学历 本科 张三 专业 计算机科学与技术 毕业院校 …

C++ set 容器

1.6 C set 容器 一般性的 Set 实现而言,是无序的,在 C 中,std::set 是有序的容器,它基于红黑树(Red-Black Tree)实现,并且会根据元素的键值进行排序。因此,std::set 中的元素总是按…

Docker 的资源控制

目录 Docker 的资源控制为什么需要资源控制?控制内存使用限制 CPU 使用磁盘 I/O 控制网络带宽限制实践建议 Docker 资源控制:保障性能与稳定性Docker资源控制概览内存限制CPU限制磁盘 I/O 控制网络带宽管理实际应用 Docker 启动后的 更新资源管理 Docker…

[蓝桥杯]接龙数列(C语言)

目录 题目链接 题目理解 解题思路 完整代码 重难点解答 *dp数组的具体用法 *对于dp[b]dp[a]1>dp[b]?dp[a]1:dp[b]的解释 题目链接 [蓝桥杯 2023 省 B] 接龙数列 - 洛谷 题目理解 这道题让我们求任给的一串数字,若想让其变成接龙数列最少需要删除的数字…

查找字符串在Text文本中的位置

public static Vector3 GetStringPositionAtText(Text text, string strFragment) {int strFragmentIndex text.text.IndexOf(strFragment); //-1表示不包含strFragmentVector3 stringPos Vector3.zero;if (strFragmentIndex > -1){Vector3 firstPos GetCharPositionAtTe…