博客 / 詳情

返回

基於強化學習的量化交易框架 TensorTrade

打開交易圖表,堆上十個技術指標,然後對着屏幕發呆不知道下一步怎麼操作——這場景對交易員來説太熟悉了。如果把歷史數據丟給計算機,告訴它“去試錯”。賺了有獎勵,虧了有懲罰。讓它在不斷的嘗試和失敗中學習,最終迭代出一個不説完美、但至少能邏輯自洽的交易策略。

這就是 TensorTrade 的核心邏輯。

TensorTrade 是一個專注於利用 強化學習 (Reinforcement Learning, RL) 構建和訓練交易算法的開源 Python 框架。

數據獲取與特徵工程

這裏用

yfinance

抓取數據,配合

pandas_ta

計算技術指標。對數收益率 (Log Returns)、RSI 和 MACD 是幾個比較基礎的特徵輸入。

  pip install yfinance pandas_ta

import yfinance as yf  
import pandas_ta as ta  
import pandas as pd  

# Pick your ticker  
TICKER = "TTRD"  # TODO: change this to something real, e.g. "AAPL", "BTC-USD"  
TRAIN_START_DATE = "2021-02-09"  
TRAIN_END_DATE   = "2021-09-30"  
EVAL_START_DATE  = "2021-10-01"  
EVAL_END_DATE    = "2021-11-12"  

def build_dataset(ticker, start, end, filename):  
    # 1. Download hourly OHLCV data  
    df = yf.Ticker(ticker).history(  
        start=start,  
        end=end,  
        interval="60m"  
    )  
    # 2. Clean up  
    df = df.drop(["Dividends", "Stock Splits"], axis=1)  
    df["Volume"] = df["Volume"].astype(int)  
    # 3. Add some basic features  
    df.ta.log_return(append=True, length=16)  
    df.ta.rsi(append=True, length=14)  
    df.ta.macd(append=True, fast=12, slow=26)  
    # 4. Move Datetime from index to column  
    df = df.reset_index()  
    # 5. Save  
    df.to_csv(filename, index=False)  
    print(f"Saved {filename} with {len(df)} rows")  

build_dataset(TICKER, TRAIN_START_DATE, TRAIN_END_DATE, "training.csv")  
 build_dataset(TICKER, EVAL_START_DATE,  EVAL_END_DATE,  "evaluation.csv")

腳本跑完,目錄下會生成

training.csv

evaluation.csv

。包含了 OHLCV 基礎數據和幾個預處理好的指標。這些就是訓練 RL 模型的數據。

構建 TensorTrade 交互環境

強化學習沒法直接使用CSV 文件。所以需要一個標準的交互 環境 (Environment):能夠輸出當前狀態 (State),接收智能體的動作 (Action),並反饋獎勵 (Reward)。

TensorTrade 把這個過程模塊化了:

  • Instrument:定義交易標的(如 USD, TTRD)。
  • Wallet:管理資產餘額。
  • Portfolio:錢包組合。
  • Stream / DataFeed:處理特徵數據流。
  • reward_scheme / action_scheme:定義怎麼操作,以及操作的好壞怎麼評分。
  pip install tensortrade

下面是一個環境工廠函數 (Environment Factory) 的實現,設計得比較輕量,這樣可以方便後續接入 Ray:

 import os  
import pandas as pd  

from tensortrade.feed.core import DataFeed, Stream  
from tensortrade.oms.instruments import Instrument  
from tensortrade.oms.exchanges import Exchange, ExchangeOptions  
from tensortrade.oms.services.execution.simulated import execute_order  
from tensortrade.oms.wallets import Wallet, Portfolio  
import tensortrade.env.default as default  

def create_env(config):  
    """  
    Build a TensorTrade environment from a CSV.  
    config needs:  
      - csv_filename  
      - window_size  
      - reward_window_size  
      - max_allowed_loss  
    """  
    # 1. Read the dataset  
    dataset = (  
        pd.read_csv(config["csv_filename"], parse_dates=["Datetime"])  
        .fillna(method="backfill")  
        .fillna(method="ffill")  
    )  
    # 2. Price stream (we'll trade on Close)  
    commission = 0.0035  # 0.35%, tweak this to your broker  
    price = Stream.source(  
        list(dataset["Close"]), dtype="float"  
    ).rename("USD-TTRD")  
    options = ExchangeOptions(commission=commission)  
    exchange = Exchange("TTSE", service=execute_order, options=options)(price)  
    # 3. Instruments and wallets  
    USD = Instrument("USD", 2, "US Dollar")  
    TTRD = Instrument("TTRD", 2, "TensorTrade Corp")  # just a label  
    cash_wallet = Wallet(exchange, 1000 * USD)  # start with $1000  
    asset_wallet = Wallet(exchange, 0 * TTRD)   # start with zero TTRD  
    portfolio = Portfolio(USD, [cash_wallet, asset_wallet])  
    # 4. Renderer feed (optional, useful for plotting later)  
    renderer_feed = DataFeed([  
        Stream.source(list(dataset["Datetime"])).rename("date"),  
        Stream.source(list(dataset["Open"]), dtype="float").rename("open"),  
        Stream.source(list(dataset["High"]), dtype="float").rename("high"),  
        Stream.source(list(dataset["Low"]), dtype="float").rename("low"),  
        Stream.source(list(dataset["Close"]), dtype="float").rename("close"),  
        Stream.source(list(dataset["Volume"]), dtype="float").rename("volume"),  
    ])  
    renderer_feed.compile()  
    # 5. Feature feed for the RL agent  
    features = []  
    # Skip Datetime (first column) and stream everything else  
    for col in dataset.columns[1:]:  
        s = Stream.source(list(dataset[col]), dtype="float").rename(col)  
        features.append(s)  
    feed = DataFeed(features)  
    feed.compile()  
    # 6. Reward and action scheme  
    reward_scheme = default.rewards.SimpleProfit(  
        window_size=config["reward_window_size"]  
    )  
    action_scheme = default.actions.BSH(  
        cash=cash_wallet,  
        asset=asset_wallet  
    )  
    # 7. Put everything together in an environment  
    env = default.create(  
        portfolio=portfolio,  
        action_scheme=action_scheme,  
        reward_scheme=reward_scheme,  
        feed=feed,  
        renderer=[],  
        renderer_feed=renderer_feed,  
        window_size=config["window_size"],  
        max_allowed_loss=config["max_allowed_loss"]  
    )  
     return env

這樣“遊戲”規則就已經定好了:觀察最近 N 根 K 線和指標(State),決定買賣持(Action),目標是讓一段時間內的利潤最大化(Reward)。

基於 Ray RLlib 與 PPO 算法的模型訓練

底層環境搭好,接下來讓 Ray RLlib 介入處理 RL 的核心邏輯。

選用 PPO (Proximal Policy Optimization) 算法,這在連續控制和離散動作空間都有不錯的表現。為了找到更優解,順手做一個簡單的超參數網格搜索:網絡架構、學習率、Minibatch 大小,都跑一遍試試。

  pip install "ray[rllib]"

訓練腳本如下:

 import os  
import ray  
from ray import tune  
from ray.tune.registry import register_env  

from your_module import create_env  # wherever you defined create_env  

# Some hyperparameter grids to try  
FC_SIZE = tune.grid_search([  
    [256, 256],  
    [1024],  
    [128, 64, 32],  
])  
LEARNING_RATE = tune.grid_search([  
    0.001,  
    0.0005,  
    0.00001,  
])  
MINIBATCH_SIZE = tune.grid_search([  
    5,  
    10,  
    20,  
])  
cwd = os.getcwd()  
# Register our custom environment with RLlib  
register_env("MyTrainingEnv", lambda cfg: create_env(cfg))  
env_config_training = {  
    "window_size": 14,  
    "reward_window_size": 7,  
    "max_allowed_loss": 0.10,  # cut episodes early if loss > 10%  
    "csv_filename": os.path.join(cwd, "training.csv"),  
}  
env_config_evaluation = {  
    "max_allowed_loss": 1.00,  
    "csv_filename": os.path.join(cwd, "evaluation.csv"),  
}  
ray.init(ignore_reinit_error=True)  
analysis = tune.run(  
    run_or_experiment="PPO",  
    name="MyExperiment1",  
    metric="episode_reward_mean",  
    mode="max",  
    stop={  
        "training_iteration": 5,  # small for demo, increase in real runs  
    },  
    config={  
        "env": "MyTrainingEnv",  
        "env_config": env_config_training,  
        "log_level": "WARNING",  
        "framework": "torch",     # or "tf"  
        "ignore_worker_failures": True,  
        "num_workers": 1,  
        "num_envs_per_worker": 1,  
        "num_gpus": 0,  
        "clip_rewards": True,  
        "lr": LEARNING_RATE,  
        "gamma": 0.50,            # discount factor  
        "observation_filter": "MeanStdFilter",  
        "model": {  
            "fcnet_hiddens": FC_SIZE,  
        },  
        "sgd_minibatch_size": MINIBATCH_SIZE,  
        "evaluation_interval": 1,  
        "evaluation_config": {  
            "env_config": env_config_evaluation,  
            "explore": False,     # no exploration during evaluation  
        },  
    },  
    num_samples=1,  
    keep_checkpoints_num=10,  
    checkpoint_freq=1,  
 )

這段代碼本質上是在運行一場“交易機器人錦標賽”。Ray 會根據定義的參數組合並行訓練多個 PPO 智能體,追蹤它們的平均回合獎勵,並保存下表現最好的 Checkpoint 供後續調用。

自定義獎勵機制 (PBR)

默認的

SimpleProfit

獎勵邏輯很簡單,但實戰中往往過於粗糙。我們有時需要根據具體的交易邏輯來重塑獎勵函數。比如説基於持倉的獎勵方案 PBR (Position-Based Reward)

  • 維護當前持倉狀態(多頭或空頭)。
  • 監控價格變動。
  • 獎勵計算 = 價格變動 × 持倉方向。

價格漲了你做多,給正反饋;價格跌了你做空,也給正反饋。反之則是懲罰。

 from tensortrade.env.default.rewards import RewardScheme  
from tensortrade.feed.core import DataFeed, Stream  

class PBR(RewardScheme):  
    """  
    Position-Based Reward (PBR)  
    Rewards the agent based on price changes and its current position.  
    """  
    registered_name = "pbr"  
    def __init__(self, price: Stream):  
        super().__init__()  
        self.position = -1  # start flat/short  
        # Price differences  
        r = Stream.sensor(price, lambda p: p.value, dtype="float").diff()  
        # Position stream  
        position = Stream.sensor(self, lambda rs: rs.position, dtype="float")  
        # Reward = price_change * position  
        reward = (r * position).fillna(0).rename("reward")  
        self.feed = DataFeed([reward])  
        self.feed.compile()  
    def on_action(self, action: int):  
        # Simple mapping: action 0 = long, everything else = short  
        self.position = 1 if action == 0 else -1  
    def get_reward(self, portfolio):  
        return self.feed.next()["reward"]  
    def reset(self):  
        self.position = -1  
         self.feed.reset()

接入也很簡單,在

create_env

函數裏替換掉原來的

reward_scheme

即可:

 reward_scheme = PBR(price)

這樣改的好處是反饋更密集。智能體不需要等到最後平倉才知道賺沒賺,每一個 step 都能收到關於“是否站對了隊”的信號。

後續優化方向與建議

這套流程跑通只是個開始,想要真正可用,還有很多工作要做 比如:

  • 數據置換:代碼裏的 TTRD 只是個佔位符,換成真實的標的(股票、Crypto、指數)。
  • 特徵工程:RSI 和 MACD 只是拋磚引玉,試試 ATR、布林帶,或者引入更長時間週期的特徵。
  • 參數調優gamma(折扣因子)、window_size(觀測窗口)對策略風格影響巨大,值得花時間去掃參。
  • 基準測試:這一步最關鍵。把你訓練出來的 RL 策略和 Buy & Hold(買入持有)比一比,甚至和隨機策略比一比。如果跑不過隨機策略,那就得從頭檢查了。

最後別忘了,我們只是研究,所以不要直接實盤。模型在訓練集上大殺四方是常態,能通過樣本外測試和模擬盤 (Paper Trading) 的考驗才是真本事。

https://avoid.overfit.cn/post/8c9e08414e514c73ab3aefd694294f79

作者:CodeBun

user avatar
0 位用戶收藏了這個故事!

發佈 評論

Some HTML is okay.