diff --git a/CMakeLists.txt b/CMakeLists.txt index 5e3c8a0..44707a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -275,6 +275,9 @@ target_link_libraries(GimbalUdpDetectionInfoSender sv_world) add_executable(EvalFpsOnVideo samples/test/eval_fps_on_video.cpp) target_link_libraries(EvalFpsOnVideo sv_world) +add_executable(EvalModelOnCocoVal samples/test/eval_mAP_on_coco_val/eval_mAP_on_coco_val.cpp) +target_link_libraries(EvalModelOnCocoVal sv_world) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/samples/calib) add_executable(CameraCalibrarion samples/calib/calibrate_camera_charuco.cpp) target_link_libraries(CameraCalibrarion ${OpenCV_LIBS}) diff --git a/samples/test/eval_mAP_on_coco_val/coco_eval.py b/samples/test/eval_mAP_on_coco_val/coco_eval.py new file mode 100644 index 0000000..39de0ed --- /dev/null +++ b/samples/test/eval_mAP_on_coco_val/coco_eval.py @@ -0,0 +1,37 @@ +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +import os +import json + +if __name__ == '__main__': + path = os.path.abspath(os.path.join(os.getcwd(),"../../..")) + pred_json = 'pd_coco.json' + anno_json = path + '/val2017/instances_val2017.json' + + # use COCO API to load forecast results and annotations + cocoGt = COCO(anno_json) + with open(pred_json,'r') as file: + data = json.load(file) + + # align anno_json with pred_json category_id + gtCatDicts = {} + for anns in range(len(cocoGt.getCatIds())): + gtCatDicts[anns] = cocoGt.getCatIds()[anns] + + pdCatIds=list(set([d['category_id'] for d in data])) + + if not set(pdCatIds).issubset(set(cocoGt.getCatIds())): + for ins in data: + temp = int(gtCatDicts[ins['category_id']]) + ins['category_id'] = temp + + # load prediction results + cocoDt = cocoGt.loadRes(data) + + # create COCO eval object + cocoEval = COCOeval(cocoGt, cocoDt,'bbox') + + # assessment + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() diff --git a/samples/test/eval_mAP_on_coco_val/eval_mAP_on_coco_val.cpp b/samples/test/eval_mAP_on_coco_val/eval_mAP_on_coco_val.cpp new file mode 100644 index 0000000..12f6339 --- /dev/null +++ b/samples/test/eval_mAP_on_coco_val/eval_mAP_on_coco_val.cpp @@ -0,0 +1,94 @@ +#include +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; +using namespace cv; + +//extract name +std::string GetImageFileName(const std::string& imagePath) { + size_t lastSlash = imagePath.find_last_of("/\\"); + if (lastSlash == std::string::npos) { + return imagePath; + } else { + std::string fileName = imagePath.substr(lastSlash + 1); + size_t lastDot = fileName.find_last_of("."); + if (lastDot != std::string::npos) { + return fileName.substr(0, lastDot); + } + return fileName; + } +} + + +int main(int argc, char *argv[]) +{ + // 实例化 通用目标 检测器类 + sv::CommonObjectDetector cod; + // 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件 + cod.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_640x480.yaml"); + + //load data + string val_path = sv::get_home() + "/SpireCV/val2017/val2017"; + vector val_image; + glob(val_path, val_image, false); + if (val_image.size() == 0) + { + printf("val_image error!!!\n"); + exit(1); + } + + //preds folder + std::string folder = sv::get_home() + "/SpireCV/val2017/preds"; + int checkStatus = std::system(("if [ -d \"" + folder + "\" ]; then echo; fi").c_str()); + if(checkStatus == 0) + { + int removeStatus = std::system(("rm -rf \"" + folder + "\"").c_str()); + if(removeStatus != 0) + { + printf("remove older preds folder error!!!\n"); + exit(1); + } + } + + int status = std::system(("mkdir \""+folder+"\"").c_str()); + if(status != 0) + { + printf("create preds folder error!!!\n"); + exit(1); + } + + + for (int i = 0; i < val_image.size(); i++) { + + //create pred file + std::string val_image_name = GetImageFileName(val_image[i]); + std::string filename = folder+"/"+ val_image_name + ".txt"; + std::ofstream file(filename); + file.is_open(); + file<