【Datawhale AI 夏令营】讯飞“基于术语词典干预的机器翻译挑战赛”

news/2024/9/13 22:34:30/ 标签: 人工智能, 机器翻译, 自然语言处理

背景

机器翻译具有悠长的发展历史,目前主流的机器翻译方法为神经网络翻译,如LSTM和transformer。在特定领域或行业中,由于机器翻译难以保证术语的一致性,导致翻译效果还不够理想。对于术语名词、人名地名等机器翻译不准确的结果,可以通过术语词典进行纠正,避免了混淆或歧义,最大限度提高翻译质量。

任务

基于术语词典干预的机器翻译挑战赛选择以英文为源语言,中文为目标语言的机器翻译。本次大赛除英文到中文的双语数据,还提供英中对照的术语词典。参赛队伍需要基于提供的训练数据样本从多语言机器翻译模型的构建与训练,并基于测试集以及术语词典,提供最终的翻译结果。

Baseline代码解读

首先导入相应的包

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from collections import Counter
import random
from torch.utils.data import Subset, DataLoader
import time

随后定义数据集、Decoder类、Encoder类、Seq2seq类

# 定义数据集类
# 修改TranslationDataset类以处理术语
class TranslationDataset(Dataset):def __init__(self, filename, terminology):self.data = []with open(filename, 'r', encoding='utf-8') as f:for line in f:en, zh = line.strip().split('\t')self.data.append((en, zh))self.terminology = terminology# 创建词汇表,注意这里需要确保术语词典中的词也被包含在词汇表中self.en_tokenizer = get_tokenizer('basic_english')self.zh_tokenizer = list  # 使用字符级分词en_vocab = Counter(self.terminology.keys())  # 确保术语在词汇表中zh_vocab = Counter()for en, zh in self.data:en_vocab.update(self.en_tokenizer(en))zh_vocab.update(self.zh_tokenizer(zh))# 添加术语到词汇表self.en_vocab = ['<pad>', '<sos>', '<eos>'] + list(self.terminology.keys()) + [word for word, _ in en_vocab.most_common(10000)]self.zh_vocab = ['<pad>', '<sos>', '<eos>'] + [word for word, _ in zh_vocab.most_common(10000)]self.en_word2idx = {word: idx for idx, word in enumerate(self.en_vocab)}self.zh_word2idx = {word: idx for idx, word in enumerate(self.zh_vocab)}def __len__(self):return len(self.data)def __getitem__(self, idx):en, zh = self.data[idx]en_tensor = torch.tensor([self.en_word2idx.get(word, self.en_word2idx['<sos>']) for word in self.en_tokenizer(en)] + [self.en_word2idx['<eos>']])zh_tensor = torch.tensor([self.zh_word2idx.get(word, self.zh_word2idx['<sos>']) for word in self.zh_tokenizer(zh)] + [self.zh_word2idx['<eos>']])return en_tensor, zh_tensordef collate_fn(batch):en_batch, zh_batch = [], []for en_item, zh_item in batch:en_batch.append(en_item)zh_batch.append(zh_item)# 对英文和中文序列分别进行填充en_batch = nn.utils.rnn.pad_sequence(en_batch, padding_value=0, batch_first=True)zh_batch = nn.utils.rnn.pad_sequence(zh_batch, padding_value=0, batch_first=True)return en_batch, zh_batch
class Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):super().__init__()self.embedding = nn.Embedding(input_dim, emb_dim)self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True)self.dropout = nn.Dropout(dropout)def forward(self, src):# src shape: [batch_size, src_len]embedded = self.dropout(self.embedding(src))# embedded shape: [batch_size, src_len, emb_dim]outputs, hidden = self.rnn(embedded)# outputs shape: [batch_size, src_len, hid_dim]# hidden shape: [n_layers, batch_size, hid_dim]return outputs, hiddenclass Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):super().__init__()self.output_dim = output_dimself.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True)self.fc_out = nn.Linear(hid_dim, output_dim)self.dropout = nn.Dropout(dropout)def forward(self, input, hidden):# input shape: [batch_size, 1]# hidden shape: [n_layers, batch_size, hid_dim]embedded = self.dropout(self.embedding(input))# embedded shape: [batch_size, 1, emb_dim]output, hidden = self.rnn(embedded, hidden)# output shape: [batch_size, 1, hid_dim]# hidden shape: [n_layers, batch_size, hid_dim]prediction = self.fc_out(output.squeeze(1))# prediction shape: [batch_size, output_dim]return prediction, hiddenclass Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super().__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, src, trg, teacher_forcing_ratio=0.5):# src shape: [batch_size, src_len]# trg shape: [batch_size, trg_len]batch_size = src.shape[0]trg_len = trg.shape[1]trg_vocab_size = self.decoder.output_dimoutputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)_, hidden = self.encoder(src)input = trg[:, 0].unsqueeze(1)  # Start tokenfor t in range(1, trg_len):output, hidden = self.decoder(input, hidden)outputs[:, t, :] = outputteacher_force = random.random() < teacher_forcing_ratiotop1 = output.argmax(1)input = trg[:, t].unsqueeze(1) if teacher_force else top1.unsqueeze(1)return outputs

增加术语词典

# 新增术语词典加载部分
def load_terminology_dictionary(dict_file):terminology = {}with open(dict_file, 'r', encoding='utf-8') as f:for line in f:en_term, ch_term = line.strip().split('\t')terminology[en_term] = ch_termreturn terminology

训练模型

def train(model, iterator, optimizer, criterion, clip):model.train()epoch_loss = 0for i, (src, trg) in enumerate(iterator):src, trg = src.to(device), trg.to(device)optimizer.zero_grad()output = model(src, trg)output_dim = output.shape[-1]output = output[:, 1:].contiguous().view(-1, output_dim)trg = trg[:, 1:].contiguous().view(-1)loss = criterion(output, trg)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()return epoch_loss / len(iterator)

主函数,设置批次大小和数据量

# 主函数
if __name__ == '__main__':start_time = time.time()  # 开始计时device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#terminology = load_terminology_dictionary('../dataset/en-zh.dic')terminology = load_terminology_dictionary('../dataset/en-zh.dic')# 加载数据dataset = TranslationDataset('../dataset/train.txt',terminology = terminology)# 选择数据集的前N个样本进行训练N = 1000  #int(len(dataset) * 1)  # 或者你可以设置为数据集大小的一定比例,如 int(len(dataset) * 0.1)subset_indices = list(range(N))subset_dataset = Subset(dataset, subset_indices)train_loader = DataLoader(subset_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)# 定义模型参数INPUT_DIM = len(dataset.en_vocab)OUTPUT_DIM = len(dataset.zh_vocab)ENC_EMB_DIM = 256DEC_EMB_DIM = 256HID_DIM = 512N_LAYERS = 2ENC_DROPOUT = 0.5DEC_DROPOUT = 0.5# 初始化模型enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)model = Seq2Seq(enc, dec, device).to(device)# 定义优化器和损失函数optimizer = optim.Adam(model.parameters())criterion = nn.CrossEntropyLoss(ignore_index=dataset.zh_word2idx['<pad>'])# 训练模型N_EPOCHS = 10CLIP = 1for epoch in range(N_EPOCHS):train_loss = train(model, train_loader, optimizer, criterion, CLIP)print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f}')# 在训练循环结束后保存模型torch.save(model.state_dict(), './translation_model_GRU.pth')end_time = time.time()  # 结束计时# 计算并打印运行时间elapsed_time_minute = (end_time - start_time)/60print(f"Total running time: {elapsed_time_minute:.2f} minutes")

由于没有对代码进行任何修改,所以效果并不好

之后尝试修改N以及NEPOCH参数,来降低损失,从而提高分数


http://www.ppmy.cn/news/1475291.html

相关文章

大模型最新黑书:基于GPT-3、ChatGPT、GPT-4等Transformer架构的自然语言处理 PDF

今天给大家推荐一本丹尼斯罗斯曼(Denis Rothman)编写的关于大语言模型&#xff08;LLM&#xff09;权威教程<<大模型应用解决方案> 基于GPT-3、ChatGPT、GPT-4等Transformer架构的自然语言处理>&#xff01;Google工程总监Antonio Gulli作序&#xff0c;这含金量不…

1509.三次操作后最大值与最小值的最小差

1.题目描述 给你一个数组 nums 。 每次操作你可以选择 nums 中的任意一个元素并将它改成 任意值 。 在 执行最多三次移动后 &#xff0c;返回 nums 中最大值与最小值的最小差值。 示例 1&#xff1a; 输入&#xff1a;nums [5,3,2,4] 输出&#xff1a;0 解释&#xff1a;我们最…

2024年浙江省高考分数一分一段数据可视化

下图根据 2024 年浙江高考一分一段表绘制&#xff0c;可以看到&#xff0c;竞争最激烈的分数区间在620分到480分之间。 不过&#xff0c;浙江是考两次取最大&#xff0c;不是很有代表性。看看湖北的数据&#xff0c;580分到400分的区段都很卷。另外&#xff0c;从这个图也可以…

QT5.12.9 通过MinGW64 / MinGW32 cmake编译Opencv4.5.1

一、安装前准备: 1.安装QT,QT5.12.9官方下载链接:https://download.qt.io/archive/qt/5.12/5.12.9/ QT安装教程:https://blog.csdn.net/Mark_md/article/details/108614209 如果电脑是64位就编译器选择MinGW64,32位就选择MinGW32,我的是MinGW64。 2.opencv源码下载:h…

SchedulerLock分布式定时任务锁

1.pom中引入依赖&#xff0c;这里使用redis作为锁 <dependency><groupId>net.javacrumbs.shedlock</groupId><artifactId>shedlock-spring</artifactId><version>4.12.0</version></dependency><dependency><groupId…

Redis在项目中的17种使用场景

Redis 是一个开源的高性能键值对数据库&#xff0c;它以其内存中数据存储、键过期策略、持久化、事务、丰富的数据类型支持以及原子操作等特性&#xff0c;在许多项目中扮演着关键角色。以下是V哥整理的17个Redis在项目中常见的使用场景&#xff1a; 缓存&#xff1a;Redis 可以…

PHP全功能微信投票迷你平台系统小程序源码

&#x1f525;让决策变得超简单&#xff01;&#x1f389; &#x1f680;【一键创建&#xff0c;秒速启动】 嘿小伙伴们&#xff0c;你还在为组织投票而手忙脚乱吗&#xff1f;来试试这款全功能投票迷你微信小程序吧&#xff01;只需轻轻一点&#xff0c;无论是班级选举、社团…

硅纪元AI应用推荐 | 百度橙篇成新宠,能写万字长文

“硅纪元AI应用推荐”栏目&#xff0c;为您精选最新、最实用的人工智能应用&#xff0c;无论您是AI发烧友还是新手&#xff0c;都能在这里找到提升生活和工作的利器。与我们一起探索AI的无限可能&#xff0c;开启智慧新时代&#xff01; 百度橙篇&#xff0c;作为百度公司在202…

Python练习题(3)

1.使用requests模块获取这个json文件http://java-api.super-yx.com/html/hello.json 2.将获取到的json转为dict 3.将dict保存为hello.json文件 4.用文件流写一个copy(src,dst)函数,复制hello.json到C:\hello.json import requests import jsondef copy(src, dst):read_file o…

【泛型】学习笔记

1.工作中使用反射去创建对象 例子1Getterprivate int type;private Class<? extends AbstractActivity> clazz;ActivityType(int type, Class<? extends AbstractActivity> clazz) {this.type type;this.clazz clazz;}public AbstractActivity newInstance(Ac…

Spark底层原理:案例解析(第34天)

系列文章目录 一、Spark架构设计概述 二、Spark核心组件 三、Spark架构设计举例分析 四、Job调度流程详解 五、Spark交互流程详解 文章目录 系列文章目录前言一、Spark架构设计概述1. 集群资源管理器&#xff08;Cluster Manager&#xff09;2. 工作节点&#xff08;Worker No…

RabbitMQ中常用的三种交换机【Fanout、Direct、Topic】

目录 1、引入 2、Fanout交换机 案例&#xff1a;利用SpringAMQP演示Fanout交换机的使用 3、Direct交换机 案例&#xff1a;利用SpringAMQP演示Direct交换机的使用 4、Topic交换机 案例&#xff1a;利用SpringAMQP演示Topic交换机的使用 1、引入 真实的生产环境都会经过e…

mysql之导入测试数据

运维时经常要这样&#xff1a;mysql改表名&#xff0c;创建一个一样的表不含数据&#xff0c;复制旧表几条数据进去 改变表的名字&#xff1a; RENAME TABLE old_table_name TO new_table_name; 这将把原来的表old_table_name重命名为new_table_name。 创建一个一样的表结构…

MES实时监控食品加工过程中各环节的安全

在实时监控食品加工过程中各环节的安全风险方面&#xff0c;万界星空科技的MES&#xff08;制造执行系统&#xff09;解决方案发挥了至关重要的作用。以下是具体如何通过MES系统实现实时监控食品加工过程中各环节安全风险的详细阐述&#xff1a; 一、集成传感器与实时监控 MES…

1.1 - Android启动概览

第一章 系统启动流程分析 第一节 Android启动概览 Android启动概览可以从多个方面进行描述&#xff0c;包括启动流程、关键组件及其作用等。以下是一个详细的Android启动概览&#xff1a; 一、启动流程 Android设备的启动流程大致可以分为以下几个阶段&#xff1a; 上电与引导…

数据结构实操代码题~考研

作者主页: 知孤云出岫 目录 数据结构实操代码题题目一&#xff1a;实现栈&#xff08;Stack&#xff09;题目二&#xff1a;实现队列&#xff08;Queue&#xff09;题目三&#xff1a;实现二叉搜索树&#xff08;BST&#xff09;题目四&#xff1a;实现链表&#xff08;Linked…

虚幻引擎ue5如何调节物体锚点

当发现锚点不在物体上时&#xff0c;如何调节瞄点在物体上。 步骤1&#xff1a;按住鼠标中键拖动锚点&#xff0c;在透视图中多次调节锚点位置。 步骤2:在物体上点击鼠标右键点击-》锚定--》“设置为枢轴偏移”即可。

2974.最小数字游戏

1.题目描述 你有一个下标从 0 开始、长度为 偶数 的整数数组 nums &#xff0c;同时还有一个空数组 arr 。Alice 和 Bob 决定玩一个游戏&#xff0c;游戏中每一轮 Alice 和 Bob 都会各自执行一次操作。游戏规则如下&#xff1a; 每一轮&#xff0c;Alice 先从 nums 中移除一个 …

机器学习扫盲:优化算法、损失函数、评估指标、激活函数、网络架构

专栏介绍 1.专栏面向零基础或基础较差的机器学习入门的读者朋友,旨在利用实际代码案例和通俗化文字说明,使读者朋友快速上手机器学习及其相关知识体系。 2.专栏内容上包括数据采集、数据读写、数据预处理、分类\回归\聚类算法、可视化等技术。 3.需要强调的是,专栏仅介绍主…

MySQL8之mysql-community-server-debug的作用

mysql-community-server-debug是MySQL社区服务器的一个调试版本&#xff0c;它主要用于开发和调试MySQL数据库服务器。与标准的MySQL社区服务器版本相比&#xff0c;调试版本包含了额外的调试信息和工具&#xff0c;以帮助开发人员和数据库管理员诊断和解决MySQL服务器中的问题…