DCCRN语音增强模型训练脚本
DCCRN语音增强模型训练脚本
该代码是一个DCCRN(Deep Complex Convolutional Recurrent Network)的训练脚本,用于去除语音信号中的噪声。
import wav_loader as loader
import net_config as net_config
import pickle
from torch.utils.data import DataLoader
import module as model_cov_bn
from si_snr import *
import train_utils
import os
########################################################################
# Change the path to the path on your computer
dns_home = r'F:\Traindata\DNS-Challenge\make_data' # dir of dns-datas
save_file = './logs' # model save
########################################################################
batch_size = 400 # calculate batch_size
load_batch = 100 # load batch_size(not calculate)
device = torch.device('cuda:0') # device
lr = 0.001 # learning_rate
# load train and test name , train:test=4:1
if os.path.exists(r'./train_test_names.data'):
train_test = pickle.load(open('./train_test_names.data', 'rb'))
else:
train_test = train_utils.get_train_test_name(dns_home)
train_noisy_names, train_clean_names, test_noisy_names, test_clean_names = \
train_utils.get_all_names(train_test, dns_home=dns_home)
train_dataset = loader.WavDataset(train_noisy_names, train_clean_names, frame_dur=37.5)
test_dataset = loader.WavDataset(test_noisy_names, test_clean_names, frame_dur=37.5)
# dataloader
train_dataloader = DataLoader(train_dataset, batch_size=load_batch, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=load_batch, shuffle=True)
dccrn = model_cov_bn.DCCRN_(
n_fft=512, hop_len=int(6.25 * 16000 / 1000), net_params=net_config.get_net_params(), batch_size=batch_size,
device=device, win_length=int((25 * 16000 / 1000))).to(device)
optimizer = torch.optim.Adam(dccrn.parameters(), lr=lr)
criterion = SiSnr()
train_utils.train(model=dccrn, optimizer=optimizer, criterion=criterion, train_iter=train_dataloader,
test_iter=test_dataloader, max_epoch=500, device=device, batch_size=batch_size, log_path=save_file,
just_test=False)
代码详细解释
-
导入必要的包和模块
wav_loader:用于加载音频数据net_config:定义了DCCRN的网络结构pickle:用于保存和读取Python对象torch.utils.data.DataLoader:用于加载数据module:包含DCCRN模型定义si_snr:用于计算语音质量指标train_utils:包含训练相关函数os:用于操作文件系统
-
定义训练参数
batch_size:训练批次大小load_batch:加载批次大小device:训练设备lr:学习率
-
加载训练和测试数据集
- 从文件中加载训练和测试样本名称
- 使用
loader.WavDataset创建数据集 - 使用
DataLoader创建数据加载器
-
实例化DCCRN模型
- 使用
model_cov_bn.DCCRN_创建DCCRN模型 - 将模型移至指定设备
- 使用
-
定义优化器和损失函数
- 使用
torch.optim.Adam创建优化器 - 使用
SiSnr创建损失函数
- 使用
-
进行训练
- 使用
train_utils.train函数进行模型训练 - 传入模型、优化器、损失函数、训练和测试数据加载器、最大训练轮数等参数
- 训练过程中会将模型保存到指定的日志路径
- 使用
总结
该代码是一个完整的DCCRN模型训练脚本,包含了数据加载、模型定义、优化器和损失函数定义、训练过程等步骤。通过该脚本,可以训练一个能够有效去除语音噪声的DCCRN模型。
原文地址: https://www.cveoy.top/t/topic/ntMJ 著作权归作者所有。请勿转载和采集!