找回密码
 立即注册
首页 业界区 业界 stable_baseline3 快速入门(一): 训练第一个强化学习模 ...

stable_baseline3 快速入门(一): 训练第一个强化学习模型

轧岔 4 天前
简介

stable_baseline3 是一个基于 PyTorch 的强化学习算法开源库,里面集成了多种强化学习算法,使用这个开源库能够让我们不需要过度关注强化学习算法细节,专注于AI业务的开发。
环境配置
  1. pip install stable-baselines3
  2. pip install gymnasium
复制代码
这里stable-baselines3会默认安装pytroch框架,但是是不带cuda版本的,这就意味着我们无法利用我们的显卡对模型进行训练。
下载cuda版本的pytroch步骤如下:

  • 卸载原来版本的pytroch框架
  1. pip uninstall torch torchvision torchaudio -y
  2. #这个是针对RTX 30/40/50显卡的。
  3. pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
复制代码
如果其他版本请参考官网: https://pytorch.org/get-started/locally/
认识stable_baseline3

stable_baseline3提供了许多模型,如下列表:
名称动作空间建议应用场景核心优势PPO连续 & 离散全能选手,如机器人走动、金融交易、游戏 AI极其稳定,对超参数不敏感,支持大规模并行训练。DQN仅离散经典游戏(Atari)、开关控制、迷宫寻路理解简单,在离散控制领域非常经典且有效。SAC仅连续复杂物理模拟、机械臂抓取、自动驾驶探索效率极高,能自动寻找最优路径且不轻易陷入局部最优。TD3仅连续工业控制、无人机飞行、精密动作针对 DDPG 的缺陷做了改进,训练过程比 SAC 更平滑。A2C连续 & 离散简单逻辑测试、快速原型验证结构简单,虽然不如 PPO 稳定,但在特定并行环境下速度极快。在声明模型中,可以设置多种参数,这里列出常用的:
目前不需要搞懂都有什么作用,后面有文章会详细讲解

  • 训练参数


  • learning_rate:学习率
  • gamma:折扣因子
  • batch_size:更新模型使用数据量
  • verbose:打印信息模式。0-静默模式,1-信息模式,2-调试模式
  • device:指定训练设备cuda使用显卡,cpu使用cpu

  • 模型规则


  • MlpPolicy:多层感知机。适用于状态是数值场景(传感器等)
  • CnnPolicy:卷积神经网络。适用于状态是图像场景(游戏等)
训练第一个强化学习模型

案例

案例描述:训练一个gymnasium默认提供的游戏环境,平衡杆游戏。
  1. import gymnasium as gym
  2. from stable_baselines3 import PPO
  3. env = gym.make("CartPole-v1")
  4. model = PPO("MlpPolicy", env, verbose=1, device="cuda")
  5. print("开始训练...")
  6. model.learn(total_timesteps=10000)
  7. print("正在保存模型...")
  8. model.save("ppo_cartpole")
  9. print("正在读取模型...")
  10. env = gym.make("CartPole-v1", render_mode="human")
  11. loaded_model = PPO.load("ppo_cartpole", env=env)
  12. print("训练结束,开始演示...")
  13. obs, _ = env.reset()
  14. for i in range(1000):
  15.     action, _states = loaded_model.predict(obs, deterministic=True)
  16.     obs, reward, terminated, truncated, info = env.step(action)
  17.    
  18.     if terminated or truncated:
  19.         obs, _ = env.reset()
  20. env.close()
复制代码
代码解释

代码流程如下:
初始化环境模型->训练模型->保存模型->加载模型->模型预测
初始化环境模型

初始化模型以及游戏的环境
  1. env = gym.make("CartPole-v1")
  2. model = PPO("MlpPolicy", env, verbose=1, device="cuda")
  3. env = gym.make("CartPole-v1", render_mode="human")
复制代码

  • gym中的make方法利用默认的游戏环境,CartPole-v1是游戏名,下面有一个render_mode="human"参数,用于标识是否展示画面。训练时展示画面会降低训练的速度,一般在预测时才使用
训练模型
  1. model.learn(total_timesteps=10000)
复制代码

  • total_timesteps:训练10000次
保存模型
  1. model.save("ppo_cartpole")
复制代码

  • "ppo_cartpole" 为保存模型的名字,这里是保存在当前文件夹中。
加载模型
  1. loaded_model = PPO.load("ppo_cartpole", env=env)
复制代码

  • 第一个参数:刚刚保存的模型路径
  • 第二个参数:训练的环境
模型预测
  1. obs, _ = env.reset()
  2. for i in range(1000):
  3.     action, _states = loaded_model.predict(obs, deterministic=True)
  4.     obs, reward, terminated, truncated, info = env.step(action)
  5.    
  6.     if terminated or truncated:
  7.         obs, _ = env.reset()
复制代码

  • env.reset()重置环境,返回初始观测值obs和info(这里没用到)
  • 模型的predict方法用于根据观测值obs预测下一步行动。注意:deterministic参数要为True,不然会报错
  • 模型的step方法根据行动值返回结果。(这些都是什么后面文章会讲)
如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册