PyTorch 实现动态输入

server/2024/12/4 18:16:16/

使用 PyTorch 实现动态输入:支持训练和推理输入维度不一致的 CNN 和 LSTM/GRU 模型

在深度学习中,处理不同大小的输入数据是一个常见的挑战。许多实际应用需要模型能够灵活地处理可变长度的输入。本文将介绍如何使用 PyTorch 实现支持动态输入的 CNN 和 LSTM/GRU 模型,并打印每一层的输入和输出。

  • 卷积神经网络(CNN):CNN 通常用于处理图像数据。它通过卷积层提取局部特征,并能够处理不同大小的输入图像。通过使用全局池化层,CNN 可以将不同大小的特征图转换为固定大小的输出。

  • 长短期记忆网络(LSTM)和门控循环单元(GRU):LSTM 和 GRU 是处理序列数据的 RNN 变体。它们能够捕捉时间序列中的长期依赖关系,并支持可变长度的输入序列。

模型搭建

1. CNN 模型

我们将构建一个简单的 CNN 模型,支持动态输入大小,并打印每一层的输入和输出。

python">import torch
import torch.nn as nn
import torch.nn.functional as Fclass DynamicCNN(nn.Module):def __init__(self):super(DynamicCNN, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)self.pool = nn.AdaptiveAvgPool2d((1, 1))  # 自适应池化层self.fc = nn.Linear(32, 10)  # 输出10个类别def forward(self, x):print(f'Input to CNN: {x.shape}')x = F.relu(self.conv1(x))print(f'Output after conv1: {x.shape}')x = F.relu(self.conv2(x))print(f'Output after conv2: {x.shape}')x = self.pool(x)print(f'Output after pooling: {x.shape}')x = x.view(x.size(0), -1)  # 展平x = self.fc(x)print(f'Output after fc: {x.shape}')return x# 创建模型
cnn_model = DynamicCNN()# 测试动态输入
input_tensor_cnn = torch.randn(1, 3, 64, 64)  # 输入形状为 (batch_size, channels, height, width)
output_cnn = cnn_model(input_tensor_cnn)
python">Input to CNN: torch.Size([1, 3, 55, 64])
Output after conv1: torch.Size([1, 16, 53, 62])
Output after conv2: torch.Size([1, 32, 51, 60])
Output after pooling: torch.Size([1, 32, 1, 1])
Output after fc: torch.Size([1, 10])
python">Input to CNN: torch.Size([1, 3, 64, 64])
Output after conv1: torch.Size([1, 16, 62, 62])
Output after conv2: torch.Size([1, 32, 60, 60])
Output after pooling: torch.Size([1, 32, 1, 1])
Output after fc: torch.Size([1, 10])

2. LSTM/GRU 模型

接下来,我们将构建一个支持动态输入的 LSTM 模型,并打印每一层的输入和输出。

python">import torch
import torch.nn as nnclass DynamicLSTM(nn.Module):def __init__(self):super(DynamicLSTM, self).__init__()self.lstm = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)self.fc = nn.Linear(20, 1)  # 输出一个值def forward(self, x):print(f'Input to LSTM: {x.shape}')x, _ = self.lstm(x)print(f'Output after LSTM: {x.shape}')x = self.fc(x[:, -1, :])  # 取最后一个时间步的输出print(f'Output after fc: {x.shape}')return x# 创建模型
lstm_model = DynamicLSTM()# 测试动态输入
input_tensor_lstm = torch.randn(5, 15, 10)  # 输入形状为 (batch_size, seq_length, input_size)
output_lstm = lstm_model(input_tensor_lstm)
python">Input to LSTM: torch.Size([5, 15, 10])
Output after LSTM: torch.Size([5, 15, 20])
Output after fc: torch.Size([5, 1])
python">Input to LSTM: torch.Size([5, 20, 10])
Output after LSTM: torch.Size([5, 20, 20])
Output after fc: torch.Size([5, 1])

代码说明

  1. DynamicCNN:该模型包含两个卷积层和一个全连接层。使用自适应平均池化层将特征图的大小调整为 (1, 1),从而支持不同大小的输入图像。每一层的输入和输出形状在前向传播中被打印出来。

  2. DynamicLSTM:该模型包含一个 LSTM 层和一个全连接层。LSTM 层能够处理可变长度的输入序列,输出的形状在前向传播中被打印出来。


http://www.ppmy.cn/server/147342.html

相关文章

《山海经》:北山

《山海经》:北山 北山一经单狐山求如山(水马:形状与马相似,滑鱼:背部红色)带山(䑏疏:似马,一只角,鵸鵌:状乌鸦五彩斑斓,儵鱼&#xff…

游戏引擎学习第27天

仓库:https://gitee.com/mrxiao_com/2d_game 欢迎 项目的开始是从零开始构建一款完整的游戏,完全不依赖任何库或引擎。这样做有两个主要原因:首先,因为这非常有趣;其次,因为它非常具有教育意义。了解游戏开发的低层次…

静态页面 和 动态页面(Java Web开发)

1. 静态页面 1.1 什么是静态页面? 静态页面是指 HTML 文件直接存放在服务器上,不依赖后端逻辑处理而生成内容。客户端浏览器请求静态页面时,服务器直接将文件发送到客户端,浏览器负责渲染页面。 特点: 固定内容&am…

Lumos学习王佩丰Excel第十七讲:数学函数

一、认识函数 1、Round函数 Roundup函数 Rounddown函数 Int函数 Round函数:将数字四舍五入到给定的位数。当末位有效数字为 5 或大于 5 时,ROUND 向上舍入;当末位有效数字小于 5 时则向下舍入。ROUND(number, num_digits),其中“…

python基础(六)

进程和线程 进程 进程就是操作系统中执行的一个程序,操作系统以进程为单位分配存储空间,每个进程都有自己的地址空间、数据栈以及其他用于跟踪进程执行的辅助数据,操作系统管理所有进程的执行,为它们合理的分配资源。创新的进程…

【系统架构设计师论文】微服务架构设计与实践

一、微服务架构概述 1.1 什么是微服务架构? 微服务架构(Microservices Architecture)是一种将应用程序拆分为多个小型服务的架构模式,每个服务都可以独立开发、部署、扩展,并通过轻量级的通信协议(通常是 HTTP/REST 或 gRPC)进行交互。这种架构的核心思想是将大型单体…

VScode离线下载扩展安装

在使用VScode下在扩展插件时,返现VScode搜索不到插件,网上搜了好多方法,都不是常规操作,解决起来十分麻烦,可以利用离线下载安装的方式安装插件!亲测有效!!! 1.找到VScod…

2024.12.2工作复盘

1.今天学了什么? 简单的写了一篇博客,是关于参数校验的问题,参数校验,一个是前后端校验到底一不一致,一个是绕过前端校验,看后台的逻辑到底能不能校验住。 2.今天解决了什么问题? 3.今天完成…