支持多模型不同输入支持

This commit is contained in:
sulv 2024-10-10 00:48:10 +08:00
parent 23cd0a7b21
commit 7a88b08500
3 changed files with 72 additions and 40 deletions

View File

@ -130,27 +130,33 @@ public class VideoInferenceApp extends JFrame {
modelManager.loadModel(this);
DefaultListModel<ModelInfo> modelList = modelManager.getModelList();
ArrayList<ModelInfo> models = Collections.list(modelList.elements());
for (ModelInfo modelInfo : models) {
if (modelInfo != null) {
videoPlayer.addInferenceEngines(new InferenceEngine(modelInfo.getModelFilePath(), modelInfo.getLabels()));
boolean alreadyAdded = videoPlayer.getInferenceEngines().stream()
.anyMatch(engine -> engine.getModelPath().equals(modelInfo.getModelFilePath()));
if (!alreadyAdded) {
videoPlayer.addInferenceEngines(new InferenceEngine(modelInfo.getModelFilePath(), modelInfo.getLabels()));
}
}
}
});
// 播放按钮
playButton.addActionListener(e -> videoPlayer.playVideo());
// 暂停按钮
pauseButton.addActionListener(e -> videoPlayer.pauseVideo());
// 重播按钮
replayButton.addActionListener(e -> videoPlayer.replayVideo());
// 后退5秒
rewind5sButton.addActionListener(e -> videoPlayer.rewind(5000));
// 快进5秒
fastForward5sButton.addActionListener(e -> videoPlayer.fastForward(5000));
// // 重播按钮
// replayButton.addActionListener(e -> videoPlayer.replayVideo());
//
// // 后退5秒
// rewind5sButton.addActionListener(e -> videoPlayer.rewind(5000));
//
// // 快进5秒
// fastForward5sButton.addActionListener(e -> videoPlayer.fastForward(5000));
// 开始播放按钮的行为
startPlayButton.addActionListener(e -> {

View File

@ -24,7 +24,7 @@ public class InferenceEngine {
private List<String> labels;
//preprocessParams输入数据的索引
private int index;
private Integer index;
// 用于存储图像预处理信息的类变量
private long[] inputShape = null;
@ -225,6 +225,8 @@ public class InferenceEngine {
throw new RuntimeException(e);
}
}
public String getModelPath() {
return this.modelPath;
}
}

View File

@ -16,6 +16,7 @@ import java.util.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static com.ly.onnx.utils.ImageUtils.matToBufferedImage;
@ -200,30 +201,8 @@ public class VideoPlayer {
isPaused = true;
}
// 停止视频
public void stopVideo() {
isPlaying = false;
isPaused = false;
if (frameReadingThread != null && frameReadingThread.isAlive()) {
frameReadingThread.interrupt();
}
if (inferenceThread != null && inferenceThread.isAlive()) {
inferenceThread.interrupt();
}
if (videoCapture != null) {
videoCapture.release();
videoCapture = null;
}
frameDataQueue.clear();
}
public void addInferenceEngines(InferenceEngine inferenceEngine) {
this.inferenceEngines.add(inferenceEngine);
}
// 定义一个内部类来存储帧数据
private static class FrameData {
@ -242,24 +221,37 @@ public class VideoPlayer {
int origHeight = image.height();
Map<Integer, Object> dynamicInput = new HashMap<>();
//定义索引
int index = 0;
AtomicInteger index = new AtomicInteger(0);
for (InferenceEngine inferenceEngine : this.inferenceEngines) {
inferenceEngine.setIndex(index);
long[] inputShape = inferenceEngine.getInputShape();
int targetWidth = (int) inputShape[2];
int targetHeight = (int) inputShape[3];
// 计算缩放因子
float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
//检查是否存在输入大小一致的 如果存在则跳过
if (!dynamicInput.isEmpty()) {
boolean flag = true; // 初始设为 true 表示需要跳过
for (Map.Entry<Integer, Object> entry : dynamicInput.entrySet()) {
Map<String, Object> input = (Map<String, Object>) entry.getValue();
if (inputShape[2] == (long) input.get("targetHeight") || inputShape[3] == (long) input.get("targetWidth")) {
break;
Integer targetHeightValue = (Integer) input.get("targetHeight");
Integer targetWidthValue = (Integer) input.get("targetWidth");
if (inputShape[2] == targetHeightValue && inputShape[3] == targetWidthValue) {
flag = false; // 如果找到相同尺寸设为 false 表示不需要跳过
}
}
if (!flag) { // 如果找到相同尺寸跳过处理
if (inferenceEngine.getIndex() == null) {
inferenceEngine.setIndex(index.get());
}
continue;
}else {
index.getAndIncrement();
}
}
// 计算新的图像尺寸
int newWidth = Math.round(origWidth * scalingFactor);
int newHeight = Math.round(origHeight * scalingFactor);
@ -321,18 +313,50 @@ public class VideoPlayer {
// 释放图像资源
resizedImage.release();
paddedImage.release();
floatImage.release();
// 将预处理结果和偏移信息存入 Map
Map<String, Object> result = new HashMap<>();
result.put("inputData", chwData);
result.put("origWidth", origWidth);
result.put("origHeight", origHeight);
result.put("targetWidth", targetWidth);
result.put("targetHeight", targetHeight);
result.put("scalingFactor", scalingFactor);
result.put("xOffset", xOffset);
result.put("yOffset", yOffset);
dynamicInput.put(index, result);
index++;
inferenceEngine.setIndex(index.get());
dynamicInput.put(index.get(), result);
}
return dynamicInput;
}
public List<InferenceEngine> getInferenceEngines() {
return this.inferenceEngines;
}
// 停止视频
public void stopVideo() {
isPlaying = false;
isPaused = false;
if (frameReadingThread != null && frameReadingThread.isAlive()) {
frameReadingThread.interrupt();
}
if (inferenceThread != null && inferenceThread.isAlive()) {
inferenceThread.interrupt();
}
if (videoCapture != null) {
videoCapture.release();
videoCapture = null;
}
frameDataQueue.clear();
}
public void addInferenceEngines(InferenceEngine inferenceEngine) {
this.inferenceEngines.add(inferenceEngine);
}
}