From 0e539348d9323fec53298d6dc3bc4e7d7be9151c Mon Sep 17 00:00:00 2001 From: CZC-123 <73985910+CZC-123@users.noreply.github.com> Date: Thu, 20 Jul 2023 15:32:26 +0800 Subject: [PATCH] finish SORT Tracker V0.1, czc Signed-off-by: CZC-123 <73985910+CZC-123@users.noreply.github.com> --- CMakeLists.txt | 48 +- samples/demo/SORT_2.cpp | 354 +++++++++++++ samples/demo/SORT_2.h | 62 +++ samples/demo/common_object_tracking_SORT.cpp | 188 +++++++ samples/demo/logging.h | 504 +++++++++++++++++++ samples/demo/macros.h | 29 ++ 6 files changed, 1169 insertions(+), 16 deletions(-) create mode 100644 samples/demo/SORT_2.cpp create mode 100644 samples/demo/SORT_2.h create mode 100644 samples/demo/common_object_tracking_SORT.cpp create mode 100644 samples/demo/logging.h create mode 100644 samples/demo/macros.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 781e602..f639df5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,7 @@ add_definitions(-DAPI_EXPORTS) set(CMAKE_BUILD_TYPE "Release") -## JETSON, X86_CUDA, X86_INTEL +## JETSON, X86_CUDA message(STATUS "System:${CMAKE_HOST_SYSTEM_PROCESSOR}") if(NOT DEFINED PLATFORM) message(FATAL_ERROR "PLATFORM NOT SPECIFIED!") @@ -53,6 +53,7 @@ endif() add_definitions(-DWITH_OCV470) find_package(OpenCV 4.7 REQUIRED) +find_package(Eigen3 REQUIRED) message(STATUS "OpenCV library status:") message(STATUS " version: ${OpenCV_VERSION}") message(STATUS " libraries: ${OpenCV_LIBS}") @@ -60,6 +61,7 @@ message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}") include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${EIGEN3_INCLUDE_DIRS}) include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/IOs/serial/include ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/FIFO @@ -185,7 +187,7 @@ list(APPEND spirecv_SRCS ${ALG_SRC_FILES}) endif() -if(USE_CUDA) # PLATFORM_X86_CUDA & PLATFORM_JETSON +if(USE_CUDA) # CUDA include_directories(/usr/local/cuda/include) link_directories(/usr/local/cuda/lib64) @@ -197,11 +199,24 @@ if(USE_CUDA) # PLATFORM_X86_CUDA & PLATFORM_JETSON target_link_libraries(sv_yoloplugins nvinfer cudart) cuda_add_library(sv_world SHARED ${spirecv_SRCS}) - target_link_libraries( - sv_world ${OpenCV_LIBS} - sv_yoloplugins sv_gimbal - nvinfer cudart - ) + if(USE_GSTREAMER) + target_link_libraries( + sv_world ${OpenCV_LIBS} + sv_yoloplugins sv_gimbal + nvinfer cudart + gstrtspserver-1.0 + ) + else() + target_link_libraries( + sv_world ${OpenCV_LIBS} + sv_yoloplugins sv_gimbal + nvinfer cudart + ) + endif() + + if(USE_FFMPEG) + target_link_libraries(sv_world ${FFMPEG_LIBS} fmt) + endif() set( YOLO_SRCS @@ -217,20 +232,18 @@ if(USE_CUDA) # PLATFORM_X86_CUDA & PLATFORM_JETSON cuda_add_executable(SpireCVSeg samples/SpireCVSeg.cpp ${YOLO_SRCS}) target_link_libraries(SpireCVSeg sv_world) -elseif(PLATFORM STREQUAL "X86_INTEL") # Links to Intel-OpenVINO libraries here +elseif(PLATFORM STREQUAL "X86_INTEL") add_library(sv_world SHARED ${spirecv_SRCS}) target_link_libraries( sv_world ${OpenCV_LIBS} sv_gimbal ) -endif() - -if(USE_GSTREAMER) - target_link_libraries(sv_world gstrtspserver-1.0) -endif() - -if(USE_FFMPEG) - target_link_libraries(sv_world ${FFMPEG_LIBS} fmt) + if(USE_GSTREAMER) + target_link_libraries(sv_world gstrtspserver-1.0) + endif() + if(USE_FFMPEG) + target_link_libraries(sv_world ${FFMPEG_LIBS} fmt) + endif() endif() #demo @@ -268,6 +281,9 @@ target_link_libraries(GimbalUdpDetectionInfoSender sv_world) add_executable(EvalFpsOnVideo samples/test/eval_fps_on_video.cpp) target_link_libraries(EvalFpsOnVideo sv_world) +add_executable(SortTracker samples/demo/common_object_tracking_SORT.cpp samples/demo/SORT_2.cpp) +target_link_libraries(SortTracker sv_world) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/samples/calib) add_executable(CameraCalibrarion samples/calib/calibrate_camera_charuco.cpp) target_link_libraries(CameraCalibrarion ${OpenCV_LIBS}) diff --git a/samples/demo/SORT_2.cpp b/samples/demo/SORT_2.cpp new file mode 100644 index 0000000..6d9ef1e --- /dev/null +++ b/samples/demo/SORT_2.cpp @@ -0,0 +1,354 @@ +#include "SORT_2.h" +#include + +using namespace std; +using namespace Eigen; + +/* +void SORT::update(vector detections) { + //allocate the memory for the vector in advance + int numDetections = detections.size(); + detections.reserve(numDetections); + + if (! tracklets_.size() || ! detections.size()){ + for (int i = 0; i < detections.size(); i++) + { + //if(find(matches.begin(), matches.end(),make_pair(-1, i)) == matches.end()) + if (detections[i](4) == -1) { + Tracklet tracklet; + tracklet.id = nextTrackletId_++;// + tracklet.bbox.segment<4>(0) = detections[i].segment<4>(0); + tracklet.age = 0; + tracklet.hits = 1; + tracklet.misses = 0; + //initate the motion + //pair motion = kf.initiate(bbox_xywh_to_wyah(detect_mat[i])); + tracklets_.push_back(tracklet); + } + } + } + else{ + // create new tracklets for unmatched detections + for (int i = 0; i < detections.size(); i++) + { + if (detections[i](4)== -1) { + Tracklet tracklet; + tracklet.id = nextTrackletId_++;// + tracklet.bbox.segment<4>(0) = detections[i].segment<4>(0); + tracklet.age = 0; + tracklet.hits = 1; + tracklet.misses = 0; + tracklets_.push_back(tracklet); + } + } + } +}*/ +KalmanFilter::KalmanFilter(){ + chi2inv95_ << 3.8415, 5.9915, 7.8147, 9.4877, 11.070, 12.592, 14.067, 15.507, 16.919; + /* + F_ << 1, 0, 0, 0, 1, 0, 0, 0, + 0, 1, 0, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 0, 0, 1, 0, + 0, 0, 0, 1, 0, 0, 0, 1, + 0, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1; + + H_ << 1, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 0, 0, 0, 0; + */ + F_ = MatrixXd::Identity(8, 8); + for (int i = 0; i < 4; i++) { + F_(i, i + 4) = 1; + } + H_ = MatrixXd::Identity(4, 8); + std_weight_position = 1. / 20; + std_weight_vel = 1. / 160; + } + +pair,Matrix> KalmanFilter::initiate(Vector4d & bbox){ + Matrix mean; + mean << bbox(0), bbox(1), bbox(2)/bbox(3), bbox(3), 0, 0, 0, 0; + //cout<<"the a is:"< covariances; + covariances = squared.asDiagonal(); + return make_pair(mean, covariances); + } + +pair,Matrix> KalmanFilter::update(Matrix mean, Matrix covariances, sv::Box & box) { + MatrixXd R_; + Vector4d stds; + + stds << std_weight_position * mean(3), std_weight_position* mean(3), 0.1, std_weight_position* mean(3); + MatrixXd squared = stds.array().square(); + R_ = squared.asDiagonal(); + + MatrixXd S = H_ * covariances * H_.transpose()+R_; + MatrixXd Kalman_gain = covariances * H_.transpose() * S.inverse(); + VectorXd measurement(4); + measurement << box.x1, box.y1, (box.x2-box.x1)/(box.y2-box.y1), box.y2-box.y1; + Matrix new_mean = mean + Kalman_gain * (measurement - H_ * mean); + Matrix new_covariances = (MatrixXd::Identity(8, 8) - Kalman_gain * H_) * covariances; + return make_pair(new_mean, new_covariances); + } + +pair, Matrix> KalmanFilter::predict(Matrixmean, Matrixcovariances) { + + VectorXd stds(8); + stds << std_weight_position * mean(3), std_weight_position* mean(3), 0.01, std_weight_position* mean(3), std_weight_vel* mean(3), std_weight_vel* mean(3), 1e-5, std_weight_vel* mean(3); + MatrixXd squared = stds.array().square(); + MatrixXd Q_ = squared.asDiagonal(); + Matrix pre_mean = F_ * mean; + Matrix pre_cov = F_ * covariances * F_.transpose()+Q_; + return make_pair(pre_mean, pre_cov); + } + +void SORT::update(sv::TargetsInFrame & tgts){ + + KalmanFilter kf; + // if (! tracklets_.size() || ! tgts.targets.size()) + if (! tracklets_.size()) + { + Vector4d bbox; + for (int i = 0; i ,Matrix> motion = kf.initiate(tracklet.bbox); + //cout<<"mean:"< match_det(100,-1); + array match_det; + match_det.fill(-1); + //predict the next state of each tracklet + for (auto& tracklet : tracklets_) { + tracklet.age++; + pair, Matrix> motion = kf.predict(tracklet.mean, tracklet.covariance); + tracklet.bbox << motion.first(0),motion.first(1),motion.first(2)*motion.first(3),motion.first(3); + tracklet.mean = motion.first; + tracklet.covariance = motion.second; + + } + + //Match the detections to the existing tracklets + cout<<"the num of targets: "<> iouMatrix(tracklets_.size(), vector (tgts.targets.size(), 0)); // + for (int i = 0; i > matches = hungarian(iouMatrix); + for (auto& match : matches) { + int trackletIndex = match.first; + int detectionIndex = match.second; + // cout<<"trackletIndex:"<=0 && detectionIndex>=0) { + if(iouMatrix[match.first][match.second]>=0) { + sv::Box box; + tgts.targets[detectionIndex].getBox(box); + tracklets_[trackletIndex].age = 0; + tracklets_[trackletIndex].hits++; + tracklets_[trackletIndex].bbox << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1; + + auto[mean,covariance]=kf.update(tracklets_[trackletIndex].mean,tracklets_[trackletIndex].covariance, box); + tracklets_[trackletIndex].mean = mean; + tracklets_[trackletIndex].covariance = covariance; + + tgts.targets[detectionIndex].tracked_id = tracklets_[trackletIndex].id; + match_det[detectionIndex]=detectionIndex; + } + } + } + // create new tracklets for unmatched detections + for (int i = 0; i < tgts.targets.size(); i++) + { + if (match_det[i]==-1) { + sv::Box box; + tgts.targets[i].getBox(box); + Tracklet tracklet; + tracklet.id = ++nextTrackletId_;// + //Vector4d bbox; + //bbox << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1; + //tracklet.bboxes.push_back(bbox); + tracklet.bbox << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1; + + tracklet.age = 0; + tracklet.hits = 1; + tracklet.misses = 0; + + auto [new_mean,new_covariance] = kf.initiate(tracklet.bbox); + tracklet.mean = new_mean; + tracklet.covariance = new_covariance; + + tgts.targets[i].tracked_id = nextTrackletId_; + tracklets_.push_back(tracklet); + } + } + } + /* + vector newTracklets; + for (auto& tracklet : tracklets_) { + if (tracklet.age < maxAge_ || tracklet.hits >= minHits_) { + newTracklets.push_back(tracklet); + } + } + tracklets_ = newTracklets; + */ +} +vector SORT::getTracklets() const{ + return tracklets_; + } + +double SORT::iou(Tracklet & tracklet, sv::Box & box) { + + double trackletX1 = tracklet.bbox(0); + double trackletY1 = tracklet.bbox(1); + double trackletX2 = tracklet.bbox(0)+tracklet.bbox(2); + double trackletY2 = tracklet.bbox(1)+tracklet.bbox(3); + /* + double trackletX1 = tracklet.bboxes.back()(0); + double trackletY1 = tracklet.bboxes.back()(1); + double trackletX2 = tracklet.bboxes.back()(0)+tracklet.bboxes.back()(2); + double trackletY2 = tracklet.bboxes.back()(1)+tracklet.bboxes.back()(3); + + double trackletX1 = tracklet.x; + double trackletY1 = tracklet.y; + double trackletX2 = tracklet.x+tracklet.w; + double trackletY2 = tracklet.y+tracklet.h; + */ + double detectionX1 = box.x1; + double detectionY1 = box.y1; + double detectionX2 = box.x2; + double detectionY2 = box.y2; + double intersectionX1 = max(trackletX1, detectionX1); + double intersectionY1 = max(trackletY1, detectionY1); + double intersectionX2 = min(trackletX2, detectionX2); + double intersectionY2 = min(trackletY2, detectionY2); + //double intersectionArea = max(intersectionX2-intersectionX1, 0.0) * max(intersectionY2-intersectionY1, 0.0); + double w = (intersectionX2-intersectionX1>0.0) ? (intersectionX2-intersectionX1):0.0; + double h = (intersectionY2-intersectionY1>0.0) ? (intersectionY2-intersectionY1):0.0; + double intersectionArea = w*h; + + double trackletArea = tracklet.bbox(2)*tracklet.bbox(3); + //double trackletArea = tracklet.w*tracklet.h; + double detectionArea = (box.x2-box.x1)*(box.y2-box.y1); + double unionArea = trackletArea + detectionArea - intersectionArea; + double iou = (1-intersectionArea / unionArea*1.0); + return iou; +} + +vector> SORT::hungarian(vector> costMatrix) { +// vector> SORT::hungarian(array> costMatrix) { + int numRows = costMatrix.size(); + int numCols = costMatrix[0].size(); + + const bool transposed = numCols > numRows; + // transpose the matrix if necessary + if (transposed) { + vector> transposedMatrix(numCols, vector(numRows)); + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + transposedMatrix[j][i] = costMatrix[i][j]; + } + } + costMatrix = transposedMatrix; + swap(numRows, numCols); + } + vectorrowMin (numRows, numeric_limits::infinity()); + vectorcolMin(numCols, numeric_limits::infinity()); + vectorrowMatch(numRows, -1); + vectorcolMatch(numCols, -1); + vector> matches; + //step1: Subtract the row minimums from each row + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + rowMin[i] = min(rowMin[i], costMatrix[i][j]); + } + for (int j = 0; j < numCols; j++) { + costMatrix[i][j] -= rowMin[i]; + } + } + //step2: substract the colcum minimums from each column + for (int j = 0; j < numCols; j++) { + for (int i = 0; i < numRows; i++) { + colMin[j] = min(colMin[j], costMatrix[i][j]); + } + for (int i = 0; i < numRows; i++) { + costMatrix[i][j] -= colMin[j]; + } + } + //step3: find a maximal matching + for (int i = 0; i < numRows; i++) { + vector visited(numCols, false); + Augment(costMatrix, i, rowMatch, colMatch, visited); + } + //step4: calculate the matches + matches.clear(); + for (int j = 0; j < numCols; j++) { + matches.push_back(make_pair(colMatch[j], j)); + } + if (transposed) { + //vector> transposedMatches; + for (auto& match : matches) { + //transposedMatches.push_back(make_pair(match.second, match.first)); + swap(match.first,match.second); + } + //matches = transposedMatches; + } + return matches; +} + +bool SORT::Augment(const vector>& costMatrix, int row, vector& rowMatch, vector& colMatch, vector& visited) +// bool SORT::Augment(const array& costMatrix, int row, vector& rowMatch, vector& colMatch, vector& visited) +{ + int numCols = costMatrix[0].size(); + for (int j = 0; j < numCols; j++) { + if (costMatrix[row][j] == 0 && !visited[j]) { + visited[j] = true; + if (colMatch[j] == -1 || Augment(costMatrix, colMatch[j], rowMatch, colMatch, visited)) { + rowMatch[row] = j; + colMatch[j] = row; + return true; + } + } + } + return false; +} \ No newline at end of file diff --git a/samples/demo/SORT_2.h b/samples/demo/SORT_2.h new file mode 100644 index 0000000..d222f60 --- /dev/null +++ b/samples/demo/SORT_2.h @@ -0,0 +1,62 @@ +#ifndef SORT_2_H +#define SORT_2_H + +#include +#include +#include +#include +#include +#include + + + +//define the tracklet struct to store the tracked objects + +struct Tracklet { + public: + //Eigen::VectorXd bbox = Eigen::VectorXd::Zero(4);//x1,y1,w,h + //std::vector bboxes; + Eigen::Vector4d bbox; + //double x, y, w, h; + int id=0; + int age; + int hits; + int misses; + std::vector features; + Eigen::Matrix mean; + Eigen::Matrix covariance; +}; + +class KalmanFilter { + public: + KalmanFilter(); + std::pair,Eigen::Matrix> initiate(Eigen::Vector4d &bbox); + std::pair,Eigen::Matrix> update(Eigen::Matrix mean,Eigen::Matrix covariances,sv::Box & box); + std::pair,Eigen::Matrix> predict(Eigen::Matrixmean,Eigen::Matrixcovariances); + private: + Eigen::Matrix F_; + Eigen::Matrix H_; + Eigen::Matrix chi2inv95_; + double std_weight_position; + double std_weight_vel; +}; + +class SORT { + public: + SORT(double iouThreshold, int maxAge, int minHits): iouThreshold_(iouThreshold), maxAge_(maxAge), minHits_(minHits), nextTrackletId_(0) {}; + //void update(std::vector detections); + void update(sv::TargetsInFrame & tgts); + std::vector getTracklets() const; + private: + double iou(Tracklet & tracklet, sv::Box & box); + std::vector> hungarian(std::vector> costMatrix); + bool Augment(const std::vector>& costMatrix, int row, std::vector& rowMatch, std::vector& colMatch, std::vector& visited); + private: + double iouThreshold_; + int maxAge_; + int minHits_; + int nextTrackletId_; + std::vector tracklets_; +}; + +#endif \ No newline at end of file diff --git a/samples/demo/common_object_tracking_SORT.cpp b/samples/demo/common_object_tracking_SORT.cpp new file mode 100644 index 0000000..3537f06 --- /dev/null +++ b/samples/demo/common_object_tracking_SORT.cpp @@ -0,0 +1,188 @@ +#include +#include +// 包含SpireVision SDK头文件 +#include +#include +#include +#include "logging.h" + +// #include +#include "SORT_2.h" + +#define TRTCHECK(status) \ + do \ + { \ + auto ret = (status); \ + if (ret != 0) \ + { \ + std::cerr << "Cuda failure: " << ret << std::endl; \ + abort(); \ + } \ + } while (0) + +using namespace std; +using namespace nvinfer1; +static Logger g_nvlogger; + +// using namespace cv; +using namespace Eigen; + +// nvinfer1::IExecutionContext *g_trt_context; +// int g_input_index; +// int g_output_index; +// void *g_p_buffers[2]; +// cudaStream_t g_cu_stream; +// float *g_p_data; +// float *g_p_prob; + + +// void tensorrt_model_init() +// { +// std::string trt_model_fn = "/home/amov/SpireVision/models/resnet_reid.engine"; +// char *trt_model_stream{nullptr}; +// size_t trt_model_size{0}; +// try +// { +// std::ifstream file(trt_model_fn, std::ios::binary); +// file.seekg(0, file.end); +// trt_model_size = file.tellg(); +// file.seekg(0, file.beg); +// trt_model_stream = new char[trt_model_size]; +// assert(trt_model_stream); +// file.read(trt_model_stream, trt_model_size); +// file.close(); +// } +// catch (const std::runtime_error &e) +// { +// throw std::runtime_error("Error loading the TensorRT model!"); +// } + +// // TensorRT +// IRuntime *runtime = nvinfer1::createInferRuntime(g_nvlogger); +// assert(runtime != nullptr); +// ICudaEngine *p_cu_engine = runtime->deserializeCudaEngine(trt_model_stream, trt_model_size); +// assert(p_cu_engine != nullptr); +// g_trt_context = p_cu_engine->createExecutionContext(); +// assert(g_trt_context != nullptr); +// delete[] trt_model_stream; +// const ICudaEngine &cu_engine = g_trt_context->getEngine(); +// assert(cu_engine.getNbBindings() == 2); + +// g_input_index = cu_engine.getBindingIndex("input"); +// g_output_index = cu_engine.getBindingIndex("output"); +// TRTCHECK(cudaMalloc(&g_p_buffers[g_input_index], 1 * 3 * 32 * 32 * sizeof(float))); +// TRTCHECK(cudaMalloc(&g_p_buffers[g_output_index], 1 * 11 * sizeof(float))); +// TRTCHECK(cudaStreamCreate(&g_cu_stream)); + +// g_p_data = new float[1 * 3 * 32 * 32]; +// g_p_prob = new float[1 * 11]; +// // Input +// TRTCHECK(cudaMemcpyAsync(g_p_buffers[g_input_index], g_p_data, 1 * 3 * 32 * 32 * sizeof(float), cudaMemcpyHostToDevice, g_cu_stream)); +// g_trt_context->enqueue(1, g_p_buffers, g_cu_stream, nullptr); +// // Output +// TRTCHECK(cudaMemcpyAsync(g_p_prob, g_p_buffers[g_output_index], 1 * 11 * sizeof(float), cudaMemcpyDeviceToHost, g_cu_stream)); +// cudaStreamSynchronize(g_cu_stream); +// } + +int main(int argc, char *argv[]) { + // 实例化 通用目标 检测器类 + sv::CommonObjectDetector cod; + // 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件 + cod.loadCameraParams(sv::get_home() + "/SpireVision/calib_webcam_640x480.yaml"); + + // 打开摄像头 + // sv::Camera cap; + // cap.open(sv::CameraType::WEBCAM, 0); + cv::VideoCapture cap; + cap.open("/home/amov/SpireVision/video/predestrian/MOT17-01.mp4"); + // cap.open("/home/amov/SpireVision/video/vehicle/vehicle_02.mp4"); + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + int frame_id = 0; + SORT tracker(0.5, 10, 3); + while (1) + { + // 实例化SpireVision的 单帧检测结果 接口类 TargetsInFrame + sv::TargetsInFrame tgts(frame_id++); + // 读取一帧图像到img + cap.read(img); + cv::resize(img, img, cv::Size(cod.image_width, cod.image_height)); + + // 执行通用目标检测 + cod.detect(img, tgts); + tracker.update(tgts); + //vector detections; + // 可视化检测结果,叠加到img上 + // sv::drawTargetsInFrame(img, tgts); + /* + for (int i=0; ienqueue(1, g_p_buffers, g_cu_stream, nullptr); + // Output + TRTCHECK(cudaMemcpyAsync(g_p_prob, g_p_buffers[g_output_index], 1 * 11 * sizeof(float), cudaMemcpyDeviceToHost, g_cu_stream)); + cudaStreamSynchronize(g_cu_stream); + + // Find max index + double max = 0; + int label = 0; + for (int i = 0; i < 11; ++i) + { + if (max < g_p_prob[i]) + { + max = g_p_prob[i]; + label = i; + } + } +*/ + //cv::imshow("roi", roi); + // cv::waitKey(10); + //tracker.update(detections); + //vector().swap(detections); + + // 可视化检测结果,叠加到img上 + sv::drawTargetsInFrame(img, tgts); + // 控制台打印通用目标检测结果 + printf("Frame-[%d]\n", frame_id); + // 打印当前检测的FPS + printf(" FPS = %.2f\n", tgts.fps); + // 打印当前相机的视场角(degree) + printf(" FOV (fx, fy) = (%.2f, %.2f)\n", tgts.fov_x, tgts.fov_y); + for (int i=0; i +#include +#include +#include +#include +#include +#include +#include "macros.h" + +using Severity = nvinfer1::ILogger::Severity; + +class LogStreamConsumerBuffer : public std::stringbuf +{ +public: + LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mOutput(stream) + , mPrefix(prefix) + , mShouldLog(shouldLog) + { + } + + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) + : mOutput(other.mOutput) + { + } + + ~LogStreamConsumerBuffer() + { + // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence + // std::streambuf::pptr() gives a pointer to the current position of the output sequence + // if the pointer to the beginning is not equal to the pointer to the current position, + // call putOutput() to log the output to the stream + if (pbase() != pptr()) + { + putOutput(); + } + } + + // synchronizes the stream buffer and returns 0 on success + // synchronizing the stream buffer consists of inserting the buffer contents into the stream, + // resetting the buffer and flushing the stream + virtual int sync() + { + putOutput(); + return 0; + } + + void putOutput() + { + if (mShouldLog) + { + // prepend timestamp + std::time_t timestamp = std::time(nullptr); + tm* tm_local = std::localtime(×tamp); + std::cout << "["; + std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; + std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; + // std::stringbuf::str() gets the string contents of the buffer + // insert the buffer contents pre-appended by the appropriate prefix into the stream + mOutput << mPrefix << str(); + // set the buffer to empty + str(""); + // flush the stream + mOutput.flush(); + } + } + + void setShouldLog(bool shouldLog) + { + mShouldLog = shouldLog; + } + +private: + std::ostream& mOutput; + std::string mPrefix; + bool mShouldLog; +}; + +//! +//! \class LogStreamConsumerBase +//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer +//! +class LogStreamConsumerBase +{ +public: + LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mBuffer(stream, prefix, shouldLog) + { + } + +protected: + LogStreamConsumerBuffer mBuffer; +}; + +//! +//! \class LogStreamConsumer +//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. +//! Order of base classes is LogStreamConsumerBase and then std::ostream. +//! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field +//! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. +//! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. +//! Please do not change the order of the parent classes. +//! +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream +{ +public: + //! \brief Creates a LogStreamConsumer which logs messages with level severity. + //! Reportable severity determines if the messages are severe enough to be logged. + LogStreamConsumer(Severity reportableSeverity, Severity severity) + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) + , std::ostream(&mBuffer) // links the stream buffer with the stream + , mShouldLog(severity <= reportableSeverity) + , mSeverity(severity) + { + } + + LogStreamConsumer(LogStreamConsumer&& other) + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) + , std::ostream(&mBuffer) // links the stream buffer with the stream + , mShouldLog(other.mShouldLog) + , mSeverity(other.mSeverity) + { + } + + void setReportableSeverity(Severity reportableSeverity) + { + mShouldLog = mSeverity <= reportableSeverity; + mBuffer.setShouldLog(mShouldLog); + } + +private: + static std::ostream& severityOstream(Severity severity) + { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + static std::string severityPrefix(Severity severity) + { + switch (severity) + { + case Severity::kINTERNAL_ERROR: return "[F] "; + case Severity::kERROR: return "[E] "; + case Severity::kWARNING: return "[W] "; + case Severity::kINFO: return "[I] "; + case Severity::kVERBOSE: return "[V] "; + default: assert(0); return ""; + } + } + + bool mShouldLog; + Severity mSeverity; +}; + +//! \class Logger +//! +//! \brief Class which manages logging of TensorRT tools and samples +//! +//! \details This class provides a common interface for TensorRT tools and samples to log information to the console, +//! and supports logging two types of messages: +//! +//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) +//! - Test pass/fail messages +//! +//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is +//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. +//! +//! In the future, this class could be extended to support dumping test results to a file in some standard format +//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). +//! +//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger +//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT +//! library and messages coming from the sample. +//! +//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the +//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger +//! object. + +class Logger : public nvinfer1::ILogger +{ +public: + Logger(Severity severity = Severity::kWARNING) + : mReportableSeverity(severity) + { + } + + //! + //! \enum TestResult + //! \brief Represents the state of a given test + //! + enum class TestResult + { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived + }; + + //! + //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger + //! \return The nvinfer1::ILogger associated with this Logger + //! + //! TODO Once all samples are updated to use this method to register the logger with TensorRT, + //! we can eliminate the inheritance of Logger from ILogger + //! + nvinfer1::ILogger& getTRTLogger() + { + return *this; + } + + //! + //! \brief Implementation of the nvinfer1::ILogger::log() virtual method + //! + //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the + //! inheritance from nvinfer1::ILogger + //! + void log(Severity severity, const char* msg) TRT_NOEXCEPT override + { + LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; + } + + //! + //! \brief Method for controlling the verbosity of logging output + //! + //! \param severity The logger will only emit messages that have severity of this level or higher. + //! + void setReportableSeverity(Severity severity) + { + mReportableSeverity = severity; + } + + //! + //! \brief Opaque handle that holds logging information for a particular test + //! + //! This object is an opaque handle to information used by the Logger to print test results. + //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used + //! with Logger::reportTest{Start,End}(). + //! + class TestAtom + { + public: + TestAtom(TestAtom&&) = default; + + private: + friend class Logger; + + TestAtom(bool started, const std::string& name, const std::string& cmdline) + : mStarted(started) + , mName(name) + , mCmdline(cmdline) + { + } + + bool mStarted; + std::string mName; + std::string mCmdline; + }; + + //! + //! \brief Define a test for logging + //! + //! \param[in] name The name of the test. This should be a string starting with + //! "TensorRT" and containing dot-separated strings containing + //! the characters [A-Za-z0-9_]. + //! For example, "TensorRT.sample_googlenet" + //! \param[in] cmdline The command line used to reproduce the test + // + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + //! + static TestAtom defineTest(const std::string& name, const std::string& cmdline) + { + return TestAtom(false, name, cmdline); + } + + //! + //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments + //! as input + //! + //! \param[in] name The name of the test + //! \param[in] argc The number of command-line arguments + //! \param[in] argv The array of command-line arguments (given as C strings) + //! + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) + { + auto cmdline = genCmdlineString(argc, argv); + return defineTest(name, cmdline); + } + + //! + //! \brief Report that a test has started. + //! + //! \pre reportTestStart() has not been called yet for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has started + //! + static void reportTestStart(TestAtom& testAtom) + { + reportTestResult(testAtom, TestResult::kRUNNING); + assert(!testAtom.mStarted); + testAtom.mStarted = true; + } + + //! + //! \brief Report that a test has ended. + //! + //! \pre reportTestStart() has been called for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has ended + //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, + //! TestResult::kFAILED, TestResult::kWAIVED + //! + static void reportTestEnd(const TestAtom& testAtom, TestResult result) + { + assert(result != TestResult::kRUNNING); + assert(testAtom.mStarted); + reportTestResult(testAtom, result); + } + + static int reportPass(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kPASSED); + return EXIT_SUCCESS; + } + + static int reportFail(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kFAILED); + return EXIT_FAILURE; + } + + static int reportWaive(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kWAIVED); + return EXIT_SUCCESS; + } + + static int reportTest(const TestAtom& testAtom, bool pass) + { + return pass ? reportPass(testAtom) : reportFail(testAtom); + } + + Severity getReportableSeverity() const + { + return mReportableSeverity; + } + +private: + //! + //! \brief returns an appropriate string for prefixing a log message with the given severity + //! + static const char* severityPrefix(Severity severity) + { + switch (severity) + { + case Severity::kINTERNAL_ERROR: return "[F] "; + case Severity::kERROR: return "[E] "; + case Severity::kWARNING: return "[W] "; + case Severity::kINFO: return "[I] "; + case Severity::kVERBOSE: return "[V] "; + default: assert(0); return ""; + } + } + + //! + //! \brief returns an appropriate string for prefixing a test result message with the given result + //! + static const char* testResultString(TestResult result) + { + switch (result) + { + case TestResult::kRUNNING: return "RUNNING"; + case TestResult::kPASSED: return "PASSED"; + case TestResult::kFAILED: return "FAILED"; + case TestResult::kWAIVED: return "WAIVED"; + default: assert(0); return ""; + } + } + + //! + //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity + //! + static std::ostream& severityOstream(Severity severity) + { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + //! + //! \brief method that implements logging test results + //! + static void reportTestResult(const TestAtom& testAtom, TestResult result) + { + severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " + << testAtom.mCmdline << std::endl; + } + + //! + //! \brief generate a command line string from the given (argc, argv) values + //! + static std::string genCmdlineString(int argc, char const* const* argv) + { + std::stringstream ss; + for (int i = 0; i < argc; i++) + { + if (i > 0) + ss << " "; + ss << argv[i]; + } + return ss.str(); + } + + Severity mReportableSeverity; +}; + +namespace +{ + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE +//! +//! Example usage: +//! +//! LOG_VERBOSE(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO +//! +//! Example usage: +//! +//! LOG_INFO(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_INFO(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING +//! +//! Example usage: +//! +//! LOG_WARN(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_WARN(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR +//! +//! Example usage: +//! +//! LOG_ERROR(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_ERROR(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR +// ("fatal" severity) +//! +//! Example usage: +//! +//! LOG_FATAL(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_FATAL(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); +} + +} // anonymous namespace + +#endif // TENSORRT_LOGGING_H diff --git a/samples/demo/macros.h b/samples/demo/macros.h new file mode 100644 index 0000000..17339a2 --- /dev/null +++ b/samples/demo/macros.h @@ -0,0 +1,29 @@ +#ifndef __MACROS_H +#define __MACROS_H + +#include + +#ifdef API_EXPORTS +#if defined(_MSC_VER) +#define API __declspec(dllexport) +#else +#define API __attribute__((visibility("default"))) +#endif +#else + +#if defined(_MSC_VER) +#define API __declspec(dllimport) +#else +#define API +#endif +#endif // API_EXPORTS + +#if NV_TENSORRT_MAJOR >= 8 +#define TRT_NOEXCEPT noexcept +#define TRT_CONST_ENQUEUE const +#else +#define TRT_NOEXCEPT +#define TRT_CONST_ENQUEUE +#endif + +#endif // __MACROS_H