博客 / 詳情

返回

stable_baseline3 快速入門(一): 訓練第一個強化學習模型

簡介

stable_baseline3 是一個基於 PyTorch 的強化學習算法開源庫,裏面集成了多種強化學習算法,使用這個開源庫能夠讓我們不需要過度關注強化學習算法細節,專注於AI業務的開發。

環境配置

pip install stable-baselines3
pip install gymnasium

這裏stable-baselines3會默認安裝pytroch框架,但是是不帶cuda版本的,這就意味着我們無法利用我們的顯卡對模型進行訓練。
下載cuda版本的pytroch步驟如下:

  1. 卸載原來版本的pytroch框架
pip uninstall torch torchvision torchaudio -y
#這個是針對RTX 30/40/50顯卡的。
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

如果其他版本請參考官網: https://pytorch.org/get-started/locally/

認識stable_baseline3

stable_baseline3提供了許多模型,如下列表:

名稱 動作空間 建議應用場景 核心優勢
PPO 連續 & 離散 全能選手,如機器人走動、金融交易、遊戲 AI 極其穩定,對超參數不敏感,支持大規模並行訓練。
DQN 僅離散 經典遊戲(Atari)、開關控制、迷宮尋路 理解簡單,在離散控制領域非常經典且有效。
SAC 僅連續 複雜物理模擬、機械臂抓取、自動駕駛 探索效率極高,能自動尋找最優路徑且不輕易陷入局部最優。
TD3 僅連續 工業控制、無人機飛行、精密動作 針對 DDPG 的缺陷做了改進,訓練過程比 SAC 更平滑。
A2C 連續 & 離散 簡單邏輯測試、快速原型驗證 結構簡單,雖然不如 PPO 穩定,但在特定並行環境下速度極快。

聲明模型中,可以設置多種參數,這裏列出常用的:
目前不需要搞懂都有什麼作用,後面有文章會詳細講解

  1. 訓練參數
  • learning_rate:學習率
  • gamma:折扣因子
  • batch_size:更新模型使用數據量
  • verbose:打印信息模式。0-靜默模式,1-信息模式,2-調試模式
  • device:指定訓練設備cuda使用顯卡,cpu使用cpu
  1. 模型規則
  • MlpPolicy:多層感知機。適用於狀態是數值場景(傳感器等)
  • CnnPolicy:卷積神經網絡。適用於狀態是圖像場景(遊戲等)

訓練第一個強化學習模型

案例

案例描述:訓練一個gymnasium默認提供的遊戲環境,平衡杆遊戲。

import gymnasium as gym
from stable_baselines3 import PPO

env = gym.make("CartPole-v1")

model = PPO("MlpPolicy", env, verbose=1, device="cuda")

print("開始訓練...")
model.learn(total_timesteps=10000)

print("正在保存模型...")
model.save("ppo_cartpole")

print("正在讀取模型...")
env = gym.make("CartPole-v1", render_mode="human")
loaded_model = PPO.load("ppo_cartpole", env=env)

print("訓練結束,開始演示...")
obs, _ = env.reset()
for i in range(1000):
    action, _states = loaded_model.predict(obs, deterministic=True)

    obs, reward, terminated, truncated, info = env.step(action)
    
    if terminated or truncated:
        obs, _ = env.reset()

env.close()

代碼解釋

代碼流程如下:
初始化環境模型->訓練模型->保存模型->加載模型->模型預測

初始化環境模型

初始化模型以及遊戲的環境

env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1, device="cuda")

env = gym.make("CartPole-v1", render_mode="human")
  • gym中的make方法利用默認的遊戲環境,CartPole-v1是遊戲名,下面有一個render_mode="human"參數,用於標識是否展示畫面。訓練時展示畫面會降低訓練的速度,一般在預測時才使用
訓練模型
model.learn(total_timesteps=10000)
  • total_timesteps:訓練10000次
保存模型
model.save("ppo_cartpole")
  • "ppo_cartpole" 為保存模型的名字,這裏是保存在當前文件夾中。
加載模型
loaded_model = PPO.load("ppo_cartpole", env=env)
  • 第一個參數:剛剛保存的模型路徑
  • 第二個參數:訓練的環境
模型預測
obs, _ = env.reset()
for i in range(1000):
    action, _states = loaded_model.predict(obs, deterministic=True)

    obs, reward, terminated, truncated, info = env.step(action)
    
    if terminated or truncated:
        obs, _ = env.reset()
  • env.reset()重置環境,返回初始觀測值obsinfo(這裏沒用到)
  • 模型的predict方法用於根據觀測值obs預測下一步行動。注意:deterministic參數要為True,不然會報錯
  • 模型的step方法根據行動值返回結果。(這些都是什麼後面文章會講)

如果❤喜歡❤本系列教程,就點個關注吧,後續不定期更新~

user avatar
0 位用戶收藏了這個故事!

發佈 評論

Some HTML is okay.