Tensorflow2基础代码实战系列之CNN文本分类实战

news/2024/12/29 15:13:34/

深度学习框架Tensorflow2系列

注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
这个系列主要和大家分享深度学习框架Tensorflow2的各种api,从基础开始。
#博学谷IT学习技术支持#


文章目录

  • 深度学习框架Tensorflow2系列
  • 前言
  • 一、文本分类任务实战
  • 二、数据集介绍
  • 三、CNN模型解读
  • 四、实战代码
    • 1.数据预处理
    • 2.定义模型
    • 3.模型训练
  • 总结


前言

通过CNN文本分类实战案例,学习Tensorflow2中一些API


一、文本分类任务实战

任务介绍:
数据集构建:影评数据集进行情感分析(分类任务)
词向量模型:加载训练好的词向量或者自己训练都可以
序列网络模型:训练RNN模型进行识别

二、数据集介绍

训练和测试集都是比较简单的电影评价数据集,标签为0和1的二分类,表示对电影的喜欢和不喜欢
在这里插入图片描述

三、CNN模型解读

在这里插入图片描述
通过不同尺度的卷积核[(2,3,4),word_dim] 来提取单词的特征,再进行max_pooling得到一个特征值,最后把所有尺度得到的特征值拼接在一起后,通过全连接进行分类。

四、实战代码

1.数据预处理

这里直接加载默认数据集,通过pad_sequences进行截断和填充操作
得到训练集和测试集大小都为(25000,300)
各25000个样本,每个样本长度为300
(25000, 300)
(25000, 300)

import warnings
warnings.filterwarnings("ignore")
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.sequence import pad_sequencesnum_features = 3000
sequence_length = 300
embedding_dimension = 100
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=num_features)
x_train = pad_sequences(x_train, maxlen=sequence_length)
x_test = pad_sequences(x_test, maxlen=sequence_length)
print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)

2.定义模型

# 多种卷积核,相当于单词数
filter_sizes=[3,4,5]
def convolution():inn = layers.Input(shape=(sequence_length, embedding_dimension, 1))#3维的cnns = []for size in filter_sizes:conv = layers.Conv2D(filters=64, kernel_size=(size, embedding_dimension),strides=1, padding='valid', activation='relu')(inn)#需要将多种卷积后的特征图池化成一个特征pool = layers.MaxPool2D(pool_size=(sequence_length-size+1, 1), padding='valid')(conv)cnns.append(pool)# 将得到的特征拼接在一起outt = layers.concatenate(cnns)model = keras.Model(inputs=inn, outputs=outt)return modeldef cnn_mulfilter():model = keras.Sequential([layers.Embedding(input_dim=num_features, output_dim=embedding_dimension,input_length=sequence_length),layers.Reshape((sequence_length, embedding_dimension, 1)),convolution(),layers.Flatten(),layers.Dense(10, activation='relu'),layers.Dropout(0.2),layers.Dense(1, activation='sigmoid')])model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.BinaryCrossentropy(),metrics=['accuracy'])return modelmodel = cnn_mulfilter()
model.summary()

在这里插入图片描述
得到的模型结构如上图,其中192来自3种不同卷积核通过max_pooling之后相加得到的结果(64*3=192),每种卷积卷积之后得到一个向量,通过max_pooling之后得到一个特征值,每种卷积核设置filters=64所有最终一个卷积核得到64个值。

3.模型训练

deom级别测试代码

history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'valiation'], loc='upper left')
plt.show()

总结

通过CNN文本分类任务代码案例实战,学习Tensorflow2的各种api。


http://www.ppmy.cn/news/84218.html

相关文章

Cobalt Strike工具基本使用

Cobalt Strike 安装启动启动server端启动client目标机器连接 工具基使用用户驱动攻击屏幕截图进程列表键盘记录文件管理远程vnc远程代理端口扫描 生成后门被攻击者运行后门文件后查看结果 钓鱼攻击信息收集网站克隆文件下载 安装 网盘地址:链接:https:/…

(汇编) 基于VS的x86汇编基础指令

文章目录 环境汇编基础标志位常用指令 vs配置END 环境 visual studio 选择x86运行 示例代码 /** | 32位 | 16位 | 高8位 | 低8位 | | ---- | ---- | ----- | ----- | | EAX | AX | AH | AL |*/ #include <iostream>int main() {int32_t x 1;int32_t y 2;//…

多模态对话语言模型-VisualGLM-6B

多模态对话语言模型-VisualGLM-6B 一、简介二、使用模型推理三、部署工具网页版 DemoAPI部署四、example五、交流一、简介 VisualGLM-6B 是一个开源的,支持图像、中文和英文的多模态对话语言模型,语言模型基于 ChatGLM-6B,具有 62 亿参数;图像部分通过训练 BLIP2-Qformer 构…

多地住建局推广工程资料电子化,帮助工程企业“降本增效”

工程资料签署和管理是每个在建工程绕不开的课题&#xff0c;庞大的签署量、动则几十万的签署成本如何优化&#xff1b;有关部门的合规审查如何过关…纸质工程资料需要面对的难题还有很多&#xff1a; 麻烦&#xff1a;从工程立项申报、审批、设计、施工到验收等全过程中产生的大…

如何使用Python和wxPython构建一个HTML Title提取工具

以下代码可以用于以下场景&#xff1a; 在Web开发中&#xff0c;获取网页中的Title内容&#xff0c;以用于页面SEO。在数据挖掘和分析中&#xff0c;获取包含Title信息的HTML页面&#xff0c;以进行进一步的文本处理和分析。在一些需要从HTML源代码中获取元数据的应用中&#…

最全iOS 上架指南

一、基本需求信息。 1、苹果开发人员账户&#xff08;公司已经可以无需申请&#xff0c;需要开启开发者功能&#xff0c;每年99美元&#xff09; 2、开发好应用程序 二、证书 上架版本需要使用正式证书。 1、创建Apple Developer证书 2、上传证书Sign In - Apple 3、点击开发者…

【ChatGPT】不会用ChatGPT?这几个镜像网站解决你的烦恼。

个人主页&#xff1a;【&#x1f60a;个人主页】 文章目录 前言ChatGPT介绍WoChatA TalkChatGPT Next WebAI EDUCHATGPTSITES 前言 还在为需要魔法才能与ChatGPT见上一面而叹息吗&#xff0c;今我就为大家汇总了国内能使用ChatGPT的方法。 也就是用国内的镜像网站玩ChatGPT&…

AI别来搅局,ChatGPT的世界不懂低代码

ChatGPT单月访问量再创新高 根据SimilarWeb统计&#xff0c;ChatGPT上月全球访问量17.6亿次&#xff0c;已超越必应、鸭鸭走DuckDuckGo等其他国际搜索引擎&#xff0c;并达到谷歌的2%&#xff0c;百度的60%。 这会&#xff0c;程序员失业的段子又得再来一遍了&#xff1a; 拖…