bert预训练模型怎么用
时间: 2025-06-19 12:36:49 浏览: 16
### 如何使用 BERT 预训练模型进行微调和预测
#### 加载预训练模型
为了加载并使用预训练的 BERT 模型,通常会借助 Hugging Face 的 `transformers` 库来简化这一过程。通过该库可以直接获取官方发布的各种版本的 BERT 模型及其对应的分词器。
```python
from transformers import BertTokenizer, BertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)
```
这段代码展示了如何从指定路径加载中文版的基础 BERT 模型用于序列分类任务,并设置类别数量为二类[^3]。
#### 数据准备与处理
对于输入数据而言,需要先将其转换成适合喂给 BERT 模型的形式。这涉及到文本编码、填充至固定长度等操作:
```python
def preprocess(texts):
inputs = tokenizer(
texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
return inputs
```
此函数接收一批次原始字符串形式的文本列表作为输入,并返回经过适当格式化后的 PyTorch 张量对象以便后续传递给模型进行推理或训练。
#### 微调模型
当拥有了已经标注好类别的样本集之后就可以着手于对基础 BERT 进行特定下游任务上的调整即所谓的“微调”。这里以简单的二元情感分析为例说明具体做法:
```python
import torch
from torch.utils.data import DataLoader, Dataset
class TextDataset(Dataset):
def __init__(self, texts, labels=None):
self.texts = texts
self.labels = labels
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
item = {key: val[idx] for key, val in preprocess([self.texts[idx]]).items()}
if self.labels is not None:
item['labels'] = torch.tensor(self.labels[idx])
return item
train_texts = ["我喜欢这部电影", "我不喜欢这本书"]
train_labels = [1, 0]
dataset = TextDataset(train_texts, train_labels)
dataloader = DataLoader(dataset, batch_size=8)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
for epoch in range(3): # 训练三个epoch
model.train()
total_loss = 0.
for batch in dataloader:
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.item()
loss.backward()
optimizer.step()
```
上述脚本定义了一个简易的数据集类用来封装待训练的短句与其对应的情感极性标签;接着创建了相应的迭代器供批量读取之用;最后实现了标准的小批次梯度下降更新逻辑完成一轮完整的参数优化循环。
#### 执行预测
一旦完成了必要的迁移学习步骤,则可利用所得到的新权重来进行未知实例所属类型的判定工作:
```python
test_text = ["这部剧真好看"]
with torch.no_grad():
predictions = model(**preprocess(test_text))
predicted_class_id = predictions.logits.argmax().item()
print(f'Predicted class ID: {predicted_class_id}')
```
以上片段演示了怎样针对单条记录执行前向传播计算从而获得其最有可能归属的目标编号。
阅读全文
相关推荐


















