From 493c40fb737b6610524fc64844dc9f791121cd23 Mon Sep 17 00:00:00 2001 From: jario Date: Wed, 6 Dec 2023 19:49:10 +0800 Subject: [PATCH] SpireCVDet supports adjusting the batch-size of outputing engines --- samples/SpireCVDet.cpp | 13 ++++--- samples/SpireCVSeg.cpp | 81 ++++++++++++++++++++++-------------------- 2 files changed, 52 insertions(+), 42 deletions(-) diff --git a/samples/SpireCVDet.cpp b/samples/SpireCVDet.cpp index 0a7353e..2825035 100644 --- a/samples/SpireCVDet.cpp +++ b/samples/SpireCVDet.cpp @@ -18,9 +18,9 @@ const static int kInputH_HD = 1280; const static int kInputW_HD = 1280; const static int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; -bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, bool& is_p6, float& gd, float& gw, std::string& img_dir, int& n_classes) { +bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, bool& is_p6, float& gd, float& gw, std::string& img_dir, int& n_classes, int& n_batch) { if (argc < 4) return false; - if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 8)) { + if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 7 || argc == 8)) { wts = std::string(argv[2]); engine = std::string(argv[3]); n_classes = atoi(argv[4]); @@ -51,6 +51,9 @@ bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, bo if (net.size() == 2 && net[1] == '6') { is_p6 = true; } + if (argc == 7) { + n_batch = atoi(argv[6]); + } } else { return false; } @@ -99,18 +102,20 @@ int main(int argc, char** argv) { float gd = 0.0f, gw = 0.0f; std::string img_dir; int n_classes; + int n_batch = 1; - if (!parse_args(argc, argv, wts_name, engine_name, is_p6, gd, gw, img_dir, n_classes)) { + if (!parse_args(argc, argv, wts_name, engine_name, is_p6, gd, gw, img_dir, n_classes, n_batch)) { std::cerr << "arguments not right!" << std::endl; std::cerr << "./SpireCVDet -s [.wts] [.engine] n_classes [n/s/m/l/x/n6/s6/m6/l6/x6 or c/c6 gd gw] // serialize model to plan file" << std::endl; // std::cerr << "./SpireCVDet -d [.engine] ../images // deserialize plan file and run inference" << std::endl; return -1; } std::cout << "n_classes: " << n_classes << std::endl; + std::cout << "max_batch: " << n_batch << std::endl; // Create a model using the API directly and serialize it to a file if (!wts_name.empty()) { - serialize_engine(kBatchSize, is_p6, gd, gw, wts_name, engine_name, n_classes); + serialize_engine(n_batch, is_p6, gd, gw, wts_name, engine_name, n_classes); return 0; } diff --git a/samples/SpireCVSeg.cpp b/samples/SpireCVSeg.cpp index ebc1ebc..d8e0c98 100644 --- a/samples/SpireCVSeg.cpp +++ b/samples/SpireCVSeg.cpp @@ -18,44 +18,47 @@ const static int kInputW = 640; const static int kOutputSize1 = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; const static int kOutputSize2 = 32 * (kInputH / 4) * (kInputW / 4); -bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, std::string& img_dir, std::string& labels_filename, int& n_classes) { - if (argc < 4) return false; - if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 8)) { - wts = std::string(argv[2]); - engine = std::string(argv[3]); - n_classes = atoi(argv[4]); - if (n_classes < 1) - return false; - auto net = std::string(argv[5]); - if (net[0] == 'n') { - gd = 0.33; - gw = 0.25; - } else if (net[0] == 's') { - gd = 0.33; - gw = 0.50; - } else if (net[0] == 'm') { - gd = 0.67; - gw = 0.75; - } else if (net[0] == 'l') { - gd = 1.0; - gw = 1.0; - } else if (net[0] == 'x') { - gd = 1.33; - gw = 1.25; - } else if (net[0] == 'c' && argc == 8) { - gd = atof(argv[6]); - gw = atof(argv[7]); - } else { - return false; - } - } else if (std::string(argv[1]) == "-d" && argc == 5) { - engine = std::string(argv[2]); - img_dir = std::string(argv[3]); - labels_filename = std::string(argv[4]); +bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, std::string& img_dir, std::string& labels_filename, int& n_classes, int& n_batch) { + if (argc < 4) return false; + if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 7 || argc == 8)) { + wts = std::string(argv[2]); + engine = std::string(argv[3]); + n_classes = atoi(argv[4]); + if (n_classes < 1) + return false; + auto net = std::string(argv[5]); + if (net[0] == 'n') { + gd = 0.33; + gw = 0.25; + } else if (net[0] == 's') { + gd = 0.33; + gw = 0.50; + } else if (net[0] == 'm') { + gd = 0.67; + gw = 0.75; + } else if (net[0] == 'l') { + gd = 1.0; + gw = 1.0; + } else if (net[0] == 'x') { + gd = 1.33; + gw = 1.25; + } else if (net[0] == 'c' && argc == 8) { + gd = atof(argv[6]); + gw = atof(argv[7]); } else { - return false; + return false; } - return true; + if (argc == 7) { + n_batch = atoi(argv[6]); + } + } else if (std::string(argv[1]) == "-d" && argc == 5) { + engine = std::string(argv[2]); + img_dir = std::string(argv[3]); + labels_filename = std::string(argv[4]); + } else { + return false; + } + return true; } @@ -98,19 +101,21 @@ int main(int argc, char** argv) { std::string labels_filename = ""; float gd = 0.0f, gw = 0.0f; int n_classes; + int n_batch = 1; std::string img_dir; - if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir, labels_filename, n_classes)) { + if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir, labels_filename, n_classes, n_batch)) { std::cerr << "arguments not right!" << std::endl; std::cerr << "./SpireCVSeg -s [.wts] [.engine] n_classes [n/s/m/l/x or c gd gw] // serialize model to plan file" << std::endl; // std::cerr << "./SpireCVSeg -d [.engine] ../images coco.txt // deserialize plan file, read the labels file and run inference" << std::endl; return -1; } std::cout << "n_classes: " << n_classes << std::endl; + std::cout << "max_batch: " << n_batch << std::endl; // Create a model using the API directly and serialize it to a file if (!wts_name.empty()) { - serialize_engine(kBatchSize, gd, gw, wts_name, engine_name, n_classes); + serialize_engine(n_batch, gd, gw, wts_name, engine_name, n_classes); return 0; }