分析代码hydramainconfig_name=configdef mainconf DictConfig if not ospathisdirconfpathfeat_path osmakedirsconfpathfeat_path if not ospathisdirconfpathfeat_train osmakedirsconfpathfeat_
- 创建main函数并读取配置文件:
@hydra.main(config_name="config") def main(conf : DictConfig):
- 创建特征保存路径:
if not os.path.isdir(conf.path.feat_path): os.makedirs(conf.path.feat_path)
- 创建训练集特征保存路径:
if not os.path.isdir(conf.path.feat_train): os.makedirs(conf.path.feat_train)
- 创建验证集特征保存路径:
if not os.path.isdir(conf.path.feat_eval): os.makedirs(conf.path.feat_eval)
- 如果需要进行特征提取,则执行以下代码:
if conf.set.features:
print(" --Feature Extraction Stage--")
Num_extract_train,data_shape = feature_transform(conf=conf,mode="train")
print("Shape of dataset is {}".format(data_shape))
print("Total training samples is {}".format(Num_extract_train))
Num_extract_eval = feature_transform(conf=conf,mode='eval')
print("Total number of samples used for evaluation: {}".format(Num_extract_eval))
print(" --Feature Extraction Complete--")
- 如果需要进行模型训练,则执行以下代码:
if conf.set.train:
if not os.path.isdir(conf.path.Model):
os.makedirs(conf.path.Model)
init_seed()
gen_train = Datagen(conf)
X_train,Y_train,X_val,Y_val = gen_train.generate_train()
X_tr = torch.tensor(X_train)
Y_tr = torch.LongTensor(Y_train)
X_val = torch.tensor(X_val)
Y_val = torch.LongTensor(Y_val)
samples_per_cls = conf.train.n_shot * 2
batch_size_tr = samples_per_cls * conf.train.k_way
batch_size_vd = batch_size_tr
if conf.train.num_episodes is not None:
num_episodes_tr = conf.train.num_episodes
num_episodes_vd = conf.train.num_episodes
else:
num_episodes_tr = len(Y_train)//batch_size_tr
num_episodes_vd = len(Y_val)//batch_size_vd
samplr_train = EpisodicBatchSampler(Y_train,num_episodes_tr,conf.train.k_way,samples_per_cls)
samplr_valid = EpisodicBatchSampler(Y_val,num_episodes_vd,conf.train.k_way,samples_per_cls)
train_dataset = torch.utils.data.TensorDataset(X_tr,Y_tr)
valid_dataset = torch.utils.data.TensorDataset(X_val,Y_val)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_sampler=samplr_train,num_workers=0,pin_memory=True,shuffle=False)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,batch_sampler=samplr_valid,num_workers=0,pin_memory=True,shuffle=False)
if conf.train.encoder == 'Resnet':
encoder = ResNet()
else:
encoder = ProtoNet()
best_acc,model = train_protonet(encoder,train_loader,valid_loader,conf,num_episodes_tr,num_episodes_vd)
print("Best accuracy of the model on training set is {}".format(best_acc))
- 如果需要进行模型评估,则执行以下代码:
if conf.set.eval:
device = 'cuda'
init_seed()
name_arr = np.array([])
onset_arr = np.array([])
offset_arr = np.array([])
all_feat_files = [file for file in glob(os.path.join(conf.path.feat_eval,'*.h5'))]
for feat_file in all_feat_files:
feat_name = feat_file.split('/')[-1]
audio_name = feat_name.replace('h5','wav')
print("Processing audio file : {}".format(audio_name))
hdf_eval = h5py.File(feat_file,'r')
strt_index_query = hdf_eval['start_index_query'][:][0]
onset,offset = evaluate_prototypes(conf,hdf_eval,device,strt_index_query)
name = np.repeat(audio_name,len(onset))
name_arr = np.append(name_arr,name)
onset_arr = np.append(onset_arr,onset)
offset_arr = np.append(offset_arr,offset)
df_out = pd.DataFrame({'Audiofilename':name_arr,'Starttime':onset_arr,'Endtime':offset_arr})
csv_path = os.path.join(conf.path.root_dir,'Eval_out.csv')
df_out.to_csv(csv_path,index=False

原文地址: http://www.cveoy.top/t/topic/d4K9 著作权归作者所有。请勿转载和采集!