contiguous
pytorch
本文字数:1.3k 字 | 阅读时长 ≈ 6 min

contiguous

pytorch
本文字数:1.3k 字 | 阅读时长 ≈ 6 min

1. tensor 在内存中的存储

1.1 信息区和存储区

tensor 在内存中的存储包含信息区存储区

高维数组在内存中是按照行优先顺序存储的,什么是行优先顺序?假设我们有一个(3, 4)的 tensor,他其实是按照一维数组的方式存储的,只不过在 tensor 的信息区记录了他的 size 和 stride 导致实际上展示出的数组是二维的,size 为(3, 4)

二维数组

内存中的一维形式

接下来我们看一个例子,例子表明 tensor 中的元素在内存上是连续的,并且也证明了确实是行优先顺序存储

tensor = torch.tensor([[[1 ,2, 3, 4], [5, 6, 7, 8], [9, 10,11,12]], 
                       [[13,14,15,16],[17,18,19,20],[21,22,23,24]]])
print(tensor.is_contiguous())
for i in range(2):
    for j in range(3):
        for k in range(4):
            print(tensor[i][j][k].data_ptr(), end=' ')
'''
True
140430616343104 140430616343112 140430616343120 140430616343128
140430616343136 140430616343144 140430616343152 140430616343160 
140430616343168 140430616343176 140430616343184 140430616343192 
140430616343200 140430616343208 140430616343216 140430616343224 
140430616343232 140430616343240 140430616343248 140430616343256 
140430616343264 140430616343272 140430616343280 140430616343288 
'''

1.2 shape && stride

继续上述的例子我们来看一下在信息区的 shape 和 stride 属性,对于(2, 3, 4)维的 tensor 他的 shape 为(2, 3 ,4),stride 为(12, 4, 1)

shape

shape 很容易理解,就是 tensor 的维度,上述例子为(2, 3, 4)的 tensor,维度就为(2, 3, 4)

stride

stride 代表着多维索引的步长,每一步都代表内存上的偏移量+1,对于(2, 3, 4)维度的 tensor:stride+1 代表着(dim2)+1,stride+4 代表其余 dim 不变,(dim1)+1,stride+12 代表其余 dim 不变,(dim0)+1,如下图所示

图示 stride

stride 计算方法

$$
stride_{i} = stride_{i+1} * size_{i+1}~~~~i\in[0, n-2]
$$

对于shape(2, 3, 4)的tensor,计算如下(stride3=1)

$$
\begin{aligned}
stride_{2} = stride_{3} * shape_{3}=14=4 \
stride_{1} = stride_{2} * shape_{2}=4
3=12
\end{aligned}
$$

stride = [1] # 初始化第一个元素  
# 从后往前遍历迭代生成 stride  
for i in range(len(tensor.size())-2, -1, -1):  
    stride.insert(0, stride[0] * tensor.shape[i+1])  
print(stride)          # [12, 4, 1]
print(tensor.stride()) # (12, 4, 1)

==理解了 tensor 在内存中的存储之后,我们再来看 contiguous==

2. contiguous

contiguous

返回一个连续内存的 tensor

Returns a contiguous in memory tensor containing the same data as self tensor. If self tensor is already in the specified memory format, this function returns the self tensor.

2.1 什么时候用 contiguous 呢?

简单理解就是tensor 在内存地址中的存储顺序与实际的一维索引顺序不一致时使用,如下所示,对上面的 tensor 进行一维索引,结果为[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],对原 tensor 运用 transpose 进行转置,在对其进行一维索引,结果为[1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12],这时索引顺序发生了变化,所以需要用 contiguous

注意:不论怎么变化每个元素对应的地址是不变的,比如 11 对应的地址为 x11,transpose 之后 11 依然对应 x11,那么变化的是什么呢?还记得 tensor 分为信息区和存储区吗,存储区是不变化的,变化的是信息区的 shape,stride 等信息,有时间以后做介绍~

代码示例

tensor = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
print(tensor)
print(tensor.is_contiguous())
tensor = tensor.transpose(1, 0)
print(tensor)
print(tensor.is_contiguous())
'''
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
True
tensor([[ 1,  5,  9],
        [ 2,  6, 10],
        [ 3,  7, 11],
        [ 4,  8, 12]])
False
'''

2.2 为什么要用 contiguous

有人可能会有疑问,既然上述情况索引与之前不一样了(不连续了),为什么要让他变连续呢?因为 pytorch 的某些操作需要索引和内存连续,比如 view

代码示例(接着上面的例子)

tensor = tensor.contiguous()
print(tensor.is_contiguous())
tensor = tensor.view(3, 4)
print(tensor)
'''
True
tensor([[ 1,  5,  9,  2],
        [ 6, 10,  3,  7],
        [11,  4,  8, 12]])
'''

如果不用 contiguous 会报以下错误

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

2.3 为什么 contiguous 能有效?

contiguous 用了一种简单粗暴的方法,既然你之前的索引和内存不连续了,那我就重新开辟一块连续的内存给他加上索引即可

代码示例,从下面代码中 stride 变化可以看出,transpose 之后的 tensor 确实是改变了信息区的信息

tensor = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
print(tensor.is_contiguous())	# True
for i in range(3):
    for j in range(4):
        print(tensor[i][j], tensor[i][j].data_ptr(), end=' ')
    print()
print(tensor.stride()) 			# (4, 1)
'''
True
tensor(1) 140430616321664 tensor(2) 140430616321672 tensor(3) 140430616321680 tensor(4) 140430616321688 
tensor(5) 140430616321696 tensor(6) 140430616321704 tensor(7) 140430616321712 tensor(8) 140430616321720 
tensor(9) 140430616321728 tensor(10) 140430616321736 tensor(11) 140430616321744 tensor(12) 140430616321752 
(4, 1)
'''

tensor = tensor.transpose(1, 0)
print(tensor.is_contiguous())	# False
for i in range(4):
    for j in range(3):
        print(tensor[i][j], tensor[i][j].data_ptr(), end=' ')
    print()
print(tensor.stride()) 			# (1, 4) changed
'''
False
tensor(1) 140430616321664 tensor(5) 140430616321696 tensor(9) 140430616321728 
tensor(2) 140430616321672 tensor(6) 140430616321704 tensor(10) 140430616321736 
tensor(3) 140430616321680 tensor(7) 140430616321712 tensor(11) 140430616321744 
tensor(4) 140430616321688 tensor(8) 140430616321720 tensor(12) 140430616321752 
(1, 4)
'''

tensor = tensor.contiguous()
print(tensor.is_contiguous())	# True
for i in range(4):
    for j in range(3):
        print(tensor[i][j], tensor[i][j].data_ptr(), end=' ')
    print()
print(tensor.stride())      	# (3, 1)
'''
True
tensor(1) 140431681244608 tensor(5) 140431681244616 tensor(9) 140431681244624 
tensor(2) 140431681244632 tensor(6) 140431681244640 tensor(10) 140431681244648 
tensor(3) 140431681244656 tensor(7) 140431681244664 tensor(11) 140431681244672 
tensor(4) 140431681244680 tensor(8) 140431681244688 tensor(12) 140431681244696 
(3, 1)
'''
4月 06, 2025
3月 10, 2025
12月 31, 2024