0%

Pytorch Usage

Pytorch Usage

.permute() .transpose() VS. tensor.view()——.contiguous()

transpose、permute 操作虽然没有修改底层一维数组,但是新建了一份Tensor元信息,并在新的元信息中的 重新指定 stride。

torch.view 方法约定了不修改数组本身,只是使用新的形状查看数据。如果我们在 transpose、permute 操作后执行 view,Pytorch 会抛出错误!

看这个实验:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
>>>t = torch.arange(12).reshape(3,4)
>>>t
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>>t.stride()
(4, 1) # 0维度之间的stride是4,1维度之间的stride是1
>>>t2 = t.transpose(0,1)
>>>t2
tensor([[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]])
>>>t2.stride()
(1, 4) # 0维度之间的stride是1,1维度之间的stride是4
>>>t.data_ptr() == t2.data_ptr() # 底层数据是同一个一维数组
True
>>>t.is_contiguous(),t2.is_contiguous() # t连续,t2不连续
(True, False)

# 连续意味着当前看到的和实际存储的顺序应该一致!

使用contiguous方法后返回新Tensor t3,重新开辟了一块内存,并使用 t2 的顺序存储底层数据。

1
2
3
4
5
6
7
8
t3 = t2.contiguous()  # 返回一个新的tensor!
>>>t3
tensor([[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]])
>>>t3.data_ptr() == t2.data_ptr() # 底层数据不是同一个一维数组
False

可以发现 t与t2 底层数据指针一致,t3 与 t2 底层数据指针不一致,说明确实重新开辟了内存空间。

总结

transposepermute 后使用 contiguous 方法则会重新开辟一块内存空间,保证数据在逻辑顺序和内存中是一致的,连续内存布局减少了CPU对对内存的请求次数(访问内存比访问寄存器慢100倍),相当于空间换时间。

Reference

https://zhuanlan.zhihu.com/p/64551412

Donate comment here.