Bert&&GNN微调

from transformers import BertModel, BertTokenizer
import torch.nn.functional as F
from torch import nn
import torch
import pickle
import numpy as np
from tqdm import tqdm
import math
import logging

logging.basicConfig(filename='evaluation_results.log',  # log name
                    level=logging.INFO,  # 设置日志级别为INFO
                    format='%(asctime)s - %(levelname)s - %(message)s')  # 设置日志格式

class BERTModel(nn.Module):
    def __init__(self, bert_model_name='../bert-base-uncased'):
        super(BERTModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.w = nn.Parameter(torch.Tensor(100, 768))
        self.b = nn.Parameter(torch.Tensor(100))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(100)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, batch_text):
        # 处理文本输入
        all_response_embeddings = []
        # batch_size = len(batch_text)
        # all_response_embeddings = torch.zeros(batch_size, 100, device=device)
        for i, response_list in enumerate(batch_text):
            response_embedding = []
            for response in response_list:
                inputs = tokenizer(response, return_tensors='pt', max_length=128, truncation=True, padding='max_length').to(device)
                output = self.bert(**inputs)
                embedding = output.last_hidden_state.mean(dim=1)
                response_embedding.append(embedding)
            # with torch.no_grad():
            response_embedding = torch.stack(response_embedding).to(device)
            mean_response_embedding = torch.mean(response_embedding.squeeze(1), dim=0)
            t_embedding = (torch.matmul(mean_response_embedding, self.w.T) + self.b)
            # with torch.no_grad():
            all_response_embeddings.append(t_embedding)
            # all_response_embeddings[i] = t_embedding
        text_embedding = torch.stack(all_response_embeddings, dim=0).to(device)
        
        return text_embedding

# InfoNCE损失函数
def info_nce_loss(hidden1, hidden2, temperature=0.07):
    batch_size = hidden1.size(0)
    similarity_matrix = F.cosine_similarity(hidden1.unsqueeze(1), hidden2.unsqueeze(0), dim=-1) / temperature 
    labels = torch.arange(batch_size).to(hidden1.device)
    loss = F.cross_entropy(similarity_matrix, labels)
    return loss

def KLAlignmentModel(hidden1, hidden2):
    # hidden1_n = F.normalize(hidden1, p=2, dim=1)
    # hidden2_n = F.normalize(hidden2, p=2, dim=1)
    # la = F.mse_loss(hidden1_n, hidden2_n)
    p_hidden1 = F.softmax(hidden1, dim=-1) + 1e-8
    p_hidden2 = F.softmax(hidden2, dim=-1) + 1e-8
    kl_loss = F.kl_div(p_hidden2.log(), p_hidden1, reduction='batchmean')
    return kl_loss

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
tokenizer = BertTokenizer.from_pretrained('../bert-base-uncased')

# # BERT模型
bert_model = BERTModel().to(device)

# 定义优化器和损失函数
optimizer = torch.optim.Adam(bert_model.parameters(), lr=1e-4)

# 加载文本数据
file_path = "../LLM/data/beauty2014/train_prompts_with_candidate_response.txt"
data = pickle.load(open(file_path, 'rb'))

# 加载GNN-embedding
gnn_path = "../LLM/data/beauty2014/gnn_item_embeddings.txt"
gnn_item_embeddings = pickle.load(open(gnn_path, 'rb'))
# gnn_item_embeddings = np.vstack([item[i].reshape(1, -1) for item in gnn_item_embeddings for i in range(len(item))])

def batch_generator(data, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

batch_size = 100

# 微调模型
num_epochs = 5
for epoch in range(num_epochs):
    logging.info(epoch)
    print("----------------------------------第", epoch, "次-----------------------------")
    index = 0
    all_loss = 0
    for batch in tqdm(batch_generator(data, batch_size)):
        batch_text = []
        for raw_texts in batch:
            response_list = raw_texts[0].split(";")
            batch_text.append(response_list)
        batch_text_embedding = bert_model(batch_text)
        gnn_embedding = torch.Tensor(gnn_item_embeddings[index]).to(device)
        bedding = torch.Tensor(gnn_item_embeddings[index]).unsqueeze(0).to(device)
        loss = info_nce_loss(batch_text_embedding, gnn_embedding)
        all_loss += loss
        if index % 100 == 0:
            logging.info(loss.item())
            print(loss.item())
        index += 1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    logging.info(f"Epoch {epoch}, Loss: {all_loss.item()}")
    print(f"Epoch {epoch}, Loss: {all_loss.item()}")

# 保存
torch.save(bert_model.state_dict(), 'bert_model.pth')

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值