SDML 模型训练脚本 - 使用 Python 和 PyTorch
这是一个 Python 脚本,用于训练 SDML 模型。脚本中包含了许多参数,例如学习率、批次大小、输出形状、数据集等等,这些参数可以通过命令行输入或在脚本中手动设置。脚本中的 main 函数是程序的入口,它会创建一个 Solver 对象并进行训练。其中,Solver 是 SDML 模型的核心类,它封装了训练和测试的逻辑。
此外,还有一些其他的辅助函数和代码,例如设置随机种子、保存训练结果等等。总的来说,这个脚本的功能比较复杂,需要有一定的 Python 编程经验才能理解。
以下是脚本的代码:
def main(config):
from SDML import Solver
solver = Solver(config)
cudnn.benchmark = True
return solver.train()
if __name__ == '__main__':
CUDA_LAUNCH_BLOCKING=3
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--compute_all', type=bool, default=False)
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--just_valid', type=bool, default=False) # wiki, pascal, nus-wide, xmedianet
parser.add_argument('--multiprocessing', type=bool, default=True)
parser.add_argument('--running_time', type=bool, default=False)
parser.add_argument('--cuda_list', type=list, default=[3])
parser.add_argument('--lr', type=list, default=[1e-4, 2e-4, 2e-4, 2e-4, 2e-4])
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--output_shape', type=int, default=512)
parser.add_argument('--alpha', type=float, default=0.5)
parser.add_argument('--datasets', type=str, default='mv_cross') # xmedia, wiki_doc2vec, MSCOCO_doc2vec, nus_wide_doc2vec mvcross
parser.add_argument('--view_id', type=int, default=-1)
parser.add_argument('--sample_interval', type=int, default=1)
parser.add_argument('--epochs', type=int, default=200)
config = parser.parse_args()
seed = 123
print('seed: ' + str(seed))
import numpy as np
np.random.seed(seed)
import random as rn
rn.seed(seed)
import os
os.environ['PYTHONHASHSEED'] = str(seed)
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
from torch.backends import cudnn
cudnn.enabled = False
results = main(config)
# print(config)
# import scipy.io as sio
# if config.running_time:
# runing_time = []
# for i in range(1):
# print('%d-th running time test', i)
# results = main(config)
# runing_time.append(results)
# print('average running time: %f', np.mean(runing_time))
# else:
# results = main(config)
# if config.just_valid:
# sio.savemat('para_results/params_' + config.datasets + '_' + str(config.batch_size) + '_' + str(config.output_shape) + '_' + str(config.alpha) + '_' + str(config.epochs) + '_' + str(config.lr) + '_loss.mat', {'val_d_loss': np.array(results[0]), 'tr_d_loss': np.array(results[1]), 'tr_ae_loss': np.array(results[2])})
# else:
# sio.savemat('results/params_' + config.datasets + '_' + str(config.batch_size) + '_' + str(config.output_shape) + '_' + str(config.alpha) + '_' + str(config.epochs) + '_' + str(config.lr) + '_resutls.mat', {'results': np.array(results)})
原文地址: https://www.cveoy.top/t/topic/nW3w 著作权归作者所有。请勿转载和采集!