1. from_pretrained
from_pretrained
函数是 HuggingFace 中 Transformers 库的一个重要函数,能够轻松地加载预训练的模型和相关配置。from_pretrained
函数可以用于加载各种预训练的 NLP 模型,如 BERT、GPT-2、RoBERTa 等
- pretrained_model_name_or_path:例如"bert-base-cased",或是已经下载好的 bert 的本地路径,例如"./bert"
- cache_dir:缓存目录,只给定了模型名字"bert-base-cased",则会自动下载到当前缓存路径
- low_cpu_mem_usage:一般是想将模型加载到 gpu,并且让此过程中最小的 cpu 内存消耗,以满足在小内存、大显存机器上的加载。默认为 False
例如,下面代码,我们选择了 bert-base-cased
模型,但是本地没有这个模型文件,就会自动下载到当前目录下的 ./bert
文件夹中,如果不给定 cache_dir 参数,则会自动下载到 /home/user_name/.cache/huggingface
文件夹下
model = AutoModel.from_pretrained("bert-base-cased", cache_dir="./bert")
2. from_pretrained 模型构建过程
如下是 from_pretrained 的代码
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True
hub_kwargs_names = [
"cache_dir",
"code_revision",
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"subfolder",
"use_auth_token",
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
if not isinstance(config, PretrainedConfig):
kwargs_orig = copy.deepcopy(kwargs)
# ensure not to pollute the config object with torch_dtype="auto" - since it's
# meaningless in the context of the config object - torch.dtype values are acceptable
if kwargs.get("torch_dtype", None) == "auto":
_ = kwargs.pop("torch_dtype")
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=trust_remote_code,
**hub_kwargs,
**kwargs,
)
# if torch_dtype=auto was passed here, ensure to pass it on
if kwargs_orig.get("torch_dtype", None) == "auto":
kwargs["torch_dtype"] = "auto"
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__]
model_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
)
_ = hub_kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
model_class.register_for_auto_class(cls.__name__)
else:
cls.register(config.__class__, model_class, exist_ok=True)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)
代码解析过程如下
- 获取 config
下载的文件夹中有一个config.json
文件,从中读取模型相关参数等,我们以 bert-base-cased 为例
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=trust_remote_code,
**hub_kwargs,
**kwargs,
)
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 28996
}
如上所示,这里的 "model_type": "bert"
会决定选取哪一个 config 类,AutoConfig.from_pretrained
进入以后会有一句 config_class = CONFIG_MAPPING[config_dict["model_type"]]
代码,也就是说他会选取 bert 模型对应的类,即 <class 'transformers.models.bert.configuration_bert.BertConfig'>
,随后就会用 config.json 文件里面的内容覆盖 BertConfig 类中的默认参数,最终得到一个 config 对象
[<class 'transformers.models.albert.configuration_albert.AlbertConfig'>,
<class 'transformers.models.align.configuration_align.AlignConfig'>,
<class 'transformers.models.altclip.configuration_altclip.AltCLIPConfig'>,
<class 'transformers.models.audio_spectrogram_transformer.configuration_audio_spectrogram_transformer.ASTConfig'>,
<class 'transformers.models.autoformer.configuration_autoformer.AutoformerConfig'>,
...]
- 初始化模型
接下来通过 config 中的内容来初始化模型,即用此BertConfig
类初始化模型,会执行到下面这句代码
model_class = _get_model_class(config, cls._model_mapping)
OrderedDict([('albert', 'AlbertModel'), ('align', 'AlignModel'), ('altclip', 'AltCLIPModel'), ('audio-spectrogram-transformer', 'ASTModel'), ('autoformer', 'AutoformerModel'), ('bark', 'BarkModel'), ('bart', 'BartModel'), ...])
在上述函数中,会将 BertConfig 和模型 type 在进行一个映射,从而得到 BertModel
,最终会得到 <class 'transformers.models.bert.modeling_bert.BertModel'>
ok,到这里我们梳理一下 from_pretrained 函数从加载 config.json 文件到得到 BertModel 做了什么。首先,我们通过 config.json 文件中的 "model_type": "bert"
来确定使用哪一个 config 类,然后用 config 类初始化一个 config 对象,最后用 config 对象和模型进行一个映射,得到模型类。
- 使用 config.json 的配置参数初始化模型
最终通过上述的 config 中的参数来初始化BertModel
模型,得到我们最终的模型
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
3. from_pretrained 参数加载过程
一般来说,下载完的配置文件有以下模型,在模型构建完后,会自动识别里面的 pytorch_model.bin.index.json
文件,并根据他来索引 bin 文件,如果 bin 参数中包含模型中没有的参数,最终构建完毕后还会将未加载的参数进行输出
pytorch_model-00001-of-00002.bin
pytorch_model-00002-of-00002.bin
pytorch_model.bin.index.json
本文由 Yonghui Wang 创作,采用
知识共享署名4.0
国际许可协议进行许可
本站文章除注明转载/出处外,均为本站原创或翻译,转载前请务必署名
最后编辑时间为:
Dec 19, 2024 12:13 pm