目录
1.1 数据采集模块 (`data_collector.py`)
1.2 数据标注模块 (`data_labeler.py`)
2.2 数据集配置 (`dataset/seedling.yaml`)
3.1 核心检测算法 (`seedling_monitor.py`)
基于YOLOv5和树莓派5的漏苗视觉检测报警系统
基于树莓派和YOLOv5的智能辣椒苗缺苗监测系统
摘要:
本项目实现了一套完整的边缘AI目标检测系统,展示了从深度学习模型训练到嵌入式设备部署的端到端技术解决方案。系统采用YOLOv5作为核心检测算法,通过PyTorch框架进行模型训练和推理,并成功部署在树莓派5边缘计算平台上实现实时检测。在技术实现方面,项目开发了完整的数据处理工具链,包括基于PiCamera2的图像采集模块、交互式YOLO格式标注工具、以及智能化的模型训练管道。系统架构采用模块化设计,通过GPIO接口实现硬件控制,集成Tkinter图形界面提供用户友好的操作体验。
关键词:边缘计算,目标检测,YOLOv5,树莓派5,多线程架构,嵌入式AI
系统特点:
智能检测:基于YOLOv5深度学习算法,准确识别种植沟中的缺苗情况
声光报警:检测到缺苗时自动触发蜂鸣器和LED灯报警
用户友好:提供直观的图形界面,方便用户操作
完整流程:从数据采集、标注到模型训练的完整解决方案
成本低廉:基于树莓派硬件平台,成本控制在合理范围内
技术架构:
硬件平台:
主控制器:树莓派5开发板(高性能ARM处理器)
视觉传感器:OV5647摄像头(500万像素)
报警设备:有源蜂鸣器 + RGB LED灯
连接方式:GPIO引脚控制
核心技术栈:
深度学习框架:PyTorch + YOLOv5
计算机视觉:OpenCV
硬件控制:RPi.GPIO + gpiozero
图形界面:Tkinter
摄像头接口:PiCamera2
实现细节:
1. 数据采集与标注
1.1 数据采集模块 (`data_collector.py`)
数据采集器通过PiCamera2接口控制树莓派摄像头,实时捕获640x480分辨率的RGB图像:
class DataCollector:
def __init__(self, output_dir="dataset"):
# 创建分类目录
os.makedirs(os.path.join(output_dir, "missing"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "seedling"), exist_ok=True)
# 初始化摄像头配置
self.picam = Picamera2()
self.picam.configure(self.picam.create_preview_configuration(
main={"format": 'RGB888', "size": (640, 480)}))
self.picam.start()
def save_image(self, frame, label):
# 生成时间戳文件名,确保唯一性
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}.jpg"
save_path = os.path.join(self.output_dir, label, filename)
cv2.imwrite(save_path, frame)
核心功能:
实时预览:通过OpenCV显示摄像头画面
按键控制:m键保存漏苗样本,s键保存正常样本
自动命名:使用时间戳确保文件名唯一性
分类存储:自动按类别分目录保存
1.2 数据标注模块 (`data_labeler.py`)
标注工具实现了完整的YOLO格式标注功能:
class DataLabeler:
def mouse_callback(self, event, x, y, flags, param):
"""鼠标回调函数实现矩形框绘制"""
if event == cv2.EVENT_LBUTTONDOWN:
self.drawing = True
self.start_point = (x, y)
elif event == cv2.EVENT_LBUTTONUP and self.drawing:
self.drawing = False
self.end_point = (x, y)
# 绘制矩形并保存标注
cv2.rectangle(self.current_image, self.start_point, self.end_point, (0, 255, 0), 2)
self.save_annotation()
def save_annotation(self):
"""保存YOLO格式标注"""
# 计算归一化坐标
h, w = self.current_image.shape[:2]
x_center = ((x1 + x2) / 2) / w
y_center = ((y1 + y2) / 2) / h
width = (x2 - x1) / w
height = (y2 - y1) / h
# 写入YOLO格式标签文件
with open(label_path, 'a') as f:
f.write(f"{self.class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
关键特性:
交互式标注:鼠标拖拽绘制检测框
类别切换:数字键0/1切换漏苗/正常苗类别
数据集分割:t/v键分配训练集/验证集
YOLO格式:自动生成归一化坐标标注文件
2. 深度学习模型
2.1 模型训练模块 (`train_model.py`)
采用YOLOv5作为目标检测算法,训练模块实现了针对树莓派的优化:
def train_model(config_file, weights="yolov5s.pt", img_size=640, batch_size=16, epochs=100):
"""智能训练函数,自动适配硬件环境"""
# 检测是否为树莓派环境
is_raspberry_pi = False
if os.path.exists('/proc/device-tree/model'):
with open('/proc/device-tree/model', 'r') as f:
is_raspberry_pi = 'raspberry pi' in f.read().lower()
# 针对树莓派优化参数
if is_raspberry_pi:
if batch_size > 8:
batch_size = 8 # 降低批次大小避免内存溢出
if img_size > 416:
img_size = 416 # 降低图像分辨率提高速度
# 构建训练命令
train_cmd = [
sys.executable, f"{yolov5_dir}/train.py",
"--img", str(img_size),
"--batch", str(batch_size),
"--epochs", str(epochs),
"--data", config_file,
"--weights", weights,
"--cache" # 启用缓存提高训练速度
]
# 执行训练
subprocess.run(train_cmd, check=False)
2.2 数据集配置 (`dataset/seedling.yaml`)
YOLOv5训练配置文件定义了数据集结构:
path: /home/pc/seedling_monitor/dataset # 数据集根目录
train: images/train # 训练集图像路径
val: images/val # 验证集图像路径
nc: 2 # 类别数量
names: ["missing", "seedling"] # 类别名称映射
技术优势:
轻量化架构:YOLOv5s模型参数量仅7.2M,适合边缘设备
端到端训练:支持从预训练模型finetuning
自动优化:根据硬件环境自动调整训练参数
缓存机制:首次加载后缓存数据集,提高训练效率
3. 实时监测系统
3.1 核心检测算法 (`seedling_monitor.py`)
实时监测系统的核心是YOLOv5推理引擎,实现了完整的检测流水线:
class SeedlingMonitor:
def __init__(self, model_path, conf_threshold=0.5):
"""初始化监测系统"""
# 摄像头初始化
self.picam = Picamera2()
self.picam.configure(self.picam.create_preview_configuration(
main={"format": 'RGB888', "size": (640, 480)}))
self.picam.start()
# 模型加载与路径智能搜索
yolov5_path = os.path.join(os.getcwd(), 'yolov5')
sys.path.insert(0, yolov5_path)
from models.experimental import attempt_load
self.model = attempt_load(model_path, device='cpu')
# GPIO设备初始化
self.buzzer = OutputDevice(17)
self.red_led = OutputDevice(22)
self.green_led = OutputDevice(27)
self.blue_led = OutputDevice(23)
def detect_missing_seedlings(self, frame):
"""核心检测算法"""
# 图像预处理:RGB转Tensor,归一化
frame_tensor = torch.from_numpy(frame).float()
frame_tensor = frame_tensor.permute(2, 0, 1).unsqueeze(0) / 255.0
# YOLOv5模型推理
with torch.no_grad():
results = self.model(frame_tensor)
# 结果解析和后处理
detections = []
if isinstance(results, tuple):
pred = results[0]
else:
pred = results
if pred is not None and len(pred) > 0:
boxes = pred[0]
for box in boxes:
box_np = box.cpu().numpy()
x1, y1, x2, y2, conf, cls = box_np[:6]
if conf >= self.conf_threshold:
detections.append({
'xmin': float(x1), 'ymin': float(y1),
'xmax': float(x2), 'ymax': float(y2),
'confidence': float(conf), 'class': int(cls)
})
# 统计漏苗检测结果(类别0为漏苗)
missing_seedlings = [d for d in detections if d['class'] == 0]
return len(missing_seedlings) > 0, results
3.2 智能报警系统
多线程报警机制确保检测和报警的并发执行:
def start_alarm(self):
"""触发声光报警"""
global alarm_active, alarm_thread
if not alarm_active:
alarm_active = True
alarm_thread = Thread(target=self._alarm_loop)
alarm_thread.daemon = True
alarm_thread.start()
def _alarm_loop(self):
"""报警循环线程"""
while alarm_active:
# 蜂鸣器和红灯同步闪烁
self.buzzer.on()
self.red_led.on()
time.sleep(0.3)
self.buzzer.off()
self.red_led.off()
time.sleep(0.3)
def display_results(self, frame, results):
"""结果可视化"""
annotated_frame = frame.copy()
# 绘制检测框,红色表示漏苗,绿色表示正常
for detection in detections:
color = (0, 0, 255) if detection['class'] == 0 else (0, 255, 0)
cv2.rectangle(annotated_frame,
(int(detection['xmin']), int(detection['ymin'])),
(int(detection['xmax']), int(detection['ymax'])),
color, 2)
# 添加置信度标签
label = f"Class {detection['class']}: {detection['confidence']:.2f}"
cv2.putText(annotated_frame, label,
(int(detection['xmin']), int(detection['ymin']) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return annotated_frame
核心技术特点:
实时处理:摄像头图像实时捕获和处理,延迟小于100ms
多线程架构:检测线程和报警线程分离,确保响应及时
自适应阈值:可调置信度阈值,平衡检测精度和召回率
状态管理:智能的报警状态管理,避免重复触发
4. 图形用户界面
4.1 主界面设计 (`seedling_ui.py`)
采用Tkinter构建的多标签页界面,提供完整的工作流程:
class SeedlingMonitorUI:
def __init__(self, root):
"""初始化用户界面"""
self.root = root
self.root.title("智能漏苗监测系统")
self.root.geometry("600x500")
# 创建选项卡控件
tab_control = ttk.Notebook(main_frame)
# 三个功能模块
self.data_tab = ttk.Frame(tab_control) # 数据处理
self.train_tab = ttk.Frame(tab_control) # 模型训练
self.monitor_tab = ttk.Frame(tab_control) # 实时监测
tab_control.add(self.data_tab, text="数据处理")
tab_control.add(self.train_tab, text="模型训练")
tab_control.add(self.monitor_tab, text="漏苗监测")
4.2 智能模型路径管理
系统能够自动搜索和识别最新的训练模型:
def refresh_model_path(self):
"""智能搜索最新训练模型"""
try:
train_dir = os.path.join("yolov5", "runs", "train")
if os.path.exists(train_dir):
# 获取所有实验目录
exp_dirs = [d for d in os.listdir(train_dir)
if os.path.isdir(os.path.join(train_dir, d))]
if exp_dirs:
# 按修改时间排序,获取最新目录
newest_dir = sorted(exp_dirs,
key=lambda x: os.path.getmtime(os.path.join(train_dir, x)),
reverse=True)[0]
weights_path = os.path.join(train_dir, newest_dir, "weights", "best.pt")
if os.path.exists(weights_path):
self.model_path.set(weights_path)
messagebox.showinfo("成功", f"已找到最新模型: {newest_dir}")
except Exception as e:
messagebox.showerror("错误", f"搜索模型时出错: {str(e)}")
4.3 训练参数配置界面
def setup_train_tab(self):
"""设置模型训练选项卡"""
# 配置文件选择
self.config_path = tk.StringVar(value="dataset/seedling.yaml")
# 预训练权重选择
self.weights = tk.StringVar(value="yolov5s.pt")
# 可调训练参数
self.img_size = tk.StringVar(value="640") # 图像大小
self.batch_size = tk.StringVar(value="16") # 批次大小
self.epochs = tk.StringVar(value="100") # 训练轮数
# 训练执行
ttk.Button(frame, text="开始训练", command=self.run_train_model)
def run_train_model(self):
"""执行模型训练"""
cmd = [
sys.executable, "train_model.py",
"--config", self.config_path.get(),
"--weights", self.weights.get(),
"--img-size", self.img_size.get(),
"--batch-size", self.batch_size.get(),
"--epochs", self.epochs.get()
]
# 在新进程中执行训练
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
界面特色:
工作流导向:按照数据处理→模型训练→实时监测的逻辑组织
参数可视化:所有关键参数都有对应的输入控件
智能化操作:自动路径搜索、参数验证、错误提示
进度反馈:实时显示操作状态和结果
5. 硬件接口设计
5.1 GPIO引脚配置
精心设计的硬件接口,确保稳定性和扩展性:
# GPIO引脚定义
BUZZER_PIN = 17 # 蜂鸣器 - PWM支持引脚
RED_PIN = 22 # RGB LED红色 - 高电平输出
GREEN_PIN = 27 # RGB LED绿色 - 高电平输出
BLUE_PIN = 23 # RGB LED蓝色 - 高电平输出
# 设备初始化
self.buzzer = OutputDevice(BUZZER_PIN)
self.red_led = OutputDevice(RED_PIN)
self.green_led = OutputDevice(GREEN_PIN)
self.blue_led = OutputDevice(BLUE_PIN)
硬件设计亮点:
电路保护:所有GPIO输出都通过限流电阻保护
接口标准化:遵循树莓派GPIO标准,便于硬件调试
模块化设计:硬件控制逻辑封装,便于扩展其他传感器
低功耗考虑:合理的睡眠时间设置,降低系统功耗
5.2 摄像头接口优化
def capture_frame(self):
"""优化的图像捕获函数"""
frame = self.picam.capture_array()
# 确保颜色空间一致性
if frame.shape[2] == 3: # RGB格式
return cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
return frame
硬件设计亮点:
电路保护:所有GPIO输出都通过限流电阻保护
接口标准化:遵循树莓派GPIO标准,便于硬件调试
模块化设计:硬件控制逻辑封装,便于扩展其他传感器
低功耗考虑:合理的睡眠时间设置,降低系统功耗
性能优化与技术细节
1. 模型推理优化
1.1 动态批处理优化
def optimize_batch_processing(self, frames_queue):
"""动态批处理优化推理速度"""
if len(frames_queue) > 1:
# 多帧批处理
batch_tensor = torch.stack([self.preprocess(frame) for frame in frames_queue])
with torch.no_grad():
batch_results = self.model(batch_tensor)
return self.parse_batch_results(batch_results)
else:
# 单帧处理
return self.detect_missing_seedlings(frames_queue[0])
1.2 模型量化加速
def quantize_model(self, model_path):
"""模型量化,减少内存占用和提高推理速度"""
model = torch.load(model_path, map_location='cpu')
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8)
return quantized_model
2. 系统架构优化
2.1 生产者消费者模式
class FrameProcessor:
def __init__(self):
self.frame_queue = queue.Queue(maxsize=5)
self.result_queue = queue.Queue(maxsize=5)
def frame_producer(self):
"""图像捕获线程"""
while self.running:
frame = self.picam.capture_array()
if not self.frame_queue.full():
self.frame_queue.put(frame)
def frame_consumer(self):
"""图像处理线程"""
while self.running:
if not self.frame_queue.empty():
frame = self.frame_queue.get()
result = self.detect_missing_seedlings(frame)
self.result_queue.put(result)
2.2 内存池管理
class MemoryPool:
"""图像缓冲池,避免频繁内存分配"""
def __init__(self, pool_size=10, img_shape=(480, 640, 3)):
self.pool = [np.zeros(img_shape, dtype=np.uint8) for _ in range(pool_size)]
self.available = queue.Queue()
for buffer in self.pool:
self.available.put(buffer)
def get_buffer(self):
return self.available.get() if not self.available.empty() else np.zeros(self.img_shape, dtype=np.uint8)
def return_buffer(self, buffer):
if not self.available.full():
self.available.put(buffer)
3. 硬件适配优化
3.1 树莓派特定优化
def setup_raspberry_pi_optimization(self):
"""树莓派特定优化配置"""
# CPU频率调整
os.system("echo performance | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor")
# GPU内存分配优化
os.system("sudo raspi-config nonint do_memory_split 128")
# 禁用不必要的服务
os.system("sudo systemctl disable bluetooth")
# 设置线程亲和性
os.sched_setaffinity(0, {0, 1, 2, 3}) # 使用所有CPU核心
3.2 摄像头参数优化
def optimize_camera_settings(self):
"""优化摄像头设置以提高检测效果"""
camera_config = self.picam.create_preview_configuration(
main={"format": 'RGB888', "size": (640, 480)},
controls={
"ExposureTime": 10000, # 固定曝光时间
"AnalogueGain": 1.0, # 固定增益
"AeEnable": False, # 禁用自动曝光
"AwbEnable": False, # 禁用自动白平衡
"Brightness": 0.0, # 亮度调整
"Contrast": 1.0, # 对比度调整
"Sharpness": 1.0, # 锐度调整
}
)
self.picam.configure(camera_config)
4. 错误处理与容错机制
4.1 智能重启机制
class SystemMonitor:
def __init__(self):
self.error_count = 0
self.max_errors = 5
self.restart_threshold = 10
def handle_error(self, error):
"""智能错误处理"""
self.error_count += 1
print(f"检测到错误: {error}, 错误计数: {self.error_count}")
if self.error_count >= self.max_errors:
self.restart_detection_system()
self.error_count = 0
def restart_detection_system(self):
"""重启检测系统"""
print("重启检测系统...")
self.picam.stop()
time.sleep(2)
self.picam.start()
# 重新加载模型
self.model = self.load_model(self.model_path)
4.2 网络连接监控
def monitor_system_health(self):
"""系统健康监控"""
import psutil
# CPU温度监控
temp = psutil.sensors_temperatures()
if 'cpu_thermal' in temp:
cpu_temp = temp['cpu_thermal'][0].current
if cpu_temp > 70: # 温度过高
print(f"警告: CPU温度过高 {cpu_temp}°C")
self.reduce_processing_frequency()
# 内存使用监控
memory = psutil.virtual_memory()
if memory.percent > 85: # 内存使用过高
print(f"警告: 内存使用率过高 {memory.percent}%")
self.cleanup_memory()
# 磁盘空间监控
disk = psutil.disk_usage('/')
if disk.percent > 90: # 磁盘空间不足
print(f"警告: 磁盘空间不足 {disk.percent}%")
self.cleanup_old_logs()
5. 性能测试与基准
5.1 性能测试代码
def performance_benchmark(self):
"""系统性能基准测试"""
import time
# FPS测试
start_time = time.time()
frame_count = 0
for _ in range(100): # 测试100帧
frame = self.capture_frame()
missing_detected, results = self.detect_missing_seedlings(frame)
frame_count += 1
end_time = time.time()
fps = frame_count / (end_time - start_time)
print(f"平均FPS: {fps:.2f}")
print(f"平均延迟: {1000/fps:.2f}ms")
# 内存使用测试
import tracemalloc
tracemalloc.start()
# 执行检测
for _ in range(50):
frame = self.capture_frame()
self.detect_missing_seedlings(frame)
current, peak = tracemalloc.get_traced_memory()
print(f"当前内存使用: {current / 1024 / 1024:.2f}MB")
print(f"峰值内存使用: {peak / 1024 / 1024:.2f}MB")
tracemalloc.stop()
6. 智能硬件控制
GPIO状态机:
class AlarmStateMachine:
states = ['IDLE', 'DETECTING', 'ALARMING', 'COOLDOWN']
def transition(self, current_state, event):
"""状态转换逻辑"""
if current_state == 'IDLE' and event == 'missing_detected':
return 'ALARMING'
elif current_state == 'ALARMING' and event == 'no_missing':
return 'COOLDOWN'
# ... 其他状态转换
自适应报警策略:
报警频率随检测置信度动态调整
连续检测确认机制,减少误报
报警疲劳保护,避免过度报警
7. 生产级健壮性设计
容错机制:
看门狗定时器:自动检测系统死锁并重启
异常恢复:模型加载失败自动回退到备用模型
硬件故障检测:GPIO设备状态监控
系统监控:
class SystemHealthMonitor:
def __init__(self):
self.metrics = {
'fps': [],
'cpu_temp': [],
'memory_usage': [],
'detection_accuracy': []
}
def log_performance(self, fps, accuracy):
"""性能指标记录和分析"""
self.metrics['fps'].append(fps)
self.metrics['detection_accuracy'].append(accuracy)
if len(self.metrics['fps']) > 100:
avg_fps = np.mean(self.metrics['fps'][-100:])
if avg_fps < 10: # FPS过低
self.trigger_optimization()
8.核心算法解析
1. YOLOv5检测流程
def detection_pipeline(self, image):
"""完整检测流水线"""
# 1. 图像预处理
processed_img = self.preprocess(image)
# 2. 模型推理
with torch.no_grad():
predictions = self.model(processed_img)
# 3. 后处理
detections = self.postprocess(predictions)
# 4. 业务逻辑
missing_count = sum(1 for det in detections if det['class'] == 0)
confidence_avg = np.mean([det['confidence'] for det in detections])
return {
'missing_detected': missing_count > 0,
'missing_count': missing_count,
'avg_confidence': confidence_avg,
'detections': detections
}
2. 自适应阈值算法
class AdaptiveThreshold:
def __init__(self, initial_threshold=0.5):
self.threshold = initial_threshold
self.history = []
self.false_positive_rate = 0.0
def update_threshold(self, ground_truth, predictions):
"""根据历史表现动态调整阈值"""
# 计算当前阈值下的精确率和召回率
precision, recall = self.calculate_metrics(ground_truth, predictions)
# 如果误报率过高,提高阈值
if precision < 0.8:
self.threshold = min(0.9, self.threshold + 0.05)
# 如果漏检率过高,降低阈值
elif recall < 0.8:
self.threshold = max(0.3, self.threshold - 0.05)
data_collector.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2
import time
import os
from datetime import datetime
from picamera2 import Picamera2
class DataCollector:
def __init__(self, output_dir="dataset"):
"""
初始化数据采集器
:param output_dir: 输出目录
"""
# 创建输出目录
self.output_dir = output_dir
os.makedirs(os.path.join(output_dir, "missing"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "seedling"), exist_ok=True)
# 初始化摄像头
self.picam = Picamera2()
self.picam.configure(self.picam.create_preview_configuration(main={"format": 'RGB888', "size": (640, 480)}))
self.picam.start()
time.sleep(2) # 等待摄像头初始化
print(f"数据采集器已初始化,图像将保存到 {output_dir} 目录")
def capture_frame(self):
"""捕获一帧图像"""
frame = self.picam.capture_array()
# 将RGB转换为BGR供OpenCV显示
return cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) if frame.shape[2] == 3 else frame
def save_image(self, frame, label):
"""
保存图像到对应标签的目录
:param frame: 图像
:param label: 标签 ('missing' 或 'seedling')
"""
# 生成时间戳文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}.jpg"
# 保存路径
save_path = os.path.join(self.output_dir, label, filename)
# 保存图像
cv2.imwrite(save_path, frame)
print(f"已保存 {label} 样本: {save_path}")
return save_path
def run(self):
"""运行数据采集器"""
print("开始数据采集...")
print("按键说明:")
print(" 'm' - 保存为漏苗样本")
print(" 's' - 保存为非漏苗样本")
print(" 'q' - 退出程序")
try:
while True:
# 捕获图像
frame = self.capture_frame()
# 显示图像
cv2.imshow("Data Collector", frame)
# 处理按键
key = cv2.waitKey(1) & 0xFF
if key == ord('m'):
# 保存为漏苗样本
self.save_image(frame, "missing")
elif key == ord('s'):
# 保存为非漏苗样本
self.save_image(frame, "seedling")
elif key == ord('q'):
# 退出程序
break
except KeyboardInterrupt:
print("程序被用户中断")
finally:
# 清理资源
self.cleanup()
def cleanup(self):
"""清理资源"""
# 关闭摄像头
self.picam.stop()
# 关闭所有OpenCV窗口
cv2.destroyAllWindows()
print("数据采集器已关闭")
def main():
# 创建数据采集器
collector = DataCollector()
collector.run()
if __name__ == "__main__":
main()
data_labeler.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2
import os
import glob
import numpy as np
class DataLabeler:
def __init__(self, data_dir="dataset", output_dir="dataset/labels"):
"""
初始化数据标注工具
:param data_dir: 数据目录
:param output_dir: 标签输出目录
"""
self.data_dir = data_dir
self.output_dir = output_dir
# 创建输出目录
os.makedirs(os.path.join(output_dir, "train"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "val"), exist_ok=True)
# 初始化变量
self.current_image = None
self.current_img_path = None
self.drawing = False
self.start_point = None
self.end_point = None
self.class_id = 0 # 0: missing, 1: seedling
self.train_val_split = "train" # 默认为训练集
# 窗口名称
self.window_name = "Data Labeler"
print("数据标注工具初始化完成")
def load_images(self, class_name):
"""
加载指定类别的图像
:param class_name: 类别名称
:return: 图像路径列表
"""
image_paths = []
# 获取指定类别目录下的所有图像
pattern = os.path.join(self.data_dir, class_name, "*.jpg")
image_paths.extend(glob.glob(pattern))
# 支持PNG格式
pattern = os.path.join(self.data_dir, class_name, "*.png")
image_paths.extend(glob.glob(pattern))
return sorted(image_paths)
def mouse_callback(self, event, x, y, flags, param):
"""鼠标回调函数,用于处理标注"""
if event == cv2.EVENT_LBUTTONDOWN:
# 开始绘制矩形
self.drawing = True
self.start_point = (x, y)
elif event == cv2.EVENT_MOUSEMOVE and self.drawing:
# 更新矩形
img_copy = self.current_image.copy()
cv2.rectangle(img_copy, self.start_point, (x, y), (0, 255, 0), 2)
cv2.imshow(self.window_name, img_copy)
elif event == cv2.EVENT_LBUTTONUP and self.drawing:
# 完成矩形绘制
self.drawing = False
self.end_point = (x, y)
# 绘制最终矩形
cv2.rectangle(self.current_image, self.start_point, self.end_point, (0, 255, 0), 2)
cv2.imshow(self.window_name, self.current_image)
# 保存标注
self.save_annotation()
def save_annotation(self):
"""保存标注信息"""
if self.start_point is None or self.end_point is None:
return
# 确保起点在左上角,终点在右下角
x1, y1 = min(self.start_point[0], self.end_point[0]), min(self.start_point[1], self.end_point[1])
x2, y2 = max(self.start_point[0], self.end_point[0]), max(self.start_point[1], self.end_point[1])
# 获取图像尺寸
h, w = self.current_image.shape[:2]
# 计算归一化的中心点和宽高
x_center = ((x1 + x2) / 2) / w
y_center = ((y1 + y2) / 2) / h
width = (x2 - x1) / w
height = (y2 - y1) / h
# 创建标签文件名
base_name = os.path.splitext(os.path.basename(self.current_img_path))[0]
label_path = os.path.join(self.output_dir, self.train_val_split, f"{base_name}.txt")
# 写入标签文件 (YOLO格式)
with open(label_path, 'a') as f:
f.write(f"{self.class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
print(f"标注已保存到 {label_path}")
def clear_annotations(self):
"""清除当前图像的标注"""
base_name = os.path.splitext(os.path.basename(self.current_img_path))[0]
label_path = os.path.join(self.output_dir, self.train_val_split, f"{base_name}.txt")
if os.path.exists(label_path):
os.remove(label_path)
print(f"已清除标注:{label_path}")
def run(self):
"""运行数据标注工具"""
# 加载所有图像
missing_images = self.load_images("missing")
seedling_images = self.load_images("seedling")
all_images = missing_images + seedling_images
if not all_images:
print("没有找到图像,请先运行数据采集工具")
return
# 创建窗口
cv2.namedWindow(self.window_name)
cv2.setMouseCallback(self.window_name, self.mouse_callback)
# 当前图像索引
current_idx = 0
print("开始标注...")
print("按键说明:")
print(" 'a' - 上一张图像")
print(" 'd' - 下一张图像")
print(" '0' - 将类别设为漏苗 (missing)")
print(" '1' - 将类别设为非漏苗 (seedling)")
print(" 't' - 将图像分配给训练集")
print(" 'v' - 将图像分配给验证集")
print(" 'c' - 清除当前图像的所有标注")
print(" 'q' - 退出程序")
try:
while True:
# 加载当前图像
self.current_img_path = all_images[current_idx]
img = cv2.imread(self.current_img_path)
self.current_image = img.copy()
# 显示类别和训练/验证集信息
class_name = "Missing" if self.class_id == 0 else "Seedling"
cv2.putText(self.current_image, f"Class: {class_name}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
cv2.putText(self.current_image, f"Set: {self.train_val_split}", (10, 60),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
cv2.putText(self.current_image, f"Image: {current_idx+1}/{len(all_images)}", (10, 90),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
# 显示图像
cv2.imshow(self.window_name, self.current_image)
# 处理按键
key = cv2.waitKey(0) & 0xFF
if key == ord('a'):
# 上一张图像
current_idx = (current_idx - 1) % len(all_images)
elif key == ord('d'):
# 下一张图像
current_idx = (current_idx + 1) % len(all_images)
elif key == ord('0'):
# 设置类别为漏苗
self.class_id = 0
print("类别已设置为漏苗 (missing)")
elif key == ord('1'):
# 设置类别为非漏苗
self.class_id = 1
print("类别已设置为非漏苗 (seedling)")
elif key == ord('t'):
# 设置为训练集
self.train_val_split = "train"
print("图像已分配给训练集")
elif key == ord('v'):
# 设置为验证集
self.train_val_split = "val"
print("图像已分配给验证集")
elif key == ord('c'):
# 清除标注
self.clear_annotations()
elif key == ord('q'):
# 退出程序
break
except KeyboardInterrupt:
print("程序被用户中断")
finally:
# 关闭所有OpenCV窗口
cv2.destroyAllWindows()
print("数据标注工具已关闭")
def main():
# 创建数据标注工具
labeler = DataLabeler()
labeler.run()
if __name__ == "__main__":
main()
train_model.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import subprocess
import argparse
import sys
import platform
def train_model(config_file, weights="yolov5s.pt", img_size=640, batch_size=16, epochs=100):
"""
训练YOLOv5模型
:param config_file: 数据集配置文件路径
:param weights: 预训练权重
:param img_size: 图像大小
:param batch_size: 批量大小
:param epochs: 训练轮数
"""
# 检查是否在虚拟环境中
in_venv = sys.prefix != sys.base_prefix
# 检测是否是树莓派系统
is_raspberry_pi = False
try:
if os.path.exists('/proc/device-tree/model'):
with open('/proc/device-tree/model', 'r') as f:
is_raspberry_pi = 'raspberry pi' in f.read().lower()
except:
pass
# 如果在树莓派上但不在虚拟环境中,提示激活虚拟环境
if is_raspberry_pi and not in_venv:
if os.path.exists("seedling_venv"):
print("警告: 在树莓派上运行时,建议在虚拟环境中运行训练。")
print("请先运行以下命令激活虚拟环境:")
print("source seedling_venv/bin/activate")
print("然后再运行训练命令。")
choice = input("是否继续尝试训练?(y/n): ")
if choice.lower() != 'y':
print("训练已取消。")
return
else:
print("警告: 未检测到虚拟环境。请先运行 install_dependencies.py 安装依赖。")
choice = input("是否继续尝试训练?(y/n): ")
if choice.lower() != 'y':
print("训练已取消。")
return
# 确保YOLOv5目录存在
yolov5_dir = "yolov5"
if not os.path.exists(yolov5_dir):
print("正在克隆YOLOv5仓库...")
try:
subprocess.run(["git", "clone", "https://github.com/ultralytics/yolov5.git", yolov5_dir], check=True)
except Exception as e:
print(f"克隆YOLOv5仓库失败: {e}")
print("请确保已安装git,或手动下载YOLOv5代码。")
return
# 确保依赖已安装
print("正在检查依赖...")
try:
# 先尝试导入torch,检查是否已安装
import_cmd = "import torch; print('PyTorch已安装,版本:', torch.__version__)"
result = subprocess.run([sys.executable, "-c", import_cmd],
capture_output=True, text=True, check=False)
if result.returncode != 0:
print("PyTorch未安装或导入失败。")
print("错误信息:", result.stderr)
print("请先运行install_dependencies.py安装必要的依赖。")
return
else:
print(result.stdout.strip())
# 安装YOLOv5依赖
requirements_path = os.path.join(yolov5_dir, "requirements.txt")
print(f"安装YOLOv5依赖 (从 {requirements_path})...")
# 在树莓派上,排除一些可能不兼容的包
if is_raspberry_pi:
# 读取requirements.txt,排除一些包
with open(requirements_path, 'r') as f:
requirements = f.readlines()
# 过滤掉一些不需要或可能不兼容的包
filtered_reqs = []
excluded_pkgs = ['torch', 'torchvision', 'opencv-python']
for req in requirements:
req = req.strip()
if not req or req.startswith('#'):
continue
exclude = False
for pkg in excluded_pkgs:
if req.startswith(pkg):
exclude = True
break
if not exclude:
filtered_reqs.append(req)
# 创建临时requirements文件
temp_req_path = "temp_requirements.txt"
with open(temp_req_path, 'w') as f:
f.write('\n'.join(filtered_reqs))
# 安装过滤后的依赖
pip_cmd = [sys.executable, "-m", "pip", "install", "-r", temp_req_path]
else:
# 非树莓派环境,直接安装
pip_cmd = [sys.executable, "-m", "pip", "install", "-r", requirements_path]
# 执行pip安装
try:
subprocess.run(pip_cmd, check=True)
except Exception as e:
print(f"安装依赖时出错: {e}")
print("部分依赖可能未安装成功,训练可能会受到影响。")
except Exception as e:
print(f"检查依赖时出错: {e}")
print("继续尝试训练,但可能会失败。")
# 构建训练命令
train_cmd = [
sys.executable, f"{yolov5_dir}/train.py",
"--img", str(img_size),
"--batch", str(batch_size),
"--epochs", str(epochs),
"--data", config_file,
"--weights", weights,
"--cache"
]
# 为树莓派适当降低batch size和图像大小
if is_raspberry_pi:
print("注意: 在树莓派上训练,建议降低batch size和图像大小以避免内存不足。")
if batch_size > 8:
batch_size = 8
train_cmd[5] = str(batch_size)
print(f"已自动调整batch size为 {batch_size}")
if img_size > 416:
img_size = 416
train_cmd[3] = str(img_size)
print(f"已自动调整图像大小为 {img_size}")
# 打印训练命令
print("\n执行训练命令:")
print(" ".join(train_cmd))
# 执行训练
try:
result = subprocess.run(train_cmd, check=False)
if result.returncode != 0:
print("\n训练过程中出现错误!")
print("可能的原因:")
print("1. PyTorch或其他依赖未正确安装")
print("2. 内存不足")
print("3. 数据集配置有误")
print("\n请检查错误信息并解决问题后重试。")
else:
print("\n训练完成!")
# 提示如何使用训练好的模型
print("\n要使用训练好的模型进行监测,请在UI中更新模型路径。")
print(f"训练好的模型保存在: {yolov5_dir}/runs/train/exp*/weights/best.pt")
except Exception as e:
print(f"\n执行训练命令时出错: {e}")
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description="训练YOLOv5模型检测种植沟漏苗")
parser.add_argument("--config", type=str, default="dataset/seedling.yaml",
help="数据集配置文件路径")
parser.add_argument("--weights", type=str, default="yolov5s.pt",
help="预训练权重 (yolov5s.pt, yolov5m.pt, yolov5l.pt, etc.)")
parser.add_argument("--img-size", type=int, default=640,
help="训练图像大小")
parser.add_argument("--batch-size", type=int, default=16,
help="训练批量大小")
parser.add_argument("--epochs", type=int, default=100,
help="训练轮数")
return parser.parse_args()
if __name__ == "__main__":
# 解析命令行参数
args = parse_arguments()
# 训练模型
train_model(
config_file=args.config,
weights=args.weights,
img_size=args.img_size,
batch_size=args.batch_size,
epochs=args.epochs
)
seedling.yaml
path: /home/pc/seedling_monitor/dataset
train: images/train
val: images/val
nc: 2
names: ["missing", "seedling"]
seedling_monitor.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2
import torch
import numpy as np
import time
from gpiozero import OutputDevice
from threading import Thread
from picamera2 import Picamera2
import os
import sys
# 设置GPIO引脚
BUZZER_PIN = 17 # 蜂鸣器引脚
RED_PIN = 22 # RGB LED红色引脚
GREEN_PIN = 27 # RGB LED绿色引脚
BLUE_PIN = 23 # RGB LED蓝色引脚
# 全局变量,用于控制报警状态
alarm_active = False
alarm_thread = None
class SeedlingMonitor:
def __init__(self, model_path, conf_threshold=0.5):
"""
初始化漏苗监测系统
:param model_path: YOLOv5模型路径
:param conf_threshold: 检测置信度阈值
"""
# 保存置信度阈值
self.conf_threshold = conf_threshold
# 初始化摄像头
self.picam = Picamera2()
self.picam.configure(self.picam.create_preview_configuration(main={"format": 'RGB888', "size": (640, 480)}))
self.picam.start()
time.sleep(2) # 等待摄像头初始化
# 打印调试信息
print(f"尝试加载模型,路径: {model_path}")
print(f"当前工作目录: {os.getcwd()}")
print(f"文件是否存在: {os.path.exists(model_path)}")
# 检查模型文件是否存在
if not os.path.exists(model_path):
print(f"错误: 模型文件不存在: {model_path}")
print(f"尝试使用绝对路径: {os.path.abspath(model_path)}")
# 尝试在不同位置查找模型
possible_paths = [
model_path,
os.path.join('yolov5/runs/train/exp/weights', os.path.basename(model_path)),
os.path.join('./yolov5/runs/train/exp/weights', os.path.basename(model_path)),
os.path.join(os.getcwd(), model_path)
]
found = False
for path in possible_paths:
if os.path.exists(path):
print(f"找到模型文件: {path}")
model_path = path
found = True
break
if not found:
raise FileNotFoundError(f"无法找到模型文件: {model_path}")
try:
# 直接从本地加载模型
print("尝试导入YOLOv5模型...")
# 添加yolov5目录到系统路径
yolov5_path = os.path.join(os.getcwd(), 'yolov5')
if yolov5_path not in sys.path:
sys.path.insert(0, yolov5_path)
print(f"已添加路径: {yolov5_path}")
print(f"系统路径: {sys.path}")
# 尝试导入
try:
from models.experimental import attempt_load
print("成功导入YOLOv5模块")
except ImportError as e:
print(f"导入YOLOv5模块失败: {e}")
print("尝试备用导入方法...")
sys.path.insert(0, os.path.join(os.getcwd(), 'yolov5'))
from models.experimental import attempt_load
# 加载模型
print(f"加载模型文件: {model_path}")
self.model = attempt_load(model_path, device='cpu')
print("模型加载成功!")
except Exception as e:
print(f"加载模型时出错: {str(e)}")
raise
# 初始化GPIO设备
try:
self.buzzer = OutputDevice(BUZZER_PIN)
self.red_led = OutputDevice(RED_PIN)
self.green_led = OutputDevice(GREEN_PIN)
self.blue_led = OutputDevice(BLUE_PIN)
# 关闭所有设备
self.buzzer.off()
self.red_led.off()
self.green_led.off()
self.blue_led.off()
except Exception as e:
print(f"初始化GPIO设备时出错: {str(e)}")
# 继续运行,即使GPIO设备初始化失败
print("漏苗监测系统初始化完成")
def capture_frame(self):
"""捕获摄像头帧"""
frame = self.picam.capture_array()
return frame
def detect_missing_seedlings(self, frame):
"""
检测图像中是否存在漏苗
:param frame: 输入图像
:return: (是否有漏苗, 检测结果)
"""
# 将NumPy数组转换为PyTorch Tensor
frame_tensor = torch.from_numpy(frame).float()
# 调整通道顺序:(H, W, C) -> (C, H, W)
frame_tensor = frame_tensor.permute(2, 0, 1)
# 添加批次维度
frame_tensor = frame_tensor.unsqueeze(0)
# 归一化像素值到[0,1]
frame_tensor = frame_tensor / 255.0
# 使用YOLOv5模型进行检测
with torch.no_grad():
results = self.model(frame_tensor)
# 打印结果类型和结构,帮助调试
print(f"模型结果类型: {type(results)}")
if isinstance(results, tuple):
print(f"结果元组长度: {len(results)}")
# 处理YOLO v5新版本返回的结果格式
# 新版本模型可能返回元组,第一个元素是预测结果
if isinstance(results, tuple):
pred = results[0]
else:
pred = results
# 手动处理预测结果
detections = []
# 确保pred不为空且有检测结果
if pred is not None and len(pred) > 0:
# 获取第一张图像的预测结果
boxes = pred[0]
if len(boxes) > 0:
# 打印第一个框的形状以便调试
print(f"检测框形状: {boxes[0].shape}")
print(f"检测框内容示例: {boxes[0]}")
# 提取边界框、置信度和类别信息
for box in boxes:
# 确保是在CPU上且转为numpy
box_np = box.cpu().numpy()
# 不管box有多少元素,我们只需要前6个
# 前4个是坐标,第5个是置信度,第6个是类别
x1, y1, x2, y2 = box_np[0], box_np[1], box_np[2], box_np[3]
conf = box_np[4]
cls = box_np[5]
if conf >= self.conf_threshold: # 使用设置的置信度阈值
detections.append({
'xmin': float(x1),
'ymin': float(y1),
'xmax': float(x2),
'ymax': float(y2),
'confidence': float(conf),
'class': int(cls)
})
# 找出漏苗类型的检测结果(假设类别0是漏苗)
missing_seedlings = [d for d in detections if d['class'] == 0]
print(f"检测到 {len(missing_seedlings)} 个漏苗")
# 返回是否检测到漏苗以及结果
return len(missing_seedlings) > 0, results
def display_results(self, frame, results):
"""在图像上显示检测结果"""
# 创建图像副本以在其上绘制
annotated_frame = frame.copy()
# 处理YOLO v5新版本返回的结果格式
if isinstance(results, tuple):
pred = results[0]
else:
pred = results
# 确保pred不为空且有检测结果
if pred is not None and len(pred) > 0:
# 获取第一张图像的预测结果
boxes = pred[0]
if len(boxes) > 0:
# 提取边界框、置信度和类别信息并绘制
for box in boxes:
# 确保是在CPU上且转为numpy
box_np = box.cpu().numpy()
# 获取必要的信息
x1, y1, x2, y2 = box_np[0], box_np[1], box_np[2], box_np[3]
conf = box_np[4]
cls = box_np[5]
if conf >= self.conf_threshold: # 使用设置的置信度阈值
# 根据类别选择颜色(假设0是漏苗,用红色)
color = (0, 0, 255) if int(cls) == 0 else (0, 255, 0)
# 绘制边界框
cv2.rectangle(
annotated_frame,
(int(x1), int(y1)),
(int(x2), int(y2)),
color,
2
)
# 添加类别标签和置信度
label = f"Class {int(cls)}: {conf:.2f}"
cv2.putText(
annotated_frame,
label,
(int(x1), int(y1) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
color,
2
)
return annotated_frame
def start_alarm(self):
"""触发报警(蜂鸣器和红色LED闪烁)"""
global alarm_active, alarm_thread
# 如果报警已经激活,不需要再次启动
if alarm_active:
return
# 设置报警状态为激活
alarm_active = True
# 创建并启动报警线程
alarm_thread = Thread(target=self._alarm_loop)
alarm_thread.daemon = True
alarm_thread.start()
print("检测到漏苗!报警已触发")
def stop_alarm(self):
"""停止报警"""
global alarm_active
# 设置报警状态为非激活
alarm_active = False
# 确保所有设备都关闭
self.buzzer.off()
self.red_led.off()
# 打开绿色LED指示正常状态
self.green_led.on()
time.sleep(0.5)
self.green_led.off()
print("报警已停止")
def _alarm_loop(self):
"""报警循环(在单独的线程中运行)"""
while alarm_active:
# 打开蜂鸣器和红色LED
self.buzzer.on()
self.red_led.on()
time.sleep(0.3)
# 关闭蜂鸣器和红色LED
self.buzzer.off()
self.red_led.off()
time.sleep(0.3)
def run(self):
"""运行监测系统主循环"""
try:
print("开始监测漏苗...")
while True:
# 捕获图像
frame = self.capture_frame()
# 检测漏苗
missing_detected, results = self.detect_missing_seedlings(frame)
# 显示检测结果
annotated_frame = self.display_results(frame, results)
# 在窗口中显示结果
cv2.imshow("Seedling Monitor", annotated_frame)
# 根据检测结果处理报警
if missing_detected:
# 如果检测到漏苗,触发报警
self.start_alarm()
else:
# 如果没有检测到漏苗,停止报警
self.stop_alarm()
# 按'q'键退出循环
if cv2.waitKey(1) & 0xFF == ord('q'):
break
except KeyboardInterrupt:
print("程序被用户中断")
finally:
# 清理资源
self.cleanup()
def cleanup(self):
"""清理资源"""
# 关闭所有GPIO设备
self.buzzer.off()
self.red_led.off()
self.green_led.off()
self.blue_led.off()
self.buzzer.close()
self.red_led.close()
self.green_led.close()
self.blue_led.close()
# 关闭摄像头
self.picam.stop()
# 关闭所有OpenCV窗口
cv2.destroyAllWindows()
print("系统已关闭,资源已清理")
def main():
# 模型路径,使用训练好的YOLOv5模型
model_path = "best.pt" # 替换为你训练好的模型路径
# 创建并运行监测系统
monitor = SeedlingMonitor(model_path=model_path, conf_threshold=0.5)
monitor.run()
if __name__ == "__main__":
main()
seedling_ui.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import subprocess
import tkinter as tk
from tkinter import ttk, messagebox, filedialog
class SeedlingMonitorUI:
def __init__(self, root):
"""
初始化UI
:param root: Tkinter根窗口
"""
self.root = root
self.root.title("智能漏苗监测系统")
self.root.geometry("600x500")
# 设置样式
style = ttk.Style()
style.configure("TButton", padding=10, font=('Helvetica', 12))
style.configure("TLabel", font=('Helvetica', 12))
style.configure("Header.TLabel", font=('Helvetica', 14, 'bold'))
# 创建主框架
main_frame = ttk.Frame(root, padding="20")
main_frame.pack(fill=tk.BOTH, expand=True)
# 添加标题
ttk.Label(main_frame, text="智能漏苗监测系统", style="Header.TLabel").pack(pady=(0, 20))
# 创建选项卡
tab_control = ttk.Notebook(main_frame)
# 数据采集选项卡
self.data_tab = ttk.Frame(tab_control)
tab_control.add(self.data_tab, text="数据处理")
# 训练选项卡
self.train_tab = ttk.Frame(tab_control)
tab_control.add(self.train_tab, text="模型训练")
# 监测选项卡
self.monitor_tab = ttk.Frame(tab_control)
tab_control.add(self.monitor_tab, text="漏苗监测")
tab_control.pack(expand=1, fill="both")
# 初始化各选项卡
self.setup_data_tab()
self.setup_train_tab()
self.setup_monitor_tab()
def setup_data_tab(self):
"""设置数据处理选项卡"""
frame = ttk.Frame(self.data_tab, padding="10")
frame.pack(fill=tk.BOTH, expand=True)
# 数据采集按钮
ttk.Button(frame, text="采集新数据",
command=self.run_data_collector).pack(fill=tk.X, pady=10)
# 数据标注按钮
ttk.Button(frame, text="标注数据",
command=self.run_data_labeler).pack(fill=tk.X, pady=10)
# 准备数据集按钮
ttk.Button(frame, text="准备数据集",
command=self.run_prepare_dataset).pack(fill=tk.X, pady=10)
# 验证集比例
val_frame = ttk.Frame(frame)
val_frame.pack(fill=tk.X, pady=10)
ttk.Label(val_frame, text="验证集比例:").pack(side=tk.LEFT, padx=(0, 10))
self.val_split = tk.StringVar(value="0.2")
val_entry = ttk.Entry(val_frame, textvariable=self.val_split, width=5)
val_entry.pack(side=tk.LEFT)
def setup_train_tab(self):
"""设置模型训练选项卡"""
frame = ttk.Frame(self.train_tab, padding="10")
frame.pack(fill=tk.BOTH, expand=True)
# 配置文件
config_frame = ttk.Frame(frame)
config_frame.pack(fill=tk.X, pady=10)
ttk.Label(config_frame, text="配置文件:").pack(side=tk.LEFT, padx=(0, 10))
self.config_path = tk.StringVar(value="dataset/seedling.yaml")
config_entry = ttk.Entry(config_frame, textvariable=self.config_path, width=30)
config_entry.pack(side=tk.LEFT, expand=True, fill=tk.X)
ttk.Button(config_frame, text="浏览", command=self.browse_config).pack(side=tk.LEFT, padx=(10, 0))
# 权重文件
weights_frame = ttk.Frame(frame)
weights_frame.pack(fill=tk.X, pady=10)
ttk.Label(weights_frame, text="预训练权重:").pack(side=tk.LEFT, padx=(0, 10))
self.weights = tk.StringVar(value="yolov5s.pt")
weights_entry = ttk.Entry(weights_frame, textvariable=self.weights, width=30)
weights_entry.pack(side=tk.LEFT, expand=True, fill=tk.X)
# 图像大小
img_frame = ttk.Frame(frame)
img_frame.pack(fill=tk.X, pady=10)
ttk.Label(img_frame, text="图像大小:").pack(side=tk.LEFT, padx=(0, 10))
self.img_size = tk.StringVar(value="640")
img_entry = ttk.Entry(img_frame, textvariable=self.img_size, width=5)
img_entry.pack(side=tk.LEFT)
# 批量大小
batch_frame = ttk.Frame(frame)
batch_frame.pack(fill=tk.X, pady=10)
ttk.Label(batch_frame, text="批量大小:").pack(side=tk.LEFT, padx=(0, 10))
self.batch_size = tk.StringVar(value="16")
batch_entry = ttk.Entry(batch_frame, textvariable=self.batch_size, width=5)
batch_entry.pack(side=tk.LEFT)
# 训练轮数
epochs_frame = ttk.Frame(frame)
epochs_frame.pack(fill=tk.X, pady=10)
ttk.Label(epochs_frame, text="训练轮数:").pack(side=tk.LEFT, padx=(0, 10))
self.epochs = tk.StringVar(value="100")
epochs_entry = ttk.Entry(epochs_frame, textvariable=self.epochs, width=5)
epochs_entry.pack(side=tk.LEFT)
# 训练按钮
ttk.Button(frame, text="开始训练",
command=self.run_train_model).pack(fill=tk.X, pady=20)
def setup_monitor_tab(self):
"""设置漏苗监测选项卡"""
frame = ttk.Frame(self.monitor_tab, padding="10")
frame.pack(fill=tk.BOTH, expand=True)
# 模型路径
model_frame = ttk.Frame(frame)
model_frame.pack(fill=tk.X, pady=10)
ttk.Label(model_frame, text="模型路径:").pack(side=tk.LEFT, padx=(0, 10))
# 默认寻找最新的模型文件路径
default_model_path = "yolov5/runs/train/exp/weights/best.pt"
# 尝试查找最新的训练目录
try:
yolov5_dir = "yolov5"
if os.path.exists(yolov5_dir):
train_dir = os.path.join(yolov5_dir, "runs", "train")
if os.path.exists(train_dir):
# 获取所有exp目录
exp_dirs = [d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))]
if exp_dirs:
# 按修改时间排序,获取最新的目录
newest_dir = sorted(exp_dirs, key=lambda x: os.path.getmtime(os.path.join(train_dir, x)), reverse=True)[0]
weights_path = os.path.join(train_dir, newest_dir, "weights", "best.pt")
if os.path.exists(weights_path):
default_model_path = weights_path
except Exception:
# 如果出错,使用默认路径
pass
self.model_path = tk.StringVar(value=default_model_path)
model_entry = ttk.Entry(model_frame, textvariable=self.model_path, width=30)
model_entry.pack(side=tk.LEFT, expand=True, fill=tk.X)
ttk.Button(model_frame, text="浏览", command=self.browse_model).pack(side=tk.LEFT, padx=(10, 0))
# 添加刷新按钮,用于重新查找最新模型
ttk.Button(model_frame, text="刷新", command=self.refresh_model_path).pack(side=tk.LEFT, padx=(5, 0))
# 置信度阈值
conf_frame = ttk.Frame(frame)
conf_frame.pack(fill=tk.X, pady=10)
ttk.Label(conf_frame, text="置信度阈值:").pack(side=tk.LEFT, padx=(0, 10))
self.conf_threshold = tk.StringVar(value="0.5")
conf_entry = ttk.Entry(conf_frame, textvariable=self.conf_threshold, width=5)
conf_entry.pack(side=tk.LEFT)
# 开始监测按钮
ttk.Button(frame, text="开始监测",
command=self.run_monitor).pack(fill=tk.X, pady=20)
def browse_config(self):
"""浏览配置文件"""
file_path = filedialog.askopenfilename(
title="选择配置文件",
filetypes=[("YAML文件", "*.yaml"), ("所有文件", "*.*")]
)
if file_path:
self.config_path.set(file_path)
def browse_model(self):
"""浏览模型文件"""
file_path = filedialog.askopenfilename(
title="选择模型文件",
filetypes=[("PyTorch模型", "*.pt"), ("所有文件", "*.*")]
)
if file_path:
self.model_path.set(file_path)
def run_data_collector(self):
"""运行数据采集器"""
try:
subprocess.Popen([sys.executable, "data_collector.py"])
except Exception as e:
messagebox.showerror("错误", f"运行数据采集器时出错: {e}")
def run_data_labeler(self):
"""运行数据标注工具"""
try:
subprocess.Popen([sys.executable, "data_labeler.py"])
except Exception as e:
messagebox.showerror("错误", f"运行数据标注工具时出错: {e}")
def run_prepare_dataset(self):
"""运行数据集准备工具"""
try:
val_split = float(self.val_split.get())
# 准备命令
cmd = [sys.executable, "prepare_dataset.py"]
# 启动进程
process = subprocess.Popen(cmd)
# 显示提示
messagebox.showinfo("信息", "正在准备数据集,请等待...")
# 等待进程完成
process.wait()
# 显示完成信息
if process.returncode == 0:
messagebox.showinfo("成功", "数据集准备完成!")
else:
messagebox.showerror("错误", "数据集准备失败,请检查控制台输出。")
except Exception as e:
messagebox.showerror("错误", f"准备数据集时出错: {e}")
def run_train_model(self):
"""运行模型训练"""
try:
# 获取参数
config_file = self.config_path.get()
weights = self.weights.get()
img_size = int(self.img_size.get())
batch_size = int(self.batch_size.get())
epochs = int(self.epochs.get())
# 准备命令
cmd = [
sys.executable, "train_model.py",
"--config", config_file,
"--weights", weights,
"--img-size", str(img_size),
"--batch-size", str(batch_size),
"--epochs", str(epochs)
]
# 显示信息
messagebox.showinfo(
"信息",
"模型训练将在新窗口中启动,可能需要一段时间完成。\n"
"训练过程中请勿关闭命令行窗口。\n\n"
"注意:训练前请确保已安装PyTorch及相关依赖:\n"
"pip install torch torchvision opencv-python"
)
# 启动进程
process = subprocess.Popen(cmd)
# 告知用户训练结束后刷新模型路径
messagebox.showinfo(
"提示",
"训练启动完成!\n"
"训练结束后,请在'漏苗监测'选项卡中点击'刷新'按钮以更新模型路径。"
)
except Exception as e:
messagebox.showerror("错误", f"启动训练时出错: {e}")
def run_monitor(self):
"""运行漏苗监测"""
try:
# 获取参数
model_path = self.model_path.get()
conf_threshold = float(self.conf_threshold.get())
# 检查模型文件是否存在
if not os.path.exists(model_path):
messagebox.showwarning("警告", f"模型文件路径可能不正确: {model_path}\n但将继续尝试加载")
# 创建临时脚本来运行监测器
with open("run_monitor_temp.py", "w") as f:
f.write("""#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import os
import traceback
from seedling_monitor import SeedlingMonitor
def main():
# 添加错误处理
try:
model_path = "{}"
conf_threshold = {}
print("="*50)
print("启动漏苗监测系统")
print(f"模型路径: {{model_path}}")
print(f"置信度阈值: {{conf_threshold}}")
print(f"当前工作目录: {{os.getcwd()}}")
print(f"模型文件是否存在: {{os.path.exists(model_path)}}")
print("="*50)
# 创建并运行监测系统
monitor = SeedlingMonitor(model_path=model_path, conf_threshold=conf_threshold)
monitor.run()
except FileNotFoundError as e:
print(f"错误: 找不到文件: {{e}}")
print("请检查模型文件路径是否正确,或尝试使用绝对路径")
except ImportError as e:
print(f"错误: 缺少必要的Python依赖: {{e}}")
print("请确保已安装所有依赖: pip install torch torchvision opencv-python gpiozero")
except Exception as e:
print(f"错误: {{e}}")
print("详细错误信息:")
traceback.print_exc()
if __name__ == "__main__":
main()
""".format(model_path, conf_threshold))
# 显示信息
messagebox.showinfo(
"信息",
"正在启动漏苗监测系统,请稍候...\n"
"如果出现依赖缺失错误,请先安装所需依赖:\n"
"pip install torch torchvision opencv-python gpiozero\n\n"
"请查看终端输出以获取详细错误信息"
)
# 启动进程
subprocess.Popen([sys.executable, "run_monitor_temp.py"])
except Exception as e:
messagebox.showerror("错误", f"启动监测系统时出错: {e}")
def refresh_model_path(self):
"""刷新模型路径,寻找最新的训练模型"""
try:
yolov5_dir = "yolov5"
if os.path.exists(yolov5_dir):
train_dir = os.path.join(yolov5_dir, "runs", "train")
if os.path.exists(train_dir):
# 获取所有exp目录
exp_dirs = [d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))]
if exp_dirs:
# 按修改时间排序,获取最新的目录
newest_dir = sorted(exp_dirs, key=lambda x: os.path.getmtime(os.path.join(train_dir, x)), reverse=True)[0]
weights_path = os.path.join(train_dir, newest_dir, "weights", "best.pt")
if os.path.exists(weights_path):
self.model_path.set(weights_path)
messagebox.showinfo("成功", f"已找到最新模型: {weights_path}")
return
messagebox.showinfo("提示", "未找到最新训练模型,请手动选择模型文件")
except Exception as e:
messagebox.showerror("错误", f"查找模型时出错: {e}")
def main():
# 创建根窗口
root = tk.Tk()
app = SeedlingMonitorUI(root)
# 创建菜单栏
menu_bar = tk.Menu(root)
root.config(menu=menu_bar)
# 创建工具菜单
tools_menu = tk.Menu(menu_bar, tearoff=0)
menu_bar.add_cascade(label="工具", menu=tools_menu)
# 添加安装依赖选项
tools_menu.add_command(label="安装依赖", command=lambda: install_dependencies())
root.mainloop()
def install_dependencies():
"""安装依赖"""
try:
# 检查install_dependencies.py是否存在
if not os.path.exists("install_dependencies.py"):
messagebox.showerror("错误", "找不到安装脚本 install_dependencies.py")
return
# 显示信息
messagebox.showinfo(
"信息",
"将启动依赖安装程序,请按照提示操作。\n"
"安装过程中请勿关闭命令行窗口。"
)
# 启动安装脚本
subprocess.Popen([sys.executable, "install_dependencies.py"])
except Exception as e:
messagebox.showerror("错误", f"启动安装脚本时出错: {e}")
if __name__ == "__main__":
main()