update
This commit is contained in:
parent
e2b80ce297
commit
02a3922c7b
17
pom.xml
17
pom.xml
|
@ -22,7 +22,7 @@
|
|||
<dependency>
|
||||
<groupId>com.microsoft.onnxruntime</groupId>
|
||||
<artifactId>onnxruntime_gpu</artifactId>
|
||||
<version>1.17.0</version>
|
||||
<version>1.16.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
|
@ -40,6 +40,21 @@
|
|||
<artifactId>ffmpeg-platform</artifactId>
|
||||
<version>5.0-1.5.7</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.openpnp</groupId>
|
||||
<artifactId>opencv</artifactId>
|
||||
<version>4.7.0-0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>1.18.32</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>fastjson</artifactId>
|
||||
<version>1.2.83</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
|
|
|
@ -3,13 +3,18 @@ package com.ly;
|
|||
import com.formdev.flatlaf.FlatLightLaf;
|
||||
import com.ly.layout.VideoPanel;
|
||||
import com.ly.model_load.ModelManager;
|
||||
import com.ly.play.VideoPlayer;
|
||||
import com.ly.onnx.engine.InferenceEngine;
|
||||
import com.ly.onnx.model.ModelInfo;
|
||||
import com.ly.play.opencv.VideoPlayer;
|
||||
|
||||
|
||||
import javax.swing.*;
|
||||
import javax.swing.filechooser.FileNameExtensionFilter;
|
||||
import javax.swing.filechooser.FileSystemView;
|
||||
import java.awt.*;
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
||||
public class VideoInferenceApp extends JFrame {
|
||||
|
||||
|
@ -43,13 +48,13 @@ public class VideoInferenceApp extends JFrame {
|
|||
videoPanel = new VideoPanel();
|
||||
videoPanel.setBackground(Color.BLACK);
|
||||
|
||||
// 初始化 VideoPlayer
|
||||
videoPlayer = new VideoPlayer(videoPanel);
|
||||
|
||||
// 模型列表区域
|
||||
modelManager = new ModelManager();
|
||||
modelManager.setPreferredSize(new Dimension(250, 0)); // 设置模型列表区域的宽度
|
||||
|
||||
// 初始化 VideoPlayer
|
||||
videoPlayer = new VideoPlayer(videoPanel, modelManager);
|
||||
|
||||
// 使用 JSplitPane 分割视频区域和模型列表区域
|
||||
JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT, videoPanel, modelManager);
|
||||
splitPane.setResizeWeight(0.8); // 视频区域初始占据80%的空间
|
||||
|
@ -120,8 +125,16 @@ public class VideoInferenceApp extends JFrame {
|
|||
// 添加视频加载按钮的行为
|
||||
loadVideoButton.addActionListener(e -> selectVideoFile());
|
||||
|
||||
// 添加模型加载按钮的行为
|
||||
loadModelButton.addActionListener(e -> modelManager.loadModel(this));
|
||||
loadModelButton.addActionListener(e -> {
|
||||
modelManager.loadModel(this);
|
||||
DefaultListModel<ModelInfo> modelList = modelManager.getModelList();
|
||||
ArrayList<ModelInfo> models = Collections.list(modelList.elements());
|
||||
for (ModelInfo modelInfo : models) {
|
||||
if (modelInfo != null) {
|
||||
videoPlayer.addInferenceEngines(new InferenceEngine(modelInfo.getModelFilePath(), modelInfo.getLabels()));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 播放按钮
|
||||
playButton.addActionListener(e -> videoPlayer.playVideo());
|
||||
|
|
|
@ -1,42 +1,37 @@
|
|||
package com.ly.model_load;
|
||||
|
||||
|
||||
|
||||
import com.ly.file.FileEditor;
|
||||
import com.ly.onnx.model.ModelInfo;
|
||||
|
||||
import javax.swing.*;
|
||||
import javax.swing.filechooser.FileNameExtensionFilter;
|
||||
import javax.swing.filechooser.FileSystemView;
|
||||
import javax.swing.table.DefaultTableModel;
|
||||
import java.awt.*;
|
||||
import java.awt.event.MouseAdapter;
|
||||
import java.awt.event.MouseEvent;
|
||||
import java.io.File;
|
||||
|
||||
public class ModelManager extends JPanel {
|
||||
private DefaultListModel<String> modelListModel;
|
||||
private JList<String> modelList;
|
||||
private DefaultListModel<ModelInfo> modelListModel;
|
||||
private JList<ModelInfo> modelList;
|
||||
|
||||
public ModelManager() {
|
||||
setLayout(new BorderLayout());
|
||||
modelListModel = new DefaultListModel<>();
|
||||
modelList = new JList<>(modelListModel);
|
||||
modelList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); // 设置为单选
|
||||
JScrollPane modelScrollPane = new JScrollPane(modelList);
|
||||
add(modelScrollPane, BorderLayout.CENTER);
|
||||
|
||||
// 添加双击事件,编辑标签文件
|
||||
// 双击编辑标签文件
|
||||
modelList.addMouseListener(new MouseAdapter() {
|
||||
public void mouseClicked(MouseEvent e) {
|
||||
if (e.getClickCount() == 2) {
|
||||
int index = modelList.locationToIndex(e.getPoint());
|
||||
if (index >= 0) {
|
||||
String item = modelListModel.getElementAt(index);
|
||||
// 解析标签文件路径
|
||||
String[] parts = item.split("\n");
|
||||
if (parts.length >= 2) {
|
||||
String labelFilePath = parts[1].replace("标签文件: ", "").trim();
|
||||
FileEditor.openFileEditor(labelFilePath);
|
||||
}
|
||||
ModelInfo item = modelListModel.getElementAt(index);
|
||||
String labelFilePath = item.getLabelFilePath();
|
||||
FileEditor.openFileEditor(labelFilePath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -45,12 +40,9 @@ public class ModelManager extends JPanel {
|
|||
|
||||
// 加载模型
|
||||
public void loadModel(JFrame parent) {
|
||||
// 获取桌面目录
|
||||
File desktopDir = FileSystemView.getFileSystemView().getHomeDirectory();
|
||||
JFileChooser fileChooser = new JFileChooser(desktopDir);
|
||||
fileChooser.setDialogTitle("选择模型文件");
|
||||
|
||||
// 设置模型文件过滤器,只显示 .onnx 文件
|
||||
FileNameExtensionFilter modelFilter = new FileNameExtensionFilter("ONNX模型文件 (*.onnx)", "onnx");
|
||||
fileChooser.setFileFilter(modelFilter);
|
||||
|
||||
|
@ -60,8 +52,6 @@ public class ModelManager extends JPanel {
|
|||
|
||||
// 选择对应的标签文件
|
||||
fileChooser.setDialogTitle("选择标签文件");
|
||||
|
||||
// 设置标签文件过滤器,只显示 .txt 文件
|
||||
FileNameExtensionFilter labelFilter = new FileNameExtensionFilter("标签文件 (*.txt)", "txt");
|
||||
fileChooser.setFileFilter(labelFilter);
|
||||
|
||||
|
@ -69,9 +59,9 @@ public class ModelManager extends JPanel {
|
|||
if (returnValue == JFileChooser.APPROVE_OPTION) {
|
||||
File labelFile = fileChooser.getSelectedFile();
|
||||
|
||||
// 将模型和标签文件添加到列表中
|
||||
String item = "模型文件: " + modelFile.getAbsolutePath() + "\n标签文件: " + labelFile.getAbsolutePath();
|
||||
modelListModel.addElement(item);
|
||||
// 添加模型信息到列表
|
||||
ModelInfo modelInfo = new ModelInfo(modelFile.getAbsolutePath(), labelFile.getAbsolutePath());
|
||||
modelListModel.addElement(modelInfo);
|
||||
} else {
|
||||
JOptionPane.showMessageDialog(parent, "未选择标签文件。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
}
|
||||
|
@ -79,4 +69,14 @@ public class ModelManager extends JPanel {
|
|||
JOptionPane.showMessageDialog(parent, "未选择模型文件。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取选中的模型
|
||||
public ModelInfo getSelectedModel() {
|
||||
return modelList.getSelectedValue();
|
||||
}
|
||||
|
||||
// 如果需要在外部访问 modelList,可以添加以下方法
|
||||
public DefaultListModel<ModelInfo> getModelList() {
|
||||
return modelListModel;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package com.ly.onnx;
|
||||
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
|
||||
public class OnnxModelInference {
|
||||
|
||||
private String modelFilePath;
|
||||
|
||||
private String labelFilePath;
|
||||
|
||||
private String[] labels;
|
||||
|
||||
OrtEnvironment environment = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,410 @@
|
|||
package com.ly.onnx.engine;
|
||||
|
||||
import ai.onnxruntime.*;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
|
||||
import org.opencv.core.*;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.*;
|
||||
|
||||
public class InferenceEngine {
|
||||
|
||||
private OrtEnvironment environment;
|
||||
private OrtSession.SessionOptions sessionOptions;
|
||||
private OrtSession session;
|
||||
|
||||
private String modelPath;
|
||||
private List<String> labels;
|
||||
|
||||
// 用于存储图像预处理信息的类变量
|
||||
private int origWidth;
|
||||
private int origHeight;
|
||||
private int newWidth;
|
||||
private int newHeight;
|
||||
private float scalingFactor;
|
||||
private int xOffset;
|
||||
private int yOffset;
|
||||
|
||||
static {
|
||||
nu.pattern.OpenCV.loadLocally();
|
||||
}
|
||||
|
||||
public InferenceEngine(String modelPath, List<String> labels) {
|
||||
this.modelPath = modelPath;
|
||||
this.labels = labels;
|
||||
init();
|
||||
}
|
||||
|
||||
public void init() {
|
||||
try {
|
||||
environment = OrtEnvironment.getEnvironment();
|
||||
sessionOptions = new OrtSession.SessionOptions();
|
||||
sessionOptions.addCUDA(0); // 使用 GPU
|
||||
session = environment.createSession(modelPath, sessionOptions);
|
||||
logModelInfo(session);
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("模型加载失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
public InferenceResult infer(float[] inputData, int w, int h, Map<String, Object> preprocessParams) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
// 从 Map 中获取偏移相关的变量
|
||||
int origWidth = (int) preprocessParams.get("origWidth");
|
||||
int origHeight = (int) preprocessParams.get("origHeight");
|
||||
float scalingFactor = (float) preprocessParams.get("scalingFactor");
|
||||
int xOffset = (int) preprocessParams.get("xOffset");
|
||||
int yOffset = (int) preprocessParams.get("yOffset");
|
||||
|
||||
try {
|
||||
Map<String, NodeInfo> inputInfo = session.getInputInfo();
|
||||
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
|
||||
|
||||
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
|
||||
|
||||
// 创建输入张量时,使用 CHW 格式的数据
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
|
||||
|
||||
// 执行推理
|
||||
long inferenceStart = System.currentTimeMillis();
|
||||
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
|
||||
long inferenceEnd = System.currentTimeMillis();
|
||||
System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms");
|
||||
|
||||
// 解析推理结果
|
||||
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
|
||||
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 6]
|
||||
|
||||
// 设定置信度阈值
|
||||
float confidenceThreshold = 0.25f; // 您可以根据需要调整
|
||||
|
||||
// 根据模型的输出结果解析边界框
|
||||
List<BoundingBox> boxes = new ArrayList<>();
|
||||
for (float[] data : outputData[0]) { // 遍历所有检测框
|
||||
// 根据模型输出格式,解析中心坐标和宽高
|
||||
float x_center = data[0];
|
||||
float y_center = data[1];
|
||||
float width = data[2];
|
||||
float height = data[3];
|
||||
float confidence = data[4];
|
||||
|
||||
if (confidence >= confidenceThreshold) {
|
||||
// 将中心坐标转换为左上角和右下角坐标
|
||||
float x1 = x_center - width / 2;
|
||||
float y1 = y_center - height / 2;
|
||||
float x2 = x_center + width / 2;
|
||||
float y2 = y_center + height / 2;
|
||||
|
||||
// 调整坐标,减去偏移并除以缩放因子
|
||||
float x1Adjusted = (x1 - xOffset) / scalingFactor;
|
||||
float y1Adjusted = (y1 - yOffset) / scalingFactor;
|
||||
float x2Adjusted = (x2 - xOffset) / scalingFactor;
|
||||
float y2Adjusted = (y2 - yOffset) / scalingFactor;
|
||||
|
||||
// 确保坐标的正确顺序
|
||||
float xMinAdjusted = Math.min(x1Adjusted, x2Adjusted);
|
||||
float xMaxAdjusted = Math.max(x1Adjusted, x2Adjusted);
|
||||
float yMinAdjusted = Math.min(y1Adjusted, y2Adjusted);
|
||||
float yMaxAdjusted = Math.max(y1Adjusted, y2Adjusted);
|
||||
|
||||
// 确保坐标在原始图像范围内
|
||||
int x = (int) Math.max(0, xMinAdjusted);
|
||||
int y = (int) Math.max(0, yMinAdjusted);
|
||||
int xMax = (int) Math.min(origWidth, xMaxAdjusted);
|
||||
int yMax = (int) Math.min(origHeight, yMaxAdjusted);
|
||||
int wBox = xMax - x;
|
||||
int hBox = yMax - y;
|
||||
|
||||
// 仅当宽度和高度为正时,才添加边界框
|
||||
if (wBox > 0 && hBox > 0) {
|
||||
// 使用您的单一标签
|
||||
String label = labels.get(0);
|
||||
|
||||
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)
|
||||
long nmsStart = System.currentTimeMillis();
|
||||
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
|
||||
|
||||
System.out.println("检测到的标签:" + JSON.toJSONString(nmsBoxes));
|
||||
if (!nmsBoxes.isEmpty()) {
|
||||
for (BoundingBox box : nmsBoxes) {
|
||||
System.out.println(box);
|
||||
}
|
||||
}
|
||||
long nmsEnd = System.currentTimeMillis();
|
||||
System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms");
|
||||
|
||||
// 封装结果并返回
|
||||
InferenceResult inferenceResult = new InferenceResult();
|
||||
inferenceResult.setBoundingBoxes(nmsBoxes);
|
||||
|
||||
long endTime = System.currentTimeMillis();
|
||||
System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms");
|
||||
|
||||
return inferenceResult;
|
||||
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("推理失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 计算两个边界框的 IoU
|
||||
private float computeIoU(BoundingBox box1, BoundingBox box2) {
|
||||
int x1 = Math.max(box1.getX(), box2.getX());
|
||||
int y1 = Math.max(box1.getY(), box2.getY());
|
||||
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
|
||||
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
|
||||
|
||||
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
|
||||
int box1Area = box1.getWidth() * box1.getHeight();
|
||||
int box2Area = box2.getWidth() * box2.getHeight();
|
||||
|
||||
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)方法
|
||||
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
|
||||
// 按置信度排序(从高到低)
|
||||
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
|
||||
|
||||
List<BoundingBox> result = new ArrayList<>();
|
||||
|
||||
while (!boxes.isEmpty()) {
|
||||
BoundingBox bestBox = boxes.remove(0);
|
||||
result.add(bestBox);
|
||||
|
||||
Iterator<BoundingBox> iterator = boxes.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
BoundingBox box = iterator.next();
|
||||
if (computeIoU(bestBox, box) > iouThreshold) {
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// 打印模型信息
|
||||
private void logModelInfo(OrtSession session) {
|
||||
System.out.println("模型输入信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输入名称: " + name);
|
||||
System.out.println("输入信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
System.out.println("模型输出信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输出名称: " + name);
|
||||
System.out.println("输出信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// 加载 OpenCV 库
|
||||
|
||||
// 初始化标签列表(只有一个标签)
|
||||
List<String> labels = Arrays.asList("person");
|
||||
|
||||
// 创建 InferenceEngine 实例
|
||||
InferenceEngine inferenceEngine = new InferenceEngine("C:\\Users\\ly\\Desktop\\person.onnx", labels);
|
||||
|
||||
for (int j = 0; j < 10; j++) {
|
||||
try {
|
||||
// 加载图片
|
||||
Mat inputImage = Imgcodecs.imread("C:\\Users\\ly\\Desktop\\10230731212230.png");
|
||||
|
||||
// 预处理图像
|
||||
long l1 = System.currentTimeMillis();
|
||||
Map<String, Object> preprocessResult = inferenceEngine.preprocessImage(inputImage);
|
||||
float[] inputData = (float[]) preprocessResult.get("inputData");
|
||||
|
||||
InferenceResult result = null;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
long l = System.currentTimeMillis();
|
||||
result = inferenceEngine.infer(inputData, 640, 640, preprocessResult);
|
||||
System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms");
|
||||
}
|
||||
|
||||
|
||||
// 处理并显示结果
|
||||
System.out.println("推理结果:");
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
System.out.println(box);
|
||||
}
|
||||
|
||||
// 可视化并保存带有边界框的图像
|
||||
Mat outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
|
||||
|
||||
// 保存图片到本地文件
|
||||
String outputFilePath = "output_image_with_boxes.jpg";
|
||||
Imgcodecs.imwrite(outputFilePath, outputImage);
|
||||
|
||||
System.out.println("已保存结果图片: " + outputFilePath);
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 在图像上绘制边界框和标签
|
||||
private Mat drawBoundingBoxes(Mat image, List<BoundingBox> boxes) {
|
||||
for (BoundingBox box : boxes) {
|
||||
// 绘制矩形边界框
|
||||
Imgproc.rectangle(image, new Point(box.getX(), box.getY()),
|
||||
new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight()),
|
||||
new Scalar(0, 0, 255), 2); // 红色边框
|
||||
|
||||
// 绘制标签文字和置信度
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
int baseLine[] = new int[1];
|
||||
Size labelSize = Imgproc.getTextSize(label, Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, 1, baseLine);
|
||||
int top = Math.max(box.getY(), (int) labelSize.height);
|
||||
Imgproc.putText(image, label, new Point(box.getX(), top),
|
||||
Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, new Scalar(255, 255, 255), 1);
|
||||
}
|
||||
|
||||
return image;
|
||||
}
|
||||
|
||||
|
||||
public Map<String, Object> preprocessImage(Mat image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
int origWidth = image.width();
|
||||
int origHeight = image.height();
|
||||
|
||||
// 计算缩放因子
|
||||
float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
int newWidth = Math.round(origWidth * scalingFactor);
|
||||
int newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 计算偏移量以居中图像
|
||||
int xOffset = (targetWidth - newWidth) / 2;
|
||||
int yOffset = (targetHeight - newHeight) / 2;
|
||||
|
||||
// 调整图像尺寸
|
||||
Mat resizedImage = new Mat();
|
||||
Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR);
|
||||
|
||||
// 转换为 RGB 并归一化
|
||||
Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
|
||||
resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
|
||||
|
||||
// 创建填充后的图像
|
||||
Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
|
||||
Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
|
||||
resizedImage.copyTo(paddedImage.submat(roi));
|
||||
|
||||
// 将图像数据转换为数组
|
||||
int imageSize = targetWidth * targetHeight;
|
||||
float[] chwData = new float[3 * imageSize];
|
||||
float[] hwcData = new float[3 * imageSize];
|
||||
paddedImage.get(0, 0, hwcData);
|
||||
|
||||
// 转换为 CHW 格式
|
||||
int channelSize = imageSize;
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int i = 0; i < imageSize; i++) {
|
||||
chwData[c * channelSize + i] = hwcData[i * 3 + c];
|
||||
}
|
||||
}
|
||||
|
||||
// 释放图像资源
|
||||
resizedImage.release();
|
||||
paddedImage.release();
|
||||
|
||||
// 将预处理结果和偏移信息存入 Map
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("inputData", chwData);
|
||||
result.put("origWidth", origWidth);
|
||||
result.put("origHeight", origHeight);
|
||||
result.put("scalingFactor", scalingFactor);
|
||||
result.put("xOffset", xOffset);
|
||||
result.put("yOffset", yOffset);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
// 图像预处理
|
||||
// public float[] preprocessImage(Mat image) {
|
||||
// int targetWidth = 640;
|
||||
// int targetHeight = 640;
|
||||
//
|
||||
// origWidth = image.width();
|
||||
// origHeight = image.height();
|
||||
//
|
||||
// // 计算缩放因子
|
||||
// scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
//
|
||||
// // 计算新的图像尺寸
|
||||
// newWidth = Math.round(origWidth * scalingFactor);
|
||||
// newHeight = Math.round(origHeight * scalingFactor);
|
||||
//
|
||||
// // 计算偏移量以居中图像
|
||||
// xOffset = (targetWidth - newWidth) / 2;
|
||||
// yOffset = (targetHeight - newHeight) / 2;
|
||||
//
|
||||
// // 调整图像尺寸
|
||||
// Mat resizedImage = new Mat();
|
||||
// Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR);
|
||||
//
|
||||
// // 转换为 RGB 并归一化
|
||||
// Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
|
||||
// resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
|
||||
//
|
||||
// // 创建填充后的图像
|
||||
// Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
|
||||
// Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
|
||||
// resizedImage.copyTo(paddedImage.submat(roi));
|
||||
//
|
||||
// // 将图像数据转换为数组
|
||||
// int imageSize = targetWidth * targetHeight;
|
||||
// float[] chwData = new float[3 * imageSize];
|
||||
// float[] hwcData = new float[3 * imageSize];
|
||||
// paddedImage.get(0, 0, hwcData);
|
||||
//
|
||||
// // 转换为 CHW 格式
|
||||
// int channelSize = imageSize;
|
||||
// for (int c = 0; c < 3; c++) {
|
||||
// for (int i = 0; i < imageSize; i++) {
|
||||
// chwData[c * channelSize + i] = hwcData[i * 3 + c];
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 释放图像资源
|
||||
// resizedImage.release();
|
||||
// paddedImage.release();
|
||||
//
|
||||
// return chwData;
|
||||
// }
|
||||
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
package com.ly.onnx.engine;
|
||||
|
||||
import ai.onnxruntime.*;
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
public class InferenceEngine_up {
|
||||
|
||||
OrtEnvironment environment = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
|
||||
|
||||
private String modelPath;
|
||||
private List<String> labels;
|
||||
|
||||
// 添加用于存储图像预处理信息的类变量
|
||||
private int origWidth;
|
||||
private int origHeight;
|
||||
private int newWidth;
|
||||
private int newHeight;
|
||||
private float scalingFactor;
|
||||
private int xOffset;
|
||||
private int yOffset;
|
||||
|
||||
public InferenceEngine_up(String modelPath, List<String> labels) {
|
||||
this.modelPath = modelPath;
|
||||
this.labels = labels;
|
||||
init();
|
||||
}
|
||||
|
||||
public void init() {
|
||||
OrtSession session = null;
|
||||
try {
|
||||
sessionOptions.addCUDA(0);
|
||||
session = environment.createSession(modelPath, sessionOptions);
|
||||
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
logModelInfo(session);
|
||||
}
|
||||
|
||||
public InferenceResult infer(float[] inputData, int w, int h) {
|
||||
// 创建ONNX输入Tensor
|
||||
try (OrtSession session = environment.createSession(modelPath, sessionOptions)) {
|
||||
Map<String, NodeInfo> inputInfo = session.getInputInfo();
|
||||
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
|
||||
|
||||
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
|
||||
|
||||
// 执行推理
|
||||
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
|
||||
|
||||
// 解析推理结果
|
||||
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
|
||||
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5]
|
||||
|
||||
long l = System.currentTimeMillis();
|
||||
// 设定置信度阈值
|
||||
float confidenceThreshold = 0.5f; // 您可以根据需要调整
|
||||
|
||||
// 根据模型的输出结果解析边界框
|
||||
List<BoundingBox> boxes = new ArrayList<>();
|
||||
for (float[] data : outputData[0]) { // 遍历所有检测框
|
||||
float confidence = data[4];
|
||||
if (confidence >= confidenceThreshold) {
|
||||
float xCenter = data[0];
|
||||
float yCenter = data[1];
|
||||
float widthBox = data[2];
|
||||
float heightBox = data[3];
|
||||
|
||||
// 调整坐标,减去偏移并除以缩放因子
|
||||
float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
|
||||
float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
|
||||
float widthAdjusted = widthBox / scalingFactor;
|
||||
float heightAdjusted = heightBox / scalingFactor;
|
||||
|
||||
// 计算左上角坐标
|
||||
int x = (int) (xCenterAdjusted - widthAdjusted / 2);
|
||||
int y = (int) (yCenterAdjusted - heightAdjusted / 2);
|
||||
int wBox = (int) widthAdjusted;
|
||||
int hBox = (int) heightAdjusted;
|
||||
|
||||
// 确保坐标在原始图像范围内
|
||||
if (x < 0) x = 0;
|
||||
if (y < 0) y = 0;
|
||||
if (x + wBox > origWidth) wBox = origWidth - x;
|
||||
if (y + hBox > origHeight) hBox = origHeight - y;
|
||||
|
||||
String label = "person"; // 由于只有一个类别
|
||||
|
||||
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
|
||||
}
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)
|
||||
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
|
||||
System.out.println("耗时:"+((System.currentTimeMillis() - l)));
|
||||
// 封装结果并返回
|
||||
InferenceResult inferenceResult = new InferenceResult();
|
||||
inferenceResult.setBoundingBoxes(nmsBoxes);
|
||||
return inferenceResult;
|
||||
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("推理失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 计算两个边界框的 IoU
|
||||
private float computeIoU(BoundingBox box1, BoundingBox box2) {
|
||||
int x1 = Math.max(box1.getX(), box2.getX());
|
||||
int y1 = Math.max(box1.getY(), box2.getY());
|
||||
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
|
||||
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
|
||||
|
||||
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
|
||||
int box1Area = box1.getWidth() * box1.getHeight();
|
||||
int box2Area = box2.getWidth() * box2.getHeight();
|
||||
|
||||
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
|
||||
}
|
||||
|
||||
// 打印模型信息
|
||||
private void logModelInfo(OrtSession session) {
|
||||
System.out.println("模型输入信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输入名称: " + name);
|
||||
System.out.println("输入信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
System.out.println("模型输出信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输出名称: " + name);
|
||||
System.out.println("输出信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// 初始化标签列表
|
||||
List<String> labels = Arrays.asList("person");
|
||||
|
||||
// 创建 InferenceEngine 实例
|
||||
InferenceEngine_up inferenceEngine = new InferenceEngine_up("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
|
||||
|
||||
try {
|
||||
// 加载图片
|
||||
File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
|
||||
BufferedImage inputImage = ImageIO.read(imageFile);
|
||||
|
||||
// 预处理图像
|
||||
float[] inputData = inferenceEngine.preprocessImage(inputImage);
|
||||
|
||||
// 执行推理
|
||||
InferenceResult result = null;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
long l = System.currentTimeMillis();
|
||||
result = inferenceEngine.infer(inputData, 640, 640);
|
||||
System.out.println(System.currentTimeMillis() - l);
|
||||
}
|
||||
|
||||
// 处理并显示结果
|
||||
System.out.println("推理结果:");
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
System.out.println(box);
|
||||
}
|
||||
|
||||
// 可视化并保存带有边界框的图像
|
||||
BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
|
||||
|
||||
// 保存图片到本地文件
|
||||
File outputFile = new File("output_image_with_boxes.jpg");
|
||||
ImageIO.write(outputImage, "jpg", outputFile);
|
||||
|
||||
System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// 在图像上绘制边界框和标签
|
||||
BufferedImage drawBoundingBoxes(BufferedImage image, List<BoundingBox> boxes) {
|
||||
Graphics2D g = image.createGraphics();
|
||||
g.setColor(Color.RED); // 设置绘制边界框的颜色
|
||||
g.setStroke(new BasicStroke(2)); // 设置线条粗细
|
||||
|
||||
for (BoundingBox box : boxes) {
|
||||
// 绘制矩形边界框
|
||||
g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
|
||||
// 绘制标签文字和置信度
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
g.setFont(new Font("Arial", Font.PLAIN, 12));
|
||||
g.drawString(label, box.getX(), box.getY() - 5);
|
||||
}
|
||||
|
||||
g.dispose(); // 释放资源
|
||||
return image;
|
||||
}
|
||||
|
||||
// 图像预处理
|
||||
float[] preprocessImage(BufferedImage image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
origWidth = image.getWidth();
|
||||
origHeight = image.getHeight();
|
||||
|
||||
// 计算缩放因子
|
||||
scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
newWidth = Math.round(origWidth * scalingFactor);
|
||||
newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 计算偏移量以居中图像
|
||||
xOffset = (targetWidth - newWidth) / 2;
|
||||
yOffset = (targetHeight - newHeight) / 2;
|
||||
|
||||
// 创建一个新的BufferedImage
|
||||
BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
|
||||
Graphics2D g = resizedImage.createGraphics();
|
||||
|
||||
// 填充背景为黑色
|
||||
g.setColor(Color.BLACK);
|
||||
g.fillRect(0, 0, targetWidth, targetHeight);
|
||||
|
||||
// 绘制缩放后的图像到新的图像上
|
||||
g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
|
||||
g.dispose();
|
||||
|
||||
float[] inputData = new float[3 * targetWidth * targetHeight];
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int y = 0; y < targetHeight; y++) {
|
||||
for (int x = 0; x < targetWidth; x++) {
|
||||
int rgb = resizedImage.getRGB(x, y);
|
||||
float value = 0f;
|
||||
if (c == 0) {
|
||||
value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
|
||||
} else if (c == 1) {
|
||||
value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
|
||||
} else if (c == 2) {
|
||||
value = (rgb & 0xFF) / 255.0f; // Blue channel
|
||||
}
|
||||
inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inputData;
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)方法
|
||||
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
|
||||
// 按置信度排序
|
||||
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
|
||||
|
||||
List<BoundingBox> result = new ArrayList<>();
|
||||
|
||||
while (!boxes.isEmpty()) {
|
||||
BoundingBox bestBox = boxes.remove(0);
|
||||
result.add(bestBox);
|
||||
|
||||
Iterator<BoundingBox> iterator = boxes.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
BoundingBox box = iterator.next();
|
||||
if (computeIoU(bestBox, box) > iouThreshold) {
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// 其他方法保持不变...
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
package com.ly.onnx.engine;
|
||||
|
||||
import ai.onnxruntime.*;
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
public class InferenceEngine_up_1 {
|
||||
|
||||
OrtEnvironment environment = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions sessionOptions = null;
|
||||
OrtSession session = null;
|
||||
private String modelPath;
|
||||
private List<String> labels;
|
||||
|
||||
// 添加用于存储图像预处理信息的类变量
|
||||
private int origWidth;
|
||||
private int origHeight;
|
||||
private int newWidth;
|
||||
private int newHeight;
|
||||
private float scalingFactor;
|
||||
private int xOffset;
|
||||
private int yOffset;
|
||||
|
||||
public InferenceEngine_up_1(String modelPath, List<String> labels) {
|
||||
this.modelPath = modelPath;
|
||||
this.labels = labels;
|
||||
init();
|
||||
}
|
||||
|
||||
public void init() {
|
||||
|
||||
try {
|
||||
sessionOptions = new OrtSession.SessionOptions();
|
||||
sessionOptions.addCUDA(0);
|
||||
session = environment.createSession(modelPath, sessionOptions);
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
logModelInfo(session);
|
||||
}
|
||||
|
||||
public InferenceResult infer(float[] inputData, int w, int h) {
|
||||
// 创建ONNX输入Tensor
|
||||
try {
|
||||
Map<String, NodeInfo> inputInfo = session.getInputInfo();
|
||||
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
|
||||
|
||||
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
|
||||
|
||||
// 执行推理
|
||||
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
|
||||
|
||||
// 解析推理结果
|
||||
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
|
||||
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5]
|
||||
|
||||
long l = System.currentTimeMillis();
|
||||
// 设定置信度阈值
|
||||
float confidenceThreshold = 0.5f; // 您可以根据需要调整
|
||||
|
||||
// 根据模型的输出结果解析边界框
|
||||
List<BoundingBox> boxes = new ArrayList<>();
|
||||
for (float[] data : outputData[0]) { // 遍历所有检测框
|
||||
float confidence = data[4];
|
||||
if (confidence >= confidenceThreshold) {
|
||||
float xCenter = data[0];
|
||||
float yCenter = data[1];
|
||||
float widthBox = data[2];
|
||||
float heightBox = data[3];
|
||||
|
||||
// 调整坐标,减去偏移并除以缩放因子
|
||||
float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
|
||||
float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
|
||||
float widthAdjusted = widthBox / scalingFactor;
|
||||
float heightAdjusted = heightBox / scalingFactor;
|
||||
|
||||
// 计算左上角坐标
|
||||
int x = (int) (xCenterAdjusted - widthAdjusted / 2);
|
||||
int y = (int) (yCenterAdjusted - heightAdjusted / 2);
|
||||
int wBox = (int) widthAdjusted;
|
||||
int hBox = (int) heightAdjusted;
|
||||
|
||||
// 确保坐标在原始图像范围内
|
||||
if (x < 0) x = 0;
|
||||
if (y < 0) y = 0;
|
||||
if (x + wBox > origWidth) wBox = origWidth - x;
|
||||
if (y + hBox > origHeight) hBox = origHeight - y;
|
||||
|
||||
String label = "person"; // 由于只有一个类别
|
||||
|
||||
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
|
||||
}
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)
|
||||
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
|
||||
System.out.println("耗时:"+((System.currentTimeMillis() - l)));
|
||||
// 封装结果并返回
|
||||
InferenceResult inferenceResult = new InferenceResult();
|
||||
inferenceResult.setBoundingBoxes(nmsBoxes);
|
||||
return inferenceResult;
|
||||
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("推理失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 计算两个边界框的 IoU
|
||||
private float computeIoU(BoundingBox box1, BoundingBox box2) {
|
||||
int x1 = Math.max(box1.getX(), box2.getX());
|
||||
int y1 = Math.max(box1.getY(), box2.getY());
|
||||
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
|
||||
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
|
||||
|
||||
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
|
||||
int box1Area = box1.getWidth() * box1.getHeight();
|
||||
int box2Area = box2.getWidth() * box2.getHeight();
|
||||
|
||||
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
|
||||
}
|
||||
|
||||
// 打印模型信息
|
||||
private void logModelInfo(OrtSession session) {
|
||||
System.out.println("模型输入信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输入名称: " + name);
|
||||
System.out.println("输入信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
System.out.println("模型输出信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输出名称: " + name);
|
||||
System.out.println("输出信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// 初始化标签列表
|
||||
List<String> labels = Arrays.asList("person");
|
||||
|
||||
// 创建 InferenceEngine 实例
|
||||
InferenceEngine_up_1 inferenceEngine = new InferenceEngine_up_1("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
|
||||
|
||||
try {
|
||||
// 加载图片
|
||||
File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
|
||||
BufferedImage inputImage = ImageIO.read(imageFile);
|
||||
|
||||
// 预处理图像
|
||||
float[] inputData = inferenceEngine.preprocessImage(inputImage);
|
||||
|
||||
// 执行推理
|
||||
InferenceResult result = null;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
long l = System.currentTimeMillis();
|
||||
result = inferenceEngine.infer(inputData, 640, 640);
|
||||
System.out.println(System.currentTimeMillis() - l);
|
||||
}
|
||||
|
||||
// 处理并显示结果
|
||||
System.out.println("推理结果:");
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
System.out.println(box);
|
||||
}
|
||||
|
||||
// 可视化并保存带有边界框的图像
|
||||
BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
|
||||
|
||||
// 保存图片到本地文件
|
||||
File outputFile = new File("output_image_with_boxes.jpg");
|
||||
ImageIO.write(outputImage, "jpg", outputFile);
|
||||
|
||||
System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// 在图像上绘制边界框和标签
|
||||
private BufferedImage drawBoundingBoxes(BufferedImage image, List<BoundingBox> boxes) {
|
||||
Graphics2D g = image.createGraphics();
|
||||
g.setColor(Color.RED); // 设置绘制边界框的颜色
|
||||
g.setStroke(new BasicStroke(2)); // 设置线条粗细
|
||||
|
||||
for (BoundingBox box : boxes) {
|
||||
// 绘制矩形边界框
|
||||
g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
|
||||
// 绘制标签文字和置信度
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
g.setFont(new Font("Arial", Font.PLAIN, 12));
|
||||
g.drawString(label, box.getX(), box.getY() - 5);
|
||||
}
|
||||
|
||||
g.dispose(); // 释放资源
|
||||
return image;
|
||||
}
|
||||
|
||||
// 图像预处理
|
||||
private float[] preprocessImage(BufferedImage image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
origWidth = image.getWidth();
|
||||
origHeight = image.getHeight();
|
||||
|
||||
// 计算缩放因子
|
||||
scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
newWidth = Math.round(origWidth * scalingFactor);
|
||||
newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 计算偏移量以居中图像
|
||||
xOffset = (targetWidth - newWidth) / 2;
|
||||
yOffset = (targetHeight - newHeight) / 2;
|
||||
|
||||
// 创建一个新的BufferedImage
|
||||
BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
|
||||
Graphics2D g = resizedImage.createGraphics();
|
||||
|
||||
// 填充背景为黑色
|
||||
g.setColor(Color.BLACK);
|
||||
g.fillRect(0, 0, targetWidth, targetHeight);
|
||||
|
||||
// 绘制缩放后的图像到新的图像上
|
||||
g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
|
||||
g.dispose();
|
||||
|
||||
float[] inputData = new float[3 * targetWidth * targetHeight];
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int y = 0; y < targetHeight; y++) {
|
||||
for (int x = 0; x < targetWidth; x++) {
|
||||
int rgb = resizedImage.getRGB(x, y);
|
||||
float value = 0f;
|
||||
if (c == 0) {
|
||||
value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
|
||||
} else if (c == 1) {
|
||||
value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
|
||||
} else if (c == 2) {
|
||||
value = (rgb & 0xFF) / 255.0f; // Blue channel
|
||||
}
|
||||
inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inputData;
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)方法
|
||||
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
|
||||
// 按置信度排序
|
||||
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
|
||||
|
||||
List<BoundingBox> result = new ArrayList<>();
|
||||
|
||||
while (!boxes.isEmpty()) {
|
||||
BoundingBox bestBox = boxes.remove(0);
|
||||
result.add(bestBox);
|
||||
|
||||
Iterator<BoundingBox> iterator = boxes.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
BoundingBox box = iterator.next();
|
||||
if (computeIoU(bestBox, box) > iouThreshold) {
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// 其他方法保持不变...
|
||||
}
|
|
@ -0,0 +1,314 @@
|
|||
package com.ly.onnx.engine;
|
||||
|
||||
import ai.onnxruntime.*;
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
public class InferenceEngine_up_2 {
|
||||
|
||||
private OrtEnvironment environment;
|
||||
private OrtSession.SessionOptions sessionOptions;
|
||||
private OrtSession session; // 将 session 作为类的成员变量
|
||||
|
||||
private String modelPath;
|
||||
private List<String> labels;
|
||||
|
||||
// 添加用于存储图像预处理信息的类变量
|
||||
private int origWidth;
|
||||
private int origHeight;
|
||||
private int newWidth;
|
||||
private int newHeight;
|
||||
private float scalingFactor;
|
||||
private int xOffset;
|
||||
private int yOffset;
|
||||
|
||||
public InferenceEngine_up_2(String modelPath, List<String> labels) {
|
||||
this.modelPath = modelPath;
|
||||
this.labels = labels;
|
||||
init();
|
||||
}
|
||||
|
||||
public void init() {
|
||||
try {
|
||||
environment = OrtEnvironment.getEnvironment();
|
||||
sessionOptions = new OrtSession.SessionOptions();
|
||||
sessionOptions.addCUDA(0); // 使用 GPU
|
||||
session = environment.createSession(modelPath, sessionOptions);
|
||||
logModelInfo(session);
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("模型加载失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
public InferenceResult infer(float[] inputData, int w, int h) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
try {
|
||||
Map<String, NodeInfo> inputInfo = session.getInputInfo();
|
||||
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
|
||||
|
||||
long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状
|
||||
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
|
||||
|
||||
// 执行推理
|
||||
long inferenceStart = System.currentTimeMillis();
|
||||
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
|
||||
long inferenceEnd = System.currentTimeMillis();
|
||||
System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms");
|
||||
|
||||
// 解析推理结果
|
||||
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
|
||||
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 5]
|
||||
|
||||
// 设定置信度阈值
|
||||
float confidenceThreshold = 0.5f; // 您可以根据需要调整
|
||||
|
||||
// 根据模型的输出结果解析边界框
|
||||
List<BoundingBox> boxes = new ArrayList<>();
|
||||
for (float[] data : outputData[0]) { // 遍历所有检测框
|
||||
float confidence = data[4];
|
||||
if (confidence >= confidenceThreshold) {
|
||||
float xCenter = data[0];
|
||||
float yCenter = data[1];
|
||||
float widthBox = data[2];
|
||||
float heightBox = data[3];
|
||||
|
||||
// 调整坐标,减去偏移并除以缩放因子
|
||||
float xCenterAdjusted = (xCenter - xOffset) / scalingFactor;
|
||||
float yCenterAdjusted = (yCenter - yOffset) / scalingFactor;
|
||||
float widthAdjusted = widthBox / scalingFactor;
|
||||
float heightAdjusted = heightBox / scalingFactor;
|
||||
|
||||
// 计算左上角坐标
|
||||
int x = (int) (xCenterAdjusted - widthAdjusted / 2);
|
||||
int y = (int) (yCenterAdjusted - heightAdjusted / 2);
|
||||
int wBox = (int) widthAdjusted;
|
||||
int hBox = (int) heightAdjusted;
|
||||
|
||||
// 确保坐标在原始图像范围内
|
||||
if (x < 0) x = 0;
|
||||
if (y < 0) y = 0;
|
||||
if (x + wBox > origWidth) wBox = origWidth - x;
|
||||
if (y + hBox > origHeight) hBox = origHeight - y;
|
||||
|
||||
String label = "person"; // 由于只有一个类别
|
||||
|
||||
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
|
||||
}
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)
|
||||
long nmsStart = System.currentTimeMillis();
|
||||
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
|
||||
long nmsEnd = System.currentTimeMillis();
|
||||
System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms");
|
||||
|
||||
// 封装结果并返回
|
||||
InferenceResult inferenceResult = new InferenceResult();
|
||||
inferenceResult.setBoundingBoxes(nmsBoxes);
|
||||
|
||||
long endTime = System.currentTimeMillis();
|
||||
System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms");
|
||||
|
||||
return inferenceResult;
|
||||
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException("推理失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 计算两个边界框的 IoU
|
||||
private float computeIoU(BoundingBox box1, BoundingBox box2) {
|
||||
int x1 = Math.max(box1.getX(), box2.getX());
|
||||
int y1 = Math.max(box1.getY(), box2.getY());
|
||||
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
|
||||
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
|
||||
|
||||
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
|
||||
int box1Area = box1.getWidth() * box1.getHeight();
|
||||
int box2Area = box2.getWidth() * box2.getHeight();
|
||||
|
||||
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
|
||||
}
|
||||
|
||||
// 打印模型信息
|
||||
private void logModelInfo(OrtSession session) {
|
||||
System.out.println("模型输入信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输入名称: " + name);
|
||||
System.out.println("输入信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
System.out.println("模型输出信息:");
|
||||
try {
|
||||
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
|
||||
String name = entry.getKey();
|
||||
NodeInfo info = entry.getValue();
|
||||
System.out.println("输出名称: " + name);
|
||||
System.out.println("输出信息: " + info.toString());
|
||||
}
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// 初始化标签列表
|
||||
List<String> labels = Arrays.asList("person");
|
||||
|
||||
// 创建 InferenceEngine 实例
|
||||
InferenceEngine_up_2 inferenceEngine = new InferenceEngine_up_2("D:\\work\\work_space\\java\\onnx-inference4j-play\\src\\main\\resources\\model\\best.onnx", labels);
|
||||
|
||||
try {
|
||||
// 加载图片
|
||||
File imageFile = new File("C:\\Users\\ly\\Desktop\\resuouce\\image\\1.jpg");
|
||||
BufferedImage inputImage = ImageIO.read(imageFile);
|
||||
|
||||
// 预处理图像
|
||||
long l1 = System.currentTimeMillis();
|
||||
float[] inputData = inferenceEngine.preprocessImage(inputImage);
|
||||
System.out.println("转"+(System.currentTimeMillis() - l1));
|
||||
// 执行推理
|
||||
InferenceResult result = null;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
long l = System.currentTimeMillis();
|
||||
result = inferenceEngine.infer(inputData, 640, 640);
|
||||
System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms");
|
||||
}
|
||||
|
||||
// 处理并显示结果
|
||||
System.out.println("推理结果:");
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
System.out.println(box);
|
||||
}
|
||||
|
||||
// 可视化并保存带有边界框的图像
|
||||
BufferedImage outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes());
|
||||
|
||||
// 保存图片到本地文件
|
||||
File outputFile = new File("output_image_with_boxes.jpg");
|
||||
ImageIO.write(outputImage, "jpg", outputFile);
|
||||
|
||||
System.out.println("已保存结果图片: " + outputFile.getAbsolutePath());
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// 在图像上绘制边界框和标签
|
||||
private BufferedImage drawBoundingBoxes(BufferedImage image, List<BoundingBox> boxes) {
|
||||
Graphics2D g = image.createGraphics();
|
||||
g.setColor(Color.RED); // 设置绘制边界框的颜色
|
||||
g.setStroke(new BasicStroke(2)); // 设置线条粗细
|
||||
|
||||
for (BoundingBox box : boxes) {
|
||||
// 绘制矩形边界框
|
||||
g.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight());
|
||||
// 绘制标签文字和置信度
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
g.setFont(new Font("Arial", Font.PLAIN, 12));
|
||||
g.drawString(label, box.getX(), box.getY() - 5);
|
||||
}
|
||||
|
||||
g.dispose(); // 释放资源
|
||||
return image;
|
||||
}
|
||||
|
||||
// 图像预处理
|
||||
public float[] preprocessImage(BufferedImage image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
origWidth = image.getWidth();
|
||||
origHeight = image.getHeight();
|
||||
|
||||
// 计算缩放因子
|
||||
scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
newWidth = Math.round(origWidth * scalingFactor);
|
||||
newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 计算偏移量以居中图像
|
||||
xOffset = (targetWidth - newWidth) / 2;
|
||||
yOffset = (targetHeight - newHeight) / 2;
|
||||
|
||||
// 创建一个新的BufferedImage
|
||||
BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
|
||||
Graphics2D g = resizedImage.createGraphics();
|
||||
|
||||
// 填充背景为黑色
|
||||
g.setColor(Color.BLACK);
|
||||
g.fillRect(0, 0, targetWidth, targetHeight);
|
||||
|
||||
// 绘制缩放后的图像到新的图像上
|
||||
g.drawImage(image.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH), xOffset, yOffset, null);
|
||||
g.dispose();
|
||||
|
||||
float[] inputData = new float[3 * targetWidth * targetHeight];
|
||||
|
||||
// 开始计时
|
||||
long preprocessStart = System.currentTimeMillis();
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int y = 0; y < targetHeight; y++) {
|
||||
for (int x = 0; x < targetWidth; x++) {
|
||||
int rgb = resizedImage.getRGB(x, y);
|
||||
float value = 0f;
|
||||
if (c == 0) {
|
||||
value = ((rgb >> 16) & 0xFF) / 255.0f; // Red channel
|
||||
} else if (c == 1) {
|
||||
value = ((rgb >> 8) & 0xFF) / 255.0f; // Green channel
|
||||
} else if (c == 2) {
|
||||
value = (rgb & 0xFF) / 255.0f; // Blue channel
|
||||
}
|
||||
inputData[c * targetWidth * targetHeight + y * targetWidth + x] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 结束计时
|
||||
long preprocessEnd = System.currentTimeMillis();
|
||||
System.out.println("图像预处理耗时:" + (preprocessEnd - preprocessStart) + " ms");
|
||||
|
||||
return inputData;
|
||||
}
|
||||
|
||||
// 非极大值抑制(NMS)方法
|
||||
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
|
||||
// 按置信度排序(从高到低)
|
||||
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
|
||||
|
||||
List<BoundingBox> result = new ArrayList<>();
|
||||
|
||||
while (!boxes.isEmpty()) {
|
||||
BoundingBox bestBox = boxes.remove(0);
|
||||
result.add(bestBox);
|
||||
|
||||
Iterator<BoundingBox> iterator = boxes.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
BoundingBox box = iterator.next();
|
||||
if (computeIoU(bestBox, box) > iouThreshold) {
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package com.ly.onnx.model;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class BoundingBox {
|
||||
private int x;
|
||||
private int y;
|
||||
private int width;
|
||||
private int height;
|
||||
private String label;
|
||||
private float confidence;
|
||||
|
||||
// 构造函数、getter 和 setter 方法
|
||||
|
||||
public BoundingBox(int x, int y, int width, int height, String label, float confidence) {
|
||||
this.x = x;
|
||||
this.y = y;
|
||||
this.width = width;
|
||||
this.height = height;
|
||||
this.label = label;
|
||||
this.confidence = confidence;
|
||||
}
|
||||
|
||||
// Getter 和 Setter 方法
|
||||
// ...
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package com.ly.onnx.model;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class InferenceResult {
|
||||
private List<BoundingBox> boundingBoxes = new ArrayList<>();
|
||||
|
||||
public List<BoundingBox> getBoundingBoxes() {
|
||||
return boundingBoxes;
|
||||
}
|
||||
|
||||
public void setBoundingBoxes(List<BoundingBox> boundingBoxes) {
|
||||
this.boundingBoxes = boundingBoxes;
|
||||
}
|
||||
|
||||
// 其他需要的属性和方法
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
package com.ly.onnx.model;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.List;
|
||||
|
||||
public class ModelInfo {
|
||||
private String modelFilePath;
|
||||
private String labelFilePath;
|
||||
private List <String> labels;
|
||||
|
||||
public ModelInfo(String modelFilePath, String labelFilePath) {
|
||||
this.modelFilePath = modelFilePath;
|
||||
this.labelFilePath = labelFilePath;
|
||||
try {
|
||||
this.labels = Files.readAllLines(Paths.get(labelFilePath));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public String getModelFilePath() {
|
||||
return modelFilePath;
|
||||
}
|
||||
|
||||
public String getLabelFilePath() {
|
||||
return labelFilePath;
|
||||
}
|
||||
|
||||
public List<String> getLabels() {
|
||||
return labels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "模型文件: " + modelFilePath + "\n标签文件: " + labelFilePath;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package com.ly.onnx.utils;
|
||||
|
||||
import com.ly.onnx.model.BoundingBox;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Point;
|
||||
import org.opencv.core.Scalar;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.util.List;
|
||||
|
||||
public class DrawImagesUtils {
|
||||
|
||||
|
||||
public static void drawInferenceResult(BufferedImage bufferedImage, List<InferenceResult> result) {
|
||||
|
||||
}
|
||||
|
||||
// 在 Mat 上绘制推理结果
|
||||
public static void drawInferenceResult(Mat image, List<InferenceResult> inferenceResults) {
|
||||
for (InferenceResult result : inferenceResults) {
|
||||
for (BoundingBox box : result.getBoundingBoxes()) {
|
||||
// 绘制矩形
|
||||
Point topLeft = new Point(box.getX(), box.getY());
|
||||
Point bottomRight = new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight());
|
||||
Imgproc.rectangle(image, topLeft, bottomRight, new Scalar(0, 0, 255), 2); // 红色边框
|
||||
|
||||
// 绘制标签
|
||||
String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence());
|
||||
int font = Imgproc.FONT_HERSHEY_SIMPLEX;
|
||||
double fontScale = 0.5;
|
||||
int thickness = 1;
|
||||
|
||||
// 计算文字大小
|
||||
int[] baseLine = new int[1];
|
||||
org.opencv.core.Size labelSize = Imgproc.getTextSize(label, font, fontScale, thickness, baseLine);
|
||||
|
||||
// 确保文字不会超出图像
|
||||
int y = Math.max((int) topLeft.y, (int) labelSize.height);
|
||||
|
||||
// 绘制文字背景
|
||||
Imgproc.rectangle(image, new Point(topLeft.x, y - labelSize.height),
|
||||
new Point(topLeft.x + labelSize.width, y + baseLine[0]),
|
||||
new Scalar(0, 0, 255), Imgproc.FILLED);
|
||||
|
||||
// 绘制文字
|
||||
Imgproc.putText(image, label, new Point(topLeft.x, y),
|
||||
font, fontScale, new Scalar(255, 255, 255), thickness);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
package com.ly.onnx.utils;
|
||||
|
||||
import org.bytedeco.javacv.Frame;
|
||||
import org.opencv.core.*;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ImageUtils {
|
||||
|
||||
// 辅助方法:将 BufferedImage 转换为浮点数组(根据您的模型需求)
|
||||
private static float[] preprocessImage(BufferedImage image) {
|
||||
int width = image.getWidth();
|
||||
int height = image.getHeight();
|
||||
float[] result = new float[width * height * 3]; // 假设是 RGB 图像
|
||||
int idx = 0;
|
||||
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
int pixel = image.getRGB(x, y);
|
||||
// 分别获取 R, G, B 值并归一化(假设归一化到 [0, 1])
|
||||
result[idx++] = ((pixel >> 16) & 0xFF) / 255.0f; // Red
|
||||
result[idx++] = ((pixel >> 8) & 0xFF) / 255.0f; // Green
|
||||
result[idx++] = (pixel & 0xFF) / 255.0f; // Blue
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
public static float[] frameToFloatArray(Frame frame) {
|
||||
// 获取 Frame 的宽度和高度
|
||||
int width = frame.imageWidth;
|
||||
int height = frame.imageHeight;
|
||||
|
||||
// 获取 Frame 的像素数据
|
||||
Buffer buffer = frame.image[0]; // 获取图像数据缓冲区
|
||||
ByteBuffer byteBuffer = (ByteBuffer) buffer; // 假设图像数据是以字节缓冲存储
|
||||
|
||||
// 创建 float 数组来存储图像的 RGB 值
|
||||
float[] result = new float[width * height * 3]; // 假设是 RGB 格式图像
|
||||
int idx = 0;
|
||||
|
||||
// 遍历每个像素,提取 R, G, B 值并归一化到 [0, 1]
|
||||
for (int i = 0; i < byteBuffer.capacity(); i += 3) {
|
||||
// 提取 RGB 通道数据
|
||||
int r = byteBuffer.get(i) & 0xFF; // Red
|
||||
int g = byteBuffer.get(i + 1) & 0xFF; // Green
|
||||
int b = byteBuffer.get(i + 2) & 0xFF; // Blue
|
||||
|
||||
// 将 RGB 值归一化为 float 并存入数组
|
||||
result[idx++] = r / 255.0f;
|
||||
result[idx++] = g / 255.0f;
|
||||
result[idx++] = b / 255.0f;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
// 将 Mat 转换为 BufferedImage
|
||||
public static BufferedImage matToBufferedImage(Mat mat) {
|
||||
int type = BufferedImage.TYPE_3BYTE_BGR;
|
||||
if (mat.channels() == 1) {
|
||||
type = BufferedImage.TYPE_BYTE_GRAY;
|
||||
}
|
||||
int bufferSize = mat.channels() * mat.cols() * mat.rows();
|
||||
byte[] buffer = new byte[bufferSize];
|
||||
mat.get(0, 0, buffer); // 获取所有像素
|
||||
BufferedImage image = new BufferedImage(mat.cols(), mat.rows(), type);
|
||||
final byte[] targetPixels = ((java.awt.image.DataBufferByte) image.getRaster().getDataBuffer()).getData();
|
||||
System.arraycopy(buffer, 0, targetPixels, 0, buffer.length);
|
||||
return image;
|
||||
}
|
||||
|
||||
// 将 Mat 转换为 float 数组,适用于推理
|
||||
public static float[] matToFloatArray(Mat mat) {
|
||||
// 假设 InferenceEngine 需要 RGB 格式的图像
|
||||
Mat rgbMat = new Mat();
|
||||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB);
|
||||
|
||||
// 假设图像已经被预处理(缩放、归一化等),否则需要在这里添加预处理步骤
|
||||
|
||||
// 将 Mat 数据转换为 float 数组
|
||||
int channels = rgbMat.channels();
|
||||
int rows = rgbMat.rows();
|
||||
int cols = rgbMat.cols();
|
||||
float[] floatData = new float[channels * rows * cols];
|
||||
byte[] byteData = new byte[channels * rows * cols];
|
||||
rgbMat.get(0, 0, byteData);
|
||||
for (int i = 0; i < floatData.length; i++) {
|
||||
// 将 unsigned byte 转换为 float [0,1]
|
||||
floatData[i] = (byteData[i] & 0xFF) / 255.0f;
|
||||
}
|
||||
rgbMat.release();
|
||||
return floatData;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,280 +0,0 @@
|
|||
package com.ly.play;
|
||||
|
||||
import com.ly.layout.VideoPanel;
|
||||
import org.bytedeco.javacv.*;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
|
||||
public class VideoPlayer {
|
||||
private FrameGrabber grabber;
|
||||
private Java2DFrameConverter converter = new Java2DFrameConverter();
|
||||
private boolean isPlaying = false;
|
||||
private boolean isPaused = false;
|
||||
private Thread videoThread;
|
||||
private VideoPanel videoPanel;
|
||||
|
||||
private long videoDuration = 0; // 毫秒
|
||||
private long currentTimestamp = 0; // 毫秒
|
||||
|
||||
public VideoPlayer(VideoPanel videoPanel) {
|
||||
this.videoPanel = videoPanel;
|
||||
}
|
||||
|
||||
// 加载视频或流
|
||||
// 加载视频或流
|
||||
public void loadVideo(String videoFilePathOrStreamUrl) throws Exception {
|
||||
stopVideo();
|
||||
|
||||
|
||||
if (videoFilePathOrStreamUrl.equals("0")) {
|
||||
int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl);
|
||||
grabber = new OpenCVFrameGrabber(cameraIndex);
|
||||
grabber.start();
|
||||
videoDuration = 0; // 摄像头没有固定的时长
|
||||
playVideo();
|
||||
} else {
|
||||
// 输入不是数字,尝试使用 FFmpegFrameGrabber 打开流或视频文件
|
||||
grabber = new FFmpegFrameGrabber(videoFilePathOrStreamUrl);
|
||||
grabber.start();
|
||||
videoDuration = grabber.getLengthInTime() / 1000; // 转换为毫秒
|
||||
}
|
||||
|
||||
|
||||
// 显示第一帧
|
||||
Frame frame;
|
||||
if (grabber instanceof OpenCVFrameGrabber) {
|
||||
frame = grabber.grab();
|
||||
} else {
|
||||
frame = grabber.grab();
|
||||
}
|
||||
if (frame != null && frame.image != null) {
|
||||
BufferedImage bufferedImage = converter.getBufferedImage(frame);
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
currentTimestamp = 0;
|
||||
}
|
||||
|
||||
// 重置到视频开始位置
|
||||
if (grabber instanceof FFmpegFrameGrabber) {
|
||||
grabber.setTimestamp(0);
|
||||
}
|
||||
currentTimestamp = 0;
|
||||
}
|
||||
|
||||
public void playVideo() {
|
||||
if (grabber == null) {
|
||||
JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isPlaying) {
|
||||
if (isPaused) {
|
||||
isPaused = false; // 恢复播放
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
isPlaying = true;
|
||||
isPaused = false;
|
||||
|
||||
videoThread = new Thread(() -> {
|
||||
try {
|
||||
if (grabber instanceof OpenCVFrameGrabber) {
|
||||
// 摄像头捕获
|
||||
while (isPlaying) {
|
||||
if (isPaused) {
|
||||
Thread.sleep(100);
|
||||
continue;
|
||||
}
|
||||
|
||||
Frame frame = grabber.grab();
|
||||
if (frame == null) {
|
||||
isPlaying = false;
|
||||
break;
|
||||
}
|
||||
|
||||
BufferedImage bufferedImage = converter.getBufferedImage(frame);
|
||||
if (bufferedImage != null) {
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 视频文件或流
|
||||
double frameRate = grabber.getFrameRate();
|
||||
if (frameRate <= 0 || Double.isNaN(frameRate)) {
|
||||
frameRate = 25; // 默认帧率
|
||||
}
|
||||
long frameInterval = (long) (1000 / frameRate); // 每帧间隔时间(毫秒)
|
||||
long startTime = System.currentTimeMillis();
|
||||
long frameCount = 0;
|
||||
|
||||
while (isPlaying) {
|
||||
if (isPaused) {
|
||||
Thread.sleep(100);
|
||||
startTime += 100; // 调整开始时间以考虑暂停时间
|
||||
continue;
|
||||
}
|
||||
|
||||
Frame frame = grabber.grab();
|
||||
if (frame == null) {
|
||||
// 视频播放结束
|
||||
isPlaying = false;
|
||||
break;
|
||||
}
|
||||
|
||||
BufferedImage bufferedImage = converter.getBufferedImage(frame);
|
||||
if (bufferedImage != null) {
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
|
||||
// 更新当前帧时间戳
|
||||
frameCount++;
|
||||
long expectedTime = frameCount * frameInterval;
|
||||
long actualTime = System.currentTimeMillis() - startTime;
|
||||
|
||||
currentTimestamp = grabber.getTimestamp() / 1000;
|
||||
|
||||
// 如果实际时间落后于预期时间,进行调整
|
||||
if (actualTime < expectedTime) {
|
||||
Thread.sleep(expectedTime - actualTime);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 视频播放完毕后,停止播放
|
||||
isPlaying = false;
|
||||
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
});
|
||||
videoThread.start();
|
||||
}
|
||||
|
||||
// 暂停视频
|
||||
public void pauseVideo() {
|
||||
if (!isPlaying) {
|
||||
return;
|
||||
}
|
||||
isPaused = true;
|
||||
}
|
||||
|
||||
// 重播视频
|
||||
public void replayVideo() {
|
||||
try {
|
||||
if (grabber instanceof FFmpegFrameGrabber) {
|
||||
grabber.setTimestamp(0); // 重置到视频开始位置
|
||||
grabber.flush(); // 清除缓存
|
||||
currentTimestamp = 0;
|
||||
|
||||
// 显示第一帧
|
||||
Frame frame = grabber.grab();
|
||||
if (frame != null && frame.image != null) {
|
||||
BufferedImage bufferedImage = converter.getBufferedImage(frame);
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
}
|
||||
|
||||
playVideo(); // 开始播放
|
||||
} else if (grabber instanceof OpenCVFrameGrabber) {
|
||||
// 对于摄像头,重播相当于重新开始播放
|
||||
playVideo();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
|
||||
}
|
||||
}
|
||||
|
||||
// 停止视频
|
||||
public void stopVideo() {
|
||||
isPlaying = false;
|
||||
isPaused = false;
|
||||
|
||||
if (videoThread != null && videoThread.isAlive()) {
|
||||
try {
|
||||
videoThread.join();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
if (grabber != null) {
|
||||
try {
|
||||
grabber.stop();
|
||||
grabber.release();
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
grabber = null;
|
||||
}
|
||||
}
|
||||
|
||||
// 快进或后退
|
||||
public void seekTo(long seekTime) {
|
||||
if (grabber == null) return;
|
||||
if (!(grabber instanceof FFmpegFrameGrabber)) {
|
||||
JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
isPaused = false; // 取消暂停
|
||||
isPlaying = false; // 停止当前播放线程
|
||||
if (videoThread != null && videoThread.isAlive()) {
|
||||
videoThread.join();
|
||||
}
|
||||
|
||||
grabber.setTimestamp(seekTime * 1000); // 转换为微秒
|
||||
grabber.flush(); // 清除缓存
|
||||
|
||||
Frame frame;
|
||||
do {
|
||||
frame = grabber.grab();
|
||||
if (frame == null) {
|
||||
break;
|
||||
}
|
||||
} while (frame.image == null); // 跳过没有图像的帧
|
||||
|
||||
if (frame != null && frame.image != null) {
|
||||
BufferedImage bufferedImage = converter.getBufferedImage(frame);
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
|
||||
// 更新当前帧时间戳
|
||||
currentTimestamp = grabber.getTimestamp() / 1000;
|
||||
}
|
||||
|
||||
// 重新开始播放
|
||||
playVideo();
|
||||
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// 快进
|
||||
public void fastForward(long milliseconds) {
|
||||
if (!(grabber instanceof FFmpegFrameGrabber)) {
|
||||
JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
return;
|
||||
}
|
||||
long newTime = Math.min(currentTimestamp + milliseconds, videoDuration);
|
||||
seekTo(newTime);
|
||||
}
|
||||
|
||||
// 后退
|
||||
public void rewind(long milliseconds) {
|
||||
if (!(grabber instanceof FFmpegFrameGrabber)) {
|
||||
JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
return;
|
||||
}
|
||||
long newTime = Math.max(currentTimestamp - milliseconds, 0);
|
||||
seekTo(newTime);
|
||||
}
|
||||
|
||||
public long getVideoDuration() {
|
||||
return videoDuration;
|
||||
}
|
||||
|
||||
public FrameGrabber getGrabber() {
|
||||
return grabber;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,328 @@
|
|||
//package com.ly.play.ff;
|
||||
//
|
||||
//import com.ly.layout.VideoPanel;
|
||||
//import com.ly.model_load.ModelManager;
|
||||
//import com.ly.onnx.engine.InferenceEngine;
|
||||
//import com.ly.onnx.model.InferenceResult;
|
||||
//import com.ly.onnx.utils.DrawImagesUtils;
|
||||
//import com.ly.onnx.utils.ImageUtils;
|
||||
//import org.bytedeco.javacv.*;
|
||||
//
|
||||
//import javax.swing.*;
|
||||
//import java.awt.image.BufferedImage;
|
||||
//import java.util.ArrayList;
|
||||
//import java.util.List;
|
||||
//
|
||||
//public class VideoPlayer {
|
||||
// private FrameGrabber grabber;
|
||||
// private Java2DFrameConverter converter = new Java2DFrameConverter();
|
||||
// private boolean isPlaying = false;
|
||||
// private boolean isPaused = false;
|
||||
// private Thread videoThread;
|
||||
// private VideoPanel videoPanel;
|
||||
//
|
||||
// private long videoDuration = 0; // 毫秒
|
||||
// private long currentTimestamp = 0; // 毫秒
|
||||
//
|
||||
// ModelManager modelManager;
|
||||
// private List<InferenceEngine> inferenceEngines = new ArrayList<>();
|
||||
//
|
||||
// public VideoPlayer(VideoPanel videoPanel, ModelManager modelManager) {
|
||||
// this.videoPanel = videoPanel;
|
||||
// this.modelManager = modelManager;
|
||||
// System.out.println();
|
||||
// }
|
||||
//
|
||||
// // 加载视频或流
|
||||
// public void loadVideo(String videoFilePathOrStreamUrl) throws Exception {
|
||||
// stopVideo();
|
||||
// if (videoFilePathOrStreamUrl.equals("0")) {
|
||||
// int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl);
|
||||
// grabber = new OpenCVFrameGrabber(cameraIndex);
|
||||
// grabber.start();
|
||||
// videoDuration = 0; // 摄像头没有固定的时长
|
||||
// playVideo();
|
||||
// } else {
|
||||
// // 输入不是数字,尝试使用 FFmpegFrameGrabber 打开流或视频文件
|
||||
// grabber = new FFmpegFrameGrabber(videoFilePathOrStreamUrl);
|
||||
// grabber.start();
|
||||
// videoDuration = grabber.getLengthInTime() / 1000; // 转换为毫秒
|
||||
// }
|
||||
//
|
||||
//
|
||||
// // 显示第一帧
|
||||
// Frame frame;
|
||||
// if (grabber instanceof OpenCVFrameGrabber) {
|
||||
// frame = grabber.grab();
|
||||
// } else {
|
||||
// frame = grabber.grab();
|
||||
// }
|
||||
// if (frame != null && frame.image != null) {
|
||||
// BufferedImage bufferedImage = converter.getBufferedImage(frame);
|
||||
// videoPanel.updateImage(bufferedImage);
|
||||
// currentTimestamp = 0;
|
||||
// }
|
||||
//
|
||||
// // 重置到视频开始位置
|
||||
// if (grabber instanceof FFmpegFrameGrabber) {
|
||||
// grabber.setTimestamp(0);
|
||||
// }
|
||||
// currentTimestamp = 0;
|
||||
// }
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
// //播放
|
||||
// public void playVideo() {
|
||||
// if (grabber == null) {
|
||||
// JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
// return;
|
||||
// }
|
||||
//
|
||||
// if (inferenceEngines == null){
|
||||
// JOptionPane.showMessageDialog(null, "请先加载模型给文件。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
// return;
|
||||
// }
|
||||
//
|
||||
// if (isPlaying) {
|
||||
// if (isPaused) {
|
||||
// isPaused = false; // 恢复播放
|
||||
// }
|
||||
// return;
|
||||
// }
|
||||
//
|
||||
// isPlaying = true;
|
||||
// isPaused = false;
|
||||
//
|
||||
// videoThread = new Thread(() -> {
|
||||
// try {
|
||||
// if (grabber instanceof OpenCVFrameGrabber) {
|
||||
// // 摄像头捕获
|
||||
// while (isPlaying) {
|
||||
// if (isPaused) {
|
||||
// Thread.sleep(10);
|
||||
// continue;
|
||||
// }
|
||||
//
|
||||
// Frame frame = grabber.grab();
|
||||
// if (frame == null) {
|
||||
// isPlaying = false;
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
// BufferedImage bufferedImage = converter.getBufferedImage(frame);
|
||||
// List<InferenceResult> inferenceResults = new ArrayList<>();
|
||||
// if (bufferedImage != null) {
|
||||
// float[] inputData = ImageUtils.frameToFloatArray(frame);
|
||||
// for (InferenceEngine inferenceEngine : inferenceEngines) {
|
||||
// inferenceResults.add(inferenceEngine.infer(inputData,640,640));
|
||||
// }
|
||||
// //绘制
|
||||
// DrawImagesUtils.drawInferenceResult(bufferedImage,inferenceResults);
|
||||
// //更新绘制后图像
|
||||
// videoPanel.updateImage(bufferedImage);
|
||||
// }
|
||||
// }
|
||||
// } else {
|
||||
// // 视频文件或流
|
||||
// double frameRate = grabber.getFrameRate();
|
||||
// if (frameRate <= 0 || Double.isNaN(frameRate)) {
|
||||
// frameRate = 25; // 默认帧率
|
||||
// }
|
||||
// long frameInterval = (long) (1000 / frameRate); // 每帧间隔时间(毫秒)
|
||||
// long startTime = System.currentTimeMillis();
|
||||
// long frameCount = 0;
|
||||
//
|
||||
// while (isPlaying) {
|
||||
// if (isPaused) {
|
||||
// Thread.sleep(100);
|
||||
// startTime += 100; // 调整开始时间以考虑暂停时间
|
||||
// continue;
|
||||
// }
|
||||
//
|
||||
// Frame frame = grabber.grab();
|
||||
// if (frame == null) {
|
||||
// // 视频播放结束
|
||||
// isPlaying = false;
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
//
|
||||
//
|
||||
// BufferedImage bufferedImage = converter.getBufferedImage(frame);
|
||||
//
|
||||
//
|
||||
// List<InferenceResult> inferenceResults = new ArrayList<>();
|
||||
// if (bufferedImage != null) {
|
||||
// float[] inputData = ImageUtils.frameToFloatArray(frame);
|
||||
// for (InferenceEngine inferenceEngine : inferenceEngines) {
|
||||
// inferenceResults.add(inferenceEngine.infer(inputData,640,640));
|
||||
// }
|
||||
// //绘制
|
||||
// DrawImagesUtils.drawInferenceResult(bufferedImage,inferenceResults);
|
||||
// //更新绘制后图像
|
||||
// videoPanel.updateImage(bufferedImage);
|
||||
// }
|
||||
//
|
||||
// if (bufferedImage != null) {
|
||||
// videoPanel.updateImage(bufferedImage);
|
||||
//
|
||||
// // 更新当前帧时间戳
|
||||
// frameCount++;
|
||||
// long expectedTime = frameCount * frameInterval;
|
||||
// long actualTime = System.currentTimeMillis() - startTime;
|
||||
//
|
||||
// currentTimestamp = grabber.getTimestamp() / 1000;
|
||||
//
|
||||
// // 如果实际时间落后于预期时间,进行调整
|
||||
// if (actualTime < expectedTime) {
|
||||
// Thread.sleep(expectedTime - actualTime);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 视频播放完毕后,停止播放
|
||||
// isPlaying = false;
|
||||
//
|
||||
// } catch (Exception ex) {
|
||||
// ex.printStackTrace();
|
||||
// }
|
||||
// });
|
||||
// videoThread.start();
|
||||
// }
|
||||
//
|
||||
// // 暂停视频
|
||||
// public void pauseVideo() {
|
||||
// if (!isPlaying) {
|
||||
// return;
|
||||
// }
|
||||
// isPaused = true;
|
||||
// }
|
||||
//
|
||||
// // 重播视频
|
||||
// public void replayVideo() {
|
||||
// try {
|
||||
// if (grabber instanceof FFmpegFrameGrabber) {
|
||||
// grabber.setTimestamp(0); // 重置到视频开始位置
|
||||
// grabber.flush(); // 清除缓存
|
||||
// currentTimestamp = 0;
|
||||
//
|
||||
// // 显示第一帧
|
||||
// Frame frame = grabber.grab();
|
||||
// if (frame != null && frame.image != null) {
|
||||
// BufferedImage bufferedImage = converter.getBufferedImage(frame);
|
||||
// videoPanel.updateImage(bufferedImage);
|
||||
// }
|
||||
//
|
||||
// playVideo(); // 开始播放
|
||||
// } else if (grabber instanceof OpenCVFrameGrabber) {
|
||||
// // 对于摄像头,重播相当于重新开始播放
|
||||
// playVideo();
|
||||
// }
|
||||
// } catch (Exception e) {
|
||||
// e.printStackTrace();
|
||||
// JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 停止视频
|
||||
// public void stopVideo() {
|
||||
// isPlaying = false;
|
||||
// isPaused = false;
|
||||
//
|
||||
// if (videoThread != null && videoThread.isAlive()) {
|
||||
// try {
|
||||
// videoThread.join();
|
||||
// } catch (InterruptedException e) {
|
||||
// e.printStackTrace();
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if (grabber != null) {
|
||||
// try {
|
||||
// grabber.stop();
|
||||
// grabber.release();
|
||||
// } catch (Exception ex) {
|
||||
// ex.printStackTrace();
|
||||
// }
|
||||
// grabber = null;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 快进或后退
|
||||
// public void seekTo(long seekTime) {
|
||||
// if (grabber == null) return;
|
||||
// if (!(grabber instanceof FFmpegFrameGrabber)) {
|
||||
// JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
// return;
|
||||
// }
|
||||
// try {
|
||||
// isPaused = false; // 取消暂停
|
||||
// isPlaying = false; // 停止当前播放线程
|
||||
// if (videoThread != null && videoThread.isAlive()) {
|
||||
// videoThread.join();
|
||||
// }
|
||||
//
|
||||
// grabber.setTimestamp(seekTime * 1000); // 转换为微秒
|
||||
// grabber.flush(); // 清除缓存
|
||||
//
|
||||
// Frame frame;
|
||||
// do {
|
||||
// frame = grabber.grab();
|
||||
// if (frame == null) {
|
||||
// break;
|
||||
// }
|
||||
// } while (frame.image == null); // 跳过没有图像的帧
|
||||
//
|
||||
// if (frame != null && frame.image != null) {
|
||||
// BufferedImage bufferedImage = converter.getBufferedImage(frame);
|
||||
// videoPanel.updateImage(bufferedImage);
|
||||
//
|
||||
// // 更新当前帧时间戳
|
||||
// currentTimestamp = grabber.getTimestamp() / 1000;
|
||||
// }
|
||||
//
|
||||
// // 重新开始播放
|
||||
// playVideo();
|
||||
//
|
||||
// } catch (Exception ex) {
|
||||
// ex.printStackTrace();
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 快进
|
||||
// public void fastForward(long milliseconds) {
|
||||
// if (!(grabber instanceof FFmpegFrameGrabber)) {
|
||||
// JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
// return;
|
||||
// }
|
||||
// long newTime = Math.min(currentTimestamp + milliseconds, videoDuration);
|
||||
// seekTo(newTime);
|
||||
// }
|
||||
//
|
||||
// // 后退
|
||||
// public void rewind(long milliseconds) {
|
||||
// if (!(grabber instanceof FFmpegFrameGrabber)) {
|
||||
// JOptionPane.showMessageDialog(null, "此操作仅支持视频文件和流。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
// return;
|
||||
// }
|
||||
// long newTime = Math.max(currentTimestamp - milliseconds, 0);
|
||||
// seekTo(newTime);
|
||||
// }
|
||||
//
|
||||
// public long getVideoDuration() {
|
||||
// return videoDuration;
|
||||
// }
|
||||
//
|
||||
// public FrameGrabber getGrabber() {
|
||||
// return grabber;
|
||||
// }
|
||||
//
|
||||
// public void addInferenceEngines(InferenceEngine inferenceEngine){
|
||||
// this.inferenceEngines.add(inferenceEngine);
|
||||
// }
|
||||
//
|
||||
//}
|
|
@ -0,0 +1,427 @@
|
|||
package com.ly.play.opencv;
|
||||
|
||||
import com.ly.layout.VideoPanel;
|
||||
import com.ly.model_load.ModelManager;
|
||||
import com.ly.onnx.engine.InferenceEngine;
|
||||
import com.ly.onnx.model.InferenceResult;
|
||||
import com.ly.onnx.utils.DrawImagesUtils;
|
||||
import org.opencv.core.CvType;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Rect;
|
||||
import org.opencv.core.Size;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
import org.opencv.videoio.VideoCapture;
|
||||
import org.opencv.videoio.Videoio;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.awt.image.DataBufferByte;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.*;
|
||||
|
||||
import static com.ly.onnx.utils.ImageUtils.matToBufferedImage;
|
||||
|
||||
public class VideoPlayer {
|
||||
static {
|
||||
// 加载 OpenCV 库
|
||||
nu.pattern.OpenCV.loadLocally();
|
||||
String OS = System.getProperty("os.name").toLowerCase();
|
||||
if (OS.contains("win")) {
|
||||
System.load(ClassLoader.getSystemResource("lib/win/opencv_videoio_ffmpeg470_64.dll").getPath());
|
||||
}
|
||||
}
|
||||
|
||||
private VideoCapture videoCapture;
|
||||
private volatile boolean isPlaying = false;
|
||||
private volatile boolean isPaused = false;
|
||||
private Thread frameReadingThread;
|
||||
private Thread inferenceThread;
|
||||
private VideoPanel videoPanel;
|
||||
|
||||
private long videoDuration = 0; // 毫秒
|
||||
private long currentTimestamp = 0; // 毫秒
|
||||
|
||||
private ModelManager modelManager;
|
||||
private List<InferenceEngine> inferenceEngines = new ArrayList<>();
|
||||
|
||||
// 定义阻塞队列来缓冲转换后的数据
|
||||
private BlockingQueue<FrameData> frameDataQueue = new LinkedBlockingQueue<>(10); // 队列容量可根据需要调整
|
||||
|
||||
public VideoPlayer(VideoPanel videoPanel, ModelManager modelManager) {
|
||||
this.videoPanel = videoPanel;
|
||||
this.modelManager = modelManager;
|
||||
}
|
||||
|
||||
// 加载视频或流
|
||||
public void loadVideo(String videoFilePathOrStreamUrl) throws Exception {
|
||||
stopVideo();
|
||||
if (videoFilePathOrStreamUrl.equals("0")) {
|
||||
int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl);
|
||||
videoCapture = new VideoCapture(cameraIndex);
|
||||
if (!videoCapture.isOpened()) {
|
||||
throw new Exception("无法打开摄像头");
|
||||
}
|
||||
videoDuration = 0; // 摄像头没有固定的时长
|
||||
playVideo();
|
||||
} else {
|
||||
// 输入不是数字,尝试打开视频文件
|
||||
videoCapture = new VideoCapture(videoFilePathOrStreamUrl, Videoio.CAP_FFMPEG);
|
||||
if (!videoCapture.isOpened()) {
|
||||
throw new Exception("无法打开视频文件:" + videoFilePathOrStreamUrl);
|
||||
}
|
||||
double frameCount = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FRAME_COUNT);
|
||||
double fps = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FPS);
|
||||
if (fps <= 0 || Double.isNaN(fps)) {
|
||||
fps = 25; // 默认帧率
|
||||
}
|
||||
videoDuration = (long) (frameCount / fps * 1000); // 转换为毫秒
|
||||
}
|
||||
|
||||
// 显示第一帧
|
||||
Mat frame = new Mat();
|
||||
if (videoCapture.read(frame)) {
|
||||
BufferedImage bufferedImage = matToBufferedImage(frame);
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
currentTimestamp = 0;
|
||||
} else {
|
||||
throw new Exception("无法读取第一帧");
|
||||
}
|
||||
|
||||
// 重置到视频开始位置
|
||||
videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_FRAMES, 0);
|
||||
currentTimestamp = 0;
|
||||
}
|
||||
|
||||
// 播放
|
||||
public void playVideo() {
|
||||
if (videoCapture == null || !videoCapture.isOpened()) {
|
||||
JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isPlaying) {
|
||||
if (isPaused) {
|
||||
isPaused = false; // 恢复播放
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
isPlaying = true;
|
||||
isPaused = false;
|
||||
|
||||
frameDataQueue.clear(); // 开始播放前清空队列
|
||||
|
||||
// 创建并启动帧读取和转换线程
|
||||
frameReadingThread = new Thread(() -> {
|
||||
try {
|
||||
double fps = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FPS);
|
||||
if (fps <= 0 || Double.isNaN(fps)) {
|
||||
fps = 25; // 默认帧率
|
||||
}
|
||||
long frameDelay = (long) (1000 / fps);
|
||||
|
||||
while (isPlaying) {
|
||||
if (Thread.currentThread().isInterrupted()) {
|
||||
break;
|
||||
}
|
||||
if (isPaused) {
|
||||
Thread.sleep(10);
|
||||
continue;
|
||||
}
|
||||
|
||||
Mat frame = new Mat();
|
||||
if (!videoCapture.read(frame) || frame.empty()) {
|
||||
isPlaying = false;
|
||||
break;
|
||||
}
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
BufferedImage bufferedImage = matToBufferedImage(frame);
|
||||
|
||||
if (bufferedImage != null) {
|
||||
float[] floats = preprocessAndConvertBufferedImage(bufferedImage);
|
||||
|
||||
// 创建 FrameData 对象并放入队列
|
||||
FrameData frameData = new FrameData(bufferedImage, floats);
|
||||
frameDataQueue.put(frameData); // 阻塞,如果队列已满
|
||||
}
|
||||
|
||||
// 控制帧率
|
||||
currentTimestamp = (long) videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_POS_MSEC);
|
||||
|
||||
// 控制播放速度
|
||||
long processingTime = System.currentTimeMillis() - startTime;
|
||||
long sleepTime = frameDelay - processingTime;
|
||||
if (sleepTime > 0) {
|
||||
Thread.sleep(sleepTime);
|
||||
}
|
||||
}
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
} finally {
|
||||
isPlaying = false;
|
||||
}
|
||||
});
|
||||
|
||||
// 创建并启动推理线程
|
||||
inferenceThread = new Thread(() -> {
|
||||
try {
|
||||
while (isPlaying || !frameDataQueue.isEmpty()) {
|
||||
if (Thread.currentThread().isInterrupted()) {
|
||||
break;
|
||||
}
|
||||
if (isPaused) {
|
||||
Thread.sleep(100);
|
||||
continue;
|
||||
}
|
||||
|
||||
FrameData frameData = frameDataQueue.poll(100, TimeUnit.MILLISECONDS); // 等待数据
|
||||
if (frameData == null) {
|
||||
continue; // 没有数据,继续检查 isPlaying
|
||||
}
|
||||
|
||||
BufferedImage bufferedImage = frameData.image;
|
||||
float[] floatArray = frameData.floatArray;
|
||||
|
||||
// 执行推理
|
||||
List<InferenceResult> inferenceResults = new ArrayList<>();
|
||||
for (InferenceEngine inferenceEngine : inferenceEngines) {
|
||||
// 假设 InferenceEngine 有 infer 方法接受 float 数组
|
||||
// inferenceResults.add(inferenceEngine.infer(floatArray, 640, 640));
|
||||
}
|
||||
// 绘制推理结果
|
||||
DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults);
|
||||
|
||||
// 更新绘制后图像
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
}
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
});
|
||||
|
||||
frameReadingThread.start();
|
||||
inferenceThread.start();
|
||||
}
|
||||
|
||||
// 暂停视频
|
||||
public void pauseVideo() {
|
||||
if (!isPlaying) {
|
||||
return;
|
||||
}
|
||||
isPaused = true;
|
||||
}
|
||||
|
||||
// 重播视频
|
||||
public void replayVideo() {
|
||||
try {
|
||||
stopVideo(); // 停止当前播放
|
||||
if (videoCapture != null) {
|
||||
videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_FRAMES, 0);
|
||||
currentTimestamp = 0;
|
||||
|
||||
// 显示第一帧
|
||||
Mat frame = new Mat();
|
||||
if (videoCapture.read(frame)) {
|
||||
BufferedImage bufferedImage = matToBufferedImage(frame);
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
}
|
||||
|
||||
playVideo(); // 开始播放
|
||||
}
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE);
|
||||
}
|
||||
}
|
||||
|
||||
// 停止视频
|
||||
public void stopVideo() {
|
||||
isPlaying = false;
|
||||
isPaused = false;
|
||||
|
||||
if (frameReadingThread != null && frameReadingThread.isAlive()) {
|
||||
frameReadingThread.interrupt();
|
||||
}
|
||||
|
||||
if (inferenceThread != null && inferenceThread.isAlive()) {
|
||||
inferenceThread.interrupt();
|
||||
}
|
||||
|
||||
if (videoCapture != null) {
|
||||
videoCapture.release();
|
||||
videoCapture = null;
|
||||
}
|
||||
|
||||
frameDataQueue.clear();
|
||||
}
|
||||
|
||||
// 快进或后退
|
||||
public void seekTo(long seekTime) {
|
||||
if (videoCapture == null) return;
|
||||
try {
|
||||
isPaused = false; // 取消暂停
|
||||
stopVideo(); // 停止当前播放
|
||||
videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_MSEC, seekTime);
|
||||
currentTimestamp = seekTime;
|
||||
|
||||
Mat frame = new Mat();
|
||||
if (videoCapture.read(frame)) {
|
||||
BufferedImage bufferedImage = matToBufferedImage(frame);
|
||||
videoPanel.updateImage(bufferedImage);
|
||||
}
|
||||
|
||||
// 重新开始播放
|
||||
playVideo();
|
||||
|
||||
} catch (Exception ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
// 快进
|
||||
public void fastForward(long milliseconds) {
|
||||
long newTime = Math.min(currentTimestamp + milliseconds, videoDuration);
|
||||
seekTo(newTime);
|
||||
}
|
||||
|
||||
// 后退
|
||||
public void rewind(long milliseconds) {
|
||||
long newTime = Math.max(currentTimestamp - milliseconds, 0);
|
||||
seekTo(newTime);
|
||||
}
|
||||
|
||||
public void addInferenceEngines(InferenceEngine inferenceEngine) {
|
||||
this.inferenceEngines.add(inferenceEngine);
|
||||
}
|
||||
|
||||
// 定义一个内部类来存储帧数据
|
||||
private static class FrameData {
|
||||
public BufferedImage image;
|
||||
public float[] floatArray;
|
||||
|
||||
public FrameData(BufferedImage image, float[] floatArray) {
|
||||
this.image = image;
|
||||
this.floatArray = floatArray;
|
||||
}
|
||||
}
|
||||
|
||||
// 将 BufferedImage 预处理并转换为一维 float[] 数组
|
||||
public static float[] preprocessAndConvertBufferedImage(BufferedImage image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
// 将 BufferedImage 转换为 Mat
|
||||
Mat matImage = bufferedImageToMat(image);
|
||||
|
||||
// 原始图像尺寸
|
||||
int origWidth = matImage.width();
|
||||
int origHeight = matImage.height();
|
||||
|
||||
// 计算缩放因子
|
||||
float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
int newWidth = Math.round(origWidth * scalingFactor);
|
||||
int newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 调整图像尺寸
|
||||
Mat resizedImage = new Mat();
|
||||
Imgproc.resize(matImage, resizedImage, new Size(newWidth, newHeight));
|
||||
|
||||
// 转换为 RGB
|
||||
Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
|
||||
|
||||
// 创建目标图像并将调整后的图像填充到目标图像中
|
||||
Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
|
||||
int xOffset = (targetWidth - newWidth) / 2;
|
||||
int yOffset = (targetHeight - newHeight) / 2;
|
||||
Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
|
||||
resizedImage.copyTo(paddedImage.submat(roi));
|
||||
|
||||
// 将图像数据转换为输入所需的浮点数组
|
||||
int imageSize = targetWidth * targetHeight;
|
||||
float[] inputData = new float[3 * imageSize];
|
||||
paddedImage.reshape(1, imageSize * 3).get(0, 0, inputData);
|
||||
|
||||
// 释放资源
|
||||
matImage.release();
|
||||
resizedImage.release();
|
||||
paddedImage.release();
|
||||
|
||||
return inputData;
|
||||
}
|
||||
|
||||
// 辅助方法:将 BufferedImage 转换为 OpenCV 的 Mat 格式
|
||||
public static Mat bufferedImageToMat(BufferedImage bi) {
|
||||
int width = bi.getWidth();
|
||||
int height = bi.getHeight();
|
||||
Mat mat = new Mat(height, width, CvType.CV_8UC3);
|
||||
byte[] data = ((DataBufferByte) bi.getRaster().getDataBuffer()).getData();
|
||||
mat.put(0, 0, data);
|
||||
return mat;
|
||||
}
|
||||
|
||||
// 可选的预处理方法
|
||||
public Map<String, Object> preprocessImage(Mat image) {
|
||||
int targetWidth = 640;
|
||||
int targetHeight = 640;
|
||||
|
||||
int origWidth = image.width();
|
||||
int origHeight = image.height();
|
||||
|
||||
// 计算缩放因子
|
||||
float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight);
|
||||
|
||||
// 计算新的图像尺寸
|
||||
int newWidth = Math.round(origWidth * scalingFactor);
|
||||
int newHeight = Math.round(origHeight * scalingFactor);
|
||||
|
||||
// 调整图像尺寸
|
||||
Mat resizedImage = new Mat();
|
||||
Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight));
|
||||
|
||||
// 转换为 RGB 并归一化
|
||||
Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB);
|
||||
resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0);
|
||||
|
||||
// 创建填充后的图像
|
||||
Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3);
|
||||
int xOffset = (targetWidth - newWidth) / 2;
|
||||
int yOffset = (targetHeight - newHeight) / 2;
|
||||
Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight);
|
||||
resizedImage.copyTo(paddedImage.submat(roi));
|
||||
|
||||
// 将图像数据转换为数组
|
||||
int imageSize = targetWidth * targetHeight;
|
||||
float[] chwData = new float[3 * imageSize];
|
||||
float[] hwcData = new float[3 * imageSize];
|
||||
paddedImage.get(0, 0, hwcData);
|
||||
|
||||
// 转换为 CHW 格式
|
||||
int channelSize = imageSize;
|
||||
for (int c = 0; c < 3; c++) {
|
||||
for (int i = 0; i < imageSize; i++) {
|
||||
chwData[c * channelSize + i] = hwcData[i * 3 + c];
|
||||
}
|
||||
}
|
||||
|
||||
// 释放图像资源
|
||||
resizedImage.release();
|
||||
paddedImage.release();
|
||||
|
||||
// 将预处理结果和偏移信息存入 Map
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("inputData", chwData);
|
||||
result.put("origWidth", origWidth);
|
||||
result.put("origHeight", origHeight);
|
||||
result.put("scalingFactor", scalingFactor);
|
||||
result.put("xOffset", xOffset);
|
||||
result.put("yOffset", yOffset);
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package com.ly.utils;
|
||||
|
||||
import org.bytedeco.javacv.FrameGrabber;
|
||||
import org.bytedeco.javacv.VideoInputFrameGrabber;
|
||||
|
||||
public class CameraDeviceLister {
|
||||
public static void main(String[] args) throws FrameGrabber.Exception {
|
||||
String[] deviceDescriptions = VideoInputFrameGrabber.getDeviceDescriptions();
|
||||
for (String deviceDescription : deviceDescriptions) {
|
||||
System.out.println("摄像头索引 " + ": " + deviceDescription);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package com.ly.utils;
|
||||
|
||||
import org.opencv.core.Core;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.videoio.VideoCapture;
|
||||
|
||||
public class OpenCVTest {
|
||||
// static {
|
||||
// nu.pattern.OpenCV.loadLocally();
|
||||
// }
|
||||
|
||||
public static void main(String[] args) {
|
||||
VideoCapture capture = new VideoCapture(0); // 打开默认摄像头
|
||||
if (!capture.isOpened()) {
|
||||
System.out.println("无法打开摄像头");
|
||||
return;
|
||||
}
|
||||
|
||||
Mat frame = new Mat();
|
||||
if (capture.read(frame)) {
|
||||
System.out.println("成功读取一帧图像");
|
||||
} else {
|
||||
System.out.println("无法读取图像");
|
||||
}
|
||||
|
||||
capture.release();
|
||||
}
|
||||
}
|
|
@ -1,57 +0,0 @@
|
|||
package com.ly.utils;
|
||||
|
||||
import org.bytedeco.ffmpeg.global.avcodec;
|
||||
import org.bytedeco.javacv.*;
|
||||
|
||||
public class RTSPStreamer {
|
||||
|
||||
public static void main(String[] args) {
|
||||
String inputFile = "C:\\Users\\ly\\Desktop\\屏幕录制 2024-09-20 225443.mp4"; // 替换为您的本地视频文件路径
|
||||
String rtspUrl = "rtsp://localhost:8554/live"; // 替换为您的 RTSP 服务器地址
|
||||
|
||||
FFmpegFrameGrabber grabber = null;
|
||||
FFmpegFrameRecorder recorder = null;
|
||||
|
||||
try {
|
||||
// 初始化 FFmpegFrameGrabber 以从本地视频文件读取
|
||||
grabber = new FFmpegFrameGrabber(inputFile);
|
||||
grabber.start();
|
||||
|
||||
// 初始化 FFmpegFrameRecorder 以推流到 RTSP 服务器
|
||||
recorder = new FFmpegFrameRecorder(rtspUrl, grabber.getImageWidth(), grabber.getImageHeight(), grabber.getAudioChannels());
|
||||
recorder.setFormat("rtsp");
|
||||
recorder.setFrameRate(grabber.getFrameRate());
|
||||
recorder.setVideoBitrate(grabber.getVideoBitrate());
|
||||
recorder.setVideoCodec(avcodec.AV_CODEC_ID_H264); // 设置视频编码格式
|
||||
recorder.setAudioCodec(avcodec.AV_CODEC_ID_AAC); // 设置音频编码格式
|
||||
|
||||
// 设置 RTSP 传输选项(如果需要)
|
||||
recorder.setOption("rtsp_transport", "tcp");
|
||||
|
||||
recorder.start();
|
||||
|
||||
Frame frame;
|
||||
while ((frame = grabber.grab()) != null) {
|
||||
recorder.record(frame);
|
||||
}
|
||||
|
||||
System.out.println("推流完成。");
|
||||
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
} finally {
|
||||
try {
|
||||
if (recorder != null) {
|
||||
recorder.stop();
|
||||
recorder.release();
|
||||
}
|
||||
if (grabber != null) {
|
||||
grabber.stop();
|
||||
grabber.release();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue