LLM-ESR实验代码讲解

LLM增强序列推荐(LLMESR)模型代码详解

论文讲解见 LLM-ESR Large Language Models Enhancement for Long-tailed Sequential Recommendation
image-20250303114202902

三个模型(LLMESR_SASRec、LLMESR_GRU4Rec和LLMESR_Bert4Rec)的具体代码实现细节。

首先,理解整体的继承关系:

  1. LLMESR_SASRec继承自DualLLMSASRec,而DualLLMSASRec继承自SASRec_seq,SASRec_seq继承自SASRec
  2. LLMESR_GRU4Rec继承自DualLLMGRU4Rec,而DualLLMGRU4Rec继承自GRU4Rec
  3. LLMESR_Bert4Rec继承自DualLLMBert4Rec,而DualLLMBert4Rec继承自Bert4Rec

这三个模型都是在序列推荐模型的基础上增加了大语言模型(LLM)的知识。

从DualLLM系列模型开始分析,因为这是三个LLMESR模型直接继承的类。从DualLLMGRU4Rec的代码可以看到:

  1. 加载预训练的LLM物品嵌入
  2. 使用adapter将LLM嵌入维度调整到模型的隐藏维度
  3. 可能包含cross-attention机制

下面分析具体实现差异:

1. 基础序列编码机制

  • GRU4Rec: 使用GRU(门控循环单元)处理序列数据
  • SASRec: 使用单向自注意力机制处理序列
  • Bert4Rec: 使用双向Transformer架构处理序列

2. 输入处理方式

  • GRU4Rec: 直接处理序列输入
  • SASRec: 需要位置编码
  • Bert4Rec: 使用位置编码和mask机制

3. DualLLM的公共设计

  • 所有DualLLM模型都加载预训练的LLM物品嵌入
  • 通过adapter网络调整LLM嵌入维度
  • 支持跨模态注意力(cross-attention)

4. LLMESR系列的增强

  • 增加了对比学习/知识蒸馏机制

  • 增加物品表示正则化选项

以下从三个方面详细对比分析三种模型的实现:基础架构、大语言模型集成方式和特征对齐机制。

一、基础序列建模实现

1. LLMESR_GRU4Rec - 基于GRU的序列建模

class GRU4RecBackbone(nn.Module):
    def __init__(self, device, args):
        super().__init__()
        # GRU核心组件
        self.gru = nn.GRU(
            input_size=args.hidden_size, 
            hidden_size=args.hidden_size,  
            num_layers=args.gru_layer_num,  # 多层GRU
            batch_first=True
        )
    
    def forward(self, seqs, log_seqs):
        # GRU处理序列数据,不需要位置编码
        output, hidden = self.gru(seqs)  # output: [batch_size, seq_len, hidden]
        return output

GRU4Rec通过循环网络处理序列,记忆门控机制能够捕捉长期依赖,实现相对简单。

2. LLMESR_SASRec - 基于自注意力的序列建模

class SASRecBackbone(nn.Module):
    def __init__(self, device, args):
        super().__init__()
        # 自注意力核心组件
        self.attention_layernorms = nn.ModuleList()
        self.attention_layers = nn.ModuleList()
        self.forward_layernorms = nn.ModuleList()
        self.forward_layers = nn.ModuleList()
        
        for _ in range(args.block_num):
            self.attention_layernorms.append(nn.LayerNorm(args.hidden_size))
            self.attention_layers.append(
                nn.MultiheadAttention(
                    args.hidden_size, 
                    args.num_heads, 
                    args.dropout_rate
                )
            )
            self.forward_layernorms.append(nn.LayerNorm(args.hidden_size))
            self.forward_layers.append(nn.Linear(args.hidden_size, args.hidden_size))
            
    def forward(self, seqs, log_seqs):
        # 生成时间线掩码(确保只看到历史信息)
        timeline_mask = torch.BoolTensor(log_seqs == 0).to(self.dev)
        seqs *= ~timeline_mask.unsqueeze(-1)
        
        # 自注意力层处理
        for i in range(len(self.attention_layers)):
            # 自注意力处理
            seqs = self.attention_layernorms[i](seqs)
            Q = self.attention_layers[i](seqs, seqs, seqs, 
                                         attn_mask=~timeline_mask)
            seqs = Q + seqs
            
            # 前馈网络处理
            seqs = self.forward_layernorms[i](seqs)
            seqs = self.forward_layers[i](seqs)
            seqs *= ~timeline_mask.unsqueeze(-1)
            
        return seqs

SASRec使用单向自注意力机制,通过掩码确保只关注历史信息,能更好捕捉长序列中的依赖关系。

3. LLMESR_Bert4Rec - 基于Transformer编码器的序列建模

class BertBackbone(nn.Module):
    def __init__(self, device, args):
        super().__init__()
        # 完整的Transformer编码器
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=args.hidden_size, 
            nhead=args.num_heads,
            dim_feedforward=args.hidden_size*4,  # 更大的前馈网络
            dropout=args.dropout_rate,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layers, 
            num_layers=args.block_num
        )
        
    def forward(self, seqs, log_seqs):
        # 生成注意力掩码(允许双向注意力)
        attn_mask = ~torch.tril(torch.ones(log_seqs.shape[1], log_seqs.shape[1])).bool()
        output = self.transformer_encoder(
            seqs, 
            mask=attn_mask.to(self.dev)
        )
        return output

Bert4Rec采用双向Transformer编码器架构,通过CLOZE任务预测,允许模型同时考虑上下文信息,表达能力最强。

二、大语言模型集成实现

三个模型共享的LLM集成代码(以DualLLMGRU4Rec为例):

# 加载LLM物品嵌入
llm_item_emb = pickle.load(open(os.path.join("data/"+args.dataset+"/handled/", "itm_emb_np.pkl"), "rb"))
# 添加填充向量和掩码向量
llm_item_emb = np.insert(llm_item_emb, 0, values=np.zeros((1, llm_item_emb.shape[1])), axis=0)
llm_item_emb = np.concatenate([llm_item_emb, np.zeros((1, llm_item_emb.shape[1]))], axis=0)
# 转换为Embedding层
self.llm_item_emb = nn.Embedding.from_pretrained(torch.Tensor(llm_item_emb))    
self.llm_item_emb.weight.requires_grad = True

# 维度调整适配器
self.adapter = nn.Sequential(
    nn.Linear(llm_item_emb.shape[1], int(llm_item_emb.shape[1] / 2)),
    nn.Linear(int(llm_item_emb.shape[1] / 2), args.hidden_size)
)

# 条件性使用跨模态注意力
if self.use_cross_att:
    self.cross_att = Multi_CrossAttention(args.hidden_size, self.num_heads)

嵌入层获取实现(以DualLLMSASRec为例):

def _get_embedding(self, log_seqs):
    # 获取ID嵌入和LLM嵌入
    id_emb = self.id_item_emb(log_seqs)
    llm_emb = self.adapter(self.llm_item_emb(log_seqs))
    
    # 根据配置使用跨模态注意力或简单融合
    if self.use_cross_att:
        seq_emb = self.cross_att(id_emb, llm_emb)
    else:
        seq_emb = id_emb + llm_emb
        
    return seq_emb

三、特征对齐机制实现

三个LLMESR模型共享的对齐机制(以LLMESR_GRU4Rec为例):

def forward(self, seq, pos, neg, positions, **kwargs):
    # 1. 获取主序列特征
    log_feats = self.log2feats(seq)[:, -1, :]
    
    # 2. 处理相似用户序列特征
    sim_seq = kwargs["sim_seq"].view(-1, seq.shape[1])
    sim_num = kwargs["sim_seq"].shape[1]  # 每个用户的相似序列数量
    sim_log_feats = self.log2feats(sim_seq)[:, -1, :]
    sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1)
    sim_log_feats = torch.mean(sim_log_feats, dim=1)  # 平均池化
    
    # 3. 计算主BPR损失
    pos_emb = self.id_item_emb(pos)
    neg_emb = self.id_item_emb(neg)
    pos_logits = (log_feats * pos_emb).sum(dim=-1)
    neg_logits = (log_feats * neg_emb).sum(dim=-1)
    loss = -torch.mean(torch.log(torch.sigmoid(pos_logits - neg_logits)))
    
    # 4. 计算对齐损失
    if self.user_sim_func == "cl":
        # 对比学习方式对齐
        align_loss = self.align(log_feats, sim_log_feats)
    elif self.user_sim_func == "kd":
        # 知识蒸馏方式对齐
        align_loss = self.align(log_feats, sim_log_feats)
    
    # 5. 可选的物品表示正则化
    if self.item_reg:
        unfold_item_id = torch.masked_select(seq, seq>0)
        llm_item_emb = self.adapter(self.llm_item_emb(unfold_item_id))
        id_item_emb = self.id_item_emb(unfold_item_id)
        reg_loss = self.reg(llm_item_emb, id_item_emb)
        loss += self.beta * reg_loss
    
    # 6. 组合最终损失
    loss += self.alpha * align_loss
    
    return loss

四、三个模型的核心差异汇总

特性LLMESR_GRU4RecLLMESR_SASRecLLMESR_Bert4Rec
序列处理GRU循环网络单向自注意力双向Transformer
时间顺序处理隐式显式(位置编码+掩码)显式(位置编码+双向)
特征维度2*hidden_sizehidden_sizehidden_size
特征获取log2feats返回最后隐藏状态log2feats返回各位置隐藏状态log2feats返回各位置输出
参数量最小中等最大

五、模型对齐机制原理

# 对比学习实现
class Contrastive_Loss2(nn.Module):
    def __init__(self, alpha=0.5, tau_min=0.1, tau_max=1.0):
        super().__init__()
        self.alpha = alpha
        self.tau_min = tau_min
        self.tau_max = tau_max

    def forward(self, q, k):
        # 计算余弦相似度矩阵
        logits = torch.mm(q, k.T)  # [batch_size, batch_size]
        
        # 获取正样本相似度 (对角线元素)
        pos_sim = torch.diag(logits)
        
        # 动态温度计算(基于logits分布)
        tau = torch.std(logits).clamp(self.tau_min, self.tau_max)
        
        # InfoNCE损失计算
        pos_exp = torch.exp(pos_sim / tau)
        neg_exp = torch.exp(logits / tau).sum(dim=1) - pos_exp
        loss = -torch.log(pos_exp / (pos_exp + neg_exp)).mean()
        
        return loss

).clamp(self.tau_min, self.tau_max)

    # InfoNCE损失计算
    pos_exp = torch.exp(pos_sim / tau)
    neg_exp = torch.exp(logits / tau).sum(dim=1) - pos_exp
    loss = -torch.log(pos_exp / (pos_exp + neg_exp)).mean()
    
    return loss

以上是三个LLMESR模型的核心实现细节,它们在原始序列推荐基础上巧妙集成了大语言模型知识,并通过对比学习或知识蒸馏实现了特征对齐,从而提升推荐效果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

三水编程

感谢客官打赏~~~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值