Spider 数据集上实现nlp2sql训练任务

news/2025/2/10 16:40:44/

NLP2SQL(自然语言处理到 SQL 查询的转换)是一个重要的自然语言处理(NLP)任务,其目标是将用户的自然语言问题转换为相应的 SQL 查询。这一任务在许多场景下具有广泛的应用,尤其是在与数据库交互的场景中,例如数据分析、业务智能和问答系统。

任务目标
  • 理解自然语言: 理解用户输入的自然语言问题,包括意图、实体和上下文。
  • 生成 SQL 查询: 将理解后的信息转换为正确的 SQL 查询,以从数据库中检索所需的数据。

例如

输入: 用户的自然语言问题,“获取 Gelderland 区的总人口。”

输出: 对应的 SQL 查询,SELECT population FROM districts WHERE name = 'Gelderland';

Spider 是一个难度最大数据集

耶鲁大学在2018年新提出的一个大规模的NL2SQL(Text-to-SQL)数据集。
该数据集包含了10,181条自然语言问句、分布在200个独立数据库中的5,693条SQL,内容覆盖了138个不同的领域。
涉及的SQL语法最全面,是目前难度最大的NL2SQL数据集。

下载查看spider数据集内容

Question 1: How many singers do we have ? ||| concert_singer
SQL: select count(*) from singer

Question 2: What is the total number of singers ? ||| concert_singer
SQL: select count(*) from singer

Question 3: Show name , country , age for all singers ordered by age from the oldest to the youngest . ||| concert_singer
SQL: select name , country , age from singer order by age desc

...

首先需要转换为Spider的标准格式(参考tables.jsontrain.json):

{"db_id": "concert_singer","question": "Show name, country, age...","query": "SELECT name, country, age FROM singer ORDER BY age DESC","schema": {"table_names": ["singer"],"column_names": [[0, "name", "text"],[0, "country", "text"],[0, "age", "int"]]}
}

拆分为table.json的原因可能涉及到数据组织和重用。每个数据库的结构(表、列、外键)在多个问题中都会被重复使用。如果每个问题都附带完整的schema信息,会导致数据冗余,增加存储和处理的开销。所以,将schema单独存储为table.json,可以让不同的数据条目引用同一个数据库模式,减少重复数据。拆分后的结构需要更高效的数据管理,例如在训练模型时,根据每个问题的db_id去table.json中查找对应的schema信息。这样做的好处是当多个问题属于同一个数据库时,不需要每次都重复加载schema提高了效率。

column_names 表示数据库表中每一列的详细信息。具体来说,column_names 是一个列表,其中每个元素都是一个包含三个部分的子列表:

  1. 表索引(0):表示该列属于哪个表。在这个例子中,所有列都属于第一个表(索引为 0)。
  2. 列名("name"、"country"、"age"):表示列的名称。
  3. 数据类型("text"、"int"):表示该列的数据类型,例如文本(text)或整数(int)。

实现下面逻辑转换原始数据

def extract_columns_from_sql(sql):# 使用正则表达式匹配 SELECT 语句中的列名match = re.search(r"SELECT\s+(.*?)\s+FROM", sql, re.IGNORECASE)if match:# 提取列名columns = match.group(1).split(",")# 构建 column_names 列表column_names = []for index, column in enumerate(columns):column = column.strip()  # 去除多余的空格data_type = "text"  # 默认数据类型为 text,可以根据需要修改# 添加到 column_names 列表,假设所有列类型为 textcolumn_names.append([0, column, data_type])return column_namesreturn []# 从 dev.sql 文件读取数据
def load_sql_data(file_path):data_list = []with open(file_path, 'r', encoding='utf-8') as f:  # 指定编码为 UTF-8lines = f.readlines()for i in range(0, len(lines), 3):  # 每三行一组question_line = lines[i].strip()sql_line = lines[i + 1].strip()if not question_line or not sql_line:continue# 提取问题和 SQLquestion = question_line.split(': ', 1)[1].strip()  # 获取问题内容sql = sql_line.split(': ', 1)[1].strip()  # 获取 SQL 查询# 提取表名db_id = question_line.split('|||')[-1].strip()  # 从问题行获取表名question = question.split('|||')[0].strip()target_sql = preprocess(question, db_id, sql)data_list.append({"input_text": f"Translate to SQL: {question} [SEP] Tables: {db_id}","target_sql": json.dumps(target_sql)  # 将目标 SQL 转换为 JSON 格式字符串})return data_list

选择Tokenizer.from_pretrained("t5-base") 是用于加载 T5(Text-to-Text Transfer Transformer)模型的分词器。T5 是一个强大的自然语言处理模型,能够处理各种文本任务(如翻译、摘要、问答等),并且将所有任务视为文本到文本的转换。

from transformers import T5Tokenizertokenizer = T5Tokenizer.from_pretrained("t5-base")def preprocess(question, db_id, sql):# 提取列名column_names = extract_columns_from_sql(sql)# 构建目标格式target_sql = {"db_id": db_id,"question": question,"query": sql,"schema": {"table_names": [db_id],"column_names": column_names}}return target_sql# 示例数据
question = "Show name, country, age for all singers ordered by age from the oldest to the youngest."
schema = "singer(name, country, age)"
sql = "SELECT name, country, age FROM singer ORDER BY age DESC"input_text, target_sql = preprocess(question, schema, sql)
# input_text = "Translate to SQL: Show name... [SEP] Tables: singer(name, country, age)"
# target_sql = "select name, country, age from singer order by age desc"
print('input_text', input_text)
print('target_sql', target_sql)

所有nlp任务都涉及的需要token化,使用t5-base 做tokenize

def tokenize_function(examples):model_inputs = tokenizer(examples["input_text"],max_length=512,truncation=True,padding="max_length")with tokenizer.as_target_tokenizer():labels = tokenizer(examples["target_sql"],max_length=512,truncation=True,padding="max_length")model_inputs["labels"] = labels["input_ids"]return model_inputs

使用 tokenizer.as_target_tokenizer() 上下文管理器,确保目标文本(即 SQL 查询)被正确处理。目标文本也经过编码,转换为 token IDs,并同样进行填充和截断。将目标文本的编码结果(token IDs)存储在 model_inputs["labels"] 中。这是模型在训练时需要的输出,用于计算损失。最终返回一个字典 model_inputs,它包含了模型的输入和对应的标签。这种结构使得模型在训练时可以直接使用。

最后组织下训练代码

tokenized_datasets = dataset.map(tokenize_function, batched=True)# 加载模型
model = T5ForConditionalGeneration.from_pretrained("t5-base")# 训练参数
training_args = Seq2SeqTrainingArguments(output_dir="./results",evaluation_strategy="epoch",learning_rate=3e-5,per_device_train_batch_size=8,per_device_eval_batch_size=8,num_train_epochs=100,predict_with_generate=True,run_name="spider"
)# 开始训练
trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=tokenized_datasets["train"] if 'train' in tokenized_datasets else tokenized_datasets,eval_dataset=tokenized_datasets["test"] if 'test' in tokenized_datasets else None,data_collator=DataCollatorForSeq2Seq(tokenizer)
)trainer.train()

这里使用的是Seq2SeqTrainer, 它是 Hugging Face 的 transformers 库中用于序列到序列(Seq2Seq)任务的训练器。它为处理诸如翻译、文本生成和问答等任务提供了一个高层次的接口,简化了训练过程。以下是 Seq2SeqTrainer 的主要功能和特点:

  1. 简化训练流程Seq2SeqTrainer 封装了许多常见的训练步骤,如数据加载、模型训练、评估和预测,使得用户可以更专注于模型和数据,而不必处理繁琐的训练细节。

  2. 支持多种训练参数: 通过 Seq2SeqTrainingArguments 类,可以灵活配置训练参数,如学习率、批量大小、训练轮数、评估策略等。

  3. 自动处理填充和截断: 在处理输入和输出序列时,Seq2SeqTrainer 可以自动填充和截断序列,以确保它们适应模型的输入要求。

  4. 集成评估和监控: 支持在训练过程中进行模型评估,并可以根据评估指标(如损失)监控训练进度。用户可以设置评估频率和评估数据集

开始训练,进行100次epoch

训练监控在 Weights & Biases ,Seq2SeqTrainer 能够向 Weights & Biases (wandb) 传输训练监控数据,主要是因为它内置了与 wandb 的集成。以下是一些关键点,解释了这一过程:

  1. 自动集成:当你使用 Seq2SeqTrainer 时,它会自动检测 wandb 的安装并在初始化时配置相关设置。这意味着你无需手动设置 wandb。

  2. 回调功能Trainer 类提供了回调功能,可以在训练过程中记录各种指标(如损失、准确率等)。这些指标会被自动发送到 wandb。

  3. 配置管理training_args 中的参数可以指定 wandb 的项目名称、运行名称等,从而更好地组织和管理实验。

  4. 训练循环:在每个训练和评估周期结束时,Trainer 会调用相应的回调函数,将重要的训练信息(如损失、学习率等)记录到 wandb。

  5. 可视化:通过 wandb,你可以实时监控训练过程,包括损失曲线、模型性能等,帮助你更好地理解模型的训练动态。

多次试验还可以比较训练性能

训练结束, 损失收敛到0.05410315271151268

{'eval_loss': 0.008576861582696438, 'eval_runtime': 1.3883, 'eval_samples_per_second': 74.912, 'eval_steps_per_second': 5.042, 'epoch': 100.0}
{'train_runtime': 2914.0548, 'train_samples_per_second': 31.914, 'train_steps_per_second': 2.025, 'train_loss': 0.05410315271151268, 'epoch': 100.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5900/5900 [48:31<00:00,  2.03it/s]
wandb:
wandb: 🚀 View run spider at: https://wandb.ai/chenruithinking-4th-paradigm/huggingface/runs/dkccvpp4
wandb: Find logs at: wandb/run-20250207_112702-dkccvpp4/logs

测试下预测能力

import os
from transformers import T5Tokenizer, T5ForConditionalGeneration# 设置 NCCL 环境变量
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"# 加载分词器
tokenizer = T5Tokenizer.from_pretrained("t5-base")model = T5ForConditionalGeneration.from_pretrained("./results/t5-sql-model")
tokenizer.save_pretrained("./results/t5-sql-model")def generate_sql(question, db_id):input_text = f"Translate to SQL: {question} [SEP] Tables: {db_id}"input_ids = tokenizer.encode(input_text, return_tensors="pt")  # 使▒~T▒ PyTorch ▒~Z~D▒| ▒~G~O▒| ▒▒~Ooutput = model.generate(input_ids,max_length=512,num_beams=5,  # 或者尝试其他解码策略early_stopping=True)print('output', output)generated_sql = tokenizer.decode(output[0], skip_special_tokens=True)return generated_sqlquestion = "How many singers do we have ?"
db_id = "concert_singer"
evaluation_output = generate_sql(question, db_id)
print("evaluation_output:", evaluation_output)

输出结果

evaluation_output: "db_id": "concert_singer", "question": "How many singers do we have ?", "query": "select count(*) from singer", "schema": "table_names": ["concert_singer"], "column_names": [[0, "count(*)", "text"]]


http://www.ppmy.cn/news/1570912.html

相关文章

【Linux】之【Get√】nmcli device wifi list 与 wpa_cli scan 和 wpa_cli scan_result 区别

nmcli device wifi list 是 NetworkManager 的命令行工具 nmcli 的一部分&#xff0c;它用于列出当前可用的无线网络。它的作用和 wpa_cli 的扫描功能类似&#xff0c;但有一些不同点。 1. nmcli device wifi list 功能&#xff1a; nmcli device wifi list 命令用于显示当前…

数据库如何清空重置索引,MySQL PostgreSQL SQLite SQL Server

要彻底清空数据库并重置自增ID&#xff08;索引&#xff09;&#xff0c;具体操作取决于您使用的数据库管理系统(DBMS)。以下是针对几种常见数据库的说明&#xff1a; MySQL 对于MySQL&#xff0c;您可以使用如下命令来删除表中的所有数据&#xff0c;并将自增计数器重置。 T…

边缘计算网关驱动智慧煤矿智能升级——实时预警、低延时决策与数字孪生护航矿山安全高效运营

迈向智能化煤矿管理新时代 工业物联网和边缘计算技术的迅猛发展&#xff0c;煤矿安全生产与高效运营正迎来全新变革。传统煤矿监控模式由于现场环境复杂、数据采集和传输延时较高&#xff0c;已难以满足当下高标准的安全管理要求。为此&#xff0c;借助边缘计算网关的实时数据…

基于springboot+vue的文物管理系统的设计与实现

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

HiveQL命令(二)- 数据表操作

文章目录 前言一、数据表操作1. 创建表1.1 语法及解释1.2 内部表1.2.1 创建内部表示例 1.3 外部表1.3.1 创建外部表示例 2. 查看表2.1 查看当前数据库中所有表2.2 查看表信息2.2.1 语法及解释2.2.2 查看表信息示例 3. 修改表3.1 重命名表3.1.1 语法3.1.2 示例 3.2 修改表属性3.…

Spring Cloud工程搭建

目录 工程搭建 搭建父子工程 创建父工程 Spring Cloud版本 创建子项目-订单服务 声明项⽬依赖 和 项⽬构建插件 创建子项目-商品服务 声明项⽬依赖 和 项⽬构建插件 工程搭建 因为拆分成了微服务&#xff0c;所以要拆分出多个项目&#xff0c;但是IDEA只能一个窗口有一…

Matplotlib基础01( 基本绘图函数/多图布局/图形嵌套/绘图属性)

Matplotlib基础 Matplotlib是一个用于绘制静态、动态和交互式图表的Python库&#xff0c;广泛应用于数据可视化领域。它是Python中最常用的绘图库之一&#xff0c;提供了多种功能&#xff0c;可以生成高质量的图表。 Matplotlib是数据分析、机器学习等领域数据可视化的重要工…

apisix网关ip-restriction插件使用说明

ip-restriction插件可以在网关层进行客户端请求ip拦截。 当然了&#xff0c;一般不推荐使用该方法&#xff0c;专业的事专业工具做。建议有条件&#xff0c;还是上防火墙或者waf来做。 官方文档&#xff1a;ip-restriction | Apache APISIX -- Cloud-Native API Gateway whit…