sort algorithms v0.2

This commit is contained in:
CZC-123 2023-08-07 17:05:53 +08:00
parent 0e539348d9
commit 8732e54980
6 changed files with 140 additions and 217 deletions

View File

@ -8,7 +8,7 @@ add_definitions(-DAPI_EXPORTS)
set(CMAKE_BUILD_TYPE "Release")
## JETSON, X86_CUDA
## JETSON, X86_CUDA, X86_INTEL
message(STATUS "System:${CMAKE_HOST_SYSTEM_PROCESSOR}")
if(NOT DEFINED PLATFORM)
message(FATAL_ERROR "PLATFORM NOT SPECIFIED!")
@ -23,6 +23,7 @@ else()
option(USE_CUDA "BUILD WITH CUDA." ON)
option(USE_FFMPEG "BUILD WITH FFMPEG." ON)
elseif(PLATFORM STREQUAL "X86_INTEL")
add_definitions(-DPLATFORM_X86_INTEL)
option(USE_FFMPEG "BUILD WITH FFMPEG." ON)
else()
message(FATAL_ERROR "UNSUPPORTED PLATFORM!")
@ -77,6 +78,7 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/video_io
${CMAKE_CURRENT_SOURCE_DIR}/algorithm/ellipse_det
${CMAKE_CURRENT_SOURCE_DIR}/utils
${CMAKE_CURRENT_SOURCE_DIR}/algorithm/common_mot
)
if(USE_GSTREAMER)
@ -111,6 +113,7 @@ set(
include/sv_gimbal.h
include/sv_algorithm_base.h
include/sv_common_det.h
include/sv_common_mot.h
include/sv_landing_det.h
include/sv_tracking.h
include/sv_color_line.h
@ -152,6 +155,7 @@ set(spirecv_SRCS
algorithm/sv_algorithm_base.cpp
algorithm/ellipse_det/ellipse_detector.cpp
algorithm/common_det/sv_common_det.cpp
algorithm/common_mot/sv_common_mot.cpp
algorithm/landing_det/sv_landing_det.cpp
algorithm/tracking/sv_tracking.cpp
algorithm/color_line/sv_color_line.cpp
@ -187,7 +191,7 @@ list(APPEND spirecv_SRCS ${ALG_SRC_FILES})
endif()
if(USE_CUDA)
if(USE_CUDA) # PLATFORM_X86_CUDA & PLATFORM_JETSON
# CUDA
include_directories(/usr/local/cuda/include)
link_directories(/usr/local/cuda/lib64)
@ -199,24 +203,11 @@ if(USE_CUDA)
target_link_libraries(sv_yoloplugins nvinfer cudart)
cuda_add_library(sv_world SHARED ${spirecv_SRCS})
if(USE_GSTREAMER)
target_link_libraries(
sv_world ${OpenCV_LIBS}
sv_yoloplugins sv_gimbal
nvinfer cudart
gstrtspserver-1.0
)
else()
target_link_libraries(
sv_world ${OpenCV_LIBS}
sv_yoloplugins sv_gimbal
nvinfer cudart
)
endif()
if(USE_FFMPEG)
target_link_libraries(sv_world ${FFMPEG_LIBS} fmt)
endif()
target_link_libraries(
sv_world ${OpenCV_LIBS}
sv_yoloplugins sv_gimbal
nvinfer cudart
)
set(
YOLO_SRCS
@ -232,18 +223,20 @@ if(USE_CUDA)
cuda_add_executable(SpireCVSeg samples/SpireCVSeg.cpp ${YOLO_SRCS})
target_link_libraries(SpireCVSeg sv_world)
elseif(PLATFORM STREQUAL "X86_INTEL")
elseif(PLATFORM STREQUAL "X86_INTEL") # Links to Intel-OpenVINO libraries here
add_library(sv_world SHARED ${spirecv_SRCS})
target_link_libraries(
sv_world ${OpenCV_LIBS}
sv_gimbal
)
if(USE_GSTREAMER)
target_link_libraries(sv_world gstrtspserver-1.0)
endif()
if(USE_FFMPEG)
target_link_libraries(sv_world ${FFMPEG_LIBS} fmt)
endif()
endif()
if(USE_GSTREAMER)
target_link_libraries(sv_world gstrtspserver-1.0)
endif()
if(USE_FFMPEG)
target_link_libraries(sv_world ${FFMPEG_LIBS} fmt)
endif()
#demo
@ -278,12 +271,12 @@ target_link_libraries(GimbalLandingMarkerDetection sv_world)
add_executable(GimbalUdpDetectionInfoSender samples/demo/gimbal_udp_detection_info_sender.cpp)
target_link_libraries(GimbalUdpDetectionInfoSender sv_world)
add_executable(SortTracker samples/demo/common_object_sort.cpp)
target_link_libraries(SortTracker sv_world)
add_executable(EvalFpsOnVideo samples/test/eval_fps_on_video.cpp)
target_link_libraries(EvalFpsOnVideo sv_world)
add_executable(SortTracker samples/demo/common_object_tracking_SORT.cpp samples/demo/SORT_2.cpp)
target_link_libraries(SortTracker sv_world)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/samples/calib)
add_executable(CameraCalibrarion samples/calib/calibrate_camera_charuco.cpp)
target_link_libraries(CameraCalibrarion ${OpenCV_LIBS})

View File

@ -1,104 +1,51 @@
#include "SORT_2.h"
#include "sv_common_mot.h"
#include <cmath>
#include "sv_util.h"
using namespace std;
using namespace Eigen;
/*
void SORT::update(vector<VectorXd> detections) {
//allocate the memory for the vector in advance
int numDetections = detections.size();
detections.reserve(numDetections);
if (! tracklets_.size() || ! detections.size()){
for (int i = 0; i < detections.size(); i++)
{
//if(find(matches.begin(), matches.end(),make_pair(-1, i)) == matches.end())
if (detections[i](4) == -1) {
Tracklet tracklet;
tracklet.id = nextTrackletId_++;//
tracklet.bbox.segment<4>(0) = detections[i].segment<4>(0);
tracklet.age = 0;
tracklet.hits = 1;
tracklet.misses = 0;
//initate the motion
//pair<VectorXd,MatrixXd> motion = kf.initiate(bbox_xywh_to_wyah(detect_mat[i]));
tracklets_.push_back(tracklet);
}
}
}
else{
// create new tracklets for unmatched detections
for (int i = 0; i < detections.size(); i++)
{
if (detections[i](4)== -1) {
Tracklet tracklet;
tracklet.id = nextTrackletId_++;//
tracklet.bbox.segment<4>(0) = detections[i].segment<4>(0);
tracklet.age = 0;
tracklet.hits = 1;
tracklet.misses = 0;
tracklets_.push_back(tracklet);
}
namespace sv {
KalmanFilter::KalmanFilter(){
chi2inv95_ << 3.8415, 5.9915, 7.8147, 9.4877, 11.070, 12.592, 14.067, 15.507, 16.919;
F_=MatrixXd::Identity(8,8);
for (int i=0;i<4;i++){
F_(i,i+4)=1;
}
H_=MatrixXd::Identity(4,8);
std_weight_position=1./20;
std_weight_vel=1./160;
}
}*/
KalmanFilter::KalmanFilter(){
chi2inv95_ << 3.8415, 5.9915, 7.8147, 9.4877, 11.070, 12.592, 14.067, 15.507, 16.919;
/*
F_ << 1, 0, 0, 0, 1, 0, 0, 0,
0, 1, 0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0, 0, 1,
0, 0, 0, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1;
H_ << 1, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0;
*/
F_ = MatrixXd::Identity(8, 8);
for (int i = 0; i < 4; i++) {
F_(i, i + 4) = 1;
}
H_ = MatrixXd::Identity(4, 8);
std_weight_position = 1. / 20;
std_weight_vel = 1. / 160;
}
pair<Matrix<double,8,1>,Matrix<double,8,8>> KalmanFilter::initiate(Vector4d & bbox){
Matrix<double,8,1> mean;
mean << bbox(0), bbox(1), bbox(2)/bbox(3), bbox(3), 0, 0, 0, 0;
//cout<<"the a is:"<<bbox(2)/bbox(3)<<endl;
VectorXd stds(8);
stds << 2 * std_weight_position * mean(3), 2 * std_weight_position * mean(3), 0.01, 2 * std_weight_position * mean(3), 10 * std_weight_vel * mean(3), 10 * std_weight_vel * mean(3), 1e-5, 10 * std_weight_vel * mean(3);
MatrixXd squared = stds.array().square();
Matrix<double, 8,8> covariances;
covariances = squared.asDiagonal();
return make_pair(mean, covariances);
KalmanFilter::~KalmanFilter(){};
pair<Matrix<double,8,1>,Matrix<double,8,8>> KalmanFilter::initiate(Vector4d &bbox){
Matrix<double,8,1> mean;
mean<<bbox(0), bbox(1), bbox(2)/bbox(3), bbox(3), 0, 0, 0, 0;
VectorXd stds(8);
stds << 2*std_weight_position*mean(3), 2*std_weight_position*mean(3), 0.01, 2*std_weight_position*mean(3), 10*std_weight_vel*mean(3), 10*std_weight_vel*mean(3), 1e-5, 10*std_weight_vel*mean(3);
MatrixXd squared = stds.array().square();
Matrix<double,8,8> covariances;
covariances = squared.asDiagonal();
return make_pair(mean, covariances);
}
pair<Matrix<double,8,1>,Matrix<double,8,8>> KalmanFilter::update(Matrix<double, 8, 1> mean, Matrix<double, 8, 8> covariances, sv::Box & box) {
MatrixXd R_;
Vector4d stds;
pair<Matrix<double,8,1>,Matrix<double,8,8>> KalmanFilter::update(Matrix<double,8,1> mean, Matrix<double,8,8> covariances, sv::Box &box){
MatrixXd R_;
Vector4d stds;
stds << std_weight_position * mean(3), std_weight_position* mean(3), 0.1, std_weight_position* mean(3);
MatrixXd squared = stds.array().square();
R_ = squared.asDiagonal();
stds << std_weight_position * mean(3), std_weight_position* mean(3), 0.1, std_weight_position* mean(3);
MatrixXd squared = stds.array().square();
R_= squared.asDiagonal();
MatrixXd S = H_ * covariances * H_.transpose()+R_;
MatrixXd S = H_ * covariances * H_.transpose()+R_;
MatrixXd Kalman_gain = covariances * H_.transpose() * S.inverse();
MatrixXd Kalman_gain = covariances * H_.transpose() * S.inverse();
VectorXd measurement(4);
measurement << box.x1, box.y1, (box.x2-box.x1)/(box.y2-box.y1), box.y2-box.y1;
Matrix<double, 8, 1> new_mean = mean + Kalman_gain * (measurement - H_ * mean);
Matrix<double, 8, 8> new_covariances = (MatrixXd::Identity(8, 8) - Kalman_gain * H_) * covariances;
return make_pair(new_mean, new_covariances);
}
}
pair<Matrix<double, 8, 1>, Matrix<double, 8, 8>> KalmanFilter::predict(Matrix<double, 8, 1>mean, Matrix<double, 8, 8>covariances) {
pair<Matrix<double, 8, 1>, Matrix<double, 8, 8>> KalmanFilter::predict(Matrix<double, 8, 1>mean, Matrix<double, 8, 8>covariances) {
VectorXd stds(8);
stds << std_weight_position * mean(3), std_weight_position* mean(3), 0.01, std_weight_position* mean(3), std_weight_vel* mean(3), std_weight_vel* mean(3), 1e-5, std_weight_vel* mean(3);
@ -109,24 +56,25 @@ pair<Matrix<double, 8, 1>, Matrix<double, 8, 8>> KalmanFilter::predict(Matrix<do
return make_pair(pre_mean, pre_cov);
}
SORT::~SORT(){
}
void SORT::update(sv::TargetsInFrame & tgts){
KalmanFilter kf;
sv::KalmanFilter kf;
// if (! tracklets_.size() || ! tgts.targets.size())
if (! tracklets_.size())
{
Vector4d bbox;
for (int i = 0; i <tgts.targets.size(); i++)
{
//if(find(matches.begin(), matches.end(),make_pair(-1, i)) == matches.end())
//if (tgts.targets[i].tracked_id == 0) {
sv::Box box;
tgts.targets[i].getBox(box);
Tracklet tracklet;
tracklet.id = ++nextTrackletId_;//
tracklet.id = ++next_tracklet_id_;//
cout<<tracklet.id<<endl;
tgts.targets[i].tracked_id = nextTrackletId_;
//bbox << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1;
tgts.targets[i].tracked_id = next_tracklet_id_;
tracklet.bbox << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1;//x,y,w,h
tracklet.age = 0;
@ -135,8 +83,6 @@ void SORT::update(sv::TargetsInFrame & tgts){
//initate the motion
pair<Matrix<double,8,1>,Matrix<double,8,8>> motion = kf.initiate(tracklet.bbox);
//cout<<"mean:"<<motion.first<<endl;
//cout<<"mean:"<<tracklet.mean<<endl;
tracklet.mean=motion.first;
tracklet.covariance = motion.second;
@ -177,9 +123,6 @@ void SORT::update(sv::TargetsInFrame & tgts){
for (auto& match : matches) {
int trackletIndex = match.first;
int detectionIndex = match.second;
// cout<<"trackletIndex:"<<match.first<<endl;
// cout<<"detectionIndex:"<<match.second<<endl;
// cout<<"iou_eles:"<<iouMatrix[match.first][match.second]<<endl;
if (trackletIndex>=0 && detectionIndex>=0) {
if(iouMatrix[match.first][match.second]>=0) {
sv::Box box;
@ -204,10 +147,7 @@ void SORT::update(sv::TargetsInFrame & tgts){
sv::Box box;
tgts.targets[i].getBox(box);
Tracklet tracklet;
tracklet.id = ++nextTrackletId_;//
//Vector4d bbox;
//bbox << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1;
//tracklet.bboxes.push_back(bbox);
tracklet.id = ++next_tracklet_id_;//
tracklet.bbox << box.x1,box.y1,box.x2-box.x1,box.y2-box.y1;
tracklet.age = 0;
@ -218,7 +158,7 @@ void SORT::update(sv::TargetsInFrame & tgts){
tracklet.mean = new_mean;
tracklet.covariance = new_covariance;
tgts.targets[i].tracked_id = nextTrackletId_;
tgts.targets[i].tracked_id = next_tracklet_id_;
tracklets_.push_back(tracklet);
}
}
@ -233,6 +173,7 @@ void SORT::update(sv::TargetsInFrame & tgts){
tracklets_ = newTracklets;
*/
}
vector<Tracklet> SORT::getTracklets() const{
return tracklets_;
}
@ -243,17 +184,7 @@ double SORT::iou(Tracklet & tracklet, sv::Box & box) {
double trackletY1 = tracklet.bbox(1);
double trackletX2 = tracklet.bbox(0)+tracklet.bbox(2);
double trackletY2 = tracklet.bbox(1)+tracklet.bbox(3);
/*
double trackletX1 = tracklet.bboxes.back()(0);
double trackletY1 = tracklet.bboxes.back()(1);
double trackletX2 = tracklet.bboxes.back()(0)+tracklet.bboxes.back()(2);
double trackletY2 = tracklet.bboxes.back()(1)+tracklet.bboxes.back()(3);
double trackletX1 = tracklet.x;
double trackletY1 = tracklet.y;
double trackletX2 = tracklet.x+tracklet.w;
double trackletY2 = tracklet.y+tracklet.h;
*/
double detectionX1 = box.x1;
double detectionY1 = box.y1;
double detectionX2 = box.x2;
@ -262,21 +193,21 @@ double SORT::iou(Tracklet & tracklet, sv::Box & box) {
double intersectionY1 = max(trackletY1, detectionY1);
double intersectionX2 = min(trackletX2, detectionX2);
double intersectionY2 = min(trackletY2, detectionY2);
//double intersectionArea = max(intersectionX2-intersectionX1, 0.0) * max(intersectionY2-intersectionY1, 0.0);
double w = (intersectionX2-intersectionX1>0.0) ? (intersectionX2-intersectionX1):0.0;
double h = (intersectionY2-intersectionY1>0.0) ? (intersectionY2-intersectionY1):0.0;
double intersectionArea = w*h;
double trackletArea = tracklet.bbox(2)*tracklet.bbox(3);
//double trackletArea = tracklet.w*tracklet.h;
double detectionArea = (box.x2-box.x1)*(box.y2-box.y1);
double unionArea = trackletArea + detectionArea - intersectionArea;
double iou = (1-intersectionArea / unionArea*1.0);
return iou;
}
vector<pair<int, int>> SORT::hungarian(vector<vector<double>> costMatrix) {
// vector<pair<int, int>> SORT::hungarian(array<array<double>> costMatrix) {
int numRows = costMatrix.size();
int numCols = costMatrix[0].size();
@ -352,3 +283,4 @@ bool SORT::Augment(const vector<vector<double>>& costMatrix, int row, vector<int
}
return false;
}
}

60
include/sv_common_mot.h Normal file
View File

@ -0,0 +1,60 @@
#ifndef __SV_COMMON_MOT__
#define __SV_COMMON_MOT__
#include <Eigen/Dense>
#include <iostream>
#include <vector>
#include <opencv2/opencv.hpp>
#include <array>
#include "sv_core.h"
namespace sv{
//define the tracklet struct to store the tracked objects.
struct Tracklet
{
/* data */
public:
Eigen::Vector4d bbox;//double x, y, w, h;
int id=0;
int age;
int hits;
int misses;
std::vector<double> features;
Eigen::Matrix<double,8,1> mean;
Eigen::Matrix<double,8,8> covariance;
};
class KalmanFilter {
public:
KalmanFilter();
~KalmanFilter();
std::pair<Eigen::Matrix<double,8,1>,Eigen::Matrix<double,8,8>> initiate(Eigen::Vector4d &bbox);
std::pair<Eigen::Matrix<double,8,1>,Eigen::Matrix<double,8,8>> update(Eigen::Matrix<double,8,1> mean, Eigen::Matrix<double,8,8> covariances, Box &box);
std::pair<Eigen::Matrix<double,8,1>,Eigen::Matrix<double,8,8>> predict(Eigen::Matrix<double,8,1> mean, Eigen::Matrix<double,8,8> covariances);
private:
Eigen::Matrix<double,8,8> F_;
Eigen::Matrix<double,4,8> H_;
Eigen::Matrix<double,9,1> chi2inv95_;
double std_weight_position;
double std_weight_vel;
};
class SORT {
public:
SORT(double iou_threshold, int max_age, int min_hits): iou_threshold_(iou_threshold),max_age_(max_age),min_hits_(min_hits),next_tracklet_id_(0) {};
~SORT();
void update(TargetsInFrame &tgts);
std::vector<Tracklet> getTracklets() const;
private:
double iou(Tracklet &tracklet, Box &box);
std::vector<std::pair<int,int>> hungarian(std::vector<std::vector<double>> costMatrix);
bool Augment(const std::vector<std::vector<double>>& costMatrix, int row, std::vector<int>& rowMatch, std::vector<int>& colMatch, std::vector<bool>& visited);
private:
double iou_threshold_;
int max_age_;
int min_hits_;
int next_tracklet_id_;
std::vector <Tracklet> tracklets_;
};
}
#endif

View File

@ -8,6 +8,7 @@
#include "sv_color_line.h"
#include "sv_video_input.h"
#include "sv_video_output.h"
#include "sv_common_mot.h"
#endif

View File

@ -1,62 +0,0 @@
#ifndef SORT_2_H
#define SORT_2_H
#include <Eigen/Dense>
#include <iostream>
#include <vector>
#include <opencv2/opencv.hpp>
#include <sv_world.h>
#include <array>
//define the tracklet struct to store the tracked objects
struct Tracklet {
public:
//Eigen::VectorXd bbox = Eigen::VectorXd::Zero(4);//x1,y1,w,h
//std::vector<Eigen::Vector4d> bboxes;
Eigen::Vector4d bbox;
//double x, y, w, h;
int id=0;
int age;
int hits;
int misses;
std::vector<double> features;
Eigen::Matrix<double,8,1> mean;
Eigen::Matrix<double,8,8> covariance;
};
class KalmanFilter {
public:
KalmanFilter();
std::pair<Eigen::Matrix<double,8,1>,Eigen::Matrix<double,8,8>> initiate(Eigen::Vector4d &bbox);
std::pair<Eigen::Matrix<double,8,1>,Eigen::Matrix<double,8,8>> update(Eigen::Matrix<double,8,1> mean,Eigen::Matrix<double,8,8> covariances,sv::Box & box);
std::pair<Eigen::Matrix<double,8,1>,Eigen::Matrix<double,8,8>> predict(Eigen::Matrix<double,8,1>mean,Eigen::Matrix<double,8,8>covariances);
private:
Eigen::Matrix<double,8,8> F_;
Eigen::Matrix<double,4,8> H_;
Eigen::Matrix<double,9,1> chi2inv95_;
double std_weight_position;
double std_weight_vel;
};
class SORT {
public:
SORT(double iouThreshold, int maxAge, int minHits): iouThreshold_(iouThreshold), maxAge_(maxAge), minHits_(minHits), nextTrackletId_(0) {};
//void update(std::vector<Eigen::VectorXd> detections);
void update(sv::TargetsInFrame & tgts);
std::vector<Tracklet> getTracklets() const;
private:
double iou(Tracklet & tracklet, sv::Box & box);
std::vector<std::pair<int, int>> hungarian(std::vector<std::vector<double>> costMatrix);
bool Augment(const std::vector<std::vector<double>>& costMatrix, int row, std::vector<int>& rowMatch, std::vector<int>& colMatch, std::vector<bool>& visited);
private:
double iouThreshold_;
int maxAge_;
int minHits_;
int nextTrackletId_;
std::vector <Tracklet> tracklets_;
};
#endif

View File

@ -7,7 +7,6 @@
#include "logging.h"
// #include <opencv2/opencv.hpp>
#include "SORT_2.h"
#define TRTCHECK(status) \
do \
@ -88,18 +87,18 @@ int main(int argc, char *argv[]) {
// 实例化 通用目标 检测器类
sv::CommonObjectDetector cod;
// 手动导入相机参数如果使用Amov的G1等吊舱或相机则可以忽略该步骤将自动下载相机参数文件
cod.loadCameraParams(sv::get_home() + "/SpireVision/calib_webcam_640x480.yaml");
cod.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_640x480.yaml");
// 打开摄像头
// sv::Camera cap;
// cap.open(sv::CameraType::WEBCAM, 0);
cv::VideoCapture cap;
cap.open("/home/amov/SpireVision/video/predestrian/MOT17-01.mp4");
sv::Camera cap;
cap.open(sv::CameraType::WEBCAM, 0);
// cv::VideoCapture cap;
// cap.open("/home/nvidia/samples/video/MOT17-01.mp4");
// cap.open("/home/amov/SpireVision/video/vehicle/vehicle_02.mp4");
// 实例化OpenCV的Mat类用于内存单帧图像
cv::Mat img;
int frame_id = 0;
SORT tracker(0.5, 10, 3);
sv::SORT tracker(0.5, 10, 3);
while (1)
{
// 实例化SpireVision的 单帧检测结果 接口类 TargetsInFrame