pytorch 设计随机种子
pytorch
本文字数:140 字 | 阅读时长 ≈ 1 min

pytorch 设计随机种子

pytorch
本文字数:140 字 | 阅读时长 ≈ 1 min

1. PyTorch 设置随机种子

在进行网络训练的时候为了之后可以成功复现当前结果,需要设置随机种子

废话少说,直接上函数,在 train.py 最初调用此函数即可

def init_seeds(seed=0):
    random.seed(seed)  # seed for module random
    np.random.seed(seed)  # seed for numpy
    torch.manual_seed(seed)  # seed for PyTorch CPU
    torch.cuda.manual_seed(seed)  # seed for current PyTorch GPU
    torch.cuda.manual_seed_all(seed)  # seed for all PyTorch GPUs
    if seed == 0:
        # if True, causes cuDNN to only use deterministic convolution algorithms. 
        torch.backends.cudnn.deterministic = True
        # if True, causes cuDNN to benchmark multiple convolution algorithms and select the fastest.
        torch.backends.cudnn.benchmark = False
9月 09, 2024