分析代码 if not ospathisdirconfpathfeat_path osmakedirsconfpathfeat_path if not ospathisdirconfpathfeat_train osmakedirsconfpathfeat_train if not ospathisdirconfpathfeat_eval
- 检查特征路径是否存在,如果不存在则创建
if not os.path.isdir(conf.path.feat_path):
os.makedirs(conf.path.feat_path)
- 检查训练集特征存储路径是否存在,如果不存在则创建
if not os.path.isdir(conf.path.feat_train):
os.makedirs(conf.path.feat_train)
- 检查验证集特征存储路径是否存在,如果不存在则创建
if not os.path.isdir(conf.path.feat_eval):
os.makedirs(conf.path.feat_eval)
- 如果需要进行特征提取,则调用函数进行特征提取,并输出数据集的形状及样本数
if conf.set.features:
Num_extract_train,data_shape = feature_transform(conf=conf,mode="train")
print("Shape of dataset is {}".format(data_shape))
print("Total training samples is {}".format(Num_extract_train))
Num_extract_eval = feature_transform(conf=conf,mode='eval')
print("Total number of samples used for evaluation: {}".format(Num_extract_eval))
- 如果需要进行模型训练,则创建模型保存路径,调用函数进行训练,并输出训练集上的最佳准确率
if conf.set.train:
if not os.path.isdir(conf.path.Model):
os.makedirs(conf.path.Model)
best_acc,model = train_protonet(encoder,train_loader,valid_loader,conf,num_episodes_tr,num_episodes_vd)
print("Best accuracy of the model on training set is {}".format(best_acc))
- 如果需要进行模型评估,则对所有待评估的特征文件进行循环,调用函数进行评估,并输出评估结果保存至CSV文件中
if conf.set.eval:
all_feat_files = [file for file in glob(os.path.join(conf.path.feat_eval,'*.h5'))]
for feat_file in all_feat_files:
onset,offset = evaluate_prototypes(conf,hdf_eval,device,strt_index_query)
df_out.to_csv(csv_path,index=False)
``
原文地址: https://www.cveoy.top/t/topic/d4LT 著作权归作者所有。请勿转载和采集!