使用 Stable Baselines3 和 Retro 训练 Street Fighter II 冠军版 AI 玩家
这段代码实现了一个使用 stable-baselines3 库和 retro 环境进行强化学习训练的过程。具体来说,代码实现了以下功能:
-
导入必要的库和模块:导入了 os、sys、retro、stable_baselines3 等库和模块。
-
设置超参数:设置了 NUM_ENV(环境数量)和 LOG_DIR(日志路径)等超参数。同时还定义了一个线性调度器(linear_schedule),用于在训练过程中动态调整学习率和剪切范围等参数。
-
定义环境:定义了 make_env 函数,用于创建 retro 环境,并对其进行包装(StreetFighterCustomWrapper 和 Monitor)。
-
创建模型:创建了一个 PPO 模型,并指定了训练所需的参数。
-
训练模型:使用 learn 函数对模型进行训练,并在训练过程中使用 checkpoint_callback 对模型进行定期保存。
-
保存模型:将训练好的模型保存到指定路径。
import os
import sys
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
NUM_ENV = 2
LOG_DIR = 'logs'
os.makedirs(LOG_DIR, exist_ok=True)
# Linear scheduler
def linear_schedule(initial_value, final_value=0.0):
if isinstance(initial_value, str):
initial_value = float(initial_value)
final_value = float(final_value)
assert (initial_value > 0.0)
def scheduler(progress):
return final_value + progress * (initial_value - final_value)
return scheduler
def make_env(game, state, seed=0):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env, rendering=True)
env = Monitor(env)
env.seed(seed)
return env
return _init
def main():
# Set up the environment and model
game = 'StreetFighterIISpecialChampionEdition-Genesis'
env = SubprocVecEnv([make_env(game, state='Champion.Level12.RyuVsBison', seed=i) for i in range(NUM_ENV)])
# Set linear schedule for learning rate
# Start
lr_schedule = linear_schedule(2.5e-4, 2.5e-6)
# fine-tune
# lr_schedule = linear_schedule(5.0e-5, 2.5e-6)
# Set linear scheduler for clip range
# Start
clip_range_schedule = linear_schedule(0.15, 0.025)
# fine-tune
# clip_range_schedule = linear_schedule(0.075, 0.025)
model = PPO(
'CnnPolicy',
env,
device='cuda',
verbose=1,
n_steps=512,
batch_size=512,
n_epochs=4,
gamma=0.94,
learning_rate=lr_schedule,
clip_range=clip_range_schedule,
tensorboard_log='logs'
)
# Set the save directory
save_dir = 'trained_models'
os.makedirs(save_dir, exist_ok=True)
# Load the model from file
# model_path = 'trained_models/ppo_ryu_7000000_steps.zip'
# Load model and modify the learning rate and entropy coefficient
# custom_objects = {
# 'learning_rate': lr_schedule,
# 'clip_range': clip_range_schedule,
# 'n_steps': 512
# }
# model = PPO.load(model_path, env=env, device='cuda', custom_objects=custom_objects)
# Set up callbacks
# Note that 1 timesetp = 6 frame
checkpoint_interval = 31250 # checkpoint_interval * num_envs = total_steps_per_checkpoint
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix='ppo_ryu')
# Writing the training logs from stdout to a file
original_stdout = sys.stdout
log_file_path = os.path.join(save_dir, 'training_log.txt')
with open(log_file_path, 'w') as log_file:
sys.stdout = log_file
model.learn(
total_timesteps=int(100000000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
callback=[checkpoint_callback]#, stage_increase_callback]
)
env.close()
# Restore stdout
sys.stdout = original_stdout
# Save the final model
model.save(os.path.join(save_dir, 'ppo_sf2_ryu_final.zip'))
if __name__ == '__main__':
main()
原文地址: https://www.cveoy.top/t/topic/nMaK 著作权归作者所有。请勿转载和采集!