1. peft 库安装
本文参考huggingface 微调,peft 库是一个高效微调 LLM 的库,与 Transformers 和 Accelerate 无缝集成,这使得能够使用来自 Transformers 的最流行和高性能的模型,以及 Accelerate 的简单性和可扩展性
conda create -n peft python=3.10
conda activate peft
# 安装相应的包
pip install git+https://github.com/huggingface/transformers
pip install git+https://github.com/huggingface/accelerate
pip install git+https://github.com/huggingface/peft
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia
2. peft 库使用
- 模型训练
from transformers import AutoModelForSeq2SeqLM
from peft import get_peft_model, LoraConfig, TaskType # add
model_name_or_path = "bigscience/mt0-large"
tokenizer_name_or_path = "bigscience/mt0-large"
peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config) # add
model.print_trainable_parameters() # add
# output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282
model.save_pretrained("output_dir")
# model.push_to_hub("my_awesome_peft_model") also works
- 模型推理
from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig # add
peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id) # add
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = model.to(device)
model.eval()
inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt")
with torch.no_grad():
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10)
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])
本文由 Yonghui Wang 创作,采用
知识共享署名4.0
国际许可协议进行许可
本站文章除注明转载/出处外,均为本站原创或翻译,转载前请务必署名
最后编辑时间为:
Dec 19, 2024 12:13 pm