diff --git a/src/main/java/com/ly/VideoInferenceApp.java b/src/main/java/com/ly/VideoInferenceApp.java index da9e96b..79b2510 100644 --- a/src/main/java/com/ly/VideoInferenceApp.java +++ b/src/main/java/com/ly/VideoInferenceApp.java @@ -25,10 +25,11 @@ public class VideoInferenceApp extends JFrame { private VideoPlayer videoPlayer; private ModelManager modelManager; + + public VideoInferenceApp() { // 设置窗口标题 - super("Video Inference Player"); - + super("https://gitee.com/sulv0302/onnx-inference4j-play.git"); // 初始化UI组件 initializeUI(); } diff --git a/src/main/java/com/ly/lishi/InferenceEngine.java b/src/main/java/com/ly/lishi/InferenceEngine.java new file mode 100644 index 0000000..63ffa66 --- /dev/null +++ b/src/main/java/com/ly/lishi/InferenceEngine.java @@ -0,0 +1,410 @@ +package com.ly.lishi; + +import ai.onnxruntime.*; +import com.alibaba.fastjson.JSON; +import com.ly.onnx.model.BoundingBox; +import com.ly.onnx.model.InferenceResult; +import lombok.Data; +import org.opencv.core.*; +import org.opencv.imgcodecs.Imgcodecs; +import org.opencv.imgproc.Imgproc; + +import java.nio.FloatBuffer; +import java.util.*; + +@Data +public class InferenceEngine { + + private OrtEnvironment environment; + private OrtSession.SessionOptions sessionOptions; + private OrtSession session; + + private String modelPath; + private List labels; + + // 用于存储图像预处理信息的类变量 + private long[] inputShape = null; + + static { + nu.pattern.OpenCV.loadLocally(); + } + + public InferenceEngine(String modelPath, List labels) { + this.modelPath = modelPath; + this.labels = labels; + init(); + } + + public void init() { + try { + environment = OrtEnvironment.getEnvironment(); + sessionOptions = new OrtSession.SessionOptions(); + sessionOptions.addCUDA(0); // 使用 GPU + session = environment.createSession(modelPath, sessionOptions); + Map inputInfo = session.getInputInfo(); + NodeInfo nodeInfo = inputInfo.values().iterator().next(); + TensorInfo tensorInfo = (TensorInfo) nodeInfo.getInfo(); + inputShape = tensorInfo.getShape(); // 从模型中获取输入形状 + logModelInfo(session); + } catch (OrtException e) { + throw new RuntimeException("模型加载失败", e); + } + } + + public InferenceResult infer(int w, int h, Map preprocessParams) { + long startTime = System.currentTimeMillis(); + + // 从 Map 中获取偏移相关的变量 + float[] inputData = (float[]) preprocessParams.get("inputData"); + int origWidth = (int) preprocessParams.get("origWidth"); + int origHeight = (int) preprocessParams.get("origHeight"); + float scalingFactor = (float) preprocessParams.get("scalingFactor"); + int xOffset = (int) preprocessParams.get("xOffset"); + int yOffset = (int) preprocessParams.get("yOffset"); + + try { + Map inputInfo = session.getInputInfo(); + String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入 + + long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状 + + // 创建输入张量时,使用 CHW 格式的数据 + OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape); + + // 执行推理 + long inferenceStart = System.currentTimeMillis(); + OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor)); + long inferenceEnd = System.currentTimeMillis(); + System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms"); + + // 解析推理结果 + String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出 + float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 6] + + // 设定置信度阈值 + float confidenceThreshold = 0.25f; // 您可以根据需要调整 + + // 根据模型的输出结果解析边界框 + List boxes = new ArrayList<>(); + for (float[] data : outputData[0]) { // 遍历所有检测框 + // 根据模型输出格式,解析中心坐标和宽高 + float x_center = data[0]; + float y_center = data[1]; + float width = data[2]; + float height = data[3]; + float confidence = data[4]; + + if (confidence >= confidenceThreshold) { + // 将中心坐标转换为左上角和右下角坐标 + float x1 = x_center - width / 2; + float y1 = y_center - height / 2; + float x2 = x_center + width / 2; + float y2 = y_center + height / 2; + + // 调整坐标,减去偏移并除以缩放因子 + float x1Adjusted = (x1 - xOffset) / scalingFactor; + float y1Adjusted = (y1 - yOffset) / scalingFactor; + float x2Adjusted = (x2 - xOffset) / scalingFactor; + float y2Adjusted = (y2 - yOffset) / scalingFactor; + + // 确保坐标的正确顺序 + float xMinAdjusted = Math.min(x1Adjusted, x2Adjusted); + float xMaxAdjusted = Math.max(x1Adjusted, x2Adjusted); + float yMinAdjusted = Math.min(y1Adjusted, y2Adjusted); + float yMaxAdjusted = Math.max(y1Adjusted, y2Adjusted); + + // 确保坐标在原始图像范围内 + int x = (int) Math.max(0, xMinAdjusted); + int y = (int) Math.max(0, yMinAdjusted); + int xMax = (int) Math.min(origWidth, xMaxAdjusted); + int yMax = (int) Math.min(origHeight, yMaxAdjusted); + int wBox = xMax - x; + int hBox = yMax - y; + + // 仅当宽度和高度为正时,才添加边界框 + if (wBox > 0 && hBox > 0) { + // 使用您的单一标签 + String label = labels.get(0); + + boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence)); + } + } + } + + // 非极大值抑制(NMS) + long nmsStart = System.currentTimeMillis(); + List nmsBoxes = nonMaximumSuppression(boxes, 0.5f); + + System.out.println("检测到的标签:" + JSON.toJSONString(nmsBoxes)); + if (!nmsBoxes.isEmpty()) { + for (BoundingBox box : nmsBoxes) { + System.out.println(box); + } + } + long nmsEnd = System.currentTimeMillis(); + System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms"); + + // 封装结果并返回 + InferenceResult inferenceResult = new InferenceResult(); + inferenceResult.setBoundingBoxes(nmsBoxes); + + long endTime = System.currentTimeMillis(); + System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms"); + + return inferenceResult; + + } catch (OrtException e) { + throw new RuntimeException("推理失败", e); + } + } + + + // 计算两个边界框的 IoU + private float computeIoU(BoundingBox box1, BoundingBox box2) { + int x1 = Math.max(box1.getX(), box2.getX()); + int y1 = Math.max(box1.getY(), box2.getY()); + int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth()); + int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight()); + + int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1); + int box1Area = box1.getWidth() * box1.getHeight(); + int box2Area = box2.getWidth() * box2.getHeight(); + + return (float) intersectionArea / (box1Area + box2Area - intersectionArea); + } + + // 非极大值抑制(NMS)方法 + private List nonMaximumSuppression(List boxes, float iouThreshold) { + // 按置信度排序(从高到低) + boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence())); + + List result = new ArrayList<>(); + + while (!boxes.isEmpty()) { + BoundingBox bestBox = boxes.remove(0); + result.add(bestBox); + + Iterator iterator = boxes.iterator(); + while (iterator.hasNext()) { + BoundingBox box = iterator.next(); + if (computeIoU(bestBox, box) > iouThreshold) { + iterator.remove(); + } + } + } + + return result; + } + + // 打印模型信息 + private void logModelInfo(OrtSession session) { + System.out.println("模型输入信息:"); + try { + for (Map.Entry entry : session.getInputInfo().entrySet()) { + String name = entry.getKey(); + NodeInfo info = entry.getValue(); + System.out.println("输入名称: " + name); + System.out.println("输入信息: " + info.toString()); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + + System.out.println("模型输出信息:"); + try { + for (Map.Entry entry : session.getOutputInfo().entrySet()) { + String name = entry.getKey(); + NodeInfo info = entry.getValue(); + System.out.println("输出名称: " + name); + System.out.println("输出信息: " + info.toString()); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) { + // 加载 OpenCV 库 + + // 初始化标签列表(只有一个标签) + List labels = Arrays.asList("person"); + + // 创建 InferenceEngine 实例 + InferenceEngine inferenceEngine = new InferenceEngine("C:\\Users\\ly\\Desktop\\person.onnx", labels); + + for (int j = 0; j < 10; j++) { + try { + // 加载图片 + Mat inputImage = Imgcodecs.imread("C:\\Users\\ly\\Desktop\\10230731212230.png"); + + // 预处理图像 + long l1 = System.currentTimeMillis(); + Map preprocessResult = inferenceEngine.preprocessImage(inputImage); + float[] inputData = (float[]) preprocessResult.get("inputData"); + + InferenceResult result = null; + for (int i = 0; i < 10; i++) { + long l = System.currentTimeMillis(); + result = inferenceEngine.infer( 640, 640, preprocessResult); + System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms"); + } + + + // 处理并显示结果 + System.out.println("推理结果:"); + for (BoundingBox box : result.getBoundingBoxes()) { + System.out.println(box); + } + + // 可视化并保存带有边界框的图像 + Mat outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes()); + + // 保存图片到本地文件 + String outputFilePath = "output_image_with_boxes.jpg"; + Imgcodecs.imwrite(outputFilePath, outputImage); + + System.out.println("已保存结果图片: " + outputFilePath); + + } catch (Exception e) { + e.printStackTrace(); + } + } + } + + // 在图像上绘制边界框和标签 + private Mat drawBoundingBoxes(Mat image, List boxes) { + for (BoundingBox box : boxes) { + // 绘制矩形边界框 + Imgproc.rectangle(image, new Point(box.getX(), box.getY()), + new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight()), + new Scalar(0, 0, 255), 2); // 红色边框 + + // 绘制标签文字和置信度 + String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); + int baseLine[] = new int[1]; + Size labelSize = Imgproc.getTextSize(label, Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, 1, baseLine); + int top = Math.max(box.getY(), (int) labelSize.height); + Imgproc.putText(image, label, new Point(box.getX(), top), + Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, new Scalar(255, 255, 255), 1); + } + + return image; + } + + + public Map preprocessImage(Mat image) { + int targetWidth = 640; + int targetHeight = 640; + + int origWidth = image.width(); + int origHeight = image.height(); + + // 计算缩放因子 + float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); + + // 计算新的图像尺寸 + int newWidth = Math.round(origWidth * scalingFactor); + int newHeight = Math.round(origHeight * scalingFactor); + + // 计算偏移量以居中图像 + int xOffset = (targetWidth - newWidth) / 2; + int yOffset = (targetHeight - newHeight) / 2; + + // 调整图像尺寸 + Mat resizedImage = new Mat(); + Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR); + + // 转换为 RGB 并归一化 + Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); + resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); + + // 创建填充后的图像 + Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); + Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); + resizedImage.copyTo(paddedImage.submat(roi)); + + // 将图像数据转换为数组 + int imageSize = targetWidth * targetHeight; + float[] chwData = new float[3 * imageSize]; + float[] hwcData = new float[3 * imageSize]; + paddedImage.get(0, 0, hwcData); + + // 转换为 CHW 格式 + int channelSize = imageSize; + for (int c = 0; c < 3; c++) { + for (int i = 0; i < imageSize; i++) { + chwData[c * channelSize + i] = hwcData[i * 3 + c]; + } + } + + // 释放图像资源 + resizedImage.release(); + paddedImage.release(); + + // 将预处理结果和偏移信息存入 Map + Map result = new HashMap<>(); + result.put("inputData", chwData); + result.put("origWidth", origWidth); + result.put("origHeight", origHeight); + result.put("scalingFactor", scalingFactor); + result.put("xOffset", xOffset); + result.put("yOffset", yOffset); + + return result; + } + + + // 图像预处理 +// public float[] preprocessImage(Mat image) { +// int targetWidth = 640; +// int targetHeight = 640; +// +// origWidth = image.width(); +// origHeight = image.height(); +// +// // 计算缩放因子 +// scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); +// +// // 计算新的图像尺寸 +// newWidth = Math.round(origWidth * scalingFactor); +// newHeight = Math.round(origHeight * scalingFactor); +// +// // 计算偏移量以居中图像 +// xOffset = (targetWidth - newWidth) / 2; +// yOffset = (targetHeight - newHeight) / 2; +// +// // 调整图像尺寸 +// Mat resizedImage = new Mat(); +// Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR); +// +// // 转换为 RGB 并归一化 +// Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); +// resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); +// +// // 创建填充后的图像 +// Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); +// Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); +// resizedImage.copyTo(paddedImage.submat(roi)); +// +// // 将图像数据转换为数组 +// int imageSize = targetWidth * targetHeight; +// float[] chwData = new float[3 * imageSize]; +// float[] hwcData = new float[3 * imageSize]; +// paddedImage.get(0, 0, hwcData); +// +// // 转换为 CHW 格式 +// int channelSize = imageSize; +// for (int c = 0; c < 3; c++) { +// for (int i = 0; i < imageSize; i++) { +// chwData[c * channelSize + i] = hwcData[i * 3 + c]; +// } +// } +// +// // 释放图像资源 +// resizedImage.release(); +// paddedImage.release(); +// +// return chwData; +// } + +} diff --git a/src/main/java/com/ly/lishi/VideoPlayer.java b/src/main/java/com/ly/lishi/VideoPlayer.java new file mode 100644 index 0000000..b263eb5 --- /dev/null +++ b/src/main/java/com/ly/lishi/VideoPlayer.java @@ -0,0 +1,378 @@ +package com.ly.lishi; + +import com.ly.layout.VideoPanel; +import com.ly.model_load.ModelManager; +import com.ly.onnx.engine.InferenceEngine; +import com.ly.onnx.model.InferenceResult; +import com.ly.onnx.utils.DrawImagesUtils; +import org.opencv.core.CvType; +import org.opencv.core.Mat; +import org.opencv.core.Rect; +import org.opencv.core.Size; +import org.opencv.imgproc.Imgproc; +import org.opencv.videoio.VideoCapture; +import org.opencv.videoio.Videoio; + +import javax.swing.*; +import java.awt.image.BufferedImage; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static com.ly.onnx.utils.ImageUtils.matToBufferedImage; + +public class VideoPlayer { + static { + // 加载 OpenCV 库 + nu.pattern.OpenCV.loadLocally(); + String OS = System.getProperty("os.name").toLowerCase(); + if (OS.contains("win")) { + System.load(ClassLoader.getSystemResource("lib/win/opencv_videoio_ffmpeg470_64.dll").getPath()); + } + } + + private VideoCapture videoCapture; + private volatile boolean isPlaying = false; + private volatile boolean isPaused = false; + private Thread frameReadingThread; + private Thread inferenceThread; + private VideoPanel videoPanel; + + private long videoDuration = 0; // 毫秒 + private long currentTimestamp = 0; // 毫秒 + + + + private ModelManager modelManager; + private List inferenceEngines = new ArrayList<>(); + + // 定义阻塞队列来缓冲转换后的数据 + private BlockingQueue frameDataQueue = new LinkedBlockingQueue<>(10); // 队列容量可根据需要调整 + + public VideoPlayer(VideoPanel videoPanel, ModelManager modelManager) { + this.videoPanel = videoPanel; + this.modelManager = modelManager; + } + + // 加载视频或流 + public void loadVideo(String videoFilePathOrStreamUrl) throws Exception { + stopVideo(); + if (videoFilePathOrStreamUrl.equals("0")) { + int cameraIndex = Integer.parseInt(videoFilePathOrStreamUrl); + videoCapture = new VideoCapture(cameraIndex); + if (!videoCapture.isOpened()) { + throw new Exception("无法打开摄像头"); + } + videoDuration = 0; // 摄像头没有固定的时长 + playVideo(); + } else { + // 输入不是数字,尝试打开视频文件 + videoCapture = new VideoCapture(videoFilePathOrStreamUrl, Videoio.CAP_FFMPEG); + if (!videoCapture.isOpened()) { + throw new Exception("无法打开视频文件:" + videoFilePathOrStreamUrl); + } + double frameCount = videoCapture.get(Videoio.CAP_PROP_FRAME_COUNT); + double fps = videoCapture.get(Videoio.CAP_PROP_FPS); + if (fps <= 0 || Double.isNaN(fps)) { + fps = 25; // 默认帧率 + } + videoDuration = (long) (frameCount / fps * 1000); // 转换为毫秒 + } + + // 显示第一帧 + Mat frame = new Mat(); + if (videoCapture.read(frame)) { + BufferedImage bufferedImage = matToBufferedImage(frame); + videoPanel.updateImage(bufferedImage); + currentTimestamp = 0; + } else { + throw new Exception("无法读取第一帧"); + } + + // 重置到视频开始位置 + videoCapture.set(Videoio.CAP_PROP_POS_FRAMES, 0); + currentTimestamp = 0; + } + + // 播放 + public void playVideo() { + if (videoCapture == null || !videoCapture.isOpened()) { + JOptionPane.showMessageDialog(null, "请先加载视频文件或流。", "提示", JOptionPane.WARNING_MESSAGE); + return; + } + + if (isPlaying) { + if (isPaused) { + isPaused = false; // 恢复播放 + } + return; + } + + isPlaying = true; + isPaused = false; + + frameDataQueue.clear(); // 开始播放前清空队列 + + // 创建并启动帧读取和转换线程 + frameReadingThread = new Thread(() -> { + try { + double fps = videoCapture.get(Videoio.CAP_PROP_FPS); + if (fps <= 0 || Double.isNaN(fps)) { + fps = 25; // 默认帧率 + } + long frameDelay = (long) (1000 / fps); + + while (isPlaying) { + if (Thread.currentThread().isInterrupted()) { + break; + } + if (isPaused) { + Thread.sleep(10); + continue; + } + + Mat frame = new Mat(); + if (!videoCapture.read(frame) || frame.empty()) { + isPlaying = false; + break; + } + + long startTime = System.currentTimeMillis(); + BufferedImage bufferedImage = matToBufferedImage(frame); + + if (bufferedImage != null) { +// float[] floats = preprocessAndConvertBufferedImage(bufferedImage); + Map stringObjectMap = preprocessImage(frame); + // 创建 FrameData 对象并放入队列 + FrameData frameData = new FrameData(bufferedImage, null,stringObjectMap); + frameDataQueue.put(frameData); // 阻塞,如果队列已满 + } + + // 控制帧率 + currentTimestamp = (long) videoCapture.get(Videoio.CAP_PROP_POS_MSEC); + + // 控制播放速度 + long processingTime = System.currentTimeMillis() - startTime; + long sleepTime = frameDelay - processingTime; + if (sleepTime > 0) { + Thread.sleep(sleepTime); + } + } + } catch (Exception ex) { + ex.printStackTrace(); + } finally { + isPlaying = false; + } + }); + + // 创建并启动推理线程 + inferenceThread = new Thread(() -> { + try { + while (isPlaying || !frameDataQueue.isEmpty()) { + if (Thread.currentThread().isInterrupted()) { + break; + } + if (isPaused) { + Thread.sleep(100); + continue; + } + + FrameData frameData = frameDataQueue.poll(100, TimeUnit.MILLISECONDS); // 等待数据 + if (frameData == null) { + continue; // 没有数据,继续检查 isPlaying + } + + BufferedImage bufferedImage = frameData.image; + Map floatObjectMap = frameData.floatObjectMap; + + // 执行推理 + List inferenceResults = new ArrayList<>(); + for (InferenceEngine inferenceEngine : inferenceEngines) { + // 假设 InferenceEngine 有 infer 方法接受 float 数组 +// inferenceResults.add(inferenceEngine.infer( 640, 640,floatObjectMap)); + } + // 绘制推理结果 + DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults); + + // 更新绘制后图像 + videoPanel.updateImage(bufferedImage); + } + } catch (Exception ex) { + ex.printStackTrace(); + } + }); + + frameReadingThread.start(); + inferenceThread.start(); + } + + // 暂停视频 + public void pauseVideo() { + if (!isPlaying) { + return; + } + isPaused = true; + } + + // 重播视频 + public void replayVideo() { + try { + stopVideo(); // 停止当前播放 + if (videoCapture != null) { + videoCapture.set(Videoio.CAP_PROP_POS_FRAMES, 0); + currentTimestamp = 0; + + // 显示第一帧 + Mat frame = new Mat(); + if (videoCapture.read(frame)) { + BufferedImage bufferedImage = matToBufferedImage(frame); + videoPanel.updateImage(bufferedImage); + } + + playVideo(); // 开始播放 + } + } catch (Exception e) { + e.printStackTrace(); + JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE); + } + } + + // 停止视频 + public void stopVideo() { + isPlaying = false; + isPaused = false; + + if (frameReadingThread != null && frameReadingThread.isAlive()) { + frameReadingThread.interrupt(); + } + + if (inferenceThread != null && inferenceThread.isAlive()) { + inferenceThread.interrupt(); + } + + if (videoCapture != null) { + videoCapture.release(); + videoCapture = null; + } + + frameDataQueue.clear(); + } + + // 快进或后退 + public void seekTo(long seekTime) { + if (videoCapture == null) return; + try { + isPaused = false; // 取消暂停 + stopVideo(); // 停止当前播放 + videoCapture.set(Videoio.CAP_PROP_POS_MSEC, seekTime); + currentTimestamp = seekTime; + + Mat frame = new Mat(); + if (videoCapture.read(frame)) { + BufferedImage bufferedImage = matToBufferedImage(frame); + videoPanel.updateImage(bufferedImage); + } + + // 重新开始播放 + playVideo(); + + } catch (Exception ex) { + ex.printStackTrace(); + } + } + + // 快进 + public void fastForward(long milliseconds) { + long newTime = Math.min(currentTimestamp + milliseconds, videoDuration); + seekTo(newTime); + } + + // 后退 + public void rewind(long milliseconds) { + long newTime = Math.max(currentTimestamp - milliseconds, 0); + seekTo(newTime); + } + + public void addInferenceEngines(InferenceEngine inferenceEngine) { + this.inferenceEngines.add(inferenceEngine); + } + + // 定义一个内部类来存储帧数据 + private static class FrameData { + public BufferedImage image; + public float[] floatArray; + public Map floatObjectMap; + + public FrameData(BufferedImage image, float[] floatArray, Map floatObjectMap) { + this.image = image; + this.floatArray = floatArray; + this.floatObjectMap = floatObjectMap; + } + } + + + // 可选的预处理方法 + public Map preprocessImage(Mat image) { + int targetWidth = 640; + int targetHeight = 640; + + int origWidth = image.width(); + int origHeight = image.height(); + + // 计算缩放因子 + float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); + + // 计算新的图像尺寸 + int newWidth = Math.round(origWidth * scalingFactor); + int newHeight = Math.round(origHeight * scalingFactor); + + // 调整图像尺寸 + Mat resizedImage = new Mat(); + Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight)); + + // 转换为 RGB 并归一化 + Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); + resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); + + // 创建填充后的图像 + Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); + int xOffset = (targetWidth - newWidth) / 2; + int yOffset = (targetHeight - newHeight) / 2; + Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); + resizedImage.copyTo(paddedImage.submat(roi)); + + // 将图像数据转换为数组 + int imageSize = targetWidth * targetHeight; + float[] chwData = new float[3 * imageSize]; + float[] hwcData = new float[3 * imageSize]; + paddedImage.get(0, 0, hwcData); + + // 转换为 CHW 格式 + int channelSize = imageSize; + for (int c = 0; c < 3; c++) { + for (int i = 0; i < imageSize; i++) { + chwData[c * channelSize + i] = hwcData[i * 3 + c]; + } + } + + // 释放图像资源 + resizedImage.release(); + paddedImage.release(); + + // 将预处理结果和偏移信息存入 Map + Map result = new HashMap<>(); + result.put("inputData", chwData); + result.put("origWidth", origWidth); + result.put("origHeight", origHeight); + result.put("scalingFactor", scalingFactor); + result.put("xOffset", xOffset); + result.put("yOffset", yOffset); + + return result; + } + +} diff --git a/src/main/java/com/ly/onnx/engine/InferenceEngine.java b/src/main/java/com/ly/onnx/engine/InferenceEngine.java index bec81f7..793ff4f 100644 --- a/src/main/java/com/ly/onnx/engine/InferenceEngine.java +++ b/src/main/java/com/ly/onnx/engine/InferenceEngine.java @@ -5,6 +5,7 @@ import com.alibaba.fastjson.JSON; import com.ly.onnx.model.BoundingBox; import com.ly.onnx.model.InferenceResult; +import lombok.Data; import org.opencv.core.*; import org.opencv.imgcodecs.Imgcodecs; import org.opencv.imgproc.Imgproc; @@ -12,6 +13,7 @@ import org.opencv.imgproc.Imgproc; import java.nio.FloatBuffer; import java.util.*; +@Data public class InferenceEngine { private OrtEnvironment environment; @@ -21,14 +23,11 @@ public class InferenceEngine { private String modelPath; private List labels; + //preprocessParams输入数据的索引 + private int index; + // 用于存储图像预处理信息的类变量 - private int origWidth; - private int origHeight; - private int newWidth; - private int newHeight; - private float scalingFactor; - private int xOffset; - private int yOffset; + private long[] inputShape = null; static { nu.pattern.OpenCV.loadLocally(); @@ -46,28 +45,32 @@ public class InferenceEngine { sessionOptions = new OrtSession.SessionOptions(); sessionOptions.addCUDA(0); // 使用 GPU session = environment.createSession(modelPath, sessionOptions); + Map inputInfo = session.getInputInfo(); + NodeInfo nodeInfo = inputInfo.values().iterator().next(); + TensorInfo tensorInfo = (TensorInfo) nodeInfo.getInfo(); + inputShape = tensorInfo.getShape(); // 从模型中获取输入形状 logModelInfo(session); } catch (OrtException e) { throw new RuntimeException("模型加载失败", e); } } - public InferenceResult infer(float[] inputData, int w, int h, Map preprocessParams) { + public InferenceResult infer(Map preprocessParams) { long startTime = System.currentTimeMillis(); - + //获取对模型需要的输入大小 + Map params = (Map) preprocessParams.get(index); // 从 Map 中获取偏移相关的变量 - int origWidth = (int) preprocessParams.get("origWidth"); - int origHeight = (int) preprocessParams.get("origHeight"); - float scalingFactor = (float) preprocessParams.get("scalingFactor"); - int xOffset = (int) preprocessParams.get("xOffset"); - int yOffset = (int) preprocessParams.get("yOffset"); + float[] inputData = (float[]) params.get("inputData"); + int origWidth = (int) params.get("origWidth"); + int origHeight = (int) params.get("origHeight"); + float scalingFactor = (float) params.get("scalingFactor"); + int xOffset = (int) params.get("xOffset"); + int yOffset = (int) params.get("yOffset"); try { Map inputInfo = session.getInputInfo(); String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入 - long[] inputShape = {1, 3, h, w}; // 根据模型需求调整形状 - // 创建输入张量时,使用 CHW 格式的数据 OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape); @@ -223,188 +226,5 @@ public class InferenceEngine { } } - public static void main(String[] args) { - // 加载 OpenCV 库 - - // 初始化标签列表(只有一个标签) - List labels = Arrays.asList("person"); - - // 创建 InferenceEngine 实例 - InferenceEngine inferenceEngine = new InferenceEngine("C:\\Users\\ly\\Desktop\\person.onnx", labels); - - for (int j = 0; j < 10; j++) { - try { - // 加载图片 - Mat inputImage = Imgcodecs.imread("C:\\Users\\ly\\Desktop\\10230731212230.png"); - - // 预处理图像 - long l1 = System.currentTimeMillis(); - Map preprocessResult = inferenceEngine.preprocessImage(inputImage); - float[] inputData = (float[]) preprocessResult.get("inputData"); - - InferenceResult result = null; - for (int i = 0; i < 10; i++) { - long l = System.currentTimeMillis(); - result = inferenceEngine.infer(inputData, 640, 640, preprocessResult); - System.out.println("第 " + (i + 1) + " 次推理耗时:" + (System.currentTimeMillis() - l) + " ms"); - } - - - // 处理并显示结果 - System.out.println("推理结果:"); - for (BoundingBox box : result.getBoundingBoxes()) { - System.out.println(box); - } - - // 可视化并保存带有边界框的图像 - Mat outputImage = inferenceEngine.drawBoundingBoxes(inputImage, result.getBoundingBoxes()); - - // 保存图片到本地文件 - String outputFilePath = "output_image_with_boxes.jpg"; - Imgcodecs.imwrite(outputFilePath, outputImage); - - System.out.println("已保存结果图片: " + outputFilePath); - - } catch (Exception e) { - e.printStackTrace(); - } - } - } - - // 在图像上绘制边界框和标签 - private Mat drawBoundingBoxes(Mat image, List boxes) { - for (BoundingBox box : boxes) { - // 绘制矩形边界框 - Imgproc.rectangle(image, new Point(box.getX(), box.getY()), - new Point(box.getX() + box.getWidth(), box.getY() + box.getHeight()), - new Scalar(0, 0, 255), 2); // 红色边框 - - // 绘制标签文字和置信度 - String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); - int baseLine[] = new int[1]; - Size labelSize = Imgproc.getTextSize(label, Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, 1, baseLine); - int top = Math.max(box.getY(), (int) labelSize.height); - Imgproc.putText(image, label, new Point(box.getX(), top), - Imgproc.FONT_HERSHEY_SIMPLEX, 0.5, new Scalar(255, 255, 255), 1); - } - - return image; - } - - - public Map preprocessImage(Mat image) { - int targetWidth = 640; - int targetHeight = 640; - - int origWidth = image.width(); - int origHeight = image.height(); - - // 计算缩放因子 - float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); - - // 计算新的图像尺寸 - int newWidth = Math.round(origWidth * scalingFactor); - int newHeight = Math.round(origHeight * scalingFactor); - - // 计算偏移量以居中图像 - int xOffset = (targetWidth - newWidth) / 2; - int yOffset = (targetHeight - newHeight) / 2; - - // 调整图像尺寸 - Mat resizedImage = new Mat(); - Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR); - - // 转换为 RGB 并归一化 - Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); - resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); - - // 创建填充后的图像 - Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); - Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); - resizedImage.copyTo(paddedImage.submat(roi)); - - // 将图像数据转换为数组 - int imageSize = targetWidth * targetHeight; - float[] chwData = new float[3 * imageSize]; - float[] hwcData = new float[3 * imageSize]; - paddedImage.get(0, 0, hwcData); - - // 转换为 CHW 格式 - int channelSize = imageSize; - for (int c = 0; c < 3; c++) { - for (int i = 0; i < imageSize; i++) { - chwData[c * channelSize + i] = hwcData[i * 3 + c]; - } - } - - // 释放图像资源 - resizedImage.release(); - paddedImage.release(); - - // 将预处理结果和偏移信息存入 Map - Map result = new HashMap<>(); - result.put("inputData", chwData); - result.put("origWidth", origWidth); - result.put("origHeight", origHeight); - result.put("scalingFactor", scalingFactor); - result.put("xOffset", xOffset); - result.put("yOffset", yOffset); - - return result; - } - - - // 图像预处理 -// public float[] preprocessImage(Mat image) { -// int targetWidth = 640; -// int targetHeight = 640; -// -// origWidth = image.width(); -// origHeight = image.height(); -// -// // 计算缩放因子 -// scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); -// -// // 计算新的图像尺寸 -// newWidth = Math.round(origWidth * scalingFactor); -// newHeight = Math.round(origHeight * scalingFactor); -// -// // 计算偏移量以居中图像 -// xOffset = (targetWidth - newWidth) / 2; -// yOffset = (targetHeight - newHeight) / 2; -// -// // 调整图像尺寸 -// Mat resizedImage = new Mat(); -// Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_LINEAR); -// -// // 转换为 RGB 并归一化 -// Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); -// resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); -// -// // 创建填充后的图像 -// Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); -// Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); -// resizedImage.copyTo(paddedImage.submat(roi)); -// -// // 将图像数据转换为数组 -// int imageSize = targetWidth * targetHeight; -// float[] chwData = new float[3 * imageSize]; -// float[] hwcData = new float[3 * imageSize]; -// paddedImage.get(0, 0, hwcData); -// -// // 转换为 CHW 格式 -// int channelSize = imageSize; -// for (int c = 0; c < 3; c++) { -// for (int i = 0; i < imageSize; i++) { -// chwData[c * channelSize + i] = hwcData[i * 3 + c]; -// } -// } -// -// // 释放图像资源 -// resizedImage.release(); -// paddedImage.release(); -// -// return chwData; -// } } diff --git a/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java b/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java index 240da5d..08bb5c9 100644 --- a/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java +++ b/src/main/java/com/ly/onnx/utils/DrawImagesUtils.java @@ -7,16 +7,45 @@ import org.opencv.core.Point; import org.opencv.core.Scalar; import org.opencv.imgproc.Imgproc; +import java.awt.*; import java.awt.image.BufferedImage; import java.util.List; public class DrawImagesUtils { - public static void drawInferenceResult(BufferedImage bufferedImage, List result) { + public static void drawInferenceResult(BufferedImage bufferedImage, List inferenceResults) { + Graphics2D g2d = bufferedImage.createGraphics(); + g2d.setFont(new Font("Arial", Font.PLAIN, 12)); + for (InferenceResult result : inferenceResults) { + for (BoundingBox box : result.getBoundingBoxes()) { + // 绘制矩形 + g2d.setColor(Color.RED); + g2d.drawRect(box.getX(), box.getY(), box.getWidth(), box.getHeight()); + + // 绘制标签 + String label = box.getLabel() + " " + String.format("%.2f", box.getConfidence()); + FontMetrics metrics = g2d.getFontMetrics(); + int labelWidth = metrics.stringWidth(label); + int labelHeight = metrics.getHeight(); + + // 确保文字不会超出图像 + int y = Math.max(box.getY(), labelHeight); + + // 绘制文字背景 + g2d.setColor(Color.RED); + g2d.fillRect(box.getX(), y - labelHeight, labelWidth, labelHeight); + + // 绘制文字 + g2d.setColor(Color.WHITE); + g2d.drawString(label, box.getX(), y); + } + } + g2d.dispose(); // 释放资源 } + // 在 Mat 上绘制推理结果 public static void drawInferenceResult(Mat image, List inferenceResults) { for (InferenceResult result : inferenceResults) { diff --git a/src/main/java/com/ly/onnx/utils/ImageUtils.java b/src/main/java/com/ly/onnx/utils/ImageUtils.java index 0fe7286..6d29774 100644 --- a/src/main/java/com/ly/onnx/utils/ImageUtils.java +++ b/src/main/java/com/ly/onnx/utils/ImageUtils.java @@ -77,30 +77,5 @@ public class ImageUtils { return image; } - // 将 Mat 转换为 float 数组,适用于推理 - public static float[] matToFloatArray(Mat mat) { - // 假设 InferenceEngine 需要 RGB 格式的图像 - Mat rgbMat = new Mat(); - Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB); - - // 假设图像已经被预处理(缩放、归一化等),否则需要在这里添加预处理步骤 - - // 将 Mat 数据转换为 float 数组 - int channels = rgbMat.channels(); - int rows = rgbMat.rows(); - int cols = rgbMat.cols(); - float[] floatData = new float[channels * rows * cols]; - byte[] byteData = new byte[channels * rows * cols]; - rgbMat.get(0, 0, byteData); - for (int i = 0; i < floatData.length; i++) { - // 将 unsigned byte 转换为 float [0,1] - floatData[i] = (byteData[i] & 0xFF) / 255.0f; - } - rgbMat.release(); - return floatData; - } - - - } diff --git a/src/main/java/com/ly/play/opencv/VideoPlayer.java b/src/main/java/com/ly/play/opencv/VideoPlayer.java index d341b70..2f2204b 100644 --- a/src/main/java/com/ly/play/opencv/VideoPlayer.java +++ b/src/main/java/com/ly/play/opencv/VideoPlayer.java @@ -5,22 +5,17 @@ import com.ly.model_load.ModelManager; import com.ly.onnx.engine.InferenceEngine; import com.ly.onnx.model.InferenceResult; import com.ly.onnx.utils.DrawImagesUtils; -import org.opencv.core.CvType; -import org.opencv.core.Mat; -import org.opencv.core.Rect; -import org.opencv.core.Size; +import org.opencv.core.*; import org.opencv.imgproc.Imgproc; import org.opencv.videoio.VideoCapture; import org.opencv.videoio.Videoio; import javax.swing.*; import java.awt.image.BufferedImage; -import java.awt.image.DataBufferByte; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.*; +import java.util.*; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import static com.ly.onnx.utils.ImageUtils.matToBufferedImage; @@ -44,6 +39,7 @@ public class VideoPlayer { private long videoDuration = 0; // 毫秒 private long currentTimestamp = 0; // 毫秒 + private ModelManager modelManager; private List inferenceEngines = new ArrayList<>(); @@ -72,8 +68,8 @@ public class VideoPlayer { if (!videoCapture.isOpened()) { throw new Exception("无法打开视频文件:" + videoFilePathOrStreamUrl); } - double frameCount = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FRAME_COUNT); - double fps = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FPS); + double frameCount = videoCapture.get(Videoio.CAP_PROP_FRAME_COUNT); + double fps = videoCapture.get(Videoio.CAP_PROP_FPS); if (fps <= 0 || Double.isNaN(fps)) { fps = 25; // 默认帧率 } @@ -91,7 +87,7 @@ public class VideoPlayer { } // 重置到视频开始位置 - videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_FRAMES, 0); + videoCapture.set(Videoio.CAP_PROP_POS_FRAMES, 0); currentTimestamp = 0; } @@ -117,12 +113,11 @@ public class VideoPlayer { // 创建并启动帧读取和转换线程 frameReadingThread = new Thread(() -> { try { - double fps = videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_FPS); + double fps = videoCapture.get(Videoio.CAP_PROP_FPS); if (fps <= 0 || Double.isNaN(fps)) { fps = 25; // 默认帧率 } long frameDelay = (long) (1000 / fps); - while (isPlaying) { if (Thread.currentThread().isInterrupted()) { break; @@ -131,27 +126,19 @@ public class VideoPlayer { Thread.sleep(10); continue; } - Mat frame = new Mat(); if (!videoCapture.read(frame) || frame.empty()) { isPlaying = false; break; } - long startTime = System.currentTimeMillis(); BufferedImage bufferedImage = matToBufferedImage(frame); - - if (bufferedImage != null) { - float[] floats = preprocessAndConvertBufferedImage(bufferedImage); - - // 创建 FrameData 对象并放入队列 - FrameData frameData = new FrameData(bufferedImage, floats); - frameDataQueue.put(frameData); // 阻塞,如果队列已满 - } - + Map stringObjectMap = preprocessImage(frame); + // 创建 FrameData 对象并放入队列 + FrameData frameData = new FrameData(bufferedImage, stringObjectMap); + frameDataQueue.put(frameData); // 阻塞,如果队列已满 // 控制帧率 - currentTimestamp = (long) videoCapture.get(org.opencv.videoio.Videoio.CAP_PROP_POS_MSEC); - + currentTimestamp = (long) videoCapture.get(Videoio.CAP_PROP_POS_MSEC); // 控制播放速度 long processingTime = System.currentTimeMillis() - startTime; long sleepTime = frameDelay - processingTime; @@ -184,17 +171,16 @@ public class VideoPlayer { } BufferedImage bufferedImage = frameData.image; - float[] floatArray = frameData.floatArray; + Map floatObjectMap = frameData.floatObjectMap; // 执行推理 List inferenceResults = new ArrayList<>(); for (InferenceEngine inferenceEngine : inferenceEngines) { // 假设 InferenceEngine 有 infer 方法接受 float 数组 -// inferenceResults.add(inferenceEngine.infer(floatArray, 640, 640)); + inferenceResults.add(inferenceEngine.infer(floatObjectMap)); } // 绘制推理结果 DrawImagesUtils.drawInferenceResult(bufferedImage, inferenceResults); - // 更新绘制后图像 videoPanel.updateImage(bufferedImage); } @@ -202,7 +188,6 @@ public class VideoPlayer { ex.printStackTrace(); } }); - frameReadingThread.start(); inferenceThread.start(); } @@ -215,29 +200,6 @@ public class VideoPlayer { isPaused = true; } - // 重播视频 - public void replayVideo() { - try { - stopVideo(); // 停止当前播放 - if (videoCapture != null) { - videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_FRAMES, 0); - currentTimestamp = 0; - - // 显示第一帧 - Mat frame = new Mat(); - if (videoCapture.read(frame)) { - BufferedImage bufferedImage = matToBufferedImage(frame); - videoPanel.updateImage(bufferedImage); - } - - playVideo(); // 开始播放 - } - } catch (Exception e) { - e.printStackTrace(); - JOptionPane.showMessageDialog(null, "重播失败: " + e.getMessage(), "错误", JOptionPane.ERROR_MESSAGE); - } - } - // 停止视频 public void stopVideo() { isPlaying = false; @@ -259,41 +221,6 @@ public class VideoPlayer { frameDataQueue.clear(); } - // 快进或后退 - public void seekTo(long seekTime) { - if (videoCapture == null) return; - try { - isPaused = false; // 取消暂停 - stopVideo(); // 停止当前播放 - videoCapture.set(org.opencv.videoio.Videoio.CAP_PROP_POS_MSEC, seekTime); - currentTimestamp = seekTime; - - Mat frame = new Mat(); - if (videoCapture.read(frame)) { - BufferedImage bufferedImage = matToBufferedImage(frame); - videoPanel.updateImage(bufferedImage); - } - - // 重新开始播放 - playVideo(); - - } catch (Exception ex) { - ex.printStackTrace(); - } - } - - // 快进 - public void fastForward(long milliseconds) { - long newTime = Math.min(currentTimestamp + milliseconds, videoDuration); - seekTo(newTime); - } - - // 后退 - public void rewind(long milliseconds) { - long newTime = Math.max(currentTimestamp - milliseconds, 0); - seekTo(newTime); - } - public void addInferenceEngines(InferenceEngine inferenceEngine) { this.inferenceEngines.add(inferenceEngine); } @@ -301,127 +228,111 @@ public class VideoPlayer { // 定义一个内部类来存储帧数据 private static class FrameData { public BufferedImage image; - public float[] floatArray; - - public FrameData(BufferedImage image, float[] floatArray) { + public Map floatObjectMap; + public FrameData(BufferedImage image, Map floatObjectMap) { this.image = image; - this.floatArray = floatArray; + this.floatObjectMap = floatObjectMap; } } - // 将 BufferedImage 预处理并转换为一维 float[] 数组 - public static float[] preprocessAndConvertBufferedImage(BufferedImage image) { - int targetWidth = 640; - int targetHeight = 640; - - // 将 BufferedImage 转换为 Mat - Mat matImage = bufferedImageToMat(image); - - // 原始图像尺寸 - int origWidth = matImage.width(); - int origHeight = matImage.height(); - - // 计算缩放因子 - float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); - - // 计算新的图像尺寸 - int newWidth = Math.round(origWidth * scalingFactor); - int newHeight = Math.round(origHeight * scalingFactor); - - // 调整图像尺寸 - Mat resizedImage = new Mat(); - Imgproc.resize(matImage, resizedImage, new Size(newWidth, newHeight)); - - // 转换为 RGB - Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); - - // 创建目标图像并将调整后的图像填充到目标图像中 - Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); - int xOffset = (targetWidth - newWidth) / 2; - int yOffset = (targetHeight - newHeight) / 2; - Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); - resizedImage.copyTo(paddedImage.submat(roi)); - - // 将图像数据转换为输入所需的浮点数组 - int imageSize = targetWidth * targetHeight; - float[] inputData = new float[3 * imageSize]; - paddedImage.reshape(1, imageSize * 3).get(0, 0, inputData); - - // 释放资源 - matImage.release(); - resizedImage.release(); - paddedImage.release(); - - return inputData; - } - - // 辅助方法:将 BufferedImage 转换为 OpenCV 的 Mat 格式 - public static Mat bufferedImageToMat(BufferedImage bi) { - int width = bi.getWidth(); - int height = bi.getHeight(); - Mat mat = new Mat(height, width, CvType.CV_8UC3); - byte[] data = ((DataBufferByte) bi.getRaster().getDataBuffer()).getData(); - mat.put(0, 0, data); - return mat; - } // 可选的预处理方法 - public Map preprocessImage(Mat image) { - int targetWidth = 640; - int targetHeight = 640; - + public Map preprocessImage(Mat image) { int origWidth = image.width(); int origHeight = image.height(); + Map dynamicInput = new HashMap<>(); + //定义索引 + int index = 0; + for (InferenceEngine inferenceEngine : this.inferenceEngines) { + inferenceEngine.setIndex(index); + long[] inputShape = inferenceEngine.getInputShape(); + int targetWidth = (int) inputShape[2]; + int targetHeight = (int) inputShape[3]; + // 计算缩放因子 + float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); - // 计算缩放因子 - float scalingFactor = Math.min((float) targetWidth / origWidth, (float) targetHeight / origHeight); - - // 计算新的图像尺寸 - int newWidth = Math.round(origWidth * scalingFactor); - int newHeight = Math.round(origHeight * scalingFactor); - - // 调整图像尺寸 - Mat resizedImage = new Mat(); - Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight)); - - // 转换为 RGB 并归一化 - Imgproc.cvtColor(resizedImage, resizedImage, Imgproc.COLOR_BGR2RGB); - resizedImage.convertTo(resizedImage, CvType.CV_32FC3, 1.0 / 255.0); - - // 创建填充后的图像 - Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); - int xOffset = (targetWidth - newWidth) / 2; - int yOffset = (targetHeight - newHeight) / 2; - Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); - resizedImage.copyTo(paddedImage.submat(roi)); - - // 将图像数据转换为数组 - int imageSize = targetWidth * targetHeight; - float[] chwData = new float[3 * imageSize]; - float[] hwcData = new float[3 * imageSize]; - paddedImage.get(0, 0, hwcData); - - // 转换为 CHW 格式 - int channelSize = imageSize; - for (int c = 0; c < 3; c++) { - for (int i = 0; i < imageSize; i++) { - chwData[c * channelSize + i] = hwcData[i * 3 + c]; + //检查是否存在输入大小一致的 如果存在则跳过 + if (!dynamicInput.isEmpty()) { + for (Map.Entry entry : dynamicInput.entrySet()) { + Map input = (Map) entry.getValue(); + if (inputShape[2] == (long) input.get("targetHeight") || inputShape[3] == (long) input.get("targetWidth")) { + break; + } + } } + // 计算新的图像尺寸 + int newWidth = Math.round(origWidth * scalingFactor); + int newHeight = Math.round(origHeight * scalingFactor); + + // 调整图像尺寸 + Mat resizedImage = new Mat(); + Imgproc.resize(image, resizedImage, new Size(newWidth, newHeight), 0, 0, Imgproc.INTER_AREA); + + // 获取图像的尺寸 + int rows = resizedImage.rows(); + int cols = resizedImage.cols(); + // 准备存储浮点型数据的数组 + float[] floatData = new float[rows * cols * 3]; + + // 获取原始字节数据 + byte[] pixelData = new byte[rows * cols * 3]; + resizedImage.get(0, 0, pixelData); + + // 手动处理像素数据 + for (int i = 0; i < rows * cols; i++) { + int byteIndex = i * 3; + int floatIndex = i * 3; + // 读取 BGR 值并转换为 0.0 - 1.0 之间的浮点数 + float b = (pixelData[byteIndex] & 0xFF) / 255.0f; + float g = (pixelData[byteIndex + 1] & 0xFF) / 255.0f; + float r = (pixelData[byteIndex + 2] & 0xFF) / 255.0f; + // 将 BGR 转换为 RGB,并存储到浮点数组中 + floatData[floatIndex] = r; + floatData[floatIndex + 1] = g; + floatData[floatIndex + 2] = b; + } + + // 将浮点数组转换回 Mat 对象 + Mat floatImage = new Mat(rows, cols, CvType.CV_32FC3); + floatImage.put(0, 0, floatData); + + resizedImage = floatImage; + + // 创建填充后的图像 + Mat paddedImage = Mat.zeros(new Size(targetWidth, targetHeight), CvType.CV_32FC3); + int xOffset = (targetWidth - newWidth) / 2; + int yOffset = (targetHeight - newHeight) / 2; + Rect roi = new Rect(xOffset, yOffset, newWidth, newHeight); + resizedImage.copyTo(paddedImage.submat(roi)); + + // 将图像数据转换为数组 + int imageSize = targetWidth * targetHeight; + float[] chwData = new float[3 * imageSize]; + float[] hwcData = new float[3 * imageSize]; + paddedImage.get(0, 0, hwcData); + + // 转换为 CHW 格式 + int channelSize = imageSize; + for (int c = 0; c < 3; c++) { + for (int i = 0; i < imageSize; i++) { + chwData[c * channelSize + i] = hwcData[i * 3 + c]; + } + } + // 释放图像资源 + resizedImage.release(); + paddedImage.release(); + // 将预处理结果和偏移信息存入 Map + Map result = new HashMap<>(); + result.put("inputData", chwData); + result.put("origWidth", origWidth); + result.put("origHeight", origHeight); + result.put("scalingFactor", scalingFactor); + result.put("xOffset", xOffset); + result.put("yOffset", yOffset); + dynamicInput.put(index, result); + index++; } - - // 释放图像资源 - resizedImage.release(); - paddedImage.release(); - - // 将预处理结果和偏移信息存入 Map - Map result = new HashMap<>(); - result.put("inputData", chwData); - result.put("origWidth", origWidth); - result.put("origHeight", origHeight); - result.put("scalingFactor", scalingFactor); - result.put("xOffset", xOffset); - result.put("yOffset", yOffset); - - return result; + return dynamicInput; } + }