ToPILImage&&ToTensor
pytorch
本文字数:563 字 | 阅读时长 ≈ 2 min

ToPILImage&&ToTensor

pytorch
本文字数:563 字 | 阅读时长 ≈ 2 min

在 Pytorch 创建数据集时,常常会有 transform.ToPILImagetransform.ToTensor 两个函数,前一个函数是将 numpy 转变为 PILImage 形式,第二个函数是将 PILImage 形式转变为 tensor 形式方便计算,转换时需要注意以下几点

图片一共有三种形式,PILImage 形式,tensor 形式以及 numpy 形式

import cv2
import torchvision.transforms as transforms
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image


'''PIL image'''
image = Image.open("./food-11/0000.jpg")
print(image)  # <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=384x512 at 0x7F1CFE47DD60>
image2 = transforms.ToTensor()(image)  # to simplify computation
print(image2)  # tensor.data[0, 1]
print(image2.dtype)  # tensor.float32
print(image2.shape)  # tensor.Size([3, 512, 384])
# if you want opencv to support tensor image, you should know below
# opencv support numpy format, dtype is uint8, pixle range [0, 255]
# when use ToTensor, tensor range is [0, 1], dtype is FloatTensor
# channel: PIL && torch:RGB   opencv:BGR
# dimensional: torch: CHW  numpy: HWC
# show tensor
tensor_ = image2
array_ = tensor_.numpy()  # convert tensor to numpy
max = array_.max()
array_ = array_*255/max  # expand pixel from [0, 1] to [0, 255], then normalization
array_convert = np.uint8(array_)  # opencv only support uint8 numpy.array
print("array_convert:", array_convert.shape)  # array_convert: (3, 512, 384)
array_convert = array_convert.transpose(1, 2, 0)  # array_convert: (512, 384, 3)
cv2.imshow("image", array_convert)  # different from the original image
cv2.waitKey()
array_convert = cv2.cvtColor(array_convert, cv2.COLOR_RGB2BGR)  # convert RGB to BGR
cv2.imshow("image2", array_convert)  # the same to the original image
cv2.waitKey()



'''opencv'''
img = cv2.imread("./food-11/0000.jpg")
print(img)  # np.array.data[0, 255]
print(type(img))  # <class 'numpy.ndarray'>
print(img.dtype)  # uint8
print(img.shape)  # (512, 384, 3)
cv2.imshow("img", img)
cv2.waitKey()
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # if don't convert, the image will distort
img2 = transforms.ToPILImage()(img)
img2.show()

如果图片为 numpy 类型的浮点数 float32,可以直接用 skimage.img_as_ubyte 将其转换为 255 范围的 uint8 类型,img=skimage.img_as_ubyte(img)

9月 09, 2024