pytorch 中输出模型参数名和梯度的一些操作
pytorch
本文字数:406 字 | 阅读时长 ≈ 1 min

pytorch 中输出模型参数名和梯度的一些操作

pytorch
本文字数:406 字 | 阅读时长 ≈ 1 min

在 PyTorch 中,我们经常需要查看和分析模型的参数信息。本文将介绍几个常用的参数查看和统计方法。

1. 查看模型参数信息

1.1 查看需要梯度的参数

以下代码可以列出模型中所有需要计算梯度的参数名称:

def print_trainable_parameters(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"Trainable parameter: {name}")

1.2 查看非 GPU 参数

检查哪些参数没有被加载到 GPU 上:

def print_cpu_parameters(model):
    for name, param in model.named_parameters():
        if not param.is_cuda:
            print(f"Parameter on CPU: {name}")

1.3 查看 Float32 类型参数

列出所有 float32 类型的参数:

def print_float32_parameters(model):
    for name, param in model.named_parameters():
        if param.dtype == torch.float32:
            print(f"Float32 parameter: {name}")

2. 统计模型参数量

以下代码可以计算模型的总参数量,并转换为十亿级别显示:

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    total_params_billion = total_params / 1e9
    
    # 可选:打印每层参数详情
    for name, param in model.named_parameters():
        print(f"{name}: {param.shape}, 参数量: {param.numel():,}")

    print(f"模型参数统计:")
    print(f"总参数量: {total_params:,}")
    print(f"十亿级参数量: {total_params_billion:.3f}B")

使用示例

# 假设已经定义了模型
model = YourModel()

# 打印需要训练的参数
print_trainable_parameters(model)

# 统计模型参数量
count_parameters(model)

注意事项

9月 09, 2024
9月 06, 2024