人脸识别0-04:insightFace-模型训练注释详解-史上最全

news/2024/10/27 20:23:58/

以下链接是个人关于insightFace所有见解,如有错误欢迎大家指出,我会第一时间纠正,如有兴趣可以加微信:a944284742相互讨论技术。
人脸识别0-00:insightFace目录:https://blog.csdn.net/weixin_43013761/article/details/99646731:

版本更替

在作者发布初始的版本中,使用的是insightface-master\src下面的代码进行训练的,本人使用的是暂时最新的版本,在insightface-master\recognition目录下面,不知道当你看到这篇博客的时候,源码的作者是否又发布了新的版本,不过没关系,在上述的链接中,给出了本人的代码,下面我们开始讲解insightface-master\recognition\train.py,该是训练的核心代码,通过前面的博客我们拷贝了以分sample_config.py为config.py,该文件主要为模型训练提供了一系列的配置。

config.py

import numpy as np
import os
from easydict import EasyDict as edict# config配置是最基本的配置,如果后面出现相同的,则被覆盖
config = edict()config.bn_mom = 0.9 # 反向传播的momentum
config.workspace = 256 # mxnet需要的缓冲空间
config.emb_size = 128 #  输出特征向量的维度
config.ckpt_embedding = True # 是否检测输出的特征向量
config.net_se = 0 # 暂时不知道
config.net_act = 'prelu' # 激活函数
config.net_unit = 3 #
config.net_input = 1 #
config.net_blocks = [1,4,6,2]
config.net_output = 'E' # 输出层,链接层的类型,如"GDC"也是其中一种,具体查看recognition\symbol\symbol_utils.py
config.net_multiplier = 1.0
config.val_targets = ['lfw', 'cfp_fp', 'agedb_30'] # 测试数据,即.bin为后缀的文件
config.ce_loss = True #Focal loss,一种改进的交叉损失熵
config.fc7_lr_mult = 1.0 # 学习率的倍数
config.fc7_wd_mult = 1.0 # 权重刷衰减的倍数
config.fc7_no_bias = False #
config.max_steps = 0 # 训练的最大步骤吧,感觉有点懵逼,不过不影响大局
config.data_rand_mirror = True # 数据随机进行镜像翻转
config.data_cutoff = False # 数据进行随机裁剪
config.data_color = 0 # 估计是数据进行彩色增强
config.data_images_filter = 0 # 暂时不知道
config.count_flops = True # 是否计算一个网络占用的浮点数内存
config.memonger = False #not work now# 可以看到很多的网络结构,就不为大家一一注释了
# 因为我也没有把每个网络都弄得很透彻,可以看到有很多网络结构,在训练的时候我们都是可以选择的
# r100 r100fc
# network settings r50 r50v1 d169 d201 y1 m1 m05 mnas mnas025
network = edict()network.r100 = edict()
network.r100.net_name = 'fresnet'
network.r100.num_layers = 100network.r100fc = edict()
network.r100fc.net_name = 'fresnet'
network.r100fc.num_layers = 100
network.r100fc.net_output = 'FC'network.r50 = edict()
network.r50.net_name = 'fresnet'
network.r50.num_layers = 50network.r50v1 = edict()
network.r50v1.net_name = 'fresnet'
network.r50v1.num_layers = 50
network.r50v1.net_unit = 1network.d169 = edict()
network.d169.net_name = 'fdensenet'
network.d169.num_layers = 169
network.d169.per_batch_size = 64
network.d169.densenet_dropout = 0.0network.d201 = edict()
network.d201.net_name = 'fdensenet'
network.d201.num_layers = 201
network.d201.per_batch_size = 64
network.d201.densenet_dropout = 0.0network.y1 = edict()
network.y1.net_name = 'fmobilefacenet'
network.y1.emb_size = 128
network.y1.net_output = 'GDC'network.y2 = edict()
network.y2.net_name = 'fmobilefacenet'
network.y2.emb_size = 256
network.y2.net_output = 'GDC'
network.y2.net_blocks = [2,8,16,4]network.m1 = edict()
network.m1.net_name = 'fmobilenet'
network.m1.emb_size = 256
network.m1.net_output = 'GDC'
network.m1.net_multiplier = 1.0network.m05 = edict()
network.m05.net_name = 'fmobilenet'
network.m05.emb_size = 256
network.m05.net_output = 'GDC'
network.m05.net_multiplier = 0.5network.mnas = edict()
network.mnas.net_name = 'fmnasnet'
network.mnas.emb_size = 256
network.mnas.net_output = 'GDC'
network.mnas.net_multiplier = 1.0network.mnas05 = edict()
network.mnas05.net_name = 'fmnasnet'
network.mnas05.emb_size = 256
network.mnas05.net_output = 'GDC'
network.mnas05.net_multiplier = 0.5network.mnas025 = edict()
network.mnas025.net_name = 'fmnasnet'
network.mnas025.emb_size = 256
network.mnas025.net_output = 'GDC'
network.mnas025.net_multiplier = 0.25# 可以看到存在emore与retina两个数据集,训练的时候我们只能指定一个。
# num_classes来自property,为人脸id数目,为了能够较好的拟合数据
# dataset settings
dataset = edict()dataset.emore = edict()
dataset.emore.dataset = 'emore'
dataset.emore.dataset_path = '../../../2.dataset/1.officialData/1.traindata/faces_glint'
dataset.emore.num_classes = 180855
dataset.emore.image_shape = (112,112,3)
dataset.emore.val_targets = ['lfw', 'cfp_fp', 'agedb_30']dataset.retina = edict()
dataset.retina.dataset = 'retina'
dataset.retina.dataset_path = '../datasets/ms1m-retinaface-t1'
dataset.retina.num_classes = 93431
dataset.retina.image_shape = (112,112,3)
dataset.retina.val_targets = ['lfw', 'cfp_fp', 'agedb_30']# 损失函数是我们的重点,大家看了之后,不要觉得太复杂,
# loss_m1,loss_m2,loss_m3,其出现3个m,作者是为了减少代码量,把多个损失函数合并在一起了
# 即nsoftmax,arcface,cosface,combined
loss = edict()
loss.softmax = edict()
loss.softmax.loss_name = 'softmax'loss.nsoftmax = edict()
loss.nsoftmax.loss_name = 'margin_softmax'
loss.nsoftmax.loss_s = 64.0
loss.nsoftmax.loss_m1 = 1.0
loss.nsoftmax.loss_m2 = 0.0
loss.nsoftmax.loss_m3 = 0.0loss.arcface = edict()
loss.arcface.loss_name = 'margin_softmax'
loss.arcface.loss_s = 64.0
loss.arcface.loss_m1 = 1.0
loss.arcface.loss_m2 = 0.5
loss.arcface.loss_m3 = 0.0loss.cosface = edict()
loss.cosface.loss_name = 'margin_softmax'
loss.cosface.loss_s = 64.0
loss.cosface.loss_m1 = 1.0
loss.cosface.loss_m2 = 0.0
loss.cosface.loss_m3 = 0.35loss.combined = edict()
loss.combined.loss_name = 'margin_softmax'
loss.combined.loss_s = 64.0
loss.combined.loss_m1 = 1.0
loss.combined.loss_m2 = 0.3
loss.combined.loss_m3 = 0.2loss.triplet = edict()
loss.triplet.loss_name = 'triplet'
loss.triplet.images_per_identity = 5
loss.triplet.triplet_alpha = 0.3
loss.triplet.triplet_bag_size = 7200
loss.triplet.triplet_max_ap = 0.0
loss.triplet.per_batch_size = 60
loss.triplet.lr = 0.05loss.atriplet = edict()
loss.atriplet.loss_name = 'atriplet'
loss.atriplet.images_per_identity = 5
loss.atriplet.triplet_alpha = 0.35
loss.atriplet.triplet_bag_size = 7200
loss.atriplet.triplet_max_ap = 0.0
loss.atriplet.per_batch_size = 60
loss.atriplet.lr = 0.05# default settings
default = edict()# default network
default.network = 'r100'
#default.pretrained = ''
default.pretrained = '../models/model-y1-test2/model'
default.pretrained_epoch = 0
# default dataset
default.dataset = 'emore'
default.loss = 'arcface'
default.frequent = 20 # 每20个批次打印一次准确率等log
default.verbose = 2000 # 每训练2000次,对验证数据进行一次评估
default.kvstore = 'device' #键值存储default.end_epoch = 10000 # 结束的epoch
default.lr = 0.01 # 初始学习率,如果每个批次训练的数目小,学习率也相应的降低
default.wd = 0.0005 # 大概是权重初始化波动的范围
default.mom = 0.9
default.per_batch_size = 48 # 每存在一个GPU,训练48个批次,如两个GPU,则实际训练的batch_size为96
default.ckpt = 0 #
default.lr_steps = '100000,160000,220000'  # 每达到步数,学习率变为原来的百分之十
default.models_root = './models' # 模型保存的位置# 对config = edict()进行更新
def generate_config(_network, _dataset, _loss):for k, v in loss[_loss].items():config[k] = vif k in default:default[k] = vfor k, v in network[_network].items():config[k] = vif k in default:default[k] = vfor k, v in dataset[_dataset].items():config[k] = vif k in default:default[k] = vconfig.loss = _lossconfig.network = _networkconfig.dataset = _datasetconfig.num_workers = 1if 'DMLC_NUM_WORKER' in os.environ:config.num_workers = int(os.environ['DMLC_NUM_WORKER'])

注释也花了少些心思,如果对你有帮助希望能点个赞,这是对我最大的鼓励,下面再为大家贴出insightface-master\recognition\train.py代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport os
import sys
import math
import random
import logging
import sklearn
import pickle
import numpy as np
import mxnet as mx
from mxnet import ndarray as nd
import argparse
import mxnet.optimizer as optimizer
from config import config, default, generate_config
from metric import *sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common'))
import flops_countersys.path.append(os.path.join(os.path.dirname(__file__), 'eval'))
import verificationsys.path.append(os.path.join(os.path.dirname(__file__), 'symbol'))
import fresnet
import fmobilefacenet
import fmobilenet
import fmnasnet
import fdensenetprint(mx.__file__)logger = logging.getLogger()
logger.setLevel(logging.INFO)args = Nonedef parse_args():parser = argparse.ArgumentParser(description='Train face network')# general# 训练的数据集默认配置parser.add_argument('--dataset', default=default.dataset, help='dataset config')# 默认网络结构选择parser.add_argument('--network', default=default.network, help='network config')# 使用默认损失函数parser.add_argument('--loss', default=default.loss, help='loss config')# 参数解析args, rest = parser.parse_known_args()generate_config(args.network, args.dataset, args.loss)# 模型保存的目录parser.add_argument('--models-root', default=default.models_root, help='root directory to save model.')# 预训练模型加载parser.add_argument('--pretrained', default=default.pretrained, help='pretrained model to load')# 指定与训练模型训练的epoch数parser.add_argument('--pretrained-epoch', type=int, default=default.pretrained_epoch,help='pretrained epoch to load')# 是否保存ckpt文件parser.add_argument('--ckpt', type=int, default=default.ckpt,help='checkpoint saving option. 0: discard saving. 1: save when necessary. 2: always save')# 验证每verbose个批次进行一次验证parser.add_argument('--verbose', type=int, default=default.verbose,help='do verification testing and model saving every verbose batches')# 学习率parser.add_argument('--lr', type=float, default=default.lr, help='start learning rate')parser.add_argument('--lr-steps', type=str, default=default.lr_steps, help='steps of lr changing')# 学习率衰减的权重parser.add_argument('--wd', type=float, default=default.wd, help='weight decay')# 梯度下降的动能parser.add_argument('--mom', type=float, default=default.mom, help='momentum')parser.add_argument('--frequent', type=int, default=default.frequent, help='')# 每个GPU没批次训练的样本数目parser.add_argument('--per-batch-size', type=int, default=default.per_batch_size, help='batch size in each context')# 键值存储的设置parser.add_argument('--kvstore', type=str, default=default.kvstore, help='kvstore setting')args = parser.parse_args()return argsdef get_symbol(args):# 获得一个特征向量embedding = eval(config.net_name).get_symbol()# 定义一个标签的占位符,用来存放标签all_label = mx.symbol.Variable('softmax_label')gt_label = all_labelis_softmax = True# 如果损失函数为softmaxif config.loss_name == 'softmax':# 定义一个全连接层的权重_weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size),lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01))# 如果不设置bias,则直接进行全链接if config.fc7_no_bias:fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, no_bias=True, num_hidden=config.num_classes,name='fc7')# 如果设置_bias,则创建_bias之后进行全连接else:_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=config.num_classes,name='fc7')# 如果损失函数为margin_softmaxelif config.loss_name == 'margin_softmax':# 创建一个权重占位符_weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size),lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01))# 获得loss中m的缩放系数s = config.loss_s# 先进行L2正则化,然后进行全链接_weight = mx.symbol.L2Normalization(_weight, mode='instance')nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * sfc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=config.num_classes,name='fc7')# 其存在m1,m2,m3是为了把算法整合在一起if config.loss_m1 != 1.0 or config.loss_m2 != 0.0 or config.loss_m3 != 0.0:if config.loss_m1 == 1.0 and config.loss_m2 == 0.0:s_m = s * config.loss_m3gt_one_hot = mx.sym.one_hot(gt_label, depth=config.num_classes, on_value=s_m, off_value=0.0)fc7 = fc7 - gt_one_hotelse:zy = mx.sym.pick(fc7, gt_label, axis=1)cos_t = zy / st = mx.sym.arccos(cos_t)if config.loss_m1 != 1.0:t = t * config.loss_m1if config.loss_m2 > 0.0:t = t + config.loss_m2body = mx.sym.cos(t)if config.loss_m3 > 0.0:body = body - config.loss_m3new_zy = body * sdiff = new_zy - zydiff = mx.sym.expand_dims(diff, 1)gt_one_hot = mx.sym.one_hot(gt_label, depth=config.num_classes, on_value=1.0, off_value=0.0)body = mx.sym.broadcast_mul(gt_one_hot, diff)fc7 = fc7 + body# 如果损失函数为tripletelif config.loss_name.find('triplet') >= 0:is_softmax = Falsenembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size // 3)positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size // 3,end=2 * args.per_batch_size // 3)negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2 * args.per_batch_size // 3, end=args.per_batch_size)if config.loss_name == 'triplet':ap = anchor - positivean = anchor - negativeap = ap * apan = an * anap = mx.symbol.sum(ap, axis=1, keepdims=1)  # (T,1)an = mx.symbol.sum(an, axis=1, keepdims=1)  # (T,1)triplet_loss = mx.symbol.Activation(data=(ap - an + config.triplet_alpha), act_type='relu')triplet_loss = mx.symbol.mean(triplet_loss)else:ap = anchor * positivean = anchor * negativeap = mx.symbol.sum(ap, axis=1, keepdims=1)  # (T,1)an = mx.symbol.sum(an, axis=1, keepdims=1)  # (T,1)ap = mx.sym.arccos(ap)an = mx.sym.arccos(an)triplet_loss = mx.symbol.Activation(data=(ap - an + config.triplet_alpha), act_type='relu')triplet_loss = mx.symbol.mean(triplet_loss)triplet_loss = mx.symbol.MakeLoss(triplet_loss)out_list = [mx.symbol.BlockGrad(embedding)]# 如果使用了softmaxif is_softmax:softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid')out_list.append(softmax)if config.ce_loss:# ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_sizebody = mx.symbol.SoftmaxActivation(data=fc7)body = mx.symbol.log(body)_label = mx.sym.one_hot(gt_label, depth=config.num_classes, on_value=-1.0, off_value=0.0)body = body * _labelce_loss = mx.symbol.sum(body) / args.per_batch_sizeout_list.append(mx.symbol.BlockGrad(ce_loss))# 如果是tripletelse:out_list.append(mx.sym.BlockGrad(gt_label))out_list.append(triplet_loss)# 聚集所有的符号out = mx.symbol.Group(out_list)return outdef train_net(args):# 判断使用GPU还是CPUctx = []cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()if len(cvd) > 0:for i in range(len(cvd.split(','))):ctx.append(mx.gpu(i))if len(ctx) == 0:ctx = [mx.cpu()]print('use cpu')else:print('gpu num:', len(ctx))# 保存模型的前缀prefix = os.path.join(args.models_root, '%s-%s-%s' % (args.network, args.loss, args.dataset), 'model')# 保存模型的路径prefix_dir = os.path.dirname(prefix)print('prefix', prefix)if not os.path.exists(prefix_dir):os.makedirs(prefix_dir)# GPU的数目args.ctx_num = len(ctx)# 计算总batch_sizeargs.batch_size = args.per_batch_size * args.ctx_numargs.rescale_threshold = 0args.image_channel = config.image_shape[2]config.batch_size = args.batch_size# 每个GPU一个批次的大小config.per_batch_size = args.per_batch_size# 训练数据的目录data_dir = config.dataset_pathpath_imgrec = Nonepath_imglist = None# 图片大小以及验证image_size = config.image_shape[0:2]assert len(image_size) == 2assert image_size[0] == image_size[1]print('image_size', image_size)# 数据集id数目print('num_classes', config.num_classes)path_imgrec = os.path.join(data_dir, "train.rec")print('Called with argument:', args, config)data_shape = (args.image_channel, image_size[0], image_size[1])mean = Nonebegin_epoch = 0# 判断预训练模型是否存在,如果不存在,初始化权重if len(args.pretrained) == 0:arg_params = Noneaux_params = Nonesym = get_symbol(args)  # 模型构建if config.net_name == 'spherenet':data_shape_dict = {'data': (args.per_batch_size,) + data_shape}spherenet.init_weights(sym, data_shape_dict, args.num_layers)else:  # 如果存在,则加载模型print('loading', args.pretrained, args.pretrained_epoch)_, arg_params, aux_params = mx.model.load_checkpoint(args.pretrained, args.pretrained_epoch)sym = get_symbol(args)# 浮点型数据占用空间计算if config.count_flops:all_layers = sym.get_internals()_sym = all_layers['fc1_output']FLOPs = flops_counter.count_flops(_sym, data=(1, 3, image_size[0], image_size[1]))_str = flops_counter.flops_str(FLOPs)print('Network FLOPs: %s' % _str)# label_name = 'softmax_label'# label_shape = (args.batch_size,)model = mx.mod.Module(context=mx.gpu(),symbol=sym,)val_dataiter = None# 主要获取数据的迭代器,triplet与sfotmax输入数据的迭代器是不一样的,具体哪里不一样,后续章节为大家分析if config.loss_name.find('triplet') >= 0:from triplet_image_iter import FaceImageItertriplet_params = [config.triplet_bag_size, config.triplet_alpha, config.triplet_max_ap]train_dataiter = FaceImageIter(batch_size=args.batch_size,data_shape=data_shape,path_imgrec=path_imgrec,shuffle=True,rand_mirror=config.data_rand_mirror,mean=mean,cutoff=config.data_cutoff,ctx_num=args.ctx_num,images_per_identity=config.images_per_identity,triplet_params=triplet_params,mx_model=model,)_metric = LossValueMetric()eval_metrics = [mx.metric.create(_metric)]else:from image_iter import FaceImageItertrain_dataiter = FaceImageIter(batch_size=args.batch_size,data_shape=data_shape,path_imgrec=path_imgrec,shuffle=True,rand_mirror=config.data_rand_mirror,mean=mean,cutoff=config.data_cutoff,color_jittering=config.data_color,images_filter=config.data_images_filter,)metric1 = AccMetric()eval_metrics = [mx.metric.create(metric1)]if config.ce_loss:metric2 = LossValueMetric()eval_metrics.append(mx.metric.create(metric2))# 对权重进行初始化if config.net_name == 'fresnet' or config.net_name == 'fmobilefacenet':initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2)  # resnet styleelse:initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)# initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style_rescale = 1.0 / args.ctx_numopt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)_cb = mx.callback.Speedometer(args.batch_size, args.frequent)# 加载所有测试数据集ver_list = []ver_name_list = []for name in config.val_targets:path = os.path.join(data_dir, name + ".bin")if os.path.exists(path):data_set = verification.load_bin(path, image_size)ver_list.append(data_set)ver_name_list.append(name)print('ver', name)# 对测试集进行测试def ver_test(nbatch):results = []for i in range(len(ver_list)):acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10,None, None)print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))# print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))results.append(acc2)return results# 最高的准曲率highest_acc = [0.0, 0.0]  # lfw and target# for i in range(len(ver_list)):#  highest_acc.append(0.0)global_step = [0]save_step = [0]lr_steps = [int(x) for x in args.lr_steps.split(',')]print('lr_steps', lr_steps)def _batch_callback(param):# global global_stepglobal_step[0] += 1mbatch = global_step[0]# 降低学习率到原来的十分之一for step in lr_steps:if mbatch == step:opt.lr *= 0.1print('lr change to', opt.lr)break_cb(param)# 每1000批次进行一次打印if mbatch % 1000 == 0:print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)# 进行if mbatch >= 0 and mbatch % args.verbose == 0:acc_list = ver_test(mbatch)save_step[0] += 1msave = save_step[0]do_save = Falseis_highest = Falseif len(acc_list) > 0:# lfw_score = acc_list[0]# if lfw_score>highest_acc[0]:#  highest_acc[0] = lfw_score#  if lfw_score>=0.998:#    do_save = Truescore = sum(acc_list)if acc_list[-1] >= highest_acc[-1]:if acc_list[-1] > highest_acc[-1]:is_highest = Trueelse:if score >= highest_acc[0]:is_highest = Truehighest_acc[0] = scorehighest_acc[-1] = acc_list[-1]# if lfw_score>=0.99:#  do_save = Trueif is_highest:do_save = Trueif args.ckpt == 0:do_save = Falseelif args.ckpt == 2:do_save = Trueelif args.ckpt == 3:msave = 1# 模型保存if do_save:print('saving', msave)arg, aux = model.get_params()if config.ckpt_embedding:all_layers = model.symbol.get_internals()_sym = all_layers['fc1_output']_arg = {}for k in arg:if not k.startswith('fc7'):_arg[k] = arg[k]mx.model.save_checkpoint(prefix, msave, _sym, _arg, aux)else:mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))if config.max_steps > 0 and mbatch > config.max_steps:sys.exit(0)epoch_cb = None# 把train_dataiter转化为mx.ioPrefetchingIter迭代器train_dataiter = mx.io.PrefetchingIter(train_dataiter)model.fit(train_dataiter,begin_epoch=begin_epoch,num_epoch=999999,eval_data=val_dataiter,eval_metric=eval_metrics,kvstore=args.kvstore,optimizer=opt,# optimizer_params   = optimizer_params,initializer=initializer,arg_params=arg_params,aux_params=aux_params,allow_missing=True,batch_end_callback=_batch_callback,epoch_end_callback=epoch_cb)def main():global argsargs = parse_args()train_net(args)if __name__ == '__main__':main()

以上除了def get_symbol(args)函数没有详细注释外,其他基本注释完成,该函数涉及到损失函数,比较复杂,下小节为大家详细讲解。记得关注点赞熬。


http://www.ppmy.cn/news/132178.html

相关文章

动力锂离子电池性能及安全测试方面的国内外标准

概述国内外动力锂离子电池在性能及安全测试方面的标准,从适用范围、测试项内容及严格程度等几个方面进行分析和比较。对国内动力锂离子电池标准体系的构建和发展进行总结和展望。 电池产品的标准,尤其是安全标准是约束质量的重要依据,也是规范…

LINUX 下使用脚本自动安装ORACLE12C数据库的简单实现

一.系统环境: OS:Red Hat Enterprise Linux Server release 7.6 DB:Oracle Database 12c Enterprise Edition Release 12.2.0.1.0 关闭防火墙,SELINUX,/etc/hosts 文件IP对应好主机名称,先将oracle12C的安装文件拷贝至根目录 /linuxx64_12201_database.zip 执行的内容不详…

2021年化工自动化控制仪表考试题及化工自动化控制仪表证考试

题库来源:安全生产模拟考试一点通公众号小程序 化工自动化控制仪表考试题参考答案及化工自动化控制仪表考试试题解析是安全生产模拟考试一点通题库老师及化工自动化控制仪表操作证已考过的学员汇总,相对有效帮助化工自动化控制仪表证考试学员顺利通过考…

电磁AI组细则建议

车模 限定使用C车模,限定车模只能正向运行,比赛任务侧重点不同,C车模的组装和调校最为简单,降低车模机械因素带来更多的不确定性; 赛道及环境 第一届建议仍然采用在室内PVC赛道(材料与前几届比赛相同&am…

铂电阻温度传感器计算

6.一个铂电阻传感器被用来测量0到200摄氏度的温度,已知当温度为 T ∘ C T^{\circ}C T∘C时传感器对应电阻为 R T Ω R_T\Omega RT​Ω, R T R_T RT​满足的方程为 R T R 0 ( 1 α T β T 2 ) R_TR_0(1\alpha T \beta T^2) RT​R0​(1αTβT2)&#xf…

斜探头的校准原理

1,斜探头的结构 ,2,斜探头入射零点的校准 在斜探头的校准过程中,找到R50和R100的最大回波,这时斜探头的入射位置大概就是R50 和 R100的圆点。 超声波探伤仪根据R50和R100的回波时间t1、t2和它们之间的距离差S150mm,可…

STM3使用光敏传感器计算光照度Lux,而不是仅仅打印个电压值或者电阻值

点击上方蓝字关注我吧 最近项目中用到了光敏电阻。搜索资料,发现很多人都使用光敏电阻,只是用了AD读取了电压值,或者算出了电阻值,就发送给上位机或者服务器,美其名曰获取了光照度。 搜索一番,也没找到用光…

Angr入门(一)

Angr学习 Top Level Interfaces基本信息Basic Block Loading a Binary基本信息Symbols and Relocations Program StateCFG 之前一直做静态代码检测,主要是针对未编译过的源代码文本,不过在文本层面能分析的问题只是一小部分,有些问题还得在执…