Huggingface 保存和加载权重
huggingface
本文字数:519 字 | 阅读时长 ≈ 2 min

Huggingface 保存和加载权重

huggingface
本文字数:519 字 | 阅读时长 ≈ 2 min

1. 保存模型的几种方法

1.1. model.save_pretrained

from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir="./bert")
model = BertModel.from_pretrained('bert-base-uncased', cache_dir="./bert")

tokenizer.save_pretrained('./saved_model')
model.save_pretrained('./saved_model')

左图 bert 下载的文件,右图是保存的文件

1.2. trainer.save_model

在模型训练的时候也可以通过 trainer 来保存模型

from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased", cache_dir="./bert_case")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", cache_dir="./bert_case", num_labels=2)
def tokenize_function(examples):
    return tokenizer(examples["sentence1"], examples["sentence2"], padding="max_length", truncation=True)
dataset = load_dataset("glue", "mrpc")
tokenized_data = dataset.map(tokenize_function, batched=True)

training_args = TrainingArguments(output_dir="./results", num_train_epochs=1, per_device_train_batch_size=8, per_device_eval_batch_size=8)
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_data["train"], eval_dataset=tokenized_data["validation"])

trainer.train()
trainer.evaluate()
trainer.save_model("./saved_trainer")

trainer 保存的文件列表

1.3. trainer._save

2. 保存检查点

from transformers import BertForSequenceClassification, Trainer, TrainingArguments

train_dataset = [{"input_ids": [0, 1, 2, 3, 4, 5], "labels": 1}] * 100  # 创建假训练数据
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", cache_dir="./bert")

training_args = TrainingArguments(
    output_dir="b", num_train_epochs=1, per_device_train_batch_size=1,
    save_steps=10,  # 每10步保存一次
    save_total_limit=2,  # 最多保存2个检查点
)

trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
trainer.train()
trainer.save_model("b")

其中 training_args = TrainingArguments 参数解析

3. 保存状态

trainer.save_state()
其实就是保存一个 json 文件,里面记录了各种信息

4. 加载 safetensors

from safetensors import safe_open

tensors = {}

# gpu
with safe_open("saved_trainer/model.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

# # cpu
# with safe_open("saved_trainer/model.safetensors", framework="pt", device="cpu") as f:
#     for k in f.keys():
#         tensors[k] = f.get_tensor(k)
b = torch.load('saved_trainer/training_args.bin')
print()