Huggingface: LoRA
huggingface
本文字数:308 字 | 阅读时长 ≈ 1 min

Huggingface: LoRA

huggingface
本文字数:308 字 | 阅读时长 ≈ 1 min

1. peft 库安装

本文参考huggingface 微调,peft 库是一个高效微调 LLM 的库,与 Transformers 和 Accelerate 无缝集成,这使得能够使用来自 Transformers 的最流行和高性能的模型,以及 Accelerate 的简单性和可扩展性

conda create -n peft python=3.10
conda activate peft

# 安装相应的包
pip install git+https://github.com/huggingface/transformers
pip install git+https://github.com/huggingface/accelerate
pip install git+https://github.com/huggingface/peft
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia

2. peft 库使用

  1. 模型训练
from transformers import AutoModelForSeq2SeqLM
from peft import get_peft_model, LoraConfig, TaskType  # add
model_name_or_path = "bigscience/mt0-large"
tokenizer_name_or_path = "bigscience/mt0-large"

peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)  # add
model.print_trainable_parameters()  # add
# output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282

model.save_pretrained("output_dir") 
# model.push_to_hub("my_awesome_peft_model") also works
  1. 模型推理
from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig  # add

peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)  # add
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

model = model.to(device)
model.eval()
inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt")

with torch.no_grad():
    outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])
4月 06, 2025
3月 10, 2025
12月 31, 2024