LLM增强序列推荐(LLMESR)模型代码详解
论文讲解见 LLM-ESR Large Language Models Enhancement for Long-tailed Sequential Recommendation
三个模型(LLMESR_SASRec、LLMESR_GRU4Rec和LLMESR_Bert4Rec)的具体代码实现细节。
首先,理解整体的继承关系:
- LLMESR_SASRec继承自DualLLMSASRec,而DualLLMSASRec继承自SASRec_seq,SASRec_seq继承自SASRec
- LLMESR_GRU4Rec继承自DualLLMGRU4Rec,而DualLLMGRU4Rec继承自GRU4Rec
- LLMESR_Bert4Rec继承自DualLLMBert4Rec,而DualLLMBert4Rec继承自Bert4Rec
这三个模型都是在序列推荐模型的基础上增加了大语言模型(LLM)的知识。
从DualLLM系列模型开始分析,因为这是三个LLMESR模型直接继承的类。从DualLLMGRU4Rec的代码可以看到:
- 加载预训练的LLM物品嵌入
- 使用adapter将LLM嵌入维度调整到模型的隐藏维度
- 可能包含cross-attention机制
下面分析具体实现差异:
1. 基础序列编码机制
- GRU4Rec: 使用GRU(门控循环单元)处理序列数据
- SASRec: 使用单向自注意力机制处理序列
- Bert4Rec: 使用双向Transformer架构处理序列
2. 输入处理方式
- GRU4Rec: 直接处理序列输入
- SASRec: 需要位置编码
- Bert4Rec: 使用位置编码和mask机制
3. DualLLM的公共设计
- 所有DualLLM模型都加载预训练的LLM物品嵌入
- 通过adapter网络调整LLM嵌入维度
- 支持跨模态注意力(cross-attention)
4. LLMESR系列的增强
-
增加了对比学习/知识蒸馏机制
-
增加物品表示正则化选项
以下从三个方面详细对比分析三种模型的实现:基础架构、大语言模型集成方式和特征对齐机制。
一、基础序列建模实现
1. LLMESR_GRU4Rec - 基于GRU的序列建模
class GRU4RecBackbone(nn.Module):
def __init__(self, device, args):
super().__init__()
# GRU核心组件
self.gru = nn.GRU(
input_size=args.hidden_size,
hidden_size=args.hidden_size,
num_layers=args.gru_layer_num, # 多层GRU
batch_first=True
)
def forward(self, seqs, log_seqs):
# GRU处理序列数据,不需要位置编码
output, hidden = self.gru(seqs) # output: [batch_size, seq_len, hidden]
return output
GRU4Rec通过循环网络处理序列,记忆门控机制能够捕捉长期依赖,实现相对简单。
2. LLMESR_SASRec - 基于自注意力的序列建模
class SASRecBackbone(nn.Module):
def __init__(self, device, args):
super().__init__()
# 自注意力核心组件
self.attention_layernorms = nn.ModuleList()
self.attention_layers = nn.ModuleList()
self.forward_layernorms = nn.ModuleList()
self.forward_layers = nn.ModuleList()
for _ in range(args.block_num):
self.attention_layernorms.append(nn.LayerNorm(args.hidden_size))
self.attention_layers.append(
nn.MultiheadAttention(
args.hidden_size,
args.num_heads,
args.dropout_rate
)
)
self.forward_layernorms.append(nn.LayerNorm(args.hidden_size))
self.forward_layers.append(nn.Linear(args.hidden_size, args.hidden_size))
def forward(self, seqs, log_seqs):
# 生成时间线掩码(确保只看到历史信息)
timeline_mask = torch.BoolTensor(log_seqs == 0).to(self.dev)
seqs *= ~timeline_mask.unsqueeze(-1)
# 自注意力层处理
for i in range(len(self.attention_layers)):
# 自注意力处理
seqs = self.attention_layernorms[i](seqs)
Q = self.attention_layers[i](seqs, seqs, seqs,
attn_mask=~timeline_mask)
seqs = Q + seqs
# 前馈网络处理
seqs = self.forward_layernorms[i](seqs)
seqs = self.forward_layers[i](seqs)
seqs *= ~timeline_mask.unsqueeze(-1)
return seqs
SASRec使用单向自注意力机制,通过掩码确保只关注历史信息,能更好捕捉长序列中的依赖关系。
3. LLMESR_Bert4Rec - 基于Transformer编码器的序列建模
class BertBackbone(nn.Module):
def __init__(self, device, args):
super().__init__()
# 完整的Transformer编码器
encoder_layers = nn.TransformerEncoderLayer(
d_model=args.hidden_size,
nhead=args.num_heads,
dim_feedforward=args.hidden_size*4, # 更大的前馈网络
dropout=args.dropout_rate,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layers,
num_layers=args.block_num
)
def forward(self, seqs, log_seqs):
# 生成注意力掩码(允许双向注意力)
attn_mask = ~torch.tril(torch.ones(log_seqs.shape[1], log_seqs.shape[1])).bool()
output = self.transformer_encoder(
seqs,
mask=attn_mask.to(self.dev)
)
return output
Bert4Rec采用双向Transformer编码器架构,通过CLOZE任务预测,允许模型同时考虑上下文信息,表达能力最强。
二、大语言模型集成实现
三个模型共享的LLM集成代码(以DualLLMGRU4Rec为例):
# 加载LLM物品嵌入
llm_item_emb = pickle.load(open(os.path.join("data/"+args.dataset+"/handled/", "itm_emb_np.pkl"), "rb"))
# 添加填充向量和掩码向量
llm_item_emb = np.insert(llm_item_emb, 0, values=np.zeros((1, llm_item_emb.shape[1])), axis=0)
llm_item_emb = np.concatenate([llm_item_emb, np.zeros((1, llm_item_emb.shape[1]))], axis=0)
# 转换为Embedding层
self.llm_item_emb = nn.Embedding.from_pretrained(torch.Tensor(llm_item_emb))
self.llm_item_emb.weight.requires_grad = True
# 维度调整适配器
self.adapter = nn.Sequential(
nn.Linear(llm_item_emb.shape[1], int(llm_item_emb.shape[1] / 2)),
nn.Linear(int(llm_item_emb.shape[1] / 2), args.hidden_size)
)
# 条件性使用跨模态注意力
if self.use_cross_att:
self.cross_att = Multi_CrossAttention(args.hidden_size, self.num_heads)
嵌入层获取实现(以DualLLMSASRec为例):
def _get_embedding(self, log_seqs):
# 获取ID嵌入和LLM嵌入
id_emb = self.id_item_emb(log_seqs)
llm_emb = self.adapter(self.llm_item_emb(log_seqs))
# 根据配置使用跨模态注意力或简单融合
if self.use_cross_att:
seq_emb = self.cross_att(id_emb, llm_emb)
else:
seq_emb = id_emb + llm_emb
return seq_emb
三、特征对齐机制实现
三个LLMESR模型共享的对齐机制(以LLMESR_GRU4Rec为例):
def forward(self, seq, pos, neg, positions, **kwargs):
# 1. 获取主序列特征
log_feats = self.log2feats(seq)[:, -1, :]
# 2. 处理相似用户序列特征
sim_seq = kwargs["sim_seq"].view(-1, seq.shape[1])
sim_num = kwargs["sim_seq"].shape[1] # 每个用户的相似序列数量
sim_log_feats = self.log2feats(sim_seq)[:, -1, :]
sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1)
sim_log_feats = torch.mean(sim_log_feats, dim=1) # 平均池化
# 3. 计算主BPR损失
pos_emb = self.id_item_emb(pos)
neg_emb = self.id_item_emb(neg)
pos_logits = (log_feats * pos_emb).sum(dim=-1)
neg_logits = (log_feats * neg_emb).sum(dim=-1)
loss = -torch.mean(torch.log(torch.sigmoid(pos_logits - neg_logits)))
# 4. 计算对齐损失
if self.user_sim_func == "cl":
# 对比学习方式对齐
align_loss = self.align(log_feats, sim_log_feats)
elif self.user_sim_func == "kd":
# 知识蒸馏方式对齐
align_loss = self.align(log_feats, sim_log_feats)
# 5. 可选的物品表示正则化
if self.item_reg:
unfold_item_id = torch.masked_select(seq, seq>0)
llm_item_emb = self.adapter(self.llm_item_emb(unfold_item_id))
id_item_emb = self.id_item_emb(unfold_item_id)
reg_loss = self.reg(llm_item_emb, id_item_emb)
loss += self.beta * reg_loss
# 6. 组合最终损失
loss += self.alpha * align_loss
return loss
四、三个模型的核心差异汇总
特性 | LLMESR_GRU4Rec | LLMESR_SASRec | LLMESR_Bert4Rec |
---|---|---|---|
序列处理 | GRU循环网络 | 单向自注意力 | 双向Transformer |
时间顺序处理 | 隐式 | 显式(位置编码+掩码) | 显式(位置编码+双向) |
特征维度 | 2*hidden_size | hidden_size | hidden_size |
特征获取 | log2feats返回最后隐藏状态 | log2feats返回各位置隐藏状态 | log2feats返回各位置输出 |
参数量 | 最小 | 中等 | 最大 |
五、模型对齐机制原理
# 对比学习实现
class Contrastive_Loss2(nn.Module):
def __init__(self, alpha=0.5, tau_min=0.1, tau_max=1.0):
super().__init__()
self.alpha = alpha
self.tau_min = tau_min
self.tau_max = tau_max
def forward(self, q, k):
# 计算余弦相似度矩阵
logits = torch.mm(q, k.T) # [batch_size, batch_size]
# 获取正样本相似度 (对角线元素)
pos_sim = torch.diag(logits)
# 动态温度计算(基于logits分布)
tau = torch.std(logits).clamp(self.tau_min, self.tau_max)
# InfoNCE损失计算
pos_exp = torch.exp(pos_sim / tau)
neg_exp = torch.exp(logits / tau).sum(dim=1) - pos_exp
loss = -torch.log(pos_exp / (pos_exp + neg_exp)).mean()
return loss
).clamp(self.tau_min, self.tau_max)
# InfoNCE损失计算
pos_exp = torch.exp(pos_sim / tau)
neg_exp = torch.exp(logits / tau).sum(dim=1) - pos_exp
loss = -torch.log(pos_exp / (pos_exp + neg_exp)).mean()
return loss
以上是三个LLMESR模型的核心实现细节,它们在原始序列推荐基础上巧妙集成了大语言模型知识,并通过对比学习或知识蒸馏实现了特征对齐,从而提升推荐效果。