diff --git a/pom.xml b/pom.xml
index cb9742a..ede1555 100644
--- a/pom.xml
+++ b/pom.xml
@@ -22,7 +22,7 @@
com.microsoft.onnxruntime
onnxruntime_gpu
- 1.17.0
+ 1.16.0
org.bytedeco
@@ -40,6 +40,21 @@
ffmpeg-platform
5.0-1.5.7
+
+ org.openpnp
+ opencv
+ 4.7.0-0
+
+
+ org.projectlombok
+ lombok
+ 1.18.32
+
+
+ com.alibaba
+ fastjson
+ 1.2.83
+
diff --git a/src/main/java/com/ly/VideoInferenceApp.java b/src/main/java/com/ly/VideoInferenceApp.java
index d693bc9..da9e96b 100644
--- a/src/main/java/com/ly/VideoInferenceApp.java
+++ b/src/main/java/com/ly/VideoInferenceApp.java
@@ -3,13 +3,18 @@ package com.ly;
import com.formdev.flatlaf.FlatLightLaf;
import com.ly.layout.VideoPanel;
import com.ly.model_load.ModelManager;
-import com.ly.play.VideoPlayer;
+import com.ly.onnx.engine.InferenceEngine;
+import com.ly.onnx.model.ModelInfo;
+import com.ly.play.opencv.VideoPlayer;
+
import javax.swing.*;
import javax.swing.filechooser.FileNameExtensionFilter;
import javax.swing.filechooser.FileSystemView;
import java.awt.*;
import java.io.File;
+import java.util.ArrayList;
+import java.util.Collections;
public class VideoInferenceApp extends JFrame {
@@ -43,13 +48,13 @@ public class VideoInferenceApp extends JFrame {
videoPanel = new VideoPanel();
videoPanel.setBackground(Color.BLACK);
- // 初始化 VideoPlayer
- videoPlayer = new VideoPlayer(videoPanel);
-
// 模型列表区域
modelManager = new ModelManager();
modelManager.setPreferredSize(new Dimension(250, 0)); // 设置模型列表区域的宽度
+ // 初始化 VideoPlayer
+ videoPlayer = new VideoPlayer(videoPanel, modelManager);
+
// 使用 JSplitPane 分割视频区域和模型列表区域
JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT, videoPanel, modelManager);
splitPane.setResizeWeight(0.8); // 视频区域初始占据80%的空间
@@ -120,8 +125,16 @@ public class VideoInferenceApp extends JFrame {
// 添加视频加载按钮的行为
loadVideoButton.addActionListener(e -> selectVideoFile());
- // 添加模型加载按钮的行为
- loadModelButton.addActionListener(e -> modelManager.loadModel(this));
+ loadModelButton.addActionListener(e -> {
+ modelManager.loadModel(this);
+ DefaultListModel modelList = modelManager.getModelList();
+ ArrayList models = Collections.list(modelList.elements());
+ for (ModelInfo modelInfo : models) {
+ if (modelInfo != null) {
+ videoPlayer.addInferenceEngines(new InferenceEngine(modelInfo.getModelFilePath(), modelInfo.getLabels()));
+ }
+ }
+ });
// 播放按钮
playButton.addActionListener(e -> videoPlayer.playVideo());
diff --git a/src/main/java/com/ly/model_load/ModelManager.java b/src/main/java/com/ly/model_load/ModelManager.java
index f80eb6d..d3dd348 100644
--- a/src/main/java/com/ly/model_load/ModelManager.java
+++ b/src/main/java/com/ly/model_load/ModelManager.java
@@ -1,42 +1,37 @@
package com.ly.model_load;
-
-
import com.ly.file.FileEditor;
+import com.ly.onnx.model.ModelInfo;
import javax.swing.*;
import javax.swing.filechooser.FileNameExtensionFilter;
import javax.swing.filechooser.FileSystemView;
-import javax.swing.table.DefaultTableModel;
import java.awt.*;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.io.File;
public class ModelManager extends JPanel {
- private DefaultListModel modelListModel;
- private JList modelList;
+ private DefaultListModel modelListModel;
+ private JList modelList;
public ModelManager() {
setLayout(new BorderLayout());
modelListModel = new DefaultListModel<>();
modelList = new JList<>(modelListModel);
+ modelList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); // 设置为单选
JScrollPane modelScrollPane = new JScrollPane(modelList);
add(modelScrollPane, BorderLayout.CENTER);
- // 添加双击事件,编辑标签文件
+ // 双击编辑标签文件
modelList.addMouseListener(new MouseAdapter() {
public void mouseClicked(MouseEvent e) {
if (e.getClickCount() == 2) {
int index = modelList.locationToIndex(e.getPoint());
if (index >= 0) {
- String item = modelListModel.getElementAt(index);
- // 解析标签文件路径
- String[] parts = item.split("\n");
- if (parts.length >= 2) {
- String labelFilePath = parts[1].replace("标签文件: ", "").trim();
- FileEditor.openFileEditor(labelFilePath);
- }
+ ModelInfo item = modelListModel.getElementAt(index);
+ String labelFilePath = item.getLabelFilePath();
+ FileEditor.openFileEditor(labelFilePath);
}
}
}
@@ -45,12 +40,9 @@ public class ModelManager extends JPanel {
// 加载模型
public void loadModel(JFrame parent) {
- // 获取桌面目录
File desktopDir = FileSystemView.getFileSystemView().getHomeDirectory();
JFileChooser fileChooser = new JFileChooser(desktopDir);
fileChooser.setDialogTitle("选择模型文件");
-
- // 设置模型文件过滤器,只显示 .onnx 文件
FileNameExtensionFilter modelFilter = new FileNameExtensionFilter("ONNX模型文件 (*.onnx)", "onnx");
fileChooser.setFileFilter(modelFilter);
@@ -60,8 +52,6 @@ public class ModelManager extends JPanel {
// 选择对应的标签文件
fileChooser.setDialogTitle("选择标签文件");
-
- // 设置标签文件过滤器,只显示 .txt 文件
FileNameExtensionFilter labelFilter = new FileNameExtensionFilter("标签文件 (*.txt)", "txt");
fileChooser.setFileFilter(labelFilter);
@@ -69,9 +59,9 @@ public class ModelManager extends JPanel {
if (returnValue == JFileChooser.APPROVE_OPTION) {
File labelFile = fileChooser.getSelectedFile();
- // 将模型和标签文件添加到列表中
- String item = "模型文件: " + modelFile.getAbsolutePath() + "\n标签文件: " + labelFile.getAbsolutePath();
- modelListModel.addElement(item);
+ // 添加模型信息到列表
+ ModelInfo modelInfo = new ModelInfo(modelFile.getAbsolutePath(), labelFile.getAbsolutePath());
+ modelListModel.addElement(modelInfo);
} else {
JOptionPane.showMessageDialog(parent, "未选择标签文件。", "提示", JOptionPane.WARNING_MESSAGE);
}
@@ -79,4 +69,14 @@ public class ModelManager extends JPanel {
JOptionPane.showMessageDialog(parent, "未选择模型文件。", "提示", JOptionPane.WARNING_MESSAGE);
}
}
-}
\ No newline at end of file
+
+ // 获取选中的模型
+ public ModelInfo getSelectedModel() {
+ return modelList.getSelectedValue();
+ }
+
+ // 如果需要在外部访问 modelList,可以添加以下方法
+ public DefaultListModel getModelList() {
+ return modelListModel;
+ }
+}
diff --git a/src/main/java/com/ly/onnx/OnnxModelInference.java b/src/main/java/com/ly/onnx/OnnxModelInference.java
new file mode 100644
index 0000000..e6916b6
--- /dev/null
+++ b/src/main/java/com/ly/onnx/OnnxModelInference.java
@@ -0,0 +1,20 @@
+package com.ly.onnx;
+
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtSession;
+
+public class OnnxModelInference {
+
+ private String modelFilePath;
+
+ private String labelFilePath;
+
+ private String[] labels;
+
+ OrtEnvironment environment = OrtEnvironment.getEnvironment();
+ OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
+
+
+
+
+}
diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine.java b/src/main/java/com/ly/onnx/engine/InferenceEngine.java
new file mode 100644
index 0000000..bec81f7
--- /dev/null
+++ b/src/main/java/com/ly/onnx/engine/InferenceEngine.java
@@ -0,0 +1,410 @@
+package com.ly.onnx.engine;
+
+import ai.onnxruntime.*;
+import com.alibaba.fastjson.JSON;
+import com.ly.onnx.model.BoundingBox;
+import com.ly.onnx.model.InferenceResult;
+
+import org.opencv.core.*;
+import org.opencv.imgcodecs.Imgcodecs;
+import org.opencv.imgproc.Imgproc;
+
+import java.nio.FloatBuffer;
+import java.util.*;
+
+public class InferenceEngine {
+
+ private OrtEnvironment environment;
+ private OrtSession.SessionOptions sessionOptions;
+ private OrtSession session;
+
+ private String modelPath;
+ private List labels;
+
+ // 用于存储图像预处理信息的类变量
+ private int origWidth;
+ private int origHeight;
+ private int newWidth;
+ private int newHeight;
+ private float scalingFactor;
+ private int xOffset;
+ private int yOffset;
+
+ static {
+ nu.pattern.OpenCV.loadLocally();
+ }
+
+ public InferenceEngine(String modelPath, List labels) {
+ this.modelPath = modelPath;
+ this.labels = labels;
+ init();
+ }
+
+ public void init() {
+ try {
+ environment = OrtEnvironment.getEnvironment();
+ sessionOptions = new OrtSession.SessionOptions();
+ sessionOptions.addCUDA(0); // 使用 GPU
+ session = environment.createSession(modelPath, sessionOptions);
+ logModelInfo(session);
+ } catch (OrtException e) {
+ throw new RuntimeException("模型加载失败", e);
+ }
+ }
+
+ public InferenceResult infer(float[] inputData, int w, int h, Map preprocessParams) {
+ long startTime = System.currentTimeMillis();
+
+ // 从 Map 中获取偏移相关的变量
+ int origWidth = (int) preprocessParams.get("origWidth");
+ int origHeight = (int) preprocessParams.get("origHeight");
+ float scalingFactor = (float) preprocessParams.get("scalingFactor");
+ int xOffset = (int) preprocessParams.get("xOffset");
+ int yOffset = (int) preprocessParams.get("yOffset");
+
+ try {
+ Map inputInfo = session.getInputInfo();
+ String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
+
+ long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
+
+ // 创建输入张量时,使用 CHW 格式的数据
+ OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
+
+ // 执行推理
+ long inferenceStart = System.currentTimeMillis();
+ OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
+ long inferenceEnd = System.currentTimeMillis();
+ System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms");
+
+ // 解析推理结果
+ String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
+ float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 6]
+
+ // 设定置信度阈值
+ float confidenceThreshold = 0.25f; // 您可以根据需要调整
+
+ // 根据模型的输出结果解析边界框
+ List boxes = new ArrayList<>();
+ for (float[] data : outputData[0]) { // 遍历所有检测框
+ // 根据模型输出格式,解析中心坐标和宽高
+ float x_center = data[0];
+ float y_center = data[1];
+ float width = data[2];
+ float height = data[3];
+ float confidence = data[4];
+
+ if (confidence >= confidenceThreshold) {
+ // 将中心坐标转换为左上角和右下角坐标
+ float x1 = x_center - width / 2;
+ float y1 = y_center - height / 2;
+ float x2 = x_center + width / 2;
+ float y2 = y_center + height / 2;
+
+ // 调整坐标,减去偏移并除以缩放因子
+ float x1Adjusted = (x1 - xOffset) / scalingFactor;
+ float y1Adjusted = (y1 - yOffset) / scalingFactor;
+ float x2Adjusted = (x2 - xOffset) / scalingFactor;
+ float y2Adjusted = (y2 - yOffset) / scalingFactor;
+
+ // 确保坐标的正确顺序
+ float xMinAdjusted = Math.min(x1Adjusted, x2Adjusted);
+ float xMaxAdjusted = Math.max(x1Adjusted, x2Adjusted);
+ float yMinAdjusted = Math.min(y1Adjusted, y2Adjusted);
+ float yMaxAdjusted = Math.max(y1Adjusted, y2Adjusted);
+
+ // 确保坐标在原始图像范围内
+ int x = (int) Math.max(0, xMinAdjusted);
+ int y = (int) Math.max(0, yMinAdjusted);
+ int xMax = (int) Math.min(origWidth, xMaxAdjusted);
+ int yMax = (int) Math.min(origHeight, yMaxAdjusted);
+ int wBox = xMax - x;
+ int hBox = yMax - y;
+
+ // 仅当宽度和高度为正时,才添加边界框
+ if (wBox > 0 && hBox > 0) {
+ // 使用您的单一标签
+ String label = labels.get(0);
+
+ boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
+ }
+ }
+ }
+
+ // 非极大值抑制(NMS)
+ long nmsStart = System.currentTimeMillis();
+ List nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
+
+ System.out.println("检测到的标签:" + JSON.toJSONString(nmsBoxes));
+ if (!nmsBoxes.isEmpty()) {
+ for (BoundingBox box : nmsBoxes) {
+ System.out.println(box);
+ }
+ }
+ long nmsEnd = System.currentTimeMillis();
+ System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms");
+
+ // 封装结果并返回
+ InferenceResult inferenceResult = new InferenceResult();
+ inferenceResult.setBoundingBoxes(nmsBoxes);
+
+ long endTime = System.currentTimeMillis();
+ System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms");
+
+ return inferenceResult;
+
+ } catch (OrtException e) {
+ throw new RuntimeException("推理失败", e);
+ }
+ }
+
+
+ // 计算两个边界框的 IoU
+ private float computeIoU(BoundingBox box1, BoundingBox box2) {
+ int x1 = Math.max(box1.getX(), box2.getX());
+ int y1 = Math.max(box1.getY(), box2.getY());
+ int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
+ int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
+
+ int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
+ int box1Area = box1.getWidth() * box1.getHeight();
+ int box2Area = box2.getWidth() * box2.getHeight();
+
+ return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
+ }
+
+ // 非极大值抑制(NMS)方法
+ private List nonMaximumSuppression(List boxes, float iouThreshold) {
+ // 按置信度排序(从高到低)
+ boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
+
+ List result = new ArrayList<>();
+
+ while (!boxes.isEmpty()) {
+ BoundingBox bestBox = boxes.remove(0);
+ result.add(bestBox);
+
+ Iterator iterator = boxes.iterator();
+ while (iterator.hasNext()) {
+ BoundingBox box = iterator.next();
+ if (computeIoU(bestBox, box) > iouThreshold) {
+ iterator.remove();
+ }
+ }
+ }
+
+ return result;
+ }
+
+ // 打印模型信息
+ private void logModelInfo(OrtSession session) {
+ System.out.println("模型输入信息:");
+ try {
+ for (Map.Entry entry : session.getInputInfo().entrySet()) {
+ String name = entry.getKey();
+ NodeInfo info = entry.getValue();
+ System.out.println("输入名称: " + name);
+ System.out.println("输入信息: " + info.toString());
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+
+ System.out.println("模型输出信息:");
+ try {
+ for (Map.Entry entry : session.getOutputInfo().entrySet()) {
+ String name = entry.getKey();
+ NodeInfo info = entry.getValue();
+ System.out.println("输出名称: " + name);
+ System.out.println("输出信息: " + info.toString());
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static void main(String[] args) {
+ // 加载 OpenCV 库
+
+ // 初始化标签列表(只有一个标签)
+ List labels = Arrays.asList("person");
+
+ // 创建 InferenceEngine 实例
+ InferenceEngine inferenceEngine = new InferenceEngine("C:\\Users\\ly\\Desktop\\person.onnx", labels);
+
+ for (int j = 0; j < 10; j++) {
+ try {
+ // 加载图片
+ Mat inputImage = Imgcodecs.imread("C:\\Users\\ly\\Desktop\\10230731212230.png");
+
+ // 预处理图像
+ long l1 = System.currentTimeMillis();
+ Map preprocessResult = inferenceEngine.preprocessImage(inputImage);
+ float[] inputData = (float[]) preprocessResult.get("inputData");
+
+ InferenceResult result = null;
+ for (int i = 0; i < 10; i++) {
+ long l = System.currentTimeMillis();
+ result = inferenceEngine.infer(inputData, 640, 640, preprocessResult);
+ System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms");
+ }
+
+
+ // 处理并显示结果
+ System.out.println("推理结果:");
+ for (BoundingBox box : result.getBoundingBoxes()) {
+ System.out.println(box);
+ }
+
+ // 可视化并保存带有边界框的图像
+ Mat outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
+
+ // 保存图片到本地文件
+ String outputFilePath = "output_image_with_boxes.jpg";
+ Imgcodecs.imwrite(outputFilePath, outputImage);
+
+ System.out.println("已保存结果图片: " + outputFilePath);
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ // 在图像上绘制边界框和标签
+ private Mat drawBoundingBoxes(Mat image, List boxes) {
+ for (BoundingBox box : boxes) {
+ // 绘制矩形边界框
+ Imgproc.rectangle(image, new Point(box.getX(), box.getY()),
+ new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight()),
+ new Scalar(0, 0, 255), 2); // 红色边框
+
+ // 绘制标签文字和置信度
+ String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
+ int baseLine[] = new int[1];
+ Size labelSize = Imgproc.getTextSize(label, Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, 1, baseLine);
+ int top = Math.max(box.getY(), (int) labelSize.height);
+ Imgproc.putText(image, label, new Point(box.getX(), top),
+ Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, new Scalar(255, 255, 255), 1);
+ }
+
+ return image;
+ }
+
+
+ public Map preprocessImage(Mat image) {
+ int targetWidth = 640;
+ int targetHeight = 640;
+
+ int origWidth = image.width();
+ int origHeight = image.height();
+
+ // 计算缩放因子
+ float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
+
+ // 计算新的图像尺寸
+ int newWidth = Math.round(origWidth * scalingFactor);
+ int newHeight = Math.round(origHeight * scalingFactor);
+
+ // 计算偏移量以居中图像
+ int xOffset = (targetWidth - newWidth) / 2;
+ int yOffset = (targetHeight - newHeight) / 2;
+
+ // 调整图像尺寸
+ Mat resizedImage = new Mat();
+ Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR);
+
+ // 转换为 RGB 并归一化
+ Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
+ resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
+
+ // 创建填充后的图像
+ Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
+ Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
+ resizedImage.copyTo(paddedImage.submat(roi));
+
+ // 将图像数据转换为数组
+ int imageSize = targetWidth * targetHeight;
+ float[] chwData = new float[3 * imageSize];
+ float[] hwcData = new float[3 * imageSize];
+ paddedImage.get(0, 0, hwcData);
+
+ // 转换为 CHW 格式
+ int channelSize = imageSize;
+ for (int c = 0; c < 3; c++) {
+ for (int i = 0; i < imageSize; i++) {
+ chwData[c * channelSize + i] = hwcData[i * 3 + c];
+ }
+ }
+
+ // 释放图像资源
+ resizedImage.release();
+ paddedImage.release();
+
+ // 将预处理结果和偏移信息存入 Map
+ Map result = new HashMap<>();
+ result.put("inputData", chwData);
+ result.put("origWidth", origWidth);
+ result.put("origHeight", origHeight);
+ result.put("scalingFactor", scalingFactor);
+ result.put("xOffset", xOffset);
+ result.put("yOffset", yOffset);
+
+ return result;
+ }
+
+
+ // 图像预处理
+// public float[] preprocessImage(Mat image) {
+// int targetWidth = 640;
+// int targetHeight = 640;
+//
+// origWidth = image.width();
+// origHeight = image.height();
+//
+// // 计算缩放因子
+// scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
+//
+// // 计算新的图像尺寸
+// newWidth = Math.round(origWidth * scalingFactor);
+// newHeight = Math.round(origHeight * scalingFactor);
+//
+// // 计算偏移量以居中图像
+// xOffset = (targetWidth - newWidth) / 2;
+// yOffset = (targetHeight - newHeight) / 2;
+//
+// // 调整图像尺寸
+// Mat resizedImage = new Mat();
+// Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR);
+//
+// // 转换为 RGB 并归一化
+// Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
+// resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
+//
+// // 创建填充后的图像
+// Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
+// Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
+// resizedImage.copyTo(paddedImage.submat(roi));
+//
+// // 将图像数据转换为数组
+// int imageSize = targetWidth * targetHeight;
+// float[] chwData = new float[3 * imageSize];
+// float[] hwcData = new float[3 * imageSize];
+// paddedImage.get(0, 0, hwcData);
+//
+// // 转换为 CHW 格式
+// int channelSize = imageSize;
+// for (int c = 0; c < 3; c++) {
+// for (int i = 0; i < imageSize; i++) {
+// chwData[c * channelSize + i] = hwcData[i * 3 + c];
+// }
+// }
+//
+// // 释放图像资源
+// resizedImage.release();
+// paddedImage.release();
+//
+// return chwData;
+// }
+
+}
diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine_up.java b/src/main/java/com/ly/onnx/engine/InferenceEngine_up.java
new file mode 100644
index 0000000..6acd759
--- /dev/null
+++ b/src/main/java/com/ly/onnx/engine/InferenceEngine_up.java
@@ -0,0 +1,297 @@
+package com.ly.onnx.engine;
+
+import ai.onnxruntime.*;
+import com.ly.onnx.model.BoundingBox;
+import com.ly.onnx.model.InferenceResult;
+
+import javax.imageio.ImageIO;
+import java.awt.*;
+import java.awt.image.BufferedImage;
+import java.io.File;
+import java.nio.FloatBuffer;
+import java.util.List;
+import java.util.*;
+
+public class InferenceEngine_up {
+
+ OrtEnvironment environment = OrtEnvironment.getEnvironment();
+ OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
+
+ private String modelPath;
+ private List labels;
+
+ // 添加用于存储图像预处理信息的类变量
+ private int origWidth;
+ private int origHeight;
+ private int newWidth;
+ private int newHeight;
+ private float scalingFactor;
+ private int xOffset;
+ private int yOffset;
+
+ public InferenceEngine_up(String modelPath, List labels) {
+ this.modelPath = modelPath;
+ this.labels = labels;
+ init();
+ }
+
+ public void init() {
+ OrtSession session = null;
+ try {
+ sessionOptions.addCUDA(0);
+ session = environment.createSession(modelPath, sessionOptions);
+
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+ logModelInfo(session);
+ }
+
+ public InferenceResult infer(float[] inputData, int w, int h) {
+ // 创建ONNX输入Tensor
+ try (OrtSession session = environment.createSession(modelPath, sessionOptions)) {
+ Map inputInfo = session.getInputInfo();
+ String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
+
+ long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
+ OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
+
+ // 执行推理
+ OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
+
+ // 解析推理结果
+ String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
+ float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5]
+
+ long l = System.currentTimeMillis();
+ // 设定置信度阈值
+ float confidenceThreshold = 0.5f; // 您可以根据需要调整
+
+ // 根据模型的输出结果解析边界框
+ List boxes = new ArrayList<>();
+ for (float[] data : outputData[0]) { // 遍历所有检测框
+ float confidence = data[4];
+ if (confidence >= confidenceThreshold) {
+ float xCenter = data[0];
+ float yCenter = data[1];
+ float widthBox = data[2];
+ float heightBox = data[3];
+
+ // 调整坐标,减去偏移并除以缩放因子
+ float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
+ float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
+ float widthAdjusted = widthBox / scalingFactor;
+ float heightAdjusted = heightBox / scalingFactor;
+
+ // 计算左上角坐标
+ int x = (int) (xCenterAdjusted - widthAdjusted / 2);
+ int y = (int) (yCenterAdjusted - heightAdjusted / 2);
+ int wBox = (int) widthAdjusted;
+ int hBox = (int) heightAdjusted;
+
+ // 确保坐标在原始图像范围内
+ if (x < 0) x = 0;
+ if (y < 0) y = 0;
+ if (x + wBox > origWidth) wBox = origWidth - x;
+ if (y + hBox > origHeight) hBox = origHeight - y;
+
+ String label = "person"; // 由于只有一个类别
+
+ boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
+ }
+ }
+
+ // 非极大值抑制(NMS)
+ List nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
+ System.out.println("耗时:"+((System.currentTimeMillis() - l)));
+ // 封装结果并返回
+ InferenceResult inferenceResult = new InferenceResult();
+ inferenceResult.setBoundingBoxes(nmsBoxes);
+ return inferenceResult;
+
+ } catch (OrtException e) {
+ throw new RuntimeException("推理失败", e);
+ }
+ }
+
+ // 计算两个边界框的 IoU
+ private float computeIoU(BoundingBox box1, BoundingBox box2) {
+ int x1 = Math.max(box1.getX(), box2.getX());
+ int y1 = Math.max(box1.getY(), box2.getY());
+ int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
+ int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
+
+ int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
+ int box1Area = box1.getWidth() * box1.getHeight();
+ int box2Area = box2.getWidth() * box2.getHeight();
+
+ return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
+ }
+
+ // 打印模型信息
+ private void logModelInfo(OrtSession session) {
+ System.out.println("模型输入信息:");
+ try {
+ for (Map.Entry entry : session.getInputInfo().entrySet()) {
+ String name = entry.getKey();
+ NodeInfo info = entry.getValue();
+ System.out.println("输入名称: " + name);
+ System.out.println("输入信息: " + info.toString());
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+
+ System.out.println("模型输出信息:");
+ try {
+ for (Map.Entry entry : session.getOutputInfo().entrySet()) {
+ String name = entry.getKey();
+ NodeInfo info = entry.getValue();
+ System.out.println("输出名称: " + name);
+ System.out.println("输出信息: " + info.toString());
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static void main(String[] args) {
+ // 初始化标签列表
+ List labels = Arrays.asList("person");
+
+ // 创建 InferenceEngine 实例
+ InferenceEngine_up inferenceEngine = new InferenceEngine_up("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
+
+ try {
+ // 加载图片
+ File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
+ BufferedImage inputImage = ImageIO.read(imageFile);
+
+ // 预处理图像
+ float[] inputData = inferenceEngine.preprocessImage(inputImage);
+
+ // 执行推理
+ InferenceResult result = null;
+ for (int i = 0; i < 100; i++) {
+ long l = System.currentTimeMillis();
+ result = inferenceEngine.infer(inputData, 640, 640);
+ System.out.println(System.currentTimeMillis() - l);
+ }
+
+ // 处理并显示结果
+ System.out.println("推理结果:");
+ for (BoundingBox box : result.getBoundingBoxes()) {
+ System.out.println(box);
+ }
+
+ // 可视化并保存带有边界框的图像
+ BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
+
+ // 保存图片到本地文件
+ File outputFile = new File("output_image_with_boxes.jpg");
+ ImageIO.write(outputImage, "jpg", outputFile);
+
+ System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ // 在图像上绘制边界框和标签
+ BufferedImage drawBoundingBoxes(BufferedImage image, List boxes) {
+ Graphics2D g = image.createGraphics();
+ g.setColor(Color.RED); // 设置绘制边界框的颜色
+ g.setStroke(new BasicStroke(2)); // 设置线条粗细
+
+ for (BoundingBox box : boxes) {
+ // 绘制矩形边界框
+ g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
+ // 绘制标签文字和置信度
+ String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
+ g.setFont(new Font("Arial", Font.PLAIN, 12));
+ g.drawString(label, box.getX(), box.getY() - 5);
+ }
+
+ g.dispose(); // 释放资源
+ return image;
+ }
+
+ // 图像预处理
+ float[] preprocessImage(BufferedImage image) {
+ int targetWidth = 640;
+ int targetHeight = 640;
+
+ origWidth = image.getWidth();
+ origHeight = image.getHeight();
+
+ // 计算缩放因子
+ scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
+
+ // 计算新的图像尺寸
+ newWidth = Math.round(origWidth * scalingFactor);
+ newHeight = Math.round(origHeight * scalingFactor);
+
+ // 计算偏移量以居中图像
+ xOffset = (targetWidth - newWidth) / 2;
+ yOffset = (targetHeight - newHeight) / 2;
+
+ // 创建一个新的BufferedImage
+ BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
+ Graphics2D g = resizedImage.createGraphics();
+
+ // 填充背景为黑色
+ g.setColor(Color.BLACK);
+ g.fillRect(0, 0, targetWidth, targetHeight);
+
+ // 绘制缩放后的图像到新的图像上
+ g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
+ g.dispose();
+
+ float[] inputData = new float[3 * targetWidth * targetHeight];
+
+ for (int c = 0; c < 3; c++) {
+ for (int y = 0; y < targetHeight; y++) {
+ for (int x = 0; x < targetWidth; x++) {
+ int rgb = resizedImage.getRGB(x, y);
+ float value = 0f;
+ if (c == 0) {
+ value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
+ } else if (c == 1) {
+ value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
+ } else if (c == 2) {
+ value = (rgb & 0xFF) / 255.0f; // Blue channel
+ }
+ inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
+ }
+ }
+ }
+
+ return inputData;
+ }
+
+ // 非极大值抑制(NMS)方法
+ private List nonMaximumSuppression(List boxes, float iouThreshold) {
+ // 按置信度排序
+ boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
+
+ List result = new ArrayList<>();
+
+ while (!boxes.isEmpty()) {
+ BoundingBox bestBox = boxes.remove(0);
+ result.add(bestBox);
+
+ Iterator iterator = boxes.iterator();
+ while (iterator.hasNext()) {
+ BoundingBox box = iterator.next();
+ if (computeIoU(bestBox, box) > iouThreshold) {
+ iterator.remove();
+ }
+ }
+ }
+
+ return result;
+ }
+
+ // 其他方法保持不变...
+}
diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine_up_1.java b/src/main/java/com/ly/onnx/engine/InferenceEngine_up_1.java
new file mode 100644
index 0000000..c4c732e
--- /dev/null
+++ b/src/main/java/com/ly/onnx/engine/InferenceEngine_up_1.java
@@ -0,0 +1,297 @@
+package com.ly.onnx.engine;
+
+import ai.onnxruntime.*;
+import com.ly.onnx.model.BoundingBox;
+import com.ly.onnx.model.InferenceResult;
+
+import javax.imageio.ImageIO;
+import java.awt.*;
+import java.awt.image.BufferedImage;
+import java.io.File;
+import java.nio.FloatBuffer;
+import java.util.List;
+import java.util.*;
+
+public class InferenceEngine_up_1 {
+
+ OrtEnvironment environment = OrtEnvironment.getEnvironment();
+ OrtSession.SessionOptions sessionOptions = null;
+ OrtSession session = null;
+ private String modelPath;
+ private List labels;
+
+ // 添加用于存储图像预处理信息的类变量
+ private int origWidth;
+ private int origHeight;
+ private int newWidth;
+ private int newHeight;
+ private float scalingFactor;
+ private int xOffset;
+ private int yOffset;
+
+ public InferenceEngine_up_1(String modelPath, List labels) {
+ this.modelPath = modelPath;
+ this.labels = labels;
+ init();
+ }
+
+ public void init() {
+
+ try {
+ sessionOptions = new OrtSession.SessionOptions();
+ sessionOptions.addCUDA(0);
+ session = environment.createSession(modelPath, sessionOptions);
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+ logModelInfo(session);
+ }
+
+ public InferenceResult infer(float[] inputData, int w, int h) {
+ // 创建ONNX输入Tensor
+ try {
+ Map inputInfo = session.getInputInfo();
+ String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
+
+ long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
+ OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
+
+ // 执行推理
+ OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
+
+ // 解析推理结果
+ String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
+ float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5]
+
+ long l = System.currentTimeMillis();
+ // 设定置信度阈值
+ float confidenceThreshold = 0.5f; // 您可以根据需要调整
+
+ // 根据模型的输出结果解析边界框
+ List boxes = new ArrayList<>();
+ for (float[] data : outputData[0]) { // 遍历所有检测框
+ float confidence = data[4];
+ if (confidence >= confidenceThreshold) {
+ float xCenter = data[0];
+ float yCenter = data[1];
+ float widthBox = data[2];
+ float heightBox = data[3];
+
+ // 调整坐标,减去偏移并除以缩放因子
+ float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
+ float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
+ float widthAdjusted = widthBox / scalingFactor;
+ float heightAdjusted = heightBox / scalingFactor;
+
+ // 计算左上角坐标
+ int x = (int) (xCenterAdjusted - widthAdjusted / 2);
+ int y = (int) (yCenterAdjusted - heightAdjusted / 2);
+ int wBox = (int) widthAdjusted;
+ int hBox = (int) heightAdjusted;
+
+ // 确保坐标在原始图像范围内
+ if (x < 0) x = 0;
+ if (y < 0) y = 0;
+ if (x + wBox > origWidth) wBox = origWidth - x;
+ if (y + hBox > origHeight) hBox = origHeight - y;
+
+ String label = "person"; // 由于只有一个类别
+
+ boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
+ }
+ }
+
+ // 非极大值抑制(NMS)
+ List nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
+ System.out.println("耗时:"+((System.currentTimeMillis() - l)));
+ // 封装结果并返回
+ InferenceResult inferenceResult = new InferenceResult();
+ inferenceResult.setBoundingBoxes(nmsBoxes);
+ return inferenceResult;
+
+ } catch (OrtException e) {
+ throw new RuntimeException("推理失败", e);
+ }
+ }
+
+ // 计算两个边界框的 IoU
+ private float computeIoU(BoundingBox box1, BoundingBox box2) {
+ int x1 = Math.max(box1.getX(), box2.getX());
+ int y1 = Math.max(box1.getY(), box2.getY());
+ int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
+ int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
+
+ int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
+ int box1Area = box1.getWidth() * box1.getHeight();
+ int box2Area = box2.getWidth() * box2.getHeight();
+
+ return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
+ }
+
+ // 打印模型信息
+ private void logModelInfo(OrtSession session) {
+ System.out.println("模型输入信息:");
+ try {
+ for (Map.Entry entry : session.getInputInfo().entrySet()) {
+ String name = entry.getKey();
+ NodeInfo info = entry.getValue();
+ System.out.println("输入名称: " + name);
+ System.out.println("输入信息: " + info.toString());
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+
+ System.out.println("模型输出信息:");
+ try {
+ for (Map.Entry entry : session.getOutputInfo().entrySet()) {
+ String name = entry.getKey();
+ NodeInfo info = entry.getValue();
+ System.out.println("输出名称: " + name);
+ System.out.println("输出信息: " + info.toString());
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static void main(String[] args) {
+ // 初始化标签列表
+ List labels = Arrays.asList("person");
+
+ // 创建 InferenceEngine 实例
+ InferenceEngine_up_1 inferenceEngine = new InferenceEngine_up_1("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
+
+ try {
+ // 加载图片
+ File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
+ BufferedImage inputImage = ImageIO.read(imageFile);
+
+ // 预处理图像
+ float[] inputData = inferenceEngine.preprocessImage(inputImage);
+
+ // 执行推理
+ InferenceResult result = null;
+ for (int i = 0; i < 100; i++) {
+ long l = System.currentTimeMillis();
+ result = inferenceEngine.infer(inputData, 640, 640);
+ System.out.println(System.currentTimeMillis() - l);
+ }
+
+ // 处理并显示结果
+ System.out.println("推理结果:");
+ for (BoundingBox box : result.getBoundingBoxes()) {
+ System.out.println(box);
+ }
+
+ // 可视化并保存带有边界框的图像
+ BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
+
+ // 保存图片到本地文件
+ File outputFile = new File("output_image_with_boxes.jpg");
+ ImageIO.write(outputImage, "jpg", outputFile);
+
+ System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ // 在图像上绘制边界框和标签
+ private BufferedImage drawBoundingBoxes(BufferedImage image, List boxes) {
+ Graphics2D g = image.createGraphics();
+ g.setColor(Color.RED); // 设置绘制边界框的颜色
+ g.setStroke(new BasicStroke(2)); // 设置线条粗细
+
+ for (BoundingBox box : boxes) {
+ // 绘制矩形边界框
+ g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
+ // 绘制标签文字和置信度
+ String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
+ g.setFont(new Font("Arial", Font.PLAIN, 12));
+ g.drawString(label, box.getX(), box.getY() - 5);
+ }
+
+ g.dispose(); // 释放资源
+ return image;
+ }
+
+ // 图像预处理
+ private float[] preprocessImage(BufferedImage image) {
+ int targetWidth = 640;
+ int targetHeight = 640;
+
+ origWidth = image.getWidth();
+ origHeight = image.getHeight();
+
+ // 计算缩放因子
+ scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
+
+ // 计算新的图像尺寸
+ newWidth = Math.round(origWidth * scalingFactor);
+ newHeight = Math.round(origHeight * scalingFactor);
+
+ // 计算偏移量以居中图像
+ xOffset = (targetWidth - newWidth) / 2;
+ yOffset = (targetHeight - newHeight) / 2;
+
+ // 创建一个新的BufferedImage
+ BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
+ Graphics2D g = resizedImage.createGraphics();
+
+ // 填充背景为黑色
+ g.setColor(Color.BLACK);
+ g.fillRect(0, 0, targetWidth, targetHeight);
+
+ // 绘制缩放后的图像到新的图像上
+ g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
+ g.dispose();
+
+ float[] inputData = new float[3 * targetWidth * targetHeight];
+
+ for (int c = 0; c < 3; c++) {
+ for (int y = 0; y < targetHeight; y++) {
+ for (int x = 0; x < targetWidth; x++) {
+ int rgb = resizedImage.getRGB(x, y);
+ float value = 0f;
+ if (c == 0) {
+ value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
+ } else if (c == 1) {
+ value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
+ } else if (c == 2) {
+ value = (rgb & 0xFF) / 255.0f; // Blue channel
+ }
+ inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
+ }
+ }
+ }
+
+ return inputData;
+ }
+
+ // 非极大值抑制(NMS)方法
+ private List nonMaximumSuppression(List boxes, float iouThreshold) {
+ // 按置信度排序
+ boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
+
+ List result = new ArrayList<>();
+
+ while (!boxes.isEmpty()) {
+ BoundingBox bestBox = boxes.remove(0);
+ result.add(bestBox);
+
+ Iterator iterator = boxes.iterator();
+ while (iterator.hasNext()) {
+ BoundingBox box = iterator.next();
+ if (computeIoU(bestBox, box) > iouThreshold) {
+ iterator.remove();
+ }
+ }
+ }
+
+ return result;
+ }
+
+ // 其他方法保持不变...
+}
diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine_up_2.java b/src/main/java/com/ly/onnx/engine/InferenceEngine_up_2.java
new file mode 100644
index 0000000..2262dc4
--- /dev/null
+++ b/src/main/java/com/ly/onnx/engine/InferenceEngine_up_2.java
@@ -0,0 +1,314 @@
+package com.ly.onnx.engine;
+
+import ai.onnxruntime.*;
+import com.ly.onnx.model.BoundingBox;
+import com.ly.onnx.model.InferenceResult;
+
+import javax.imageio.ImageIO;
+import java.awt.*;
+import java.awt.image.BufferedImage;
+import java.io.File;
+import java.nio.FloatBuffer;
+import java.util.List;
+import java.util.*;
+
+public class InferenceEngine_up_2 {
+
+ private OrtEnvironment environment;
+ private OrtSession.SessionOptions sessionOptions;
+ private OrtSession session; // 将 session 作为类的成员变量
+
+ private String modelPath;
+ private List labels;
+
+ // 添加用于存储图像预处理信息的类变量
+ private int origWidth;
+ private int origHeight;
+ private int newWidth;
+ private int newHeight;
+ private float scalingFactor;
+ private int xOffset;
+ private int yOffset;
+
+ public InferenceEngine_up_2(String modelPath, List labels) {
+ this.modelPath = modelPath;
+ this.labels = labels;
+ init();
+ }
+
+ public void init() {
+ try {
+ environment = OrtEnvironment.getEnvironment();
+ sessionOptions = new OrtSession.SessionOptions();
+ sessionOptions.addCUDA(0); // 使用 GPU
+ session = environment.createSession(modelPath, sessionOptions);
+ logModelInfo(session);
+ } catch (OrtException e) {
+ throw new RuntimeException("模型加载失败", e);
+ }
+ }
+
+ public InferenceResult infer(float[] inputData, int w, int h) {
+ long startTime = System.currentTimeMillis();
+
+ try {
+ Map inputInfo = session.getInputInfo();
+ String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
+
+ long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
+ OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
+
+ // 执行推理
+ long inferenceStart = System.currentTimeMillis();
+ OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
+ long inferenceEnd = System.currentTimeMillis();
+ System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms");
+
+ // 解析推理结果
+ String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
+ float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5]
+
+ // 设定置信度阈值
+ float confidenceThreshold = 0.5f; // 您可以根据需要调整
+
+ // 根据模型的输出结果解析边界框
+ List boxes = new ArrayList<>();
+ for (float[] data : outputData[0]) { // 遍历所有检测框
+ float confidence = data[4];
+ if (confidence >= confidenceThreshold) {
+ float xCenter = data[0];
+ float yCenter = data[1];
+ float widthBox = data[2];
+ float heightBox = data[3];
+
+ // 调整坐标,减去偏移并除以缩放因子
+ float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
+ float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
+ float widthAdjusted = widthBox / scalingFactor;
+ float heightAdjusted = heightBox / scalingFactor;
+
+ // 计算左上角坐标
+ int x = (int) (xCenterAdjusted - widthAdjusted / 2);
+ int y = (int) (yCenterAdjusted - heightAdjusted / 2);
+ int wBox = (int) widthAdjusted;
+ int hBox = (int) heightAdjusted;
+
+ // 确保坐标在原始图像范围内
+ if (x < 0) x = 0;
+ if (y < 0) y = 0;
+ if (x + wBox > origWidth) wBox = origWidth - x;
+ if (y + hBox > origHeight) hBox = origHeight - y;
+
+ String label = "person"; // 由于只有一个类别
+
+ boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
+ }
+ }
+
+ // 非极大值抑制(NMS)
+ long nmsStart = System.currentTimeMillis();
+ List nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
+ long nmsEnd = System.currentTimeMillis();
+ System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms");
+
+ // 封装结果并返回
+ InferenceResult inferenceResult = new InferenceResult();
+ inferenceResult.setBoundingBoxes(nmsBoxes);
+
+ long endTime = System.currentTimeMillis();
+ System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms");
+
+ return inferenceResult;
+
+ } catch (OrtException e) {
+ throw new RuntimeException("推理失败", e);
+ }
+ }
+
+ // 计算两个边界框的 IoU
+ private float computeIoU(BoundingBox box1, BoundingBox box2) {
+ int x1 = Math.max(box1.getX(), box2.getX());
+ int y1 = Math.max(box1.getY(), box2.getY());
+ int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
+ int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
+
+ int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
+ int box1Area = box1.getWidth() * box1.getHeight();
+ int box2Area = box2.getWidth() * box2.getHeight();
+
+ return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
+ }
+
+ // 打印模型信息
+ private void logModelInfo(OrtSession session) {
+ System.out.println("模型输入信息:");
+ try {
+ for (Map.Entry entry : session.getInputInfo().entrySet()) {
+ String name = entry.getKey();
+ NodeInfo info = entry.getValue();
+ System.out.println("输入名称: " + name);
+ System.out.println("输入信息: " + info.toString());
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+
+ System.out.println("模型输出信息:");
+ try {
+ for (Map.Entry entry : session.getOutputInfo().entrySet()) {
+ String name = entry.getKey();
+ NodeInfo info = entry.getValue();
+ System.out.println("输出名称: " + name);
+ System.out.println("输出信息: " + info.toString());
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static void main(String[] args) {
+ // 初始化标签列表
+ List labels = Arrays.asList("person");
+
+ // 创建 InferenceEngine 实例
+ InferenceEngine_up_2 inferenceEngine = new InferenceEngine_up_2("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
+
+ try {
+ // 加载图片
+ File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
+ BufferedImage inputImage = ImageIO.read(imageFile);
+
+ // 预处理图像
+ long l1 = System.currentTimeMillis();
+ float[] inputData = inferenceEngine.preprocessImage(inputImage);
+ System.out.println("转"+(System.currentTimeMillis() - l1));
+ // 执行推理
+ InferenceResult result = null;
+ for (int i = 0; i < 10; i++) {
+ long l = System.currentTimeMillis();
+ result = inferenceEngine.infer(inputData, 640, 640);
+ System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms");
+ }
+
+ // 处理并显示结果
+ System.out.println("推理结果:");
+ for (BoundingBox box : result.getBoundingBoxes()) {
+ System.out.println(box);
+ }
+
+ // 可视化并保存带有边界框的图像
+ BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
+
+ // 保存图片到本地文件
+ File outputFile = new File("output_image_with_boxes.jpg");
+ ImageIO.write(outputImage, "jpg", outputFile);
+
+ System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ // 在图像上绘制边界框和标签
+ private BufferedImage drawBoundingBoxes(BufferedImage image, List boxes) {
+ Graphics2D g = image.createGraphics();
+ g.setColor(Color.RED); // 设置绘制边界框的颜色
+ g.setStroke(new BasicStroke(2)); // 设置线条粗细
+
+ for (BoundingBox box : boxes) {
+ // 绘制矩形边界框
+ g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
+ // 绘制标签文字和置信度
+ String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
+ g.setFont(new Font("Arial", Font.PLAIN, 12));
+ g.drawString(label, box.getX(), box.getY() - 5);
+ }
+
+ g.dispose(); // 释放资源
+ return image;
+ }
+
+ // 图像预处理
+ public float[] preprocessImage(BufferedImage image) {
+ int targetWidth = 640;
+ int targetHeight = 640;
+
+ origWidth = image.getWidth();
+ origHeight = image.getHeight();
+
+ // 计算缩放因子
+ scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
+
+ // 计算新的图像尺寸
+ newWidth = Math.round(origWidth * scalingFactor);
+ newHeight = Math.round(origHeight * scalingFactor);
+
+ // 计算偏移量以居中图像
+ xOffset = (targetWidth - newWidth) / 2;
+ yOffset = (targetHeight - newHeight) / 2;
+
+ // 创建一个新的BufferedImage
+ BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
+ Graphics2D g = resizedImage.createGraphics();
+
+ // 填充背景为黑色
+ g.setColor(Color.BLACK);
+ g.fillRect(0, 0, targetWidth, targetHeight);
+
+ // 绘制缩放后的图像到新的图像上
+ g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
+ g.dispose();
+
+ float[] inputData = new float[3 * targetWidth * targetHeight];
+
+ // 开始计时
+ long preprocessStart = System.currentTimeMillis();
+
+ for (int c = 0; c < 3; c++) {
+ for (int y = 0; y < targetHeight; y++) {
+ for (int x = 0; x < targetWidth; x++) {
+ int rgb = resizedImage.getRGB(x, y);
+ float value = 0f;
+ if (c == 0) {
+ value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
+ } else if (c == 1) {
+ value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
+ } else if (c == 2) {
+ value = (rgb & 0xFF) / 255.0f; // Blue channel
+ }
+ inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
+ }
+ }
+ }
+
+ // 结束计时
+ long preprocessEnd = System.currentTimeMillis();
+ System.out.println("图像预处理耗时:" + (preprocessEnd - preprocessStart) + " ms");
+
+ return inputData;
+ }
+
+ // 非极大值抑制(NMS)方法
+ private List nonMaximumSuppression(List boxes, float iouThreshold) {
+ // 按置信度排序(从高到低)
+ boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
+
+ List result = new ArrayList<>();
+
+ while (!boxes.isEmpty()) {
+ BoundingBox bestBox = boxes.remove(0);
+ result.add(bestBox);
+
+ Iterator iterator = boxes.iterator();
+ while (iterator.hasNext()) {
+ BoundingBox box = iterator.next();
+ if (computeIoU(bestBox, box) > iouThreshold) {
+ iterator.remove();
+ }
+ }
+ }
+
+ return result;
+ }
+}
diff --git a/src/main/java/com/ly/onnx/model/BoundingBox.java b/src/main/java/com/ly/onnx/model/BoundingBox.java
new file mode 100644
index 0000000..332a888
--- /dev/null
+++ b/src/main/java/com/ly/onnx/model/BoundingBox.java
@@ -0,0 +1,27 @@
+package com.ly.onnx.model;
+
+import lombok.Data;
+
+@Data
+public class BoundingBox {
+ private int x;
+ private int y;
+ private int width;
+ private int height;
+ private String label;
+ private float confidence;
+
+ // 构造函数、getter 和 setter 方法
+
+ public BoundingBox(int x, int y, int width, int height, String label, float confidence) {
+ this.x = x;
+ this.y = y;
+ this.width = width;
+ this.height = height;
+ this.label = label;
+ this.confidence = confidence;
+ }
+
+ // Getter 和 Setter 方法
+ // ...
+}
diff --git a/src/main/java/com/ly/onnx/model/InferenceResult.java b/src/main/java/com/ly/onnx/model/InferenceResult.java
new file mode 100644
index 0000000..f9e8ea8
--- /dev/null
+++ b/src/main/java/com/ly/onnx/model/InferenceResult.java
@@ -0,0 +1,18 @@
+package com.ly.onnx.model;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class InferenceResult {
+ private List boundingBoxes = new ArrayList<>();
+
+ public List getBoundingBoxes() {
+ return boundingBoxes;
+ }
+
+ public void setBoundingBoxes(List boundingBoxes) {
+ this.boundingBoxes = boundingBoxes;
+ }
+
+ // 其他需要的属性和方法
+}
diff --git a/src/main/java/com/ly/onnx/model/ModelInfo.java b/src/main/java/com/ly/onnx/model/ModelInfo.java
new file mode 100644
index 0000000..3d0865b
--- /dev/null
+++ b/src/main/java/com/ly/onnx/model/ModelInfo.java
@@ -0,0 +1,39 @@
+package com.ly.onnx.model;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.List;
+
+public class ModelInfo {
+ private String modelFilePath;
+ private String labelFilePath;
+ private List labels;
+
+ public ModelInfo(String modelFilePath, String labelFilePath) {
+ this.modelFilePath = modelFilePath;
+ this.labelFilePath = labelFilePath;
+ try {
+ this.labels = Files.readAllLines(Paths.get(labelFilePath));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public String getModelFilePath() {
+ return modelFilePath;
+ }
+
+ public String getLabelFilePath() {
+ return labelFilePath;
+ }
+
+ public List getLabels() {
+ return labels;
+ }
+
+ @Override
+ public String toString() {
+ return "模型文件: " + modelFilePath + "\n标签文件: " + labelFilePath;
+ }
+}
diff --git a/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java b/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java
new file mode 100644
index 0000000..240da5d
--- /dev/null
+++ b/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java
@@ -0,0 +1,53 @@
+package com.ly.onnx.utils;
+
+import com.ly.onnx.model.BoundingBox;
+import com.ly.onnx.model.InferenceResult;
+import org.opencv.core.Mat;
+import org.opencv.core.Point;
+import org.opencv.core.Scalar;
+import org.opencv.imgproc.Imgproc;
+
+import java.awt.image.BufferedImage;
+import java.util.List;
+
+public class DrawImagesUtils {
+
+
+ public static void drawInferenceResult(BufferedImage bufferedImage, List result) {
+
+ }
+
+ // 在 Mat 上绘制推理结果
+ public static void drawInferenceResult(Mat image, List inferenceResults) {
+ for (InferenceResult result : inferenceResults) {
+ for (BoundingBox box : result.getBoundingBoxes()) {
+ // 绘制矩形
+ Point topLeft = new Point(box.getX(), box.getY());
+ Point bottomRight = new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight());
+ Imgproc.rectangle(image, topLeft, bottomRight, new Scalar(0, 0, 255), 2); // 红色边框
+
+ // 绘制标签
+ String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
+ int font = Imgproc.FONT_HERSHEY_SIMPLEX;
+ double fontScale = 0.5;
+ int thickness = 1;
+
+ // 计算文字大小
+ int[] baseLine = new int[1];
+ org.opencv.core.Size labelSize = Imgproc.getTextSize(label, font, fontScale, thickness, baseLine);
+
+ // 确保文字不会超出图像
+ int y = Math.max((int) topLeft.y, (int) labelSize.height);
+
+ // 绘制文字背景
+ Imgproc.rectangle(image, new Point(topLeft.x, y - labelSize.height),
+ new Point(topLeft.x + labelSize.width, y + baseLine[0]),
+ new Scalar(0, 0, 255), Imgproc.FILLED);
+
+ // 绘制文字
+ Imgproc.putText(image, label, new Point(topLeft.x, y),
+ font, fontScale, new Scalar(255, 255, 255), thickness);
+ }
+ }
+ }
+}
diff --git a/src/main/java/com/ly/onnx/utils/ImageUtils.java b/src/main/java/com/ly/onnx/utils/ImageUtils.java
new file mode 100644
index 0000000..0fe7286
--- /dev/null
+++ b/src/main/java/com/ly/onnx/utils/ImageUtils.java
@@ -0,0 +1,106 @@
+package com.ly.onnx.utils;
+
+import org.bytedeco.javacv.Frame;
+import org.opencv.core.*;
+import org.opencv.imgproc.Imgproc;
+
+import java.awt.image.BufferedImage;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ImageUtils {
+
+ // 辅助方法:将 BufferedImage 转换为浮点数组(根据您的模型需求)
+ private static float[] preprocessImage(BufferedImage image) {
+ int width = image.getWidth();
+ int height = image.getHeight();
+ float[] result = new float[width * height * 3]; // 假设是 RGB 图像
+ int idx = 0;
+
+ for (int y = 0; y < height; y++) {
+ for (int x = 0; x < width; x++) {
+ int pixel = image.getRGB(x, y);
+ // 分别获取 R, G, B 值并归一化(假设归一化到 [0, 1])
+ result[idx++] = ((pixel >> 16) & 0xFF) / 255.0f; // Red
+ result[idx++] = ((pixel >> 8) & 0xFF) / 255.0f; // Green
+ result[idx++] = (pixel & 0xFF) / 255.0f; // Blue
+ }
+ }
+ return result;
+ }
+
+
+
+
+ public static float[] frameToFloatArray(Frame frame) {
+ // 获取 Frame 的宽度和高度
+ int width = frame.imageWidth;
+ int height = frame.imageHeight;
+
+ // 获取 Frame 的像素数据
+ Buffer buffer = frame.image[0]; // 获取图像数据缓冲区
+ ByteBuffer byteBuffer = (ByteBuffer) buffer; // 假设图像数据是以字节缓冲存储
+
+ // 创建 float 数组来存储图像的 RGB 值
+ float[] result = new float[width * height * 3]; // 假设是 RGB 格式图像
+ int idx = 0;
+
+ // 遍历每个像素,提取 R, G, B 值并归一化到 [0, 1]
+ for (int i = 0; i < byteBuffer.capacity(); i += 3) {
+ // 提取 RGB 通道数据
+ int r = byteBuffer.get(i) & 0xFF; // Red
+ int g = byteBuffer.get(i + 1) & 0xFF; // Green
+ int b = byteBuffer.get(i + 2) & 0xFF; // Blue
+
+ // 将 RGB 值归一化为 float 并存入数组
+ result[idx++] = r / 255.0f;
+ result[idx++] = g / 255.0f;
+ result[idx++] = b / 255.0f;
+ }
+
+ return result;
+ }
+ // 将 Mat 转换为 BufferedImage
+ public static BufferedImage matToBufferedImage(Mat mat) {
+ int type = BufferedImage.TYPE_3BYTE_BGR;
+ if (mat.channels() == 1) {
+ type = BufferedImage.TYPE_BYTE_GRAY;
+ }
+ int bufferSize = mat.channels() * mat.cols() * mat.rows();
+ byte[] buffer = new byte[bufferSize];
+ mat.get(0, 0, buffer); // 获取所有像素
+ BufferedImage image = new BufferedImage(mat.cols(), mat.rows(), type);
+ final byte[] targetPixels = ((java.awt.image.DataBufferByte) image.getRaster().getDataBuffer()).getData();
+ System.arraycopy(buffer, 0, targetPixels, 0, buffer.length);
+ return image;
+ }
+
+ // 将 Mat 转换为 float 数组,适用于推理
+ public static float[] matToFloatArray(Mat mat) {
+ // 假设 InferenceEngine 需要 RGB 格式的图像
+ Mat rgbMat = new Mat();
+ Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB);
+
+ // 假设图像已经被预处理(缩放、归一化等),否则需要在这里添加预处理步骤
+
+ // 将 Mat 数据转换为 float 数组
+ int channels = rgbMat.channels();
+ int rows = rgbMat.rows();
+ int cols = rgbMat.cols();
+ float[] floatData = new float[channels * rows * cols];
+ byte[] byteData = new byte[channels * rows * cols];
+ rgbMat.get(0, 0, byteData);
+ for (int i = 0; i < floatData.length; i++) {
+ // 将 unsigned byte 转换为 float [0,1]
+ floatData[i] = (byteData[i] & 0xFF) / 255.0f;
+ }
+ rgbMat.release();
+ return floatData;
+ }
+
+
+
+
+}
diff --git a/src/main/java/com/ly/play/VideoPlayer.java b/src/main/java/com/ly/play/VideoPlayer.java
deleted file mode 100644
index 1e91524..0000000
--- a/src/main/java/com/ly/play/VideoPlayer.java
+++ /dev/null
@@ -1,280 +0,0 @@
-package com.ly.play;
-
-import com.ly.layout.VideoPanel;
-import org.bytedeco.javacv.*;
-
-import javax.swing.*;
-import java.awt.image.BufferedImage;
-
-public class VideoPlayer {
- private FrameGrabber grabber;
- private Java2DFrameConverter converter = new Java2DFrameConverter();
- private boolean isPlaying = false;
- private boolean isPaused = false;
- private Thread videoThread;
- private VideoPanel videoPanel;
-
- private long videoDuration = 0; // 毫秒
- private long currentTimestamp = 0; // 毫秒
-
- public VideoPlayer(VideoPanel videoPanel) {
- this.videoPanel = videoPanel;
- }
-
- // 加载视频或流
- // 加载视频或流
- public void loadVideo(String videoFilePathOrStreamUrl) throws Exception {
- stopVideo();
-
-
- if (videoFilePathOrStreamUrl.equals("0")) {
- int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl);
- grabber = new OpenCVFrameGrabber(cameraIndex);
- grabber.start();
- videoDuration = 0; // 摄像头没有固定的时长
- playVideo();
- } else {
- // 输入不是数字,尝试使用 FFmpegFrameGrabber 打开流或视频文件
- grabber = new FFmpegFrameGrabber(videoFilePathOrStreamUrl);
- grabber.start();
- videoDuration = grabber.getLengthInTime() / 1000; // 转换为毫秒
- }
-
-
- // 显示第一帧
- Frame frame;
- if (grabber instanceof OpenCVFrameGrabber) {
- frame = grabber.grab();
- } else {
- frame = grabber.grab();
- }
- if (frame != null && frame.image != null) {
- BufferedImage bufferedImage = converter.getBufferedImage(frame);
- videoPanel.updateImage(bufferedImage);
- currentTimestamp = 0;
- }
-
- // 重置到视频开始位置
- if (grabber instanceof FFmpegFrameGrabber) {
- grabber.setTimestamp(0);
- }
- currentTimestamp = 0;
- }
-
- public void playVideo() {
- if (grabber == null) {
- JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE);
- return;
- }
-
- if (isPlaying) {
- if (isPaused) {
- isPaused = false; // 恢复播放
- }
- return;
- }
-
- isPlaying = true;
- isPaused = false;
-
- videoThread = new Thread(() -> {
- try {
- if (grabber instanceof OpenCVFrameGrabber) {
- // 摄像头捕获
- while (isPlaying) {
- if (isPaused) {
- Thread.sleep(100);
- continue;
- }
-
- Frame frame = grabber.grab();
- if (frame == null) {
- isPlaying = false;
- break;
- }
-
- BufferedImage bufferedImage = converter.getBufferedImage(frame);
- if (bufferedImage != null) {
- videoPanel.updateImage(bufferedImage);
- }
- }
- } else {
- // 视频文件或流
- double frameRate = grabber.getFrameRate();
- if (frameRate <= 0 || Double.isNaN(frameRate)) {
- frameRate = 25; // 默认帧率
- }
- long frameInterval = (long) (1000 / frameRate); // 每帧间隔时间(毫秒)
- long startTime = System.currentTimeMillis();
- long frameCount = 0;
-
- while (isPlaying) {
- if (isPaused) {
- Thread.sleep(100);
- startTime += 100; // 调整开始时间以考虑暂停时间
- continue;
- }
-
- Frame frame = grabber.grab();
- if (frame == null) {
- // 视频播放结束
- isPlaying = false;
- break;
- }
-
- BufferedImage bufferedImage = converter.getBufferedImage(frame);
- if (bufferedImage != null) {
- videoPanel.updateImage(bufferedImage);
-
- // 更新当前帧时间戳
- frameCount++;
- long expectedTime = frameCount * frameInterval;
- long actualTime = System.currentTimeMillis() - startTime;
-
- currentTimestamp = grabber.getTimestamp() / 1000;
-
- // 如果实际时间落后于预期时间,进行调整
- if (actualTime < expectedTime) {
- Thread.sleep(expectedTime - actualTime);
- }
- }
- }
- }
-
- // 视频播放完毕后,停止播放
- isPlaying = false;
-
- } catch (Exception ex) {
- ex.printStackTrace();
- }
- });
- videoThread.start();
- }
-
- // 暂停视频
- public void pauseVideo() {
- if (!isPlaying) {
- return;
- }
- isPaused = true;
- }
-
- // 重播视频
- public void replayVideo() {
- try {
- if (grabber instanceof FFmpegFrameGrabber) {
- grabber.setTimestamp(0); // 重置到视频开始位置
- grabber.flush(); // 清除缓存
- currentTimestamp = 0;
-
- // 显示第一帧
- Frame frame = grabber.grab();
- if (frame != null && frame.image != null) {
- BufferedImage bufferedImage = converter.getBufferedImage(frame);
- videoPanel.updateImage(bufferedImage);
- }
-
- playVideo(); // 开始播放
- } else if (grabber instanceof OpenCVFrameGrabber) {
- // 对于摄像头,重播相当于重新开始播放
- playVideo();
- }
- } catch (Exception e) {
- e.printStackTrace();
- JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
- }
- }
-
- // 停止视频
- public void stopVideo() {
- isPlaying = false;
- isPaused = false;
-
- if (videoThread != null && videoThread.isAlive()) {
- try {
- videoThread.join();
- } catch (InterruptedException e) {
- e.printStackTrace();
- }
- }
-
- if (grabber != null) {
- try {
- grabber.stop();
- grabber.release();
- } catch (Exception ex) {
- ex.printStackTrace();
- }
- grabber = null;
- }
- }
-
- // 快进或后退
- public void seekTo(long seekTime) {
- if (grabber == null) return;
- if (!(grabber instanceof FFmpegFrameGrabber)) {
- JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
- return;
- }
- try {
- isPaused = false; // 取消暂停
- isPlaying = false; // 停止当前播放线程
- if (videoThread != null && videoThread.isAlive()) {
- videoThread.join();
- }
-
- grabber.setTimestamp(seekTime * 1000); // 转换为微秒
- grabber.flush(); // 清除缓存
-
- Frame frame;
- do {
- frame = grabber.grab();
- if (frame == null) {
- break;
- }
- } while (frame.image == null); // 跳过没有图像的帧
-
- if (frame != null && frame.image != null) {
- BufferedImage bufferedImage = converter.getBufferedImage(frame);
- videoPanel.updateImage(bufferedImage);
-
- // 更新当前帧时间戳
- currentTimestamp = grabber.getTimestamp() / 1000;
- }
-
- // 重新开始播放
- playVideo();
-
- } catch (Exception ex) {
- ex.printStackTrace();
- }
- }
-
- // 快进
- public void fastForward(long milliseconds) {
- if (!(grabber instanceof FFmpegFrameGrabber)) {
- JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
- return;
- }
- long newTime = Math.min(currentTimestamp + milliseconds, videoDuration);
- seekTo(newTime);
- }
-
- // 后退
- public void rewind(long milliseconds) {
- if (!(grabber instanceof FFmpegFrameGrabber)) {
- JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
- return;
- }
- long newTime = Math.max(currentTimestamp - milliseconds, 0);
- seekTo(newTime);
- }
-
- public long getVideoDuration() {
- return videoDuration;
- }
-
- public FrameGrabber getGrabber() {
- return grabber;
- }
-}
diff --git a/src/main/java/com/ly/play/ff/VideoPlayer.java b/src/main/java/com/ly/play/ff/VideoPlayer.java
new file mode 100644
index 0000000..1d5a46b
--- /dev/null
+++ b/src/main/java/com/ly/play/ff/VideoPlayer.java
@@ -0,0 +1,328 @@
+//package com.ly.play.ff;
+//
+//import com.ly.layout.VideoPanel;
+//import com.ly.model_load.ModelManager;
+//import com.ly.onnx.engine.InferenceEngine;
+//import com.ly.onnx.model.InferenceResult;
+//import com.ly.onnx.utils.DrawImagesUtils;
+//import com.ly.onnx.utils.ImageUtils;
+//import org.bytedeco.javacv.*;
+//
+//import javax.swing.*;
+//import java.awt.image.BufferedImage;
+//import java.util.ArrayList;
+//import java.util.List;
+//
+//public class VideoPlayer {
+// private FrameGrabber grabber;
+// private Java2DFrameConverter converter = new Java2DFrameConverter();
+// private boolean isPlaying = false;
+// private boolean isPaused = false;
+// private Thread videoThread;
+// private VideoPanel videoPanel;
+//
+// private long videoDuration = 0; // 毫秒
+// private long currentTimestamp = 0; // 毫秒
+//
+// ModelManager modelManager;
+// private List inferenceEngines = new ArrayList<>();
+//
+// public VideoPlayer(VideoPanel videoPanel, ModelManager modelManager) {
+// this.videoPanel = videoPanel;
+// this.modelManager = modelManager;
+// System.out.println();
+// }
+//
+// // 加载视频或流
+// public void loadVideo(String videoFilePathOrStreamUrl) throws Exception {
+// stopVideo();
+// if (videoFilePathOrStreamUrl.equals("0")) {
+// int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl);
+// grabber = new OpenCVFrameGrabber(cameraIndex);
+// grabber.start();
+// videoDuration = 0; // 摄像头没有固定的时长
+// playVideo();
+// } else {
+// // 输入不是数字,尝试使用 FFmpegFrameGrabber 打开流或视频文件
+// grabber = new FFmpegFrameGrabber(videoFilePathOrStreamUrl);
+// grabber.start();
+// videoDuration = grabber.getLengthInTime() / 1000; // 转换为毫秒
+// }
+//
+//
+// // 显示第一帧
+// Frame frame;
+// if (grabber instanceof OpenCVFrameGrabber) {
+// frame = grabber.grab();
+// } else {
+// frame = grabber.grab();
+// }
+// if (frame != null && frame.image != null) {
+// BufferedImage bufferedImage = converter.getBufferedImage(frame);
+// videoPanel.updateImage(bufferedImage);
+// currentTimestamp = 0;
+// }
+//
+// // 重置到视频开始位置
+// if (grabber instanceof FFmpegFrameGrabber) {
+// grabber.setTimestamp(0);
+// }
+// currentTimestamp = 0;
+// }
+//
+//
+//
+//
+//
+// //播放
+// public void playVideo() {
+// if (grabber == null) {
+// JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE);
+// return;
+// }
+//
+// if (inferenceEngines == null){
+// JOptionPane.showMessageDialog(null, "请先加载模型给文件。", "提示", JOptionPane.WARNING_MESSAGE);
+// return;
+// }
+//
+// if (isPlaying) {
+// if (isPaused) {
+// isPaused = false; // 恢复播放
+// }
+// return;
+// }
+//
+// isPlaying = true;
+// isPaused = false;
+//
+// videoThread = new Thread(() -> {
+// try {
+// if (grabber instanceof OpenCVFrameGrabber) {
+// // 摄像头捕获
+// while (isPlaying) {
+// if (isPaused) {
+// Thread.sleep(10);
+// continue;
+// }
+//
+// Frame frame = grabber.grab();
+// if (frame == null) {
+// isPlaying = false;
+// break;
+// }
+//
+// BufferedImage bufferedImage = converter.getBufferedImage(frame);
+// List inferenceResults = new ArrayList<>();
+// if (bufferedImage != null) {
+// float[] inputData = ImageUtils.frameToFloatArray(frame);
+// for (InferenceEngine inferenceEngine : inferenceEngines) {
+// inferenceResults.add(inferenceEngine.infer(inputData,640,640));
+// }
+// //绘制
+// DrawImagesUtils.drawInferenceResult(bufferedImage,inferenceResults);
+// //更新绘制后图像
+// videoPanel.updateImage(bufferedImage);
+// }
+// }
+// } else {
+// // 视频文件或流
+// double frameRate = grabber.getFrameRate();
+// if (frameRate <= 0 || Double.isNaN(frameRate)) {
+// frameRate = 25; // 默认帧率
+// }
+// long frameInterval = (long) (1000 / frameRate); // 每帧间隔时间(毫秒)
+// long startTime = System.currentTimeMillis();
+// long frameCount = 0;
+//
+// while (isPlaying) {
+// if (isPaused) {
+// Thread.sleep(100);
+// startTime += 100; // 调整开始时间以考虑暂停时间
+// continue;
+// }
+//
+// Frame frame = grabber.grab();
+// if (frame == null) {
+// // 视频播放结束
+// isPlaying = false;
+// break;
+// }
+//
+//
+//
+// BufferedImage bufferedImage = converter.getBufferedImage(frame);
+//
+//
+// List inferenceResults = new ArrayList<>();
+// if (bufferedImage != null) {
+// float[] inputData = ImageUtils.frameToFloatArray(frame);
+// for (InferenceEngine inferenceEngine : inferenceEngines) {
+// inferenceResults.add(inferenceEngine.infer(inputData,640,640));
+// }
+// //绘制
+// DrawImagesUtils.drawInferenceResult(bufferedImage,inferenceResults);
+// //更新绘制后图像
+// videoPanel.updateImage(bufferedImage);
+// }
+//
+// if (bufferedImage != null) {
+// videoPanel.updateImage(bufferedImage);
+//
+// // 更新当前帧时间戳
+// frameCount++;
+// long expectedTime = frameCount * frameInterval;
+// long actualTime = System.currentTimeMillis() - startTime;
+//
+// currentTimestamp = grabber.getTimestamp() / 1000;
+//
+// // 如果实际时间落后于预期时间,进行调整
+// if (actualTime < expectedTime) {
+// Thread.sleep(expectedTime - actualTime);
+// }
+// }
+// }
+// }
+//
+// // 视频播放完毕后,停止播放
+// isPlaying = false;
+//
+// } catch (Exception ex) {
+// ex.printStackTrace();
+// }
+// });
+// videoThread.start();
+// }
+//
+// // 暂停视频
+// public void pauseVideo() {
+// if (!isPlaying) {
+// return;
+// }
+// isPaused = true;
+// }
+//
+// // 重播视频
+// public void replayVideo() {
+// try {
+// if (grabber instanceof FFmpegFrameGrabber) {
+// grabber.setTimestamp(0); // 重置到视频开始位置
+// grabber.flush(); // 清除缓存
+// currentTimestamp = 0;
+//
+// // 显示第一帧
+// Frame frame = grabber.grab();
+// if (frame != null && frame.image != null) {
+// BufferedImage bufferedImage = converter.getBufferedImage(frame);
+// videoPanel.updateImage(bufferedImage);
+// }
+//
+// playVideo(); // 开始播放
+// } else if (grabber instanceof OpenCVFrameGrabber) {
+// // 对于摄像头,重播相当于重新开始播放
+// playVideo();
+// }
+// } catch (Exception e) {
+// e.printStackTrace();
+// JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
+// }
+// }
+//
+// // 停止视频
+// public void stopVideo() {
+// isPlaying = false;
+// isPaused = false;
+//
+// if (videoThread != null && videoThread.isAlive()) {
+// try {
+// videoThread.join();
+// } catch (InterruptedException e) {
+// e.printStackTrace();
+// }
+// }
+//
+// if (grabber != null) {
+// try {
+// grabber.stop();
+// grabber.release();
+// } catch (Exception ex) {
+// ex.printStackTrace();
+// }
+// grabber = null;
+// }
+// }
+//
+// // 快进或后退
+// public void seekTo(long seekTime) {
+// if (grabber == null) return;
+// if (!(grabber instanceof FFmpegFrameGrabber)) {
+// JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
+// return;
+// }
+// try {
+// isPaused = false; // 取消暂停
+// isPlaying = false; // 停止当前播放线程
+// if (videoThread != null && videoThread.isAlive()) {
+// videoThread.join();
+// }
+//
+// grabber.setTimestamp(seekTime * 1000); // 转换为微秒
+// grabber.flush(); // 清除缓存
+//
+// Frame frame;
+// do {
+// frame = grabber.grab();
+// if (frame == null) {
+// break;
+// }
+// } while (frame.image == null); // 跳过没有图像的帧
+//
+// if (frame != null && frame.image != null) {
+// BufferedImage bufferedImage = converter.getBufferedImage(frame);
+// videoPanel.updateImage(bufferedImage);
+//
+// // 更新当前帧时间戳
+// currentTimestamp = grabber.getTimestamp() / 1000;
+// }
+//
+// // 重新开始播放
+// playVideo();
+//
+// } catch (Exception ex) {
+// ex.printStackTrace();
+// }
+// }
+//
+// // 快进
+// public void fastForward(long milliseconds) {
+// if (!(grabber instanceof FFmpegFrameGrabber)) {
+// JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
+// return;
+// }
+// long newTime = Math.min(currentTimestamp + milliseconds, videoDuration);
+// seekTo(newTime);
+// }
+//
+// // 后退
+// public void rewind(long milliseconds) {
+// if (!(grabber instanceof FFmpegFrameGrabber)) {
+// JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
+// return;
+// }
+// long newTime = Math.max(currentTimestamp - milliseconds, 0);
+// seekTo(newTime);
+// }
+//
+// public long getVideoDuration() {
+// return videoDuration;
+// }
+//
+// public FrameGrabber getGrabber() {
+// return grabber;
+// }
+//
+// public void addInferenceEngines(InferenceEngine inferenceEngine){
+// this.inferenceEngines.add(inferenceEngine);
+// }
+//
+//}
diff --git a/src/main/java/com/ly/play/opencv/VideoPlayer.java b/src/main/java/com/ly/play/opencv/VideoPlayer.java
new file mode 100644
index 0000000..d341b70
--- /dev/null
+++ b/src/main/java/com/ly/play/opencv/VideoPlayer.java
@@ -0,0 +1,427 @@
+package com.ly.play.opencv;
+
+import com.ly.layout.VideoPanel;
+import com.ly.model_load.ModelManager;
+import com.ly.onnx.engine.InferenceEngine;
+import com.ly.onnx.model.InferenceResult;
+import com.ly.onnx.utils.DrawImagesUtils;
+import org.opencv.core.CvType;
+import org.opencv.core.Mat;
+import org.opencv.core.Rect;
+import org.opencv.core.Size;
+import org.opencv.imgproc.Imgproc;
+import org.opencv.videoio.VideoCapture;
+import org.opencv.videoio.Videoio;
+
+import javax.swing.*;
+import java.awt.image.BufferedImage;
+import java.awt.image.DataBufferByte;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.*;
+
+import static com.ly.onnx.utils.ImageUtils.matToBufferedImage;
+
+public class VideoPlayer {
+ static {
+ // 加载 OpenCV 库
+ nu.pattern.OpenCV.loadLocally();
+ String OS = System.getProperty("os.name").toLowerCase();
+ if (OS.contains("win")) {
+ System.load(ClassLoader.getSystemResource("lib/win/opencv_videoio_ffmpeg470_64.dll").getPath());
+ }
+ }
+
+ private VideoCapture videoCapture;
+ private volatile boolean isPlaying = false;
+ private volatile boolean isPaused = false;
+ private Thread frameReadingThread;
+ private Thread inferenceThread;
+ private VideoPanel videoPanel;
+
+ private long videoDuration = 0; // 毫秒
+ private long currentTimestamp = 0; // 毫秒
+
+ private ModelManager modelManager;
+ private List inferenceEngines = new ArrayList<>();
+
+ // 定义阻塞队列来缓冲转换后的数据
+ private BlockingQueue frameDataQueue = new LinkedBlockingQueue<>(10); // 队列容量可根据需要调整
+
+ public VideoPlayer(VideoPanel videoPanel, ModelManager modelManager) {
+ this.videoPanel = videoPanel;
+ this.modelManager = modelManager;
+ }
+
+ // 加载视频或流
+ public void loadVideo(String videoFilePathOrStreamUrl) throws Exception {
+ stopVideo();
+ if (videoFilePathOrStreamUrl.equals("0")) {
+ int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl);
+ videoCapture = new VideoCapture(cameraIndex);
+ if (!videoCapture.isOpened()) {
+ throw new Exception("无法打开摄像头");
+ }
+ videoDuration = 0; // 摄像头没有固定的时长
+ playVideo();
+ } else {
+ // 输入不是数字,尝试打开视频文件
+ videoCapture = new VideoCapture(videoFilePathOrStreamUrl, Videoio.CAP_FFMPEG);
+ if (!videoCapture.isOpened()) {
+ throw new Exception("无法打开视频文件:" + videoFilePathOrStreamUrl);
+ }
+ double frameCount = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FRAME_COUNT);
+ double fps = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FPS);
+ if (fps <= 0 || Double.isNaN(fps)) {
+ fps = 25; // 默认帧率
+ }
+ videoDuration = (long) (frameCount / fps * 1000); // 转换为毫秒
+ }
+
+ // 显示第一帧
+ Mat frame = new Mat();
+ if (videoCapture.read(frame)) {
+ BufferedImage bufferedImage = matToBufferedImage(frame);
+ videoPanel.updateImage(bufferedImage);
+ currentTimestamp = 0;
+ } else {
+ throw new Exception("无法读取第一帧");
+ }
+
+ // 重置到视频开始位置
+ videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_FRAMES, 0);
+ currentTimestamp = 0;
+ }
+
+ // 播放
+ public void playVideo() {
+ if (videoCapture == null || !videoCapture.isOpened()) {
+ JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE);
+ return;
+ }
+
+ if (isPlaying) {
+ if (isPaused) {
+ isPaused = false; // 恢复播放
+ }
+ return;
+ }
+
+ isPlaying = true;
+ isPaused = false;
+
+ frameDataQueue.clear(); // 开始播放前清空队列
+
+ // 创建并启动帧读取和转换线程
+ frameReadingThread = new Thread(() -> {
+ try {
+ double fps = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FPS);
+ if (fps <= 0 || Double.isNaN(fps)) {
+ fps = 25; // 默认帧率
+ }
+ long frameDelay = (long) (1000 / fps);
+
+ while (isPlaying) {
+ if (Thread.currentThread().isInterrupted()) {
+ break;
+ }
+ if (isPaused) {
+ Thread.sleep(10);
+ continue;
+ }
+
+ Mat frame = new Mat();
+ if (!videoCapture.read(frame) || frame.empty()) {
+ isPlaying = false;
+ break;
+ }
+
+ long startTime = System.currentTimeMillis();
+ BufferedImage bufferedImage = matToBufferedImage(frame);
+
+ if (bufferedImage != null) {
+ float[] floats = preprocessAndConvertBufferedImage(bufferedImage);
+
+ // 创建 FrameData 对象并放入队列
+ FrameData frameData = new FrameData(bufferedImage, floats);
+ frameDataQueue.put(frameData); // 阻塞,如果队列已满
+ }
+
+ // 控制帧率
+ currentTimestamp = (long) videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_POS_MSEC);
+
+ // 控制播放速度
+ long processingTime = System.currentTimeMillis() - startTime;
+ long sleepTime = frameDelay - processingTime;
+ if (sleepTime > 0) {
+ Thread.sleep(sleepTime);
+ }
+ }
+ } catch (Exception ex) {
+ ex.printStackTrace();
+ } finally {
+ isPlaying = false;
+ }
+ });
+
+ // 创建并启动推理线程
+ inferenceThread = new Thread(() -> {
+ try {
+ while (isPlaying || !frameDataQueue.isEmpty()) {
+ if (Thread.currentThread().isInterrupted()) {
+ break;
+ }
+ if (isPaused) {
+ Thread.sleep(100);
+ continue;
+ }
+
+ FrameData frameData = frameDataQueue.poll(100, TimeUnit.MILLISECONDS); // 等待数据
+ if (frameData == null) {
+ continue; // 没有数据,继续检查 isPlaying
+ }
+
+ BufferedImage bufferedImage = frameData.image;
+ float[] floatArray = frameData.floatArray;
+
+ // 执行推理
+ List inferenceResults = new ArrayList<>();
+ for (InferenceEngine inferenceEngine : inferenceEngines) {
+ // 假设 InferenceEngine 有 infer 方法接受 float 数组
+// inferenceResults.add(inferenceEngine.infer(floatArray, 640, 640));
+ }
+ // 绘制推理结果
+ DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults);
+
+ // 更新绘制后图像
+ videoPanel.updateImage(bufferedImage);
+ }
+ } catch (Exception ex) {
+ ex.printStackTrace();
+ }
+ });
+
+ frameReadingThread.start();
+ inferenceThread.start();
+ }
+
+ // 暂停视频
+ public void pauseVideo() {
+ if (!isPlaying) {
+ return;
+ }
+ isPaused = true;
+ }
+
+ // 重播视频
+ public void replayVideo() {
+ try {
+ stopVideo(); // 停止当前播放
+ if (videoCapture != null) {
+ videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_FRAMES, 0);
+ currentTimestamp = 0;
+
+ // 显示第一帧
+ Mat frame = new Mat();
+ if (videoCapture.read(frame)) {
+ BufferedImage bufferedImage = matToBufferedImage(frame);
+ videoPanel.updateImage(bufferedImage);
+ }
+
+ playVideo(); // 开始播放
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
+ }
+ }
+
+ // 停止视频
+ public void stopVideo() {
+ isPlaying = false;
+ isPaused = false;
+
+ if (frameReadingThread != null && frameReadingThread.isAlive()) {
+ frameReadingThread.interrupt();
+ }
+
+ if (inferenceThread != null && inferenceThread.isAlive()) {
+ inferenceThread.interrupt();
+ }
+
+ if (videoCapture != null) {
+ videoCapture.release();
+ videoCapture = null;
+ }
+
+ frameDataQueue.clear();
+ }
+
+ // 快进或后退
+ public void seekTo(long seekTime) {
+ if (videoCapture == null) return;
+ try {
+ isPaused = false; // 取消暂停
+ stopVideo(); // 停止当前播放
+ videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_MSEC, seekTime);
+ currentTimestamp = seekTime;
+
+ Mat frame = new Mat();
+ if (videoCapture.read(frame)) {
+ BufferedImage bufferedImage = matToBufferedImage(frame);
+ videoPanel.updateImage(bufferedImage);
+ }
+
+ // 重新开始播放
+ playVideo();
+
+ } catch (Exception ex) {
+ ex.printStackTrace();
+ }
+ }
+
+ // 快进
+ public void fastForward(long milliseconds) {
+ long newTime = Math.min(currentTimestamp + milliseconds, videoDuration);
+ seekTo(newTime);
+ }
+
+ // 后退
+ public void rewind(long milliseconds) {
+ long newTime = Math.max(currentTimestamp - milliseconds, 0);
+ seekTo(newTime);
+ }
+
+ public void addInferenceEngines(InferenceEngine inferenceEngine) {
+ this.inferenceEngines.add(inferenceEngine);
+ }
+
+ // 定义一个内部类来存储帧数据
+ private static class FrameData {
+ public BufferedImage image;
+ public float[] floatArray;
+
+ public FrameData(BufferedImage image, float[] floatArray) {
+ this.image = image;
+ this.floatArray = floatArray;
+ }
+ }
+
+ // 将 BufferedImage 预处理并转换为一维 float[] 数组
+ public static float[] preprocessAndConvertBufferedImage(BufferedImage image) {
+ int targetWidth = 640;
+ int targetHeight = 640;
+
+ // 将 BufferedImage 转换为 Mat
+ Mat matImage = bufferedImageToMat(image);
+
+ // 原始图像尺寸
+ int origWidth = matImage.width();
+ int origHeight = matImage.height();
+
+ // 计算缩放因子
+ float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
+
+ // 计算新的图像尺寸
+ int newWidth = Math.round(origWidth * scalingFactor);
+ int newHeight = Math.round(origHeight * scalingFactor);
+
+ // 调整图像尺寸
+ Mat resizedImage = new Mat();
+ Imgproc.resize(matImage, resizedImage, new Size(newWidth, newHeight));
+
+ // 转换为 RGB
+ Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
+
+ // 创建目标图像并将调整后的图像填充到目标图像中
+ Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
+ int xOffset = (targetWidth - newWidth) / 2;
+ int yOffset = (targetHeight - newHeight) / 2;
+ Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
+ resizedImage.copyTo(paddedImage.submat(roi));
+
+ // 将图像数据转换为输入所需的浮点数组
+ int imageSize = targetWidth * targetHeight;
+ float[] inputData = new float[3 * imageSize];
+ paddedImage.reshape(1, imageSize * 3).get(0, 0, inputData);
+
+ // 释放资源
+ matImage.release();
+ resizedImage.release();
+ paddedImage.release();
+
+ return inputData;
+ }
+
+ // 辅助方法:将 BufferedImage 转换为 OpenCV 的 Mat 格式
+ public static Mat bufferedImageToMat(BufferedImage bi) {
+ int width = bi.getWidth();
+ int height = bi.getHeight();
+ Mat mat = new Mat(height, width, CvType.CV_8UC3);
+ byte[] data = ((DataBufferByte) bi.getRaster().getDataBuffer()).getData();
+ mat.put(0, 0, data);
+ return mat;
+ }
+
+ // 可选的预处理方法
+ public Map preprocessImage(Mat image) {
+ int targetWidth = 640;
+ int targetHeight = 640;
+
+ int origWidth = image.width();
+ int origHeight = image.height();
+
+ // 计算缩放因子
+ float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
+
+ // 计算新的图像尺寸
+ int newWidth = Math.round(origWidth * scalingFactor);
+ int newHeight = Math.round(origHeight * scalingFactor);
+
+ // 调整图像尺寸
+ Mat resizedImage = new Mat();
+ Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight));
+
+ // 转换为 RGB 并归一化
+ Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
+ resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
+
+ // 创建填充后的图像
+ Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
+ int xOffset = (targetWidth - newWidth) / 2;
+ int yOffset = (targetHeight - newHeight) / 2;
+ Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
+ resizedImage.copyTo(paddedImage.submat(roi));
+
+ // 将图像数据转换为数组
+ int imageSize = targetWidth * targetHeight;
+ float[] chwData = new float[3 * imageSize];
+ float[] hwcData = new float[3 * imageSize];
+ paddedImage.get(0, 0, hwcData);
+
+ // 转换为 CHW 格式
+ int channelSize = imageSize;
+ for (int c = 0; c < 3; c++) {
+ for (int i = 0; i < imageSize; i++) {
+ chwData[c * channelSize + i] = hwcData[i * 3 + c];
+ }
+ }
+
+ // 释放图像资源
+ resizedImage.release();
+ paddedImage.release();
+
+ // 将预处理结果和偏移信息存入 Map
+ Map result = new HashMap<>();
+ result.put("inputData", chwData);
+ result.put("origWidth", origWidth);
+ result.put("origHeight", origHeight);
+ result.put("scalingFactor", scalingFactor);
+ result.put("xOffset", xOffset);
+ result.put("yOffset", yOffset);
+
+ return result;
+ }
+}
diff --git a/src/main/java/com/ly/utils/CameraDeviceLister.java b/src/main/java/com/ly/utils/CameraDeviceLister.java
deleted file mode 100644
index 8f59cd1..0000000
--- a/src/main/java/com/ly/utils/CameraDeviceLister.java
+++ /dev/null
@@ -1,13 +0,0 @@
-package com.ly.utils;
-
-import org.bytedeco.javacv.FrameGrabber;
-import org.bytedeco.javacv.VideoInputFrameGrabber;
-
-public class CameraDeviceLister {
- public static void main(String[] args) throws FrameGrabber.Exception {
- String[] deviceDescriptions = VideoInputFrameGrabber.getDeviceDescriptions();
- for (String deviceDescription : deviceDescriptions) {
- System.out.println("摄像头索引 " + ": " + deviceDescription);
- }
- }
-}
diff --git a/src/main/java/com/ly/utils/OpenCVTest.java b/src/main/java/com/ly/utils/OpenCVTest.java
new file mode 100644
index 0000000..b25f50a
--- /dev/null
+++ b/src/main/java/com/ly/utils/OpenCVTest.java
@@ -0,0 +1,28 @@
+package com.ly.utils;
+
+import org.opencv.core.Core;
+import org.opencv.core.Mat;
+import org.opencv.videoio.VideoCapture;
+
+public class OpenCVTest {
+// static {
+// nu.pattern.OpenCV.loadLocally();
+// }
+
+ public static void main(String[] args) {
+ VideoCapture capture = new VideoCapture(0); // 打开默认摄像头
+ if (!capture.isOpened()) {
+ System.out.println("无法打开摄像头");
+ return;
+ }
+
+ Mat frame = new Mat();
+ if (capture.read(frame)) {
+ System.out.println("成功读取一帧图像");
+ } else {
+ System.out.println("无法读取图像");
+ }
+
+ capture.release();
+ }
+}
diff --git a/src/main/java/com/ly/utils/RTSPStreamer.java b/src/main/java/com/ly/utils/RTSPStreamer.java
deleted file mode 100644
index 2cb6845..0000000
--- a/src/main/java/com/ly/utils/RTSPStreamer.java
+++ /dev/null
@@ -1,57 +0,0 @@
-package com.ly.utils;
-
-import org.bytedeco.ffmpeg.global.avcodec;
-import org.bytedeco.javacv.*;
-
-public class RTSPStreamer {
-
- public static void main(String[] args) {
- String inputFile = "C:\\Users\\ly\\Desktop\\屏幕录制 2024-09-20 225443.mp4"; // 替换为您的本地视频文件路径
- String rtspUrl = "rtsp://localhost:8554/live"; // 替换为您的 RTSP 服务器地址
-
- FFmpegFrameGrabber grabber = null;
- FFmpegFrameRecorder recorder = null;
-
- try {
- // 初始化 FFmpegFrameGrabber 以从本地视频文件读取
- grabber = new FFmpegFrameGrabber(inputFile);
- grabber.start();
-
- // 初始化 FFmpegFrameRecorder 以推流到 RTSP 服务器
- recorder = new FFmpegFrameRecorder(rtspUrl, grabber.getImageWidth(), grabber.getImageHeight(), grabber.getAudioChannels());
- recorder.setFormat("rtsp");
- recorder.setFrameRate(grabber.getFrameRate());
- recorder.setVideoBitrate(grabber.getVideoBitrate());
- recorder.setVideoCodec(avcodec.AV_CODEC_ID_H264); // 设置视频编码格式
- recorder.setAudioCodec(avcodec.AV_CODEC_ID_AAC); // 设置音频编码格式
-
- // 设置 RTSP 传输选项(如果需要)
- recorder.setOption("rtsp_transport", "tcp");
-
- recorder.start();
-
- Frame frame;
- while ((frame = grabber.grab()) != null) {
- recorder.record(frame);
- }
-
- System.out.println("推流完成。");
-
- } catch (Exception e) {
- e.printStackTrace();
- } finally {
- try {
- if (recorder != null) {
- recorder.stop();
- recorder.release();
- }
- if (grabber != null) {
- grabber.stop();
- grabber.release();
- }
- } catch (Exception e) {
- e.printStackTrace();
- }
- }
- }
-}
diff --git a/src/main/resources/lib/win/opencv_videoio_ffmpeg470_64.dll b/src/main/resources/lib/win/opencv_videoio_ffmpeg470_64.dll
new file mode 100644
index 0000000..798d5cd
Binary files /dev/null and b/src/main/resources/lib/win/opencv_videoio_ffmpeg470_64.dll differ
diff --git a/src/main/resources/model/best.onnx b/src/main/resources/model/best.onnx
new file mode 100644
index 0000000..b9d7489
Binary files /dev/null and b/src/main/resources/model/best.onnx differ