import torch
from dataset import SEGAN_Dataset
from hparams import hparams
from model import Generator, Discriminator
import os
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
from torch.autograd import Variable

if __name__ == "__main__":
    
   
    
    # 定义device

    device = torch.device("cuda:1")
    
    # 导入参数

    para = hparams()
    
    # 创建数据保存文件夹

    os.makedirs(para.save_path,exist_ok=True)
    
    # 创建生成器

    generator = Generator()
    generator = generator.to(device)
    
    # 创建鉴别器

    discriminator = Discriminator()
    discriminator = discriminator.to(device)
    
    # # 创建G 的优化器

    # m_G_optimizer = torch.optim.RMSprop(m_G.parameters(), lr=para.lr_G)

    
    # # 创建D 的优化器

    # d_optimizer = torch.optim.RMSprop(m_D.parameters(), lr=para.lr_D)

    # optimizers

    
    g_optimizer = torch.optim.RMSprop(generator.parameters(), lr=0.0001, weight_decay=0.001)
    d_optimizer = torch.optim.RMSprop(discriminator.parameters(), lr=0.0001, weight_decay=0.001)
    
    # 定义数据集

    m_dataset = SEGAN_Dataset(para)
    
    # 获取ref-batch

    ref_batch = m_dataset.ref_batch(para.ref_batch_size)
    ref_batch = Variable(ref_batch)
    ref_batch = ref_batch.to(device)
    
    # 定义dataloader

    m_dataloader = DataLoader(m_dataset,batch_size = para.batch_size,shuffle = True, num_workers = 8)
    loss_d_all =0
    loss_g_all =0
    n_step =0
    for epoch in range(para.n_epoch):
        
        for i_batch, sample_batch in enumerate(m_dataloader):
            batch_clean = sample_batch[0]
            batch_noisy = sample_batch[1]
            batch_clean = Variable(batch_clean)
            batch_noisy = Variable(batch_noisy)
            
            batch_clean = batch_clean.to(device)
            batch_noisy = batch_noisy.to(device)
            
            # print(batch_clean.size())

            # print(batch_noisy.size())

            
            batch_z = nn.init.normal(torch.Tensor(batch_clean.size(0), para.size_z[0], para.size_z[1]))
            batch_z = Variable(batch_z)
            batch_z = batch_z.to(device)
            
            
            discriminator.zero_grad()
            train_batch = Variable(torch.cat([batch_clean,batch_noisy],axis=1))
            # train_batch = torch.cat([batch_clean,batch_noisy],axis=1)

            outputs = discriminator(train_batch, ref_batch)
            clean_loss = torch.mean((outputs - 1.0) ** 2)  # L2 loss - we want them all to be 1

            # clean_loss.backward()


            # TRAIN D to recognize generated audio as noisy

            generated_outputs = generator(batch_noisy, batch_z)
            outputs = discriminator(torch.cat((generated_outputs, batch_noisy), dim=1), ref_batch)
            noisy_loss = torch.mean(outputs ** 2)  # L2 loss - we want them all to be 0

            # noisy_loss.backward()


            d_loss = clean_loss + noisy_loss
            d_loss.backward()
            d_optimizer.step()  # update parameters


            # TRAIN G so that D recognizes G(z) as real

            generator.zero_grad()
            generated_outputs = generator(batch_noisy, batch_z)
            gen_noise_pair = torch.cat((generated_outputs, batch_noisy), dim=1)
            outputs = discriminator(gen_noise_pair, ref_batch)

            g_loss_ = 0.5 * torch.mean((outputs - 1.0) ** 2)
            # L1 loss between generated output and clean sample

            l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(batch_clean)))
            g_cond_loss = 100 * torch.mean(l1_dist)  # conditional loss

            g_loss = g_loss_ + g_cond_loss

            # backprop + optimize

            g_loss.backward()
            g_optimizer.step()
            
            
            
            # # 更新参数 D

            # m_D.zero_grad()

            
            # # 计算真实数据的 d_loss

            # batch_train_real = torch.cat((batch_clean,batch_noisy),axis=1)

  
            
            # real_d_out = m_D(batch_train_real,batch_ref)

            # d_real_loss = torch.mean((real_d_out-1.0)**2)

            # d_real_loss.backward()

            
            # # 计算虚假数据的 d_loss

            # batch_g = m_G(batch_noisy,batch_z)

            # batch_train_fake = torch.cat((batch_g,batch_noisy),axis=1)

            # fake_d_out = m_D(batch_train_fake,batch_ref)

            # d_fake_loss = torch.mean(fake_d_out**2)

            # d_fake_loss.backward()

            
            # d_loss =  d_real_loss + d_fake_loss

            
            # # d_loss.backward()

            # m_D_optimizer.step()

            
            # # 更新参数G

            # m_G.zero_grad()

            
            # batch_g = m_G(batch_noisy,batch_z)

            
            # batch_train_fake = torch.cat((batch_g,batch_noisy),axis=1)

            # fake_d_out = m_D(batch_train_fake,batch_ref)

            # # D 将生成语音认作真实语音

            # g_loss_1 = 0.5*torch.mean((fake_d_out-1.0)**2)

            # # 生成语音和真实语音的L1距离最小

           
            # L1_dis = torch.abs(torc
SEGAN 音频去噪模型训练代码 - 基于 PyTorch

原文地址: https://www.cveoy.top/t/topic/nvVu 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录