手写单层RNN网络,后续更新

devtools/2025/2/5 19:24:28/

文章目录

  • 1. 原理
  • 2. pytorch 源码,只是测试版,后续持续优化

1. 原理

根据如下公式,简单的手写实现单层的RNN神经网络,加强代码功能和对网络的理解能力
在这里插入图片描述

2. pytorch 源码,只是测试版,后续持续优化

import torch
import torch.nn as nn
import torch.nn.functional as Ftorch.set_printoptions(precision=3, sci_mode=False)
torch.manual_seed(23435)if __name__ == "__main__":run_code = 0input_size = 4hidden_size = 3num_layers = 1batch_first = Truesingle_rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first)print(single_rnn)for name in single_rnn.named_parameters():print(name)single_rnn_weight_ih_l0 = single_rnn.weight_ih_l0single_rnn_weight_hh_l0 = single_rnn.weight_hh_l0single_rnn_bias_ih_l0 = single_rnn.bias_ih_l0single_rnn_bias_hh_l0 = single_rnn.bias_hh_l0# print(f"single_rnn_weight_ih_l0=\n{single_rnn_weight_ih_l0}")# input --> batch_size,seq_len,feature_mapin_batch_size = 1in_seq_len = 2in_feature_map = input_sizeinput_matrix = torch.randn(in_batch_size, in_seq_len, in_feature_map)output_matrix, output_hn = single_rnn(input_matrix)print(f"output_matrix=\n{output_matrix}")print(f"output_hn=\n{output_hn}")test_output0 = input_matrix @ single_rnn_weight_ih_l0.T + single_rnn_bias_ih_l0ht_1 = torch.zeros_like(test_output0)print(f"ht_1=\n{ht_1}")print(f"ht_1.shape=\n{ht_1.shape}")test_output1 = ht_1 @ single_rnn_weight_hh_l0.T + single_rnn_bias_hh_l0test_output = torch.tanh(test_output1 + test_output0)ht_1[:,1, :] = test_output[:,0, :]test_output1 = ht_1 @ single_rnn_weight_hh_l0.T + single_rnn_bias_hh_l0test_output = torch.tanh(test_output1 + test_output0)print(f"test_output=\n{test_output}")print(f"test_output.shape=\n{test_output.shape}")
  • 结果:经计算,通过pytorch官方的API输出的结果和自定义的结果一致!!!
RNN(4, 3, batch_first=True)
('weight_ih_l0', Parameter containing:
tensor([[ 0.413,  0.044,  0.243,  0.171],[-0.093,  0.250, -0.499, -0.450],[-0.571,  0.220,  0.464, -0.154]], requires_grad=True))
('weight_hh_l0', Parameter containing:
tensor([[-0.403,  0.165, -0.244],[ 0.216, -0.511, -0.441],[ 0.133,  0.278, -0.211]], requires_grad=True))
('bias_ih_l0', Parameter containing:
tensor([ 0.115, -0.493,  0.555], requires_grad=True))
('bias_hh_l0', Parameter containing:
tensor([-0.309, -0.504,  0.311], requires_grad=True))
output_matrix=
tensor([[[ 0.243, -0.467, -0.554],[-0.013, -0.802, -0.490]]], grad_fn=<TransposeBackward1>)
output_hn=
tensor([[[-0.013, -0.802, -0.490]]], grad_fn=<StackBackward0>)
ht_1=
tensor([[[0., 0., 0.],[0., 0., 0.]]])
ht_1.shape=
torch.Size([1, 2, 3])
test_output=
tensor([[[ 0.243, -0.467, -0.554],[-0.013, -0.802, -0.490]]], grad_fn=<TanhBackward0>)
test_output.shape=
torch.Size([1, 2, 3])

http://www.ppmy.cn/devtools/156359.html

相关文章

力扣面试150 长度最小的子数组 滑动窗口

Problem: 209. 长度最小的子数组 参考题解 滑动窗口 class Solution {public int minSubArrayLen(int target, int[] nums) {int n nums.length;int ans n 1;int sum 0; // 子数组元素和int left 0; // 子数组左端点for (int right 0; right < n; right) { // 枚举…

使用 Docker(Podman) 部署 MongoDB 数据库及使用详解

在现代开发环境中&#xff0c;容器化技术&#xff08;如 Docker 和 Podman&#xff09;已成为部署和管理应用程序的标准方式。本文将详细介绍如何使用 Podman/Docker 部署 MongoDB 数据库&#xff0c;并确保其他应用程序容器能够通过 Docker 网络成功连接到 MongoDB。我们将逐步…

Docker 部署教程jenkins

Docker 部署 jenkins 教程 Jenkins 官方网站 Jenkins 是一个开源的自动化服务器&#xff0c;主要用于持续集成&#xff08;CI&#xff09;和持续交付&#xff08;CD&#xff09;过程。它帮助开发人员自动化构建、测试和部署应用程序&#xff0c;显著提高软件开发的效率和质量…

Redis详解

一、介绍 Redis是一个基于内存的key-value结构数据库 基于内存存储&#xff0c;读写性能高适合存储热点数据企业应用广泛 二、安装 官网&#xff1a;Redis中文网 三、数据类型 四、常用命令 1.字符串操作命令 2.哈希操作命令 3.列表操作命令 4.集合操作命令 5.有序集合操作命…

billd-live 一款开源、免费、技术先进的直播系统

一、简介 Billd-Live是一个基于Vue3、WebRTC、Node、SRS和FFmpeg等技术搭建的直播间系统&#xff0c;支持在线Web和安卓端查看。它实现了类似于bilibili的Web在线直播功能&#xff0c;允许用户发布直播并观看他人的直播内容。 二、功能 原生 webrtc 推拉流 srs webrtc 推流&…

Unity打包安卓报错sdk version 0.0 < 26.0(亲测解决)

问题描述和尝试解决方案&#xff1a; Unity打包安卓报错sdk version 0.0 < 26.0高版本Unity手动指定SDK地址时&#xff0c;比较容易出现上述错误高手支招1&#xff1a;修改sdk的tools文件夹中package.xml的obsolete"false"无解&#xff0c;因为打开platform-tool…

excel实用问题:提取文字当中的数字进行运算

0、前言&#xff1a; 这里汇总在使用excel工作过程中遇到的问题&#xff0c;excel使用wps版本&#xff0c;小规模数据我们自己提取数据可行&#xff0c;大规模数据就有些难受了&#xff0c;因此就产生了如下处理办法。 需求&#xff1a;需要把所有文字当中的数字提取出来&…

【怎么用系列】短视频戒除—1—对推荐算法进行干扰

如今推荐算法已经渗透到人们生活的方方面面&#xff0c;尤其是抖音等短视频核心就是推荐算法。 【短视频的危害】 1> 会让人变笨&#xff0c;慢慢让人丧失注意力与专注力 2> 让人丧失阅读长文的能力 3> 让人沉浸在一个又一个快感与嗨点当中。当我们刷短视频时&#x…