动态

详情 返回 返回

Google開源Tunix:JAX生態的LLM微調方案來了 - 动态 详情

JAX生態這兩年在LLM訓練這塊追趕得挺快。PyTorch雖然還是主流但JAX在並行計算、TPU加速和API組合性上確實有些獨特的優勢。Google今天放出了Tunix這個庫,專門做LLM的後訓練——微調、強化學習、知識蒸餾這些都能搞。

Tunix是什麼

這是個構建在JAX之上的後訓練庫,和Flax NNX集成得比較緊密。主要解決三類問題:

  • 監督微調(Supervised Fine-Tuning)
  • 強化學習(Reinforcement Learning)
  • 知識蒸餾(Knowledge Distillation)

現在還在早期開發階段,功能在持續迭代,支持的模型也在慢慢擴展。

核心功能

監督微調:既支持全參數微調,也支持LoRA和Q-LoRA這類參數高效的方法。內存和算力受限的時候,PEFT方案還是挺實用的。

強化學習:實現了幾個主流算法:PPO(Proximal Policy Optimization)、GRPO(Group Relative Policy Optimization)、還有token級別的GSPO。另外還有DPO(Direct Preference Optimization)做偏好對齊,這個在RLHF場景用得比較多。

知識蒸餾:支持幾種策略,包括基於logit的概率分佈匹配、注意力機制的轉移和投影、跨架構的特徵池化與投影。這幾種方法在不同場景下各有用處。

庫的設計比較模塊化,組件可以自由組合,想擴展自定義流程也不算麻煩。分佈式訓練支持數據並行(DP)、完全分片數據並行(FSDP)和張量並行(TP),對TPU做了專門優化。

安裝

三種裝法:

從PyPI裝(推薦):

 pip install "tunix[prod]"

或者直接從GitHub主分支:

 pip install git+https://github.com/google/tunix

開發模式從源碼裝:

 git clone https://github.com/google/tunix.git  
 cd tunix  
 pip install -e".[dev]"

TPU上用QLoRA微調Gemma

拿個英譯法的任務來演示。用的是Google的Gemma 2B模型,跑在TPU v5e-8上。

環境準備

 pip install -q kagglehub safetensors tensorflow tensorflow_datasets tensorboardX transformers grain datasets  
 pip install -q git+https://github.com/google/tunix  
 pip install -q git+https://github.com/google/qwix  
   
 # Flax需要升級到最新版
 pip uninstall -q -y flax  
 pip install -q git+https://github.com/google/flax.git

完整流程

第一步,從Kaggle拉預訓練checkpoint:

 import kagglehub  
   
 model_path = "google/gemma/flax/2b"  
 kaggle_ckpt_path = kagglehub.model_download(model_path)

初始化模型和tokenizer:

 from flax import nnx  
from tunix.models.gemma import model as gemma_lib, params as params_lib  
from tunix.generate import tokenizer_adapter as tokenizer_lib  

base_model = gemma_lib.Transformer.from_params(  
    params_lib.load_and_format_params(kaggle_ckpt_path, "2b"),  
    version="2b"  
)  
 tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=f"{kaggle_ckpt_path}/tokenizer.model")

掛上QLoRA adapter:

 import qwix  

lora_provider = qwix.LoraProvider(  
    module_path=".*(q_einsum|kv_einsum|proj)",  
    rank=16,  
    alpha=2.0,  
    weight_qtype="nf4"  # enable QLoRA quantization
)  
 lora_model = qwix.apply_lora_to_model(base_model, lora_provider)

這裏rank設成16,alpha是2.0,weight_qtype指定nf4量化格式。

加載訓練數據:

 from tunix.examples.data import translation_dataset  

train_ds, validation_ds = translation_dataset.create_datasets(  
    dataset_name="mtnt/en-fr",  
    global_batch_size=16,  
    max_target_length=256,  
    num_train_epochs=3,  
    tokenizer=tokenizer,  
 )

用的是mtnt的英法平行語料,batch size 16,目標序列最長256個token。

開始訓練:

 from tunix.sft import peft_trainer, utils  
import optax  

trainer=peft_trainer.PeftTrainer(  
    lora_model,  
    optimizer=optax.adamw(1e-3),  
    config=peft_trainer.TrainingConfig(max_steps=100)  
)  
 trainer.train(train_ds, validation_ds)

優化器用AdamW,學習率1e-3,跑100步看看效果。

推理測試:

訓練完直接用adapter過的模型做生成。Tunix提供了Sampler工具:

 from tunix.generate import sampler as sampler_lib  

# initialize sampler
sampler = sampler_lib.Sampler(  
    transformer=lora_model,  
    tokenizer=tokenizer,  
    cache_config=sampler_lib.CacheConfig(  
        cache_size=256,  
        num_layers=base_model.num_layers,  
        num_kv_heads=base_model.num_kv_heads,  
        head_dim=base_model.head_dim,  
    ),  
)  

# test prompts
input_batch = [  
    "Translate this into French:\nHello, my name is Morgane.\n",  
    "Translate this into French:\nThis dish is delicious!\n",  
    "Translate this into French:\nI am a student.\n",  
    "Translate this into French:\nHow's the weather today?\n",  
]  

# generate predictions
out_data = sampler(  
    input_strings=input_batch,  
    max_generation_steps=20,  
)  

# print results
for input_string, out_string in zip(input_batch, out_data.text):  
    print(f"----------------------")  
    print(f"Prompt:\n{input_string}")  
     print(f"Output:\n{out_string}")

如果用的是QLoRA,把lora_model換成qlora_model就行。生產環境可以考慮把adapter合併回基模型,推理延遲能降下來。

總結

100步訓練之後,模型已經能生成一些翻譯結果了,雖然質量還不夠好。多訓練一段時間,準確率會明顯提升,而且內存開銷和訓練速度都保持在不錯的水平。

Tunix現在還比較新,但已經能看出一些潛力。TPU優先的設計、模塊化的API、LoRA/QLoRA支持、完整的分佈式訓練策略,這些對做LLM適配研究的人來説都挺有用。

後續應該會繼續擴展支持的模型類型和訓練算法,值得關注。

地址:https://avoid.overfit.cn/post/c434311d8a894922b6c52ea179cf8d97

作者:Abish Pius

user avatar leeqvip 头像 skysailstar 头像
点赞 2 用户, 点赞了这篇动态!
点赞

Add a new 评论

Some HTML is okay.