mmlab 中 Registry 类
utils
本文字数:1.3k 字 | 阅读时长 ≈ 5 min

mmlab 中 Registry 类

utils
本文字数:1.3k 字 | 阅读时长 ≈ 5 min

1. Registry 的功能

mmlab 中的 registry 类主要用来对 model 中的 backbone、neck 或者 dataset、optimizer 等进行一个构建,维护一个全局 key-value 对,当我们想更换某一部分的时候直接更改即可

例如我通过 registry 注册 backbone 类中有 ResNet18、ResNet50、VGG19 等等,假设我们需要实例化 ResNet,只需要更改配置文件中的如下参数即可

backbone=dict(
    type='ResNet', # 待实例化的类名
    depth=50, # 后面的都是对于的类初始化参数
    num_stages=4,
    out_indices=(0, 1, 2, 3),
    frozen_stages=1,
    norm_cfg=dict(type='BN', requires_grad=True),
    norm_eval=True,
    style='pytorch'),

其中 type 就是我们要更换的 backbone 类型,随后的是 backbone 对应的参数

2. Registry 的简单实现

上面提到 registry 类目的是维护一个全局的 key-value 对,所以我们需要一个全局变量,并不断完善它

下面的例子均来源于mmlab 知乎

_module_dict = dict()

# 定义装饰器函数
def register_module(name):
    def _register(cls):
        _module_dict[name] = cls
        return cls
    return _register

# 装饰器用法
@register_module('one_class')
class OneTest(object):
    pass

@register_module('two_class')
class TwoTest(object):
    pass

if __name__ == '__main__':
    # 通过注册类名实现自动实例化功能
    one_test = _module_dict['one_class']()
    print(one_test)

'''
<__main__.OneTest object at 0x7f4eee6bdc70>
'''

对上面的例子进行解释,函数运行到 @register_module('one_class') 时,会对 OneTest 类进行注册,注册后全局字典 _module_dict 会多出一个参数,同样的,运行到 TwoTest 中时依然会对其进行注册,全部注册完后,_module_dict 的内容如下所示,正好是两个 key-value 值

{'one_class': <class '__main__.OneTest'>, 'two_class': <class '__main__.TwoTest'>}

3. Registry 类实现

registry 的类实现就是 mmcv 中所使用的方法,方法非常简洁,registry 类如下

class Registry:
    def __init__(self, name):
        self._name = name  # 类名, 例如backbone, neck等
        self._module_dict = dict()  # 全局key-value对

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):  # 判断是否为类
            raise TypeError('module must be a class, ' f'but got {type(module_class)}')

        if module_name is None:  # 如果没有名, 默认用注册的类名
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:  # 是否覆盖重名类
            raise KeyError(f'{module_name} is already registered ' f'in {self.name}')
        self._module_dict[module_name] = module_class  # 加入key-value对

    # 装饰器函数
    def register_module(self, name=None, force=False, module=None):
        if module is not None:  # 如果给定module,直接增加到字典中即可
            self._register_module(module_class=module, module_name=name, force=force)
            return module
        
        def _register(cls):  # 装饰器用法
            self._register_module(module_class=cls, module_name=name, force=force)
            return cls
        return _register

registry 类有两个属性,_name_module_dict,第一个是 registry 的名字,比如 backbone,第二个保存 key-value 属性,比如 resnet 和其对应类。其余两个函数就是用来注册字典属性的了,和之前的简单实现类似,还有一些特殊参数,下面举例看一下其使用

第一个例子

这个例子表示 register 装饰器可以自动命名,比如我们不想让他用默认的 Converter1 类名,想给他命名为 abc,这样字典就是{“abc”: Converter1()}了,此外 force 为 True 表示,如果命名产生了重复,则会覆盖之前的类

@CONVERTERS.register_module(name="abc", force=True)
class Converter1(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

第二个例子

下面这个例子表示不适用装饰器的方法,而是通过函数对其注册

CONVERTERS.register_module(module=Converter1())

4. Register 的调用

我们维护了这么多 register 类,那么是如何调用的呢,mmlab 给出了答案,只需要调用 build_from_cfg 函数即可,build_from_cfg(cfg, register, default_args=None) 有三个参数

下面我们看一个例子就很容易理解了

@CONVERTERS.register_module()
class Converter1(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

converter_cfg = dict(type='Converter1', a=1, b=2)
converter = build_from_cfg(converter_cfg, CONVERTERS)

如上所示,我们注册了 Converter1 这个 key,注册的类别为 CONVERTERS,build_from_cfg 第一个参数为 converter_cfg,表示我们想要 Converter1 这个 key,他的参数为 a 和 b,第二个参数为 CONVERTERS 表示我们要在 CONVERTERS 这个注册类中寻找 Converter1 这个 key,找到之后就自动用 ab 参数初始化了,是不是很简单?

下面我们看一下 build_from_cfg 内部是如何实现的

代码如下,实现非常简单,首先如果 default_args 有值,则将其与 cfg 合并,随后取出 type 这个参数找到相应的 key,如果有这个 key 则返回类,否则如果实例化了就直接返回,如果都没有就报错,最后找到这个类后,用其配置参数也就是我们上面说的 ab 来初始化他并返回~就本例而言,我们会直接返回一个用 a=1,b=1 初始化的 Converter1 类,随后使用即可

def build_from_cfg(cfg, registry, default_args=None):
    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type') # 注册 str 类名
    if isinstance(obj_type, str):
        obj_cls = registry.get(obj_type)  # 从dict中根据key获取value
        if obj_cls is None:
            raise KeyError(f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):  # 如果已经实例化了,那就直接返回
        obj_cls = obj_type
    else:
        raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')

    return obj_cls(**args)  # 根据给定cfg参数初始化对应类
9月 09, 2024
9月 06, 2024