pytorch张量分块投影示例代码

server/2025/1/16 0:22:26/

张量的投影操作

背景

张量投影 是深度学习中常见的操作,将输入张量通过线性变换映射到另一个空间。例如:
Y=W⋅X+b
其中:

  • X: 输入张量(形状可能为 (B,M,K),即批量维度、序列维度、特征维度)。
  • W: 权重矩阵((K,N),将 K 维投影到 N 维)。
  • b: 偏置向量(可选,(N,))。
  • Y: 输出张量(形状 (B,M,N))。

对于巨大张量 XX,直接计算 W⋅XW⋅X 可能会因为显存不足导致 OOM(Out of Memory)。因此,分块操作是一种有效的解决方案。


分块投影的操作方法

原理

将输入张量 X 沿着某个维度(通常是 序列维度 M 或 批量维度 B)分成多个小块,分别进行线性变换,再将结果拼接起来。

具体步骤
  1. 定义分块大小

    • 根据显存限制和硬件特性,确定每次可以处理的块大小(chunk_size)。
  2. 迭代计算

    • 将输入张量 X 按 序列维度 M(或其他维度)进行切片。
    • 对每个切片分别进行线性投影操作。
    • 将每次的结果存储起来,最后拼接成完整输出。

分块投影计算函数代码:

import torchdef block_projection(X, W, b=None, chunk_size=64):"""Perform block-wise tensor projection.Args:X: Input tensor of shape (B, M, K)W: Weight matrix of shape (K, N)b: Bias vector of shape (N,) or Nonechunk_size: Size of each block along the M dimensionReturns:Y: Output tensor of shape (B, M, N)"""B, M, K = X.shape

http://www.ppmy.cn/server/158698.html

相关文章

神经网络初始化 (init) 介绍

文章目录 引言1. 初始化的重要性1.1 打破对称性1.2 控制方差1.3 加速收敛与提高泛化能力 2. 常见的初始化方法及其应用场景2.1 Xavier/Glorot 初始化2.2 He 初始化2.3 正交初始化2.4 其他初始化方法 3. 如何设置初始化4. 基于 BERT 的文本分类如何进行初始化4.1 项目背景4.2 模…

分布式ID的实现方案

1. 什么是分布式ID ​ 对于低访问量的系统来说,无需对数据库进行分库分表,单库单表完全可以应对,但是随着系统访问量的上升,单表单库的访问压力逐渐增大,这时候就需要采用分库分表的方案,来缓解压力。 ​…

<C++学习> C++ Boost 字符串操作教程

C Boost 字符串操作教程 Boost 提供了一些实用的库来增强 C 的字符串操作能力,特别是 Boost.StringAlgo 和其他与字符串相关的工具。这些库为字符串处理提供了更高效、更简洁的方法,相比标准库功能更为丰富。 1. Boost.StringAlgo 简介 Boost.StringAl…

数据结构:栈(Stack)和队列(Queue)—面试题(二)

1. 用队列实现栈。 习题链接https://leetcode.cn/problems/implement-stack-using-queues/description/描述: 请你仅使用两个队列实现一个后入先出(LIFO)的栈,并支持普通栈的全部四种操作(push、top、pop 和 empty&a…

JVM之垃圾回收器ZGC概述以及垃圾回收器总结的详细解析

ZGC ZGC 收集器是一个可伸缩的、低延迟的垃圾收集器,基于 Region 内存布局的,不设分代,使用了读屏障、染色指针和内存多重映射等技术来实现可并发的标记压缩算法 在 CMS 和 G1 中都用到了写屏障,而 ZGC 用到了读屏障 染色指针&a…

2025年01月13日Github流行趋势

1. 项目名称:Jobs_Applier_AI_Agent 项目地址url:https://github.com/feder-cr/Jobs_Applier_AI_Agent项目语言:Python历史star数:25929今日star数:401项目维护者:surapuramakhil, feder-cr, cjbbb, sarob…

13:00面试,13:08就出来了,问的问题有点变态。。。

从小厂出来,没想到在另一家公司又寄了。 到这家公司开始上班,加班是每天必不可少的,看在钱给的比较多的份上,就不太计较了。没想到9月一纸通知,所有人不准加班,加班费不仅没有了,薪资还要降40%…

linux手动安装mysql5.7

一、下载mysql5.7 1、可以去官方网站下载mysql-5.7.24-linux-glibc2.12-x86_64.tar压缩包: https://downloads.mysql.com/archives/community/ 2、在线下载,使用wget命令,直接从官网下载到linux服务器上 wget https://downloads.mysql.co…