Few-Shot Audio Event Detection with Prototypical Networks
This code implements a Few-shot learning based audio event detection method. It mainly consists of two stages: feature extraction and training.
In the feature extraction stage, Mel spectrograms are used as feature representations for the audio signal. In the training stage, a Prototypical Neural Network model (ProtoNet) is used for training.
During training, the code first samples and splits the input data. Then, it uses EpisodicBatchSampler to sample the data and generate small batches for training. The code uses two different model structures: ResNet and ProtoNet, with ProtoNet performing better. The code also uses techniques such as the Adam optimizer and learning rate decay to improve the training results.
During training, the code also outputs the average loss and accuracy of the training and validation sets, as well as the validation set loss and accuracy for each epoch. The best model is saved to a file. Finally, the code also provides an evaluation stage for processing test data and outputting prediction results.
Code Snippet:
optim = torch.optim.Adam([{'params':encoder.parameters()}] ,lr=conf.train.lr_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim, gamma=conf.train.scheduler_gamma,
step_size=conf.train.scheduler_step_size)
num_epochs = conf.train.epochs
best_model_path = conf.path.best_model
last_model_path = conf.path.last_model
train_loss = []
val_loss = []
train_acc = []
val_acc = []
best_val_acc = 0.0
encoder.to(device)
for epoch in range(num_epochs):
print('Epoch {}'.format(epoch))
train_iterator = iter(train_loader)
for batch in tqdm(train_iterator):
optim.zero_grad()
encoder.train()
x, y = batch
x = x.to(device)
y = y.to(device)
x_out = encoder(x)
tr_loss,tr_acc = loss_fn(x_out,y,conf.train.n_shot)
train_loss.append(tr_loss.item())
train_acc.append(tr_acc.item())
tr_loss.backward()
optim.step()
avg_loss_tr = np.mean(train_loss[-num_batches_tr:])
avg_acc_tr = np.mean(train_acc[-num_batches_tr:])
print('Average train loss: {} Average training accuracy: {}'.format(avg_loss_tr,avg_acc_tr))
lr_scheduler.step()
encoder.eval()
val_iterator = iter(valid_loader)
for batch in tqdm(val_iterator):
x,y = batch
x = x.to(device)
x_val = encoder(x)
valid_loss, valid_acc = loss_fn(x_val, y, conf.train.n_shot)
val_loss.append(valid_loss.item())
val_acc.append(valid_acc.item())
avg_loss_vd = np.mean(val_loss[-num_batches_vd:])
avg_acc_vd = np.mean(val_acc[-num_batches_vd:])
print ('Epoch {}, Validation loss {:.4f}, Validation accuracy {:.4f}'.format(epoch,avg_loss_vd,avg_acc_vd))
if avg_acc_vd > best_val_acc:
print('Saving the best model with valdation accuracy {}'.format(avg_acc_vd))
best_val_acc = avg_acc_vd
#best_state = model.state_dict()
torch.save({'encoder':encoder.state_dict()},best_model_path)
torch.save({'encoder':encoder.state_dict()},last_model_path)
return best_val_acc,encoder
@hydra.main(config_name='config')
def main(conf : DictConfig):
if not os.path.isdir(conf.path.feat_path):
os.makedirs(conf.path.feat_path)
if not os.path.isdir(conf.path.feat_train):
os.makedirs(conf.path.feat_train)
if not os.path.isdir(conf.path.feat_eval):
os.makedirs(conf.path.feat_eval)
if conf.set.features:
print(' --Feature Extraction Stage--')
Num_extract_train,data_shape = feature_transform(conf=conf,mode='train')
print('Shape of dataset is {}'.format(data_shape))
print('Total training samples is {}'.format(Num_extract_train))
Num_extract_eval = feature_transform(conf=conf,mode='eval')
print('Total number of samples used for evaluation: {}'.format(Num_extract_eval))
print(' --Feature Extraction Complete--')
if conf.set.train:
if not os.path.isdir(conf.path.Model):
os.makedirs(conf.path.Model)
init_seed()
gen_train = Datagen(conf)
X_train,Y_train,X_val,Y_val = gen_train.generate_train()
X_tr = torch.tensor(X_train)
Y_tr = torch.LongTensor(Y_train)
X_val = torch.tensor(X_val)
Y_val = torch.LongTensor(Y_val)
samples_per_cls = conf.train.n_shot * 2
batch_size_tr = samples_per_cls * conf.train.k_way
batch_size_vd = batch_size_tr
if conf.train.num_episodes is not None:
num_episodes_tr = conf.train.num_episodes
num_episodes_vd = conf.train.num_episodes
else:
num_episodes_tr = len(Y_train)//batch_size_tr
num_episodes_vd = len(Y_val)//batch_size_vd
samplr_train = EpisodicBatchSampler(Y_train,num_episodes_tr,conf.train.k_way,samples_per_cls)
samplr_valid = EpisodicBatchSampler(Y_val,num_episodes_vd,conf.train.k_way,samples_per_cls)
train_dataset = torch.utils.data.TensorDataset(X_tr,Y_tr)
valid_dataset = torch.utils.data.TensorDataset(X_val,Y_val)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_sampler=samplr_train,num_workers=0,pin_memory=True,shuffle=False)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,batch_sampler=samplr_valid,num_workers=0,pin_memory=True,shuffle=False)
if conf.train.encoder == 'Resnet':
encoder = ResNet()
else:
encoder = ProtoNet()
best_acc,model = train_protonet(encoder,train_loader,valid_loader,conf,num_episodes_tr,num_episodes_vd)
print('Best accuracy of the model on training set is {}'.format(best_acc))
if conf.set.eval:
device = 'cuda'
init_seed()
name_arr = np.array([])
onset_arr = np.array([])
offset_arr = np.array([])
all_feat_files = [file for file in glob(os.path.join(conf.path.feat_eval,'*.h5'))]
for feat_file in all_feat_files:
feat_name = feat_file.split('/')[-1]
audio_name = feat_name.replace('h5','wav')
print('Processing audio file : {}'.format(audio_name))
hdf_eval = h5py.File(feat_file,'r')
strt_index_query = hdf_eval['start_index_query'][:][0]
onset,offset = evaluate_prototypes(conf,hdf_eval,device,strt_index_query)
name = np.repeat(audio_name,len(onset))
name_arr = np.append(name_arr,name)
onset_arr = np.append(onset_arr,onset)
offset_arr = np.append(offset_arr,offset)
df_out = pd.DataFrame({'Audiofilename':name_arr,'Starttime':onset_arr,'Endtime':offset_arr})
csv_path = os.path.join(conf.path.root_dir,'Eval_out.csv')
df_out.to_csv(csv_path,index=False)
if __name__ == '__main__':
main()
Key Features:
- Few-shot learning: The model can learn to detect new audio events with only a few labeled examples.
- Prototypical Networks: The model uses Prototypical Networks, a powerful approach for few-shot learning.
- Mel spectrograms: Mel spectrograms are used as features to represent audio signals.
- ResNet and ProtoNet: The code provides support for both ResNet and ProtoNet model architectures.
- Episodic training: The model is trained using episodic training, which mimics the few-shot learning scenario.
- Evaluation stage: The code includes an evaluation stage for processing test data and outputting predictions.
Applications:
This code can be used for a variety of audio event detection tasks, including:
- Speech recognition: Detecting different types of speech sounds, such as phonemes or words.
- Music information retrieval: Identifying different musical instruments or genres.
- Environmental monitoring: Detecting sounds like bird calls, traffic, or machinery.
原文地址: https://www.cveoy.top/t/topic/nMRm 著作权归作者所有。请勿转载和采集!