CNN文本分类实现文本分类案例解析(附实例源码)

news/2025/2/22 16:11:49/

前言

  • 实现步骤
    • 1.导入所需的库和模块:
    • 2.设置随机种子:
    • 3.定义模型超参数:
    • 4.加载数据集:
    • 5.对文本进行填充和截断:
    • 6.构建模型:
    • 7.编译模型:
    • 8.训练模型:
    • 9.评估模型:
  • 完整代码

CNN(卷积神经网络)在文本分类任务中具有良好的特征提取能力、位置不变性、参数共享和处理大规模数据的优势,能够有效地学习文本的局部和全局特征,提高模型性能和泛化能力,所以本文将以CNN实现文本分类。

CNN对文本分类的支持主要提现在:

特征提取:CNN能够有效地提取文本中的局部特征。卷积层通过应用多个卷积核来捕获不同大小的n-gram特征,从而能够识别关键词、短语和句子结构等重要信息。

位置不变性:对于文本分类任务,特征的位置通常是不重要的。CNN中的池化层(如全局最大池化)能够保留特征的最显著信息,同时忽略其具体位置,这对于处理可变长度的文本输入非常有帮助。

参数共享:CNN中的卷积核在整个输入上共享参数,这意味着相同的特征可以在不同位置进行识别。这种参数共享能够极大地减少模型的参数量,降低过拟合的风险,并加快模型的训练速度。

处理大规模数据:CNN可以高效地处理大规模的文本数据。由于卷积和池化操作的局部性质,CNN在处理文本序列时具有较小的计算复杂度和内存消耗,使得它能够适应大规模的文本分类任务。

上下文建模:通过使用多个卷积核和不同的大小,CNN可以捕捉不同尺度的上下文信息。这有助于提高模型对文本的理解能力,并能够捕捉更长范围的依赖关系。

实现步骤

1.导入所需的库和模块:

import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Embedding, Conv1D, GlobalMaxPooling1D

这些库包括NumPy用于数据处理,TensorFlow用于构建和训练模型,以及Keras中的各种层和模型类。

2.设置随机种子:

np.random.seed(42)

设置随机种子,可以确保每次运行代码时生成的随机数是相同的,以便结果可重现。

3.定义模型超参数:

max_features = 5000  # 词汇表大小
max_length = 100  # 文本最大长度
embedding_dims = 50  # 词嵌入维度
filters = 250  # 卷积核数量
kernel_size = 3  # 卷积核大小
hidden_dims = 250  # 全连接层神经元数量
batch_size = 32  # 批处理大小
epochs = 5  # 训练迭代次数

这些超参数将影响模型的结构和训练过程。可自行调整。

4.加载数据集:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)

示例中,使用的IMDB电影评论数据集,其中包含以数字表示的评论文本和相应的情感标签(正面或负面)。使用tf.keras.datasets.imdb.load_data函数可以方便地加载数据集,并指定num_words参数来限制词汇表的大小。

5.对文本进行填充和截断:

x_train = sequence.pad_sequences(x_train, maxlen=max_length)
x_test = sequence.pad_sequences(x_test, maxlen=max_length)

由于每条评论的长度可能不同,需要将它们统一到相同的长度。sequence.pad_sequences函数用于在文本序列前后进行填充或截断,使它们具有相同的长度。

6.构建模型:

model = Sequential()
model.add(Embedding(max_features, embedding_dims, input_length=max_length))
model.add(Dropout(0.2))
model.add(Conv1D(filters, kernel_size, padding='valid', activation='relu', strides=1))
model.add(GlobalMaxPooling1D())
model.add(Dense(hidden_dims, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))

这个模型使用Sequential模型类构建,依次添加了嵌入层(Embedding)、卷积层(Conv1D)、全局最大池化层(GlobalMaxPooling1D)和两个全连接层(Dense)。嵌入层将输入的整数序列转换为固定维度的词嵌入表示,卷积层通过应用多个卷积核来提取特征,全局最大池化层获取每个特征通道的最大值,而两个全连接层用于分类任务。

7.编译模型:

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

在编译模型之前,需要指定损失函数、优化器和评估指标。使用二元交叉熵作为损失函数,Adam优化器进行参数优化,并使用准确率作为评估指标。

8.训练模型:

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test))

使用fit函数对模型进行训练。需要传入训练数据、标签,批处理大小、训练迭代次数,并可以指定验证集进行模型性能评估。

9.评估模型:

scores = model.evaluate(x_test, y_test, verbose=0)
print("Test accuracy:", scores[1])

使用evaluate函数评估模型在测试集上的性能,计算并打印出测试准确率。

完整代码

import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Embedding, Conv1D, GlobalMaxPooling1D# 设置随机种子
np.random.seed(42)# 定义模型超参数
max_features = 5000  # 词汇表大小
max_length = 100  # 文本最大长度
embedding_dims = 50  # 词嵌入维度
filters = 250  # 卷积核数量
kernel_size = 3  # 卷积核大小
hidden_dims = 250  # 全连接层神经元数量
batch_size = 32  # 批处理大小
epochs = 5  # 训练迭代次数# 加载数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)# 对文本进行填充和截断,使其具有相同的长度
x_train = sequence.pad_sequences(x_train, maxlen=max_length)
x_test = sequence.pad_sequences(x_test, maxlen=max_length)# 构建模型
model = Sequential()
model.add(Embedding(max_features, embedding_dims, input_length=max_length))
model.add(Dropout(0.2))
model.add(Conv1D(filters, kernel_size, padding='valid', activation='relu', strides=1))
model.add(GlobalMaxPooling1D())
model.add(Dense(hidden_dims, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))# 编译模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test))# 评估模型
scores = model.evaluate(x_test, y_test, verbose=0)
print("Test accuracy:", scores[1])

http://www.ppmy.cn/news/69439.html

相关文章

低代码信创开发核心技术(一):基于Vue.js的描述依赖渲染DDR实现模型驱动的组件

前言 随着数字化转型的不断发展,低代码开发平台已成为企业快速建立自己的应用程序的首选方案。然而,实现这样一个平台需要具备高效、灵活和可定制化的能力。这正是基于描述依赖渲染(Description dependency rendering)所实现的。…

计算image1和image2之间的LPIPS指标的python代码

计算image1和image2之间的LPIPS指标的python代码 import cv2 import lpips import torchloss_fn_vgg lpips.LPIPS(netalex).to("cuda:0")image1_path F:\YXL\project\Restormer-mainV364_v1\YXL_dir\photo\photo_otput_result.png image2_path F:\YXL\project\Re…

3.编写油猴脚本之-helloword

3.编写油猴脚本之-helloword Start 通过上一篇文章的学习,我们安装完毕了油猴插件。今天我们来编写一个helloword的脚步,体验一下油猴。 1. 开始 点击油猴插件>添加新脚本 默认生成的脚本 // UserScript // name New Userscript // name…

Java构造方法

Java构造方法 Java构造方法是啥,有什么作用构造方法如何定义?构造方法该如何调用?案例(利用构造方法完成一个时间打印)构造方法必须与类名相同构造方法可以重载吗?啥是缺省构造器? Java构造方法是啥,有什么…

美团数据指标体系搭建实战

在美团商家版中,美团为商家搭建的数据指标体系,很好的指导了商家的经营发展方向以及提供经营状况概览。​ 本文通过体验美团商家版经营数据子功能,对美团商家版数据指标体系搭建的情况做出一个概述。 美团商家版的店铺子功能下,…

【Spark编程基础】第7章 Structured Streaming

系列文章目录 文章目录 系列文章目录前言第7章 Structured Streaming7.1 概述7.1.1 基本概念7.1.2 两种处理模型7.1.3 Structured Streaming 和 Spark SQL、Spark Streaming 关系 7.2 编写Structured Streaming程序的基本步骤7.3 输入源7.3.1 File源7.3.2 Kafka源7.3.3 Socket源…

【MySql】数据库设计过程

目录 概念数据库设计: 逻辑数据库设计: 物理数据库设计: ->需求分析(收集需求和理解需求,“源”) ->概念数据库设计(建立概念模型:"E-R图/IDEF1X") ->逻辑数据库设计&…

新星计划 uni-app 学习1

活动地址: 教程地址: 白话uni-app | uni-app官网 每日一学Vue脚手架中基础的ref属性与原生id区别_ref和id区别_lqj_本人的博客-CSDN博客 每日一学vue2:组件复用(详细讲解)、mixin(混入)、mo…