支持多模型不同输入支持

This commit is contained in:
sulv 2024-10-10 00:48:10 +08:00
parent 23cd0a7b21
commit 7a88b08500
3 changed files with 72 additions and 40 deletions

View File

@ -130,27 +130,33 @@ public class VideoInferenceApp extends JFrame {
modelManager.loadModel(this); modelManager.loadModel(this);
DefaultListModel<ModelInfo> modelList = modelManager.getModelList(); DefaultListModel<ModelInfo> modelList = modelManager.getModelList();
ArrayList<ModelInfo> models = Collections.list(modelList.elements()); ArrayList<ModelInfo> models = Collections.list(modelList.elements());
for (ModelInfo modelInfo : models) { for (ModelInfo modelInfo : models) {
if (modelInfo != null) { if (modelInfo != null) {
videoPlayer.addInferenceEngines(new InferenceEngine(modelInfo.getModelFilePath(), modelInfo.getLabels())); boolean alreadyAdded = videoPlayer.getInferenceEngines().stream()
.anyMatch(engine -> engine.getModelPath().equals(modelInfo.getModelFilePath()));
if (!alreadyAdded) {
videoPlayer.addInferenceEngines(new InferenceEngine(modelInfo.getModelFilePath(), modelInfo.getLabels()));
}
} }
} }
}); });
// 播放按钮 // 播放按钮
playButton.addActionListener(e -> videoPlayer.playVideo()); playButton.addActionListener(e -> videoPlayer.playVideo());
// 暂停按钮 // 暂停按钮
pauseButton.addActionListener(e -> videoPlayer.pauseVideo()); pauseButton.addActionListener(e -> videoPlayer.pauseVideo());
// 重播按钮 // // 重播按钮
replayButton.addActionListener(e -> videoPlayer.replayVideo()); // replayButton.addActionListener(e -> videoPlayer.replayVideo());
//
// 后退5秒 // // 后退5秒
rewind5sButton.addActionListener(e -> videoPlayer.rewind(5000)); // rewind5sButton.addActionListener(e -> videoPlayer.rewind(5000));
//
// 快进5秒 // // 快进5秒
fastForward5sButton.addActionListener(e -> videoPlayer.fastForward(5000)); // fastForward5sButton.addActionListener(e -> videoPlayer.fastForward(5000));
// 开始播放按钮的行为 // 开始播放按钮的行为
startPlayButton.addActionListener(e -> { startPlayButton.addActionListener(e -> {

View File

@ -24,7 +24,7 @@ public class InferenceEngine {
private List<String> labels; private List<String> labels;
//preprocessParams输入数据的索引 //preprocessParams输入数据的索引
private int index; private Integer index;
// 用于存储图像预处理信息的类变量 // 用于存储图像预处理信息的类变量
private long[] inputShape = null; private long[] inputShape = null;
@ -225,6 +225,8 @@ public class InferenceEngine {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
public String getModelPath() {
return this.modelPath;
}
} }

View File

@ -16,6 +16,7 @@ import java.util.*;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static com.ly.onnx.utils.ImageUtils.matToBufferedImage; import static com.ly.onnx.utils.ImageUtils.matToBufferedImage;
@ -200,30 +201,8 @@ public class VideoPlayer {
isPaused = true; isPaused = true;
} }
// 停止视频
public void stopVideo() {
isPlaying = false;
isPaused = false;
if (frameReadingThread != null && frameReadingThread.isAlive()) {
frameReadingThread.interrupt();
}
if (inferenceThread != null && inferenceThread.isAlive()) {
inferenceThread.interrupt();
}
if (videoCapture != null) {
videoCapture.release();
videoCapture = null;
}
frameDataQueue.clear();
}
public void addInferenceEngines(InferenceEngine inferenceEngine) {
this.inferenceEngines.add(inferenceEngine);
}
// 定义一个内部类来存储帧数据 // 定义一个内部类来存储帧数据
private static class FrameData { private static class FrameData {
@ -242,24 +221,37 @@ public class VideoPlayer {
int origHeight = image.height(); int origHeight = image.height();
Map<Integer, Object> dynamicInput = new HashMap<>(); Map<Integer, Object> dynamicInput = new HashMap<>();
//定义索引 //定义索引
int index = 0; AtomicInteger index = new AtomicInteger(0);
for (InferenceEngine inferenceEngine : this.inferenceEngines) { for (InferenceEngine inferenceEngine : this.inferenceEngines) {
inferenceEngine.setIndex(index);
long[] inputShape = inferenceEngine.getInputShape(); long[] inputShape = inferenceEngine.getInputShape();
int targetWidth = (int) inputShape[2]; int targetWidth = (int) inputShape[2];
int targetHeight = (int) inputShape[3]; int targetHeight = (int) inputShape[3];
// 计算缩放因子 // 计算缩放因子
float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
//检查是否存在输入大小一致的 如果存在则跳过 //检查是否存在输入大小一致的 如果存在则跳过
if (!dynamicInput.isEmpty()) { if (!dynamicInput.isEmpty()) {
boolean flag = true; // 初始设为 true 表示需要跳过
for (Map.Entry<Integer, Object> entry : dynamicInput.entrySet()) { for (Map.Entry<Integer, Object> entry : dynamicInput.entrySet()) {
Map<String, Object> input = (Map<String, Object>) entry.getValue(); Map<String, Object> input = (Map<String, Object>) entry.getValue();
if (inputShape[2] == (long) input.get("targetHeight") || inputShape[3] == (long) input.get("targetWidth")) { Integer targetHeightValue = (Integer) input.get("targetHeight");
break; Integer targetWidthValue = (Integer) input.get("targetWidth");
if (inputShape[2] == targetHeightValue && inputShape[3] == targetWidthValue) {
flag = false; // 如果找到相同尺寸设为 false 表示不需要跳过
} }
} }
if (!flag) { // 如果找到相同尺寸跳过处理
if (inferenceEngine.getIndex() == null) {
inferenceEngine.setIndex(index.get());
}
continue;
}else {
index.getAndIncrement();
}
} }
// 计算新的图像尺寸 // 计算新的图像尺寸
int newWidth = Math.round(origWidth * scalingFactor); int newWidth = Math.round(origWidth * scalingFactor);
int newHeight = Math.round(origHeight * scalingFactor); int newHeight = Math.round(origHeight * scalingFactor);
@ -321,18 +313,50 @@ public class VideoPlayer {
// 释放图像资源 // 释放图像资源
resizedImage.release(); resizedImage.release();
paddedImage.release(); paddedImage.release();
floatImage.release();
// 将预处理结果和偏移信息存入 Map // 将预处理结果和偏移信息存入 Map
Map<String, Object> result = new HashMap<>(); Map<String, Object> result = new HashMap<>();
result.put("inputData", chwData); result.put("inputData", chwData);
result.put("origWidth", origWidth); result.put("origWidth", origWidth);
result.put("origHeight", origHeight); result.put("origHeight", origHeight);
result.put("targetWidth", targetWidth);
result.put("targetHeight", targetHeight);
result.put("scalingFactor", scalingFactor); result.put("scalingFactor", scalingFactor);
result.put("xOffset", xOffset); result.put("xOffset", xOffset);
result.put("yOffset", yOffset); result.put("yOffset", yOffset);
dynamicInput.put(index, result); inferenceEngine.setIndex(index.get());
index++; dynamicInput.put(index.get(), result);
} }
return dynamicInput; return dynamicInput;
} }
public List<InferenceEngine> getInferenceEngines() {
return this.inferenceEngines;
}
// 停止视频
public void stopVideo() {
isPlaying = false;
isPaused = false;
if (frameReadingThread != null && frameReadingThread.isAlive()) {
frameReadingThread.interrupt();
}
if (inferenceThread != null && inferenceThread.isAlive()) {
inferenceThread.interrupt();
}
if (videoCapture != null) {
videoCapture.release();
videoCapture = null;
}
frameDataQueue.clear();
}
public void addInferenceEngines(InferenceEngine inferenceEngine) {
this.inferenceEngines.add(inferenceEngine);
}
} }