使用 Detectron2 训练自定义目标检测模型
Detectron2 是 Facebook AI Research (FAIR) 开源的一款高效的目标检测框架,支持多种先进算法,如 Faster R-CNN、Mask R-CNN、RetinaNet 等。本文将手把手教你如何使用 Detectron2 训练自定义目标检测数据集。
🚀 一、环境准备
建议使用 Python 3.8+ 和 PyTorch 对应版本,并配置合适的 CUDA 环境。
1. 安装 Detectron2
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install opencv-python pycocotools
pip install git+https://github.com/facebookresearch/detectron2.git
验证是否安装成功:
python -m detectron2.utils.collect_env
📁 二、准备数据集(COCO 格式)
Detectron2 支持的格式为 COCO JSON 格式。
1. 数据集目录结构示意
datasets/
└── my_dataset/
├── annotations/
│ ├── instances_train.json
│ └── instances_val.json
├── train/
└── val/
2. 生成 COCO JSON
推荐使用 LabelMe + labelme2coco.py
脚本或 Roboflow 导出生成 COCO 格式的 JSON 标注文件。
🔗 三、注册数据集
from detectron2.data.datasets import register_coco_instances
register_coco_instances("my_dataset_train", {},
"datasets/my_dataset/annotations/instances_train.json",
"datasets/my_dataset/train")
register_coco_instances("my_dataset_val", {},
"datasets/my_dataset/annotations/instances_val.json",
"datasets/my_dataset/val")
💡 四、设置配置 Config
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2 import model_zoo
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("my_dataset_train",)
cfg.DATASETS.TEST = ("my_dataset_val",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 3000
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3 # 按你的类别数修改
🧑🏫 五、训练模型
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
输出将保存在 output/
文件夹中:
model_final.pth
:最终模型权重last_checkpoint
:最新模型路径- TensorBoard 日志:用于查看 loss、mAP、学习率等曲线
🧾 六、训练过程输出说明
1. 权重文件
output/model_final.pth
output/last_checkpoint
2. TensorBoard 日志文件
tensorboard --logdir output/
3. 控制台评估输出示例:
Evaluation results for bbox:
AP50: 85.3
AP75: 71.2
APs: 45.6
APm: 60.4
APl: 78.9
🧪 七、模型推理测试
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
import cv2
predictor = DefaultPredictor(cfg)
image = cv2.imread("datasets/my_dataset/val/image1.jpg")
outputs = predictor(image)
v = Visualizer(image[:, :, ::-1], MetadataCatalog.get("my_dataset_train"), scale=1.0)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2.imshow("result", out.get_image()[:, :, ::-1])
cv2.waitKey(0)
推理输出结构说明
outputs["instances"]
包含:
pred_boxes
:边界框(tensor[N, 4])scores
:置信度pred_classes
:预测的类别 ID- (可选)
pred_masks
:实例分割掩码(Mask R-CNN)
示例输出:
Instances(num_instances=2, fields=[
pred_boxes: tensor([[ 34.1, 45.0, 180.5, 200.2], [201.1, 40.0, 305.5, 210.3]]),
scores: tensor([0.95, 0.88]),
pred_classes: tensor([0, 2])
])
✅ 总结
使用 Detectron2 训练自定义目标检测模型的流程如下:
- 安装环境
- 准备数据(COCO 格式)
- 注册数据集
- 配置模型参数
- 启动训练过程
- 使用模型进行推理并可视化
- 分析训练输出和模型性能
Detectron2 提供了强大的算法基线与灵活的配置接口,非常适合中高级开发者进行实际项目开发与研究。