onnx-inference4j-play/src/main/java/com/ly/onnx/engine/InferenceEngine.java

231 lines
9.1 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.ly.onnx.engine;
import ai.onnxruntime.*;
import com.alibaba.fastjson.JSON;
import com.ly.onnx.model.BoundingBox;
import com.ly.onnx.model.InferenceResult;
import lombok.Data;
import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.nio.FloatBuffer;
import java.util.*;
@Data
public class InferenceEngine {
private OrtEnvironment environment;
private OrtSession.SessionOptions sessionOptions;
private OrtSession session;
private String modelPath;
private List<String> labels;
//preprocessParams输入数据的索引
private int index;
// 用于存储图像预处理信息的类变量
private long[] inputShape = null;
static {
nu.pattern.OpenCV.loadLocally();
}
public InferenceEngine(String modelPath, List<String> labels) {
this.modelPath = modelPath;
this.labels = labels;
init();
}
public void init() {
try {
environment = OrtEnvironment.getEnvironment();
sessionOptions = new OrtSession.SessionOptions();
sessionOptions.addCUDA(0); // 使用 GPU
session = environment.createSession(modelPath, sessionOptions);
Map<String, NodeInfo> inputInfo = session.getInputInfo();
NodeInfo nodeInfo = inputInfo.values().iterator().next();
TensorInfo tensorInfo = (TensorInfo) nodeInfo.getInfo();
inputShape = tensorInfo.getShape(); // 从模型中获取输入形状
logModelInfo(session);
} catch (OrtException e) {
throw new RuntimeException("模型加载失败", e);
}
}
public InferenceResult infer(Map<Integer, Object> preprocessParams) {
long startTime = System.currentTimeMillis();
//获取对模型需要的输入大小
Map<String, Object> params = (Map<String, Object>) preprocessParams.get(index);
// 从 Map 中获取偏移相关的变量
float[] inputData = (float[]) params.get("inputData");
int origWidth = (int) params.get("origWidth");
int origHeight = (int) params.get("origHeight");
float scalingFactor = (float) params.get("scalingFactor");
int xOffset = (int) params.get("xOffset");
int yOffset = (int) params.get("yOffset");
try {
Map<String, NodeInfo> inputInfo = session.getInputInfo();
String inputName = inputInfo.keySet().iterator().next(); // 假设只有一个输入
// 创建输入张量时,使用 CHW 格式的数据
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), inputShape);
// 执行推理
long inferenceStart = System.currentTimeMillis();
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
long inferenceEnd = System.currentTimeMillis();
System.out.println("模型推理耗时:" + (inferenceEnd - inferenceStart) + " ms");
// 解析推理结果
String outputName = session.getOutputInfo().keySet().iterator().next(); // 假设只有一个输出
float[][][] outputData = (float[][][]) result.get(outputName).get().getValue(); // 输出形状:[1, N, 6]
// 设定置信度阈值
float confidenceThreshold = 0.25f; // 您可以根据需要调整
// 根据模型的输出结果解析边界框
List<BoundingBox> boxes = new ArrayList<>();
for (float[] data : outputData[0]) { // 遍历所有检测框
// 根据模型输出格式,解析中心坐标和宽高
float x_center = data[0];
float y_center = data[1];
float width = data[2];
float height = data[3];
float confidence = data[4];
if (confidence >= confidenceThreshold) {
// 将中心坐标转换为左上角和右下角坐标
float x1 = x_center - width / 2;
float y1 = y_center - height / 2;
float x2 = x_center + width / 2;
float y2 = y_center + height / 2;
// 调整坐标,减去偏移并除以缩放因子
float x1Adjusted = (x1 - xOffset) / scalingFactor;
float y1Adjusted = (y1 - yOffset) / scalingFactor;
float x2Adjusted = (x2 - xOffset) / scalingFactor;
float y2Adjusted = (y2 - yOffset) / scalingFactor;
// 确保坐标的正确顺序
float xMinAdjusted = Math.min(x1Adjusted, x2Adjusted);
float xMaxAdjusted = Math.max(x1Adjusted, x2Adjusted);
float yMinAdjusted = Math.min(y1Adjusted, y2Adjusted);
float yMaxAdjusted = Math.max(y1Adjusted, y2Adjusted);
// 确保坐标在原始图像范围内
int x = (int) Math.max(0, xMinAdjusted);
int y = (int) Math.max(0, yMinAdjusted);
int xMax = (int) Math.min(origWidth, xMaxAdjusted);
int yMax = (int) Math.min(origHeight, yMaxAdjusted);
int wBox = xMax - x;
int hBox = yMax - y;
// 仅当宽度和高度为正时,才添加边界框
if (wBox > 0 && hBox > 0) {
// 使用您的单一标签
String label = labels.get(0);
boxes.add(new BoundingBox(x, y, wBox, hBox, label, confidence));
}
}
}
// 非极大值抑制NMS
long nmsStart = System.currentTimeMillis();
List<BoundingBox> nmsBoxes = nonMaximumSuppression(boxes, 0.5f);
System.out.println("检测到的标签:" + JSON.toJSONString(nmsBoxes));
if (!nmsBoxes.isEmpty()) {
for (BoundingBox box : nmsBoxes) {
System.out.println(box);
}
}
long nmsEnd = System.currentTimeMillis();
System.out.println("NMS 耗时:" + (nmsEnd - nmsStart) + " ms");
// 封装结果并返回
InferenceResult inferenceResult = new InferenceResult();
inferenceResult.setBoundingBoxes(nmsBoxes);
long endTime = System.currentTimeMillis();
System.out.println("一次推理总耗时:" + (endTime - startTime) + " ms");
return inferenceResult;
} catch (OrtException e) {
throw new RuntimeException("推理失败", e);
}
}
// 计算两个边界框的 IoU
private float computeIoU(BoundingBox box1, BoundingBox box2) {
int x1 = Math.max(box1.getX(), box2.getX());
int y1 = Math.max(box1.getY(), box2.getY());
int x2 = Math.min(box1.getX() + box1.getWidth(), box2.getX() + box2.getWidth());
int y2 = Math.min(box1.getY() + box1.getHeight(), box2.getY() + box2.getHeight());
int intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
int box1Area = box1.getWidth() * box1.getHeight();
int box2Area = box2.getWidth() * box2.getHeight();
return (float) intersectionArea / (box1Area + box2Area - intersectionArea);
}
// 非极大值抑制NMS方法
private List<BoundingBox> nonMaximumSuppression(List<BoundingBox> boxes, float iouThreshold) {
// 按置信度排序(从高到低)
boxes.sort((a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
List<BoundingBox> result = new ArrayList<>();
while (!boxes.isEmpty()) {
BoundingBox bestBox = boxes.remove(0);
result.add(bestBox);
Iterator<BoundingBox> iterator = boxes.iterator();
while (iterator.hasNext()) {
BoundingBox box = iterator.next();
if (computeIoU(bestBox, box) > iouThreshold) {
iterator.remove();
}
}
}
return result;
}
// 打印模型信息
private void logModelInfo(OrtSession session) {
System.out.println("模型输入信息:");
try {
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
String name = entry.getKey();
NodeInfo info = entry.getValue();
System.out.println("输入名称: " + name);
System.out.println("输入信息: " + info.toString());
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
System.out.println("模型输出信息:");
try {
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
String name = entry.getKey();
NodeInfo info = entry.getValue();
System.out.println("输出名称: " + name);
System.out.println("输出信息: " + info.toString());
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
}
}