🍉 CSDN 叶庭云:https://yetingyun.blog.csdn.net/
在使用 PyTorch 进行编程时,给定一个形状为 torch.Size([36, 1, 1000, 2048])
的张量是 output.hidden_states
。请详细解释并说明以下代码的功能与作用:mean_embeddings = torch.mean(output.hidden_states, dim=-2).squeeze()
。具体阐述此操作对输入张量进行了怎样的变换。
下面我们逐步拆解这行代码对张量的变换含义:
mean_embeddings = torch.mean(output.hidden_states, dim=-2).squeeze()
假设原始张量:
output.hidden_states.shape == torch.Size([36, 1, 1000, 2048])
各维度含义(常见于 Transformer 类模型的输出)为:
-
36:隐藏层数
-
1:“分段(segment)” 维度
-
1000:序列长度(token 数或时序步数)
-
2048:隐藏层特征维度(hidden size)
1. torch.mean(…, dim=-2):dim=-2 指的是从后往前数的第 2 个维度,也就是原张量的 “1000” 这一维(序列长度维度)。torch.mean 会沿该维度对所有元素取算术平均。因此,执行后的中间结果张量形状:
torch.mean(output.hidden_states, dim=-2).shape
--> [36, 1, 2048]
含义:对每个隐藏层、每个 “分段”,把 1000 个 Token 的 2048 维向量做平均,得到一个长度 2048 的 “全序列平均” 向量。该序列将所有 Token 的上下文信息融合为一个固定长度的向量,常用于文本分类、相似度计算等需要句子级表示的场景。
2. .squeeze():squeeze() 会删除所有维度大小为 1 的维度。在这里,当前张量是 [36, 1, 2048]
,其中中间的 “1” 维会被去掉。最终形状变为:
[36, 2048]
总结一下整体变换流程:
-
输入:
[36, 1, 1000, 2048]
-
沿序列维度(1000)求平均 →
torch.mean(..., dim=-2)
。结果形状[36, 1, 2048]
。每个样本得到每层的一个 2048 维的 “全序列平均” 向量 -
去除冗余维度 → .squeeze(),
[36, 1, 2048] → [36, 2048]
这样,mean_embeddings
就是对原始序列在 Token 维度上进行 Token 平均后,去掉多余维度得到的最终特征表示。它方便后续操作,使我们能够直接获得每个样本的 2048
维向量。