Huggingface 的 from_pretrained
huggingface
本文字数:1.2k 字 | 阅读时长 ≈ 5 min

Huggingface 的 from_pretrained

huggingface
本文字数:1.2k 字 | 阅读时长 ≈ 5 min

1. from_pretrained

from_pretrained 函数是 HuggingFace 中 Transformers 库的一个重要函数,能够轻松地加载预训练的模型和相关配置。from_pretrained 函数可以用于加载各种预训练的 NLP 模型,如 BERT、GPT-2、RoBERTa 等

例如,下面代码,我们选择了 bert-base-cased 模型,但是本地没有这个模型文件,就会自动下载到当前目录下的 ./bert 文件夹中,如果不给定 cache_dir 参数,则会自动下载到 /home/user_name/.cache/huggingface 文件夹下

model = AutoModel.from_pretrained("bert-base-cased", cache_dir="./bert")

2. from_pretrained 模型构建过程

如下是 from_pretrained 的代码

def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
    config = kwargs.pop("config", None)
    trust_remote_code = kwargs.pop("trust_remote_code", None)
    kwargs["_from_auto"] = True
    hub_kwargs_names = [
        "cache_dir",
        "code_revision",
        "force_download",
        "local_files_only",
        "proxies",
        "resume_download",
        "revision",
        "subfolder",
        "use_auth_token",
    ]
    hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
    if not isinstance(config, PretrainedConfig):
        kwargs_orig = copy.deepcopy(kwargs)
        # ensure not to pollute the config object with torch_dtype="auto" - since it's
        # meaningless in the context of the config object - torch.dtype values are acceptable
        if kwargs.get("torch_dtype", None) == "auto":
            _ = kwargs.pop("torch_dtype")

        config, kwargs = AutoConfig.from_pretrained(
            pretrained_model_name_or_path,
            return_unused_kwargs=True,
            trust_remote_code=trust_remote_code,
            **hub_kwargs,
            **kwargs,
        )

        # if torch_dtype=auto was passed here, ensure to pass it on
        if kwargs_orig.get("torch_dtype", None) == "auto":
            kwargs["torch_dtype"] = "auto"

    has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
    has_local_code = type(config) in cls._model_mapping.keys()
    trust_remote_code = resolve_trust_remote_code(
        trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
    )
    if has_remote_code and trust_remote_code:
        class_ref = config.auto_map[cls.__name__]
        model_class = get_class_from_dynamic_module(
            class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
        )
        _ = hub_kwargs.pop("code_revision", None)
        if os.path.isdir(pretrained_model_name_or_path):
            model_class.register_for_auto_class(cls.__name__)
        else:
            cls.register(config.__class__, model_class, exist_ok=True)
        return model_class.from_pretrained(
            pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
        )
    elif type(config) in cls._model_mapping.keys():
        model_class = _get_model_class(config, cls._model_mapping)
        return model_class.from_pretrained(
            pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
        )
    raise ValueError(
        f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
        f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    )

代码解析过程如下

  1. 获取 config
    下载的文件夹中有一个 config.json 文件,从中读取模型相关参数等,我们以 bert-base-cased 为例
config, kwargs = AutoConfig.from_pretrained(
    pretrained_model_name_or_path,
    return_unused_kwargs=True,
    trust_remote_code=trust_remote_code,
    **hub_kwargs,
    **kwargs,
)
{
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.6.0.dev0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 28996
}

如上所示,这里的 "model_type": "bert" 会决定选取哪一个 config 类,AutoConfig.from_pretrained 进入以后会有一句 config_class = CONFIG_MAPPING[config_dict["model_type"]] 代码,也就是说他会选取 bert 模型对应的类,即 <class 'transformers.models.bert.configuration_bert.BertConfig'>,随后就会用 config.json 文件里面的内容覆盖 BertConfig 类中的默认参数,最终得到一个 config 对象

[<class 'transformers.models.albert.configuration_albert.AlbertConfig'>,
 <class 'transformers.models.align.configuration_align.AlignConfig'>, 
 <class 'transformers.models.altclip.configuration_altclip.AltCLIPConfig'>, 
 <class 'transformers.models.audio_spectrogram_transformer.configuration_audio_spectrogram_transformer.ASTConfig'>, 
 <class 'transformers.models.autoformer.configuration_autoformer.AutoformerConfig'>, 
 ...]
  1. 初始化模型
    接下来通过 config 中的内容来初始化模型,即用此 BertConfig 类初始化模型,会执行到下面这句代码
model_class = _get_model_class(config, cls._model_mapping)
OrderedDict([('albert', 'AlbertModel'), ('align', 'AlignModel'), ('altclip', 'AltCLIPModel'), ('audio-spectrogram-transformer', 'ASTModel'), ('autoformer', 'AutoformerModel'), ('bark', 'BarkModel'), ('bart', 'BartModel'), ...])

在上述函数中,会将 BertConfig 和模型 type 在进行一个映射,从而得到 BertModel,最终会得到 <class 'transformers.models.bert.modeling_bert.BertModel'>
ok,到这里我们梳理一下 from_pretrained 函数从加载 config.json 文件到得到 BertModel 做了什么。首先,我们通过 config.json 文件中的 "model_type": "bert" 来确定使用哪一个 config 类,然后用 config 类初始化一个 config 对象,最后用 config 对象和模型进行一个映射,得到模型类。

  1. 使用 config.json 的配置参数初始化模型
    最终通过上述的 config 中的参数来初始化 BertModel 模型,得到我们最终的模型
return model_class.from_pretrained(
    pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)

3. from_pretrained 参数加载过程

一般来说,下载完的配置文件有以下模型,在模型构建完后,会自动识别里面的 pytorch_model.bin.index.json 文件,并根据他来索引 bin 文件,如果 bin 参数中包含模型中没有的参数,最终构建完毕后还会将未加载的参数进行输出

pytorch_model-00001-of-00002.bin
pytorch_model-00002-of-00002.bin
pytorch_model.bin.index.json