1. 依赖库安装
如果你还没安装相关库,请先执行:
pip install torch torchaudio torchvision scikit-learn matplotlib tqdm
2. 数据加载
这里假设你有一个 音频分类数据集,其文件结构如下:
dataset/
│── train/
│ ├── class_0/
│ │ ├── audio_0.wav
│ │ ├── audio_1.wav
│ ├── class_1/
│ │ ├── audio_0.wav
│ │ ├── audio_1.wav
│── val/
│ ├── class_0/
│ ├── class_1/
实现数据加载器:
python">import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms# 音频数据集类
class AudioDataset(Dataset):def __init__(self, root_dir, sample_rate=16000, n_mfcc=40):self.root_dir = root_dirself.sample_rate = sample_rateself.n_mfcc = n_mfccself.classes = sorted(os.listdir(root_dir)) # 目录名作为类别self.file_paths = []self.labels = []for label, class_name in enumerate(self.classes):class_dir = os.path.join(root_dir, class_name)for file_name in os.listdir(class_dir):self.file_paths.append(os.path.join(class_dir, file_name))self.labels.append(label)self.mfcc_transform = torchaudio.transforms.MFCC(sample_rate=self.sample_rate,n_mfcc=self.n_mfcc,melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 64})def __len__(self):return len(self.file_paths)def __getitem__(self, idx):file_path = self.file_paths[idx]label = self.labels[idx]waveform, sr = torchaudio.load(file_path)if sr != self.sample_rate:resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)waveform = resampler(waveform)mfcc = self.mfcc_transform(waveform).squeeze(0) # (n_mfcc, time)mfcc = mfcc.unsqueeze(0).repeat(3, 1, 1) # (3, n_mfcc, time) 适配 ResNetreturn mfcc, label# 创建数据加载器
def get_dataloaders(train_dir, val_dir, batch_size=32, num_workers=2):train_dataset = AudioDataset(train_dir)val_dataset = AudioDataset(val_dir)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_loader, val_loader
3. 训练和验证
python">import torch.optim as optim
from tqdm import tqdmdef train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001, device="cuda"):model = model.to(device)criterion = torch.nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)for epoch in range(num_epochs):print(f"Epoch [{epoch+1}/{num_epochs}]")# 训练阶段model.train()total_loss, correct, total = 0, 0, 0for inputs, labels in tqdm(train_loader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item()total += labels.size(0)correct += (outputs.argmax(dim=1) == labels).sum().item()train_acc = correct / totalprint(f"Train Loss: {total_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}")# 验证阶段model.eval()total_loss, correct, total = 0, 0, 0with torch.no_grad():for inputs, labels in tqdm(val_loader):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)total_loss += loss.item()total += labels.size(0)correct += (outputs.argmax(dim=1) == labels).sum().item()val_acc = correct / totalprint(f"Val Loss: {total_loss/len(val_loader):.4f}, Val Acc: {val_acc:.4f}")return model
4. 混淆矩阵可视化
python">import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplaydef evaluate_model(model, val_loader, device="cuda"):model.eval()all_preds = []all_labels = []with torch.no_grad():for inputs, labels in tqdm(val_loader):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)preds = outputs.argmax(dim=1).cpu().numpy()labels = labels.cpu().numpy()all_preds.extend(preds)all_labels.extend(labels)return np.array(all_labels), np.array(all_preds)def plot_confusion_matrix(y_true, y_pred, class_names):cm = confusion_matrix(y_true, y_pred)disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)disp.plot(cmap=plt.cm.Blues, values_format="d")plt.xticks(rotation=45)plt.show()
5. 运行完整训练流程
python">if __name__ == "__main__":train_dir = "dataset/train"val_dir = "dataset/val"batch_size = 32num_epochs = 10device = "cuda" if torch.cuda.is_available() else "cpu"# 加载数据train_loader, val_loader = get_dataloaders(train_dir, val_dir, batch_size)# 初始化模型model = ResNetSE(num_classes=len(os.listdir(train_dir)))# 训练模型trained_model = train_model(model, train_loader, val_loader, num_epochs=num_epochs, device=device)# 计算混淆矩阵y_true, y_pred = evaluate_model(trained_model, val_loader, device=device)# 绘制混淆矩阵class_names = sorted(os.listdir(train_dir))plot_confusion_matrix(y_true, y_pred, class_names)
6. 总结
✅ 数据加载
- 通过
torchaudio
提取 MFCC 特征,并适配 ResNet 输入格式。
✅ ResNet-SE 训练
- 训练过程包含 Adam 优化器 + 交叉熵损失,支持 GPU 训练。
✅ 混淆矩阵可视化
- 通过
sklearn
计算混淆矩阵,并绘制 分类效果图。
改进方向
🚀 模型优化
- 使用 ResNet-34/50 替代 ResNet-18 提升表达能力。
- 结合 SpecAugment 增强数据,提高鲁棒性。
⚡ 推理加速
- 采用 TorchScript / ONNX 进行模型导出,提高部署效率。
💡 数据增强
- 额外使用 时域和频域增强(如
torchaudio.transforms.TimeMasking
)。
这样,你就能完整训练 ResNet-SE + MFCC 进行音频分类,并分析模型性能了!💪🚀