神经网络初始化 (init) 介绍

server/2025/1/16 0:20:20/

文章目录

    • 引言
    • 1. 初始化的重要性
      • 1.1 打破对称性
      • 1.2 控制方差
      • 1.3 加速收敛与提高泛化能力
    • 2. 常见的初始化方法及其应用场景
    • 3. 如何设置初始化
    • 4. 基于 BERT 的文本分类如何进行初始化
      • 4.1 项目背景
      • 4.2 模型构建
      • 4.3 模型训练与评估
      • 4.4 结果分析
    • 结论
    • 参考资料

引言

深度学习的世界中,构建一个高效且性能优异的神经网络模型需要综合考虑多个因素。尽管选择合适的架构和优化算法至关重要,但权重初始化这一环节同样不容忽视。合适的初始化策略不仅能加速模型的收敛速度,提升训练稳定性,还能显著影响最终的模型性能。

1. 初始化的重要性

在深入探讨各种初始化方法之前,首先需要理解权重初始化神经网络训练中的关键作用。

1.1 打破对称性

如果所有神经元的权重初始化为相同的值,网络在训练初期将无法学习到多样化的特征。这是因为在前向传播和反向传播过程中,每个神经元都会计算出相同的输出和梯度,导致它们在训练过程中同步更新,学习到相同的内容。打破对称性通过确保每个神经元的初始权重具有一定的随机性,使得每个神经元能够独立地探索不同的特征空间,从而提高模型的表达能力。

1.2 控制方差

在深层网络中,信号通过多层传播可能会逐渐放大或缩小,导致梯度消失或爆炸。这些问题会严重影响模型的训练效果,尤其是在反向传播阶段。合理的权重初始化能够保持每一层输出的方差稳定,确保信号在整个网络中均匀传播,避免梯度消失或爆炸,从而促进稳定的训练过程。

1.3 加速收敛与提高泛化能力

正确的初始化策略能够引导损失函数的优化过程,使其更容易找到好的局部最小值或全局最小值。这不仅能加快训练速度,还能提升模型的最终性能。此外,合理的初始化还能够帮助模型更快地进入一个具备良好泛化能力的参数区域,提升其在未见数据上的表现。

2. 常见的初始化方法及其应用场景

根据不同的激活函数和网络架构,存在多种权重初始化方法。以下是几种常见且有效的初始化策略及其适用场景。

2.1 Xavier/Glorot 初始化

适用场景:适用于激活函数为 Sigmoid 或 Tanh 的神经网络

原理:Xavier 初始化通过维持每一层输入和输出信号的方差一致,防止梯度在传播过程中逐渐消失或爆炸。具体来说,它根据输入和输出神经元的数量来设定权重的初始化范围,通常采用均匀分布或正态分布。

示例

import torch.nn as nn# Xavier 初始化示例
linear = nn.Linear(in_features=256, out_features=128)
nn.init.xavier_uniform_(linear.weight)

2.2 He 初始化

适用场景:专为 ReLU 及其变体设计的神经网络

原理:He 初始化考虑到 ReLU 激活函数的非负输出特性,调整了初始化时权重的方差,使其更适合 ReLU 的特性。这样可以有效地保持信号在前向传播过程中的标准差不变,避免梯度消失。

示例

import torch.nn as nn# He 初始化示例
conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3)
nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')

2.3 正交初始化

适用场景:特别适合循环神经网络(RNNs)和非常深的前馈网络。

原理:正交初始化通过确保权重矩阵的列彼此正交且具有单位长度,有效地防止了梯度消失或爆炸的问题。对于 RNN 来说,正交矩阵能够保持序列数据的长时间依赖关系,提升模型的表现。

示例

import torch.nn as nn# 正交初始化示例
linear = nn.Linear(in_features=256, out_features=256)
nn.init.orthogonal_(linear.weight)

2.4 其他初始化方法

除了上述主要方法外,还有一些其他初始化策略,尽管在现代实践中使用较少,但在特定场景下也有其应用价值。

  • 初始化:将权重初始化为零。这种方法通常不推荐用于隐藏层,因为会导致对称性问题,但可以用于初始化偏置项。

    import torch.nn as nn# 零初始化示例
    linear = nn.Linear(in_features=256, out_features=128)
    nn.init.constant_(linear.weight, 0)
    
  • 随机初始化:使用标准正态分布或均匀分布随机初始化权重。虽然简单,但在深层网络中可能导致梯度问题。

    import torch.nn as nn# 随机初始化示例
    linear = nn.Linear(in_features=256, out_features=128)
    nn.init.normal_(linear.weight, mean=0.0, std=0.02)
    
  • 稀疏初始化初始化部分权重为非零,其他为零。适用于希望网络具有稀疏连接的场景。

    import torch.nn as nn# 稀疏初始化示例
    linear = nn.Linear(in_features=256, out_features=128)
    nn.init.sparse_(linear.weight, sparsity=0.1)
    

3. 如何设置初始化

深度学习框架如 PyTorch 中,权重初始化可以通过自定义初始化方法或直接利用内置函数来实现。以下以一个简单的卷积神经网络(CNN)为例,展示如何在构造函数中应用不同的初始化策略。

import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self, num_classes=10):super(SimpleCNN, self).__init__()# 定义网络层self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(32)self.relu = nn.ReLU(inplace=True)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc = nn.Linear(32 * 16 * 16, num_classes)  # 假设输入图像大小为32x32# 初始化权重self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):# 使用 He 初始化nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):# 批归一化层权重初始化为1,偏置初始化为0nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):# 使用 Xavier 初始化nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x = self.relu(self.bn1(self.conv1(x)))x = self.pool(x)x = x.view(x.size(0), -1)x = self.fc(x)return x

解析

  • 卷积层(Conv2d):采用 He 初始化,适合 ReLU 激活函数,确保信号在前向传播中的稳定。

  • 批归一化层(BatchNorm2d):权重初始化为1,偏置初始化为0,保证初始状态下批归一化层的标准化效果。

  • 全连接层(Linear):采用 Xavier 初始化,适合 Sigmoid 或 Tanh 激活函数,保持输入和输出信号的方差一致。

4. 基于 BERT 的文本分类如何进行初始化

为了更深入地理解权重初始化的实际应用,本文将通过一个具体的文本分类任务,展示如何在预训练模型 BERT 的基础上进行初始化和微调。

4.1 项目背景

文本分类是自然语言处理中的基础任务之一,广泛应用于情感分析、垃圾邮件检测、话题分类等场景。近年来,预训练语言模型如 BERT(Bidirectional Encoder Representations from Transformers)因其强大的语言理解能力,成为文本分类任务的首选基础模型。

4.2 模型构建

以下代码展示了如何构建一个基于 BERT 的文本分类器,并在新增加的分类层上应用权重初始化

from transformers import BertModel, BertTokenizer
import torch.nn as nnclass BertForTextClassification(nn.Module):def __init__(self, num_labels=2, dropout_rate=0.3):super(BertForTextClassification, self).__init__()# 加载预训练的 BERT 模型self.bert = BertModel.from_pretrained('bert-base-uncased')self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')        # 冻结 BERT 模型的参数以加快训练速度(可选)for param in self.bert.parameters():param.requires_grad = False       # 定义分类头self.dropout = nn.Dropout(dropout_rate)self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)        # 初始化分类层的权重self._initialize_weights()def _initialize_weights(self):# 使用 Xavier 初始化分类层权重nn.init.xavier_uniform_(self.classifier.weight)if self.classifier.bias is not None:nn.init.zeros_(self.classifier.bias)def forward(self, input_ids, attention_mask=None, token_type_ids=None):# 获取 BERT 的输出outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)# 取 [CLS] 标记的输出作为分类依据cls_output = outputs.last_hidden_state[:, 0, :]cls_output = self.dropout(cls_output)logits = self.classifier(cls_output)return logits

关键步骤解析

  1. 加载预训练模型

    self.bert = BertModel.from_pretrained('bert-base-uncased')
    self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    

    通过 transformers 库加载预训练的 BERT 模型及其对应的分词器。

  2. 冻结预训练模型参数(可选)

    for param in self.bert.parameters():param.requires_grad = False
    

    冻结 BERT 模型的参数,仅训练新增加的分类层,能够显著减少训练时间和计算资源消耗,适用于数据量较小的场景。

  3. 定义分类头

    self.dropout = nn.Dropout(dropout_rate)
    self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
    

    使用 dropout 层防止过拟合,并通过全连接层将 BERT 的输出映射到分类标签空间。

  4. 初始化分类层权重

    self._initialize_weights()
    

    为新增加的分类层应用 Xavier 初始化,确保其在训练开始时具有良好的表现。

4.3 模型训练与评估

在训练过程中,合理的初始化策略能够帮助模型更快地收敛,并在有限的训练迭代中达到较好的性能。以下是一个简单的训练和评估流程示例:

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW
from sklearn.metrics import accuracy_score# 假设已定义好文本数据集和数据加载器
class TextDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_length=128):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_length = max_length    def __len__(self):return len(self.texts)    def __getitem__(self, idx):encoding = self.tokenizer.encode_plus(self.texts[idx],add_special_tokens=True,max_length=self.max_length,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt')return {'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'labels': torch.tensor(self.labels[idx], dtype=torch.long)}# 初始化数据集和数据加载器
train_dataset = TextDataset(train_texts, train_labels, model.tokenizer)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = TextDataset(val_texts, val_labels, model.tokenizer)
val_loader = DataLoader(val_dataset, batch_size=32)# 初始化模型、优化器和损失函数
model = BertForTextClassification(num_labels=2)
optimizer = AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()# 训练循环
for epoch in range(3):  # 假设训练3个epochmodel.train()for batch in train_loader:optimizer.zero_grad()input_ids = batch['input_ids']attention_mask = batch['attention_mask']labels = batch['labels']outputs = model(input_ids=input_ids, attention_mask=attention_mask)loss = criterion(outputs, labels)loss.backward()optimizer.step()    # 评估model.eval()all_preds, all_labels = [], []with torch.no_grad():for batch in val_loader:input_ids = batch['input_ids']attention_mask = batch['attention_mask']labels = batch['labels']outputs = model(input_ids=input_ids, attention_mask=attention_mask)preds = torch.argmax(outputs, dim=1)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())acc = accuracy_score(all_labels, all_preds)print(f"Epoch {epoch + 1} - Validation Accuracy: {acc:.4f}")

关键点

  • 数据预处理:使用 BERT 的分词器将文本转换为模型可接受的输入格式,包括 input_idsattention_mask

  • 冻结与微调:根据具体需求,可以选择冻结 BERT 的部分或全部参数,仅训练新增加的层,或进行全模型微调。

  • 优化器与损失函数:使用 AdamW 优化器和交叉熵损失函数,适用于分类任务。

4.4 结果分析

通过合理的权重初始化和预训练模型的优势,基于 BERT 的文本分类器在多个标准数据集上表现出色。例如,在情感分析任务中,冻结 BERT 参数并仅训练分类层的模型,能够在较短的训练时间内达到接近全模型微调的性能,同时显著减少计算资源的消耗。

结论

权重初始化神经网络训练中扮演着至关重要的角色。合适的初始化策略不仅能够打破对称性,控制信号的方差,还能加速模型的收敛,提高泛化能力。本文系统地介绍了几种常见的初始化方法及其适用场景,并通过基于 BERT 的文本分类示例,展示了如何在实际项目中应用这些初始化策略。

参考资料

  • Glorot, X., & Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks.

  • He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving deep into rectifiers: Surpassing human-level performance on imagenet classification.

  • BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

  • PyTorch 官方文档 - 权重初始化


http://www.ppmy.cn/server/158697.html

相关文章

分布式ID的实现方案

1. 什么是分布式ID ​ 对于低访问量的系统来说,无需对数据库进行分库分表,单库单表完全可以应对,但是随着系统访问量的上升,单表单库的访问压力逐渐增大,这时候就需要采用分库分表的方案,来缓解压力。 ​…

<C++学习> C++ Boost 字符串操作教程

C Boost 字符串操作教程 Boost 提供了一些实用的库来增强 C 的字符串操作能力,特别是 Boost.StringAlgo 和其他与字符串相关的工具。这些库为字符串处理提供了更高效、更简洁的方法,相比标准库功能更为丰富。 1. Boost.StringAlgo 简介 Boost.StringAl…

数据结构:栈(Stack)和队列(Queue)—面试题(二)

1. 用队列实现栈。 习题链接https://leetcode.cn/problems/implement-stack-using-queues/description/描述: 请你仅使用两个队列实现一个后入先出(LIFO)的栈,并支持普通栈的全部四种操作(push、top、pop 和 empty&a…

JVM之垃圾回收器ZGC概述以及垃圾回收器总结的详细解析

ZGC ZGC 收集器是一个可伸缩的、低延迟的垃圾收集器,基于 Region 内存布局的,不设分代,使用了读屏障、染色指针和内存多重映射等技术来实现可并发的标记压缩算法 在 CMS 和 G1 中都用到了写屏障,而 ZGC 用到了读屏障 染色指针&a…

2025年01月13日Github流行趋势

1. 项目名称:Jobs_Applier_AI_Agent 项目地址url:https://github.com/feder-cr/Jobs_Applier_AI_Agent项目语言:Python历史star数:25929今日star数:401项目维护者:surapuramakhil, feder-cr, cjbbb, sarob…

13:00面试,13:08就出来了,问的问题有点变态。。。

从小厂出来,没想到在另一家公司又寄了。 到这家公司开始上班,加班是每天必不可少的,看在钱给的比较多的份上,就不太计较了。没想到9月一纸通知,所有人不准加班,加班费不仅没有了,薪资还要降40%…

linux手动安装mysql5.7

一、下载mysql5.7 1、可以去官方网站下载mysql-5.7.24-linux-glibc2.12-x86_64.tar压缩包: https://downloads.mysql.com/archives/community/ 2、在线下载,使用wget命令,直接从官网下载到linux服务器上 wget https://downloads.mysql.co…

第432场周赛:跳过交替单元格的之字形遍历、机器人可以获得的最大金币数、图的最大边权的最小值、统计 K 次操作以内得到非递减子数组的数目

Q1、跳过交替单元格的之字形遍历 1、题目描述 给你一个 m x n 的二维数组 grid,数组由 正整数 组成。 你的任务是以 之字形 遍历 grid,同时跳过每个 交替 的单元格。 之字形遍历的定义如下: 从左上角的单元格 (0, 0) 开始。在当前行中向…