from future import absolute_import from future import division from future import print_function

from datetime import datetime import os.path import time import ipdb import numpy as np import tensorflow as tf import SENN import audio_reader

LR = 0.00001

FLAGS = tf.app.flags.FLAGS

store the check points

tf.app.flags.DEFINE_string(

'train_dir',

'./train_dir',

'''Directory where to write event logs ''')

write summary about the loss and etc.

tf.app.flags.DEFINE_string(

'sum_dir',

'./sum_dir',

'''Directory where to write summary ''')

noise directory

tf.app.flags.DEFINE_string(

'noise_dir',

'D:/graduation_design/data/noisy_trainset_56spk_wav/noisy_trainset_56spk_wav',

# '/home/nca/Downloads/raw_data/Nonspeech_train/',

'''Directory where to load noise ''')

speech directory

tf.app.flags.DEFINE_string(

'speech_dir',

'D:/graduation_design/data/clean_trainset_56spk_wav/clean_trainset_56spk_wav',

# '/home/nca/Downloads/raw_data/speech_train/',

'''Directory where to load speech ''')

validation noise directory

tf.app.flags.DEFINE_string(

'val_noise_dir',

'D:/graduation_design/data/noisy_testset_wav/noisy_testset_wav',

# '/home/nca/Downloads/raw_data/Nonspeech_test/',

'''Directory where to load noise ''')

validation speech directory

tf.app.flags.DEFINE_string(

'val_speech_dir',

'D:/graduation_design/data/clean_testset_wav/clean_testset_wav',

# '/home/nca/Downloads/raw_data/speech_test/',

'''Directory where to load noise ''')

tf.app.flags.DEFINE_integer('max_steps', 2000000000,

'''Number of batches to run.''')

import argparse

parser = argparse.ArgumentParser() parser.add_argument('--train_dir', type=str, default='./train_dir', help='Directory where to write event logs') parser.add_argument('--sum_dir', type=str, default='./sum_dir', help='Directory where to write summary') parser.add_argument('--noise_dir', type=str, default='D:/graduation_design/data/noisy_trainset_56spk_wav/noisy_trainset_56spk_wav', help='Directory where to load noise') parser.add_argument('--speech_dir', type=str, default='D:/graduation_design/data/clean_trainset_56spk_wav/clean_trainset_56spk_wav', help='Directory where to load speech') parser.add_argument('--val_noise_dir', type=str, default='D:/graduation_design/data/noisy_testset_wav/noisy_testset_wav', help='Directory where to load noise') parser.add_argument('--val_speech_dir', type=str, default='D:/graduation_design/data/clean_testset_wav/clean_testset_wav', help='Directory where to load noise') parser.add_argument('--max_steps', type=int, default=2000000000, help='Number of batches to run.') args = parser.parse_args()

train_dir = args.train_dir sum_dir = args.sum_dir noise_dir = args.noise_dir speech_dir = args.speech_dir val_noise_dir = args.val_noise_dir val_speech_dir = args.val_speech_dir max_steps = args.max_steps

NFFT = 256 # number of fft points NEFF = 129 # number of effective fft points frame_move = 64 # hop size batch_size = 128 N_IN = 8 # number of frames presented to the net N_OUT = 1 # output frame number validation_samples = 848824 # total numbers of the validation set batch_of_val = np.floor(validation_samples / batch_size)

after all the batches, dequeue the left to make sure

all the samples in the validation set are the same

val_left_to_dequeue = validation_samples - batch_of_val * batch_size val_loss = np.zeros([1000000])

def train(): coord = tf.train.Coordinator()

# speech reader
audio_rd = audio_reader.Audio_reader(
    speech_dir, noise_dir, coord, N_IN, NFFT,
    frame_move, is_val=False)

# noise reader
val_audio_rd = audio_reader.Audio_reader(
    val_speech_dir, val_noise_dir, coord, N_IN, NFFT,
    frame_move, is_val=False)

# flag for validation or training
is_val = tf.placeholder(dtype=tf.bool, shape=())

# speech enhancement net
SE_Net = SENN.SE_NET(
    batch_size, NEFF, N_IN, N_OUT)

# raw data frames
train_data_frames = audio_rd.dequeue(batch_size)

val_data_frames = val_audio_rd.dequeue(batch_size)

# select which to use in validation or training
data_frames = tf.cond(
    is_val, lambda: val_data_frames, lambda: train_data_frames)

# transform raw data into inputs for the nets
# it is not done in preprocessing because it runs really fast
# and we don't need to store all the mixed samples
images, targets = SE_Net.inputs(data_frames)

# infer the clean speech
inf_targets = SE_Net.inference(images, is_train=True)

loss = SE_Net.loss(inf_targets, targets)  # compute loss

train_op = SE_Net.train(loss, LR)  # optimizer

saver = tf.train.Saver(tf.all_variables())

summary_op = tf.merge_all_summaries()

init = tf.initialize_all_variables()

sess = tf.Session()

sess.run(init)

audio_rd.start_threads(sess)  # start audio reading threads
val_audio_rd.start_threads(sess)

# tf.train.start_queue_runners(sess=sess)

summary_writer = tf.train.SummaryWriter(
    sum_dir,
    sess.graph)

# to track the times of validation
val_loss_id = 0
for step in xrange(max_steps):

    start_time = time.time()
    _, loss_value = sess.run(
        [train_op, loss], feed_dict={is_val: False})
    # images_batch, targets_batch, inf_batch, _, loss_value = sess.run(
    #     [images, targets, inf_targets, train_op, loss], feed_dict={is_val: False})
    # ipdb.set_trace()
    duration = time.time() - start_time

    assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
    # display training loss every 100 steps
    if step % 100 == 0:
        # if step % 10000000 == 0:
        #     ipdb.set_trace()
        num_examples_per_step = batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = (
            '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '                'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

    # write summary every 100 step
    if step % 100 == 0:
        summary_str = sess.run(
            summary_op, feed_dict={is_val: False})
        summary_writer.add_summary(summary_str, step)

    # do validation every 100000 step
    if step % 100000 == 0 or (step + 1) == max_steps:
        np_val_loss = 0
        print('Doing validation, please wait ...')
        for j in range(int(batch_of_val)):
            # images_batch, targets_batch, inf_batch, temp_loss = sess.run(
            #     [images, targets, inf_targets, loss],
            temp_loss, = sess.run(
                [loss],
                feed_dict={is_val: True})
            # ipdb.set_trace()
            np_val_loss += temp_loss
        val_audio_rd.dequeue(val_left_to_dequeue)
        mean_val_loss = np_val_loss / batch_of_val
        print('validation loss %.2f' % mean_val_loss)
        val_loss[val_loss_id] = mean_val_loss
        val_loss_id += 1
        np.save('val_loss2.npy', val_loss)

    # store the model every 10000 step
    if step % 10000 == 0 or (step + 1) == max_steps:
        checkpoint_path = os.path.join(train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)

train()

Speech Enhancement Neural Network (SENN) Training

原文地址: https://www.cveoy.top/t/topic/ntdz 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录