**ResNet-SE + MFCC** 训练框架,包括 **数据加载、训练流程**,以及 **混淆矩阵** 可视化示例

devtools/2025/3/14 5:53:42/

1. 依赖库安装

如果你还没安装相关库,请先执行:

pip install torch torchaudio torchvision scikit-learn matplotlib tqdm

2. 数据加载

这里假设你有一个 音频分类数据集,其文件结构如下:

dataset/
│── train/
│   ├── class_0/
│   │   ├── audio_0.wav
│   │   ├── audio_1.wav
│   ├── class_1/
│   │   ├── audio_0.wav
│   │   ├── audio_1.wav
│── val/
│   ├── class_0/
│   ├── class_1/

实现数据加载器:

python">import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms# 音频数据集类
class AudioDataset(Dataset):def __init__(self, root_dir, sample_rate=16000, n_mfcc=40):self.root_dir = root_dirself.sample_rate = sample_rateself.n_mfcc = n_mfccself.classes = sorted(os.listdir(root_dir))  # 目录名作为类别self.file_paths = []self.labels = []for label, class_name in enumerate(self.classes):class_dir = os.path.join(root_dir, class_name)for file_name in os.listdir(class_dir):self.file_paths.append(os.path.join(class_dir, file_name))self.labels.append(label)self.mfcc_transform = torchaudio.transforms.MFCC(sample_rate=self.sample_rate,n_mfcc=self.n_mfcc,melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 64})def __len__(self):return len(self.file_paths)def __getitem__(self, idx):file_path = self.file_paths[idx]label = self.labels[idx]waveform, sr = torchaudio.load(file_path)if sr != self.sample_rate:resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)waveform = resampler(waveform)mfcc = self.mfcc_transform(waveform).squeeze(0)  # (n_mfcc, time)mfcc = mfcc.unsqueeze(0).repeat(3, 1, 1)  # (3, n_mfcc, time) 适配 ResNetreturn mfcc, label# 创建数据加载器
def get_dataloaders(train_dir, val_dir, batch_size=32, num_workers=2):train_dataset = AudioDataset(train_dir)val_dataset = AudioDataset(val_dir)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_loader, val_loader

3. 训练和验证

python">import torch.optim as optim
from tqdm import tqdmdef train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001, device="cuda"):model = model.to(device)criterion = torch.nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)for epoch in range(num_epochs):print(f"Epoch [{epoch+1}/{num_epochs}]")# 训练阶段model.train()total_loss, correct, total = 0, 0, 0for inputs, labels in tqdm(train_loader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item()total += labels.size(0)correct += (outputs.argmax(dim=1) == labels).sum().item()train_acc = correct / totalprint(f"Train Loss: {total_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}")# 验证阶段model.eval()total_loss, correct, total = 0, 0, 0with torch.no_grad():for inputs, labels in tqdm(val_loader):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)total_loss += loss.item()total += labels.size(0)correct += (outputs.argmax(dim=1) == labels).sum().item()val_acc = correct / totalprint(f"Val Loss: {total_loss/len(val_loader):.4f}, Val Acc: {val_acc:.4f}")return model

4. 混淆矩阵可视化

python">import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplaydef evaluate_model(model, val_loader, device="cuda"):model.eval()all_preds = []all_labels = []with torch.no_grad():for inputs, labels in tqdm(val_loader):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)preds = outputs.argmax(dim=1).cpu().numpy()labels = labels.cpu().numpy()all_preds.extend(preds)all_labels.extend(labels)return np.array(all_labels), np.array(all_preds)def plot_confusion_matrix(y_true, y_pred, class_names):cm = confusion_matrix(y_true, y_pred)disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)disp.plot(cmap=plt.cm.Blues, values_format="d")plt.xticks(rotation=45)plt.show()

5. 运行完整训练流程

python">if __name__ == "__main__":train_dir = "dataset/train"val_dir = "dataset/val"batch_size = 32num_epochs = 10device = "cuda" if torch.cuda.is_available() else "cpu"# 加载数据train_loader, val_loader = get_dataloaders(train_dir, val_dir, batch_size)# 初始化模型model = ResNetSE(num_classes=len(os.listdir(train_dir)))# 训练模型trained_model = train_model(model, train_loader, val_loader, num_epochs=num_epochs, device=device)# 计算混淆矩阵y_true, y_pred = evaluate_model(trained_model, val_loader, device=device)# 绘制混淆矩阵class_names = sorted(os.listdir(train_dir))plot_confusion_matrix(y_true, y_pred, class_names)

6. 总结

数据加载

  • 通过 torchaudio 提取 MFCC 特征,并适配 ResNet 输入格式。

ResNet-SE 训练

  • 训练过程包含 Adam 优化器 + 交叉熵损失,支持 GPU 训练。

混淆矩阵可视化

  • 通过 sklearn 计算混淆矩阵,并绘制 分类效果图

改进方向

🚀 模型优化

  • 使用 ResNet-34/50 替代 ResNet-18 提升表达能力。
  • 结合 SpecAugment 增强数据,提高鲁棒性。

推理加速

  • 采用 TorchScript / ONNX 进行模型导出,提高部署效率。

💡 数据增强

  • 额外使用 时域和频域增强(如 torchaudio.transforms.TimeMasking)。

这样,你就能完整训练 ResNet-SE + MFCC 进行音频分类,并分析模型性能了!💪🚀


http://www.ppmy.cn/devtools/166937.html

相关文章

星越L_副驾驶屏使用讲解

目录 1.副驾驶屏在副驾驶前方 2.上方时间、网络、蓝牙显示 3.右侧导航可以返回主页,切换模式 4.空调控制 5.应用 6.应用商城下载应用 7.音乐 8.下滑屏幕 9.长按退出 10.一键息屏 1.副驾驶屏在副驾驶前方 此屏又叫娱乐屏,可以看电视,听音乐,终身免费会员,支持车机…

软件安全测评之渗透测试流程和工具分享

在信息技术快速发展的今天,软件安全显得尤为重要。软件渗透测试作为安全测评中重要的测试方法,是一种模拟攻击手段,用于识别应用程序、网络或系统中的安全漏洞。这一过程通常由专业的安全团队执行,目的是在攻击者可以利用这些漏洞…

Hutool RedisDS:Java开发中的Redis极简集成与高阶应用

在Java开发中,Redis作为高性能内存数据库,广泛应用于缓存、分布式锁等场景。然而原生的客户端操作涉及连接管理、序列化等繁琐细节。Hutool工具包提供的RedisDS模块,通过高度封装显著简化了这一过程。本文从实战角度解析其核心特性与使用技巧…

【机器学习】迁移学习(Transfer Learning)

迁移学习(Transfer Learning)作为一种机器学习方法,主要通过将源域中学到的知识迁移到目标域,解决目标域中数据不足或标注困难的问题,尤其在无监督学习如聚类任务中具有显著优势。迁移学习的关键思想包括领域适应、知识…

【DevOps】使用Azure DevOps为Azure静态网站配置多阶段部署

【DevOps】使用Azure DevOps为Azure静态网站配置多阶段部署 推荐超级课程: 本地离线DeepSeek AI方案部署实战教程【完全版】Docker快速入门到精通Kubernetes入门到大师通关课AWS云服务快速入门实战目录 【DevOps】使用Azure DevOps为Azure静态网站配置多阶段部署示例应用程序…

手写一个Tomcat

Tomcat 是一个广泛使用的开源 Java Servlet 容器,用于运行 Java Web 应用程序。虽然 Tomcat 本身功能强大且复杂,但通过手写一个简易版的 Tomcat,我们可以更好地理解其核心工作原理。本文将带你一步步实现一个简易版的 Tomcat,并深…

如何在Django中实现批量覆盖更新的示例

在使用Django进行开发时,数据的更新是一个常见的操作。有时候,我们需要对多个记录进行批量覆盖更新,这样可以提高效率,减少数据库的交互次数。本文将详细介绍如何在Django中实现批量覆盖更新,并提供示例代码来帮助你更…

篮球游戏(200分)

(200分)113.篮球游戏 篮球游戏 问题描述 幼儿园里有一个放倒的圆桶,它是一个线性结构,允许在桶的右边将篮球放入,可以在桶的左边和右边将篮球取出。每个篮球有单独的编号。老师可以连续放入一个或多个篮球,小朋友可以在桶左边或右边将篮球取出。当桶只有一个篮球时,必须…