Common Object Detector supports input of HR images

This commit is contained in:
jario 2023-10-27 20:42:14 +08:00
parent c46fe93dbb
commit 47bb722038
4 changed files with 179 additions and 55 deletions

View File

@ -103,7 +103,7 @@ void infer_seg(IExecutionContext& context, cudaStream_t& stream, void **buffers,
CUDA_CHECK(cudaMemcpyAsync(output2, buffers[2], batchSize * kOutputSize2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaMemcpyAsync(output2, buffers[2], batchSize * kOutputSize2 * sizeof(float), cudaMemcpyDeviceToHost, stream));
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
} }
void CommonObjectDetectorCUDAImpl::_prepare_buffers(int input_h, int input_w) { void CommonObjectDetectorCUDAImpl::_prepare_buffers(int input_h, int input_w, int batchsize) {
assert(this->_engine->getNbBindings() == 2); assert(this->_engine->getNbBindings() == 2);
// In order to bind the buffers, we need to know the names of the input and output tensors. // In order to bind the buffers, we need to know the names of the input and output tensors.
// Note that indices are guaranteed to be less than IEngine::getNbBindings() // Note that indices are guaranteed to be less than IEngine::getNbBindings()
@ -112,12 +112,12 @@ void CommonObjectDetectorCUDAImpl::_prepare_buffers(int input_h, int input_w) {
assert(inputIndex == 0); assert(inputIndex == 0);
assert(outputIndex == 1); assert(outputIndex == 1);
// Create GPU buffers on device // Create GPU buffers on device
CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[0]), kBatchSize * 3 * input_h * input_w * sizeof(float))); CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[0]), batchsize * 3 * input_h * input_w * sizeof(float)));
CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[1]), kBatchSize * kOutputSize * sizeof(float))); CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[1]), batchsize * kOutputSize * sizeof(float)));
this->_cpu_output_buffer = new float[kBatchSize * kOutputSize]; this->_cpu_output_buffer = new float[batchsize * kOutputSize];
} }
void CommonObjectDetectorCUDAImpl::_prepare_buffers_seg(int input_h, int input_w) { void CommonObjectDetectorCUDAImpl::_prepare_buffers_seg(int input_h, int input_w, int batchsize) {
assert(this->_engine->getNbBindings() == 3); assert(this->_engine->getNbBindings() == 3);
// In order to bind the buffers, we need to know the names of the input and output tensors. // In order to bind the buffers, we need to know the names of the input and output tensors.
// Note that indices are guaranteed to be less than IEngine::getNbBindings() // Note that indices are guaranteed to be less than IEngine::getNbBindings()
@ -129,13 +129,13 @@ void CommonObjectDetectorCUDAImpl::_prepare_buffers_seg(int input_h, int input_w
assert(outputIndex2 == 2); assert(outputIndex2 == 2);
// Create GPU buffers on device // Create GPU buffers on device
CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[0]), kBatchSize * 3 * input_h * input_w * sizeof(float))); CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[0]), batchsize * 3 * input_h * input_w * sizeof(float)));
CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[1]), kBatchSize * kOutputSize1 * sizeof(float))); CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[1]), batchsize * kOutputSize1 * sizeof(float)));
CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[2]), kBatchSize * kOutputSize2 * sizeof(float))); CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[2]), batchsize * kOutputSize2 * sizeof(float)));
// Alloc CPU buffers // Alloc CPU buffers
this->_cpu_output_buffer1 = new float[kBatchSize * kOutputSize1]; this->_cpu_output_buffer1 = new float[batchsize * kOutputSize1];
this->_cpu_output_buffer2 = new float[kBatchSize * kOutputSize2]; this->_cpu_output_buffer2 = new float[batchsize * kOutputSize2];
} }
void deserialize_engine(std::string& engine_name, IRuntime** runtime, ICudaEngine** engine, IExecutionContext** context) { void deserialize_engine(std::string& engine_name, IRuntime** runtime, ICudaEngine** engine, IExecutionContext** context) {
std::ifstream file(engine_name, std::ios::binary); std::ifstream file(engine_name, std::ios::binary);
@ -172,7 +172,8 @@ void CommonObjectDetectorCUDAImpl::cudaDetect(
std::vector<float>& boxes_h_, std::vector<float>& boxes_h_,
std::vector<int>& boxes_label_, std::vector<int>& boxes_label_,
std::vector<float>& boxes_score_, std::vector<float>& boxes_score_,
std::vector<cv::Mat>& boxes_seg_ std::vector<cv::Mat>& boxes_seg_,
bool input_4k_
) )
{ {
#ifdef WITH_CUDA #ifdef WITH_CUDA
@ -183,9 +184,51 @@ void CommonObjectDetectorCUDAImpl::cudaDetect(
double thrs_nms = base_->getThrsNms(); double thrs_nms = base_->getThrsNms();
std::vector<cv::Mat> img_batch; std::vector<cv::Mat> img_batch;
img_batch.push_back(img_); if (input_4k_)
// Preprocess {
cuda_batch_preprocess(img_batch, this->_gpu_buffers[0], input_w, input_h, this->_stream); if (img_.cols == 3840 && img_.rows == 2160)
{
cv::Mat patch1, patch2, patch3, patch4, patch5, patch6;
img_.colRange(200, 1480).rowRange(0, 1280).copyTo(patch1);
img_.colRange(1280, 2560).rowRange(0, 1280).copyTo(patch2);
img_.colRange(2360, 3640).rowRange(0, 1280).copyTo(patch3);
img_.colRange(200, 1480).rowRange(880, 2160).copyTo(patch4);
img_.colRange(1280, 2560).rowRange(880, 2160).copyTo(patch5);
img_.colRange(2360, 3640).rowRange(880, 2160).copyTo(patch6);
img_batch.push_back(patch1);
img_batch.push_back(patch2);
img_batch.push_back(patch3);
img_batch.push_back(patch4);
img_batch.push_back(patch5);
img_batch.push_back(patch6);
}
else
{
throw std::runtime_error("SpireCV (106) Input image is NOT 4K (3840 x 2160)!");
}
if (with_segmentation)
{
throw std::runtime_error("SpireCV (106) Resolution 4K DO NOT Support Segmentation!");
}
}
else
{
img_batch.push_back(img_);
}
if (input_4k_)
{
// Preprocess
cuda_batch_preprocess(img_batch, this->_gpu_buffers[0], 1280, 1280, this->_stream);
}
else
{
// Preprocess
cuda_batch_preprocess(img_batch, this->_gpu_buffers[0], input_w, input_h, this->_stream);
}
// Run inference // Run inference
if (with_segmentation) if (with_segmentation)
@ -194,7 +237,14 @@ void CommonObjectDetectorCUDAImpl::cudaDetect(
} }
else else
{ {
infer(*this->_context, this->_stream, (void**)this->_gpu_buffers, this->_cpu_output_buffer, kBatchSize); if (input_4k_)
{
infer(*this->_context, this->_stream, (void**)this->_gpu_buffers, this->_cpu_output_buffer, 6);
}
else
{
infer(*this->_context, this->_stream, (void**)this->_gpu_buffers, this->_cpu_output_buffer, kBatchSize);
}
} }
// NMS // NMS
@ -208,45 +258,102 @@ void CommonObjectDetectorCUDAImpl::cudaDetect(
batch_nms(res_batch, this->_cpu_output_buffer, img_batch.size(), kOutputSize, thrs_conf, thrs_nms); batch_nms(res_batch, this->_cpu_output_buffer, img_batch.size(), kOutputSize, thrs_conf, thrs_nms);
} }
std::vector<Detection> res = res_batch[0];
std::vector<cv::Mat> masks; if (input_4k_)
if (with_segmentation)
{ {
masks = process_mask(&(this->_cpu_output_buffer2[0]), kOutputSize2, res, input_h, input_w); for (size_t k = 0; k < res_batch.size(); k++)
}
for (size_t j = 0; j < res.size(); j++) {
cv::Rect r = get_rect(img_, res[j].bbox, input_h, input_w);
if (r.x < 0) r.x = 0;
if (r.y < 0) r.y = 0;
if (r.x + r.width >= img_.cols) r.width = img_.cols - r.x - 1;
if (r.y + r.height >= img_.rows) r.height = img_.rows - r.y - 1;
if (r.width > 5 && r.height > 5)
{ {
// cv::rectangle(img_show, r, cv::Scalar(0, 0, 255), 2); std::vector<Detection> res = res_batch[k];
// cv::putText(img_show, vehiclenames[(int)res[j].class_id], cv::Point(r.x, r.y - 1), cv::FONT_HERSHEY_PLAIN, 2.2, cv::Scalar(0, 0, 255), 2); for (size_t j = 0; j < res.size(); j++)
boxes_x_.push_back(r.x);
boxes_y_.push_back(r.y);
boxes_w_.push_back(r.width);
boxes_h_.push_back(r.height);
boxes_label_.push_back((int)res[j].class_id);
boxes_score_.push_back(res[j].conf);
if (with_segmentation)
{ {
cv::Mat mask_j = masks[j].clone(); cv::Rect r = get_rect(img_batch[k], res[j].bbox, 1280, 1280);
boxes_seg_.push_back(mask_j); if (r.x < 0) r.x = 0;
if (r.y < 0) r.y = 0;
if (r.x + r.width >= 1280) r.width = 1280 - r.x - 1;
if (r.y + r.height >= 1280) r.height = 1280 - r.y - 1;
if (r.width > 3 && r.height > 3)
{
if (0 == k)
{
boxes_x_.push_back(r.x + 200);
boxes_y_.push_back(r.y);
}
else if (1 == k)
{
boxes_x_.push_back(r.x + 1280);
boxes_y_.push_back(r.y);
}
else if (2 == k)
{
boxes_x_.push_back(r.x + 2360);
boxes_y_.push_back(r.y);
}
else if (3 == k)
{
boxes_x_.push_back(r.x + 200);
boxes_y_.push_back(r.y + 880);
}
else if (4 == k)
{
boxes_x_.push_back(r.x + 1280);
boxes_y_.push_back(r.y + 880);
}
else if (5 == k)
{
boxes_x_.push_back(r.x + 2360);
boxes_y_.push_back(r.y + 880);
}
boxes_w_.push_back(r.width);
boxes_h_.push_back(r.height);
boxes_label_.push_back((int)res[j].class_id);
boxes_score_.push_back(res[j].conf);
}
} }
} }
} }
else
{
std::vector<Detection> res = res_batch[0];
std::vector<cv::Mat> masks;
if (with_segmentation)
{
masks = process_mask(&(this->_cpu_output_buffer2[0]), kOutputSize2, res, input_h, input_w);
}
for (size_t j = 0; j < res.size(); j++)
{
cv::Rect r = get_rect(img_, res[j].bbox, input_h, input_w);
if (r.x < 0) r.x = 0;
if (r.y < 0) r.y = 0;
if (r.x + r.width >= img_.cols) r.width = img_.cols - r.x - 1;
if (r.y + r.height >= img_.rows) r.height = img_.rows - r.y - 1;
if (r.width > 5 && r.height > 5)
{
// cv::rectangle(img_show, r, cv::Scalar(0, 0, 255), 2);
// cv::putText(img_show, vehiclenames[(int)res[j].class_id], cv::Point(r.x, r.y - 1), cv::FONT_HERSHEY_PLAIN, 2.2, cv::Scalar(0, 0, 255), 2);
boxes_x_.push_back(r.x);
boxes_y_.push_back(r.y);
boxes_w_.push_back(r.width);
boxes_h_.push_back(r.height);
boxes_label_.push_back((int)res[j].class_id);
boxes_score_.push_back(res[j].conf);
if (with_segmentation)
{
cv::Mat mask_j = masks[j].clone();
boxes_seg_.push_back(mask_j);
}
}
}
}
#endif #endif
} }
bool CommonObjectDetectorCUDAImpl::cudaSetup(CommonObjectDetectorBase* base_) bool CommonObjectDetectorCUDAImpl::cudaSetup(CommonObjectDetectorBase* base_, bool input_4k_)
{ {
#ifdef WITH_CUDA #ifdef WITH_CUDA
std::string dataset = base_->getDataset(); std::string dataset = base_->getDataset();
@ -273,6 +380,11 @@ bool CommonObjectDetectorCUDAImpl::cudaSetup(CommonObjectDetectorBase* base_)
throw std::runtime_error("SpireCV (104) Error loading the CommonObject TensorRT model (File Not Exist)"); throw std::runtime_error("SpireCV (104) Error loading the CommonObject TensorRT model (File Not Exist)");
} }
if (input_4k_ && with_segmentation)
{
throw std::runtime_error("SpireCV (106) Resolution 4K DO NOT Support Segmentation!");
}
deserialize_engine(engine_fn, &this->_runtime, &this->_engine, &this->_context); deserialize_engine(engine_fn, &this->_runtime, &this->_engine, &this->_context);
CUDA_CHECK(cudaStreamCreate(&this->_stream)); CUDA_CHECK(cudaStreamCreate(&this->_stream));
@ -282,12 +394,20 @@ bool CommonObjectDetectorCUDAImpl::cudaSetup(CommonObjectDetectorBase* base_)
if (with_segmentation) if (with_segmentation)
{ {
// Prepare cpu and gpu buffers // Prepare cpu and gpu buffers
this->_prepare_buffers_seg(input_h, input_w); this->_prepare_buffers_seg(input_h, input_w, 1);
} }
else else
{ {
// Prepare cpu and gpu buffers if (input_4k_)
this->_prepare_buffers(input_h, input_w); {
// Prepare cpu and gpu buffers
this->_prepare_buffers(input_h, input_w, 6);
}
else
{
// Prepare cpu and gpu buffers
this->_prepare_buffers(input_h, input_w, 1);
}
} }
return true; return true;
#endif #endif

View File

@ -26,7 +26,7 @@ public:
CommonObjectDetectorCUDAImpl(); CommonObjectDetectorCUDAImpl();
~CommonObjectDetectorCUDAImpl(); ~CommonObjectDetectorCUDAImpl();
bool cudaSetup(CommonObjectDetectorBase* base_); bool cudaSetup(CommonObjectDetectorBase* base_, bool input_4k_);
void cudaDetect( void cudaDetect(
CommonObjectDetectorBase* base_, CommonObjectDetectorBase* base_,
cv::Mat img_, cv::Mat img_,
@ -36,12 +36,13 @@ public:
std::vector<float>& boxes_h_, std::vector<float>& boxes_h_,
std::vector<int>& boxes_label_, std::vector<int>& boxes_label_,
std::vector<float>& boxes_score_, std::vector<float>& boxes_score_,
std::vector<cv::Mat>& boxes_seg_ std::vector<cv::Mat>& boxes_seg_,
bool input_4k_
); );
#ifdef WITH_CUDA #ifdef WITH_CUDA
void _prepare_buffers_seg(int input_h, int input_w); void _prepare_buffers_seg(int input_h, int input_w, int batchsize);
void _prepare_buffers(int input_h, int input_w); void _prepare_buffers(int input_h, int input_w, int batchsize);
nvinfer1::IExecutionContext* _context; nvinfer1::IExecutionContext* _context;
nvinfer1::IRuntime* _runtime; nvinfer1::IRuntime* _runtime;
nvinfer1::ICudaEngine* _engine; nvinfer1::ICudaEngine* _engine;

View File

@ -12,8 +12,9 @@
namespace sv { namespace sv {
CommonObjectDetector::CommonObjectDetector() CommonObjectDetector::CommonObjectDetector(bool input_4k)
{ {
this->_input_4k = input_4k;
#ifdef WITH_CUDA #ifdef WITH_CUDA
this->_cuda_impl = new CommonObjectDetectorCUDAImpl; this->_cuda_impl = new CommonObjectDetectorCUDAImpl;
#endif #endif
@ -25,7 +26,7 @@ CommonObjectDetector::~CommonObjectDetector()
bool CommonObjectDetector::setupImpl() bool CommonObjectDetector::setupImpl()
{ {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return this->_cuda_impl->cudaSetup(this); return this->_cuda_impl->cudaSetup(this, this->_input_4k);
#endif #endif
return false; return false;
} }
@ -51,7 +52,8 @@ void CommonObjectDetector::detectImpl(
boxes_h_, boxes_h_,
boxes_label_, boxes_label_,
boxes_score_, boxes_score_,
boxes_seg_ boxes_seg_,
this->_input_4k
); );
#endif #endif
} }

View File

@ -16,7 +16,7 @@ class CommonObjectDetectorCUDAImpl;
class CommonObjectDetector : public CommonObjectDetectorBase class CommonObjectDetector : public CommonObjectDetectorBase
{ {
public: public:
CommonObjectDetector(); CommonObjectDetector(bool input_4k=false);
~CommonObjectDetector(); ~CommonObjectDetector();
protected: protected:
bool setupImpl(); bool setupImpl();
@ -32,6 +32,7 @@ protected:
); );
CommonObjectDetectorCUDAImpl* _cuda_impl; CommonObjectDetectorCUDAImpl* _cuda_impl;
bool _input_4k;
}; };