拖拽支持,性能优化,动态开启跟踪,识别框优化,动态识别框

This commit is contained in:
sulv 2024-10-10 22:54:09 +08:00
parent c2dd067a56
commit 8a0924f649
15 changed files with 524 additions and 1764 deletions

View File

@ -7,14 +7,15 @@ import com.ly.onnx.engine.InferenceEngine;
import com.ly.onnx.model.ModelInfo;
import com.ly.play.opencv.VideoPlayer;
import javax.swing.*;
import javax.swing.filechooser.FileNameExtensionFilter;
import javax.swing.filechooser.FileSystemView;
import java.awt.*;
import java.awt.datatransfer.DataFlavor;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class VideoInferenceApp extends JFrame {
@ -26,10 +27,9 @@ public class VideoInferenceApp extends JFrame {
private ModelManager modelManager;
public VideoInferenceApp() {
// 设置窗口标题
super("https://gitee.com/sulv0302/onnx-inference4j-play.git");
super("ONNX Inference Application");
// 初始化UI组件
initializeUI();
}
@ -49,13 +49,53 @@ public class VideoInferenceApp extends JFrame {
videoPanel = new VideoPanel();
videoPanel.setBackground(Color.BLACK);
// 模型列表区域
// 设置拖拽功能
videoPanel.setTransferHandler(new TransferHandler() {
@Override
public boolean canImport(TransferSupport support) {
return support.isDataFlavorSupported(DataFlavor.javaFileListFlavor);
}
@Override
public boolean importData(TransferSupport support) {
if (!canImport(support)) {
return false;
}
try {
// 获取拖拽的文件列表
List<File> files = (List<File>) support.getTransferable().getTransferData(DataFlavor.javaFileListFlavor);
for (File file : files) {
String fileName = file.getName().toLowerCase();
if (fileName.endsWith(".jpg") || fileName.endsWith(".jpeg") ||
fileName.endsWith(".png") || fileName.endsWith(".bmp") ||
fileName.endsWith(".gif")) {
// 加载并处理拖拽的图片文件
videoPlayer.loadImage(file.getAbsolutePath());
} else if (fileName.endsWith(".mp4") || fileName.endsWith(".avi") ||
fileName.endsWith(".mkv") || fileName.endsWith(".mov") ||
fileName.endsWith(".flv") || fileName.endsWith(".wmv")) {
// 加载并播放拖拽的视频文件
videoPlayer.loadVideo(file.getAbsolutePath());
}
}
} catch (Exception ex) {
ex.printStackTrace();
return false;
}
return true;
}
});
// 初始化 ModelManager不传递 videoPlayer
modelManager = new ModelManager();
modelManager.setPreferredSize(new Dimension(250, 0)); // 设置模型列表区域的宽度
// 初始化 VideoPlayer
// 初始化 VideoPlayer 并传递 modelManager
videoPlayer = new VideoPlayer(videoPanel, modelManager);
// videoPlayer 设置到 modelManager
modelManager.setVideoPlayer(videoPlayer);
// 使用 JSplitPane 分割视频区域和模型列表区域
JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT, videoPanel, modelManager);
splitPane.setResizeWeight(0.8); // 视频区域初始占据80%的空间
@ -97,6 +137,10 @@ public class VideoInferenceApp extends JFrame {
JButton loadVideoButton = new JButton("选择视频文件");
loadVideoButton.setPreferredSize(new Dimension(150, 30));
// 图片文件选择按钮
JButton loadImageButton = new JButton("选择图片文件");
loadImageButton.setPreferredSize(new Dimension(150, 30));
// 模型文件选择按钮
JButton loadModelButton = new JButton("选择模型");
loadModelButton.setPreferredSize(new Dimension(150, 30));
@ -108,12 +152,19 @@ public class VideoInferenceApp extends JFrame {
JButton startPlayButton = new JButton("开始播放");
startPlayButton.setPreferredSize(new Dimension(100, 30));
// 添加目标跟踪复选框
JCheckBox trackingCheckBox = new JCheckBox("启用目标跟踪");
trackingCheckBox.setSelected(false); // 默认不启用目标跟踪
// 将按钮和输入框添加到顶部面板
topPanel.add(loadVideoButton);
topPanel.add(loadImageButton); // 添加图片按钮
topPanel.add(loadModelButton);
topPanel.add(new JLabel("流地址:"));
topPanel.add(streamUrlField);
topPanel.add(startPlayButton);
// 将复选框添加到顶部面板
topPanel.add(trackingCheckBox);
this.add(topPanel, BorderLayout.NORTH);
@ -126,6 +177,9 @@ public class VideoInferenceApp extends JFrame {
// 添加视频加载按钮的行为
loadVideoButton.addActionListener(e -> selectVideoFile());
// 添加图片加载按钮的行为
loadImageButton.addActionListener(e -> selectImageFile());
loadModelButton.addActionListener(e -> {
modelManager.loadModel(this);
DefaultListModel<ModelInfo> modelList = modelManager.getModelList();
@ -141,16 +195,31 @@ public class VideoInferenceApp extends JFrame {
}
});
// 为复选框添加监听器动态启用或禁用目标跟踪
trackingCheckBox.addActionListener(e -> {
boolean isSelected = trackingCheckBox.isSelected(); // 获取当前复选框状态
videoPlayer.setTrackingEnabled(isSelected); // 设置是否启用目标跟踪
});
// 播放按钮
playButton.addActionListener(e -> videoPlayer.playVideo());
playButton.addActionListener(e -> {
videoPlayer.playVideo();
});
// 暂停按钮
pauseButton.addActionListener(e -> videoPlayer.pauseVideo());
// // 重播按钮
// replayButton.addActionListener(e -> videoPlayer.replayVideo());
//
// 重播按钮
replayButton.addActionListener(e -> {
try {
// videoPlayer.loadVideo(videoPlayer.getCurrentVideoPath());
videoPlayer.playVideo();
} catch (Exception ex) {
ex.printStackTrace();
JOptionPane.showMessageDialog(this, "重播视频失败: " + ex.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
}
});
// // 后退5秒
// rewind5sButton.addActionListener(e -> videoPlayer.rewind(5000));
//
@ -195,6 +264,28 @@ public class VideoInferenceApp extends JFrame {
}
}
// 选择图片文件
private void selectImageFile() {
File desktopDir = FileSystemView.getFileSystemView().getHomeDirectory();
JFileChooser fileChooser = new JFileChooser(desktopDir);
fileChooser.setDialogTitle("选择图片文件");
// 设置图片文件过滤器支持常见的图片格式
FileNameExtensionFilter imageFilter = new FileNameExtensionFilter(
"图片文件 (*.jpg;*.jpeg;*.png;*.bmp;*.gif)", "jpg", "jpeg", "png", "bmp", "gif");
fileChooser.setFileFilter(imageFilter);
int returnValue = fileChooser.showOpenDialog(this);
if (returnValue == JFileChooser.APPROVE_OPTION) {
File selectedFile = fileChooser.getSelectedFile();
try {
videoPlayer.loadImage(selectedFile.getAbsolutePath());
} catch (Exception ex) {
ex.printStackTrace();
JOptionPane.showMessageDialog(this, "加载图片失败: " + ex.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
}
}
}
public static void main(String[] args) {
SwingUtilities.invokeLater(VideoInferenceApp::new);
}

View File

@ -1,410 +0,0 @@
package com.ly.lishi;
import ai.onnxruntime.*;
import com.alibaba.fastjson.JSON;
import com.ly.onnx.model.BoundingBox;
import com.ly.onnx.model.InferenceResult;
import lombok.Data;
import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.nio.FloatBuffer;
import java.util.*;
@Data
public class InferenceEngine {
private OrtEnvironment environment;
private OrtSession.SessionOptions sessionOptions;
private OrtSession session;
private String modelPath;
private List<String> labels;
// 用于存储图像预处理信息的类变量
private long[] inputShape = null;
static {
nu.pattern.OpenCV.loadLocally();
}
public InferenceEngine(String modelPath, List<String> labels) {
this.modelPath = modelPath;
this.labels = labels;
init();
}
public void init() {
try {
environment = OrtEnvironment.getEnvironment();
sessionOptions = new OrtSession.SessionOptions();
sessionOptions.addCUDA(0); // 使用 GPU
session = environment.createSession(modelPath, sessionOptions);
Map<String, NodeInfo> inputInfo = session.getInputInfo();
NodeInfo nodeInfo = inputInfo.values().iterator().next();
TensorInfo tensorInfo = (TensorInfo) nodeInfo.getInfo();
inputShape = tensorInfo.getShape(); // 从模型中获取输入形状
logModelInfo(session);
} catch (OrtException e) {
throw new RuntimeException("模型加载失败", e);
}
}
public InferenceResult infer(int w, int h, Map<String, Object> preprocessParams) {
long startTime = System.currentTimeMillis();
// Map 中获取偏移相关的变量
float[] inputData = (float[]) preprocessParams.get("inputData");
int origWidth = (int) preprocessParams.get("origWidth");
int origHeight = (int) preprocessParams.get("origHeight");
float scalingFactor = (float) preprocessParams.get("scalingFactor");
int xOffset = (int) preprocessParams.get("xOffset");
int yOffset = (int) preprocessParams.get("yOffset");
try {
Map<String, NodeInfo> inputInfo = session.getInputInfo();
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
// 创建输入张量时使用 CHW 格式的数据
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
// 执行推理
long inferenceStart = System.currentTimeMillis();
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
long inferenceEnd = System.currentTimeMillis();
System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms");
// 解析推理结果
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状[1, N, 6]
// 设定置信度阈值
float confidenceThreshold = 0.25f; // 您可以根据需要调整
// 根据模型的输出结果解析边界框
List<BoundingBox> boxes = new ArrayList<>();
for (float[] data : outputData[0]) { // 遍历所有检测框
// 根据模型输出格式解析中心坐标和宽高
float x_center = data[0];
float y_center = data[1];
float width = data[2];
float height = data[3];
float confidence = data[4];
if (confidence >= confidenceThreshold) {
// 将中心坐标转换为左上角和右下角坐标
float x1 = x_center - width / 2;
float y1 = y_center - height / 2;
float x2 = x_center + width / 2;
float y2 = y_center + height / 2;
// 调整坐标减去偏移并除以缩放因子
float x1Adjusted = (x1 - xOffset) / scalingFactor;
float y1Adjusted = (y1 - yOffset) / scalingFactor;
float x2Adjusted = (x2 - xOffset) / scalingFactor;
float y2Adjusted = (y2 - yOffset) / scalingFactor;
// 确保坐标的正确顺序
float xMinAdjusted = Math.min(x1Adjusted, x2Adjusted);
float xMaxAdjusted = Math.max(x1Adjusted, x2Adjusted);
float yMinAdjusted = Math.min(y1Adjusted, y2Adjusted);
float yMaxAdjusted = Math.max(y1Adjusted, y2Adjusted);
// 确保坐标在原始图像范围内
int x = (int) Math.max(0, xMinAdjusted);
int y = (int) Math.max(0, yMinAdjusted);
int xMax = (int) Math.min(origWidth, xMaxAdjusted);
int yMax = (int) Math.min(origHeight, yMaxAdjusted);
int wBox = xMax - x;
int hBox = yMax - y;
// 仅当宽度和高度为正时才添加边界框
if (wBox > 0 && hBox > 0) {
// 使用您的单一标签
String label = labels.get(0);
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
}
}
}
// 非极大值抑制NMS
long nmsStart = System.currentTimeMillis();
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
System.out.println("检测到的标签:" + JSON.toJSONString(nmsBoxes));
if (!nmsBoxes.isEmpty()) {
for (BoundingBox box : nmsBoxes) {
System.out.println(box);
}
}
long nmsEnd = System.currentTimeMillis();
System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms");
// 封装结果并返回
InferenceResult inferenceResult = new InferenceResult();
inferenceResult.setBoundingBoxes(nmsBoxes);
long endTime = System.currentTimeMillis();
System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms");
return inferenceResult;
} catch (OrtException e) {
throw new RuntimeException("推理失败", e);
}
}
// 计算两个边界框的 IoU
private float computeIoU(BoundingBox box1, BoundingBox box2) {
int x1 = Math.max(box1.getX(), box2.getX());
int y1 = Math.max(box1.getY(), box2.getY());
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
int box1Area = box1.getWidth() * box1.getHeight();
int box2Area = box2.getWidth() * box2.getHeight();
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
}
// 非极大值抑制NMS方法
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
// 按置信度排序从高到低
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
List<BoundingBox> result = new ArrayList<>();
while (!boxes.isEmpty()) {
BoundingBox bestBox = boxes.remove(0);
result.add(bestBox);
Iterator<BoundingBox> iterator = boxes.iterator();
while (iterator.hasNext()) {
BoundingBox box = iterator.next();
if (computeIoU(bestBox, box) > iouThreshold) {
iterator.remove();
}
}
}
return result;
}
// 打印模型信息
private void logModelInfo(OrtSession session) {
System.out.println("模型输入信息:");
try {
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
String name = entry.getKey();
NodeInfo info = entry.getValue();
System.out.println("输入名称: " + name);
System.out.println("输入信息: " + info.toString());
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
System.out.println("模型输出信息:");
try {
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
String name = entry.getKey();
NodeInfo info = entry.getValue();
System.out.println("输出名称: " + name);
System.out.println("输出信息: " + info.toString());
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) {
// 加载 OpenCV
// 初始化标签列表只有一个标签
List<String> labels = Arrays.asList("person");
// 创建 InferenceEngine 实例
InferenceEngine inferenceEngine = new InferenceEngine("C:\\Users\\ly\\Desktop\\person.onnx", labels);
for (int j = 0; j < 10; j++) {
try {
// 加载图片
Mat inputImage = Imgcodecs.imread("C:\\Users\\ly\\Desktop\\10230731212230.png");
// 预处理图像
long l1 = System.currentTimeMillis();
Map<String, Object> preprocessResult = inferenceEngine.preprocessImage(inputImage);
float[] inputData = (float[]) preprocessResult.get("inputData");
InferenceResult result = null;
for (int i = 0; i < 10; i++) {
long l = System.currentTimeMillis();
result = inferenceEngine.infer( 640, 640, preprocessResult);
System.out.println("" + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms");
}
// 处理并显示结果
System.out.println("推理结果:");
for (BoundingBox box : result.getBoundingBoxes()) {
System.out.println(box);
}
// 可视化并保存带有边界框的图像
Mat outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
// 保存图片到本地文件
String outputFilePath = "output_image_with_boxes.jpg";
Imgcodecs.imwrite(outputFilePath, outputImage);
System.out.println("已保存结果图片: " + outputFilePath);
} catch (Exception e) {
e.printStackTrace();
}
}
}
// 在图像上绘制边界框和标签
private Mat drawBoundingBoxes(Mat image, List<BoundingBox> boxes) {
for (BoundingBox box : boxes) {
// 绘制矩形边界框
Imgproc.rectangle(image, new Point(box.getX(), box.getY()),
new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight()),
new Scalar(0, 0, 255), 2); // 红色边框
// 绘制标签文字和置信度
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
int baseLine[] = new int[1];
Size labelSize = Imgproc.getTextSize(label, Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, 1, baseLine);
int top = Math.max(box.getY(), (int) labelSize.height);
Imgproc.putText(image, label, new Point(box.getX(), top),
Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, new Scalar(255, 255, 255), 1);
}
return image;
}
public Map<String, Object> preprocessImage(Mat image) {
int targetWidth = 640;
int targetHeight = 640;
int origWidth = image.width();
int origHeight = image.height();
// 计算缩放因子
float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
// 计算新的图像尺寸
int newWidth = Math.round(origWidth * scalingFactor);
int newHeight = Math.round(origHeight * scalingFactor);
// 计算偏移量以居中图像
int xOffset = (targetWidth - newWidth) / 2;
int yOffset = (targetHeight - newHeight) / 2;
// 调整图像尺寸
Mat resizedImage = new Mat();
Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR);
// 转换为 RGB 并归一化
Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
// 创建填充后的图像
Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
resizedImage.copyTo(paddedImage.submat(roi));
// 将图像数据转换为数组
int imageSize = targetWidth * targetHeight;
float[] chwData = new float[3 * imageSize];
float[] hwcData = new float[3 * imageSize];
paddedImage.get(0, 0, hwcData);
// 转换为 CHW 格式
int channelSize = imageSize;
for (int c = 0; c < 3; c++) {
for (int i = 0; i < imageSize; i++) {
chwData[c * channelSize + i] = hwcData[i * 3 + c];
}
}
// 释放图像资源
resizedImage.release();
paddedImage.release();
// 将预处理结果和偏移信息存入 Map
Map<String, Object> result = new HashMap<>();
result.put("inputData", chwData);
result.put("origWidth", origWidth);
result.put("origHeight", origHeight);
result.put("scalingFactor", scalingFactor);
result.put("xOffset", xOffset);
result.put("yOffset", yOffset);
return result;
}
// 图像预处理
// public float[] preprocessImage(Mat image) {
// int targetWidth = 640;
// int targetHeight = 640;
//
// origWidth = image.width();
// origHeight = image.height();
//
// // 计算缩放因子
// scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
//
// // 计算新的图像尺寸
// newWidth = Math.round(origWidth * scalingFactor);
// newHeight = Math.round(origHeight * scalingFactor);
//
// // 计算偏移量以居中图像
// xOffset = (targetWidth - newWidth) / 2;
// yOffset = (targetHeight - newHeight) / 2;
//
// // 调整图像尺寸
// Mat resizedImage = new Mat();
// Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR);
//
// // 转换为 RGB 并归一化
// Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
// resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
//
// // 创建填充后的图像
// Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
// Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
// resizedImage.copyTo(paddedImage.submat(roi));
//
// // 将图像数据转换为数组
// int imageSize = targetWidth * targetHeight;
// float[] chwData = new float[3 * imageSize];
// float[] hwcData = new float[3 * imageSize];
// paddedImage.get(0, 0, hwcData);
//
// // 转换为 CHW 格式
// int channelSize = imageSize;
// for (int c = 0; c < 3; c++) {
// for (int i = 0; i < imageSize; i++) {
// chwData[c * channelSize + i] = hwcData[i * 3 + c];
// }
// }
//
// // 释放图像资源
// resizedImage.release();
// paddedImage.release();
//
// return chwData;
// }
}

View File

@ -1,378 +0,0 @@
package com.ly.lishi;
import com.ly.layout.VideoPanel;
import com.ly.model_load.ModelManager;
import com.ly.onnx.engine.InferenceEngine;
import com.ly.onnx.model.InferenceResult;
import com.ly.onnx.utils.DrawImagesUtils;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Rect;
import org.opencv.core.Size;
import org.opencv.imgproc.Imgproc;
import org.opencv.videoio.VideoCapture;
import org.opencv.videoio.Videoio;
import javax.swing.*;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import static com.ly.onnx.utils.ImageUtils.matToBufferedImage;
public class VideoPlayer {
static {
// 加载 OpenCV
nu.pattern.OpenCV.loadLocally();
String OS = System.getProperty("os.name").toLowerCase();
if (OS.contains("win")) {
System.load(ClassLoader.getSystemResource("lib/win/opencv_videoio_ffmpeg470_64.dll").getPath());
}
}
private VideoCapture videoCapture;
private volatile boolean isPlaying = false;
private volatile boolean isPaused = false;
private Thread frameReadingThread;
private Thread inferenceThread;
private VideoPanel videoPanel;
private long videoDuration = 0; // 毫秒
private long currentTimestamp = 0; // 毫秒
private ModelManager modelManager;
private List<InferenceEngine> inferenceEngines = new ArrayList<>();
// 定义阻塞队列来缓冲转换后的数据
private BlockingQueue<FrameData> frameDataQueue = new LinkedBlockingQueue<>(10); // 队列容量可根据需要调整
public VideoPlayer(VideoPanel videoPanel, ModelManager modelManager) {
this.videoPanel = videoPanel;
this.modelManager = modelManager;
}
// 加载视频或流
public void loadVideo(String videoFilePathOrStreamUrl) throws Exception {
stopVideo();
if (videoFilePathOrStreamUrl.equals("0")) {
int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl);
videoCapture = new VideoCapture(cameraIndex);
if (!videoCapture.isOpened()) {
throw new Exception("无法打开摄像头");
}
videoDuration = 0; // 摄像头没有固定的时长
playVideo();
} else {
// 输入不是数字尝试打开视频文件
videoCapture = new VideoCapture(videoFilePathOrStreamUrl, Videoio.CAP_FFMPEG);
if (!videoCapture.isOpened()) {
throw new Exception("无法打开视频文件:" + videoFilePathOrStreamUrl);
}
double frameCount = videoCapture.get(Videoio.CAP_PROP_FRAME_COUNT);
double fps = videoCapture.get(Videoio.CAP_PROP_FPS);
if (fps <= 0 || Double.isNaN(fps)) {
fps = 25; // 默认帧率
}
videoDuration = (long) (frameCount / fps * 1000); // 转换为毫秒
}
// 显示第一帧
Mat frame = new Mat();
if (videoCapture.read(frame)) {
BufferedImage bufferedImage = matToBufferedImage(frame);
videoPanel.updateImage(bufferedImage);
currentTimestamp = 0;
} else {
throw new Exception("无法读取第一帧");
}
// 重置到视频开始位置
videoCapture.set(Videoio.CAP_PROP_POS_FRAMES, 0);
currentTimestamp = 0;
}
// 播放
public void playVideo() {
if (videoCapture == null || !videoCapture.isOpened()) {
JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE);
return;
}
if (isPlaying) {
if (isPaused) {
isPaused = false; // 恢复播放
}
return;
}
isPlaying = true;
isPaused = false;
frameDataQueue.clear(); // 开始播放前清空队列
// 创建并启动帧读取和转换线程
frameReadingThread = new Thread(() -> {
try {
double fps = videoCapture.get(Videoio.CAP_PROP_FPS);
if (fps <= 0 || Double.isNaN(fps)) {
fps = 25; // 默认帧率
}
long frameDelay = (long) (1000 / fps);
while (isPlaying) {
if (Thread.currentThread().isInterrupted()) {
break;
}
if (isPaused) {
Thread.sleep(10);
continue;
}
Mat frame = new Mat();
if (!videoCapture.read(frame) || frame.empty()) {
isPlaying = false;
break;
}
long startTime = System.currentTimeMillis();
BufferedImage bufferedImage = matToBufferedImage(frame);
if (bufferedImage != null) {
// float[] floats = preprocessAndConvertBufferedImage(bufferedImage);
Map<String, Object> stringObjectMap = preprocessImage(frame);
// 创建 FrameData 对象并放入队列
FrameData frameData = new FrameData(bufferedImage, null,stringObjectMap);
frameDataQueue.put(frameData); // 阻塞如果队列已满
}
// 控制帧率
currentTimestamp = (long) videoCapture.get(Videoio.CAP_PROP_POS_MSEC);
// 控制播放速度
long processingTime = System.currentTimeMillis() - startTime;
long sleepTime = frameDelay - processingTime;
if (sleepTime > 0) {
Thread.sleep(sleepTime);
}
}
} catch (Exception ex) {
ex.printStackTrace();
} finally {
isPlaying = false;
}
});
// 创建并启动推理线程
inferenceThread = new Thread(() -> {
try {
while (isPlaying || !frameDataQueue.isEmpty()) {
if (Thread.currentThread().isInterrupted()) {
break;
}
if (isPaused) {
Thread.sleep(100);
continue;
}
FrameData frameData = frameDataQueue.poll(100, TimeUnit.MILLISECONDS); // 等待数据
if (frameData == null) {
continue; // 没有数据继续检查 isPlaying
}
BufferedImage bufferedImage = frameData.image;
Map<String, Object> floatObjectMap = frameData.floatObjectMap;
// 执行推理
List<InferenceResult> inferenceResults = new ArrayList<>();
for (InferenceEngine inferenceEngine : inferenceEngines) {
// 假设 InferenceEngine infer 方法接受 float 数组
// inferenceResults.add(inferenceEngine.infer( 640, 640,floatObjectMap));
}
// 绘制推理结果
DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults);
// 更新绘制后图像
videoPanel.updateImage(bufferedImage);
}
} catch (Exception ex) {
ex.printStackTrace();
}
});
frameReadingThread.start();
inferenceThread.start();
}
// 暂停视频
public void pauseVideo() {
if (!isPlaying) {
return;
}
isPaused = true;
}
// 重播视频
public void replayVideo() {
try {
stopVideo(); // 停止当前播放
if (videoCapture != null) {
videoCapture.set(Videoio.CAP_PROP_POS_FRAMES, 0);
currentTimestamp = 0;
// 显示第一帧
Mat frame = new Mat();
if (videoCapture.read(frame)) {
BufferedImage bufferedImage = matToBufferedImage(frame);
videoPanel.updateImage(bufferedImage);
}
playVideo(); // 开始播放
}
} catch (Exception e) {
e.printStackTrace();
JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
}
}
// 停止视频
public void stopVideo() {
isPlaying = false;
isPaused = false;
if (frameReadingThread != null && frameReadingThread.isAlive()) {
frameReadingThread.interrupt();
}
if (inferenceThread != null && inferenceThread.isAlive()) {
inferenceThread.interrupt();
}
if (videoCapture != null) {
videoCapture.release();
videoCapture = null;
}
frameDataQueue.clear();
}
// 快进或后退
public void seekTo(long seekTime) {
if (videoCapture == null) return;
try {
isPaused = false; // 取消暂停
stopVideo(); // 停止当前播放
videoCapture.set(Videoio.CAP_PROP_POS_MSEC, seekTime);
currentTimestamp = seekTime;
Mat frame = new Mat();
if (videoCapture.read(frame)) {
BufferedImage bufferedImage = matToBufferedImage(frame);
videoPanel.updateImage(bufferedImage);
}
// 重新开始播放
playVideo();
} catch (Exception ex) {
ex.printStackTrace();
}
}
// 快进
public void fastForward(long milliseconds) {
long newTime = Math.min(currentTimestamp + milliseconds, videoDuration);
seekTo(newTime);
}
// 后退
public void rewind(long milliseconds) {
long newTime = Math.max(currentTimestamp - milliseconds, 0);
seekTo(newTime);
}
public void addInferenceEngines(InferenceEngine inferenceEngine) {
this.inferenceEngines.add(inferenceEngine);
}
// 定义一个内部类来存储帧数据
private static class FrameData {
public BufferedImage image;
public float[] floatArray;
public Map<String, Object> floatObjectMap;
public FrameData(BufferedImage image, float[] floatArray, Map<String, Object> floatObjectMap) {
this.image = image;
this.floatArray = floatArray;
this.floatObjectMap = floatObjectMap;
}
}
// 可选的预处理方法
public Map<String, Object> preprocessImage(Mat image) {
int targetWidth = 640;
int targetHeight = 640;
int origWidth = image.width();
int origHeight = image.height();
// 计算缩放因子
float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
// 计算新的图像尺寸
int newWidth = Math.round(origWidth * scalingFactor);
int newHeight = Math.round(origHeight * scalingFactor);
// 调整图像尺寸
Mat resizedImage = new Mat();
Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight));
// 转换为 RGB 并归一化
Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
// 创建填充后的图像
Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
int xOffset = (targetWidth - newWidth) / 2;
int yOffset = (targetHeight - newHeight) / 2;
Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
resizedImage.copyTo(paddedImage.submat(roi));
// 将图像数据转换为数组
int imageSize = targetWidth * targetHeight;
float[] chwData = new float[3 * imageSize];
float[] hwcData = new float[3 * imageSize];
paddedImage.get(0, 0, hwcData);
// 转换为 CHW 格式
int channelSize = imageSize;
for (int c = 0; c < 3; c++) {
for (int i = 0; i < imageSize; i++) {
chwData[c * channelSize + i] = hwcData[i * 3 + c];
}
}
// 释放图像资源
resizedImage.release();
paddedImage.release();
// 将预处理结果和偏移信息存入 Map
Map<String, Object> result = new HashMap<>();
result.put("inputData", chwData);
result.put("origWidth", origWidth);
result.put("origHeight", origHeight);
result.put("scalingFactor", scalingFactor);
result.put("xOffset", xOffset);
result.put("yOffset", yOffset);
return result;
}
}

View File

@ -1,19 +1,30 @@
package com.ly.model_load;
import com.ly.file.FileEditor;
import com.ly.onnx.engine.InferenceEngine;
import com.ly.onnx.model.ModelInfo;
import com.ly.play.opencv.VideoPlayer;
import javax.swing.*;
import javax.swing.filechooser.FileNameExtensionFilter;
import javax.swing.filechooser.FileSystemView;
import java.awt.*;
import java.awt.datatransfer.DataFlavor;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
import java.io.BufferedReader;
public class ModelManager extends JPanel {
private DefaultListModel<ModelInfo> modelListModel;
private JList<ModelInfo> modelList;
private JPopupMenu popupMenu;
private VideoPlayer videoPlayer;
public ModelManager() {
setLayout(new BorderLayout());
@ -23,6 +34,45 @@ public class ModelManager extends JPanel {
JScrollPane modelScrollPane = new JScrollPane(modelList);
add(modelScrollPane, BorderLayout.CENTER);
// 创建右键菜单
popupMenu = new JPopupMenu();
JMenuItem deleteMenuItem = new JMenuItem("删除");
popupMenu.add(deleteMenuItem);
// 为模型列表添加右键菜单
modelList.addMouseListener(new MouseAdapter() {
public void mousePressed(MouseEvent e) {
if (e.isPopupTrigger()) { // 如果是右键触发
showPopup(e);
}
}
public void mouseReleased(MouseEvent e) {
if (e.isPopupTrigger()) { // 如果是右键触发
showPopup(e);
}
}
private void showPopup(MouseEvent e) {
int index = modelList.locationToIndex(e.getPoint());
if (index != -1) {
modelList.setSelectedIndex(index); // 选中右键点击的行
popupMenu.show(modelList, e.getX(), e.getY());
}
}
});
// 为删除菜单项添加操作
deleteMenuItem.addActionListener(e -> {
int selectedIndex = modelList.getSelectedIndex();
if (selectedIndex != -1) {
int confirmation = JOptionPane.showConfirmDialog(null, "确定要删除此模型吗?", "确认删除", JOptionPane.YES_NO_OPTION);
if (confirmation == JOptionPane.YES_OPTION) {
modelListModel.remove(selectedIndex); // 删除选中的模型
}
}
});
// 双击编辑标签文件
modelList.addMouseListener(new MouseAdapter() {
public void mouseClicked(MouseEvent e) {
@ -36,8 +86,78 @@ public class ModelManager extends JPanel {
}
}
});
// 设置拖拽功能处理模型和标签文件
setTransferHandler(new TransferHandler() {
@Override
public boolean canImport(TransferSupport support) {
return support.isDataFlavorSupported(DataFlavor.javaFileListFlavor);
}
@Override
public boolean importData(TransferSupport support) {
if (!canImport(support)) {
return false;
}
try {
// 获取拖拽的文件列表
List<File> files = (List<File>) support.getTransferable().getTransferData(DataFlavor.javaFileListFlavor);
if (files.size() == 2) { // 确保拖拽的是两个文件
File modelFile = null;
File labelFile = null;
for (File file : files) {
if (file.getName().endsWith(".onnx")) {
modelFile = file;
} else if (file.getName().endsWith(".txt")) {
labelFile = file;
}
}
if (modelFile != null && labelFile != null) {
// 确保 videoPlayer 被正确设置
if (videoPlayer == null) {
throw new IllegalStateException("VideoPlayer is not set in ModelManager.");
}
// 添加模型信息到列表
ModelInfo modelInfo = new ModelInfo(modelFile.getAbsolutePath(), labelFile.getAbsolutePath());
modelListModel.addElement(modelInfo);
// 读取标签文件内容转为 List<String>
List<String> labels = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(new FileReader(labelFile))) {
String line;
while ((line = reader.readLine()) != null) {
labels.add(line.trim());
}
}
// 创建推理引擎并传递给 VideoPlayer
InferenceEngine inferenceEngine = new InferenceEngine(modelFile.getAbsolutePath(), labels);
videoPlayer.addInferenceEngines(inferenceEngine);
return true;
} else {
JOptionPane.showMessageDialog(null, "请拖入一个 .onnx 文件和一个 .txt 文件。", "提示", JOptionPane.WARNING_MESSAGE);
}
} else {
JOptionPane.showMessageDialog(null, "请拖入两个文件:一个 .onnx 文件和一个 .txt 文件。", "提示", JOptionPane.WARNING_MESSAGE);
}
} catch (Exception ex) {
ex.printStackTrace();
return false;
}
return false;
}
});
}
// 添加设置 VideoPlayer 的方法
public void setVideoPlayer(VideoPlayer videoPlayer) {
this.videoPlayer = videoPlayer;
}
// 加载模型
public void loadModel(JFrame parent) {
File desktopDir = FileSystemView.getFileSystemView().getHomeDirectory();

View File

@ -154,7 +154,6 @@ public class InferenceEngine {
if (wBox > 0 && hBox > 0) {
// 使用您的单一标签
String label = labels.get(0);
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
}
}

View File

@ -1,297 +0,0 @@
package com.ly.onnx.engine;
import ai.onnxruntime.*;
import com.ly.onnx.model.BoundingBox;
import com.ly.onnx.model.InferenceResult;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.List;
import java.util.*;
public class InferenceEngine_up {
OrtEnvironment environment = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
private String modelPath;
private List<String> labels;
// 添加用于存储图像预处理信息的类变量
private int origWidth;
private int origHeight;
private int newWidth;
private int newHeight;
private float scalingFactor;
private int xOffset;
private int yOffset;
public InferenceEngine_up(String modelPath, List<String> labels) {
this.modelPath = modelPath;
this.labels = labels;
init();
}
public void init() {
OrtSession session = null;
try {
sessionOptions.addCUDA(0);
session = environment.createSession(modelPath, sessionOptions);
} catch (OrtException e) {
throw new RuntimeException(e);
}
logModelInfo(session);
}
public InferenceResult infer(float[] inputData, int w, int h) {
// 创建ONNX输入Tensor
try (OrtSession session = environment.createSession(modelPath, sessionOptions)) {
Map<String, NodeInfo> inputInfo = session.getInputInfo();
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
// 执行推理
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
// 解析推理结果
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状[1, N, 5]
long l = System.currentTimeMillis();
// 设定置信度阈值
float confidenceThreshold = 0.5f; // 您可以根据需要调整
// 根据模型的输出结果解析边界框
List<BoundingBox> boxes = new ArrayList<>();
for (float[] data : outputData[0]) { // 遍历所有检测框
float confidence = data[4];
if (confidence >= confidenceThreshold) {
float xCenter = data[0];
float yCenter = data[1];
float widthBox = data[2];
float heightBox = data[3];
// 调整坐标减去偏移并除以缩放因子
float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
float widthAdjusted = widthBox / scalingFactor;
float heightAdjusted = heightBox / scalingFactor;
// 计算左上角坐标
int x = (int) (xCenterAdjusted - widthAdjusted / 2);
int y = (int) (yCenterAdjusted - heightAdjusted / 2);
int wBox = (int) widthAdjusted;
int hBox = (int) heightAdjusted;
// 确保坐标在原始图像范围内
if (x < 0) x = 0;
if (y < 0) y = 0;
if (x + wBox > origWidth) wBox = origWidth - x;
if (y + hBox > origHeight) hBox = origHeight - y;
String label = "person"; // 由于只有一个类别
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
}
}
// 非极大值抑制NMS
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
System.out.println("耗时:"+((System.currentTimeMillis() - l)));
// 封装结果并返回
InferenceResult inferenceResult = new InferenceResult();
inferenceResult.setBoundingBoxes(nmsBoxes);
return inferenceResult;
} catch (OrtException e) {
throw new RuntimeException("推理失败", e);
}
}
// 计算两个边界框的 IoU
private float computeIoU(BoundingBox box1, BoundingBox box2) {
int x1 = Math.max(box1.getX(), box2.getX());
int y1 = Math.max(box1.getY(), box2.getY());
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
int box1Area = box1.getWidth() * box1.getHeight();
int box2Area = box2.getWidth() * box2.getHeight();
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
}
// 打印模型信息
private void logModelInfo(OrtSession session) {
System.out.println("模型输入信息:");
try {
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
String name = entry.getKey();
NodeInfo info = entry.getValue();
System.out.println("输入名称: " + name);
System.out.println("输入信息: " + info.toString());
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
System.out.println("模型输出信息:");
try {
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
String name = entry.getKey();
NodeInfo info = entry.getValue();
System.out.println("输出名称: " + name);
System.out.println("输出信息: " + info.toString());
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) {
// 初始化标签列表
List<String> labels = Arrays.asList("person");
// 创建 InferenceEngine 实例
InferenceEngine_up inferenceEngine = new InferenceEngine_up("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
try {
// 加载图片
File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
BufferedImage inputImage = ImageIO.read(imageFile);
// 预处理图像
float[] inputData = inferenceEngine.preprocessImage(inputImage);
// 执行推理
InferenceResult result = null;
for (int i = 0; i < 100; i++) {
long l = System.currentTimeMillis();
result = inferenceEngine.infer(inputData, 640, 640);
System.out.println(System.currentTimeMillis() - l);
}
// 处理并显示结果
System.out.println("推理结果:");
for (BoundingBox box : result.getBoundingBoxes()) {
System.out.println(box);
}
// 可视化并保存带有边界框的图像
BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
// 保存图片到本地文件
File outputFile = new File("output_image_with_boxes.jpg");
ImageIO.write(outputImage, "jpg", outputFile);
System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
} catch (Exception e) {
e.printStackTrace();
}
}
// 在图像上绘制边界框和标签
BufferedImage drawBoundingBoxes(BufferedImage image, List<BoundingBox> boxes) {
Graphics2D g = image.createGraphics();
g.setColor(Color.RED); // 设置绘制边界框的颜色
g.setStroke(new BasicStroke(2)); // 设置线条粗细
for (BoundingBox box : boxes) {
// 绘制矩形边界框
g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
// 绘制标签文字和置信度
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
g.setFont(new Font("Arial", Font.PLAIN, 12));
g.drawString(label, box.getX(), box.getY() - 5);
}
g.dispose(); // 释放资源
return image;
}
// 图像预处理
float[] preprocessImage(BufferedImage image) {
int targetWidth = 640;
int targetHeight = 640;
origWidth = image.getWidth();
origHeight = image.getHeight();
// 计算缩放因子
scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
// 计算新的图像尺寸
newWidth = Math.round(origWidth * scalingFactor);
newHeight = Math.round(origHeight * scalingFactor);
// 计算偏移量以居中图像
xOffset = (targetWidth - newWidth) / 2;
yOffset = (targetHeight - newHeight) / 2;
// 创建一个新的BufferedImage
BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
Graphics2D g = resizedImage.createGraphics();
// 填充背景为黑色
g.setColor(Color.BLACK);
g.fillRect(0, 0, targetWidth, targetHeight);
// 绘制缩放后的图像到新的图像上
g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
g.dispose();
float[] inputData = new float[3 * targetWidth * targetHeight];
for (int c = 0; c < 3; c++) {
for (int y = 0; y < targetHeight; y++) {
for (int x = 0; x < targetWidth; x++) {
int rgb = resizedImage.getRGB(x, y);
float value = 0f;
if (c == 0) {
value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
} else if (c == 1) {
value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
} else if (c == 2) {
value = (rgb & 0xFF) / 255.0f; // Blue channel
}
inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
}
}
}
return inputData;
}
// 非极大值抑制NMS方法
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
// 按置信度排序
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
List<BoundingBox> result = new ArrayList<>();
while (!boxes.isEmpty()) {
BoundingBox bestBox = boxes.remove(0);
result.add(bestBox);
Iterator<BoundingBox> iterator = boxes.iterator();
while (iterator.hasNext()) {
BoundingBox box = iterator.next();
if (computeIoU(bestBox, box) > iouThreshold) {
iterator.remove();
}
}
}
return result;
}
// 其他方法保持不变...
}

View File

@ -1,297 +0,0 @@
package com.ly.onnx.engine;
import ai.onnxruntime.*;
import com.ly.onnx.model.BoundingBox;
import com.ly.onnx.model.InferenceResult;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.List;
import java.util.*;
public class InferenceEngine_up_1 {
OrtEnvironment environment = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions sessionOptions = null;
OrtSession session = null;
private String modelPath;
private List<String> labels;
// 添加用于存储图像预处理信息的类变量
private int origWidth;
private int origHeight;
private int newWidth;
private int newHeight;
private float scalingFactor;
private int xOffset;
private int yOffset;
public InferenceEngine_up_1(String modelPath, List<String> labels) {
this.modelPath = modelPath;
this.labels = labels;
init();
}
public void init() {
try {
sessionOptions = new OrtSession.SessionOptions();
sessionOptions.addCUDA(0);
session = environment.createSession(modelPath, sessionOptions);
} catch (OrtException e) {
throw new RuntimeException(e);
}
logModelInfo(session);
}
public InferenceResult infer(float[] inputData, int w, int h) {
// 创建ONNX输入Tensor
try {
Map<String, NodeInfo> inputInfo = session.getInputInfo();
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
// 执行推理
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
// 解析推理结果
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状[1, N, 5]
long l = System.currentTimeMillis();
// 设定置信度阈值
float confidenceThreshold = 0.5f; // 您可以根据需要调整
// 根据模型的输出结果解析边界框
List<BoundingBox> boxes = new ArrayList<>();
for (float[] data : outputData[0]) { // 遍历所有检测框
float confidence = data[4];
if (confidence >= confidenceThreshold) {
float xCenter = data[0];
float yCenter = data[1];
float widthBox = data[2];
float heightBox = data[3];
// 调整坐标减去偏移并除以缩放因子
float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
float widthAdjusted = widthBox / scalingFactor;
float heightAdjusted = heightBox / scalingFactor;
// 计算左上角坐标
int x = (int) (xCenterAdjusted - widthAdjusted / 2);
int y = (int) (yCenterAdjusted - heightAdjusted / 2);
int wBox = (int) widthAdjusted;
int hBox = (int) heightAdjusted;
// 确保坐标在原始图像范围内
if (x < 0) x = 0;
if (y < 0) y = 0;
if (x + wBox > origWidth) wBox = origWidth - x;
if (y + hBox > origHeight) hBox = origHeight - y;
String label = "person"; // 由于只有一个类别
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
}
}
// 非极大值抑制NMS
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
System.out.println("耗时:"+((System.currentTimeMillis() - l)));
// 封装结果并返回
InferenceResult inferenceResult = new InferenceResult();
inferenceResult.setBoundingBoxes(nmsBoxes);
return inferenceResult;
} catch (OrtException e) {
throw new RuntimeException("推理失败", e);
}
}
// 计算两个边界框的 IoU
private float computeIoU(BoundingBox box1, BoundingBox box2) {
int x1 = Math.max(box1.getX(), box2.getX());
int y1 = Math.max(box1.getY(), box2.getY());
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
int box1Area = box1.getWidth() * box1.getHeight();
int box2Area = box2.getWidth() * box2.getHeight();
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
}
// 打印模型信息
private void logModelInfo(OrtSession session) {
System.out.println("模型输入信息:");
try {
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
String name = entry.getKey();
NodeInfo info = entry.getValue();
System.out.println("输入名称: " + name);
System.out.println("输入信息: " + info.toString());
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
System.out.println("模型输出信息:");
try {
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
String name = entry.getKey();
NodeInfo info = entry.getValue();
System.out.println("输出名称: " + name);
System.out.println("输出信息: " + info.toString());
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) {
// 初始化标签列表
List<String> labels = Arrays.asList("person");
// 创建 InferenceEngine 实例
InferenceEngine_up_1 inferenceEngine = new InferenceEngine_up_1("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
try {
// 加载图片
File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
BufferedImage inputImage = ImageIO.read(imageFile);
// 预处理图像
float[] inputData = inferenceEngine.preprocessImage(inputImage);
// 执行推理
InferenceResult result = null;
for (int i = 0; i < 100; i++) {
long l = System.currentTimeMillis();
result = inferenceEngine.infer(inputData, 640, 640);
System.out.println(System.currentTimeMillis() - l);
}
// 处理并显示结果
System.out.println("推理结果:");
for (BoundingBox box : result.getBoundingBoxes()) {
System.out.println(box);
}
// 可视化并保存带有边界框的图像
BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
// 保存图片到本地文件
File outputFile = new File("output_image_with_boxes.jpg");
ImageIO.write(outputImage, "jpg", outputFile);
System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
} catch (Exception e) {
e.printStackTrace();
}
}
// 在图像上绘制边界框和标签
private BufferedImage drawBoundingBoxes(BufferedImage image, List<BoundingBox> boxes) {
Graphics2D g = image.createGraphics();
g.setColor(Color.RED); // 设置绘制边界框的颜色
g.setStroke(new BasicStroke(2)); // 设置线条粗细
for (BoundingBox box : boxes) {
// 绘制矩形边界框
g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
// 绘制标签文字和置信度
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
g.setFont(new Font("Arial", Font.PLAIN, 12));
g.drawString(label, box.getX(), box.getY() - 5);
}
g.dispose(); // 释放资源
return image;
}
// 图像预处理
private float[] preprocessImage(BufferedImage image) {
int targetWidth = 640;
int targetHeight = 640;
origWidth = image.getWidth();
origHeight = image.getHeight();
// 计算缩放因子
scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
// 计算新的图像尺寸
newWidth = Math.round(origWidth * scalingFactor);
newHeight = Math.round(origHeight * scalingFactor);
// 计算偏移量以居中图像
xOffset = (targetWidth - newWidth) / 2;
yOffset = (targetHeight - newHeight) / 2;
// 创建一个新的BufferedImage
BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
Graphics2D g = resizedImage.createGraphics();
// 填充背景为黑色
g.setColor(Color.BLACK);
g.fillRect(0, 0, targetWidth, targetHeight);
// 绘制缩放后的图像到新的图像上
g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
g.dispose();
float[] inputData = new float[3 * targetWidth * targetHeight];
for (int c = 0; c < 3; c++) {
for (int y = 0; y < targetHeight; y++) {
for (int x = 0; x < targetWidth; x++) {
int rgb = resizedImage.getRGB(x, y);
float value = 0f;
if (c == 0) {
value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
} else if (c == 1) {
value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
} else if (c == 2) {
value = (rgb & 0xFF) / 255.0f; // Blue channel
}
inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
}
}
}
return inputData;
}
// 非极大值抑制NMS方法
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
// 按置信度排序
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
List<BoundingBox> result = new ArrayList<>();
while (!boxes.isEmpty()) {
BoundingBox bestBox = boxes.remove(0);
result.add(bestBox);
Iterator<BoundingBox> iterator = boxes.iterator();
while (iterator.hasNext()) {
BoundingBox box = iterator.next();
if (computeIoU(bestBox, box) > iouThreshold) {
iterator.remove();
}
}
}
return result;
}
// 其他方法保持不变...
}

View File

@ -1,314 +0,0 @@
package com.ly.onnx.engine;
import ai.onnxruntime.*;
import com.ly.onnx.model.BoundingBox;
import com.ly.onnx.model.InferenceResult;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.List;
import java.util.*;
public class InferenceEngine_up_2 {
private OrtEnvironment environment;
private OrtSession.SessionOptions sessionOptions;
private OrtSession session; // session 作为类的成员变量
private String modelPath;
private List<String> labels;
// 添加用于存储图像预处理信息的类变量
private int origWidth;
private int origHeight;
private int newWidth;
private int newHeight;
private float scalingFactor;
private int xOffset;
private int yOffset;
public InferenceEngine_up_2(String modelPath, List<String> labels) {
this.modelPath = modelPath;
this.labels = labels;
init();
}
public void init() {
try {
environment = OrtEnvironment.getEnvironment();
sessionOptions = new OrtSession.SessionOptions();
sessionOptions.addCUDA(0); // 使用 GPU
session = environment.createSession(modelPath, sessionOptions);
logModelInfo(session);
} catch (OrtException e) {
throw new RuntimeException("模型加载失败", e);
}
}
public InferenceResult infer(float[] inputData, int w, int h) {
long startTime = System.currentTimeMillis();
try {
Map<String, NodeInfo> inputInfo = session.getInputInfo();
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
// 执行推理
long inferenceStart = System.currentTimeMillis();
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
long inferenceEnd = System.currentTimeMillis();
System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms");
// 解析推理结果
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状[1, N, 5]
// 设定置信度阈值
float confidenceThreshold = 0.5f; // 您可以根据需要调整
// 根据模型的输出结果解析边界框
List<BoundingBox> boxes = new ArrayList<>();
for (float[] data : outputData[0]) { // 遍历所有检测框
float confidence = data[4];
if (confidence >= confidenceThreshold) {
float xCenter = data[0];
float yCenter = data[1];
float widthBox = data[2];
float heightBox = data[3];
// 调整坐标减去偏移并除以缩放因子
float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
float widthAdjusted = widthBox / scalingFactor;
float heightAdjusted = heightBox / scalingFactor;
// 计算左上角坐标
int x = (int) (xCenterAdjusted - widthAdjusted / 2);
int y = (int) (yCenterAdjusted - heightAdjusted / 2);
int wBox = (int) widthAdjusted;
int hBox = (int) heightAdjusted;
// 确保坐标在原始图像范围内
if (x < 0) x = 0;
if (y < 0) y = 0;
if (x + wBox > origWidth) wBox = origWidth - x;
if (y + hBox > origHeight) hBox = origHeight - y;
String label = "person"; // 由于只有一个类别
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
}
}
// 非极大值抑制NMS
long nmsStart = System.currentTimeMillis();
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
long nmsEnd = System.currentTimeMillis();
System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms");
// 封装结果并返回
InferenceResult inferenceResult = new InferenceResult();
inferenceResult.setBoundingBoxes(nmsBoxes);
long endTime = System.currentTimeMillis();
System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms");
return inferenceResult;
} catch (OrtException e) {
throw new RuntimeException("推理失败", e);
}
}
// 计算两个边界框的 IoU
private float computeIoU(BoundingBox box1, BoundingBox box2) {
int x1 = Math.max(box1.getX(), box2.getX());
int y1 = Math.max(box1.getY(), box2.getY());
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
int box1Area = box1.getWidth() * box1.getHeight();
int box2Area = box2.getWidth() * box2.getHeight();
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
}
// 打印模型信息
private void logModelInfo(OrtSession session) {
System.out.println("模型输入信息:");
try {
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
String name = entry.getKey();
NodeInfo info = entry.getValue();
System.out.println("输入名称: " + name);
System.out.println("输入信息: " + info.toString());
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
System.out.println("模型输出信息:");
try {
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
String name = entry.getKey();
NodeInfo info = entry.getValue();
System.out.println("输出名称: " + name);
System.out.println("输出信息: " + info.toString());
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) {
// 初始化标签列表
List<String> labels = Arrays.asList("person");
// 创建 InferenceEngine 实例
InferenceEngine_up_2 inferenceEngine = new InferenceEngine_up_2("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
try {
// 加载图片
File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
BufferedImage inputImage = ImageIO.read(imageFile);
// 预处理图像
long l1 = System.currentTimeMillis();
float[] inputData = inferenceEngine.preprocessImage(inputImage);
System.out.println(""+(System.currentTimeMillis() - l1));
// 执行推理
InferenceResult result = null;
for (int i = 0; i < 10; i++) {
long l = System.currentTimeMillis();
result = inferenceEngine.infer(inputData, 640, 640);
System.out.println("" + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms");
}
// 处理并显示结果
System.out.println("推理结果:");
for (BoundingBox box : result.getBoundingBoxes()) {
System.out.println(box);
}
// 可视化并保存带有边界框的图像
BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
// 保存图片到本地文件
File outputFile = new File("output_image_with_boxes.jpg");
ImageIO.write(outputImage, "jpg", outputFile);
System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
} catch (Exception e) {
e.printStackTrace();
}
}
// 在图像上绘制边界框和标签
private BufferedImage drawBoundingBoxes(BufferedImage image, List<BoundingBox> boxes) {
Graphics2D g = image.createGraphics();
g.setColor(Color.RED); // 设置绘制边界框的颜色
g.setStroke(new BasicStroke(2)); // 设置线条粗细
for (BoundingBox box : boxes) {
// 绘制矩形边界框
g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
// 绘制标签文字和置信度
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
g.setFont(new Font("Arial", Font.PLAIN, 12));
g.drawString(label, box.getX(), box.getY() - 5);
}
g.dispose(); // 释放资源
return image;
}
// 图像预处理
public float[] preprocessImage(BufferedImage image) {
int targetWidth = 640;
int targetHeight = 640;
origWidth = image.getWidth();
origHeight = image.getHeight();
// 计算缩放因子
scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
// 计算新的图像尺寸
newWidth = Math.round(origWidth * scalingFactor);
newHeight = Math.round(origHeight * scalingFactor);
// 计算偏移量以居中图像
xOffset = (targetWidth - newWidth) / 2;
yOffset = (targetHeight - newHeight) / 2;
// 创建一个新的BufferedImage
BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
Graphics2D g = resizedImage.createGraphics();
// 填充背景为黑色
g.setColor(Color.BLACK);
g.fillRect(0, 0, targetWidth, targetHeight);
// 绘制缩放后的图像到新的图像上
g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
g.dispose();
float[] inputData = new float[3 * targetWidth * targetHeight];
// 开始计时
long preprocessStart = System.currentTimeMillis();
for (int c = 0; c < 3; c++) {
for (int y = 0; y < targetHeight; y++) {
for (int x = 0; x < targetWidth; x++) {
int rgb = resizedImage.getRGB(x, y);
float value = 0f;
if (c == 0) {
value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
} else if (c == 1) {
value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
} else if (c == 2) {
value = (rgb & 0xFF) / 255.0f; // Blue channel
}
inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
}
}
}
// 结束计时
long preprocessEnd = System.currentTimeMillis();
System.out.println("图像预处理耗时:" + (preprocessEnd - preprocessStart) + " ms");
return inputData;
}
// 非极大值抑制NMS方法
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
// 按置信度排序从高到低
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
List<BoundingBox> result = new ArrayList<>();
while (!boxes.isEmpty()) {
BoundingBox bestBox = boxes.remove(0);
result.add(bestBox);
Iterator<BoundingBox> iterator = boxes.iterator();
while (iterator.hasNext()) {
BoundingBox box = iterator.next();
if (computeIoU(bestBox, box) > iouThreshold) {
iterator.remove();
}
}
}
return result;
}
}

View File

@ -10,6 +10,7 @@ public class BoundingBox {
private int height;
private String label;
private float confidence;
private long trackId;
// 构造函数getter setter 方法
@ -22,6 +23,5 @@ public class BoundingBox {
this.confidence = confidence;
}
// Getter Setter 方法
// ...
}

View File

@ -13,53 +13,164 @@ import java.util.List;
public class DrawImagesUtils {
// 使用HSL颜色生成更高级的颜色
public static Color hslToRgb(float hue, float saturation, float lightness) {
float c = (1 - Math.abs(2 * lightness - 1)) * saturation;
float x = c * (1 - Math.abs((hue / 60) % 2 - 1));
float m = lightness - c / 2;
float r = 0, g = 0, b = 0;
if (0 <= hue && hue < 60) {
r = c;
g = x;
} else if (60 <= hue && hue < 120) {
r = x;
g = c;
} else if (120 <= hue && hue < 180) {
g = c;
b = x;
} else if (180 <= hue && hue < 240) {
g = x;
b = c;
} else if (240 <= hue && hue < 300) {
r = x;
b = c;
} else if (300 <= hue && hue < 360) {
r = c;
b = x;
}
int rVal = (int) ((r + m) * 255);
int gVal = (int) ((g + m) * 255);
int bVal = (int) ((b + m) * 255);
return new Color(rVal, gVal, bVal);
}
// 根据模型索引生成颜色
private static Color generateColorForModel(int modelIndex, int totalModels) {
float hue = (360.0f / totalModels) * modelIndex; // 根据模型索引设置色相
return hslToRgb(hue, 0.7f, 0.5f); // 饱和度0.7亮度0.5
}
// BufferedImage 上绘制推理结果
public static void drawInferenceResult(BufferedImage bufferedImage, List<InferenceResult> inferenceResults) {
Graphics2D g2d = bufferedImage.createGraphics();
g2d.setFont(new Font("Arial", Font.PLAIN, 12));
g2d.setFont(new Font("Arial", Font.PLAIN, 24)); // 设置字体大小为24
int modelIndex = 0; // 模型索引
int totalModels = inferenceResults.size(); // 总模型数
for (InferenceResult result : inferenceResults) {
Color modelColor = generateColorForModel(modelIndex++, totalModels); // 为每个模型生成独立颜色
for (BoundingBox box : result.getBoundingBoxes()) {
// 绘制矩形
g2d.setColor(Color.RED);
// 绘制矩形框
g2d.setColor(modelColor);
g2d.setStroke(new BasicStroke(4)); // 设置线条粗细
g2d.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
// 绘制标签
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
// 获取字体度量
FontMetrics metrics = g2d.getFontMetrics();
int labelWidth = metrics.stringWidth(label);
int labelHeight = metrics.getHeight();
int labelHeight = metrics.getHeight() + 4; // 标签高度
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
int labelWidth = metrics.stringWidth(label) + 10; // 标签宽度
// 确保文字不会超出图像
int y = Math.max(box.getY(), labelHeight);
String trackIdLabel = "TrackID: " + box.getTrackId();
int trackIdWidth = metrics.stringWidth(trackIdLabel) + 10; // TrackID标签宽度
int trackIdHeight = metrics.getHeight() + 4; // TrackID标签高度
// 绘制文字背景
g2d.setColor(Color.RED);
g2d.fillRect(box.getX(), y - labelHeight, labelWidth, labelHeight);
// 计算标签总高度
int totalLabelHeight = (box.getTrackId() > 0 ? trackIdHeight : 0) + labelHeight;
// 绘制文字
g2d.setColor(Color.WHITE);
g2d.drawString(label, box.getX(), y);
// 边距
int margin = 10;
// 检查上方是否有足够空间绘制标签
boolean canDrawAbove = box.getY() >= totalLabelHeight + margin;
if (canDrawAbove) {
// 在检测框上方绘制标签
int currentY = box.getY() - totalLabelHeight;
// 绘制 TrackID如果有
if (box.getTrackId() > 0) {
// 绘制 TrackID 背景
g2d.setColor(modelColor);
g2d.fillRect(box.getX(), currentY, trackIdWidth, trackIdHeight);
// 绘制 TrackID 文字
g2d.setColor(Color.BLACK);
g2d.drawString(trackIdLabel, box.getX() + 5, currentY + metrics.getAscent());
currentY += trackIdHeight;
}
// 绘制 classid 背景
g2d.setColor(modelColor);
g2d.fillRect(box.getX(), currentY, labelWidth, labelHeight);
// 绘制 classid 文字
g2d.setColor(Color.BLACK);
g2d.drawString(label, box.getX() + 5, currentY + metrics.getAscent());
} else {
// 如果上方空间不足则在检测框内部顶部绘制标签
int currentY = box.getY() + 5; // 内边距5
// 绘制半透明背景以提高可读性
int bgAlpha = 200; // 透明度0-255
Color backgroundColor = new Color(modelColor.getRed(), modelColor.getGreen(), modelColor.getBlue(), bgAlpha);
if (box.getTrackId() > 0) {
// 绘制 TrackID 背景
g2d.setColor(backgroundColor);
g2d.fillRect(box.getX(), currentY, trackIdWidth, trackIdHeight);
// 绘制 TrackID 文字
g2d.setColor(Color.BLACK);
g2d.drawString(trackIdLabel, box.getX() + 5, currentY + metrics.getAscent());
currentY += trackIdHeight;
}
// 绘制 classid 背景
g2d.setColor(backgroundColor);
g2d.fillRect(box.getX(), currentY, labelWidth, labelHeight);
// 绘制 classid 文字
g2d.setColor(Color.BLACK);
g2d.drawString(label, box.getX() + 5, currentY + metrics.getAscent());
}
}
}
g2d.dispose(); // 释放资源
}
// Mat 上绘制推理结果
// Mat 上绘制推理结果 (OpenCV 版本)
public static void drawInferenceResult(Mat image, List<InferenceResult> inferenceResults) {
int modelIndex = 0;
int totalModels = inferenceResults.size();
for (InferenceResult result : inferenceResults) {
Scalar modelColor = convertColorToScalar(generateColorForModel(modelIndex++, totalModels));
for (BoundingBox box : result.getBoundingBoxes()) {
// 绘制矩形
Point topLeft = new Point(box.getX(), box.getY());
Point bottomRight = new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight());
Imgproc.rectangle(image, topLeft, bottomRight, new Scalar(0, 0, 255), 2); // 红色边框
Imgproc.rectangle(image, topLeft, bottomRight, modelColor, 3); // 加粗边框
// 绘制标签
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
int font = Imgproc.FONT_HERSHEY_SIMPLEX;
double fontScale = 0.5;
int thickness = 1;
double fontScale = 0.7;
int thickness = 2;
// 计算文字大小
int[] baseLine = new int[1];
@ -71,12 +182,17 @@ public class DrawImagesUtils {
// 绘制文字背景
Imgproc.rectangle(image, new Point(topLeft.x, y - labelSize.height),
new Point(topLeft.x + labelSize.width, y + baseLine[0]),
new Scalar(0, 0, 255), Imgproc.FILLED);
modelColor, Imgproc.FILLED);
// 绘制文字
// 绘制黑色文字
Imgproc.putText(image, label, new Point(topLeft.x, y),
font, fontScale, new Scalar(255, 255, 255), thickness);
font, fontScale, new Scalar(0, 0, 0), thickness); // 黑色文字
}
}
}
// Color 转为 Scalar (用于 OpenCV)
private static Scalar convertColorToScalar(Color color) {
return new Scalar(color.getBlue(), color.getGreen(), color.getRed()); // OpenCV 中颜色顺序是 BGR
}
}

View File

@ -3,9 +3,12 @@ package com.ly.play.opencv;
import com.ly.layout.VideoPanel;
import com.ly.model_load.ModelManager;
import com.ly.onnx.engine.InferenceEngine;
import com.ly.onnx.model.BoundingBox;
import com.ly.onnx.model.InferenceResult;
import com.ly.onnx.utils.DrawImagesUtils;
import com.ly.track.SimpleTracker;
import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import org.opencv.videoio.VideoCapture;
import org.opencv.videoio.Videoio;
@ -37,9 +40,13 @@ public class VideoPlayer {
private Thread inferenceThread;
private VideoPanel videoPanel;
// 创建简单的跟踪器
SimpleTracker tracker = new SimpleTracker();
private long videoDuration = 0; // 毫秒
private long currentTimestamp = 0; // 毫秒
private boolean isTrackingEnabled;
private ModelManager modelManager;
private List<InferenceEngine> inferenceEngines = new ArrayList<>();
@ -178,8 +185,20 @@ public class VideoPlayer {
List<InferenceResult> inferenceResults = new ArrayList<>();
for (InferenceEngine inferenceEngine : inferenceEngines) {
// 假设 InferenceEngine infer 方法接受 float 数组
inferenceResults.add(inferenceEngine.infer(floatObjectMap));
InferenceResult infer = inferenceEngine.infer(floatObjectMap);
inferenceResults.add(infer);
}
// 合并所有模型的推理结果
List<BoundingBox> allBoundingBoxes = new ArrayList<>();
for (InferenceResult result : inferenceResults) {
allBoundingBoxes.addAll(result.getBoundingBoxes());
}
// 如果启用了目标跟踪则更新边界框并分配 trackId
if (isTrackingEnabled) {
tracker.update(allBoundingBoxes);
}
// 绘制推理结果
DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults);
// 更新绘制后图像
@ -201,20 +220,22 @@ public class VideoPlayer {
isPaused = true;
}
// 设置是否启用目标跟踪
public void setTrackingEnabled(boolean enabled) {
this.isTrackingEnabled = enabled;
}
// 定义一个内部类来存储帧数据
private static class FrameData {
public BufferedImage image;
public Map<Integer, Object> floatObjectMap;
public FrameData(BufferedImage image, Map<Integer, Object> floatObjectMap) {
this.image = image;
this.floatObjectMap = floatObjectMap;
}
}
// 可选的预处理方法
public Map<Integer, Object> preprocessImage(Mat image) {
int origWidth = image.width();
@ -246,9 +267,8 @@ public class VideoPlayer {
inferenceEngine.setIndex(index.get());
}
continue;
}else {
} else {
index.getAndIncrement();
}
}
@ -330,6 +350,7 @@ public class VideoPlayer {
return dynamicInput;
}
public List<InferenceEngine> getInferenceEngines() {
return this.inferenceEngines;
}
@ -359,4 +380,45 @@ public class VideoPlayer {
this.inferenceEngines.add(inferenceEngine);
}
// 加载并处理图片
public void loadImage(String imagePath) throws Exception {
// 停止任何正在播放的视频
stopVideo();
// 读取图片
Mat image = Imgcodecs.imread(imagePath);
if (image.empty()) {
throw new Exception("无法读取图片文件:" + imagePath);
}
// 转换为 BufferedImage
BufferedImage bufferedImage = matToBufferedImage(image);
// 预处理图片
Map<Integer, Object> preprocessedData = preprocessImage(image);
// 执行推理
List<InferenceResult> inferenceResults = new ArrayList<>();
for (InferenceEngine inferenceEngine : inferenceEngines) {
InferenceResult infer = inferenceEngine.infer(preprocessedData);
inferenceResults.add(infer);
}
// 合并所有模型的推理结果
List<BoundingBox> allBoundingBoxes = new ArrayList<>();
for (InferenceResult result : inferenceResults) {
allBoundingBoxes.addAll(result.getBoundingBoxes());
}
// 如果启用了目标跟踪则更新边界框并分配 trackId
if (isTrackingEnabled) {
tracker.update(allBoundingBoxes);
}
// 绘制推理结果
DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults);
// VideoPanel 上显示图片
videoPanel.updateImage(bufferedImage);
}
}

View File

@ -0,0 +1,73 @@
package com.ly.track;
import com.ly.onnx.model.BoundingBox;
import lombok.Data;
import java.awt.*;
import java.util.*;
import java.util.List;
public class SimpleTracker {
private Map<Long, TrackedObject> trackedObjects = new HashMap<>(); // 使用自定义 TrackedObject 来跟踪
private long currentTrackId = 0;
// 跟踪器更新方法
public List<BoundingBox> update(List<BoundingBox> detections) {
List<BoundingBox> updatedResults = new ArrayList<>();
for (BoundingBox detection : detections) {
boolean matched = false;
Point detectionCenter = getCenter(detection); // 获取当前检测目标的中心点
// 遍历现有的跟踪目标
for (Map.Entry<Long, TrackedObject> entry : trackedObjects.entrySet()) {
TrackedObject trackedObject = entry.getValue();
Point trackedCenter = getCenter(trackedObject.boundingBox);
// 使用中心点欧几里得距离进行匹配
double distance = euclideanDistance(detectionCenter, trackedCenter);
// 如果距离小于某个阈值认为是同一目标
if (distance < 50.0) { // 自定义距离阈值可以根据需要调整
detection.setTrackId(entry.getKey()); // 更新检测框的 trackId
trackedObject.update(detection); // 更新跟踪对象
matched = true;
break;
}
}
// 如果没有匹配到创建新的 trackId
if (!matched) {
long newTrackId = ++currentTrackId;
detection.setTrackId(newTrackId);
trackedObjects.put(newTrackId, new TrackedObject(detection));
}
updatedResults.add(detection);
}
// 清理丢失的目标
cleanupLostObjects();
return updatedResults;
}
// 计算目标的中心点
private Point getCenter(BoundingBox box) {
int centerX = box.getX() + box.getWidth() / 2;
int centerY = box.getY() + box.getHeight() / 2;
return new Point(centerX, centerY);
}
// 计算欧几里得距离
private double euclideanDistance(Point p1, Point p2) {
return Math.sqrt(Math.pow(p1.x - p2.x, 2) + Math.pow(p1.y - p2.y, 2));
}
// 清理丢失的跟踪对象例如不再检测到的对象
private void cleanupLostObjects() {
// 可以根据时间戳或其他条件来清理长时间没有更新的目标
trackedObjects.entrySet().removeIf(entry -> entry.getValue().isLost());
}
}

View File

@ -0,0 +1,23 @@
package com.ly.track;
import com.ly.onnx.model.BoundingBox;
public class TrackedObject {
public BoundingBox boundingBox;
private int lostFrames = 0; // 记录连续多少帧未检测到
public TrackedObject(BoundingBox initialBox) {
this.boundingBox = initialBox;
}
// 更新跟踪目标的位置
public void update(BoundingBox newBox) {
this.boundingBox = newBox;
lostFrames = 0; // 重置丢失计数
}
// 如果目标连续丢失多帧认为目标丢失
public boolean isLost() {
return lostFrames++ > 10; // 如果丢失超过10帧就认为目标丢失
}
}

View File

@ -1,28 +0,0 @@
package com.ly.utils;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.videoio.VideoCapture;
public class OpenCVTest {
// static {
// nu.pattern.OpenCV.loadLocally();
// }
public static void main(String[] args) {
VideoCapture capture = new VideoCapture(0); // 打开默认摄像头
if (!capture.isOpened()) {
System.out.println("无法打开摄像头");
return;
}
Mat frame = new Mat();
if (capture.read(frame)) {
System.out.println("成功读取一帧图像");
} else {
System.out.println("无法读取图像");
}
capture.release();
}
}

Binary file not shown.