Pre Merge pull request !9 from 褚昭晨/czc
This commit is contained in:
commit
bbd9e18725
|
@ -54,6 +54,7 @@ endif()
|
||||||
|
|
||||||
add_definitions(-DWITH_OCV470)
|
add_definitions(-DWITH_OCV470)
|
||||||
find_package(OpenCV 4.7 REQUIRED)
|
find_package(OpenCV 4.7 REQUIRED)
|
||||||
|
find_package(Eigen3 REQUIRED)
|
||||||
message(STATUS "OpenCV library status:")
|
message(STATUS "OpenCV library status:")
|
||||||
message(STATUS " version: ${OpenCV_VERSION}")
|
message(STATUS " version: ${OpenCV_VERSION}")
|
||||||
message(STATUS " libraries: ${OpenCV_LIBS}")
|
message(STATUS " libraries: ${OpenCV_LIBS}")
|
||||||
|
@ -61,6 +62,7 @@ message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}")
|
||||||
|
|
||||||
|
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||||
|
include_directories(${EIGEN3_INCLUDE_DIRS})
|
||||||
include_directories(
|
include_directories(
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/IOs/serial/include
|
${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/IOs/serial/include
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/FIFO
|
${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/FIFO
|
||||||
|
@ -76,6 +78,7 @@ include_directories(
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/video_io
|
${CMAKE_CURRENT_SOURCE_DIR}/video_io
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/algorithm/ellipse_det
|
${CMAKE_CURRENT_SOURCE_DIR}/algorithm/ellipse_det
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils
|
${CMAKE_CURRENT_SOURCE_DIR}/utils
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/algorithm/common_mot
|
||||||
)
|
)
|
||||||
|
|
||||||
if(USE_GSTREAMER)
|
if(USE_GSTREAMER)
|
||||||
|
@ -110,6 +113,7 @@ set(
|
||||||
include/sv_gimbal.h
|
include/sv_gimbal.h
|
||||||
include/sv_algorithm_base.h
|
include/sv_algorithm_base.h
|
||||||
include/sv_common_det.h
|
include/sv_common_det.h
|
||||||
|
include/sv_common_mot.h
|
||||||
include/sv_landing_det.h
|
include/sv_landing_det.h
|
||||||
include/sv_tracking.h
|
include/sv_tracking.h
|
||||||
include/sv_color_line.h
|
include/sv_color_line.h
|
||||||
|
@ -151,6 +155,7 @@ set(spirecv_SRCS
|
||||||
algorithm/sv_algorithm_base.cpp
|
algorithm/sv_algorithm_base.cpp
|
||||||
algorithm/ellipse_det/ellipse_detector.cpp
|
algorithm/ellipse_det/ellipse_detector.cpp
|
||||||
algorithm/common_det/sv_common_det.cpp
|
algorithm/common_det/sv_common_det.cpp
|
||||||
|
algorithm/common_mot/sv_common_mot.cpp
|
||||||
algorithm/landing_det/sv_landing_det.cpp
|
algorithm/landing_det/sv_landing_det.cpp
|
||||||
algorithm/tracking/sv_tracking.cpp
|
algorithm/tracking/sv_tracking.cpp
|
||||||
algorithm/color_line/sv_color_line.cpp
|
algorithm/color_line/sv_color_line.cpp
|
||||||
|
@ -266,6 +271,9 @@ target_link_libraries(GimbalLandingMarkerDetection sv_world)
|
||||||
add_executable(GimbalUdpDetectionInfoSender samples/demo/gimbal_udp_detection_info_sender.cpp)
|
add_executable(GimbalUdpDetectionInfoSender samples/demo/gimbal_udp_detection_info_sender.cpp)
|
||||||
target_link_libraries(GimbalUdpDetectionInfoSender sv_world)
|
target_link_libraries(GimbalUdpDetectionInfoSender sv_world)
|
||||||
|
|
||||||
|
add_executable(SortTracker samples/demo/common_object_sort.cpp)
|
||||||
|
target_link_libraries(SortTracker sv_world)
|
||||||
|
|
||||||
add_executable(EvalFpsOnVideo samples/test/eval_fps_on_video.cpp)
|
add_executable(EvalFpsOnVideo samples/test/eval_fps_on_video.cpp)
|
||||||
target_link_libraries(EvalFpsOnVideo sv_world)
|
target_link_libraries(EvalFpsOnVideo sv_world)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,286 @@
|
||||||
|
#include "sv_common_mot.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include "sv_util.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace Eigen;
|
||||||
|
|
||||||
|
namespace sv {
|
||||||
|
KalmanFilter::KalmanFilter(){
|
||||||
|
chi2inv95_ << 3.8415, 5.9915, 7.8147, 9.4877, 11.070, 12.592, 14.067, 15.507, 16.919;
|
||||||
|
F_=MatrixXd::Identity(8,8);
|
||||||
|
for (int i=0;i<4;i++){
|
||||||
|
F_(i,i+4)=1;
|
||||||
|
}
|
||||||
|
H_=MatrixXd::Identity(4,8);
|
||||||
|
std_weight_position=1./20;
|
||||||
|
std_weight_vel=1./160;
|
||||||
|
}
|
||||||
|
KalmanFilter::~KalmanFilter(){};
|
||||||
|
pair<Matrix<double,8,1>,Matrix<double,8,8>> KalmanFilter::initiate(Vector4d &bbox){
|
||||||
|
Matrix<double,8,1> mean;
|
||||||
|
mean<<bbox(0), bbox(1), bbox(2)/bbox(3), bbox(3), 0, 0, 0, 0;
|
||||||
|
VectorXd stds(8);
|
||||||
|
stds << 2*std_weight_position*mean(3), 2*std_weight_position*mean(3), 0.01, 2*std_weight_position*mean(3), 10*std_weight_vel*mean(3), 10*std_weight_vel*mean(3), 1e-5, 10*std_weight_vel*mean(3);
|
||||||
|
MatrixXd squared = stds.array().square();
|
||||||
|
Matrix<double,8,8> covariances;
|
||||||
|
covariances = squared.asDiagonal();
|
||||||
|
return make_pair(mean, covariances);
|
||||||
|
}
|
||||||
|
|
||||||
|
pair<Matrix<double,8,1>,Matrix<double,8,8>> KalmanFilter::update(Matrix<double,8,1> mean, Matrix<double,8,8> covariances, sv::Box &box){
|
||||||
|
MatrixXd R_;
|
||||||
|
Vector4d stds;
|
||||||
|
|
||||||
|
stds << std_weight_position * mean(3), std_weight_position* mean(3), 0.1, std_weight_position* mean(3);
|
||||||
|
MatrixXd squared = stds.array().square();
|
||||||
|
R_= squared.asDiagonal();
|
||||||
|
MatrixXd S = H_ * covariances * H_.transpose()+R_;
|
||||||
|
|
||||||
|
MatrixXd Kalman_gain = covariances * H_.transpose() * S.inverse();
|
||||||
|
VectorXd measurement(4);
|
||||||
|
measurement << box.x1, box.y1, (box.x2-box.x1)/(box.y2-box.y1), box.y2-box.y1;
|
||||||
|
Matrix<double, 8, 1> new_mean = mean + Kalman_gain * (measurement - H_ * mean);
|
||||||
|
Matrix<double, 8, 8> new_covariances = (MatrixXd::Identity(8, 8) - Kalman_gain * H_) * covariances;
|
||||||
|
return make_pair(new_mean, new_covariances);
|
||||||
|
}
|
||||||
|
|
||||||
|
pair<Matrix<double, 8, 1>, Matrix<double, 8, 8>> KalmanFilter::predict(Matrix<double, 8, 1>mean, Matrix<double, 8, 8>covariances) {
|
||||||
|
|
||||||
|
VectorXd stds(8);
|
||||||
|
stds << std_weight_position * mean(3), std_weight_position* mean(3), 0.01, std_weight_position* mean(3), std_weight_vel* mean(3), std_weight_vel* mean(3), 1e-5, std_weight_vel* mean(3);
|
||||||
|
MatrixXd squared = stds.array().square();
|
||||||
|
MatrixXd Q_ = squared.asDiagonal();
|
||||||
|
Matrix<double,8,1> pre_mean = F_ * mean;
|
||||||
|
Matrix<double,8,8> pre_cov = F_ * covariances * F_.transpose()+Q_;
|
||||||
|
return make_pair(pre_mean, pre_cov);
|
||||||
|
}
|
||||||
|
|
||||||
|
SORT::~SORT(){
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void SORT::update(sv::TargetsInFrame & tgts){
|
||||||
|
|
||||||
|
sv::KalmanFilter kf;
|
||||||
|
// if (! tracklets_.size() || ! tgts.targets.size())
|
||||||
|
if (! tracklets_.size())
|
||||||
|
{
|
||||||
|
Vector4d bbox;
|
||||||
|
for (int i = 0; i <tgts.targets.size(); i++)
|
||||||
|
{
|
||||||
|
sv::Box box;
|
||||||
|
tgts.targets[i].getBox(box);
|
||||||
|
Tracklet tracklet;
|
||||||
|
tracklet.id = ++next_tracklet_id_;//
|
||||||
|
cout<<tracklet.id<<endl;
|
||||||
|
tgts.targets[i].tracked_id = next_tracklet_id_;
|
||||||
|
|
||||||
|
tracklet.bbox << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1;//x,y,w,h
|
||||||
|
tracklet.age = 0;
|
||||||
|
tracklet.hits = 1;
|
||||||
|
tracklet.misses = 0;
|
||||||
|
//initate the motion
|
||||||
|
|
||||||
|
pair<Matrix<double,8,1>,Matrix<double,8,8>> motion = kf.initiate(tracklet.bbox);
|
||||||
|
tracklet.mean=motion.first;
|
||||||
|
tracklet.covariance = motion.second;
|
||||||
|
|
||||||
|
tracklets_.push_back(tracklet);
|
||||||
|
//}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
for (int i = 0; i < tgts.targets.size(); i++)
|
||||||
|
{
|
||||||
|
tgts.targets[i].tracked_id = 0;
|
||||||
|
}
|
||||||
|
// vector<int> match_det(100,-1);
|
||||||
|
array<int,100> match_det;
|
||||||
|
match_det.fill(-1);
|
||||||
|
//predict the next state of each tracklet
|
||||||
|
for (auto& tracklet : tracklets_) {
|
||||||
|
tracklet.age++;
|
||||||
|
pair<Matrix<double, 8, 1>, Matrix<double, 8, 8>> motion = kf.predict(tracklet.mean, tracklet.covariance);
|
||||||
|
tracklet.bbox << motion.first(0),motion.first(1),motion.first(2)*motion.first(3),motion.first(3);
|
||||||
|
tracklet.mean = motion.first;
|
||||||
|
tracklet.covariance = motion.second;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
//Match the detections to the existing tracklets
|
||||||
|
cout<<"the num of targets: "<<tgts.targets.size()<<endl;
|
||||||
|
cout<<"the num of tracklets: "<<tracklets_.size()<<endl;
|
||||||
|
vector<vector<double>> iouMatrix(tracklets_.size(), vector<double> (tgts.targets.size(), 0)); //
|
||||||
|
for (int i = 0; i <tracklets_.size(); i++) {
|
||||||
|
for (int j = 0; j < tgts.targets.size(); j++) {
|
||||||
|
sv::Box box;
|
||||||
|
tgts.targets[j].getBox(box);
|
||||||
|
iouMatrix[i][j] = iou(tracklets_[i], box);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vector<pair<int,int>> matches = hungarian(iouMatrix);
|
||||||
|
for (auto& match : matches) {
|
||||||
|
int trackletIndex = match.first;
|
||||||
|
int detectionIndex = match.second;
|
||||||
|
if (trackletIndex>=0 && detectionIndex>=0) {
|
||||||
|
if(iouMatrix[match.first][match.second]>=0) {
|
||||||
|
sv::Box box;
|
||||||
|
tgts.targets[detectionIndex].getBox(box);
|
||||||
|
tracklets_[trackletIndex].age = 0;
|
||||||
|
tracklets_[trackletIndex].hits++;
|
||||||
|
tracklets_[trackletIndex].bbox << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1;
|
||||||
|
|
||||||
|
auto[mean,covariance]=kf.update(tracklets_[trackletIndex].mean,tracklets_[trackletIndex].covariance, box);
|
||||||
|
tracklets_[trackletIndex].mean = mean;
|
||||||
|
tracklets_[trackletIndex].covariance = covariance;
|
||||||
|
|
||||||
|
tgts.targets[detectionIndex].tracked_id = tracklets_[trackletIndex].id;
|
||||||
|
match_det[detectionIndex]=detectionIndex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// create new tracklets for unmatched detections
|
||||||
|
for (int i = 0; i < tgts.targets.size(); i++)
|
||||||
|
{
|
||||||
|
if (match_det[i]==-1) {
|
||||||
|
sv::Box box;
|
||||||
|
tgts.targets[i].getBox(box);
|
||||||
|
Tracklet tracklet;
|
||||||
|
tracklet.id = ++next_tracklet_id_;//
|
||||||
|
tracklet.bbox << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1;
|
||||||
|
|
||||||
|
tracklet.age = 0;
|
||||||
|
tracklet.hits = 1;
|
||||||
|
tracklet.misses = 0;
|
||||||
|
|
||||||
|
auto [new_mean,new_covariance] = kf.initiate(tracklet.bbox);
|
||||||
|
tracklet.mean = new_mean;
|
||||||
|
tracklet.covariance = new_covariance;
|
||||||
|
|
||||||
|
tgts.targets[i].tracked_id = next_tracklet_id_;
|
||||||
|
tracklets_.push_back(tracklet);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
vector<Tracklet> newTracklets;
|
||||||
|
for (auto& tracklet : tracklets_) {
|
||||||
|
if (tracklet.age < maxAge_ || tracklet.hits >= minHits_) {
|
||||||
|
newTracklets.push_back(tracklet);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracklets_ = newTracklets;
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Tracklet> SORT::getTracklets() const{
|
||||||
|
return tracklets_;
|
||||||
|
}
|
||||||
|
|
||||||
|
double SORT::iou(Tracklet & tracklet, sv::Box & box) {
|
||||||
|
|
||||||
|
double trackletX1 = tracklet.bbox(0);
|
||||||
|
double trackletY1 = tracklet.bbox(1);
|
||||||
|
double trackletX2 = tracklet.bbox(0)+tracklet.bbox(2);
|
||||||
|
double trackletY2 = tracklet.bbox(1)+tracklet.bbox(3);
|
||||||
|
|
||||||
|
double detectionX1 = box.x1;
|
||||||
|
double detectionY1 = box.y1;
|
||||||
|
double detectionX2 = box.x2;
|
||||||
|
double detectionY2 = box.y2;
|
||||||
|
double intersectionX1 = max(trackletX1, detectionX1);
|
||||||
|
double intersectionY1 = max(trackletY1, detectionY1);
|
||||||
|
double intersectionX2 = min(trackletX2, detectionX2);
|
||||||
|
double intersectionY2 = min(trackletY2, detectionY2);
|
||||||
|
|
||||||
|
double w = (intersectionX2-intersectionX1>0.0) ? (intersectionX2-intersectionX1):0.0;
|
||||||
|
double h = (intersectionY2-intersectionY1>0.0) ? (intersectionY2-intersectionY1):0.0;
|
||||||
|
double intersectionArea = w*h;
|
||||||
|
|
||||||
|
double trackletArea = tracklet.bbox(2)*tracklet.bbox(3);
|
||||||
|
|
||||||
|
double detectionArea = (box.x2-box.x1)*(box.y2-box.y1);
|
||||||
|
double unionArea = trackletArea + detectionArea - intersectionArea;
|
||||||
|
double iou = (1-intersectionArea / unionArea*1.0);
|
||||||
|
|
||||||
|
return iou;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<pair<int, int>> SORT::hungarian(vector<vector<double>> costMatrix) {
|
||||||
|
int numRows = costMatrix.size();
|
||||||
|
int numCols = costMatrix[0].size();
|
||||||
|
|
||||||
|
const bool transposed = numCols > numRows;
|
||||||
|
// transpose the matrix if necessary
|
||||||
|
if (transposed) {
|
||||||
|
vector<vector<double>> transposedMatrix(numCols, vector<double>(numRows));
|
||||||
|
for (int i = 0; i < numRows; i++) {
|
||||||
|
for (int j = 0; j < numCols; j++) {
|
||||||
|
transposedMatrix[j][i] = costMatrix[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
costMatrix = transposedMatrix;
|
||||||
|
swap(numRows, numCols);
|
||||||
|
}
|
||||||
|
vector<double>rowMin (numRows, numeric_limits<double>::infinity());
|
||||||
|
vector<double>colMin(numCols, numeric_limits<double>::infinity());
|
||||||
|
vector<int>rowMatch(numRows, -1);
|
||||||
|
vector<int>colMatch(numCols, -1);
|
||||||
|
vector<pair<int, int>> matches;
|
||||||
|
//step1: Subtract the row minimums from each row
|
||||||
|
for (int i = 0; i < numRows; i++) {
|
||||||
|
for (int j = 0; j < numCols; j++) {
|
||||||
|
rowMin[i] = min(rowMin[i], costMatrix[i][j]);
|
||||||
|
}
|
||||||
|
for (int j = 0; j < numCols; j++) {
|
||||||
|
costMatrix[i][j] -= rowMin[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//step2: substract the colcum minimums from each column
|
||||||
|
for (int j = 0; j < numCols; j++) {
|
||||||
|
for (int i = 0; i < numRows; i++) {
|
||||||
|
colMin[j] = min(colMin[j], costMatrix[i][j]);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < numRows; i++) {
|
||||||
|
costMatrix[i][j] -= colMin[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//step3: find a maximal matching
|
||||||
|
for (int i = 0; i < numRows; i++) {
|
||||||
|
vector<bool> visited(numCols, false);
|
||||||
|
Augment(costMatrix, i, rowMatch, colMatch, visited);
|
||||||
|
}
|
||||||
|
//step4: calculate the matches
|
||||||
|
matches.clear();
|
||||||
|
for (int j = 0; j < numCols; j++) {
|
||||||
|
matches.push_back(make_pair(colMatch[j], j));
|
||||||
|
}
|
||||||
|
if (transposed) {
|
||||||
|
//vector<pair<int,int>> transposedMatches;
|
||||||
|
for (auto& match : matches) {
|
||||||
|
//transposedMatches.push_back(make_pair(match.second, match.first));
|
||||||
|
swap(match.first,match.second);
|
||||||
|
}
|
||||||
|
//matches = transposedMatches;
|
||||||
|
}
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SORT::Augment(const vector<vector<double>>& costMatrix, int row, vector<int>& rowMatch, vector<int>& colMatch, vector<bool>& visited)
|
||||||
|
// bool SORT::Augment(const array<array<double>& costMatrix, int row, vector<int>& rowMatch, vector<int>& colMatch, vector<bool>& visited)
|
||||||
|
{
|
||||||
|
int numCols = costMatrix[0].size();
|
||||||
|
for (int j = 0; j < numCols; j++) {
|
||||||
|
if (costMatrix[row][j] == 0 && !visited[j]) {
|
||||||
|
visited[j] = true;
|
||||||
|
if (colMatch[j] == -1 || Augment(costMatrix, colMatch[j], rowMatch, colMatch, visited)) {
|
||||||
|
rowMatch[row] = j;
|
||||||
|
colMatch[j] = row;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
#ifndef __SV_COMMON_MOT__
|
||||||
|
#define __SV_COMMON_MOT__
|
||||||
|
|
||||||
|
#include <Eigen/Dense>
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include <opencv2/opencv.hpp>
|
||||||
|
#include <array>
|
||||||
|
#include "sv_core.h"
|
||||||
|
|
||||||
|
namespace sv{
|
||||||
|
//define the tracklet struct to store the tracked objects.
|
||||||
|
struct Tracklet
|
||||||
|
{
|
||||||
|
/* data */
|
||||||
|
public:
|
||||||
|
Eigen::Vector4d bbox;//double x, y, w, h;
|
||||||
|
int id=0;
|
||||||
|
int age;
|
||||||
|
int hits;
|
||||||
|
int misses;
|
||||||
|
std::vector<double> features;
|
||||||
|
Eigen::Matrix<double,8,1> mean;
|
||||||
|
Eigen::Matrix<double,8,8> covariance;
|
||||||
|
};
|
||||||
|
|
||||||
|
class KalmanFilter {
|
||||||
|
public:
|
||||||
|
KalmanFilter();
|
||||||
|
~KalmanFilter();
|
||||||
|
std::pair<Eigen::Matrix<double,8,1>,Eigen::Matrix<double,8,8>> initiate(Eigen::Vector4d &bbox);
|
||||||
|
std::pair<Eigen::Matrix<double,8,1>,Eigen::Matrix<double,8,8>> update(Eigen::Matrix<double,8,1> mean, Eigen::Matrix<double,8,8> covariances, Box &box);
|
||||||
|
std::pair<Eigen::Matrix<double,8,1>,Eigen::Matrix<double,8,8>> predict(Eigen::Matrix<double,8,1> mean, Eigen::Matrix<double,8,8> covariances);
|
||||||
|
private:
|
||||||
|
Eigen::Matrix<double,8,8> F_;
|
||||||
|
Eigen::Matrix<double,4,8> H_;
|
||||||
|
Eigen::Matrix<double,9,1> chi2inv95_;
|
||||||
|
double std_weight_position;
|
||||||
|
double std_weight_vel;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SORT {
|
||||||
|
public:
|
||||||
|
SORT(double iou_threshold, int max_age, int min_hits): iou_threshold_(iou_threshold),max_age_(max_age),min_hits_(min_hits),next_tracklet_id_(0) {};
|
||||||
|
~SORT();
|
||||||
|
void update(TargetsInFrame &tgts);
|
||||||
|
std::vector<Tracklet> getTracklets() const;
|
||||||
|
private:
|
||||||
|
double iou(Tracklet &tracklet, Box &box);
|
||||||
|
std::vector<std::pair<int,int>> hungarian(std::vector<std::vector<double>> costMatrix);
|
||||||
|
bool Augment(const std::vector<std::vector<double>>& costMatrix, int row, std::vector<int>& rowMatch, std::vector<int>& colMatch, std::vector<bool>& visited);
|
||||||
|
private:
|
||||||
|
double iou_threshold_;
|
||||||
|
int max_age_;
|
||||||
|
int min_hits_;
|
||||||
|
int next_tracklet_id_;
|
||||||
|
std::vector <Tracklet> tracklets_;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -8,6 +8,7 @@
|
||||||
#include "sv_color_line.h"
|
#include "sv_color_line.h"
|
||||||
#include "sv_video_input.h"
|
#include "sv_video_input.h"
|
||||||
#include "sv_video_output.h"
|
#include "sv_video_output.h"
|
||||||
|
#include "sv_common_mot.h"
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -0,0 +1,187 @@
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
// 包含SpireVision SDK头文件
|
||||||
|
#include <sv_world.h>
|
||||||
|
#include <NvInfer.h>
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
#include "logging.h"
|
||||||
|
|
||||||
|
// #include <opencv2/opencv.hpp>
|
||||||
|
|
||||||
|
#define TRTCHECK(status) \
|
||||||
|
do \
|
||||||
|
{ \
|
||||||
|
auto ret = (status); \
|
||||||
|
if (ret != 0) \
|
||||||
|
{ \
|
||||||
|
std::cerr << "Cuda failure: " << ret << std::endl; \
|
||||||
|
abort(); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace nvinfer1;
|
||||||
|
static Logger g_nvlogger;
|
||||||
|
|
||||||
|
// using namespace cv;
|
||||||
|
using namespace Eigen;
|
||||||
|
|
||||||
|
// nvinfer1::IExecutionContext *g_trt_context;
|
||||||
|
// int g_input_index;
|
||||||
|
// int g_output_index;
|
||||||
|
// void *g_p_buffers[2];
|
||||||
|
// cudaStream_t g_cu_stream;
|
||||||
|
// float *g_p_data;
|
||||||
|
// float *g_p_prob;
|
||||||
|
|
||||||
|
|
||||||
|
// void tensorrt_model_init()
|
||||||
|
// {
|
||||||
|
// std::string trt_model_fn = "/home/amov/SpireVision/models/resnet_reid.engine";
|
||||||
|
// char *trt_model_stream{nullptr};
|
||||||
|
// size_t trt_model_size{0};
|
||||||
|
// try
|
||||||
|
// {
|
||||||
|
// std::ifstream file(trt_model_fn, std::ios::binary);
|
||||||
|
// file.seekg(0, file.end);
|
||||||
|
// trt_model_size = file.tellg();
|
||||||
|
// file.seekg(0, file.beg);
|
||||||
|
// trt_model_stream = new char[trt_model_size];
|
||||||
|
// assert(trt_model_stream);
|
||||||
|
// file.read(trt_model_stream, trt_model_size);
|
||||||
|
// file.close();
|
||||||
|
// }
|
||||||
|
// catch (const std::runtime_error &e)
|
||||||
|
// {
|
||||||
|
// throw std::runtime_error("Error loading the TensorRT model!");
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // TensorRT
|
||||||
|
// IRuntime *runtime = nvinfer1::createInferRuntime(g_nvlogger);
|
||||||
|
// assert(runtime != nullptr);
|
||||||
|
// ICudaEngine *p_cu_engine = runtime->deserializeCudaEngine(trt_model_stream, trt_model_size);
|
||||||
|
// assert(p_cu_engine != nullptr);
|
||||||
|
// g_trt_context = p_cu_engine->createExecutionContext();
|
||||||
|
// assert(g_trt_context != nullptr);
|
||||||
|
// delete[] trt_model_stream;
|
||||||
|
// const ICudaEngine &cu_engine = g_trt_context->getEngine();
|
||||||
|
// assert(cu_engine.getNbBindings() == 2);
|
||||||
|
|
||||||
|
// g_input_index = cu_engine.getBindingIndex("input");
|
||||||
|
// g_output_index = cu_engine.getBindingIndex("output");
|
||||||
|
// TRTCHECK(cudaMalloc(&g_p_buffers[g_input_index], 1 * 3 * 32 * 32 * sizeof(float)));
|
||||||
|
// TRTCHECK(cudaMalloc(&g_p_buffers[g_output_index], 1 * 11 * sizeof(float)));
|
||||||
|
// TRTCHECK(cudaStreamCreate(&g_cu_stream));
|
||||||
|
|
||||||
|
// g_p_data = new float[1 * 3 * 32 * 32];
|
||||||
|
// g_p_prob = new float[1 * 11];
|
||||||
|
// // Input
|
||||||
|
// TRTCHECK(cudaMemcpyAsync(g_p_buffers[g_input_index], g_p_data, 1 * 3 * 32 * 32 * sizeof(float), cudaMemcpyHostToDevice, g_cu_stream));
|
||||||
|
// g_trt_context->enqueue(1, g_p_buffers, g_cu_stream, nullptr);
|
||||||
|
// // Output
|
||||||
|
// TRTCHECK(cudaMemcpyAsync(g_p_prob, g_p_buffers[g_output_index], 1 * 11 * sizeof(float), cudaMemcpyDeviceToHost, g_cu_stream));
|
||||||
|
// cudaStreamSynchronize(g_cu_stream);
|
||||||
|
// }
|
||||||
|
|
||||||
|
int main(int argc, char *argv[]) {
|
||||||
|
// 实例化 通用目标 检测器类
|
||||||
|
sv::CommonObjectDetector cod;
|
||||||
|
// 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件
|
||||||
|
cod.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_640x480.yaml");
|
||||||
|
|
||||||
|
// 打开摄像头
|
||||||
|
sv::Camera cap;
|
||||||
|
cap.open(sv::CameraType::WEBCAM, 0);
|
||||||
|
// cv::VideoCapture cap;
|
||||||
|
// cap.open("/home/nvidia/samples/video/MOT17-01.mp4");
|
||||||
|
// cap.open("/home/amov/SpireVision/video/vehicle/vehicle_02.mp4");
|
||||||
|
// 实例化OpenCV的Mat类,用于内存单帧图像
|
||||||
|
cv::Mat img;
|
||||||
|
int frame_id = 0;
|
||||||
|
sv::SORT tracker(0.5, 10, 3);
|
||||||
|
while (1)
|
||||||
|
{
|
||||||
|
// 实例化SpireVision的 单帧检测结果 接口类 TargetsInFrame
|
||||||
|
sv::TargetsInFrame tgts(frame_id++);
|
||||||
|
// 读取一帧图像到img
|
||||||
|
cap.read(img);
|
||||||
|
cv::resize(img, img, cv::Size(cod.image_width, cod.image_height));
|
||||||
|
|
||||||
|
// 执行通用目标检测
|
||||||
|
cod.detect(img, tgts);
|
||||||
|
tracker.update(tgts);
|
||||||
|
//vector<VectorXd> detections;
|
||||||
|
// 可视化检测结果,叠加到img上
|
||||||
|
// sv::drawTargetsInFrame(img, tgts);
|
||||||
|
/*
|
||||||
|
for (int i=0; i<tgts.targets.size(); i++)
|
||||||
|
{
|
||||||
|
sv::Box box;
|
||||||
|
tgts.targets[i].getBox(box);
|
||||||
|
cv::Mat roi = img(cv::Rect(box.x1, box.y1, box.x2-box.x1, box.y2-box.y1)).clone();
|
||||||
|
VectorXd detect(5);
|
||||||
|
detect.fill(-1.0);
|
||||||
|
detect << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1,-1.0;
|
||||||
|
detections.push_back(detect);
|
||||||
|
}
|
||||||
|
tracker.update(detections);
|
||||||
|
|
||||||
|
|
||||||
|
// Input
|
||||||
|
TRTCHECK(cudaMemcpyAsync(g_p_buffers[g_input_index], g_p_data, 1 * 3 * 32 * 32 * sizeof(float), cudaMemcpyHostToDevice, g_cu_stream));
|
||||||
|
g_trt_context->enqueue(1, g_p_buffers, g_cu_stream, nullptr);
|
||||||
|
// Output
|
||||||
|
TRTCHECK(cudaMemcpyAsync(g_p_prob, g_p_buffers[g_output_index], 1 * 11 * sizeof(float), cudaMemcpyDeviceToHost, g_cu_stream));
|
||||||
|
cudaStreamSynchronize(g_cu_stream);
|
||||||
|
|
||||||
|
// Find max index
|
||||||
|
double max = 0;
|
||||||
|
int label = 0;
|
||||||
|
for (int i = 0; i < 11; ++i)
|
||||||
|
{
|
||||||
|
if (max < g_p_prob[i])
|
||||||
|
{
|
||||||
|
max = g_p_prob[i];
|
||||||
|
label = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
//cv::imshow("roi", roi);
|
||||||
|
// cv::waitKey(10);
|
||||||
|
//tracker.update(detections);
|
||||||
|
//vector<Detection>().swap(detections);
|
||||||
|
|
||||||
|
// 可视化检测结果,叠加到img上
|
||||||
|
sv::drawTargetsInFrame(img, tgts);
|
||||||
|
// 控制台打印通用目标检测结果
|
||||||
|
printf("Frame-[%d]\n", frame_id);
|
||||||
|
// 打印当前检测的FPS
|
||||||
|
printf(" FPS = %.2f\n", tgts.fps);
|
||||||
|
// 打印当前相机的视场角(degree)
|
||||||
|
printf(" FOV (fx, fy) = (%.2f, %.2f)\n", tgts.fov_x, tgts.fov_y);
|
||||||
|
for (int i=0; i<tgts.targets.size(); i++)
|
||||||
|
{
|
||||||
|
printf("Frame-[%d], Object-[%d]\n", frame_id, i);
|
||||||
|
// 打印每个目标的中心位置,cx,cy的值域为[0, 1]
|
||||||
|
printf(" Object Center (cx, cy) = (%.3f, %.3f)\n", tgts.targets[i].cx, tgts.targets[i].cy);
|
||||||
|
// 打印每个目标的外接矩形框的宽度、高度,w,h的值域为(0, 1]
|
||||||
|
printf(" Object Size (w, h) = (%.3f, %.3f)\n", tgts.targets[i].w, tgts.targets[i].h);
|
||||||
|
// 打印每个目标的置信度
|
||||||
|
printf(" Object Score = %.3f\n", tgts.targets[i].score);
|
||||||
|
// 打印每个目标的类别,字符串类型
|
||||||
|
printf(" Object Category = %s, Category ID = [%d]\n", tgts.targets[i].category.c_str(), tgts.targets[i].category_id);
|
||||||
|
// 打印每个目标的视线角,跟相机视场相关
|
||||||
|
printf(" Object Line-of-sight (ax, ay) = (%.3f, %.3f)\n", tgts.targets[i].los_ax, tgts.targets[i].los_ay);
|
||||||
|
// 打印每个目标的3D位置(在相机坐标系下),跟目标实际长宽、相机参数相关
|
||||||
|
printf(" Object Position = (x, y, z) = (%.3f, %.3f, %.3f)\n", tgts.targets[i].px, tgts.targets[i].py, tgts.targets[i].pz);
|
||||||
|
// print tracked_id
|
||||||
|
printf(" Object tracked_id = %d\n", tgts.targets[i].tracked_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 显示检测结果img
|
||||||
|
cv::imshow("img", img);
|
||||||
|
cv::waitKey(10);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,504 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef TENSORRT_LOGGING_H
|
||||||
|
#define TENSORRT_LOGGING_H
|
||||||
|
|
||||||
|
#include "NvInferRuntimeCommon.h"
|
||||||
|
#include <cassert>
|
||||||
|
#include <ctime>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
|
#include <ostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include "macros.h"
|
||||||
|
|
||||||
|
using Severity = nvinfer1::ILogger::Severity;
|
||||||
|
|
||||||
|
class LogStreamConsumerBuffer : public std::stringbuf
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
|
||||||
|
: mOutput(stream)
|
||||||
|
, mPrefix(prefix)
|
||||||
|
, mShouldLog(shouldLog)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other)
|
||||||
|
: mOutput(other.mOutput)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
~LogStreamConsumerBuffer()
|
||||||
|
{
|
||||||
|
// std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
|
||||||
|
// std::streambuf::pptr() gives a pointer to the current position of the output sequence
|
||||||
|
// if the pointer to the beginning is not equal to the pointer to the current position,
|
||||||
|
// call putOutput() to log the output to the stream
|
||||||
|
if (pbase() != pptr())
|
||||||
|
{
|
||||||
|
putOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// synchronizes the stream buffer and returns 0 on success
|
||||||
|
// synchronizing the stream buffer consists of inserting the buffer contents into the stream,
|
||||||
|
// resetting the buffer and flushing the stream
|
||||||
|
virtual int sync()
|
||||||
|
{
|
||||||
|
putOutput();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void putOutput()
|
||||||
|
{
|
||||||
|
if (mShouldLog)
|
||||||
|
{
|
||||||
|
// prepend timestamp
|
||||||
|
std::time_t timestamp = std::time(nullptr);
|
||||||
|
tm* tm_local = std::localtime(×tamp);
|
||||||
|
std::cout << "[";
|
||||||
|
std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
|
||||||
|
std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
|
||||||
|
std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
|
||||||
|
std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
|
||||||
|
std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
|
||||||
|
std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
|
||||||
|
// std::stringbuf::str() gets the string contents of the buffer
|
||||||
|
// insert the buffer contents pre-appended by the appropriate prefix into the stream
|
||||||
|
mOutput << mPrefix << str();
|
||||||
|
// set the buffer to empty
|
||||||
|
str("");
|
||||||
|
// flush the stream
|
||||||
|
mOutput.flush();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void setShouldLog(bool shouldLog)
|
||||||
|
{
|
||||||
|
mShouldLog = shouldLog;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::ostream& mOutput;
|
||||||
|
std::string mPrefix;
|
||||||
|
bool mShouldLog;
|
||||||
|
};
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \class LogStreamConsumerBase
|
||||||
|
//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
|
||||||
|
//!
|
||||||
|
class LogStreamConsumerBase
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
|
||||||
|
: mBuffer(stream, prefix, shouldLog)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
LogStreamConsumerBuffer mBuffer;
|
||||||
|
};
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \class LogStreamConsumer
|
||||||
|
//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
|
||||||
|
//! Order of base classes is LogStreamConsumerBase and then std::ostream.
|
||||||
|
//! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
|
||||||
|
//! in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
|
||||||
|
//! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
|
||||||
|
//! Please do not change the order of the parent classes.
|
||||||
|
//!
|
||||||
|
class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
//! \brief Creates a LogStreamConsumer which logs messages with level severity.
|
||||||
|
//! Reportable severity determines if the messages are severe enough to be logged.
|
||||||
|
LogStreamConsumer(Severity reportableSeverity, Severity severity)
|
||||||
|
: LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
|
||||||
|
, std::ostream(&mBuffer) // links the stream buffer with the stream
|
||||||
|
, mShouldLog(severity <= reportableSeverity)
|
||||||
|
, mSeverity(severity)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
LogStreamConsumer(LogStreamConsumer&& other)
|
||||||
|
: LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
|
||||||
|
, std::ostream(&mBuffer) // links the stream buffer with the stream
|
||||||
|
, mShouldLog(other.mShouldLog)
|
||||||
|
, mSeverity(other.mSeverity)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void setReportableSeverity(Severity reportableSeverity)
|
||||||
|
{
|
||||||
|
mShouldLog = mSeverity <= reportableSeverity;
|
||||||
|
mBuffer.setShouldLog(mShouldLog);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static std::ostream& severityOstream(Severity severity)
|
||||||
|
{
|
||||||
|
return severity >= Severity::kINFO ? std::cout : std::cerr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string severityPrefix(Severity severity)
|
||||||
|
{
|
||||||
|
switch (severity)
|
||||||
|
{
|
||||||
|
case Severity::kINTERNAL_ERROR: return "[F] ";
|
||||||
|
case Severity::kERROR: return "[E] ";
|
||||||
|
case Severity::kWARNING: return "[W] ";
|
||||||
|
case Severity::kINFO: return "[I] ";
|
||||||
|
case Severity::kVERBOSE: return "[V] ";
|
||||||
|
default: assert(0); return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool mShouldLog;
|
||||||
|
Severity mSeverity;
|
||||||
|
};
|
||||||
|
|
||||||
|
//! \class Logger
|
||||||
|
//!
|
||||||
|
//! \brief Class which manages logging of TensorRT tools and samples
|
||||||
|
//!
|
||||||
|
//! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
|
||||||
|
//! and supports logging two types of messages:
|
||||||
|
//!
|
||||||
|
//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
|
||||||
|
//! - Test pass/fail messages
|
||||||
|
//!
|
||||||
|
//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
|
||||||
|
//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
|
||||||
|
//!
|
||||||
|
//! In the future, this class could be extended to support dumping test results to a file in some standard format
|
||||||
|
//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
|
||||||
|
//!
|
||||||
|
//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
|
||||||
|
//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
|
||||||
|
//! library and messages coming from the sample.
|
||||||
|
//!
|
||||||
|
//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
|
||||||
|
//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
|
||||||
|
//! object.
|
||||||
|
|
||||||
|
class Logger : public nvinfer1::ILogger
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Logger(Severity severity = Severity::kWARNING)
|
||||||
|
: mReportableSeverity(severity)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \enum TestResult
|
||||||
|
//! \brief Represents the state of a given test
|
||||||
|
//!
|
||||||
|
enum class TestResult
|
||||||
|
{
|
||||||
|
kRUNNING, //!< The test is running
|
||||||
|
kPASSED, //!< The test passed
|
||||||
|
kFAILED, //!< The test failed
|
||||||
|
kWAIVED //!< The test was waived
|
||||||
|
};
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
|
||||||
|
//! \return The nvinfer1::ILogger associated with this Logger
|
||||||
|
//!
|
||||||
|
//! TODO Once all samples are updated to use this method to register the logger with TensorRT,
|
||||||
|
//! we can eliminate the inheritance of Logger from ILogger
|
||||||
|
//!
|
||||||
|
nvinfer1::ILogger& getTRTLogger()
|
||||||
|
{
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief Implementation of the nvinfer1::ILogger::log() virtual method
|
||||||
|
//!
|
||||||
|
//! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
|
||||||
|
//! inheritance from nvinfer1::ILogger
|
||||||
|
//!
|
||||||
|
void log(Severity severity, const char* msg) TRT_NOEXCEPT override
|
||||||
|
{
|
||||||
|
LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief Method for controlling the verbosity of logging output
|
||||||
|
//!
|
||||||
|
//! \param severity The logger will only emit messages that have severity of this level or higher.
|
||||||
|
//!
|
||||||
|
void setReportableSeverity(Severity severity)
|
||||||
|
{
|
||||||
|
mReportableSeverity = severity;
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief Opaque handle that holds logging information for a particular test
|
||||||
|
//!
|
||||||
|
//! This object is an opaque handle to information used by the Logger to print test results.
|
||||||
|
//! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
|
||||||
|
//! with Logger::reportTest{Start,End}().
|
||||||
|
//!
|
||||||
|
class TestAtom
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
TestAtom(TestAtom&&) = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class Logger;
|
||||||
|
|
||||||
|
TestAtom(bool started, const std::string& name, const std::string& cmdline)
|
||||||
|
: mStarted(started)
|
||||||
|
, mName(name)
|
||||||
|
, mCmdline(cmdline)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
bool mStarted;
|
||||||
|
std::string mName;
|
||||||
|
std::string mCmdline;
|
||||||
|
};
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief Define a test for logging
|
||||||
|
//!
|
||||||
|
//! \param[in] name The name of the test. This should be a string starting with
|
||||||
|
//! "TensorRT" and containing dot-separated strings containing
|
||||||
|
//! the characters [A-Za-z0-9_].
|
||||||
|
//! For example, "TensorRT.sample_googlenet"
|
||||||
|
//! \param[in] cmdline The command line used to reproduce the test
|
||||||
|
//
|
||||||
|
//! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
|
||||||
|
//!
|
||||||
|
static TestAtom defineTest(const std::string& name, const std::string& cmdline)
|
||||||
|
{
|
||||||
|
return TestAtom(false, name, cmdline);
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
|
||||||
|
//! as input
|
||||||
|
//!
|
||||||
|
//! \param[in] name The name of the test
|
||||||
|
//! \param[in] argc The number of command-line arguments
|
||||||
|
//! \param[in] argv The array of command-line arguments (given as C strings)
|
||||||
|
//!
|
||||||
|
//! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
|
||||||
|
static TestAtom defineTest(const std::string& name, int argc, char const* const* argv)
|
||||||
|
{
|
||||||
|
auto cmdline = genCmdlineString(argc, argv);
|
||||||
|
return defineTest(name, cmdline);
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief Report that a test has started.
|
||||||
|
//!
|
||||||
|
//! \pre reportTestStart() has not been called yet for the given testAtom
|
||||||
|
//!
|
||||||
|
//! \param[in] testAtom The handle to the test that has started
|
||||||
|
//!
|
||||||
|
static void reportTestStart(TestAtom& testAtom)
|
||||||
|
{
|
||||||
|
reportTestResult(testAtom, TestResult::kRUNNING);
|
||||||
|
assert(!testAtom.mStarted);
|
||||||
|
testAtom.mStarted = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief Report that a test has ended.
|
||||||
|
//!
|
||||||
|
//! \pre reportTestStart() has been called for the given testAtom
|
||||||
|
//!
|
||||||
|
//! \param[in] testAtom The handle to the test that has ended
|
||||||
|
//! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
|
||||||
|
//! TestResult::kFAILED, TestResult::kWAIVED
|
||||||
|
//!
|
||||||
|
static void reportTestEnd(const TestAtom& testAtom, TestResult result)
|
||||||
|
{
|
||||||
|
assert(result != TestResult::kRUNNING);
|
||||||
|
assert(testAtom.mStarted);
|
||||||
|
reportTestResult(testAtom, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int reportPass(const TestAtom& testAtom)
|
||||||
|
{
|
||||||
|
reportTestEnd(testAtom, TestResult::kPASSED);
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int reportFail(const TestAtom& testAtom)
|
||||||
|
{
|
||||||
|
reportTestEnd(testAtom, TestResult::kFAILED);
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int reportWaive(const TestAtom& testAtom)
|
||||||
|
{
|
||||||
|
reportTestEnd(testAtom, TestResult::kWAIVED);
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int reportTest(const TestAtom& testAtom, bool pass)
|
||||||
|
{
|
||||||
|
return pass ? reportPass(testAtom) : reportFail(testAtom);
|
||||||
|
}
|
||||||
|
|
||||||
|
Severity getReportableSeverity() const
|
||||||
|
{
|
||||||
|
return mReportableSeverity;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
//!
|
||||||
|
//! \brief returns an appropriate string for prefixing a log message with the given severity
|
||||||
|
//!
|
||||||
|
static const char* severityPrefix(Severity severity)
|
||||||
|
{
|
||||||
|
switch (severity)
|
||||||
|
{
|
||||||
|
case Severity::kINTERNAL_ERROR: return "[F] ";
|
||||||
|
case Severity::kERROR: return "[E] ";
|
||||||
|
case Severity::kWARNING: return "[W] ";
|
||||||
|
case Severity::kINFO: return "[I] ";
|
||||||
|
case Severity::kVERBOSE: return "[V] ";
|
||||||
|
default: assert(0); return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief returns an appropriate string for prefixing a test result message with the given result
|
||||||
|
//!
|
||||||
|
static const char* testResultString(TestResult result)
|
||||||
|
{
|
||||||
|
switch (result)
|
||||||
|
{
|
||||||
|
case TestResult::kRUNNING: return "RUNNING";
|
||||||
|
case TestResult::kPASSED: return "PASSED";
|
||||||
|
case TestResult::kFAILED: return "FAILED";
|
||||||
|
case TestResult::kWAIVED: return "WAIVED";
|
||||||
|
default: assert(0); return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
|
||||||
|
//!
|
||||||
|
static std::ostream& severityOstream(Severity severity)
|
||||||
|
{
|
||||||
|
return severity >= Severity::kINFO ? std::cout : std::cerr;
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief method that implements logging test results
|
||||||
|
//!
|
||||||
|
static void reportTestResult(const TestAtom& testAtom, TestResult result)
|
||||||
|
{
|
||||||
|
severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
|
||||||
|
<< testAtom.mCmdline << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief generate a command line string from the given (argc, argv) values
|
||||||
|
//!
|
||||||
|
static std::string genCmdlineString(int argc, char const* const* argv)
|
||||||
|
{
|
||||||
|
std::stringstream ss;
|
||||||
|
for (int i = 0; i < argc; i++)
|
||||||
|
{
|
||||||
|
if (i > 0)
|
||||||
|
ss << " ";
|
||||||
|
ss << argv[i];
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
Severity mReportableSeverity;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
|
||||||
|
//!
|
||||||
|
//! Example usage:
|
||||||
|
//!
|
||||||
|
//! LOG_VERBOSE(logger) << "hello world" << std::endl;
|
||||||
|
//!
|
||||||
|
inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
|
||||||
|
{
|
||||||
|
return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
|
||||||
|
//!
|
||||||
|
//! Example usage:
|
||||||
|
//!
|
||||||
|
//! LOG_INFO(logger) << "hello world" << std::endl;
|
||||||
|
//!
|
||||||
|
inline LogStreamConsumer LOG_INFO(const Logger& logger)
|
||||||
|
{
|
||||||
|
return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
|
||||||
|
//!
|
||||||
|
//! Example usage:
|
||||||
|
//!
|
||||||
|
//! LOG_WARN(logger) << "hello world" << std::endl;
|
||||||
|
//!
|
||||||
|
inline LogStreamConsumer LOG_WARN(const Logger& logger)
|
||||||
|
{
|
||||||
|
return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
|
||||||
|
//!
|
||||||
|
//! Example usage:
|
||||||
|
//!
|
||||||
|
//! LOG_ERROR(logger) << "hello world" << std::endl;
|
||||||
|
//!
|
||||||
|
inline LogStreamConsumer LOG_ERROR(const Logger& logger)
|
||||||
|
{
|
||||||
|
return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
//!
|
||||||
|
//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
|
||||||
|
// ("fatal" severity)
|
||||||
|
//!
|
||||||
|
//! Example usage:
|
||||||
|
//!
|
||||||
|
//! LOG_FATAL(logger) << "hello world" << std::endl;
|
||||||
|
//!
|
||||||
|
inline LogStreamConsumer LOG_FATAL(const Logger& logger)
|
||||||
|
{
|
||||||
|
return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
#endif // TENSORRT_LOGGING_H
|
|
@ -0,0 +1,29 @@
|
||||||
|
#ifndef __MACROS_H
|
||||||
|
#define __MACROS_H
|
||||||
|
|
||||||
|
#include <NvInfer.h>
|
||||||
|
|
||||||
|
#ifdef API_EXPORTS
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#define API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#define API __declspec(dllimport)
|
||||||
|
#else
|
||||||
|
#define API
|
||||||
|
#endif
|
||||||
|
#endif // API_EXPORTS
|
||||||
|
|
||||||
|
#if NV_TENSORRT_MAJOR >= 8
|
||||||
|
#define TRT_NOEXCEPT noexcept
|
||||||
|
#define TRT_CONST_ENQUEUE const
|
||||||
|
#else
|
||||||
|
#define TRT_NOEXCEPT
|
||||||
|
#define TRT_CONST_ENQUEUE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // __MACROS_H
|
Loading…
Reference in New Issue