tensor 常用操作
pytorch
本文字数:939 字 | 阅读时长 ≈ 5 min

tensor 常用操作

pytorch
本文字数:939 字 | 阅读时长 ≈ 5 min

1. 维度重新排列

view

变换数据的维度,首先将 tensor 展平为一维,然后进行维度变换操作
view 要求数据内存地址是连续的

input = torch.randn(2, 3, 4)
print(input.view(3, 2, 4))

transpose

将数据的两个维度进行变化,变化后内存不连续

permuate

将数据的多个维度进行变化,变化后内存不连续

如果在使用 transpose 或者 permuate 之后使用 view 会报错,这是因为维度变换后内存不连续导致的,只需要在 view 之前使用 contiguous 函数即可,示例如下

input = torch.randn(2, 3, 4)
input = input.permuate(2, 1, 0)
print(input.is_contiguous())
# input.view(2, 3, 4)  # wrong
input = input.contiguous()
print(input.view())  # correct

2. Tensor.expand()

Tensor.expand(sizes)

将 tensor 进行扩展,size 为扩展后的维度,-1 表示对这一维度不进行拓展

Returns a new view of the self tensor with singleton dimensions expanded to a larger size.

Passing -1 as the size for a dimension means not changing the size of that dimension.

Tensor can be also expanded to a larger number of dimensions, and the new ones will be appended at the front. For the new dimensions, the size cannot be set to -1.

Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size one is expanded to a larger size by setting the stride to 0. Any dimension of size 1 can be expanded to an arbitrary value without allocating new memory.

x = torch.tensor([[1], [2], [3]])
print(x)
print(x.size())
'''
tensor([[1],
        [2],
        [3]])
torch.Size([3, 1])
'''


x_ex = x.expand(3, 4)
print(x_ex)
x_ex2 = x.expand(-1, 4)   # -1 means not changing the size of that dimension
print(x_ex2)
'''
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])
'''

3. torch.cat()

torch.cat(tensors, dim=0)

将 tensor 进行拼接,拼接的维度根据 dim 设置,默认为 0(行拼接)

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk()

x = torch.randn(2, 3)
print(x)
'''
tensor([[ 0.7004, -0.0935, -0.2668],
        [ 0.7922,  0.9567,  1.4191]])
'''

x_row = torch.cat((x, x, x), 0)
print(x_row)
x_col = torch.cat((x, x), 1)
print(x_col)
'''
tensor([[ 0.7004, -0.0935, -0.2668],
        [ 0.7922,  0.9567,  1.4191],
        [ 0.7004, -0.0935, -0.2668],
        [ 0.7922,  0.9567,  1.4191],
        [ 0.7004, -0.0935, -0.2668],
        [ 0.7922,  0.9567,  1.4191]])
tensor([[ 0.7004, -0.0935, -0.2668,  0.7004, -0.0935, -0.2668],
        [ 0.7922,  0.9567,  1.4191,  0.7922,  0.9567,  1.4191]])
'''

4. torch.split()

torch.split(tensor, split_size_or_sections, dim=0)

将 tensor 进行切片,split_size_or_sections 表示切片的大小,可以为整型或者列表,dim 为切片维度,默认为 0 对行进行切片

Splits the tensor into chunks. Each chunk is a view of the original tensor.

If split_size_or_sections is an integer type, then tensorwill be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.

If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.

a = torch.arange(10).reshape(5,2)
print(a)
'''
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
'''


a1 = torch.split(a, 2)
print(a1)
print(a1[1]) # 因为分为了三组,所以可以选择任一项输出
'''
(tensor([[0, 1],
         [2, 3]]), 
 tensor([[4, 5],
         [6, 7]]), 
 tensor([[8, 9]]))
tensor([[4, 5],
        [6, 7]])
'''


a2 = torch.split(a, [1,4])
print(a2)
'''
(tensor([[0, 1]]), 
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))
'''

5. torch.chunk()

torch.chunk(tensor, chunks, dim=0)

torch.chunk()torch.split() 功能完全一致,唯一的区别是参数 chunks 输入只能是整型

Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.

Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.

b = torch.arange(10).reshape(2,5)
print(b)
'''
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
'''


b1 = torch.chunk(b, 2, dim=1)
print(b1)
print(b1[1])
'''
(tensor([[0, 1, 2],
         [5, 6, 7]]), 
 tensor([[3, 4],
         [8, 9]]))
tensor([[3, 4],
        [8, 9]])
'''


b2 = torch.chunk(b, 2, dim=1)
print(b2)
'''
(tensor([[0, 1, 2],
         [5, 6, 7]]), 
 tensor([[3, 4],
         [8, 9]]))
'''
4月 06, 2025
3月 10, 2025
12月 31, 2024