博客 / 詳情

返回

stable_baseline3 快速入門(二): 訓練自定義遊戲,構建Gymnasium訓練環境

簡介

Gymnasium 為強化學習提供了一個標準化的API,它定義了 Agent 應該如何觀察世界、如何做出動作以及如何獲得獎勵,不管是遊戲,還是工業設備,只需要滿足Gymnasium標準都能使用同一套代碼進行訓練。

認識Gymnasium

使用stable_baseline3只需要定義好Gymnasium環境,關注訓練的獎勵機制,將重點放在業務的開發上而不是複雜的算法。

Gymnasium提供了幾個核心的api:

方法 功能 返回值
reset() 將環境重置為初始狀態,開始新回合。 obs, info
step(action) 環境向前推進一步,執行動作。 obs, reward, terminated, truncated, info
render() 可視化環境(根據 render_mode 渲染圖像或彈出窗口)。 視配置而定(通常無或為 np.array
close() 釋放環境資源(關閉窗口、清理內存)。

其中的各個返回值的含義:

  • observation (Object): 當前狀態的描述。例如敵人,玩家的位置,玩家的狀態等
  • reward (Float): 上一步動作獲得的獎勵
  • terminated (Bool): 是否由於任務邏輯結束。例如:到達終點、掉進岩漿等
  • truncated (Bool): 是否由於外部限制結束。例如:達到最大步數 500 步
  • info (Dict): 輔助診斷信息,模型訓練通常不用,用於用户自定義調試或記錄額外統計。

手動構建環境

案例

案例描述:利用pygame構建一個簡單的遊戲,躲避掉落方塊,利用構建的獎勵機制,進行強化學習。

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pygame
import random
import cv2
import os
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_checker import check_env

class MyEnv(gym.Env):
    def __init__(self, render_mode=None):
        super(MyEnv, self).__init__()

        #初始化參數
        self.width = 400
        self.height = 300
        self.player_size = 30
        self.enemy_size = 30
        self.render_mode = render_mode

        self.action_space = spaces.Discrete(3)

        self.observation_space = spaces.Box(
            low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
        )

        pygame.init()
        if self.render_mode == "human":
            self.screen = pygame.display.set_mode((self.width, self.height))
        
        self.canvas = pygame.Surface((self.width, self.height))
        self.font = pygame.font.SysFont("monospace", 15)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        self.player_x = self.width // 2 - self.player_size // 2
        self.player_y = self.height - self.player_size - 10
        self.enemies = []
        self.score = 0
        self.frame_count = 0

        self.current_speed = 5
        self.spawn_rate = 30

        return self._get_obs(), {}

    def step(self, action):
        reward = 0
        terminated = False
        truncated = False

        move_speed = 8
        if action == 1 and self.player_x > 0: # 
            self.player_x -= move_speed
            reward -= 0.05

        if action == 2 and self.player_x < self.width - self.player_size:
            self.player_x += move_speed
            reward -= 0.05

        self.frame_count += 1

        level = self.score // 5
        self.current_speed = 5 + level
        self.spawn_rate = 30 - level * 2
        spawn_rate = max(10, 30 - level)

        if self.frame_count >= spawn_rate:
            self.frame_count = 0
            enemy_x = random.randint(0, self.width - self.enemy_size)
            self.enemies.append([enemy_x, 0]) # [x, y]

        for enemy in self.enemies:
            enemy[1] += self.current_speed
            
            player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)
            enemy_rect = pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size)
            
            if player_rect.colliderect(enemy_rect):
                reward = -10 
                terminated = True

            elif enemy[1] > self.height:
                self.enemies.remove(enemy)
                self.score += 1
                reward = 1 
        
        if not terminated:
            if self.score > 100:
                reward += 0.01
            reward += 0.01

        obs = self._get_obs()

        if self.render_mode == "human":
            self._render_window()

        return obs, reward, terminated, truncated, {}

    def _get_obs(self):
        self.canvas.fill((0, 0, 0))
        pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))
        
        for enemy in self.enemies:
            pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size))

        img_array = pygame.surfarray.array3d(self.canvas)
        img_array = np.transpose(img_array, (1, 0, 2))
        obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)

        return obs.astype(np.uint8)

    def _render_window(self):
        self.screen.blit(self.canvas, (0, 0))
        text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
        self.screen.blit(text, (10, 10))
        pygame.display.flip()

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

def train():
    log_dir = "logs/DodgeGame"
    os.makedirs(log_dir, exist_ok=True)

    env = MyEnv()
    check_env(env)
    print("環境檢查通過...")

    model_path = "models/dodge_ai.zip"
    if not os.path.exists(model_path):
        print("🆕 未發現舊模型,從頭開始訓練...")
        model = PPO(
            "CnnPolicy", 
            env, 
            verbose=1,
            tensorboard_log=log_dir,
            learning_rate=0.0001,
            n_steps=4096,
            batch_size=256,
            device="cuda")
        reset_timesteps = True
    else:
        print("發現舊模型,加載並繼續訓練...")
        model = PPO.load(
            model_path, 
            env=env,      
            device="cuda",
            custom_objects={"learning_rate": 0.0001, "n_steps": 4096, "batch_size": 256}
        )
        reset_timesteps = False
   
    print("開始訓練...")

    model.learn(
        total_timesteps=50000,
        reset_num_timesteps=reset_timesteps
    )

    model.save("models/dodge_ai")
    print("模型已保存!")
    env.close()

def prodict():
    env = MyEnv(render_mode="human")
    model = PPO.load("models/dodge_ai", env=env, device="cuda")
    obs, _ = env.reset()

    while True:
        action, _states = model.predict(obs, deterministic=True)

        obs, reward, terminated, truncated, info = env.step(action)

        if terminated or truncated:
            obs, _ = env.reset()
        
        pygame.time.Clock().tick(30)

if __name__ == "__main__":
    train()

    prodict()

代碼解析

代碼流程如下:
構建遊戲環境->訓練模型->模型預測
本篇重點講構建遊戲環境,其中的pygame相關代碼簡略,另外兩個流程參考之前文章。

構建遊戲環境

初始化類

該類繼承gym.Env

class MyEnv(gym.Env):
構造函數__init__
def __init__(self, render_mode=None):
        super(MyEnv, self).__init__()

        #初始化參數
        self.width = 400
        self.height = 300
        self.player_size = 30
        self.enemy_size = 30
        self.render_mode = render_mode

        self.action_space = spaces.Discrete(3)

        self.observation_space = spaces.Box(
            low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
        )

        pygame.init()
        if self.render_mode == "human":
            self.screen = pygame.display.set_mode((self.width, self.height))
        
        self.canvas = pygame.Surface((self.width, self.height))
        self.font = pygame.font.SysFont("monospace", 15)

在構造函數中,我們主要完成的是聲明訓練的維度,和輸入:

  • 輸入:self.action_space = spaces.Discrete(3)其中的self.action_space固定名稱的父類變量spaces.Discrete(3)聲明輸入的數量,例如:向左 向右 和 不動3個輸入。
  • 觀測維度:self.observation_space也是固定名稱的父類變量spaces.Box聲明觀測維度。
self.observation_space = spaces.Box(
    low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
)
  1. low:觀測參數的最小值
  2. high:觀測參數的最大值
  3. shape:聲明維度。例如:觀測圖片shape(高,寬,RGB),觀測一個平面,shape(高,寬)
  4. dtype:每個變量類型,這裏選np.uint8能夠節省訓練成本,默認是浮點型的。
任務重置 reset

相當於初始化遊戲狀態,遊戲的重新開始。返回的是觀測值狀態信息(用於調試日誌)

def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        self.player_x = self.width // 2 - self.player_size // 2
        self.player_y = self.height - self.player_size - 10
        self.enemies = []
        self.score = 0
        self.frame_count = 0

        self.current_speed = 5
        self.spawn_rate = 30

        return self._get_obs(), {}

觀測值 _get_obs
通過pygame畫出的畫面,然後用opencv進行簡單處理:

  1. 轉換座標軸(由於opencv座標xy軸跟pygame的xy是顛倒的)
  2. 將畫面縮放到84 * 84(可以提高訓練效率)
def _get_obs(self):
        self.canvas.fill((0, 0, 0))
        pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))
        
        for enemy in self.enemies:
            pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size))

        img_array = pygame.surfarray.array3d(self.canvas)
        img_array = np.transpose(img_array, (1, 0, 2))
        obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)

        return obs.astype(np.uint8)
步 step(重要)

這個函數是強化訓練的核心,規定了在一幀或者一步,我們給AI的分數。
分數的設置至關重要,這直接決定了訓練出來AI的質量
根據下面代碼(大部分都是遊戲邏輯),主要講設置獎勵分數

  1. 在AI進行移動時 懲罰 0.05 分
  2. 在AI存活時 獎勵 0.01分,遊戲分數大於100時 存活獎勵 0.02分
  3. 在障礙物完全下落時 獎勵 1 分
  4. 在與障礙物碰撞時 懲罰 10 分
def step(self, action):
        reward = 0
        terminated = False
        truncated = False

        move_speed = 8
        if action == 1 and self.player_x > 0: # 
            self.player_x -= move_speed
            reward -= 0.05

        if action == 2 and self.player_x < self.width - self.player_size:
            self.player_x += move_speed
            reward -= 0.05

        self.frame_count += 1

        level = self.score // 5
        self.current_speed = 5 + level
        self.spawn_rate = 30 - level * 2
        spawn_rate = max(10, 30 - level)

        if self.frame_count >= spawn_rate:
            self.frame_count = 0
            enemy_x = random.randint(0, self.width - self.enemy_size)
            self.enemies.append([enemy_x, 0]) # [x, y]

        for enemy in self.enemies:
            enemy[1] += self.current_speed
            
            player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)
            enemy_rect = pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size)
            
            if player_rect.colliderect(enemy_rect):
                reward = -10 
                terminated = True

            elif enemy[1] > self.height:
                self.enemies.remove(enemy)
                self.score += 1
                reward = 1 
        
        if not terminated:
            if self.score > 100:
                reward += 0.01
            reward += 0.01

        obs = self._get_obs()

        if self.render_mode == "human":
            self._render_window()

        return obs, reward, terminated, truncated, {}
展示遊戲畫面

下面完全是pygame代碼,用於顯示遊戲畫面,這裏就不解釋了。

def _render_window(self):
        self.screen.blit(self.canvas, (0, 0))
        text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
        self.screen.blit(text, (10, 10))
        pygame.display.flip()

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

你成功成為了一名調參俠了,快來試試吧!

如果❤喜歡❤本系列教程,就點個關注吧,後續不定期更新~

user avatar
0 位用戶收藏了這個故事!

發佈 評論

Some HTML is okay.