pytorch生成对抗网络

news/2025/2/7 1:27:31/

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

生成对抗网络(GAN,Generative Adversarial Network)是一种深度学习模型,由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。这两个网络通过对抗过程共同训练,从而使生成器能够生成越来越真实的假数据。

GAN的基本工作原理:

  1. 生成器(G):它的任务是生成与真实数据相似的假数据。生成器通常从一个随机噪声(例如,均匀分布或高斯分布的噪声)开始,经过多层神经网络的处理,输出伪造的数据样本。

  2. 判别器(D):它的任务是区分输入数据是来自真实数据分布,还是生成器伪造的假数据。判别器通常是一个二分类器,其输出是一个表示“真实”或“假”的概率值。

训练过程:

  • 对抗过程:生成器和判别器相互博弈。生成器希望生成尽可能像真的数据,以骗过判别器;而判别器希望准确区分真假数据。最终,生成器会通过优化损失函数,使得生成的数据与真实数据尽可能相似,判别器的性能则被提升到一个极限,使得它不能再轻易地区分真假数据。
  • 数学公式:

  • 判别器的目标是最大化其输出的正确分类概率,即区分真假数据。
  • 生成器的目标是最小化其输出的“假数据”被判定为假的概率。

常见的GAN变种:

  1. DCGAN(Deep Convolutional GAN):使用卷积神经网络(CNN)来增强生成器和判别器的表现。
  2. WGAN(Wasserstein GAN):引入了Wasserstein距离,改进了训练稳定性。
  3. CycleGAN:能够在没有成对样本的情况下进行图像到图像的转换,例如将马变成斑马。

以下是一个简化的PyTorch GAN实现的框架,生成一个语音的梅尔频谱(假设已经处理了音频并提取了梅尔频谱特征)

import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import matplotlib.pyplot as plt# 生成器(Generator)
class Generator(nn.Module):def __init__(self, z_dim=100):super(Generator, self).__init__()self.fc = nn.Sequential(nn.Linear(z_dim, 128),nn.ReLU(),nn.Linear(128, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 1024),nn.ReLU(),nn.Linear(1024, 80),  # 80表示梅尔频谱的时间步(例如:80个梅尔频率)nn.Tanh()  # 生成梅尔频谱,范围在[-1, 1]之间)def forward(self, z):return self.fc(z)# 判别器(Discriminator)
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.fc = nn.Sequential(nn.Linear(80, 512),  # 输入为梅尔频谱的时间步nn.LeakyReLU(0.2),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid()  # 输出判定是“真”还是“假”)def forward(self, x):return self.fc(x)# 初始化生成器和判别器
z_dim = 100
generator = Generator(z_dim)
discriminator = Discriminator()# 优化器
lr = 0.0002
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))# 损失函数
criterion = nn.BCELoss()# 加载数据(假设已经提取了梅尔频谱特征,取一个示例)
def load_example_mel_spectrogram():# 假设这是一个真实梅尔频谱的示例,实际数据应从音频文件中提取mel = torch.rand((80))  # 生成一个假的梅尔频谱数据return mel.unsqueeze(0)  # 扩展维度以适应网络# 训练GAN
num_epochs = 1000
for epoch in range(num_epochs):# 真实数据real_data = load_example_mel_spectrogram()real_labels = torch.ones(real_data.size(0), 1)  # 标签为1表示真实数据# 假数据z = torch.randn(real_data.size(0), z_dim)  # 随机噪声fake_data = generator(z)fake_labels = torch.zeros(real_data.size(0), 1)  # 标签为0表示假数据# 训练判别器discriminator.zero_grad()real_loss = criterion(discriminator(real_data), real_labels)fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)d_loss = (real_loss + fake_loss) / 2d_loss.backward()d_optimizer.step()# 训练生成器generator.zero_grad()g_loss = criterion(discriminator(fake_data), real_labels)  # 生成器希望判别器判定为真实g_loss.backward()g_optimizer.step()if epoch % 100 == 0:print(f"Epoch [{epoch}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")# 可视化生成的梅尔频谱(只显示最后一次生成的结果)if epoch == num_epochs - 1:plt.figure(figsize=(10, 4))plt.imshow(fake_data.detach().numpy(), aspect='auto', origin='lower')plt.title(f"Generated Mel Spectrogram - Epoch {epoch}")plt.colorbar()plt.show()# 测试阶段:使用训练好的生成器进行语音生成
z_test = torch.randn(1, z_dim)  # 创建一个新的随机噪声向量
generated_mel_spectrogram = generator(z_test)# 可视化生成的梅尔频谱
plt.figure(figsize=(10, 4))
plt.imshow(generated_mel_spectrogram.detach().numpy(), aspect='auto', origin='lower')
plt.title("Generated Mel Spectrogram from Test Data")
plt.colorbar()
plt.show()

解释:

  1. 测试阶段

    • 在训练完成后,我们使用一个新的随机噪声向量z_test来生成一个新的梅尔频谱。
    • generated_mel_spectrogram = generator(z_test)是生成梅尔频谱的过程。
  2. 可视化

    • 使用plt.imshow()来可视化生成的梅尔频谱图,origin='lower'是确保频谱图正确显示。
    • plt.colorbar()添加颜色条,以便更清晰地理解梅尔频谱的数值范围。

结果:

  • 在训练过程中,你会看到每个epoch的损失值,并在最后一次epoch时显示生成的梅尔频谱。
  • 在测试阶段,生成器会基于随机噪声生成一个新的梅尔频谱并进行可视化,帮助你观察最终模型生成的语音特征。

http://www.ppmy.cn/news/1569951.html

相关文章

ES6 变量解构赋值总结

1. 数组的解构赋值 1.1 基本用法 // 基本数组解构 const [a, b, c] [1, 2, 3]; console.log(a); // 1 console.log(b); // 2 console.log(c); // 3// 跳过某些值 const [x, , y] [1, 2, 3]; console.log(x); // 1 console.log(y); // 3// 解构剩余元素 const [first, ...re…

day37|完全背包基础+leetcode 518.零钱兑换II ,377.组合总和II

完全背包理论基础 完全背包与01背包的不同在于01背包的不同物品每个都只可以使用一次,但是完全背包的不同物品可以使用无数次 在01背包理论基础中,为了使得物品只被使用一次,我们采取倒序遍历来控制 回顾:>> for(int j …

STM32 DMA数据转运

DMA简介 DMA(Direct Memory Access)直接存储器存取 DMA可以提供外设和存储器或者存储器和存储器之间的高速数据传输,无须CPU干预,节省了CPU的资源 12个独立可配置的通道: DMA1(7个通道)&#xf…

RTMP 和 WebRTC

WebRTC(Web Real-Time Communication)和 RTMP(Real-Time Messaging Protocol)是两种完全不同的流媒体协议,设计目标、协议栈、交互流程和应用场景均有显著差异。以下是两者的详细对比,涵盖协议字段、交互流程及核心设计思想。 一、协议栈与设计目标对比 特性RTMPWebRTC传…

Spring Web MVC基础第一篇

目录 1.什么是Spring Web MVC? 2.创建Spring Web MVC项目 3.注解使用 3.1RequestMapping(路由映射) 3.2一般参数传递 3.3RequestParam(参数重命名) 3.4RequestBody(传递JSON数据) 3.5Pa…

JVM执行流程与架构(对应不同版本JDK)

直接上图(对应JDK8以及以后的HotSpot) 这里主要区分说明一下 方法区于 字符串常量池 的位置更迭: 方法区 JDK7 以及之前的版本将方法区存放在堆区域中的 永久代空间,堆的大小由虚拟机参数来控制。 JDK8 以及之后的版本将方法…

python-leetcode-验证二叉搜索树

98. 验证二叉搜索树 - 力扣(LeetCode) # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftNone, rightNone): # self.val val # self.left left # self.right right class Soluti…

信息安全、网络安全和数据安全的区别和联系

一、区别 1.信息安全 定义 信息安全是指为数据处理系统建立和采用的技术和管理的安全保护,保护计算机硬件、软件和数据不因偶然和恶意的原因而遭到破坏、更改和泄露。它的范围比较广泛,涵盖了信息的保密性、完整性和可用性等多个方面。 侧重点 更强…