This Python code implements an audio enhancement model using a Generative Adversarial Network (GAN). The model is designed to learn how to remove noise from audio recordings. It consists of two main components:

  1. Generator (G): This network takes noisy audio as input and generates a clean audio signal.
  2. Discriminator (D): This network distinguishes between real (clean) audio and generated (fake) audio. It outputs a value between 0 and 1, where 1 represents real audio and 0 represents fake audio.

The model is trained using an adversarial approach where the generator aims to fool the discriminator by generating audio that sounds as real as possible. Simultaneously, the discriminator tries to identify the generated audio. This back-and-forth learning process leads to both networks improving their performance.

Understanding the Metrics

  • d_clean_loss: This represents the error made by the discriminator when classifying clean audio as real. It should be minimized and ideally close to 0. A low value indicates that the discriminator correctly identifies clean audio.

  • d_noisy_loss: This metric represents the discriminator's error in classifying generated noisy audio as fake. Like d_clean_loss, it should be minimized and ideally close to 0. A lower value indicates better performance in identifying generated noisy audio.

  • g_loss: This loss represents the overall error of the generator in producing audio that the discriminator believes is real. It should also be minimized, aiming for a value near 0. A lower value implies the generator is better at creating realistic-sounding audio.

  • g_conditional_loss: This metric measures the L1 distance between the generated audio and the original clean audio. It quantifies how similar the generated audio is to the actual clean version. The value should be as low as possible, indicating closer resemblance. Its importance and desired value will depend on the specific dataset and the task at hand.

Code Breakdown

import argparse
import os

import torch
import torch.nn as nn
from scipy.io import wavfile
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_preprocess import sample_rate
from model import Generator, Discriminator
from utils import AudioDataset, emphasis

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Audio Enhancement')
    parser.add_argument('--batch_size', default=32, type=int, help='train batch size')
    parser.add_argument('--num_epochs', default=50, type=int, help='train epochs number')

    opt = parser.parse_args()
    BATCH_SIZE = opt.batch_size
    NUM_EPOCHS = opt.num_epochs

    # load data
    print('loading data...')
    train_dataset = AudioDataset(data_type='train')
    test_dataset = AudioDataset(data_type='test')
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    test_data_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    # generate reference batch
    ref_batch = train_dataset.reference_batch(BATCH_SIZE)

    # create D and G instances
    discriminator = Discriminator()
    generator = Generator()
    if torch.cuda.is_available():
        discriminator.cuda()
        generator.cuda()
        ref_batch = ref_batch.cuda()
    ref_batch = Variable(ref_batch)
    print('# generator parameters:', sum(param.numel() for param in generator.parameters()))
    print('# discriminator parameters:', sum(param.numel() for param in discriminator.parameters()))
    # optimizers
    g_optimizer = optim.RMSprop(generator.parameters(), lr=0.0001)
    d_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.0001)

    for epoch in range(NUM_EPOCHS):
        train_bar = tqdm(train_data_loader)
        for train_batch, train_clean, train_noisy in train_bar:

            # latent vector - normal distribution
            z = nn.init.normal(torch.Tensor(train_batch.size(0), 1024, 8))
            if torch.cuda.is_available():
                train_batch, train_clean, train_noisy = train_batch.cuda(), train_clean.cuda(), train_noisy.cuda()
                z = z.cuda()
            train_batch, train_clean, train_noisy = Variable(train_batch), Variable(train_clean), Variable(train_noisy)
            z = Variable(z)

            # TRAIN D to recognize clean audio as clean
            # training batch pass
            discriminator.zero_grad()
            with torch.no_grad():
                outputs = discriminator(train_batch, ref_batch)


                clean_loss = torch.mean((outputs - 1.0) ** 2)  # L2 loss - we want them all to be 1
                clean_loss.requires_grad_(True)
                loss = torch.zeros(1, requires_grad=True)
                clean_loss.backward()

            # TRAIN D to recognize generated audio as noisy
            generated_outputs = generator(train_noisy, z)
            with torch.no_grad():
                outputs = discriminator(torch.cat((generated_outputs, train_noisy), dim=1), ref_batch)
            noisy_loss = torch.mean(outputs ** 2)  # L2 loss - we want them all to be 0
            noisy_loss.requires_grad_(True)
            noisy_loss.backward()

            # d_loss = clean_loss + noisy_loss
            d_optimizer.step()  # update parameters

            # TRAIN G so that D recognizes G(z) as real
            generator.zero_grad()
            with torch.no_grad():
                generated_outputs = generator(train_noisy, z)
                gen_noise_pair = torch.cat((generated_outputs, train_noisy), dim=1)
            #with torch.no_grad():
                outputs = discriminator(gen_noise_pair, ref_batch)

            g_loss_ = 0.5 * torch.mean((outputs - 1.0) ** 2)
            # L1 loss between generated output and clean sample
            l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(train_clean)))
            g_cond_loss = 100 * torch.mean(l1_dist)  # conditional loss
            g_loss = g_loss_ + g_cond_loss

            # backprop + optimize
            g_loss.requires_grad_(True)
            g_loss.backward()
            g_optimizer.step()

            train_bar.set_description(
                'Epoch {}: d_clean_loss {:.4f}, d_noisy_loss {:.4f}, g_loss {:.4f}, g_conditional_loss {:.4f}'\n                    .format(epoch + 1, clean_loss.data, noisy_loss.data, g_loss.data, g_cond_loss.data))

        # TEST model
        test_bar = tqdm(test_data_loader, desc='Test model and save generated audios')
        for test_file_names, test_noisy in test_bar:
            z = nn.init.normal(torch.Tensor(test_noisy.size(0), 1024, 8))
            if torch.cuda.is_available():
                test_noisy, z = test_noisy.cuda(), z.cuda()
            test_noisy, z = Variable(test_noisy), Variable(z)
            fake_speech = generator(test_noisy, z).data.cpu().numpy()  # convert to numpy array
            fake_speech = emphasis(fake_speech, emph_coeff=0.95, pre=False)

            for idx in range(fake_speech.shape[0]):
                generated_sample = fake_speech[idx]
                file_name = os.path.join('results',
                                         '{}_e{}.wav'.format(test_file_names[idx].replace('.npy', ''), epoch + 1))
                wavfile.write(file_name, sample_rate, generated_sample.T)

        # save the model parameters for each epoch
        g_path = os.path.join('epochs', 'generator-{}.pkl'.format(epoch + 1))
        d_path = os.path.join('epochs', 'discriminator-{}.pkl'.format(epoch + 1))
        torch.save(generator.state_dict(), g_path)
        torch.save(discriminator.state_dict(), d_path)

What are the ideal values for the metrics?

  • d_clean_loss and d_noisy_loss should ideally be close to 0. This indicates the discriminator is effectively identifying real and generated audio.
  • g_loss should also be minimized, aiming for a value near 0. This implies the generator is producing audio that the discriminator believes is real.
  • g_conditional_loss should be lower than g_loss, ideally by an order of magnitude. This signifies that the generated audio is becoming more similar to the original clean audio. The exact desired value will depend on the complexity of the audio and the effectiveness of the model in removing noise.

It's important to note that the ideal values for these metrics will depend on the specific audio dataset and the desired level of noise reduction. Experimentation and careful monitoring of the metrics during training will help optimize the model's performance.

Audio Enhancement using Generative Adversarial Networks (GANs)

原文地址: https://www.cveoy.top/t/topic/np4R 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录