diff --git a/src/main/java/com/ly/VideoInferenceApp.java b/src/main/java/com/ly/VideoInferenceApp.java index b21b378..3ff8d95 100644 --- a/src/main/java/com/ly/VideoInferenceApp.java +++ b/src/main/java/com/ly/VideoInferenceApp.java @@ -7,14 +7,15 @@ import com.ly.onnx.engine.InferenceEngine; import com.ly.onnx.model.ModelInfo; import com.ly.play.opencv.VideoPlayer; - import javax.swing.*; import javax.swing.filechooser.FileNameExtensionFilter; import javax.swing.filechooser.FileSystemView; import java.awt.*; +import java.awt.datatransfer.DataFlavor; import java.io.File; import java.util.ArrayList; import java.util.Collections; +import java.util.List; public class VideoInferenceApp extends JFrame { @@ -26,10 +27,9 @@ public class VideoInferenceApp extends JFrame { private ModelManager modelManager; - public VideoInferenceApp() { // 设置窗口标题 - super("https://gitee.com/sulv0302/onnx-inference4j-play.git"); + super("ONNX Inference Application"); // 初始化UI组件 initializeUI(); } @@ -49,13 +49,53 @@ public class VideoInferenceApp extends JFrame { videoPanel = new VideoPanel(); videoPanel.setBackground(Color.BLACK); - // 模型列表区域 + // 设置拖拽功能 + videoPanel.setTransferHandler(new TransferHandler() { + @Override + public boolean canImport(TransferSupport support) { + return support.isDataFlavorSupported(DataFlavor.javaFileListFlavor); + } + + @Override + public boolean importData(TransferSupport support) { + if (!canImport(support)) { + return false; + } + try { + // 获取拖拽的文件列表 + List files = (List) support.getTransferable().getTransferData(DataFlavor.javaFileListFlavor); + for (File file : files) { + String fileName = file.getName().toLowerCase(); + if (fileName.endsWith(".jpg") || fileName.endsWith(".jpeg") || + fileName.endsWith(".png") || fileName.endsWith(".bmp") || + fileName.endsWith(".gif")) { + // 加载并处理拖拽的图片文件 + videoPlayer.loadImage(file.getAbsolutePath()); + } else if (fileName.endsWith(".mp4") || fileName.endsWith(".avi") || + fileName.endsWith(".mkv") || fileName.endsWith(".mov") || + fileName.endsWith(".flv") || fileName.endsWith(".wmv")) { + // 加载并播放拖拽的视频文件 + videoPlayer.loadVideo(file.getAbsolutePath()); + } + } + } catch (Exception ex) { + ex.printStackTrace(); + return false; + } + return true; + } + }); + + // 初始化 ModelManager(不传递 videoPlayer) modelManager = new ModelManager(); modelManager.setPreferredSize(new Dimension(250, 0)); // 设置模型列表区域的宽度 - // 初始化 VideoPlayer + // 初始化 VideoPlayer 并传递 modelManager videoPlayer = new VideoPlayer(videoPanel, modelManager); + // 将 videoPlayer 设置到 modelManager 中 + modelManager.setVideoPlayer(videoPlayer); + // 使用 JSplitPane 分割视频区域和模型列表区域 JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT, videoPanel, modelManager); splitPane.setResizeWeight(0.8); // 视频区域初始占据80%的空间 @@ -97,6 +137,10 @@ public class VideoInferenceApp extends JFrame { JButton loadVideoButton = new JButton("选择视频文件"); loadVideoButton.setPreferredSize(new Dimension(150, 30)); + // 图片文件选择按钮 + JButton loadImageButton = new JButton("选择图片文件"); + loadImageButton.setPreferredSize(new Dimension(150, 30)); + // 模型文件选择按钮 JButton loadModelButton = new JButton("选择模型"); loadModelButton.setPreferredSize(new Dimension(150, 30)); @@ -108,12 +152,19 @@ public class VideoInferenceApp extends JFrame { JButton startPlayButton = new JButton("开始播放"); startPlayButton.setPreferredSize(new Dimension(100, 30)); + // 添加目标跟踪复选框 + JCheckBox trackingCheckBox = new JCheckBox("启用目标跟踪"); + trackingCheckBox.setSelected(false); // 默认不启用目标跟踪 + // 将按钮和输入框添加到顶部面板 topPanel.add(loadVideoButton); + topPanel.add(loadImageButton); // 添加图片按钮 topPanel.add(loadModelButton); topPanel.add(new JLabel("流地址:")); topPanel.add(streamUrlField); topPanel.add(startPlayButton); + // 将复选框添加到顶部面板 + topPanel.add(trackingCheckBox); this.add(topPanel, BorderLayout.NORTH); @@ -126,6 +177,9 @@ public class VideoInferenceApp extends JFrame { // 添加视频加载按钮的行为 loadVideoButton.addActionListener(e -> selectVideoFile()); + // 添加图片加载按钮的行为 + loadImageButton.addActionListener(e -> selectImageFile()); + loadModelButton.addActionListener(e -> { modelManager.loadModel(this); DefaultListModel modelList = modelManager.getModelList(); @@ -141,16 +195,31 @@ public class VideoInferenceApp extends JFrame { } }); + // 为复选框添加监听器,动态启用或禁用目标跟踪 + trackingCheckBox.addActionListener(e -> { + boolean isSelected = trackingCheckBox.isSelected(); // 获取当前复选框状态 + videoPlayer.setTrackingEnabled(isSelected); // 设置是否启用目标跟踪 + }); // 播放按钮 - playButton.addActionListener(e -> videoPlayer.playVideo()); + playButton.addActionListener(e -> { + videoPlayer.playVideo(); + }); // 暂停按钮 pauseButton.addActionListener(e -> videoPlayer.pauseVideo()); -// // 重播按钮 -// replayButton.addActionListener(e -> videoPlayer.replayVideo()); -// + // 重播按钮 + replayButton.addActionListener(e -> { + try { +// videoPlayer.loadVideo(videoPlayer.getCurrentVideoPath()); + videoPlayer.playVideo(); + } catch (Exception ex) { + ex.printStackTrace(); + JOptionPane.showMessageDialog(this, "重播视频失败: " + ex.getMessage(), "错误", JOptionPane.ERROR_MESSAGE); + } + }); + // // 后退5秒 // rewind5sButton.addActionListener(e -> videoPlayer.rewind(5000)); // @@ -195,6 +264,28 @@ public class VideoInferenceApp extends JFrame { } } + // 选择图片文件 + private void selectImageFile() { + File desktopDir = FileSystemView.getFileSystemView().getHomeDirectory(); + JFileChooser fileChooser = new JFileChooser(desktopDir); + fileChooser.setDialogTitle("选择图片文件"); + // 设置图片文件过滤器,支持常见的图片格式 + FileNameExtensionFilter imageFilter = new FileNameExtensionFilter( + "图片文件 (*.jpg;*.jpeg;*.png;*.bmp;*.gif)", "jpg", "jpeg", "png", "bmp", "gif"); + fileChooser.setFileFilter(imageFilter); + + int returnValue = fileChooser.showOpenDialog(this); + if (returnValue == JFileChooser.APPROVE_OPTION) { + File selectedFile = fileChooser.getSelectedFile(); + try { + videoPlayer.loadImage(selectedFile.getAbsolutePath()); + } catch (Exception ex) { + ex.printStackTrace(); + JOptionPane.showMessageDialog(this, "加载图片失败: " + ex.getMessage(), "错误", JOptionPane.ERROR_MESSAGE); + } + } + } + public static void main(String[] args) { SwingUtilities.invokeLater(VideoInferenceApp::new); } diff --git a/src/main/java/com/ly/lishi/InferenceEngine.java b/src/main/java/com/ly/lishi/InferenceEngine.java deleted file mode 100644 index 63ffa66..0000000 --- a/src/main/java/com/ly/lishi/InferenceEngine.java +++ /dev/null @@ -1,410 +0,0 @@ -package com.ly.lishi; - -import ai.onnxruntime.*; -import com.alibaba.fastjson.JSON; -import com.ly.onnx.model.BoundingBox; -import com.ly.onnx.model.InferenceResult; -import lombok.Data; -import org.opencv.core.*; -import org.opencv.imgcodecs.Imgcodecs; -import org.opencv.imgproc.Imgproc; - -import java.nio.FloatBuffer; -import java.util.*; - -@Data -public class InferenceEngine { - - private OrtEnvironment environment; - private OrtSession.SessionOptions sessionOptions; - private OrtSession session; - - private String modelPath; - private List labels; - - // 用于存储图像预处理信息的类变量 - private long[] inputShape = null; - - static { - nu.pattern.OpenCV.loadLocally(); - } - - public InferenceEngine(String modelPath, List labels) { - this.modelPath = modelPath; - this.labels = labels; - init(); - } - - public void init() { - try { - environment = OrtEnvironment.getEnvironment(); - sessionOptions = new OrtSession.SessionOptions(); - sessionOptions.addCUDA(0); // 使用 GPU - session = environment.createSession(modelPath, sessionOptions); - Map inputInfo = session.getInputInfo(); - NodeInfo nodeInfo = inputInfo.values().iterator().next(); - TensorInfo tensorInfo = (TensorInfo) nodeInfo.getInfo(); - inputShape = tensorInfo.getShape(); // 从模型中获取输入形状 - logModelInfo(session); - } catch (OrtException e) { - throw new RuntimeException("模型加载失败", e); - } - } - - public InferenceResult infer(int w, int h, Map preprocessParams) { - long startTime = System.currentTimeMillis(); - - // 从 Map 中获取偏移相关的变量 - float[] inputData = (float[]) preprocessParams.get("inputData"); - int origWidth = (int) preprocessParams.get("origWidth"); - int origHeight = (int) preprocessParams.get("origHeight"); - float scalingFactor = (float) preprocessParams.get("scalingFactor"); - int xOffset = (int) preprocessParams.get("xOffset"); - int yOffset = (int) preprocessParams.get("yOffset"); - - try { - Map inputInfo = session.getInputInfo(); - String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入 - - long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状 - - // 创建输入张量时,使用 CHW 格式的数据 - OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape); - - // 执行推理 - long inferenceStart = System.currentTimeMillis(); - OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor)); - long inferenceEnd = System.currentTimeMillis(); - System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms"); - - // 解析推理结果 - String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出 - float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 6] - - // 设定置信度阈值 - float confidenceThreshold = 0.25f; // 您可以根据需要调整 - - // 根据模型的输出结果解析边界框 - List boxes = new ArrayList<>(); - for (float[] data : outputData[0]) { // 遍历所有检测框 - // 根据模型输出格式,解析中心坐标和宽高 - float x_center = data[0]; - float y_center = data[1]; - float width = data[2]; - float height = data[3]; - float confidence = data[4]; - - if (confidence >= confidenceThreshold) { - // 将中心坐标转换为左上角和右下角坐标 - float x1 = x_center - width / 2; - float y1 = y_center - height / 2; - float x2 = x_center + width / 2; - float y2 = y_center + height / 2; - - // 调整坐标,减去偏移并除以缩放因子 - float x1Adjusted = (x1 - xOffset) / scalingFactor; - float y1Adjusted = (y1 - yOffset) / scalingFactor; - float x2Adjusted = (x2 - xOffset) / scalingFactor; - float y2Adjusted = (y2 - yOffset) / scalingFactor; - - // 确保坐标的正确顺序 - float xMinAdjusted = Math.min(x1Adjusted, x2Adjusted); - float xMaxAdjusted = Math.max(x1Adjusted, x2Adjusted); - float yMinAdjusted = Math.min(y1Adjusted, y2Adjusted); - float yMaxAdjusted = Math.max(y1Adjusted, y2Adjusted); - - // 确保坐标在原始图像范围内 - int x = (int) Math.max(0, xMinAdjusted); - int y = (int) Math.max(0, yMinAdjusted); - int xMax = (int) Math.min(origWidth, xMaxAdjusted); - int yMax = (int) Math.min(origHeight, yMaxAdjusted); - int wBox = xMax - x; - int hBox = yMax - y; - - // 仅当宽度和高度为正时,才添加边界框 - if (wBox > 0 && hBox > 0) { - // 使用您的单一标签 - String label = labels.get(0); - - boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence)); - } - } - } - - // 非极大值抑制(NMS) - long nmsStart = System.currentTimeMillis(); - List nmsBoxes = nonMaximumSuppression(boxes, 0.5f); - - System.out.println("检测到的标签:" + JSON.toJSONString(nmsBoxes)); - if (!nmsBoxes.isEmpty()) { - for (BoundingBox box : nmsBoxes) { - System.out.println(box); - } - } - long nmsEnd = System.currentTimeMillis(); - System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms"); - - // 封装结果并返回 - InferenceResult inferenceResult = new InferenceResult(); - inferenceResult.setBoundingBoxes(nmsBoxes); - - long endTime = System.currentTimeMillis(); - System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms"); - - return inferenceResult; - - } catch (OrtException e) { - throw new RuntimeException("推理失败", e); - } - } - - - // 计算两个边界框的 IoU - private float computeIoU(BoundingBox box1, BoundingBox box2) { - int x1 = Math.max(box1.getX(), box2.getX()); - int y1 = Math.max(box1.getY(), box2.getY()); - int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth()); - int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight()); - - int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1); - int box1Area = box1.getWidth() * box1.getHeight(); - int box2Area = box2.getWidth() * box2.getHeight(); - - return (float) intersectionArea / (box1Area + box2Area - intersectionArea); - } - - // 非极大值抑制(NMS)方法 - private List nonMaximumSuppression(List boxes, float iouThreshold) { - // 按置信度排序(从高到低) - boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence())); - - List result = new ArrayList<>(); - - while (!boxes.isEmpty()) { - BoundingBox bestBox = boxes.remove(0); - result.add(bestBox); - - Iterator iterator = boxes.iterator(); - while (iterator.hasNext()) { - BoundingBox box = iterator.next(); - if (computeIoU(bestBox, box) > iouThreshold) { - iterator.remove(); - } - } - } - - return result; - } - - // 打印模型信息 - private void logModelInfo(OrtSession session) { - System.out.println("模型输入信息:"); - try { - for (Map.Entry entry : session.getInputInfo().entrySet()) { - String name = entry.getKey(); - NodeInfo info = entry.getValue(); - System.out.println("输入名称: " + name); - System.out.println("输入信息: " + info.toString()); - } - } catch (OrtException e) { - throw new RuntimeException(e); - } - - System.out.println("模型输出信息:"); - try { - for (Map.Entry entry : session.getOutputInfo().entrySet()) { - String name = entry.getKey(); - NodeInfo info = entry.getValue(); - System.out.println("输出名称: " + name); - System.out.println("输出信息: " + info.toString()); - } - } catch (OrtException e) { - throw new RuntimeException(e); - } - } - - public static void main(String[] args) { - // 加载 OpenCV 库 - - // 初始化标签列表(只有一个标签) - List labels = Arrays.asList("person"); - - // 创建 InferenceEngine 实例 - InferenceEngine inferenceEngine = new InferenceEngine("C:\\Users\\ly\\Desktop\\person.onnx", labels); - - for (int j = 0; j < 10; j++) { - try { - // 加载图片 - Mat inputImage = Imgcodecs.imread("C:\\Users\\ly\\Desktop\\10230731212230.png"); - - // 预处理图像 - long l1 = System.currentTimeMillis(); - Map preprocessResult = inferenceEngine.preprocessImage(inputImage); - float[] inputData = (float[]) preprocessResult.get("inputData"); - - InferenceResult result = null; - for (int i = 0; i < 10; i++) { - long l = System.currentTimeMillis(); - result = inferenceEngine.infer( 640, 640, preprocessResult); - System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms"); - } - - - // 处理并显示结果 - System.out.println("推理结果:"); - for (BoundingBox box : result.getBoundingBoxes()) { - System.out.println(box); - } - - // 可视化并保存带有边界框的图像 - Mat outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes()); - - // 保存图片到本地文件 - String outputFilePath = "output_image_with_boxes.jpg"; - Imgcodecs.imwrite(outputFilePath, outputImage); - - System.out.println("已保存结果图片: " + outputFilePath); - - } catch (Exception e) { - e.printStackTrace(); - } - } - } - - // 在图像上绘制边界框和标签 - private Mat drawBoundingBoxes(Mat image, List boxes) { - for (BoundingBox box : boxes) { - // 绘制矩形边界框 - Imgproc.rectangle(image, new Point(box.getX(), box.getY()), - new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight()), - new Scalar(0, 0, 255), 2); // 红色边框 - - // 绘制标签文字和置信度 - String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); - int baseLine[] = new int[1]; - Size labelSize = Imgproc.getTextSize(label, Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, 1, baseLine); - int top = Math.max(box.getY(), (int) labelSize.height); - Imgproc.putText(image, label, new Point(box.getX(), top), - Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, new Scalar(255, 255, 255), 1); - } - - return image; - } - - - public Map preprocessImage(Mat image) { - int targetWidth = 640; - int targetHeight = 640; - - int origWidth = image.width(); - int origHeight = image.height(); - - // 计算缩放因子 - float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); - - // 计算新的图像尺寸 - int newWidth = Math.round(origWidth * scalingFactor); - int newHeight = Math.round(origHeight * scalingFactor); - - // 计算偏移量以居中图像 - int xOffset = (targetWidth - newWidth) / 2; - int yOffset = (targetHeight - newHeight) / 2; - - // 调整图像尺寸 - Mat resizedImage = new Mat(); - Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR); - - // 转换为 RGB 并归一化 - Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); - resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); - - // 创建填充后的图像 - Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); - Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); - resizedImage.copyTo(paddedImage.submat(roi)); - - // 将图像数据转换为数组 - int imageSize = targetWidth * targetHeight; - float[] chwData = new float[3 * imageSize]; - float[] hwcData = new float[3 * imageSize]; - paddedImage.get(0, 0, hwcData); - - // 转换为 CHW 格式 - int channelSize = imageSize; - for (int c = 0; c < 3; c++) { - for (int i = 0; i < imageSize; i++) { - chwData[c * channelSize + i] = hwcData[i * 3 + c]; - } - } - - // 释放图像资源 - resizedImage.release(); - paddedImage.release(); - - // 将预处理结果和偏移信息存入 Map - Map result = new HashMap<>(); - result.put("inputData", chwData); - result.put("origWidth", origWidth); - result.put("origHeight", origHeight); - result.put("scalingFactor", scalingFactor); - result.put("xOffset", xOffset); - result.put("yOffset", yOffset); - - return result; - } - - - // 图像预处理 -// public float[] preprocessImage(Mat image) { -// int targetWidth = 640; -// int targetHeight = 640; -// -// origWidth = image.width(); -// origHeight = image.height(); -// -// // 计算缩放因子 -// scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); -// -// // 计算新的图像尺寸 -// newWidth = Math.round(origWidth * scalingFactor); -// newHeight = Math.round(origHeight * scalingFactor); -// -// // 计算偏移量以居中图像 -// xOffset = (targetWidth - newWidth) / 2; -// yOffset = (targetHeight - newHeight) / 2; -// -// // 调整图像尺寸 -// Mat resizedImage = new Mat(); -// Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR); -// -// // 转换为 RGB 并归一化 -// Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); -// resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); -// -// // 创建填充后的图像 -// Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); -// Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); -// resizedImage.copyTo(paddedImage.submat(roi)); -// -// // 将图像数据转换为数组 -// int imageSize = targetWidth * targetHeight; -// float[] chwData = new float[3 * imageSize]; -// float[] hwcData = new float[3 * imageSize]; -// paddedImage.get(0, 0, hwcData); -// -// // 转换为 CHW 格式 -// int channelSize = imageSize; -// for (int c = 0; c < 3; c++) { -// for (int i = 0; i < imageSize; i++) { -// chwData[c * channelSize + i] = hwcData[i * 3 + c]; -// } -// } -// -// // 释放图像资源 -// resizedImage.release(); -// paddedImage.release(); -// -// return chwData; -// } - -} diff --git a/src/main/java/com/ly/lishi/VideoPlayer.java b/src/main/java/com/ly/lishi/VideoPlayer.java deleted file mode 100644 index b263eb5..0000000 --- a/src/main/java/com/ly/lishi/VideoPlayer.java +++ /dev/null @@ -1,378 +0,0 @@ -package com.ly.lishi; - -import com.ly.layout.VideoPanel; -import com.ly.model_load.ModelManager; -import com.ly.onnx.engine.InferenceEngine; -import com.ly.onnx.model.InferenceResult; -import com.ly.onnx.utils.DrawImagesUtils; -import org.opencv.core.CvType; -import org.opencv.core.Mat; -import org.opencv.core.Rect; -import org.opencv.core.Size; -import org.opencv.imgproc.Imgproc; -import org.opencv.videoio.VideoCapture; -import org.opencv.videoio.Videoio; - -import javax.swing.*; -import java.awt.image.BufferedImage; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; - -import static com.ly.onnx.utils.ImageUtils.matToBufferedImage; - -public class VideoPlayer { - static { - // 加载 OpenCV 库 - nu.pattern.OpenCV.loadLocally(); - String OS = System.getProperty("os.name").toLowerCase(); - if (OS.contains("win")) { - System.load(ClassLoader.getSystemResource("lib/win/opencv_videoio_ffmpeg470_64.dll").getPath()); - } - } - - private VideoCapture videoCapture; - private volatile boolean isPlaying = false; - private volatile boolean isPaused = false; - private Thread frameReadingThread; - private Thread inferenceThread; - private VideoPanel videoPanel; - - private long videoDuration = 0; // 毫秒 - private long currentTimestamp = 0; // 毫秒 - - - - private ModelManager modelManager; - private List inferenceEngines = new ArrayList<>(); - - // 定义阻塞队列来缓冲转换后的数据 - private BlockingQueue frameDataQueue = new LinkedBlockingQueue<>(10); // 队列容量可根据需要调整 - - public VideoPlayer(VideoPanel videoPanel, ModelManager modelManager) { - this.videoPanel = videoPanel; - this.modelManager = modelManager; - } - - // 加载视频或流 - public void loadVideo(String videoFilePathOrStreamUrl) throws Exception { - stopVideo(); - if (videoFilePathOrStreamUrl.equals("0")) { - int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl); - videoCapture = new VideoCapture(cameraIndex); - if (!videoCapture.isOpened()) { - throw new Exception("无法打开摄像头"); - } - videoDuration = 0; // 摄像头没有固定的时长 - playVideo(); - } else { - // 输入不是数字,尝试打开视频文件 - videoCapture = new VideoCapture(videoFilePathOrStreamUrl, Videoio.CAP_FFMPEG); - if (!videoCapture.isOpened()) { - throw new Exception("无法打开视频文件:" + videoFilePathOrStreamUrl); - } - double frameCount = videoCapture.get(Videoio.CAP_PROP_FRAME_COUNT); - double fps = videoCapture.get(Videoio.CAP_PROP_FPS); - if (fps <= 0 || Double.isNaN(fps)) { - fps = 25; // 默认帧率 - } - videoDuration = (long) (frameCount / fps * 1000); // 转换为毫秒 - } - - // 显示第一帧 - Mat frame = new Mat(); - if (videoCapture.read(frame)) { - BufferedImage bufferedImage = matToBufferedImage(frame); - videoPanel.updateImage(bufferedImage); - currentTimestamp = 0; - } else { - throw new Exception("无法读取第一帧"); - } - - // 重置到视频开始位置 - videoCapture.set(Videoio.CAP_PROP_POS_FRAMES, 0); - currentTimestamp = 0; - } - - // 播放 - public void playVideo() { - if (videoCapture == null || !videoCapture.isOpened()) { - JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE); - return; - } - - if (isPlaying) { - if (isPaused) { - isPaused = false; // 恢复播放 - } - return; - } - - isPlaying = true; - isPaused = false; - - frameDataQueue.clear(); // 开始播放前清空队列 - - // 创建并启动帧读取和转换线程 - frameReadingThread = new Thread(() -> { - try { - double fps = videoCapture.get(Videoio.CAP_PROP_FPS); - if (fps <= 0 || Double.isNaN(fps)) { - fps = 25; // 默认帧率 - } - long frameDelay = (long) (1000 / fps); - - while (isPlaying) { - if (Thread.currentThread().isInterrupted()) { - break; - } - if (isPaused) { - Thread.sleep(10); - continue; - } - - Mat frame = new Mat(); - if (!videoCapture.read(frame) || frame.empty()) { - isPlaying = false; - break; - } - - long startTime = System.currentTimeMillis(); - BufferedImage bufferedImage = matToBufferedImage(frame); - - if (bufferedImage != null) { -// float[] floats = preprocessAndConvertBufferedImage(bufferedImage); - Map stringObjectMap = preprocessImage(frame); - // 创建 FrameData 对象并放入队列 - FrameData frameData = new FrameData(bufferedImage, null,stringObjectMap); - frameDataQueue.put(frameData); // 阻塞,如果队列已满 - } - - // 控制帧率 - currentTimestamp = (long) videoCapture.get(Videoio.CAP_PROP_POS_MSEC); - - // 控制播放速度 - long processingTime = System.currentTimeMillis() - startTime; - long sleepTime = frameDelay - processingTime; - if (sleepTime > 0) { - Thread.sleep(sleepTime); - } - } - } catch (Exception ex) { - ex.printStackTrace(); - } finally { - isPlaying = false; - } - }); - - // 创建并启动推理线程 - inferenceThread = new Thread(() -> { - try { - while (isPlaying || !frameDataQueue.isEmpty()) { - if (Thread.currentThread().isInterrupted()) { - break; - } - if (isPaused) { - Thread.sleep(100); - continue; - } - - FrameData frameData = frameDataQueue.poll(100, TimeUnit.MILLISECONDS); // 等待数据 - if (frameData == null) { - continue; // 没有数据,继续检查 isPlaying - } - - BufferedImage bufferedImage = frameData.image; - Map floatObjectMap = frameData.floatObjectMap; - - // 执行推理 - List inferenceResults = new ArrayList<>(); - for (InferenceEngine inferenceEngine : inferenceEngines) { - // 假设 InferenceEngine 有 infer 方法接受 float 数组 -// inferenceResults.add(inferenceEngine.infer( 640, 640,floatObjectMap)); - } - // 绘制推理结果 - DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults); - - // 更新绘制后图像 - videoPanel.updateImage(bufferedImage); - } - } catch (Exception ex) { - ex.printStackTrace(); - } - }); - - frameReadingThread.start(); - inferenceThread.start(); - } - - // 暂停视频 - public void pauseVideo() { - if (!isPlaying) { - return; - } - isPaused = true; - } - - // 重播视频 - public void replayVideo() { - try { - stopVideo(); // 停止当前播放 - if (videoCapture != null) { - videoCapture.set(Videoio.CAP_PROP_POS_FRAMES, 0); - currentTimestamp = 0; - - // 显示第一帧 - Mat frame = new Mat(); - if (videoCapture.read(frame)) { - BufferedImage bufferedImage = matToBufferedImage(frame); - videoPanel.updateImage(bufferedImage); - } - - playVideo(); // 开始播放 - } - } catch (Exception e) { - e.printStackTrace(); - JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE); - } - } - - // 停止视频 - public void stopVideo() { - isPlaying = false; - isPaused = false; - - if (frameReadingThread != null && frameReadingThread.isAlive()) { - frameReadingThread.interrupt(); - } - - if (inferenceThread != null && inferenceThread.isAlive()) { - inferenceThread.interrupt(); - } - - if (videoCapture != null) { - videoCapture.release(); - videoCapture = null; - } - - frameDataQueue.clear(); - } - - // 快进或后退 - public void seekTo(long seekTime) { - if (videoCapture == null) return; - try { - isPaused = false; // 取消暂停 - stopVideo(); // 停止当前播放 - videoCapture.set(Videoio.CAP_PROP_POS_MSEC, seekTime); - currentTimestamp = seekTime; - - Mat frame = new Mat(); - if (videoCapture.read(frame)) { - BufferedImage bufferedImage = matToBufferedImage(frame); - videoPanel.updateImage(bufferedImage); - } - - // 重新开始播放 - playVideo(); - - } catch (Exception ex) { - ex.printStackTrace(); - } - } - - // 快进 - public void fastForward(long milliseconds) { - long newTime = Math.min(currentTimestamp + milliseconds, videoDuration); - seekTo(newTime); - } - - // 后退 - public void rewind(long milliseconds) { - long newTime = Math.max(currentTimestamp - milliseconds, 0); - seekTo(newTime); - } - - public void addInferenceEngines(InferenceEngine inferenceEngine) { - this.inferenceEngines.add(inferenceEngine); - } - - // 定义一个内部类来存储帧数据 - private static class FrameData { - public BufferedImage image; - public float[] floatArray; - public Map floatObjectMap; - - public FrameData(BufferedImage image, float[] floatArray, Map floatObjectMap) { - this.image = image; - this.floatArray = floatArray; - this.floatObjectMap = floatObjectMap; - } - } - - - // 可选的预处理方法 - public Map preprocessImage(Mat image) { - int targetWidth = 640; - int targetHeight = 640; - - int origWidth = image.width(); - int origHeight = image.height(); - - // 计算缩放因子 - float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); - - // 计算新的图像尺寸 - int newWidth = Math.round(origWidth * scalingFactor); - int newHeight = Math.round(origHeight * scalingFactor); - - // 调整图像尺寸 - Mat resizedImage = new Mat(); - Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight)); - - // 转换为 RGB 并归一化 - Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); - resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); - - // 创建填充后的图像 - Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); - int xOffset = (targetWidth - newWidth) / 2; - int yOffset = (targetHeight - newHeight) / 2; - Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); - resizedImage.copyTo(paddedImage.submat(roi)); - - // 将图像数据转换为数组 - int imageSize = targetWidth * targetHeight; - float[] chwData = new float[3 * imageSize]; - float[] hwcData = new float[3 * imageSize]; - paddedImage.get(0, 0, hwcData); - - // 转换为 CHW 格式 - int channelSize = imageSize; - for (int c = 0; c < 3; c++) { - for (int i = 0; i < imageSize; i++) { - chwData[c * channelSize + i] = hwcData[i * 3 + c]; - } - } - - // 释放图像资源 - resizedImage.release(); - paddedImage.release(); - - // 将预处理结果和偏移信息存入 Map - Map result = new HashMap<>(); - result.put("inputData", chwData); - result.put("origWidth", origWidth); - result.put("origHeight", origHeight); - result.put("scalingFactor", scalingFactor); - result.put("xOffset", xOffset); - result.put("yOffset", yOffset); - - return result; - } - -} diff --git a/src/main/java/com/ly/model_load/ModelManager.java b/src/main/java/com/ly/model_load/ModelManager.java index d3dd348..3f36daa 100644 --- a/src/main/java/com/ly/model_load/ModelManager.java +++ b/src/main/java/com/ly/model_load/ModelManager.java @@ -1,19 +1,30 @@ package com.ly.model_load; import com.ly.file.FileEditor; +import com.ly.onnx.engine.InferenceEngine; import com.ly.onnx.model.ModelInfo; +import com.ly.play.opencv.VideoPlayer; import javax.swing.*; import javax.swing.filechooser.FileNameExtensionFilter; import javax.swing.filechooser.FileSystemView; import java.awt.*; +import java.awt.datatransfer.DataFlavor; import java.awt.event.MouseAdapter; import java.awt.event.MouseEvent; import java.io.File; +import java.io.FileReader; +import java.util.ArrayList; +import java.util.List; +import java.io.BufferedReader; + public class ModelManager extends JPanel { private DefaultListModel modelListModel; private JList modelList; + private JPopupMenu popupMenu; + private VideoPlayer videoPlayer; + public ModelManager() { setLayout(new BorderLayout()); @@ -23,6 +34,45 @@ public class ModelManager extends JPanel { JScrollPane modelScrollPane = new JScrollPane(modelList); add(modelScrollPane, BorderLayout.CENTER); + // 创建右键菜单 + popupMenu = new JPopupMenu(); + JMenuItem deleteMenuItem = new JMenuItem("删除"); + popupMenu.add(deleteMenuItem); + + // 为模型列表添加右键菜单 + modelList.addMouseListener(new MouseAdapter() { + public void mousePressed(MouseEvent e) { + if (e.isPopupTrigger()) { // 如果是右键触发 + showPopup(e); + } + } + + public void mouseReleased(MouseEvent e) { + if (e.isPopupTrigger()) { // 如果是右键触发 + showPopup(e); + } + } + + private void showPopup(MouseEvent e) { + int index = modelList.locationToIndex(e.getPoint()); + if (index != -1) { + modelList.setSelectedIndex(index); // 选中右键点击的行 + popupMenu.show(modelList, e.getX(), e.getY()); + } + } + }); + + // 为删除菜单项添加操作 + deleteMenuItem.addActionListener(e -> { + int selectedIndex = modelList.getSelectedIndex(); + if (selectedIndex != -1) { + int confirmation = JOptionPane.showConfirmDialog(null, "确定要删除此模型吗?", "确认删除", JOptionPane.YES_NO_OPTION); + if (confirmation == JOptionPane.YES_OPTION) { + modelListModel.remove(selectedIndex); // 删除选中的模型 + } + } + }); + // 双击编辑标签文件 modelList.addMouseListener(new MouseAdapter() { public void mouseClicked(MouseEvent e) { @@ -36,8 +86,78 @@ public class ModelManager extends JPanel { } } }); + + // 设置拖拽功能处理模型和标签文件 + setTransferHandler(new TransferHandler() { + @Override + public boolean canImport(TransferSupport support) { + return support.isDataFlavorSupported(DataFlavor.javaFileListFlavor); + } + + @Override + public boolean importData(TransferSupport support) { + if (!canImport(support)) { + return false; + } + try { + // 获取拖拽的文件列表 + List files = (List) support.getTransferable().getTransferData(DataFlavor.javaFileListFlavor); + if (files.size() == 2) { // 确保拖拽的是两个文件 + File modelFile = null; + File labelFile = null; + + for (File file : files) { + if (file.getName().endsWith(".onnx")) { + modelFile = file; + } else if (file.getName().endsWith(".txt")) { + labelFile = file; + } + } + + if (modelFile != null && labelFile != null) { + // 确保 videoPlayer 被正确设置 + if (videoPlayer == null) { + throw new IllegalStateException("VideoPlayer is not set in ModelManager."); + } + + // 添加模型信息到列表 + ModelInfo modelInfo = new ModelInfo(modelFile.getAbsolutePath(), labelFile.getAbsolutePath()); + modelListModel.addElement(modelInfo); + + // 读取标签文件内容,转为 List + List labels = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new FileReader(labelFile))) { + String line; + while ((line = reader.readLine()) != null) { + labels.add(line.trim()); + } + } + + // 创建推理引擎并传递给 VideoPlayer + InferenceEngine inferenceEngine = new InferenceEngine(modelFile.getAbsolutePath(), labels); + videoPlayer.addInferenceEngines(inferenceEngine); + return true; + } else { + JOptionPane.showMessageDialog(null, "请拖入一个 .onnx 文件和一个 .txt 文件。", "提示", JOptionPane.WARNING_MESSAGE); + } + } else { + JOptionPane.showMessageDialog(null, "请拖入两个文件:一个 .onnx 文件和一个 .txt 文件。", "提示", JOptionPane.WARNING_MESSAGE); + } + } catch (Exception ex) { + ex.printStackTrace(); + return false; + } + return false; + } + }); } + // 添加设置 VideoPlayer 的方法 + public void setVideoPlayer(VideoPlayer videoPlayer) { + this.videoPlayer = videoPlayer; + } + + // 加载模型 public void loadModel(JFrame parent) { File desktopDir = FileSystemView.getFileSystemView().getHomeDirectory(); diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine.java b/src/main/java/com/ly/onnx/engine/InferenceEngine.java index fbc51cb..f540dbb 100644 --- a/src/main/java/com/ly/onnx/engine/InferenceEngine.java +++ b/src/main/java/com/ly/onnx/engine/InferenceEngine.java @@ -154,7 +154,6 @@ public class InferenceEngine { if (wBox > 0 && hBox > 0) { // 使用您的单一标签 String label = labels.get(0); - boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence)); } } diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine_up.java b/src/main/java/com/ly/onnx/engine/InferenceEngine_up.java deleted file mode 100644 index 6acd759..0000000 --- a/src/main/java/com/ly/onnx/engine/InferenceEngine_up.java +++ /dev/null @@ -1,297 +0,0 @@ -package com.ly.onnx.engine; - -import ai.onnxruntime.*; -import com.ly.onnx.model.BoundingBox; -import com.ly.onnx.model.InferenceResult; - -import javax.imageio.ImageIO; -import java.awt.*; -import java.awt.image.BufferedImage; -import java.io.File; -import java.nio.FloatBuffer; -import java.util.List; -import java.util.*; - -public class InferenceEngine_up { - - OrtEnvironment environment = OrtEnvironment.getEnvironment(); - OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); - - private String modelPath; - private List labels; - - // 添加用于存储图像预处理信息的类变量 - private int origWidth; - private int origHeight; - private int newWidth; - private int newHeight; - private float scalingFactor; - private int xOffset; - private int yOffset; - - public InferenceEngine_up(String modelPath, List labels) { - this.modelPath = modelPath; - this.labels = labels; - init(); - } - - public void init() { - OrtSession session = null; - try { - sessionOptions.addCUDA(0); - session = environment.createSession(modelPath, sessionOptions); - - } catch (OrtException e) { - throw new RuntimeException(e); - } - logModelInfo(session); - } - - public InferenceResult infer(float[] inputData, int w, int h) { - // 创建ONNX输入Tensor - try (OrtSession session = environment.createSession(modelPath, sessionOptions)) { - Map inputInfo = session.getInputInfo(); - String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入 - - long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状 - OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape); - - // 执行推理 - OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor)); - - // 解析推理结果 - String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出 - float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5] - - long l = System.currentTimeMillis(); - // 设定置信度阈值 - float confidenceThreshold = 0.5f; // 您可以根据需要调整 - - // 根据模型的输出结果解析边界框 - List boxes = new ArrayList<>(); - for (float[] data : outputData[0]) { // 遍历所有检测框 - float confidence = data[4]; - if (confidence >= confidenceThreshold) { - float xCenter = data[0]; - float yCenter = data[1]; - float widthBox = data[2]; - float heightBox = data[3]; - - // 调整坐标,减去偏移并除以缩放因子 - float xCenterAdjusted = (xCenter - xOffset) / scalingFactor; - float yCenterAdjusted = (yCenter - yOffset) / scalingFactor; - float widthAdjusted = widthBox / scalingFactor; - float heightAdjusted = heightBox / scalingFactor; - - // 计算左上角坐标 - int x = (int) (xCenterAdjusted - widthAdjusted / 2); - int y = (int) (yCenterAdjusted - heightAdjusted / 2); - int wBox = (int) widthAdjusted; - int hBox = (int) heightAdjusted; - - // 确保坐标在原始图像范围内 - if (x < 0) x = 0; - if (y < 0) y = 0; - if (x + wBox > origWidth) wBox = origWidth - x; - if (y + hBox > origHeight) hBox = origHeight - y; - - String label = "person"; // 由于只有一个类别 - - boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence)); - } - } - - // 非极大值抑制(NMS) - List nmsBoxes = nonMaximumSuppression(boxes, 0.5f); - System.out.println("耗时:"+((System.currentTimeMillis() - l))); - // 封装结果并返回 - InferenceResult inferenceResult = new InferenceResult(); - inferenceResult.setBoundingBoxes(nmsBoxes); - return inferenceResult; - - } catch (OrtException e) { - throw new RuntimeException("推理失败", e); - } - } - - // 计算两个边界框的 IoU - private float computeIoU(BoundingBox box1, BoundingBox box2) { - int x1 = Math.max(box1.getX(), box2.getX()); - int y1 = Math.max(box1.getY(), box2.getY()); - int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth()); - int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight()); - - int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1); - int box1Area = box1.getWidth() * box1.getHeight(); - int box2Area = box2.getWidth() * box2.getHeight(); - - return (float) intersectionArea / (box1Area + box2Area - intersectionArea); - } - - // 打印模型信息 - private void logModelInfo(OrtSession session) { - System.out.println("模型输入信息:"); - try { - for (Map.Entry entry : session.getInputInfo().entrySet()) { - String name = entry.getKey(); - NodeInfo info = entry.getValue(); - System.out.println("输入名称: " + name); - System.out.println("输入信息: " + info.toString()); - } - } catch (OrtException e) { - throw new RuntimeException(e); - } - - System.out.println("模型输出信息:"); - try { - for (Map.Entry entry : session.getOutputInfo().entrySet()) { - String name = entry.getKey(); - NodeInfo info = entry.getValue(); - System.out.println("输出名称: " + name); - System.out.println("输出信息: " + info.toString()); - } - } catch (OrtException e) { - throw new RuntimeException(e); - } - } - - public static void main(String[] args) { - // 初始化标签列表 - List labels = Arrays.asList("person"); - - // 创建 InferenceEngine 实例 - InferenceEngine_up inferenceEngine = new InferenceEngine_up("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels); - - try { - // 加载图片 - File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg"); - BufferedImage inputImage = ImageIO.read(imageFile); - - // 预处理图像 - float[] inputData = inferenceEngine.preprocessImage(inputImage); - - // 执行推理 - InferenceResult result = null; - for (int i = 0; i < 100; i++) { - long l = System.currentTimeMillis(); - result = inferenceEngine.infer(inputData, 640, 640); - System.out.println(System.currentTimeMillis() - l); - } - - // 处理并显示结果 - System.out.println("推理结果:"); - for (BoundingBox box : result.getBoundingBoxes()) { - System.out.println(box); - } - - // 可视化并保存带有边界框的图像 - BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes()); - - // 保存图片到本地文件 - File outputFile = new File("output_image_with_boxes.jpg"); - ImageIO.write(outputImage, "jpg", outputFile); - - System.out.println("已保存结果图片: " + outputFile.getAbsolutePath()); - - } catch (Exception e) { - e.printStackTrace(); - } - } - - // 在图像上绘制边界框和标签 - BufferedImage drawBoundingBoxes(BufferedImage image, List boxes) { - Graphics2D g = image.createGraphics(); - g.setColor(Color.RED); // 设置绘制边界框的颜色 - g.setStroke(new BasicStroke(2)); // 设置线条粗细 - - for (BoundingBox box : boxes) { - // 绘制矩形边界框 - g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight()); - // 绘制标签文字和置信度 - String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); - g.setFont(new Font("Arial", Font.PLAIN, 12)); - g.drawString(label, box.getX(), box.getY() - 5); - } - - g.dispose(); // 释放资源 - return image; - } - - // 图像预处理 - float[] preprocessImage(BufferedImage image) { - int targetWidth = 640; - int targetHeight = 640; - - origWidth = image.getWidth(); - origHeight = image.getHeight(); - - // 计算缩放因子 - scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); - - // 计算新的图像尺寸 - newWidth = Math.round(origWidth * scalingFactor); - newHeight = Math.round(origHeight * scalingFactor); - - // 计算偏移量以居中图像 - xOffset = (targetWidth - newWidth) / 2; - yOffset = (targetHeight - newHeight) / 2; - - // 创建一个新的BufferedImage - BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB); - Graphics2D g = resizedImage.createGraphics(); - - // 填充背景为黑色 - g.setColor(Color.BLACK); - g.fillRect(0, 0, targetWidth, targetHeight); - - // 绘制缩放后的图像到新的图像上 - g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null); - g.dispose(); - - float[] inputData = new float[3 * targetWidth * targetHeight]; - - for (int c = 0; c < 3; c++) { - for (int y = 0; y < targetHeight; y++) { - for (int x = 0; x < targetWidth; x++) { - int rgb = resizedImage.getRGB(x, y); - float value = 0f; - if (c == 0) { - value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel - } else if (c == 1) { - value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel - } else if (c == 2) { - value = (rgb & 0xFF) / 255.0f; // Blue channel - } - inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value; - } - } - } - - return inputData; - } - - // 非极大值抑制(NMS)方法 - private List nonMaximumSuppression(List boxes, float iouThreshold) { - // 按置信度排序 - boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence())); - - List result = new ArrayList<>(); - - while (!boxes.isEmpty()) { - BoundingBox bestBox = boxes.remove(0); - result.add(bestBox); - - Iterator iterator = boxes.iterator(); - while (iterator.hasNext()) { - BoundingBox box = iterator.next(); - if (computeIoU(bestBox, box) > iouThreshold) { - iterator.remove(); - } - } - } - - return result; - } - - // 其他方法保持不变... -} diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine_up_1.java b/src/main/java/com/ly/onnx/engine/InferenceEngine_up_1.java deleted file mode 100644 index c4c732e..0000000 --- a/src/main/java/com/ly/onnx/engine/InferenceEngine_up_1.java +++ /dev/null @@ -1,297 +0,0 @@ -package com.ly.onnx.engine; - -import ai.onnxruntime.*; -import com.ly.onnx.model.BoundingBox; -import com.ly.onnx.model.InferenceResult; - -import javax.imageio.ImageIO; -import java.awt.*; -import java.awt.image.BufferedImage; -import java.io.File; -import java.nio.FloatBuffer; -import java.util.List; -import java.util.*; - -public class InferenceEngine_up_1 { - - OrtEnvironment environment = OrtEnvironment.getEnvironment(); - OrtSession.SessionOptions sessionOptions = null; - OrtSession session = null; - private String modelPath; - private List labels; - - // 添加用于存储图像预处理信息的类变量 - private int origWidth; - private int origHeight; - private int newWidth; - private int newHeight; - private float scalingFactor; - private int xOffset; - private int yOffset; - - public InferenceEngine_up_1(String modelPath, List labels) { - this.modelPath = modelPath; - this.labels = labels; - init(); - } - - public void init() { - - try { - sessionOptions = new OrtSession.SessionOptions(); - sessionOptions.addCUDA(0); - session = environment.createSession(modelPath, sessionOptions); - } catch (OrtException e) { - throw new RuntimeException(e); - } - logModelInfo(session); - } - - public InferenceResult infer(float[] inputData, int w, int h) { - // 创建ONNX输入Tensor - try { - Map inputInfo = session.getInputInfo(); - String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入 - - long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状 - OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape); - - // 执行推理 - OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor)); - - // 解析推理结果 - String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出 - float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5] - - long l = System.currentTimeMillis(); - // 设定置信度阈值 - float confidenceThreshold = 0.5f; // 您可以根据需要调整 - - // 根据模型的输出结果解析边界框 - List boxes = new ArrayList<>(); - for (float[] data : outputData[0]) { // 遍历所有检测框 - float confidence = data[4]; - if (confidence >= confidenceThreshold) { - float xCenter = data[0]; - float yCenter = data[1]; - float widthBox = data[2]; - float heightBox = data[3]; - - // 调整坐标,减去偏移并除以缩放因子 - float xCenterAdjusted = (xCenter - xOffset) / scalingFactor; - float yCenterAdjusted = (yCenter - yOffset) / scalingFactor; - float widthAdjusted = widthBox / scalingFactor; - float heightAdjusted = heightBox / scalingFactor; - - // 计算左上角坐标 - int x = (int) (xCenterAdjusted - widthAdjusted / 2); - int y = (int) (yCenterAdjusted - heightAdjusted / 2); - int wBox = (int) widthAdjusted; - int hBox = (int) heightAdjusted; - - // 确保坐标在原始图像范围内 - if (x < 0) x = 0; - if (y < 0) y = 0; - if (x + wBox > origWidth) wBox = origWidth - x; - if (y + hBox > origHeight) hBox = origHeight - y; - - String label = "person"; // 由于只有一个类别 - - boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence)); - } - } - - // 非极大值抑制(NMS) - List nmsBoxes = nonMaximumSuppression(boxes, 0.5f); - System.out.println("耗时:"+((System.currentTimeMillis() - l))); - // 封装结果并返回 - InferenceResult inferenceResult = new InferenceResult(); - inferenceResult.setBoundingBoxes(nmsBoxes); - return inferenceResult; - - } catch (OrtException e) { - throw new RuntimeException("推理失败", e); - } - } - - // 计算两个边界框的 IoU - private float computeIoU(BoundingBox box1, BoundingBox box2) { - int x1 = Math.max(box1.getX(), box2.getX()); - int y1 = Math.max(box1.getY(), box2.getY()); - int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth()); - int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight()); - - int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1); - int box1Area = box1.getWidth() * box1.getHeight(); - int box2Area = box2.getWidth() * box2.getHeight(); - - return (float) intersectionArea / (box1Area + box2Area - intersectionArea); - } - - // 打印模型信息 - private void logModelInfo(OrtSession session) { - System.out.println("模型输入信息:"); - try { - for (Map.Entry entry : session.getInputInfo().entrySet()) { - String name = entry.getKey(); - NodeInfo info = entry.getValue(); - System.out.println("输入名称: " + name); - System.out.println("输入信息: " + info.toString()); - } - } catch (OrtException e) { - throw new RuntimeException(e); - } - - System.out.println("模型输出信息:"); - try { - for (Map.Entry entry : session.getOutputInfo().entrySet()) { - String name = entry.getKey(); - NodeInfo info = entry.getValue(); - System.out.println("输出名称: " + name); - System.out.println("输出信息: " + info.toString()); - } - } catch (OrtException e) { - throw new RuntimeException(e); - } - } - - public static void main(String[] args) { - // 初始化标签列表 - List labels = Arrays.asList("person"); - - // 创建 InferenceEngine 实例 - InferenceEngine_up_1 inferenceEngine = new InferenceEngine_up_1("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels); - - try { - // 加载图片 - File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg"); - BufferedImage inputImage = ImageIO.read(imageFile); - - // 预处理图像 - float[] inputData = inferenceEngine.preprocessImage(inputImage); - - // 执行推理 - InferenceResult result = null; - for (int i = 0; i < 100; i++) { - long l = System.currentTimeMillis(); - result = inferenceEngine.infer(inputData, 640, 640); - System.out.println(System.currentTimeMillis() - l); - } - - // 处理并显示结果 - System.out.println("推理结果:"); - for (BoundingBox box : result.getBoundingBoxes()) { - System.out.println(box); - } - - // 可视化并保存带有边界框的图像 - BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes()); - - // 保存图片到本地文件 - File outputFile = new File("output_image_with_boxes.jpg"); - ImageIO.write(outputImage, "jpg", outputFile); - - System.out.println("已保存结果图片: " + outputFile.getAbsolutePath()); - - } catch (Exception e) { - e.printStackTrace(); - } - } - - // 在图像上绘制边界框和标签 - private BufferedImage drawBoundingBoxes(BufferedImage image, List boxes) { - Graphics2D g = image.createGraphics(); - g.setColor(Color.RED); // 设置绘制边界框的颜色 - g.setStroke(new BasicStroke(2)); // 设置线条粗细 - - for (BoundingBox box : boxes) { - // 绘制矩形边界框 - g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight()); - // 绘制标签文字和置信度 - String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); - g.setFont(new Font("Arial", Font.PLAIN, 12)); - g.drawString(label, box.getX(), box.getY() - 5); - } - - g.dispose(); // 释放资源 - return image; - } - - // 图像预处理 - private float[] preprocessImage(BufferedImage image) { - int targetWidth = 640; - int targetHeight = 640; - - origWidth = image.getWidth(); - origHeight = image.getHeight(); - - // 计算缩放因子 - scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); - - // 计算新的图像尺寸 - newWidth = Math.round(origWidth * scalingFactor); - newHeight = Math.round(origHeight * scalingFactor); - - // 计算偏移量以居中图像 - xOffset = (targetWidth - newWidth) / 2; - yOffset = (targetHeight - newHeight) / 2; - - // 创建一个新的BufferedImage - BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB); - Graphics2D g = resizedImage.createGraphics(); - - // 填充背景为黑色 - g.setColor(Color.BLACK); - g.fillRect(0, 0, targetWidth, targetHeight); - - // 绘制缩放后的图像到新的图像上 - g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null); - g.dispose(); - - float[] inputData = new float[3 * targetWidth * targetHeight]; - - for (int c = 0; c < 3; c++) { - for (int y = 0; y < targetHeight; y++) { - for (int x = 0; x < targetWidth; x++) { - int rgb = resizedImage.getRGB(x, y); - float value = 0f; - if (c == 0) { - value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel - } else if (c == 1) { - value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel - } else if (c == 2) { - value = (rgb & 0xFF) / 255.0f; // Blue channel - } - inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value; - } - } - } - - return inputData; - } - - // 非极大值抑制(NMS)方法 - private List nonMaximumSuppression(List boxes, float iouThreshold) { - // 按置信度排序 - boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence())); - - List result = new ArrayList<>(); - - while (!boxes.isEmpty()) { - BoundingBox bestBox = boxes.remove(0); - result.add(bestBox); - - Iterator iterator = boxes.iterator(); - while (iterator.hasNext()) { - BoundingBox box = iterator.next(); - if (computeIoU(bestBox, box) > iouThreshold) { - iterator.remove(); - } - } - } - - return result; - } - - // 其他方法保持不变... -} diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine_up_2.java b/src/main/java/com/ly/onnx/engine/InferenceEngine_up_2.java deleted file mode 100644 index 2262dc4..0000000 --- a/src/main/java/com/ly/onnx/engine/InferenceEngine_up_2.java +++ /dev/null @@ -1,314 +0,0 @@ -package com.ly.onnx.engine; - -import ai.onnxruntime.*; -import com.ly.onnx.model.BoundingBox; -import com.ly.onnx.model.InferenceResult; - -import javax.imageio.ImageIO; -import java.awt.*; -import java.awt.image.BufferedImage; -import java.io.File; -import java.nio.FloatBuffer; -import java.util.List; -import java.util.*; - -public class InferenceEngine_up_2 { - - private OrtEnvironment environment; - private OrtSession.SessionOptions sessionOptions; - private OrtSession session; // 将 session 作为类的成员变量 - - private String modelPath; - private List labels; - - // 添加用于存储图像预处理信息的类变量 - private int origWidth; - private int origHeight; - private int newWidth; - private int newHeight; - private float scalingFactor; - private int xOffset; - private int yOffset; - - public InferenceEngine_up_2(String modelPath, List labels) { - this.modelPath = modelPath; - this.labels = labels; - init(); - } - - public void init() { - try { - environment = OrtEnvironment.getEnvironment(); - sessionOptions = new OrtSession.SessionOptions(); - sessionOptions.addCUDA(0); // 使用 GPU - session = environment.createSession(modelPath, sessionOptions); - logModelInfo(session); - } catch (OrtException e) { - throw new RuntimeException("模型加载失败", e); - } - } - - public InferenceResult infer(float[] inputData, int w, int h) { - long startTime = System.currentTimeMillis(); - - try { - Map inputInfo = session.getInputInfo(); - String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入 - - long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状 - OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape); - - // 执行推理 - long inferenceStart = System.currentTimeMillis(); - OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor)); - long inferenceEnd = System.currentTimeMillis(); - System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms"); - - // 解析推理结果 - String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出 - float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5] - - // 设定置信度阈值 - float confidenceThreshold = 0.5f; // 您可以根据需要调整 - - // 根据模型的输出结果解析边界框 - List boxes = new ArrayList<>(); - for (float[] data : outputData[0]) { // 遍历所有检测框 - float confidence = data[4]; - if (confidence >= confidenceThreshold) { - float xCenter = data[0]; - float yCenter = data[1]; - float widthBox = data[2]; - float heightBox = data[3]; - - // 调整坐标,减去偏移并除以缩放因子 - float xCenterAdjusted = (xCenter - xOffset) / scalingFactor; - float yCenterAdjusted = (yCenter - yOffset) / scalingFactor; - float widthAdjusted = widthBox / scalingFactor; - float heightAdjusted = heightBox / scalingFactor; - - // 计算左上角坐标 - int x = (int) (xCenterAdjusted - widthAdjusted / 2); - int y = (int) (yCenterAdjusted - heightAdjusted / 2); - int wBox = (int) widthAdjusted; - int hBox = (int) heightAdjusted; - - // 确保坐标在原始图像范围内 - if (x < 0) x = 0; - if (y < 0) y = 0; - if (x + wBox > origWidth) wBox = origWidth - x; - if (y + hBox > origHeight) hBox = origHeight - y; - - String label = "person"; // 由于只有一个类别 - - boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence)); - } - } - - // 非极大值抑制(NMS) - long nmsStart = System.currentTimeMillis(); - List nmsBoxes = nonMaximumSuppression(boxes, 0.5f); - long nmsEnd = System.currentTimeMillis(); - System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms"); - - // 封装结果并返回 - InferenceResult inferenceResult = new InferenceResult(); - inferenceResult.setBoundingBoxes(nmsBoxes); - - long endTime = System.currentTimeMillis(); - System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms"); - - return inferenceResult; - - } catch (OrtException e) { - throw new RuntimeException("推理失败", e); - } - } - - // 计算两个边界框的 IoU - private float computeIoU(BoundingBox box1, BoundingBox box2) { - int x1 = Math.max(box1.getX(), box2.getX()); - int y1 = Math.max(box1.getY(), box2.getY()); - int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth()); - int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight()); - - int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1); - int box1Area = box1.getWidth() * box1.getHeight(); - int box2Area = box2.getWidth() * box2.getHeight(); - - return (float) intersectionArea / (box1Area + box2Area - intersectionArea); - } - - // 打印模型信息 - private void logModelInfo(OrtSession session) { - System.out.println("模型输入信息:"); - try { - for (Map.Entry entry : session.getInputInfo().entrySet()) { - String name = entry.getKey(); - NodeInfo info = entry.getValue(); - System.out.println("输入名称: " + name); - System.out.println("输入信息: " + info.toString()); - } - } catch (OrtException e) { - throw new RuntimeException(e); - } - - System.out.println("模型输出信息:"); - try { - for (Map.Entry entry : session.getOutputInfo().entrySet()) { - String name = entry.getKey(); - NodeInfo info = entry.getValue(); - System.out.println("输出名称: " + name); - System.out.println("输出信息: " + info.toString()); - } - } catch (OrtException e) { - throw new RuntimeException(e); - } - } - - public static void main(String[] args) { - // 初始化标签列表 - List labels = Arrays.asList("person"); - - // 创建 InferenceEngine 实例 - InferenceEngine_up_2 inferenceEngine = new InferenceEngine_up_2("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels); - - try { - // 加载图片 - File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg"); - BufferedImage inputImage = ImageIO.read(imageFile); - - // 预处理图像 - long l1 = System.currentTimeMillis(); - float[] inputData = inferenceEngine.preprocessImage(inputImage); - System.out.println("转"+(System.currentTimeMillis() - l1)); - // 执行推理 - InferenceResult result = null; - for (int i = 0; i < 10; i++) { - long l = System.currentTimeMillis(); - result = inferenceEngine.infer(inputData, 640, 640); - System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms"); - } - - // 处理并显示结果 - System.out.println("推理结果:"); - for (BoundingBox box : result.getBoundingBoxes()) { - System.out.println(box); - } - - // 可视化并保存带有边界框的图像 - BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes()); - - // 保存图片到本地文件 - File outputFile = new File("output_image_with_boxes.jpg"); - ImageIO.write(outputImage, "jpg", outputFile); - - System.out.println("已保存结果图片: " + outputFile.getAbsolutePath()); - - } catch (Exception e) { - e.printStackTrace(); - } - } - - // 在图像上绘制边界框和标签 - private BufferedImage drawBoundingBoxes(BufferedImage image, List boxes) { - Graphics2D g = image.createGraphics(); - g.setColor(Color.RED); // 设置绘制边界框的颜色 - g.setStroke(new BasicStroke(2)); // 设置线条粗细 - - for (BoundingBox box : boxes) { - // 绘制矩形边界框 - g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight()); - // 绘制标签文字和置信度 - String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); - g.setFont(new Font("Arial", Font.PLAIN, 12)); - g.drawString(label, box.getX(), box.getY() - 5); - } - - g.dispose(); // 释放资源 - return image; - } - - // 图像预处理 - public float[] preprocessImage(BufferedImage image) { - int targetWidth = 640; - int targetHeight = 640; - - origWidth = image.getWidth(); - origHeight = image.getHeight(); - - // 计算缩放因子 - scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); - - // 计算新的图像尺寸 - newWidth = Math.round(origWidth * scalingFactor); - newHeight = Math.round(origHeight * scalingFactor); - - // 计算偏移量以居中图像 - xOffset = (targetWidth - newWidth) / 2; - yOffset = (targetHeight - newHeight) / 2; - - // 创建一个新的BufferedImage - BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB); - Graphics2D g = resizedImage.createGraphics(); - - // 填充背景为黑色 - g.setColor(Color.BLACK); - g.fillRect(0, 0, targetWidth, targetHeight); - - // 绘制缩放后的图像到新的图像上 - g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null); - g.dispose(); - - float[] inputData = new float[3 * targetWidth * targetHeight]; - - // 开始计时 - long preprocessStart = System.currentTimeMillis(); - - for (int c = 0; c < 3; c++) { - for (int y = 0; y < targetHeight; y++) { - for (int x = 0; x < targetWidth; x++) { - int rgb = resizedImage.getRGB(x, y); - float value = 0f; - if (c == 0) { - value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel - } else if (c == 1) { - value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel - } else if (c == 2) { - value = (rgb & 0xFF) / 255.0f; // Blue channel - } - inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value; - } - } - } - - // 结束计时 - long preprocessEnd = System.currentTimeMillis(); - System.out.println("图像预处理耗时:" + (preprocessEnd - preprocessStart) + " ms"); - - return inputData; - } - - // 非极大值抑制(NMS)方法 - private List nonMaximumSuppression(List boxes, float iouThreshold) { - // 按置信度排序(从高到低) - boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence())); - - List result = new ArrayList<>(); - - while (!boxes.isEmpty()) { - BoundingBox bestBox = boxes.remove(0); - result.add(bestBox); - - Iterator iterator = boxes.iterator(); - while (iterator.hasNext()) { - BoundingBox box = iterator.next(); - if (computeIoU(bestBox, box) > iouThreshold) { - iterator.remove(); - } - } - } - - return result; - } -} diff --git a/src/main/java/com/ly/onnx/model/BoundingBox.java b/src/main/java/com/ly/onnx/model/BoundingBox.java index 332a888..feb130c 100644 --- a/src/main/java/com/ly/onnx/model/BoundingBox.java +++ b/src/main/java/com/ly/onnx/model/BoundingBox.java @@ -10,6 +10,7 @@ public class BoundingBox { private int height; private String label; private float confidence; + private long trackId; // 构造函数、getter 和 setter 方法 @@ -22,6 +23,5 @@ public class BoundingBox { this.confidence = confidence; } - // Getter 和 Setter 方法 - // ... + } diff --git a/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java b/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java index 08bb5c9..186a100 100644 --- a/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java +++ b/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java @@ -13,53 +13,164 @@ import java.util.List; public class DrawImagesUtils { + // 使用HSL颜色生成更高级的颜色 + public static Color hslToRgb(float hue, float saturation, float lightness) { + float c = (1 - Math.abs(2 * lightness - 1)) * saturation; + float x = c * (1 - Math.abs((hue / 60) % 2 - 1)); + float m = lightness - c / 2; + float r = 0, g = 0, b = 0; + if (0 <= hue && hue < 60) { + r = c; + g = x; + } else if (60 <= hue && hue < 120) { + r = x; + g = c; + } else if (120 <= hue && hue < 180) { + g = c; + b = x; + } else if (180 <= hue && hue < 240) { + g = x; + b = c; + } else if (240 <= hue && hue < 300) { + r = x; + b = c; + } else if (300 <= hue && hue < 360) { + r = c; + b = x; + } + + int rVal = (int) ((r + m) * 255); + int gVal = (int) ((g + m) * 255); + int bVal = (int) ((b + m) * 255); + return new Color(rVal, gVal, bVal); + } + + // 根据模型索引生成颜色 + private static Color generateColorForModel(int modelIndex, int totalModels) { + float hue = (360.0f / totalModels) * modelIndex; // 根据模型索引设置色相 + return hslToRgb(hue, 0.7f, 0.5f); // 饱和度0.7,亮度0.5 + } + + // 在 BufferedImage 上绘制推理结果 public static void drawInferenceResult(BufferedImage bufferedImage, List inferenceResults) { Graphics2D g2d = bufferedImage.createGraphics(); - g2d.setFont(new Font("Arial", Font.PLAIN, 12)); + g2d.setFont(new Font("Arial", Font.PLAIN, 24)); // 设置字体大小为24 + + int modelIndex = 0; // 模型索引 + int totalModels = inferenceResults.size(); // 总模型数 for (InferenceResult result : inferenceResults) { + Color modelColor = generateColorForModel(modelIndex++, totalModels); // 为每个模型生成独立颜色 + for (BoundingBox box : result.getBoundingBoxes()) { - // 绘制矩形 - g2d.setColor(Color.RED); + // 绘制矩形框 + g2d.setColor(modelColor); + g2d.setStroke(new BasicStroke(4)); // 设置线条粗细 g2d.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight()); - // 绘制标签 - String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); + // 获取字体度量 FontMetrics metrics = g2d.getFontMetrics(); - int labelWidth = metrics.stringWidth(label); - int labelHeight = metrics.getHeight(); + int labelHeight = metrics.getHeight() + 4; // 标签高度 + String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); + int labelWidth = metrics.stringWidth(label) + 10; // 标签宽度 - // 确保文字不会超出图像 - int y = Math.max(box.getY(), labelHeight); + String trackIdLabel = "TrackID: " + box.getTrackId(); + int trackIdWidth = metrics.stringWidth(trackIdLabel) + 10; // TrackID标签宽度 + int trackIdHeight = metrics.getHeight() + 4; // TrackID标签高度 - // 绘制文字背景 - g2d.setColor(Color.RED); - g2d.fillRect(box.getX(), y - labelHeight, labelWidth, labelHeight); + // 计算标签总高度 + int totalLabelHeight = (box.getTrackId() > 0 ? trackIdHeight : 0) + labelHeight; - // 绘制文字 - g2d.setColor(Color.WHITE); - g2d.drawString(label, box.getX(), y); + // 边距 + int margin = 10; + + // 检查上方是否有足够空间绘制标签 + boolean canDrawAbove = box.getY() >= totalLabelHeight + margin; + + if (canDrawAbove) { + // 在检测框上方绘制标签 + int currentY = box.getY() - totalLabelHeight; + + // 绘制 TrackID(如果有) + if (box.getTrackId() > 0) { + // 绘制 TrackID 背景 + g2d.setColor(modelColor); + g2d.fillRect(box.getX(), currentY, trackIdWidth, trackIdHeight); + + // 绘制 TrackID 文字 + g2d.setColor(Color.BLACK); + g2d.drawString(trackIdLabel, box.getX() + 5, currentY + metrics.getAscent()); + + currentY += trackIdHeight; + } + + // 绘制 classid 背景 + g2d.setColor(modelColor); + g2d.fillRect(box.getX(), currentY, labelWidth, labelHeight); + + // 绘制 classid 文字 + g2d.setColor(Color.BLACK); + g2d.drawString(label, box.getX() + 5, currentY + metrics.getAscent()); + } else { + // 如果上方空间不足,则在检测框内部顶部绘制标签 + int currentY = box.getY() + 5; // 内边距5 + + // 绘制半透明背景以提高可读性 + int bgAlpha = 200; // 透明度(0-255) + Color backgroundColor = new Color(modelColor.getRed(), modelColor.getGreen(), modelColor.getBlue(), bgAlpha); + + if (box.getTrackId() > 0) { + // 绘制 TrackID 背景 + g2d.setColor(backgroundColor); + g2d.fillRect(box.getX(), currentY, trackIdWidth, trackIdHeight); + + // 绘制 TrackID 文字 + g2d.setColor(Color.BLACK); + g2d.drawString(trackIdLabel, box.getX() + 5, currentY + metrics.getAscent()); + + currentY += trackIdHeight; + } + + // 绘制 classid 背景 + g2d.setColor(backgroundColor); + g2d.fillRect(box.getX(), currentY, labelWidth, labelHeight); + + // 绘制 classid 文字 + g2d.setColor(Color.BLACK); + g2d.drawString(label, box.getX() + 5, currentY + metrics.getAscent()); + } } } g2d.dispose(); // 释放资源 } - // 在 Mat 上绘制推理结果 + + + + + + + // 在 Mat 上绘制推理结果 (OpenCV 版本) public static void drawInferenceResult(Mat image, List inferenceResults) { + int modelIndex = 0; + int totalModels = inferenceResults.size(); + for (InferenceResult result : inferenceResults) { + Scalar modelColor = convertColorToScalar(generateColorForModel(modelIndex++, totalModels)); + for (BoundingBox box : result.getBoundingBoxes()) { // 绘制矩形 Point topLeft = new Point(box.getX(), box.getY()); Point bottomRight = new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight()); - Imgproc.rectangle(image, topLeft, bottomRight, new Scalar(0, 0, 255), 2); // 红色边框 + Imgproc.rectangle(image, topLeft, bottomRight, modelColor, 3); // 加粗边框 // 绘制标签 String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); int font = Imgproc.FONT_HERSHEY_SIMPLEX; - double fontScale = 0.5; - int thickness = 1; + double fontScale = 0.7; + int thickness = 2; // 计算文字大小 int[] baseLine = new int[1]; @@ -71,12 +182,17 @@ public class DrawImagesUtils { // 绘制文字背景 Imgproc.rectangle(image, new Point(topLeft.x, y - labelSize.height), new Point(topLeft.x + labelSize.width, y + baseLine[0]), - new Scalar(0, 0, 255), Imgproc.FILLED); + modelColor, Imgproc.FILLED); - // 绘制文字 + // 绘制黑色文字 Imgproc.putText(image, label, new Point(topLeft.x, y), - font, fontScale, new Scalar(255, 255, 255), thickness); + font, fontScale, new Scalar(0, 0, 0), thickness); // 黑色文字 } } } + + // 将 Color 转为 Scalar (用于 OpenCV) + private static Scalar convertColorToScalar(Color color) { + return new Scalar(color.getBlue(), color.getGreen(), color.getRed()); // OpenCV 中颜色顺序是 BGR + } } diff --git a/src/main/java/com/ly/play/opencv/VideoPlayer.java b/src/main/java/com/ly/play/opencv/VideoPlayer.java index a3ed6b2..d8b6455 100644 --- a/src/main/java/com/ly/play/opencv/VideoPlayer.java +++ b/src/main/java/com/ly/play/opencv/VideoPlayer.java @@ -3,9 +3,12 @@ package com.ly.play.opencv; import com.ly.layout.VideoPanel; import com.ly.model_load.ModelManager; import com.ly.onnx.engine.InferenceEngine; +import com.ly.onnx.model.BoundingBox; import com.ly.onnx.model.InferenceResult; import com.ly.onnx.utils.DrawImagesUtils; +import com.ly.track.SimpleTracker; import org.opencv.core.*; +import org.opencv.imgcodecs.Imgcodecs; import org.opencv.imgproc.Imgproc; import org.opencv.videoio.VideoCapture; import org.opencv.videoio.Videoio; @@ -37,9 +40,13 @@ public class VideoPlayer { private Thread inferenceThread; private VideoPanel videoPanel; + // 创建简单的跟踪器 + SimpleTracker tracker = new SimpleTracker(); + private long videoDuration = 0; // 毫秒 private long currentTimestamp = 0; // 毫秒 + private boolean isTrackingEnabled; private ModelManager modelManager; private List inferenceEngines = new ArrayList<>(); @@ -178,8 +185,20 @@ public class VideoPlayer { List inferenceResults = new ArrayList<>(); for (InferenceEngine inferenceEngine : inferenceEngines) { // 假设 InferenceEngine 有 infer 方法接受 float 数组 - inferenceResults.add(inferenceEngine.infer(floatObjectMap)); + InferenceResult infer = inferenceEngine.infer(floatObjectMap); + inferenceResults.add(infer); } + + // 合并所有模型的推理结果 + List allBoundingBoxes = new ArrayList<>(); + for (InferenceResult result : inferenceResults) { + allBoundingBoxes.addAll(result.getBoundingBoxes()); + } + // 如果启用了目标跟踪,则更新边界框并分配 trackId + if (isTrackingEnabled) { + tracker.update(allBoundingBoxes); + } + // 绘制推理结果 DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults); // 更新绘制后图像 @@ -201,20 +220,22 @@ public class VideoPlayer { isPaused = true; } - - + // 设置是否启用目标跟踪 + public void setTrackingEnabled(boolean enabled) { + this.isTrackingEnabled = enabled; + } // 定义一个内部类来存储帧数据 private static class FrameData { public BufferedImage image; public Map floatObjectMap; + public FrameData(BufferedImage image, Map floatObjectMap) { this.image = image; this.floatObjectMap = floatObjectMap; } } - // 可选的预处理方法 public Map preprocessImage(Mat image) { int origWidth = image.width(); @@ -246,9 +267,8 @@ public class VideoPlayer { inferenceEngine.setIndex(index.get()); } continue; - }else { + } else { index.getAndIncrement(); - } } @@ -330,6 +350,7 @@ public class VideoPlayer { return dynamicInput; } + public List getInferenceEngines() { return this.inferenceEngines; } @@ -359,4 +380,45 @@ public class VideoPlayer { this.inferenceEngines.add(inferenceEngine); } + // 加载并处理图片 + public void loadImage(String imagePath) throws Exception { + // 停止任何正在播放的视频 + stopVideo(); + + // 读取图片 + Mat image = Imgcodecs.imread(imagePath); + if (image.empty()) { + throw new Exception("无法读取图片文件:" + imagePath); + } + + // 转换为 BufferedImage + BufferedImage bufferedImage = matToBufferedImage(image); + + // 预处理图片 + Map preprocessedData = preprocessImage(image); + + // 执行推理 + List inferenceResults = new ArrayList<>(); + for (InferenceEngine inferenceEngine : inferenceEngines) { + InferenceResult infer = inferenceEngine.infer(preprocessedData); + inferenceResults.add(infer); + } + + // 合并所有模型的推理结果 + List allBoundingBoxes = new ArrayList<>(); + for (InferenceResult result : inferenceResults) { + allBoundingBoxes.addAll(result.getBoundingBoxes()); + } + + // 如果启用了目标跟踪,则更新边界框并分配 trackId + if (isTrackingEnabled) { + tracker.update(allBoundingBoxes); + } + + // 绘制推理结果 + DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults); + + // 在 VideoPanel 上显示图片 + videoPanel.updateImage(bufferedImage); + } } diff --git a/src/main/java/com/ly/track/SimpleTracker.java b/src/main/java/com/ly/track/SimpleTracker.java new file mode 100644 index 0000000..fabcaa2 --- /dev/null +++ b/src/main/java/com/ly/track/SimpleTracker.java @@ -0,0 +1,73 @@ +package com.ly.track; + +import com.ly.onnx.model.BoundingBox; +import lombok.Data; + +import java.awt.*; +import java.util.*; +import java.util.List; + + +public class SimpleTracker { + private Map trackedObjects = new HashMap<>(); // 使用自定义 TrackedObject 来跟踪 + private long currentTrackId = 0; + + // 跟踪器更新方法 + public List update(List detections) { + List updatedResults = new ArrayList<>(); + + for (BoundingBox detection : detections) { + boolean matched = false; + Point detectionCenter = getCenter(detection); // 获取当前检测目标的中心点 + + // 遍历现有的跟踪目标 + for (Map.Entry entry : trackedObjects.entrySet()) { + TrackedObject trackedObject = entry.getValue(); + Point trackedCenter = getCenter(trackedObject.boundingBox); + + // 使用中心点欧几里得距离进行匹配 + double distance = euclideanDistance(detectionCenter, trackedCenter); + + // 如果距离小于某个阈值,认为是同一目标 + if (distance < 50.0) { // 自定义距离阈值,可以根据需要调整 + detection.setTrackId(entry.getKey()); // 更新检测框的 trackId + trackedObject.update(detection); // 更新跟踪对象 + matched = true; + break; + } + } + + // 如果没有匹配到,创建新的 trackId + if (!matched) { + long newTrackId = ++currentTrackId; + detection.setTrackId(newTrackId); + trackedObjects.put(newTrackId, new TrackedObject(detection)); + } + + updatedResults.add(detection); + } + + // 清理丢失的目标 + cleanupLostObjects(); + + return updatedResults; + } + + // 计算目标的中心点 + private Point getCenter(BoundingBox box) { + int centerX = box.getX() + box.getWidth() / 2; + int centerY = box.getY() + box.getHeight() / 2; + return new Point(centerX, centerY); + } + + // 计算欧几里得距离 + private double euclideanDistance(Point p1, Point p2) { + return Math.sqrt(Math.pow(p1.x - p2.x, 2) + Math.pow(p1.y - p2.y, 2)); + } + + // 清理丢失的跟踪对象(例如不再检测到的对象) + private void cleanupLostObjects() { + // 可以根据时间戳或其他条件来清理长时间没有更新的目标 + trackedObjects.entrySet().removeIf(entry -> entry.getValue().isLost()); + } +} diff --git a/src/main/java/com/ly/track/TrackedObject.java b/src/main/java/com/ly/track/TrackedObject.java new file mode 100644 index 0000000..fe23b13 --- /dev/null +++ b/src/main/java/com/ly/track/TrackedObject.java @@ -0,0 +1,23 @@ +package com.ly.track; + +import com.ly.onnx.model.BoundingBox; + +public class TrackedObject { + public BoundingBox boundingBox; + private int lostFrames = 0; // 记录连续多少帧未检测到 + + public TrackedObject(BoundingBox initialBox) { + this.boundingBox = initialBox; + } + + // 更新跟踪目标的位置 + public void update(BoundingBox newBox) { + this.boundingBox = newBox; + lostFrames = 0; // 重置丢失计数 + } + + // 如果目标连续丢失多帧,认为目标丢失 + public boolean isLost() { + return lostFrames++ > 10; // 如果丢失超过10帧,就认为目标丢失 + } +} diff --git a/src/main/java/com/ly/utils/OpenCVTest.java b/src/main/java/com/ly/utils/OpenCVTest.java deleted file mode 100644 index b25f50a..0000000 --- a/src/main/java/com/ly/utils/OpenCVTest.java +++ /dev/null @@ -1,28 +0,0 @@ -package com.ly.utils; - -import org.opencv.core.Core; -import org.opencv.core.Mat; -import org.opencv.videoio.VideoCapture; - -public class OpenCVTest { -// static { -// nu.pattern.OpenCV.loadLocally(); -// } - - public static void main(String[] args) { - VideoCapture capture = new VideoCapture(0); // 打开默认摄像头 - if (!capture.isOpened()) { - System.out.println("无法打开摄像头"); - return; - } - - Mat frame = new Mat(); - if (capture.read(frame)) { - System.out.println("成功读取一帧图像"); - } else { - System.out.println("无法读取图像"); - } - - capture.release(); - } -} diff --git a/src/main/resources/model/best.onnx b/src/main/resources/model/best.onnx index b9d7489..e69de29 100644 Binary files a/src/main/resources/model/best.onnx and b/src/main/resources/model/best.onnx differ