Pytorch cat()与stack()函数详解

torch.cat()

cat为concatenate的缩写,意思为拼接,torch.cat()函数一般是用于张量拼接使用的

cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor:

可以看到cat()函数的参数,常用的参数为,第一个参数:可以选择元组或者列表,内部包含需要拼接的张量,需要按照顺序排列,第二个参数为dim,用于指定需要拼接的维度

import torch
import numpy as np

data1 = torch.randint(0, 10, [2, 3, 4])
data2 = torch.randint(0, 10, [2, 3, 4])

print(data1)
print(data2)
print("-" * 20)

print(torch.cat([data1, data2], dim=0))
print(torch.cat([data1, data2], dim=1))
print(torch.cat([data1, data2], dim=2))
# tensor([[[9, 4, 0, 0],
#          [3, 3, 7, 6],
#          [6, 1, 0, 8]],
# 
#         [[9, 1, 1, 2],
#          [1, 0, 6, 4],
#          [7, 9, 3, 9]]])
# tensor([[[3, 2, 6, 3],
#          [8, 3, 1, 1],
#          [0, 9, 2, 5]],
# 
#         [[2, 6, 7, 5],
#          [9, 1, 0, 1],
#         
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值