diff --git a/pom.xml b/pom.xml index cb9742a..ede1555 100644 --- a/pom.xml +++ b/pom.xml @@ -22,7 +22,7 @@ com.microsoft.onnxruntime onnxruntime_gpu - 1.17.0 + 1.16.0 org.bytedeco @@ -40,6 +40,21 @@ ffmpeg-platform 5.0-1.5.7 + + org.openpnp + opencv + 4.7.0-0 + + + org.projectlombok + lombok + 1.18.32 + + + com.alibaba + fastjson + 1.2.83 + diff --git a/src/main/java/com/ly/VideoInferenceApp.java b/src/main/java/com/ly/VideoInferenceApp.java index d693bc9..da9e96b 100644 --- a/src/main/java/com/ly/VideoInferenceApp.java +++ b/src/main/java/com/ly/VideoInferenceApp.java @@ -3,13 +3,18 @@ package com.ly; import com.formdev.flatlaf.FlatLightLaf; import com.ly.layout.VideoPanel; import com.ly.model_load.ModelManager; -import com.ly.play.VideoPlayer; +import com.ly.onnx.engine.InferenceEngine; +import com.ly.onnx.model.ModelInfo; +import com.ly.play.opencv.VideoPlayer; + import javax.swing.*; import javax.swing.filechooser.FileNameExtensionFilter; import javax.swing.filechooser.FileSystemView; import java.awt.*; import java.io.File; +import java.util.ArrayList; +import java.util.Collections; public class VideoInferenceApp extends JFrame { @@ -43,13 +48,13 @@ public class VideoInferenceApp extends JFrame { videoPanel = new VideoPanel(); videoPanel.setBackground(Color.BLACK); - // 初始化 VideoPlayer - videoPlayer = new VideoPlayer(videoPanel); - // 模型列表区域 modelManager = new ModelManager(); modelManager.setPreferredSize(new Dimension(250, 0)); // 设置模型列表区域的宽度 + // 初始化 VideoPlayer + videoPlayer = new VideoPlayer(videoPanel, modelManager); + // 使用 JSplitPane 分割视频区域和模型列表区域 JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT, videoPanel, modelManager); splitPane.setResizeWeight(0.8); // 视频区域初始占据80%的空间 @@ -120,8 +125,16 @@ public class VideoInferenceApp extends JFrame { // 添加视频加载按钮的行为 loadVideoButton.addActionListener(e -> selectVideoFile()); - // 添加模型加载按钮的行为 - loadModelButton.addActionListener(e -> modelManager.loadModel(this)); + loadModelButton.addActionListener(e -> { + modelManager.loadModel(this); + DefaultListModel modelList = modelManager.getModelList(); + ArrayList models = Collections.list(modelList.elements()); + for (ModelInfo modelInfo : models) { + if (modelInfo != null) { + videoPlayer.addInferenceEngines(new InferenceEngine(modelInfo.getModelFilePath(), modelInfo.getLabels())); + } + } + }); // 播放按钮 playButton.addActionListener(e -> videoPlayer.playVideo()); diff --git a/src/main/java/com/ly/model_load/ModelManager.java b/src/main/java/com/ly/model_load/ModelManager.java index f80eb6d..d3dd348 100644 --- a/src/main/java/com/ly/model_load/ModelManager.java +++ b/src/main/java/com/ly/model_load/ModelManager.java @@ -1,42 +1,37 @@ package com.ly.model_load; - - import com.ly.file.FileEditor; +import com.ly.onnx.model.ModelInfo; import javax.swing.*; import javax.swing.filechooser.FileNameExtensionFilter; import javax.swing.filechooser.FileSystemView; -import javax.swing.table.DefaultTableModel; import java.awt.*; import java.awt.event.MouseAdapter; import java.awt.event.MouseEvent; import java.io.File; public class ModelManager extends JPanel { - private DefaultListModel modelListModel; - private JList modelList; + private DefaultListModel modelListModel; + private JList modelList; public ModelManager() { setLayout(new BorderLayout()); modelListModel = new DefaultListModel<>(); modelList = new JList<>(modelListModel); + modelList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); // 设置为单选 JScrollPane modelScrollPane = new JScrollPane(modelList); add(modelScrollPane, BorderLayout.CENTER); - // 添加双击事件,编辑标签文件 + // 双击编辑标签文件 modelList.addMouseListener(new MouseAdapter() { public void mouseClicked(MouseEvent e) { if (e.getClickCount() == 2) { int index = modelList.locationToIndex(e.getPoint()); if (index >= 0) { - String item = modelListModel.getElementAt(index); - // 解析标签文件路径 - String[] parts = item.split("\n"); - if (parts.length >= 2) { - String labelFilePath = parts[1].replace("标签文件: ", "").trim(); - FileEditor.openFileEditor(labelFilePath); - } + ModelInfo item = modelListModel.getElementAt(index); + String labelFilePath = item.getLabelFilePath(); + FileEditor.openFileEditor(labelFilePath); } } } @@ -45,12 +40,9 @@ public class ModelManager extends JPanel { // 加载模型 public void loadModel(JFrame parent) { - // 获取桌面目录 File desktopDir = FileSystemView.getFileSystemView().getHomeDirectory(); JFileChooser fileChooser = new JFileChooser(desktopDir); fileChooser.setDialogTitle("选择模型文件"); - - // 设置模型文件过滤器,只显示 .onnx 文件 FileNameExtensionFilter modelFilter = new FileNameExtensionFilter("ONNX模型文件 (*.onnx)", "onnx"); fileChooser.setFileFilter(modelFilter); @@ -60,8 +52,6 @@ public class ModelManager extends JPanel { // 选择对应的标签文件 fileChooser.setDialogTitle("选择标签文件"); - - // 设置标签文件过滤器,只显示 .txt 文件 FileNameExtensionFilter labelFilter = new FileNameExtensionFilter("标签文件 (*.txt)", "txt"); fileChooser.setFileFilter(labelFilter); @@ -69,9 +59,9 @@ public class ModelManager extends JPanel { if (returnValue == JFileChooser.APPROVE_OPTION) { File labelFile = fileChooser.getSelectedFile(); - // 将模型和标签文件添加到列表中 - String item = "模型文件: " + modelFile.getAbsolutePath() + "\n标签文件: " + labelFile.getAbsolutePath(); - modelListModel.addElement(item); + // 添加模型信息到列表 + ModelInfo modelInfo = new ModelInfo(modelFile.getAbsolutePath(), labelFile.getAbsolutePath()); + modelListModel.addElement(modelInfo); } else { JOptionPane.showMessageDialog(parent, "未选择标签文件。", "提示", JOptionPane.WARNING_MESSAGE); } @@ -79,4 +69,14 @@ public class ModelManager extends JPanel { JOptionPane.showMessageDialog(parent, "未选择模型文件。", "提示", JOptionPane.WARNING_MESSAGE); } } -} \ No newline at end of file + + // 获取选中的模型 + public ModelInfo getSelectedModel() { + return modelList.getSelectedValue(); + } + + // 如果需要在外部访问 modelList,可以添加以下方法 + public DefaultListModel getModelList() { + return modelListModel; + } +} diff --git a/src/main/java/com/ly/onnx/OnnxModelInference.java b/src/main/java/com/ly/onnx/OnnxModelInference.java new file mode 100644 index 0000000..e6916b6 --- /dev/null +++ b/src/main/java/com/ly/onnx/OnnxModelInference.java @@ -0,0 +1,20 @@ +package com.ly.onnx; + +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtSession; + +public class OnnxModelInference { + + private String modelFilePath; + + private String labelFilePath; + + private String[] labels; + + OrtEnvironment environment = OrtEnvironment.getEnvironment(); + OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); + + + + +} diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine.java b/src/main/java/com/ly/onnx/engine/InferenceEngine.java new file mode 100644 index 0000000..bec81f7 --- /dev/null +++ b/src/main/java/com/ly/onnx/engine/InferenceEngine.java @@ -0,0 +1,410 @@ +package com.ly.onnx.engine; + +import ai.onnxruntime.*; +import com.alibaba.fastjson.JSON; +import com.ly.onnx.model.BoundingBox; +import com.ly.onnx.model.InferenceResult; + +import org.opencv.core.*; +import org.opencv.imgcodecs.Imgcodecs; +import org.opencv.imgproc.Imgproc; + +import java.nio.FloatBuffer; +import java.util.*; + +public class InferenceEngine { + + private OrtEnvironment environment; + private OrtSession.SessionOptions sessionOptions; + private OrtSession session; + + private String modelPath; + private List labels; + + // 用于存储图像预处理信息的类变量 + private int origWidth; + private int origHeight; + private int newWidth; + private int newHeight; + private float scalingFactor; + private int xOffset; + private int yOffset; + + static { + nu.pattern.OpenCV.loadLocally(); + } + + public InferenceEngine(String modelPath, List labels) { + this.modelPath = modelPath; + this.labels = labels; + init(); + } + + public void init() { + try { + environment = OrtEnvironment.getEnvironment(); + sessionOptions = new OrtSession.SessionOptions(); + sessionOptions.addCUDA(0); // 使用 GPU + session = environment.createSession(modelPath, sessionOptions); + logModelInfo(session); + } catch (OrtException e) { + throw new RuntimeException("模型加载失败", e); + } + } + + public InferenceResult infer(float[] inputData, int w, int h, Map preprocessParams) { + long startTime = System.currentTimeMillis(); + + // 从 Map 中获取偏移相关的变量 + int origWidth = (int) preprocessParams.get("origWidth"); + int origHeight = (int) preprocessParams.get("origHeight"); + float scalingFactor = (float) preprocessParams.get("scalingFactor"); + int xOffset = (int) preprocessParams.get("xOffset"); + int yOffset = (int) preprocessParams.get("yOffset"); + + try { + Map inputInfo = session.getInputInfo(); + String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入 + + long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状 + + // 创建输入张量时,使用 CHW 格式的数据 + OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape); + + // 执行推理 + long inferenceStart = System.currentTimeMillis(); + OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor)); + long inferenceEnd = System.currentTimeMillis(); + System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms"); + + // 解析推理结果 + String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出 + float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 6] + + // 设定置信度阈值 + float confidenceThreshold = 0.25f; // 您可以根据需要调整 + + // 根据模型的输出结果解析边界框 + List boxes = new ArrayList<>(); + for (float[] data : outputData[0]) { // 遍历所有检测框 + // 根据模型输出格式,解析中心坐标和宽高 + float x_center = data[0]; + float y_center = data[1]; + float width = data[2]; + float height = data[3]; + float confidence = data[4]; + + if (confidence >= confidenceThreshold) { + // 将中心坐标转换为左上角和右下角坐标 + float x1 = x_center - width / 2; + float y1 = y_center - height / 2; + float x2 = x_center + width / 2; + float y2 = y_center + height / 2; + + // 调整坐标,减去偏移并除以缩放因子 + float x1Adjusted = (x1 - xOffset) / scalingFactor; + float y1Adjusted = (y1 - yOffset) / scalingFactor; + float x2Adjusted = (x2 - xOffset) / scalingFactor; + float y2Adjusted = (y2 - yOffset) / scalingFactor; + + // 确保坐标的正确顺序 + float xMinAdjusted = Math.min(x1Adjusted, x2Adjusted); + float xMaxAdjusted = Math.max(x1Adjusted, x2Adjusted); + float yMinAdjusted = Math.min(y1Adjusted, y2Adjusted); + float yMaxAdjusted = Math.max(y1Adjusted, y2Adjusted); + + // 确保坐标在原始图像范围内 + int x = (int) Math.max(0, xMinAdjusted); + int y = (int) Math.max(0, yMinAdjusted); + int xMax = (int) Math.min(origWidth, xMaxAdjusted); + int yMax = (int) Math.min(origHeight, yMaxAdjusted); + int wBox = xMax - x; + int hBox = yMax - y; + + // 仅当宽度和高度为正时,才添加边界框 + if (wBox > 0 && hBox > 0) { + // 使用您的单一标签 + String label = labels.get(0); + + boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence)); + } + } + } + + // 非极大值抑制(NMS) + long nmsStart = System.currentTimeMillis(); + List nmsBoxes = nonMaximumSuppression(boxes, 0.5f); + + System.out.println("检测到的标签:" + JSON.toJSONString(nmsBoxes)); + if (!nmsBoxes.isEmpty()) { + for (BoundingBox box : nmsBoxes) { + System.out.println(box); + } + } + long nmsEnd = System.currentTimeMillis(); + System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms"); + + // 封装结果并返回 + InferenceResult inferenceResult = new InferenceResult(); + inferenceResult.setBoundingBoxes(nmsBoxes); + + long endTime = System.currentTimeMillis(); + System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms"); + + return inferenceResult; + + } catch (OrtException e) { + throw new RuntimeException("推理失败", e); + } + } + + + // 计算两个边界框的 IoU + private float computeIoU(BoundingBox box1, BoundingBox box2) { + int x1 = Math.max(box1.getX(), box2.getX()); + int y1 = Math.max(box1.getY(), box2.getY()); + int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth()); + int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight()); + + int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1); + int box1Area = box1.getWidth() * box1.getHeight(); + int box2Area = box2.getWidth() * box2.getHeight(); + + return (float) intersectionArea / (box1Area + box2Area - intersectionArea); + } + + // 非极大值抑制(NMS)方法 + private List nonMaximumSuppression(List boxes, float iouThreshold) { + // 按置信度排序(从高到低) + boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence())); + + List result = new ArrayList<>(); + + while (!boxes.isEmpty()) { + BoundingBox bestBox = boxes.remove(0); + result.add(bestBox); + + Iterator iterator = boxes.iterator(); + while (iterator.hasNext()) { + BoundingBox box = iterator.next(); + if (computeIoU(bestBox, box) > iouThreshold) { + iterator.remove(); + } + } + } + + return result; + } + + // 打印模型信息 + private void logModelInfo(OrtSession session) { + System.out.println("模型输入信息:"); + try { + for (Map.Entry entry : session.getInputInfo().entrySet()) { + String name = entry.getKey(); + NodeInfo info = entry.getValue(); + System.out.println("输入名称: " + name); + System.out.println("输入信息: " + info.toString()); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + + System.out.println("模型输出信息:"); + try { + for (Map.Entry entry : session.getOutputInfo().entrySet()) { + String name = entry.getKey(); + NodeInfo info = entry.getValue(); + System.out.println("输出名称: " + name); + System.out.println("输出信息: " + info.toString()); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) { + // 加载 OpenCV 库 + + // 初始化标签列表(只有一个标签) + List labels = Arrays.asList("person"); + + // 创建 InferenceEngine 实例 + InferenceEngine inferenceEngine = new InferenceEngine("C:\\Users\\ly\\Desktop\\person.onnx", labels); + + for (int j = 0; j < 10; j++) { + try { + // 加载图片 + Mat inputImage = Imgcodecs.imread("C:\\Users\\ly\\Desktop\\10230731212230.png"); + + // 预处理图像 + long l1 = System.currentTimeMillis(); + Map preprocessResult = inferenceEngine.preprocessImage(inputImage); + float[] inputData = (float[]) preprocessResult.get("inputData"); + + InferenceResult result = null; + for (int i = 0; i < 10; i++) { + long l = System.currentTimeMillis(); + result = inferenceEngine.infer(inputData, 640, 640, preprocessResult); + System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms"); + } + + + // 处理并显示结果 + System.out.println("推理结果:"); + for (BoundingBox box : result.getBoundingBoxes()) { + System.out.println(box); + } + + // 可视化并保存带有边界框的图像 + Mat outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes()); + + // 保存图片到本地文件 + String outputFilePath = "output_image_with_boxes.jpg"; + Imgcodecs.imwrite(outputFilePath, outputImage); + + System.out.println("已保存结果图片: " + outputFilePath); + + } catch (Exception e) { + e.printStackTrace(); + } + } + } + + // 在图像上绘制边界框和标签 + private Mat drawBoundingBoxes(Mat image, List boxes) { + for (BoundingBox box : boxes) { + // 绘制矩形边界框 + Imgproc.rectangle(image, new Point(box.getX(), box.getY()), + new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight()), + new Scalar(0, 0, 255), 2); // 红色边框 + + // 绘制标签文字和置信度 + String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); + int baseLine[] = new int[1]; + Size labelSize = Imgproc.getTextSize(label, Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, 1, baseLine); + int top = Math.max(box.getY(), (int) labelSize.height); + Imgproc.putText(image, label, new Point(box.getX(), top), + Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, new Scalar(255, 255, 255), 1); + } + + return image; + } + + + public Map preprocessImage(Mat image) { + int targetWidth = 640; + int targetHeight = 640; + + int origWidth = image.width(); + int origHeight = image.height(); + + // 计算缩放因子 + float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); + + // 计算新的图像尺寸 + int newWidth = Math.round(origWidth * scalingFactor); + int newHeight = Math.round(origHeight * scalingFactor); + + // 计算偏移量以居中图像 + int xOffset = (targetWidth - newWidth) / 2; + int yOffset = (targetHeight - newHeight) / 2; + + // 调整图像尺寸 + Mat resizedImage = new Mat(); + Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR); + + // 转换为 RGB 并归一化 + Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); + resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); + + // 创建填充后的图像 + Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); + Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); + resizedImage.copyTo(paddedImage.submat(roi)); + + // 将图像数据转换为数组 + int imageSize = targetWidth * targetHeight; + float[] chwData = new float[3 * imageSize]; + float[] hwcData = new float[3 * imageSize]; + paddedImage.get(0, 0, hwcData); + + // 转换为 CHW 格式 + int channelSize = imageSize; + for (int c = 0; c < 3; c++) { + for (int i = 0; i < imageSize; i++) { + chwData[c * channelSize + i] = hwcData[i * 3 + c]; + } + } + + // 释放图像资源 + resizedImage.release(); + paddedImage.release(); + + // 将预处理结果和偏移信息存入 Map + Map result = new HashMap<>(); + result.put("inputData", chwData); + result.put("origWidth", origWidth); + result.put("origHeight", origHeight); + result.put("scalingFactor", scalingFactor); + result.put("xOffset", xOffset); + result.put("yOffset", yOffset); + + return result; + } + + + // 图像预处理 +// public float[] preprocessImage(Mat image) { +// int targetWidth = 640; +// int targetHeight = 640; +// +// origWidth = image.width(); +// origHeight = image.height(); +// +// // 计算缩放因子 +// scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); +// +// // 计算新的图像尺寸 +// newWidth = Math.round(origWidth * scalingFactor); +// newHeight = Math.round(origHeight * scalingFactor); +// +// // 计算偏移量以居中图像 +// xOffset = (targetWidth - newWidth) / 2; +// yOffset = (targetHeight - newHeight) / 2; +// +// // 调整图像尺寸 +// Mat resizedImage = new Mat(); +// Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR); +// +// // 转换为 RGB 并归一化 +// Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); +// resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); +// +// // 创建填充后的图像 +// Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); +// Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); +// resizedImage.copyTo(paddedImage.submat(roi)); +// +// // 将图像数据转换为数组 +// int imageSize = targetWidth * targetHeight; +// float[] chwData = new float[3 * imageSize]; +// float[] hwcData = new float[3 * imageSize]; +// paddedImage.get(0, 0, hwcData); +// +// // 转换为 CHW 格式 +// int channelSize = imageSize; +// for (int c = 0; c < 3; c++) { +// for (int i = 0; i < imageSize; i++) { +// chwData[c * channelSize + i] = hwcData[i * 3 + c]; +// } +// } +// +// // 释放图像资源 +// resizedImage.release(); +// paddedImage.release(); +// +// return chwData; +// } + +} diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine_up.java b/src/main/java/com/ly/onnx/engine/InferenceEngine_up.java new file mode 100644 index 0000000..6acd759 --- /dev/null +++ b/src/main/java/com/ly/onnx/engine/InferenceEngine_up.java @@ -0,0 +1,297 @@ +package com.ly.onnx.engine; + +import ai.onnxruntime.*; +import com.ly.onnx.model.BoundingBox; +import com.ly.onnx.model.InferenceResult; + +import javax.imageio.ImageIO; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.File; +import java.nio.FloatBuffer; +import java.util.List; +import java.util.*; + +public class InferenceEngine_up { + + OrtEnvironment environment = OrtEnvironment.getEnvironment(); + OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); + + private String modelPath; + private List labels; + + // 添加用于存储图像预处理信息的类变量 + private int origWidth; + private int origHeight; + private int newWidth; + private int newHeight; + private float scalingFactor; + private int xOffset; + private int yOffset; + + public InferenceEngine_up(String modelPath, List labels) { + this.modelPath = modelPath; + this.labels = labels; + init(); + } + + public void init() { + OrtSession session = null; + try { + sessionOptions.addCUDA(0); + session = environment.createSession(modelPath, sessionOptions); + + } catch (OrtException e) { + throw new RuntimeException(e); + } + logModelInfo(session); + } + + public InferenceResult infer(float[] inputData, int w, int h) { + // 创建ONNX输入Tensor + try (OrtSession session = environment.createSession(modelPath, sessionOptions)) { + Map inputInfo = session.getInputInfo(); + String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入 + + long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状 + OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape); + + // 执行推理 + OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor)); + + // 解析推理结果 + String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出 + float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5] + + long l = System.currentTimeMillis(); + // 设定置信度阈值 + float confidenceThreshold = 0.5f; // 您可以根据需要调整 + + // 根据模型的输出结果解析边界框 + List boxes = new ArrayList<>(); + for (float[] data : outputData[0]) { // 遍历所有检测框 + float confidence = data[4]; + if (confidence >= confidenceThreshold) { + float xCenter = data[0]; + float yCenter = data[1]; + float widthBox = data[2]; + float heightBox = data[3]; + + // 调整坐标,减去偏移并除以缩放因子 + float xCenterAdjusted = (xCenter - xOffset) / scalingFactor; + float yCenterAdjusted = (yCenter - yOffset) / scalingFactor; + float widthAdjusted = widthBox / scalingFactor; + float heightAdjusted = heightBox / scalingFactor; + + // 计算左上角坐标 + int x = (int) (xCenterAdjusted - widthAdjusted / 2); + int y = (int) (yCenterAdjusted - heightAdjusted / 2); + int wBox = (int) widthAdjusted; + int hBox = (int) heightAdjusted; + + // 确保坐标在原始图像范围内 + if (x < 0) x = 0; + if (y < 0) y = 0; + if (x + wBox > origWidth) wBox = origWidth - x; + if (y + hBox > origHeight) hBox = origHeight - y; + + String label = "person"; // 由于只有一个类别 + + boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence)); + } + } + + // 非极大值抑制(NMS) + List nmsBoxes = nonMaximumSuppression(boxes, 0.5f); + System.out.println("耗时:"+((System.currentTimeMillis() - l))); + // 封装结果并返回 + InferenceResult inferenceResult = new InferenceResult(); + inferenceResult.setBoundingBoxes(nmsBoxes); + return inferenceResult; + + } catch (OrtException e) { + throw new RuntimeException("推理失败", e); + } + } + + // 计算两个边界框的 IoU + private float computeIoU(BoundingBox box1, BoundingBox box2) { + int x1 = Math.max(box1.getX(), box2.getX()); + int y1 = Math.max(box1.getY(), box2.getY()); + int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth()); + int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight()); + + int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1); + int box1Area = box1.getWidth() * box1.getHeight(); + int box2Area = box2.getWidth() * box2.getHeight(); + + return (float) intersectionArea / (box1Area + box2Area - intersectionArea); + } + + // 打印模型信息 + private void logModelInfo(OrtSession session) { + System.out.println("模型输入信息:"); + try { + for (Map.Entry entry : session.getInputInfo().entrySet()) { + String name = entry.getKey(); + NodeInfo info = entry.getValue(); + System.out.println("输入名称: " + name); + System.out.println("输入信息: " + info.toString()); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + + System.out.println("模型输出信息:"); + try { + for (Map.Entry entry : session.getOutputInfo().entrySet()) { + String name = entry.getKey(); + NodeInfo info = entry.getValue(); + System.out.println("输出名称: " + name); + System.out.println("输出信息: " + info.toString()); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) { + // 初始化标签列表 + List labels = Arrays.asList("person"); + + // 创建 InferenceEngine 实例 + InferenceEngine_up inferenceEngine = new InferenceEngine_up("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels); + + try { + // 加载图片 + File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg"); + BufferedImage inputImage = ImageIO.read(imageFile); + + // 预处理图像 + float[] inputData = inferenceEngine.preprocessImage(inputImage); + + // 执行推理 + InferenceResult result = null; + for (int i = 0; i < 100; i++) { + long l = System.currentTimeMillis(); + result = inferenceEngine.infer(inputData, 640, 640); + System.out.println(System.currentTimeMillis() - l); + } + + // 处理并显示结果 + System.out.println("推理结果:"); + for (BoundingBox box : result.getBoundingBoxes()) { + System.out.println(box); + } + + // 可视化并保存带有边界框的图像 + BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes()); + + // 保存图片到本地文件 + File outputFile = new File("output_image_with_boxes.jpg"); + ImageIO.write(outputImage, "jpg", outputFile); + + System.out.println("已保存结果图片: " + outputFile.getAbsolutePath()); + + } catch (Exception e) { + e.printStackTrace(); + } + } + + // 在图像上绘制边界框和标签 + BufferedImage drawBoundingBoxes(BufferedImage image, List boxes) { + Graphics2D g = image.createGraphics(); + g.setColor(Color.RED); // 设置绘制边界框的颜色 + g.setStroke(new BasicStroke(2)); // 设置线条粗细 + + for (BoundingBox box : boxes) { + // 绘制矩形边界框 + g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight()); + // 绘制标签文字和置信度 + String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); + g.setFont(new Font("Arial", Font.PLAIN, 12)); + g.drawString(label, box.getX(), box.getY() - 5); + } + + g.dispose(); // 释放资源 + return image; + } + + // 图像预处理 + float[] preprocessImage(BufferedImage image) { + int targetWidth = 640; + int targetHeight = 640; + + origWidth = image.getWidth(); + origHeight = image.getHeight(); + + // 计算缩放因子 + scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); + + // 计算新的图像尺寸 + newWidth = Math.round(origWidth * scalingFactor); + newHeight = Math.round(origHeight * scalingFactor); + + // 计算偏移量以居中图像 + xOffset = (targetWidth - newWidth) / 2; + yOffset = (targetHeight - newHeight) / 2; + + // 创建一个新的BufferedImage + BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB); + Graphics2D g = resizedImage.createGraphics(); + + // 填充背景为黑色 + g.setColor(Color.BLACK); + g.fillRect(0, 0, targetWidth, targetHeight); + + // 绘制缩放后的图像到新的图像上 + g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null); + g.dispose(); + + float[] inputData = new float[3 * targetWidth * targetHeight]; + + for (int c = 0; c < 3; c++) { + for (int y = 0; y < targetHeight; y++) { + for (int x = 0; x < targetWidth; x++) { + int rgb = resizedImage.getRGB(x, y); + float value = 0f; + if (c == 0) { + value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel + } else if (c == 1) { + value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel + } else if (c == 2) { + value = (rgb & 0xFF) / 255.0f; // Blue channel + } + inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value; + } + } + } + + return inputData; + } + + // 非极大值抑制(NMS)方法 + private List nonMaximumSuppression(List boxes, float iouThreshold) { + // 按置信度排序 + boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence())); + + List result = new ArrayList<>(); + + while (!boxes.isEmpty()) { + BoundingBox bestBox = boxes.remove(0); + result.add(bestBox); + + Iterator iterator = boxes.iterator(); + while (iterator.hasNext()) { + BoundingBox box = iterator.next(); + if (computeIoU(bestBox, box) > iouThreshold) { + iterator.remove(); + } + } + } + + return result; + } + + // 其他方法保持不变... +} diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine_up_1.java b/src/main/java/com/ly/onnx/engine/InferenceEngine_up_1.java new file mode 100644 index 0000000..c4c732e --- /dev/null +++ b/src/main/java/com/ly/onnx/engine/InferenceEngine_up_1.java @@ -0,0 +1,297 @@ +package com.ly.onnx.engine; + +import ai.onnxruntime.*; +import com.ly.onnx.model.BoundingBox; +import com.ly.onnx.model.InferenceResult; + +import javax.imageio.ImageIO; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.File; +import java.nio.FloatBuffer; +import java.util.List; +import java.util.*; + +public class InferenceEngine_up_1 { + + OrtEnvironment environment = OrtEnvironment.getEnvironment(); + OrtSession.SessionOptions sessionOptions = null; + OrtSession session = null; + private String modelPath; + private List labels; + + // 添加用于存储图像预处理信息的类变量 + private int origWidth; + private int origHeight; + private int newWidth; + private int newHeight; + private float scalingFactor; + private int xOffset; + private int yOffset; + + public InferenceEngine_up_1(String modelPath, List labels) { + this.modelPath = modelPath; + this.labels = labels; + init(); + } + + public void init() { + + try { + sessionOptions = new OrtSession.SessionOptions(); + sessionOptions.addCUDA(0); + session = environment.createSession(modelPath, sessionOptions); + } catch (OrtException e) { + throw new RuntimeException(e); + } + logModelInfo(session); + } + + public InferenceResult infer(float[] inputData, int w, int h) { + // 创建ONNX输入Tensor + try { + Map inputInfo = session.getInputInfo(); + String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入 + + long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状 + OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape); + + // 执行推理 + OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor)); + + // 解析推理结果 + String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出 + float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5] + + long l = System.currentTimeMillis(); + // 设定置信度阈值 + float confidenceThreshold = 0.5f; // 您可以根据需要调整 + + // 根据模型的输出结果解析边界框 + List boxes = new ArrayList<>(); + for (float[] data : outputData[0]) { // 遍历所有检测框 + float confidence = data[4]; + if (confidence >= confidenceThreshold) { + float xCenter = data[0]; + float yCenter = data[1]; + float widthBox = data[2]; + float heightBox = data[3]; + + // 调整坐标,减去偏移并除以缩放因子 + float xCenterAdjusted = (xCenter - xOffset) / scalingFactor; + float yCenterAdjusted = (yCenter - yOffset) / scalingFactor; + float widthAdjusted = widthBox / scalingFactor; + float heightAdjusted = heightBox / scalingFactor; + + // 计算左上角坐标 + int x = (int) (xCenterAdjusted - widthAdjusted / 2); + int y = (int) (yCenterAdjusted - heightAdjusted / 2); + int wBox = (int) widthAdjusted; + int hBox = (int) heightAdjusted; + + // 确保坐标在原始图像范围内 + if (x < 0) x = 0; + if (y < 0) y = 0; + if (x + wBox > origWidth) wBox = origWidth - x; + if (y + hBox > origHeight) hBox = origHeight - y; + + String label = "person"; // 由于只有一个类别 + + boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence)); + } + } + + // 非极大值抑制(NMS) + List nmsBoxes = nonMaximumSuppression(boxes, 0.5f); + System.out.println("耗时:"+((System.currentTimeMillis() - l))); + // 封装结果并返回 + InferenceResult inferenceResult = new InferenceResult(); + inferenceResult.setBoundingBoxes(nmsBoxes); + return inferenceResult; + + } catch (OrtException e) { + throw new RuntimeException("推理失败", e); + } + } + + // 计算两个边界框的 IoU + private float computeIoU(BoundingBox box1, BoundingBox box2) { + int x1 = Math.max(box1.getX(), box2.getX()); + int y1 = Math.max(box1.getY(), box2.getY()); + int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth()); + int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight()); + + int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1); + int box1Area = box1.getWidth() * box1.getHeight(); + int box2Area = box2.getWidth() * box2.getHeight(); + + return (float) intersectionArea / (box1Area + box2Area - intersectionArea); + } + + // 打印模型信息 + private void logModelInfo(OrtSession session) { + System.out.println("模型输入信息:"); + try { + for (Map.Entry entry : session.getInputInfo().entrySet()) { + String name = entry.getKey(); + NodeInfo info = entry.getValue(); + System.out.println("输入名称: " + name); + System.out.println("输入信息: " + info.toString()); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + + System.out.println("模型输出信息:"); + try { + for (Map.Entry entry : session.getOutputInfo().entrySet()) { + String name = entry.getKey(); + NodeInfo info = entry.getValue(); + System.out.println("输出名称: " + name); + System.out.println("输出信息: " + info.toString()); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) { + // 初始化标签列表 + List labels = Arrays.asList("person"); + + // 创建 InferenceEngine 实例 + InferenceEngine_up_1 inferenceEngine = new InferenceEngine_up_1("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels); + + try { + // 加载图片 + File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg"); + BufferedImage inputImage = ImageIO.read(imageFile); + + // 预处理图像 + float[] inputData = inferenceEngine.preprocessImage(inputImage); + + // 执行推理 + InferenceResult result = null; + for (int i = 0; i < 100; i++) { + long l = System.currentTimeMillis(); + result = inferenceEngine.infer(inputData, 640, 640); + System.out.println(System.currentTimeMillis() - l); + } + + // 处理并显示结果 + System.out.println("推理结果:"); + for (BoundingBox box : result.getBoundingBoxes()) { + System.out.println(box); + } + + // 可视化并保存带有边界框的图像 + BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes()); + + // 保存图片到本地文件 + File outputFile = new File("output_image_with_boxes.jpg"); + ImageIO.write(outputImage, "jpg", outputFile); + + System.out.println("已保存结果图片: " + outputFile.getAbsolutePath()); + + } catch (Exception e) { + e.printStackTrace(); + } + } + + // 在图像上绘制边界框和标签 + private BufferedImage drawBoundingBoxes(BufferedImage image, List boxes) { + Graphics2D g = image.createGraphics(); + g.setColor(Color.RED); // 设置绘制边界框的颜色 + g.setStroke(new BasicStroke(2)); // 设置线条粗细 + + for (BoundingBox box : boxes) { + // 绘制矩形边界框 + g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight()); + // 绘制标签文字和置信度 + String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); + g.setFont(new Font("Arial", Font.PLAIN, 12)); + g.drawString(label, box.getX(), box.getY() - 5); + } + + g.dispose(); // 释放资源 + return image; + } + + // 图像预处理 + private float[] preprocessImage(BufferedImage image) { + int targetWidth = 640; + int targetHeight = 640; + + origWidth = image.getWidth(); + origHeight = image.getHeight(); + + // 计算缩放因子 + scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); + + // 计算新的图像尺寸 + newWidth = Math.round(origWidth * scalingFactor); + newHeight = Math.round(origHeight * scalingFactor); + + // 计算偏移量以居中图像 + xOffset = (targetWidth - newWidth) / 2; + yOffset = (targetHeight - newHeight) / 2; + + // 创建一个新的BufferedImage + BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB); + Graphics2D g = resizedImage.createGraphics(); + + // 填充背景为黑色 + g.setColor(Color.BLACK); + g.fillRect(0, 0, targetWidth, targetHeight); + + // 绘制缩放后的图像到新的图像上 + g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null); + g.dispose(); + + float[] inputData = new float[3 * targetWidth * targetHeight]; + + for (int c = 0; c < 3; c++) { + for (int y = 0; y < targetHeight; y++) { + for (int x = 0; x < targetWidth; x++) { + int rgb = resizedImage.getRGB(x, y); + float value = 0f; + if (c == 0) { + value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel + } else if (c == 1) { + value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel + } else if (c == 2) { + value = (rgb & 0xFF) / 255.0f; // Blue channel + } + inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value; + } + } + } + + return inputData; + } + + // 非极大值抑制(NMS)方法 + private List nonMaximumSuppression(List boxes, float iouThreshold) { + // 按置信度排序 + boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence())); + + List result = new ArrayList<>(); + + while (!boxes.isEmpty()) { + BoundingBox bestBox = boxes.remove(0); + result.add(bestBox); + + Iterator iterator = boxes.iterator(); + while (iterator.hasNext()) { + BoundingBox box = iterator.next(); + if (computeIoU(bestBox, box) > iouThreshold) { + iterator.remove(); + } + } + } + + return result; + } + + // 其他方法保持不变... +} diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine_up_2.java b/src/main/java/com/ly/onnx/engine/InferenceEngine_up_2.java new file mode 100644 index 0000000..2262dc4 --- /dev/null +++ b/src/main/java/com/ly/onnx/engine/InferenceEngine_up_2.java @@ -0,0 +1,314 @@ +package com.ly.onnx.engine; + +import ai.onnxruntime.*; +import com.ly.onnx.model.BoundingBox; +import com.ly.onnx.model.InferenceResult; + +import javax.imageio.ImageIO; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.File; +import java.nio.FloatBuffer; +import java.util.List; +import java.util.*; + +public class InferenceEngine_up_2 { + + private OrtEnvironment environment; + private OrtSession.SessionOptions sessionOptions; + private OrtSession session; // 将 session 作为类的成员变量 + + private String modelPath; + private List labels; + + // 添加用于存储图像预处理信息的类变量 + private int origWidth; + private int origHeight; + private int newWidth; + private int newHeight; + private float scalingFactor; + private int xOffset; + private int yOffset; + + public InferenceEngine_up_2(String modelPath, List labels) { + this.modelPath = modelPath; + this.labels = labels; + init(); + } + + public void init() { + try { + environment = OrtEnvironment.getEnvironment(); + sessionOptions = new OrtSession.SessionOptions(); + sessionOptions.addCUDA(0); // 使用 GPU + session = environment.createSession(modelPath, sessionOptions); + logModelInfo(session); + } catch (OrtException e) { + throw new RuntimeException("模型加载失败", e); + } + } + + public InferenceResult infer(float[] inputData, int w, int h) { + long startTime = System.currentTimeMillis(); + + try { + Map inputInfo = session.getInputInfo(); + String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入 + + long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状 + OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape); + + // 执行推理 + long inferenceStart = System.currentTimeMillis(); + OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor)); + long inferenceEnd = System.currentTimeMillis(); + System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms"); + + // 解析推理结果 + String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出 + float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5] + + // 设定置信度阈值 + float confidenceThreshold = 0.5f; // 您可以根据需要调整 + + // 根据模型的输出结果解析边界框 + List boxes = new ArrayList<>(); + for (float[] data : outputData[0]) { // 遍历所有检测框 + float confidence = data[4]; + if (confidence >= confidenceThreshold) { + float xCenter = data[0]; + float yCenter = data[1]; + float widthBox = data[2]; + float heightBox = data[3]; + + // 调整坐标,减去偏移并除以缩放因子 + float xCenterAdjusted = (xCenter - xOffset) / scalingFactor; + float yCenterAdjusted = (yCenter - yOffset) / scalingFactor; + float widthAdjusted = widthBox / scalingFactor; + float heightAdjusted = heightBox / scalingFactor; + + // 计算左上角坐标 + int x = (int) (xCenterAdjusted - widthAdjusted / 2); + int y = (int) (yCenterAdjusted - heightAdjusted / 2); + int wBox = (int) widthAdjusted; + int hBox = (int) heightAdjusted; + + // 确保坐标在原始图像范围内 + if (x < 0) x = 0; + if (y < 0) y = 0; + if (x + wBox > origWidth) wBox = origWidth - x; + if (y + hBox > origHeight) hBox = origHeight - y; + + String label = "person"; // 由于只有一个类别 + + boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence)); + } + } + + // 非极大值抑制(NMS) + long nmsStart = System.currentTimeMillis(); + List nmsBoxes = nonMaximumSuppression(boxes, 0.5f); + long nmsEnd = System.currentTimeMillis(); + System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms"); + + // 封装结果并返回 + InferenceResult inferenceResult = new InferenceResult(); + inferenceResult.setBoundingBoxes(nmsBoxes); + + long endTime = System.currentTimeMillis(); + System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms"); + + return inferenceResult; + + } catch (OrtException e) { + throw new RuntimeException("推理失败", e); + } + } + + // 计算两个边界框的 IoU + private float computeIoU(BoundingBox box1, BoundingBox box2) { + int x1 = Math.max(box1.getX(), box2.getX()); + int y1 = Math.max(box1.getY(), box2.getY()); + int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth()); + int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight()); + + int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1); + int box1Area = box1.getWidth() * box1.getHeight(); + int box2Area = box2.getWidth() * box2.getHeight(); + + return (float) intersectionArea / (box1Area + box2Area - intersectionArea); + } + + // 打印模型信息 + private void logModelInfo(OrtSession session) { + System.out.println("模型输入信息:"); + try { + for (Map.Entry entry : session.getInputInfo().entrySet()) { + String name = entry.getKey(); + NodeInfo info = entry.getValue(); + System.out.println("输入名称: " + name); + System.out.println("输入信息: " + info.toString()); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + + System.out.println("模型输出信息:"); + try { + for (Map.Entry entry : session.getOutputInfo().entrySet()) { + String name = entry.getKey(); + NodeInfo info = entry.getValue(); + System.out.println("输出名称: " + name); + System.out.println("输出信息: " + info.toString()); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) { + // 初始化标签列表 + List labels = Arrays.asList("person"); + + // 创建 InferenceEngine 实例 + InferenceEngine_up_2 inferenceEngine = new InferenceEngine_up_2("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels); + + try { + // 加载图片 + File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg"); + BufferedImage inputImage = ImageIO.read(imageFile); + + // 预处理图像 + long l1 = System.currentTimeMillis(); + float[] inputData = inferenceEngine.preprocessImage(inputImage); + System.out.println("转"+(System.currentTimeMillis() - l1)); + // 执行推理 + InferenceResult result = null; + for (int i = 0; i < 10; i++) { + long l = System.currentTimeMillis(); + result = inferenceEngine.infer(inputData, 640, 640); + System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms"); + } + + // 处理并显示结果 + System.out.println("推理结果:"); + for (BoundingBox box : result.getBoundingBoxes()) { + System.out.println(box); + } + + // 可视化并保存带有边界框的图像 + BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes()); + + // 保存图片到本地文件 + File outputFile = new File("output_image_with_boxes.jpg"); + ImageIO.write(outputImage, "jpg", outputFile); + + System.out.println("已保存结果图片: " + outputFile.getAbsolutePath()); + + } catch (Exception e) { + e.printStackTrace(); + } + } + + // 在图像上绘制边界框和标签 + private BufferedImage drawBoundingBoxes(BufferedImage image, List boxes) { + Graphics2D g = image.createGraphics(); + g.setColor(Color.RED); // 设置绘制边界框的颜色 + g.setStroke(new BasicStroke(2)); // 设置线条粗细 + + for (BoundingBox box : boxes) { + // 绘制矩形边界框 + g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight()); + // 绘制标签文字和置信度 + String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); + g.setFont(new Font("Arial", Font.PLAIN, 12)); + g.drawString(label, box.getX(), box.getY() - 5); + } + + g.dispose(); // 释放资源 + return image; + } + + // 图像预处理 + public float[] preprocessImage(BufferedImage image) { + int targetWidth = 640; + int targetHeight = 640; + + origWidth = image.getWidth(); + origHeight = image.getHeight(); + + // 计算缩放因子 + scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); + + // 计算新的图像尺寸 + newWidth = Math.round(origWidth * scalingFactor); + newHeight = Math.round(origHeight * scalingFactor); + + // 计算偏移量以居中图像 + xOffset = (targetWidth - newWidth) / 2; + yOffset = (targetHeight - newHeight) / 2; + + // 创建一个新的BufferedImage + BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB); + Graphics2D g = resizedImage.createGraphics(); + + // 填充背景为黑色 + g.setColor(Color.BLACK); + g.fillRect(0, 0, targetWidth, targetHeight); + + // 绘制缩放后的图像到新的图像上 + g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null); + g.dispose(); + + float[] inputData = new float[3 * targetWidth * targetHeight]; + + // 开始计时 + long preprocessStart = System.currentTimeMillis(); + + for (int c = 0; c < 3; c++) { + for (int y = 0; y < targetHeight; y++) { + for (int x = 0; x < targetWidth; x++) { + int rgb = resizedImage.getRGB(x, y); + float value = 0f; + if (c == 0) { + value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel + } else if (c == 1) { + value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel + } else if (c == 2) { + value = (rgb & 0xFF) / 255.0f; // Blue channel + } + inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value; + } + } + } + + // 结束计时 + long preprocessEnd = System.currentTimeMillis(); + System.out.println("图像预处理耗时:" + (preprocessEnd - preprocessStart) + " ms"); + + return inputData; + } + + // 非极大值抑制(NMS)方法 + private List nonMaximumSuppression(List boxes, float iouThreshold) { + // 按置信度排序(从高到低) + boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence())); + + List result = new ArrayList<>(); + + while (!boxes.isEmpty()) { + BoundingBox bestBox = boxes.remove(0); + result.add(bestBox); + + Iterator iterator = boxes.iterator(); + while (iterator.hasNext()) { + BoundingBox box = iterator.next(); + if (computeIoU(bestBox, box) > iouThreshold) { + iterator.remove(); + } + } + } + + return result; + } +} diff --git a/src/main/java/com/ly/onnx/model/BoundingBox.java b/src/main/java/com/ly/onnx/model/BoundingBox.java new file mode 100644 index 0000000..332a888 --- /dev/null +++ b/src/main/java/com/ly/onnx/model/BoundingBox.java @@ -0,0 +1,27 @@ +package com.ly.onnx.model; + +import lombok.Data; + +@Data +public class BoundingBox { + private int x; + private int y; + private int width; + private int height; + private String label; + private float confidence; + + // 构造函数、getter 和 setter 方法 + + public BoundingBox(int x, int y, int width, int height, String label, float confidence) { + this.x = x; + this.y = y; + this.width = width; + this.height = height; + this.label = label; + this.confidence = confidence; + } + + // Getter 和 Setter 方法 + // ... +} diff --git a/src/main/java/com/ly/onnx/model/InferenceResult.java b/src/main/java/com/ly/onnx/model/InferenceResult.java new file mode 100644 index 0000000..f9e8ea8 --- /dev/null +++ b/src/main/java/com/ly/onnx/model/InferenceResult.java @@ -0,0 +1,18 @@ +package com.ly.onnx.model; + +import java.util.ArrayList; +import java.util.List; + +public class InferenceResult { + private List boundingBoxes = new ArrayList<>(); + + public List getBoundingBoxes() { + return boundingBoxes; + } + + public void setBoundingBoxes(List boundingBoxes) { + this.boundingBoxes = boundingBoxes; + } + + // 其他需要的属性和方法 +} diff --git a/src/main/java/com/ly/onnx/model/ModelInfo.java b/src/main/java/com/ly/onnx/model/ModelInfo.java new file mode 100644 index 0000000..3d0865b --- /dev/null +++ b/src/main/java/com/ly/onnx/model/ModelInfo.java @@ -0,0 +1,39 @@ +package com.ly.onnx.model; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.List; + +public class ModelInfo { + private String modelFilePath; + private String labelFilePath; + private List labels; + + public ModelInfo(String modelFilePath, String labelFilePath) { + this.modelFilePath = modelFilePath; + this.labelFilePath = labelFilePath; + try { + this.labels = Files.readAllLines(Paths.get(labelFilePath)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public String getModelFilePath() { + return modelFilePath; + } + + public String getLabelFilePath() { + return labelFilePath; + } + + public List getLabels() { + return labels; + } + + @Override + public String toString() { + return "模型文件: " + modelFilePath + "\n标签文件: " + labelFilePath; + } +} diff --git a/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java b/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java new file mode 100644 index 0000000..240da5d --- /dev/null +++ b/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java @@ -0,0 +1,53 @@ +package com.ly.onnx.utils; + +import com.ly.onnx.model.BoundingBox; +import com.ly.onnx.model.InferenceResult; +import org.opencv.core.Mat; +import org.opencv.core.Point; +import org.opencv.core.Scalar; +import org.opencv.imgproc.Imgproc; + +import java.awt.image.BufferedImage; +import java.util.List; + +public class DrawImagesUtils { + + + public static void drawInferenceResult(BufferedImage bufferedImage, List result) { + + } + + // 在 Mat 上绘制推理结果 + public static void drawInferenceResult(Mat image, List inferenceResults) { + for (InferenceResult result : inferenceResults) { + for (BoundingBox box : result.getBoundingBoxes()) { + // 绘制矩形 + Point topLeft = new Point(box.getX(), box.getY()); + Point bottomRight = new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight()); + Imgproc.rectangle(image, topLeft, bottomRight, new Scalar(0, 0, 255), 2); // 红色边框 + + // 绘制标签 + String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); + int font = Imgproc.FONT_HERSHEY_SIMPLEX; + double fontScale = 0.5; + int thickness = 1; + + // 计算文字大小 + int[] baseLine = new int[1]; + org.opencv.core.Size labelSize = Imgproc.getTextSize(label, font, fontScale, thickness, baseLine); + + // 确保文字不会超出图像 + int y = Math.max((int) topLeft.y, (int) labelSize.height); + + // 绘制文字背景 + Imgproc.rectangle(image, new Point(topLeft.x, y - labelSize.height), + new Point(topLeft.x + labelSize.width, y + baseLine[0]), + new Scalar(0, 0, 255), Imgproc.FILLED); + + // 绘制文字 + Imgproc.putText(image, label, new Point(topLeft.x, y), + font, fontScale, new Scalar(255, 255, 255), thickness); + } + } + } +} diff --git a/src/main/java/com/ly/onnx/utils/ImageUtils.java b/src/main/java/com/ly/onnx/utils/ImageUtils.java new file mode 100644 index 0000000..0fe7286 --- /dev/null +++ b/src/main/java/com/ly/onnx/utils/ImageUtils.java @@ -0,0 +1,106 @@ +package com.ly.onnx.utils; + +import org.bytedeco.javacv.Frame; +import org.opencv.core.*; +import org.opencv.imgproc.Imgproc; + +import java.awt.image.BufferedImage; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +public class ImageUtils { + + // 辅助方法:将 BufferedImage 转换为浮点数组(根据您的模型需求) + private static float[] preprocessImage(BufferedImage image) { + int width = image.getWidth(); + int height = image.getHeight(); + float[] result = new float[width * height * 3]; // 假设是 RGB 图像 + int idx = 0; + + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + int pixel = image.getRGB(x, y); + // 分别获取 R, G, B 值并归一化(假设归一化到 [0, 1]) + result[idx++] = ((pixel >> 16) & 0xFF) / 255.0f; // Red + result[idx++] = ((pixel >> 8) & 0xFF) / 255.0f; // Green + result[idx++] = (pixel & 0xFF) / 255.0f; // Blue + } + } + return result; + } + + + + + public static float[] frameToFloatArray(Frame frame) { + // 获取 Frame 的宽度和高度 + int width = frame.imageWidth; + int height = frame.imageHeight; + + // 获取 Frame 的像素数据 + Buffer buffer = frame.image[0]; // 获取图像数据缓冲区 + ByteBuffer byteBuffer = (ByteBuffer) buffer; // 假设图像数据是以字节缓冲存储 + + // 创建 float 数组来存储图像的 RGB 值 + float[] result = new float[width * height * 3]; // 假设是 RGB 格式图像 + int idx = 0; + + // 遍历每个像素,提取 R, G, B 值并归一化到 [0, 1] + for (int i = 0; i < byteBuffer.capacity(); i += 3) { + // 提取 RGB 通道数据 + int r = byteBuffer.get(i) & 0xFF; // Red + int g = byteBuffer.get(i + 1) & 0xFF; // Green + int b = byteBuffer.get(i + 2) & 0xFF; // Blue + + // 将 RGB 值归一化为 float 并存入数组 + result[idx++] = r / 255.0f; + result[idx++] = g / 255.0f; + result[idx++] = b / 255.0f; + } + + return result; + } + // 将 Mat 转换为 BufferedImage + public static BufferedImage matToBufferedImage(Mat mat) { + int type = BufferedImage.TYPE_3BYTE_BGR; + if (mat.channels() == 1) { + type = BufferedImage.TYPE_BYTE_GRAY; + } + int bufferSize = mat.channels() * mat.cols() * mat.rows(); + byte[] buffer = new byte[bufferSize]; + mat.get(0, 0, buffer); // 获取所有像素 + BufferedImage image = new BufferedImage(mat.cols(), mat.rows(), type); + final byte[] targetPixels = ((java.awt.image.DataBufferByte) image.getRaster().getDataBuffer()).getData(); + System.arraycopy(buffer, 0, targetPixels, 0, buffer.length); + return image; + } + + // 将 Mat 转换为 float 数组,适用于推理 + public static float[] matToFloatArray(Mat mat) { + // 假设 InferenceEngine 需要 RGB 格式的图像 + Mat rgbMat = new Mat(); + Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB); + + // 假设图像已经被预处理(缩放、归一化等),否则需要在这里添加预处理步骤 + + // 将 Mat 数据转换为 float 数组 + int channels = rgbMat.channels(); + int rows = rgbMat.rows(); + int cols = rgbMat.cols(); + float[] floatData = new float[channels * rows * cols]; + byte[] byteData = new byte[channels * rows * cols]; + rgbMat.get(0, 0, byteData); + for (int i = 0; i < floatData.length; i++) { + // 将 unsigned byte 转换为 float [0,1] + floatData[i] = (byteData[i] & 0xFF) / 255.0f; + } + rgbMat.release(); + return floatData; + } + + + + +} diff --git a/src/main/java/com/ly/play/VideoPlayer.java b/src/main/java/com/ly/play/VideoPlayer.java deleted file mode 100644 index 1e91524..0000000 --- a/src/main/java/com/ly/play/VideoPlayer.java +++ /dev/null @@ -1,280 +0,0 @@ -package com.ly.play; - -import com.ly.layout.VideoPanel; -import org.bytedeco.javacv.*; - -import javax.swing.*; -import java.awt.image.BufferedImage; - -public class VideoPlayer { - private FrameGrabber grabber; - private Java2DFrameConverter converter = new Java2DFrameConverter(); - private boolean isPlaying = false; - private boolean isPaused = false; - private Thread videoThread; - private VideoPanel videoPanel; - - private long videoDuration = 0; // 毫秒 - private long currentTimestamp = 0; // 毫秒 - - public VideoPlayer(VideoPanel videoPanel) { - this.videoPanel = videoPanel; - } - - // 加载视频或流 - // 加载视频或流 - public void loadVideo(String videoFilePathOrStreamUrl) throws Exception { - stopVideo(); - - - if (videoFilePathOrStreamUrl.equals("0")) { - int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl); - grabber = new OpenCVFrameGrabber(cameraIndex); - grabber.start(); - videoDuration = 0; // 摄像头没有固定的时长 - playVideo(); - } else { - // 输入不是数字,尝试使用 FFmpegFrameGrabber 打开流或视频文件 - grabber = new FFmpegFrameGrabber(videoFilePathOrStreamUrl); - grabber.start(); - videoDuration = grabber.getLengthInTime() / 1000; // 转换为毫秒 - } - - - // 显示第一帧 - Frame frame; - if (grabber instanceof OpenCVFrameGrabber) { - frame = grabber.grab(); - } else { - frame = grabber.grab(); - } - if (frame != null && frame.image != null) { - BufferedImage bufferedImage = converter.getBufferedImage(frame); - videoPanel.updateImage(bufferedImage); - currentTimestamp = 0; - } - - // 重置到视频开始位置 - if (grabber instanceof FFmpegFrameGrabber) { - grabber.setTimestamp(0); - } - currentTimestamp = 0; - } - - public void playVideo() { - if (grabber == null) { - JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE); - return; - } - - if (isPlaying) { - if (isPaused) { - isPaused = false; // 恢复播放 - } - return; - } - - isPlaying = true; - isPaused = false; - - videoThread = new Thread(() -> { - try { - if (grabber instanceof OpenCVFrameGrabber) { - // 摄像头捕获 - while (isPlaying) { - if (isPaused) { - Thread.sleep(100); - continue; - } - - Frame frame = grabber.grab(); - if (frame == null) { - isPlaying = false; - break; - } - - BufferedImage bufferedImage = converter.getBufferedImage(frame); - if (bufferedImage != null) { - videoPanel.updateImage(bufferedImage); - } - } - } else { - // 视频文件或流 - double frameRate = grabber.getFrameRate(); - if (frameRate <= 0 || Double.isNaN(frameRate)) { - frameRate = 25; // 默认帧率 - } - long frameInterval = (long) (1000 / frameRate); // 每帧间隔时间(毫秒) - long startTime = System.currentTimeMillis(); - long frameCount = 0; - - while (isPlaying) { - if (isPaused) { - Thread.sleep(100); - startTime += 100; // 调整开始时间以考虑暂停时间 - continue; - } - - Frame frame = grabber.grab(); - if (frame == null) { - // 视频播放结束 - isPlaying = false; - break; - } - - BufferedImage bufferedImage = converter.getBufferedImage(frame); - if (bufferedImage != null) { - videoPanel.updateImage(bufferedImage); - - // 更新当前帧时间戳 - frameCount++; - long expectedTime = frameCount * frameInterval; - long actualTime = System.currentTimeMillis() - startTime; - - currentTimestamp = grabber.getTimestamp() / 1000; - - // 如果实际时间落后于预期时间,进行调整 - if (actualTime < expectedTime) { - Thread.sleep(expectedTime - actualTime); - } - } - } - } - - // 视频播放完毕后,停止播放 - isPlaying = false; - - } catch (Exception ex) { - ex.printStackTrace(); - } - }); - videoThread.start(); - } - - // 暂停视频 - public void pauseVideo() { - if (!isPlaying) { - return; - } - isPaused = true; - } - - // 重播视频 - public void replayVideo() { - try { - if (grabber instanceof FFmpegFrameGrabber) { - grabber.setTimestamp(0); // 重置到视频开始位置 - grabber.flush(); // 清除缓存 - currentTimestamp = 0; - - // 显示第一帧 - Frame frame = grabber.grab(); - if (frame != null && frame.image != null) { - BufferedImage bufferedImage = converter.getBufferedImage(frame); - videoPanel.updateImage(bufferedImage); - } - - playVideo(); // 开始播放 - } else if (grabber instanceof OpenCVFrameGrabber) { - // 对于摄像头,重播相当于重新开始播放 - playVideo(); - } - } catch (Exception e) { - e.printStackTrace(); - JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE); - } - } - - // 停止视频 - public void stopVideo() { - isPlaying = false; - isPaused = false; - - if (videoThread != null && videoThread.isAlive()) { - try { - videoThread.join(); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } - - if (grabber != null) { - try { - grabber.stop(); - grabber.release(); - } catch (Exception ex) { - ex.printStackTrace(); - } - grabber = null; - } - } - - // 快进或后退 - public void seekTo(long seekTime) { - if (grabber == null) return; - if (!(grabber instanceof FFmpegFrameGrabber)) { - JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE); - return; - } - try { - isPaused = false; // 取消暂停 - isPlaying = false; // 停止当前播放线程 - if (videoThread != null && videoThread.isAlive()) { - videoThread.join(); - } - - grabber.setTimestamp(seekTime * 1000); // 转换为微秒 - grabber.flush(); // 清除缓存 - - Frame frame; - do { - frame = grabber.grab(); - if (frame == null) { - break; - } - } while (frame.image == null); // 跳过没有图像的帧 - - if (frame != null && frame.image != null) { - BufferedImage bufferedImage = converter.getBufferedImage(frame); - videoPanel.updateImage(bufferedImage); - - // 更新当前帧时间戳 - currentTimestamp = grabber.getTimestamp() / 1000; - } - - // 重新开始播放 - playVideo(); - - } catch (Exception ex) { - ex.printStackTrace(); - } - } - - // 快进 - public void fastForward(long milliseconds) { - if (!(grabber instanceof FFmpegFrameGrabber)) { - JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE); - return; - } - long newTime = Math.min(currentTimestamp + milliseconds, videoDuration); - seekTo(newTime); - } - - // 后退 - public void rewind(long milliseconds) { - if (!(grabber instanceof FFmpegFrameGrabber)) { - JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE); - return; - } - long newTime = Math.max(currentTimestamp - milliseconds, 0); - seekTo(newTime); - } - - public long getVideoDuration() { - return videoDuration; - } - - public FrameGrabber getGrabber() { - return grabber; - } -} diff --git a/src/main/java/com/ly/play/ff/VideoPlayer.java b/src/main/java/com/ly/play/ff/VideoPlayer.java new file mode 100644 index 0000000..1d5a46b --- /dev/null +++ b/src/main/java/com/ly/play/ff/VideoPlayer.java @@ -0,0 +1,328 @@ +//package com.ly.play.ff; +// +//import com.ly.layout.VideoPanel; +//import com.ly.model_load.ModelManager; +//import com.ly.onnx.engine.InferenceEngine; +//import com.ly.onnx.model.InferenceResult; +//import com.ly.onnx.utils.DrawImagesUtils; +//import com.ly.onnx.utils.ImageUtils; +//import org.bytedeco.javacv.*; +// +//import javax.swing.*; +//import java.awt.image.BufferedImage; +//import java.util.ArrayList; +//import java.util.List; +// +//public class VideoPlayer { +// private FrameGrabber grabber; +// private Java2DFrameConverter converter = new Java2DFrameConverter(); +// private boolean isPlaying = false; +// private boolean isPaused = false; +// private Thread videoThread; +// private VideoPanel videoPanel; +// +// private long videoDuration = 0; // 毫秒 +// private long currentTimestamp = 0; // 毫秒 +// +// ModelManager modelManager; +// private List inferenceEngines = new ArrayList<>(); +// +// public VideoPlayer(VideoPanel videoPanel, ModelManager modelManager) { +// this.videoPanel = videoPanel; +// this.modelManager = modelManager; +// System.out.println(); +// } +// +// // 加载视频或流 +// public void loadVideo(String videoFilePathOrStreamUrl) throws Exception { +// stopVideo(); +// if (videoFilePathOrStreamUrl.equals("0")) { +// int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl); +// grabber = new OpenCVFrameGrabber(cameraIndex); +// grabber.start(); +// videoDuration = 0; // 摄像头没有固定的时长 +// playVideo(); +// } else { +// // 输入不是数字,尝试使用 FFmpegFrameGrabber 打开流或视频文件 +// grabber = new FFmpegFrameGrabber(videoFilePathOrStreamUrl); +// grabber.start(); +// videoDuration = grabber.getLengthInTime() / 1000; // 转换为毫秒 +// } +// +// +// // 显示第一帧 +// Frame frame; +// if (grabber instanceof OpenCVFrameGrabber) { +// frame = grabber.grab(); +// } else { +// frame = grabber.grab(); +// } +// if (frame != null && frame.image != null) { +// BufferedImage bufferedImage = converter.getBufferedImage(frame); +// videoPanel.updateImage(bufferedImage); +// currentTimestamp = 0; +// } +// +// // 重置到视频开始位置 +// if (grabber instanceof FFmpegFrameGrabber) { +// grabber.setTimestamp(0); +// } +// currentTimestamp = 0; +// } +// +// +// +// +// +// //播放 +// public void playVideo() { +// if (grabber == null) { +// JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE); +// return; +// } +// +// if (inferenceEngines == null){ +// JOptionPane.showMessageDialog(null, "请先加载模型给文件。", "提示", JOptionPane.WARNING_MESSAGE); +// return; +// } +// +// if (isPlaying) { +// if (isPaused) { +// isPaused = false; // 恢复播放 +// } +// return; +// } +// +// isPlaying = true; +// isPaused = false; +// +// videoThread = new Thread(() -> { +// try { +// if (grabber instanceof OpenCVFrameGrabber) { +// // 摄像头捕获 +// while (isPlaying) { +// if (isPaused) { +// Thread.sleep(10); +// continue; +// } +// +// Frame frame = grabber.grab(); +// if (frame == null) { +// isPlaying = false; +// break; +// } +// +// BufferedImage bufferedImage = converter.getBufferedImage(frame); +// List inferenceResults = new ArrayList<>(); +// if (bufferedImage != null) { +// float[] inputData = ImageUtils.frameToFloatArray(frame); +// for (InferenceEngine inferenceEngine : inferenceEngines) { +// inferenceResults.add(inferenceEngine.infer(inputData,640,640)); +// } +// //绘制 +// DrawImagesUtils.drawInferenceResult(bufferedImage,inferenceResults); +// //更新绘制后图像 +// videoPanel.updateImage(bufferedImage); +// } +// } +// } else { +// // 视频文件或流 +// double frameRate = grabber.getFrameRate(); +// if (frameRate <= 0 || Double.isNaN(frameRate)) { +// frameRate = 25; // 默认帧率 +// } +// long frameInterval = (long) (1000 / frameRate); // 每帧间隔时间(毫秒) +// long startTime = System.currentTimeMillis(); +// long frameCount = 0; +// +// while (isPlaying) { +// if (isPaused) { +// Thread.sleep(100); +// startTime += 100; // 调整开始时间以考虑暂停时间 +// continue; +// } +// +// Frame frame = grabber.grab(); +// if (frame == null) { +// // 视频播放结束 +// isPlaying = false; +// break; +// } +// +// +// +// BufferedImage bufferedImage = converter.getBufferedImage(frame); +// +// +// List inferenceResults = new ArrayList<>(); +// if (bufferedImage != null) { +// float[] inputData = ImageUtils.frameToFloatArray(frame); +// for (InferenceEngine inferenceEngine : inferenceEngines) { +// inferenceResults.add(inferenceEngine.infer(inputData,640,640)); +// } +// //绘制 +// DrawImagesUtils.drawInferenceResult(bufferedImage,inferenceResults); +// //更新绘制后图像 +// videoPanel.updateImage(bufferedImage); +// } +// +// if (bufferedImage != null) { +// videoPanel.updateImage(bufferedImage); +// +// // 更新当前帧时间戳 +// frameCount++; +// long expectedTime = frameCount * frameInterval; +// long actualTime = System.currentTimeMillis() - startTime; +// +// currentTimestamp = grabber.getTimestamp() / 1000; +// +// // 如果实际时间落后于预期时间,进行调整 +// if (actualTime < expectedTime) { +// Thread.sleep(expectedTime - actualTime); +// } +// } +// } +// } +// +// // 视频播放完毕后,停止播放 +// isPlaying = false; +// +// } catch (Exception ex) { +// ex.printStackTrace(); +// } +// }); +// videoThread.start(); +// } +// +// // 暂停视频 +// public void pauseVideo() { +// if (!isPlaying) { +// return; +// } +// isPaused = true; +// } +// +// // 重播视频 +// public void replayVideo() { +// try { +// if (grabber instanceof FFmpegFrameGrabber) { +// grabber.setTimestamp(0); // 重置到视频开始位置 +// grabber.flush(); // 清除缓存 +// currentTimestamp = 0; +// +// // 显示第一帧 +// Frame frame = grabber.grab(); +// if (frame != null && frame.image != null) { +// BufferedImage bufferedImage = converter.getBufferedImage(frame); +// videoPanel.updateImage(bufferedImage); +// } +// +// playVideo(); // 开始播放 +// } else if (grabber instanceof OpenCVFrameGrabber) { +// // 对于摄像头,重播相当于重新开始播放 +// playVideo(); +// } +// } catch (Exception e) { +// e.printStackTrace(); +// JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE); +// } +// } +// +// // 停止视频 +// public void stopVideo() { +// isPlaying = false; +// isPaused = false; +// +// if (videoThread != null && videoThread.isAlive()) { +// try { +// videoThread.join(); +// } catch (InterruptedException e) { +// e.printStackTrace(); +// } +// } +// +// if (grabber != null) { +// try { +// grabber.stop(); +// grabber.release(); +// } catch (Exception ex) { +// ex.printStackTrace(); +// } +// grabber = null; +// } +// } +// +// // 快进或后退 +// public void seekTo(long seekTime) { +// if (grabber == null) return; +// if (!(grabber instanceof FFmpegFrameGrabber)) { +// JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE); +// return; +// } +// try { +// isPaused = false; // 取消暂停 +// isPlaying = false; // 停止当前播放线程 +// if (videoThread != null && videoThread.isAlive()) { +// videoThread.join(); +// } +// +// grabber.setTimestamp(seekTime * 1000); // 转换为微秒 +// grabber.flush(); // 清除缓存 +// +// Frame frame; +// do { +// frame = grabber.grab(); +// if (frame == null) { +// break; +// } +// } while (frame.image == null); // 跳过没有图像的帧 +// +// if (frame != null && frame.image != null) { +// BufferedImage bufferedImage = converter.getBufferedImage(frame); +// videoPanel.updateImage(bufferedImage); +// +// // 更新当前帧时间戳 +// currentTimestamp = grabber.getTimestamp() / 1000; +// } +// +// // 重新开始播放 +// playVideo(); +// +// } catch (Exception ex) { +// ex.printStackTrace(); +// } +// } +// +// // 快进 +// public void fastForward(long milliseconds) { +// if (!(grabber instanceof FFmpegFrameGrabber)) { +// JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE); +// return; +// } +// long newTime = Math.min(currentTimestamp + milliseconds, videoDuration); +// seekTo(newTime); +// } +// +// // 后退 +// public void rewind(long milliseconds) { +// if (!(grabber instanceof FFmpegFrameGrabber)) { +// JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE); +// return; +// } +// long newTime = Math.max(currentTimestamp - milliseconds, 0); +// seekTo(newTime); +// } +// +// public long getVideoDuration() { +// return videoDuration; +// } +// +// public FrameGrabber getGrabber() { +// return grabber; +// } +// +// public void addInferenceEngines(InferenceEngine inferenceEngine){ +// this.inferenceEngines.add(inferenceEngine); +// } +// +//} diff --git a/src/main/java/com/ly/play/opencv/VideoPlayer.java b/src/main/java/com/ly/play/opencv/VideoPlayer.java new file mode 100644 index 0000000..d341b70 --- /dev/null +++ b/src/main/java/com/ly/play/opencv/VideoPlayer.java @@ -0,0 +1,427 @@ +package com.ly.play.opencv; + +import com.ly.layout.VideoPanel; +import com.ly.model_load.ModelManager; +import com.ly.onnx.engine.InferenceEngine; +import com.ly.onnx.model.InferenceResult; +import com.ly.onnx.utils.DrawImagesUtils; +import org.opencv.core.CvType; +import org.opencv.core.Mat; +import org.opencv.core.Rect; +import org.opencv.core.Size; +import org.opencv.imgproc.Imgproc; +import org.opencv.videoio.VideoCapture; +import org.opencv.videoio.Videoio; + +import javax.swing.*; +import java.awt.image.BufferedImage; +import java.awt.image.DataBufferByte; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.*; + +import static com.ly.onnx.utils.ImageUtils.matToBufferedImage; + +public class VideoPlayer { + static { + // 加载 OpenCV 库 + nu.pattern.OpenCV.loadLocally(); + String OS = System.getProperty("os.name").toLowerCase(); + if (OS.contains("win")) { + System.load(ClassLoader.getSystemResource("lib/win/opencv_videoio_ffmpeg470_64.dll").getPath()); + } + } + + private VideoCapture videoCapture; + private volatile boolean isPlaying = false; + private volatile boolean isPaused = false; + private Thread frameReadingThread; + private Thread inferenceThread; + private VideoPanel videoPanel; + + private long videoDuration = 0; // 毫秒 + private long currentTimestamp = 0; // 毫秒 + + private ModelManager modelManager; + private List inferenceEngines = new ArrayList<>(); + + // 定义阻塞队列来缓冲转换后的数据 + private BlockingQueue frameDataQueue = new LinkedBlockingQueue<>(10); // 队列容量可根据需要调整 + + public VideoPlayer(VideoPanel videoPanel, ModelManager modelManager) { + this.videoPanel = videoPanel; + this.modelManager = modelManager; + } + + // 加载视频或流 + public void loadVideo(String videoFilePathOrStreamUrl) throws Exception { + stopVideo(); + if (videoFilePathOrStreamUrl.equals("0")) { + int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl); + videoCapture = new VideoCapture(cameraIndex); + if (!videoCapture.isOpened()) { + throw new Exception("无法打开摄像头"); + } + videoDuration = 0; // 摄像头没有固定的时长 + playVideo(); + } else { + // 输入不是数字,尝试打开视频文件 + videoCapture = new VideoCapture(videoFilePathOrStreamUrl, Videoio.CAP_FFMPEG); + if (!videoCapture.isOpened()) { + throw new Exception("无法打开视频文件:" + videoFilePathOrStreamUrl); + } + double frameCount = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FRAME_COUNT); + double fps = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FPS); + if (fps <= 0 || Double.isNaN(fps)) { + fps = 25; // 默认帧率 + } + videoDuration = (long) (frameCount / fps * 1000); // 转换为毫秒 + } + + // 显示第一帧 + Mat frame = new Mat(); + if (videoCapture.read(frame)) { + BufferedImage bufferedImage = matToBufferedImage(frame); + videoPanel.updateImage(bufferedImage); + currentTimestamp = 0; + } else { + throw new Exception("无法读取第一帧"); + } + + // 重置到视频开始位置 + videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_FRAMES, 0); + currentTimestamp = 0; + } + + // 播放 + public void playVideo() { + if (videoCapture == null || !videoCapture.isOpened()) { + JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE); + return; + } + + if (isPlaying) { + if (isPaused) { + isPaused = false; // 恢复播放 + } + return; + } + + isPlaying = true; + isPaused = false; + + frameDataQueue.clear(); // 开始播放前清空队列 + + // 创建并启动帧读取和转换线程 + frameReadingThread = new Thread(() -> { + try { + double fps = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FPS); + if (fps <= 0 || Double.isNaN(fps)) { + fps = 25; // 默认帧率 + } + long frameDelay = (long) (1000 / fps); + + while (isPlaying) { + if (Thread.currentThread().isInterrupted()) { + break; + } + if (isPaused) { + Thread.sleep(10); + continue; + } + + Mat frame = new Mat(); + if (!videoCapture.read(frame) || frame.empty()) { + isPlaying = false; + break; + } + + long startTime = System.currentTimeMillis(); + BufferedImage bufferedImage = matToBufferedImage(frame); + + if (bufferedImage != null) { + float[] floats = preprocessAndConvertBufferedImage(bufferedImage); + + // 创建 FrameData 对象并放入队列 + FrameData frameData = new FrameData(bufferedImage, floats); + frameDataQueue.put(frameData); // 阻塞,如果队列已满 + } + + // 控制帧率 + currentTimestamp = (long) videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_POS_MSEC); + + // 控制播放速度 + long processingTime = System.currentTimeMillis() - startTime; + long sleepTime = frameDelay - processingTime; + if (sleepTime > 0) { + Thread.sleep(sleepTime); + } + } + } catch (Exception ex) { + ex.printStackTrace(); + } finally { + isPlaying = false; + } + }); + + // 创建并启动推理线程 + inferenceThread = new Thread(() -> { + try { + while (isPlaying || !frameDataQueue.isEmpty()) { + if (Thread.currentThread().isInterrupted()) { + break; + } + if (isPaused) { + Thread.sleep(100); + continue; + } + + FrameData frameData = frameDataQueue.poll(100, TimeUnit.MILLISECONDS); // 等待数据 + if (frameData == null) { + continue; // 没有数据,继续检查 isPlaying + } + + BufferedImage bufferedImage = frameData.image; + float[] floatArray = frameData.floatArray; + + // 执行推理 + List inferenceResults = new ArrayList<>(); + for (InferenceEngine inferenceEngine : inferenceEngines) { + // 假设 InferenceEngine 有 infer 方法接受 float 数组 +// inferenceResults.add(inferenceEngine.infer(floatArray, 640, 640)); + } + // 绘制推理结果 + DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults); + + // 更新绘制后图像 + videoPanel.updateImage(bufferedImage); + } + } catch (Exception ex) { + ex.printStackTrace(); + } + }); + + frameReadingThread.start(); + inferenceThread.start(); + } + + // 暂停视频 + public void pauseVideo() { + if (!isPlaying) { + return; + } + isPaused = true; + } + + // 重播视频 + public void replayVideo() { + try { + stopVideo(); // 停止当前播放 + if (videoCapture != null) { + videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_FRAMES, 0); + currentTimestamp = 0; + + // 显示第一帧 + Mat frame = new Mat(); + if (videoCapture.read(frame)) { + BufferedImage bufferedImage = matToBufferedImage(frame); + videoPanel.updateImage(bufferedImage); + } + + playVideo(); // 开始播放 + } + } catch (Exception e) { + e.printStackTrace(); + JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE); + } + } + + // 停止视频 + public void stopVideo() { + isPlaying = false; + isPaused = false; + + if (frameReadingThread != null && frameReadingThread.isAlive()) { + frameReadingThread.interrupt(); + } + + if (inferenceThread != null && inferenceThread.isAlive()) { + inferenceThread.interrupt(); + } + + if (videoCapture != null) { + videoCapture.release(); + videoCapture = null; + } + + frameDataQueue.clear(); + } + + // 快进或后退 + public void seekTo(long seekTime) { + if (videoCapture == null) return; + try { + isPaused = false; // 取消暂停 + stopVideo(); // 停止当前播放 + videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_MSEC, seekTime); + currentTimestamp = seekTime; + + Mat frame = new Mat(); + if (videoCapture.read(frame)) { + BufferedImage bufferedImage = matToBufferedImage(frame); + videoPanel.updateImage(bufferedImage); + } + + // 重新开始播放 + playVideo(); + + } catch (Exception ex) { + ex.printStackTrace(); + } + } + + // 快进 + public void fastForward(long milliseconds) { + long newTime = Math.min(currentTimestamp + milliseconds, videoDuration); + seekTo(newTime); + } + + // 后退 + public void rewind(long milliseconds) { + long newTime = Math.max(currentTimestamp - milliseconds, 0); + seekTo(newTime); + } + + public void addInferenceEngines(InferenceEngine inferenceEngine) { + this.inferenceEngines.add(inferenceEngine); + } + + // 定义一个内部类来存储帧数据 + private static class FrameData { + public BufferedImage image; + public float[] floatArray; + + public FrameData(BufferedImage image, float[] floatArray) { + this.image = image; + this.floatArray = floatArray; + } + } + + // 将 BufferedImage 预处理并转换为一维 float[] 数组 + public static float[] preprocessAndConvertBufferedImage(BufferedImage image) { + int targetWidth = 640; + int targetHeight = 640; + + // 将 BufferedImage 转换为 Mat + Mat matImage = bufferedImageToMat(image); + + // 原始图像尺寸 + int origWidth = matImage.width(); + int origHeight = matImage.height(); + + // 计算缩放因子 + float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); + + // 计算新的图像尺寸 + int newWidth = Math.round(origWidth * scalingFactor); + int newHeight = Math.round(origHeight * scalingFactor); + + // 调整图像尺寸 + Mat resizedImage = new Mat(); + Imgproc.resize(matImage, resizedImage, new Size(newWidth, newHeight)); + + // 转换为 RGB + Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); + + // 创建目标图像并将调整后的图像填充到目标图像中 + Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); + int xOffset = (targetWidth - newWidth) / 2; + int yOffset = (targetHeight - newHeight) / 2; + Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); + resizedImage.copyTo(paddedImage.submat(roi)); + + // 将图像数据转换为输入所需的浮点数组 + int imageSize = targetWidth * targetHeight; + float[] inputData = new float[3 * imageSize]; + paddedImage.reshape(1, imageSize * 3).get(0, 0, inputData); + + // 释放资源 + matImage.release(); + resizedImage.release(); + paddedImage.release(); + + return inputData; + } + + // 辅助方法:将 BufferedImage 转换为 OpenCV 的 Mat 格式 + public static Mat bufferedImageToMat(BufferedImage bi) { + int width = bi.getWidth(); + int height = bi.getHeight(); + Mat mat = new Mat(height, width, CvType.CV_8UC3); + byte[] data = ((DataBufferByte) bi.getRaster().getDataBuffer()).getData(); + mat.put(0, 0, data); + return mat; + } + + // 可选的预处理方法 + public Map preprocessImage(Mat image) { + int targetWidth = 640; + int targetHeight = 640; + + int origWidth = image.width(); + int origHeight = image.height(); + + // 计算缩放因子 + float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); + + // 计算新的图像尺寸 + int newWidth = Math.round(origWidth * scalingFactor); + int newHeight = Math.round(origHeight * scalingFactor); + + // 调整图像尺寸 + Mat resizedImage = new Mat(); + Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight)); + + // 转换为 RGB 并归一化 + Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); + resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); + + // 创建填充后的图像 + Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); + int xOffset = (targetWidth - newWidth) / 2; + int yOffset = (targetHeight - newHeight) / 2; + Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); + resizedImage.copyTo(paddedImage.submat(roi)); + + // 将图像数据转换为数组 + int imageSize = targetWidth * targetHeight; + float[] chwData = new float[3 * imageSize]; + float[] hwcData = new float[3 * imageSize]; + paddedImage.get(0, 0, hwcData); + + // 转换为 CHW 格式 + int channelSize = imageSize; + for (int c = 0; c < 3; c++) { + for (int i = 0; i < imageSize; i++) { + chwData[c * channelSize + i] = hwcData[i * 3 + c]; + } + } + + // 释放图像资源 + resizedImage.release(); + paddedImage.release(); + + // 将预处理结果和偏移信息存入 Map + Map result = new HashMap<>(); + result.put("inputData", chwData); + result.put("origWidth", origWidth); + result.put("origHeight", origHeight); + result.put("scalingFactor", scalingFactor); + result.put("xOffset", xOffset); + result.put("yOffset", yOffset); + + return result; + } +} diff --git a/src/main/java/com/ly/utils/CameraDeviceLister.java b/src/main/java/com/ly/utils/CameraDeviceLister.java deleted file mode 100644 index 8f59cd1..0000000 --- a/src/main/java/com/ly/utils/CameraDeviceLister.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.ly.utils; - -import org.bytedeco.javacv.FrameGrabber; -import org.bytedeco.javacv.VideoInputFrameGrabber; - -public class CameraDeviceLister { - public static void main(String[] args) throws FrameGrabber.Exception { - String[] deviceDescriptions = VideoInputFrameGrabber.getDeviceDescriptions(); - for (String deviceDescription : deviceDescriptions) { - System.out.println("摄像头索引 " + ": " + deviceDescription); - } - } -} diff --git a/src/main/java/com/ly/utils/OpenCVTest.java b/src/main/java/com/ly/utils/OpenCVTest.java new file mode 100644 index 0000000..b25f50a --- /dev/null +++ b/src/main/java/com/ly/utils/OpenCVTest.java @@ -0,0 +1,28 @@ +package com.ly.utils; + +import org.opencv.core.Core; +import org.opencv.core.Mat; +import org.opencv.videoio.VideoCapture; + +public class OpenCVTest { +// static { +// nu.pattern.OpenCV.loadLocally(); +// } + + public static void main(String[] args) { + VideoCapture capture = new VideoCapture(0); // 打开默认摄像头 + if (!capture.isOpened()) { + System.out.println("无法打开摄像头"); + return; + } + + Mat frame = new Mat(); + if (capture.read(frame)) { + System.out.println("成功读取一帧图像"); + } else { + System.out.println("无法读取图像"); + } + + capture.release(); + } +} diff --git a/src/main/java/com/ly/utils/RTSPStreamer.java b/src/main/java/com/ly/utils/RTSPStreamer.java deleted file mode 100644 index 2cb6845..0000000 --- a/src/main/java/com/ly/utils/RTSPStreamer.java +++ /dev/null @@ -1,57 +0,0 @@ -package com.ly.utils; - -import org.bytedeco.ffmpeg.global.avcodec; -import org.bytedeco.javacv.*; - -public class RTSPStreamer { - - public static void main(String[] args) { - String inputFile = "C:\\Users\\ly\\Desktop\\屏幕录制 2024-09-20 225443.mp4"; // 替换为您的本地视频文件路径 - String rtspUrl = "rtsp://localhost:8554/live"; // 替换为您的 RTSP 服务器地址 - - FFmpegFrameGrabber grabber = null; - FFmpegFrameRecorder recorder = null; - - try { - // 初始化 FFmpegFrameGrabber 以从本地视频文件读取 - grabber = new FFmpegFrameGrabber(inputFile); - grabber.start(); - - // 初始化 FFmpegFrameRecorder 以推流到 RTSP 服务器 - recorder = new FFmpegFrameRecorder(rtspUrl, grabber.getImageWidth(), grabber.getImageHeight(), grabber.getAudioChannels()); - recorder.setFormat("rtsp"); - recorder.setFrameRate(grabber.getFrameRate()); - recorder.setVideoBitrate(grabber.getVideoBitrate()); - recorder.setVideoCodec(avcodec.AV_CODEC_ID_H264); // 设置视频编码格式 - recorder.setAudioCodec(avcodec.AV_CODEC_ID_AAC); // 设置音频编码格式 - - // 设置 RTSP 传输选项(如果需要) - recorder.setOption("rtsp_transport", "tcp"); - - recorder.start(); - - Frame frame; - while ((frame = grabber.grab()) != null) { - recorder.record(frame); - } - - System.out.println("推流完成。"); - - } catch (Exception e) { - e.printStackTrace(); - } finally { - try { - if (recorder != null) { - recorder.stop(); - recorder.release(); - } - if (grabber != null) { - grabber.stop(); - grabber.release(); - } - } catch (Exception e) { - e.printStackTrace(); - } - } - } -} diff --git a/src/main/resources/lib/win/opencv_videoio_ffmpeg470_64.dll b/src/main/resources/lib/win/opencv_videoio_ffmpeg470_64.dll new file mode 100644 index 0000000..798d5cd Binary files /dev/null and b/src/main/resources/lib/win/opencv_videoio_ffmpeg470_64.dll differ diff --git a/src/main/resources/model/best.onnx b/src/main/resources/model/best.onnx new file mode 100644 index 0000000..b9d7489 Binary files /dev/null and b/src/main/resources/model/best.onnx differ