这段代码实现了一个使用 stable-baselines3 库和 retro 环境进行强化学习训练的过程。具体来说,代码实现了以下功能:

  1. 导入必要的库和模块:导入了 os、sys、retro、stable_baselines3 等库和模块。

  2. 设置超参数:设置了 NUM_ENV(环境数量)和 LOG_DIR(日志路径)等超参数。同时还定义了一个线性调度器(linear_schedule),用于在训练过程中动态调整学习率和剪切范围等参数。

  3. 定义环境:定义了 make_env 函数,用于创建 retro 环境,并对其进行包装(StreetFighterCustomWrapper 和 Monitor)。

  4. 创建模型:创建了一个 PPO 模型,并指定了训练所需的参数。

  5. 训练模型:使用 learn 函数对模型进行训练,并在训练过程中使用 checkpoint_callback 对模型进行定期保存。

  6. 保存模型:将训练好的模型保存到指定路径。

import os
import sys

import retro
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env import SubprocVecEnv

from street_fighter_custom_wrapper import StreetFighterCustomWrapper

NUM_ENV = 2
LOG_DIR = 'logs'
os.makedirs(LOG_DIR, exist_ok=True)

# Linear scheduler
def linear_schedule(initial_value, final_value=0.0):

    if isinstance(initial_value, str):
        initial_value = float(initial_value)
        final_value = float(final_value)
        assert (initial_value > 0.0)

    def scheduler(progress):
        return final_value + progress * (initial_value - final_value)

    return scheduler

def make_env(game, state, seed=0):
    def _init():
        env = retro.make(
            game=game, 
            state=state, 
            use_restricted_actions=retro.Actions.FILTERED, 
            obs_type=retro.Observations.IMAGE    
        )
        env = StreetFighterCustomWrapper(env, rendering=True)
        env = Monitor(env)
        env.seed(seed)
        return env
    return _init

def main():
    # Set up the environment and model
    game = 'StreetFighterIISpecialChampionEdition-Genesis'
    env = SubprocVecEnv([make_env(game, state='Champion.Level12.RyuVsBison', seed=i) for i in range(NUM_ENV)])

    # Set linear schedule for learning rate
    # Start
    lr_schedule = linear_schedule(2.5e-4, 2.5e-6)

    # fine-tune
    # lr_schedule = linear_schedule(5.0e-5, 2.5e-6)

    # Set linear scheduler for clip range
    # Start
    clip_range_schedule = linear_schedule(0.15, 0.025)

    # fine-tune
    # clip_range_schedule = linear_schedule(0.075, 0.025)

    model = PPO(
        'CnnPolicy', 
        env,
        device='cuda', 
        verbose=1,
        n_steps=512,
        batch_size=512,
        n_epochs=4,
        gamma=0.94,
        learning_rate=lr_schedule,
        clip_range=clip_range_schedule,
        tensorboard_log='logs'
    )

    # Set the save directory
    save_dir = 'trained_models'
    os.makedirs(save_dir, exist_ok=True)

    # Load the model from file
    # model_path = 'trained_models/ppo_ryu_7000000_steps.zip'
    
    # Load model and modify the learning rate and entropy coefficient
    # custom_objects = {
    #     'learning_rate': lr_schedule,
    #     'clip_range': clip_range_schedule,
    #     'n_steps': 512
    # }
    # model = PPO.load(model_path, env=env, device='cuda', custom_objects=custom_objects)

    # Set up callbacks
    # Note that 1 timesetp = 6 frame
    checkpoint_interval = 31250 # checkpoint_interval * num_envs = total_steps_per_checkpoint
    checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix='ppo_ryu')

    # Writing the training logs from stdout to a file
    original_stdout = sys.stdout
    log_file_path = os.path.join(save_dir, 'training_log.txt')
    with open(log_file_path, 'w') as log_file:
        sys.stdout = log_file
    
        model.learn(
            total_timesteps=int(100000000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
            callback=[checkpoint_callback]#, stage_increase_callback]
        )
        env.close()

    # Restore stdout
    sys.stdout = original_stdout

    # Save the final model
    model.save(os.path.join(save_dir, 'ppo_sf2_ryu_final.zip'))

if __name__ == '__main__':
    main()
使用 Stable Baselines3 和 Retro 训练 Street Fighter II 冠军版 AI 玩家

原文地址: https://www.cveoy.top/t/topic/nMaK 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录