wandb 记录 loss 和 image
image
本文字数:1.7k 字 | 阅读时长 ≈ 7 min

wandb 记录 loss 和 image

image
本文字数:1.7k 字 | 阅读时长 ≈ 7 min

wandb 是一个可视化库,他有 visdom 实时查看的优点以及能够永久的记录到云端

1. wandb 安装

登陆https://wandb.ai注册一个 wandb 账号,然后运行下面命令安装 wandb。使用时在终端输入 wandb login 按照提示输入即可

pip install wandb

使用手册

对于 PyTorch:Quick start
Document:使用手册

2. loss 曲线记录

2.1 step 记录

wandb 默认的是按照 step 记录的,也就是说每执行一次命令,就记录一次 step,下面举个例子说明

注意:下面例子均需要修改 wandb.init 内部的值,自己写个简单的 demo 修改即可

''' parser '''
parser = argparse.ArgumentParser(description='your description')
parser.add_argument('--xx', default=xx)
opt = parser.parse_args()

''' init wandb '''
wandb.init(project="project nmae", entity="your account", config=opt)

# 添加了loss曲线
for epoch in range(1, 20):
	wandb.log({"train loss": 10/(3*epoch)})
	wandb.log({"test loss": 10/(2*epoch)})

运行结果如下,横坐标为 step,从 0 开始,每运行一次 wandb.log 命令 step+1,也就是说 train loss 的曲线横坐标为[0, 2, 4, …],test loss 的曲线横坐标为[1, 3, 5, …],在训练网络的时候我们需要将横坐标改为 epoch 或者其他我们需要的值,怎么操作呢?看第二步

2.2 custom x-axis

这里提供两种自定义 x 轴的方法,其中第一种较为简单(推荐),第二种也可以学习一下

方法一(推荐)

只需要在 wandb.log 中多记录一条 epoch 即可,代码如下

for epoch in range(1, 20):
    wandb.log({"train loss": 10/(3*epoch), 'epoch':epoch})
    wandb.log({"test loss": 10/(2*epoch), 'epoch':epoch})

运行后结果如下,多了一条 epoch 曲线,但是 loss 曲线并没有变化啊,别急,看第二幅图,只需要点击右上角的 x 从 step 改变为 epoch 即可

改变后的曲线如下

方法二

此方法要使用 wandb.define_metric 方法,具体使用直接看例子就很容易明白,函数内的 step_metric 意思是将 custom_epoch 添加到 train loss 的横坐标当中去

''' init wandb '''
wandb.init(project="project nmae", entity="your account", config=opt)
wandb.define_metric("custom_epoch")
wandb.define_metric("train loss", step_metric='custom_epoch')

# 添加了loss曲线
for epoch in range(1, 20):
    wandb.log({"custom_epoch": epoch})
	wandb.log({"train loss": 10/(3*epoch)})
	wandb.log({"test loss": 10/(2*epoch)})

运行结果如下,运行完后需要点击 train loss 右上角的小笔,把他的 x 轴改为 custom_epoch 不然不会生效

3. Image 记录

官方教程:https://docs.wandb.ai/guides/track/log/media

3.1 单张 Image

这里 numpy 为 uint8、[0, 255]、HWC 类型,使用 wandb.Image 将其转化为 wandb 能够记录的图片,其中 caption 为图片的标题,即图片的名字,然后用 wandb.log 记录即可

''' init wandb '''
wandb.init(project="project nmae", entity="your account", config=opt)

img = np.arange(1, 301).reshape(10, 10, 3)
Img = wandb.Image(img, caption="I am an image")
wandb.log({"log an image": Img})

效果如下,不同于曲线,图片记录在 Media 模块下

3.2 多张 Image

如果我们想记录多张 image 怎么办呢?例如我们需要记录每个 epoch 的图片并查看他的变化

还记得 loss 曲线中讨论的 step 吗,wandb.log 每运行一次 step 就会加 1,对于曲线的绘制我们可以改变它的横坐标,但是图片的横坐标改变不了(我尝试过,怎么也改不了)。如果我们记录了 100 张图片,我们很难知道他们是哪个 epoch 的结果,需要从头数,这样很麻烦

下面依然给出两个解决方案

方法一(推荐)

先直接给出代码看效果,我们在记录图片的同时运行了 wandb.log 记录 loss,当我们运行到 epoch=10 时,我们的 step 到了 27(如下图),图片的索引(即 step)没法修改为 epoch,所以我们在存储图片的时候直接修改 caption 即可知道是哪个 epoch 的图片了

''' init wandb '''
wandb.init(project="project nmae", entity="your account", config=opt)

for epoch in range(1, 20):
    img = np.arange(1, 301).reshape(10, 10, 3)
    Img = wandb.Image(img, caption="epoch:{}".format(epoch))  # attention!!!
    wandb.log({"loglog": Img})
    wandb.log({"train loss": 10/(3*epoch)})
    wandb.log({"test loss": 10/(2*epoch)})

方法二
这个方法比较简单粗暴。。就是把所有的图片都记录下来。。这样会有很多个框框,直接看代码和运行结果吧,有多少个 epoch,在 media 下就有多少张图片,会显得比较冗杂

for epoch in range(1, 20):
    img = np.arange(1, 301).reshape(10, 10, 3)
    Img = wandb.Image(img, caption="I am an image")  
    wandb.log({"loglog{}".format(epoch): Img})							# attention!!!
    wandb.log({"train loss": 10/(3*epoch)})
    wandb.log({"test loss": 10/(2*epoch)})

两个方法对比

其实这两个方法就是命名略有不同,一个是命名在图片备注上,一个是明明在 log 中

''' 方法一 '''
Img = wandb.Image(img, caption="epoch:{}".format(epoch))  # here!!!
wandb.log({"loglog": Img})
''' 方法二 '''
Img = wandb.Image(img, caption="I am an image")  
wandb.log({"loglog{}".format(epoch): Img})			      # here!!!

3.3 Image 拼接

一般来说我们在一个 epoch 的时候想看看 GT 与 Rec 的区别,这就需要将两张图或者更多图拼接,效果基本上如下所示,直接给出代码

转化为 Numpy [255, unit8, RGB]

# 将tensor转为cpu
img, gt = img[0].detach().cpu(), gt[0].detach().cpu()
# 如果获取数据时用了-mean/var操作,后续需要在调回去
img = img*0.5+0.5
gt  = gt*0.5+0.5
# concat并转为numpy 255 uint8
ori_bgc = np.concatenate([img, gt], axis=2)
ori_bgc = np.transpose(ori_bgc, (1, 2, 0))
ori_bgc = Image.fromarray(np.clip(ori_bgc * 255.0, 0, 255.0).astype('uint8'))
# caption and log info
LtoH    = wandb.Image(ori_bgc, caption="epoch{}: results".format(epoch))
wandb.log({"LtoR: Origin BGColor Results": LtoH})

4. distribute 注意事项

注意在分布式训练中我们一般只记录一个 GPU 上的内容,也就是说如果我们将 images 进行记录并不会记录所有的图片。举个例子:我们验证集一共有 24 张,我们用了 6 张卡,在数据分配的时候每张卡分配了 4 个数据,在进行验证记录每个 epoch 结果的时候一般来说我们只将标号为 0 的那张卡进行记录,所以只记录了 4 幅图,并不会记录所有的 24 张

那么如果我们想记录所有的图怎么办?那就让所有的卡(0,1,2,3,4,5)都去记录,这样的一个缺点是最后会在 project 下生成 6 个记录,而不是一个

9月 09, 2024