时间序列生成数据,TransformerGAN

news/2024/10/22 16:30:52/

        简介:这个代码可以用于时间序列修复和生成。使用transformer提取单变量或者多变时间窗口的趋势分布情况。然后使用GAN生成分布类似的时间序列。

        此外,还实现了基于prompt的数据生成,比如指定生成某个月份的数据、某半个月的数据、某一个星期的数据。

1、模型架构

        如下图所示,生成器和鉴别器都使用Transformer的编码器部分提取时间序列的特征,然后鉴别器使用这些进行二分类、生成器使用这些特征生成伪造的数据。

        重点:在下面的图的基础上,我还添加了基于提示的生成代码,类似于AI提示绘画一样,因此可以指定生成一月份、二月份等任意指定周期的数据。

2、训练GAN的代码

        下面是GAN的训练部分。

python"># 训练GAN
num_epochs = 100
for epoch in range(num_epochs):for real_x,x_g,zz in loader: # 分别是真实值real_x、提示词信息x_g、噪声zzreal_data = real_xnoisy_data = x_g# Train Discriminatoroptimizer_D.zero_grad()out = discriminator(real_data)real_loss = criterion(discriminator(real_data), torch.ones(real_data.size(0), 1))fake_data = generator(noisy_data,zz)fake_loss = criterion(discriminator(fake_data.detach()), torch.zeros(fake_data.size(0), 1))d_loss = real_loss + fake_lossd_loss.backward()optimizer_D.step()# Train Generatoroptimizer_G.zero_grad()g_loss = criterion(discriminator(fake_data), torch.ones(fake_data.size(0), 1))g_loss.backward()optimizer_G.step()print(f'Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

3、生成器代码

python">class Generator(nn.Module):def __init__(self, seq_len=8, patch_size=2, channels=1, num_classes=9, latent_dim=100, embed_dim=10, depth=1,num_heads=5, forward_drop_rate=0.5, attn_drop_rate=0.5):super(Generator, self).__init__()self.channels = channelsself.latent_dim = latent_dimself.seq_len = seq_lenself.embed_dim = embed_dimself.patch_size = patch_sizeself.depth = depthself.attn_drop_rate = attn_drop_rateself.forward_drop_rate = forward_drop_rateself.l1 = nn.Linear(self.latent_dim, self.seq_len * self.embed_dim)self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, self.embed_dim))self.blocks = Gen_TransformerEncoder(depth=self.depth,emb_size = self.embed_dim,drop_p = self.attn_drop_rate,)self.deconv = nn.Sequential(nn.Conv2d(self.embed_dim, self.channels, 1, 1, 0))def forward(self, z):x = self.l1(z).view(-1, self.seq_len, self.embed_dim)x = x + self.pos_embedH, W = 1, self.seq_lenx = self.blocks(x)x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])output = self.deconv(x.permute(0, 3, 1, 2))output = output.view(-1, self.channels, H, W)return output

4、生成数据和真实数据分布对比

        使用PCA和TSNE对生成的时间窗口数据进行降维,然后scatter这些二维点。如果生成的真实数据的互相混合在一起,说明模型学习到了真东西,也就是模型伪造的数据和真实数据分布是一样的,美滋滋。从下面的PCA可以看出,两者的分布还是近似的。

        进一步的,可以拟合两个二维正态分布,然后计算他们的KL散度作为一个评价指标。

5、生成数据展示

        上面是真实数据、下面是伪造的数据。由于只有几百个样本,以及参数都没有进行调整,但是效果还不错。

6、损失函数变化情况

        模型还是学习到了一点东西的。


http://www.ppmy.cn/news/1442443.html

相关文章

【GitHub】主页简历优化

【github主页】优化简历 写在最前面一、新建秘密仓库二、插件卡片配置1、仓库状态统计2、Most used languages(GitHub 常用语言统计)使用细则 3、Visitor Badge(GitHub 访客徽章)4、社交统计5、打字特效6、省略展示小猫 &#x1f…

HTTP如何自动跳转到HTTPS,免费SSL证书如何获取

如今HTTPS已经成为了网站标配,然而,对于一些刚刚起步的网站或是个人博客而言,如何自动跳转到HTTPS,以及免费SSL证书的获取,可能还是一个需要解决的问题。下面就来详细解答这两个问题。 我们需要先了解HTTP与HTTPS的区…

FreeBSD系统安装Tlsla P4专业AI计算卡(待续)

首先打开FreeBSD下的linux兼容服务 sysrc linux_enable"YES" service linux start 安装nvidia英伟达驱动 # 由shkhln提供,官网:https://github.com/shkhln/libc6-shim pkg install libc6-shim# nvidia驱动占用空间较大,需要2G空…

Win linux 下配置adb fastboot

一、Win配置adb & fastboot 环境变量 主机:Win10,除了adb fastboot需要设置变量之外,驱动直接安装即可 win下adb fastboot 下载地址:https://download.csdn.net/download/u012627628/89215420 win下qcom设备驱动下载地址&a…

最大连续1的个数 ||| ---- 滑动窗口

题目链接 题目: 分析: 题目中说可以将最多k个0翻转成1, 如果我们真的这样算就会十分麻烦, 所以我们可以换一种思路: 找到一个最长的子数组, 最多有k个0解法一: 暴力解法: 找到所有的最多有k个0的子字符串, 返回最长的解法二: 找到最长的子数组, 我们可以想到"滑动窗口算…

模拟LinkedList实现的双向循环链表

1. 前言 前文我们分别实现了不带哨兵的单链表,带哨兵节点的双向链表,接着我们实现带哨兵节点的双向循环链表.双向循环链表只需一个哨兵节点,该节点的prev指针和next指针都指向了自身哨兵节点. 2. 实现双向循环链表的代码 例 : //模拟双向…

三个目前主流的计算机视觉软件

计算机视觉是人工智能的一个重要分支,它涉及到使计算机能够理解和解释图像和视频数据。近年来,计算机视觉领域取得了显著的进展,尤其是在深度学习的帮助下。尽管如此,将计算机视觉的能力直接与人类的视觉能力进行比较并不完全准确…

深度学习之基础模型——循环神经网络RNN

相关资料 (1)What are Recurrent Neural Networks? | IBM (2)浅析循环神经网络(RNN)的反向求导过程 - 知乎 (zhihu.com) 总共有四篇 (3)循环神经网络(RNN)浅析 - 简书 (jianshu.co…