音频增强模型训练:使用生成对抗网络 (GAN) 和条件损失
import argparse import os
import torch import torch.nn as nn from scipy.io import wavfile from torch import optim from torch.autograd import Variable from torch.utils.data import DataLoader from tqdm import tqdm
from data_preprocess import sample_rate from model import Generator, Discriminator from utils import AudioDataset, emphasis
if name == 'main': parser = argparse.ArgumentParser(description='Train Audio Enhancement') parser.add_argument('--batch_size', default=64, type=int, help='train batch size') parser.add_argument('--num_epochs', default=86, type=int, help='train epochs number')
opt = parser.parse_args()
BATCH_SIZE = opt.batch_size
NUM_EPOCHS = opt.num_epochs
# load data
print('loading data...')
train_dataset = AudioDataset(data_type='train')
test_dataset = AudioDataset(data_type='test')
train_data_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
# generate reference batch
ref_batch = train_dataset.reference_batch(BATCH_SIZE)
# create D and G instances
discriminator = Discriminator()
generator = Generator()
if torch.cuda.is_available():
discriminator.cuda()
generator.cuda()
ref_batch = ref_batch.cuda()
ref_batch = Variable(ref_batch)
print('# generator parameters:', sum(param.numel() for param in generator.parameters()))
print('# discriminator parameters:', sum(param.numel() for param in discriminator.parameters()))
# optimizers
g_optimizer = optim.RMSprop(generator.parameters(), lr=0.0001)
d_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.0001)
for epoch in range(NUM_EPOCHS):
train_bar = tqdm(train_data_loader)
for train_batch, train_clean, train_noisy in train_bar:
# latent vector - normal distribution
z = nn.init.normal(torch.Tensor(train_batch.size(0), 1024, 8))
if torch.cuda.is_available():
train_batch, train_clean, train_noisy = train_batch.cuda(), train_clean.cuda(), train_noisy.cuda()
z = z.cuda()
train_batch, train_clean, train_noisy = Variable(train_batch), Variable(train_clean), Variable(train_noisy)
z = Variable(z)
# TRAIN D to recognize clean audio as clean
# training batch pass
discriminator.zero_grad()
with torch.no_grad():
outputs = discriminator(train_batch, ref_batch)
clean_loss = torch.mean((outputs - 1.0) ** 2) # L2 loss - we want them all to be 1
clean_loss.requires_grad_(True)
loss = torch.zeros(1, requires_grad=True)
clean_loss.backward()
# TRAIN D to recognize generated audio as noisy
generated_outputs = generator(train_noisy, z)
with torch.no_grad():
outputs = discriminator(torch.cat((generated_outputs, train_noisy), dim=1), ref_batch)
noisy_loss = torch.mean(outputs ** 2) # L2 loss - we want them all to be 0
noisy_loss.requires_grad_(True)
noisy_loss.backward()
# d_loss = clean_loss + noisy_loss
d_optimizer.step() # update parameters
# TRAIN G so that D recognizes G(z) as real
generator.zero_grad()
with torch.no_grad():
generated_outputs = generator(train_noisy, z)
gen_noise_pair = torch.cat((generated_outputs, train_noisy), dim=1)
#with torch.no_grad():
outputs = discriminator(gen_noise_pair, ref_batch)
g_loss_ = 0.5 * torch.mean((outputs - 1.0) ** 2)
# L1 loss between generated output and clean sample
l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(train_clean)))
g_cond_loss = 100 * torch.mean(l1_dist) # conditional loss
g_loss = g_loss_ + g_cond_loss
# backprop + optimize
g_loss.requires_grad_(True)
g_loss.backward()
g_optimizer.step()
train_bar.set_description(
'Epoch {}: d_clean_loss {:.4f}, d_noisy_loss {:.4f}, g_loss {:.4f}, g_conditional_loss {:.4f}'\n .format(epoch + 1, clean_loss.data, noisy_loss.data, g_loss.data, g_cond_loss.data))
# TEST model
test_bar = tqdm(test_data_loader, desc='Test model and save generated audios')
for test_file_names, test_noisy in test_bar:
z = nn.init.normal(torch.Tensor(test_noisy.size(0), 1024, 8))
if torch.cuda.is_available():
test_noisy, z = test_noisy.cuda(), z.cuda()
test_noisy, z = Variable(test_noisy), Variable(z)
fake_speech = generator(test_noisy, z).data.cpu().numpy() # convert to numpy array
fake_speech = emphasis(fake_speech, emph_coeff=0.95, pre=False)
for idx in range(fake_speech.shape[0]):
generated_sample = fake_speech[idx]
file_name = os.path.join('results',\n '{}_e{}.wav'.format(test_file_names[idx].replace('.npy', ''), epoch + 1))
wavfile.write(file_name, sample_rate, generated_sample.T)
# save the model parameters for each epoch
g_path = os.path.join('epochs', 'generator-{}.pkl'.format(epoch + 1))
d_path = os.path.join('epochs', 'discriminator-{}.pkl'.format(epoch + 1))
torch.save(generator.state_dict(), g_path)
torch.save(discriminator.state_dict(), d_path)
中g_cond_loss = 100 * torch.mean(l1_dist)是什么意思,100这个数字调为10对生成器的影响是什么
g_cond_loss是生成器的条件损失,用来衡量生成的音频与原始干净音频之间的差异。l1_dist计算了生成的音频与原始干净音频之间的L1距离,即差异的绝对值之和。将l1_dist乘以100是为了放大其对总损失的影响,使得生成器更加关注生成音频与原始干净音频之间的相似性。
将100调整为10会减小条件损失对总损失的影响,使得生成器更加关注对抗损失。这可能会导致生成器生成的音频更加接近噪声,而不是原始的干净音频。因此,调整条件损失的权重需要根据具体的任务和数据集进行调整。
原文地址: https://www.cveoy.top/t/topic/nvyZ 著作权归作者所有。请勿转载和采集!