簡介
stable_baseline3 是一個基於 PyTorch 的強化學習算法開源庫,裏面集成了多種強化學習算法,使用這個開源庫能夠讓我們不需要過度關注強化學習算法細節,專注於AI業務的開發。
環境配置
pip install stable-baselines3
pip install gymnasium
這裏stable-baselines3會默認安裝pytroch框架,但是是不帶cuda版本的,這就意味着我們無法利用我們的顯卡對模型進行訓練。
下載cuda版本的pytroch步驟如下:
- 卸載原來版本的
pytroch框架
pip uninstall torch torchvision torchaudio -y
#這個是針對RTX 30/40/50顯卡的。
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
如果其他版本請參考官網: https://pytorch.org/get-started/locally/
認識stable_baseline3
stable_baseline3提供了許多模型,如下列表:
| 名稱 | 動作空間 | 建議應用場景 | 核心優勢 |
|---|---|---|---|
| PPO | 連續 & 離散 | 全能選手,如機器人走動、金融交易、遊戲 AI | 極其穩定,對超參數不敏感,支持大規模並行訓練。 |
| DQN | 僅離散 | 經典遊戲(Atari)、開關控制、迷宮尋路 | 理解簡單,在離散控制領域非常經典且有效。 |
| SAC | 僅連續 | 複雜物理模擬、機械臂抓取、自動駕駛 | 探索效率極高,能自動尋找最優路徑且不輕易陷入局部最優。 |
| TD3 | 僅連續 | 工業控制、無人機飛行、精密動作 | 針對 DDPG 的缺陷做了改進,訓練過程比 SAC 更平滑。 |
| A2C | 連續 & 離散 | 簡單邏輯測試、快速原型驗證 | 結構簡單,雖然不如 PPO 穩定,但在特定並行環境下速度極快。 |
在聲明模型中,可以設置多種參數,這裏列出常用的:
目前不需要搞懂都有什麼作用,後面有文章會詳細講解
- 訓練參數
learning_rate:學習率gamma:折扣因子batch_size:更新模型使用數據量verbose:打印信息模式。0-靜默模式,1-信息模式,2-調試模式device:指定訓練設備cuda使用顯卡,cpu使用cpu
- 模型規則
MlpPolicy:多層感知機。適用於狀態是數值場景(傳感器等)CnnPolicy:卷積神經網絡。適用於狀態是圖像場景(遊戲等)
訓練第一個強化學習模型
案例
案例描述:訓練一個gymnasium默認提供的遊戲環境,平衡杆遊戲。
import gymnasium as gym
from stable_baselines3 import PPO
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1, device="cuda")
print("開始訓練...")
model.learn(total_timesteps=10000)
print("正在保存模型...")
model.save("ppo_cartpole")
print("正在讀取模型...")
env = gym.make("CartPole-v1", render_mode="human")
loaded_model = PPO.load("ppo_cartpole", env=env)
print("訓練結束,開始演示...")
obs, _ = env.reset()
for i in range(1000):
action, _states = loaded_model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
env.close()
代碼解釋
代碼流程如下:
初始化環境模型->訓練模型->保存模型->加載模型->模型預測
初始化環境模型
初始化模型以及遊戲的環境
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1, device="cuda")
env = gym.make("CartPole-v1", render_mode="human")
gym中的make方法利用默認的遊戲環境,CartPole-v1是遊戲名,下面有一個render_mode="human"參數,用於標識是否展示畫面。訓練時展示畫面會降低訓練的速度,一般在預測時才使用
訓練模型
model.learn(total_timesteps=10000)
total_timesteps:訓練10000次
保存模型
model.save("ppo_cartpole")
"ppo_cartpole"為保存模型的名字,這裏是保存在當前文件夾中。
加載模型
loaded_model = PPO.load("ppo_cartpole", env=env)
- 第一個參數:剛剛保存的模型路徑
- 第二個參數:訓練的環境
模型預測
obs, _ = env.reset()
for i in range(1000):
action, _states = loaded_model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
env.reset()重置環境,返回初始觀測值obs和info(這裏沒用到)- 模型的
predict方法用於根據觀測值obs預測下一步行動。注意:deterministic參數要為True,不然會報錯 - 模型的
step方法根據行動值返回結果。(這些都是什麼後面文章會講)
如果❤喜歡❤本系列教程,就點個關注吧,後續不定期更新~