优化加载模型慢,首次推理慢

This commit is contained in:
sulv 2024-10-10 20:11:19 +08:00
parent 7a88b08500
commit c2dd067a56
2 changed files with 27 additions and 2 deletions

View File

@ -130,7 +130,6 @@ public class VideoInferenceApp extends JFrame {
modelManager.loadModel(this);
DefaultListModel<ModelInfo> modelList = modelManager.getModelList();
ArrayList<ModelInfo> models = Collections.list(modelList.elements());
for (ModelInfo modelInfo : models) {
if (modelInfo != null) {
boolean alreadyAdded = videoPlayer.getInferenceEngines().stream()

View File

@ -12,6 +12,9 @@ import org.opencv.imgproc.Imgproc;
import java.nio.FloatBuffer;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
@Data
public class InferenceEngine {
@ -36,7 +39,30 @@ public class InferenceEngine {
public InferenceEngine(String modelPath, List<String> labels) {
this.modelPath = modelPath;
this.labels = labels;
init();
initAsync();
}
// 异步执行模型初始化
public void initAsync() {
ExecutorService executor = Executors.newSingleThreadExecutor();
executor.execute(()->{
init();
warmUp();
});
}
public void warmUp() {
// 提前执行一次空推理用于初始化模型CUDA上下文等
try {
float[] dummyInput = new float[(int) inputShape[0] * (int) inputShape[1] * (int) inputShape[2] * (int) inputShape[3]];
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(dummyInput), inputShape);
String inputName = session.getInputInfo().keySet().iterator().next();
// 执行空推理
session.run(Collections.singletonMap(inputName, inputTensor));
inputTensor.close();
System.out.println("预热推理完成,首次推理性能已优化。");
} catch (Exception e) {
throw new RuntimeException("模型预热失败", e);
}
}
public void init() {