博客 / 詳情

返回

[python]FastAPI-Tracking ID 的設計

前言

在實際業務中,根據 tracking_id 追查日誌中一條請求的完整處理路徑是一個比較常見的需求。不過 FastAPI 官方並沒有提供相對應的功能,因此需要開發者自行實現。本文介紹如何基於 contextvars,為每次請求的完整流程都添加一個 tracking_id,並在日誌中記錄。

什麼是 contextvars

Python 在 3.7 版本的標準庫中加入了一個模塊 contextvars,顧名思義就是 "(Context Variables) 上下文變量",通常用來隱式地傳遞一些環境信息的變量,其作用跟 threading.local() 比較相似。不過 threading.local() 是針對線程的,隔離線程之間的數據狀態,而 contextvars 可以用在 asyncio 生態的異步協程中。PS: contextvars 不僅可以用在異步協程中,也可以替代 threading.local() 用在多線程函數中。

基本使用

  1. 首先編寫 context.py
import contextvars
from typing import Optional

TRACKING_ID: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
    'tracking_id', 
    default=None
)

def get_tracking_id() -> Optional[str]:
    """用於依賴注入"""
    return TRACKING_ID.get()

  1. 編寫中間件 middlewares.py,在請求頭和響應頭中添加 tracking_id 的信息。常見場景就是客户拿着 tracking_id 找碴。
import uuid

from starlette.middleware.base import (BaseHTTPMiddleware,
                                       RequestResponseEndpoint)
from starlette.requests import Request
from starlette.responses import Response

from context import TRACKING_ID


class TrackingIDMiddleware(BaseHTTPMiddleware):
    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint
    ) -> Response:
        tracking_id = str(uuid.uuid4())
        token = TRACKING_ID.set(tracking_id)
        # HTTP 請求頭習慣於使用 latin-1 編碼
        request.scope["headers"].append((b"x-request-id", tracking_id.encode("latin-1")))

        try:
            resp = await call_next(request)
        finally:
            # 無論是否成功,每次請求結束時重置 tracking_id,避免泄露到下一次的請求中
            TRACKING_ID.reset(token)

        # 可選, 在響應中設置跟蹤 ID 頭
        resp.headers["X-Tracking-ID"] = tracking_id

        return resp

  1. 編寫 handler 函數 handlers.py,測試在 handler 函數中獲取 tracking_id。
import asyncio

from context import TRACKING_ID


async def mock_db_query():
    await asyncio.sleep(1)
    current_id = TRACKING_ID.get()
    print(f"This is mock_db_query. Current tracking ID: {current_id}")
    await asyncio.sleep(1)
  1. 編寫主函數 main.py
import uvicorn
from fastapi import Depends, FastAPI
from fastapi.responses import PlainTextResponse
from starlette.background import BackgroundTasks

from context import TRACKING_ID, get_tracking_id
from handlers import mock_db_query
from middlewares import TrackingIDMiddleware

app = FastAPI()

app.add_middleware(TrackingIDMiddleware)


@app.get("/qwer")
async def get_qwer():
    """測試上下文變量傳遞"""
    current_id = TRACKING_ID.get()
    print(f"This is get qwer. Current tracking ID: {current_id}")
    return PlainTextResponse(f"Current tracking ID: {current_id}")


@app.get("/asdf")
async def get_asdf(tracking_id: str = Depends(get_tracking_id)):
    """測試依賴注入"""
    print(f"This is get asdf. tracking ID: {tracking_id}")
    await mock_db_query()
    return PlainTextResponse(f"Get request, tracking ID: {tracking_id}")

if __name__ == "__main__":
    uvicorn.run("main:app", host="127.0.0.1", port=8000, workers=4)
  1. 啓動服務後用 curl 測試 api,在控制枱可以看到 tracking_id 在請求中都能捕獲到。
This is get qwer. Current tracking ID: 01b0153f-4877-4ca0-ac35-ed88ab406452
INFO:     127.0.0.1:55708 - "GET /qwer HTTP/1.1" 200 OK
This is get asdf. tracking ID: 0be61d8d-11a0-4cb6-812f-51b9bfdc2639
This is mock_db_query. Current tracking ID: 0be61d8d-11a0-4cb6-812f-51b9bfdc2639
INFO:     127.0.0.1:55722 - "GET /asdf HTTP/1.1" 200 OK

使用 curl 的控制枱輸出

$ curl -i http://127.0.0.1:8000/qwer
HTTP/1.1 200 OK
date: Sat, 29 Nov 2025 06:16:46 GMT
server: uvicorn
content-length: 57
content-type: text/plain; charset=utf-8
x-tracking-id: 01b0153f-4877-4ca0-ac35-ed88ab406452

Current tracking ID: 01b0153f-4877-4ca0-ac35-ed88ab406452

===== 另一個請求 =====

$ curl -i http://127.0.0.1:8000/asdf
HTTP/1.1 200 OK
date: Sat, 29 Nov 2025 06:16:49 GMT
server: uvicorn
content-length: 62
content-type: text/plain; charset=utf-8
x-tracking-id: 0be61d8d-11a0-4cb6-812f-51b9bfdc2639

Get request, tracking ID: 0be61d8d-11a0-4cb6-812f-51b9bfdc2639

後台任務型 API

FastAPI 中有 from starlette.background import BackgroundTasks 可以讓接口直接響應,將實際流程放到後台異步處理。因為上面的中間件在響應時會重置 tracking_id,所以後台的協程函數可能不會獲取到 tracking_id。理論上是這樣的,但是本地測試時發現在異步協程中還是能獲取到 tracking_id,這可能是本地低併發的問題。在生產環境高併發的情況下,最好還是強制解耦,顯式傳遞 contextvars。

  1. 自定義的中間件類保持不變。
  2. 添加後台任務的 api
from starlette.background import BackgroundTasks
from handlers import mock_backgroud_task

@app.get("/zxcv")
async def get_zxcv(tasks: BackgroundTasks):
    """測試後台任務"""
    current_id = TRACKING_ID.get()
    print(f"This is get zxcv. The current id is {current_id}")

    # 顯式傳遞 tracking_id
    tasks.add_task(mock_backgroud_task, current_id)

    return PlainTextResponse(f"This is get zxcv. The current id is {current_id}")
  1. 後台協程函數中顯式處理。任務在啓動時,使用傳入的參數值,在自己的任務執行上下文中,重新設置 TRACKING_ID。在任務結束時,對任務自己設置的上下文進行清理。
async def mock_backgroud_task(request_tracking_id: Optional[str]):
    if request_tracking_id is None:
        request_tracking_id = str(uuid4())
        print(f"WARNING: No tracking ID found. Generate a new one: {request_tracking_id}")
    token = TRACKING_ID.set(request_tracking_id)
    try:
        # 模擬耗時的後台異步任務
        await asyncio.sleep(5)
        print(f"This is mock backgroud task. Current tracking ID: {request_tracking_id}")
    finally:
        # 確保 tracking ID 被重置
        TRACKING_ID.reset(token)
  1. 啓動主程序並測試 tracking_id 是否一致。

日誌記錄 tracking_id

tracking_id 在前面是手動添加到 print() 裏面的,未免還是麻煩了點,不注意的話還可能忘了。因此,最好是用日誌記錄器自動獲取 tracking_id,這樣代碼的侵入性改動就大大降低了。

首先需要封裝一個日誌模塊。這裏選用標準庫 logging,喜歡 loguru 的話也可以試試用 loguru 封裝(如果生產環境有頻繁且嚴苛的安全掃描,個人建議還是儘量用標準庫比較好,免得經常升級第三方依賴)。如果你的應用運行在 docker 或者 k8s 上,有現成的控制枱日誌採集工具,那麼這個日誌模塊只需要輸出到控制枱就行了。如果應用運行在 server 上,需要用比如 filebeat 這樣的工具讀取日誌文件,那麼需要解決日誌文件體積增長uvicorn 多工作進程運行同時寫日誌文件的競爭問題、以及高併發情況下寫日誌文件阻塞異步協程的問題

以下日誌模塊的示例解決了上述問題,主要特點:

  • 自動獲取 tracking_id,無需手動記錄。
  • 日誌為 JSON 格式,便於在 ELK 這樣的日誌聚合系統中查日誌。
  • 日誌會同時輸出到標準輸出和日誌文件,也可以配置輸出到其中一個。
  • 日誌文件按天輪轉,默認保留最近 7 天的日誌,避免日誌文件體積增長的問題。
  • 主進程和工作進程輸出到各自的日誌文件中,避免同時寫日誌文件的競爭問題。
  • 日誌會先輸出到隊列,避免寫文件阻塞異步協程的問題。
  1. 編輯文件 pkg/log/log.py。主要內容都在這個部分。需要對外提供的對象只有 setup_loggerclose_log_queue
import json
import logging
import os
import re
import sys
from logging.handlers import QueueHandler, QueueListener, TimedRotatingFileHandler
from multiprocessing import current_process
from pathlib import Path
from queue import Queue
from typing import Optional

from context import TRACKING_ID

_queue_listener = None
_queue_logger: Optional[Queue] = None
PATTERN_PROCESS_NAME = re.compile(r"SpawnProcess-(\d+)")


class JSONFormatter(logging.Formatter):
    """A logging formatter that outputs logs in JSON format."""
    def format(self, record: logging.LogRecord) -> str:
        log_record = {
            "@timestamp": self.formatTime(record, "%Y-%m-%dT%H:%M:%S%z"),
            "level": record.levelname,
            "name": record.name,
            # "taskName": getattr(record, "taskName", None),  # Record task name if needed
            "processName": record.processName,  # Record process name if needed
            "tracking_id": getattr(record, "tracking_id", None),
            "loc": "%s:%d" % (record.filename, record.lineno),
            "func": record.funcName,
            "message": record.getMessage(),
        }

        return json.dumps(log_record, ensure_ascii=False, default=str)


class TrackingIDFilter(logging.Filter):
    """A logging filter that adds tracking_id to log records.
    """
    def filter(self, record):
        record.tracking_id = TRACKING_ID.get()
        return True


def _setup_console_handler(level: int) -> logging.StreamHandler:
    """Setup a StreamHandler for console logging.
    
    Args:
        level (int): The logging level.
    """
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(level)
    handler.setFormatter(JSONFormatter())
    return handler


def _setup_file_handler(
    level: int, log_path: str, rotate_days: int
) -> TimedRotatingFileHandler:
    """Setup a TimedRotatingFileHandler for logging.
    
    Args:
        level (int): The logging level.
        log_path (str): The log path.
        rotate_days (int): The number of days to keep log files.
    """
    handler = TimedRotatingFileHandler(
        filename=log_path,
        when="midnight",
        interval=1,
        backupCount=rotate_days,
        encoding="utf-8",
    )
    handler.setLevel(level)
    handler.setFormatter(JSONFormatter())
    return handler


def _setup_queue_handler(level: int, log_queue: Queue) -> QueueHandler:
    """Setup a QueueHandler for logging.
    
    Args:
        level (int): The logging level.
        log_queue (Queue): The log queue.
    """
    handler = QueueHandler(log_queue)
    handler.setLevel(level)
    return handler


def _get_spawn_process_number(name: str) -> str:
    """
    Get the spawn process number for log file naming.
    The server should be started with multiple processes using uvicorn's --workers option.
    Prevent issues caused by multiple processes writing to the same log file.

    Args:
        name (str): The name of the log file.

    Returns:
        str: The spawn process number for log file naming.
    """
    try:
        process_name = current_process().name
        pid = os.getpid()

        if process_name == "MainProcess":
            return name
        elif m := PATTERN_PROCESS_NAME.match(process_name):
            return f"{name}-sp{m.group(1)}"
        else:
            return f"{name}-{pid}"

    except:
        return f"{name}-{os.getpid()}"


def _setup_logpath(log_dir: str, name: str) -> str:
    """Setup the log path.
    
    Args:
        log_dir (str): The log directory.
        name (str): The name of the log file. Example: "app"

    Returns:
        str: The log path.
    """
    main_name = _get_spawn_process_number(name)
    log_file = f"{main_name}.log"
    log_path = Path(log_dir) / log_file

    if not log_path.parent.exists():
        try:
            log_path.parent.mkdir(parents=True, exist_ok=True)
        except Exception as e:
            raise RuntimeError(
                f"Failed to create log directory: {log_path.parent}"
            ) from e
    return str(log_path)


def _validate(level: int, enable_console: bool, enable_file: bool, rotate_days: int) -> None:
    """Validate the log configuration.
    
    Args:
        level (int): The logging level.
        enable_console (bool): Whether to enable console logging.
        enable_file (bool): Whether to enable file logging.
        rotate_days (int): The number of days to keep log files.

    Raises:
        ValueError: If the log configuration is invalid.
    """
    if not enable_console and not enable_file:
        raise ValueError("At least one of enable_console or enable_file must be True.")

    if level not in [
        logging.DEBUG,
        logging.INFO,
        logging.WARNING,
        logging.ERROR,
        logging.CRITICAL,
    ]:
        raise ValueError("Invalid logging level specified.")

    if not isinstance(rotate_days, int) or rotate_days <= 0:
        raise ValueError("rotate_days must be a positive integer.")


def setup_logger(
    name: str = "app",
    level: int = logging.DEBUG,
    enable_console: bool = True,
    enable_file: bool = True,
    log_dir: str = "logs",
    rotate_days: int = 7,
) -> logging.Logger:
    """Setup a logger with console and/or file handlers.
    
    Args:
        name (str): The name of the logger. This will be used as the log file name prefix. Defaults to "app".
        level (int): The logging level. Defaults to logging.DEBUG.
        enable_console (bool): Whether to enable console logging. Defaults to True.
        enable_file (bool): Whether to enable file logging. Defaults to True.
        log_dir (str): The log directory. Defaults to "logs".
        rotate_days (int): The number of days to keep log files. Defaults to 7.

    Returns:
        logging.Logger: The configured logger.
    """
    logger = logging.getLogger(name)

    if logger.hasHandlers():
        return logger  # Logger is already configured

    _validate(level, enable_console, enable_file, rotate_days)

    logger.setLevel(level)
    logger.propagate = False  # Prevent log messages from being propagated to the root logger

    log_path = _setup_logpath(log_dir, name)

    handlers = []

    if enable_console:
        handlers.append(_setup_console_handler(level))

    if enable_file:
        handlers.append(_setup_file_handler(level, log_path, rotate_days))

    global _queue_logger, _queue_listener
    if not _queue_logger:
        _queue_logger = Queue(-1)

    queue_handler = _setup_queue_handler(level, _queue_logger)

    if not _queue_listener:
        _queue_listener = QueueListener(
            _queue_logger, *handlers, respect_handler_level=True
        )
        _queue_listener.start()

    logger.addHandler(queue_handler)
    logger.addFilter(TrackingIDFilter())

    return logger


def close_log_queue() -> None:
    """Close the log queue and stop the listener.
    This function should be called when the application is shutting down to ensure that the log queue is closed and the listener is stopped.
    """
    global _queue_listener
    if _queue_listener:
        _queue_listener.stop()
        _queue_listener = None

  1. 編輯 pkg/log/__init__.py,便於其它模塊導入,個人一般也是通過這個文件告訴別人這個模塊哪些對象是 public 的。logger 對象是一個單例,其它模塊直接使用即可,一般無需使用 setup_logger 單獨創建日誌記錄器對象。不過有的需求是各個類對象創建各自的日誌記錄器,所以也對外提供出去了(個人認為沒啥必要)。
from pathlib import Path

from .log import close_log_queue, setup_logger

logger = setup_logger(
    log_dir=str(Path(__file__).parent.parent.parent / "logs"),
)

__all__ = [
    "logger",
    "setup_logger",
    "close_log_queue",
]

  1. 通過 FastAPI 的 lifespan 來調用 close_log_queue
from contextlib import asynccontextmanager

from pkg.log import close_log_queue

@asynccontextmanager
async def lifespan(app: FastAPI):
    try:
        yield
    finally:
        close_log_queue()
        print("Shutdown!")


app = FastAPI(lifespan=lifespan)
  1. print() 改成 logger.info(),測試日誌中是否自動輸出 tracking_id。其中部分改動如下
async def mock_db_query():
    await asyncio.sleep(1)
    current_id = TRACKING_ID.get()
    logger.info(f"This is mock_db_query. Current tracking ID: {current_id}")
    await asyncio.sleep(1)
  1. 使用 curl 測試並查看日誌

curl 命令輸出如下:

$ curl -i  http://127.0.0.1:8000/asdf
HTTP/1.1 200 OK
date: Sat, 29 Nov 2025 13:07:09 GMT
content-length: 62
content-type: text/plain; charset=utf-8
x-tracking-id: 38a015f7-b0d3-41ea-a2b3-179bf608b4bb

Get request, tracking ID: 38a015f7-b0d3-41ea-a2b3-179bf608b4bb

服務端的控制枱輸出如下:

{"@timestamp": "2025-11-29T21:07:09+0800", "level": "INFO", "name": "app", "processName": "SpawnProcess-3", "tracking_id": "38a015f7-b0d3-41ea-a2b3-179bf608b4bb", "loc": "main.py:39", "func": "get_asdf", "message": "This is get asdf. tracking ID: 38a015f7-b0d3-41ea-a2b3-179bf608b4bb"}
{"@timestamp": "2025-11-29T21:07:10+0800", "level": "INFO", "name": "app", "processName": "SpawnProcess-3", "tracking_id": "38a015f7-b0d3-41ea-a2b3-179bf608b4bb", "loc": "handlers.py:12", "func": "mock_db_query", "message": "This is mock_db_query. Current tracking ID: 38a015f7-b0d3-41ea-a2b3-179bf608b4bb"}

日誌文件輸出如下:

{"@timestamp": "2025-11-29T21:07:09+0800", "level": "INFO", "name": "app", "processName": "SpawnProcess-3", "tracking_id": "38a015f7-b0d3-41ea-a2b3-179bf608b4bb", "loc": "main.py:39", "func": "get_asdf", "message": "This is get asdf. tracking ID: 38a015f7-b0d3-41ea-a2b3-179bf608b4bb"}
{"@timestamp": "2025-11-29T21:07:10+0800", "level": "INFO", "name": "app", "processName": "SpawnProcess-3", "tracking_id": "38a015f7-b0d3-41ea-a2b3-179bf608b4bb", "loc": "handlers.py:12", "func": "mock_db_query", "message": "This is mock_db_query. Current tracking ID: 38a015f7-b0d3-41ea-a2b3-179bf608b4bb"}

根據以上三條輸出可見,tracking_id 在日誌中和日誌消息中保持一致。

補充: 另一種解決多進程文件寫入問題的方法

費勁調試完上面的日誌模塊中,想到另一個解決多進程文件寫入問題的方法,那就是由外部工具寫日誌,服務只輸出到控制枱,類似 docker 和 k8s 運行環境的解決方案。以下只是思路,大致想來應該沒什麼問題,感興趣的話可以嘗試一下。

  1. 移除日誌模塊中的寫文件功能,保留輸出到控制枱的部分。注意標準輸出也有可能阻塞異步協程,所以隊列處理器還是要保留的。
  2. 啓動時用 nohup 將控制枱輸出重定向到文件中。比如:nohup python main.py > logs/start.log 2>&1 &
  3. 配置 logrotate 來做日誌輪轉。
user avatar
0 位用戶收藏了這個故事!

發佈 評論

Some HTML is okay.