73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
'''
|
|
Description: Test Model
|
|
Author: renlirong
|
|
Date: 2024-07-25 11:32:36
|
|
LastEditTime: 2024-09-13 11:25:45
|
|
LastEditors: renlirong
|
|
'''
|
|
import os
|
|
import cv2
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
import json
|
|
from pathlib import Path
|
|
from ultralytics import YOLOv10
|
|
|
|
model_path = 'runs/detect/train/weights/best.pt'
|
|
test_image_folder = 'datasets/page_seg/images/train_antd'
|
|
output_folder = 'runs/test/test_output_v10_train'
|
|
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
|
|
model = YOLOv10(model_path)
|
|
confidence_threshold = 0.2
|
|
|
|
image_paths = [os.path.join(test_image_folder, img) for img in os.listdir(test_image_folder) if img.endswith(('png', 'jpg', 'jpeg'))]
|
|
|
|
for image_path in image_paths:
|
|
img = cv2.imread(image_path)
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
results = model(img_rgb, conf=confidence_threshold)
|
|
|
|
detections = results[0].boxes.data
|
|
|
|
height, width, _ = img.shape
|
|
detection_results = []
|
|
|
|
component_id = 1
|
|
for detection in detections:
|
|
x1, y1, x2, y2, conf, cls = detection[:6]
|
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
|
|
|
position = {
|
|
"x": x1 / width,
|
|
"y": y1 / height,
|
|
"width": (x2 - x1) / width,
|
|
"height": (y2 - y1) / height
|
|
}
|
|
detection_results.append({
|
|
"id": f"component_{component_id}",
|
|
"type": model.names[int(cls)],
|
|
"position": position
|
|
})
|
|
|
|
component_id += 1
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
|
label = f'{model.names[int(cls)]} {conf:.2f}'
|
|
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
|
|
|
output_path = os.path.join(output_folder, os.path.basename(image_path))
|
|
cv2.imwrite(output_path, img)
|
|
|
|
json_file_name = os.path.splitext(os.path.basename(image_path))[0] + '.json'
|
|
json_output_path = os.path.join(output_folder, json_file_name)
|
|
with open(json_output_path, 'w') as f:
|
|
json.dump(detection_results, f, indent=4)
|
|
|
|
# 可选:显示结果图片
|
|
# plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
|
# plt.show()
|
|
|
|
print('检测完成,结果已保存至', output_folder)
|