簡介

本文介紹了一個基於多模態大模型的醫療圖像診斷項目。項目旨在通過訓練一個醫療領域的多模態大模型,提高醫生處理醫學圖像的效率,輔助診斷和治療。作者以家中老人的腦部CT為例,展示瞭如何利用MedTrinity-25M數據集訓練模型,經過數據準備、環境搭建、模型訓練及微調、最終驗證等步驟,成功使模型能夠識別CT圖像並給出具體的診斷意見,與專業醫生的診斷結果高度吻合。

前言

隨着多模態大模型的發展,其不僅限於文字處理,更能夠在圖像、視頻、音頻方面進行識別與理解。醫療領域中,醫生們往往需要對各種醫學圖像進行處理,以輔助診斷和治療。如果將多模態大模型與圖像診斷相結合,那麼這會極大地提升診斷效率。

項目目標

訓練一個醫療多模態大模型,用於圖像診斷。

剛好家裏老爺子近期略感頭疼,去醫院做了腦部CT,診斷患有垂體瘤,我將嘗試使用多模態大模型進行進一步診斷。

實現過程

1. 數據集準備

為了訓練模型,需要準備大量的醫學圖像數據。通過搜索我們找到以下訓練數據:

數據名稱:MedTrinity-25M
數據地址:https://github.com/UCSC-VLAA/MedTrinity-25M數據簡介:MedTrinity-25M數據集是一個用於醫學圖像分析和計算機視覺研究的大型數據集。
數據來源:該數據集由加州大學聖克魯茲分校(UCSC)提供,旨在促進醫學圖像處理和分析的研究。
數據量:MedTrinity-25M包含約2500萬條醫學圖像數據,涵蓋多種醫學成像技術,如CT、MRI和超聲等。

數據內容: 該數據集有兩份,分別是25Mdemo和25Mfull。

25Mdemo (約162,000條)數據集內容如下

數據集結構示例

25Mfull (約24,800,000條)數據集內容如下

數據集結構示例

2. 數據下載

2.1 安裝Hugging Face的Datasets庫
pip install datasets
2.2 下載數據集
from datasets import load_dataset

# 加載數據集
ds = load_dataset("UCSC-VLAA/MedTrinity-25M", "25M_demo", cache_dir="cache")

執行結果

數據集下載進度顯示

説明

  • 以上方法是使用HuggingFace的Datasets庫下載數據集,下載的路徑為當前腳本所在路徑下的cache文件夾。
  • 使用HuggingFace下載需要能夠訪問https://huggingface.co/並且在網站上申請數據集讀取權限才可以。
  • 如果沒有權限訪問HuggingFace,可以關注"一起AI技術"公眾號後,回覆“MedTrinity”獲取百度網盤下載地址。
2.3 預覽數據集
# 查看訓練集的前1個樣本
print(ds['train'][:1])

運行結果

{
    'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x15DD6D06530>], 
    'id': ['8031efe0-1b5c-11ef-8929-000066532cad'], 
    'caption': ['The image is a non-contrasted computed tomography (CT) scan of the brain, showing the cerebral structures without any medical devices present. The region of interest, located centrally and in the middle of the image, exhibits an area of altered density, which is indicative of a brain hemorrhage. This area is distinct from the surrounding brain tissue, suggesting a possible hematoma or bleeding within the brain parenchyma. The location and characteristics of this abnormality may suggest a relationship with the surrounding brain tissue, potentially causing a mass effect or contributing to increased intracranial pressure.']
}

使用如下命令對數據集的圖片進行可視化查看:

# 可視化image內容
from PIL import Image
import matplotlib.pyplot as plt

image = ds['train'][0]['image']  # 獲取第一張圖像

plt.imshow(image)
plt.axis('off')  # 不顯示座標軸
plt.show()

運行結果

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_數據集

3. 數據預處理

由於後續我們要通過LLama Factory進行多模態大模型微調,所以我們需要對上述的數據集進行預處理以符合LLama Factory的要求。

3.1 LLama Factory數據格式

查看LLama Factory的多模態數據格式要求如下:

[
  {
    "messages": [
      {
        "content": "<image>他們是誰?",
        "role": "user"
      },
      {
        "content": "他們是拜仁慕尼黑的凱恩和格雷茨卡。",
        "role": "assistant"
      },
      {
        "content": "他們在做什麼?",
        "role": "user"
      },
      {
        "content": "他們在足球場上慶祝。",
        "role": "assistant"
      }
    ],
    "images": [
      "mllm_demo_data/1.jpg"
    ]
  }
]
3.2 實現數據格式轉換腳本
from datasets import load_dataset
import os
import json
from PIL import Image

def save_images_and_json(ds, output_dir="mllm_data"):
    """
    將數據集中的圖像和對應的 JSON 信息保存到指定目錄。

    參數:
    ds: 數據集對象,包含圖像和標題。
    output_dir: 輸出目錄,默認為 "mllm_data"。
    """
    # 創建輸出目錄
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 創建一個列表來存儲所有的消息和圖像信息
    all_data = []

    # 遍歷數據集中的每個項目
    for item in ds:
        img_path = f"{output_dir}/{item['id']}.jpg"  # 圖像保存路徑
        image = item["image"]  # 假設這裏是一個 PIL 圖像對象

        # 將圖像對象保存為文件
        image.save(img_path)  # 使用 PIL 的 save 方法

        # 添加消息和圖像信息到列表中
        all_data.append(
            {
                "messages": [
                    {
                        "content": "<image>圖片中的診斷結果是怎樣?",
                        "role": "user",
                    },
                    {
                        "content": item["caption"],  # 從數據集中獲取的標題
                        "role": "assistant",
                    },
                ],
                "images": [img_path],  # 圖像文件路徑
            }
        )

    # 創建 JSON 文件
    json_file_path = f"{output_dir}/mllm_data.json"
    with open(json_file_path, "w", encoding='utf-8') as f:
        json.dump(all_data, f, ensure_ascii=False)  # 確保中文字符正常顯示

if __name__ == "__main__":
    # 加載數據集
    ds = load_dataset("UCSC-VLAA/MedTrinity-25M", "25M_demo", cache_dir="cache")

    # 保存數據集中的圖像和 JSON 信息
    save_images_and_json(ds['train'])

運行結果

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_模態_02

4. 模型下載

本次微調,我們使用阿里最新發布的多模態大模型:Qwen2-VL-2B-Instruct 作為底座模型。

模型説明地址:https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct

使用如下命令下載模型:

git lfs install
# 下載模型
git clone https://www.modelscope.cn/Qwen/Qwen2-VL-2B-Instruct.git

5. 環境準備

5.1 機器環境

硬件

  • 顯卡:4080 Super
  • 顯存:16GB

軟件

  • 系統:Ubuntu 20.04 LTS
  • python:3.10
  • pytorch:2.1.2 + cuda12.1
5.2 準備虛擬環境
# 創建python3.10版本虛擬環境
conda create --name train_env python=3.10

# 激活環境
conda activate train_env

# 安裝依賴包
pip install streamlit torch torchvision

# 安裝Qwen2建議的transformers版本
pip install git+https://github.com/huggingface/transformers

6. 準備訓練框架

下載並安裝LLamaFactory框架的具體步驟,請見【課程總結】day24(上):大模型三階段訓練方法(LLaMa Factory)中“準備訓練框架”部分內容,本章不再贅述。

6.1 修改LLaMaFactory源碼以適配transformer

由於Qwen2-VL使用的transformer的版本為4.47.0.dev0,LLamaFactory還不支持,所以需要修改LLaMaFactory的代碼,具體方法如下:

  1. 第一步:在llamafactory源碼中,找到check_dependencies()函數,這個函數位於src/llamafactory/extras/misc.py文件的第82行。
  2. 第二步:修改check_dependencies()函數並保存:
# 原始代碼
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")

# 修改後代碼
require_version("transformers>=4.41.2,<=4.47.0", "To fix: pip install transformers>=4.41.2,<=4.47.0")
  1. 第三步:重新啓動LLaMaFactory服務
llamafactory-cli webui

這個過程可能會提示ImportError: accelerate>=0.34.0 is required for a normal functioning of this module, but found accelerate==0.32.0.如遇到上述問題,可以重新安裝accelerate,如下:

# 卸載舊的 accelerate
pip uninstall accelerate

# 安裝新的 accelerate
pip install accelerate==0.34.0

7. 測試當前模型

  1. 第一步:啓動LLaMa Factory後,訪問http://0.0.0.0:7860
  2. 第二步:在web頁面配置模型路徑為4.步驟下載的模型路徑,並點擊加載模型

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_數據集_03

  1. 第三步:上傳一張CT圖片並輸入問題:“請使用中文描述下這個圖像並給出你的診斷結果”

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_數據集_04

由上圖可以看到,模型能夠識別到這是一個CT圖像,顯示了大概的位置以及相應的器官,但是並不能給出是否存在診斷結果。

8. 模型訓練

8.1 數據準備
  1. 第一步:將3.2步驟生成的mllm_data文件拷貝到LLaMaFactory的data目錄下
  2. 第二步:將4.步驟下載的底座模型Qwen2-VL拷貝到LLaMaFactory的model目錄下
  3. 第三步:修改LLaMaFactory data目錄下的dataset_info.json,增加自定義數據集:
"mllm_med": {
    "file_name": "mllm_data/mllm_data.json",
    "formatting": "sharegpt",
    "columns": {
        "messages": "messages",
        "images": "images"
    },
    "tags": {
        "role_tag": "role",
        "content_tag": "content",
        "user_tag": "user",
        "assistant_tag": "assistant"
    }
}
8.2 配置訓練參數

訪問LLaMaFactory的web頁面,配置微調的訓練參數:

  • Model name: Qwen2-VL-2B-Instruct
  • Model path: models/Qwen2-VL-2B-Instruct
  • Finetuning method: lora
  • Stage: Supervised Fine-Tuning
  • Dataset: mllm_med
  • Output dir: saves/Qwen2-VL/lora/Qwen2-VL-sft-demo1

配置參數中最好將save_steps設置大一點,否則訓練過程會生成非常多的訓練日誌,導致硬盤空間不足而訓練終止。

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_模態_05

點擊Preview Command預覽命令行無誤後,點擊Run按鈕開始訓練。

訓練參數

llamafactory-cli train \
    --do_train True \
    --model_name_or_path models/Qwen2-VL-2B-Instruct \
    --preprocessing_num_workers 16 \
    --finetuning_type lora \
    --template qwen2_vl \
    --flash_attn auto \
    --dataset_dir data \
    --dataset mllm_med \
    --cutoff_len 1024 \
    --learning_rate 5e-05 \
    --num_train_epochs 3.0 \
    --max_samples 100000 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 5 \
    --save_steps 3000 \
    --warmup_steps 0 \
    --optim adamw_torch \
    --packing False \
    --report_to none \
    --output_dir saves/Qwen2-VL-2B/full/Qwen2-VL-sft-demo1 \
    --bf16 True \
    --plot_loss True \
    --ddp_timeout 180000000 \
    --include_num_input_tokens_seen True \
    --lora_rank 8 \
    --lora_alpha 16 \
    --lora_dropout 0 \
    --lora_target all

訓練過程

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_模態_06

訓練的過程中,可以通過watch -n 1 nvidia-smi實時查看GPU顯存的消耗情況。

經過35小時的訓練,模型訓練完成,損失函數如下:

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_模態_07

損失函數一般降低至1.2左右,太低會導致模型過擬合。

8.3 合併導出模型

接下來,我們將Lora補丁與原始模型合併導出:

  1. 切換到Expert標籤下
  2. Model path: 選擇Qwen2-VL的基座模型,即:models/Qwen2-VL-2B-Instruct
  3. Checkpoint path: 選擇lora微調的輸出路徑,即saves/Qwen2-VL/lora/Qwen2-VL-sft-demo1
  4. Export path:設置一個新的路徑,例如:Qwen2-VL-sft-final
  5. 點擊開始導出按鈕

導出完畢後,會在LLaMaFactory的根目錄下生成一個Qwen2-VL-sft-final的文件夾。

9. 模型驗證

9.1 模型效果對比
  1. 第一步:在LLaMa Factory中卸載之前的模型
  2. 第二步:在LLaMa Factory中加載導出的模型,並配置模型路徑為Qwen2-VL-sft-final
  3. 第三步:加載模型並上傳之前的CT圖片提問同樣的問題

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_模態_08

可以看到,經過微調後的模型,可以給出具體區域存在的可能異常問題。

9.2 實際診斷

接下來,我將使用微調後的模型,為家裏老爺子的CT片做診斷,看看模型給出的診斷與大夫的異同點。

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_數據_09

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_數據_10

項目實戰:LLaMaFactory和Qwen2-VL-2B微調大模型實戰_模態_11

我總計測試了CT片上的52張局部結果,其中具有代表性的為上述三張,可以看到模型還是比較準確地診斷出:腦部有垂體瘤,可能會影響到眼部。這與大夫給出的診斷和後續檢查方案一致。

不足之處

訓練集

  1. 多模態:本次訓練只是採用了MedTrinity-25Mdemo數據集,如果使用MedTrinity-25Mfull數據集,效果應該會更好。
  2. 中英文:本次訓練集中使用的MedTrinity-25Mdemo數據集,只包含了英文數據,如果將英文標註翻譯為中文,提供中英文雙文數據集,相信效果會更好。
  3. 對話數據集:本次訓練只是使用了多模態數據集,如果增加中文對話(如:中文醫療對話數據-Chinese-medical-dialogue),相信效果會更好。

前端頁面

  1. 前端頁面:本次實踐曾使用streamlit構建前端頁面,以便圖片上傳和問題提出,但是在加載微調後的模型時,會出現:ValueError: No chat template is set for this processor問題,所以轉而使用LLaMaFactory的web頁面進行展示。
  2. 多個圖片推理:在Qwen2-VL的官方指導文檔中,提供了Multi image inference方法,本次未進行嘗試,相信將多個圖片交給大模型進行推理,效果會更好。

內容小結

  1. Qwen2-VL-2B作為多模態大模型,具備有非常強的多模態處理能力,除了能夠識別圖片內容,還可以進行相關的推理。
  2. 我們可以通過LLaMaFactory對模型進行微調,使得其具備醫療方面的處理能力。
  3. 微調數據集採用開源的MedTrinity-25M數據集,該數據集有兩個版本:25Mdemo和25Mfull。
  4. 訓練前需要對數據集進行預處理,使得其適配LLaMaFactory的微調格式。
  5. 經過微調後的多模態大模型,不但可以詳細地描述圖片中的內容,還可以給出可能的診斷結果。