【PyTorch】torch.Tensor.expand() 函数:扩展(broadcasting)张量维度为 1 的轴

PyTorch .expand() 方法

1. expand() 的作用

expand() 方法用于 扩展(broadcasting)张量的维度,但不会真正复制数据。它主要用于将形状较小的张量扩展到较大的张量,使它们能够进行广播计算。与 torch.broadcast_tensors() 类似,expand() 主要用于 不占用额外内存 的情况下,使张量在计算时适配目标形状。


2. 语法

tensor.expand(*sizes)

或:

tensor.expand_as(other_tensor)
参数
  • sizes: 目标形状,必须与原始形状兼容,即:
    • 原张量维度为 1 的轴可以被扩展。
    • 不能扩展非 1 的轴,否则会报错。
  • expand_as(other_tensor): 让 tensor 变为 other_tensor 相同的形状(必须符合广播规则)。
返回值

返回一个新的张量,但不会真正复制数据,只是创建了一个新的视图(view)。


3. expand() 使用示例

3.1. 基础示例
import torch

x = torch.tensor([[1], [2], [3]])  # 形状: (3, 1)

# 扩展到 (3, 4)
x_expanded = x.expand(3, 4)
print(x_expanded)
print("Shape:", x_expanded.shape)

输出

tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])
Shape: torch.Size([3, 4])

解析

  • 原张量 x 形状是 (3, 1)
  • expand(3, 4) 使得 1 维的部分被扩展成 4,但不会真正复制数据。

3.2. 使用 expand_as()
x = torch.tensor([[1], [2], [3]])  # 形状: (3, 1)
y = torch.empty(3, 4)  # 目标形状 (3, 4)

x_expanded = x.expand_as(y)
print(x_expanded)
print("Shape:", x_expanded.shape)

输出

tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])
Shape: torch.Size([3, 4])

解析

  • expand_as(y)x 变成 y 的形状 (3, 4),效果等同于 expand(3, 4)

3.3. 验证 .expand() 不会复制数据
x = torch.tensor([[1], [2], [3]])  # 形状: (3, 1)
x_expanded = x.expand(3, 4)

print("Memory address comparison:", x.data_ptr() == x_expanded.data_ptr())

输出

Memory address comparison: True

解析

  • .expand() 不会分配新内存,它只是创建了共享相同数据的视图
  • data_ptr() 返回张量的内存地址,xx_expanded 共享相同的数据。

4. .expand() vs .repeat()

方法复制数据?内存效率用途
expand()不复制数据高效适用于广播计算,节省内存
repeat()复制数据占用更多内存适用于真正需要数据复制的情况
示例
x = torch.tensor([[1], [2], [3]])

x_expanded = x.expand(3, 4)
x_repeated = x.repeat(1, 4)

print("Expanded tensor:")
print(x_expanded)

print("\nRepeated tensor:")
print(x_repeated)

输出

Expanded tensor:
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])

Repeated tensor:
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])

区别

  • .expand() 不会复制数据,只是创建视图
  • .repeat() 会复制数据,新张量占用更多内存

5. .expand() 的局限性

5.1. 只能扩展维度为 1 的轴

如果 expand() 尝试扩展非 1 的轴,会报错:

x = torch.tensor([[1, 2, 3], [4, 5, 6]])  # 形状: (2, 3)

# 尝试扩展非 1 轴
try:
    x_expanded = x.expand(4, 3)
except RuntimeError as e:
    print("RuntimeError:", e)

输出

RuntimeError: The expanded size of the tensor (4) must match the existing size (2) at non-singleton dimension 0. Target sizes: [4, 3]. Tensor sizes: [2, 3].

原因

  • x 的第一维是 2,不能扩展成 4,因为 expand() 只能扩展尺寸为 1 的维度

6. 总结

方法是否复制数据?适用于
.expand()❌ 不复制数据,只是视图广播计算,节省内存
.repeat()✅ 复制数据真正需要数据复制的情况
.expand_as(other_tensor)❌ 作用等同于 .expand(*other_tensor.shape)快速匹配另一个张量的形状
关键点
  • .expand() 不会真正复制数据,而是创建共享相同数据的视图
  • 只能扩展 维度为 1 的轴,不能扩展非 1 轴。
  • .expand_as(other_tensor) 可以直接匹配另一个张量的形状。

适用场景

  • 适用于 广播计算,如 A + B,当 B 形状较小时,可以使用 .expand() 使其兼容。
  • 节省内存,避免不必要的数据复制。

如果你不想占用额外内存,但又需要扩展张量的形状,.expand() 是最佳选择

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值