import argparse import os

import torch import torch.nn as nn from scipy.io import wavfile from torch import optim from torch.autograd import Variable from torch.utils.data import DataLoader from tqdm import tqdm

from data_preprocess import sample_rate from model import Generator, Discriminator from utils import AudioDataset, emphasis

if name == 'main': parser = argparse.ArgumentParser(description='Train Audio Enhancement') parser.add_argument('--batch_size', default=32, type=int, help='train batch size') parser.add_argument('--num_epochs', default=50, type=int, help='train epochs number')

opt = parser.parse_args()
BATCH_SIZE = opt.batch_size
NUM_EPOCHS = opt.num_epochs

# load data
print('loading data...')
train_dataset = AudioDataset(data_type='train')
test_dataset = AudioDataset(data_type='test')
train_data_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
# generate reference batch
ref_batch = train_dataset.reference_batch(BATCH_SIZE)

# create D and G instances
discriminator = Discriminator()
generator = Generator()
if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()
    ref_batch = ref_batch.cuda()
ref_batch = Variable(ref_batch)
print('# generator parameters:', sum(param.numel() for param in generator.parameters()))
print('# discriminator parameters:', sum(param.numel() for param in discriminator.parameters()))
# optimizers
g_optimizer = optim.RMSprop(generator.parameters(), lr=0.0001)
d_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.0001)

for epoch in range(NUM_EPOCHS):
    train_bar = tqdm(train_data_loader)
    for train_batch, train_clean, train_noisy in train_bar:

        # latent vector - normal distribution
        z = nn.init.normal(torch.Tensor(train_batch.size(0), 1024, 8))
        if torch.cuda.is_available():
            train_batch, train_clean, train_noisy = train_batch.cuda(), train_clean.cuda(), train_noisy.cuda()
            z = z.cuda()
        train_batch, train_clean, train_noisy = Variable(train_batch), Variable(train_clean), Variable(train_noisy)
        z = Variable(z)

        # TRAIN D to recognize clean audio as clean
        # training batch pass
        discriminator.zero_grad()
        with torch.no_grad():
            outputs = discriminator(train_batch, ref_batch)


            clean_loss = torch.mean((outputs - 1.0) ** 2)  # L2 loss - we want them all to be 1
            clean_loss.requires_grad_(True)
            loss = torch.zeros(1, requires_grad=True)
            clean_loss.backward()

        # TRAIN D to recognize generated audio as noisy
        generated_outputs = generator(train_noisy, z)
        with torch.no_grad():
            outputs = discriminator(torch.cat((generated_outputs, train_noisy), dim=1), ref_batch)
        noisy_loss = torch.mean(outputs ** 2)  # L2 loss - we want them all to be 0
        noisy_loss.requires_grad_(True)
        noisy_loss.backward()

        # d_loss = clean_loss + noisy_loss
        d_optimizer.step()  # update parameters

        # TRAIN G so that D recognizes G(z) as real
        generator.zero_grad()
        with torch.no_grad():
            generated_outputs = generator(train_noisy, z)
            gen_noise_pair = torch.cat((generated_outputs, train_noisy), dim=1)
        #with torch.no_grad():
            outputs = discriminator(gen_noise_pair, ref_batch)

        g_loss_ = 0.5 * torch.mean((outputs - 1.0) ** 2)
        # L1 loss between generated output and clean sample
        l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(train_clean)))
        g_cond_loss = 100 * torch.mean(l1_dist)  # conditional loss
        g_loss = g_loss_ + g_cond_loss

        # backprop + optimize
        g_loss.requires_grad_(True)
        g_loss.backward()
        g_optimizer.step()

        train_bar.set_description(
            'Epoch {}: d_clean_loss {:.4f}, d_noisy_loss {:.4f}, g_loss {:.4f}, g_conditional_loss {:.4f}'\n                    .format(epoch + 1, clean_loss.data, noisy_loss.data, g_loss.data, g_cond_loss.data))

    # TEST model
    test_bar = tqdm(test_data_loader, desc='Test model and save generated audios')
    for test_file_names, test_noisy in test_bar:
        z = nn.init.normal(torch.Tensor(test_noisy.size(0), 1024, 8))
        if torch.cuda.is_available():
            test_noisy, z = test_noisy.cuda(), z.cuda()
        test_noisy, z = Variable(test_noisy), Variable(z)
        fake_speech = generator(test_noisy, z).data.cpu().numpy()  # convert to numpy array
        fake_speech = emphasis(fake_speech, emph_coeff=0.95, pre=False)

        for idx in range(fake_speech.shape[0]):
            generated_sample = fake_speech[idx]
            file_name = os.path.join('results',
                                     '{}_e{}.wav'.format(test_file_names[idx].replace('.npy', ''), epoch + 1))
            wavfile.write(file_name, sample_rate, generated_sample.T)

    # save the model parameters for each epoch
    g_path = os.path.join('epochs', 'generator-{}.pkl'.format(epoch + 1))
    d_path = os.path.join('epochs', 'discriminator-{}.pkl'.format(epoch + 1))
    torch.save(generator.state_dict(), g_path)
    torch.save(discriminator.state_dict(), d_path)

运行出现d_clean_loss 0.2637, d_noisy_loss 0.2530, g_loss 9.1083, g_conditional_loss 8.9884:分别代表什么意思,应该为多少比较合理,应该随着迭代越来越小吗

内容:d_clean_loss 是判别器在真实干净音频上的损失,表示判别器将真实干净音频判定为真实干净音频的能力。d_noisy_loss 是判别器在噪声音频上的损失,表示判别器将噪声音频判定为噪声音频的能力。g_loss 是生成器的损失,表示生成器生成的音频与真实干净音频之间的差距。g_conditional_loss 是生成器条件损失,表示生成器生成的音频与输入的噪声音频之间的差距。

这些损失值应该随着训练迭代次数的增加而逐渐降低。但实际上,损失值的大小不仅受到训练次数的影响,还受到数据集、模型结构等因素的影响。因此,最终的合理损失值需要根据具体情况进行评估。

Audio Enhancement Training with Generative Adversarial Networks (GANs)

原文地址: https://www.cveoy.top/t/topic/np4X 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录