import argparse import os

import torch import torch.nn as nn from scipy.io import wavfile from torch import optim from torch.autograd import Variable from torch.utils.data import DataLoader from tqdm import tqdm

from data_preprocess import sample_rate from model import Generator, Discriminator from utils import AudioDataset, emphasis

if name == 'main': parser = argparse.ArgumentParser(description='Train Audio Enhancement') parser.add_argument('--batch_size', default=64, type=int, help='train batch size') parser.add_argument('--num_epochs', default=86, type=int, help='train epochs number')

opt = parser.parse_args()
BATCH_SIZE = opt.batch_size
NUM_EPOCHS = opt.num_epochs

# load data
print('loading data...')
train_dataset = AudioDataset(data_type='train')
test_dataset = AudioDataset(data_type='test')
train_data_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
# generate reference batch
ref_batch = train_dataset.reference_batch(BATCH_SIZE)

# create D and G instances
discriminator = Discriminator()
generator = Generator()
if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()
    ref_batch = ref_batch.cuda()
ref_batch = Variable(ref_batch)
print('# generator parameters:', sum(param.numel() for param in generator.parameters()))
print('# discriminator parameters:', sum(param.numel() for param in discriminator.parameters()))
# optimizers
g_optimizer = optim.RMSprop(generator.parameters(), lr=0.0001)
d_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.0001)

for epoch in range(NUM_EPOCHS):
    train_bar = tqdm(train_data_loader)
    for train_batch, train_clean, train_noisy in train_bar:

        # latent vector - normal distribution
        z = nn.init.normal(torch.Tensor(train_batch.size(0), 1024, 8))
        if torch.cuda.is_available():
            train_batch, train_clean, train_noisy = train_batch.cuda(), train_clean.cuda(), train_noisy.cuda()
            z = z.cuda()
        train_batch, train_clean, train_noisy = Variable(train_batch), Variable(train_clean), Variable(train_noisy)
        z = Variable(z)

        # TRAIN D to recognize clean audio as clean
        # training batch pass
        discriminator.zero_grad()
        with torch.no_grad():
            outputs = discriminator(train_batch, ref_batch)


            clean_loss = torch.mean((outputs - 1.0) ** 2)  # L2 loss - we want them all to be 1
            clean_loss.requires_grad_(True)
            loss = torch.zeros(1, requires_grad=True)
            clean_loss.backward()

        # TRAIN D to recognize generated audio as noisy
        generated_outputs = generator(train_noisy, z)
        with torch.no_grad():
            outputs = discriminator(torch.cat((generated_outputs, train_noisy), dim=1), ref_batch)
        noisy_loss = torch.mean(outputs ** 2)  # L2 loss - we want them all to be 0
        noisy_loss.requires_grad_(True)
        noisy_loss.backward()

        # d_loss = clean_loss + noisy_loss
        d_optimizer.step()  # update parameters

        # TRAIN G so that D recognizes G(z) as real
        generator.zero_grad()
        with torch.no_grad():
            generated_outputs = generator(train_noisy, z)
            gen_noise_pair = torch.cat((generated_outputs, train_noisy), dim=1)
        #with torch.no_grad():
            outputs = discriminator(gen_noise_pair, ref_batch)

        g_loss_ = 0.5 * torch.mean((outputs - 1.0) ** 2)
        # L1 loss between generated output and clean sample
        l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(train_clean)))
        g_cond_loss = 100 * torch.mean(l1_dist)  # conditional loss
        g_loss = g_loss_ + g_cond_loss

        # backprop + optimize
        g_loss.requires_grad_(True)
        g_loss.backward()
        g_optimizer.step()

        train_bar.set_description(
            'Epoch {}: d_clean_loss {:.4f}, d_noisy_loss {:.4f}, g_loss {:.4f}, g_conditional_loss {:.4f}'\n                    .format(epoch + 1, clean_loss.data, noisy_loss.data, g_loss.data, g_cond_loss.data))

    # TEST model
    test_bar = tqdm(test_data_loader, desc='Test model and save generated audios')
    for test_file_names, test_noisy in test_bar:
        z = nn.init.normal(torch.Tensor(test_noisy.size(0), 1024, 8))
        if torch.cuda.is_available():
            test_noisy, z = test_noisy.cuda(), z.cuda()
        test_noisy, z = Variable(test_noisy), Variable(z)
        fake_speech = generator(test_noisy, z).data.cpu().numpy()  # convert to numpy array
        fake_speech = emphasis(fake_speech, emph_coeff=0.95, pre=False)

        for idx in range(fake_speech.shape[0]):
            generated_sample = fake_speech[idx]
            file_name = os.path.join('results',
                                     '{}_e{}.wav'.format(test_file_names[idx].replace('.npy', ''), epoch + 1))
            wavfile.write(file_name, sample_rate, generated_sample.T)

    # save the model parameters for each epoch
    g_path = os.path.join('epochs', 'generator-{}.pkl'.format(epoch + 1))
    d_path = os.path.join('epochs', 'discriminator-{}.pkl'.format(epoch + 1))
    torch.save(generator.state_dict(), g_path)
    torch.save(discriminator.state_dict(), d_path)

Code Explanation

Import Necessary Libraries

The code starts by importing the necessary libraries for training a GAN model for audio enhancement:

  • argparse for parsing command-line arguments.
  • os for file and directory operations.
  • torch, torch.nn, optim, torch.autograd, and torch.utils.data for PyTorch functionalities.
  • scipy.io.wavfile for writing audio files in WAV format.
  • tqdm for displaying progress bars.
  • Custom modules: data_preprocess, model, and utils.

Define Training Parameters

  • batch_size: The number of audio samples processed in each training iteration.
  • num_epochs: The total number of training epochs.

Load Data

  • The code creates AudioDataset objects for training and testing data. These datasets likely load audio samples and their corresponding labels (e.g., clean audio and noisy audio).
  • DataLoader objects are created to iterate over the datasets in batches.

Create Model Instances

  • Discriminator and Generator instances are created. These are the core components of the GAN model.
  • The models are moved to the GPU if available.

Define Optimizers

  • RMSprop optimizers are defined for the generator and discriminator, responsible for updating their parameters during training.

Training Loop

The training loop iterates over the specified number of epochs.

  • For Each Epoch:
    • The code iterates over the training data in batches using train_data_loader.
    • For Each Batch:
      • Train Discriminator:
        • The discriminator is trained to recognize clean audio as clean and generated audio as noisy.
        • The discriminator's parameters are updated using d_optimizer.step().
      • Train Generator:
        • The generator is trained to produce audio that the discriminator considers real (i.e., clean).
        • The generator's parameters are updated using g_optimizer.step().

Testing Loop

  • After each epoch, the model is evaluated on the test dataset using test_data_loader.
  • The generated audio samples are saved to a specified directory.

Save Model Parameters

  • The parameters of the generator and discriminator are saved at the end of each epoch.

Conditional Loss

  • The line g_cond_loss = 100 * torch.mean(l1_dist) calculates the conditional loss for the generator.
  • The conditional loss is based on the L1 distance between the generated audio and the clean audio, multiplied by a weight factor of 100.
  • This loss encourages the generator to produce audio that is similar to the clean audio. This is achieved by minimizing the difference between the two audio signals.

Key Points

  • The code implements a GAN model for audio enhancement, using a generator and a discriminator.
  • The discriminator is trained to distinguish between real and generated audio, while the generator learns to produce audio that fools the discriminator.
  • The use of conditional loss helps to ensure that the generated audio is close to the clean audio.
  • The model is evaluated on a test dataset after each epoch to monitor its performance.
  • The model parameters are saved at the end of each epoch to allow for later restoration.

Potential Enhancements

  • Experiment with different GAN architectures and loss functions.
  • Implement techniques for improving the stability of GAN training, such as gradient penalty or spectral normalization.
  • Investigate the use of transfer learning to leverage pre-trained models for audio processing tasks.
Audio Enhancement with Generative Adversarial Networks (GANs)

原文地址: https://www.cveoy.top/t/topic/nvyl 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录