torch.cat()
cat为concatenate的缩写,意思为拼接,torch.cat()函数一般是用于张量拼接使用的
cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor:
可以看到cat()函数的参数,常用的参数为,第一个参数:可以选择元组或者列表,内部包含需要拼接的张量,需要按照顺序排列,第二个参数为dim,用于指定需要拼接的维度
import torch
import numpy as np
data1 = torch.randint(0, 10, [2, 3, 4])
data2 = torch.randint(0, 10, [2, 3, 4])
print(data1)
print(data2)
print("-" * 20)
print(torch.cat([data1, data2], dim=0))
print(torch.cat([data1, data2], dim=1))
print(torch.cat([data1, data2], dim=2))
# tensor([[[9, 4, 0, 0],
# [3, 3, 7, 6],
# [6, 1, 0, 8]],
#
# [[9, 1, 1, 2],
# [1, 0, 6, 4],
# [7, 9, 3, 9]]])
# tensor([[[3, 2, 6, 3],
# [8, 3, 1, 1],
# [0, 9, 2, 5]],
#
# [[2, 6, 7, 5],
# [9, 1, 0, 1],
#