代码解读 | Hybrid Transformers for Music Source Separation[06]

embedded/2024/10/18 9:17:39/

一、背景

        0、Hybrid Transformer 论文解读

        1、代码复现|Demucs Music Source Separation_demucs架构原理-CSDN博客

        2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方

        3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?

        4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?

        5、Hybrid Transformer 拆解STFT模块

        6、Hybrid Transformer 拆解频域编码模块


        从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块ISTFT模块)7个模块。已完成解读:STFT模块、频域编码模块(时域编码和频域编码类似,后续不再解读时域编码模块),待解读:Cross-Domain Transformer Encoder模块。

        本篇目标:拆解频域解码模块ISTFT模块的底层。时域解码和频域解码原理类似(后续不再拆解时域解码模块)。

二、频域解码模块

python">
class HDecLayer(nn.Module):def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,context_freq=True, rewrite=True):"""Same as HEncLayer but for decoder. See `HEncLayer` for documentation."""super().__init__()norm_fn = lambda d: nn.Identity()  # noqaif norm:norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqaif pad:pad = kernel_size // 4else:pad = 0self.pad = padself.last = lastself.freq = freqself.chin = chinself.empty = emptyself.stride = strideself.kernel_size = kernel_sizeself.norm = normself.context_freq = context_freqklass = nn.Conv1dklass_tr = nn.ConvTranspose1dif freq:kernel_size = [kernel_size, 1]stride = [stride, 1]klass = nn.Conv2dklass_tr = nn.ConvTranspose2dself.conv_tr = klass_tr(chin, chout, kernel_size, stride)self.norm2 = norm_fn(chout)if self.empty:returnself.rewrite = Noneif rewrite:if context_freq:self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)else:self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,[0, context])self.norm1 = norm_fn(2 * chin)self.dconv = Noneif dconv:self.dconv = DConv(chin, **dconv_kw)def forward(self, x, skip, length):if self.freq and x.dim() == 3:B, C, T = x.shapex = x.view(B, self.chin, -1, T)if not self.empty:x = x + skipif self.rewrite:y = F.glu(self.norm1(self.rewrite(x)), dim=1)else:y = xif self.dconv:if self.freq:B, C, Fr, T = y.shapey = y.permute(0, 2, 1, 3).reshape(-1, C, T)y = self.dconv(y)if self.freq:y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)else:y = xassert skip is Nonez = self.norm2(self.conv_tr(y))print('self.pad,self.last:', self.pad,self.last)if self.freq:if self.pad:z = z[..., self.pad:-self.pad, :]else:z = z[..., self.pad:self.pad + length]assert z.shape[-1] == length, (z.shape[-1], length)if not self.last:z = F.gelu(z)return z, y

        频域解码模块的核心代码如上所示。在上一篇频域编码模块的基础上,继续贴出完善之后的频域编解码模块全景图。

编码层:Conv2d+Norm1+GELU,  Norm1:Identity()

解码层:(Conv2d+Norm1+GLU)+(ConvTranspose2d+Norm2+倒数第二个维度裁剪+GELU),    Norm1\Norm2:Identity()

残差连接:(Conv1d+GroupNorm+GELU +Conv1d+GroupNorm+GLU+LayerScale())+(Conv2d+Norm2+GLU),Norm2:Identity() ,备注:Identity可以理解成直通

python">#频域编码层1-4的Conv2d分别是:
Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))#频域解码层4-1的Conv2d和ConvTranspose2d
Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(384, 192, kernel_size=(8, 1), stride=(4, 1)) 
Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(192, 96, kernel_size=(8, 1), stride=(4, 1))
Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(96, 48, kernel_size=(8, 1), stride=(4, 1))
Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
ConvTranspose2d(48, 16, kernel_size=(8, 1), stride=(4, 1))

        残差连接模块如下所示。

python">#残差连接1
DConv((layers): ModuleList((0): Sequential((0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 6, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 96, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 6, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 96, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))#残差连接2
DConv((layers): ModuleList((0): Sequential((0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 12, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 192, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 12, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 192, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))#残差连接3
DConv((layers): ModuleList((0): Sequential((0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 24, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 384, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 24, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 384, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))#残差连接4
DConv((layers): ModuleList((0): Sequential((0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 48, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 768, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 48, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 768, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))

三、ISTFT模块

        ISTFT模块的核心代码如下所示。

python">import torch as th
def ispectro(z, hop_length=None, length=None, pad=0):*other, freqs, frames = z.shapen_fft = 2 * freqs - 2z = z.view(-1, freqs, frames)win_length = n_fft // (1 + pad)is_mps = z.device.type == 'mps'if is_mps:z = z.cpu()x = th.istft(z,n_fft,hop_length,window=th.hann_window(win_length).to(z.real),win_length=win_length,normalized=True,length=length,center=True)_, length = x.shapereturn x.view(*other, length)

        其中,torch.istft【逆短时傅里叶变换(Inverse Short Time Fourier Transform,ISTFT)】,该函数期望是torch.stft函数的逆过程。它具有相同的参数(加上一个可选参数length),并且应该返回原始信号的最小二乘估计。算法将根据NOLA条件(非零重叠)进行检查。

#### torch.istft接口参数####
input (Tensor): 输入张量,期望是`torch.stft`的输出,可以是复数形式(`channel`, `fft_size`, `n_frame`),或者是实数形式(`channel`, `fft_size`, `n_frame`, 2),其中`channel`维度是可选的。

       deprecated:: 1.8.0
            实数输入已废弃,请使用`stft(..., return_complex=True)`返回的复数输入代替。
n_fft (int): 傅里叶变换的大小。
hop_length (Optional[int]): 相邻滑动窗口帧之间的距离。(默认:`n_fft // 4`)
win_length (Optional[int]): 窗口帧和STFT滤波器的大小。(默认:`n_fft`)
window (Optional[torch.Tensor]): 可选的窗函数。(默认:`torch.ones(win_length)`)
center (bool): 指示输入是否在两边进行了填充,使得第`t`帧位于时间`t × hop_length`处居中。(默认:`True`)
normalized (bool): 指示STFT是否被标准化。(默认:`False`)
onesided (Optional[bool]): 指示STFT是否为单边谱。(默认:如果输入尺寸中的`n_fft != fft_size`则为`True`)
length (Optional[int]): 修剪信号的长度,即原始信号的长度。(默认:整个信号)
return_complex (Optional[bool]):指示输出是否应为复数,或者输入是否应假定源自实信号和窗函数。注意,这与`onesided=True`不兼容。(默认:`False`)

        频域解码模块和ISTFT模块解读完毕。还剩一个Cross-Domain Transformer Encoder模块没有解读。后面又来新的活了,希望能把demucs落地~。


        感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)


http://www.ppmy.cn/embedded/48586.html

相关文章

Ceph入门到精通-Crimson 和 Classic Ceph OSD 架构之间的区别

Crimson 和 Classic Ceph OSD 架构之间的区别 在典型的 ceph-osd 架构中,messenger 线程从 wire 读取客户端消息,它将消息放在 OP 队列中。然后,osd-op thread-pool 会提取消息,并创建一个事务并将其排队到 BlueStore,当前的默认 ObjectStore 实现。然后,BlueStore 的 k…

Vue项目中,利用iframe在线预览pdf/图片等

在components下新建 src\components\iFrame\index.vue <template><div v-loading"loading" :style"height: height"><iframe :src"src" frameborder"no" style"width: 100%; height: 100%" scrolling"…

江苏哪些行业需要服务器托管?

服务器托管顾名思义就是用户委托具有完善设备的机房、良好网络和丰富运营经验的服务商管理其计算机系统&#xff0c;使企业的服务器能够更加安全、稳定和高效的运行&#xff0c;那在江苏都有哪些行业需要服务器托管服务呢&#xff1f;本文就来大概介绍一下。 首先让我们来一起了…

u盘数据要在哪台电脑上恢复?u盘数据恢复后保存在哪里

在数字化时代&#xff0c;U盘已成为我们日常生活中不可或缺的数据存储设备。然而&#xff0c;由于各种原因&#xff0c;U盘中的数据可能会意外丢失&#xff0c;这时数据恢复就显得尤为重要。但是&#xff0c;很多人对于在哪台电脑上进行U盘数据恢复以及恢复后的数据应保存在哪里…

Pod中使用自定义服务账号调用自定义资源

一、背景 1、从开发角度 &#xff08;1&#xff09;服务通过容器化部署的方式运行在云环境的Pod中&#xff0c;然而在 Kubernetes 中&#xff0c;Pod 中的服务不能直接通过 client-go 访问 Kubernetes 资源&#xff0c;而是需要通过 Kubernetes API Server 来进行访问。clien…

数据结构——队列(Queue)详解

1.队列&#xff08;Queue&#xff09; 1.1概念 队列&#xff1a;只允许在一端进行插入数据操作&#xff0c;在另一端进行删除数据操作的特殊线性表&#xff0c;队列具有先进先出FIFO(First In First Out)的性质 入队列&#xff1a;进行插入操作的一端称为队尾(Tail/Rear) 出…

SQL 窗口函数

1.窗口函数之排序函数 RANK, DENSE_RANK, ROW_NUMBER RANK函数 计算排序时,如果存在相同位次的记录,则会跳过之后的位次 有 3 条记录排在第 1 位时: 1 位、1 位、1 位、4 位…DENSE_RANK函数 同样是计算排序,即使存在相同位次的记录,也不会跳过之后的位次 有 3 条记录排在…

Prometheus+Grafana监控MySQL

一、准备 grafana服务器&#xff1a;192.168.48.136Prometheus服务器&#xff1a;192.168.48.136被监控服务器&#xff1a;192.168.48.134、192.168.48.135查看时间是否同步 二、安装prometheus server 【2.1】安装 # 解压安装包 tar -zxvf prometheus-2.52.0.linux-amd64.t…