【深度学习】PyTorch深度学习实践 - Lecture_13_RNN_Classifier

news/2024/12/1 0:29:44/

文章目录

  • 一、问题描述
  • 二、OurModel
  • 三、准备数据
    • 3.1 Data Convert
    • 3.2 Padding Data
    • 3.3 Label Convert
  • 四、双向RNN
  • 五、PyTorch代码实现
    • 5.1 引入相关库
    • 5.2 创建Tensors函数
    • 5.3 将名字转化为字符列表函数
    • 5.4 国家名字数据集对象
    • 5.5 RNN(GRU)分类器对象
    • 5.6 训练函数
    • 5.7 测试函数
    • 5.8 主要代码块
    • 5.9 完整代码
    • 5.10 运行输出


一、问题描述

问题:根据名字,预测其所属国家
在这里插入图片描述

二、OurModel

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

三、准备数据

3.1 Data Convert

由于输入数据全是英文字符,所以可以利用ASCII码,将字符型数据转化为数值型数据
在这里插入图片描述

3.2 Padding Data

由于数据长短不一,所以我们需要进行Padding操作,填充0以使得所有数据长度一样
在这里插入图片描述

3.3 Label Convert

在这里插入图片描述

四、双向RNN

在这里插入图片描述

五、PyTorch代码实现

5.1 引入相关库

import time
import torch
import csv
import gzip
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

5.2 创建Tensors函数

def make_tensors(names, countries):sequences_and_lengths = [name2list(name) for name in names]name_sequences = [sl[0] for sl in sequences_and_lengths]seq_lengths = torch.LongTensor([sl[1] for sl in sequences_and_lengths])countries = countries.long()seq_tensor = torch.zeros(len(name_sequences), seq_lengths.max()).long()for idx, (seq, seq_len) in enumerate(zip(name_sequences, seq_lengths), 0):seq_tensor[idx, :seq_len] = torch.LongTensor(seq)seq_lengths, perm_idx = seq_lengths.sort(dim=0, descending=True)seq_tensor = seq_tensor[perm_idx]countries = countries[perm_idx]return torch.LongTensor(seq_tensor), \torch.LongTensor(seq_lengths), \torch.LongTensor(countries)

5.3 将名字转化为字符列表函数

def name2list(name):arr = [ord(c) for c in name]return arr, len(arr)

5.4 国家名字数据集对象

class NameDataset(Dataset):def __init__(self, is_train_set=True):filename = '../dataset/names_train.csv.gz' if is_train_set else '../dataset/names_test.csv.gz'with gzip.open(filename, 'rt') as f:reader = csv.reader(f)rows = list(reader)self.names = [row[0] for row in rows]self.len = len(self.names)self.countries = [row[1] for row in rows]self.country_list = list(sorted(set(self.countries)))self.country_dict = self.getCountryDict()self.country_num = len(self.country_list)# 获取国家字典def getCountryDict(self):country_dict = dict()for idx, country_name in enumerate(self.country_list, 0):country_dict[country_name] = idxreturn country_dict# 获取国家数量def getCountriesNum(self):return self.country_num# 根据索引,返回国家的字符串def idx2country(self, index):return self.country_list[index]def __getitem__(self, index):return self.names[index], self.country_dict[self.countries[index]]def __len__(self):return self.len

5.5 RNN(GRU)分类器对象

class RNNClassifier(torch.nn.Module):def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):super(RNNClassifier, self).__init__()self.hidden_size = hidden_sizeself.n_layers = n_layersself.n_directions = 2 if bidirectional else 1self.embedding = torch.nn.Embedding(input_size, hidden_size)self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers, bidirectional=bidirectional)self.fc = torch.nn.Linear(hidden_size * self.n_directions, output_size)def forward(self, input, seq_len):input = input.t()  # 转置batch_size = input.size(1)hidden = self._init_hidden(batch_size)embedding = self.embedding(input.to(device))gru_input = pack_padded_sequence(embedding, seq_len)output, hidden = self.gru(gru_input.to(device), hidden.to(device))if self.n_directions == 2:hidden_cat = torch.cat([hidden[-1], hidden[-2]], dim=1)else:hidden_cat = hidden[-1]fc_output = self.fc(hidden_cat)return fc_outputdef _init_hidden(self, batch_size):hidden = torch.zeros(self.n_layers * self.n_directions, batch_size, self.hidden_size)return torch.FloatTensor(hidden)

5.6 训练函数

def trainModel():total_loss = 0print('=' * 20, 'Epoch', epoch, '=' * 20)for i, (names, countries) in enumerate(train_loader, 1):inputs, seq_len, target = make_tensors(names, countries)output = classifier(inputs, seq_len)loss = criterion(output.to(device), target.to(device))optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()if i % 10 == 0:print(f'[程序已运行{time.time() - start} 秒]', end='')print(f' - [{i * len(inputs)}/{len(train_set)}]', end='')print(f' , loss={total_loss / (i * len(inputs))}')

5.7 测试函数

def tttModel():correct = 0total = len(test_set)with torch.no_grad():for i, (names, countries) in enumerate(test_loader, 1):inputs, seq_len, target = make_tensors(names, countries)output = classifier(inputs, seq_len)pred = output.max(dim=1, keepdim=True)[1]correct += pred.eq(target.to(device).view_as(pred)).sum().item()percent = '%.2f' % (100 * correct / total)print(f'在训练集上评估模型: Accuracy {correct}/{total} {percent}%')return correct / total

5.8 主要代码块

if __name__ == '__main__':# 参数HIDDEN_SIZE = 100  # 隐藏层尺寸BATCH_SIZE = 256  #N_LAYER = 2N_EPOCHS = 50  # 迭代次数N_CHARS = 128  # 字符长度USE_GPU = True  # 是否启用GPU加速# 准备数据train_set = NameDataset(is_train_set=True)train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)test_set = NameDataset(is_train_set=False)test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)# 国家数量N_COUNTRY = train_set.getCountriesNum()# 声明RNN模型classifier = RNNClassifier(N_CHARS, HIDDEN_SIZE, N_COUNTRY, N_LAYER)if USE_GPU:device = torch.device("cuda:0")classifier.to(device)# 损失函数criterion = torch.nn.CrossEntropyLoss()# 优化器optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)start = time.time()acc_list = []for epoch in range(1, N_EPOCHS + 1):trainModel()acc = tttModel()acc_list.append(acc)# 画图plt.plot([i + 1 for i in range(len(acc_list))], acc_list)plt.show()

5.9 完整代码

import time
import torch
import csv
import gzip
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as pltdef name2list(name):arr = [ord(c) for c in name]return arr, len(arr)# 创建Tensors
def make_tensors(names, countries):sequences_and_lengths = [name2list(name) for name in names]name_sequences = [sl[0] for sl in sequences_and_lengths]seq_lengths = torch.LongTensor([sl[1] for sl in sequences_and_lengths])countries = countries.long()seq_tensor = torch.zeros(len(name_sequences), seq_lengths.max()).long()for idx, (seq, seq_len) in enumerate(zip(name_sequences, seq_lengths), 0):seq_tensor[idx, :seq_len] = torch.LongTensor(seq)seq_lengths, perm_idx = seq_lengths.sort(dim=0, descending=True)seq_tensor = seq_tensor[perm_idx]countries = countries[perm_idx]return torch.LongTensor(seq_tensor), \torch.LongTensor(seq_lengths), \torch.LongTensor(countries)# 国家名字数据集对象
class NameDataset(Dataset):def __init__(self, is_train_set=True):filename = '../dataset/names_train.csv.gz' if is_train_set else '../dataset/names_test.csv.gz'with gzip.open(filename, 'rt') as f:reader = csv.reader(f)rows = list(reader)self.names = [row[0] for row in rows]self.len = len(self.names)self.countries = [row[1] for row in rows]self.country_list = list(sorted(set(self.countries)))self.country_dict = self.getCountryDict()self.country_num = len(self.country_list)# 获取国家字典def getCountryDict(self):country_dict = dict()for idx, country_name in enumerate(self.country_list, 0):country_dict[country_name] = idxreturn country_dict# 获取国家数量def getCountriesNum(self):return self.country_num# 根据索引,返回国家的字符串def idx2country(self, index):return self.country_list[index]def __getitem__(self, index):return self.names[index], self.country_dict[self.countries[index]]def __len__(self):return self.len# RNN分类器对象
class RNNClassifier(torch.nn.Module):def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):super(RNNClassifier, self).__init__()self.hidden_size = hidden_sizeself.n_layers = n_layersself.n_directions = 2 if bidirectional else 1self.embedding = torch.nn.Embedding(input_size, hidden_size)self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers, bidirectional=bidirectional)self.fc = torch.nn.Linear(hidden_size * self.n_directions, output_size)def forward(self, input, seq_len):input = input.t()  # 转置batch_size = input.size(1)hidden = self._init_hidden(batch_size)embedding = self.embedding(input.to(device))gru_input = pack_padded_sequence(embedding, seq_len)output, hidden = self.gru(gru_input.to(device), hidden.to(device))if self.n_directions == 2:hidden_cat = torch.cat([hidden[-1], hidden[-2]], dim=1)else:hidden_cat = hidden[-1]fc_output = self.fc(hidden_cat)return fc_outputdef _init_hidden(self, batch_size):hidden = torch.zeros(self.n_layers * self.n_directions, batch_size, self.hidden_size)return torch.FloatTensor(hidden)# 训练函数
def trainModel():total_loss = 0print('=' * 20, 'Epoch', epoch, '=' * 20)for i, (names, countries) in enumerate(train_loader, 1):inputs, seq_len, target = make_tensors(names, countries)output = classifier(inputs, seq_len)loss = criterion(output.to(device), target.to(device))optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()if i % 10 == 0:print(f'[程序已运行{time.time() - start} 秒]', end='')print(f' - [{i * len(inputs)}/{len(train_set)}]', end='')print(f' , loss={total_loss / (i * len(inputs))}')# 测试函数
def tttModel():correct = 0total = len(test_set)with torch.no_grad():for i, (names, countries) in enumerate(test_loader, 1):inputs, seq_len, target = make_tensors(names, countries)output = classifier(inputs, seq_len)pred = output.max(dim=1, keepdim=True)[1]correct += pred.eq(target.to(device).view_as(pred)).sum().item()percent = '%.2f' % (100 * correct / total)print(f'在训练集上评估模型: Accuracy {correct}/{total} {percent}%')return correct / totalif __name__ == '__main__':# 参数HIDDEN_SIZE = 100  # 隐藏层尺寸BATCH_SIZE = 256  #N_LAYER = 2N_EPOCHS = 50  # 迭代次数N_CHARS = 128  # 字符长度USE_GPU = True  # 是否启用GPU加速# 准备数据train_set = NameDataset(is_train_set=True)train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)test_set = NameDataset(is_train_set=False)test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)# 国家数量N_COUNTRY = train_set.getCountriesNum()# 声明RNN模型classifier = RNNClassifier(N_CHARS, HIDDEN_SIZE, N_COUNTRY, N_LAYER)if USE_GPU:device = torch.device("cuda:0")classifier.to(device)# 损失函数criterion = torch.nn.CrossEntropyLoss()# 优化器optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)start = time.time()acc_list = []for epoch in range(1, N_EPOCHS + 1):trainModel()acc = tttModel()acc_list.append(acc)# 画图plt.plot([i + 1 for i in range(len(acc_list))], acc_list)plt.show()

5.10 运行输出

模型在测试集上的正确率迭代图:
在这里插入图片描述
控制台输出:

==================== Epoch 1 ====================
[程序已运行0.44780421257019043] - [2560/13374] , loss=0.00890564052388072
[程序已运行0.6133613586425781] - [5120/13374] , loss=0.007599683245643973
[程序已运行0.7878952026367188] - [7680/13374] , loss=0.006917052948847413
[程序已运行0.9584388732910156] - [10240/13374] , loss=0.006446240993682295
[程序已运行1.1319754123687744] - [12800/13374] , loss=0.0060893167182803154
在训练集上评估模型: Accuracy 4472/6700 66.75%
==================== Epoch 2 ====================
[程序已运行1.5678095817565918] - [2560/13374] , loss=0.004120685928501189
[程序已运行1.7363591194152832] - [5120/13374] , loss=0.004010635381564498
[程序已运行1.9069037437438965] - [7680/13374] , loss=0.00399712462288638
[程序已运行2.0754520893096924] - [10240/13374] , loss=0.003917965857544914
[程序已运行2.2400126457214355] - [12800/13374] , loss=0.0038193828100338578
在训练集上评估模型: Accuracy 4984/6700 74.39%
==================== Epoch 3 ====================
[程序已运行2.7569212913513184] - [2560/13374] , loss=0.003344046091660857
[程序已运行2.928469181060791] - [5120/13374] , loss=0.003259772143792361
[程序已运行3.1019983291625977] - [7680/13374] , loss=0.003187967735963563
[程序已运行3.274538040161133] - [10240/13374] , loss=0.0031107491464354097
[程序已运行3.442089319229126] - [12800/13374] , loss=0.0030547570576891303
在训练集上评估模型: Accuracy 5251/6700 78.37%
==================== Epoch 4 ====================
[程序已运行3.8719401359558105] - [2560/13374] , loss=0.0026993038831278683
[程序已运行4.046473503112793] - [5120/13374] , loss=0.002674128650687635
[程序已运行4.216020822525024] - [7680/13374] , loss=0.0026587502487624686
[程序已运行4.38556694984436] - [10240/13374] , loss=0.0025898678810335695
[程序已运行4.550126552581787] - [12800/13374] , loss=0.0025764494528993966
在训练集上评估模型: Accuracy 5364/6700 80.06%
==================== Epoch 5 ====================
[程序已运行4.99593448638916] - [2560/13374] , loss=0.002204193570651114
[程序已运行5.161492109298706] - [5120/13374] , loss=0.002289206086425111
[程序已运行5.331039667129517] - [7680/13374] , loss=0.0022666183416731656
[程序已运行5.49859094619751] - [10240/13374] , loss=0.002261629086569883
[程序已运行5.671129941940308] - [12800/13374] , loss=0.0022295594890601933
在训练集上评估模型: Accuracy 5463/6700 81.54%
==================== Epoch 6 ====================
[程序已运行6.233642578125] - [2560/13374] , loss=0.0019896522513590752
[程序已运行6.456540107727051] - [5120/13374] , loss=0.001986617426155135
[程序已运行6.664837121963501] - [7680/13374] , loss=0.0019825143235114714
[程序已运行6.833385467529297] - [10240/13374] , loss=0.0020075739128515126
[程序已运行7.032860040664673] - [12800/13374] , loss=0.0020189793314784764
在训练集上评估模型: Accuracy 5509/6700 82.22%
==================== Epoch 7 ====================
[程序已运行7.519379138946533] - [2560/13374] , loss=0.0018858694704249502
[程序已运行7.713961362838745] - [5120/13374] , loss=0.0019061228667851537
[程序已运行7.924145698547363] - [7680/13374] , loss=0.0018647492048330604
[程序已运行8.105660200119019] - [10240/13374] , loss=0.0018422425084281713
[程序已运行8.272214412689209] - [12800/13374] , loss=0.001829312415793538
在训练集上评估模型: Accuracy 5554/6700 82.90%
==================== Epoch 8 ====================
[程序已运行8.776322841644287] - [2560/13374] , loss=0.0017062094528228044
[程序已运行8.956839084625244] - [5120/13374] , loss=0.0017181114875711502
[程序已运行9.128380060195923] - [7680/13374] , loss=0.0016691674555962285
[程序已运行9.299922227859497] - [10240/13374] , loss=0.001663037418620661
[程序已运行9.480438709259033] - [12800/13374] , loss=0.0016579296253621577
在训练集上评估模型: Accuracy 5622/6700 83.91%
==================== Epoch 9 ====================
[程序已运行10.042935371398926] - [2560/13374] , loss=0.00138091582339257
[程序已运行10.261351108551025] - [5120/13374] , loss=0.0014375038270372897
[程序已运行10.564762353897095] - [7680/13374] , loss=0.0014615616644732654
[程序已运行10.782928466796875] - [10240/13374] , loss=0.0014871433144435287
[程序已运行10.983396530151367] - [12800/13374] , loss=0.001485794959589839
在训练集上评估模型: Accuracy 5639/6700 84.16%
==================== Epoch 10 ====================
[程序已运行11.527957201004028] - [2560/13374] , loss=0.0013533846475183963
[程序已运行11.727681636810303] - [5120/13374] , loss=0.0013366769824642688
[程序已运行11.922213315963745] - [7680/13374] , loss=0.0013372401415836066
[程序已运行12.146097660064697] - [10240/13374] , loss=0.0013334597359062172
[程序已运行12.360528945922852] - [12800/13374] , loss=0.0013430006837006658
在训练集上评估模型: Accuracy 5671/6700 84.64%
==================== Epoch 11 ====================
[程序已运行12.899601221084595] - [2560/13374] , loss=0.0012242367956787348
[程序已运行13.09657597541809] - [5120/13374] , loss=0.0012018651730613783
[程序已运行13.297226905822754] - [7680/13374] , loss=0.0011939232140624275
[程序已运行13.494091510772705] - [10240/13374] , loss=0.0012114218538044953
[程序已运行13.68766736984253] - [12800/13374] , loss=0.0012003930215723812
在训练集上评估模型: Accuracy 5694/6700 84.99%
==================== Epoch 12 ====================
[程序已运行14.210429906845093] - [2560/13374] , loss=0.0009987134893890471
[程序已运行14.412419557571411] - [5120/13374] , loss=0.001071632924140431
[程序已运行14.66275668144226] - [7680/13374] , loss=0.001076526817632839
[程序已运行14.856532573699951] - [10240/13374] , loss=0.0010921183042228223
[程序已运行15.086088180541992] - [12800/13374] , loss=0.0010934956604614853
在训练集上评估模型: Accuracy 5662/6700 84.51%
==================== Epoch 13 ====================
[程序已运行15.603407859802246] - [2560/13374] , loss=0.0009401162387803197
[程序已运行15.79611325263977] - [5120/13374] , loss=0.000941314865485765
[程序已运行15.991790294647217] - [7680/13374] , loss=0.0009610804088879376
[程序已运行16.179932117462158] - [10240/13374] , loss=0.0009680362229119055
[程序已运行16.40265130996704] - [12800/13374] , loss=0.000982134909136221
在训练集上评估模型: Accuracy 5667/6700 84.58%
==================== Epoch 14 ====================
[程序已运行16.909828424453735] - [2560/13374] , loss=0.0007867097272537648
[程序已运行17.106796979904175] - [5120/13374] , loss=0.000815554044675082
[程序已运行17.292539358139038] - [7680/13374] , loss=0.0008367548192230364
[程序已运行17.472803592681885] - [10240/13374] , loss=0.0008467563966405578
[程序已运行17.639357805252075] - [12800/13374] , loss=0.0008704598969779909
在训练集上评估模型: Accuracy 5668/6700 84.60%
==================== Epoch 15 ====================
[程序已运行18.066723585128784] - [2560/13374] , loss=0.0007090812985552474
[程序已运行18.24724054336548] - [5120/13374] , loss=0.0007125684132915922
[程序已运行18.403822898864746] - [7680/13374] , loss=0.0007363049614165599
[程序已运行18.56738543510437] - [10240/13374] , loss=0.0007670420709473547
[程序已运行18.725961923599243] - [12800/13374] , loss=0.0007662988040829078
在训练集上评估模型: Accuracy 5691/6700 84.94%
==================== Epoch 16 ====================
[程序已运行19.231609582901] - [2560/13374] , loss=0.0006500366900581867
[程序已运行19.412126302719116] - [5120/13374] , loss=0.0006256612250581384
[程序已运行19.586659908294678] - [7680/13374] , loss=0.0006333356761994461
[程序已运行19.75919795036316] - [10240/13374] , loss=0.0006514595239423216
[程序已运行19.923759937286377] - [12800/13374] , loss=0.0006744604662526399
在训练集上评估模型: Accuracy 5667/6700 84.58%
==================== Epoch 17 ====================
[程序已运行20.44287872314453] - [2560/13374] , loss=0.0005562761740293354
[程序已运行20.62125062942505] - [5120/13374] , loss=0.0005988578370306641
[程序已运行20.796781301498413] - [7680/13374] , loss=0.0005926105329611649
[程序已运行20.967325687408447] - [10240/13374] , loss=0.000587682421610225
[程序已运行21.13288402557373] - [12800/13374] , loss=0.0005947937141172587
在训练集上评估模型: Accuracy 5690/6700 84.93%
==================== Epoch 18 ====================
[程序已运行21.58068537712097] - [2560/13374] , loss=0.0004808117635548115
[程序已运行21.750232458114624] - [5120/13374] , loss=0.0004957637938787229
[程序已运行21.91977834701538] - [7680/13374] , loss=0.0005180548294447362
[程序已运行22.09730362892151] - [10240/13374] , loss=0.0005155952792847529
[程序已运行22.261864185333252] - [12800/13374] , loss=0.000532688939711079
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 19 ====================
[程序已运行22.743576526641846] - [2560/13374] , loss=0.00045145115873310713
[程序已运行22.916114330291748] - [5120/13374] , loss=0.0004681079401052557
[程序已运行23.086658239364624] - [7680/13374] , loss=0.0004730896607118969
[程序已运行23.260194301605225] - [10240/13374] , loss=0.00047693301239633
[程序已运行23.420764923095703] - [12800/13374] , loss=0.00048351394420024005
在训练集上评估模型: Accuracy 5686/6700 84.87%
==================== Epoch 20 ====================
[程序已运行23.879538536071777] - [2560/13374] , loss=0.000395867633051239
[程序已运行24.058061122894287] - [5120/13374] , loss=0.00040292235498782245
[程序已运行24.22611951828003] - [7680/13374] , loss=0.00041214636488196753
[程序已运行24.40364646911621] - [10240/13374] , loss=0.0004200873590889387
[程序已运行24.599121570587158] - [12800/13374] , loss=0.00042729771870654077
在训练集上评估模型: Accuracy 5683/6700 84.82%
==================== Epoch 21 ====================
[程序已运行25.093798637390137] - [2560/13374] , loss=0.0003232844435842708
[程序已运行25.266338109970093] - [5120/13374] , loss=0.0003470184303296264
[程序已运行25.42990016937256] - [7680/13374] , loss=0.00036114812683081255
[程序已运行25.605431079864502] - [10240/13374] , loss=0.00036963203929190057
[程序已运行25.77896738052368] - [12800/13374] , loss=0.00038419733085902406
在训练集上评估模型: Accuracy 5656/6700 84.42%
==================== Epoch 22 ====================
[程序已运行26.23973536491394] - [2560/13374] , loss=0.00035485800908645614
[程序已运行26.42025327682495] - [5120/13374] , loss=0.000331850739166839
[程序已运行26.59478497505188] - [7680/13374] , loss=0.00033587128127692266
[程序已运行26.76333522796631] - [10240/13374] , loss=0.0003520488495269092
[程序已运行26.939862489700317] - [12800/13374] , loss=0.00036097451025852935
在训练集上评估模型: Accuracy 5633/6700 84.07%
==================== Epoch 23 ====================
[程序已运行27.467453002929688] - [2560/13374] , loss=0.0003026763326488435
[程序已运行27.70481777191162] - [5120/13374] , loss=0.0003143761219689623
[程序已运行27.911266088485718] - [7680/13374] , loss=0.0003264661878347397
[程序已运行28.186530113220215] - [10240/13374] , loss=0.0003306483005871996
[程序已运行28.46631908416748] - [12800/13374] , loss=0.0003335028060246259
在训练集上评估模型: Accuracy 5673/6700 84.67%
==================== Epoch 24 ====================
[程序已运行29.179603815078735] - [2560/13374] , loss=0.00027135475975228476
[程序已运行29.430758237838745] - [5120/13374] , loss=0.00028713325373246333
[程序已运行29.679874897003174] - [7680/13374] , loss=0.0003023453294493568
[程序已运行29.934645414352417] - [10240/13374] , loss=0.000312889136330341
[程序已运行30.228501319885254] - [12800/13374] , loss=0.0003050918216467835
在训练集上评估模型: Accuracy 5648/6700 84.30%
==================== Epoch 25 ====================
[程序已运行30.873358488082886] - [2560/13374] , loss=0.00023736636503599585
[程序已运行31.12848401069641] - [5120/13374] , loss=0.000262370355630992
[程序已运行31.37047266960144] - [7680/13374] , loss=0.00028469724299308533
[程序已运行31.587894439697266] - [10240/13374] , loss=0.00029130332404747605
[程序已运行31.81080985069275] - [12800/13374] , loss=0.000298141545499675
在训练集上评估模型: Accuracy 5650/6700 84.33%
==================== Epoch 26 ====================
[程序已运行32.386191606521606] - [2560/13374] , loss=0.0002229748701211065
[程序已运行32.59639310836792] - [5120/13374] , loss=0.00024768941948423163
[程序已运行32.80773115158081] - [7680/13374] , loss=0.0002617782947102872
[程序已运行33.03980994224548] - [10240/13374] , loss=0.0002768368027318502
[程序已运行33.29946231842041] - [12800/13374] , loss=0.0002899381099268794
在训练集上评估模型: Accuracy 5641/6700 84.19%
==================== Epoch 27 ====================
[程序已运行33.88727021217346] - [2560/13374] , loss=0.00024374632048420608
[程序已运行34.09471583366394] - [5120/13374] , loss=0.00025897946034092456
[程序已运行34.289196252822876] - [7680/13374] , loss=0.0002538432978326455
[程序已运行34.47469925880432] - [10240/13374] , loss=0.000259564047519234
[程序已运行34.67316937446594] - [12800/13374] , loss=0.00027740672609070314
在训练集上评估模型: Accuracy 5646/6700 84.27%
==================== Epoch 28 ====================
[程序已运行35.17283344268799] - [2560/13374] , loss=0.00024596738221589477
[程序已运行35.37529134750366] - [5120/13374] , loss=0.00025625340931583195
[程序已运行35.612366676330566] - [7680/13374] , loss=0.00025634167104726657
[程序已运行35.83150577545166] - [10240/13374] , loss=0.00026004306309914684
[程序已运行36.03250527381897] - [12800/13374] , loss=0.00026288528344593944
在训练集上评估模型: Accuracy 5638/6700 84.15%
==================== Epoch 29 ====================
[程序已运行36.57138133049011] - [2560/13374] , loss=0.00021622761414619163
[程序已运行36.77258634567261] - [5120/13374] , loss=0.00022565958206541837
[程序已运行36.98577690124512] - [7680/13374] , loss=0.0002272359988031288
[程序已运行37.19248104095459] - [10240/13374] , loss=0.000239311121913488
[程序已运行37.42102932929993] - [12800/13374] , loss=0.00026318065298255535
在训练集上评估模型: Accuracy 5660/6700 84.48%
==================== Epoch 30 ====================
[程序已运行37.98303580284119] - [2560/13374] , loss=0.0002419873373582959
[程序已运行38.19469690322876] - [5120/13374] , loss=0.00024032749206526205
[程序已运行38.433921813964844] - [7680/13374] , loss=0.00023429312568623573
[程序已运行38.63934087753296] - [10240/13374] , loss=0.000236004658654565
[程序已运行38.855393171310425] - [12800/13374] , loss=0.00024584917380707336
在训练集上评估模型: Accuracy 5652/6700 84.36%
==================== Epoch 31 ====================
[程序已运行39.43484330177307] - [2560/13374] , loss=0.00018982139372383244
[程序已运行39.639156103134155] - [5120/13374] , loss=0.0001874874855275266
[程序已运行39.86286163330078] - [7680/13374] , loss=0.0002010059698174397
[程序已运行40.10820508003235] - [10240/13374] , loss=0.00022044739162083716
[程序已运行40.321282148361206] - [12800/13374] , loss=0.00023711820482276381
在训练集上评估模型: Accuracy 5650/6700 84.33%
==================== Epoch 32 ====================
[程序已运行40.8946692943573] - [2560/13374] , loss=0.00015288349750335327
[程序已运行41.10710620880127] - [5120/13374] , loss=0.0001839709879277507
[程序已运行41.33549523353577] - [7680/13374] , loss=0.00019845828088970545
[程序已运行41.57286047935486] - [10240/13374] , loss=0.00021456888480315684
[程序已运行41.8036105632782] - [12800/13374] , loss=0.0002269831285229884
在训练集上评估模型: Accuracy 5666/6700 84.57%
==================== Epoch 33 ====================
[程序已运行42.351365089416504] - [2560/13374] , loss=0.00020009858708363025
[程序已运行42.55391049385071] - [5120/13374] , loss=0.00021211478597251698
[程序已运行42.762874364852905] - [7680/13374] , loss=0.00021397633778785045
[程序已运行42.9832968711853] - [10240/13374] , loss=0.0002237053031421965
[程序已运行43.203397274017334] - [12800/13374] , loss=0.00022588698309846221
在训练集上评估模型: Accuracy 5663/6700 84.52%
==================== Epoch 34 ====================
[程序已运行43.75180983543396] - [2560/13374] , loss=0.000175720240076771
[程序已运行43.96774888038635] - [5120/13374] , loss=0.0001822654259740375
[程序已运行44.17185306549072] - [7680/13374] , loss=0.00019087508820424167
[程序已运行44.37160062789917] - [10240/13374] , loss=0.0002025523937845719
[程序已运行44.58753275871277] - [12800/13374] , loss=0.0002117827812617179
在训练集上评估模型: Accuracy 5661/6700 84.49%
==================== Epoch 35 ====================
[程序已运行45.14763021469116] - [2560/13374] , loss=0.00018244399325340055
[程序已运行45.38248586654663] - [5120/13374] , loss=0.0001909391281515127
[程序已运行45.58654308319092] - [7680/13374] , loss=0.00019028932996055422
[程序已运行45.78802943229675] - [10240/13374] , loss=0.00019807684529951075
[程序已运行45.99347996711731] - [12800/13374] , loss=0.00020920663388096728
在训练集上评估模型: Accuracy 5680/6700 84.78%
==================== Epoch 36 ====================
[程序已运行46.53318977355957] - [2560/13374] , loss=0.0002301990520209074
[程序已运行46.713706970214844] - [5120/13374] , loss=0.00019428682790021413
[程序已运行46.90220332145691] - [7680/13374] , loss=0.00020380564043686415
[程序已运行47.091697454452515] - [10240/13374] , loss=0.00020663877294282428
[程序已运行47.291162967681885] - [12800/13374] , loss=0.0002089952616370283
在训练集上评估模型: Accuracy 5676/6700 84.72%
==================== Epoch 37 ====================
[程序已运行47.806785345077515] - [2560/13374] , loss=0.0001326974183029961
[程序已运行47.99926948547363] - [5120/13374] , loss=0.00016608303412795066
[程序已运行48.17978763580322] - [7680/13374] , loss=0.00018860103252033394
[程序已运行48.39720582962036] - [10240/13374] , loss=0.00019829184930131305
[程序已运行48.613649129867554] - [12800/13374] , loss=0.00020632889354601502
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 38 ====================
[程序已运行49.16527962684631] - [2560/13374] , loss=0.00017558708641445263
[程序已运行49.377071380615234] - [5120/13374] , loss=0.0001803442672098754
[程序已运行49.61643576622009] - [7680/13374] , loss=0.00019150104708387516
[程序已运行49.82301330566406] - [10240/13374] , loss=0.00020083691142644967
[程序已运行50.0354540348053] - [12800/13374] , loss=0.00020664259296609088
在训练集上评估模型: Accuracy 5665/6700 84.55%
==================== Epoch 39 ====================
[程序已运行50.57412886619568] - [2560/13374] , loss=0.0001647219163714908
[程序已运行50.781656980514526] - [5120/13374] , loss=0.00017860122316051274
[程序已运行50.985915660858154] - [7680/13374] , loss=0.00018149100748511652
[程序已运行51.19026255607605] - [10240/13374] , loss=0.00018654396444617305
[程序已运行51.396902322769165] - [12800/13374] , loss=0.00019622410647571087
在训练集上评估模型: Accuracy 5674/6700 84.69%
==================== Epoch 40 ====================
[程序已运行51.950305223464966] - [2560/13374] , loss=0.0001813760813092813
[程序已运行52.168620586395264] - [5120/13374] , loss=0.0002006525617616717
[程序已运行52.388041496276855] - [7680/13374] , loss=0.00020146279275650157
[程序已运行52.609692335128784] - [10240/13374] , loss=0.00020502966435742564
[程序已运行52.82223105430603] - [12800/13374] , loss=0.00020634495449485258
在训练集上评估模型: Accuracy 5660/6700 84.48%
==================== Epoch 41 ====================
[程序已运行53.38243508338928] - [2560/13374] , loss=0.00015071030429680832
[程序已运行53.587127923965454] - [5120/13374] , loss=0.00017189887112181168
[程序已运行53.806541204452515] - [7680/13374] , loss=0.00017305417253131358
[程序已运行54.02480912208557] - [10240/13374] , loss=0.00019079587727901525
[程序已运行54.2585232257843] - [12800/13374] , loss=0.00020163424400379882
在训练集上评估模型: Accuracy 5648/6700 84.30%
==================== Epoch 42 ====================
[程序已运行54.82281470298767] - [2560/13374] , loss=0.00013891297894588205
[程序已运行55.025827407836914] - [5120/13374] , loss=0.00015851605949137592
[程序已运行55.231240034103394] - [7680/13374] , loss=0.00016577779958121635
[程序已运行55.46313452720642] - [10240/13374] , loss=0.00018027596570391324
[程序已运行55.67412066459656] - [12800/13374] , loss=0.0001966867937153438
在训练集上评估模型: Accuracy 5665/6700 84.55%
==================== Epoch 43 ====================
[程序已运行56.20661902427673] - [2560/13374] , loss=0.00015996379297575914
[程序已运行56.41858148574829] - [5120/13374] , loss=0.00016856673537404276
[程序已运行56.639917612075806] - [7680/13374] , loss=0.00018089073128066958
[程序已运行56.8443865776062] - [10240/13374] , loss=0.0001840785967942793
[程序已运行57.06223940849304] - [12800/13374] , loss=0.00019384760729735718
在训练集上评估模型: Accuracy 5645/6700 84.25%
==================== Epoch 44 ====================
[程序已运行57.572874546051025] - [2560/13374] , loss=0.00015061586382216773
[程序已运行57.78630328178406] - [5120/13374] , loss=0.0001388985794619657
[程序已运行57.9957435131073] - [7680/13374] , loss=0.00016478789621032774
[程序已运行58.19321537017822] - [10240/13374] , loss=0.00017915765802172244
[程序已运行58.42410707473755] - [12800/13374] , loss=0.00018903483942267485
在训练集上评估模型: Accuracy 5671/6700 84.64%
==================== Epoch 45 ====================
[程序已运行58.94371771812439] - [2560/13374] , loss=0.00015125685386010445
[程序已运行59.1222403049469] - [5120/13374] , loss=0.0001521789192338474
[程序已运行59.306747913360596] - [7680/13374] , loss=0.00016950705078973746
[程序已运行59.48726511001587] - [10240/13374] , loss=0.00017524105423945003
[程序已运行59.67775559425354] - [12800/13374] , loss=0.0001805883324414026
在训练集上评估模型: Accuracy 5667/6700 84.58%
==================== Epoch 46 ====================
[程序已运行60.20154929161072] - [2560/13374] , loss=0.00014828720522928052
[程序已运行60.40001821517944] - [5120/13374] , loss=0.00015882042025623378
[程序已运行60.58452534675598] - [7680/13374] , loss=0.0001613501580626083
[程序已运行60.77501559257507] - [10240/13374] , loss=0.00016973716010397765
[程序已运行60.948551416397095] - [12800/13374] , loss=0.0001806799619225785
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 47 ====================
[程序已运行61.3913676738739] - [2560/13374] , loss=0.00015815104998182506
[程序已运行61.57487750053406] - [5120/13374] , loss=0.00016703655710443855
[程序已运行61.75738835334778] - [7680/13374] , loss=0.00017824470463286465
[程序已运行61.92992830276489] - [10240/13374] , loss=0.00017852899200079264
[程序已运行62.09648156166077] - [12800/13374] , loss=0.00018316529472940603
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 48 ====================
[程序已运行62.557249307632446] - [2560/13374] , loss=0.00015278960127034223
[程序已运行62.73379874229431] - [5120/13374] , loss=0.00015682993143855127
[程序已运行62.920299768447876] - [7680/13374] , loss=0.00016720300166828868
[程序已运行63.097824573516846] - [10240/13374] , loss=0.00017679908323771089
[程序已运行63.27834177017212] - [12800/13374] , loss=0.00018408266638289205
在训练集上评估模型: Accuracy 5632/6700 84.06%
==================== Epoch 49 ====================
[程序已运行63.809921741485596] - [2560/13374] , loss=0.00013651168810611126
[程序已运行63.99043846130371] - [5120/13374] , loss=0.00015391943106806137
[程序已运行64.16397380828857] - [7680/13374] , loss=0.00016998272246079675
[程序已运行64.34748315811157] - [10240/13374] , loss=0.0001738338686664065
[程序已运行64.52999496459961] - [12800/13374] , loss=0.0001867358752497239
在训练集上评估模型: Accuracy 5660/6700 84.48%
==================== Epoch 50 ====================
[程序已运行64.99973917007446] - [2560/13374] , loss=0.00014780706333112904
[程序已运行65.21316909790039] - [5120/13374] , loss=0.00015933765134832357
[程序已运行65.40266251564026] - [7680/13374] , loss=0.0001664065877169681
[程序已运行65.57220840454102] - [10240/13374] , loss=0.00017008764552883804
[程序已运行65.75272583961487] - [12800/13374] , loss=0.00018112511475919745
在训练集上评估模型: Accuracy 5646/6700 84.27%

http://www.ppmy.cn/news/487137.html

相关文章

Android dump渲染和合成图层GraphicBuffer指南

Android dump渲染和合成图层GraphicBuffer指南 引言 博客停更很久了,提起笔来渐感生疏啊!看来,还是得抽出时间来更新更新啊!好了,感慨也发完了,是时候切入正题了。本篇博客主要以本人在实际项目的开发中&am…

Linux—vmstat命令详解

vmstat概念 vmstat命令是 Virtual Meomory Statistics(虚拟内存统计)的缩写,可用来动态监控系统资源的 CPU 使用、进程状态、内存使用、虚拟内存使用、硬盘输入/输出状态等信息使用情况 vmstat下载 有些系统可能没有自带vmsata命令&#xf…

一文看懂华为Mate30:4摄5G+7680帧超高速慢动作摄影,6200元起

【惊奇科技】华为Mate30 Pro简单上手 电影镜头貌似有点意思 作者 | 吴波 出品 | 网易手机&数码《易评机》栏目组 (油管&B站:惊奇科技 抖音:JQ163) 今年对于华为手机来说可谓历经各种坎坷,在重压之下华为还是于…

win11解决80端口默认被占用的问题

错误原因是Nginx无法启动, 提示80端口被占用 报错: [emerg] 14180#15404: bind() to 0.0.0.0:80 failed (10013: An attempt was made to access a socket in a way forbidden by its access permissions) 首先查看80端口被谁占用: TCP 0.0.0.0:80 …

数据处理快人一步!聊一聊如何使用CUDA的Stream

作者 | 伊凡 编辑 | 极市平台 原文链接:https://zhuanlan.zhihu.com/p/51402722 点击下方卡片,关注“自动驾驶之心”公众号 ADAS巨卷干货,即可获取 点击进入→自动驾驶之心技术交流群 后台回复【CUDA】获取CUDA实战书籍! 导读 来…

前端大屏适配几种方案

记录一下前端大屏的几种适配方案。 我们是1920*1080的设计稿。 文章目录 一、方案一:remfont-size1.查看适配情况1.1 1920*1080情况下1.2 3840*2160(4k屏)情况下1.3 7680*2160 超宽屏下 二、方案二:vw(单位)三、方案三&#xff1…

端口排查步骤-7680端口分析-Dosvc服务

简介: 对通过服务启动的进程查找主进程,出现大量7680端口的内网连接,百度未找到端口信息,需证明为系统服务,否则为蠕虫。 1、 确认端口对应进程PID netstat -ano 7680端口对应pid:6128 2、 查找pid对应进程…

【数据库】MySQL 高级(进阶) SQL 语句

文章目录 前提条件一、常用查询1. SELECT(显示查询)2. DISTINCT(不重复查询)3. WHERE(有条件查询)4. AND/OR(且/或)5. IN (显示已知值的字段)6. BETWEEN&…