SpireCVDet supports adjusting the batch-size of outputing engines

This commit is contained in:
jario 2023-12-06 19:49:10 +08:00
parent 43deec9daa
commit 493c40fb73
2 changed files with 52 additions and 42 deletions

View File

@ -18,9 +18,9 @@ const static int kInputH_HD = 1280;
const static int kInputW_HD = 1280;
const static int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1;
bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, bool& is_p6, float& gd, float& gw, std::string& img_dir, int& n_classes) {
bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, bool& is_p6, float& gd, float& gw, std::string& img_dir, int& n_classes, int& n_batch) {
if (argc < 4) return false;
if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 8)) {
if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 7 || argc == 8)) {
wts = std::string(argv[2]);
engine = std::string(argv[3]);
n_classes = atoi(argv[4]);
@ -51,6 +51,9 @@ bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, bo
if (net.size() == 2 && net[1] == '6') {
is_p6 = true;
}
if (argc == 7) {
n_batch = atoi(argv[6]);
}
} else {
return false;
}
@ -99,18 +102,20 @@ int main(int argc, char** argv) {
float gd = 0.0f, gw = 0.0f;
std::string img_dir;
int n_classes;
int n_batch = 1;
if (!parse_args(argc, argv, wts_name, engine_name, is_p6, gd, gw, img_dir, n_classes)) {
if (!parse_args(argc, argv, wts_name, engine_name, is_p6, gd, gw, img_dir, n_classes, n_batch)) {
std::cerr << "arguments not right!" << std::endl;
std::cerr << "./SpireCVDet -s [.wts] [.engine] n_classes [n/s/m/l/x/n6/s6/m6/l6/x6 or c/c6 gd gw] // serialize model to plan file" << std::endl;
// std::cerr << "./SpireCVDet -d [.engine] ../images // deserialize plan file and run inference" << std::endl;
return -1;
}
std::cout << "n_classes: " << n_classes << std::endl;
std::cout << "max_batch: " << n_batch << std::endl;
// Create a model using the API directly and serialize it to a file
if (!wts_name.empty()) {
serialize_engine(kBatchSize, is_p6, gd, gw, wts_name, engine_name, n_classes);
serialize_engine(n_batch, is_p6, gd, gw, wts_name, engine_name, n_classes);
return 0;
}

View File

@ -18,44 +18,47 @@ const static int kInputW = 640;
const static int kOutputSize1 = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1;
const static int kOutputSize2 = 32 * (kInputH / 4) * (kInputW / 4);
bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, std::string& img_dir, std::string& labels_filename, int& n_classes) {
if (argc < 4) return false;
if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 8)) {
wts = std::string(argv[2]);
engine = std::string(argv[3]);
n_classes = atoi(argv[4]);
if (n_classes < 1)
return false;
auto net = std::string(argv[5]);
if (net[0] == 'n') {
gd = 0.33;
gw = 0.25;
} else if (net[0] == 's') {
gd = 0.33;
gw = 0.50;
} else if (net[0] == 'm') {
gd = 0.67;
gw = 0.75;
} else if (net[0] == 'l') {
gd = 1.0;
gw = 1.0;
} else if (net[0] == 'x') {
gd = 1.33;
gw = 1.25;
} else if (net[0] == 'c' && argc == 8) {
gd = atof(argv[6]);
gw = atof(argv[7]);
} else {
return false;
}
} else if (std::string(argv[1]) == "-d" && argc == 5) {
engine = std::string(argv[2]);
img_dir = std::string(argv[3]);
labels_filename = std::string(argv[4]);
bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, std::string& img_dir, std::string& labels_filename, int& n_classes, int& n_batch) {
if (argc < 4) return false;
if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 7 || argc == 8)) {
wts = std::string(argv[2]);
engine = std::string(argv[3]);
n_classes = atoi(argv[4]);
if (n_classes < 1)
return false;
auto net = std::string(argv[5]);
if (net[0] == 'n') {
gd = 0.33;
gw = 0.25;
} else if (net[0] == 's') {
gd = 0.33;
gw = 0.50;
} else if (net[0] == 'm') {
gd = 0.67;
gw = 0.75;
} else if (net[0] == 'l') {
gd = 1.0;
gw = 1.0;
} else if (net[0] == 'x') {
gd = 1.33;
gw = 1.25;
} else if (net[0] == 'c' && argc == 8) {
gd = atof(argv[6]);
gw = atof(argv[7]);
} else {
return false;
return false;
}
return true;
if (argc == 7) {
n_batch = atoi(argv[6]);
}
} else if (std::string(argv[1]) == "-d" && argc == 5) {
engine = std::string(argv[2]);
img_dir = std::string(argv[3]);
labels_filename = std::string(argv[4]);
} else {
return false;
}
return true;
}
@ -98,19 +101,21 @@ int main(int argc, char** argv) {
std::string labels_filename = "";
float gd = 0.0f, gw = 0.0f;
int n_classes;
int n_batch = 1;
std::string img_dir;
if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir, labels_filename, n_classes)) {
if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir, labels_filename, n_classes, n_batch)) {
std::cerr << "arguments not right!" << std::endl;
std::cerr << "./SpireCVSeg -s [.wts] [.engine] n_classes [n/s/m/l/x or c gd gw] // serialize model to plan file" << std::endl;
// std::cerr << "./SpireCVSeg -d [.engine] ../images coco.txt // deserialize plan file, read the labels file and run inference" << std::endl;
return -1;
}
std::cout << "n_classes: " << n_classes << std::endl;
std::cout << "max_batch: " << n_batch << std::endl;
// Create a model using the API directly and serialize it to a file
if (!wts_name.empty()) {
serialize_engine(kBatchSize, gd, gw, wts_name, engine_name, n_classes);
serialize_engine(n_batch, gd, gw, wts_name, engine_name, n_classes);
return 0;
}