PyTorch .expand()
方法
1. expand()
的作用
expand()
方法用于 扩展(broadcasting)张量的维度,但不会真正复制数据。它主要用于将形状较小的张量扩展到较大的张量,使它们能够进行广播计算。与 torch.broadcast_tensors()
类似,expand()
主要用于 不占用额外内存 的情况下,使张量在计算时适配目标形状。
2. 语法
tensor.expand(*sizes)
或:
tensor.expand_as(other_tensor)
参数
sizes
: 目标形状,必须与原始形状兼容,即:- 原张量维度为 1 的轴可以被扩展。
- 不能扩展非 1 的轴,否则会报错。
expand_as(other_tensor)
: 让tensor
变为other_tensor
相同的形状(必须符合广播规则)。
返回值
返回一个新的张量,但不会真正复制数据,只是创建了一个新的视图(view)。
3. expand()
使用示例
3.1. 基础示例
import torch
x = torch.tensor([[1], [2], [3]]) # 形状: (3, 1)
# 扩展到 (3, 4)
x_expanded = x.expand(3, 4)
print(x_expanded)
print("Shape:", x_expanded.shape)
输出
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
Shape: torch.Size([3, 4])
解析:
- 原张量
x
形状是(3, 1)
。 expand(3, 4)
使得1
维的部分被扩展成4
,但不会真正复制数据。
3.2. 使用 expand_as()
x = torch.tensor([[1], [2], [3]]) # 形状: (3, 1)
y = torch.empty(3, 4) # 目标形状 (3, 4)
x_expanded = x.expand_as(y)
print(x_expanded)
print("Shape:", x_expanded.shape)
输出
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
Shape: torch.Size([3, 4])
解析:
expand_as(y)
让x
变成y
的形状(3, 4)
,效果等同于expand(3, 4)
。
3.3. 验证 .expand()
不会复制数据
x = torch.tensor([[1], [2], [3]]) # 形状: (3, 1)
x_expanded = x.expand(3, 4)
print("Memory address comparison:", x.data_ptr() == x_expanded.data_ptr())
输出
Memory address comparison: True
解析:
.expand()
不会分配新内存,它只是创建了共享相同数据的视图。data_ptr()
返回张量的内存地址,x
和x_expanded
共享相同的数据。
4. .expand()
vs .repeat()
方法 | 复制数据? | 内存效率 | 用途 |
---|---|---|---|
expand() | 不复制数据 | 高效 | 适用于广播计算,节省内存 |
repeat() | 复制数据 | 占用更多内存 | 适用于真正需要数据复制的情况 |
示例
x = torch.tensor([[1], [2], [3]])
x_expanded = x.expand(3, 4)
x_repeated = x.repeat(1, 4)
print("Expanded tensor:")
print(x_expanded)
print("\nRepeated tensor:")
print(x_repeated)
输出
Expanded tensor:
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
Repeated tensor:
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
区别
.expand()
不会复制数据,只是创建视图。.repeat()
会复制数据,新张量占用更多内存。
5. .expand()
的局限性
5.1. 只能扩展维度为 1 的轴
如果 expand()
尝试扩展非 1 的轴,会报错:
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状: (2, 3)
# 尝试扩展非 1 轴
try:
x_expanded = x.expand(4, 3)
except RuntimeError as e:
print("RuntimeError:", e)
输出
RuntimeError: The expanded size of the tensor (4) must match the existing size (2) at non-singleton dimension 0. Target sizes: [4, 3]. Tensor sizes: [2, 3].
原因
x
的第一维是2
,不能扩展成4
,因为expand()
只能扩展尺寸为 1 的维度。
6. 总结
方法 | 是否复制数据? | 适用于 |
---|---|---|
.expand() | ❌ 不复制数据,只是视图 | 广播计算,节省内存 |
.repeat() | ✅ 复制数据 | 真正需要数据复制的情况 |
.expand_as(other_tensor) | ❌ 作用等同于 .expand(*other_tensor.shape) | 快速匹配另一个张量的形状 |
关键点
.expand()
不会真正复制数据,而是创建共享相同数据的视图。- 只能扩展 维度为 1 的轴,不能扩展非 1 轴。
.expand_as(other_tensor)
可以直接匹配另一个张量的形状。
适用场景
- 适用于 广播计算,如
A + B
,当B
形状较小时,可以使用.expand()
使其兼容。 - 节省内存,避免不必要的数据复制。
如果你不想占用额外内存,但又需要扩展张量的形状,.expand()
是最佳选择。