简介
stable_baseline3 是一个基于 PyTorch 的强化学习算法开源库,里面集成了多种强化学习算法,使用这个开源库能够让我们不需要过度关注强化学习算法细节,专注于AI业务的开发。
环境配置
- pip install stable-baselines3
- pip install gymnasium
复制代码 这里stable-baselines3会默认安装pytroch框架,但是是不带cuda版本的,这就意味着我们无法利用我们的显卡对模型进行训练。
下载cuda版本的pytroch步骤如下:
- pip uninstall torch torchvision torchaudio -y
- #这个是针对RTX 30/40/50显卡的。
- pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
复制代码 如果其他版本请参考官网: https://pytorch.org/get-started/locally/
认识stable_baseline3
stable_baseline3提供了许多模型,如下列表:
名称动作空间建议应用场景核心优势PPO连续 & 离散全能选手,如机器人走动、金融交易、游戏 AI极其稳定,对超参数不敏感,支持大规模并行训练。DQN仅离散经典游戏(Atari)、开关控制、迷宫寻路理解简单,在离散控制领域非常经典且有效。SAC仅连续复杂物理模拟、机械臂抓取、自动驾驶探索效率极高,能自动寻找最优路径且不轻易陷入局部最优。TD3仅连续工业控制、无人机飞行、精密动作针对 DDPG 的缺陷做了改进,训练过程比 SAC 更平滑。A2C连续 & 离散简单逻辑测试、快速原型验证结构简单,虽然不如 PPO 稳定,但在特定并行环境下速度极快。在声明模型中,可以设置多种参数,这里列出常用的:
目前不需要搞懂都有什么作用,后面有文章会详细讲解
- learning_rate:学习率
- gamma:折扣因子
- batch_size:更新模型使用数据量
- verbose:打印信息模式。0-静默模式,1-信息模式,2-调试模式
- device:指定训练设备cuda使用显卡,cpu使用cpu
- MlpPolicy:多层感知机。适用于状态是数值场景(传感器等)
- CnnPolicy:卷积神经网络。适用于状态是图像场景(游戏等)
训练第一个强化学习模型
案例
案例描述:训练一个gymnasium默认提供的游戏环境,平衡杆游戏。- import gymnasium as gym
- from stable_baselines3 import PPO
- env = gym.make("CartPole-v1")
- model = PPO("MlpPolicy", env, verbose=1, device="cuda")
- print("开始训练...")
- model.learn(total_timesteps=10000)
- print("正在保存模型...")
- model.save("ppo_cartpole")
- print("正在读取模型...")
- env = gym.make("CartPole-v1", render_mode="human")
- loaded_model = PPO.load("ppo_cartpole", env=env)
- print("训练结束,开始演示...")
- obs, _ = env.reset()
- for i in range(1000):
- action, _states = loaded_model.predict(obs, deterministic=True)
- obs, reward, terminated, truncated, info = env.step(action)
-
- if terminated or truncated:
- obs, _ = env.reset()
- env.close()
复制代码 代码解释
代码流程如下:
初始化环境模型->训练模型->保存模型->加载模型->模型预测
初始化环境模型
初始化模型以及游戏的环境- env = gym.make("CartPole-v1")
- model = PPO("MlpPolicy", env, verbose=1, device="cuda")
- env = gym.make("CartPole-v1", render_mode="human")
复制代码
- gym中的make方法利用默认的游戏环境,CartPole-v1是游戏名,下面有一个render_mode="human"参数,用于标识是否展示画面。训练时展示画面会降低训练的速度,一般在预测时才使用
训练模型
- model.learn(total_timesteps=10000)
复制代码 保存模型
- model.save("ppo_cartpole")
复制代码
- "ppo_cartpole" 为保存模型的名字,这里是保存在当前文件夹中。
加载模型
- loaded_model = PPO.load("ppo_cartpole", env=env)
复制代码
- 第一个参数:刚刚保存的模型路径
- 第二个参数:训练的环境
模型预测
- obs, _ = env.reset()
- for i in range(1000):
- action, _states = loaded_model.predict(obs, deterministic=True)
- obs, reward, terminated, truncated, info = env.step(action)
-
- if terminated or truncated:
- obs, _ = env.reset()
复制代码
- env.reset()重置环境,返回初始观测值obs和info(这里没用到)
- 模型的predict方法用于根据观测值obs预测下一步行动。注意:deterministic参数要为True,不然会报错
- 模型的step方法根据行动值返回结果。(这些都是什么后面文章会讲)
如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |