簡介
Gymnasium 為強化學習提供了一個標準化的API,它定義了 Agent 應該如何觀察世界、如何做出動作以及如何獲得獎勵,不管是遊戲,還是工業設備,只需要滿足Gymnasium標準都能使用同一套代碼進行訓練。
認識Gymnasium
使用stable_baseline3只需要定義好Gymnasium環境,關注訓練的獎勵機制,將重點放在業務的開發上而不是複雜的算法。
Gymnasium提供了幾個核心的api:
| 方法 | 功能 | 返回值 |
|---|---|---|
reset() |
將環境重置為初始狀態,開始新回合。 | obs, info |
step(action) |
環境向前推進一步,執行動作。 | obs, reward, terminated, truncated, info |
render() |
可視化環境(根據 render_mode 渲染圖像或彈出窗口)。 |
視配置而定(通常無或為 np.array) |
close() |
釋放環境資源(關閉窗口、清理內存)。 | 無 |
其中的各個返回值的含義:
observation(Object): 當前狀態的描述。例如敵人,玩家的位置,玩家的狀態等reward(Float): 上一步動作獲得的獎勵terminated(Bool): 是否由於任務邏輯結束。例如:到達終點、掉進岩漿等truncated(Bool): 是否由於外部限制結束。例如:達到最大步數 500 步info(Dict): 輔助診斷信息,模型訓練通常不用,用於用户自定義調試或記錄額外統計。
手動構建環境
案例
案例描述:利用pygame構建一個簡單的遊戲,躲避掉落方塊,利用構建的獎勵機制,進行強化學習。
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pygame
import random
import cv2
import os
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_checker import check_env
class MyEnv(gym.Env):
def __init__(self, render_mode=None):
super(MyEnv, self).__init__()
#初始化參數
self.width = 400
self.height = 300
self.player_size = 30
self.enemy_size = 30
self.render_mode = render_mode
self.action_space = spaces.Discrete(3)
self.observation_space = spaces.Box(
low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
)
pygame.init()
if self.render_mode == "human":
self.screen = pygame.display.set_mode((self.width, self.height))
self.canvas = pygame.Surface((self.width, self.height))
self.font = pygame.font.SysFont("monospace", 15)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.player_x = self.width // 2 - self.player_size // 2
self.player_y = self.height - self.player_size - 10
self.enemies = []
self.score = 0
self.frame_count = 0
self.current_speed = 5
self.spawn_rate = 30
return self._get_obs(), {}
def step(self, action):
reward = 0
terminated = False
truncated = False
move_speed = 8
if action == 1 and self.player_x > 0: #
self.player_x -= move_speed
reward -= 0.05
if action == 2 and self.player_x < self.width - self.player_size:
self.player_x += move_speed
reward -= 0.05
self.frame_count += 1
level = self.score // 5
self.current_speed = 5 + level
self.spawn_rate = 30 - level * 2
spawn_rate = max(10, 30 - level)
if self.frame_count >= spawn_rate:
self.frame_count = 0
enemy_x = random.randint(0, self.width - self.enemy_size)
self.enemies.append([enemy_x, 0]) # [x, y]
for enemy in self.enemies:
enemy[1] += self.current_speed
player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)
enemy_rect = pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size)
if player_rect.colliderect(enemy_rect):
reward = -10
terminated = True
elif enemy[1] > self.height:
self.enemies.remove(enemy)
self.score += 1
reward = 1
if not terminated:
if self.score > 100:
reward += 0.01
reward += 0.01
obs = self._get_obs()
if self.render_mode == "human":
self._render_window()
return obs, reward, terminated, truncated, {}
def _get_obs(self):
self.canvas.fill((0, 0, 0))
pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))
for enemy in self.enemies:
pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size))
img_array = pygame.surfarray.array3d(self.canvas)
img_array = np.transpose(img_array, (1, 0, 2))
obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)
return obs.astype(np.uint8)
def _render_window(self):
self.screen.blit(self.canvas, (0, 0))
text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
self.screen.blit(text, (10, 10))
pygame.display.flip()
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
def train():
log_dir = "logs/DodgeGame"
os.makedirs(log_dir, exist_ok=True)
env = MyEnv()
check_env(env)
print("環境檢查通過...")
model_path = "models/dodge_ai.zip"
if not os.path.exists(model_path):
print("🆕 未發現舊模型,從頭開始訓練...")
model = PPO(
"CnnPolicy",
env,
verbose=1,
tensorboard_log=log_dir,
learning_rate=0.0001,
n_steps=4096,
batch_size=256,
device="cuda")
reset_timesteps = True
else:
print("發現舊模型,加載並繼續訓練...")
model = PPO.load(
model_path,
env=env,
device="cuda",
custom_objects={"learning_rate": 0.0001, "n_steps": 4096, "batch_size": 256}
)
reset_timesteps = False
print("開始訓練...")
model.learn(
total_timesteps=50000,
reset_num_timesteps=reset_timesteps
)
model.save("models/dodge_ai")
print("模型已保存!")
env.close()
def prodict():
env = MyEnv(render_mode="human")
model = PPO.load("models/dodge_ai", env=env, device="cuda")
obs, _ = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
pygame.time.Clock().tick(30)
if __name__ == "__main__":
train()
prodict()
代碼解析
代碼流程如下:
構建遊戲環境->訓練模型->模型預測
本篇重點講構建遊戲環境,其中的pygame相關代碼簡略,另外兩個流程參考之前文章。
構建遊戲環境
初始化類
該類繼承gym.Env類
class MyEnv(gym.Env):
構造函數__init__
def __init__(self, render_mode=None):
super(MyEnv, self).__init__()
#初始化參數
self.width = 400
self.height = 300
self.player_size = 30
self.enemy_size = 30
self.render_mode = render_mode
self.action_space = spaces.Discrete(3)
self.observation_space = spaces.Box(
low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
)
pygame.init()
if self.render_mode == "human":
self.screen = pygame.display.set_mode((self.width, self.height))
self.canvas = pygame.Surface((self.width, self.height))
self.font = pygame.font.SysFont("monospace", 15)
在構造函數中,我們主要完成的是聲明訓練的維度,和輸入:
- 輸入:
self.action_space = spaces.Discrete(3)其中的self.action_space是固定名稱的父類變量。spaces.Discrete(3)聲明輸入的數量,例如:向左 向右 和 不動3個輸入。 - 觀測維度:
self.observation_space也是固定名稱的父類變量。spaces.Box聲明觀測維度。
self.observation_space = spaces.Box(
low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
)
low:觀測參數的最小值high:觀測參數的最大值shape:聲明維度。例如:觀測圖片shape(高,寬,RGB),觀測一個平面,shape(高,寬)dtype:每個變量類型,這裏選np.uint8能夠節省訓練成本,默認是浮點型的。
任務重置 reset
相當於初始化遊戲狀態,遊戲的重新開始。返回的是觀測值和狀態信息(用於調試日誌)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.player_x = self.width // 2 - self.player_size // 2
self.player_y = self.height - self.player_size - 10
self.enemies = []
self.score = 0
self.frame_count = 0
self.current_speed = 5
self.spawn_rate = 30
return self._get_obs(), {}
觀測值 _get_obs:
通過pygame畫出的畫面,然後用opencv進行簡單處理:
- 轉換座標軸(由於
opencv座標xy軸跟pygame的xy是顛倒的) - 將畫面縮放到
84 * 84(可以提高訓練效率)
def _get_obs(self):
self.canvas.fill((0, 0, 0))
pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))
for enemy in self.enemies:
pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size))
img_array = pygame.surfarray.array3d(self.canvas)
img_array = np.transpose(img_array, (1, 0, 2))
obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)
return obs.astype(np.uint8)
步 step(重要)
這個函數是強化訓練的核心,規定了在一幀或者一步,我們給AI的分數。
分數的設置至關重要,這直接決定了訓練出來AI的質量
根據下面代碼(大部分都是遊戲邏輯),主要講設置獎勵分數:
- 在AI進行移動時 懲罰 0.05 分
- 在AI存活時 獎勵 0.01分,遊戲分數大於100時 存活獎勵 0.02分
- 在障礙物完全下落時 獎勵 1 分
- 在與障礙物碰撞時 懲罰 10 分
def step(self, action):
reward = 0
terminated = False
truncated = False
move_speed = 8
if action == 1 and self.player_x > 0: #
self.player_x -= move_speed
reward -= 0.05
if action == 2 and self.player_x < self.width - self.player_size:
self.player_x += move_speed
reward -= 0.05
self.frame_count += 1
level = self.score // 5
self.current_speed = 5 + level
self.spawn_rate = 30 - level * 2
spawn_rate = max(10, 30 - level)
if self.frame_count >= spawn_rate:
self.frame_count = 0
enemy_x = random.randint(0, self.width - self.enemy_size)
self.enemies.append([enemy_x, 0]) # [x, y]
for enemy in self.enemies:
enemy[1] += self.current_speed
player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)
enemy_rect = pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size)
if player_rect.colliderect(enemy_rect):
reward = -10
terminated = True
elif enemy[1] > self.height:
self.enemies.remove(enemy)
self.score += 1
reward = 1
if not terminated:
if self.score > 100:
reward += 0.01
reward += 0.01
obs = self._get_obs()
if self.render_mode == "human":
self._render_window()
return obs, reward, terminated, truncated, {}
展示遊戲畫面
下面完全是pygame代碼,用於顯示遊戲畫面,這裏就不解釋了。
def _render_window(self):
self.screen.blit(self.canvas, (0, 0))
text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
self.screen.blit(text, (10, 10))
pygame.display.flip()
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
你成功成為了一名調參俠了,快來試試吧!
如果❤喜歡❤本系列教程,就點個關注吧,後續不定期更新~