inula/packages/inula-code-generator/YOLOv10_cs/test_model.py

73 lines
2.2 KiB
Python

'''
Description: Test Model
Author: renlirong
Date: 2024-07-25 11:32:36
LastEditTime: 2024-09-13 11:25:45
LastEditors: renlirong
'''
import os
import cv2
import torch
import matplotlib.pyplot as plt
import json
from pathlib import Path
from ultralytics import YOLOv10
model_path = 'runs/detect/train/weights/best.pt'
test_image_folder = 'datasets/page_seg/images/train_antd'
output_folder = 'runs/test/test_output_v10_train'
os.makedirs(output_folder, exist_ok=True)
model = YOLOv10(model_path)
confidence_threshold = 0.2
image_paths = [os.path.join(test_image_folder, img) for img in os.listdir(test_image_folder) if img.endswith(('png', 'jpg', 'jpeg'))]
for image_path in image_paths:
img = cv2.imread(image_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
results = model(img_rgb, conf=confidence_threshold)
detections = results[0].boxes.data
height, width, _ = img.shape
detection_results = []
component_id = 1
for detection in detections:
x1, y1, x2, y2, conf, cls = detection[:6]
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
position = {
"x": x1 / width,
"y": y1 / height,
"width": (x2 - x1) / width,
"height": (y2 - y1) / height
}
detection_results.append({
"id": f"component_{component_id}",
"type": model.names[int(cls)],
"position": position
})
component_id += 1
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f'{model.names[int(cls)]} {conf:.2f}'
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
output_path = os.path.join(output_folder, os.path.basename(image_path))
cv2.imwrite(output_path, img)
json_file_name = os.path.splitext(os.path.basename(image_path))[0] + '.json'
json_output_path = os.path.join(output_folder, json_file_name)
with open(json_output_path, 'w') as f:
json.dump(detection_results, f, indent=4)
# 可选:显示结果图片
# plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# plt.show()
print('检测完成,结果已保存至', output_folder)