SpireCVDet supports adjusting the batch-size of outputing engines
This commit is contained in:
parent
43deec9daa
commit
493c40fb73
|
@ -18,9 +18,9 @@ const static int kInputH_HD = 1280;
|
|||
const static int kInputW_HD = 1280;
|
||||
const static int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1;
|
||||
|
||||
bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, bool& is_p6, float& gd, float& gw, std::string& img_dir, int& n_classes) {
|
||||
bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, bool& is_p6, float& gd, float& gw, std::string& img_dir, int& n_classes, int& n_batch) {
|
||||
if (argc < 4) return false;
|
||||
if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 8)) {
|
||||
if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 7 || argc == 8)) {
|
||||
wts = std::string(argv[2]);
|
||||
engine = std::string(argv[3]);
|
||||
n_classes = atoi(argv[4]);
|
||||
|
@ -51,6 +51,9 @@ bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, bo
|
|||
if (net.size() == 2 && net[1] == '6') {
|
||||
is_p6 = true;
|
||||
}
|
||||
if (argc == 7) {
|
||||
n_batch = atoi(argv[6]);
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
@ -99,18 +102,20 @@ int main(int argc, char** argv) {
|
|||
float gd = 0.0f, gw = 0.0f;
|
||||
std::string img_dir;
|
||||
int n_classes;
|
||||
int n_batch = 1;
|
||||
|
||||
if (!parse_args(argc, argv, wts_name, engine_name, is_p6, gd, gw, img_dir, n_classes)) {
|
||||
if (!parse_args(argc, argv, wts_name, engine_name, is_p6, gd, gw, img_dir, n_classes, n_batch)) {
|
||||
std::cerr << "arguments not right!" << std::endl;
|
||||
std::cerr << "./SpireCVDet -s [.wts] [.engine] n_classes [n/s/m/l/x/n6/s6/m6/l6/x6 or c/c6 gd gw] // serialize model to plan file" << std::endl;
|
||||
// std::cerr << "./SpireCVDet -d [.engine] ../images // deserialize plan file and run inference" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
std::cout << "n_classes: " << n_classes << std::endl;
|
||||
std::cout << "max_batch: " << n_batch << std::endl;
|
||||
|
||||
// Create a model using the API directly and serialize it to a file
|
||||
if (!wts_name.empty()) {
|
||||
serialize_engine(kBatchSize, is_p6, gd, gw, wts_name, engine_name, n_classes);
|
||||
serialize_engine(n_batch, is_p6, gd, gw, wts_name, engine_name, n_classes);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,44 +18,47 @@ const static int kInputW = 640;
|
|||
const static int kOutputSize1 = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1;
|
||||
const static int kOutputSize2 = 32 * (kInputH / 4) * (kInputW / 4);
|
||||
|
||||
bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, std::string& img_dir, std::string& labels_filename, int& n_classes) {
|
||||
if (argc < 4) return false;
|
||||
if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 8)) {
|
||||
wts = std::string(argv[2]);
|
||||
engine = std::string(argv[3]);
|
||||
n_classes = atoi(argv[4]);
|
||||
if (n_classes < 1)
|
||||
return false;
|
||||
auto net = std::string(argv[5]);
|
||||
if (net[0] == 'n') {
|
||||
gd = 0.33;
|
||||
gw = 0.25;
|
||||
} else if (net[0] == 's') {
|
||||
gd = 0.33;
|
||||
gw = 0.50;
|
||||
} else if (net[0] == 'm') {
|
||||
gd = 0.67;
|
||||
gw = 0.75;
|
||||
} else if (net[0] == 'l') {
|
||||
gd = 1.0;
|
||||
gw = 1.0;
|
||||
} else if (net[0] == 'x') {
|
||||
gd = 1.33;
|
||||
gw = 1.25;
|
||||
} else if (net[0] == 'c' && argc == 8) {
|
||||
gd = atof(argv[6]);
|
||||
gw = atof(argv[7]);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else if (std::string(argv[1]) == "-d" && argc == 5) {
|
||||
engine = std::string(argv[2]);
|
||||
img_dir = std::string(argv[3]);
|
||||
labels_filename = std::string(argv[4]);
|
||||
bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, std::string& img_dir, std::string& labels_filename, int& n_classes, int& n_batch) {
|
||||
if (argc < 4) return false;
|
||||
if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 7 || argc == 8)) {
|
||||
wts = std::string(argv[2]);
|
||||
engine = std::string(argv[3]);
|
||||
n_classes = atoi(argv[4]);
|
||||
if (n_classes < 1)
|
||||
return false;
|
||||
auto net = std::string(argv[5]);
|
||||
if (net[0] == 'n') {
|
||||
gd = 0.33;
|
||||
gw = 0.25;
|
||||
} else if (net[0] == 's') {
|
||||
gd = 0.33;
|
||||
gw = 0.50;
|
||||
} else if (net[0] == 'm') {
|
||||
gd = 0.67;
|
||||
gw = 0.75;
|
||||
} else if (net[0] == 'l') {
|
||||
gd = 1.0;
|
||||
gw = 1.0;
|
||||
} else if (net[0] == 'x') {
|
||||
gd = 1.33;
|
||||
gw = 1.25;
|
||||
} else if (net[0] == 'c' && argc == 8) {
|
||||
gd = atof(argv[6]);
|
||||
gw = atof(argv[7]);
|
||||
} else {
|
||||
return false;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
if (argc == 7) {
|
||||
n_batch = atoi(argv[6]);
|
||||
}
|
||||
} else if (std::string(argv[1]) == "-d" && argc == 5) {
|
||||
engine = std::string(argv[2]);
|
||||
img_dir = std::string(argv[3]);
|
||||
labels_filename = std::string(argv[4]);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
@ -98,19 +101,21 @@ int main(int argc, char** argv) {
|
|||
std::string labels_filename = "";
|
||||
float gd = 0.0f, gw = 0.0f;
|
||||
int n_classes;
|
||||
int n_batch = 1;
|
||||
|
||||
std::string img_dir;
|
||||
if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir, labels_filename, n_classes)) {
|
||||
if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir, labels_filename, n_classes, n_batch)) {
|
||||
std::cerr << "arguments not right!" << std::endl;
|
||||
std::cerr << "./SpireCVSeg -s [.wts] [.engine] n_classes [n/s/m/l/x or c gd gw] // serialize model to plan file" << std::endl;
|
||||
// std::cerr << "./SpireCVSeg -d [.engine] ../images coco.txt // deserialize plan file, read the labels file and run inference" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
std::cout << "n_classes: " << n_classes << std::endl;
|
||||
std::cout << "max_batch: " << n_batch << std::endl;
|
||||
|
||||
// Create a model using the API directly and serialize it to a file
|
||||
if (!wts_name.empty()) {
|
||||
serialize_engine(kBatchSize, gd, gw, wts_name, engine_name, n_classes);
|
||||
serialize_engine(n_batch, gd, gw, wts_name, engine_name, n_classes);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue