拖拽支持,性能优化,动态开启跟踪,识别框优化,动态识别框
This commit is contained in:
parent
c2dd067a56
commit
8a0924f649
|
@ -7,14 +7,15 @@ import com.ly.onnx.engine.InferenceEngine;
|
|||
import com.ly.onnx.model.ModelInfo;
|
||||
import com.ly.play.opencv.VideoPlayer;
|
||||
|
||||
|
||||
import javax.swing.*;
|
||||
import javax.swing.filechooser.FileNameExtensionFilter;
|
||||
import javax.swing.filechooser.FileSystemView;
|
||||
import java.awt.*;
|
||||
import java.awt.datatransfer.DataFlavor;
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class VideoInferenceApp extends JFrame {
|
||||
|
||||
|
@ -26,10 +27,9 @@ public class VideoInferenceApp extends JFrame {
|
|||
private ModelManager modelManager;
|
||||
|
||||
|
||||
|
||||
public VideoInferenceApp() {
|
||||
// 设置窗口标题
|
||||
super("https://gitee.com/sulv0302/onnx-inference4j-play.git");
|
||||
super("ONNX Inference Application");
|
||||
// 初始化UI组件
|
||||
initializeUI();
|
||||
}
|
||||
|
@ -49,13 +49,53 @@ public class VideoInferenceApp extends JFrame {
|
|||
videoPanel = new VideoPanel();
|
||||
videoPanel.setBackground(Color.BLACK);
|
||||
|
||||
// 模型列表区域
|
||||
// 设置拖拽功能
|
||||
videoPanel.setTransferHandler(new TransferHandler() {
|
||||
@Override
|
||||
public boolean canImport(TransferSupport support) {
|
||||
return support.isDataFlavorSupported(DataFlavor.javaFileListFlavor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean importData(TransferSupport support) {
|
||||
if (!canImport(support)) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
// 获取拖拽的文件列表
|
||||
List<File> files = (List<File>) support.getTransferable().getTransferData(DataFlavor.javaFileListFlavor);
|
||||
for (File file : files) {
|
||||
String fileName = file.getName().toLowerCase();
|
||||
if (fileName.endsWith(".jpg") || fileName.endsWith(".jpeg") ||
|
||||
fileName.endsWith(".png") || fileName.endsWith(".bmp") ||
|
||||
fileName.endsWith(".gif")) {
|
||||
// 加载并处理拖拽的图片文件
|
||||
videoPlayer.loadImage(file.getAbsolutePath());
|
||||
} else if (fileName.endsWith(".mp4") || fileName.endsWith(".avi") ||
|
||||
fileName.endsWith(".mkv") || fileName.endsWith(".mov") ||
|
||||
fileName.endsWith(".flv") || fileName.endsWith(".wmv")) {
|
||||
// 加载并播放拖拽的视频文件
|
||||
videoPlayer.loadVideo(file.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
});
|
||||
|
||||
// 初始化 ModelManager(不传递 videoPlayer)
|
||||
modelManager = new ModelManager();
|
||||
modelManager.setPreferredSize(new Dimension(250, 0)); // 设置模型列表区域的宽度
|
||||
|
||||
// 初始化 VideoPlayer
|
||||
// 初始化 VideoPlayer 并传递 modelManager
|
||||
videoPlayer = new VideoPlayer(videoPanel, modelManager);
|
||||
|
||||
// 将 videoPlayer 设置到 modelManager 中
|
||||
modelManager.setVideoPlayer(videoPlayer);
|
||||
|
||||
// 使用 JSplitPane 分割视频区域和模型列表区域
|
||||
JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT, videoPanel, modelManager);
|
||||
splitPane.setResizeWeight(0.8); // 视频区域初始占据80%的空间
|
||||
|
@ -97,6 +137,10 @@ public class VideoInferenceApp extends JFrame {
|
|||
JButton loadVideoButton = new JButton("选择视频文件");
|
||||
loadVideoButton.setPreferredSize(new Dimension(150, 30));
|
||||
|
||||
// 图片文件选择按钮
|
||||
JButton loadImageButton = new JButton("选择图片文件");
|
||||
loadImageButton.setPreferredSize(new Dimension(150, 30));
|
||||
|
||||
// 模型文件选择按钮
|
||||
JButton loadModelButton = new JButton("选择模型");
|
||||
loadModelButton.setPreferredSize(new Dimension(150, 30));
|
||||
|
@ -108,12 +152,19 @@ public class VideoInferenceApp extends JFrame {
|
|||
JButton startPlayButton = new JButton("开始播放");
|
||||
startPlayButton.setPreferredSize(new Dimension(100, 30));
|
||||
|
||||
// 添加目标跟踪复选框
|
||||
JCheckBox trackingCheckBox = new JCheckBox("启用目标跟踪");
|
||||
trackingCheckBox.setSelected(false); // 默认不启用目标跟踪
|
||||
|
||||
// 将按钮和输入框添加到顶部面板
|
||||
topPanel.add(loadVideoButton);
|
||||
topPanel.add(loadImageButton); // 添加图片按钮
|
||||
topPanel.add(loadModelButton);
|
||||
topPanel.add(new JLabel("流地址:"));
|
||||
topPanel.add(streamUrlField);
|
||||
topPanel.add(startPlayButton);
|
||||
// 将复选框添加到顶部面板
|
||||
topPanel.add(trackingCheckBox);
|
||||
|
||||
this.add(topPanel, BorderLayout.NORTH);
|
||||
|
||||
|
@ -126,6 +177,9 @@ public class VideoInferenceApp extends JFrame {
|
|||
// 添加视频加载按钮的行为
|
||||
loadVideoButton.addActionListener(e -> selectVideoFile());
|
||||
|
||||
// 添加图片加载按钮的行为
|
||||
loadImageButton.addActionListener(e -> selectImageFile());
|
||||
|
||||
loadModelButton.addActionListener(e -> {
|
||||
modelManager.loadModel(this);
|
||||
DefaultListModel<ModelInfo> modelList = modelManager.getModelList();
|
||||
|
@ -141,16 +195,31 @@ public class VideoInferenceApp extends JFrame {
|
|||
}
|
||||
});
|
||||
|
||||
// 为复选框添加监听器,动态启用或禁用目标跟踪
|
||||
trackingCheckBox.addActionListener(e -> {
|
||||
boolean isSelected = trackingCheckBox.isSelected(); // 获取当前复选框状态
|
||||
videoPlayer.setTrackingEnabled(isSelected); // 设置是否启用目标跟踪
|
||||
});
|
||||
|
||||
// 播放按钮
|
||||
playButton.addActionListener(e -> videoPlayer.playVideo());
|
||||
playButton.addActionListener(e -> {
|
||||
videoPlayer.playVideo();
|
||||
});
|
||||
|
||||
// 暂停按钮
|
||||
pauseButton.addActionListener(e -> videoPlayer.pauseVideo());
|
||||
|
||||
// // 重播按钮
|
||||
// replayButton.addActionListener(e -> videoPlayer.replayVideo());
|
||||
//
|
||||
// 重播按钮
|
||||
replayButton.addActionListener(e -> {
|
||||
try {
|
||||
// videoPlayer.loadVideo(videoPlayer.getCurrentVideoPath());
|
||||
videoPlayer.playVideo();
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
JOptionPane.showMessageDialog(this, "重播视频失败: " + ex.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
|
||||
}
|
||||
});
|
||||
|
||||
// // 后退5秒
|
||||
// rewind5sButton.addActionListener(e -> videoPlayer.rewind(5000));
|
||||
//
|
||||
|
@ -195,6 +264,28 @@ public class VideoInferenceApp extends JFrame {
|
|||
}
|
||||
}
|
||||
|
||||
// 选择图片文件
|
||||
private void selectImageFile() {
|
||||
File desktopDir = FileSystemView.getFileSystemView().getHomeDirectory();
|
||||
JFileChooser fileChooser = new JFileChooser(desktopDir);
|
||||
fileChooser.setDialogTitle("选择图片文件");
|
||||
// 设置图片文件过滤器,支持常见的图片格式
|
||||
FileNameExtensionFilter imageFilter = new FileNameExtensionFilter(
|
||||
"图片文件 (*.jpg;*.jpeg;*.png;*.bmp;*.gif)", "jpg", "jpeg", "png", "bmp", "gif");
|
||||
fileChooser.setFileFilter(imageFilter);
|
||||
|
||||
int returnValue = fileChooser.showOpenDialog(this);
|
||||
if (returnValue == JFileChooser.APPROVE_OPTION) {
|
||||
File selectedFile = fileChooser.getSelectedFile();
|
||||
try {
|
||||
videoPlayer.loadImage(selectedFile.getAbsolutePath());
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
JOptionPane.showMessageDialog(this, "加载图片失败: " + ex.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
SwingUtilities.invokeLater(VideoInferenceApp::new);
|
||||
}
|
||||
|
|
|
@ -1,410 +0,0 @@
|
|||
package com.ly.lishi;
|
||||
|
||||
import ai.onnxruntime.*;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
import lombok.Data;
|
||||
import org.opencv.core.*;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.*;
|
||||
|
||||
@Data
|
||||
public class InferenceEngine {
|
||||
|
||||
private OrtEnvironment environment;
|
||||
private OrtSession.SessionOptions sessionOptions;
|
||||
private OrtSession session;
|
||||
|
||||
private String modelPath;
|
||||
private List<String> labels;
|
||||
|
||||
// 用于存储图像预处理信息的类变量
|
||||
private long[] inputShape = null;
|
||||
|
||||
static {
|
||||
nu.pattern.OpenCV.loadLocally();
|
||||
}
|
||||
|
||||
public InferenceEngine(String modelPath, List<String> labels) {
|
||||
this.modelPath = modelPath;
|
||||
this.labels = labels;
|
||||
init();
|
||||
}
|
||||
|
||||
public void init() {
|
||||
try {
|
||||
environment = OrtEnvironment.getEnvironment();
|
||||
sessionOptions = new OrtSession.SessionOptions();
|
||||
sessionOptions.addCUDA(0); // 使用 GPU
|
||||
session = environment.createSession(modelPath, sessionOptions);
|
||||
Map<String, NodeInfo> inputInfo = session.getInputInfo();
|
||||
NodeInfo nodeInfo = inputInfo.values().iterator().next();
|
||||
TensorInfo tensorInfo = (TensorInfo) nodeInfo.getInfo();
|
||||
inputShape = tensorInfo.getShape(); // 从模型中获取输入形状
|
||||
logModelInfo(session);
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("模型加载失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
public InferenceResult infer(int w, int h, Map<String, Object> preprocessParams) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
// 从 Map 中获取偏移相关的变量
|
||||
float[] inputData = (float[]) preprocessParams.get("inputData");
|
||||
int origWidth = (int) preprocessParams.get("origWidth");
|
||||
int origHeight = (int) preprocessParams.get("origHeight");
|
||||
float scalingFactor = (float) preprocessParams.get("scalingFactor");
|
||||
int xOffset = (int) preprocessParams.get("xOffset");
|
||||
int yOffset = (int) preprocessParams.get("yOffset");
|
||||
|
||||
try {
|
||||
Map<String, NodeInfo> inputInfo = session.getInputInfo();
|
||||
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
|
||||
|
||||
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
|
||||
|
||||
// 创建输入张量时,使用 CHW 格式的数据
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
|
||||
|
||||
// 执行推理
|
||||
long inferenceStart = System.currentTimeMillis();
|
||||
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
|
||||
long inferenceEnd = System.currentTimeMillis();
|
||||
System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms");
|
||||
|
||||
// 解析推理结果
|
||||
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
|
||||
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 6]
|
||||
|
||||
// 设定置信度阈值
|
||||
float confidenceThreshold = 0.25f; // 您可以根据需要调整
|
||||
|
||||
// 根据模型的输出结果解析边界框
|
||||
List<BoundingBox> boxes = new ArrayList<>();
|
||||
for (float[] data : outputData[0]) { // 遍历所有检测框
|
||||
// 根据模型输出格式,解析中心坐标和宽高
|
||||
float x_center = data[0];
|
||||
float y_center = data[1];
|
||||
float width = data[2];
|
||||
float height = data[3];
|
||||
float confidence = data[4];
|
||||
|
||||
if (confidence >= confidenceThreshold) {
|
||||
// 将中心坐标转换为左上角和右下角坐标
|
||||
float x1 = x_center - width / 2;
|
||||
float y1 = y_center - height / 2;
|
||||
float x2 = x_center + width / 2;
|
||||
float y2 = y_center + height / 2;
|
||||
|
||||
// 调整坐标,减去偏移并除以缩放因子
|
||||
float x1Adjusted = (x1 - xOffset) / scalingFactor;
|
||||
float y1Adjusted = (y1 - yOffset) / scalingFactor;
|
||||
float x2Adjusted = (x2 - xOffset) / scalingFactor;
|
||||
float y2Adjusted = (y2 - yOffset) / scalingFactor;
|
||||
|
||||
// 确保坐标的正确顺序
|
||||
float xMinAdjusted = Math.min(x1Adjusted, x2Adjusted);
|
||||
float xMaxAdjusted = Math.max(x1Adjusted, x2Adjusted);
|
||||
float yMinAdjusted = Math.min(y1Adjusted, y2Adjusted);
|
||||
float yMaxAdjusted = Math.max(y1Adjusted, y2Adjusted);
|
||||
|
||||
// 确保坐标在原始图像范围内
|
||||
int x = (int) Math.max(0, xMinAdjusted);
|
||||
int y = (int) Math.max(0, yMinAdjusted);
|
||||
int xMax = (int) Math.min(origWidth, xMaxAdjusted);
|
||||
int yMax = (int) Math.min(origHeight, yMaxAdjusted);
|
||||
int wBox = xMax - x;
|
||||
int hBox = yMax - y;
|
||||
|
||||
// 仅当宽度和高度为正时,才添加边界框
|
||||
if (wBox > 0 && hBox > 0) {
|
||||
// 使用您的单一标签
|
||||
String label = labels.get(0);
|
||||
|
||||
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)
|
||||
long nmsStart = System.currentTimeMillis();
|
||||
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
|
||||
|
||||
System.out.println("检测到的标签:" + JSON.toJSONString(nmsBoxes));
|
||||
if (!nmsBoxes.isEmpty()) {
|
||||
for (BoundingBox box : nmsBoxes) {
|
||||
System.out.println(box);
|
||||
}
|
||||
}
|
||||
long nmsEnd = System.currentTimeMillis();
|
||||
System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms");
|
||||
|
||||
// 封装结果并返回
|
||||
InferenceResult inferenceResult = new InferenceResult();
|
||||
inferenceResult.setBoundingBoxes(nmsBoxes);
|
||||
|
||||
long endTime = System.currentTimeMillis();
|
||||
System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms");
|
||||
|
||||
return inferenceResult;
|
||||
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("推理失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 计算两个边界框的 IoU
|
||||
private float computeIoU(BoundingBox box1, BoundingBox box2) {
|
||||
int x1 = Math.max(box1.getX(), box2.getX());
|
||||
int y1 = Math.max(box1.getY(), box2.getY());
|
||||
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
|
||||
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
|
||||
|
||||
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
|
||||
int box1Area = box1.getWidth() * box1.getHeight();
|
||||
int box2Area = box2.getWidth() * box2.getHeight();
|
||||
|
||||
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)方法
|
||||
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
|
||||
// 按置信度排序(从高到低)
|
||||
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
|
||||
|
||||
List<BoundingBox> result = new ArrayList<>();
|
||||
|
||||
while (!boxes.isEmpty()) {
|
||||
BoundingBox bestBox = boxes.remove(0);
|
||||
result.add(bestBox);
|
||||
|
||||
Iterator<BoundingBox> iterator = boxes.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
BoundingBox box = iterator.next();
|
||||
if (computeIoU(bestBox, box) > iouThreshold) {
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// 打印模型信息
|
||||
private void logModelInfo(OrtSession session) {
|
||||
System.out.println("模型输入信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输入名称: " + name);
|
||||
System.out.println("输入信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
System.out.println("模型输出信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输出名称: " + name);
|
||||
System.out.println("输出信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// 加载 OpenCV 库
|
||||
|
||||
// 初始化标签列表(只有一个标签)
|
||||
List<String> labels = Arrays.asList("person");
|
||||
|
||||
// 创建 InferenceEngine 实例
|
||||
InferenceEngine inferenceEngine = new InferenceEngine("C:\\Users\\ly\\Desktop\\person.onnx", labels);
|
||||
|
||||
for (int j = 0; j < 10; j++) {
|
||||
try {
|
||||
// 加载图片
|
||||
Mat inputImage = Imgcodecs.imread("C:\\Users\\ly\\Desktop\\10230731212230.png");
|
||||
|
||||
// 预处理图像
|
||||
long l1 = System.currentTimeMillis();
|
||||
Map<String, Object> preprocessResult = inferenceEngine.preprocessImage(inputImage);
|
||||
float[] inputData = (float[]) preprocessResult.get("inputData");
|
||||
|
||||
InferenceResult result = null;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
long l = System.currentTimeMillis();
|
||||
result = inferenceEngine.infer( 640, 640, preprocessResult);
|
||||
System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms");
|
||||
}
|
||||
|
||||
|
||||
// 处理并显示结果
|
||||
System.out.println("推理结果:");
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
System.out.println(box);
|
||||
}
|
||||
|
||||
// 可视化并保存带有边界框的图像
|
||||
Mat outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
|
||||
|
||||
// 保存图片到本地文件
|
||||
String outputFilePath = "output_image_with_boxes.jpg";
|
||||
Imgcodecs.imwrite(outputFilePath, outputImage);
|
||||
|
||||
System.out.println("已保存结果图片: " + outputFilePath);
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 在图像上绘制边界框和标签
|
||||
private Mat drawBoundingBoxes(Mat image, List<BoundingBox> boxes) {
|
||||
for (BoundingBox box : boxes) {
|
||||
// 绘制矩形边界框
|
||||
Imgproc.rectangle(image, new Point(box.getX(), box.getY()),
|
||||
new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight()),
|
||||
new Scalar(0, 0, 255), 2); // 红色边框
|
||||
|
||||
// 绘制标签文字和置信度
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
int baseLine[] = new int[1];
|
||||
Size labelSize = Imgproc.getTextSize(label, Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, 1, baseLine);
|
||||
int top = Math.max(box.getY(), (int) labelSize.height);
|
||||
Imgproc.putText(image, label, new Point(box.getX(), top),
|
||||
Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, new Scalar(255, 255, 255), 1);
|
||||
}
|
||||
|
||||
return image;
|
||||
}
|
||||
|
||||
|
||||
public Map<String, Object> preprocessImage(Mat image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
int origWidth = image.width();
|
||||
int origHeight = image.height();
|
||||
|
||||
// 计算缩放因子
|
||||
float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
int newWidth = Math.round(origWidth * scalingFactor);
|
||||
int newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 计算偏移量以居中图像
|
||||
int xOffset = (targetWidth - newWidth) / 2;
|
||||
int yOffset = (targetHeight - newHeight) / 2;
|
||||
|
||||
// 调整图像尺寸
|
||||
Mat resizedImage = new Mat();
|
||||
Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR);
|
||||
|
||||
// 转换为 RGB 并归一化
|
||||
Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
|
||||
resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
|
||||
|
||||
// 创建填充后的图像
|
||||
Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
|
||||
Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
|
||||
resizedImage.copyTo(paddedImage.submat(roi));
|
||||
|
||||
// 将图像数据转换为数组
|
||||
int imageSize = targetWidth * targetHeight;
|
||||
float[] chwData = new float[3 * imageSize];
|
||||
float[] hwcData = new float[3 * imageSize];
|
||||
paddedImage.get(0, 0, hwcData);
|
||||
|
||||
// 转换为 CHW 格式
|
||||
int channelSize = imageSize;
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int i = 0; i < imageSize; i++) {
|
||||
chwData[c * channelSize + i] = hwcData[i * 3 + c];
|
||||
}
|
||||
}
|
||||
|
||||
// 释放图像资源
|
||||
resizedImage.release();
|
||||
paddedImage.release();
|
||||
|
||||
// 将预处理结果和偏移信息存入 Map
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("inputData", chwData);
|
||||
result.put("origWidth", origWidth);
|
||||
result.put("origHeight", origHeight);
|
||||
result.put("scalingFactor", scalingFactor);
|
||||
result.put("xOffset", xOffset);
|
||||
result.put("yOffset", yOffset);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
// 图像预处理
|
||||
// public float[] preprocessImage(Mat image) {
|
||||
// int targetWidth = 640;
|
||||
// int targetHeight = 640;
|
||||
//
|
||||
// origWidth = image.width();
|
||||
// origHeight = image.height();
|
||||
//
|
||||
// // 计算缩放因子
|
||||
// scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
//
|
||||
// // 计算新的图像尺寸
|
||||
// newWidth = Math.round(origWidth * scalingFactor);
|
||||
// newHeight = Math.round(origHeight * scalingFactor);
|
||||
//
|
||||
// // 计算偏移量以居中图像
|
||||
// xOffset = (targetWidth - newWidth) / 2;
|
||||
// yOffset = (targetHeight - newHeight) / 2;
|
||||
//
|
||||
// // 调整图像尺寸
|
||||
// Mat resizedImage = new Mat();
|
||||
// Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR);
|
||||
//
|
||||
// // 转换为 RGB 并归一化
|
||||
// Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
|
||||
// resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
|
||||
//
|
||||
// // 创建填充后的图像
|
||||
// Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
|
||||
// Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
|
||||
// resizedImage.copyTo(paddedImage.submat(roi));
|
||||
//
|
||||
// // 将图像数据转换为数组
|
||||
// int imageSize = targetWidth * targetHeight;
|
||||
// float[] chwData = new float[3 * imageSize];
|
||||
// float[] hwcData = new float[3 * imageSize];
|
||||
// paddedImage.get(0, 0, hwcData);
|
||||
//
|
||||
// // 转换为 CHW 格式
|
||||
// int channelSize = imageSize;
|
||||
// for (int c = 0; c < 3; c++) {
|
||||
// for (int i = 0; i < imageSize; i++) {
|
||||
// chwData[c * channelSize + i] = hwcData[i * 3 + c];
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 释放图像资源
|
||||
// resizedImage.release();
|
||||
// paddedImage.release();
|
||||
//
|
||||
// return chwData;
|
||||
// }
|
||||
|
||||
}
|
|
@ -1,378 +0,0 @@
|
|||
package com.ly.lishi;
|
||||
|
||||
import com.ly.layout.VideoPanel;
|
||||
import com.ly.model_load.ModelManager;
|
||||
import com.ly.onnx.engine.InferenceEngine;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
import com.ly.onnx.utils.DrawImagesUtils;
|
||||
import org.opencv.core.CvType;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Rect;
|
||||
import org.opencv.core.Size;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
import org.opencv.videoio.VideoCapture;
|
||||
import org.opencv.videoio.Videoio;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static com.ly.onnx.utils.ImageUtils.matToBufferedImage;
|
||||
|
||||
public class VideoPlayer {
|
||||
static {
|
||||
// 加载 OpenCV 库
|
||||
nu.pattern.OpenCV.loadLocally();
|
||||
String OS = System.getProperty("os.name").toLowerCase();
|
||||
if (OS.contains("win")) {
|
||||
System.load(ClassLoader.getSystemResource("lib/win/opencv_videoio_ffmpeg470_64.dll").getPath());
|
||||
}
|
||||
}
|
||||
|
||||
private VideoCapture videoCapture;
|
||||
private volatile boolean isPlaying = false;
|
||||
private volatile boolean isPaused = false;
|
||||
private Thread frameReadingThread;
|
||||
private Thread inferenceThread;
|
||||
private VideoPanel videoPanel;
|
||||
|
||||
private long videoDuration = 0; // 毫秒
|
||||
private long currentTimestamp = 0; // 毫秒
|
||||
|
||||
|
||||
|
||||
private ModelManager modelManager;
|
||||
private List<InferenceEngine> inferenceEngines = new ArrayList<>();
|
||||
|
||||
// 定义阻塞队列来缓冲转换后的数据
|
||||
private BlockingQueue<FrameData> frameDataQueue = new LinkedBlockingQueue<>(10); // 队列容量可根据需要调整
|
||||
|
||||
public VideoPlayer(VideoPanel videoPanel, ModelManager modelManager) {
|
||||
this.videoPanel = videoPanel;
|
||||
this.modelManager = modelManager;
|
||||
}
|
||||
|
||||
// 加载视频或流
|
||||
public void loadVideo(String videoFilePathOrStreamUrl) throws Exception {
|
||||
stopVideo();
|
||||
if (videoFilePathOrStreamUrl.equals("0")) {
|
||||
int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl);
|
||||
videoCapture = new VideoCapture(cameraIndex);
|
||||
if (!videoCapture.isOpened()) {
|
||||
throw new Exception("无法打开摄像头");
|
||||
}
|
||||
videoDuration = 0; // 摄像头没有固定的时长
|
||||
playVideo();
|
||||
} else {
|
||||
// 输入不是数字,尝试打开视频文件
|
||||
videoCapture = new VideoCapture(videoFilePathOrStreamUrl, Videoio.CAP_FFMPEG);
|
||||
if (!videoCapture.isOpened()) {
|
||||
throw new Exception("无法打开视频文件:" + videoFilePathOrStreamUrl);
|
||||
}
|
||||
double frameCount = videoCapture.get(Videoio.CAP_PROP_FRAME_COUNT);
|
||||
double fps = videoCapture.get(Videoio.CAP_PROP_FPS);
|
||||
if (fps <= 0 || Double.isNaN(fps)) {
|
||||
fps = 25; // 默认帧率
|
||||
}
|
||||
videoDuration = (long) (frameCount / fps * 1000); // 转换为毫秒
|
||||
}
|
||||
|
||||
// 显示第一帧
|
||||
Mat frame = new Mat();
|
||||
if (videoCapture.read(frame)) {
|
||||
BufferedImage bufferedImage = matToBufferedImage(frame);
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
currentTimestamp = 0;
|
||||
} else {
|
||||
throw new Exception("无法读取第一帧");
|
||||
}
|
||||
|
||||
// 重置到视频开始位置
|
||||
videoCapture.set(Videoio.CAP_PROP_POS_FRAMES, 0);
|
||||
currentTimestamp = 0;
|
||||
}
|
||||
|
||||
// 播放
|
||||
public void playVideo() {
|
||||
if (videoCapture == null || !videoCapture.isOpened()) {
|
||||
JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isPlaying) {
|
||||
if (isPaused) {
|
||||
isPaused = false; // 恢复播放
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
isPlaying = true;
|
||||
isPaused = false;
|
||||
|
||||
frameDataQueue.clear(); // 开始播放前清空队列
|
||||
|
||||
// 创建并启动帧读取和转换线程
|
||||
frameReadingThread = new Thread(() -> {
|
||||
try {
|
||||
double fps = videoCapture.get(Videoio.CAP_PROP_FPS);
|
||||
if (fps <= 0 || Double.isNaN(fps)) {
|
||||
fps = 25; // 默认帧率
|
||||
}
|
||||
long frameDelay = (long) (1000 / fps);
|
||||
|
||||
while (isPlaying) {
|
||||
if (Thread.currentThread().isInterrupted()) {
|
||||
break;
|
||||
}
|
||||
if (isPaused) {
|
||||
Thread.sleep(10);
|
||||
continue;
|
||||
}
|
||||
|
||||
Mat frame = new Mat();
|
||||
if (!videoCapture.read(frame) || frame.empty()) {
|
||||
isPlaying = false;
|
||||
break;
|
||||
}
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
BufferedImage bufferedImage = matToBufferedImage(frame);
|
||||
|
||||
if (bufferedImage != null) {
|
||||
// float[] floats = preprocessAndConvertBufferedImage(bufferedImage);
|
||||
Map<String, Object> stringObjectMap = preprocessImage(frame);
|
||||
// 创建 FrameData 对象并放入队列
|
||||
FrameData frameData = new FrameData(bufferedImage, null,stringObjectMap);
|
||||
frameDataQueue.put(frameData); // 阻塞,如果队列已满
|
||||
}
|
||||
|
||||
// 控制帧率
|
||||
currentTimestamp = (long) videoCapture.get(Videoio.CAP_PROP_POS_MSEC);
|
||||
|
||||
// 控制播放速度
|
||||
long processingTime = System.currentTimeMillis() - startTime;
|
||||
long sleepTime = frameDelay - processingTime;
|
||||
if (sleepTime > 0) {
|
||||
Thread.sleep(sleepTime);
|
||||
}
|
||||
}
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
} finally {
|
||||
isPlaying = false;
|
||||
}
|
||||
});
|
||||
|
||||
// 创建并启动推理线程
|
||||
inferenceThread = new Thread(() -> {
|
||||
try {
|
||||
while (isPlaying || !frameDataQueue.isEmpty()) {
|
||||
if (Thread.currentThread().isInterrupted()) {
|
||||
break;
|
||||
}
|
||||
if (isPaused) {
|
||||
Thread.sleep(100);
|
||||
continue;
|
||||
}
|
||||
|
||||
FrameData frameData = frameDataQueue.poll(100, TimeUnit.MILLISECONDS); // 等待数据
|
||||
if (frameData == null) {
|
||||
continue; // 没有数据,继续检查 isPlaying
|
||||
}
|
||||
|
||||
BufferedImage bufferedImage = frameData.image;
|
||||
Map<String, Object> floatObjectMap = frameData.floatObjectMap;
|
||||
|
||||
// 执行推理
|
||||
List<InferenceResult> inferenceResults = new ArrayList<>();
|
||||
for (InferenceEngine inferenceEngine : inferenceEngines) {
|
||||
// 假设 InferenceEngine 有 infer 方法接受 float 数组
|
||||
// inferenceResults.add(inferenceEngine.infer( 640, 640,floatObjectMap));
|
||||
}
|
||||
// 绘制推理结果
|
||||
DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults);
|
||||
|
||||
// 更新绘制后图像
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
}
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
});
|
||||
|
||||
frameReadingThread.start();
|
||||
inferenceThread.start();
|
||||
}
|
||||
|
||||
// 暂停视频
|
||||
public void pauseVideo() {
|
||||
if (!isPlaying) {
|
||||
return;
|
||||
}
|
||||
isPaused = true;
|
||||
}
|
||||
|
||||
// 重播视频
|
||||
public void replayVideo() {
|
||||
try {
|
||||
stopVideo(); // 停止当前播放
|
||||
if (videoCapture != null) {
|
||||
videoCapture.set(Videoio.CAP_PROP_POS_FRAMES, 0);
|
||||
currentTimestamp = 0;
|
||||
|
||||
// 显示第一帧
|
||||
Mat frame = new Mat();
|
||||
if (videoCapture.read(frame)) {
|
||||
BufferedImage bufferedImage = matToBufferedImage(frame);
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
}
|
||||
|
||||
playVideo(); // 开始播放
|
||||
}
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
|
||||
}
|
||||
}
|
||||
|
||||
// 停止视频
|
||||
public void stopVideo() {
|
||||
isPlaying = false;
|
||||
isPaused = false;
|
||||
|
||||
if (frameReadingThread != null && frameReadingThread.isAlive()) {
|
||||
frameReadingThread.interrupt();
|
||||
}
|
||||
|
||||
if (inferenceThread != null && inferenceThread.isAlive()) {
|
||||
inferenceThread.interrupt();
|
||||
}
|
||||
|
||||
if (videoCapture != null) {
|
||||
videoCapture.release();
|
||||
videoCapture = null;
|
||||
}
|
||||
|
||||
frameDataQueue.clear();
|
||||
}
|
||||
|
||||
// 快进或后退
|
||||
public void seekTo(long seekTime) {
|
||||
if (videoCapture == null) return;
|
||||
try {
|
||||
isPaused = false; // 取消暂停
|
||||
stopVideo(); // 停止当前播放
|
||||
videoCapture.set(Videoio.CAP_PROP_POS_MSEC, seekTime);
|
||||
currentTimestamp = seekTime;
|
||||
|
||||
Mat frame = new Mat();
|
||||
if (videoCapture.read(frame)) {
|
||||
BufferedImage bufferedImage = matToBufferedImage(frame);
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
}
|
||||
|
||||
// 重新开始播放
|
||||
playVideo();
|
||||
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// 快进
|
||||
public void fastForward(long milliseconds) {
|
||||
long newTime = Math.min(currentTimestamp + milliseconds, videoDuration);
|
||||
seekTo(newTime);
|
||||
}
|
||||
|
||||
// 后退
|
||||
public void rewind(long milliseconds) {
|
||||
long newTime = Math.max(currentTimestamp - milliseconds, 0);
|
||||
seekTo(newTime);
|
||||
}
|
||||
|
||||
public void addInferenceEngines(InferenceEngine inferenceEngine) {
|
||||
this.inferenceEngines.add(inferenceEngine);
|
||||
}
|
||||
|
||||
// 定义一个内部类来存储帧数据
|
||||
private static class FrameData {
|
||||
public BufferedImage image;
|
||||
public float[] floatArray;
|
||||
public Map<String, Object> floatObjectMap;
|
||||
|
||||
public FrameData(BufferedImage image, float[] floatArray, Map<String, Object> floatObjectMap) {
|
||||
this.image = image;
|
||||
this.floatArray = floatArray;
|
||||
this.floatObjectMap = floatObjectMap;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 可选的预处理方法
|
||||
public Map<String, Object> preprocessImage(Mat image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
int origWidth = image.width();
|
||||
int origHeight = image.height();
|
||||
|
||||
// 计算缩放因子
|
||||
float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
int newWidth = Math.round(origWidth * scalingFactor);
|
||||
int newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 调整图像尺寸
|
||||
Mat resizedImage = new Mat();
|
||||
Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight));
|
||||
|
||||
// 转换为 RGB 并归一化
|
||||
Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
|
||||
resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
|
||||
|
||||
// 创建填充后的图像
|
||||
Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
|
||||
int xOffset = (targetWidth - newWidth) / 2;
|
||||
int yOffset = (targetHeight - newHeight) / 2;
|
||||
Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
|
||||
resizedImage.copyTo(paddedImage.submat(roi));
|
||||
|
||||
// 将图像数据转换为数组
|
||||
int imageSize = targetWidth * targetHeight;
|
||||
float[] chwData = new float[3 * imageSize];
|
||||
float[] hwcData = new float[3 * imageSize];
|
||||
paddedImage.get(0, 0, hwcData);
|
||||
|
||||
// 转换为 CHW 格式
|
||||
int channelSize = imageSize;
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int i = 0; i < imageSize; i++) {
|
||||
chwData[c * channelSize + i] = hwcData[i * 3 + c];
|
||||
}
|
||||
}
|
||||
|
||||
// 释放图像资源
|
||||
resizedImage.release();
|
||||
paddedImage.release();
|
||||
|
||||
// 将预处理结果和偏移信息存入 Map
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("inputData", chwData);
|
||||
result.put("origWidth", origWidth);
|
||||
result.put("origHeight", origHeight);
|
||||
result.put("scalingFactor", scalingFactor);
|
||||
result.put("xOffset", xOffset);
|
||||
result.put("yOffset", yOffset);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,19 +1,30 @@
|
|||
package com.ly.model_load;
|
||||
|
||||
import com.ly.file.FileEditor;
|
||||
import com.ly.onnx.engine.InferenceEngine;
|
||||
import com.ly.onnx.model.ModelInfo;
|
||||
import com.ly.play.opencv.VideoPlayer;
|
||||
|
||||
import javax.swing.*;
|
||||
import javax.swing.filechooser.FileNameExtensionFilter;
|
||||
import javax.swing.filechooser.FileSystemView;
|
||||
import java.awt.*;
|
||||
import java.awt.datatransfer.DataFlavor;
|
||||
import java.awt.event.MouseAdapter;
|
||||
import java.awt.event.MouseEvent;
|
||||
import java.io.File;
|
||||
import java.io.FileReader;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.io.BufferedReader;
|
||||
|
||||
|
||||
public class ModelManager extends JPanel {
|
||||
private DefaultListModel<ModelInfo> modelListModel;
|
||||
private JList<ModelInfo> modelList;
|
||||
private JPopupMenu popupMenu;
|
||||
private VideoPlayer videoPlayer;
|
||||
|
||||
|
||||
public ModelManager() {
|
||||
setLayout(new BorderLayout());
|
||||
|
@ -23,6 +34,45 @@ public class ModelManager extends JPanel {
|
|||
JScrollPane modelScrollPane = new JScrollPane(modelList);
|
||||
add(modelScrollPane, BorderLayout.CENTER);
|
||||
|
||||
// 创建右键菜单
|
||||
popupMenu = new JPopupMenu();
|
||||
JMenuItem deleteMenuItem = new JMenuItem("删除");
|
||||
popupMenu.add(deleteMenuItem);
|
||||
|
||||
// 为模型列表添加右键菜单
|
||||
modelList.addMouseListener(new MouseAdapter() {
|
||||
public void mousePressed(MouseEvent e) {
|
||||
if (e.isPopupTrigger()) { // 如果是右键触发
|
||||
showPopup(e);
|
||||
}
|
||||
}
|
||||
|
||||
public void mouseReleased(MouseEvent e) {
|
||||
if (e.isPopupTrigger()) { // 如果是右键触发
|
||||
showPopup(e);
|
||||
}
|
||||
}
|
||||
|
||||
private void showPopup(MouseEvent e) {
|
||||
int index = modelList.locationToIndex(e.getPoint());
|
||||
if (index != -1) {
|
||||
modelList.setSelectedIndex(index); // 选中右键点击的行
|
||||
popupMenu.show(modelList, e.getX(), e.getY());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 为删除菜单项添加操作
|
||||
deleteMenuItem.addActionListener(e -> {
|
||||
int selectedIndex = modelList.getSelectedIndex();
|
||||
if (selectedIndex != -1) {
|
||||
int confirmation = JOptionPane.showConfirmDialog(null, "确定要删除此模型吗?", "确认删除", JOptionPane.YES_NO_OPTION);
|
||||
if (confirmation == JOptionPane.YES_OPTION) {
|
||||
modelListModel.remove(selectedIndex); // 删除选中的模型
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 双击编辑标签文件
|
||||
modelList.addMouseListener(new MouseAdapter() {
|
||||
public void mouseClicked(MouseEvent e) {
|
||||
|
@ -36,8 +86,78 @@ public class ModelManager extends JPanel {
|
|||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 设置拖拽功能处理模型和标签文件
|
||||
setTransferHandler(new TransferHandler() {
|
||||
@Override
|
||||
public boolean canImport(TransferSupport support) {
|
||||
return support.isDataFlavorSupported(DataFlavor.javaFileListFlavor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean importData(TransferSupport support) {
|
||||
if (!canImport(support)) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
// 获取拖拽的文件列表
|
||||
List<File> files = (List<File>) support.getTransferable().getTransferData(DataFlavor.javaFileListFlavor);
|
||||
if (files.size() == 2) { // 确保拖拽的是两个文件
|
||||
File modelFile = null;
|
||||
File labelFile = null;
|
||||
|
||||
for (File file : files) {
|
||||
if (file.getName().endsWith(".onnx")) {
|
||||
modelFile = file;
|
||||
} else if (file.getName().endsWith(".txt")) {
|
||||
labelFile = file;
|
||||
}
|
||||
}
|
||||
|
||||
if (modelFile != null && labelFile != null) {
|
||||
// 确保 videoPlayer 被正确设置
|
||||
if (videoPlayer == null) {
|
||||
throw new IllegalStateException("VideoPlayer is not set in ModelManager.");
|
||||
}
|
||||
|
||||
// 添加模型信息到列表
|
||||
ModelInfo modelInfo = new ModelInfo(modelFile.getAbsolutePath(), labelFile.getAbsolutePath());
|
||||
modelListModel.addElement(modelInfo);
|
||||
|
||||
// 读取标签文件内容,转为 List<String>
|
||||
List<String> labels = new ArrayList<>();
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader(labelFile))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
labels.add(line.trim());
|
||||
}
|
||||
}
|
||||
|
||||
// 创建推理引擎并传递给 VideoPlayer
|
||||
InferenceEngine inferenceEngine = new InferenceEngine(modelFile.getAbsolutePath(), labels);
|
||||
videoPlayer.addInferenceEngines(inferenceEngine);
|
||||
return true;
|
||||
} else {
|
||||
JOptionPane.showMessageDialog(null, "请拖入一个 .onnx 文件和一个 .txt 文件。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
}
|
||||
} else {
|
||||
JOptionPane.showMessageDialog(null, "请拖入两个文件:一个 .onnx 文件和一个 .txt 文件。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
}
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 添加设置 VideoPlayer 的方法
|
||||
public void setVideoPlayer(VideoPlayer videoPlayer) {
|
||||
this.videoPlayer = videoPlayer;
|
||||
}
|
||||
|
||||
|
||||
// 加载模型
|
||||
public void loadModel(JFrame parent) {
|
||||
File desktopDir = FileSystemView.getFileSystemView().getHomeDirectory();
|
||||
|
|
|
@ -154,7 +154,6 @@ public class InferenceEngine {
|
|||
if (wBox > 0 && hBox > 0) {
|
||||
// 使用您的单一标签
|
||||
String label = labels.get(0);
|
||||
|
||||
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,297 +0,0 @@
|
|||
package com.ly.onnx.engine;
|
||||
|
||||
import ai.onnxruntime.*;
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
public class InferenceEngine_up {
|
||||
|
||||
OrtEnvironment environment = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
|
||||
|
||||
private String modelPath;
|
||||
private List<String> labels;
|
||||
|
||||
// 添加用于存储图像预处理信息的类变量
|
||||
private int origWidth;
|
||||
private int origHeight;
|
||||
private int newWidth;
|
||||
private int newHeight;
|
||||
private float scalingFactor;
|
||||
private int xOffset;
|
||||
private int yOffset;
|
||||
|
||||
public InferenceEngine_up(String modelPath, List<String> labels) {
|
||||
this.modelPath = modelPath;
|
||||
this.labels = labels;
|
||||
init();
|
||||
}
|
||||
|
||||
public void init() {
|
||||
OrtSession session = null;
|
||||
try {
|
||||
sessionOptions.addCUDA(0);
|
||||
session = environment.createSession(modelPath, sessionOptions);
|
||||
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
logModelInfo(session);
|
||||
}
|
||||
|
||||
public InferenceResult infer(float[] inputData, int w, int h) {
|
||||
// 创建ONNX输入Tensor
|
||||
try (OrtSession session = environment.createSession(modelPath, sessionOptions)) {
|
||||
Map<String, NodeInfo> inputInfo = session.getInputInfo();
|
||||
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
|
||||
|
||||
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
|
||||
|
||||
// 执行推理
|
||||
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
|
||||
|
||||
// 解析推理结果
|
||||
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
|
||||
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5]
|
||||
|
||||
long l = System.currentTimeMillis();
|
||||
// 设定置信度阈值
|
||||
float confidenceThreshold = 0.5f; // 您可以根据需要调整
|
||||
|
||||
// 根据模型的输出结果解析边界框
|
||||
List<BoundingBox> boxes = new ArrayList<>();
|
||||
for (float[] data : outputData[0]) { // 遍历所有检测框
|
||||
float confidence = data[4];
|
||||
if (confidence >= confidenceThreshold) {
|
||||
float xCenter = data[0];
|
||||
float yCenter = data[1];
|
||||
float widthBox = data[2];
|
||||
float heightBox = data[3];
|
||||
|
||||
// 调整坐标,减去偏移并除以缩放因子
|
||||
float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
|
||||
float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
|
||||
float widthAdjusted = widthBox / scalingFactor;
|
||||
float heightAdjusted = heightBox / scalingFactor;
|
||||
|
||||
// 计算左上角坐标
|
||||
int x = (int) (xCenterAdjusted - widthAdjusted / 2);
|
||||
int y = (int) (yCenterAdjusted - heightAdjusted / 2);
|
||||
int wBox = (int) widthAdjusted;
|
||||
int hBox = (int) heightAdjusted;
|
||||
|
||||
// 确保坐标在原始图像范围内
|
||||
if (x < 0) x = 0;
|
||||
if (y < 0) y = 0;
|
||||
if (x + wBox > origWidth) wBox = origWidth - x;
|
||||
if (y + hBox > origHeight) hBox = origHeight - y;
|
||||
|
||||
String label = "person"; // 由于只有一个类别
|
||||
|
||||
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
|
||||
}
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)
|
||||
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
|
||||
System.out.println("耗时:"+((System.currentTimeMillis() - l)));
|
||||
// 封装结果并返回
|
||||
InferenceResult inferenceResult = new InferenceResult();
|
||||
inferenceResult.setBoundingBoxes(nmsBoxes);
|
||||
return inferenceResult;
|
||||
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("推理失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 计算两个边界框的 IoU
|
||||
private float computeIoU(BoundingBox box1, BoundingBox box2) {
|
||||
int x1 = Math.max(box1.getX(), box2.getX());
|
||||
int y1 = Math.max(box1.getY(), box2.getY());
|
||||
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
|
||||
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
|
||||
|
||||
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
|
||||
int box1Area = box1.getWidth() * box1.getHeight();
|
||||
int box2Area = box2.getWidth() * box2.getHeight();
|
||||
|
||||
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
|
||||
}
|
||||
|
||||
// 打印模型信息
|
||||
private void logModelInfo(OrtSession session) {
|
||||
System.out.println("模型输入信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输入名称: " + name);
|
||||
System.out.println("输入信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
System.out.println("模型输出信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输出名称: " + name);
|
||||
System.out.println("输出信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// 初始化标签列表
|
||||
List<String> labels = Arrays.asList("person");
|
||||
|
||||
// 创建 InferenceEngine 实例
|
||||
InferenceEngine_up inferenceEngine = new InferenceEngine_up("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
|
||||
|
||||
try {
|
||||
// 加载图片
|
||||
File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
|
||||
BufferedImage inputImage = ImageIO.read(imageFile);
|
||||
|
||||
// 预处理图像
|
||||
float[] inputData = inferenceEngine.preprocessImage(inputImage);
|
||||
|
||||
// 执行推理
|
||||
InferenceResult result = null;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
long l = System.currentTimeMillis();
|
||||
result = inferenceEngine.infer(inputData, 640, 640);
|
||||
System.out.println(System.currentTimeMillis() - l);
|
||||
}
|
||||
|
||||
// 处理并显示结果
|
||||
System.out.println("推理结果:");
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
System.out.println(box);
|
||||
}
|
||||
|
||||
// 可视化并保存带有边界框的图像
|
||||
BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
|
||||
|
||||
// 保存图片到本地文件
|
||||
File outputFile = new File("output_image_with_boxes.jpg");
|
||||
ImageIO.write(outputImage, "jpg", outputFile);
|
||||
|
||||
System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// 在图像上绘制边界框和标签
|
||||
BufferedImage drawBoundingBoxes(BufferedImage image, List<BoundingBox> boxes) {
|
||||
Graphics2D g = image.createGraphics();
|
||||
g.setColor(Color.RED); // 设置绘制边界框的颜色
|
||||
g.setStroke(new BasicStroke(2)); // 设置线条粗细
|
||||
|
||||
for (BoundingBox box : boxes) {
|
||||
// 绘制矩形边界框
|
||||
g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
|
||||
// 绘制标签文字和置信度
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
g.setFont(new Font("Arial", Font.PLAIN, 12));
|
||||
g.drawString(label, box.getX(), box.getY() - 5);
|
||||
}
|
||||
|
||||
g.dispose(); // 释放资源
|
||||
return image;
|
||||
}
|
||||
|
||||
// 图像预处理
|
||||
float[] preprocessImage(BufferedImage image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
origWidth = image.getWidth();
|
||||
origHeight = image.getHeight();
|
||||
|
||||
// 计算缩放因子
|
||||
scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
newWidth = Math.round(origWidth * scalingFactor);
|
||||
newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 计算偏移量以居中图像
|
||||
xOffset = (targetWidth - newWidth) / 2;
|
||||
yOffset = (targetHeight - newHeight) / 2;
|
||||
|
||||
// 创建一个新的BufferedImage
|
||||
BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
|
||||
Graphics2D g = resizedImage.createGraphics();
|
||||
|
||||
// 填充背景为黑色
|
||||
g.setColor(Color.BLACK);
|
||||
g.fillRect(0, 0, targetWidth, targetHeight);
|
||||
|
||||
// 绘制缩放后的图像到新的图像上
|
||||
g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
|
||||
g.dispose();
|
||||
|
||||
float[] inputData = new float[3 * targetWidth * targetHeight];
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int y = 0; y < targetHeight; y++) {
|
||||
for (int x = 0; x < targetWidth; x++) {
|
||||
int rgb = resizedImage.getRGB(x, y);
|
||||
float value = 0f;
|
||||
if (c == 0) {
|
||||
value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
|
||||
} else if (c == 1) {
|
||||
value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
|
||||
} else if (c == 2) {
|
||||
value = (rgb & 0xFF) / 255.0f; // Blue channel
|
||||
}
|
||||
inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inputData;
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)方法
|
||||
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
|
||||
// 按置信度排序
|
||||
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
|
||||
|
||||
List<BoundingBox> result = new ArrayList<>();
|
||||
|
||||
while (!boxes.isEmpty()) {
|
||||
BoundingBox bestBox = boxes.remove(0);
|
||||
result.add(bestBox);
|
||||
|
||||
Iterator<BoundingBox> iterator = boxes.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
BoundingBox box = iterator.next();
|
||||
if (computeIoU(bestBox, box) > iouThreshold) {
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// 其他方法保持不变...
|
||||
}
|
|
@ -1,297 +0,0 @@
|
|||
package com.ly.onnx.engine;
|
||||
|
||||
import ai.onnxruntime.*;
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
public class InferenceEngine_up_1 {
|
||||
|
||||
OrtEnvironment environment = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions sessionOptions = null;
|
||||
OrtSession session = null;
|
||||
private String modelPath;
|
||||
private List<String> labels;
|
||||
|
||||
// 添加用于存储图像预处理信息的类变量
|
||||
private int origWidth;
|
||||
private int origHeight;
|
||||
private int newWidth;
|
||||
private int newHeight;
|
||||
private float scalingFactor;
|
||||
private int xOffset;
|
||||
private int yOffset;
|
||||
|
||||
public InferenceEngine_up_1(String modelPath, List<String> labels) {
|
||||
this.modelPath = modelPath;
|
||||
this.labels = labels;
|
||||
init();
|
||||
}
|
||||
|
||||
public void init() {
|
||||
|
||||
try {
|
||||
sessionOptions = new OrtSession.SessionOptions();
|
||||
sessionOptions.addCUDA(0);
|
||||
session = environment.createSession(modelPath, sessionOptions);
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
logModelInfo(session);
|
||||
}
|
||||
|
||||
public InferenceResult infer(float[] inputData, int w, int h) {
|
||||
// 创建ONNX输入Tensor
|
||||
try {
|
||||
Map<String, NodeInfo> inputInfo = session.getInputInfo();
|
||||
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
|
||||
|
||||
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
|
||||
|
||||
// 执行推理
|
||||
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
|
||||
|
||||
// 解析推理结果
|
||||
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
|
||||
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5]
|
||||
|
||||
long l = System.currentTimeMillis();
|
||||
// 设定置信度阈值
|
||||
float confidenceThreshold = 0.5f; // 您可以根据需要调整
|
||||
|
||||
// 根据模型的输出结果解析边界框
|
||||
List<BoundingBox> boxes = new ArrayList<>();
|
||||
for (float[] data : outputData[0]) { // 遍历所有检测框
|
||||
float confidence = data[4];
|
||||
if (confidence >= confidenceThreshold) {
|
||||
float xCenter = data[0];
|
||||
float yCenter = data[1];
|
||||
float widthBox = data[2];
|
||||
float heightBox = data[3];
|
||||
|
||||
// 调整坐标,减去偏移并除以缩放因子
|
||||
float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
|
||||
float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
|
||||
float widthAdjusted = widthBox / scalingFactor;
|
||||
float heightAdjusted = heightBox / scalingFactor;
|
||||
|
||||
// 计算左上角坐标
|
||||
int x = (int) (xCenterAdjusted - widthAdjusted / 2);
|
||||
int y = (int) (yCenterAdjusted - heightAdjusted / 2);
|
||||
int wBox = (int) widthAdjusted;
|
||||
int hBox = (int) heightAdjusted;
|
||||
|
||||
// 确保坐标在原始图像范围内
|
||||
if (x < 0) x = 0;
|
||||
if (y < 0) y = 0;
|
||||
if (x + wBox > origWidth) wBox = origWidth - x;
|
||||
if (y + hBox > origHeight) hBox = origHeight - y;
|
||||
|
||||
String label = "person"; // 由于只有一个类别
|
||||
|
||||
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
|
||||
}
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)
|
||||
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
|
||||
System.out.println("耗时:"+((System.currentTimeMillis() - l)));
|
||||
// 封装结果并返回
|
||||
InferenceResult inferenceResult = new InferenceResult();
|
||||
inferenceResult.setBoundingBoxes(nmsBoxes);
|
||||
return inferenceResult;
|
||||
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("推理失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 计算两个边界框的 IoU
|
||||
private float computeIoU(BoundingBox box1, BoundingBox box2) {
|
||||
int x1 = Math.max(box1.getX(), box2.getX());
|
||||
int y1 = Math.max(box1.getY(), box2.getY());
|
||||
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
|
||||
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
|
||||
|
||||
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
|
||||
int box1Area = box1.getWidth() * box1.getHeight();
|
||||
int box2Area = box2.getWidth() * box2.getHeight();
|
||||
|
||||
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
|
||||
}
|
||||
|
||||
// 打印模型信息
|
||||
private void logModelInfo(OrtSession session) {
|
||||
System.out.println("模型输入信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输入名称: " + name);
|
||||
System.out.println("输入信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
System.out.println("模型输出信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输出名称: " + name);
|
||||
System.out.println("输出信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// 初始化标签列表
|
||||
List<String> labels = Arrays.asList("person");
|
||||
|
||||
// 创建 InferenceEngine 实例
|
||||
InferenceEngine_up_1 inferenceEngine = new InferenceEngine_up_1("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
|
||||
|
||||
try {
|
||||
// 加载图片
|
||||
File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
|
||||
BufferedImage inputImage = ImageIO.read(imageFile);
|
||||
|
||||
// 预处理图像
|
||||
float[] inputData = inferenceEngine.preprocessImage(inputImage);
|
||||
|
||||
// 执行推理
|
||||
InferenceResult result = null;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
long l = System.currentTimeMillis();
|
||||
result = inferenceEngine.infer(inputData, 640, 640);
|
||||
System.out.println(System.currentTimeMillis() - l);
|
||||
}
|
||||
|
||||
// 处理并显示结果
|
||||
System.out.println("推理结果:");
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
System.out.println(box);
|
||||
}
|
||||
|
||||
// 可视化并保存带有边界框的图像
|
||||
BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
|
||||
|
||||
// 保存图片到本地文件
|
||||
File outputFile = new File("output_image_with_boxes.jpg");
|
||||
ImageIO.write(outputImage, "jpg", outputFile);
|
||||
|
||||
System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// 在图像上绘制边界框和标签
|
||||
private BufferedImage drawBoundingBoxes(BufferedImage image, List<BoundingBox> boxes) {
|
||||
Graphics2D g = image.createGraphics();
|
||||
g.setColor(Color.RED); // 设置绘制边界框的颜色
|
||||
g.setStroke(new BasicStroke(2)); // 设置线条粗细
|
||||
|
||||
for (BoundingBox box : boxes) {
|
||||
// 绘制矩形边界框
|
||||
g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
|
||||
// 绘制标签文字和置信度
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
g.setFont(new Font("Arial", Font.PLAIN, 12));
|
||||
g.drawString(label, box.getX(), box.getY() - 5);
|
||||
}
|
||||
|
||||
g.dispose(); // 释放资源
|
||||
return image;
|
||||
}
|
||||
|
||||
// 图像预处理
|
||||
private float[] preprocessImage(BufferedImage image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
origWidth = image.getWidth();
|
||||
origHeight = image.getHeight();
|
||||
|
||||
// 计算缩放因子
|
||||
scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
newWidth = Math.round(origWidth * scalingFactor);
|
||||
newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 计算偏移量以居中图像
|
||||
xOffset = (targetWidth - newWidth) / 2;
|
||||
yOffset = (targetHeight - newHeight) / 2;
|
||||
|
||||
// 创建一个新的BufferedImage
|
||||
BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
|
||||
Graphics2D g = resizedImage.createGraphics();
|
||||
|
||||
// 填充背景为黑色
|
||||
g.setColor(Color.BLACK);
|
||||
g.fillRect(0, 0, targetWidth, targetHeight);
|
||||
|
||||
// 绘制缩放后的图像到新的图像上
|
||||
g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
|
||||
g.dispose();
|
||||
|
||||
float[] inputData = new float[3 * targetWidth * targetHeight];
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int y = 0; y < targetHeight; y++) {
|
||||
for (int x = 0; x < targetWidth; x++) {
|
||||
int rgb = resizedImage.getRGB(x, y);
|
||||
float value = 0f;
|
||||
if (c == 0) {
|
||||
value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
|
||||
} else if (c == 1) {
|
||||
value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
|
||||
} else if (c == 2) {
|
||||
value = (rgb & 0xFF) / 255.0f; // Blue channel
|
||||
}
|
||||
inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inputData;
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)方法
|
||||
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
|
||||
// 按置信度排序
|
||||
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
|
||||
|
||||
List<BoundingBox> result = new ArrayList<>();
|
||||
|
||||
while (!boxes.isEmpty()) {
|
||||
BoundingBox bestBox = boxes.remove(0);
|
||||
result.add(bestBox);
|
||||
|
||||
Iterator<BoundingBox> iterator = boxes.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
BoundingBox box = iterator.next();
|
||||
if (computeIoU(bestBox, box) > iouThreshold) {
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// 其他方法保持不变...
|
||||
}
|
|
@ -1,314 +0,0 @@
|
|||
package com.ly.onnx.engine;
|
||||
|
||||
import ai.onnxruntime.*;
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
public class InferenceEngine_up_2 {
|
||||
|
||||
private OrtEnvironment environment;
|
||||
private OrtSession.SessionOptions sessionOptions;
|
||||
private OrtSession session; // 将 session 作为类的成员变量
|
||||
|
||||
private String modelPath;
|
||||
private List<String> labels;
|
||||
|
||||
// 添加用于存储图像预处理信息的类变量
|
||||
private int origWidth;
|
||||
private int origHeight;
|
||||
private int newWidth;
|
||||
private int newHeight;
|
||||
private float scalingFactor;
|
||||
private int xOffset;
|
||||
private int yOffset;
|
||||
|
||||
public InferenceEngine_up_2(String modelPath, List<String> labels) {
|
||||
this.modelPath = modelPath;
|
||||
this.labels = labels;
|
||||
init();
|
||||
}
|
||||
|
||||
public void init() {
|
||||
try {
|
||||
environment = OrtEnvironment.getEnvironment();
|
||||
sessionOptions = new OrtSession.SessionOptions();
|
||||
sessionOptions.addCUDA(0); // 使用 GPU
|
||||
session = environment.createSession(modelPath, sessionOptions);
|
||||
logModelInfo(session);
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("模型加载失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
public InferenceResult infer(float[] inputData, int w, int h) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
try {
|
||||
Map<String, NodeInfo> inputInfo = session.getInputInfo();
|
||||
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
|
||||
|
||||
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
|
||||
|
||||
// 执行推理
|
||||
long inferenceStart = System.currentTimeMillis();
|
||||
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
|
||||
long inferenceEnd = System.currentTimeMillis();
|
||||
System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms");
|
||||
|
||||
// 解析推理结果
|
||||
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
|
||||
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5]
|
||||
|
||||
// 设定置信度阈值
|
||||
float confidenceThreshold = 0.5f; // 您可以根据需要调整
|
||||
|
||||
// 根据模型的输出结果解析边界框
|
||||
List<BoundingBox> boxes = new ArrayList<>();
|
||||
for (float[] data : outputData[0]) { // 遍历所有检测框
|
||||
float confidence = data[4];
|
||||
if (confidence >= confidenceThreshold) {
|
||||
float xCenter = data[0];
|
||||
float yCenter = data[1];
|
||||
float widthBox = data[2];
|
||||
float heightBox = data[3];
|
||||
|
||||
// 调整坐标,减去偏移并除以缩放因子
|
||||
float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
|
||||
float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
|
||||
float widthAdjusted = widthBox / scalingFactor;
|
||||
float heightAdjusted = heightBox / scalingFactor;
|
||||
|
||||
// 计算左上角坐标
|
||||
int x = (int) (xCenterAdjusted - widthAdjusted / 2);
|
||||
int y = (int) (yCenterAdjusted - heightAdjusted / 2);
|
||||
int wBox = (int) widthAdjusted;
|
||||
int hBox = (int) heightAdjusted;
|
||||
|
||||
// 确保坐标在原始图像范围内
|
||||
if (x < 0) x = 0;
|
||||
if (y < 0) y = 0;
|
||||
if (x + wBox > origWidth) wBox = origWidth - x;
|
||||
if (y + hBox > origHeight) hBox = origHeight - y;
|
||||
|
||||
String label = "person"; // 由于只有一个类别
|
||||
|
||||
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
|
||||
}
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)
|
||||
long nmsStart = System.currentTimeMillis();
|
||||
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
|
||||
long nmsEnd = System.currentTimeMillis();
|
||||
System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms");
|
||||
|
||||
// 封装结果并返回
|
||||
InferenceResult inferenceResult = new InferenceResult();
|
||||
inferenceResult.setBoundingBoxes(nmsBoxes);
|
||||
|
||||
long endTime = System.currentTimeMillis();
|
||||
System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms");
|
||||
|
||||
return inferenceResult;
|
||||
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("推理失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 计算两个边界框的 IoU
|
||||
private float computeIoU(BoundingBox box1, BoundingBox box2) {
|
||||
int x1 = Math.max(box1.getX(), box2.getX());
|
||||
int y1 = Math.max(box1.getY(), box2.getY());
|
||||
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
|
||||
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
|
||||
|
||||
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
|
||||
int box1Area = box1.getWidth() * box1.getHeight();
|
||||
int box2Area = box2.getWidth() * box2.getHeight();
|
||||
|
||||
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
|
||||
}
|
||||
|
||||
// 打印模型信息
|
||||
private void logModelInfo(OrtSession session) {
|
||||
System.out.println("模型输入信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输入名称: " + name);
|
||||
System.out.println("输入信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
System.out.println("模型输出信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输出名称: " + name);
|
||||
System.out.println("输出信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// 初始化标签列表
|
||||
List<String> labels = Arrays.asList("person");
|
||||
|
||||
// 创建 InferenceEngine 实例
|
||||
InferenceEngine_up_2 inferenceEngine = new InferenceEngine_up_2("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
|
||||
|
||||
try {
|
||||
// 加载图片
|
||||
File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
|
||||
BufferedImage inputImage = ImageIO.read(imageFile);
|
||||
|
||||
// 预处理图像
|
||||
long l1 = System.currentTimeMillis();
|
||||
float[] inputData = inferenceEngine.preprocessImage(inputImage);
|
||||
System.out.println("转"+(System.currentTimeMillis() - l1));
|
||||
// 执行推理
|
||||
InferenceResult result = null;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
long l = System.currentTimeMillis();
|
||||
result = inferenceEngine.infer(inputData, 640, 640);
|
||||
System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms");
|
||||
}
|
||||
|
||||
// 处理并显示结果
|
||||
System.out.println("推理结果:");
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
System.out.println(box);
|
||||
}
|
||||
|
||||
// 可视化并保存带有边界框的图像
|
||||
BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
|
||||
|
||||
// 保存图片到本地文件
|
||||
File outputFile = new File("output_image_with_boxes.jpg");
|
||||
ImageIO.write(outputImage, "jpg", outputFile);
|
||||
|
||||
System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// 在图像上绘制边界框和标签
|
||||
private BufferedImage drawBoundingBoxes(BufferedImage image, List<BoundingBox> boxes) {
|
||||
Graphics2D g = image.createGraphics();
|
||||
g.setColor(Color.RED); // 设置绘制边界框的颜色
|
||||
g.setStroke(new BasicStroke(2)); // 设置线条粗细
|
||||
|
||||
for (BoundingBox box : boxes) {
|
||||
// 绘制矩形边界框
|
||||
g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
|
||||
// 绘制标签文字和置信度
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
g.setFont(new Font("Arial", Font.PLAIN, 12));
|
||||
g.drawString(label, box.getX(), box.getY() - 5);
|
||||
}
|
||||
|
||||
g.dispose(); // 释放资源
|
||||
return image;
|
||||
}
|
||||
|
||||
// 图像预处理
|
||||
public float[] preprocessImage(BufferedImage image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
origWidth = image.getWidth();
|
||||
origHeight = image.getHeight();
|
||||
|
||||
// 计算缩放因子
|
||||
scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
newWidth = Math.round(origWidth * scalingFactor);
|
||||
newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 计算偏移量以居中图像
|
||||
xOffset = (targetWidth - newWidth) / 2;
|
||||
yOffset = (targetHeight - newHeight) / 2;
|
||||
|
||||
// 创建一个新的BufferedImage
|
||||
BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
|
||||
Graphics2D g = resizedImage.createGraphics();
|
||||
|
||||
// 填充背景为黑色
|
||||
g.setColor(Color.BLACK);
|
||||
g.fillRect(0, 0, targetWidth, targetHeight);
|
||||
|
||||
// 绘制缩放后的图像到新的图像上
|
||||
g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
|
||||
g.dispose();
|
||||
|
||||
float[] inputData = new float[3 * targetWidth * targetHeight];
|
||||
|
||||
// 开始计时
|
||||
long preprocessStart = System.currentTimeMillis();
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int y = 0; y < targetHeight; y++) {
|
||||
for (int x = 0; x < targetWidth; x++) {
|
||||
int rgb = resizedImage.getRGB(x, y);
|
||||
float value = 0f;
|
||||
if (c == 0) {
|
||||
value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
|
||||
} else if (c == 1) {
|
||||
value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
|
||||
} else if (c == 2) {
|
||||
value = (rgb & 0xFF) / 255.0f; // Blue channel
|
||||
}
|
||||
inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 结束计时
|
||||
long preprocessEnd = System.currentTimeMillis();
|
||||
System.out.println("图像预处理耗时:" + (preprocessEnd - preprocessStart) + " ms");
|
||||
|
||||
return inputData;
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)方法
|
||||
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
|
||||
// 按置信度排序(从高到低)
|
||||
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
|
||||
|
||||
List<BoundingBox> result = new ArrayList<>();
|
||||
|
||||
while (!boxes.isEmpty()) {
|
||||
BoundingBox bestBox = boxes.remove(0);
|
||||
result.add(bestBox);
|
||||
|
||||
Iterator<BoundingBox> iterator = boxes.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
BoundingBox box = iterator.next();
|
||||
if (computeIoU(bestBox, box) > iouThreshold) {
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ public class BoundingBox {
|
|||
private int height;
|
||||
private String label;
|
||||
private float confidence;
|
||||
private long trackId;
|
||||
|
||||
// 构造函数、getter 和 setter 方法
|
||||
|
||||
|
@ -22,6 +23,5 @@ public class BoundingBox {
|
|||
this.confidence = confidence;
|
||||
}
|
||||
|
||||
// Getter 和 Setter 方法
|
||||
// ...
|
||||
|
||||
}
|
||||
|
|
|
@ -13,53 +13,164 @@ import java.util.List;
|
|||
|
||||
public class DrawImagesUtils {
|
||||
|
||||
// 使用HSL颜色生成更高级的颜色
|
||||
public static Color hslToRgb(float hue, float saturation, float lightness) {
|
||||
float c = (1 - Math.abs(2 * lightness - 1)) * saturation;
|
||||
float x = c * (1 - Math.abs((hue / 60) % 2 - 1));
|
||||
float m = lightness - c / 2;
|
||||
float r = 0, g = 0, b = 0;
|
||||
|
||||
if (0 <= hue && hue < 60) {
|
||||
r = c;
|
||||
g = x;
|
||||
} else if (60 <= hue && hue < 120) {
|
||||
r = x;
|
||||
g = c;
|
||||
} else if (120 <= hue && hue < 180) {
|
||||
g = c;
|
||||
b = x;
|
||||
} else if (180 <= hue && hue < 240) {
|
||||
g = x;
|
||||
b = c;
|
||||
} else if (240 <= hue && hue < 300) {
|
||||
r = x;
|
||||
b = c;
|
||||
} else if (300 <= hue && hue < 360) {
|
||||
r = c;
|
||||
b = x;
|
||||
}
|
||||
|
||||
int rVal = (int) ((r + m) * 255);
|
||||
int gVal = (int) ((g + m) * 255);
|
||||
int bVal = (int) ((b + m) * 255);
|
||||
return new Color(rVal, gVal, bVal);
|
||||
}
|
||||
|
||||
// 根据模型索引生成颜色
|
||||
private static Color generateColorForModel(int modelIndex, int totalModels) {
|
||||
float hue = (360.0f / totalModels) * modelIndex; // 根据模型索引设置色相
|
||||
return hslToRgb(hue, 0.7f, 0.5f); // 饱和度0.7,亮度0.5
|
||||
}
|
||||
|
||||
// 在 BufferedImage 上绘制推理结果
|
||||
public static void drawInferenceResult(BufferedImage bufferedImage, List<InferenceResult> inferenceResults) {
|
||||
Graphics2D g2d = bufferedImage.createGraphics();
|
||||
g2d.setFont(new Font("Arial", Font.PLAIN, 12));
|
||||
g2d.setFont(new Font("Arial", Font.PLAIN, 24)); // 设置字体大小为24
|
||||
|
||||
int modelIndex = 0; // 模型索引
|
||||
int totalModels = inferenceResults.size(); // 总模型数
|
||||
|
||||
for (InferenceResult result : inferenceResults) {
|
||||
Color modelColor = generateColorForModel(modelIndex++, totalModels); // 为每个模型生成独立颜色
|
||||
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
// 绘制矩形
|
||||
g2d.setColor(Color.RED);
|
||||
// 绘制矩形框
|
||||
g2d.setColor(modelColor);
|
||||
g2d.setStroke(new BasicStroke(4)); // 设置线条粗细
|
||||
g2d.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
|
||||
|
||||
// 绘制标签
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
// 获取字体度量
|
||||
FontMetrics metrics = g2d.getFontMetrics();
|
||||
int labelWidth = metrics.stringWidth(label);
|
||||
int labelHeight = metrics.getHeight();
|
||||
int labelHeight = metrics.getHeight() + 4; // 标签高度
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
int labelWidth = metrics.stringWidth(label) + 10; // 标签宽度
|
||||
|
||||
// 确保文字不会超出图像
|
||||
int y = Math.max(box.getY(), labelHeight);
|
||||
String trackIdLabel = "TrackID: " + box.getTrackId();
|
||||
int trackIdWidth = metrics.stringWidth(trackIdLabel) + 10; // TrackID标签宽度
|
||||
int trackIdHeight = metrics.getHeight() + 4; // TrackID标签高度
|
||||
|
||||
// 绘制文字背景
|
||||
g2d.setColor(Color.RED);
|
||||
g2d.fillRect(box.getX(), y - labelHeight, labelWidth, labelHeight);
|
||||
// 计算标签总高度
|
||||
int totalLabelHeight = (box.getTrackId() > 0 ? trackIdHeight : 0) + labelHeight;
|
||||
|
||||
// 绘制文字
|
||||
g2d.setColor(Color.WHITE);
|
||||
g2d.drawString(label, box.getX(), y);
|
||||
// 边距
|
||||
int margin = 10;
|
||||
|
||||
// 检查上方是否有足够空间绘制标签
|
||||
boolean canDrawAbove = box.getY() >= totalLabelHeight + margin;
|
||||
|
||||
if (canDrawAbove) {
|
||||
// 在检测框上方绘制标签
|
||||
int currentY = box.getY() - totalLabelHeight;
|
||||
|
||||
// 绘制 TrackID(如果有)
|
||||
if (box.getTrackId() > 0) {
|
||||
// 绘制 TrackID 背景
|
||||
g2d.setColor(modelColor);
|
||||
g2d.fillRect(box.getX(), currentY, trackIdWidth, trackIdHeight);
|
||||
|
||||
// 绘制 TrackID 文字
|
||||
g2d.setColor(Color.BLACK);
|
||||
g2d.drawString(trackIdLabel, box.getX() + 5, currentY + metrics.getAscent());
|
||||
|
||||
currentY += trackIdHeight;
|
||||
}
|
||||
|
||||
// 绘制 classid 背景
|
||||
g2d.setColor(modelColor);
|
||||
g2d.fillRect(box.getX(), currentY, labelWidth, labelHeight);
|
||||
|
||||
// 绘制 classid 文字
|
||||
g2d.setColor(Color.BLACK);
|
||||
g2d.drawString(label, box.getX() + 5, currentY + metrics.getAscent());
|
||||
} else {
|
||||
// 如果上方空间不足,则在检测框内部顶部绘制标签
|
||||
int currentY = box.getY() + 5; // 内边距5
|
||||
|
||||
// 绘制半透明背景以提高可读性
|
||||
int bgAlpha = 200; // 透明度(0-255)
|
||||
Color backgroundColor = new Color(modelColor.getRed(), modelColor.getGreen(), modelColor.getBlue(), bgAlpha);
|
||||
|
||||
if (box.getTrackId() > 0) {
|
||||
// 绘制 TrackID 背景
|
||||
g2d.setColor(backgroundColor);
|
||||
g2d.fillRect(box.getX(), currentY, trackIdWidth, trackIdHeight);
|
||||
|
||||
// 绘制 TrackID 文字
|
||||
g2d.setColor(Color.BLACK);
|
||||
g2d.drawString(trackIdLabel, box.getX() + 5, currentY + metrics.getAscent());
|
||||
|
||||
currentY += trackIdHeight;
|
||||
}
|
||||
|
||||
// 绘制 classid 背景
|
||||
g2d.setColor(backgroundColor);
|
||||
g2d.fillRect(box.getX(), currentY, labelWidth, labelHeight);
|
||||
|
||||
// 绘制 classid 文字
|
||||
g2d.setColor(Color.BLACK);
|
||||
g2d.drawString(label, box.getX() + 5, currentY + metrics.getAscent());
|
||||
}
|
||||
}
|
||||
}
|
||||
g2d.dispose(); // 释放资源
|
||||
}
|
||||
|
||||
|
||||
// 在 Mat 上绘制推理结果
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// 在 Mat 上绘制推理结果 (OpenCV 版本)
|
||||
public static void drawInferenceResult(Mat image, List<InferenceResult> inferenceResults) {
|
||||
int modelIndex = 0;
|
||||
int totalModels = inferenceResults.size();
|
||||
|
||||
for (InferenceResult result : inferenceResults) {
|
||||
Scalar modelColor = convertColorToScalar(generateColorForModel(modelIndex++, totalModels));
|
||||
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
// 绘制矩形
|
||||
Point topLeft = new Point(box.getX(), box.getY());
|
||||
Point bottomRight = new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight());
|
||||
Imgproc.rectangle(image, topLeft, bottomRight, new Scalar(0, 0, 255), 2); // 红色边框
|
||||
Imgproc.rectangle(image, topLeft, bottomRight, modelColor, 3); // 加粗边框
|
||||
|
||||
// 绘制标签
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
int font = Imgproc.FONT_HERSHEY_SIMPLEX;
|
||||
double fontScale = 0.5;
|
||||
int thickness = 1;
|
||||
double fontScale = 0.7;
|
||||
int thickness = 2;
|
||||
|
||||
// 计算文字大小
|
||||
int[] baseLine = new int[1];
|
||||
|
@ -71,12 +182,17 @@ public class DrawImagesUtils {
|
|||
// 绘制文字背景
|
||||
Imgproc.rectangle(image, new Point(topLeft.x, y - labelSize.height),
|
||||
new Point(topLeft.x + labelSize.width, y + baseLine[0]),
|
||||
new Scalar(0, 0, 255), Imgproc.FILLED);
|
||||
modelColor, Imgproc.FILLED);
|
||||
|
||||
// 绘制文字
|
||||
// 绘制黑色文字
|
||||
Imgproc.putText(image, label, new Point(topLeft.x, y),
|
||||
font, fontScale, new Scalar(255, 255, 255), thickness);
|
||||
font, fontScale, new Scalar(0, 0, 0), thickness); // 黑色文字
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 将 Color 转为 Scalar (用于 OpenCV)
|
||||
private static Scalar convertColorToScalar(Color color) {
|
||||
return new Scalar(color.getBlue(), color.getGreen(), color.getRed()); // OpenCV 中颜色顺序是 BGR
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,9 +3,12 @@ package com.ly.play.opencv;
|
|||
import com.ly.layout.VideoPanel;
|
||||
import com.ly.model_load.ModelManager;
|
||||
import com.ly.onnx.engine.InferenceEngine;
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
import com.ly.onnx.utils.DrawImagesUtils;
|
||||
import com.ly.track.SimpleTracker;
|
||||
import org.opencv.core.*;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
import org.opencv.videoio.VideoCapture;
|
||||
import org.opencv.videoio.Videoio;
|
||||
|
@ -37,9 +40,13 @@ public class VideoPlayer {
|
|||
private Thread inferenceThread;
|
||||
private VideoPanel videoPanel;
|
||||
|
||||
// 创建简单的跟踪器
|
||||
SimpleTracker tracker = new SimpleTracker();
|
||||
|
||||
private long videoDuration = 0; // 毫秒
|
||||
private long currentTimestamp = 0; // 毫秒
|
||||
|
||||
private boolean isTrackingEnabled;
|
||||
|
||||
private ModelManager modelManager;
|
||||
private List<InferenceEngine> inferenceEngines = new ArrayList<>();
|
||||
|
@ -178,8 +185,20 @@ public class VideoPlayer {
|
|||
List<InferenceResult> inferenceResults = new ArrayList<>();
|
||||
for (InferenceEngine inferenceEngine : inferenceEngines) {
|
||||
// 假设 InferenceEngine 有 infer 方法接受 float 数组
|
||||
inferenceResults.add(inferenceEngine.infer(floatObjectMap));
|
||||
InferenceResult infer = inferenceEngine.infer(floatObjectMap);
|
||||
inferenceResults.add(infer);
|
||||
}
|
||||
|
||||
// 合并所有模型的推理结果
|
||||
List<BoundingBox> allBoundingBoxes = new ArrayList<>();
|
||||
for (InferenceResult result : inferenceResults) {
|
||||
allBoundingBoxes.addAll(result.getBoundingBoxes());
|
||||
}
|
||||
// 如果启用了目标跟踪,则更新边界框并分配 trackId
|
||||
if (isTrackingEnabled) {
|
||||
tracker.update(allBoundingBoxes);
|
||||
}
|
||||
|
||||
// 绘制推理结果
|
||||
DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults);
|
||||
// 更新绘制后图像
|
||||
|
@ -201,20 +220,22 @@ public class VideoPlayer {
|
|||
isPaused = true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// 设置是否启用目标跟踪
|
||||
public void setTrackingEnabled(boolean enabled) {
|
||||
this.isTrackingEnabled = enabled;
|
||||
}
|
||||
|
||||
// 定义一个内部类来存储帧数据
|
||||
private static class FrameData {
|
||||
public BufferedImage image;
|
||||
public Map<Integer, Object> floatObjectMap;
|
||||
|
||||
public FrameData(BufferedImage image, Map<Integer, Object> floatObjectMap) {
|
||||
this.image = image;
|
||||
this.floatObjectMap = floatObjectMap;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 可选的预处理方法
|
||||
public Map<Integer, Object> preprocessImage(Mat image) {
|
||||
int origWidth = image.width();
|
||||
|
@ -246,9 +267,8 @@ public class VideoPlayer {
|
|||
inferenceEngine.setIndex(index.get());
|
||||
}
|
||||
continue;
|
||||
}else {
|
||||
} else {
|
||||
index.getAndIncrement();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -330,6 +350,7 @@ public class VideoPlayer {
|
|||
|
||||
return dynamicInput;
|
||||
}
|
||||
|
||||
public List<InferenceEngine> getInferenceEngines() {
|
||||
return this.inferenceEngines;
|
||||
}
|
||||
|
@ -359,4 +380,45 @@ public class VideoPlayer {
|
|||
this.inferenceEngines.add(inferenceEngine);
|
||||
}
|
||||
|
||||
// 加载并处理图片
|
||||
public void loadImage(String imagePath) throws Exception {
|
||||
// 停止任何正在播放的视频
|
||||
stopVideo();
|
||||
|
||||
// 读取图片
|
||||
Mat image = Imgcodecs.imread(imagePath);
|
||||
if (image.empty()) {
|
||||
throw new Exception("无法读取图片文件:" + imagePath);
|
||||
}
|
||||
|
||||
// 转换为 BufferedImage
|
||||
BufferedImage bufferedImage = matToBufferedImage(image);
|
||||
|
||||
// 预处理图片
|
||||
Map<Integer, Object> preprocessedData = preprocessImage(image);
|
||||
|
||||
// 执行推理
|
||||
List<InferenceResult> inferenceResults = new ArrayList<>();
|
||||
for (InferenceEngine inferenceEngine : inferenceEngines) {
|
||||
InferenceResult infer = inferenceEngine.infer(preprocessedData);
|
||||
inferenceResults.add(infer);
|
||||
}
|
||||
|
||||
// 合并所有模型的推理结果
|
||||
List<BoundingBox> allBoundingBoxes = new ArrayList<>();
|
||||
for (InferenceResult result : inferenceResults) {
|
||||
allBoundingBoxes.addAll(result.getBoundingBoxes());
|
||||
}
|
||||
|
||||
// 如果启用了目标跟踪,则更新边界框并分配 trackId
|
||||
if (isTrackingEnabled) {
|
||||
tracker.update(allBoundingBoxes);
|
||||
}
|
||||
|
||||
// 绘制推理结果
|
||||
DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults);
|
||||
|
||||
// 在 VideoPanel 上显示图片
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
package com.ly.track;
|
||||
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import lombok.Data;
|
||||
|
||||
import java.awt.*;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public class SimpleTracker {
|
||||
private Map<Long, TrackedObject> trackedObjects = new HashMap<>(); // 使用自定义 TrackedObject 来跟踪
|
||||
private long currentTrackId = 0;
|
||||
|
||||
// 跟踪器更新方法
|
||||
public List<BoundingBox> update(List<BoundingBox> detections) {
|
||||
List<BoundingBox> updatedResults = new ArrayList<>();
|
||||
|
||||
for (BoundingBox detection : detections) {
|
||||
boolean matched = false;
|
||||
Point detectionCenter = getCenter(detection); // 获取当前检测目标的中心点
|
||||
|
||||
// 遍历现有的跟踪目标
|
||||
for (Map.Entry<Long, TrackedObject> entry : trackedObjects.entrySet()) {
|
||||
TrackedObject trackedObject = entry.getValue();
|
||||
Point trackedCenter = getCenter(trackedObject.boundingBox);
|
||||
|
||||
// 使用中心点欧几里得距离进行匹配
|
||||
double distance = euclideanDistance(detectionCenter, trackedCenter);
|
||||
|
||||
// 如果距离小于某个阈值,认为是同一目标
|
||||
if (distance < 50.0) { // 自定义距离阈值,可以根据需要调整
|
||||
detection.setTrackId(entry.getKey()); // 更新检测框的 trackId
|
||||
trackedObject.update(detection); // 更新跟踪对象
|
||||
matched = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有匹配到,创建新的 trackId
|
||||
if (!matched) {
|
||||
long newTrackId = ++currentTrackId;
|
||||
detection.setTrackId(newTrackId);
|
||||
trackedObjects.put(newTrackId, new TrackedObject(detection));
|
||||
}
|
||||
|
||||
updatedResults.add(detection);
|
||||
}
|
||||
|
||||
// 清理丢失的目标
|
||||
cleanupLostObjects();
|
||||
|
||||
return updatedResults;
|
||||
}
|
||||
|
||||
// 计算目标的中心点
|
||||
private Point getCenter(BoundingBox box) {
|
||||
int centerX = box.getX() + box.getWidth() / 2;
|
||||
int centerY = box.getY() + box.getHeight() / 2;
|
||||
return new Point(centerX, centerY);
|
||||
}
|
||||
|
||||
// 计算欧几里得距离
|
||||
private double euclideanDistance(Point p1, Point p2) {
|
||||
return Math.sqrt(Math.pow(p1.x - p2.x, 2) + Math.pow(p1.y - p2.y, 2));
|
||||
}
|
||||
|
||||
// 清理丢失的跟踪对象(例如不再检测到的对象)
|
||||
private void cleanupLostObjects() {
|
||||
// 可以根据时间戳或其他条件来清理长时间没有更新的目标
|
||||
trackedObjects.entrySet().removeIf(entry -> entry.getValue().isLost());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package com.ly.track;
|
||||
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
|
||||
public class TrackedObject {
|
||||
public BoundingBox boundingBox;
|
||||
private int lostFrames = 0; // 记录连续多少帧未检测到
|
||||
|
||||
public TrackedObject(BoundingBox initialBox) {
|
||||
this.boundingBox = initialBox;
|
||||
}
|
||||
|
||||
// 更新跟踪目标的位置
|
||||
public void update(BoundingBox newBox) {
|
||||
this.boundingBox = newBox;
|
||||
lostFrames = 0; // 重置丢失计数
|
||||
}
|
||||
|
||||
// 如果目标连续丢失多帧,认为目标丢失
|
||||
public boolean isLost() {
|
||||
return lostFrames++ > 10; // 如果丢失超过10帧,就认为目标丢失
|
||||
}
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
package com.ly.utils;
|
||||
|
||||
import org.opencv.core.Core;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.videoio.VideoCapture;
|
||||
|
||||
public class OpenCVTest {
|
||||
// static {
|
||||
// nu.pattern.OpenCV.loadLocally();
|
||||
// }
|
||||
|
||||
public static void main(String[] args) {
|
||||
VideoCapture capture = new VideoCapture(0); // 打开默认摄像头
|
||||
if (!capture.isOpened()) {
|
||||
System.out.println("无法打开摄像头");
|
||||
return;
|
||||
}
|
||||
|
||||
Mat frame = new Mat();
|
||||
if (capture.read(frame)) {
|
||||
System.out.println("成功读取一帧图像");
|
||||
} else {
|
||||
System.out.println("无法读取图像");
|
||||
}
|
||||
|
||||
capture.release();
|
||||
}
|
||||
}
|
Binary file not shown.
Loading…
Reference in New Issue