diff --git a/src/main/java/com/ly/VideoInferenceApp.java b/src/main/java/com/ly/VideoInferenceApp.java index 79b2510..8de3422 100644 --- a/src/main/java/com/ly/VideoInferenceApp.java +++ b/src/main/java/com/ly/VideoInferenceApp.java @@ -130,27 +130,33 @@ public class VideoInferenceApp extends JFrame { modelManager.loadModel(this); DefaultListModel modelList = modelManager.getModelList(); ArrayList models = Collections.list(modelList.elements()); + for (ModelInfo modelInfo : models) { if (modelInfo != null) { - videoPlayer.addInferenceEngines(new InferenceEngine(modelInfo.getModelFilePath(), modelInfo.getLabels())); + boolean alreadyAdded = videoPlayer.getInferenceEngines().stream() + .anyMatch(engine -> engine.getModelPath().equals(modelInfo.getModelFilePath())); + if (!alreadyAdded) { + videoPlayer.addInferenceEngines(new InferenceEngine(modelInfo.getModelFilePath(), modelInfo.getLabels())); + } } } }); + // 播放按钮 playButton.addActionListener(e -> videoPlayer.playVideo()); // 暂停按钮 pauseButton.addActionListener(e -> videoPlayer.pauseVideo()); - // 重播按钮 - replayButton.addActionListener(e -> videoPlayer.replayVideo()); - - // 后退5秒 - rewind5sButton.addActionListener(e -> videoPlayer.rewind(5000)); - - // 快进5秒 - fastForward5sButton.addActionListener(e -> videoPlayer.fastForward(5000)); +// // 重播按钮 +// replayButton.addActionListener(e -> videoPlayer.replayVideo()); +// +// // 后退5秒 +// rewind5sButton.addActionListener(e -> videoPlayer.rewind(5000)); +// +// // 快进5秒 +// fastForward5sButton.addActionListener(e -> videoPlayer.fastForward(5000)); // 开始播放按钮的行为 startPlayButton.addActionListener(e -> { diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine.java b/src/main/java/com/ly/onnx/engine/InferenceEngine.java index 793ff4f..0e62925 100644 --- a/src/main/java/com/ly/onnx/engine/InferenceEngine.java +++ b/src/main/java/com/ly/onnx/engine/InferenceEngine.java @@ -24,7 +24,7 @@ public class InferenceEngine { private List labels; //preprocessParams输入数据的索引 - private int index; + private Integer index; // 用于存储图像预处理信息的类变量 private long[] inputShape = null; @@ -225,6 +225,8 @@ public class InferenceEngine { throw new RuntimeException(e); } } - + public String getModelPath() { + return this.modelPath; + } } diff --git a/src/main/java/com/ly/play/opencv/VideoPlayer.java b/src/main/java/com/ly/play/opencv/VideoPlayer.java index 2f2204b..a3ed6b2 100644 --- a/src/main/java/com/ly/play/opencv/VideoPlayer.java +++ b/src/main/java/com/ly/play/opencv/VideoPlayer.java @@ -16,6 +16,7 @@ import java.util.*; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static com.ly.onnx.utils.ImageUtils.matToBufferedImage; @@ -200,30 +201,8 @@ public class VideoPlayer { isPaused = true; } - // 停止视频 - public void stopVideo() { - isPlaying = false; - isPaused = false; - if (frameReadingThread != null && frameReadingThread.isAlive()) { - frameReadingThread.interrupt(); - } - if (inferenceThread != null && inferenceThread.isAlive()) { - inferenceThread.interrupt(); - } - - if (videoCapture != null) { - videoCapture.release(); - videoCapture = null; - } - - frameDataQueue.clear(); - } - - public void addInferenceEngines(InferenceEngine inferenceEngine) { - this.inferenceEngines.add(inferenceEngine); - } // 定义一个内部类来存储帧数据 private static class FrameData { @@ -242,24 +221,37 @@ public class VideoPlayer { int origHeight = image.height(); Map dynamicInput = new HashMap<>(); //定义索引 - int index = 0; + AtomicInteger index = new AtomicInteger(0); for (InferenceEngine inferenceEngine : this.inferenceEngines) { - inferenceEngine.setIndex(index); + long[] inputShape = inferenceEngine.getInputShape(); int targetWidth = (int) inputShape[2]; int targetHeight = (int) inputShape[3]; // 计算缩放因子 float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); - //检查是否存在输入大小一致的 如果存在则跳过 if (!dynamicInput.isEmpty()) { + boolean flag = true; // 初始设为 true 表示需要跳过 for (Map.Entry entry : dynamicInput.entrySet()) { Map input = (Map) entry.getValue(); - if (inputShape[2] == (long) input.get("targetHeight") || inputShape[3] == (long) input.get("targetWidth")) { - break; + Integer targetHeightValue = (Integer) input.get("targetHeight"); + Integer targetWidthValue = (Integer) input.get("targetWidth"); + if (inputShape[2] == targetHeightValue && inputShape[3] == targetWidthValue) { + flag = false; // 如果找到相同尺寸,设为 false 表示不需要跳过 } } + + if (!flag) { // 如果找到相同尺寸,跳过处理 + if (inferenceEngine.getIndex() == null) { + inferenceEngine.setIndex(index.get()); + } + continue; + }else { + index.getAndIncrement(); + + } } + // 计算新的图像尺寸 int newWidth = Math.round(origWidth * scalingFactor); int newHeight = Math.round(origHeight * scalingFactor); @@ -321,18 +313,50 @@ public class VideoPlayer { // 释放图像资源 resizedImage.release(); paddedImage.release(); + floatImage.release(); // 将预处理结果和偏移信息存入 Map Map result = new HashMap<>(); result.put("inputData", chwData); result.put("origWidth", origWidth); result.put("origHeight", origHeight); + result.put("targetWidth", targetWidth); + result.put("targetHeight", targetHeight); result.put("scalingFactor", scalingFactor); result.put("xOffset", xOffset); result.put("yOffset", yOffset); - dynamicInput.put(index, result); - index++; + inferenceEngine.setIndex(index.get()); + dynamicInput.put(index.get(), result); } + return dynamicInput; } + public List getInferenceEngines() { + return this.inferenceEngines; + } + + // 停止视频 + public void stopVideo() { + isPlaying = false; + isPaused = false; + + if (frameReadingThread != null && frameReadingThread.isAlive()) { + frameReadingThread.interrupt(); + } + + if (inferenceThread != null && inferenceThread.isAlive()) { + inferenceThread.interrupt(); + } + + if (videoCapture != null) { + videoCapture.release(); + videoCapture = null; + } + + frameDataQueue.clear(); + } + + public void addInferenceEngines(InferenceEngine inferenceEngine) { + this.inferenceEngines.add(inferenceEngine); + } }