李宏毅(2020)作业9:无监督学习降维、聚类、自编码

news/2024/11/17 6:48:40/

在这里插入图片描述

文章目录

  • 数据集
  • 作业
    • 任务1
    • 任务2
    • 任务3
    • 数据
  • 下载数据集
  • 准备训练数据
  • 一些工具函数
  • 模型
  • 训练
  • 降维和聚类
  • 问题1(作图)
  • 问题2
  • 问题3

数据集

  • valX.npy
  • valY.npy
  • trainX_new.npy

作业

任务1

请至少使用两种方法 (autoencoder 架构、optimizer、data preprocessing、后续降维方法、clustering 算法等等) 来改进 baseline code 的 accuracy。

  • 记录改进前、后的 accuracy 分别为多少。
  • 使用改进前、后的方法,分别将 val data 的降维结果 (embedding) 与他们对应的 label 画出来。
    在这里插入图片描述

任务2

使用你 accuracy 最高的 autoencoder,从 trainX 中,取出 index 1, 2, 3, 6, 7, 9 这 6 张图片。

  • 画出他们的原图以及 reconstruct 之后的图片。
    在这里插入图片描述

任务3

在 autoencoder 的训练过程中,至少挑选 10 个 checkpoints。

  • 请用 model 的 reconstruction error (用所有的 trainX 计算 MSE) 和 val accuracy 对那些 checkpoints 作图。
    在这里插入图片描述

数据

请同学以 np.load() 读入资料,valX.npy 和 valY.npy 只用来检验我们的训练效果,不能用来训练。

trainX.npy

  • 里面总共有 8500 张 RGB 图片,大小都是 32 * 32 * 3
  • shape 为 (8500, 32, 32, 3)

valX.npy

  • 请不要用来训练
    • 里面总共有 500 张 RGB 图片,大小都是 32 * 32 * 3
    • shape 为 (500, 32, 32, 3)

valY.npy

  • 请不要用来训练
  • 对应 valX.npy 的 label
  • shape为 (500,)

下载数据集

创建 checkpoints文件夹

#!gdown --id '1BZb2AqOHHaad7Mo82St1qTBaXo_xtcUc' --output trainX.npy 
# !gdown --id '152NKCpj8S_zuIx3bQy0NN5oqpvBjdPIq' --output valX.npy 
# !gdown --id '1_hRGsFtm5KEazUg2ZvPZcuNScGF-ANh4' --output valY.npy 
!mkdir checkpoints
!ls
mkdir: 无法创建目录"checkpoints": 文件已存在
checkpoints	       trainX.npy
p1_baseline.png        valX.npy
prediction.csv	       valY.npy
prediction_invert.csv  李宏毅机器学习2020-作业9:无监督学习.ipynb

准备训练数据

定义我们的 preprocess:将图片的数值介于 0~255 的 int 线性转为 -1~1 的 float。

import numpy as npdef preprocess(image_list):""" Normalize Image and Permute (N,H,W,C) to (N,C,H,W)Args:image_list: List of images (9000, 32, 32, 3)Returns:image_list: List of images (9000, 3, 32, 32)"""image_list = np.array(image_list)image_list = np.transpose(image_list, (0, 3, 1, 2))image_list = (image_list / 255.0) * 2 - 1image_list = image_list.astype(np.float32)return image_list

自定义Dataset

from torch.utils.data import Datasetclass Image_Dataset(Dataset):def __init__(self, image_list):self.image_list = image_listdef __len__(self):return len(self.image_list)def __getitem__(self, idx):images = self.image_list[idx]return images

将训练资料读入,并且 preprocess。之后我们将 preprocess 完的训练资料变成我们需要的 dataset。请同学不要使用 valX 和 valY 来训练。

from torch.utils.data import DataLoadertrainX = np.load('trainX.npy')
trainX_preprocessed = preprocess(trainX)
img_dataset = Image_Dataset(trainX_preprocessed)

一些工具函数

这边提供一些有用的 functions。一个是计算 model 参数量的(report 会用到),另一个是固定训练的随机种子(以便 reproduce)。

import random
import torchdef count_parameters(model, only_trainable=False):if only_trainable:return sum(p.numel() for p in model.parameters() if p.requires_grad)else:return sum(p.numel() for p in model.parameters())def same_seeds(seed):torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.np.random.seed(seed)  # Numpy module.random.seed(seed)  # Python random module.torch.backends.cudnn.benchmark = False #不做网络加速torch.backends.cudnn.deterministic = True #每次返回的卷积算法固定

模型

定义我们的 baseline autoencoder
ConvTranspose2d-逆卷积
在这里插入图片描述

关于模型的改进,我只是加深了一层encoder和decoder,效果会变好,参数的调整,只有epoch改为了1000

import torch.nn as nnclass AE(nn.Module):def __init__(self):super(AE, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1),nn.ReLU(True),nn.MaxPool2d(2),nn.Conv2d(64, 128, 3, stride=1, padding=1),nn.ReLU(True),nn.MaxPool2d(2),nn.Conv2d(128, 256, 3, stride=1, padding=1),nn.ReLU(True),nn.MaxPool2d(2),nn.Conv2d(256, 512, 3, stride=1, padding=1),nn.ReLU(True),nn.MaxPool2d(2))self.decoder = nn.Sequential(nn.ConvTranspose2d(512, 256, 3, stride=1),nn.ReLU(True),nn.ConvTranspose2d(256, 128, 5, stride=1),nn.ReLU(True),nn.ConvTranspose2d(128, 64, 9, stride=1),nn.ReLU(True),nn.ConvTranspose2d(64, 3, 17, stride=1),nn.Tanh())def forward(self, x):x1 = self.encoder(x)x  = self.decoder(x1)return x1, x
!nvidia-smi
Thu Nov  4 17:03:39 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03    Driver Version: 460.91.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  GeForce RTX 3090    Off  | 00000000:1A:00.0 Off |                  N/A |
| 57%   70C    P2   325W / 350W |   8107MiB / 24268MiB |     91%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 3090    Off  | 00000000:68:00.0 Off |                  N/A |
|  0%   29C    P8    25W / 350W |    299MiB / 24265MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------++-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     11145      C   python                           8103MiB |
|    1   N/A  N/A      2432      G   /usr/lib/xorg/Xorg                 14MiB |
|    1   N/A  N/A      4006      G   /usr/bin/gnome-shell               17MiB |
|    1   N/A  N/A      4984      G   /usr/lib/xorg/Xorg                 70MiB |
|    1   N/A  N/A      5058      G   /usr/lib/xorg/Xorg                 18MiB |
|    1   N/A  N/A      5233      G   /usr/bin/gnome-shell              100MiB |
|    1   N/A  N/A      5384      G   /usr/bin/gnome-shell               36MiB |
|    1   N/A  N/A      6548      G   ...2179,14311511775341437302       36MiB |
+-----------------------------------------------------------------------------+

-----------------------------+

训练

这个部分就是主要的训练阶段。我们先将准备好的 dataset 当作参数喂给 dataloader。将 dataloader、model、loss criterion、optimizer 都准备好之后,就可以开始训练。训练完成后,我们会将 model 存下来。

import torch
from torch import optimsame_seeds(0)model = AE().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)model.train()
n_epoch = 1000# 准备 dataloader, model, loss criterion 和 optimizer
img_dataloader = DataLoader(img_dataset, batch_size=64, shuffle=True)epoch_loss = 0# 主要的训练过程
for epoch in range(n_epoch):epoch_loss = 0for data in img_dataloader:img = dataimg = img.cuda()output1, output = model(img)loss = criterion(output, img)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 10 == 0:torch.save(model.state_dict(), './checkpoints/checkpoint_{}.pth'.format(epoch+1))epoch_loss += loss.item()print('epoch [{}/{}], loss:{:.5f}'.format(epoch+1, n_epoch, epoch_loss))# 训练完成后存储 model
torch.save(model.state_dict(), './checkpoints/last_checkpoint.pth')
epoch [1/1000], loss:30.54165
epoch [2/1000], loss:26.34405
epoch [3/1000], loss:21.83250
epoch [4/1000], loss:19.13653
epoch [5/1000], loss:16.89123
epoch [6/1000], loss:15.81137
epoch [7/1000], loss:15.24495
epoch [8/1000], loss:14.82142
epoch [9/1000], loss:14.43517
epoch [10/1000], loss:14.08439
epoch [11/1000], loss:13.73920
epoch [12/1000], loss:13.40639
epoch [13/1000], loss:13.08327
epoch [14/1000], loss:12.66554
epoch [15/1000], loss:12.26715
epoch [16/1000], loss:11.93717
epoch [17/1000], loss:11.67487
epoch [18/1000], loss:11.45737
epoch [19/1000], loss:11.28208
epoch [20/1000], loss:11.08628
epoch [21/1000], loss:10.94622
epoch [22/1000], loss:10.80847
epoch [23/1000], loss:10.70417
epoch [24/1000], loss:10.58255
epoch [25/1000], loss:10.48495
epoch [26/1000], loss:10.39527
epoch [27/1000], loss:10.30006
epoch [28/1000], loss:10.20910
epoch [29/1000], loss:10.13124
epoch [30/1000], loss:10.04456
epoch [31/1000], loss:9.96836
epoch [32/1000], loss:9.88246
epoch [33/1000], loss:9.81235
epoch [34/1000], loss:9.72425
epoch [35/1000], loss:9.65545
epoch [36/1000], loss:9.57657
epoch [37/1000], loss:9.51310
epoch [38/1000], loss:9.45421
epoch [39/1000], loss:9.38250
epoch [40/1000], loss:9.31712
epoch [41/1000], loss:9.25833
epoch [42/1000], loss:9.20196
epoch [43/1000], loss:9.14868
epoch [44/1000], loss:9.08939
epoch [45/1000], loss:9.02597
epoch [46/1000], loss:8.95911
epoch [47/1000], loss:8.91480
epoch [48/1000], loss:8.86116
epoch [49/1000], loss:8.79443
epoch [50/1000], loss:8.73779
epoch [51/1000], loss:8.68570
epoch [52/1000], loss:8.62910
epoch [53/1000], loss:8.57338
epoch [54/1000], loss:8.53807
epoch [55/1000], loss:8.48156
epoch [56/1000], loss:8.43463
epoch [57/1000], loss:8.39641
epoch [58/1000], loss:8.34074
epoch [59/1000], loss:8.30465
epoch [60/1000], loss:8.27341
epoch [61/1000], loss:8.23230
epoch [62/1000], loss:8.18089
epoch [63/1000], loss:8.15129
epoch [64/1000], loss:8.11520
epoch [65/1000], loss:8.07959
epoch [66/1000], loss:8.04687
epoch [67/1000], loss:8.02380
epoch [68/1000], loss:7.98933
epoch [69/1000], loss:7.95649
epoch [70/1000], loss:7.92910
epoch [71/1000], loss:7.88972
epoch [72/1000], loss:7.85813
epoch [73/1000], loss:7.82851
epoch [74/1000], loss:7.81065
epoch [75/1000], loss:7.78497
epoch [76/1000], loss:7.73110
epoch [77/1000], loss:7.71461
epoch [78/1000], loss:7.68887
epoch [79/1000], loss:7.65523
epoch [80/1000], loss:7.63705
epoch [81/1000], loss:7.61096
epoch [82/1000], loss:7.57877
epoch [83/1000], loss:7.54703
epoch [84/1000], loss:7.52961
epoch [85/1000], loss:7.48876
epoch [86/1000], loss:7.46642
epoch [87/1000], loss:7.43804
epoch [88/1000], loss:7.41458
epoch [89/1000], loss:7.38298
epoch [90/1000], loss:7.38157
epoch [91/1000], loss:7.34053
epoch [92/1000], loss:7.32307
epoch [93/1000], loss:7.28897
epoch [94/1000], loss:7.27476
epoch [95/1000], loss:7.25432
epoch [96/1000], loss:7.23210
epoch [97/1000], loss:7.20764
epoch [98/1000], loss:7.17726
epoch [99/1000], loss:7.16785
epoch [100/1000], loss:7.14477
epoch [101/1000], loss:7.12776
epoch [102/1000], loss:7.10490
epoch [103/1000], loss:7.08108
epoch [104/1000], loss:7.06430
epoch [105/1000], loss:7.04382
epoch [106/1000], loss:7.01336
epoch [107/1000], loss:7.00099
epoch [108/1000], loss:6.97758
epoch [109/1000], loss:6.95376
epoch [110/1000], loss:6.94354
epoch [111/1000], loss:6.91744
epoch [112/1000], loss:6.91015
epoch [113/1000], loss:6.88055
epoch [114/1000], loss:6.86521
epoch [115/1000], loss:6.84671
epoch [116/1000], loss:6.82973
epoch [117/1000], loss:6.80817
epoch [118/1000], loss:6.78769
epoch [119/1000], loss:6.77140
epoch [120/1000], loss:6.76178
epoch [121/1000], loss:6.74296
epoch [122/1000], loss:6.71641
epoch [123/1000], loss:6.69564
epoch [124/1000], loss:6.67923
epoch [125/1000], loss:6.66339
epoch [126/1000], loss:6.64667
epoch [127/1000], loss:6.62993
epoch [128/1000], loss:6.60127
epoch [129/1000], loss:6.58229
epoch [130/1000], loss:6.57563
epoch [131/1000], loss:6.55139
epoch [132/1000], loss:6.53123
epoch [133/1000], loss:6.51448
epoch [134/1000], loss:6.49753
epoch [135/1000], loss:6.46827
epoch [136/1000], loss:6.45886
epoch [137/1000], loss:6.43451
epoch [138/1000], loss:6.41819
epoch [139/1000], loss:6.39429
epoch [140/1000], loss:6.38479
epoch [141/1000], loss:6.36964
epoch [142/1000], loss:6.34008
epoch [143/1000], loss:6.32599
epoch [144/1000], loss:6.30631
epoch [145/1000], loss:6.29071
epoch [146/1000], loss:6.27065
epoch [147/1000], loss:6.25629
epoch [148/1000], loss:6.23477
epoch [149/1000], loss:6.22027
epoch [150/1000], loss:6.20892
epoch [151/1000], loss:6.18379
epoch [152/1000], loss:6.16717
epoch [153/1000], loss:6.15294
epoch [154/1000], loss:6.13922
epoch [155/1000], loss:6.12273
epoch [156/1000], loss:6.09983
epoch [157/1000], loss:6.09613
epoch [158/1000], loss:6.08098
epoch [159/1000], loss:6.06648
epoch [160/1000], loss:6.05687
epoch [161/1000], loss:6.03163
epoch [162/1000], loss:6.00917
epoch [163/1000], loss:6.00572
epoch [164/1000], loss:5.99157
epoch [165/1000], loss:5.97707
epoch [166/1000], loss:5.96627
epoch [167/1000], loss:5.96171
epoch [168/1000], loss:5.93227
epoch [169/1000], loss:5.92656
epoch [170/1000], loss:5.92673
epoch [171/1000], loss:5.90135
epoch [172/1000], loss:5.89017
epoch [173/1000], loss:5.87263
epoch [174/1000], loss:5.86483
epoch [175/1000], loss:5.85099
epoch [176/1000], loss:5.83615
epoch [177/1000], loss:5.83101
epoch [178/1000], loss:5.82030
epoch [179/1000], loss:5.82544
epoch [180/1000], loss:5.78977
epoch [181/1000], loss:5.78293
epoch [182/1000], loss:5.77460
epoch [183/1000], loss:5.76192
epoch [184/1000], loss:5.75049
epoch [185/1000], loss:5.74188
epoch [186/1000], loss:5.73882
epoch [187/1000], loss:5.72205
epoch [188/1000], loss:5.70864
epoch [189/1000], loss:5.70273
epoch [190/1000], loss:5.69353
epoch [191/1000], loss:5.68343
epoch [192/1000], loss:5.67216
epoch [193/1000], loss:5.66239
epoch [194/1000], loss:5.65125
epoch [195/1000], loss:5.63932
epoch [196/1000], loss:5.63388
epoch [197/1000], loss:5.62116
epoch [198/1000], loss:5.61385
epoch [199/1000], loss:5.61483
epoch [200/1000], loss:5.59609
epoch [201/1000], loss:5.57955
epoch [202/1000], loss:5.57469
epoch [203/1000], loss:5.56383
epoch [204/1000], loss:5.55489
epoch [205/1000], loss:5.54320
epoch [206/1000], loss:5.52971
epoch [207/1000], loss:5.53083
epoch [208/1000], loss:5.51958
epoch [209/1000], loss:5.50594
epoch [210/1000], loss:5.50194
epoch [211/1000], loss:5.49632
epoch [212/1000], loss:5.47642
epoch [213/1000], loss:5.47105
epoch [214/1000], loss:5.46103
epoch [215/1000], loss:5.45136
epoch [216/1000], loss:5.44927
epoch [217/1000], loss:5.43372
epoch [218/1000], loss:5.43229
epoch [219/1000], loss:5.41293
epoch [220/1000], loss:5.40677
epoch [221/1000], loss:5.39713
epoch [222/1000], loss:5.39402
epoch [223/1000], loss:5.38856
epoch [224/1000], loss:5.37551
epoch [225/1000], loss:5.36045
epoch [226/1000], loss:5.35389
epoch [227/1000], loss:5.34672
epoch [228/1000], loss:5.33802
epoch [229/1000], loss:5.33105
epoch [230/1000], loss:5.32277
epoch [231/1000], loss:5.30828
epoch [232/1000], loss:5.29910
epoch [233/1000], loss:5.29399
epoch [234/1000], loss:5.28984
epoch [235/1000], loss:5.27597
epoch [236/1000], loss:5.26934
epoch [237/1000], loss:5.26663
epoch [238/1000], loss:5.25943
epoch [239/1000], loss:5.24395
epoch [240/1000], loss:5.24214
epoch [241/1000], loss:5.23017
epoch [242/1000], loss:5.21525
epoch [243/1000], loss:5.21001
epoch [244/1000], loss:5.20533
epoch [245/1000], loss:5.19778
epoch [246/1000], loss:5.19444
epoch [247/1000], loss:5.17834
epoch [248/1000], loss:5.17032
epoch [249/1000], loss:5.16573
epoch [250/1000], loss:5.16030
epoch [251/1000], loss:5.15691
epoch [252/1000], loss:5.14337
epoch [253/1000], loss:5.13357
epoch [254/1000], loss:5.12614
epoch [255/1000], loss:5.12397
epoch [256/1000], loss:5.11111
epoch [257/1000], loss:5.09905
epoch [258/1000], loss:5.09718
epoch [259/1000], loss:5.09271
epoch [260/1000], loss:5.08443
epoch [261/1000], loss:5.07630
epoch [262/1000], loss:5.06473
epoch [263/1000], loss:5.06329
epoch [264/1000], loss:5.05452
epoch [265/1000], loss:5.04306
epoch [266/1000], loss:5.04899
epoch [267/1000], loss:5.03139
epoch [268/1000], loss:5.02383
epoch [269/1000], loss:5.01982
epoch [270/1000], loss:5.01273
epoch [271/1000], loss:5.00642
epoch [272/1000], loss:4.99454
epoch [273/1000], loss:4.99690
epoch [274/1000], loss:4.98375
epoch [275/1000], loss:4.98370
epoch [276/1000], loss:4.96812
epoch [277/1000], loss:4.96210
epoch [278/1000], loss:4.96167
epoch [279/1000], loss:4.94264
epoch [280/1000], loss:4.94708
epoch [281/1000], loss:4.93381
epoch [282/1000], loss:4.92656
epoch [283/1000], loss:4.92751
epoch [284/1000], loss:4.91519
epoch [285/1000], loss:4.90649
epoch [286/1000], loss:4.90130
epoch [287/1000], loss:4.89965
epoch [288/1000], loss:4.88647
epoch [289/1000], loss:4.88522
epoch [290/1000], loss:4.87119
epoch [291/1000], loss:4.86967
epoch [292/1000], loss:4.86545
epoch [293/1000], loss:4.85670
epoch [294/1000], loss:4.84635
epoch [295/1000], loss:4.84253
epoch [296/1000], loss:4.84705
epoch [297/1000], loss:4.82709
epoch [298/1000], loss:4.82251
epoch [299/1000], loss:4.81915
epoch [300/1000], loss:4.81493
epoch [301/1000], loss:4.80140
epoch [302/1000], loss:4.79302
epoch [303/1000], loss:4.79099
epoch [304/1000], loss:4.78271
epoch [305/1000], loss:4.77509
epoch [306/1000], loss:4.76755
epoch [307/1000], loss:4.76485
epoch [308/1000], loss:4.76169
epoch [309/1000], loss:4.75328
epoch [310/1000], loss:4.74254
epoch [311/1000], loss:4.74224
epoch [312/1000], loss:4.74067
epoch [313/1000], loss:4.72933
epoch [314/1000], loss:4.71486
epoch [315/1000], loss:4.71784
epoch [316/1000], loss:4.70222
epoch [317/1000], loss:4.70290
epoch [318/1000], loss:4.69542
epoch [319/1000], loss:4.69025
epoch [320/1000], loss:4.68246
epoch [321/1000], loss:4.67295
epoch [322/1000], loss:4.67523
epoch [323/1000], loss:4.67207
epoch [324/1000], loss:4.66636
epoch [325/1000], loss:4.64616
epoch [326/1000], loss:4.64512
epoch [327/1000], loss:4.64286
epoch [328/1000], loss:4.63428
epoch [329/1000], loss:4.62759
epoch [330/1000], loss:4.62275
epoch [331/1000], loss:4.61570
epoch [332/1000], loss:4.61228
epoch [333/1000], loss:4.60109
epoch [334/1000], loss:4.60413
epoch [335/1000], loss:4.58950
epoch [336/1000], loss:4.59071
epoch [337/1000], loss:4.58295
epoch [338/1000], loss:4.57782
epoch [339/1000], loss:4.57129
epoch [340/1000], loss:4.56505
epoch [341/1000], loss:4.56037
epoch [342/1000], loss:4.55598
epoch [343/1000], loss:4.54537
epoch [344/1000], loss:4.54019
epoch [345/1000], loss:4.53571
epoch [346/1000], loss:4.53185
epoch [347/1000], loss:4.53183
epoch [348/1000], loss:4.52009
epoch [349/1000], loss:4.51411
epoch [350/1000], loss:4.50916
epoch [351/1000], loss:4.50595
epoch [352/1000], loss:4.50171
epoch [353/1000], loss:4.49431
epoch [354/1000], loss:4.48945
epoch [355/1000], loss:4.48904
epoch [356/1000], loss:4.47484
epoch [357/1000], loss:4.47601
epoch [358/1000], loss:4.46283
epoch [359/1000], loss:4.46043
epoch [360/1000], loss:4.45623
epoch [361/1000], loss:4.45144
epoch [387/1000], loss:4.32588
epoch [388/1000], loss:4.31738
epoch [389/1000], loss:4.31798
epoch [390/1000], loss:4.31714
epoch [391/1000], loss:4.30985
epoch [392/1000], loss:4.29957
epoch [393/1000], loss:4.29696
epoch [394/1000], loss:4.29420
epoch [395/1000], loss:4.28667
epoch [396/1000], loss:4.28612
epoch [397/1000], loss:4.27635
epoch [398/1000], loss:4.27332
epoch [399/1000], loss:4.27225
epoch [400/1000], loss:4.26569
epoch [401/1000], loss:4.26683
epoch [402/1000], loss:4.25562
epoch [403/1000], loss:4.24940
epoch [404/1000], loss:4.24415
epoch [405/1000], loss:4.24422
epoch [406/1000], loss:4.24053
epoch [407/1000], loss:4.23612
epoch [408/1000], loss:4.23212
epoch [409/1000], loss:4.23014
epoch [410/1000], loss:4.22054
epoch [411/1000], loss:4.21572
epoch [412/1000], loss:4.21339
epoch [413/1000], loss:4.20922
epoch [414/1000], loss:4.20910
epoch [415/1000], loss:4.20353
epoch [416/1000], loss:4.19610
epoch [417/1000], loss:4.19232
epoch [418/1000], loss:4.18926
epoch [419/1000], loss:4.18134
epoch [420/1000], loss:4.17638
epoch [421/1000], loss:4.17397
epoch [422/1000], loss:4.17142
epoch [423/1000], loss:4.16676
epoch [424/1000], loss:4.17102
epoch [425/1000], loss:4.15542
epoch [426/1000], loss:4.15438
epoch [427/1000], loss:4.15161
epoch [428/1000], loss:4.14431
epoch [429/1000], loss:4.14308
epoch [430/1000], loss:4.14248
epoch [431/1000], loss:4.13705
epoch [432/1000], loss:4.13069
epoch [433/1000], loss:4.12359
epoch [434/1000], loss:4.12440
epoch [435/1000], loss:4.12047
epoch [436/1000], loss:4.11715
epoch [437/1000], loss:4.11095
epoch [438/1000], loss:4.10556
epoch [439/1000], loss:4.10342
epoch [440/1000], loss:4.10314
epoch [441/1000], loss:4.09450
epoch [442/1000], loss:4.08683
epoch [443/1000], loss:4.08545
epoch [444/1000], loss:4.08673
epoch [445/1000], loss:4.07830
epoch [446/1000], loss:4.07518
epoch [447/1000], loss:4.06704
epoch [448/1000], loss:4.06815
epoch [449/1000], loss:4.06158
epoch [450/1000], loss:4.06410
epoch [451/1000], loss:4.05870
epoch [452/1000], loss:4.05462
epoch [453/1000], loss:4.04799
epoch [454/1000], loss:4.04455
epoch [455/1000], loss:4.03678
epoch [456/1000], loss:4.04038
epoch [457/1000], loss:4.03390
epoch [458/1000], loss:4.02727
epoch [459/1000], loss:4.02408
epoch [460/1000], loss:4.02337
epoch [461/1000], loss:4.01824
epoch [462/1000], loss:4.01433
epoch [463/1000], loss:4.00995
epoch [464/1000], loss:4.00826
epoch [465/1000], loss:4.00209
epoch [466/1000], loss:4.00384
epoch [467/1000], loss:3.99173
epoch [468/1000], loss:3.99856
epoch [469/1000], loss:3.99148
epoch [470/1000], loss:3.98304
epoch [471/1000], loss:3.98313
epoch [472/1000], loss:3.97725
epoch [473/1000], loss:3.97736
epoch [474/1000], loss:3.97326
epoch [475/1000], loss:3.96900
epoch [476/1000], loss:3.96096
epoch [477/1000], loss:3.96076
epoch [478/1000], loss:3.96005
epoch [479/1000], loss:3.95441
epoch [480/1000], loss:3.95287
epoch [481/1000], loss:3.94587
epoch [482/1000], loss:3.94024
epoch [483/1000], loss:3.93922
epoch [484/1000], loss:3.93559
epoch [485/1000], loss:3.93831
epoch [486/1000], loss:3.92520
epoch [487/1000], loss:3.92634
epoch [488/1000], loss:3.92151
epoch [489/1000], loss:3.91649
epoch [490/1000], loss:3.91573
epoch [491/1000], loss:3.91516
epoch [492/1000], loss:3.90679
epoch [493/1000], loss:3.90961
epoch [494/1000], loss:3.89975
epoch [495/1000], loss:3.89675
epoch [496/1000], loss:3.89311
epoch [497/1000], loss:3.89344
epoch [498/1000], loss:3.89109
epoch [499/1000], loss:3.88556
epoch [500/1000], loss:3.87982
epoch [501/1000], loss:3.87826
epoch [502/1000], loss:3.87651
epoch [503/1000], loss:3.87134
epoch [504/1000], loss:3.86625
epoch [505/1000], loss:3.86563
epoch [506/1000], loss:3.86109
epoch [507/1000], loss:3.86168
epoch [508/1000], loss:3.85732
epoch [509/1000], loss:3.84998
epoch [510/1000], loss:3.85233
epoch [511/1000], loss:3.84760
epoch [512/1000], loss:3.84713
epoch [513/1000], loss:3.83537
epoch [514/1000], loss:3.83900
epoch [515/1000], loss:3.82796
epoch [516/1000], loss:3.82622
epoch [517/1000], loss:3.83100
epoch [518/1000], loss:3.82413
epoch [519/1000], loss:3.81903
epoch [520/1000], loss:3.81732
epoch [521/1000], loss:3.81084
epoch [522/1000], loss:3.81144
epoch [523/1000], loss:3.80305
epoch [524/1000], loss:3.80411
epoch [525/1000], loss:3.80302
epoch [526/1000], loss:3.79430
epoch [527/1000], loss:3.79282
epoch [528/1000], loss:3.79408
epoch [529/1000], loss:3.79307
epoch [530/1000], loss:3.78673
epoch [531/1000], loss:3.78254
epoch [532/1000], loss:3.77649
epoch [533/1000], loss:3.77460
epoch [534/1000], loss:3.77207
epoch [535/1000], loss:3.76966
epoch [536/1000], loss:3.76757
epoch [537/1000], loss:3.76382
epoch [538/1000], loss:3.75726
epoch [539/1000], loss:3.76330
epoch [540/1000], loss:3.75130
epoch [541/1000], loss:3.74979
epoch [542/1000], loss:3.74968
epoch [543/1000], loss:3.73983
epoch [544/1000], loss:3.73901
epoch [545/1000], loss:3.73932
epoch [546/1000], loss:3.73718
epoch [547/1000], loss:3.73794
epoch [548/1000], loss:3.72818
epoch [549/1000], loss:3.72528
epoch [550/1000], loss:3.72475
epoch [551/1000], loss:3.71988
epoch [552/1000], loss:3.71729
epoch [553/1000], loss:3.71119
epoch [554/1000], loss:3.71207
epoch [555/1000], loss:3.71167
epoch [556/1000], loss:3.70275
epoch [557/1000], loss:3.70654
epoch [558/1000], loss:3.69792
epoch [559/1000], loss:3.69927
epoch [560/1000], loss:3.69409
epoch [561/1000], loss:3.69188
epoch [562/1000], loss:3.68632
epoch [563/1000], loss:3.68308
epoch [564/1000], loss:3.68161
epoch [565/1000], loss:3.68463
epoch [566/1000], loss:3.67181
epoch [567/1000], loss:3.67101
epoch [568/1000], loss:3.66956
epoch [569/1000], loss:3.66723
epoch [570/1000], loss:3.66829
epoch [571/1000], loss:3.66422
epoch [572/1000], loss:3.66120
epoch [573/1000], loss:3.65323
epoch [574/1000], loss:3.65280
epoch [575/1000], loss:3.65279
epoch [576/1000], loss:3.64698
epoch [577/1000], loss:3.64525
epoch [578/1000], loss:3.64385
epoch [579/1000], loss:3.63892
epoch [580/1000], loss:3.63570
epoch [581/1000], loss:3.63038
epoch [582/1000], loss:3.63306
epoch [583/1000], loss:3.62456
epoch [584/1000], loss:3.62961
epoch [585/1000], loss:3.61710
epoch [586/1000], loss:3.62218
epoch [587/1000], loss:3.61367
epoch [588/1000], loss:3.61351
epoch [589/1000], loss:3.61048
epoch [590/1000], loss:3.60863
epoch [591/1000], loss:3.60503
epoch [592/1000], loss:3.60068
epoch [593/1000], loss:3.59856
epoch [594/1000], loss:3.59472
epoch [595/1000], loss:3.59365
epoch [596/1000], loss:3.59324
epoch [597/1000], loss:3.58769
epoch [598/1000], loss:3.58214
epoch [599/1000], loss:3.58244
epoch [600/1000], loss:3.57799
epoch [601/1000], loss:3.57877
epoch [602/1000], loss:3.57055
epoch [603/1000], loss:3.57307
epoch [604/1000], loss:3.57202
epoch [605/1000], loss:3.56517
epoch [606/1000], loss:3.56280
epoch [607/1000], loss:3.56200
epoch [608/1000], loss:3.56267
epoch [609/1000], loss:3.55470
epoch [610/1000], loss:3.55250
epoch [611/1000], loss:3.54826
epoch [612/1000], loss:3.55154
epoch [613/1000], loss:3.54208
epoch [614/1000], loss:3.54206
epoch [615/1000], loss:3.54105
epoch [616/1000], loss:3.53665
epoch [617/1000], loss:3.53198
epoch [618/1000], loss:3.52956
epoch [619/1000], loss:3.52716
epoch [620/1000], loss:3.52535
epoch [621/1000], loss:3.52693
epoch [622/1000], loss:3.51926
epoch [623/1000], loss:3.51655
epoch [624/1000], loss:3.51352
epoch [625/1000], loss:3.51410
epoch [626/1000], loss:3.50871
epoch [627/1000], loss:3.50490
epoch [628/1000], loss:3.50470
epoch [629/1000], loss:3.50429
epoch [630/1000], loss:3.50063
epoch [631/1000], loss:3.49522
epoch [632/1000], loss:3.49489
epoch [633/1000], loss:3.49385
epoch [634/1000], loss:3.48804
epoch [635/1000], loss:3.48522
epoch [636/1000], loss:3.48331
epoch [637/1000], loss:3.47941
epoch [638/1000], loss:3.47592
epoch [639/1000], loss:3.47459
epoch [640/1000], loss:3.47359
epoch [641/1000], loss:3.47270
epoch [642/1000], loss:3.46967
epoch [643/1000], loss:3.46600
epoch [644/1000], loss:3.46549
epoch [645/1000], loss:3.46019
epoch [646/1000], loss:3.45748
epoch [647/1000], loss:3.45389
epoch [648/1000], loss:3.44896
epoch [649/1000], loss:3.44991
epoch [650/1000], loss:3.44311
epoch [651/1000], loss:3.44865
epoch [652/1000], loss:3.44133
epoch [653/1000], loss:3.43858
epoch [654/1000], loss:3.44189
epoch [655/1000], loss:3.43480
epoch [656/1000], loss:3.43255
epoch [657/1000], loss:3.42989
epoch [658/1000], loss:3.42864
epoch [659/1000], loss:3.42396
epoch [660/1000], loss:3.42112
epoch [661/1000], loss:3.42302
epoch [662/1000], loss:3.41736
epoch [663/1000], loss:3.41416
epoch [664/1000], loss:3.41132
epoch [665/1000], loss:3.41046
epoch [666/1000], loss:3.40492
epoch [667/1000], loss:3.40502
epoch [668/1000], loss:3.40614
epoch [669/1000], loss:3.40063
epoch [670/1000], loss:3.40028
epoch [671/1000], loss:3.39271
epoch [672/1000], loss:3.39536
epoch [673/1000], loss:3.39127
epoch [674/1000], loss:3.38746
epoch [675/1000], loss:3.38874
epoch [676/1000], loss:3.38427
epoch [677/1000], loss:3.38143
epoch [678/1000], loss:3.37742
epoch [679/1000], loss:3.37587
epoch [680/1000], loss:3.37513
epoch [681/1000], loss:3.37196
epoch [682/1000], loss:3.36916
epoch [683/1000], loss:3.36594
epoch [684/1000], loss:3.36606
epoch [685/1000], loss:3.36292
epoch [686/1000], loss:3.35892
epoch [687/1000], loss:3.35532
epoch [688/1000], loss:3.35597
epoch [689/1000], loss:3.35689
epoch [690/1000], loss:3.34953
epoch [691/1000], loss:3.34964
epoch [692/1000], loss:3.34474
epoch [693/1000], loss:3.34500
epoch [694/1000], loss:3.34074
epoch [695/1000], loss:3.34088
epoch [696/1000], loss:3.33748
epoch [697/1000], loss:3.33662
epoch [698/1000], loss:3.33202
epoch [699/1000], loss:3.33229
epoch [700/1000], loss:3.32739
epoch [701/1000], loss:3.32630
epoch [702/1000], loss:3.32807
epoch [703/1000], loss:3.32146
epoch [704/1000], loss:3.31806
epoch [705/1000], loss:3.31831
epoch [706/1000], loss:3.31332
epoch [707/1000], loss:3.31269
epoch [708/1000], loss:3.30964
epoch [709/1000], loss:3.30984
epoch [710/1000], loss:3.30538
epoch [711/1000], loss:3.30281
epoch [712/1000], loss:3.30262
epoch [713/1000], loss:3.29772
epoch [714/1000], loss:3.29625
epoch [715/1000], loss:3.29219
epoch [716/1000], loss:3.29506
epoch [717/1000], loss:3.28936
epoch [718/1000], loss:3.28897
epoch [719/1000], loss:3.29049
epoch [720/1000], loss:3.28375
epoch [721/1000], loss:3.28123
epoch [722/1000], loss:3.27900
epoch [723/1000], loss:3.27359
epoch [724/1000], loss:3.27611
epoch [725/1000], loss:3.27433
epoch [726/1000], loss:3.27112
epoch [727/1000], loss:3.26646
epoch [728/1000], loss:3.26737
epoch [729/1000], loss:3.26536
epoch [730/1000], loss:3.26612
epoch [731/1000], loss:3.26075
epoch [732/1000], loss:3.26027
epoch [733/1000], loss:3.25291
epoch [734/1000], loss:3.25916
epoch [735/1000], loss:3.24919
epoch [736/1000], loss:3.25470
epoch [737/1000], loss:3.24516
epoch [738/1000], loss:3.24314
epoch [739/1000], loss:3.24429
epoch [740/1000], loss:3.24261
epoch [741/1000], loss:3.23813
epoch [742/1000], loss:3.23578
epoch [743/1000], loss:3.23666
epoch [744/1000], loss:3.23200
epoch [745/1000], loss:3.23238
epoch [746/1000], loss:3.22988
epoch [747/1000], loss:3.22826
epoch [748/1000], loss:3.23023
epoch [749/1000], loss:3.22209
epoch [750/1000], loss:3.21966
epoch [751/1000], loss:3.21754
epoch [752/1000], loss:3.21620
epoch [753/1000], loss:3.21760
epoch [754/1000], loss:3.21165
epoch [755/1000], loss:3.21131
epoch [756/1000], loss:3.21038
epoch [757/1000], loss:3.20712
epoch [758/1000], loss:3.20317
epoch [759/1000], loss:3.20223
epoch [760/1000], loss:3.20180
epoch [761/1000], loss:3.20010
epoch [762/1000], loss:3.19946
epoch [763/1000], loss:3.19183
epoch [764/1000], loss:3.19291
epoch [765/1000], loss:3.18863
epoch [766/1000], loss:3.18918
epoch [767/1000], loss:3.18898
epoch [768/1000], loss:3.18414
epoch [769/1000], loss:3.18572
epoch [770/1000], loss:3.18738
epoch [771/1000], loss:3.17861
epoch [772/1000], loss:3.17652
epoch [773/1000], loss:3.17587
epoch [774/1000], loss:3.17144
epoch [775/1000], loss:3.17319
epoch [776/1000], loss:3.17009
epoch [777/1000], loss:3.16943
epoch [778/1000], loss:3.16559
epoch [779/1000], loss:3.16415
epoch [780/1000], loss:3.16417
epoch [781/1000], loss:3.16414
epoch [782/1000], loss:3.15878
epoch [783/1000], loss:3.15620
epoch [784/1000], loss:3.15162
epoch [785/1000], loss:3.15188
epoch [786/1000], loss:3.15056
epoch [787/1000], loss:3.14792
epoch [788/1000], loss:3.14884
epoch [789/1000], loss:3.14594
epoch [790/1000], loss:3.14544
epoch [791/1000], loss:3.14156
epoch [792/1000], loss:3.13851
epoch [793/1000], loss:3.13792
epoch [794/1000], loss:3.13770
epoch [795/1000], loss:3.13333
epoch [796/1000], loss:3.13036
epoch [797/1000], loss:3.12862
epoch [798/1000], loss:3.13088
epoch [799/1000], loss:3.12679
epoch [800/1000], loss:3.12329
epoch [801/1000], loss:3.12549
epoch [802/1000], loss:3.12244
epoch [803/1000], loss:3.11828
epoch [804/1000], loss:3.11357
epoch [805/1000], loss:3.11698
epoch [806/1000], loss:3.11326
epoch [807/1000], loss:3.11584
epoch [808/1000], loss:3.10921
epoch [809/1000], loss:3.10769
epoch [810/1000], loss:3.10721
epoch [811/1000], loss:3.10426
epoch [812/1000], loss:3.10207
epoch [813/1000], loss:3.09837
epoch [814/1000], loss:3.09836
epoch [815/1000], loss:3.09801
epoch [816/1000], loss:3.09438
epoch [817/1000], loss:3.09267
epoch [818/1000], loss:3.09224
epoch [819/1000], loss:3.08851
epoch [820/1000], loss:3.08578
epoch [821/1000], loss:3.08942
epoch [822/1000], loss:3.08425
epoch [823/1000], loss:3.08528
epoch [824/1000], loss:3.08140
epoch [825/1000], loss:3.07830
epoch [826/1000], loss:3.07588
epoch [827/1000], loss:3.07775
epoch [828/1000], loss:3.07456
epoch [829/1000], loss:3.07019
epoch [830/1000], loss:3.07405
epoch [831/1000], loss:3.06494
epoch [832/1000], loss:3.06572
epoch [833/1000], loss:3.06405
epoch [834/1000], loss:3.06366
epoch [835/1000], loss:3.05963
epoch [836/1000], loss:3.05978
epoch [837/1000], loss:3.05587
epoch [838/1000], loss:3.05641
epoch [839/1000], loss:3.05452
epoch [840/1000], loss:3.05307
epoch [841/1000], loss:3.04878
epoch [842/1000], loss:3.05134
epoch [843/1000], loss:3.04592
epoch [844/1000], loss:3.04432
epoch [845/1000], loss:3.04292
epoch [846/1000], loss:3.04020
epoch [847/1000], loss:3.04101
epoch [848/1000], loss:3.04131
epoch [849/1000], loss:3.03655
epoch [850/1000], loss:3.03434
epoch [851/1000], loss:3.03037
epoch [852/1000], loss:3.03011
epoch [853/1000], loss:3.03031
epoch [854/1000], loss:3.02658
epoch [855/1000], loss:3.02762
epoch [856/1000], loss:3.02805
epoch [857/1000], loss:3.02052
epoch [858/1000], loss:3.02101
epoch [859/1000], loss:3.01820
epoch [860/1000], loss:3.01740
epoch [861/1000], loss:3.01673
epoch [862/1000], loss:3.01265
epoch [863/1000], loss:3.00953
epoch [864/1000], loss:3.01045
epoch [865/1000], loss:3.00850
epoch [866/1000], loss:3.01031
epoch [867/1000], loss:3.00408
epoch [868/1000], loss:3.00111
epoch [869/1000], loss:3.00130
epoch [870/1000], loss:3.00163
epoch [871/1000], loss:2.99810
epoch [872/1000], loss:2.99874
epoch [873/1000], loss:2.99178
epoch [874/1000], loss:2.99280
epoch [875/1000], loss:2.99230
epoch [876/1000], loss:2.98815
epoch [877/1000], loss:2.98851
epoch [878/1000], loss:2.98612
epoch [879/1000], loss:2.98797
epoch [880/1000], loss:2.98337
epoch [881/1000], loss:2.98161
epoch [882/1000], loss:2.98003
epoch [883/1000], loss:2.97484
epoch [884/1000], loss:2.97611
epoch [885/1000], loss:2.97621
epoch [886/1000], loss:2.97396
epoch [887/1000], loss:2.96927
epoch [888/1000], loss:2.96680
epoch [889/1000], loss:2.96926
epoch [890/1000], loss:2.96575
epoch [891/1000], loss:2.96431
epoch [892/1000], loss:2.96193
epoch [893/1000], loss:2.95761
epoch [894/1000], loss:2.96028
epoch [895/1000], loss:2.96046
epoch [896/1000], loss:2.95814
epoch [897/1000], loss:2.95228
epoch [898/1000], loss:2.94921
epoch [899/1000], loss:2.95213
epoch [900/1000], loss:2.94890
epoch [901/1000], loss:2.94738
epoch [902/1000], loss:2.94390
epoch [903/1000], loss:2.94118
epoch [904/1000], loss:2.94426
epoch [905/1000], loss:2.94239
epoch [906/1000], loss:2.93883
epoch [907/1000], loss:2.93823
epoch [908/1000], loss:2.93640
epoch [909/1000], loss:2.93234
epoch [910/1000], loss:2.93235
epoch [911/1000], loss:2.92981
epoch [912/1000], loss:2.93039
epoch [913/1000], loss:2.93373
epoch [914/1000], loss:2.92795
epoch [915/1000], loss:2.92420
epoch [916/1000], loss:2.92136
epoch [917/1000], loss:2.91813
epoch [918/1000], loss:2.91754
epoch [919/1000], loss:2.91795
epoch [920/1000], loss:2.91643
epoch [921/1000], loss:2.91321
epoch [922/1000], loss:2.91369
epoch [923/1000], loss:2.91094
epoch [924/1000], loss:2.91049
epoch [925/1000], loss:2.90867
epoch [926/1000], loss:2.90595
epoch [927/1000], loss:2.90455
epoch [928/1000], loss:2.90523
epoch [929/1000], loss:2.90355
epoch [930/1000], loss:2.90085
epoch [931/1000], loss:2.89791
epoch [932/1000], loss:2.89439
epoch [933/1000], loss:2.89587
epoch [934/1000], loss:2.89358
epoch [935/1000], loss:2.89229
epoch [936/1000], loss:2.88939
epoch [937/1000], loss:2.89070
epoch [938/1000], loss:2.88834
epoch [939/1000], loss:2.88700
epoch [940/1000], loss:2.88633
epoch [941/1000], loss:2.88195
epoch [942/1000], loss:2.88308
epoch [943/1000], loss:2.87824
epoch [944/1000], loss:2.87709
epoch [945/1000], loss:2.87709
epoch [946/1000], loss:2.87699
epoch [947/1000], loss:2.87330
epoch [948/1000], loss:2.87141
epoch [949/1000], loss:2.87136
epoch [950/1000], loss:2.86982
epoch [951/1000], loss:2.86829
epoch [952/1000], loss:2.86615
epoch [953/1000], loss:2.86325
epoch [954/1000], loss:2.86094
epoch [955/1000], loss:2.86219
epoch [956/1000], loss:2.85894
epoch [957/1000], loss:2.86180
epoch [958/1000], loss:2.85887
epoch [959/1000], loss:2.85384
epoch [960/1000], loss:2.85410
epoch [961/1000], loss:2.85243
epoch [962/1000], loss:2.85051
epoch [963/1000], loss:2.84668
epoch [964/1000], loss:2.84494
epoch [965/1000], loss:2.84352
epoch [966/1000], loss:2.84500
epoch [967/1000], loss:2.84642
epoch [968/1000], loss:2.83922
epoch [969/1000], loss:2.83965
epoch [970/1000], loss:2.84072
epoch [971/1000], loss:2.83823
epoch [972/1000], loss:2.83543
epoch [973/1000], loss:2.83415
epoch [974/1000], loss:2.83639
epoch [975/1000], loss:2.82995
epoch [976/1000], loss:2.82914
epoch [977/1000], loss:2.82669
epoch [978/1000], loss:2.83094
epoch [979/1000], loss:2.82190
epoch [980/1000], loss:2.82548
epoch [981/1000], loss:2.82011
epoch [982/1000], loss:2.82137
epoch [983/1000], loss:2.81966
epoch [984/1000], loss:2.81743
epoch [985/1000], loss:2.81949
epoch [986/1000], loss:2.81346
epoch [987/1000], loss:2.81393
epoch [988/1000], loss:2.81204
epoch [989/1000], loss:2.81101
epoch [990/1000], loss:2.81068
epoch [991/1000], loss:2.80631
epoch [992/1000], loss:2.80828
epoch [993/1000], loss:2.80407
epoch [994/1000], loss:2.80417
epoch [995/1000], loss:2.80385
epoch [996/1000], loss:2.80122
epoch [997/1000], loss:2.80091
epoch [998/1000], loss:2.79750
epoch [999/1000], loss:2.79585
epoch [1000/1000], loss:2.79409

降维和聚类

import numpy as npdef cal_acc(gt, pred):""" Computes categorization accuracy of our task.Args:gt: Ground truth labels (9000, )pred: Predicted labels (9000, )Returns:acc: Accuracy (0~1 scalar)"""# Calculate Correct predictionscorrect = np.sum(gt == pred)acc = correct / gt.shape[0]# 因为是binary unsupervised clustering,因此取max(acc,1-acc)# 因为我们只在乎有没有成功将图片分成两群return max(acc, 1-acc)
import matplotlib.pyplot as pltdef plot_scatter(feat, label, savefig=None):""" Plot Scatter Image.Args:feat: the (x, y) coordinate of clustering result, shape: (9000, 2)label: ground truth label of image (0/1), shape: (9000,)Returns:None"""X = feat[:, 0]Y = feat[:, 1]plt.scatter(X, Y, c = label)plt.legend(loc='best')if savefig is not None:plt.savefig(savefig)plt.show()return

接着我们使用训练好的 model,来预测 testing data 的类别。

由于 testing data 与 training data 一样,因此我们使用同样的 dataset 来实作 dataloader。与 training 不同的地方在于 shuffle 这个参数值在这边是 False。

准备好 model 与 dataloader,我们就可以进行预测了。

我们只需要 encoder 的结果(latents),利用 latents 进行 clustering 之后,就可以分类了。

import torch
from sklearn.decomposition import KernelPCA
#主成分分析(PCA)
#主成分分析(Principal Component Analysis)是目前为止最流行的降维算法。首先它找到接近数据集分布的超平面,然后将所有的数据都投影到这个超平面上。
#保留最大方差的超平面
# kPCA 是无监督学习算法,因此没有明显的性能指标可以帮助我们选择最佳的核和超参数值。不过,降维通常是监督学习任务(例如分类)的准备步骤.
from sklearn.manifold import TSNE
from sklearn.cluster import MiniBatchKMeansdef inference(X, model, batch_size=256):X = preprocess(X)dataset = Image_Dataset(X)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)latents = []for i, x in enumerate(dataloader):#数据格式转换,以及取出相应格式的数据x = torch.FloatTensor(x)vec, img = model(x.cuda())if i == 0:#view()函数的功能根reshape类似,用来转换size大小。#x = x.view(batchsize, -1)中batchsize指转换后有几行,而-1指在不告诉函数有多少列的情况下,根据原tensor数据和batchsize自动分配列数。latents = vec.view(img.size()[0], -1).cpu().detach().numpy()else:latents = np.concatenate((latents, vec.view(img.size()[0], -1).cpu().detach().numpy()), axis = 0)#在零轴方向上合并print('Latents Shape:', latents.shape)return latentsdef predict(latents):# First Dimension Reduction#这里用到的rbf核函数transformer = KernelPCA(n_components=200, kernel='rbf', n_jobs=-1)#n_components:  #意义:PCA算法中所要保留的主成分个数n,也即保留下来的特征个数n#n_jobs:int型变量,并行运行的个数。 #-1:使用所有CPU. n_jobs<-1时,使用(n_cpus+1+n_jobs)个CPU#transform函数是一定可以替换为fit_transform函数的#fit_transform函数不能替换为transform函数!#fit前缀只是方便后面API调用.kpca = transformer.fit_transform(latents)print('First Reduction Shape:', kpca.shape)# # Second Dimesnion ReductionX_embedded = TSNE(n_components=2).fit_transform(kpca)print('Second Reduction Shape:', X_embedded.shape)# Clustering#n_cluster:类中心的个数,默认为8#random_state:参数为int,RandomState instance or None.用来设置生成随机数的方式 pred = MiniBatchKMeans(n_clusters=2, random_state=0).fit(X_embedded)pred = [int(i) for i in pred.labels_]pred = np.array(pred)return pred, X_embeddeddef invert(pred):return np.abs(1-pred)def save_prediction(pred, out_csv='prediction.csv'):with open(out_csv, 'w') as f:f.write('id,label\n')for i, p in enumerate(pred):f.write(f'{i},{p}\n')print(f'Save prediction to {out_csv}.')# load model
model = AE().cuda()
model.load_state_dict(torch.load('./checkpoints/last_checkpoint.pth'))
model.eval()# 准备 data
trainX = np.load('trainX.npy')# 预测答案
latents = inference(X=trainX, model=model)
pred, X_embedded = predict(latents)# 將预测結果存檔,上上传 kaggle
save_prediction(pred, 'prediction.csv')# 由于是unsupervised的二分类问题,我们只在乎有没有成功将图片分成两群
# 如果上面的档案上传kaggle后正确率不足0.5,只要将label反过来就行了
save_prediction(invert(pred), 'prediction_invert.csv')
Latents Shape: (8500, 2048)
First Reduction Shape: (8500, 200)
Second Reduction Shape: (8500, 2)
Save prediction to prediction.csv.
Save prediction to prediction_invert.csv.

问题1(作图)

将 val data 的降维结果 (embedding) 与他们对应的 label 画出来。

valX = np.load('valX.npy')
valY = np.load('valY.npy')# ==============================================
#  我们示范basline model的作图,
#  report请同学另外还要再画一张improved model的图。
# ==============================================
model.load_state_dict(torch.load('./checkpoints/last_checkpoint.pth'))
model.eval()
latents = inference(valX, model)
pred_from_latent, emb_from_latent = predict(latents)
acc_latent = cal_acc(valY, pred_from_latent)
print('The clustering accuracy is:', acc_latent)
print('The clustering result:')
plot_scatter(emb_from_latent, valY, savefig='p1_baseline.png')
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)No handles with labels found to put in legend.Second Reduction Shape: (500, 2)
The clustering accuracy is: 0.75
The clustering result:

在这里插入图片描述

问题2

使用你 test accuracy 最高的 autoencoder,从 trainX 中,取出 index 1, 2, 3, 6, 7, 9 这 6 张图片 画出他们的原图以及 reconstruct 之后的图片

import matplotlib.pyplot as plt
import numpy as np# 画出原图
plt.figure(figsize=(10,4))
indexes = [1,2,3,6,7,9]
imgs = trainX[indexes,]
for i, img in enumerate(imgs):plt.subplot(2, 6, i+1, xticks=[], yticks=[])plt.imshow(img)# 画出 reconstruct 的图
inp = torch.Tensor(trainX_preprocessed[indexes,]).cuda()
latents, recs = model(inp)
recs = ((recs+1)/2 ).cpu().detach().numpy()
recs = recs.transpose(0, 2, 3, 1)
for i, img in enumerate(recs):plt.subplot(2, 6, 6+i+1, xticks=[], yticks=[])plt.imshow(img)plt.tight_layout()

在这里插入图片描述

问题3

在 autoencoder 的训练过程中,至少挑选 10 个 checkpoints 请用 model 的 train reconstruction error 对 val accuracy 作图 简单说明你观察到的现象

import os
import glob
checkpoints_list = sorted(glob.glob('checkpoints/checkpoint_*.pth'), key= lambda x: int(os.path.splitext(os.path.basename(x))[0].split('_')[-1]))
print(checkpoints_list)
# load data
dataset = Image_Dataset(trainX_preprocessed)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)points = []
with torch.no_grad():for i, checkpoint in enumerate(checkpoints_list):print('[{}/{}] {}'.format(i+1, len(checkpoints_list), checkpoint))model.load_state_dict(torch.load(checkpoint))model.eval()err = 0n = 0for x in dataloader:x = x.cuda()_, rec = model(x)err += torch.nn.MSELoss(reduction='sum')(x, rec).item()n += x.flatten().size(0)print('Reconstruction error (MSE):', err/n)latents = inference(X=valX, model=model)pred, X_embedded = predict(latents)acc = cal_acc(valY, pred)print('Accuracy:', acc)points.append((err/n, acc))
['checkpoints/checkpoint_10.pth', 'checkpoints/checkpoint_20.pth', 'checkpoints/checkpoint_30.pth', 'checkpoints/checkpoint_40.pth', 'checkpoints/checkpoint_50.pth', 'checkpoints/checkpoint_60.pth', 'checkpoints/checkpoint_70.pth', 'checkpoints/checkpoint_80.pth', 'checkpoints/checkpoint_90.pth', 'checkpoints/checkpoint_100.pth', 'checkpoints/checkpoint_110.pth', 'checkpoints/checkpoint_120.pth', 'checkpoints/checkpoint_130.pth', 'checkpoints/checkpoint_140.pth', 'checkpoints/checkpoint_150.pth', 'checkpoints/checkpoint_160.pth', 'checkpoints/checkpoint_170.pth', 'checkpoints/checkpoint_180.pth', 'checkpoints/checkpoint_190.pth', 'checkpoints/checkpoint_200.pth', 'checkpoints/checkpoint_210.pth', 'checkpoints/checkpoint_220.pth', 'checkpoints/checkpoint_230.pth', 'checkpoints/checkpoint_240.pth', 'checkpoints/checkpoint_250.pth', 'checkpoints/checkpoint_260.pth', 'checkpoints/checkpoint_270.pth', 'checkpoints/checkpoint_280.pth', 'checkpoints/checkpoint_290.pth', 'checkpoints/checkpoint_300.pth', 'checkpoints/checkpoint_310.pth', 'checkpoints/checkpoint_320.pth', 'checkpoints/checkpoint_330.pth', 'checkpoints/checkpoint_340.pth', 'checkpoints/checkpoint_350.pth', 'checkpoints/checkpoint_360.pth', 'checkpoints/checkpoint_370.pth', 'checkpoints/checkpoint_380.pth', 'checkpoints/checkpoint_390.pth', 'checkpoints/checkpoint_400.pth', 'checkpoints/checkpoint_410.pth', 'checkpoints/checkpoint_420.pth', 'checkpoints/checkpoint_430.pth', 'checkpoints/checkpoint_440.pth', 'checkpoints/checkpoint_450.pth', 'checkpoints/checkpoint_460.pth', 'checkpoints/checkpoint_470.pth', 'checkpoints/checkpoint_480.pth', 'checkpoints/checkpoint_490.pth', 'checkpoints/checkpoint_500.pth', 'checkpoints/checkpoint_510.pth', 'checkpoints/checkpoint_520.pth', 'checkpoints/checkpoint_530.pth', 'checkpoints/checkpoint_540.pth', 'checkpoints/checkpoint_550.pth', 'checkpoints/checkpoint_560.pth', 'checkpoints/checkpoint_570.pth', 'checkpoints/checkpoint_580.pth', 'checkpoints/checkpoint_590.pth', 'checkpoints/checkpoint_600.pth', 'checkpoints/checkpoint_610.pth', 'checkpoints/checkpoint_620.pth', 'checkpoints/checkpoint_630.pth', 'checkpoints/checkpoint_640.pth', 'checkpoints/checkpoint_650.pth', 'checkpoints/checkpoint_660.pth', 'checkpoints/checkpoint_670.pth', 'checkpoints/checkpoint_680.pth', 'checkpoints/checkpoint_690.pth', 'checkpoints/checkpoint_700.pth', 'checkpoints/checkpoint_710.pth', 'checkpoints/checkpoint_720.pth', 'checkpoints/checkpoint_730.pth', 'checkpoints/checkpoint_740.pth', 'checkpoints/checkpoint_750.pth', 'checkpoints/checkpoint_760.pth', 'checkpoints/checkpoint_770.pth', 'checkpoints/checkpoint_780.pth', 'checkpoints/checkpoint_790.pth', 'checkpoints/checkpoint_800.pth', 'checkpoints/checkpoint_810.pth', 'checkpoints/checkpoint_820.pth', 'checkpoints/checkpoint_830.pth', 'checkpoints/checkpoint_840.pth', 'checkpoints/checkpoint_850.pth', 'checkpoints/checkpoint_860.pth', 'checkpoints/checkpoint_870.pth', 'checkpoints/checkpoint_880.pth', 'checkpoints/checkpoint_890.pth', 'checkpoints/checkpoint_900.pth', 'checkpoints/checkpoint_910.pth', 'checkpoints/checkpoint_920.pth', 'checkpoints/checkpoint_930.pth', 'checkpoints/checkpoint_940.pth', 'checkpoints/checkpoint_950.pth', 'checkpoints/checkpoint_960.pth', 'checkpoints/checkpoint_970.pth', 'checkpoints/checkpoint_980.pth', 'checkpoints/checkpoint_990.pth', 'checkpoints/checkpoint_1000.pth']
[1/100] checkpoints/checkpoint_10.pth
Reconstruction error (MSE): 0.10465650191961551
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.56
[2/100] checkpoints/checkpoint_20.pth
Reconstruction error (MSE): 0.08282024884691426
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.564
[3/100] checkpoints/checkpoint_30.pth
Reconstruction error (MSE): 0.07529751972123688
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.552
[4/100] checkpoints/checkpoint_40.pth
Reconstruction error (MSE): 0.06996455570295745
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[5/100] checkpoints/checkpoint_50.pth
Reconstruction error (MSE): 0.06556768215403837
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[6/100] checkpoints/checkpoint_60.pth
Reconstruction error (MSE): 0.0623151860704609
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[7/100] checkpoints/checkpoint_70.pth
Reconstruction error (MSE): 0.05941213181439568
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.536
[8/100] checkpoints/checkpoint_80.pth
Reconstruction error (MSE): 0.057350127874636184
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.52
[9/100] checkpoints/checkpoint_90.pth
Reconstruction error (MSE): 0.05522508699753705
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.516
[10/100] checkpoints/checkpoint_100.pth
Reconstruction error (MSE): 0.05384457483478621
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[11/100] checkpoints/checkpoint_110.pth
Reconstruction error (MSE): 0.05195554022695504
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.5
[12/100] checkpoints/checkpoint_120.pth
Reconstruction error (MSE): 0.05074959627787272
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[13/100] checkpoints/checkpoint_130.pth
Reconstruction error (MSE): 0.04992709094402837
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[14/100] checkpoints/checkpoint_140.pth
Reconstruction error (MSE): 0.04817914684146058
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.502
[15/100] checkpoints/checkpoint_150.pth
Reconstruction error (MSE): 0.04657277587815827
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.512
[16/100] checkpoints/checkpoint_160.pth
Reconstruction error (MSE): 0.045626810316945994
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[17/100] checkpoints/checkpoint_170.pth
Reconstruction error (MSE): 0.04440261214387183
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.548
[18/100] checkpoints/checkpoint_180.pth
Reconstruction error (MSE): 0.04345548491384469
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[19/100] checkpoints/checkpoint_190.pth
Reconstruction error (MSE): 0.04282478637321323
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[20/100] checkpoints/checkpoint_200.pth
Reconstruction error (MSE): 0.042173867076051
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[21/100] checkpoints/checkpoint_210.pth
Reconstruction error (MSE): 0.041361579988517014
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[22/100] checkpoints/checkpoint_220.pth
Reconstruction error (MSE): 0.040615920683916874
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[23/100] checkpoints/checkpoint_230.pth
Reconstruction error (MSE): 0.039873808785980826
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[24/100] checkpoints/checkpoint_240.pth
Reconstruction error (MSE): 0.03932966136932373
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[25/100] checkpoints/checkpoint_250.pth
Reconstruction error (MSE): 0.038771576432620775
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[26/100] checkpoints/checkpoint_260.pth
Reconstruction error (MSE): 0.0381339080099966
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[27/100] checkpoints/checkpoint_270.pth
Reconstruction error (MSE): 0.03751208638209923
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[28/100] checkpoints/checkpoint_280.pth
Reconstruction error (MSE): 0.037052626366708794
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[29/100] checkpoints/checkpoint_290.pth
Reconstruction error (MSE): 0.03666375287373861
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[30/100] checkpoints/checkpoint_300.pth
Reconstruction error (MSE): 0.03611967169069776
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[31/100] checkpoints/checkpoint_310.pth
Reconstruction error (MSE): 0.03564991631227381
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[32/100] checkpoints/checkpoint_320.pth
Reconstruction error (MSE): 0.035199516689076144
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[33/100] checkpoints/checkpoint_330.pth
Reconstruction error (MSE): 0.034691137108148314
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[34/100] checkpoints/checkpoint_340.pth
Reconstruction error (MSE): 0.03432022960513246
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[35/100] checkpoints/checkpoint_350.pth
Reconstruction error (MSE): 0.033855246824376725
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.52
[36/100] checkpoints/checkpoint_360.pth
Reconstruction error (MSE): 0.03339220189113243
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[37/100] checkpoints/checkpoint_370.pth
Reconstruction error (MSE): 0.03329624884736304
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.502
[38/100] checkpoints/checkpoint_380.pth
Reconstruction error (MSE): 0.03264928217495189
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[39/100] checkpoints/checkpoint_390.pth
Reconstruction error (MSE): 0.03237577991859586
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[40/100] checkpoints/checkpoint_400.pth
Reconstruction error (MSE): 0.03208851829229617
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[41/100] checkpoints/checkpoint_410.pth
Reconstruction error (MSE): 0.0316866933037253
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[42/100] checkpoints/checkpoint_420.pth
Reconstruction error (MSE): 0.031364078933117434
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[43/100] checkpoints/checkpoint_430.pth
Reconstruction error (MSE): 0.031136608348173254
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[44/100] checkpoints/checkpoint_440.pth
Reconstruction error (MSE): 0.0309632776297775
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[45/100] checkpoints/checkpoint_450.pth
Reconstruction error (MSE): 0.030496950392629587
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[46/100] checkpoints/checkpoint_460.pth
Reconstruction error (MSE): 0.030128193126005284
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[47/100] checkpoints/checkpoint_470.pth
Reconstruction error (MSE): 0.029998875262690527
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[48/100] checkpoints/checkpoint_480.pth
Reconstruction error (MSE): 0.029572404412662283
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[49/100] checkpoints/checkpoint_490.pth
Reconstruction error (MSE): 0.02939370559243595
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[50/100] checkpoints/checkpoint_500.pth
Reconstruction error (MSE): 0.02911538221321854
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
[51/100] checkpoints/checkpoint_510.pth
Reconstruction error (MSE): 0.02889633548960966
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.546
[52/100] checkpoints/checkpoint_520.pth
Reconstruction error (MSE): 0.02860628096262614
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.648
[53/100] checkpoints/checkpoint_530.pth
Reconstruction error (MSE): 0.028405724600249645
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.54
[54/100] checkpoints/checkpoint_540.pth
Reconstruction error (MSE): 0.028084655219433353
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[55/100] checkpoints/checkpoint_550.pth
Reconstruction error (MSE): 0.02798689774905934
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.538
[56/100] checkpoints/checkpoint_560.pth
Reconstruction error (MSE): 0.027731095856311276
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.552
[57/100] checkpoints/checkpoint_570.pth
Reconstruction error (MSE): 0.027528591081207875
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[58/100] checkpoints/checkpoint_580.pth
Reconstruction error (MSE): 0.02748092877631094
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.516
[59/100] checkpoints/checkpoint_590.pth
Reconstruction error (MSE): 0.027148202690423704
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.536
[60/100] checkpoints/checkpoint_600.pth
Reconstruction error (MSE): 0.02693716204400156
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.538
[61/100] checkpoints/checkpoint_610.pth
Reconstruction error (MSE): 0.02663602849548938
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.656
[62/100] checkpoints/checkpoint_620.pth
Reconstruction error (MSE): 0.026486863996468338
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[63/100] checkpoints/checkpoint_630.pth
Reconstruction error (MSE): 0.026279585034239526
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[64/100] checkpoints/checkpoint_640.pth
Reconstruction error (MSE): 0.02615043982337503
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.61
[65/100] checkpoints/checkpoint_650.pth
Reconstruction error (MSE): 0.025924385631785674
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[66/100] checkpoints/checkpoint_660.pth
Reconstruction error (MSE): 0.025687772582559023
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.59
[67/100] checkpoints/checkpoint_670.pth
Reconstruction error (MSE): 0.025555453281776577
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.628
[68/100] checkpoints/checkpoint_680.pth
Reconstruction error (MSE): 0.025691534911884983
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.648
[69/100] checkpoints/checkpoint_690.pth
Reconstruction error (MSE): 0.025101487271925984
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.708
[70/100] checkpoints/checkpoint_700.pth
Reconstruction error (MSE): 0.02504801980186911
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.732
[71/100] checkpoints/checkpoint_710.pth
Reconstruction error (MSE): 0.02484492769428328
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.752
[72/100] checkpoints/checkpoint_720.pth
Reconstruction error (MSE): 0.02478704075719796
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[73/100] checkpoints/checkpoint_730.pth
Reconstruction error (MSE): 0.02446424291648117
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.608
[74/100] checkpoints/checkpoint_740.pth
Reconstruction error (MSE): 0.024349503433003145
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[75/100] checkpoints/checkpoint_750.pth
Reconstruction error (MSE): 0.02417324640236649
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.764
[76/100] checkpoints/checkpoint_760.pth
Reconstruction error (MSE): 0.024010706882850796
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.726
[77/100] checkpoints/checkpoint_770.pth
Reconstruction error (MSE): 0.02394120900771197
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.762
[78/100] checkpoints/checkpoint_780.pth
Reconstruction error (MSE): 0.023713757514953613
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.55
[79/100] checkpoints/checkpoint_790.pth
Reconstruction error (MSE): 0.02374166191325468
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
[80/100] checkpoints/checkpoint_800.pth
Reconstruction error (MSE): 0.023461397339315977
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.64
[81/100] checkpoints/checkpoint_810.pth
Reconstruction error (MSE): 0.023291605500613943
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.756
[82/100] checkpoints/checkpoint_820.pth
Reconstruction error (MSE): 0.023138159677094105
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.526
[83/100] checkpoints/checkpoint_830.pth
Reconstruction error (MSE): 0.02306466459760479
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.792
[84/100] checkpoints/checkpoint_840.pth
Reconstruction error (MSE): 0.022922015835257138
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.664
[85/100] checkpoints/checkpoint_850.pth
Reconstruction error (MSE): 0.022727084767584706
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.67
[86/100] checkpoints/checkpoint_860.pth
Reconstruction error (MSE): 0.022709223756603166
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[87/100] checkpoints/checkpoint_870.pth
Reconstruction error (MSE): 0.022506213861353257
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.542
[88/100] checkpoints/checkpoint_880.pth
Reconstruction error (MSE): 0.022330569921755323
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.682
[89/100] checkpoints/checkpoint_890.pth
Reconstruction error (MSE): 0.022259797694636325
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.792
[90/100] checkpoints/checkpoint_900.pth
Reconstruction error (MSE): 0.022161509541904226
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[91/100] checkpoints/checkpoint_910.pth
Reconstruction error (MSE): 0.022015575642679253
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.712
[92/100] checkpoints/checkpoint_920.pth
Reconstruction error (MSE): 0.021944920754900166
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.758
[93/100] checkpoints/checkpoint_930.pth
Reconstruction error (MSE): 0.021774335898605047
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.59
[94/100] checkpoints/checkpoint_940.pth
Reconstruction error (MSE): 0.021657160151238534
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[95/100] checkpoints/checkpoint_950.pth
Reconstruction error (MSE): 0.021555810619803037
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.698
[96/100] checkpoints/checkpoint_960.pth
Reconstruction error (MSE): 0.021441521494996313
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[97/100] checkpoints/checkpoint_970.pth
Reconstruction error (MSE): 0.02138799679513071
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.802
[98/100] checkpoints/checkpoint_980.pth
Reconstruction error (MSE): 0.021166577629014558
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[99/100] checkpoints/checkpoint_990.pth
Reconstruction error (MSE): 0.02112917330685784
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.736
[100/100] checkpoints/checkpoint_1000.pth
Reconstruction error (MSE): 0.02094145799150654
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
ps = list(zip(*points))
plt.figure(figsize=(6,6))
plt.subplot(211, title='Reconstruction error (MSE)').plot(ps[0])
plt.subplot(212, title='Accuracy (val)').plot(ps[1])
plt.show()

在这里插入图片描述

精度抖动相当剧烈,无监督果然难train


http://www.ppmy.cn/news/897880.html

相关文章

纯色bmp图片生成的效率

各种编程语言生成纯色bmp图片的效率 之前使用了各种语言生成纯色bmp图片&#xff0c;这里汇总并对比下纯色bmp图片文件生成的效率。 主要指标是完成bmp文件生成的耗时时长。 为了公平客观的对比&#xff0c;通过linux的date指令获取时间&#xff0c;在执行bmp文件生成前后各…

文献解读|环境DNA揭示了海洋群落的季节变化和潜在的相互作用

TITLE&#xff1a;Environmental DNA reveals seasonal shifts andpotential interactions in a marine community 译名&#xff1a;环境DNA揭示了海洋群落的季节变化和潜在的相互作用 期刊&#xff1a;nature communications 日期&#xff1a;2020年1月 下载链接&#xff1a; …

Nature子刊:用于阿尔茨海默病痴呆评估的多模态深度学习模型

在全球范围内&#xff0c;每年有近1000万新发痴呆病例&#xff0c;其中阿尔茨海默病&#xff08;AD&#xff09;最为常见。需要新的措施来改善对各种病因导致认知障碍的个体的诊断。作者报告了一个深度学习框架&#xff0c;该框架以连续方式完成多个诊断步骤&#xff0c;以识别…

Python 还原控制SCI论文算法系列1: 基于策略迭代的自适应最优控制器设计

Python 还原控制SCI论文算法系列1&#xff1a; 基于策略迭代的自适应最优控制器设计 文章目录 Python 还原控制SCI论文算法系列1&#xff1a; 基于策略迭代的自适应最优控制器设计0.前言1.研究问题的描述1.1 经典线性系统最优控制器设计问题1.2 策略迭代问题的提出 2. 论文所提…

简单回归分析方法———基于R

最小二乘回归 w read.table > w read.table("COfreewy.txt",header T) > a lm(CO~.,w) > summary(a)Call: lm(formula CO ~ ., data w)Residuals:Min 1Q Median 3Q Max -0.75030 -0.33275 -0.09021 0.22653 1.25112 Coefficie…

教你如何用R进行数据挖掘

教你如何用R进行数据挖掘 R是一种广泛用于数据分析和统计计算的强大语言&#xff0c;于上世纪90年代开始发展起来。得益于全世界众多 爱好者的无尽努力&#xff0c;大家继而开发出了一种基于R但优于R基本文本编辑器的R Studio&#xff08;用户的界面体验更好&#xff09;。也正…

使用 pyecharts 以及matplotlib分析猫眼哪吒影评

Geo函数参数源码解析 def __add(self, name, attr, value,type"scatter",maptypechina,symbol_size12,border_color"#111",geo_normal_color"#323c48",geo_emphasis_color"#2a333d",**kwargs):参数详情可看&#xff1a;https://blog.…

第3章(3.11~3.16节)模型细节/Kaggle实战【深度学习基础】--动手学深度学习【Tensorflow2.0版本】

项目地址&#xff1a;https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 UC 伯克利李沐的《动手学深度学习》开源书一经推出便广受好评。很多开发者使用了书的内容&#xff0c;并采用各种各样的深度学习框架将其复现。 现在&#xff0c;《动手学深度学习》书又有了一个新…