传知代码-图神经网络长对话理解(论文复现)

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

概述

情感识别是人类对话理解的关键任务。随着多模态数据的概念,如语言、声音和面部表情,任务变得更加具有挑战性。作为典型解决方案,利用全局和局部上下文信息来预测对话中每个单个句子(即话语)的情感标签。具体来说,全局表示可以通过对话级别的跨模态交互建模来捕获。局部表示通常是通过发言者的时间信息或情感转变来推断的,这忽略了话语级别的重要因素。此外,大多数现有方法在统一输入中使用多模态的融合特征,而不利用模态特定的表示。针对这些问题,我们提出了一种名为“关系时序图神经网络与辅助跨模态交互(CORECT)”的新型神经网络框架,它以模态特定的方式有效捕获了对话级别的跨模态交互和话语级别的时序依赖,用于对话理解。大量实验证明了CORECT的有效性,通过在IEMOCAP和CMUMOSEI数据集上取得了多模态ERC任务的最新成果。

模型整体架构

在这里插入图片描述

特征提取

文本采用transformerde方式进行编码
在这里插入图片描述

音频,视频都采用全连接的方式进行编码
在这里插入图片描述

通过添加相应的讲话者嵌入来增强技术增强
在这里插入图片描述

关系时序图卷积网络(RT-GCN)

解读:RT-GCN旨在通过利用话语之间以及话语与其模态之间的多模态图来捕获对话中每个话语的局部上下文信息,关系时序图在一个模块中同时实现了上下文信息,与模态之间的信息的传递。对话中情感识别需要跨模态学习到信息,同时也需要学习上下文的信息,整合成一个模块的作用将两部分并行处理,降低模型的复杂程度,降低训练成本,降低训练难度。

建图方式,模态与模态之间有边相连,对话之间有边相连:

在这里插入图片描述

建图之后,用图transformer融合不同模态,以及不同语句的信息,得到处理之后特征向量:
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

两两交叉模态特征交互

跨模态的异质性经常提高了分析人类语言的难度。利用跨模态交互可能有助于揭示跨模态之间的“不对齐”特性和长期依赖关系。受到这一思想的启发(Tsai等人,2019),我们将配对的跨模态特征交互(P-CM)方法设计到我们提出的用于对话理解的框架中。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

线性分类器

最后就是根据提取出来的特征进行情感分类了:

在这里插入图片描述

代码修改

这是对话中多模态情感识别(视觉,音频,文本)在数据集IEMOCAP目前为止的SOTA。在离线系统已经取得了相当不错的表现。(离线系统的意思是,是一段已经录制好的视频,而不是事实录制如线上开会)

但是却存在一个问题,输入的数据是已经给定的一个视频,分析某一句话的情感状态的时候,论文的方法使用了过去的信息,也使用了未来的信息,这样会在工业界实时应用场景存在一定的问题。

比如在开线上会议,需要检测开会双方的情绪,不可能用未来将要说的话预测现在的情绪。因为未来的话都还没被说话者说出来,此时,就不能参考到未来的语句来预测现在语句的情感信息。但是原文的方法在数据结构图的构建的时候,连接上了未来语句和现在语句的边,用图神经网络学习了之间的关联。

因此,修改建图方式,不考虑未来的情感信息,重新训练网络,得到了还可以接受的效果,精度大概在82%左右,原文的精度在84%左右,2%精度的牺牲解决了是否能实时的问题其实是值得的。

演示效果

在这里插入图片描述
在这里插入图片描述

核心逻辑

在这里可以粘贴您的核心代码逻辑:

# start#模型核心部分import torch
import torch.nn as nn
import torch.nn.functional as Ffrom .Classifier import Classifier
from .UnimodalEncoder import UnimodalEncoder
from .CrossmodalNet import CrossmodalNet
from .GraphModel import GraphModel
from .functions import multi_concat, feature_packing
import corectlog = corect.utils.get_logger()class CORECT(nn.Module):def __init__(self, args):super(CORECT, self).__init__()self.args = argsself.wp = args.wpself.wf = args.wfself.modalities = args.modalitiesself.n_modals = len(self.modalities)self.use_speaker = args.use_speakerg_dim = args.hidden_sizeh_dim = args.hidden_sizeic_dim = 0if not args.no_gnn:ic_dim = h_dim * self.n_modalsif not args.use_graph_transformer and (args.gcn_conv == "gat_gcn" or args.gcn_conv == "gcn_gat"):ic_dim = ic_dim * 2if args.use_graph_transformer:ic_dim *= args.graph_transformer_nheadsif args.use_crossmodal and self.n_modals > 1:ic_dim += h_dim * self.n_modals * (self.n_modals - 1)if self.args.no_gnn and (not self.args.use_crossmodal or self.n_modals == 1):ic_dim = h_dim * self.n_modalsa_dim = args.dataset_embedding_dims[args.dataset]['a']t_dim = args.dataset_embedding_dims[args.dataset]['t']v_dim = args.dataset_embedding_dims[args.dataset]['v']dataset_label_dict = {"iemocap": {"hap": 0, "sad": 1, "neu": 2, "ang": 3, "exc": 4, "fru": 5},"iemocap_4": {"hap": 0, "sad": 1, "neu": 2, "ang": 3},"mosei": {"Negative": 0, "Positive": 1},}dataset_speaker_dict = {"iemocap": 2,"iemocap_4": 2,"mosei":1,}tag_size = len(dataset_label_dict[args.dataset])self.n_speakers = dataset_speaker_dict[args.dataset]self.wp = args.wpself.wf = args.wfself.device = args.deviceself.encoder = UnimodalEncoder(a_dim, t_dim, v_dim, g_dim, args)self.speaker_embedding = nn.Embedding(self.n_speakers, g_dim)print(f"{args.dataset} speakers: {self.n_speakers}")if not args.no_gnn:self.graph_model = GraphModel(g_dim, h_dim, h_dim, self.device, args)print('CORECT --> Use GNN')if args.use_crossmodal and self.n_modals > 1:self.crossmodal = CrossmodalNet(g_dim, args)print('CORECT --> Use Crossmodal')elif self.n_modals == 1:print('CORECT --> Crossmodal not available when number of modalitiy is 1')self.clf = Classifier(ic_dim, h_dim, tag_size, args)self.rlog = {}def represent(self, data):# Encoding multimodal featurea = data['audio_tensor'] if 'a' in self.modalities else Nonet = data['text_tensor'] if 't' in self.modalities else Nonev = data['visual_tensor'] if 'v' in self.modalities else Nonea, t, v = self.encoder(a, t, v, data['text_len_tensor'])# Speaker embeddingif self.use_speaker:emb = self.speaker_embedding(data['speaker_tensor'])a = a + emb if a != None else Nonet = t + emb if t != None else Nonev = v + emb if v != None else None# Graph constructmultimodal_features = []if a != None:multimodal_features.append(a)if t != None:multimodal_features.append(t)if v != None:multimodal_features.append(v)out_encode = feature_packing(multimodal_features, data['text_len_tensor'])out_encode = multi_concat(out_encode, data['text_len_tensor'], self.n_modals)out = []if not self.args.no_gnn:out_graph = self.graph_model(multimodal_features, data['text_len_tensor'])out.append(out_graph)if self.args.use_crossmodal and self.n_modals > 1:out_cr = self.crossmodal(multimodal_features)out_cr = out_cr.permute(1, 0, 2)lengths = data['text_len_tensor']batch_size = lengths.size(0)cr_feat = []for j in range(batch_size):cur_len = lengths[j].item()cr_feat.append(out_cr[j,:cur_len])cr_feat = torch.cat(cr_feat, dim=0).to(self.device)out.append(cr_feat)if self.args.no_gnn and (not self.args.use_crossmodal or self.n_modals == 1):out = out_encodeelse:out = torch.cat(out, dim=-1)return outdef forward(self, data):graph_out = self.represent(data)out = self.clf(graph_out, data["text_len_tensor"])return outdef get_loss(self, data):graph_out = self.represent(data)loss = self.clf.get_loss(graph_out, data["label_tensor"], data["text_len_tensor"])return lossdef get_log(self):return self.rlog#图神经网络
import torch
import torch.nn as nn
from torch_geometric.nn import RGCNConv, TransformerConvimport corectclass GNN(nn.Module):def __init__(self, g_dim, h1_dim, h2_dim, num_relations, num_modals, args):super(GNN, self).__init__()self.args = argsself.num_modals = num_modalsif args.gcn_conv == "rgcn":print("GNN --> Use RGCN")self.conv1 = RGCNConv(g_dim, h1_dim, num_relations)if args.use_graph_transformer:print("GNN --> Use Graph Transformer")in_dim = h1_dimself.conv2 = TransformerConv(in_dim, h2_dim, heads=args.graph_transformer_nheads, concat=True)self.bn = nn.BatchNorm1d(h2_dim * args.graph_transformer_nheads)def forward(self, node_features, node_type, edge_index, edge_type):if self.args.gcn_conv == "rgcn":x = self.conv1(node_features, edge_index, edge_type)if self.args.use_graph_transformer:x = nn.functional.leaky_relu(self.bn(self.conv2(x, edge_index)))return x

使用方式&部署方式

首先建议安装conda,因为想要复现深度学习的代码,github上不同项目的环境差别太大,同时处理多个项目的时候很麻烦,在这里就不做conda安装的教程了,请自行学习。

安装pytorch:
请到pytorch官网找安装命令,尽量不要直接pip install
https://pytorch.org/get-started/previous-versions/

给大家直接对着我安装版本来下载,因为图神经网络的包版本要求很苛刻,版本对应不上很容易报错:
在这里插入图片描述
在这里插入图片描述

只要环境配置好了,找到这个文件,里面的代码粘贴到终端运行即可
在这里插入图片描述

温馨提示

1.数据集和已训练好的模型都在.md文件中有百度网盘链接,直接下载放到指定文件夹即可
2.注意,训练出来的模型是有硬件要求的,我是用cpu进行训练的,模型只能在cpu跑,如果想在gpu上跑,请进行重新训练
3.如果有朋友希望用苹果的gpu进行训练,虽然现在pytorch框架已经支持mps(mac版本的cuda可以这么理解)训练,但是很遗憾,图神经网络的包还不支持,不过不用担心,这个模型的训练量很小,我全程都是苹果笔记本完成训练的。

源码下载


http://www.ppmy.cn/ops/56865.html

相关文章

Go语言入门之变量、常量、指针以及数据类型

Go语言入门之变量、常量、指针以及数据类型 1.变量的声明和定义 var 变量名 变量类型// 声明单变量 var age int // 定义int类型的年龄,初始值为0// 声明多变量 var a, b int 1, 2// 声明变量不写数据类型可以自动判断 var a, b 123, "hello"// 变…

debian 12 PXE Server 批量部署系统

pxe server 前言 PXE(Preboot eXecution Environment,预启动执行环境)是一种网络启动协议,允许计算机通过网络启动而不是使用本地硬盘。PXE服务器是实现这一功能的服务器,它提供了启动镜像和引导加载程序,…

索引原理;为什么采用B+树?

在MySQL中,索引的原理是通过数据结构来快速查找数据。常见的索引数据结构有B树、B树和哈希表等。MySQL大多数存储引擎(如InnoDB)使用B树作为索引的数据结构。 为什么采用B树? 1. B树结构 B树是一种平衡树,它是在B树…

LabVIEW机器视觉技术在产品质量检测中有哪些应用实例

LabVIEW的机器视觉技术在产品质量检测中有广泛的应用,通过图像采集、处理和分析,实现对产品缺陷的自动检测、尺寸测量和定位校准,提高生产效率和产品质量。 1. 电子元器件质量检测 在电子制造业中,电子元器件的质量检测是确保产品…

红黑树,B+树,B树的结构原理及对比

红黑树 结构原理: 红黑树是一种自平衡的二叉搜索树,它通过在每个节点上增加一个颜色属性(红色或黑色)来确保树的平衡性。红黑树的平衡是通过一系列旋转和重新着色操作来实现的,这些操作在插入、删除节点时进行&#…

Redis的配置优化、数据类型、消息队列

文章目录 一、Redis的配置优化redis主要配置项CONFIG 动态修改配置慢查询持久化RDB模式AOF模式 Redis多实例Redis命令相关 二、Redis数据类型字符串string列表list集合 set有序集合sorted set哈希hash 三、消息队列生产者消费者模式发布者订阅者模式 一、Redis的配置优化 redi…

面向对象进阶基础练习

Java学习笔记(新手纯小白向) 第一章 JAVA基础概念 第二章 JAVA安装和环境配置 第三章 IntelliJ IDEA安装 第四章 运算符 第五章 运算符联系 第六章 判断与循环 第七章 判断与循环练习 第八章 循环高级综合 第九章 数组介绍及其内存图 第十章 数…

html5——CSS3_文本样式属性

目录 字体样式 字体类型 字体大小 字体风格 字体的粗细 文本样式 文本颜色 排版文本段落 文本修饰和垂直对齐 文本阴影 字体样式 字体类型 p{font-family:Verdana,"楷体";} body{font-family: Times,"Times New Roman", "楷体";} …

从零开始学习嵌入式----C语言数组指针

目录 拨开迷雾:深入浅出C语言数组指针 一、 数组与指针:剪不断理还乱的关系 二、 数组指针:指向数组的指针 三、 数组指针的应用场景 四、 总结 拨开迷雾:深入浅出C语言数组指针 数组和指针,在C语言的世界里&…

怎样优化 PostgreSQL 中对布尔类型数据的查询?

文章目录 一、索引的合理使用1. 常规 B-tree 索引2. 部分索引 二、查询编写技巧1. 避免不必要的类型转换2. 逻辑表达式的优化 三、表结构设计1. 避免过度细分的布尔列2. 规范化与反规范化 四、数据分布与分区1. 数据分布的考虑2. 表分区 五、数据库参数调整1. 相关配置参数2. 定…

java -Navicat的安装和使用

Navicat 是一款流行的数据库管理工具,支持多种数据库类型,包括 MySQL、MariaDB、SQL Server、Oracle、PostgreSQL 和 SQLite。以下是如何在 Java 环境中安装和使用 Navicat 的详细步骤。 ### 一、安装 Navicat #### 1. 下载 Navicat 前往 Navicat 官方…

开源大势所趋

一、开源项目的发展趋势 技术栈多样化与专业化:随着技术的不断进步,开源项目涵盖了从云计算、大数据、人工智能到区块链、物联网等各个领域,技术栈日益丰富和专业化。这种趋势使得开发者能够根据自己的需求选择最适合的技术工具,促…

Linux C语言基础 day10

目录 学习目标: 学习内容: 1.指针指向数组 1.1 指针与数组的关系 1.2 指针与一维数组关系实现 1.2.1 指针与一维数组的关系 1.2.2 指针指向一维整型数组作为函数参数传递 课外作业: 学习目标: 一周掌握 C基础知识 学习内…

2024年浙江省高考分数一分一段数据可视化

下图根据 2024 年浙江高考一分一段表绘制,可以看到,竞争最激烈的分数区间在620分到480分之间。 不过,浙江是考两次取最大,不是很有代表性。看看湖北的数据,580分到400分的区段都很卷。另外,从这个图也可以…

uniapp进行微信小程序开发,使用navigateBack返回到上一个页面时候,接口未刷新。

代码背景: 使用uniapp进行微信小程序开发时,有a和b两个页面,从a进入b页面后,通过uni.navigateBack()方法返回a页面时候,无法触发a页面的onShow函数里面的接口调用。 解决思路 uniapp官网页面通信 1.通过EventChann…

深入理解循环神经网络(RNN)

深入理解循环神经网络(RNN) 循环神经网络(Recurrent Neural Network, RNN)是一类专门处理序列数据的神经网络,广泛应用于自然语言处理、时间序列预测、语音识别等领域。本文将详细解释RNN的基本结构、工作原理以及其优…

UNIAPP_ReferenceError: TextEncoder is not defined 解决

错误信息 1、安装text-decoding npm install text-decoding2、main.js import { TextEncoder, TextDecoder } from text-decoding global.TextEncoder TextEncoder global.TextDecoder TextDecoder

OpenGL笔记二之glad加载opengl函数以及opengl-API(函数)初体验

OpenGL笔记二之glad加载opengl函数以及opengl-API(函数)初体验 总结自bilibili赵新政老师的教程 code review! 文章目录 OpenGL笔记二之glad加载opengl函数以及opengl-API(函数)初体验1.运行2.重点3.目录结构4.main.cpp5.CMakeList.txt 1.运行 2.重点 3.目录结构 01_GLFW_WI…

Vben admin 中 ApiSelect 类型的用法

ApiSelect作为一个接口下拉框选择的类型,其中还是有很多值得学习的功能作用: 一、参数及其功能解释 1、placeholder 显示提示文本 placeholder: 请选择人员, 2、labelField 下拉框所显示字段 labelField: nickname, 这里我需要显示人员名称&#x…

01. 课程简介

1. 课程简介 本课程的核心内容可以分为三个部分,分别是需要理解记忆的计算机底层基础,后端通用组件以及需要不断编码练习的数据结构和算法。 计算机底层基础可以包含计算机网络、操作系统、编译原理、计算机组成原理,后两者在面试中出现的频…