from transformers import BertModel, BertTokenizer
import torch.nn.functional as F
from torch import nn
import torch
import pickle
import numpy as np
from tqdm import tqdm
import math
import logging
logging.basicConfig(filename='evaluation_results.log', # log name
level=logging.INFO, # 设置日志级别为INFO
format='%(asctime)s - %(levelname)s - %(message)s') # 设置日志格式
class BERTModel(nn.Module):
def __init__(self, bert_model_name='../bert-base-uncased'):
super(BERTModel, self).__init__()
self.bert = BertModel.from_pretrained(bert_model_name)
self.w = nn.Parameter(torch.Tensor(100, 768))
self.b = nn.Parameter(torch.Tensor(100))
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(100)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, batch_text):
# 处理文本输入
all_response_embeddings = []
# batch_size = len(batch_text)
# all_response_embeddings = torch.zeros(batch_size, 100, device=device)
for i, response_list in enumerate(batch_text):
response_embedding = []
for response in response_list:
inputs = tokenizer(response, return_tensors='pt', max_length=128, truncation=True, padding='max_length').to(device)
output = self.bert(**inputs)
embedding = output.last_hidden_state.mean(dim=1)
response_embedding.append(embedding)
# with torch.no_grad():
response_embedding = torch.stack(response_embedding).to(device)
mean_response_embedding = torch.mean(response_embedding.squeeze(1), dim=0)
t_embedding = (torch.matmul(mean_response_embedding, self.w.T) + self.b)
# with torch.no_grad():
all_response_embeddings.append(t_embedding)
# all_response_embeddings[i] = t_embedding
text_embedding = torch.stack(all_response_embeddings, dim=0).to(device)
return text_embedding
# InfoNCE损失函数
def info_nce_loss(hidden1, hidden2, temperature=0.07):
batch_size = hidden1.size(0)
similarity_matrix = F.cosine_similarity(hidden1.unsqueeze(1), hidden2.unsqueeze(0), dim=-1) / temperature
labels = torch.arange(batch_size).to(hidden1.device)
loss = F.cross_entropy(similarity_matrix, labels)
return loss
def KLAlignmentModel(hidden1, hidden2):
# hidden1_n = F.normalize(hidden1, p=2, dim=1)
# hidden2_n = F.normalize(hidden2, p=2, dim=1)
# la = F.mse_loss(hidden1_n, hidden2_n)
p_hidden1 = F.softmax(hidden1, dim=-1) + 1e-8
p_hidden2 = F.softmax(hidden2, dim=-1) + 1e-8
kl_loss = F.kl_div(p_hidden2.log(), p_hidden1, reduction='batchmean')
return kl_loss
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
tokenizer = BertTokenizer.from_pretrained('../bert-base-uncased')
# # BERT模型
bert_model = BERTModel().to(device)
# 定义优化器和损失函数
optimizer = torch.optim.Adam(bert_model.parameters(), lr=1e-4)
# 加载文本数据
file_path = "../LLM/data/beauty2014/train_prompts_with_candidate_response.txt"
data = pickle.load(open(file_path, 'rb'))
# 加载GNN-embedding
gnn_path = "../LLM/data/beauty2014/gnn_item_embeddings.txt"
gnn_item_embeddings = pickle.load(open(gnn_path, 'rb'))
# gnn_item_embeddings = np.vstack([item[i].reshape(1, -1) for item in gnn_item_embeddings for i in range(len(item))])
def batch_generator(data, batch_size):
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]
batch_size = 100
# 微调模型
num_epochs = 5
for epoch in range(num_epochs):
logging.info(epoch)
print("----------------------------------第", epoch, "次-----------------------------")
index = 0
all_loss = 0
for batch in tqdm(batch_generator(data, batch_size)):
batch_text = []
for raw_texts in batch:
response_list = raw_texts[0].split(";")
batch_text.append(response_list)
batch_text_embedding = bert_model(batch_text)
gnn_embedding = torch.Tensor(gnn_item_embeddings[index]).to(device)
bedding = torch.Tensor(gnn_item_embeddings[index]).unsqueeze(0).to(device)
loss = info_nce_loss(batch_text_embedding, gnn_embedding)
all_loss += loss
if index % 100 == 0:
logging.info(loss.item())
print(loss.item())
index += 1
optimizer.zero_grad()
loss.backward()
optimizer.step()
logging.info(f"Epoch {epoch}, Loss: {all_loss.item()}")
print(f"Epoch {epoch}, Loss: {all_loss.item()}")
# 保存
torch.save(bert_model.state_dict(), 'bert_model.pth')