torch.stack()
stack在英文中有“堆叠的意思”。所以stack通常是把一些低纬(二维)的tensor堆叠为一个高维(三维)的tensor。
stack()官方解释:torch.stack[source] → Tensor :
函数目的: 沿着一个新维度对输入张量序列进行拼接 。其中序列中所有的 张量 都应该为相同形状。
outputs = torch.stack(inputs, dim=0) # → Tensor
参数:
- inputs : 待连接的张量序列。
注:python的序列数据只有list和tuple。 - dim : 新的维度, 必须在0到len(outputs)之间。
注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。
例子
- 准备2个tensor数据,每个的shape都是[3,3]
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
- 测试stack函数
R0 = torch.stack((T1, T2), dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
"""
R0:
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
R0.shape:
torch.Size([2, 3, 3])
"""
R1 = torch.stack((T1, T2), dim=1)
print("R1.shape:\n", R1.shape)
"""
R1:
tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]])
R1.shape:
torch.Size([3, 2, 3])
"""
R2 = torch.stack((T1, T2), dim=2)
print("R2:\n", R2)
print("R2.shape:\n", R2.shape)
"""
R2:
tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]])
R2.shape:
torch.Size([3, 3, 2])
"""
R3 = torch.stack((T1, T2), dim=3)
print("R3:\n", R3)
print("R3.shape:\n", R3.shape)
"""
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
"""
torch.cat()
一般torch.cat()是为了把函数torch.stack()得到tensor进行拼接而存在的。torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor。
函数目的:
在给定维度上对输入的张量序列seq 进行连接操作。
outputs = torch.cat(inputs, dim=0) → Tensor
参数:
- inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列。
- dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。
重点:
- 输入数据必须是序列,序列中数据是任意相同的shape的同类型tensor
- 维度不可以超过输入数据的任一个张量的维度
例子
- 准备数据,每个的shape都是[2,3]
x1 = torch.tensor([[11, 21, 31], [21, 31, 41]], dtype=torch.int)
print("x1:\n", x1)
print("x1.shape:\n", x1.shape)
'''
x1:
tensor([[11, 21, 31],
[21, 31, 41]], dtype=torch.int32)
x1.shape:
torch.Size([2, 3])
'''
x2 = torch.tensor([[12, 22, 32], [22, 32, 42]])
print("x2:\n", x2)
print("x2.shape:\n", x2.shape)
'''
x2:
tensor([[12, 22, 32],
[22, 32, 42]])
x2.shape:
torch.Size([2, 3])
'''
- 合成inputs
inputs = [x1, x2]
print("inputs:\n", inputs)
'''
inputs:
[tensor([[11, 21, 31],
[21, 31, 41]], dtype=torch.int32), tensor([[12, 22, 32],
[22, 32, 42]])]
'''
- 查看结果, 测试不同的dim拼接结果
R0 = torch.cat(inputs, dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
'''
R0:
tensor([[11, 21, 31],
[21, 31, 41],
[12, 22, 32],
[22, 32, 42]])
R0.shape:
torch.Size([4, 3])
'''
R1 = torch.cat(inputs, dim=1)
print("R1:\n", R1)
print("R1.shape:\n", R1.shape)
'''
R1:
tensor([[11, 21, 31, 12, 22, 32],
[21, 31, 41, 22, 32, 42]])
R1.shape:
torch.Size([2, 6])
'''
R2 = torch.cat(inputs, dim=2)
print("R2:\n", R2)
print("R2.shape:\n", R2.shape)
'''
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
'''
总结
- torch.stack()是新增一维,属于增维操作(22,22 → 222)。
- torch.cat()是在特定维度上进行拼接(22, 22 → 2*4)。
参考链接:https://blog.csdn.net/qq_40507857/article/details/119854085