PyTorch实现CNN音频降噪模型:完整代码及示例
由于缺乏具体的数据集和评估指标,以下代码仅供参考。
数据加载与处理:
import torch
import torchaudio
import os
import numpy as np
# 定义数据处理函数
def process_data(path):
waveform, sr = torchaudio.load(path)
waveform = waveform.mean(0, keepdim=True)
waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
waveform = torchaudio.transforms.Vad(sample_rate=16000)(waveform)
waveform = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=256)(waveform)
mag, phase = torchaudio.functional.magphase(waveform)
mag = mag.unsqueeze(1)
return mag
# 定义数据集类
class AudioDataset(torch.utils.data.Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.files = os.listdir(root_dir)
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
path = os.path.join(self.root_dir, self.files[idx])
mag = process_data(path)
return mag
模型定义:
import torch.nn as nn
import torch.nn.functional as F
# 定义CNN降噪模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 1, kernel_size=3, padding=1)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.conv3(x)
return x
训练与保存模型:
# 定义训练函数
def train(model, train_loader, optimizer, criterion, device):
model.train()
train_loss = 0
for batch_idx, data in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, data)
loss.backward()
optimizer.step()
train_loss += loss.item()
return train_loss / len(train_loader)
# 定义保存模型函数
def save_model(model, epoch, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, f'model_epoch{epoch}.pt')
torch.save(model.state_dict(), save_path)
# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
train_dataset = AudioDataset(root_dir='data/train')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
for epoch in range(10):
train_loss = train(model, train_loader, optimizer, criterion, device)
print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}')
save_model(model, epoch+1, save_dir='models')
测试与评估:
# 定义测试函数
def test(model, test_loader, device):
model.eval()
test_loss = 0
with torch.no_grad():
for batch_idx, data in enumerate(test_loader):
data = data.to(device)
output = model(data)
loss = criterion(output, data)
test_loss += loss.item()
return test_loss / len(test_loader)
# 加载模型
model = CNN().to(device)
model.load_state_dict(torch.load('models/model_epoch10.pt'))
# 定义评估函数
def evaluate(model, data_path, save_path, device):
data = process_data(data_path).to(device)
output = model(data)
output = output.squeeze(1).detach().cpu().numpy()
output = torchaudio.functional.istft(torch.from_numpy(output), n_fft=512, hop_length=256)
torchaudio.save(save_path, output, 16000)
# 评估模型
data_path = 'data/test/test.wav'
save_path = 'data/test/test_denoised.wav'
evaluate(model, data_path, save_path, device)
原文地址: https://www.cveoy.top/t/topic/nsMX 著作权归作者所有。请勿转载和采集!