From 0136d3c2dada5859b077d5d8f53f1a26c9136bff Mon Sep 17 00:00:00 2001 From: jario-jin Date: Fri, 16 Jun 2023 10:42:02 +0800 Subject: [PATCH] fst commit --- .gitignore | 42 + CMakeLists.txt | 378 ++ LICENSE | 202 + README.md | 80 +- .../common_det/cuda/common_det_cuda_impl.cpp | 310 ++ .../common_det/cuda/common_det_cuda_impl.h | 58 + .../common_det/cuda/yolov7/calibrator.cpp | 97 + algorithm/common_det/cuda/yolov7/calibrator.h | 36 + algorithm/common_det/cuda/yolov7/config.h | 55 + algorithm/common_det/cuda/yolov7/cuda_utils.h | 18 + algorithm/common_det/cuda/yolov7/logging.h | 504 +++ algorithm/common_det/cuda/yolov7/macros.h | 29 + algorithm/common_det/cuda/yolov7/model.cpp | 628 +++ algorithm/common_det/cuda/yolov7/model.h | 16 + .../common_det/cuda/yolov7/postprocess.cpp | 198 + .../common_det/cuda/yolov7/postprocess.h | 16 + .../common_det/cuda/yolov7/preprocess.cu | 153 + algorithm/common_det/cuda/yolov7/preprocess.h | 15 + algorithm/common_det/cuda/yolov7/types.h | 17 + algorithm/common_det/cuda/yolov7/utils.h | 70 + algorithm/common_det/cuda/yolov7/yololayer.cu | 280 ++ algorithm/common_det/cuda/yolov7/yololayer.h | 106 + algorithm/common_det/sv_common_det.cpp | 62 + algorithm/ellipse_det/ellipse_detector.cpp | 3863 +++++++++++++++++ algorithm/ellipse_det/ellipse_detector.h | 1379 ++++++ .../cuda/landing_det_cuda_impl.cpp | 160 + .../landing_det/cuda/landing_det_cuda_impl.h | 48 + algorithm/landing_det/sv_landing_det.cpp | 48 + algorithm/sv_algorithm_base.cpp | 1142 +++++ .../tracking/ocv470/tracking_ocv470_impl.cpp | 135 + .../tracking/ocv470/tracking_ocv470_impl.h | 40 + algorithm/tracking/sv_tracking.cpp | 41 + build_on_jetson.sh | 9 + build_on_x86_cuda.sh | 9 + gimbal_ctrl/IOs/serial/README.md | 63 + .../IOs/serial/include/serial/impl/unix.h | 221 + .../IOs/serial/include/serial/impl/win.h | 207 + .../IOs/serial/include/serial/serial.h | 796 ++++ .../IOs/serial/include/serial/v8stdint.h | 57 + .../src/impl/list_ports/list_ports_linux.cc | 336 ++ .../src/impl/list_ports/list_ports_osx.cc | 286 ++ .../src/impl/list_ports/list_ports_win.cc | 152 + gimbal_ctrl/IOs/serial/src/impl/unix.cc | 1084 +++++ gimbal_ctrl/IOs/serial/src/impl/win.cc | 646 +++ gimbal_ctrl/IOs/serial/src/serial.cc | 432 ++ gimbal_ctrl/driver/src/FIFO/Ring_Fifo.cc | 212 + gimbal_ctrl/driver/src/FIFO/Ring_Fifo.h | 47 + gimbal_ctrl/driver/src/G1/g1_gimbal_crc32.h | 93 + .../driver/src/G1/g1_gimbal_driver.cpp | 245 ++ gimbal_ctrl/driver/src/G1/g1_gimbal_driver.h | 68 + .../driver/src/G1/g1_gimbal_funtion.cpp | 118 + gimbal_ctrl/driver/src/G1/g1_gimbal_struct.h | 91 + gimbal_ctrl/driver/src/G2/g2_gimbal_crc.h | 166 + .../driver/src/G2/g2_gimbal_driver.cpp | 243 ++ gimbal_ctrl/driver/src/G2/g2_gimbal_driver.h | 90 + .../driver/src/G2/g2_gimbal_funtion.cpp | 81 + .../driver/src/G2/g2_gimbal_iap_funtion.cpp | 357 ++ gimbal_ctrl/driver/src/G2/g2_gimbal_struct.h | 81 + .../driver/src/Q10f/Q10f_gimbal_crc32.h | 27 + .../driver/src/Q10f/Q10f_gimbal_driver.cpp | 258 ++ .../driver/src/Q10f/Q10f_gimbal_driver.h | 71 + .../driver/src/Q10f/Q10f_gimbal_funtion.cpp | 180 + .../driver/src/Q10f/Q10f_gimbal_struct.h | 105 + gimbal_ctrl/driver/src/amov_gimabl.cpp | 239 + gimbal_ctrl/driver/src/amov_gimbal.h | 118 + gimbal_ctrl/driver/src/amov_gimbal_struct.h | 74 + gimbal_ctrl/sv_gimbal.cpp | 411 ++ gimbal_ctrl/sv_gimbal_io.hpp | 68 + include/sv_algorithm_base.h | 172 + include/sv_common_det.h | 39 + include/sv_core.h | 8 + include/sv_gimbal.h | 162 + include/sv_landing_det.h | 33 + include/sv_tracking.h | 32 + include/sv_video_base.h | 399 ++ include/sv_video_input.h | 27 + include/sv_video_output.h | 53 + include/sv_world.h | 12 + samples/SpireCVDet.cpp | 120 + samples/SpireCVSeg.cpp | 120 + samples/calib/aruco_samples_utility.hpp | 48 + samples/calib/calibrate_camera_charuco.cpp | 293 ++ samples/demo/aruco_detection.cpp | 74 + samples/demo/camera_reading.cpp | 27 + samples/demo/common_object_detection.cpp | 72 + .../demo/detection_with_clicked_tracking.cpp | 193 + samples/demo/ellipse_detection.cpp | 70 + samples/demo/landing_marker_detection.cpp | 72 + samples/demo/single_object_tracking.cpp | 139 + samples/demo/udp_detection_info_receiver.cpp | 248 ++ samples/demo/udp_detection_info_sender.cpp | 60 + samples/demo/video_saving.cpp | 51 + samples/demo/video_streaming.cpp | 35 + scripts/common/download_test_videos.sh | 28 + scripts/common/ffmpeg425-install.sh | 63 + scripts/common/gst-install-orin.sh | 23 + scripts/common/gst-install.sh | 19 + scripts/common/models-converting.sh | 41 + scripts/common/models-downloading.sh | 107 + scripts/common/opencv470-install.sh | 56 + .../jetson/opencv470-jetpack511-install.sh | 54 + scripts/x86-cuda/x86-gst-install.sh | 19 + scripts/x86-cuda/x86-opencv470-install.sh | 56 + .../x86-ubuntu2004-cuda-cudnn-11-6.sh | 98 + utils/gason.cpp | 396 ++ utils/gason.h | 145 + utils/sv_crclib.cpp | 575 +++ utils/sv_crclib.h | 32 + utils/sv_util.cpp | 139 + utils/sv_util.h | 34 + video_io/ffmpeg/bs_common.h | 49 + video_io/ffmpeg/bs_push_streamer.cpp | 377 ++ video_io/ffmpeg/bs_push_streamer.h | 93 + video_io/ffmpeg/bs_video_saver.cpp | 392 ++ video_io/ffmpeg/bs_video_saver.h | 90 + .../gstreamer/streamer_gstreamer_impl.cpp | 102 + video_io/gstreamer/streamer_gstreamer_impl.h | 48 + video_io/gstreamer/writer_gstreamer_impl.cpp | 61 + video_io/gstreamer/writer_gstreamer_impl.h | 37 + video_io/sv_video_base.cpp | 1310 ++++++ video_io/sv_video_input.cpp | 77 + video_io/sv_video_output.cpp | 151 + 122 files changed, 25173 insertions(+), 25 deletions(-) create mode 100644 CMakeLists.txt create mode 100644 LICENSE create mode 100644 algorithm/common_det/cuda/common_det_cuda_impl.cpp create mode 100644 algorithm/common_det/cuda/common_det_cuda_impl.h create mode 100644 algorithm/common_det/cuda/yolov7/calibrator.cpp create mode 100644 algorithm/common_det/cuda/yolov7/calibrator.h create mode 100644 algorithm/common_det/cuda/yolov7/config.h create mode 100644 algorithm/common_det/cuda/yolov7/cuda_utils.h create mode 100644 algorithm/common_det/cuda/yolov7/logging.h create mode 100644 algorithm/common_det/cuda/yolov7/macros.h create mode 100644 algorithm/common_det/cuda/yolov7/model.cpp create mode 100644 algorithm/common_det/cuda/yolov7/model.h create mode 100644 algorithm/common_det/cuda/yolov7/postprocess.cpp create mode 100644 algorithm/common_det/cuda/yolov7/postprocess.h create mode 100644 algorithm/common_det/cuda/yolov7/preprocess.cu create mode 100644 algorithm/common_det/cuda/yolov7/preprocess.h create mode 100644 algorithm/common_det/cuda/yolov7/types.h create mode 100644 algorithm/common_det/cuda/yolov7/utils.h create mode 100644 algorithm/common_det/cuda/yolov7/yololayer.cu create mode 100644 algorithm/common_det/cuda/yolov7/yololayer.h create mode 100644 algorithm/common_det/sv_common_det.cpp create mode 100644 algorithm/ellipse_det/ellipse_detector.cpp create mode 100644 algorithm/ellipse_det/ellipse_detector.h create mode 100644 algorithm/landing_det/cuda/landing_det_cuda_impl.cpp create mode 100644 algorithm/landing_det/cuda/landing_det_cuda_impl.h create mode 100644 algorithm/landing_det/sv_landing_det.cpp create mode 100644 algorithm/sv_algorithm_base.cpp create mode 100644 algorithm/tracking/ocv470/tracking_ocv470_impl.cpp create mode 100644 algorithm/tracking/ocv470/tracking_ocv470_impl.h create mode 100644 algorithm/tracking/sv_tracking.cpp create mode 100755 build_on_jetson.sh create mode 100755 build_on_x86_cuda.sh create mode 100755 gimbal_ctrl/IOs/serial/README.md create mode 100755 gimbal_ctrl/IOs/serial/include/serial/impl/unix.h create mode 100755 gimbal_ctrl/IOs/serial/include/serial/impl/win.h create mode 100755 gimbal_ctrl/IOs/serial/include/serial/serial.h create mode 100755 gimbal_ctrl/IOs/serial/include/serial/v8stdint.h create mode 100755 gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_linux.cc create mode 100755 gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_osx.cc create mode 100755 gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_win.cc create mode 100755 gimbal_ctrl/IOs/serial/src/impl/unix.cc create mode 100755 gimbal_ctrl/IOs/serial/src/impl/win.cc create mode 100755 gimbal_ctrl/IOs/serial/src/serial.cc create mode 100755 gimbal_ctrl/driver/src/FIFO/Ring_Fifo.cc create mode 100755 gimbal_ctrl/driver/src/FIFO/Ring_Fifo.h create mode 100755 gimbal_ctrl/driver/src/G1/g1_gimbal_crc32.h create mode 100755 gimbal_ctrl/driver/src/G1/g1_gimbal_driver.cpp create mode 100755 gimbal_ctrl/driver/src/G1/g1_gimbal_driver.h create mode 100755 gimbal_ctrl/driver/src/G1/g1_gimbal_funtion.cpp create mode 100755 gimbal_ctrl/driver/src/G1/g1_gimbal_struct.h create mode 100755 gimbal_ctrl/driver/src/G2/g2_gimbal_crc.h create mode 100755 gimbal_ctrl/driver/src/G2/g2_gimbal_driver.cpp create mode 100755 gimbal_ctrl/driver/src/G2/g2_gimbal_driver.h create mode 100644 gimbal_ctrl/driver/src/G2/g2_gimbal_funtion.cpp create mode 100755 gimbal_ctrl/driver/src/G2/g2_gimbal_iap_funtion.cpp create mode 100755 gimbal_ctrl/driver/src/G2/g2_gimbal_struct.h create mode 100755 gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_crc32.h create mode 100755 gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_driver.cpp create mode 100755 gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_driver.h create mode 100755 gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_funtion.cpp create mode 100755 gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_struct.h create mode 100755 gimbal_ctrl/driver/src/amov_gimabl.cpp create mode 100755 gimbal_ctrl/driver/src/amov_gimbal.h create mode 100755 gimbal_ctrl/driver/src/amov_gimbal_struct.h create mode 100644 gimbal_ctrl/sv_gimbal.cpp create mode 100644 gimbal_ctrl/sv_gimbal_io.hpp create mode 100644 include/sv_algorithm_base.h create mode 100644 include/sv_common_det.h create mode 100644 include/sv_core.h create mode 100644 include/sv_gimbal.h create mode 100644 include/sv_landing_det.h create mode 100644 include/sv_tracking.h create mode 100644 include/sv_video_base.h create mode 100644 include/sv_video_input.h create mode 100644 include/sv_video_output.h create mode 100644 include/sv_world.h create mode 100644 samples/SpireCVDet.cpp create mode 100644 samples/SpireCVSeg.cpp create mode 100644 samples/calib/aruco_samples_utility.hpp create mode 100644 samples/calib/calibrate_camera_charuco.cpp create mode 100644 samples/demo/aruco_detection.cpp create mode 100644 samples/demo/camera_reading.cpp create mode 100644 samples/demo/common_object_detection.cpp create mode 100644 samples/demo/detection_with_clicked_tracking.cpp create mode 100644 samples/demo/ellipse_detection.cpp create mode 100644 samples/demo/landing_marker_detection.cpp create mode 100644 samples/demo/single_object_tracking.cpp create mode 100644 samples/demo/udp_detection_info_receiver.cpp create mode 100644 samples/demo/udp_detection_info_sender.cpp create mode 100644 samples/demo/video_saving.cpp create mode 100644 samples/demo/video_streaming.cpp create mode 100644 scripts/common/download_test_videos.sh create mode 100644 scripts/common/ffmpeg425-install.sh create mode 100644 scripts/common/gst-install-orin.sh create mode 100644 scripts/common/gst-install.sh create mode 100644 scripts/common/models-converting.sh create mode 100644 scripts/common/models-downloading.sh create mode 100644 scripts/common/opencv470-install.sh create mode 100644 scripts/jetson/opencv470-jetpack511-install.sh create mode 100644 scripts/x86-cuda/x86-gst-install.sh create mode 100644 scripts/x86-cuda/x86-opencv470-install.sh create mode 100644 scripts/x86-cuda/x86-ubuntu2004-cuda-cudnn-11-6.sh create mode 100644 utils/gason.cpp create mode 100644 utils/gason.h create mode 100644 utils/sv_crclib.cpp create mode 100644 utils/sv_crclib.h create mode 100644 utils/sv_util.cpp create mode 100644 utils/sv_util.h create mode 100644 video_io/ffmpeg/bs_common.h create mode 100644 video_io/ffmpeg/bs_push_streamer.cpp create mode 100644 video_io/ffmpeg/bs_push_streamer.h create mode 100644 video_io/ffmpeg/bs_video_saver.cpp create mode 100644 video_io/ffmpeg/bs_video_saver.h create mode 100644 video_io/gstreamer/streamer_gstreamer_impl.cpp create mode 100644 video_io/gstreamer/streamer_gstreamer_impl.h create mode 100644 video_io/gstreamer/writer_gstreamer_impl.cpp create mode 100644 video_io/gstreamer/writer_gstreamer_impl.h create mode 100644 video_io/sv_video_base.cpp create mode 100644 video_io/sv_video_input.cpp create mode 100644 video_io/sv_video_output.cpp diff --git a/.gitignore b/.gitignore index 259148f..e377bf6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,21 @@ +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +eggs/ +.eggs/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +.idea/ + # Prerequisites *.d @@ -30,3 +48,27 @@ *.exe *.out *.app + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# VSCode Editor +.vscode/ + diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..bba4797 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,378 @@ +cmake_minimum_required(VERSION 3.0 FATAL_ERROR) +cmake_policy(SET CMP0054 NEW) + +set(PROJECT_VERSION 0.2.0) +project(SpireCV VERSION ${PROJECT_VERSION} LANGUAGES CXX) + +add_definitions(-DAPI_EXPORTS) +set(CMAKE_BUILD_TYPE "Release") + + +## JETSON, X86_CUDA +message(STATUS "System:${CMAKE_HOST_SYSTEM_PROCESSOR}") +if(NOT DEFINED PLATFORM) + message(FATAL_ERROR "PLATFORM NOT SPECIFIED!") +else() + message(STATUS "PLATFORM: ${PLATFORM}") + if(PLATFORM STREQUAL "JETSON") + add_definitions(-DPLATFORM_JETSON) + option(USE_CUDA "BUILD WITH CUDA." ON) + option(USE_GSTREAMER "BUILD WITH GSTREAMER." ON) + elseif(PLATFORM STREQUAL "X86_CUDA") + add_definitions(-DPLATFORM_X86_CUDA) + option(USE_CUDA "BUILD WITH CUDA." ON) + option(USE_FFMPEG "BUILD WITH FFMPEG." ON) + else() + message(FATAL_ERROR "UNSUPPORTED PLATFORM!") + endif() +endif() + + +if(USE_CUDA) + add_definitions(-DWITH_CUDA) + option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) + find_package(CUDA REQUIRED) + message(STATUS "CUDA: ON") +endif() + + +if(USE_GSTREAMER) + add_definitions(-DWITH_GSTREAMER) + message(STATUS "GSTREAMER: ON") +endif() + +if(USE_FFMPEG) + add_definitions(-DWITH_FFMPEG) + find_package(fmt REQUIRED) + set(FFMPEG_LIBS libavutil.so libavcodec.so libavformat.so libavdevice.so libavfilter.so libswscale.so) + message(STATUS "WITH_FFMPEG: ON") +endif() + + +add_definitions(-DWITH_OCV470) +find_package(OpenCV 4.7 REQUIRED) +message(STATUS "OpenCV library status:") +message(STATUS " version: ${OpenCV_VERSION}") +message(STATUS " libraries: ${OpenCV_LIBS}") +message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}") + + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/IOs/serial/include + ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/FIFO + ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/G1 + ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/G2 + ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/Q10f + ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src + ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl + ${CMAKE_CURRENT_SOURCE_DIR}/algorithm/common_det/cuda + ${CMAKE_CURRENT_SOURCE_DIR}/algorithm/landing_det/cuda + ${CMAKE_CURRENT_SOURCE_DIR}/algorithm/tracking/ocv470 + ${CMAKE_CURRENT_SOURCE_DIR}/video_io + ${CMAKE_CURRENT_SOURCE_DIR}/algorithm/ellipse_det + ${CMAKE_CURRENT_SOURCE_DIR}/utils +) + +if(USE_GSTREAMER) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/video_io/gstreamer) + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") + include_directories( + "/usr/include/gstreamer-1.0" + "/usr/local/include/gstreamer-1.0" + "/usr/include/glib-2.0" + "/usr/lib/aarch64-linux-gnu/glib-2.0/include" + ) + elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") + include_directories( + "/usr/include/gstreamer-1.0" + "/usr/local/include/gstreamer-1.0" + "/usr/include/glib-2.0" + "/usr/lib/x86_64-linux-gnu/glib-2.0/include" + ) + endif() +endif() + +if(USE_FFMPEG) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/video_io/ffmpeg) +endif() + + +# Public header +set( + public_HEADS + include/sv_core.h + include/sv_video_base.h + include/sv_gimbal.h + include/sv_algorithm_base.h + include/sv_common_det.h + include/sv_landing_det.h + include/sv_tracking.h + include/sv_video_input.h + include/sv_video_output.h + include/sv_world.h +) + +# Gimbal Sources +set(serial_SRCS + gimbal_ctrl/IOs/serial/src/serial.cc +) +list(APPEND serial_SRCS gimbal_ctrl/IOs/serial/src/impl/unix.cc) +list(APPEND serial_SRCS gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_linux.cc) + +set(driver_SRCS + gimbal_ctrl/driver/src/FIFO/Ring_Fifo.cc +) +file(GLOB DRV_LIB_FILES ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/G1/*.cpp) +list(APPEND driver_SRCS ${DRV_LIB_FILES}) +file(GLOB DRV_LIB_FILES ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/G2/*.cpp) +list(APPEND driver_SRCS ${DRV_LIB_FILES}) +file(GLOB DRV_LIB_FILES ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/Q10f/*.cpp) +list(APPEND driver_SRCS ${DRV_LIB_FILES}) +file(GLOB DRV_LIB_FILES ${CMAKE_CURRENT_SOURCE_DIR}/gimbal_ctrl/driver/src/*.cpp) +list(APPEND driver_SRCS ${DRV_LIB_FILES}) + +set(gimbal_SRCS + gimbal_ctrl/sv_gimbal.cpp + gimbal_ctrl/sv_gimbal_io.hpp +) + +# Gimbal Lib +add_library(sv_gimbal SHARED ${serial_SRCS} ${driver_SRCS} ${gimbal_SRCS}) +target_link_libraries(sv_gimbal rt pthread) + + +set(spirecv_SRCS + algorithm/sv_algorithm_base.cpp + algorithm/ellipse_det/ellipse_detector.cpp + algorithm/common_det/sv_common_det.cpp + algorithm/landing_det/sv_landing_det.cpp + algorithm/tracking/sv_tracking.cpp +) + +file(GLOB ALG_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/algorithm/tracking/ocv470/*.cpp) +list(APPEND spirecv_SRCS ${ALG_SRC_FILES}) +file(GLOB ALG_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/video_io/*.cpp) +list(APPEND spirecv_SRCS ${ALG_SRC_FILES}) +file(GLOB ALG_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/utils/*.cpp) +list(APPEND spirecv_SRCS ${ALG_SRC_FILES}) + +if(USE_CUDA) + list(APPEND spirecv_SRCS algorithm/common_det/cuda/yolov7/preprocess.cu) + file(GLOB ALG_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/algorithm/common_det/cuda/*.cpp) + list(APPEND spirecv_SRCS ${ALG_SRC_FILES}) + file(GLOB ALG_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/algorithm/common_det/cuda/yolov7/*.cpp) + list(APPEND spirecv_SRCS ${ALG_SRC_FILES}) + file(GLOB ALG_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/algorithm/landing_det/cuda/*.cpp) + list(APPEND spirecv_SRCS ${ALG_SRC_FILES}) +endif() + +if(USE_FFMPEG) + file(GLOB ALG_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/video_io/ffmpeg/*.cpp) + list(APPEND spirecv_SRCS ${ALG_SRC_FILES}) +endif() + +if(USE_GSTREAMER) +file(GLOB ALG_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/video_io/gstreamer/*.cpp) +list(APPEND spirecv_SRCS ${ALG_SRC_FILES}) +endif() + + +if(USE_CUDA) + # CUDA + include_directories(/usr/local/cuda/include) + link_directories(/usr/local/cuda/lib64) + # TensorRT + include_directories(/usr/include/x86_64-linux-gnu) + link_directories(/usr/lib/x86_64-linux-gnu) + # Add library + cuda_add_library(sv_yoloplugins SHARED algorithm/common_det/cuda/yolov7/yololayer.cu) + target_link_libraries(sv_yoloplugins nvinfer cudart) + + cuda_add_library(sv_world SHARED ${spirecv_SRCS}) + if(USE_GSTREAMER) + target_link_libraries( + sv_world ${OpenCV_LIBS} + sv_yoloplugins sv_gimbal + nvinfer cudart + gstrtspserver-1.0 + ) + else() + target_link_libraries( + sv_world ${OpenCV_LIBS} + sv_yoloplugins sv_gimbal + nvinfer cudart + ) + endif() + + if(USE_FFMPEG) + target_link_libraries(sv_world ${FFMPEG_LIBS} fmt) + endif() + + set( + YOLO_SRCS + algorithm/common_det/cuda/yolov7/preprocess.cu + algorithm/common_det/cuda/yolov7/postprocess.cpp + algorithm/common_det/cuda/yolov7/model.cpp + algorithm/common_det/cuda/yolov7/calibrator.cpp + ) + + cuda_add_executable(SpireCVDet samples/SpireCVDet.cpp ${YOLO_SRCS}) + target_link_libraries(SpireCVDet sv_world) + + cuda_add_executable(SpireCVSeg samples/SpireCVSeg.cpp ${YOLO_SRCS}) + target_link_libraries(SpireCVSeg sv_world) + +elseif(PLATFORM STREQUAL "X86_CPU") + add_library(sv_world SHARED ${spirecv_SRCS}) + target_link_libraries( + sv_world ${OpenCV_LIBS} + sv_gimbal + ) + if(USE_GSTREAMER) + target_link_libraries(sv_world gstrtspserver-1.0) + endif() + if(USE_FFMPEG) + target_link_libraries(sv_world ${FFMPEG_LIBS} fmt) + endif() +endif() + + +add_executable(ArucoDetection samples/demo/aruco_detection.cpp) +target_link_libraries(ArucoDetection sv_world) +add_executable(CameraReading samples/demo/camera_reading.cpp) +target_link_libraries(CameraReading sv_world) +add_executable(CommonObjectDetection samples/demo/common_object_detection.cpp) +target_link_libraries(CommonObjectDetection sv_world) +add_executable(DetectionWithClickedTracking samples/demo/detection_with_clicked_tracking.cpp) +target_link_libraries(DetectionWithClickedTracking sv_world) +add_executable(EllipseDetection samples/demo/ellipse_detection.cpp) +target_link_libraries(EllipseDetection sv_world) +add_executable(LandingMarkerDetection samples/demo/landing_marker_detection.cpp) +target_link_libraries(LandingMarkerDetection sv_world) +add_executable(SingleObjectTracking samples/demo/single_object_tracking.cpp) +target_link_libraries(SingleObjectTracking sv_world) +add_executable(UdpDetectionInfoReceiver samples/demo/udp_detection_info_receiver.cpp) +target_link_libraries(UdpDetectionInfoReceiver sv_world) +add_executable(UdpDetectionInfoSender samples/demo/udp_detection_info_sender.cpp) +target_link_libraries(UdpDetectionInfoSender sv_world) +add_executable(VideoSaving samples/demo/video_saving.cpp) +target_link_libraries(VideoSaving sv_world) +add_executable(VideoStreaming samples/demo/video_streaming.cpp) +target_link_libraries(VideoStreaming sv_world) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/samples/calib) +add_executable(CameraCalibrarion samples/calib/calibrate_camera_charuco.cpp) +target_link_libraries(CameraCalibrarion ${OpenCV_LIBS}) + + +message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") +if (NOT DEFINED SV_INSTALL_PREFIX) + set(SV_INSTALL_PREFIX ${CMAKE_INSTALL_PREFIX}) + message(STATUS "SV_INSTALL_PREFIX: ${SV_INSTALL_PREFIX}") +else() + message(STATUS "SV_INSTALL_PREFIX: ${SV_INSTALL_PREFIX}") +endif() + + +if(USE_CUDA) + install(TARGETS sv_gimbal sv_yoloplugins sv_world + LIBRARY DESTINATION lib + ) + install(TARGETS SpireCVDet SpireCVSeg + RUNTIME DESTINATION bin + ) +elseif(PLATFORM STREQUAL "X86_CPU") + install(TARGETS sv_world + LIBRARY DESTINATION lib + ) +endif() + +install(FILES ${public_HEADS} + DESTINATION include +) + + +if(PLATFORM STREQUAL "JETSON") +file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/build/${PROJECT_NAME}Config.cmake.in [[ +@PACKAGE_INIT@ +find_package(OpenCV 4 REQUIRED) +link_directories(/usr/local/cuda/lib64) +set(SV_INCLUDE_DIRS + @SV_INSTALL_PREFIX@/include + /usr/include/x86_64-linux-gnu + /usr/local/cuda/include + ${OpenCV_INCLUDE_DIRS} + /usr/include/gstreamer-1.0 + /usr/local/include/gstreamer-1.0 + /usr/include/glib-2.0 + /usr/lib/aarch64-linux-gnu/glib-2.0/include +) +set(SV_LIBRARIES + @SV_INSTALL_PREFIX@/lib/libsv_yoloplugins.so + @SV_INSTALL_PREFIX@/lib/libsv_world.so + @SV_INSTALL_PREFIX@/lib/libsv_gimbal.so + ${OpenCV_LIBS} + nvinfer cudart rt pthread + gstrtspserver-1.0 +) +]]) +elseif(PLATFORM STREQUAL "X86_CUDA") +file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/build/${PROJECT_NAME}Config.cmake.in [[ +@PACKAGE_INIT@ +find_package(OpenCV 4 REQUIRED) +find_package(fmt REQUIRED) +link_directories(/usr/local/cuda/lib64) +set(SV_INCLUDE_DIRS + @SV_INSTALL_PREFIX@/include + /usr/include/x86_64-linux-gnu + /usr/local/cuda/include + ${OpenCV_INCLUDE_DIRS} +) +set(SV_LIBRARIES + @SV_INSTALL_PREFIX@/lib/libsv_yoloplugins.so + @SV_INSTALL_PREFIX@/lib/libsv_world.so + @SV_INSTALL_PREFIX@/lib/libsv_gimbal.so + ${OpenCV_LIBS} + nvinfer cudart rt pthread + @FFMPEG_LIBS@ fmt +) +]]) +elseif(PLATFORM STREQUAL "X86_CPU") +file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/build/${PROJECT_NAME}Config.cmake.in [[ +@PACKAGE_INIT@ +find_package(OpenCV 4 REQUIRED) +find_package(fmt REQUIRED) +set(SV_INCLUDE_DIRS + @SV_INSTALL_PREFIX@/include + /usr/include/x86_64-linux-gnu + ${OpenCV_INCLUDE_DIRS} +) +set(SV_LIBRARIES + @SV_INSTALL_PREFIX@/lib/libsv_world.so + @SV_INSTALL_PREFIX@/lib/libsv_gimbal.so + ${OpenCV_LIBS} + rt pthread + @FFMPEG_LIBS@ fmt +) +]]) +endif() + + +include(CMakePackageConfigHelpers) +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config-version.cmake + VERSION ${PROJECT_VERSION} + COMPATIBILITY AnyNewerVersion +) +configure_package_config_file(${CMAKE_CURRENT_BINARY_DIR}/build/${PROJECT_NAME}Config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config.cmake + INSTALL_DESTINATION lib/cmake/${PROJECT_NAME} +) +install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config-version.cmake + DESTINATION lib/cmake/${PROJECT_NAME} +) + + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index b173686..a7cd067 100644 --- a/README.md +++ b/README.md @@ -1,37 +1,67 @@ -# SpireCV +# SpireCV 智能感知算法库 -#### 介绍 -SpireCV是一个专为智能无人系统打造的边缘实时感知SDK,主要功能包括相机/吊舱控制、视频保存与推流、目标探测识别与跟踪、边缘数据管理迭代等。旨在为移动机器人开发者提供高性能、高可靠、接口简洁、功能丰富的视觉感知能力。 +## 项目概况 -#### 软件架构 -软件架构说明 +SpireCV是一个专为**智能无人系统**打造的**边缘实时感知SDK**,主要功能包括**相机/吊舱控制**、**视频保存与推流**、**目标探测识别与跟踪**、**边缘数据管理迭代**等。旨在为移动机器人开发者提供高性能、高可靠、接口简洁、功能丰富的视觉感知能力。 + - Github:https://github.com/amov-lab/SpireCV + - Gitee:https://gitee.com/amovlab/SpireCV + - **开源项目,维护不易,还烦请点一个star收藏,谢谢支持!** -#### 安装教程 +## 快速入门 -1. xxxx -2. xxxx -3. xxxx + - 安装及使用:[SpireCV使用手册](https://wiki.amovlab.com/public/spirecv-wiki/) + - 需掌握C++语言基础、CMake编译工具基础。 + - 需要掌握OpenCV视觉库基础,了解CUDA、OpenVINO、RKNN和CANN等计算库。 + - 需要了解ROS基本概念及基本操作。 -#### 使用说明 + - 答疑及交流: + - 答疑论坛(官方定期答疑,推荐):[阿木社区-SpireCV问答专区](https://bbs.amovlab.com/) + - 添加微信jiayue199506(备注消息:SpireCV)进入SpireCV智能感知算法库交流群。 + - B站搜索并关注“阿木社区”,开发团队定期直播答疑。 -1. xxxx -2. xxxx -3. xxxx +## 项目框架 -#### 参与贡献 +#### 主要框架如图所示: -1. Fork 本仓库 -2. 新建 Feat_xxx 分支 -3. 提交代码 -4. 新建 Pull Request + +#### 目前支持情况: + - **功能层**: + - [x] 视频算法模块(提供接口统一、性能高效、功能多样的感知算法) + - [x] 视频输入、保存与推流模块(提供稳定、跨平台的视频读写能力) + - [x] 相机、吊舱控制模块(针对典型硬件生态打通接口,易使用) + - [x] 感知信息交互模块(提供UDP通信协议) + - [x] [ROS接口](https://gitee.com/amovlab1/spirecv-ros.git) + - **平台层**: + - [x] X86+Nvidia GPU(推荐10系、20系、30系显卡) + - [x] Jetson(AGX Orin/Xavier、Orin NX/Nano、Xavier NX) + - [ ] Intel CPU(推进中) + - [ ] Rockchip(推进中) + - [ ] HUAWEI Ascend(推进中) -#### 特技 +## 功能展示 + - **二维码检测** + + + - **起降标志检测** + + + - **椭圆检测** + + + - **目标框选跟踪** + + + - **通用目标检测** + + + - **低延迟推流** + + -1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md -2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com) -3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目 -4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目 -5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) -6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) +## 版权声明 + + - 本项目受 Apache License 2.0 协议保护。 + - 本项目仅限个人使用,请勿用于商业用途。 + - 如利用本项目进行营利活动,阿木实验室将追究侵权行为。 diff --git a/algorithm/common_det/cuda/common_det_cuda_impl.cpp b/algorithm/common_det/cuda/common_det_cuda_impl.cpp new file mode 100644 index 0000000..d86f467 --- /dev/null +++ b/algorithm/common_det/cuda/common_det_cuda_impl.cpp @@ -0,0 +1,310 @@ +#include "common_det_cuda_impl.h" +#include +#include + +#define SV_MODEL_DIR "/SpireCV/models/" +#define SV_ROOT_DIR "/SpireCV/" + + +#ifdef WITH_CUDA +#include "yolov7/cuda_utils.h" +#include "yolov7/logging.h" +#include "yolov7/utils.h" +#include "yolov7/preprocess.h" +#include "yolov7/postprocess.h" +#include "yolov7/model.h" +#define TRTCHECK(status) \ + do \ + { \ + auto ret = (status); \ + if (ret != 0) \ + { \ + std::cerr << "Cuda failure: " << ret << std::endl; \ + abort(); \ + } \ + } while (0) + +#define DEVICE 0 // GPU id +#define BATCH_SIZE 1 +#define MAX_IMAGE_INPUT_SIZE_THRESH 3000 * 3000 // ensure it exceed the maximum size in the input images ! +#endif + + +namespace sv { + +using namespace cv; + + +#ifdef WITH_CUDA +using namespace nvinfer1; +static Logger g_nvlogger; +const static int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; +const static int kOutputSize1 = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; +const static int kOutputSize2 = 32 * (640 / 4) * (640 / 4); +#endif + + +CommonObjectDetectorCUDAImpl::CommonObjectDetectorCUDAImpl() +{ +#ifdef WITH_CUDA + this->_gpu_buffers[0] = nullptr; + this->_gpu_buffers[1] = nullptr; + this->_gpu_buffers[2] = nullptr; + this->_cpu_output_buffer = nullptr; + this->_cpu_output_buffer1 = nullptr; + this->_cpu_output_buffer2 = nullptr; + this->_context = nullptr; + this->_engine = nullptr; + this->_runtime = nullptr; +#endif +} + + +CommonObjectDetectorCUDAImpl::~CommonObjectDetectorCUDAImpl() +{ +#ifdef WITH_CUDA + // Release stream and buffers + cudaStreamDestroy(_stream); + if (_gpu_buffers[0]) + CUDA_CHECK(cudaFree(_gpu_buffers[0])); + if (_gpu_buffers[1]) + CUDA_CHECK(cudaFree(_gpu_buffers[1])); + if (_gpu_buffers[2]) + CUDA_CHECK(cudaFree(_gpu_buffers[2])); + if (_cpu_output_buffer) + delete[] _cpu_output_buffer; + if (_cpu_output_buffer1) + delete[] _cpu_output_buffer1; + if (_cpu_output_buffer2) + delete[] _cpu_output_buffer2; + cuda_preprocess_destroy(); + // Destroy the engine + if (_context) + _context->destroy(); + if (_engine) + _engine->destroy(); + if (_runtime) + _runtime->destroy(); +#endif +} + + +#ifdef WITH_CUDA +void infer(IExecutionContext& context, cudaStream_t& stream, void** gpu_buffers, float* output, int batchsize) { + context.enqueue(batchsize, gpu_buffers, stream, nullptr); + // context.enqueueV2(gpu_buffers, stream, nullptr); + CUDA_CHECK(cudaMemcpyAsync(output, gpu_buffers[1], batchsize * kOutputSize * sizeof(float), cudaMemcpyDeviceToHost, stream)); + cudaStreamSynchronize(stream); +} +void infer_seg(IExecutionContext& context, cudaStream_t& stream, void **buffers, float* output1, float* output2, int batchSize) { + context.enqueue(batchSize, buffers, stream, nullptr); + // context.enqueueV2(buffers, stream, nullptr); + CUDA_CHECK(cudaMemcpyAsync(output1, buffers[1], batchSize * kOutputSize1 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(output2, buffers[2], batchSize * kOutputSize2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + cudaStreamSynchronize(stream); +} +void CommonObjectDetectorCUDAImpl::_prepare_buffers(int input_h, int input_w) { + assert(this->_engine->getNbBindings() == 2); + // In order to bind the buffers, we need to know the names of the input and output tensors. + // Note that indices are guaranteed to be less than IEngine::getNbBindings() + const int inputIndex = this->_engine->getBindingIndex(kInputTensorName); + const int outputIndex = this->_engine->getBindingIndex(kOutputTensorName); + assert(inputIndex == 0); + assert(outputIndex == 1); + // Create GPU buffers on device + CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[0]), kBatchSize * 3 * input_h * input_w * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[1]), kBatchSize * kOutputSize * sizeof(float))); + + this->_cpu_output_buffer = new float[kBatchSize * kOutputSize]; +} +void CommonObjectDetectorCUDAImpl::_prepare_buffers_seg(int input_h, int input_w) { + assert(this->_engine->getNbBindings() == 3); + // In order to bind the buffers, we need to know the names of the input and output tensors. + // Note that indices are guaranteed to be less than IEngine::getNbBindings() + const int inputIndex = this->_engine->getBindingIndex(kInputTensorName); + const int outputIndex1 = this->_engine->getBindingIndex(kOutputTensorName); + const int outputIndex2 = this->_engine->getBindingIndex("proto"); + assert(inputIndex == 0); + assert(outputIndex1 == 1); + assert(outputIndex2 == 2); + + // Create GPU buffers on device + CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[0]), kBatchSize * 3 * input_h * input_w * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[1]), kBatchSize * kOutputSize1 * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)&(this->_gpu_buffers[2]), kBatchSize * kOutputSize2 * sizeof(float))); + + // Alloc CPU buffers + this->_cpu_output_buffer1 = new float[kBatchSize * kOutputSize1]; + this->_cpu_output_buffer2 = new float[kBatchSize * kOutputSize2]; +} +void deserialize_engine(std::string& engine_name, IRuntime** runtime, ICudaEngine** engine, IExecutionContext** context) { + std::ifstream file(engine_name, std::ios::binary); + if (!file.good()) { + std::cerr << "read " << engine_name << " error!" << std::endl; + assert(false); + } + size_t size = 0; + file.seekg(0, file.end); + size = file.tellg(); + file.seekg(0, file.beg); + char* serialized_engine = new char[size]; + assert(serialized_engine); + file.read(serialized_engine, size); + file.close(); + + *runtime = createInferRuntime(g_nvlogger); + assert(*runtime); + *engine = (*runtime)->deserializeCudaEngine(serialized_engine, size); + assert(*engine); + *context = (*engine)->createExecutionContext(); + assert(*context); + delete[] serialized_engine; +} +#endif + + +void CommonObjectDetectorCUDAImpl::cudaDetect( + CommonObjectDetectorBase* base_, + cv::Mat img_, + std::vector& boxes_x_, + std::vector& boxes_y_, + std::vector& boxes_w_, + std::vector& boxes_h_, + std::vector& boxes_label_, + std::vector& boxes_score_, + std::vector& boxes_seg_ +) +{ +#ifdef WITH_CUDA + int input_h = base_->getInputH(); + int input_w = base_->getInputW(); + bool with_segmentation = base_->withSegmentation(); + double thrs_conf = base_->getThrsConf(); + double thrs_nms = base_->getThrsNms(); + + std::vector img_batch; + img_batch.push_back(img_); + // Preprocess + cuda_batch_preprocess(img_batch, this->_gpu_buffers[0], input_w, input_h, this->_stream); + + // Run inference + if (with_segmentation) + { + infer_seg(*this->_context, this->_stream, (void**)this->_gpu_buffers, this->_cpu_output_buffer1, this->_cpu_output_buffer2, kBatchSize); + } + else + { + infer(*this->_context, this->_stream, (void**)this->_gpu_buffers, this->_cpu_output_buffer, kBatchSize); + } + + // NMS + std::vector> res_batch; + if (with_segmentation) + { + batch_nms(res_batch, this->_cpu_output_buffer1, img_batch.size(), kOutputSize1, thrs_conf, thrs_nms); + } + else + { + batch_nms(res_batch, this->_cpu_output_buffer, img_batch.size(), kOutputSize, thrs_conf, thrs_nms); + } + + std::vector res = res_batch[0]; + std::vector masks; + if (with_segmentation) + { + masks = process_mask(&(this->_cpu_output_buffer2[0]), kOutputSize2, res, input_h, input_w); + } + + + + for (size_t j = 0; j < res.size(); j++) { + cv::Rect r = get_rect(img_, res[j].bbox, input_h, input_w); + if (r.x < 0) r.x = 0; + if (r.y < 0) r.y = 0; + if (r.x + r.width >= img_.cols) r.width = img_.cols - r.x - 1; + if (r.y + r.height >= img_.rows) r.height = img_.rows - r.y - 1; + if (r.width > 5 && r.height > 5) + { + // cv::rectangle(img_show, r, cv::Scalar(0, 0, 255), 2); + // cv::putText(img_show, vehiclenames[(int)res[j].class_id], cv::Point(r.x, r.y - 1), cv::FONT_HERSHEY_PLAIN, 2.2, cv::Scalar(0, 0, 255), 2); + boxes_x_.push_back(r.x); + boxes_y_.push_back(r.y); + boxes_w_.push_back(r.width); + boxes_h_.push_back(r.height); + + boxes_label_.push_back((int)res[j].class_id); + boxes_score_.push_back(res[j].conf); + + if (with_segmentation) + { + cv::Mat mask_j = masks[j].clone(); + boxes_seg_.push_back(mask_j); + } + } + } + +#endif +} + +bool CommonObjectDetectorCUDAImpl::cudaSetup(CommonObjectDetectorBase* base_) +{ +#ifdef WITH_CUDA + std::string dataset = base_->getDataset(); + int input_h = base_->getInputH(); + int input_w = base_->getInputW(); + bool with_segmentation = base_->withSegmentation(); + double thrs_conf = base_->getThrsConf(); + double thrs_nms = base_->getThrsNms(); + + std::string engine_fn = get_home() + SV_MODEL_DIR + dataset + ".engine"; + if (input_w == 1280) + { + engine_fn = get_home() + SV_MODEL_DIR + dataset + "_HD.engine"; + } + if (with_segmentation) + { + base_->setInputH(640); + base_->setInputW(640); + engine_fn = get_home() + SV_MODEL_DIR + dataset + "_SEG.engine"; + } + std::cout << "Load: " << engine_fn << std::endl; + if (!is_file_exist(engine_fn)) + { + throw std::runtime_error("SpireCV (104) Error loading the CommonObject TensorRT model (File Not Exist)"); + } + + deserialize_engine(engine_fn, &this->_runtime, &this->_engine, &this->_context); + CUDA_CHECK(cudaStreamCreate(&this->_stream)); + + // Init CUDA preprocessing + cuda_preprocess_init(kMaxInputImageSize); + + if (with_segmentation) + { + // Prepare cpu and gpu buffers + this->_prepare_buffers_seg(input_h, input_w); + } + else + { + // Prepare cpu and gpu buffers + this->_prepare_buffers(input_h, input_w); + } + return true; +#endif + return false; +} + + + + + + + + + + + + + +} + diff --git a/algorithm/common_det/cuda/common_det_cuda_impl.h b/algorithm/common_det/cuda/common_det_cuda_impl.h new file mode 100644 index 0000000..cb44faf --- /dev/null +++ b/algorithm/common_det/cuda/common_det_cuda_impl.h @@ -0,0 +1,58 @@ +#ifndef __SV_COMMON_DET_CUDA__ +#define __SV_COMMON_DET_CUDA__ + +#include "sv_core.h" +#include +#include +#include +#include +#include + + + +#ifdef WITH_CUDA +#include +#include +#endif + + + +namespace sv { + + +class CommonObjectDetectorCUDAImpl +{ +public: + CommonObjectDetectorCUDAImpl(); + ~CommonObjectDetectorCUDAImpl(); + + bool cudaSetup(CommonObjectDetectorBase* base_); + void cudaDetect( + CommonObjectDetectorBase* base_, + cv::Mat img_, + std::vector& boxes_x_, + std::vector& boxes_y_, + std::vector& boxes_w_, + std::vector& boxes_h_, + std::vector& boxes_label_, + std::vector& boxes_score_, + std::vector& boxes_seg_ + ); + +#ifdef WITH_CUDA + void _prepare_buffers_seg(int input_h, int input_w); + void _prepare_buffers(int input_h, int input_w); + nvinfer1::IExecutionContext* _context; + nvinfer1::IRuntime* _runtime; + nvinfer1::ICudaEngine* _engine; + cudaStream_t _stream; + float* _gpu_buffers[3]; + float* _cpu_output_buffer; + float* _cpu_output_buffer1; + float* _cpu_output_buffer2; +#endif +}; + + +} +#endif diff --git a/algorithm/common_det/cuda/yolov7/calibrator.cpp b/algorithm/common_det/cuda/yolov7/calibrator.cpp new file mode 100644 index 0000000..ed7ce19 --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/calibrator.cpp @@ -0,0 +1,97 @@ +#include "calibrator.h" +#include "cuda_utils.h" +#include "utils.h" + +#include +#include +#include +#include +#include + +static cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h) { + int w, h, x, y; + float r_w = input_w / (img.cols * 1.0); + float r_h = input_h / (img.rows * 1.0); + if (r_h > r_w) { + w = input_w; + h = r_w * img.rows; + x = 0; + y = (input_h - h) / 2; + } else { + w = r_h * img.cols; + h = input_h; + x = (input_w - w) / 2; + y = 0; + } + cv::Mat re(h, w, CV_8UC3); + cv::resize(img, re, re.size(), 0, 0, cv::INTER_LINEAR); + cv::Mat out(input_h, input_w, CV_8UC3, cv::Scalar(128, 128, 128)); + re.copyTo(out(cv::Rect(x, y, re.cols, re.rows))); + return out; +} + +Int8EntropyCalibrator2::Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache) + : batchsize_(batchsize), + input_w_(input_w), + input_h_(input_h), + img_idx_(0), + img_dir_(img_dir), + calib_table_name_(calib_table_name), + input_blob_name_(input_blob_name), + read_cache_(read_cache) { + input_count_ = 3 * input_w * input_h * batchsize; + CUDA_CHECK(cudaMalloc(&device_input_, input_count_ * sizeof(float))); + read_files_in_dir(img_dir, img_files_); +} + +Int8EntropyCalibrator2::~Int8EntropyCalibrator2() { + CUDA_CHECK(cudaFree(device_input_)); +} + +int Int8EntropyCalibrator2::getBatchSize() const TRT_NOEXCEPT { + return batchsize_; +} + +bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT { + if (img_idx_ + batchsize_ > (int)img_files_.size()) { + return false; + } + + std::vector input_imgs_; + for (int i = img_idx_; i < img_idx_ + batchsize_; i++) { + std::cout << img_files_[i] << " " << i << std::endl; + cv::Mat temp = cv::imread(img_dir_ + img_files_[i]); + if (temp.empty()) { + std::cerr << "Fatal error: image cannot open!" << std::endl; + return false; + } + cv::Mat pr_img = preprocess_img(temp, input_w_, input_h_); + input_imgs_.push_back(pr_img); + } + img_idx_ += batchsize_; + cv::Mat blob = cv::dnn::blobFromImages(input_imgs_, 1.0 / 255.0, cv::Size(input_w_, input_h_), cv::Scalar(0, 0, 0), true, false); + + CUDA_CHECK(cudaMemcpy(device_input_, blob.ptr(0), input_count_ * sizeof(float), cudaMemcpyHostToDevice)); + assert(!strcmp(names[0], input_blob_name_)); + bindings[0] = device_input_; + return true; +} + +const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) TRT_NOEXCEPT { + std::cout << "reading calib cache: " << calib_table_name_ << std::endl; + calib_cache_.clear(); + std::ifstream input(calib_table_name_, std::ios::binary); + input >> std::noskipws; + if (read_cache_ && input.good()) { + std::copy(std::istream_iterator(input), std::istream_iterator(), std::back_inserter(calib_cache_)); + } + length = calib_cache_.size(); + return length ? calib_cache_.data() : nullptr; +} + +void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT { + std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl; + std::ofstream output(calib_table_name_, std::ios::binary); + output.write(reinterpret_cast(cache), length); +} + diff --git a/algorithm/common_det/cuda/yolov7/calibrator.h b/algorithm/common_det/cuda/yolov7/calibrator.h new file mode 100644 index 0000000..ed77b5f --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/calibrator.h @@ -0,0 +1,36 @@ +#pragma once + +#include "macros.h" +#include +#include + +//! \class Int8EntropyCalibrator2 +//! +//! \brief Implements Entropy calibrator 2. +//! CalibrationAlgoType is kENTROPY_CALIBRATION_2. +//! +class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { + public: + Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache = true); + + virtual ~Int8EntropyCalibrator2(); + int getBatchSize() const TRT_NOEXCEPT override; + bool getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT override; + const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override; + void writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT override; + + private: + int batchsize_; + int input_w_; + int input_h_; + int img_idx_; + std::string img_dir_; + std::vector img_files_; + size_t input_count_; + std::string calib_table_name_; + const char* input_blob_name_; + bool read_cache_; + void* device_input_; + std::vector calib_cache_; +}; + diff --git a/algorithm/common_det/cuda/yolov7/config.h b/algorithm/common_det/cuda/yolov7/config.h new file mode 100644 index 0000000..a6e96f0 --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/config.h @@ -0,0 +1,55 @@ +#pragma once + +/* -------------------------------------------------------- + * These configs are related to tensorrt model, if these are changed, + * please re-compile and re-serialize the tensorrt model. + * --------------------------------------------------------*/ + +// For INT8, you need prepare the calibration dataset, please refer to +// https://github.com/wang-xinyu/tensorrtx/tree/master/yolov5#int8-quantization +#define USE_FP16 // set USE_INT8 or USE_FP16 or USE_FP32 + +// These are used to define input/output tensor names, +// you can set them to whatever you want. +const static char* kInputTensorName = "data"; +const static char* kOutputTensorName = "prob"; + +// Detection model and Segmentation model' number of classes +// constexpr static int kNumClass = 80; + +// Classfication model's number of classes +constexpr static int kClsNumClass = 1000; + +constexpr static int kBatchSize = 1; + +// Yolo's input width and height must by divisible by 32 +// constexpr static int kInputH = 640; +// constexpr static int kInputW = 640; + +// Classfication model's input shape +constexpr static int kClsInputH = 224; +constexpr static int kClsInputW = 224; + +// Maximum number of output bounding boxes from yololayer plugin. +// That is maximum number of output bounding boxes before NMS. +constexpr static int kMaxNumOutputBbox = 1000; + +constexpr static int kNumAnchor = 3; + +// The bboxes whose confidence is lower than kIgnoreThresh will be ignored in yololayer plugin. +constexpr static float kIgnoreThresh = 0.1f; + +/* -------------------------------------------------------- + * These configs are NOT related to tensorrt model, if these are changed, + * please re-compile, but no need to re-serialize the tensorrt model. + * --------------------------------------------------------*/ + +// NMS overlapping thresh and final detection confidence thresh +const static float kNmsThresh = 0.45f; +const static float kConfThresh = 0.5f; + +const static int kGpuId = 0; + +// If your image size is larger than 4096 * 3112, please increase this value +const static int kMaxInputImageSize = 4096 * 3112; + diff --git a/algorithm/common_det/cuda/yolov7/cuda_utils.h b/algorithm/common_det/cuda/yolov7/cuda_utils.h new file mode 100644 index 0000000..8fbd319 --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/cuda_utils.h @@ -0,0 +1,18 @@ +#ifndef TRTX_CUDA_UTILS_H_ +#define TRTX_CUDA_UTILS_H_ + +#include + +#ifndef CUDA_CHECK +#define CUDA_CHECK(callstr)\ + {\ + cudaError_t error_code = callstr;\ + if (error_code != cudaSuccess) {\ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\ + assert(0);\ + }\ + } +#endif // CUDA_CHECK + +#endif // TRTX_CUDA_UTILS_H_ + diff --git a/algorithm/common_det/cuda/yolov7/logging.h b/algorithm/common_det/cuda/yolov7/logging.h new file mode 100644 index 0000000..6b79a8b --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/logging.h @@ -0,0 +1,504 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORRT_LOGGING_H +#define TENSORRT_LOGGING_H + +#include "NvInferRuntimeCommon.h" +#include +#include +#include +#include +#include +#include +#include +#include "macros.h" + +using Severity = nvinfer1::ILogger::Severity; + +class LogStreamConsumerBuffer : public std::stringbuf +{ +public: + LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mOutput(stream) + , mPrefix(prefix) + , mShouldLog(shouldLog) + { + } + + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) + : mOutput(other.mOutput) + { + } + + ~LogStreamConsumerBuffer() + { + // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence + // std::streambuf::pptr() gives a pointer to the current position of the output sequence + // if the pointer to the beginning is not equal to the pointer to the current position, + // call putOutput() to log the output to the stream + if (pbase() != pptr()) + { + putOutput(); + } + } + + // synchronizes the stream buffer and returns 0 on success + // synchronizing the stream buffer consists of inserting the buffer contents into the stream, + // resetting the buffer and flushing the stream + virtual int sync() + { + putOutput(); + return 0; + } + + void putOutput() + { + if (mShouldLog) + { + // prepend timestamp + std::time_t timestamp = std::time(nullptr); + tm* tm_local = std::localtime(×tamp); + std::cout << "["; + std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; + std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; + // std::stringbuf::str() gets the string contents of the buffer + // insert the buffer contents pre-appended by the appropriate prefix into the stream + mOutput << mPrefix << str(); + // set the buffer to empty + str(""); + // flush the stream + mOutput.flush(); + } + } + + void setShouldLog(bool shouldLog) + { + mShouldLog = shouldLog; + } + +private: + std::ostream& mOutput; + std::string mPrefix; + bool mShouldLog; +}; + +//! +//! \class LogStreamConsumerBase +//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer +//! +class LogStreamConsumerBase +{ +public: + LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mBuffer(stream, prefix, shouldLog) + { + } + +protected: + LogStreamConsumerBuffer mBuffer; +}; + +//! +//! \class LogStreamConsumer +//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. +//! Order of base classes is LogStreamConsumerBase and then std::ostream. +//! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field +//! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. +//! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. +//! Please do not change the order of the parent classes. +//! +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream +{ +public: + //! \brief Creates a LogStreamConsumer which logs messages with level severity. + //! Reportable severity determines if the messages are severe enough to be logged. + LogStreamConsumer(Severity reportableSeverity, Severity severity) + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) + , std::ostream(&mBuffer) // links the stream buffer with the stream + , mShouldLog(severity <= reportableSeverity) + , mSeverity(severity) + { + } + + LogStreamConsumer(LogStreamConsumer&& other) + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) + , std::ostream(&mBuffer) // links the stream buffer with the stream + , mShouldLog(other.mShouldLog) + , mSeverity(other.mSeverity) + { + } + + void setReportableSeverity(Severity reportableSeverity) + { + mShouldLog = mSeverity <= reportableSeverity; + mBuffer.setShouldLog(mShouldLog); + } + +private: + static std::ostream& severityOstream(Severity severity) + { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + static std::string severityPrefix(Severity severity) + { + switch (severity) + { + case Severity::kINTERNAL_ERROR: return "[F] "; + case Severity::kERROR: return "[E] "; + case Severity::kWARNING: return "[W] "; + case Severity::kINFO: return "[I] "; + case Severity::kVERBOSE: return "[V] "; + default: assert(0); return ""; + } + } + + bool mShouldLog; + Severity mSeverity; +}; + +//! \class Logger +//! +//! \brief Class which manages logging of TensorRT tools and samples +//! +//! \details This class provides a common interface for TensorRT tools and samples to log information to the console, +//! and supports logging two types of messages: +//! +//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) +//! - Test pass/fail messages +//! +//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is +//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. +//! +//! In the future, this class could be extended to support dumping test results to a file in some standard format +//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). +//! +//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger +//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT +//! library and messages coming from the sample. +//! +//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the +//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger +//! object. + +class Logger : public nvinfer1::ILogger +{ +public: + Logger(Severity severity = Severity::kWARNING) + : mReportableSeverity(severity) + { + } + + //! + //! \enum TestResult + //! \brief Represents the state of a given test + //! + enum class TestResult + { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived + }; + + //! + //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger + //! \return The nvinfer1::ILogger associated with this Logger + //! + //! TODO Once all samples are updated to use this method to register the logger with TensorRT, + //! we can eliminate the inheritance of Logger from ILogger + //! + nvinfer1::ILogger& getTRTLogger() + { + return *this; + } + + //! + //! \brief Implementation of the nvinfer1::ILogger::log() virtual method + //! + //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the + //! inheritance from nvinfer1::ILogger + //! + void log(Severity severity, const char* msg) TRT_NOEXCEPT override + { + LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; + } + + //! + //! \brief Method for controlling the verbosity of logging output + //! + //! \param severity The logger will only emit messages that have severity of this level or higher. + //! + void setReportableSeverity(Severity severity) + { + mReportableSeverity = severity; + } + + //! + //! \brief Opaque handle that holds logging information for a particular test + //! + //! This object is an opaque handle to information used by the Logger to print test results. + //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used + //! with Logger::reportTest{Start,End}(). + //! + class TestAtom + { + public: + TestAtom(TestAtom&&) = default; + + private: + friend class Logger; + + TestAtom(bool started, const std::string& name, const std::string& cmdline) + : mStarted(started) + , mName(name) + , mCmdline(cmdline) + { + } + + bool mStarted; + std::string mName; + std::string mCmdline; + }; + + //! + //! \brief Define a test for logging + //! + //! \param[in] name The name of the test. This should be a string starting with + //! "TensorRT" and containing dot-separated strings containing + //! the characters [A-Za-z0-9_]. + //! For example, "TensorRT.sample_googlenet" + //! \param[in] cmdline The command line used to reproduce the test + // + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + //! + static TestAtom defineTest(const std::string& name, const std::string& cmdline) + { + return TestAtom(false, name, cmdline); + } + + //! + //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments + //! as input + //! + //! \param[in] name The name of the test + //! \param[in] argc The number of command-line arguments + //! \param[in] argv The array of command-line arguments (given as C strings) + //! + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) + { + auto cmdline = genCmdlineString(argc, argv); + return defineTest(name, cmdline); + } + + //! + //! \brief Report that a test has started. + //! + //! \pre reportTestStart() has not been called yet for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has started + //! + static void reportTestStart(TestAtom& testAtom) + { + reportTestResult(testAtom, TestResult::kRUNNING); + assert(!testAtom.mStarted); + testAtom.mStarted = true; + } + + //! + //! \brief Report that a test has ended. + //! + //! \pre reportTestStart() has been called for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has ended + //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, + //! TestResult::kFAILED, TestResult::kWAIVED + //! + static void reportTestEnd(const TestAtom& testAtom, TestResult result) + { + assert(result != TestResult::kRUNNING); + assert(testAtom.mStarted); + reportTestResult(testAtom, result); + } + + static int reportPass(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kPASSED); + return EXIT_SUCCESS; + } + + static int reportFail(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kFAILED); + return EXIT_FAILURE; + } + + static int reportWaive(const TestAtom& testAtom) + { + reportTestEnd(testAtom, TestResult::kWAIVED); + return EXIT_SUCCESS; + } + + static int reportTest(const TestAtom& testAtom, bool pass) + { + return pass ? reportPass(testAtom) : reportFail(testAtom); + } + + Severity getReportableSeverity() const + { + return mReportableSeverity; + } + +private: + //! + //! \brief returns an appropriate string for prefixing a log message with the given severity + //! + static const char* severityPrefix(Severity severity) + { + switch (severity) + { + case Severity::kINTERNAL_ERROR: return "[F] "; + case Severity::kERROR: return "[E] "; + case Severity::kWARNING: return "[W] "; + case Severity::kINFO: return "[I] "; + case Severity::kVERBOSE: return "[V] "; + default: assert(0); return ""; + } + } + + //! + //! \brief returns an appropriate string for prefixing a test result message with the given result + //! + static const char* testResultString(TestResult result) + { + switch (result) + { + case TestResult::kRUNNING: return "RUNNING"; + case TestResult::kPASSED: return "PASSED"; + case TestResult::kFAILED: return "FAILED"; + case TestResult::kWAIVED: return "WAIVED"; + default: assert(0); return ""; + } + } + + //! + //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity + //! + static std::ostream& severityOstream(Severity severity) + { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + //! + //! \brief method that implements logging test results + //! + static void reportTestResult(const TestAtom& testAtom, TestResult result) + { + severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " + << testAtom.mCmdline << std::endl; + } + + //! + //! \brief generate a command line string from the given (argc, argv) values + //! + static std::string genCmdlineString(int argc, char const* const* argv) + { + std::stringstream ss; + for (int i = 0; i < argc; i++) + { + if (i > 0) + ss << " "; + ss << argv[i]; + } + return ss.str(); + } + + Severity mReportableSeverity; +}; + +namespace +{ + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE +//! +//! Example usage: +//! +//! LOG_VERBOSE(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO +//! +//! Example usage: +//! +//! LOG_INFO(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_INFO(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING +//! +//! Example usage: +//! +//! LOG_WARN(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_WARN(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR +//! +//! Example usage: +//! +//! LOG_ERROR(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_ERROR(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR +// ("fatal" severity) +//! +//! Example usage: +//! +//! LOG_FATAL(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_FATAL(const Logger& logger) +{ + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); +} + +} // anonymous namespace + +#endif // TENSORRT_LOGGING_H diff --git a/algorithm/common_det/cuda/yolov7/macros.h b/algorithm/common_det/cuda/yolov7/macros.h new file mode 100644 index 0000000..17339a2 --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/macros.h @@ -0,0 +1,29 @@ +#ifndef __MACROS_H +#define __MACROS_H + +#include + +#ifdef API_EXPORTS +#if defined(_MSC_VER) +#define API __declspec(dllexport) +#else +#define API __attribute__((visibility("default"))) +#endif +#else + +#if defined(_MSC_VER) +#define API __declspec(dllimport) +#else +#define API +#endif +#endif // API_EXPORTS + +#if NV_TENSORRT_MAJOR >= 8 +#define TRT_NOEXCEPT noexcept +#define TRT_CONST_ENQUEUE const +#else +#define TRT_NOEXCEPT +#define TRT_CONST_ENQUEUE +#endif + +#endif // __MACROS_H diff --git a/algorithm/common_det/cuda/yolov7/model.cpp b/algorithm/common_det/cuda/yolov7/model.cpp new file mode 100644 index 0000000..467cd47 --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/model.cpp @@ -0,0 +1,628 @@ +#include "model.h" +#include "calibrator.h" +#include "config.h" +#include "yololayer.h" + +#include +#include +#include +#include +#include +#include + +using namespace nvinfer1; + +// TensorRT weight files have a simple space delimited format: +// [type] [size] +static std::map loadWeights(const std::string file) { + std::cout << "Loading weights: " << file << std::endl; + std::map weightMap; + + // Open weights file + std::ifstream input(file); + assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!"); + + // Read number of weight blobs + int32_t count; + input >> count; + assert(count > 0 && "Invalid weight map file."); + + while (count--) { + Weights wt{ DataType::kFLOAT, nullptr, 0 }; + uint32_t size; + + // Read name and type of blob + std::string name; + input >> name >> std::dec >> size; + wt.type = DataType::kFLOAT; + + // Load blob + uint32_t* val = reinterpret_cast(malloc(sizeof(val) * size)); + for (uint32_t x = 0, y = size; x < y; ++x) { + input >> std::hex >> val[x]; + } + wt.values = val; + + wt.count = size; + weightMap[name] = wt; + } + + return weightMap; +} + +static int get_width(int x, float gw, int divisor = 8) { + return int(ceil((x * gw) / divisor)) * divisor; +} + +static int get_depth(int x, float gd) { + if (x == 1) return 1; + int r = round(x * gd); + if (x * gd - int(x * gd) == 0.5 && (int(x * gd) % 2) == 0) { + --r; + } + return std::max(r, 1); +} + +static IScaleLayer* addBatchNorm2d(INetworkDefinition *network, std::map& weightMap, ITensor& input, std::string lname, float eps) { + float* gamma = (float*)weightMap[lname + ".weight"].values; + float* beta = (float*)weightMap[lname + ".bias"].values; + float* mean = (float*)weightMap[lname + ".running_mean"].values; + float* var = (float*)weightMap[lname + ".running_var"].values; + int len = weightMap[lname + ".running_var"].count; + + float* scval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + scval[i] = gamma[i] / sqrt(var[i] + eps); + } + Weights scale{ DataType::kFLOAT, scval, len }; + + float* shval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps); + } + Weights shift{ DataType::kFLOAT, shval, len }; + + float* pval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + pval[i] = 1.0; + } + Weights power{ DataType::kFLOAT, pval, len }; + + weightMap[lname + ".scale"] = scale; + weightMap[lname + ".shift"] = shift; + weightMap[lname + ".power"] = power; + IScaleLayer* scale_1 = network->addScale(input, ScaleMode::kCHANNEL, shift, scale, power); + assert(scale_1); + return scale_1; +} + +static ILayer* convBlock(INetworkDefinition *network, std::map& weightMap, ITensor& input, int outch, int ksize, int s, int g, std::string lname) { + Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; + int p = ksize / 3; + IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{ ksize, ksize }, weightMap[lname + ".conv.weight"], emptywts); + assert(conv1); + conv1->setStrideNd(DimsHW{ s, s }); + conv1->setPaddingNd(DimsHW{ p, p }); + conv1->setNbGroups(g); + conv1->setName((lname + ".conv").c_str()); + IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + ".bn", 1e-3); + + // silu = x * sigmoid + auto sig = network->addActivation(*bn1->getOutput(0), ActivationType::kSIGMOID); + assert(sig); + auto ew = network->addElementWise(*bn1->getOutput(0), *sig->getOutput(0), ElementWiseOperation::kPROD); + assert(ew); + return ew; +} + +static ILayer* focus(INetworkDefinition *network, std::map& weightMap, ITensor& input, int inch, int outch, int ksize, std::string lname, int input_h, int input_w) { + ISliceLayer* s1 = network->addSlice(input, Dims3{ 0, 0, 0 }, Dims3{ inch, input_h / 2, input_w / 2 }, Dims3{ 1, 2, 2 }); + ISliceLayer* s2 = network->addSlice(input, Dims3{ 0, 1, 0 }, Dims3{ inch, input_h / 2, input_w / 2 }, Dims3{ 1, 2, 2 }); + ISliceLayer* s3 = network->addSlice(input, Dims3{ 0, 0, 1 }, Dims3{ inch, input_h / 2, input_w / 2 }, Dims3{ 1, 2, 2 }); + ISliceLayer* s4 = network->addSlice(input, Dims3{ 0, 1, 1 }, Dims3{ inch, input_h / 2, input_w / 2 }, Dims3{ 1, 2, 2 }); + ITensor* inputTensors[] = { s1->getOutput(0), s2->getOutput(0), s3->getOutput(0), s4->getOutput(0) }; + auto cat = network->addConcatenation(inputTensors, 4); + auto conv = convBlock(network, weightMap, *cat->getOutput(0), outch, ksize, 1, 1, lname + ".conv"); + return conv; +} + +static ILayer* bottleneck(INetworkDefinition *network, std::map& weightMap, ITensor& input, int c1, int c2, bool shortcut, int g, float e, std::string lname) { + auto cv1 = convBlock(network, weightMap, input, (int)((float)c2 * e), 1, 1, 1, lname + ".cv1"); + auto cv2 = convBlock(network, weightMap, *cv1->getOutput(0), c2, 3, 1, g, lname + ".cv2"); + if (shortcut && c1 == c2) { + auto ew = network->addElementWise(input, *cv2->getOutput(0), ElementWiseOperation::kSUM); + return ew; + } + return cv2; +} + +static ILayer* bottleneckCSP(INetworkDefinition *network, std::map& weightMap, ITensor& input, int c1, int c2, int n, bool shortcut, int g, float e, std::string lname) { + Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; + int c_ = (int)((float)c2 * e); + auto cv1 = convBlock(network, weightMap, input, c_, 1, 1, 1, lname + ".cv1"); + auto cv2 = network->addConvolutionNd(input, c_, DimsHW{ 1, 1 }, weightMap[lname + ".cv2.weight"], emptywts); + ITensor* y1 = cv1->getOutput(0); + for (int i = 0; i < n; i++) { + auto b = bottleneck(network, weightMap, *y1, c_, c_, shortcut, g, 1.0, lname + ".m." + std::to_string(i)); + y1 = b->getOutput(0); + } + auto cv3 = network->addConvolutionNd(*y1, c_, DimsHW{ 1, 1 }, weightMap[lname + ".cv3.weight"], emptywts); + + ITensor* inputTensors[] = { cv3->getOutput(0), cv2->getOutput(0) }; + auto cat = network->addConcatenation(inputTensors, 2); + + IScaleLayer* bn = addBatchNorm2d(network, weightMap, *cat->getOutput(0), lname + ".bn", 1e-4); + auto lr = network->addActivation(*bn->getOutput(0), ActivationType::kLEAKY_RELU); + lr->setAlpha(0.1); + + auto cv4 = convBlock(network, weightMap, *lr->getOutput(0), c2, 1, 1, 1, lname + ".cv4"); + return cv4; +} + +static ILayer* C3(INetworkDefinition *network, std::map& weightMap, ITensor& input, int c1, int c2, int n, bool shortcut, int g, float e, std::string lname) { + int c_ = (int)((float)c2 * e); + auto cv1 = convBlock(network, weightMap, input, c_, 1, 1, 1, lname + ".cv1"); + auto cv2 = convBlock(network, weightMap, input, c_, 1, 1, 1, lname + ".cv2"); + ITensor *y1 = cv1->getOutput(0); + for (int i = 0; i < n; i++) { + auto b = bottleneck(network, weightMap, *y1, c_, c_, shortcut, g, 1.0, lname + ".m." + std::to_string(i)); + y1 = b->getOutput(0); + } + + ITensor* inputTensors[] = { y1, cv2->getOutput(0) }; + auto cat = network->addConcatenation(inputTensors, 2); + + auto cv3 = convBlock(network, weightMap, *cat->getOutput(0), c2, 1, 1, 1, lname + ".cv3"); + return cv3; +} + +static ILayer* SPP(INetworkDefinition *network, std::map& weightMap, ITensor& input, int c1, int c2, int k1, int k2, int k3, std::string lname) { + int c_ = c1 / 2; + auto cv1 = convBlock(network, weightMap, input, c_, 1, 1, 1, lname + ".cv1"); + + auto pool1 = network->addPoolingNd(*cv1->getOutput(0), PoolingType::kMAX, DimsHW{ k1, k1 }); + pool1->setPaddingNd(DimsHW{ k1 / 2, k1 / 2 }); + pool1->setStrideNd(DimsHW{ 1, 1 }); + auto pool2 = network->addPoolingNd(*cv1->getOutput(0), PoolingType::kMAX, DimsHW{ k2, k2 }); + pool2->setPaddingNd(DimsHW{ k2 / 2, k2 / 2 }); + pool2->setStrideNd(DimsHW{ 1, 1 }); + auto pool3 = network->addPoolingNd(*cv1->getOutput(0), PoolingType::kMAX, DimsHW{ k3, k3 }); + pool3->setPaddingNd(DimsHW{ k3 / 2, k3 / 2 }); + pool3->setStrideNd(DimsHW{ 1, 1 }); + + ITensor* inputTensors[] = { cv1->getOutput(0), pool1->getOutput(0), pool2->getOutput(0), pool3->getOutput(0) }; + auto cat = network->addConcatenation(inputTensors, 4); + + auto cv2 = convBlock(network, weightMap, *cat->getOutput(0), c2, 1, 1, 1, lname + ".cv2"); + return cv2; +} + +static ILayer* SPPF(INetworkDefinition *network, std::map& weightMap, ITensor& input, int c1, int c2, int k, std::string lname) { + int c_ = c1 / 2; + auto cv1 = convBlock(network, weightMap, input, c_, 1, 1, 1, lname + ".cv1"); + + auto pool1 = network->addPoolingNd(*cv1->getOutput(0), PoolingType::kMAX, DimsHW{ k, k }); + pool1->setPaddingNd(DimsHW{ k / 2, k / 2 }); + pool1->setStrideNd(DimsHW{ 1, 1 }); + auto pool2 = network->addPoolingNd(*pool1->getOutput(0), PoolingType::kMAX, DimsHW{ k, k }); + pool2->setPaddingNd(DimsHW{ k / 2, k / 2 }); + pool2->setStrideNd(DimsHW{ 1, 1 }); + auto pool3 = network->addPoolingNd(*pool2->getOutput(0), PoolingType::kMAX, DimsHW{ k, k }); + pool3->setPaddingNd(DimsHW{ k / 2, k / 2 }); + pool3->setStrideNd(DimsHW{ 1, 1 }); + ITensor* inputTensors[] = { cv1->getOutput(0), pool1->getOutput(0), pool2->getOutput(0), pool3->getOutput(0) }; + auto cat = network->addConcatenation(inputTensors, 4); + auto cv2 = convBlock(network, weightMap, *cat->getOutput(0), c2, 1, 1, 1, lname + ".cv2"); + return cv2; +} + +static ILayer* Proto(INetworkDefinition* network, std::map& weightMap, ITensor& input, int c_, int c2, std::string lname) { + auto cv1 = convBlock(network, weightMap, input, c_, 3, 1, 1, lname + ".cv1"); + + auto upsample = network->addResize(*cv1->getOutput(0)); + assert(upsample); + upsample->setResizeMode(ResizeMode::kNEAREST); + const float scales[] = {1, 2, 2}; + upsample->setScales(scales, 3); + + auto cv2 = convBlock(network, weightMap, *upsample->getOutput(0), c_, 3, 1, 1, lname + ".cv2"); + auto cv3 = convBlock(network, weightMap, *cv2->getOutput(0), c2, 1, 1, 1, lname + ".cv3"); + assert(cv3); + return cv3; +} + +static std::vector> getAnchors(std::map& weightMap, std::string lname) { + std::vector> anchors; + Weights wts = weightMap[lname + ".anchor_grid"]; + int anchor_len = kNumAnchor * 2; + for (int i = 0; i < wts.count / anchor_len; i++) { + auto *p = (const float*)wts.values + i * anchor_len; + std::vector anchor(p, p + anchor_len); + anchors.push_back(anchor); + } + return anchors; +} + +static IPluginV2Layer* addYoLoLayer(INetworkDefinition *network, std::map& weightMap, std::string lname, std::vector dets, int input_h, int input_w, int n_classes, bool is_segmentation = false) { + auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1"); + auto anchors = getAnchors(weightMap, lname); + PluginField plugin_fields[2]; + int netinfo[5] = {n_classes, input_w, input_h, kMaxNumOutputBbox, (int)is_segmentation}; + plugin_fields[0].data = netinfo; + plugin_fields[0].length = 5; + plugin_fields[0].name = "netinfo"; + plugin_fields[0].type = PluginFieldType::kFLOAT32; + + //load strides from Detect layer + assert(weightMap.find(lname + ".strides") != weightMap.end() && "Not found `strides`, please check gen_wts.py!!!"); + Weights strides = weightMap[lname + ".strides"]; + auto *p = (const float*)(strides.values); + std::vector scales(p, p + strides.count); + + std::vector kernels; + for (size_t i = 0; i < anchors.size(); i++) { + YoloKernel kernel; + kernel.width = input_w / scales[i]; + kernel.height = input_h / scales[i]; + memcpy(kernel.anchors, &anchors[i][0], anchors[i].size() * sizeof(float)); + kernels.push_back(kernel); + } + plugin_fields[1].data = &kernels[0]; + plugin_fields[1].length = kernels.size(); + plugin_fields[1].name = "kernels"; + plugin_fields[1].type = PluginFieldType::kFLOAT32; + PluginFieldCollection plugin_data; + plugin_data.nbFields = 2; + plugin_data.fields = plugin_fields; + IPluginV2 *plugin_obj = creator->createPlugin("yololayer", &plugin_data); + std::vector input_tensors; + for (auto det: dets) { + input_tensors.push_back(det->getOutput(0)); + } + auto yolo = network->addPluginV2(&input_tensors[0], input_tensors.size(), *plugin_obj); + return yolo; +} + +ICudaEngine* build_det_engine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt, float& gd, float& gw, std::string& wts_name, int input_h, int input_w, int n_classes) { + INetworkDefinition* network = builder->createNetworkV2(0U); + + // Create input tensor of shape {3, input_h, input_w} + ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, input_h, input_w }); + assert(data); + std::map weightMap = loadWeights(wts_name); + + // Backbone + auto conv0 = convBlock(network, weightMap, *data, get_width(64, gw), 6, 2, 1, "model.0"); + assert(conv0); + auto conv1 = convBlock(network, weightMap, *conv0->getOutput(0), get_width(128, gw), 3, 2, 1, "model.1"); + auto bottleneck_CSP2 = C3(network, weightMap, *conv1->getOutput(0), get_width(128, gw), get_width(128, gw), get_depth(3, gd), true, 1, 0.5, "model.2"); + auto conv3 = convBlock(network, weightMap, *bottleneck_CSP2->getOutput(0), get_width(256, gw), 3, 2, 1, "model.3"); + auto bottleneck_csp4 = C3(network, weightMap, *conv3->getOutput(0), get_width(256, gw), get_width(256, gw), get_depth(6, gd), true, 1, 0.5, "model.4"); + auto conv5 = convBlock(network, weightMap, *bottleneck_csp4->getOutput(0), get_width(512, gw), 3, 2, 1, "model.5"); + auto bottleneck_csp6 = C3(network, weightMap, *conv5->getOutput(0), get_width(512, gw), get_width(512, gw), get_depth(9, gd), true, 1, 0.5, "model.6"); + auto conv7 = convBlock(network, weightMap, *bottleneck_csp6->getOutput(0), get_width(1024, gw), 3, 2, 1, "model.7"); + auto bottleneck_csp8 = C3(network, weightMap, *conv7->getOutput(0), get_width(1024, gw), get_width(1024, gw), get_depth(3, gd), true, 1, 0.5, "model.8"); + auto spp9 = SPPF(network, weightMap, *bottleneck_csp8->getOutput(0), get_width(1024, gw), get_width(1024, gw), 5, "model.9"); + + // Head + auto conv10 = convBlock(network, weightMap, *spp9->getOutput(0), get_width(512, gw), 1, 1, 1, "model.10"); + + auto upsample11 = network->addResize(*conv10->getOutput(0)); + assert(upsample11); + upsample11->setResizeMode(ResizeMode::kNEAREST); + upsample11->setOutputDimensions(bottleneck_csp6->getOutput(0)->getDimensions()); + + ITensor* inputTensors12[] = { upsample11->getOutput(0), bottleneck_csp6->getOutput(0) }; + auto cat12 = network->addConcatenation(inputTensors12, 2); + auto bottleneck_csp13 = C3(network, weightMap, *cat12->getOutput(0), get_width(1024, gw), get_width(512, gw), get_depth(3, gd), false, 1, 0.5, "model.13"); + auto conv14 = convBlock(network, weightMap, *bottleneck_csp13->getOutput(0), get_width(256, gw), 1, 1, 1, "model.14"); + + auto upsample15 = network->addResize(*conv14->getOutput(0)); + assert(upsample15); + upsample15->setResizeMode(ResizeMode::kNEAREST); + upsample15->setOutputDimensions(bottleneck_csp4->getOutput(0)->getDimensions()); + + ITensor* inputTensors16[] = { upsample15->getOutput(0), bottleneck_csp4->getOutput(0) }; + auto cat16 = network->addConcatenation(inputTensors16, 2); + + auto bottleneck_csp17 = C3(network, weightMap, *cat16->getOutput(0), get_width(512, gw), get_width(256, gw), get_depth(3, gd), false, 1, 0.5, "model.17"); + + // Detect + IConvolutionLayer* det0 = network->addConvolutionNd(*bottleneck_csp17->getOutput(0), 3 * (n_classes + 5), DimsHW{ 1, 1 }, weightMap["model.24.m.0.weight"], weightMap["model.24.m.0.bias"]); + auto conv18 = convBlock(network, weightMap, *bottleneck_csp17->getOutput(0), get_width(256, gw), 3, 2, 1, "model.18"); + ITensor* inputTensors19[] = { conv18->getOutput(0), conv14->getOutput(0) }; + auto cat19 = network->addConcatenation(inputTensors19, 2); + auto bottleneck_csp20 = C3(network, weightMap, *cat19->getOutput(0), get_width(512, gw), get_width(512, gw), get_depth(3, gd), false, 1, 0.5, "model.20"); + IConvolutionLayer* det1 = network->addConvolutionNd(*bottleneck_csp20->getOutput(0), 3 * (n_classes + 5), DimsHW{ 1, 1 }, weightMap["model.24.m.1.weight"], weightMap["model.24.m.1.bias"]); + auto conv21 = convBlock(network, weightMap, *bottleneck_csp20->getOutput(0), get_width(512, gw), 3, 2, 1, "model.21"); + ITensor* inputTensors22[] = { conv21->getOutput(0), conv10->getOutput(0) }; + auto cat22 = network->addConcatenation(inputTensors22, 2); + auto bottleneck_csp23 = C3(network, weightMap, *cat22->getOutput(0), get_width(1024, gw), get_width(1024, gw), get_depth(3, gd), false, 1, 0.5, "model.23"); + IConvolutionLayer* det2 = network->addConvolutionNd(*bottleneck_csp23->getOutput(0), 3 * (n_classes + 5), DimsHW{ 1, 1 }, weightMap["model.24.m.2.weight"], weightMap["model.24.m.2.bias"]); + + auto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector{det0, det1, det2}, input_h, input_w, n_classes); + yolo->getOutput(0)->setName(kOutputTensorName); + network->markOutput(*yolo->getOutput(0)); + + // Engine config + builder->setMaxBatchSize(maxBatchSize); + config->setMaxWorkspaceSize(16 * (1 << 20)); // 16MB +#if defined(USE_FP16) + config->setFlag(BuilderFlag::kFP16); +#elif defined(USE_INT8) + std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; + assert(builder->platformHasFastInt8()); + config->setFlag(BuilderFlag::kINT8); + Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, input_w, input_h, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setInt8Calibrator(calibrator); +#endif + + std::cout << "Building engine, please wait for a while..." << std::endl; + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); + std::cout << "Build engine successfully!" << std::endl; + + // Don't need the network any more + network->destroy(); + + // Release host memory + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); + } + + return engine; +} + +ICudaEngine* build_det_p6_engine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt, float& gd, float& gw, std::string& wts_name, int input_h, int input_w, int n_classes) { + INetworkDefinition* network = builder->createNetworkV2(0U); + + // Create input tensor of shape {3, input_h, input_w} + ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, input_h, input_w }); + assert(data); + + std::map weightMap = loadWeights(wts_name); + + // Backbone + auto conv0 = convBlock(network, weightMap, *data, get_width(64, gw), 6, 2, 1, "model.0"); + auto conv1 = convBlock(network, weightMap, *conv0->getOutput(0), get_width(128, gw), 3, 2, 1, "model.1"); + auto c3_2 = C3(network, weightMap, *conv1->getOutput(0), get_width(128, gw), get_width(128, gw), get_depth(3, gd), true, 1, 0.5, "model.2"); + auto conv3 = convBlock(network, weightMap, *c3_2->getOutput(0), get_width(256, gw), 3, 2, 1, "model.3"); + auto c3_4 = C3(network, weightMap, *conv3->getOutput(0), get_width(256, gw), get_width(256, gw), get_depth(6, gd), true, 1, 0.5, "model.4"); + auto conv5 = convBlock(network, weightMap, *c3_4->getOutput(0), get_width(512, gw), 3, 2, 1, "model.5"); + auto c3_6 = C3(network, weightMap, *conv5->getOutput(0), get_width(512, gw), get_width(512, gw), get_depth(9, gd), true, 1, 0.5, "model.6"); + auto conv7 = convBlock(network, weightMap, *c3_6->getOutput(0), get_width(768, gw), 3, 2, 1, "model.7"); + auto c3_8 = C3(network, weightMap, *conv7->getOutput(0), get_width(768, gw), get_width(768, gw), get_depth(3, gd), true, 1, 0.5, "model.8"); + auto conv9 = convBlock(network, weightMap, *c3_8->getOutput(0), get_width(1024, gw), 3, 2, 1, "model.9"); + auto c3_10 = C3(network, weightMap, *conv9->getOutput(0), get_width(1024, gw), get_width(1024, gw), get_depth(3, gd), true, 1, 0.5, "model.10"); + auto sppf11 = SPPF(network, weightMap, *c3_10->getOutput(0), get_width(1024, gw), get_width(1024, gw), 5, "model.11"); + + // Head + auto conv12 = convBlock(network, weightMap, *sppf11->getOutput(0), get_width(768, gw), 1, 1, 1, "model.12"); + auto upsample13 = network->addResize(*conv12->getOutput(0)); + assert(upsample13); + upsample13->setResizeMode(ResizeMode::kNEAREST); + upsample13->setOutputDimensions(c3_8->getOutput(0)->getDimensions()); + ITensor* inputTensors14[] = { upsample13->getOutput(0), c3_8->getOutput(0) }; + auto cat14 = network->addConcatenation(inputTensors14, 2); + auto c3_15 = C3(network, weightMap, *cat14->getOutput(0), get_width(1536, gw), get_width(768, gw), get_depth(3, gd), false, 1, 0.5, "model.15"); + + auto conv16 = convBlock(network, weightMap, *c3_15->getOutput(0), get_width(512, gw), 1, 1, 1, "model.16"); + auto upsample17 = network->addResize(*conv16->getOutput(0)); + assert(upsample17); + upsample17->setResizeMode(ResizeMode::kNEAREST); + upsample17->setOutputDimensions(c3_6->getOutput(0)->getDimensions()); + ITensor* inputTensors18[] = { upsample17->getOutput(0), c3_6->getOutput(0) }; + auto cat18 = network->addConcatenation(inputTensors18, 2); + auto c3_19 = C3(network, weightMap, *cat18->getOutput(0), get_width(1024, gw), get_width(512, gw), get_depth(3, gd), false, 1, 0.5, "model.19"); + + auto conv20 = convBlock(network, weightMap, *c3_19->getOutput(0), get_width(256, gw), 1, 1, 1, "model.20"); + auto upsample21 = network->addResize(*conv20->getOutput(0)); + assert(upsample21); + upsample21->setResizeMode(ResizeMode::kNEAREST); + upsample21->setOutputDimensions(c3_4->getOutput(0)->getDimensions()); + ITensor* inputTensors21[] = { upsample21->getOutput(0), c3_4->getOutput(0) }; + auto cat22 = network->addConcatenation(inputTensors21, 2); + auto c3_23 = C3(network, weightMap, *cat22->getOutput(0), get_width(512, gw), get_width(256, gw), get_depth(3, gd), false, 1, 0.5, "model.23"); + + auto conv24 = convBlock(network, weightMap, *c3_23->getOutput(0), get_width(256, gw), 3, 2, 1, "model.24"); + ITensor* inputTensors25[] = { conv24->getOutput(0), conv20->getOutput(0) }; + auto cat25 = network->addConcatenation(inputTensors25, 2); + auto c3_26 = C3(network, weightMap, *cat25->getOutput(0), get_width(1024, gw), get_width(512, gw), get_depth(3, gd), false, 1, 0.5, "model.26"); + + auto conv27 = convBlock(network, weightMap, *c3_26->getOutput(0), get_width(512, gw), 3, 2, 1, "model.27"); + ITensor* inputTensors28[] = { conv27->getOutput(0), conv16->getOutput(0) }; + auto cat28 = network->addConcatenation(inputTensors28, 2); + auto c3_29 = C3(network, weightMap, *cat28->getOutput(0), get_width(1536, gw), get_width(768, gw), get_depth(3, gd), false, 1, 0.5, "model.29"); + + auto conv30 = convBlock(network, weightMap, *c3_29->getOutput(0), get_width(768, gw), 3, 2, 1, "model.30"); + ITensor* inputTensors31[] = { conv30->getOutput(0), conv12->getOutput(0) }; + auto cat31 = network->addConcatenation(inputTensors31, 2); + auto c3_32 = C3(network, weightMap, *cat31->getOutput(0), get_width(2048, gw), get_width(1024, gw), get_depth(3, gd), false, 1, 0.5, "model.32"); + + // Detect + IConvolutionLayer* det0 = network->addConvolutionNd(*c3_23->getOutput(0), 3 * (n_classes + 5), DimsHW{ 1, 1 }, weightMap["model.33.m.0.weight"], weightMap["model.33.m.0.bias"]); + IConvolutionLayer* det1 = network->addConvolutionNd(*c3_26->getOutput(0), 3 * (n_classes + 5), DimsHW{ 1, 1 }, weightMap["model.33.m.1.weight"], weightMap["model.33.m.1.bias"]); + IConvolutionLayer* det2 = network->addConvolutionNd(*c3_29->getOutput(0), 3 * (n_classes + 5), DimsHW{ 1, 1 }, weightMap["model.33.m.2.weight"], weightMap["model.33.m.2.bias"]); + IConvolutionLayer* det3 = network->addConvolutionNd(*c3_32->getOutput(0), 3 * (n_classes + 5), DimsHW{ 1, 1 }, weightMap["model.33.m.3.weight"], weightMap["model.33.m.3.bias"]); + + auto yolo = addYoLoLayer(network, weightMap, "model.33", std::vector{det0, det1, det2, det3}, input_h, input_w, n_classes); + yolo->getOutput(0)->setName(kOutputTensorName); + network->markOutput(*yolo->getOutput(0)); + + // Engine config + builder->setMaxBatchSize(maxBatchSize); + config->setMaxWorkspaceSize(16 * (1 << 20)); // 16MB +#if defined(USE_FP16) + config->setFlag(BuilderFlag::kFP16); +#elif defined(USE_INT8) + std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; + assert(builder->platformHasFastInt8()); + config->setFlag(BuilderFlag::kINT8); + Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, input_w, input_h, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setInt8Calibrator(calibrator); +#endif + + std::cout << "Building engine, please wait for a while..." << std::endl; + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); + std::cout << "Build engine successfully!" << std::endl; + + // Don't need the network any more + network->destroy(); + + // Release host memory + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); + } + + return engine; +} + +ICudaEngine* build_cls_engine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt, float& gd, float& gw, std::string& wts_name) { + INetworkDefinition* network = builder->createNetworkV2(0U); + + // Create input tensor + ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, kClsInputH, kClsInputW }); + assert(data); + std::map weightMap = loadWeights(wts_name); + + // Backbone + auto conv0 = convBlock(network, weightMap, *data, get_width(64, gw), 6, 2, 1, "model.0"); + assert(conv0); + auto conv1 = convBlock(network, weightMap, *conv0->getOutput(0), get_width(128, gw), 3, 2, 1, "model.1"); + auto bottleneck_CSP2 = C3(network, weightMap, *conv1->getOutput(0), get_width(128, gw), get_width(128, gw), get_depth(3, gd), true, 1, 0.5, "model.2"); + auto conv3 = convBlock(network, weightMap, *bottleneck_CSP2->getOutput(0), get_width(256, gw), 3, 2, 1, "model.3"); + auto bottleneck_csp4 = C3(network, weightMap, *conv3->getOutput(0), get_width(256, gw), get_width(256, gw), get_depth(6, gd), true, 1, 0.5, "model.4"); + auto conv5 = convBlock(network, weightMap, *bottleneck_csp4->getOutput(0), get_width(512, gw), 3, 2, 1, "model.5"); + auto bottleneck_csp6 = C3(network, weightMap, *conv5->getOutput(0), get_width(512, gw), get_width(512, gw), get_depth(9, gd), true, 1, 0.5, "model.6"); + auto conv7 = convBlock(network, weightMap, *bottleneck_csp6->getOutput(0), get_width(1024, gw), 3, 2, 1, "model.7"); + auto bottleneck_csp8 = C3(network, weightMap, *conv7->getOutput(0), get_width(1024, gw), get_width(1024, gw), get_depth(3, gd), true, 1, 0.5, "model.8"); + + // Head + auto conv_class = convBlock(network, weightMap, *bottleneck_csp8->getOutput(0), 1280, 1, 1, 1, "model.9.conv"); + IPoolingLayer* pool2 = network->addPoolingNd(*conv_class->getOutput(0), PoolingType::kAVERAGE, DimsHW{7, 7}); + assert(pool2); + IFullyConnectedLayer* yolo = network->addFullyConnected(*pool2->getOutput(0), kClsNumClass, weightMap["model.9.linear.weight"], weightMap["model.9.linear.bias"]); + assert(yolo); + + yolo->getOutput(0)->setName(kOutputTensorName); + network->markOutput(*yolo->getOutput(0)); + + // Engine config + builder->setMaxBatchSize(maxBatchSize); + config->setMaxWorkspaceSize(16 * (1 << 20)); // 16MB + +#if defined(USE_FP16) + config->setFlag(BuilderFlag::kFP16); +#elif defined(USE_INT8) + std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; + assert(builder->platformHasFastInt8()); + config->setFlag(BuilderFlag::kINT8); + Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kClsInputW, kClsInputW, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setInt8Calibrator(calibrator); +#endif + + std::cout << "Building engine, please wait for a while..." << std::endl; + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); + std::cout << "Build engine successfully!" << std::endl; + + // Don't need the network any more + network->destroy(); + + // Release host memory + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); + } + + return engine; +} + +ICudaEngine* build_seg_engine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt, float& gd, float& gw, std::string& wts_name, int input_h, int input_w, int n_classes) { + INetworkDefinition* network = builder->createNetworkV2(0U); + ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, input_h, input_w }); + assert(data); + std::map weightMap = loadWeights(wts_name); + + // Backbone + auto conv0 = convBlock(network, weightMap, *data, get_width(64, gw), 6, 2, 1, "model.0"); + assert(conv0); + auto conv1 = convBlock(network, weightMap, *conv0->getOutput(0), get_width(128, gw), 3, 2, 1, "model.1"); + auto bottleneck_CSP2 = C3(network, weightMap, *conv1->getOutput(0), get_width(128, gw), get_width(128, gw), get_depth(3, gd), true, 1, 0.5, "model.2"); + auto conv3 = convBlock(network, weightMap, *bottleneck_CSP2->getOutput(0), get_width(256, gw), 3, 2, 1, "model.3"); + auto bottleneck_csp4 = C3(network, weightMap, *conv3->getOutput(0), get_width(256, gw), get_width(256, gw), get_depth(6, gd), true, 1, 0.5, "model.4"); + auto conv5 = convBlock(network, weightMap, *bottleneck_csp4->getOutput(0), get_width(512, gw), 3, 2, 1, "model.5"); + auto bottleneck_csp6 = C3(network, weightMap, *conv5->getOutput(0), get_width(512, gw), get_width(512, gw), get_depth(9, gd), true, 1, 0.5, "model.6"); + auto conv7 = convBlock(network, weightMap, *bottleneck_csp6->getOutput(0), get_width(1024, gw), 3, 2, 1, "model.7"); + auto bottleneck_csp8 = C3(network, weightMap, *conv7->getOutput(0), get_width(1024, gw), get_width(1024, gw), get_depth(3, gd), true, 1, 0.5, "model.8"); + auto spp9 = SPPF(network, weightMap, *bottleneck_csp8->getOutput(0), get_width(1024, gw), get_width(1024, gw), 5, "model.9"); + + // Head + auto conv10 = convBlock(network, weightMap, *spp9->getOutput(0), get_width(512, gw), 1, 1, 1, "model.10"); + + auto upsample11 = network->addResize(*conv10->getOutput(0)); + assert(upsample11); + upsample11->setResizeMode(ResizeMode::kNEAREST); + upsample11->setOutputDimensions(bottleneck_csp6->getOutput(0)->getDimensions()); + + ITensor* inputTensors12[] = { upsample11->getOutput(0), bottleneck_csp6->getOutput(0) }; + auto cat12 = network->addConcatenation(inputTensors12, 2); + auto bottleneck_csp13 = C3(network, weightMap, *cat12->getOutput(0), get_width(1024, gw), get_width(512, gw), get_depth(3, gd), false, 1, 0.5, "model.13"); + auto conv14 = convBlock(network, weightMap, *bottleneck_csp13->getOutput(0), get_width(256, gw), 1, 1, 1, "model.14"); + + auto upsample15 = network->addResize(*conv14->getOutput(0)); + assert(upsample15); + upsample15->setResizeMode(ResizeMode::kNEAREST); + upsample15->setOutputDimensions(bottleneck_csp4->getOutput(0)->getDimensions()); + + ITensor* inputTensors16[] = { upsample15->getOutput(0), bottleneck_csp4->getOutput(0) }; + auto cat16 = network->addConcatenation(inputTensors16, 2); + + auto bottleneck_csp17 = C3(network, weightMap, *cat16->getOutput(0), get_width(512, gw), get_width(256, gw), get_depth(3, gd), false, 1, 0.5, "model.17"); + + // Segmentation + IConvolutionLayer* det0 = network->addConvolutionNd(*bottleneck_csp17->getOutput(0), 3 * (32 + n_classes + 5), DimsHW{ 1, 1 }, weightMap["model.24.m.0.weight"], weightMap["model.24.m.0.bias"]); + auto conv18 = convBlock(network, weightMap, *bottleneck_csp17->getOutput(0), get_width(256, gw), 3, 2, 1, "model.18"); + ITensor* inputTensors19[] = { conv18->getOutput(0), conv14->getOutput(0) }; + auto cat19 = network->addConcatenation(inputTensors19, 2); + auto bottleneck_csp20 = C3(network, weightMap, *cat19->getOutput(0), get_width(512, gw), get_width(512, gw), get_depth(3, gd), false, 1, 0.5, "model.20"); + IConvolutionLayer* det1 = network->addConvolutionNd(*bottleneck_csp20->getOutput(0), 3 * (32 + n_classes + 5), DimsHW{ 1, 1 }, weightMap["model.24.m.1.weight"], weightMap["model.24.m.1.bias"]); + auto conv21 = convBlock(network, weightMap, *bottleneck_csp20->getOutput(0), get_width(512, gw), 3, 2, 1, "model.21"); + ITensor* inputTensors22[] = { conv21->getOutput(0), conv10->getOutput(0) }; + auto cat22 = network->addConcatenation(inputTensors22, 2); + auto bottleneck_csp23 = C3(network, weightMap, *cat22->getOutput(0), get_width(1024, gw), get_width(1024, gw), get_depth(3, gd), false, 1, 0.5, "model.23"); + IConvolutionLayer* det2 = network->addConvolutionNd(*bottleneck_csp23->getOutput(0), 3 * (32 + n_classes + 5), DimsHW{ 1, 1 }, weightMap["model.24.m.2.weight"], weightMap["model.24.m.2.bias"]); + + auto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector{det0, det1, det2}, input_h, input_w, n_classes, true); + yolo->getOutput(0)->setName(kOutputTensorName); + network->markOutput(*yolo->getOutput(0)); + + auto proto = Proto(network, weightMap, *bottleneck_csp17->getOutput(0), get_width(256, gw), 32, "model.24.proto"); + proto->getOutput(0)->setName("proto"); + network->markOutput(*proto->getOutput(0)); + + // Engine config + builder->setMaxBatchSize(maxBatchSize); + config->setMaxWorkspaceSize(16 * (1 << 20)); // 16MB +#if defined(USE_FP16) + config->setFlag(BuilderFlag::kFP16); +#elif defined(USE_INT8) + std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; + assert(builder->platformHasFastInt8()); + config->setFlag(BuilderFlag::kINT8); + Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, input_w, input_h, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setInt8Calibrator(calibrator); +#endif + + std::cout << "Building engine, please wait for a while..." << std::endl; + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); + std::cout << "Build engine successfully!" << std::endl; + + // Don't need the network any more + network->destroy(); + + // Release host memory + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); + } + + return engine; +} + diff --git a/algorithm/common_det/cuda/yolov7/model.h b/algorithm/common_det/cuda/yolov7/model.h new file mode 100644 index 0000000..3153c64 --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/model.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +nvinfer1::ICudaEngine* build_det_engine(unsigned int maxBatchSize, nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, + float& gd, float& gw, std::string& wts_name, int input_h, int input_w, int n_classes); + +nvinfer1::ICudaEngine* build_det_p6_engine(unsigned int maxBatchSize, nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, + float& gd, float& gw, std::string& wts_name, int input_h, int input_w, int n_classes); + +nvinfer1::ICudaEngine* build_cls_engine(unsigned int maxBatchSize, nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, float& gd, float& gw, std::string& wts_name); + +nvinfer1::ICudaEngine* build_seg_engine(unsigned int maxBatchSize, nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, float& gd, float& gw, std::string& wts_name, int input_h, int input_w, int n_classes); diff --git a/algorithm/common_det/cuda/yolov7/postprocess.cpp b/algorithm/common_det/cuda/yolov7/postprocess.cpp new file mode 100644 index 0000000..289540b --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/postprocess.cpp @@ -0,0 +1,198 @@ +#include "postprocess.h" +#include "utils.h" + + +cv::Rect get_rect(cv::Mat& img, float bbox[4], int input_h, int input_w) +{ + float l, r, t, b; + float r_w = input_w / (img.cols * 1.0); + float r_h = input_h / (img.rows * 1.0); + if (r_h > r_w) { + l = bbox[0] - bbox[2] / 2.f; + r = bbox[0] + bbox[2] / 2.f; + t = bbox[1] - bbox[3] / 2.f - (input_h - r_w * img.rows) / 2; + b = bbox[1] + bbox[3] / 2.f - (input_h - r_w * img.rows) / 2; + l = l / r_w; + r = r / r_w; + t = t / r_w; + b = b / r_w; + } else { + l = bbox[0] - bbox[2] / 2.f - (input_w - r_h * img.cols) / 2; + r = bbox[0] + bbox[2] / 2.f - (input_w - r_h * img.cols) / 2; + t = bbox[1] - bbox[3] / 2.f; + b = bbox[1] + bbox[3] / 2.f; + l = l / r_h; + r = r / r_h; + t = t / r_h; + b = b / r_h; + } + return cv::Rect(round(l), round(t), round(r - l), round(b - t)); +} + +static float iou(float lbox[4], float rbox[4]) { + float interBox[] = { + (std::max)(lbox[0] - lbox[2] / 2.f , rbox[0] - rbox[2] / 2.f), //left + (std::min)(lbox[0] + lbox[2] / 2.f , rbox[0] + rbox[2] / 2.f), //right + (std::max)(lbox[1] - lbox[3] / 2.f , rbox[1] - rbox[3] / 2.f), //top + (std::min)(lbox[1] + lbox[3] / 2.f , rbox[1] + rbox[3] / 2.f), //bottom + }; + + if (interBox[2] > interBox[3] || interBox[0] > interBox[1]) + return 0.0f; + + float interBoxS = (interBox[1] - interBox[0])*(interBox[3] - interBox[2]); + return interBoxS / (lbox[2] * lbox[3] + rbox[2] * rbox[3] - interBoxS); +} + +static bool cmp(const Detection& a, const Detection& b) { + return a.conf > b.conf; +} + +void nms(std::vector& res, float* output, float conf_thresh, float nms_thresh) { + int det_size = sizeof(Detection) / sizeof(float); + std::map> m; + for (int i = 0; i < output[0] && i < kMaxNumOutputBbox; i++) { + if (output[1 + det_size * i + 4] <= conf_thresh) continue; + Detection det; + memcpy(&det, &output[1 + det_size * i], det_size * sizeof(float)); + if (m.count(det.class_id) == 0) m.emplace(det.class_id, std::vector()); + m[det.class_id].push_back(det); + } + for (auto it = m.begin(); it != m.end(); it++) { + auto& dets = it->second; + std::sort(dets.begin(), dets.end(), cmp); + for (size_t m = 0; m < dets.size(); ++m) { + auto& item = dets[m]; + res.push_back(item); + for (size_t n = m + 1; n < dets.size(); ++n) { + if (iou(item.bbox, dets[n].bbox) > nms_thresh) { + dets.erase(dets.begin() + n); + --n; + } + } + } + } +} + +void batch_nms(std::vector>& res_batch, float *output, int batch_size, int output_size, float conf_thresh, float nms_thresh) { + res_batch.resize(batch_size); + for (int i = 0; i < batch_size; i++) { + nms(res_batch[i], &output[i * output_size], conf_thresh, nms_thresh); + } +} + +void draw_bbox(std::vector& img_batch, std::vector>& res_batch, int input_h, int input_w) { + for (size_t i = 0; i < img_batch.size(); i++) { + auto& res = res_batch[i]; + cv::Mat img = img_batch[i]; + for (size_t j = 0; j < res.size(); j++) { + cv::Rect r = get_rect(img, res[j].bbox, input_h, input_w); + cv::rectangle(img, r, cv::Scalar(0x27, 0xC1, 0x36), 2); + cv::putText(img, std::to_string((int)res[j].class_id), cv::Point(r.x, r.y - 1), cv::FONT_HERSHEY_PLAIN, 1.2, cv::Scalar(0xFF, 0xFF, 0xFF), 2); + } + } +} + +static cv::Rect get_downscale_rect(float bbox[4], float scale) { + float left = bbox[0] - bbox[2] / 2; + float top = bbox[1] - bbox[3] / 2; + float right = bbox[0] + bbox[2] / 2; + float bottom = bbox[1] + bbox[3] / 2; + left /= scale; + top /= scale; + right /= scale; + bottom /= scale; + return cv::Rect(round(left), round(top), round(right - left), round(bottom - top)); +} + +std::vector process_mask(const float* proto, int proto_size, std::vector& dets, int input_h, int input_w) { + std::vector masks; + for (size_t i = 0; i < dets.size(); i++) { + cv::Mat mask_mat = cv::Mat::zeros(input_h / 4, input_w / 4, CV_32FC1); + auto r = get_downscale_rect(dets[i].bbox, 4); + if (r.x < 0) r.x = 0; + if (r.y < 0) r.y = 0; + if (r.x + r.width > mask_mat.cols) r.width = mask_mat.cols - r.x; + if (r.y + r.height > mask_mat.rows) r.height = mask_mat.rows - r.y; + // printf(" %d, %d, %d, %d, (%d, %d)\n", r.x, r.y, r.width, r.height, mask_mat.cols, mask_mat.rows); + + for (int x = r.x; x < r.x + r.width; x++) { + for (int y = r.y; y < r.y + r.height; y++) { + float e = 0.0f; + for (int j = 0; j < 32; j++) { + e += dets[i].mask[j] * proto[j * proto_size / 32 + y * mask_mat.cols + x]; + } + e = 1.0f / (1.0f + expf(-e)); + mask_mat.at(y, x) = e; + } + } + + cv::resize(mask_mat, mask_mat, cv::Size(input_w, input_h)); + masks.push_back(mask_mat); + } + return masks; +} + +cv::Mat scale_mask(cv::Mat mask, cv::Mat img, int input_h, int input_w) { + int x, y, w, h; + float r_w = input_w / (img.cols * 1.0); + float r_h = input_h / (img.rows * 1.0); + if (r_h > r_w) { + w = input_w; + h = r_w * img.rows; + x = 0; + y = (input_h - h) / 2; + } else { + w = r_h * img.cols; + h = input_h; + x = (input_w - w) / 2; + y = 0; + } + cv::Rect r(x, y, w, h); + cv::Mat res; + cv::resize(mask(r), res, img.size()); + return res; +} + +void draw_mask_bbox(cv::Mat& img, std::vector& dets, std::vector& masks, std::unordered_map& labels_map, int input_h, int input_w) { + static std::vector colors = {0xFF3838, 0xFF9D97, 0xFF701F, 0xFFB21D, 0xCFD231, 0x48F90A, + 0x92CC17, 0x3DDB86, 0x1A9334, 0x00D4BB, 0x2C99A8, 0x00C2FF, + 0x344593, 0x6473FF, 0x0018EC, 0x8438FF, 0x520085, 0xCB38FF, + 0xFF95C8, 0xFF37C7}; + for (size_t i = 0; i < dets.size(); i++) { + cv::Mat img_mask = scale_mask(masks[i], img, input_h, input_w); + auto color = colors[(int)dets[i].class_id % colors.size()]; + auto bgr = cv::Scalar(color & 0xFF, color >> 8 & 0xFF, color >> 16 & 0xFF); + + cv::Rect r = get_rect(img, dets[i].bbox, input_h, input_w); + for (int x = r.x; x < r.x + r.width; x++) { + for (int y = r.y; y < r.y + r.height; y++) { + float val = img_mask.at(y, x); + if (val <= 0.5) continue; + img.at(y, x)[0] = img.at(y, x)[0] / 2 + bgr[0] / 2; + img.at(y, x)[1] = img.at(y, x)[1] / 2 + bgr[1] / 2; + img.at(y, x)[2] = img.at(y, x)[2] / 2 + bgr[2] / 2; + } + } + + cv::rectangle(img, r, bgr, 2); + + // Get the size of the text + cv::Size textSize = cv::getTextSize(labels_map[(int)dets[i].class_id] + " " + to_string_with_precision(dets[i].conf), cv::FONT_HERSHEY_PLAIN, 1.2, 2, NULL); + // Set the top left corner of the rectangle + cv::Point topLeft(r.x, r.y - textSize.height); + + // Set the bottom right corner of the rectangle + cv::Point bottomRight(r.x + textSize.width, r.y + textSize.height); + + // Set the thickness of the rectangle lines + int lineThickness = 2; + + // Draw the rectangle on the image + cv::rectangle(img, topLeft, bottomRight, bgr, -1); + + cv::putText(img, labels_map[(int)dets[i].class_id] + " " + to_string_with_precision(dets[i].conf), cv::Point(r.x, r.y + 4), cv::FONT_HERSHEY_PLAIN, 1.2, cv::Scalar::all(0xFF), 2); + + } +} + diff --git a/algorithm/common_det/cuda/yolov7/postprocess.h b/algorithm/common_det/cuda/yolov7/postprocess.h new file mode 100644 index 0000000..4555bd7 --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/postprocess.h @@ -0,0 +1,16 @@ +#pragma once + +#include "types.h" +#include + +cv::Rect get_rect(cv::Mat& img, float bbox[4], int input_h, int input_w); + +void nms(std::vector& res, float *output, float conf_thresh, float nms_thresh = 0.5); + +void batch_nms(std::vector>& batch_res, float *output, int batch_size, int output_size, float conf_thresh, float nms_thresh = 0.5); + +void draw_bbox(std::vector& img_batch, std::vector>& res_batch, int input_h, int input_w); + +std::vector process_mask(const float* proto, int proto_size, std::vector& dets, int input_h, int input_w); + +void draw_mask_bbox(cv::Mat& img, std::vector& dets, std::vector& masks, std::unordered_map& labels_map, int input_h, int input_w); diff --git a/algorithm/common_det/cuda/yolov7/preprocess.cu b/algorithm/common_det/cuda/yolov7/preprocess.cu new file mode 100644 index 0000000..8de0093 --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/preprocess.cu @@ -0,0 +1,153 @@ +#include "preprocess.h" +#include "cuda_utils.h" + +static uint8_t* img_buffer_host = nullptr; +static uint8_t* img_buffer_device = nullptr; + +struct AffineMatrix { + float value[6]; +}; + +__global__ void warpaffine_kernel( + uint8_t* src, int src_line_size, int src_width, + int src_height, float* dst, int dst_width, + int dst_height, uint8_t const_value_st, + AffineMatrix d2s, int edge) { + int position = blockDim.x * blockIdx.x + threadIdx.x; + if (position >= edge) return; + + float m_x1 = d2s.value[0]; + float m_y1 = d2s.value[1]; + float m_z1 = d2s.value[2]; + float m_x2 = d2s.value[3]; + float m_y2 = d2s.value[4]; + float m_z2 = d2s.value[5]; + + int dx = position % dst_width; + int dy = position / dst_width; + float src_x = m_x1 * dx + m_y1 * dy + m_z1 + 0.5f; + float src_y = m_x2 * dx + m_y2 * dy + m_z2 + 0.5f; + float c0, c1, c2; + + if (src_x <= -1 || src_x >= src_width || src_y <= -1 || src_y >= src_height) { + // out of range + c0 = const_value_st; + c1 = const_value_st; + c2 = const_value_st; + } else { + int y_low = floorf(src_y); + int x_low = floorf(src_x); + int y_high = y_low + 1; + int x_high = x_low + 1; + + uint8_t const_value[] = {const_value_st, const_value_st, const_value_st}; + float ly = src_y - y_low; + float lx = src_x - x_low; + float hy = 1 - ly; + float hx = 1 - lx; + float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + uint8_t* v1 = const_value; + uint8_t* v2 = const_value; + uint8_t* v3 = const_value; + uint8_t* v4 = const_value; + + if (y_low >= 0) { + if (x_low >= 0) + v1 = src + y_low * src_line_size + x_low * 3; + + if (x_high < src_width) + v2 = src + y_low * src_line_size + x_high * 3; + } + + if (y_high < src_height) { + if (x_low >= 0) + v3 = src + y_high * src_line_size + x_low * 3; + + if (x_high < src_width) + v4 = src + y_high * src_line_size + x_high * 3; + } + + c0 = w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0]; + c1 = w1 * v1[1] + w2 * v2[1] + w3 * v3[1] + w4 * v4[1]; + c2 = w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2]; + } + + // bgr to rgb + float t = c2; + c2 = c0; + c0 = t; + + // normalization + c0 = c0 / 255.0f; + c1 = c1 / 255.0f; + c2 = c2 / 255.0f; + + // rgbrgbrgb to rrrgggbbb + int area = dst_width * dst_height; + float* pdst_c0 = dst + dy * dst_width + dx; + float* pdst_c1 = pdst_c0 + area; + float* pdst_c2 = pdst_c1 + area; + *pdst_c0 = c0; + *pdst_c1 = c1; + *pdst_c2 = c2; +} + +void cuda_preprocess( + uint8_t* src, int src_width, int src_height, + float* dst, int dst_width, int dst_height, + cudaStream_t stream) { + + int img_size = src_width * src_height * 3; + // copy data to pinned memory + memcpy(img_buffer_host, src, img_size); + // copy data to device memory + CUDA_CHECK(cudaMemcpyAsync(img_buffer_device, img_buffer_host, img_size, cudaMemcpyHostToDevice, stream)); + + AffineMatrix s2d, d2s; + float scale = std::min(dst_height / (float)src_height, dst_width / (float)src_width); + + s2d.value[0] = scale; + s2d.value[1] = 0; + s2d.value[2] = -scale * src_width * 0.5 + dst_width * 0.5; + s2d.value[3] = 0; + s2d.value[4] = scale; + s2d.value[5] = -scale * src_height * 0.5 + dst_height * 0.5; + + cv::Mat m2x3_s2d(2, 3, CV_32F, s2d.value); + cv::Mat m2x3_d2s(2, 3, CV_32F, d2s.value); + cv::invertAffineTransform(m2x3_s2d, m2x3_d2s); + + memcpy(d2s.value, m2x3_d2s.ptr(0), sizeof(d2s.value)); + + int jobs = dst_height * dst_width; + int threads = 256; + int blocks = ceil(jobs / (float)threads); + + warpaffine_kernel<<>>( + img_buffer_device, src_width * 3, src_width, + src_height, dst, dst_width, + dst_height, 128, d2s, jobs); +} + +void cuda_batch_preprocess(std::vector& img_batch, + float* dst, int dst_width, int dst_height, + cudaStream_t stream) { + int dst_size = dst_width * dst_height * 3; + for (size_t i = 0; i < img_batch.size(); i++) { + cuda_preprocess(img_batch[i].ptr(), img_batch[i].cols, img_batch[i].rows, &dst[dst_size * i], dst_width, dst_height, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } +} + +void cuda_preprocess_init(int max_image_size) { + // prepare input data in pinned memory + CUDA_CHECK(cudaMallocHost((void**)&img_buffer_host, max_image_size * 3)); + // prepare input data in device memory + CUDA_CHECK(cudaMalloc((void**)&img_buffer_device, max_image_size * 3)); +} + +void cuda_preprocess_destroy() { + CUDA_CHECK(cudaFree(img_buffer_device)); + CUDA_CHECK(cudaFreeHost(img_buffer_host)); +} + diff --git a/algorithm/common_det/cuda/yolov7/preprocess.h b/algorithm/common_det/cuda/yolov7/preprocess.h new file mode 100644 index 0000000..c0dc1aa --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/preprocess.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include +#include + +void cuda_preprocess_init(int max_image_size); +void cuda_preprocess_destroy(); +void cuda_preprocess(uint8_t* src, int src_width, int src_height, + float* dst, int dst_width, int dst_height, + cudaStream_t stream); +void cuda_batch_preprocess(std::vector& img_batch, + float* dst, int dst_width, int dst_height, + cudaStream_t stream); + diff --git a/algorithm/common_det/cuda/yolov7/types.h b/algorithm/common_det/cuda/yolov7/types.h new file mode 100644 index 0000000..8004eda --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/types.h @@ -0,0 +1,17 @@ +#pragma once + +#include "config.h" + +struct YoloKernel { + int width; + int height; + float anchors[kNumAnchor * 2]; +}; + +struct alignas(float) Detection { + float bbox[4]; // center_x center_y w h + float conf; // bbox_conf * cls_conf + float class_id; + float mask[32]; +}; + diff --git a/algorithm/common_det/cuda/yolov7/utils.h b/algorithm/common_det/cuda/yolov7/utils.h new file mode 100644 index 0000000..2dea946 --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/utils.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +static inline int read_files_in_dir(const char* p_dir_name, std::vector& file_names) { + DIR *p_dir = opendir(p_dir_name); + if (p_dir == nullptr) { + return -1; + } + + struct dirent* p_file = nullptr; + while ((p_file = readdir(p_dir)) != nullptr) { + if (strcmp(p_file->d_name, ".") != 0 && + strcmp(p_file->d_name, "..") != 0) { + //std::string cur_file_name(p_dir_name); + //cur_file_name += "/"; + //cur_file_name += p_file->d_name; + std::string cur_file_name(p_file->d_name); + file_names.push_back(cur_file_name); + } + } + + closedir(p_dir); + return 0; +} + +// Function to trim leading and trailing whitespace from a string +static inline std::string trim_leading_whitespace(const std::string& str) { + size_t first = str.find_first_not_of(' '); + if (std::string::npos == first) { + return str; + } + size_t last = str.find_last_not_of(' '); + return str.substr(first, (last - first + 1)); +} + +// Src: https://stackoverflow.com/questions/16605967 +static inline std::string to_string_with_precision(const float a_value, const int n = 2) { + std::ostringstream out; + out.precision(n); + out << std::fixed << a_value; + return out.str(); +} + +static inline int read_labels(const std::string labels_filename, std::unordered_map& labels_map) { + + std::ifstream file(labels_filename); + // Read each line of the file + std::string line; + int index = 0; + while (std::getline(file, line)) { + // Strip the line of any leading or trailing whitespace + line = trim_leading_whitespace(line); + + // Add the stripped line to the labels_map, using the loop index as the key + labels_map[index] = line; + index++; + } + // Close the file + file.close(); + + return 0; +} + diff --git a/algorithm/common_det/cuda/yolov7/yololayer.cu b/algorithm/common_det/cuda/yolov7/yololayer.cu new file mode 100644 index 0000000..d80a9a4 --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/yololayer.cu @@ -0,0 +1,280 @@ +#include "yololayer.h" +#include "cuda_utils.h" + +#include +#include +#include + +namespace Tn { +template +void write(char*& buffer, const T& val) { + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); +} + +template +void read(const char*& buffer, T& val) { + val = *reinterpret_cast(buffer); + buffer += sizeof(T); +} +} + +namespace nvinfer1 { +YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector& vYoloKernel) { + mClassCount = classCount; + mYoloV5NetWidth = netWidth; + mYoloV5NetHeight = netHeight; + mMaxOutObject = maxOut; + is_segmentation_ = is_segmentation; + mYoloKernel = vYoloKernel; + mKernelCount = vYoloKernel.size(); + + CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); + size_t AnchorLen = sizeof(float)* kNumAnchor * 2; + for (int ii = 0; ii < mKernelCount; ii++) { + CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); + const auto& yolo = mYoloKernel[ii]; + CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); + } +} + +YoloLayerPlugin::~YoloLayerPlugin() { + for (int ii = 0; ii < mKernelCount; ii++) { + CUDA_CHECK(cudaFree(mAnchor[ii])); + } + CUDA_CHECK(cudaFreeHost(mAnchor)); +} + +// create the plugin at runtime from a byte stream +YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) { + using namespace Tn; + const char *d = reinterpret_cast(data), *a = d; + read(d, mClassCount); + read(d, mThreadCount); + read(d, mKernelCount); + read(d, mYoloV5NetWidth); + read(d, mYoloV5NetHeight); + read(d, mMaxOutObject); + read(d, is_segmentation_); + mYoloKernel.resize(mKernelCount); + auto kernelSize = mKernelCount * sizeof(YoloKernel); + memcpy(mYoloKernel.data(), d, kernelSize); + d += kernelSize; + CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); + size_t AnchorLen = sizeof(float)* kNumAnchor * 2; + for (int ii = 0; ii < mKernelCount; ii++) { + CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); + const auto& yolo = mYoloKernel[ii]; + CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); + } + assert(d == a + length); +} + +void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT { + using namespace Tn; + char* d = static_cast(buffer), *a = d; + write(d, mClassCount); + write(d, mThreadCount); + write(d, mKernelCount); + write(d, mYoloV5NetWidth); + write(d, mYoloV5NetHeight); + write(d, mMaxOutObject); + write(d, is_segmentation_); + auto kernelSize = mKernelCount * sizeof(YoloKernel); + memcpy(d, mYoloKernel.data(), kernelSize); + d += kernelSize; + + assert(d == a + getSerializationSize()); +} + +size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT { + size_t s = sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount); + s += sizeof(YoloKernel) * mYoloKernel.size(); + s += sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight); + s += sizeof(mMaxOutObject) + sizeof(is_segmentation_); + return s; +} + +int YoloLayerPlugin::initialize() TRT_NOEXCEPT { + return 0; +} + +Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT { + //output the result to channel + int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float); + return Dims3(totalsize + 1, 1, 1); +} + +// Set plugin namespace +void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT { + mPluginNamespace = pluginNamespace; +} + +const char* YoloLayerPlugin::getPluginNamespace() const TRT_NOEXCEPT { + return mPluginNamespace; +} + +// Return the DataType of the plugin output at the requested index +DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT { + return DataType::kFLOAT; +} + +// Return true if output tensor is broadcast across a batch. +bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT { + return false; +} + +// Return true if plugin can use input that is broadcast across batch without replication. +bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT { + return false; +} + +void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT {} + +// Attach the plugin object to an execution context and grant the plugin the access to some context resource. +void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {} + +// Detach the plugin object from its execution context. +void YoloLayerPlugin::detachFromContext() TRT_NOEXCEPT {} + +const char* YoloLayerPlugin::getPluginType() const TRT_NOEXCEPT { + return "YoloLayer_TRT"; +} + +const char* YoloLayerPlugin::getPluginVersion() const TRT_NOEXCEPT { + return "1"; +} + +void YoloLayerPlugin::destroy() TRT_NOEXCEPT { + delete this; +} + +// Clone the plugin +IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT { + YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, is_segmentation_, mYoloKernel); + p->setPluginNamespace(mPluginNamespace); + return p; +} + +__device__ float Logist(float data) { return 1.0f / (1.0f + expf(-data)); }; + +__global__ void CalDetection(const float *input, float *output, int noElements, + const int netwidth, const int netheight, int maxoutobject, int yoloWidth, + int yoloHeight, const float anchors[kNumAnchor * 2], int classes, int outputElem, bool is_segmentation) { + + int idx = threadIdx.x + blockDim.x * blockIdx.x; + if (idx >= noElements) return; + + int total_grid = yoloWidth * yoloHeight; + int bnIdx = idx / total_grid; + idx = idx - total_grid * bnIdx; + int info_len_i = 5 + classes; + if (is_segmentation) info_len_i += 32; + const float* curInput = input + bnIdx * (info_len_i * total_grid * kNumAnchor); + + for (int k = 0; k < kNumAnchor; ++k) { + float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]); + if (box_prob < kIgnoreThresh) continue; + int class_id = 0; + float max_cls_prob = 0.0; + for (int i = 5; i < 5 + classes; ++i) { + float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]); + if (p > max_cls_prob) { + max_cls_prob = p; + class_id = i - 5; + } + } + float *res_count = output + bnIdx * outputElem; + int count = (int)atomicAdd(res_count, 1); + if (count >= maxoutobject) return; + char *data = (char*)res_count + sizeof(float) + count * sizeof(Detection); + Detection *det = (Detection*)(data); + + int row = idx / yoloWidth; + int col = idx % yoloWidth; + + det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * netwidth / yoloWidth; + det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * netheight / yoloHeight; + + det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]); + det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2 * k]; + det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]); + det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1]; + det->conf = box_prob * max_cls_prob; + det->class_id = class_id; + + for (int i = 0; is_segmentation && i < 32; i++) { + det->mask[i] = curInput[idx + k * info_len_i * total_grid + (i + 5 + classes) * total_grid]; + } + } +} + +void YoloLayerPlugin::forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize) { + int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float); + for (int idx = 0; idx < batchSize; ++idx) { + CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream)); + } + int numElem = 0; + for (unsigned int i = 0; i < mYoloKernel.size(); ++i) { + const auto& yolo = mYoloKernel[i]; + numElem = yolo.width * yolo.height * batchSize; + if (numElem < mThreadCount) mThreadCount = numElem; + + CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> > + (inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float*)mAnchor[i], mClassCount, outputElem, is_segmentation_); + } +} + + +int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void* TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT { + forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize); + return 0; +} + +PluginFieldCollection YoloPluginCreator::mFC{}; +std::vector YoloPluginCreator::mPluginAttributes; + +YoloPluginCreator::YoloPluginCreator() { + mPluginAttributes.clear(); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char* YoloPluginCreator::getPluginName() const TRT_NOEXCEPT { + return "YoloLayer_TRT"; +} + +const char* YoloPluginCreator::getPluginVersion() const TRT_NOEXCEPT { + return "1"; +} + +const PluginFieldCollection* YoloPluginCreator::getFieldNames() TRT_NOEXCEPT { + return &mFC; +} + +IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT { + assert(fc->nbFields == 2); + assert(strcmp(fc->fields[0].name, "netinfo") == 0); + assert(strcmp(fc->fields[1].name, "kernels") == 0); + int *p_netinfo = (int*)(fc->fields[0].data); + int class_count = p_netinfo[0]; + int input_w = p_netinfo[1]; + int input_h = p_netinfo[2]; + int max_output_object_count = p_netinfo[3]; + bool is_segmentation = (bool)p_netinfo[4]; + std::vector kernels(fc->fields[1].length); + memcpy(&kernels[0], fc->fields[1].data, kernels.size() * sizeof(YoloKernel)); + YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, is_segmentation, kernels); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; +} + +IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT { + // This object will be deleted when the network is destroyed, which will + // call YoloLayerPlugin::destroy() + YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; +} +} + diff --git a/algorithm/common_det/cuda/yolov7/yololayer.h b/algorithm/common_det/cuda/yolov7/yololayer.h new file mode 100644 index 0000000..a73190b --- /dev/null +++ b/algorithm/common_det/cuda/yolov7/yololayer.h @@ -0,0 +1,106 @@ +#pragma once + +#include "types.h" +#include "macros.h" + +#include +#include + +namespace nvinfer1 { +class API YoloLayerPlugin : public IPluginV2IOExt { +public: + YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector& vYoloKernel); + YoloLayerPlugin(const void* data, size_t length); + ~YoloLayerPlugin(); + + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override; + + int initialize() TRT_NOEXCEPT override; + + virtual void terminate() TRT_NOEXCEPT override {}; + + virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0; } + + virtual int enqueue(int batchSize, const void* const* inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; + + virtual size_t getSerializationSize() const TRT_NOEXCEPT override; + + virtual void serialize(void* buffer) const TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const TRT_NOEXCEPT override { + return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT; + } + + const char* getPluginType() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override; + + IPluginV2IOExt* clone() const TRT_NOEXCEPT override; + + void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override; + + const char* getPluginNamespace() const TRT_NOEXCEPT override; + + DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; + + bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT override; + + bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override; + + void attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override; + + void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + private: + void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1); + int mThreadCount = 256; + const char* mPluginNamespace; + int mKernelCount; + int mClassCount; + int mYoloV5NetWidth; + int mYoloV5NetHeight; + int mMaxOutObject; + bool is_segmentation_; + std::vector mYoloKernel; + void** mAnchor; +}; + +class API YoloPluginCreator : public IPluginCreator { + public: + YoloPluginCreator(); + + ~YoloPluginCreator() override = default; + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override; + + IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override; + + IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + + void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override { + mNamespace = libNamespace; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return mNamespace.c_str(); + } + + private: + std::string mNamespace; + static PluginFieldCollection mFC; + static std::vector mPluginAttributes; +}; +REGISTER_TENSORRT_PLUGIN(YoloPluginCreator); +}; + diff --git a/algorithm/common_det/sv_common_det.cpp b/algorithm/common_det/sv_common_det.cpp new file mode 100644 index 0000000..d1c5853 --- /dev/null +++ b/algorithm/common_det/sv_common_det.cpp @@ -0,0 +1,62 @@ +#include "sv_common_det.h" +#include +#include + +#ifdef WITH_CUDA +#include +#include +#include "common_det_cuda_impl.h" +#endif + + +namespace sv { + + +CommonObjectDetector::CommonObjectDetector() +{ + this->_cuda_impl = new CommonObjectDetectorCUDAImpl; +} +CommonObjectDetector::~CommonObjectDetector() +{ +} + +bool CommonObjectDetector::setupImpl() +{ +#ifdef WITH_CUDA + return this->_cuda_impl->cudaSetup(this); +#endif + return false; +} + +void CommonObjectDetector::detectImpl( + cv::Mat img_, + std::vector& boxes_x_, + std::vector& boxes_y_, + std::vector& boxes_w_, + std::vector& boxes_h_, + std::vector& boxes_label_, + std::vector& boxes_score_, + std::vector& boxes_seg_ +) +{ +#ifdef WITH_CUDA + this->_cuda_impl->cudaDetect( + this, + img_, + boxes_x_, + boxes_y_, + boxes_w_, + boxes_h_, + boxes_label_, + boxes_score_, + boxes_seg_ + ); +#endif +} + + + + + +} + diff --git a/algorithm/ellipse_det/ellipse_detector.cpp b/algorithm/ellipse_det/ellipse_detector.cpp new file mode 100644 index 0000000..afe1d8b --- /dev/null +++ b/algorithm/ellipse_det/ellipse_detector.cpp @@ -0,0 +1,3863 @@ +#include "ellipse_detector.h" +#include +#include +#include + +using namespace std; +using namespace cv; + + + +#ifndef USE_OMP +int omp_get_max_threads() { return 1; } +int omp_get_thread_num() { return 0; } +// int omp_set_num_threads(int){ return 0; } +#endif + + +namespace yaed { + +void _list_dir(std::string dir, std::vector& files, std::string suffixs, bool r) { + // assert(_endswith(dir, "/") || _endswith(dir, "\\")); + + DIR *pdir; + struct dirent *ent; + string childpath; + string absolutepath; + pdir = opendir(dir.c_str()); + assert(pdir != NULL); + + vector suffixd(0); + if (!suffixs.empty() && suffixs != "") { + suffixd = _split(suffixs, "|"); + } + + while ((ent = readdir(pdir)) != NULL) { + if (ent->d_type & DT_DIR) { + if (strcmp(ent->d_name, ".") == 0 || strcmp(ent->d_name, "..") == 0) { + continue; + } + if (r) { // If need to traverse subdirectories + childpath = dir + ent->d_name; + _list_dir(childpath, files); + } + } + else { + if (suffixd.size() > 0) { + bool can_push = false; + for (int i = 0; i < (int)suffixd.size(); i++) { + if (_endswith(ent->d_name, suffixd[i])) + can_push = true; + } + if (can_push) { + absolutepath = dir + ent->d_name; + files.push_back(ent->d_name); // filepath + } + } + else { + absolutepath = dir + ent->d_name; + files.push_back(ent->d_name); // filepath + } + } + } + sort(files.begin(), files.end()); //sort names +} + +vector _split(const string& srcstr, const string& delimeter) +{ + vector ret(0); //use ret save the spilted reault + if (srcstr.empty()) //judge the arguments + { + return ret; + } + string::size_type pos_begin = srcstr.find_first_not_of(delimeter); //find first element of srcstr + + string::size_type dlm_pos; //the delimeter postion + string temp; //use third-party temp to save splited element + while (pos_begin != string::npos) //if not a next of end, continue spliting + { + dlm_pos = srcstr.find(delimeter, pos_begin); //find the delimeter symbol + if (dlm_pos != string::npos) + { + temp = srcstr.substr(pos_begin, dlm_pos - pos_begin); + pos_begin = dlm_pos + delimeter.length(); + } + else + { + temp = srcstr.substr(pos_begin); + pos_begin = dlm_pos; + } + if (!temp.empty()) + ret.push_back(temp); + } + return ret; +} + +bool _startswith(const std::string& str, const std::string& start) +{ + size_t srclen = str.size(); + size_t startlen = start.size(); + if (srclen >= startlen) + { + string temp = str.substr(0, startlen); + if (temp == start) + return true; + } + + return false; +} + +bool _endswith(const std::string& str, const std::string& end) +{ + size_t srclen = str.size(); + size_t endlen = end.size(); + if (srclen >= endlen) + { + string temp = str.substr(srclen - endlen, endlen); + if (temp == end) + return true; + } + + return false; +} + +int inline randint(int l, int u) +{ + return l + rand() % (u - l + 1); + // rand() % (u-l+1) -> [0, u-l] + // [0, u-l] -> [l, u] +} + +void _randperm(int n, int m, int arr[], bool sort_) +{ + int* x = (int*)malloc(sizeof(int)*n); + for (int i = 0; i < n; ++i) + x[i] = i; + for (int i = 0; i < m; ++i) + { + int j = randint(i, n - 1); + int t = x[i]; x[i] = x[j]; x[j] = t; // swap(x[i], x[j]); + } + if (sort_) + sort(x, x + m); + for (int i = 0; i < m; ++i) + arr[i] = x[i]; + free(x); +} + +#define FL_PI 3.14159265358979323846f +#define FL_1_2_PI 1.57079632679f +#define FL_2__PI 6.28318530718 +float _atan2(float y, float x) +{ + float ang(0); + if (x > 0) + ang = atanf(y / x); + else if (y >= 0 && x < 0) + ang = atanf(y / x) + FL_PI; + else if (y < 0 && x < 0) + ang = atanf(y / x) - FL_PI; + else if (y > 0 && x == 0) + ang = FL_1_2_PI; + else if (y < 0 && x == 0) + ang = -FL_1_2_PI; + else // (y == 0 && x == 0) + ang = INFINITY; + // if (ang < 0) ang += FL_2__PI; + return ang; +} + +void _mean_std(std::vector& vec, float& mean, float& std) +{ + float sum = std::accumulate(std::begin(vec), std::end(vec), 0.0); + mean = sum / vec.size(); + + float accum = 0.0; + std::for_each(std::begin(vec), std::end(vec), [&](const double d) { + accum += (d - mean)*(d - mean); + }); + + std = sqrt(accum / (vec.size() - 1)); +} + +float _get_min_angle_PI(float alpha, float beta) +{ + float pi = float(CV_PI); + float pi2 = float(2.0 * CV_PI); + + //normalize data in [0, 2*pi] + float a = fmod(alpha + pi2, pi2); + float b = fmod(beta + pi2, pi2); + + //normalize data in [0, pi] + if (a > pi) + a -= pi; + if (b > pi) + b -= pi; + + if (a > b) + { + swap(a, b); + } + + float diff1 = b - a; + float diff2 = pi - diff1; + return min(diff1, diff2); +} + + +void _load_ellipse_GT(const string& gt_file_name, vector& gt_ellipses, bool is_angle_radians) +{ + ifstream in(gt_file_name); + if (!in.good()) + { + cout << "Error opening: " << gt_file_name << endl; + return; + } + + unsigned n; + in >> n; + + gt_ellipses.clear(); + gt_ellipses.reserve(n); + + while (in.good() && n--) + { + Ellipse e; + in >> e.xc_ >> e.yc_ >> e.a_ >> e.b_ >> e.rad_; + + if (!is_angle_radians) + { + // convert to radians + e.rad_ = float(e.rad_ * CV_PI / 180.0); + } + + if (e.a_ < e.b_) + { + float temp = e.a_; + e.a_ = e.b_; + e.b_ = temp; + + e.rad_ = e.rad_ + float(0.5*CV_PI); + } + + e.rad_ = fmod(float(e.rad_ + 2.f*CV_PI), float(CV_PI)); + e.score_ = 1.f; + gt_ellipses.push_back(e); + } + in.close(); +} + +void _load_ellipse_DT(const string& dt_file_name, vector& dt_ellipses, bool is_angle_radians) +{ + ifstream in(dt_file_name); + if (!in.good()) + { + cout << "Error opening: " << dt_file_name << endl; + return; + } + + unsigned n; + in >> n; + + dt_ellipses.clear(); + dt_ellipses.reserve(n); + + while (in.good() && n--) + { + Ellipse e; + in >> e.xc_ >> e.yc_ >> e.a_ >> e.b_ >> e.rad_ >> e.score_; + + if (!is_angle_radians) + { + // convert to radians + e.rad_ = float(e.rad_ * CV_PI / 180.0); + } + + if (e.a_ < e.b_) + { + float temp = e.a_; + e.a_ = e.b_; + e.b_ = temp; + + e.rad_ = e.rad_ + float(0.5*CV_PI); + } + + e.rad_ = fmod(float(e.rad_ + 2.f*CV_PI), float(CV_PI)); + e.score_ = 1.f; + dt_ellipses.push_back(e); + } + in.close(); +} + +bool _ellipse_overlap(const Mat1b& gt, const Mat1b& dt, float th) +{ + float f_and = float(countNonZero(gt & dt)); + float f_or = float(countNonZero(gt | dt)); + float f_sim = f_and / f_or; + + return (f_sim >= th); +} + +float _ellipse_overlap_real(const Mat1b& gt, const Mat1b& dt) +{ + float f_and = float(countNonZero(gt & dt)); + float f_or = float(countNonZero(gt | dt)); + float f_sim = f_and / f_or; + + return f_sim; +} + +int _bool_count(const std::vector vb) +{ + int counter = 0; + for (unsigned i = 0; i& ell_gt, const vector& ell_dt, const Mat3b& img) +{ + float threshold_overlap = 0.8f; + // float threshold = 0.95f; + + unsigned sz_gt = ell_gt.size(); + unsigned sz_dt = ell_dt.size(); + + unsigned sz_dt_min = unsigned(min(1000, int(sz_dt))); + + vector mat_gts(sz_gt); + vector mat_dts(sz_dt_min); + + // Draw each ground-Truth ellipse + for (unsigned i = 0; i vec_gt(sz_gt, false); + vector vec_dt(sz_dt_min, false); + + // Each row in the matrix has one means the ellipse be found + for (unsigned int i = 0; i < sz_dt_min; ++i) + { + for (unsigned int j = 0; j < sz_gt; ++j) + { + bool b_found = overlap(j, i) != 0; + if (b_found) + { + vec_gt[j] = true; + vec_dt[i] = true; + } + } + } + + float tp = _bool_count(vec_gt); + float fn = int(sz_gt) - tp; + float fp = sz_dt - _bool_count(vec_dt); // !!!! sz_dt - _bool_count(vec_dt); // + + float pr(0.f); + float re(0.f); + float fmeasure(0.f); + + if (tp == 0) { + if (fp == 0) { + pr = 1.f; + re = 0.f; + fmeasure = (2.f * pr * re) / (pr + re); + } + else { + pr = 0.f; + re = 0.f; + fmeasure = 0.f; + } + } + else { + pr = float(tp) / float(tp + fp); + re = float(tp) / float(tp + fn); + fmeasure = (2.f * pr * re) / (pr + re); + } + + return fmeasure; +} + +float _ellipse_evaluate(vector& image_fns, vector& gt_fns, vector& dt_fns, bool gt_angle_radians) +{ + float fmeasure(0.f); + for (int i = 0; i < image_fns.size(); i++) { + Mat3b image = imread(image_fns[i]); + + vector ell_gt, ell_dt; + _load_ellipse_GT(gt_fns[i], ell_gt, gt_angle_radians); + _load_ellipse_DT(dt_fns[i], ell_dt); + + int tp, fn, fp; + fmeasure += _ellipse_evaluate_one(ell_gt, ell_dt, image); + + } + + fmeasure /= image_fns.size(); + return fmeasure; +} + +Point2f inline _lineCrossPoint(Point2f l1p1, Point2f l1p2, Point2f l2p1, Point2f l2p2) +{ + Point2f crossPoint; + float k1, k2, b1, b2; + if (l1p1.x == l1p2.x&&l2p1.x == l2p2.x) { + crossPoint = Point2f(0, 0); // invalid point + return crossPoint; + } + if (l1p1.x == l1p2.x) + { + crossPoint.x = l1p1.x; + k2 = (l2p2.y - l2p1.y) / (l2p2.x - l2p1.x); + b2 = l2p1.y - k2*l2p1.x; + crossPoint.y = k2*crossPoint.x + b2; + return crossPoint; + } + if (l2p1.x == l2p2.x) + { + crossPoint.x = l2p1.x; + k2 = (l1p2.y - l1p1.y) / (l1p2.x - l1p1.x); + b2 = l1p1.y - k2*l1p1.x; + crossPoint.y = k2*crossPoint.x + b2; + return crossPoint; + } + + k1 = (l1p2.y - l1p1.y) / (l1p2.x - l1p1.x); + k2 = (l2p2.y - l2p1.y) / (l2p2.x - l2p1.x); + b1 = l1p1.y - k1*l1p1.x; + b2 = l2p1.y - k2*l2p1.x; + if (k1 == k2) + { + crossPoint = Point2f(0, 0); // invalid point + } + else + { + crossPoint.x = (b2 - b1) / (k1 - k2); + crossPoint.y = k1*crossPoint.x + b1; + } + return crossPoint; +} + +void inline _point2Mat(Point2f p1, Point2f p2, float mat[2][2]) +{ + mat[0][0] = p1.x; + mat[0][1] = p1.y; + mat[1][0] = p2.x; + mat[1][1] = p2.y; +} + +float _value4SixPoints(cv::Point2f p3, cv::Point2f p2, cv::Point2f p1, cv::Point2f p4, cv::Point2f p5, cv::Point2f p6) +{ + float result = 1; + Mat A, B, C; + float matB[2][2], matC[2][2]; + Point2f v, w, u; + v = _lineCrossPoint(p1, p2, p3, p4); + w = _lineCrossPoint(p5, p6, p3, p4); + u = _lineCrossPoint(p5, p6, p1, p2); + + _point2Mat(u, v, matB); + _point2Mat(p1, p2, matC); + B = Mat(2, 2, CV_32F, matB); + C = Mat(2, 2, CV_32F, matC); + A = C*B.inv(); + + // cout<<"u:\t"<(0, 0)*A.at(1, 0) / (A.at(0, 1)*A.at(1, 1)); + + _point2Mat(p3, p4, matC); + _point2Mat(v, w, matB); + B = Mat(2, 2, CV_32F, matB); + C = Mat(2, 2, CV_32F, matC); + A = C*B.inv(); + result *= A.at(0, 0)*A.at(1, 0) / (A.at(0, 1)*A.at(1, 1)); + + _point2Mat(p5, p6, matC); + _point2Mat(w, u, matB); + B = Mat(2, 2, CV_32F, matB); + C = Mat(2, 2, CV_32F, matC); + A = C*B.inv(); + result *= A.at(0, 0)*A.at(1, 0) / (A.at(0, 1)*A.at(1, 1)); + return result; +} + +/*----------------------------------------------------------------------------*/ +/** Compute ellipse foci, given ellipse params. +*/ +void _ellipse_foci(float *param, float *foci) +{ + float f; + /* check parameters */ + if (param == NULL) fprintf(stderr, "ellipse_foci: invalid input ellipse."); + if (foci == NULL) fprintf(stderr, "ellipse_foci: 'foci' must be non null."); + + f = sqrt(param[2] * param[2] - param[3] * param[3]); + foci[0] = param[0] + f * cos(param[4]); + foci[1] = param[1] + f * sin(param[4]); + foci[2] = param[0] - f * cos(param[4]); + foci[3] = param[1] - f * sin(param[4]); +} + +/*----------------------------------------------------------------------------*/ +/** Signed angle difference. +*/ +float angle_diff_signed(float a, float b) +{ + a -= b; + while (a <= -M_PI) a += M_2__PI; + while (a > M_PI) a -= M_2__PI; + return a; +} + +/*----------------------------------------------------------------------------*/ +/** Absolute value angle difference. +*/ +float _angle_diff(float a, float b) +{ + a -= b; + while (a <= -M_PI) a += M_2__PI; + while (a > M_PI) a -= M_2__PI; + if (a < 0.0) a = -a; + return a; +} + +/*----------------------------------------------------------------------------*/ +/** Compute the angle of the normal to a point belonging to an ellipse +using the focal property. +*/ +float _ellipse_normal_angle(float x, float y, float *foci) +{ + float tmp1, tmp2, tmp3, theta; + /* check parameters */ + if (foci == NULL) fprintf(stderr, "ellipse_normal_angle: 'foci' must be non null"); + + tmp1 = atan2(y - foci[1], x - foci[0]); + tmp2 = atan2(y - foci[3], x - foci[2]); + tmp3 = angle_diff_signed(tmp1, tmp2); + + theta = tmp1 - tmp3 / 2.0; + while (theta <= -M_PI) theta += M_2__PI; + while (theta > M_PI) theta -= M_2__PI; + return theta; +} + + +void cv_canny(cv::Mat& src, cv::Mat& dst, + cv::Mat& dx, cv::Mat& dy, + int aperture_size, bool L2gradient, double percent_ne) { + + cv::AutoBuffer buffer; + std::vector stack; + uchar **stack_top = 0, **stack_bottom = 0; + + if (CV_MAT_TYPE(src.type()) != CV_8UC1 || + CV_MAT_TYPE(dst.type()) != CV_8UC1 || + CV_MAT_TYPE(dx.type()) != CV_16SC1 || + CV_MAT_TYPE(dy.type()) != CV_16SC1) + CV_Error(CV_StsUnsupportedFormat, ""); + + if (!CV_ARE_SIZES_EQ(&src, &dst)) + CV_Error(CV_StsUnmatchedSizes, ""); + + aperture_size &= INT_MAX; + if ((aperture_size & 1) == 0 || aperture_size < 3 || aperture_size > 7) + CV_Error(CV_StsBadFlag, ""); + + int i, j; + CvSize size; + size.width = src.cols; + size.height = src.rows; + + //cv::Sobel(src, dx, 1, 0, aperture_size); + //cv::Sobel(src, dy, 0, 1, aperture_size); + cv::Sobel(src, dx, CV_16S, 1, 0, aperture_size); + cv::Sobel(src, dy, CV_16S, 0, 1, aperture_size); + + // double min, max; + // cv::minMaxLoc(Mat(dx->rows, dx->cols, CV_16SC1, dx->data.fl), &min, &max); + // cout << "min: " << min << ", max: " << max << endl; + + Mat1f magGrad(size.height, size.width, 0.f); + float maxGrad(0); + float val(0); + for (i = 0; i < size.height; ++i) + { + float* _pmag = magGrad.ptr(i); + //const short* _dx = (short*)(dx.data.ptr + dx.step*i); + //const short* _dy = (short*)(dy.data.ptr + dy.step*i); + for (j = 0; j < size.width; ++j) + { + //val = float(abs(_dx[j]) + abs(_dy[j])); + val = float(abs(dx.at(i, j)) + abs(dy.at(i, j))); + _pmag[j] = val; + maxGrad = (val > maxGrad) ? val : maxGrad; + } + } + // cout << "maxGrad: " << maxGrad << endl; + + // set magic numbers + const int NUM_BINS = 64; + const double percent_of_pixels_not_edges = percent_ne; + const double threshold_ratio = 0.3; + + // compute histogram + int bin_size = cvFloor(maxGrad / float(NUM_BINS) + 0.5f) + 1; + if (bin_size < 1) bin_size = 1; + int bins[NUM_BINS] = { 0 }; + for (i = 0; i < size.height; ++i) + { + float *_pmag = magGrad.ptr(i); + for (j = 0; j < size.width; ++j) + { + int hgf = int(_pmag[j]); + bins[int(_pmag[j]) / bin_size]++; + } + } + // for (int i = 0; i < NUM_BINS; i++) + // cout << "BIN " << i << " :" << bins[i] << endl; + + // Select the thresholds + float total(0.f); + float target = float(size.height * size.width * percent_of_pixels_not_edges); + int low_thresh, high_thresh(0); + + while (total < target) + { + total += bins[high_thresh]; + high_thresh++; + } + high_thresh *= bin_size; + low_thresh = cvFloor(threshold_ratio * float(high_thresh)); + // cout << "low_thresh: " << low_thresh << ", high_thresh: " << high_thresh << endl; + + int low, high, maxsize; + int* mag_buf[3]; + uchar* map; + ptrdiff_t mapstep; + + if (L2gradient) { + Cv32suf ul, uh; + ul.f = (float)low_thresh; + uh.f = (float)high_thresh; + + low = ul.i; + high = uh.i; + } + else { + low = cvFloor(low_thresh); + high = cvFloor(high_thresh); + } + + buffer.allocate((size.width + 2)*(size.height + 2) + (size.width + 2) * 3 * sizeof(int)); + // cout << sizeof(int) << ", " << (size.width + 2)*(size.height + 2) + (size.width + 2) * 3 * sizeof(int) << endl; + mag_buf[0] = (int*)(char*)buffer; + mag_buf[1] = mag_buf[0] + size.width + 2; + mag_buf[2] = mag_buf[1] + size.width + 2; + map = (uchar*)(mag_buf[2] + size.width + 2); + mapstep = size.width + 2; + + maxsize = MAX(1 << 10, size.width*size.height / 10); + stack.resize(maxsize); + stack_top = stack_bottom = &stack[0]; + + memset(mag_buf[0], 0, (size.width + 2) * sizeof(int)); + memset(map, 1, mapstep); + memset(map + mapstep*(size.height + 1), 1, mapstep); + + /* sector numbers + (Top-Left Origin) + + 1 2 3 + * * * + * * * + 0*******0 + * * * + * * * + 3 2 1 + */ + +#define CANNY_PUSH(d) *(d) = (uchar)2, *stack_top++ = (d) +#define CANNY_POP(d) (d) = *--stack_top + + CvMat mag_row = cvMat(1, size.width, CV_32F); + + // Mat push_show = Mat::zeros(size.height+1, size.width+1, CV_8U); + + // calculate magnitude and angle of gradient, perform non-maxima supression. + // fill the map with one of the following values: + // 0 - the pixel might belong to an edge + // 1 - the pixel can not belong to an edge + // 2 - the pixel does belong to an edge + for (i = 0; i <= size.height; i++) + { + int* _mag = mag_buf[(i > 0) + 1] + 1; + float* _magf = (float*)_mag; + //const short* _dx = (short*)(dx->data.ptr + dx->step*i); + //const short* _dy = (short*)(dy->data.ptr + dy->step*i); + uchar* _map; + int x, y; + ptrdiff_t magstep1, magstep2; + int prev_flag = 0; + + if (i < size.height) + { + _mag[-1] = _mag[size.width] = 0; + + if (!L2gradient) + for (j = 0; j < size.width; j++) + _mag[j] = abs(dx.at(i, j)) + abs(dy.at(i, j)); + else + { + for (j = 0; j < size.width; j++) + { + x = dx.at(i, j); y = dy.at(i, j); + _magf[j] = (float)std::sqrt((double)x*x + (double)y*y); + } + } + } + else + memset(_mag - 1, 0, (size.width + 2) * sizeof(int)); + + // at the very beginning we do not have a complete ring + // buffer of 3 magnitude rows for non-maxima suppression + if (i == 0) + continue; + + _map = map + mapstep*i + 1; + _map[-1] = _map[size.width] = 1; + + _mag = mag_buf[1] + 1; // take the central row + //_dx = (short*)(dx->data.ptr + dx->step*(i - 1)); + //_dy = (short*)(dy->data.ptr + dy->step*(i - 1)); + + magstep1 = mag_buf[2] - mag_buf[1]; + magstep2 = mag_buf[0] - mag_buf[1]; + + if ((stack_top - stack_bottom) + size.width > maxsize) + { + int sz = (int)(stack_top - stack_bottom); + maxsize = MAX(maxsize * 3 / 2, maxsize + 8); + stack.resize(maxsize); + stack_bottom = &stack[0]; + stack_top = stack_bottom + sz; + } + +#define CANNY_SHIFT 15 +#define TG22 (int)(0.4142135623730950488016887242097*(1<(i, j); + y = dy.at(i, j); + int s = x ^ y; // XOR + int m = _mag[j]; + + x = abs(x); + y = abs(y); + if (m > low) + { + int tg22x = x * TG22; + int tg67x = tg22x + ((x + x) << CANNY_SHIFT); + int tmp = 1 << CANNY_SHIFT; + y <<= CANNY_SHIFT; + + if (y < tg22x) { + if (m > _mag[j - 1] && m >= _mag[j + 1]) { + if (m > high && !prev_flag && _map[j - mapstep] != 2) { + CANNY_PUSH(_map + j); // push_show.at(i, j) = 255; + prev_flag = 1; + } + else { + _map[j] = (uchar)0; + } + continue; + } + } + else if (y > tg67x) { + if (m > _mag[j + magstep2] && m >= _mag[j + magstep1]) { + if (m > high && !prev_flag && _map[j - mapstep] != 2) { + CANNY_PUSH(_map + j); // push_show.at(i, j) = 255; + prev_flag = 1; + } + else { + _map[j] = (uchar)0; + } + continue; + } + } + else { + s = s < 0 ? -1 : 1; + if (m > _mag[j + magstep2 - s] && m > _mag[j + magstep1 + s]) { + if (m > high && !prev_flag && _map[j - mapstep] != 2) { + CANNY_PUSH(_map + j); // push_show.at(i, j) = 255; + prev_flag = 1; + } + else { + _map[j] = (uchar)0; + } + continue; + } + } + } + prev_flag = 0; + _map[j] = (uchar)1; + } + + // scroll the ring buffer + _mag = mag_buf[0]; + mag_buf[0] = mag_buf[1]; + mag_buf[1] = mag_buf[2]; + mag_buf[2] = _mag; + } + + // imshow("mag", push_show); waitKey(); + // now track the edges (hysteresis thresholding) + while (stack_top > stack_bottom) + { + uchar* m; + if ((stack_top - stack_bottom) + 8 > maxsize) + { + int sz = (int)(stack_top - stack_bottom); + maxsize = MAX(maxsize * 3 / 2, maxsize + 8); + stack.resize(maxsize); + stack_bottom = &stack[0]; + stack_top = stack_bottom + sz; + } + + CANNY_POP(m); + + if (!m[-1]) + CANNY_PUSH(m - 1); + if (!m[1]) + CANNY_PUSH(m + 1); + if (!m[-mapstep - 1]) + CANNY_PUSH(m - mapstep - 1); + if (!m[-mapstep]) + CANNY_PUSH(m - mapstep); + if (!m[-mapstep + 1]) + CANNY_PUSH(m - mapstep + 1); + if (!m[mapstep - 1]) + CANNY_PUSH(m + mapstep - 1); + if (!m[mapstep]) + CANNY_PUSH(m + mapstep); + if (!m[mapstep + 1]) + CANNY_PUSH(m + mapstep + 1); + } + + // the final pass, form the final image + for (i = 0; i < size.height; i++) { + const uchar* _map = map + mapstep*(i + 1) + 1; + //uchar* _dst = dst->data.ptr + dst->step*i; + + for (j = 0; j < size.width; j++) { + // if (_map[j] == 2) + // cout << (int)_map[j] << ", " << (int)(_map[j] >> 1) << ", " << (int)(uchar)-(_map[j] >> 1) << endl; + dst.at(i, j) = (uchar)-(_map[j] >> 1); + } + } +} + +void _tag_canny(InputArray image, OutputArray _edges, + OutputArray _sobel_x, OutputArray _sobel_y, + int apertureSize, bool L2gradient, double percent_ne) { + + Mat src = image.getMat(); + _edges.create(src.size(), CV_8U); + _sobel_x.create(src.size(), CV_16S); + _sobel_y.create(src.size(), CV_16S); + + Mat c_src = src; + Mat c_dst = _edges.getMat(); + Mat c_dx = _sobel_x.getMat(); + Mat c_dy = _sobel_y.getMat(); + + cv_canny(c_src, c_dst, + c_dx, c_dy, + apertureSize, L2gradient, percent_ne); +} + +void _find_contours_eight(cv::Mat1b& image, std::vector& segments, int iMinLength) +{ + vector ims; ims.resize(8); + for (int i = 0; i < 8; i++) { + ims[i] = Mat1b::zeros(image.size()); + } + for (int r = 0; r < image.rows; r++) { + uchar* _e8 = image.ptr(r); + vector _es; _es.resize(8); + for (int i = 0; i < 8; i++) + _es[i] = ims[i].ptr(r); + + for (int c = 0; c < image.cols; c++) { + for (int i = 0; i < 8; i++) { + if (_e8[c] == (uchar)(i + 1)) { + _es[i][c] = (uchar)255; + } + } + } + } + + segments.resize(8); + for (int i = 0; i < 8; i++) { + _tag_find_contours(ims[i], segments[i], iMinLength); + } +} + +void _tag_find_contours(cv::Mat1b& image, VVP& segments, int iMinLength) { +#define RG_STACK_SIZE 8192*4 + + + // use stack to speed up the processing of global (even at the expense of memory occupied) + int stack2[RG_STACK_SIZE]; +#define RG_PUSH2(a) (stack2[sp2] = (a) , sp2++) +#define RG_POP2(a) (sp2-- , (a) = stack2[sp2]) + + // use stack to speed up the processing of global (even at the expense of memory occupied) + Point stack3[RG_STACK_SIZE]; +#define RG_PUSH3(a) (stack3[sp3] = (a) , sp3++) +#define RG_POP3(a) (sp3-- , (a) = stack3[sp3]) + + int i, w, h, iDim; + int x, y; + int x2, y2; + int sp2; // stack pointer + int sp3; + + Mat_ src = image.clone(); + w = src.cols; + h = src.rows; + iDim = w*h; + + Point point; + for (y = 0; y0) + { // rg traditional + RG_POP2(i); + x2 = i%w; + y2 = i / w; + + point.x = x2; + point.y = y2; + + if (src(y2, x2)) + { + RG_PUSH3(point); + src(y2, x2) = 0; + } + + // insert the new points in the stack only if there are + // and they are points labelled + + // 4 connected + // left + if (x2>0 && (src(y2, x2 - 1) != 0)) + RG_PUSH2(i - 1); + // up + if (y2>0 && (src(y2 - 1, x2) != 0)) + RG_PUSH2(i - w); + // down + if (y20 && y2>0 && (src(y2 - 1, x2 - 1) != 0)) + RG_PUSH2(i - w - 1); + if (x2>0 && y20 && (src(y2 - 1, x2 + 1) != 0)) + RG_PUSH2(i - w + 1); + if (x2= iMinLength) + { + vector component; + component.reserve(sp3); + + // push it to the points + for (i = 0; i src = image.clone(); + w = src.cols; + h = src.rows; + iDim = w*h; + + Point point; + for (y = 0; y component; + point.x = x; + point.y = y; + component.push_back(point); + x2 = x; + y2 = y; + + bool found = true; + sp2 = 0; + while (found && component.size() < 3) + { + found = false; + if (x2 > 0 && y2 < h - 1 && (src(y2 + 1, x2 - 1) != 0)) + { + src(y2 + 1, x2 - 1) = 0; if (!found) { found = true; point.x = x2 - 1; point.y = y2 + 1; component.push_back(point); } + } + if (x2 < w - 1 && y2 < h - 1 && (src(y2 + 1, x2 + 1) != 0)) + { + src(y2 + 1, x2 + 1) = 0; if (!found) { found = true; point.x = x2 + 1; point.y = y2 + 1; component.push_back(point); } + } + if (y2 < h - 1 && (src(y2 + 1, x2) != 0)) + { + src(y2 + 1, x2) = 0; if (!found) { found = true; point.x = x2; point.y = y2 + 1; component.push_back(point); } + } + if (x2 > 0 && (src(y2, x2 - 1) != 0)) + { + src(y2, x2 - 1) = 0; if (!found) { found = true; point.x = x2 - 1; point.y = y2; component.push_back(point); } + } + if (x2 < w - 1 && (src(y2, x2 + 1) != 0)) + { + src(y2, x2 + 1) = 0; if (!found) { found = true; point.x = x2 + 1; point.y = y2; component.push_back(point); } + } + if (x2 > 0 && y2 > 0 && (src(y2 - 1, x2 - 1) != 0)) + { + src(y2 - 1, x2 - 1) = 0; if (!found) { found = true; point.x = x2 - 1; point.y = y2 - 1; component.push_back(point); } + } + if (x2 < w - 1 && y2 > 0 && (src(y2 - 1, x2 + 1) != 0)) + { + src(y2 - 1, x2 + 1) = 0; if (!found) { found = true; point.x = x2 + 1; point.y = y2 - 1; component.push_back(point); } + } + if (y2 > 0 && (src(y2 - 1, x2) != 0)) + { + src(y2 - 1, x2) = 0; if (!found) { found = true; point.x = x2; point.y = y2 - 1; component.push_back(point); } + } + if (found) + { + sp2++; x2 = component[sp2].x; y2 = component[sp2].y; + } + } + + if (component.size() < 3) continue; + ang = _atan2(component[2].y - component[0].y, component[2].x - component[0].x); + sp1 = 0; + + found = true; + while (found) + { + ang_dmin = 1e3; + found = false; + if (x2 > 0 && y2 < h - 1 && (src(y2 + 1, x2 - 1) != 0)) + { + ang_t = _atan2(y2 + 1 - component[sp1].y, x2 - 1 - component[sp1].x); + ang_d = abs(ang - ang_t); + src(y2 + 1, x2 - 1) = 0; if (ang_d < ang_dmin) { ang_dmin = ang_d; ang_i = ang_t; xn = x2 - 1; yn = y2 + 1; } + } + if (x2 < w - 1 && y2 < h - 1 && (src(y2 + 1, x2 + 1) != 0)) + { + ang_t = _atan2(y2 + 1 - component[sp1].y, x2 + 1 - component[sp1].x); + ang_d = abs(ang - ang_t); + src(y2 + 1, x2 + 1) = 0; if (ang_d < ang_dmin) { ang_dmin = ang_d; ang_i = ang_t; xn = x2 + 1; yn = y2 + 1; } + } + if (y2 < h - 1 && (src(y2 + 1, x2) != 0)) + { + ang_t = _atan2(y2 + 1 - component[sp1].y, x2 - component[sp1].x); + ang_d = abs(ang - ang_t); + src(y2 + 1, x2) = 0; if (ang_d < ang_dmin) { ang_dmin = ang_d; ang_i = ang_t; xn = x2; yn = y2 + 1; } + } + if (x2 > 0 && (src(y2, x2 - 1) != 0)) + { + ang_t = _atan2(y2 - component[sp1].y, x2 - 1 - component[sp1].x); + ang_d = abs(ang - ang_t); + src(y2, x2 - 1) = 0; if (ang_d < ang_dmin) { ang_dmin = ang_d; ang_i = ang_t; xn = x2 - 1; yn = y2; } + } + if (x2 < w - 1 && (src(y2, x2 + 1) != 0)) + { + ang_t = _atan2(y2 - component[sp1].y, x2 + 1 - component[sp1].x); + ang_d = abs(ang - ang_t); + src(y2, x2 + 1) = 0; if (ang_d < ang_dmin) { ang_dmin = ang_d; ang_i = ang_t; xn = x2 + 1; yn = y2; } + } + if (x2 > 0 && y2 > 0 && (src(y2 - 1, x2 - 1) != 0)) + { + ang_t = _atan2(y2 - 1 - component[sp1].y, x2 - 1 - component[sp1].x); + ang_d = abs(ang - ang_t); + src(y2 - 1, x2 - 1) = 0; if (ang_d < ang_dmin) { ang_dmin = ang_d; ang_i = ang_t; xn = x2 - 1; yn = y2 - 1; } + } + if (x2 < w - 1 && y2 > 0 && (src(y2 - 1, x2 + 1) != 0)) + { + ang_t = _atan2(y2 - 1 - component[sp1].y, x2 + 1 - component[sp1].x); + ang_d = abs(ang - ang_t); + src(y2 - 1, x2 + 1) = 0; if (ang_d < ang_dmin) { ang_dmin = ang_d; ang_i = ang_t; xn = x2 + 1; yn = y2 - 1; } + } + if (y2 > 0 && (src(y2 - 1, x2) != 0)) + { + ang_t = _atan2(y2 - 1 - component[sp1].y, x2 - component[sp1].x); + ang_d = abs(ang - ang_t); + src(y2 - 1, x2) = 0; if (ang_d < ang_dmin) { ang_dmin = ang_d; ang_i = ang_t; xn = x2; yn = y2 - 1; } + } + if (ang_dmin < M_1_2_PI) + { + found = true; point.x = xn; point.y = yn; component.push_back(point); + x2 = xn; + y2 = yn; + sp1++; sp2++; + if (sp2 >= 9 && sp2 % 3 == 0) + { + float a1 = _atan2(component[sp2 - 6].y - component[sp2 - 9].y, component[sp2 - 6].x - component[sp2 - 9].x); + float a2 = _atan2(component[sp2 - 3].y - component[sp2 - 6].y, component[sp2 - 3].x - component[sp2 - 6].x); + theta_s = a1 - a2; + + a1 = _atan2(component[sp2 - 3].y - component[sp2 - 6].y, component[sp2 - 3].x - component[sp2 - 6].x); + a2 = _atan2(component[sp2].y - component[sp2 - 3].y, component[sp2].x - component[sp2 - 3].x); + theta_i = a1 - a2; + if (abs(theta_s - theta_i) > 0.6) break; + } + ang = ang_i; + } + } + + if (component.size() >= iMinLength) + { + segments.push_back(component); + } + } + } + } +} + + + +void _tag_show_contours(cv::Mat1b& image, VVP& segments, const char* title) +{ + Mat3b contoursIm(image.rows, image.cols, Vec3b(0, 0, 0)); + for (size_t i = 0; i& segments8, const char* title) +{ + Mat3b contoursIm(image.rows, image.cols, Vec3b(0, 0, 0)); + for (size_t i = 0; i < segments8.size(); ++i) + { + Vec3b color(rand() % 255, 128 + rand() % 127, 128 + rand() % 127); + for (size_t j = 0; j < segments8[i].size(); ++j) + { + for (size_t k = 0; k < segments8[i][j].size(); ++k) + contoursIm(segments8[i][j][k]) = color; + } + } + imshow(title, contoursIm); +} + +bool _SortBottomLeft2TopRight(const Point& lhs, const Point& rhs) +{ + if (lhs.x == rhs.x) + { + return lhs.y > rhs.y; + } + return lhs.x < rhs.x; +} + +bool _SortBottomLeft2TopRight2f(const Point2f& lhs, const Point2f& rhs) +{ + if (lhs.x == rhs.x) + { + return lhs.y > rhs.y; + } + return lhs.x < rhs.x; +} + +bool _SortTopLeft2BottomRight(const Point& lhs, const Point& rhs) +{ + if (lhs.x == rhs.x) + { + return lhs.y < rhs.y; + } + return lhs.x < rhs.x; +} + + + + + + +// #define DEBUG_SPEED +// #define DEBUG_ELLFIT +// #define DEBUG_PREPROCESSING + +// #define DISCARD_TCN +#define DISCARD_TCN2 +#define DISCARD_CONSTRAINT_OBOX + +// #define DISCARD_CONSTRAINT_CONVEXITY +// #define DISCARD_CONSTRAINT_POSITION +#define CONSTRAINT_CNC_1 +// #define CONSTRAINT_CNC_2 +// #define CONSTRAINT_CNC_3 +// #define DISCARD_CONSTRAINT_CENTER + +// #define T_CNC 0.2f +// #define T_TCN_L 0.4f // filter lines +// #define T_TCN_P 0.6f + +// #define Thre_r 0.2f + + +void _concate_arcs(VP& arc1, VP& arc2, VP& arc3, VP& arc) +{ + for (int i = 0; i < arc1.size(); i++) + arc.push_back(arc1[i]); + for (int i = 0; i < arc2.size(); i++) + arc.push_back(arc2[i]); + for (int i = 0; i < arc3.size(); i++) + arc.push_back(arc3[i]); +} + +void _concate_arcs(VP& arc1, VP& arc2, VP& arc) +{ + for (int i = 0; i < arc1.size(); i++) + arc.push_back(arc1[i]); + for (int i = 0; i < arc2.size(); i++) + arc.push_back(arc2[i]); +} + +EllipseDetector::EllipseDetector(void) +{ + // Default Parameters Settings + szPreProcessingGaussKernel_ = Size(5, 5); + dPreProcessingGaussSigma_ = 1.0; + fThrArcPosition_ = 1.0f; + fMaxCenterDistance_ = 100.0f * 0.05f; + fMaxCenterDistance2_ = fMaxCenterDistance_ * fMaxCenterDistance_; + iMinEdgeLength_ = 16; + fMinOrientedRectSide_ = 3.0f; + fDistanceToEllipseContour_ = 0.1f; + fMinScore_ = 0.7f; + fMinReliability_ = 0.5f; + uNs_ = 16; + dPercentNe_ = 0.9; + + fT_CNC_ = 0.2f; + fT_TCN_L_ = 0.4f; // filter lines + fT_TCN_P_ = 0.6f; + fThre_r_ = 0.2f; + + srand(unsigned(time(NULL))); +} + +EllipseDetector::~EllipseDetector(void) +{ +} + +void EllipseDetector::SetParameters(Size szPreProcessingGaussKernel, + double dPreProcessingGaussSigma, // 高斯模糊,去噪声 + float fThPosition, // 位置约束,阈值,低于阈值不要 + float fMaxCenterDistance, // 中心点约束 + int iMinEdgeLength, // 最小弧度常识,一般 8-16, 小于的弧不要 + float fMinOrientedRectSide, // 筛选弧参数,最小外接矩形,短边长度 + float fDistanceToEllipseContour, // fitting得分d, 一般几个像素 + float fMinScore, // 0.5 - 0.7, 最终椭圆得分阈值 + float fMinReliability, // 弧长与椭圆周长比值阈值 + int iNs, // 中心性约数, 2个弧中心与椭圆中心距离 一般22 + double dPercentNe, // 边缘检测阈值, 越大保留的弧数量越多 0.9 - 0.99 + float fT_CNC, // 判读集合CNC, 一般不会调 + float fT_TCN_L, + float fT_TCN_P, + float fThre_r +) +{ + szPreProcessingGaussKernel_ = szPreProcessingGaussKernel; + dPreProcessingGaussSigma_ = dPreProcessingGaussSigma; + fThrArcPosition_ = fThPosition; + fMaxCenterDistance_ = fMaxCenterDistance; + iMinEdgeLength_ = iMinEdgeLength; + fMinOrientedRectSide_ = fMinOrientedRectSide; + fDistanceToEllipseContour_ = fDistanceToEllipseContour; + fMinScore_ = fMinScore; + fMinReliability_ = fMinReliability; + uNs_ = iNs; + dPercentNe_ = dPercentNe; + + fT_CNC_ = fT_CNC; + fT_TCN_L_ = fT_TCN_L; // filter lines + fT_TCN_P_ = fT_TCN_P; + fThre_r_ = fThre_r; + + fMaxCenterDistance2_ = fMaxCenterDistance_ * fMaxCenterDistance_; +} + +void EllipseDetector::SetMCD(float fMaxCenterDistance) +{ + fMaxCenterDistance_ = fMaxCenterDistance; + fMaxCenterDistance2_ = fMaxCenterDistance_ * fMaxCenterDistance_; +} + +void EllipseDetector::RemoveStraightLine(VVP& segments, VVP& segments_update, int id) +{ + int countedges = 0; + // For each edge + for (int i = 0; i < segments.size(); ++i) + { + VP& edgeSegment = segments[i]; + +#ifndef DISCARD_CONSTRAINT_OBOX + // Selection strategy - Step 1 - See Sect [3.1.2] of the paper + // Constraint on axes aspect ratio + RotatedRect oriented = minAreaRect(edgeSegment); + float o_min = min(oriented.size.width, oriented.size.height); + + if (o_min < fMinOrientedRectSide_) + { + countedges++; + continue; + } +#endif + + // Order edge points of the same arc + if (id == 0 || id == 1 || id == 4 || id == 5) { + sort(edgeSegment.begin(), edgeSegment.end(), _SortTopLeft2BottomRight); + } + else if (id == 2 || id == 3 || id == 6 || id == 7) { + sort(edgeSegment.begin(), edgeSegment.end(), _SortBottomLeft2TopRight); + } + + int iEdgeSegmentSize = int(edgeSegment.size()); + + // Get extrema of the arc + Point& left = edgeSegment[0]; + Point& right = edgeSegment[iEdgeSegmentSize - 1]; + +#ifndef DISCARD_TCN +#ifndef DISCARD_TCN2 + int flag = 0; + for (int j = 0; j fT_TCN_L_) { + flag = 1; + break; + } + } + if (0 == flag) { + countedges++; + continue; + } +#endif +#ifndef DISCARD_TCN1 + Point& mid = edgeSegment[iEdgeSegmentSize / 2]; + float data[] = { float(left.x), float(left.y), 1.0f, float(mid.x), float(mid.y), + 1.0f, float(right.x), float(right.y), 1.0f }; + Mat threePoints(Size(3, 3), CV_32FC1, data); + double ans = determinant(threePoints); + + float dx = 1.0f*(left.x - right.x); + float dy = 1.0f*(left.y - right.y); + float edgelength2 = dx*dx + dy*dy; + // double TCNl = ans / edgelength2; + double TCNl = ans / (2 * pow(edgelength2, fT_TCN_P_)); + if (abs(TCNl) < fT_TCN_L_) { + countedges++; + continue; + } +#endif +#endif + + segments_update.push_back(edgeSegment); + } +} + +void EllipseDetector::Detect(cv::Mat& I, std::vector& ellipses) +{ + countOfFindEllipse_ = 0; + countOfGetFastCenter_ = 0; + + Mat1b gray; + cvtColor(I, gray, CV_BGR2GRAY); + + + // Set the image size + szIm_ = I.size(); + + // Initialize temporary data structures + Mat1b DP = Mat1b::zeros(szIm_); // arcs along positive diagonal + Mat1b DN = Mat1b::zeros(szIm_); // arcs along negative diagonal + + + ACC_N_SIZE = 101; + ACC_R_SIZE = 180; + ACC_A_SIZE = max(szIm_.height, szIm_.width); + + // Allocate accumulators + accN = new int[ACC_N_SIZE]; + accR = new int[ACC_R_SIZE]; + accA = new int[ACC_A_SIZE]; + + + // Other temporary + + unordered_map centers; // hash map for reusing already computed EllipseData + + PreProcessing(gray, DP, DN); + + + points_1.clear(); + points_2.clear(); + points_3.clear(); + points_4.clear(); + // Detect edges and find convexities + DetectEdges13(DP, points_1, points_3); + DetectEdges24(DN, points_2, points_4); + + + Triplets124(points_1, points_2, points_4, centers, ellipses); + Triplets231(points_2, points_3, points_1, centers, ellipses); + Triplets342(points_3, points_4, points_2, centers, ellipses); + Triplets413(points_4, points_1, points_3, centers, ellipses); + + sort(ellipses.begin(), ellipses.end()); + + + // Free accumulator memory + delete[] accN; + delete[] accR; + delete[] accA; + + // Cluster detections + ClusterEllipses(ellipses); +} + +void EllipseDetector::Detect(Mat3b& I, vector& ellipses) +{ + countOfFindEllipse_ = 0; + countOfGetFastCenter_ = 0; + +#ifdef DEBUG_SPEED + Tic(0); // prepare data structure +#endif + + // Convert to grayscale + Mat1b gray; + cvtColor(I, gray, CV_BGR2GRAY); + + // Set the image size + szIm_ = I.size(); + + // Initialize temporary data structures + Mat1b DP = Mat1b::zeros(szIm_); // arcs along positive diagonal + Mat1b DN = Mat1b::zeros(szIm_); // arcs along negative diagonal + + // Initialize accumulator dimensions + ACC_N_SIZE = 101; + ACC_R_SIZE = 180; + ACC_A_SIZE = max(szIm_.height, szIm_.width); + + // Allocate accumulators + accN = new int[ACC_N_SIZE]; + accR = new int[ACC_R_SIZE]; + accA = new int[ACC_A_SIZE]; + + // Other temporary + + unordered_map centers; // hash map for reusing already computed EllipseData + +#ifdef DEBUG_SPEED + Toc(0, "prepare data structure"); // prepare data structure +#endif + + // Preprocessing + // From input image I, find edge point with coarse convexity along positive (DP) or negative (DN) diagonal + PreProcessing(gray, DP, DN); + +#ifdef DEBUG_SPEED + Tic(3); // preprocessing +#endif + + points_1.clear(); + points_2.clear(); + points_3.clear(); + points_4.clear(); + // Detect edges and find convexities + DetectEdges13(DP, points_1, points_3); + DetectEdges24(DN, points_2, points_4); + +#ifdef DEBUG_SPEED + Toc(3, "preprocessing_2"); // preprocessing +#endif + + // #define DEBUG_PREPROCESSING_S4 +#ifdef DEBUG_PREPROCESSING_S4 + Mat3b out(I.rows, I.cols, Vec3b(0, 0, 0)); + for (unsigned i = 0; i < points_1.size(); ++i) + { + // Vec3b color(rand()%255, 128+rand()%127, 128+rand()%127); + Vec3b color(255, 0, 0); + for (unsigned j = 0; j < points_1[i].size(); ++j) + out(points_1[i][j]) = color; + } + for (unsigned i = 0; i < points_2.size(); ++i) + { + // Vec3b color(rand()%255, 128+rand()%127, 128+rand()%127); + Vec3b color(0, 255, 0); + for (unsigned j = 0; j < points_2[i].size(); ++j) + out(points_2[i][j]) = color; + } + for (unsigned i = 0; i < points_3.size(); ++i) + { + // Vec3b color(rand()%255, 128+rand()%127, 128+rand()%127); + Vec3b color(0, 0, 255); + for (unsigned j = 0; j < points_3[i].size(); ++j) + out(points_3[i][j]) = color; + } + for (unsigned i = 0; i < points_4.size(); ++i) + { + // Vec3b color(rand()%255, 128+rand()%127, 128+rand()%127); + Vec3b color(255, 0, 255); + for (unsigned j = 0; j < points_4[i].size(); ++j) + out(points_4[i][j]) = color; + } + cv::imshow("PreProcessing->Output", out); + waitKey(); +#endif + +#ifdef DEBUG_SPEED + Tic(4); // grouping +#endif + + + Triplets124(points_1, points_2, points_4, centers, ellipses); + Triplets231(points_2, points_3, points_1, centers, ellipses); + Triplets342(points_3, points_4, points_2, centers, ellipses); + Triplets413(points_4, points_1, points_3, centers, ellipses); + +#ifdef DEBUG_SPEED + Toc(4, "grouping"); // grouping +#endif + +#ifdef DEBUG_SPEED + Tic(5); +#endif + + // Sort detected ellipses with respect to score + sort(ellipses.begin(), ellipses.end()); + + // Free accumulator memory + delete[] accN; + delete[] accR; + delete[] accA; + + // Cluster detections + ClusterEllipses(ellipses); + +#ifdef DEBUG_SPEED + Toc(5, "cluster detections"); +#endif +} + +void EllipseDetector::PreProcessing(Mat1b& I, Mat1b& arcs8) +{ + GaussianBlur(I, I, szPreProcessingGaussKernel_, dPreProcessingGaussSigma_); + + // Temp variables + Mat1b E; // edge mask + Mat1s DX, DY; // sobel derivatives + + _tag_canny(I, E, DX, DY, 3, false, dPercentNe_); // Detect edges + + // Mat1f dxf, dyf; + // normalize(DX, dxf, 0, 1, NORM_MINMAX); + // normalize(DY, dyf, 0, 1, NORM_MINMAX); + //Mat1f dx, dy; + //DX.convertTo(dx, CV_32F); + //DY.convertTo(dy, CV_32F); + //// DYDX_ = dy / dx; + + ////Mat1f edx, edy; + ////Sobel(E, edx, CV_32F, 1, 0); + ////Sobel(E, edy, CV_32F, 0, 1); + + //DYDX_ = -1 / (dy / dx); + //CV_Assert(DYDX_.type() == CV_32F); + + //cv::GaussianBlur(dx, dx, Size(5, 5), 0, 0); + //cv::GaussianBlur(dy, dy, Size(5, 5), 0, 0); + //cv::phase(dx, dy, EO_); + +#define SIGN(n) ((n)<=0?((n)<0?-1:0):1) + + // cout << SIGN(0) << " " << SIGN(-1) << " " << SIGN(1) << endl; + + int sign_x, sign_y, sign_xy; + for (int i = 0; i < szIm_.height; ++i) { + short* _dx = DX.ptr(i); + short* _dy = DY.ptr(i); + uchar* _e = E.ptr(i); + uchar* _arc = arcs8.ptr(i); + + for (int j = 0; j < szIm_.width; ++j) { + if (!(_e[j] <= 0)) // !!!!!! + { + sign_x = SIGN(_dx[j]); + sign_y = SIGN(_dy[j]); + sign_xy = SIGN(abs(_dx[j]) - abs(_dy[j])); + if (sign_x == 1 && sign_y == 1 && sign_xy == 1) { + _arc[j] = (uchar)3; + } + else if (sign_x == 1 && sign_y == 1 && sign_xy == -1) { + _arc[j] = (uchar)4; + } + else if (sign_x == 1 && sign_y == -1 && sign_xy == -1) { + _arc[j] = (uchar)1; + } + else if (sign_x == 1 && sign_y == -1 && sign_xy == 1) { + _arc[j] = (uchar)2; + } + else if (sign_x == -1 && sign_y == -1 && sign_xy == 1) { + _arc[j] = (uchar)7; + } + else if (sign_x == -1 && sign_y == -1 && sign_xy == -1) { + _arc[j] = (uchar)8; + } + else if (sign_x == -1 && sign_y == 1 && sign_xy == -1) { + _arc[j] = (uchar)5; + } + else if (sign_x == -1 && sign_y == 1 && sign_xy == 1) { + _arc[j] = (uchar)6; + } + } + } + } +} + +void EllipseDetector::PreProcessing(Mat1b& I, Mat1b& DP, Mat1b& DN) +{ +#ifdef DEBUG_SPEED + Tic(1); // edge detection +#endif + + GaussianBlur(I, I, szPreProcessingGaussKernel_, dPreProcessingGaussSigma_); + + // Temp variables + Mat1b E; // edge mask + Mat1s DX, DY; // sobel derivatives + + // Detect edges + _tag_canny(I, E, DX, DY, 3, false, dPercentNe_); + + Mat1f dx, dy; + DX.convertTo(dx, CV_32F); + DY.convertTo(dy, CV_32F); + //// cv::GaussianBlur(dx, dx, Size(5, 5), 0, 0); + //// cv::GaussianBlur(dy, dy, Size(5, 5), 0, 0); + cv::phase(dx, dy, EO_); + +#ifdef DEBUG_PREPROCESSING + imshow("PreProcessing->Edge", E); waitKey(50); +#endif + +#ifdef DEBUG_SPEED + Toc(1, "edge detection"); // edge detection + Tic(2); // preprocessing +#endif + + float M_3_2_PI = M_PI + M_1_2_PI; + // For each edge points, compute the edge direction + for (int i = 0; i < szIm_.height; ++i) { + float* _o = EO_.ptr(i); + uchar* _e = E.ptr(i); + uchar* _dp = DP.ptr(i); + uchar* _dn = DN.ptr(i); + + for (int j = 0; j < szIm_.width; ++j) { + if (!(_e[j] <= 0)) // !!!!!! + { + if (_o[j] == 0 || _o[j] == M_1_2_PI || _o[j] == M_PI || _o[j] == M_3_2_PI) { + _dn[j] = (uchar)255; _dp[j] = (uchar)255; + } + else if ((_o[j] > 0 && _o[j] < M_1_2_PI) || (_o[j] > M_PI && _o[j] < M_3_2_PI)) _dn[j] = (uchar)255; + else _dp[j] = (uchar)255; + } + } + } + +#ifdef DEBUG_PREPROCESSING + imshow("PreProcessing->DP", DP); waitKey(50); + imshow("PreProcessing->DN", DN); waitKey(); +#endif +#ifdef DEBUG_SPEED + Toc(2, "preprocessing"); // preprocessing +#endif +} + +void EllipseDetector::DetectEdges13(Mat1b& DP, VVP& points_1, VVP& points_3) +{ + // Vector of connected edge points + VVP contours; + int countedges = 0; + // Labeling 8-connected edge points, discarding edge too small + _tag_find_contours(DP, contours, iMinEdgeLength_); // put small arc edges to a vector + +#ifdef DEBUG_PREPROCESSING + Mat1b DP_show = DP.clone(); + _tag_show_contours(DP_show, contours, "PreProcessing->Contours13"); waitKey(); +#endif + + + // For each edge + for (int i = 0; i < contours.size(); ++i) + { + VP& edgeSegment = contours[i]; + +#ifndef DISCARD_CONSTRAINT_OBOX + // Selection strategy - Step 1 - See Sect [3.1.2] of the paper + // Constraint on axes aspect ratio + RotatedRect oriented = minAreaRect(edgeSegment); + float o_min = min(oriented.size.width, oriented.size.height); + + if (o_min < fMinOrientedRectSide_) + { + countedges++; + continue; + } +#endif + + // Order edge points of the same arc + sort(edgeSegment.begin(), edgeSegment.end(), _SortTopLeft2BottomRight); + int iEdgeSegmentSize = int(edgeSegment.size()); + + // Get extrema of the arc + Point& left = edgeSegment[0]; + Point& right = edgeSegment[iEdgeSegmentSize - 1]; + +#ifndef DISCARD_TCN +#ifndef DISCARD_TCN2 + int flag = 0; + for (int j = 0; j fT_TCN_L_) { + flag = 1; + break; + } + } + if (0 == flag) { + countedges++; + continue; + } +#endif +#ifndef DISCARD_TCN1 + Point& mid = edgeSegment[iEdgeSegmentSize / 2]; + float data[] = { float(left.x), float(left.y), 1.0f, float(mid.x), float(mid.y), + 1.0f, float(right.x), float(right.y), 1.0f }; + Mat threePoints(Size(3, 3), CV_32FC1, data); + double ans = determinant(threePoints); + + float dx = 1.0f*(left.x - right.x); + float dy = 1.0f*(left.y - right.y); + float edgelength2 = dx*dx + dy*dy; + // double TCNl = ans / edgelength2; + double TCNl = ans / (2 * pow(edgelength2, fT_TCN_P_)); + if (abs(TCNl) < fT_TCN_L_) { + countedges++; + continue; + } +#endif +#endif + + // Find convexity - See Sect [3.1.3] of the paper + int iCountTop = 0; + int xx = left.x; + for (int k = 1; k < iEdgeSegmentSize; ++k) + { + if (edgeSegment[k].x == xx) continue; + + iCountTop += (edgeSegment[k].y - left.y); + xx = edgeSegment[k].x; + } + + int width = abs(right.x - left.x) + 1; + int height = abs(right.y - left.y) + 1; + int iCountBottom = (width * height) - iEdgeSegmentSize - iCountTop; + + if (iCountBottom > iCountTop) + { // 1 + points_1.push_back(edgeSegment); + } + else if (iCountBottom < iCountTop) + { // 3 + points_3.push_back(edgeSegment); + } + } + +} + +void EllipseDetector::DetectEdges24(Mat1b& DN, VVP& points_2, VVP& points_4) +{ + // Vector of connected edge points + VVP contours; + int countedges = 0; + /// Labeling 8-connected edge points, discarding edge too small + _tag_find_contours(DN, contours, iMinEdgeLength_); + +#ifdef DEBUG_PREPROCESSING + _tag_show_contours(DN, contours, "PreProcessing->Contours24"); waitKey(); +#endif + + int iContoursSize = unsigned(contours.size()); + + + // For each edge + for (int i = 0; i < iContoursSize; ++i) + { + VP& edgeSegment = contours[i]; + +#ifndef DISCARD_CONSTRAINT_OBOX + // Selection strategy - Step 1 - See Sect [3.1.2] of the paper + // Constraint on axes aspect ratio + RotatedRect oriented = minAreaRect(edgeSegment); + float o_min = min(oriented.size.width, oriented.size.height); + + if (o_min < fMinOrientedRectSide_) + { + countedges++; + continue; + } +#endif + + // Order edge points of the same arc + sort(edgeSegment.begin(), edgeSegment.end(), _SortBottomLeft2TopRight); + int iEdgeSegmentSize = unsigned(edgeSegment.size()); + + // Get extrema of the arc + Point& left = edgeSegment[0]; + Point& right = edgeSegment[iEdgeSegmentSize - 1]; + +#ifndef DISCARD_TCN +#ifndef DISCARD_TCN2 + int flag = 0; + for (int j = 0; j fT_TCN_L_) { + flag = 1; + break; + } + } + if (0 == flag) { + countedges++; + continue; + } +#endif +#ifndef DISCARD_TCN1 + Point& mid = edgeSegment[iEdgeSegmentSize / 2]; + float data[] = { float(left.x), float(left.y), 1.0f, float(mid.x), float(mid.y), + 1.0f, float(right.x), float(right.y), 1.0f }; + Mat threePoints(Size(3, 3), CV_32FC1, data); + double ans = determinant(threePoints); + + float dx = 1.0f*(left.x - right.x); + float dy = 1.0f*(left.y - right.y); + float edgelength2 = dx*dx + dy*dy; + // double TCNl = ans / edgelength2; + double TCNl = ans / (2 * pow(edgelength2, fT_TCN_P_)); + if (abs(TCNl) < fT_TCN_L_) { + countedges++; + continue; + } +#endif +#endif + + // Find convexity - See Sect [3.1.3] of the paper + int iCountBottom = 0; + int xx = left.x; + for (int k = 1; k < iEdgeSegmentSize; ++k) + { + if (edgeSegment[k].x == xx) continue; + + iCountBottom += (left.y - edgeSegment[k].y); + xx = edgeSegment[k].x; + } + + int width = abs(right.x - left.x) + 1; + int height = abs(right.y - left.y) + 1; + int iCountTop = (width *height) - iEdgeSegmentSize - iCountBottom; + + if (iCountBottom > iCountTop) + { + // 2 + points_2.push_back(edgeSegment); + } + else if (iCountBottom < iCountTop) + { + // 4 + points_4.push_back(edgeSegment); + } + } + +} + +float inline ed2(const cv::Point& A, const cv::Point& B) +{ + return float(((B.x - A.x)*(B.x - A.x) + (B.y - A.y)*(B.y - A.y))); +} + + +#define T124 pjf,pjm,pjl,pif,pim,pil // origin +#define T231 pil,pim,pif,pjf,pjm,pjl +#define T342 pif,pim,pil,pjf,pjm,pjl +#define T413 pif,pim,pil,pjl,pjm,pjf + + +// Verify triplets of arcs with convexity: i=1, j=2, k=4 +void EllipseDetector::Triplets124(VVP& pi, + VVP& pj, + VVP& pk, + unordered_map& data, + vector& ellipses +) +{ + // get arcs length + ushort sz_i = ushort(pi.size()); + ushort sz_j = ushort(pj.size()); + ushort sz_k = ushort(pk.size()); + + // For each edge i + for (ushort i = 0; i < sz_i; ++i) + { + VP& edge_i = pi[i]; + ushort sz_ei = ushort(edge_i.size()); + + Point& pif = edge_i[0]; + Point& pim = edge_i[sz_ei / 2]; + Point& pil = edge_i[sz_ei - 1]; + + // 1,2 -> reverse 1, swap + VP rev_i(edge_i.size()); + reverse_copy(edge_i.begin(), edge_i.end(), rev_i.begin()); + + // For each edge j + for (ushort j = 0; j < sz_j; ++j) + { + vector ellipses_i; + + VP& edge_j = pj[j]; + ushort sz_ej = ushort(edge_j.size()); + + Point& pjf = edge_j[0]; + Point& pjm = edge_j[sz_ej / 2]; + Point& pjl = edge_j[sz_ej - 1]; + +#ifndef DISCARD_CONSTRAINT_POSITION + + //if (sqrt((pjl.x - pif.x)*(pjl.x - pif.x) + (pjl.y - pif.y)*(pjl.y - pif.y)) > MAX(edge_i.size(), edge_j.size())) + // continue; + double tm1 = _tic(); + // CONSTRAINTS on position + if (pjl.x > pif.x + fThrArcPosition_) //is right + continue; + +#endif + + tm1 = _tic(); + if (_ed2(pil, pjf) / _ed2(pif, pjl) < fThre_r_) + continue; + +#ifdef CONSTRAINT_CNC_1 + tm1 = _tic(); + // cnc constraint1 2se se1//pil,pim,pif,pjf,pjm,pjl pjf,pjm,pjl,pif,pim,pil + if (fabs(_value4SixPoints(T124) - 1) > fT_CNC_) + continue; +#endif + + tm1 = _tic(); + EllipseData data_ij; + uint key_ij = GenerateKey(PAIR_12, i, j); + // If the data for the pair i-j have not been computed yet + if (data.count(key_ij) == 0) + { + // 1,2 -> reverse 1, swap + // Compute data! + GetFastCenter(edge_j, rev_i, data_ij); + // Insert computed data in the hash table + data.insert(pair(key_ij, data_ij)); + } + else + { + // Otherwise, just lookup the data in the hash table + data_ij = data.at(key_ij); + } + + // for each edge k + for (ushort k = 0; k < sz_k; ++k) + { + VP& edge_k = pk[k]; + ushort sz_ek = ushort(edge_k.size()); + + Point& pkf = edge_k[0]; + Point& pkm = edge_k[sz_ek / 2]; + Point& pkl = edge_k[sz_ek - 1]; + +#ifndef DISCARD_CONSTRAINT_POSITION + // CONSTRAINTS on position + if (pkl.y < pil.y - fThrArcPosition_) + continue; +#endif +#ifdef CONSTRAINT_CNC_2 + // cnc constraint2 + if (fabs(_value4SixPoints(pif, pim, pil, pkf, pkm, pkl) - 1) > fT_CNC_) + continue; +#endif +#ifdef CONSTRAINT_CNC_3 + // cnc constraint3 + if (fabs(_value4SixPoints(pjf, pjm, pjl, pkf, pkm, pkl) - 1) > fT_CNC_) + continue; +#endif + + uint key_ik = GenerateKey(PAIR_14, i, k); + + // Find centers + EllipseData data_ik; + + // If the data for the pair i-k have not been computed yet + if (data.count(key_ik) == 0) + { + // 1,4 -> ok + // Compute data! + GetFastCenter(edge_i, edge_k, data_ik); + // Insert computed data in the hash table + data.insert(pair(key_ik, data_ik)); + } + else + { + // Otherwise, just lookup the data in the hash table + data_ik = data.at(key_ik); + } + + // INVALID CENTERS + if (!data_ij.isValid || !data_ik.isValid) + { + continue; + } + +#ifndef DISCARD_CONSTRAINT_CENTER + // Selection strategy - Step 3. See Sect [3.2.2] in the paper + // The computed centers are not close enough + if (ed2(data_ij.Cab, data_ik.Cab) > fMaxCenterDistance2_) + { + // discard + continue; + } +#endif + + // If all constraints of the selection strategy have been satisfied, + // we can start estimating the ellipse parameters + + // Find ellipse parameters + // Get the coordinates of the center (xc, yc) + Point2f center = GetCenterCoordinates(data_ij, data_ik); + + Ellipse ell; + // Find remaining paramters (A,B,rho) + FindEllipses(center, edge_i, edge_j, edge_k, data_ij, data_ik, ell); + ellipses_i.push_back(ell); + } + /*-----------------------------------------------------------------*/ + int rt = -1; + float rs = 0; + for (int t = 0; t < (int)ellipses_i.size(); t++) + { + if (ellipses_i[t].score_ > rs) + { + rs = ellipses_i[t].score_; rt = t; + } + } + if (rt > -1) + { + ellipses.push_back(ellipses_i[rt]); + } + /*-----------------------------------------------------------------*/ + } + } +} + +void EllipseDetector::Triplets231(VVP& pi, + VVP& pj, + VVP& pk, + unordered_map& data, + vector& ellipses +) +{ + ushort sz_i = ushort(pi.size()); + ushort sz_j = ushort(pj.size()); + ushort sz_k = ushort(pk.size()); + + // For each edge i + for (ushort i = 0; i < sz_i; ++i) + { + VP& edge_i = pi[i]; + ushort sz_ei = ushort(edge_i.size()); + + Point& pif = edge_i[0]; + Point& pim = edge_i[sz_ei / 2]; + Point& pil = edge_i[sz_ei - 1]; + + VP rev_i(edge_i.size()); + reverse_copy(edge_i.begin(), edge_i.end(), rev_i.begin()); + + // For each edge j + for (ushort j = 0; j < sz_j; ++j) + { + vector ellipses_i; + + VP& edge_j = pj[j]; + ushort sz_ej = ushort(edge_j.size()); + + Point& pjf = edge_j[0]; + Point& pjm = edge_j[sz_ej / 2]; + Point& pjl = edge_j[sz_ej - 1]; + +#ifndef DISCARD_CONSTRAINT_POSITION + //if (sqrt((pjf.x - pif.x)*(pjf.x - pif.x) + (pjf.y - pif.y)*(pjf.y - pif.y)) > MAX(edge_i.size(), edge_j.size())) + // continue; + + double tm1 = _tic(); + // CONSTRAINTS on position + if (pjf.y < pif.y - fThrArcPosition_) + continue; + +#endif + + tm1 = _tic(); + if (_ed2(pif, pjf) / _ed2(pil, pjl) < fThre_r_) + continue; + +#ifdef CONSTRAINT_CNC_1 + tm1 = _tic(); + // cnc constraint1 2es se3 // pif,pim,pil,pjf,pjm,pjl pil,pim,pif,pjf,pjm,pjl + if (fabs(_value4SixPoints(T231) - 1) > fT_CNC_) + continue; +#endif + + tm1 = _tic(); + VP rev_j(edge_j.size()); + reverse_copy(edge_j.begin(), edge_j.end(), rev_j.begin()); + + EllipseData data_ij; + uint key_ij = GenerateKey(PAIR_23, i, j); + if (data.count(key_ij) == 0) + { + // 2,3 -> reverse 2,3 + GetFastCenter(rev_i, rev_j, data_ij); + data.insert(pair(key_ij, data_ij)); + } + else + { + data_ij = data.at(key_ij); + } + + // For each edge k + for (ushort k = 0; k < sz_k; ++k) + { + VP& edge_k = pk[k]; + + ushort sz_ek = ushort(edge_k.size()); + + Point& pkf = edge_k[0]; + Point& pkm = edge_k[sz_ek / 2]; + Point& pkl = edge_k[sz_ek - 1]; + +#ifndef DISCARD_CONSTRAINT_POSITION + // CONSTRAINTS on position + if (pkf.x < pil.x - fThrArcPosition_) + continue; +#endif + +#ifdef CONSTRAINT_CNC_2 + // cnc constraint2 + if (fabs(_value4SixPoints(pif, pim, pil, pkf, pkm, pkl) - 1) > fT_CNC_) + continue; +#endif +#ifdef CONSTRAINT_CNC_3 + // cnc constraint3 + if (fabs(_value4SixPoints(pjf, pjm, pjl, pkf, pkm, pkl) - 1) > fT_CNC_) + continue; +#endif + uint key_ik = GenerateKey(PAIR_12, k, i); + + // Find centers + + EllipseData data_ik; + + + if (data.count(key_ik) == 0) + { + // 2,1 -> reverse 1 + VP rev_k(edge_k.size()); + reverse_copy(edge_k.begin(), edge_k.end(), rev_k.begin()); + + GetFastCenter(edge_i, rev_k, data_ik); + data.insert(pair(key_ik, data_ik)); + } + else + { + data_ik = data.at(key_ik); + } + + // INVALID CENTERS + if (!data_ij.isValid || !data_ik.isValid) + { + continue; + } + +#ifndef DISCARD_CONSTRAINT_CENTER + // CONSTRAINT ON CENTERS + if (ed2(data_ij.Cab, data_ik.Cab) > fMaxCenterDistance2_) + { + // discard + continue; + } +#endif + // Find ellipse parameters + Point2f center = GetCenterCoordinates(data_ij, data_ik); + + Ellipse ell; + FindEllipses(center, edge_i, edge_j, edge_k, data_ij, data_ik, ell); + ellipses_i.push_back(ell); + } + /*-----------------------------------------------------------------*/ + int rt = -1; + float rs = 0; + for (int t = 0; t < (int)ellipses_i.size(); t++) + { + if (ellipses_i[t].score_ > rs) + { + rs = ellipses_i[t].score_; rt = t; + } + } + if (rt > -1) + { + ellipses.push_back(ellipses_i[rt]); + } + /*-----------------------------------------------------------------*/ + } + } +} + +void EllipseDetector::Triplets342(VVP& pi, + VVP& pj, + VVP& pk, + unordered_map& data, + vector& ellipses +) +{ + ushort sz_i = ushort(pi.size()); + ushort sz_j = ushort(pj.size()); + ushort sz_k = ushort(pk.size()); + + // For each edge i + for (ushort i = 0; i < sz_i; ++i) + { + VP& edge_i = pi[i]; + ushort sz_ei = ushort(edge_i.size()); + + Point& pif = edge_i[0]; + Point& pim = edge_i[sz_ei / 2]; + Point& pil = edge_i[sz_ei - 1]; + + VP rev_i(edge_i.size()); + reverse_copy(edge_i.begin(), edge_i.end(), rev_i.begin()); + + // For each edge j + for (ushort j = 0; j < sz_j; ++j) + { + vector ellipses_i; + + VP& edge_j = pj[j]; + ushort sz_ej = ushort(edge_j.size()); + + Point& pjf = edge_j[0]; + Point& pjm = edge_j[sz_ej / 2]; + Point& pjl = edge_j[sz_ej - 1]; + +#ifndef DISCARD_CONSTRAINT_POSITION + //if (sqrt((pjf.x - pil.x)*(pjf.x - pil.x) + (pjf.y - pil.y)*(pjf.y - pil.y)) > MAX(edge_i.size(), edge_j.size())) + // continue; + + double tm1 = _tic(); + // CONSTRAINTS on position + if (pjf.x < pil.x - fThrArcPosition_) // is left + continue; + +#endif + + tm1 = _tic(); + if (_ed2(pil, pjf) / _ed2(pif, pjl) < fThre_r_) + continue; + +#ifdef CONSTRAINT_CNC_1 + tm1 = _tic(); + // cnc constraint1 3se se4 // pil,pim,pif,pjf,pjm,pjl pif,pim,pil,pjf,pjm,pjl + if (fabs(_value4SixPoints(T342) - 1) > fT_CNC_) + continue; + +#endif + + tm1 = _tic(); + VP rev_j(edge_j.size()); + reverse_copy(edge_j.begin(), edge_j.end(), rev_j.begin()); + + EllipseData data_ij; + uint key_ij = GenerateKey(PAIR_34, i, j); + + if (data.count(key_ij) == 0) + { + // 3,4 -> reverse 4 + + GetFastCenter(edge_i, rev_j, data_ij); + data.insert(pair(key_ij, data_ij)); + } + else + { + data_ij = data.at(key_ij); + } + + + // For each edge k + for (ushort k = 0; k < sz_k; ++k) + { + VP& edge_k = pk[k]; + ushort sz_ek = ushort(edge_k.size()); + + Point& pkf = edge_k[0]; + Point& pkm = edge_k[sz_ek / 2]; + Point& pkl = edge_k[sz_ek - 1]; + +#ifndef DISCARD_CONSTRAINT_POSITION + // CONSTRAINTS on position + if (pkf.y > pif.y + fThrArcPosition_) + continue; +#endif + +#ifdef CONSTRAINT_CNC_2 + // cnc constraint2 + if (fabs(_value4SixPoints(pif, pim, pil, pkf, pkm, pkl) - 1) > fT_CNC_) + continue; +#endif +#ifdef CONSTRAINT_CNC_3 + // cnc constraint3 + if (fabs(_value4SixPoints(pjf, pjm, pjl, pkf, pkm, pkl) - 1) > fT_CNC_) + continue; +#endif + uint key_ik = GenerateKey(PAIR_23, k, i); + + // Find centers + + EllipseData data_ik; + + + + if (data.count(key_ik) == 0) + { + // 3,2 -> reverse 3,2 + + VP rev_k(edge_k.size()); + reverse_copy(edge_k.begin(), edge_k.end(), rev_k.begin()); + + GetFastCenter(rev_i, rev_k, data_ik); + + data.insert(pair(key_ik, data_ik)); + } + else + { + data_ik = data.at(key_ik); + } + + + // INVALID CENTERS + if (!data_ij.isValid || !data_ik.isValid) + { + continue; + } + +#ifndef DISCARD_CONSTRAINT_CENTER + // CONSTRAINT ON CENTERS + if (ed2(data_ij.Cab, data_ik.Cab) > fMaxCenterDistance2_) + { + // discard + continue; + } +#endif + // Find ellipse parameters + Point2f center = GetCenterCoordinates(data_ij, data_ik); + + Ellipse ell; + FindEllipses(center, edge_i, edge_j, edge_k, data_ij, data_ik, ell); + ellipses_i.push_back(ell); + } + /*-----------------------------------------------------------------*/ + int rt = -1; + float rs = 0; + for (int t = 0; t < (int)ellipses_i.size(); t++) + { + if (ellipses_i[t].score_ > rs) + { + rs = ellipses_i[t].score_; rt = t; + } + } + if (rt > -1) + { + ellipses.push_back(ellipses_i[rt]); + } + /*-----------------------------------------------------------------*/ + } + } +} + +void EllipseDetector::Triplets413(VVP& pi, + VVP& pj, + VVP& pk, + unordered_map& data, + vector& ellipses +) +{ + ushort sz_i = ushort(pi.size()); + ushort sz_j = ushort(pj.size()); + ushort sz_k = ushort(pk.size()); + + // For each edge i + for (ushort i = 0; i < sz_i; ++i) + { + VP& edge_i = pi[i]; + ushort sz_ei = ushort(edge_i.size()); + + Point& pif = edge_i[0]; + Point& pim = edge_i[sz_ei / 2]; + Point& pil = edge_i[sz_ei - 1]; + + VP rev_i(edge_i.size()); + reverse_copy(edge_i.begin(), edge_i.end(), rev_i.begin()); + + // For each edge j + for (ushort j = 0; j < sz_j; ++j) + { + vector ellipses_i; + + VP& edge_j = pj[j]; + ushort sz_ej = ushort(edge_j.size()); + + Point& pjf = edge_j[0]; + Point& pjm = edge_j[sz_ej / 2]; + Point& pjl = edge_j[sz_ej - 1]; + +#ifndef DISCARD_CONSTRAINT_POSITION + //if (sqrt((pjl.x - pil.x)*(pjl.x - pil.x) + (pjl.y - pil.y)*(pjl.y - pil.y)) > MAX(edge_i.size(), edge_j.size())) + // continue; + + double tm1 = _tic(); + // CONSTRAINTS on position + if (pjl.y > pil.y + fThrArcPosition_) // is below + continue; + +#endif + + tm1 = _tic(); + if (_ed2(pif, pjf) / _ed2(pil, pjl) < fThre_r_) + continue; + +#ifdef CONSTRAINT_CNC_1 + tm1 = _tic(); + // cnc constraint1 4se es1 // pif,pim,pil,pjf,pjm,pjl pil,pim,pif,pjl,pjm,pjf pif,pim,pil,pjl,pjm,pjf + if (fabs(_value4SixPoints(T413) - 1) > fT_CNC_) + continue; +#endif + + tm1 = _tic(); + EllipseData data_ij; + uint key_ij = GenerateKey(PAIR_14, j, i); + + if (data.count(key_ij) == 0) + { + // 4,1 -> OK + GetFastCenter(edge_i, edge_j, data_ij); + data.insert(pair(key_ij, data_ij)); + } + else + { + data_ij = data.at(key_ij); + } + + // For each edge k + for (ushort k = 0; k < sz_k; ++k) + { + VP& edge_k = pk[k]; + ushort sz_ek = ushort(edge_k.size()); + + Point& pkf = edge_k[0]; + Point& pkm = edge_k[sz_ek / 2]; + Point& pkl = edge_k[sz_ek - 1]; + +#ifndef DISCARD_CONSTRAINT_POSITION + // CONSTRAINTS on position + if (pkl.x > pif.x + fThrArcPosition_) + continue; +#endif + +#ifdef CONSTRAINT_CNC_2 + // cnc constraint2 + if (fabs(_value4SixPoints(pif, pim, pil, pkf, pkm, pkl) - 1) > fT_CNC_) + continue; +#endif +#ifdef CONSTRAINT_CNC_3 + // cnc constraint2 + if (fabs(_value4SixPoints(pjf, pjm, pjl, pkf, pkm, pkl) - 1) > fT_CNC_) + continue; +#endif + uint key_ik = GenerateKey(PAIR_34, k, i); + + // Find centers + + EllipseData data_ik; + + + + if (data.count(key_ik) == 0) + { + // 4,3 -> reverse 4 + GetFastCenter(rev_i, edge_k, data_ik); + data.insert(pair(key_ik, data_ik)); + } + else + { + data_ik = data.at(key_ik); + } + + // INVALID CENTERS + if (!data_ij.isValid || !data_ik.isValid) + { + continue; + } + +#ifndef DISCARD_CONSTRAINT_CENTER + // CONSTRAIN ON CENTERS + if (ed2(data_ij.Cab, data_ik.Cab) > fMaxCenterDistance2_) + { + // discard + continue; + } +#endif + // Find ellipse parameters + Point2f center = GetCenterCoordinates(data_ij, data_ik); + + Ellipse ell; + FindEllipses(center, edge_i, edge_j, edge_k, data_ij, data_ik, ell); + ellipses_i.push_back(ell); + } + /*-----------------------------------------------------------------*/ + int rt = -1; + float rs = 0; + for (int t = 0; t < (int)ellipses_i.size(); t++) + { + if (ellipses_i[t].score_ > rs) + { + rs = ellipses_i[t].score_; rt = t; + } + } + if (rt > -1) + { + ellipses.push_back(ellipses_i[rt]); + } + /*-----------------------------------------------------------------*/ + } + } +} + +int EllipseDetector::FindMaxK(const int* v) const +{ + int max_val = 0; + int max_idx = 0; + for (int i = 0; i max_val) ? max_val = v[i], max_idx = i : (int)NULL; + } + + return max_idx + 90; +} + +int EllipseDetector::FindMaxN(const int* v) const +{ + int max_val = 0; + int max_idx = 0; + for (int i = 0; i max_val) ? max_val = v[i], max_idx = i : (int)NULL; + } + + return max_idx; +} + +int EllipseDetector::FindMaxA(const int* v) const +{ + int max_val = 0; + int max_idx = 0; + for (int i = 0; i max_val) ? max_val = v[i], max_idx = i : (int)NULL; + } + + return max_idx; +} + +// Most important function for detecting ellipses. See Sect[3.2.3] of the paper +void EllipseDetector::FindEllipses(Point2f& center, + VP& edge_i, VP& edge_j, VP& edge_k, + EllipseData& data_ij, EllipseData& data_ik, Ellipse& ell) +{ + countOfFindEllipse_++; + // Find ellipse parameters + + // 0-initialize accumulators + memset(accN, 0, sizeof(int)*ACC_N_SIZE); + memset(accR, 0, sizeof(int)*ACC_R_SIZE); + memset(accA, 0, sizeof(int)*ACC_A_SIZE); + + // estimation + // Get size of the 4 vectors of slopes (2 pairs of arcs) + int sz_ij1 = int(data_ij.Sa.size()); + int sz_ij2 = int(data_ij.Sb.size()); + int sz_ik1 = int(data_ik.Sa.size()); + int sz_ik2 = int(data_ik.Sb.size()); + + // Get the size of the 3 arcs + size_t sz_ei = edge_i.size(); + size_t sz_ej = edge_j.size(); + size_t sz_ek = edge_k.size(); + + // Center of the estimated ellipse + float a0 = center.x; + float b0 = center.y; + + + // Estimation of remaining parameters + // Uses 4 combinations of parameters. See Table 1 and Sect [3.2.3] of the paper. + // ij1 and ik + { + float q1 = data_ij.ra; + float q3 = data_ik.ra; + float q5 = data_ik.rb; + + for (int ij1 = 0; ij1 < sz_ij1; ++ij1) + { + float q2 = data_ij.Sa[ij1]; // need iter \A3\BF + + float q1xq2 = q1*q2; + // ij1 and ik1 + for (int ik1 = 0; ik1 < sz_ik1; ++ik1) + { + float q4 = data_ik.Sa[ik1]; // need iter \A3\BF + + float q3xq4 = q3*q4; + + // See Eq. [13-18] in the paper + + float a = (q1xq2 - q3xq4); // gama + float b = (q3xq4 + 1)*(q1 + q2) - (q1xq2 + 1)*(q3 + q4); // beta + float Kp = (-b + sqrt(b*b + 4 * a*a)) / (2 * a); // K+ + float zplus = ((q1 - Kp)*(q2 - Kp)) / ((1 + q1*Kp)*(1 + q2*Kp)); + // check zplus and K is linear + if (zplus >= 0.0f) continue; + + float Np = sqrt(-zplus); // N+ + float rho = atan(Kp); // rho tmp + int rhoDeg; + if (Np > 1.f) + { + Np = 1.f / Np; + rhoDeg = cvRound((rho * 180 / CV_PI) + 180) % 180; // [0,180) + } + else + { + rhoDeg = cvRound((rho * 180 / CV_PI) + 90) % 180; // [0,180) rho angel rep and norm + } + + int iNp = cvRound(Np * 100); // [0, 100] + + if (0 <= iNp && iNp < ACC_N_SIZE && + 0 <= rhoDeg && rhoDeg < ACC_R_SIZE + ) + { // why iter all. beacause zplus and K is not linear? + ++accN[iNp]; // Increment N accumulator + ++accR[rhoDeg]; // Increment R accumulator + } + } + // ij1 and ik2 + for (int ik2 = 0; ik2 < sz_ik2; ++ik2) + { + float q4 = data_ik.Sb[ik2]; + + float q5xq4 = q5*q4; + + // See Eq. [13-18] in the paper + + float a = (q1xq2 - q5xq4); + float b = (q5xq4 + 1)*(q1 + q2) - (q1xq2 + 1)*(q5 + q4); + float Kp = (-b + sqrt(b*b + 4 * a*a)) / (2 * a); + float zplus = ((q1 - Kp)*(q2 - Kp)) / ((1 + q1*Kp)*(1 + q2*Kp)); + + if (zplus >= 0.0f) + { + continue; + } + + float Np = sqrt(-zplus); + float rho = atan(Kp); + int rhoDeg; + if (Np > 1.f) + { + Np = 1.f / Np; + rhoDeg = cvRound((rho * 180 / CV_PI) + 180) % 180; // [0,180) + } + else + { + rhoDeg = cvRound((rho * 180 / CV_PI) + 90) % 180; // [0,180) + } + + int iNp = cvRound(Np * 100); // [0, 100] + + if (0 <= iNp && iNp < ACC_N_SIZE && + 0 <= rhoDeg && rhoDeg < ACC_R_SIZE + ) + { + ++accN[iNp]; // Increment N accumulator + ++accR[rhoDeg]; // Increment R accumulator + } + } + + } + } + + // ij2 and ik + { + float q1 = data_ij.rb; + float q3 = data_ik.rb; + float q5 = data_ik.ra; + + for (int ij2 = 0; ij2 < sz_ij2; ++ij2) + { + float q2 = data_ij.Sb[ij2]; + + float q1xq2 = q1*q2; + // ij2 and ik2 + for (int ik2 = 0; ik2 < sz_ik2; ++ik2) + { + float q4 = data_ik.Sb[ik2]; + + float q3xq4 = q3*q4; + + // See Eq. [13-18] in the paper + + float a = (q1xq2 - q3xq4); + float b = (q3xq4 + 1)*(q1 + q2) - (q1xq2 + 1)*(q3 + q4); + float Kp = (-b + sqrt(b*b + 4 * a*a)) / (2 * a); + float zplus = ((q1 - Kp)*(q2 - Kp)) / ((1 + q1*Kp)*(1 + q2*Kp)); + + if (zplus >= 0.0f) + { + continue; + } + + float Np = sqrt(-zplus); + float rho = atan(Kp); + int rhoDeg; + if (Np > 1.f) + { + Np = 1.f / Np; + rhoDeg = cvRound((rho * 180 / CV_PI) + 180) % 180; // [0,180) + } + else + { + rhoDeg = cvRound((rho * 180 / CV_PI) + 90) % 180; // [0,180) + } + + int iNp = cvRound(Np * 100); // [0, 100] + + if (0 <= iNp && iNp < ACC_N_SIZE && + 0 <= rhoDeg && rhoDeg < ACC_R_SIZE + ) + { + ++accN[iNp]; // Increment N accumulator + ++accR[rhoDeg]; // Increment R accumulator + } + } + + // ij2 and ik1 + for (int ik1 = 0; ik1 < sz_ik1; ++ik1) + { + float q4 = data_ik.Sa[ik1]; + + float q5xq4 = q5*q4; + + // See Eq. [13-18] in the paper + + float a = (q1xq2 - q5xq4); + float b = (q5xq4 + 1)*(q1 + q2) - (q1xq2 + 1)*(q5 + q4); + float Kp = (-b + sqrt(b*b + 4 * a*a)) / (2 * a); + float zplus = ((q1 - Kp)*(q2 - Kp)) / ((1 + q1*Kp)*(1 + q2*Kp)); + + if (zplus >= 0.0f) + { + continue; + } + + float Np = sqrt(-zplus); + float rho = atan(Kp); + int rhoDeg; + if (Np > 1.f) + { + Np = 1.f / Np; + rhoDeg = cvRound((rho * 180 / CV_PI) + 180) % 180; // [0,180) + } + else + { + rhoDeg = cvRound((rho * 180 / CV_PI) + 90) % 180; // [0,180) + } + + int iNp = cvRound(Np * 100); // [0, 100] + + if (0 <= iNp && iNp < ACC_N_SIZE && + 0 <= rhoDeg && rhoDeg < ACC_R_SIZE + ) + { + ++accN[iNp]; // Increment N accumulator + ++accR[rhoDeg]; // Increment R accumulator + } + } + + } + } + + // Find peak in N and K accumulator + int iN = FindMaxN(accN); + int iK = FindMaxK(accR); + + // Recover real values + float fK = float(iK); + float Np = float(iN) * 0.01f; + float rho = fK * float(CV_PI) / 180.f; // deg 2 rad + float Kp = tan(rho); + + // Estimate A. See Eq. [19 - 22] in Sect [3.2.3] of the paper + // + // may optm + for (ushort l = 0; l < sz_ei; ++l) + { + Point& pp = edge_i[l]; + float sk = 1.f / sqrt(Kp*Kp + 1.f); // cos rho + float x0 = ((pp.x - a0) * sk) + (((pp.y - b0)*Kp) * sk); // may optm + float y0 = -(((pp.x - a0) * Kp) * sk) + ((pp.y - b0) * sk); // may optm + float Ax = sqrt((x0*x0*Np*Np + y0*y0) / ((Np*Np)*(1.f + Kp*Kp))); + int A = cvRound(abs(Ax / cos(rho))); // may optm + if ((0 <= A) && (A < ACC_A_SIZE)) + { + ++accA[A]; + } + } + + for (ushort l = 0; l < sz_ej; ++l) + { + Point& pp = edge_j[l]; + float sk = 1.f / sqrt(Kp*Kp + 1.f); + float x0 = ((pp.x - a0) * sk) + (((pp.y - b0)*Kp) * sk); + float y0 = -(((pp.x - a0) * Kp) * sk) + ((pp.y - b0) * sk); + float Ax = sqrt((x0*x0*Np*Np + y0*y0) / ((Np*Np)*(1.f + Kp*Kp))); + int A = cvRound(abs(Ax / cos(rho))); + if ((0 <= A) && (A < ACC_A_SIZE)) + { + ++accA[A]; + } + } + + for (ushort l = 0; l < sz_ek; ++l) + { + Point& pp = edge_k[l]; + float sk = 1.f / sqrt(Kp*Kp + 1.f); + float x0 = ((pp.x - a0) * sk) + (((pp.y - b0)*Kp) * sk); + float y0 = -(((pp.x - a0) * Kp) * sk) + ((pp.y - b0) * sk); + float Ax = sqrt((x0*x0*Np*Np + y0*y0) / ((Np*Np)*(1.f + Kp*Kp))); + int A = cvRound(abs(Ax / cos(rho))); + if ((0 <= A) && (A < ACC_A_SIZE)) + { + ++accA[A]; + } + } + + // Find peak in A accumulator + int A = FindMaxA(accA); + float fA = float(A); + + // Find B value. See Eq [23] in the paper + float fB = abs(fA * Np); + + // Got all ellipse parameters! + // Ellipse ell(a0, b0, fA, fB, fmod(rho + float(CV_PI)*2.f, float(CV_PI))); + ell.xc_ = a0; + ell.yc_ = b0; + ell.a_ = fA; + ell.b_ = fB; + ell.rad_ = fmod(rho + float(CV_PI)*2.f, float(CV_PI)); + + // estimation end + // validation start + // Get the score. See Sect [3.3.1] in the paper + + // Find the number of edge pixel lying on the ellipse + float _cos = cos(-ell.rad_); + float _sin = sin(-ell.rad_); + + float invA2 = 1.f / (ell.a_ * ell.a_); + float invB2 = 1.f / (ell.b_ * ell.b_); + + float invNofPoints = 1.f / float(sz_ei + sz_ej + sz_ek); + int counter_on_perimeter = 0; + float probc_on_perimeter = 0; + + for (ushort l = 0; l < sz_ei; ++l) + { + float tx = float(edge_i[l].x) - ell.xc_; + float ty = float(edge_i[l].y) - ell.yc_; + float rx = (tx*_cos - ty*_sin); + float ry = (tx*_sin + ty*_cos); + + float h = (rx*rx)*invA2 + (ry*ry)*invB2; + if (abs(h - 1.f) < fDistanceToEllipseContour_) + { + ++counter_on_perimeter; + probc_on_perimeter += 1; // (fDistanceToEllipseContour_ - abs(h - 1.f)) / fDistanceToEllipseContour_; + } + } + + for (ushort l = 0; l < sz_ej; ++l) + { + float tx = float(edge_j[l].x) - ell.xc_; + float ty = float(edge_j[l].y) - ell.yc_; + float rx = (tx*_cos - ty*_sin); + float ry = (tx*_sin + ty*_cos); + + float h = (rx*rx)*invA2 + (ry*ry)*invB2; + if (abs(h - 1.f) < fDistanceToEllipseContour_) + { + ++counter_on_perimeter; + probc_on_perimeter += 1; // (fDistanceToEllipseContour_ - abs(h - 1.f)) / fDistanceToEllipseContour_; + } + } + + for (ushort l = 0; l < sz_ek; ++l) + { + float tx = float(edge_k[l].x) - ell.xc_; + float ty = float(edge_k[l].y) - ell.yc_; + float rx = (tx*_cos - ty*_sin); + float ry = (tx*_sin + ty*_cos); + + float h = (rx*rx)*invA2 + (ry*ry)*invB2; + if (abs(h - 1.f) < fDistanceToEllipseContour_) + { + ++counter_on_perimeter; + probc_on_perimeter += 1; // (fDistanceToEllipseContour_ - abs(h - 1.f)) / fDistanceToEllipseContour_; + } + } + + + // no points found on the ellipse + if (counter_on_perimeter <= 0) + { + // validation + return; + } + + // Compute score + // float score = float(counter_on_perimeter) * invNofPoints; + float score = probc_on_perimeter * invNofPoints; + if (score < fMinScore_) + { + // validation + return; + } + + // Compute reliability + // this metric is not described in the paper, mostly due to space limitations. + // The main idea is that for a given ellipse (TD) even if the score is high, the arcs + // can cover only a small amount of the contour of the estimated ellipse. + // A low reliability indicate that the arcs form an elliptic shape by chance, but do not underlie + // an actual ellipse. The value is normalized between 0 and 1. + // The default value is 0.4. + + // It is somehow similar to the "Angular Circumreference Ratio" saliency criteria + // as in the paper: + // D. K. Prasad, M. K. Leung, S.-Y. Cho, Edge curvature and convexity + // based ellipse detection method, Pattern Recognition 45 (2012) 3204-3221. + + float di, dj, dk; + { + Point2f p1(float(edge_i[0].x), float(edge_i[0].y)); + Point2f p2(float(edge_i[sz_ei - 1].x), float(edge_i[sz_ei - 1].y)); + p1.x -= ell.xc_; + p1.y -= ell.yc_; + p2.x -= ell.xc_; + p2.y -= ell.yc_; + Point2f r1((p1.x*_cos - p1.y*_sin), (p1.x*_sin + p1.y*_cos)); + Point2f r2((p2.x*_cos - p2.y*_sin), (p2.x*_sin + p2.y*_cos)); + di = abs(r2.x - r1.x) + abs(r2.y - r1.y); + } + { + Point2f p1(float(edge_j[0].x), float(edge_j[0].y)); + Point2f p2(float(edge_j[sz_ej - 1].x), float(edge_j[sz_ej - 1].y)); + p1.x -= ell.xc_; + p1.y -= ell.yc_; + p2.x -= ell.xc_; + p2.y -= ell.yc_; + Point2f r1((p1.x*_cos - p1.y*_sin), (p1.x*_sin + p1.y*_cos)); + Point2f r2((p2.x*_cos - p2.y*_sin), (p2.x*_sin + p2.y*_cos)); + dj = abs(r2.x - r1.x) + abs(r2.y - r1.y); + } + { + Point2f p1(float(edge_k[0].x), float(edge_k[0].y)); + Point2f p2(float(edge_k[sz_ek - 1].x), float(edge_k[sz_ek - 1].y)); + p1.x -= ell.xc_; + p1.y -= ell.yc_; + p2.x -= ell.xc_; + p2.y -= ell.yc_; + Point2f r1((p1.x*_cos - p1.y*_sin), (p1.x*_sin + p1.y*_cos)); + Point2f r2((p2.x*_cos - p2.y*_sin), (p2.x*_sin + p2.y*_cos)); + dk = abs(r2.x - r1.x) + abs(r2.y - r1.y); + } + + // This allows to get rid of thick edges + float rel = min(1.f, ((di + dj + dk) / (3 * (ell.a_ + ell.b_)))); + + if (rel < fMinReliability_) + { + // validation + return; + } + if (_isnan(rel)) + return; + + // Assign the new score! + ell.score_ = (score + rel) * 0.5f; // need to change + // ell.score_ = (score*rel); // need to change + + if (ell.score_ < fMinScore_) + { + return; + } + // The tentative detection has been confirmed. Save it! + // ellipses.push_back(ell); + + // Validation end +} + + + +// Ellipse clustering procedure. See Sect [3.3.2] in the paper. +void EllipseDetector::ClusterEllipses(vector& ellipses) +{ + float th_Da = 0.1f; + float th_Db = 0.1f; + float th_Dr = 0.1f; + + float th_Dc_ratio = 0.1f; + float th_Dr_circle = 0.9f; + + int iNumOfEllipses = int(ellipses.size()); + if (iNumOfEllipses == 0) return; + + // The first ellipse is assigned to a cluster + vector clusters; + clusters.push_back(ellipses[0]); + + // bool bFoundCluster = false; + + for (int i = 1; i th_Dc) + { + //not same cluster + continue; + } + + // a + float Da = abs(e1.a_ - e2.a_) / max(e1.a_, e2.a_); + if (Da > th_Da) + { + //not same cluster + continue; + } + + // b + float Db = abs(e1.b_ - e2.b_) / min(e1.b_, e2.b_); + if (Db > th_Db) + { + //not same cluster + continue; + } + + // angle + float Dr = GetMinAnglePI(e1.rad_, e2.rad_) / float(CV_PI); + if ((Dr > th_Dr) && (ba_e1 < th_Dr_circle) && (ba_e2 < th_Dr_circle)) + { + //not same cluster + continue; + } + + // Same cluster as e2 + bFoundCluster = true;// + // Discard, no need to create a new cluster + break; + } + + if (!bFoundCluster) + { + // Create a new cluster + clusters.push_back(e1); + } + } + + clusters.swap(ellipses); +} + +float EllipseDetector::GetMedianSlope(vector& med, Point2f& M, vector& slopes) +{ + // input med slopes, output M, return slope + // med : vector of points + // M : centroid of the points in med + // slopes : vector of the slopes + + unsigned iNofPoints = unsigned(med.size()); + // CV_Assert(iNofPoints >= 2); + + unsigned halfSize = iNofPoints >> 1; + unsigned quarterSize = halfSize >> 1; + + vector xx, yy; + slopes.reserve(halfSize); + xx.reserve(iNofPoints); + yy.reserve(iNofPoints); + + for (unsigned i = 0; i < halfSize; ++i) + { + Point2f& p1 = med[i]; + Point2f& p2 = med[halfSize + i]; + + xx.push_back(p1.x); + xx.push_back(p2.x); + yy.push_back(p1.y); + yy.push_back(p2.y); + + float den = (p2.x - p1.x); + float num = (p2.y - p1.y); + + if (den == 0) den = 0.00001f; + + slopes.push_back(num / den); + } + + nth_element(slopes.begin(), slopes.begin() + quarterSize, slopes.end()); + nth_element(xx.begin(), xx.begin() + halfSize, xx.end()); + nth_element(yy.begin(), yy.begin() + halfSize, yy.end()); + M.x = xx[halfSize]; + M.y = yy[halfSize]; + + return slopes[quarterSize]; +} + +int inline sgn(float val) { + return (0.f < val) - (val < 0.f); +} + +void EllipseDetector::GetFastCenter(vector& e1, vector& e2, EllipseData& data) +{ + countOfGetFastCenter_++; + data.isValid = true; + + unsigned size_1 = unsigned(e1.size()); + unsigned size_2 = unsigned(e2.size()); + + unsigned hsize_1 = size_1 >> 1; + unsigned hsize_2 = size_2 >> 1; + + Point& med1 = e1[hsize_1]; + Point& med2 = e2[hsize_2]; + + Point2f M12, M34; + float q2, q4; + + {// First to second Reference slope + float dx_ref = float(e1[0].x - med2.x); + float dy_ref = float(e1[0].y - med2.y); + + if (dy_ref == 0) dy_ref = 0.00001f; + + float m_ref = dy_ref / dx_ref; + data.ra = m_ref; + + // Find points with same slope as reference + vector med; + med.reserve(hsize_2); + + unsigned minPoints = (uNs_ < hsize_2) ? uNs_ : hsize_2; // parallel chords + + vector indexes(minPoints); + if (uNs_ < hsize_2) + { // hsize_2 bigger than uNs_ + unsigned iSzBin = hsize_2 / unsigned(uNs_); + unsigned iIdx = hsize_2 + (iSzBin / 2); + + for (unsigned i = 0; i> 1; + while (end - begin > 2) + { + float x2 = float(e1[j].x); + float y2 = float(e1[j].y); + float res = ((x2 - x1) * dy_ref) - ((y2 - y1) * dx_ref); + int sign_res = sgn(res); + + if (sign_res == 0) + { + // found + med.push_back(Point2f((x2 + x1)* 0.5f, (y2 + y1)* 0.5f)); + break; + } + + if (sign_res + sign_begin == 0) + { + sign_end = sign_res; + end = j; + } + else + { + sign_begin = sign_res; + begin = j; + } + j = (begin + end) >> 1; + } + // search end error ? + med.push_back(Point2f((e1[j].x + x1)* 0.5f, (e1[j].y + y1)* 0.5f)); + } + + if (med.size() < 2) + { + data.isValid = false; + return; + } + + /**************************************** + Mat3b out(480, 640, Vec3b(0, 0, 0)); + Vec3b color(0, 255, 0); + for (int ci = 0; ci < med.size(); ci++) + circle(out, med[ci], 2, color); + imshow("test", out); waitKey(100); + ****************************************/ + q2 = GetMedianSlope(med, M12, data.Sa); //get Sa ta = q2 Ma + } + + {// Second to first + // Reference slope + float dx_ref = float(med1.x - e2[0].x); + float dy_ref = float(med1.y - e2[0].y); + + if (dy_ref == 0) dy_ref = 0.00001f; + + float m_ref = dy_ref / dx_ref; + data.rb = m_ref; + + // Find points with same slope as reference + vector med; + med.reserve(hsize_1); + + uint minPoints = (uNs_ < hsize_1) ? uNs_ : hsize_1; + + vector indexes(minPoints); + if (uNs_ < hsize_1) + { + unsigned iSzBin = hsize_1 / unsigned(uNs_); + unsigned iIdx = hsize_1 + (iSzBin / 2); + + for (unsigned i = 0; i> 1; + + while (end - begin > 2) + { + float x2 = float(e2[j].x); + float y2 = float(e2[j].y); + float res = ((x2 - x1) * dy_ref) - ((y2 - y1) * dx_ref); + int sign_res = sgn(res); + + if (sign_res == 0) + { + //found + med.push_back(Point2f((x2 + x1)* 0.5f, (y2 + y1)* 0.5f)); + break; + } + + if (sign_res + sign_begin == 0) + { + sign_end = sign_res; + end = j; + } + else + { + sign_begin = sign_res; + begin = j; + } + j = (begin + end) >> 1; + } + + med.push_back(Point2f((e2[j].x + x1)* 0.5f, (e2[j].y + y1)* 0.5f)); + } + + if (med.size() < 2) + { + data.isValid = false; + return; + } + q4 = GetMedianSlope(med, M34, data.Sb); + } + + if (q2 == q4) + { + data.isValid = false; + return; + } + + float invDen = 1 / (q2 - q4); + data.Cab.x = (M34.y - q4*M34.x - M12.y + q2*M12.x) * invDen; + data.Cab.y = (q2*M34.y - q4*M12.y + q2*q4*(M12.x - M34.x)) * invDen; + data.ta = q2; + data.tb = q4; + data.Ma = M12; + data.Mb = M34; +} + +float EllipseDetector::GetMinAnglePI(float alpha, float beta) +{ + float pi = float(CV_PI); + float pi2 = float(2.0 * CV_PI); + + // normalize data in [0, 2*pi] + float a = fmod(alpha + pi2, pi2); + float b = fmod(beta + pi2, pi2); + + // normalize data in [0, pi] + if (a > pi) + a -= pi; + if (b > pi) + b -= pi; + + if (a > b) + { + swap(a, b); + } + + float diff1 = b - a; + float diff2 = pi - diff1; + return min(diff1, diff2); +} + +// Get the coordinates of the center, given the intersection of the estimated lines. See Fig. [8] in Sect [3.2.3] in the paper. +Point2f EllipseDetector::GetCenterCoordinates(EllipseData& data_ij, EllipseData& data_ik) +{ + float xx[7]; + float yy[7]; + + xx[0] = data_ij.Cab.x; + xx[1] = data_ik.Cab.x; + yy[0] = data_ij.Cab.y; + yy[1] = data_ik.Cab.y; + + { + // 1-1 + float q2 = data_ij.ta; + float q4 = data_ik.ta; + Point2f& M12 = data_ij.Ma; + Point2f& M34 = data_ik.Ma; + + float invDen = 1 / (q2 - q4); + xx[2] = (M34.y - q4*M34.x - M12.y + q2*M12.x) * invDen; + yy[2] = (q2*M34.y - q4*M12.y + q2*q4*(M12.x - M34.x)) * invDen; + } + + { + // 1-2 + float q2 = data_ij.ta; + float q4 = data_ik.tb; + Point2f& M12 = data_ij.Ma; + Point2f& M34 = data_ik.Mb; + + float invDen = 1 / (q2 - q4); + xx[3] = (M34.y - q4*M34.x - M12.y + q2*M12.x) * invDen; + yy[3] = (q2*M34.y - q4*M12.y + q2*q4*(M12.x - M34.x)) * invDen; + } + + { + // 2-2 + float q2 = data_ij.tb; + float q4 = data_ik.tb; + Point2f& M12 = data_ij.Mb; + Point2f& M34 = data_ik.Mb; + + float invDen = 1 / (q2 - q4); + xx[4] = (M34.y - q4*M34.x - M12.y + q2*M12.x) * invDen; + yy[4] = (q2*M34.y - q4*M12.y + q2*q4*(M12.x - M34.x)) * invDen; + } + + { + // 2-1 + float q2 = data_ij.tb; + float q4 = data_ik.ta; + Point2f& M12 = data_ij.Mb; + Point2f& M34 = data_ik.Ma; + + float invDen = 1 / (q2 - q4); + xx[5] = (M34.y - q4*M34.x - M12.y + q2*M12.x) * invDen; + yy[5] = (q2*M34.y - q4*M12.y + q2*q4*(M12.x - M34.x)) * invDen; + } + + xx[6] = (xx[0] + xx[1]) * 0.5f; + yy[6] = (yy[0] + yy[1]) * 0.5f; + + + // Median + nth_element(xx, xx + 3, xx + 7); + nth_element(yy, yy + 3, yy + 7); + float xc = xx[3]; + float yc = yy[3]; + + return Point2f(xc, yc); +} + +uint inline EllipseDetector::GenerateKey(uchar pair, ushort u, ushort v) +{ + return (pair << 30) + (u << 15) + v; +} + +// Draw at most iTopN detected ellipses. +void EllipseDetector::DrawDetectedEllipses(Mat& output, vector& ellipses, int iTopN, int thickness) +{ + int sz_ell = int(ellipses.size()); + int n = (iTopN == 0) ? sz_ell : min(iTopN, sz_ell); + for (int i = 0; i < n; ++i) + { + Ellipse& e = ellipses[n - i - 1]; + int g = cvRound(e.score_ * 255.f); + Scalar color(0, g, 0); + ellipse(output, Point(cvRound(e.xc_), cvRound(e.yc_)), Size(cvRound(e.a_), cvRound(e.b_)), e.rad_*180.0 / CV_PI, 0.0, 360.0, color, thickness); + // cv::circle(output, Point(e.xc_, e.yc_), 2, Scalar(0, 0, 255), 2); + } +} + +bool inline convex_check(VP& vp1, int start, int end) +{ + int x_min(4096), x_max(0), y_min(4096), y_max(0); + int integral_u(0), integral_d(0); + for (int i = start; i <= end; i++) + { + Point& val = vp1[i]; + x_min = MIN(x_min, val.x); + x_max = MAX(x_max, val.x); + y_min = MIN(y_min, val.y); + y_max = MAX(y_max, val.y); + } + for (int i = start; i <= end; i++) + { + Point& val = vp1[i]; + integral_u += (val.y - y_min); + integral_d += (y_max - val.y); + } + if (integral_u > integral_d) + return false; + else + return true; +} + +bool inline concave_check(VP& vp1, int start, int end) +{ + int x_min(4096), x_max(0), y_min(4096), y_max(0); + int integral_u(0), integral_d(0); + for (int i = start; i <= end; i++) + { + Point& val = vp1[i]; + x_min = MIN(x_min, val.x); + x_max = MAX(x_max, val.x); + y_min = MIN(y_min, val.y); + y_max = MAX(y_max, val.y); + } + for (int i = start; i <= end; i++) + { + Point& val = vp1[i]; + integral_u += (val.y - y_min); + integral_d += (y_max - val.y); + } + if (integral_u < integral_d) + return false; + else + return true; +} + +void EllipseDetector::ArcsCheck1234(VVP& points_1, VVP& points_2, VVP& points_3, VVP& points_4) +{ + static int nchecks = 21; + int i, j; + VVP vps_1, vps_2, vps_3, vps_4; + Mat3b out(480, 640, Vec3b(0, 0, 0)); + for (i = 0; i < points_2.size(); ++i) + { + // Vec3b color(rand()%255, 128+rand()%127, 128+rand()%127); + Vec3b color(255, 0, 0); + for (j = 0; j < points_2[i].size(); ++j) + out(points_2[i][j]) = color; + } + imshow("out", out); waitKey(); + + for (i = 0; i < points_1.size(); i++) + { + VP& vp1 = points_1[i]; + if (vp1.size() > nchecks) + { + VP vpn; + for (j = 0; j <= vp1.size() - nchecks; j++) + { + if (convex_check(vp1, j, j + nchecks - 1)) + { + vpn.push_back(vp1[j]); + } + else + { + vps_1.push_back(vpn); vpn.clear(); + } + } + vps_1.push_back(vpn); + } + else + { + cout << "==== small arc I ====" << endl; + } + } + for (i = 0; i < points_2.size(); i++) + { + VP& vp1 = points_2[i]; + if (vp1.size() > nchecks) + { + VP vpn; + for (j = 0; j <= vp1.size() - nchecks; j++) + { + if (concave_check(vp1, j, j + nchecks - 1)) + { + vpn.push_back(vp1[j]); + } + else + { + vps_2.push_back(vpn); vpn.clear(); + } + } + vps_2.push_back(vpn); + } + else + { + cout << "==== small arc II ====" << endl; + } + } + Mat3b out2(480, 640, Vec3b(0, 0, 0)); + for (i = 0; i < vps_2.size(); ++i) + { + // Vec3b color(rand()%255, 128+rand()%127, 128+rand()%127); + Vec3b color(255, 0, 0); + for (j = 0; j < vps_2[i].size(); ++j) + out2(vps_2[i][j]) = color; + } + imshow("out2", out2); waitKey(); + for (i = 0; i < points_3.size(); i++) + { + VP& vp1 = points_3[i]; + if (vp1.size() > nchecks) + { + VP vpn; + for (j = 0; j <= vp1.size() - nchecks; j++) + { + if (concave_check(vp1, j, j + nchecks - 1)) + { + vpn.push_back(vp1[j]); + } + else + { + vps_3.push_back(vpn); vpn.clear(); + } + } + vps_3.push_back(vpn); + } + else + { + cout << "==== small arc III ====" << endl; + } + } + for (i = 0; i < points_4.size(); i++) + { + VP& vp1 = points_4[i]; + if (vp1.size() > nchecks) + { + VP vpn; + for (j = 0; j <= vp1.size() - nchecks; j++) + { + if (convex_check(vp1, j, j + nchecks - 1)) + { + vpn.push_back(vp1[j]); + } + else + { + vps_4.push_back(vpn); vpn.clear(); + } + } + vps_4.push_back(vpn); + } + else + { + cout << "==== small arc IV ====" << endl; + } + } + points_1 = vps_1; points_2 = vps_2; points_3 = vps_3; points_4 = vps_4; +} + +} + diff --git a/algorithm/ellipse_det/ellipse_detector.h b/algorithm/ellipse_det/ellipse_detector.h new file mode 100644 index 0000000..4518a19 --- /dev/null +++ b/algorithm/ellipse_det/ellipse_detector.h @@ -0,0 +1,1379 @@ +#ifndef SPIRE_ELLIPSEDETECTOR_H +#define SPIRE_ELLIPSEDETECTOR_H + +#include +#include +#include +#include +#include +#include + + + +#ifdef _WIN32 +/* +* Define architecture flags so we don't need to include windows.h. +* Avoiding windows.h makes it simpler to use windows sockets in conjunction +* with dirent.h. +*/ +#if !defined(_68K_) && !defined(_MPPC_) && !defined(_X86_) && !defined(_IA64_) && !defined(_AMD64_) && defined(_M_IX86) +# define _X86_ +#endif +#if !defined(_68K_) && !defined(_MPPC_) && !defined(_X86_) && !defined(_IA64_) && !defined(_AMD64_) && defined(_M_AMD64) +#define _AMD64_ +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* Indicates that d_type field is available in dirent structure */ +#define _DIRENT_HAVE_D_TYPE + +/* Indicates that d_namlen field is available in dirent structure */ +#define _DIRENT_HAVE_D_NAMLEN + +/* Entries missing from MSVC 6.0 */ +#if !defined(FILE_ATTRIBUTE_DEVICE) +# define FILE_ATTRIBUTE_DEVICE 0x40 +#endif + +/* File type and permission flags for stat(), general mask */ +#if !defined(S_IFMT) +# define S_IFMT _S_IFMT +#endif + +/* Directory bit */ +#if !defined(S_IFDIR) +# define S_IFDIR _S_IFDIR +#endif + +/* Character device bit */ +#if !defined(S_IFCHR) +# define S_IFCHR _S_IFCHR +#endif + +/* Pipe bit */ +#if !defined(S_IFFIFO) +# define S_IFFIFO _S_IFFIFO +#endif + +/* Regular file bit */ +#if !defined(S_IFREG) +# define S_IFREG _S_IFREG +#endif + +/* Read permission */ +#if !defined(S_IREAD) +# define S_IREAD _S_IREAD +#endif + +/* Write permission */ +#if !defined(S_IWRITE) +# define S_IWRITE _S_IWRITE +#endif + +/* Execute permission */ +#if !defined(S_IEXEC) +# define S_IEXEC _S_IEXEC +#endif + +/* Pipe */ +#if !defined(S_IFIFO) +# define S_IFIFO _S_IFIFO +#endif + +/* Block device */ +#if !defined(S_IFBLK) +# define S_IFBLK 0 +#endif + +/* Link */ +#if !defined(S_IFLNK) +# define S_IFLNK 0 +#endif + +/* Socket */ +#if !defined(S_IFSOCK) +# define S_IFSOCK 0 +#endif + +/* Read user permission */ +#if !defined(S_IRUSR) +# define S_IRUSR S_IREAD +#endif + +/* Write user permission */ +#if !defined(S_IWUSR) +# define S_IWUSR S_IWRITE +#endif + +/* Execute user permission */ +#if !defined(S_IXUSR) +# define S_IXUSR 0 +#endif + +/* Read group permission */ +#if !defined(S_IRGRP) +# define S_IRGRP 0 +#endif + +/* Write group permission */ +#if !defined(S_IWGRP) +# define S_IWGRP 0 +#endif + +/* Execute group permission */ +#if !defined(S_IXGRP) +# define S_IXGRP 0 +#endif + +/* Read others permission */ +#if !defined(S_IROTH) +# define S_IROTH 0 +#endif + +/* Write others permission */ +#if !defined(S_IWOTH) +# define S_IWOTH 0 +#endif + +/* Execute others permission */ +#if !defined(S_IXOTH) +# define S_IXOTH 0 +#endif + +/* Maximum length of file name */ +#if !defined(PATH_MAX) +# define PATH_MAX MAX_PATH +#endif +#if !defined(FILENAME_MAX) +# define FILENAME_MAX MAX_PATH +#endif +#if !defined(NAME_MAX) +# define NAME_MAX FILENAME_MAX +#endif + +/* File type flags for d_type */ +#define DT_UNKNOWN 0 +#define DT_REG S_IFREG +#define DT_DIR S_IFDIR +#define DT_FIFO S_IFIFO +#define DT_SOCK S_IFSOCK +#define DT_CHR S_IFCHR +#define DT_BLK S_IFBLK +#define DT_LNK S_IFLNK + +/* Macros for converting between st_mode and d_type */ +#define IFTODT(mode) ((mode) & S_IFMT) +#define DTTOIF(type) (type) + +/* +* File type macros. Note that block devices, sockets and links cannot be +* distinguished on Windows and the macros S_ISBLK, S_ISSOCK and S_ISLNK are +* only defined for compatibility. These macros should always return false +* on Windows. +*/ +#if !defined(S_ISFIFO) +# define S_ISFIFO(mode) (((mode) & S_IFMT) == S_IFIFO) +#endif +#if !defined(S_ISDIR) +# define S_ISDIR(mode) (((mode) & S_IFMT) == S_IFDIR) +#endif +#if !defined(S_ISREG) +# define S_ISREG(mode) (((mode) & S_IFMT) == S_IFREG) +#endif +#if !defined(S_ISLNK) +# define S_ISLNK(mode) (((mode) & S_IFMT) == S_IFLNK) +#endif +#if !defined(S_ISSOCK) +# define S_ISSOCK(mode) (((mode) & S_IFMT) == S_IFSOCK) +#endif +#if !defined(S_ISCHR) +# define S_ISCHR(mode) (((mode) & S_IFMT) == S_IFCHR) +#endif +#if !defined(S_ISBLK) +# define S_ISBLK(mode) (((mode) & S_IFMT) == S_IFBLK) +#endif + +/* Return the exact length of d_namlen without zero terminator */ +#define _D_EXACT_NAMLEN(p) ((p)->d_namlen) + +/* Return number of bytes needed to store d_namlen */ +#define _D_ALLOC_NAMLEN(p) (PATH_MAX) + + +#ifdef __cplusplus +extern "C" { +#endif + + + /* Wide-character version */ + struct _wdirent { + /* Always zero */ + long d_ino; + + /* Structure size */ + unsigned short d_reclen; + + /* Length of name without \0 */ + size_t d_namlen; + + /* File type */ + int d_type; + + /* File name */ + wchar_t d_name[PATH_MAX]; + }; + typedef struct _wdirent _wdirent; + + struct _WDIR { + /* Current directory entry */ + struct _wdirent ent; + + /* Private file data */ + WIN32_FIND_DATAW data; + + /* True if data is valid */ + int cached; + + /* Win32 search handle */ + HANDLE handle; + + /* Initial directory name */ + wchar_t *patt; + }; + typedef struct _WDIR _WDIR; + + static _WDIR *_wopendir(const wchar_t *dirname); + static struct _wdirent *_wreaddir(_WDIR *dirp); + static int _wclosedir(_WDIR *dirp); + static void _wrewinddir(_WDIR* dirp); + + + /* For compatibility with Symbian */ +#define wdirent _wdirent +#define WDIR _WDIR +#define wopendir _wopendir +#define wreaddir _wreaddir +#define wclosedir _wclosedir +#define wrewinddir _wrewinddir + + + /* Multi-byte character versions */ + struct dirent { + /* Always zero */ + long d_ino; + + /* Structure size */ + unsigned short d_reclen; + + /* Length of name without \0 */ + size_t d_namlen; + + /* File type */ + int d_type; + + /* File name */ + char d_name[PATH_MAX]; + }; + typedef struct dirent dirent; + + struct DIR { + struct dirent ent; + struct _WDIR *wdirp; + }; + typedef struct DIR DIR; + + static DIR *opendir(const char *dirname); + static struct dirent *readdir(DIR *dirp); + static int closedir(DIR *dirp); + static void rewinddir(DIR* dirp); + + + /* Internal utility functions */ + static WIN32_FIND_DATAW *dirent_first(_WDIR *dirp); + static WIN32_FIND_DATAW *dirent_next(_WDIR *dirp); + + static int dirent_mbstowcs_s( + size_t *pReturnValue, + wchar_t *wcstr, + size_t sizeInWords, + const char *mbstr, + size_t count); + + static int dirent_wcstombs_s( + size_t *pReturnValue, + char *mbstr, + size_t sizeInBytes, + const wchar_t *wcstr, + size_t count); + + static void dirent_set_errno(int error); + + /* + * Open directory stream DIRNAME for read and return a pointer to the + * internal working area that is used to retrieve individual directory + * entries. + */ + static _WDIR* + _wopendir( + const wchar_t *dirname) + { + _WDIR *dirp = NULL; + int error; + + /* Must have directory name */ + if (dirname == NULL || dirname[0] == '\0') { + dirent_set_errno(ENOENT); + return NULL; + } + + /* Allocate new _WDIR structure */ + dirp = (_WDIR*)malloc(sizeof(struct _WDIR)); + if (dirp != NULL) { + DWORD n; + + /* Reset _WDIR structure */ + dirp->handle = INVALID_HANDLE_VALUE; + dirp->patt = NULL; + dirp->cached = 0; + + /* Compute the length of full path plus zero terminator */ + n = GetFullPathNameW(dirname, 0, NULL, NULL); + + /* Allocate room for absolute directory name and search pattern */ + dirp->patt = (wchar_t*)malloc(sizeof(wchar_t) * n + 16); + if (dirp->patt) { + + /* + * Convert relative directory name to an absolute one. This + * allows rewinddir() to function correctly even when current + * working directory is changed between opendir() and rewinddir(). + */ + n = GetFullPathNameW(dirname, n, dirp->patt, NULL); + if (n > 0) { + wchar_t *p; + + /* Append search pattern \* to the directory name */ + p = dirp->patt + n; + if (dirp->patt < p) { + switch (p[-1]) { + case '\\': + case '/': + case ':': + /* Directory ends in path separator, e.g. c:\temp\ */ + /*NOP*/; + break; + + default: + /* Directory name doesn't end in path separator */ + *p++ = '\\'; + } + } + *p++ = '*'; + *p = '\0'; + + /* Open directory stream and retrieve the first entry */ + if (dirent_first(dirp)) { + /* Directory stream opened successfully */ + error = 0; + } + else { + /* Cannot retrieve first entry */ + error = 1; + dirent_set_errno(ENOENT); + } + + } + else { + /* Cannot retrieve full path name */ + dirent_set_errno(ENOENT); + error = 1; + } + + } + else { + /* Cannot allocate memory for search pattern */ + error = 1; + } + + } + else { + /* Cannot allocate _WDIR structure */ + error = 1; + } + + /* Clean up in case of error */ + if (error && dirp) { + _wclosedir(dirp); + dirp = NULL; + } + + return dirp; + } + + /* + * Read next directory entry. The directory entry is returned in dirent + * structure in the d_name field. Individual directory entries returned by + * this function include regular files, sub-directories, pseudo-directories + * "." and ".." as well as volume labels, hidden files and system files. + */ + static struct _wdirent* + _wreaddir( + _WDIR *dirp) + { + WIN32_FIND_DATAW *datap; + struct _wdirent *entp; + + /* Read next directory entry */ + datap = dirent_next(dirp); + if (datap) { + size_t n; + DWORD attr; + + /* Pointer to directory entry to return */ + entp = &dirp->ent; + + /* + * Copy file name as wide-character string. If the file name is too + * long to fit in to the destination buffer, then truncate file name + * to PATH_MAX characters and zero-terminate the buffer. + */ + n = 0; + while (n + 1 < PATH_MAX && datap->cFileName[n] != 0) { + entp->d_name[n] = datap->cFileName[n]; + n++; + } + dirp->ent.d_name[n] = 0; + + /* Length of file name excluding zero terminator */ + entp->d_namlen = n; + + /* File type */ + attr = datap->dwFileAttributes; + if ((attr & FILE_ATTRIBUTE_DEVICE) != 0) { + entp->d_type = DT_CHR; + } + else if ((attr & FILE_ATTRIBUTE_DIRECTORY) != 0) { + entp->d_type = DT_DIR; + } + else { + entp->d_type = DT_REG; + } + + /* Reset dummy fields */ + entp->d_ino = 0; + entp->d_reclen = sizeof(struct _wdirent); + + } + else { + + /* Last directory entry read */ + entp = NULL; + + } + + return entp; + } + + /* + * Close directory stream opened by opendir() function. This invalidates the + * DIR structure as well as any directory entry read previously by + * _wreaddir(). + */ + static int + _wclosedir( + _WDIR *dirp) + { + int ok; + if (dirp) { + + /* Release search handle */ + if (dirp->handle != INVALID_HANDLE_VALUE) { + FindClose(dirp->handle); + dirp->handle = INVALID_HANDLE_VALUE; + } + + /* Release search pattern */ + if (dirp->patt) { + free(dirp->patt); + dirp->patt = NULL; + } + + /* Release directory structure */ + free(dirp); + ok = /*success*/0; + + } + else { + /* Invalid directory stream */ + dirent_set_errno(EBADF); + ok = /*failure*/-1; + } + return ok; + } + + /* + * Rewind directory stream such that _wreaddir() returns the very first + * file name again. + */ + static void + _wrewinddir( + _WDIR* dirp) + { + if (dirp) { + /* Release existing search handle */ + if (dirp->handle != INVALID_HANDLE_VALUE) { + FindClose(dirp->handle); + } + + /* Open new search handle */ + dirent_first(dirp); + } + } + + /* Get first directory entry (internal) */ + static WIN32_FIND_DATAW* + dirent_first( + _WDIR *dirp) + { + WIN32_FIND_DATAW *datap; + + /* Open directory and retrieve the first entry */ + dirp->handle = FindFirstFileW(dirp->patt, &dirp->data); + if (dirp->handle != INVALID_HANDLE_VALUE) { + + /* a directory entry is now waiting in memory */ + datap = &dirp->data; + dirp->cached = 1; + + } + else { + + /* Failed to re-open directory: no directory entry in memory */ + dirp->cached = 0; + datap = NULL; + + } + return datap; + } + + /* Get next directory entry (internal) */ + static WIN32_FIND_DATAW* + dirent_next( + _WDIR *dirp) + { + WIN32_FIND_DATAW *p; + + /* Get next directory entry */ + if (dirp->cached != 0) { + + /* A valid directory entry already in memory */ + p = &dirp->data; + dirp->cached = 0; + + } + else if (dirp->handle != INVALID_HANDLE_VALUE) { + + /* Get the next directory entry from stream */ + if (FindNextFileW(dirp->handle, &dirp->data) != FALSE) { + /* Got a file */ + p = &dirp->data; + } + else { + /* The very last entry has been processed or an error occured */ + FindClose(dirp->handle); + dirp->handle = INVALID_HANDLE_VALUE; + p = NULL; + } + + } + else { + + /* End of directory stream reached */ + p = NULL; + + } + + return p; + } + + /* + * Open directory stream using plain old C-string. + */ + static DIR* + opendir( + const char *dirname) + { + struct DIR *dirp; + int error; + + /* Must have directory name */ + if (dirname == NULL || dirname[0] == '\0') { + dirent_set_errno(ENOENT); + return NULL; + } + + /* Allocate memory for DIR structure */ + dirp = (DIR*)malloc(sizeof(struct DIR)); + if (dirp) { + wchar_t wname[PATH_MAX]; + size_t n; + + /* Convert directory name to wide-character string */ + error = dirent_mbstowcs_s(&n, wname, PATH_MAX, dirname, PATH_MAX); + if (!error) { + + /* Open directory stream using wide-character name */ + dirp->wdirp = _wopendir(wname); + if (dirp->wdirp) { + /* Directory stream opened */ + error = 0; + } + else { + /* Failed to open directory stream */ + error = 1; + } + + } + else { + /* + * Cannot convert file name to wide-character string. This + * occurs if the string contains invalid multi-byte sequences or + * the output buffer is too small to contain the resulting + * string. + */ + error = 1; + } + + } + else { + /* Cannot allocate DIR structure */ + error = 1; + } + + /* Clean up in case of error */ + if (error && dirp) { + free(dirp); + dirp = NULL; + } + + return dirp; + } + + /* + * Read next directory entry. + * + * When working with text consoles, please note that file names returned by + * readdir() are represented in the default ANSI code page while any output to + * console is typically formatted on another code page. Thus, non-ASCII + * characters in file names will not usually display correctly on console. The + * problem can be fixed in two ways: (1) change the character set of console + * to 1252 using chcp utility and use Lucida Console font, or (2) use + * _cprintf function when writing to console. The _cprinf() will re-encode + * ANSI strings to the console code page so many non-ASCII characters will + * display correcly. + */ + static struct dirent* + readdir( + DIR *dirp) + { + WIN32_FIND_DATAW *datap; + struct dirent *entp; + + /* Read next directory entry */ + datap = dirent_next(dirp->wdirp); + if (datap) { + size_t n; + int error; + + /* Attempt to convert file name to multi-byte string */ + error = dirent_wcstombs_s( + &n, dirp->ent.d_name, PATH_MAX, datap->cFileName, PATH_MAX); + + /* + * If the file name cannot be represented by a multi-byte string, + * then attempt to use old 8+3 file name. This allows traditional + * Unix-code to access some file names despite of unicode + * characters, although file names may seem unfamiliar to the user. + * + * Be ware that the code below cannot come up with a short file + * name unless the file system provides one. At least + * VirtualBox shared folders fail to do this. + */ + if (error && datap->cAlternateFileName[0] != '\0') { + error = dirent_wcstombs_s( + &n, dirp->ent.d_name, PATH_MAX, + datap->cAlternateFileName, PATH_MAX); + } + + if (!error) { + DWORD attr; + + /* Initialize directory entry for return */ + entp = &dirp->ent; + + /* Length of file name excluding zero terminator */ + entp->d_namlen = n - 1; + + /* File attributes */ + attr = datap->dwFileAttributes; + if ((attr & FILE_ATTRIBUTE_DEVICE) != 0) { + entp->d_type = DT_CHR; + } + else if ((attr & FILE_ATTRIBUTE_DIRECTORY) != 0) { + entp->d_type = DT_DIR; + } + else { + entp->d_type = DT_REG; + } + + /* Reset dummy fields */ + entp->d_ino = 0; + entp->d_reclen = sizeof(struct dirent); + + } + else { + /* + * Cannot convert file name to multi-byte string so construct + * an errornous directory entry and return that. Note that + * we cannot return NULL as that would stop the processing + * of directory entries completely. + */ + entp = &dirp->ent; + entp->d_name[0] = '?'; + entp->d_name[1] = '\0'; + entp->d_namlen = 1; + entp->d_type = DT_UNKNOWN; + entp->d_ino = 0; + entp->d_reclen = 0; + } + + } + else { + /* No more directory entries */ + entp = NULL; + } + + return entp; + } + + /* + * Close directory stream. + */ + static int + closedir( + DIR *dirp) + { + int ok; + if (dirp) { + + /* Close wide-character directory stream */ + ok = _wclosedir(dirp->wdirp); + dirp->wdirp = NULL; + + /* Release multi-byte character version */ + free(dirp); + + } + else { + + /* Invalid directory stream */ + dirent_set_errno(EBADF); + ok = /*failure*/-1; + + } + return ok; + } + + /* + * Rewind directory stream to beginning. + */ + static void + rewinddir( + DIR* dirp) + { + /* Rewind wide-character string directory stream */ + _wrewinddir(dirp->wdirp); + } + + /* Convert multi-byte string to wide character string */ + static int + dirent_mbstowcs_s( + size_t *pReturnValue, + wchar_t *wcstr, + size_t sizeInWords, + const char *mbstr, + size_t count) + { + int error; + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + + /* Microsoft Visual Studio 2005 or later */ + error = mbstowcs_s(pReturnValue, wcstr, sizeInWords, mbstr, count); + +#else + + /* Older Visual Studio or non-Microsoft compiler */ + size_t n; + + /* Convert to wide-character string (or count characters) */ + n = mbstowcs(wcstr, mbstr, sizeInWords); + if (!wcstr || n < count) { + + /* Zero-terminate output buffer */ + if (wcstr && sizeInWords) { + if (n >= sizeInWords) { + n = sizeInWords - 1; + } + wcstr[n] = 0; + } + + /* Length of resuting multi-byte string WITH zero terminator */ + if (pReturnValue) { + *pReturnValue = n + 1; + } + + /* Success */ + error = 0; + + } + else { + + /* Could not convert string */ + error = 1; + + } + +#endif + + return error; + } + + /* Convert wide-character string to multi-byte string */ + static int + dirent_wcstombs_s( + size_t *pReturnValue, + char *mbstr, + size_t sizeInBytes, /* max size of mbstr */ + const wchar_t *wcstr, + size_t count) + { + int error; + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + + /* Microsoft Visual Studio 2005 or later */ + error = wcstombs_s(pReturnValue, mbstr, sizeInBytes, wcstr, count); + +#else + + /* Older Visual Studio or non-Microsoft compiler */ + size_t n; + + /* Convert to multi-byte string (or count the number of bytes needed) */ + n = wcstombs(mbstr, wcstr, sizeInBytes); + if (!mbstr || n < count) { + + /* Zero-terminate output buffer */ + if (mbstr && sizeInBytes) { + if (n >= sizeInBytes) { + n = sizeInBytes - 1; + } + mbstr[n] = '\0'; + } + + /* Lenght of resulting multi-bytes string WITH zero-terminator */ + if (pReturnValue) { + *pReturnValue = n + 1; + } + + /* Success */ + error = 0; + + } + else { + + /* Cannot convert string */ + error = 1; + + } + +#endif + + return error; + } + + /* Set errno variable */ + static void + dirent_set_errno( + int error) + { +#if defined(_MSC_VER) && _MSC_VER >= 1400 + + /* Microsoft Visual Studio 2005 and later */ + _set_errno(error); + +#else + + /* Non-Microsoft compiler or older Microsoft compiler */ + errno = error; + +#endif + } + + +#ifdef __cplusplus +} +#endif +#include +#else +#include +// #include +#endif + +#ifdef USE_OMP +#include +#else +int omp_get_max_threads(); +int omp_get_thread_num(); +// int omp_set_num_threads(int); +#endif + + +namespace yaed { + +typedef std::vector VP; +typedef std::vector< VP > VVP; +typedef unsigned int uint; + + +void _list_dir(std::string dir, std::vector& files, std::string suffixs = "", bool r = false); + +std::vector _split(const std::string& srcstr, const std::string& delimeter); +bool _startswith(const std::string& str, const std::string& start); +bool _endswith(const std::string& str, const std::string& end); +void _randperm(int n, int m, int arr[], bool sort_ = true); + +/***************** math-related functions ****************/ +float _atan2(float y, float x); +void _mean_std(std::vector& vec, float& mean, float& std); +int inline _sgn(float val) { return (0.f < val) - (val < 0.f); } +float inline _ed2(const cv::Point& A, const cv::Point& B) +{ + return float(((B.x - A.x)*(B.x - A.x) + (B.y - A.y)*(B.y - A.y))); +} +float _get_min_angle_PI(float alpha, float beta); + +double inline _tic() +{ + return (double)cv::getTickCount(); +} +double inline _toc(double tic) // ms +{ + return ((double)cv::getTickCount() - tic)*1000. / cv::getTickFrequency(); +} +inline int _isnan(double x) { return x != x; } + + +void _tag_canny(cv::InputArray image, cv::OutputArray _edges, + cv::OutputArray _sobel_x, cv::OutputArray _sobel_y, + int apertureSize, bool L2gradient, double percent_ne); + + +void _find_contours_oneway(cv::Mat1b& image, VVP& segments, int iMinLength); +void _find_contours_eight(cv::Mat1b& image, std::vector& segments, int iMinLength); +void _show_contours_eight(cv::Mat1b& image, std::vector& segments, const char* title); + +void _tag_find_contours(cv::Mat1b& image, VVP& segments, int iMinLength); +void _tag_show_contours(cv::Mat1b& image, VVP& segments, const char* title); +void _tag_show_contours(cv::Size& imsz, VVP& segments, const char* title); + +bool _SortBottomLeft2TopRight(const cv::Point& lhs, const cv::Point& rhs); +bool _SortBottomLeft2TopRight2f(const cv::Point2f& lhs, const cv::Point2f& rhs); +bool _SortTopLeft2BottomRight(const cv::Point& lhs, const cv::Point& rhs); + + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif +#define M_2__PI 6.28318530718 +#define M_1_2_PI 1.57079632679 + +// Elliptical struct definition +class Ellipse +{ +public: + float xc_; + float yc_; + float a_; + float b_; + float rad_; + float score_; + + // Elliptic General equations Ax^2 + Bxy + Cy^2 + Dx + Ey + 1 = 0 + float A_; + float B_; + float C_; + float D_; + float E_; + float F_; + + Ellipse() : xc_(0.f), yc_(0.f), a_(0.f), b_(0.f), rad_(0.f), score_(0.f), + A_(0.f), B_(0.f), C_(0.f), D_(0.f), E_(0.f), F_(1.f) {} + Ellipse(float xc, float yc, float a, float b, float rad, float score = 0.f) : xc_(xc), yc_(yc), a_(a), b_(b), rad_(rad), score_(score) {} + Ellipse(const Ellipse& other) : xc_(other.xc_), yc_(other.yc_), a_(other.a_), b_(other.b_), rad_(other.rad_), score_(other.score_), + A_(other.A_), B_(other.B_), C_(other.C_), D_(other.D_), E_(other.E_) {} + + void Draw(cv::Mat& img, const cv::Scalar& color, const int thickness) + { + if (IsValid()) + ellipse(img, cv::Point(cvRound(xc_), cvRound(yc_)), cv::Size(cvRound(a_), cvRound(b_)), rad_ * 180.0 / CV_PI, 0.0, 360.0, color, thickness); + } + + void Draw(cv::Mat3b& img, const int thickness) + { + cv::Scalar color(0, cvFloor(255.f * score_), 0); + if (IsValid()) + ellipse(img, cv::Point(cvRound(xc_), cvRound(yc_)), cv::Size(cvRound(a_), cvRound(b_)), rad_ * 180.0 / CV_PI, 0.0, 360.0, color, thickness); + } + + bool operator<(const Ellipse& other) const + { // use for sorting + if (score_ == other.score_) + { + float lhs_e = b_ / a_; + float rhs_e = other.b_ / other.a_; + if (lhs_e == rhs_e) + { + return false; + } + return lhs_e > rhs_e; + } + return score_ > other.score_; + } + + // Elliptic General equations Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0 + void TransferFromGeneral() { + float denominator = (B_*B_ - 4 * A_*C_); + + xc_ = (2 * C_*D_ - B_*E_) / denominator; + yc_ = (2 * A_*E_ - B_*D_) / denominator; + + float pre = 2 * (A_*E_*E_ + C_*D_*D_ - B_*D_*E_ + denominator*F_); + float lst = sqrt((A_ - C_)*(A_ - C_) + B_*B_); + + a_ = -sqrt(pre*(A_ + C_ + lst)) / denominator; + b_ = -sqrt(pre*(A_ + C_ - lst)) / denominator; + + if (B_ == 0 && A_C_) + rad_ = CV_PI / 2; + else + rad_ = atan((C_ - A_ - lst) / B_); + } + + // Elliptic General equations Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0 + void TransferToGeneral() { + A_ = a_*a_*sin(rad_)*sin(rad_) + b_*b_*cos(rad_)*cos(rad_); + B_ = 2.f*(b_*b_ - a_*a_)*sin(rad_)*cos(rad_); + C_ = a_*a_*cos(rad_)*cos(rad_) + b_*b_*sin(rad_)*sin(rad_); + D_ = -2.f*A_*xc_ - B_*yc_; + E_ = -B_*xc_ - 2.f*C_*yc_; + F_ = A_*xc_*xc_ + B_*xc_*yc_ + C_*yc_*yc_ - a_*a_*b_*b_; + } + + void GetRectangle(cv::Rect& rect) { + float sin_theta = sin(-rad_); + float cos_theta = cos(-rad_); + float A = a_*a_ * sin_theta * sin_theta + b_* b_ * cos_theta * cos_theta; + float B = 2 * (a_* a_ - b_ * b_) * sin_theta * cos_theta; + float C = a_* a_ * cos_theta * cos_theta + b_ * b_ * sin_theta * sin_theta; + float F = - a_ * a_ * b_ * b_; + + float y = sqrt(4 * A * F / (B * B - 4 * A * C)); + float y1 = -abs(y), y2 = abs(y); + float x = sqrt(4 * C * F / (B * B - 4 * C * A)); + float x1 = -abs(x), x2 = abs(x); + + rect = cv::Rect(int(round(xc_ + x1)), int(round(yc_ + y1)), int(round(x2 - x1)), int(round(y2 - y1))); + } + + float Perimeter() { + // return 2*CV_PI*b_ + 4*(a_ - b_); + return CV_PI*(3.f*(a_ + b_) - sqrt((3.f*a_ + b_)*(a_ + 3.f*b_))); + } + + float Area() { + return CV_PI*a_*b_; + } + + bool IsValid() { + bool nan = isnan(xc_) || isnan(yc_) || isnan(a_) || isnan(b_) || isnan(rad_); + return !nan; + } +}; + +// Data available after selection strategy. +// They are kept in an associative array to: +// 1) avoid recomputing data when starting from same arcs +// 2) be reused in firther proprecessing +struct EllipseData +{ + bool isValid; + float ta; // arc_a center line gradient + float tb; // arc_b + float ra; // gradient of a (slope of start of chord_1 and center of chord_2) + float rb; // gradient of b (slope of center of chord_1 and last of chord_2) + cv::Point2f Ma; // arc_a center of element + cv::Point2f Mb; // arc_b + cv::Point2f Cab; // center of ellipse + std::vector Sa; // arc_a's center line of parallel chords + std::vector Sb; // arc_b's center line of parallel chords +}; + +struct EllipseThreePoint +{ + bool isValid; + cv::Point Cab; + VP ArcI; + VP ArcJ; + VP ArcK; +}; + +/********************** EllipseFitting functions **********************/ +void _ellipse_foci(float *param, float *foci); +float _ellipse_normal_angle(float x, float y, float *foci); +float _angle_diff(float a, float b); + +/*************************** CNC functions ****************************/ +float _value4SixPoints(cv::Point2f p3, cv::Point2f p2, cv::Point2f p1, cv::Point2f p4, cv::Point2f p5, cv::Point2f p6); + + +/**************** ellipse-evaluation-related functions ****************/ +void _load_ellipse_GT(const std::string& gt_file_name, std::vector & gt_ellipses, bool is_angle_radians = true); +void _load_ellipse_DT(const std::string& dt_file_name, std::vector & dt_ellipses, bool is_angle_radians = true); + +bool _ellipse_overlap(const cv::Mat1b& gt, const cv::Mat1b& dt, float th); +float _ellipse_overlap_real(const cv::Mat1b& gt, const cv::Mat1b& dt); +int _bool_count(const std::vector vb); +float _ellipse_evaluate_one(const std::vector& ell_gt, const std::vector& ell_dt, const cv::Mat3b& img); +float _ellipse_evaluate(std::vector& image_fns, std::vector& gt_fns, std::vector& dt_fns, + bool gt_angle_radians = true); + + + + +class EllipseDetector +{ + // Parameters + + // Preprocessing - Gaussian filter. See Sect [] in the paper + cv::Size szPreProcessingGaussKernel_; // size of the Gaussian filter in preprocessing step + double dPreProcessingGaussSigma_; // sigma of the Gaussian filter in the preprocessing step + + + // Selection strategy - Step 1 - Discard noisy or straight arcs. See Sect [] in the paper + int iMinEdgeLength_; // minimum edge size + float fMinOrientedRectSide_; // minumum size of the oriented bounding box containing the arc + float fMaxRectAxesRatio_; // maximum aspect ratio of the oriented bounding box containing the arc + + // Selection strategy - Step 2 - Remove according to mutual convexities. See Sect [] in the paper + float fThrArcPosition_; + + // Selection Strategy - Step 3 - Number of points considered for slope estimation when estimating the center. See Sect [] in the paper + unsigned uNs_; // Find at most Ns parallel chords. + + // Selection strategy - Step 3 - Discard pairs of arcs if their estimated center is not close enough. See Sect [] in the paper + float fMaxCenterDistance_; // maximum distance in pixel between 2 center points + float fMaxCenterDistance2_; // _fMaxCenterDistance * _fMaxCenterDistance + + // Validation - Points within a this threshold are considered to lie on the ellipse contour. See Sect [] in the paper + float fDistanceToEllipseContour_; // maximum distance between a point and the contour. See equation [] in the paper + + // Validation - Assign a score. See Sect [] in the paper + float fMinScore_; // minimum score to confirm a detection + float fMinReliability_; // minimum auxiliary score to confirm a detection + + double dPercentNe_; + + float fT_CNC_; + float fT_TCN_L_; // filter lines + float fT_TCN_P_; + float fThre_r_; + + // auxiliary variables + cv::Size szIm_; // input image size + + std::vector times_; // times_ is a vector containing the execution time of each step. + + int ACC_N_SIZE; // size of accumulator N = B/A + int ACC_R_SIZE; // size of accumulator R = rho = atan(K) + int ACC_A_SIZE; // size of accumulator A + + int* accN; // pointer to accumulator N + int* accR; // pointer to accumulator R + int* accA; // pointer to accumulator A + + cv::Mat1f EO_; + + VVP points_1, points_2, points_3, points_4; // vector of points, one for each convexity class + +public: + + // Constructor and Destructor + EllipseDetector(void); + ~EllipseDetector(void); + + // Detect the ellipses in the gray image + void Detect(cv::Mat3b& I, std::vector& ellipses); + void Detect(cv::Mat& I, std::vector& ellipses); + + // Draw the first iTopN ellipses on output + void DrawDetectedEllipses(cv::Mat& output, std::vector& ellipses, int iTopN = 0, int thickness = 2); + + // Set the parameters of the detector + void SetParameters(cv::Size szPreProcessingGaussKernelSize, + double dPreProcessingGaussSigma, + float fThPosition, + float fMaxCenterDistance, + int iMinEdgeLength, + float fMinOrientedRectSide, + float fDistanceToEllipseContour, + float fMinScore, + float fMinReliability, + int iNs, + double dPercentNe, + float fT_CNC, + float fT_TCN_L, + float fT_TCN_P, + float fThre_r + ); + + void SetMCD(float fMaxCenterDistance); + + // Return the execution time + double GetExecTime() { + double time_all(0); + for (size_t i = 0; i < times_.size(); i++) time_all += times_[i]; + return time_all; + } + std::vector GetTimes() { return times_; } + + float countOfFindEllipse_; + float countOfGetFastCenter_; + +private: + + // keys for hash table + static const ushort PAIR_12 = 0x00; + static const ushort PAIR_23 = 0x01; + static const ushort PAIR_34 = 0x02; + static const ushort PAIR_14 = 0x03; + + // generate keys from pair and indicse + uint inline GenerateKey(uchar pair, ushort u, ushort v); + + void PreProcessing(cv::Mat1b& I, cv::Mat1b& arcs8); + void RemoveStraightLine(VVP& segments, VVP& segments_update, int id = 0); + void PreProcessing(cv::Mat1b& I, cv::Mat1b& DP, cv::Mat1b& DN); + + void ClusterEllipses(std::vector& ellipses); + + // int FindMaxK(const std::vector& v) const; + // int FindMaxN(const std::vector& v) const; + // int FindMaxA(const std::vector& v) const; + + int FindMaxK(const int* v) const; + int FindMaxN(const int* v) const; + int FindMaxA(const int* v) const; + + float GetMedianSlope(std::vector& med, cv::Point2f& M, std::vector& slopes); + void GetFastCenter(std::vector& e1, std::vector& e2, EllipseData& data); + float GetMinAnglePI(float alpha, float beta); + + void DetectEdges13(cv::Mat1b& DP, VVP& points_1, VVP& points_3); + void DetectEdges24(cv::Mat1b& DN, VVP& points_2, VVP& points_4); + + void ArcsCheck1234(VVP& points_1, VVP& points_2, VVP& points_3, VVP& points_4); + + void FindEllipses(cv::Point2f& center, + VP& edge_i, + VP& edge_j, + VP& edge_k, + EllipseData& data_ij, + EllipseData& data_ik, + Ellipse& ell + ); + + cv::Point2f GetCenterCoordinates(EllipseData& data_ij, EllipseData& data_ik); + + void Triplets124(VVP& pi, + VVP& pj, + VVP& pk, + std::unordered_map& data, + std::vector& ellipses + ); + + void Triplets231(VVP& pi, + VVP& pj, + VVP& pk, + std::unordered_map& data, + std::vector& ellipses + ); + + void Triplets342(VVP& pi, + VVP& pj, + VVP& pk, + std::unordered_map& data, + std::vector& ellipses + ); + + void Triplets413(VVP& pi, + VVP& pj, + VVP& pk, + std::unordered_map& data, + std::vector& ellipses + ); + + void Tic(unsigned idx = 0) //start + { + while (idx >= timesSign_.size()) { + timesSign_.push_back(0); + times_.push_back(.0); + } + timesSign_[idx] = 0; + timesSign_[idx]++; + times_[idx] = (double)cv::getTickCount(); + } + + void Toc(unsigned idx = 0, std::string step = "") //stop + { + assert(timesSign_[idx] == 1); + timesSign_[idx]++; + times_[idx] = ((double)cv::getTickCount() - times_[idx])*1000. / cv::getTickFrequency(); + // #ifdef DEBUG_SPEED + std::cout << "Cost time: " << times_[idx] << " ms [" << idx << "] - " << step << std::endl; + if (idx == times_.size() - 1) + std::cout << "Totally cost time: " << this->GetExecTime() << " ms" << std::endl; + // #endif + } + +private: + std::vector timesSign_; +}; + +} + +#endif // SPIRE_ELLIPSEDETECTOR_H diff --git a/algorithm/landing_det/cuda/landing_det_cuda_impl.cpp b/algorithm/landing_det/cuda/landing_det_cuda_impl.cpp new file mode 100644 index 0000000..a6fa04a --- /dev/null +++ b/algorithm/landing_det/cuda/landing_det_cuda_impl.cpp @@ -0,0 +1,160 @@ +#include "landing_det_cuda_impl.h" +#include +#include + +#define SV_MODEL_DIR "/SpireCV/models/" +#define SV_ROOT_DIR "/SpireCV/" + + +#ifdef WITH_CUDA +#include "yolov7/logging.h" +#define TRTCHECK(status) \ + do \ + { \ + auto ret = (status); \ + if (ret != 0) \ + { \ + std::cerr << "Cuda failure: " << ret << std::endl; \ + abort(); \ + } \ + } while (0) + +#define DEVICE 0 // GPU id +#define BATCH_SIZE 1 +#define MAX_IMAGE_INPUT_SIZE_THRESH 3000 * 3000 // ensure it exceed the maximum size in the input images ! +#endif + + +namespace sv { + +using namespace cv; + + +#ifdef WITH_CUDA +using namespace nvinfer1; +static Logger g_nvlogger; +#endif + + +LandingMarkerDetectorCUDAImpl::LandingMarkerDetectorCUDAImpl() +{ +} + + +LandingMarkerDetectorCUDAImpl::~LandingMarkerDetectorCUDAImpl() +{ +} + + +bool LandingMarkerDetectorCUDAImpl::cudaSetup() +{ +#ifdef WITH_CUDA + std::string trt_model_fn = get_home() + SV_MODEL_DIR + "LandingMarker.engine"; + if (!is_file_exist(trt_model_fn)) + { + throw std::runtime_error("SpireCV (104) Error loading the LandingMarker TensorRT model (File Not Exist)"); + } + char *trt_model_stream{nullptr}; + size_t trt_model_size{0}; + try + { + std::ifstream file(trt_model_fn, std::ios::binary); + file.seekg(0, file.end); + trt_model_size = file.tellg(); + file.seekg(0, file.beg); + trt_model_stream = new char[trt_model_size]; + assert(trt_model_stream); + file.read(trt_model_stream, trt_model_size); + file.close(); + } + catch (const std::runtime_error &e) + { + throw std::runtime_error("SpireCV (104) Error loading the TensorRT model!"); + } + + // TensorRT + IRuntime *runtime = nvinfer1::createInferRuntime(g_nvlogger); + assert(runtime != nullptr); + ICudaEngine *p_cu_engine = runtime->deserializeCudaEngine(trt_model_stream, trt_model_size); + assert(p_cu_engine != nullptr); + this->_trt_context = p_cu_engine->createExecutionContext(); + assert(this->_trt_context != nullptr); + + delete[] trt_model_stream; + const ICudaEngine &cu_engine = this->_trt_context->getEngine(); + assert(cu_engine.getNbBindings() == 2); + + this->_input_index = cu_engine.getBindingIndex("input"); + this->_output_index = cu_engine.getBindingIndex("output"); + TRTCHECK(cudaMalloc(&_p_buffers[this->_input_index], 1 * 3 * 32 * 32 * sizeof(float))); + TRTCHECK(cudaMalloc(&_p_buffers[this->_output_index], 1 * 11 * sizeof(float))); + TRTCHECK(cudaStreamCreate(&_cu_stream)); + + auto input_dims = nvinfer1::Dims4{1, 3, 32, 32}; + this->_trt_context->setBindingDimensions(this->_input_index, input_dims); + + this->_p_data = new float[1 * 3 * 32 * 32]; + this->_p_prob = new float[1 * 11]; + // Input + TRTCHECK(cudaMemcpyAsync(_p_buffers[this->_input_index], this->_p_data, 1 * 3 * 32 * 32 * sizeof(float), cudaMemcpyHostToDevice, this->_cu_stream)); + // this->_trt_context->enqueue(1, _p_buffers, this->_cu_stream, nullptr); + this->_trt_context->enqueueV2(_p_buffers, this->_cu_stream, nullptr); + // Output + TRTCHECK(cudaMemcpyAsync(this->_p_prob, _p_buffers[this->_output_index], 1 * 11 * sizeof(float), cudaMemcpyDeviceToHost, this->_cu_stream)); + cudaStreamSynchronize(this->_cu_stream); + return true; +#endif + return false; +} + + +void LandingMarkerDetectorCUDAImpl::cudaRoiCNN( + std::vector& input_rois_, + std::vector& output_labels_ +) +{ +#ifdef WITH_CUDA + output_labels_.clear(); + for (int i=0; i_p_data[col + row * 32] = ((float)uc_pixel[0] - 136.20f) / 44.77f; + this->_p_data[col + row * 32 + 32 * 32] = ((float)uc_pixel[1] - 141.50f) / 44.20f; + this->_p_data[col + row * 32 + 32 * 32 * 2] = ((float)uc_pixel[2] - 145.41f) / 44.30f; + uc_pixel += 3; + } + } + + // Input + TRTCHECK(cudaMemcpyAsync(_p_buffers[this->_input_index], this->_p_data, 1 * 3 * 32 * 32 * sizeof(float), cudaMemcpyHostToDevice, this->_cu_stream)); + // this->_trt_context->enqueue(1, _p_buffers, this->_cu_stream, nullptr); + this->_trt_context->enqueueV2(_p_buffers, this->_cu_stream, nullptr); + // Output + TRTCHECK(cudaMemcpyAsync(this->_p_prob, _p_buffers[this->_output_index], 1 * 11 * sizeof(float), cudaMemcpyDeviceToHost, this->_cu_stream)); + cudaStreamSynchronize(this->_cu_stream); + + // Find max index + double max = 0; + int label = 0; + for (int i = 0; i < 11; ++i) + { + if (max < this->_p_prob[i]) + { + max = this->_p_prob[i]; + label = i; + } + } + output_labels_.push_back(label); + } +#endif +} + + +} + diff --git a/algorithm/landing_det/cuda/landing_det_cuda_impl.h b/algorithm/landing_det/cuda/landing_det_cuda_impl.h new file mode 100644 index 0000000..69f8b69 --- /dev/null +++ b/algorithm/landing_det/cuda/landing_det_cuda_impl.h @@ -0,0 +1,48 @@ +#ifndef __SV_LANDING_DET_CUDA__ +#define __SV_LANDING_DET_CUDA__ + +#include "sv_core.h" +#include +#include +#include +#include +#include + + + +#ifdef WITH_CUDA +#include +#include +#endif + + + +namespace sv { + + +class LandingMarkerDetectorCUDAImpl +{ +public: + LandingMarkerDetectorCUDAImpl(); + ~LandingMarkerDetectorCUDAImpl(); + + bool cudaSetup(); + void cudaRoiCNN( + std::vector& input_rois_, + std::vector& output_labels_ + ); + +#ifdef WITH_CUDA + float *_p_data; + float *_p_prob; + nvinfer1::IExecutionContext *_trt_context; + int _input_index; + int _output_index; + void *_p_buffers[2]; + cudaStream_t _cu_stream; +#endif +}; + + +} +#endif diff --git a/algorithm/landing_det/sv_landing_det.cpp b/algorithm/landing_det/sv_landing_det.cpp new file mode 100644 index 0000000..dd0f732 --- /dev/null +++ b/algorithm/landing_det/sv_landing_det.cpp @@ -0,0 +1,48 @@ +#include "sv_landing_det.h" +#include +#include +#ifdef WITH_CUDA +#include +#include +#include "landing_det_cuda_impl.h" +#endif + + +namespace sv { + + +LandingMarkerDetector::LandingMarkerDetector() +{ + this->_cuda_impl = new LandingMarkerDetectorCUDAImpl; +} +LandingMarkerDetector::~LandingMarkerDetector() +{ +} + +bool LandingMarkerDetector::setupImpl() +{ +#ifdef WITH_CUDA + return this->_cuda_impl->cudaSetup(); +#endif + return false; +} + +void LandingMarkerDetector::roiCNN( + std::vector& input_rois_, + std::vector& output_labels_ +) +{ +#ifdef WITH_CUDA + this->_cuda_impl->cudaRoiCNN( + input_rois_, + output_labels_ + ); +#endif +} + + + + + +} + diff --git a/algorithm/sv_algorithm_base.cpp b/algorithm/sv_algorithm_base.cpp new file mode 100644 index 0000000..f99142f --- /dev/null +++ b/algorithm/sv_algorithm_base.cpp @@ -0,0 +1,1142 @@ +#include "sv_algorithm_base.h" +#include +#include +#include +#include +#include "gason.h" +#include +#include "sv_util.h" +#include "ellipse_detector.h" + +#define SV_MODEL_DIR "/SpireCV/models/" +#define SV_ROOT_DIR "/SpireCV/" + + +namespace sv { + +using namespace cv; +using namespace cv::dnn; + + +void _cameraMatrix2Fov(cv::Mat camera_matrix_, int img_w_, int img_h_, double& fov_x_, double& fov_y_) +{ + fov_x_ = 2 * atan(img_w_ / 2. / camera_matrix_.at(0, 0)) * SV_RAD2DEG; + fov_y_ = 2 * atan(img_h_ / 2. / camera_matrix_.at(1, 1)) * SV_RAD2DEG; +} + + + +CameraAlgorithm::CameraAlgorithm() +{ + // this->_value = NULL; + // this->_allocator = NULL; + this->_t0 = std::chrono::system_clock::now(); + + this->alg_params_fn = _get_home() + SV_ROOT_DIR + "sv_algorithm_params.json"; + // std::cout << "CameraAlgorithm->alg_params_fn: " << this->alg_params_fn << std::endl; + // if (_is_file_exist(params_fn)) + // this->loadAlgorithmParams(params_fn); +} +CameraAlgorithm::~CameraAlgorithm() +{ + // if (_value) delete _value; + // if (_allocator) delete _allocator; +} + +void CameraAlgorithm::loadCameraParams(std::string yaml_fn_) +{ + cv::FileStorage fs(yaml_fn_, cv::FileStorage::READ); + if (!fs.isOpened()) + { + throw std::runtime_error("SpireCV (104) Camera calib file NOT exist!"); + } + fs["camera_matrix"] >> this->camera_matrix; + fs["distortion_coefficients"] >> this->distortion; + fs["image_width"] >> this->image_width; + fs["image_height"] >> this->image_height; + + if (this->camera_matrix.rows != 3 || this->camera_matrix.cols != 3 || + this->distortion.rows != 1 || this->distortion.cols != 5 || + this->image_width == 0 || this->image_height == 0) + { + throw std::runtime_error("SpireCV (106) Camera parameters reading ERROR!"); + } + + _cameraMatrix2Fov(this->camera_matrix, this->image_width, this->image_height, this->fov_x, this->fov_y); + // std::cout << this->fov_x << ", " << this->fov_y << std::endl; +} + + +void CameraAlgorithm::loadAlgorithmParams(std::string json_fn_) +{ + this->alg_params_fn = json_fn_; +} + + +ArucoDetector::ArucoDetector() +{ + _params_loaded = false; + _dictionary = nullptr; +} + + +void ArucoDetector::_load() +{ + JsonValue all_value; + JsonAllocator allocator; + std::cout << "Load: [" << this->alg_params_fn << "]" << std::endl; + _load_all_json(this->alg_params_fn, all_value, allocator); + + JsonValue aruco_params_value; + _parser_algorithm_params("ArucoDetector", all_value, aruco_params_value); + + _dictionary_id = 10; + // _detector_params = aruco::DetectorParameters::create(); + _detector_params = new aruco::DetectorParameters; + for (auto i : aruco_params_value) { + if ("_dictionary_id" == std::string(i->key)) { + _dictionary_id = i->value.toNumber(); + } + else if ("adaptiveThreshConstant" == std::string(i->key)) { + // std::cout << "adaptiveThreshConstant (old, new): " << _detector_params->adaptiveThreshConstant << ", " << i->value.toNumber() << std::endl; + _detector_params->adaptiveThreshConstant = i->value.toNumber(); + } + else if ("adaptiveThreshWinSizeMax" == std::string(i->key)) { + // std::cout << "adaptiveThreshWinSizeMax (old, new): " << _detector_params->adaptiveThreshWinSizeMax << ", " << i->value.toNumber() << std::endl; + _detector_params->adaptiveThreshWinSizeMax = i->value.toNumber(); + } + else if ("adaptiveThreshWinSizeMin" == std::string(i->key)) { + // std::cout << "adaptiveThreshWinSizeMin (old, new): " << _detector_params->adaptiveThreshWinSizeMin << ", " << i->value.toNumber() << std::endl; + _detector_params->adaptiveThreshWinSizeMin = i->value.toNumber(); + } + else if ("adaptiveThreshWinSizeStep" == std::string(i->key)) { + // std::cout << "adaptiveThreshWinSizeStep (old, new): " << _detector_params->adaptiveThreshWinSizeStep << ", " << i->value.toNumber() << std::endl; + _detector_params->adaptiveThreshWinSizeStep = i->value.toNumber(); + } + else if ("aprilTagCriticalRad" == std::string(i->key)) { + // std::cout << "aprilTagCriticalRad (old, new): " << _detector_params->aprilTagCriticalRad << ", " << i->value.toNumber() << std::endl; + _detector_params->aprilTagCriticalRad = i->value.toNumber(); + } + else if ("aprilTagDeglitch" == std::string(i->key)) { + // std::cout << "aprilTagDeglitch (old, new): " << _detector_params->aprilTagDeglitch << ", " << i->value.toNumber() << std::endl; + _detector_params->aprilTagDeglitch = i->value.toNumber(); + } + else if ("aprilTagMaxLineFitMse" == std::string(i->key)) { + // std::cout << "aprilTagMaxLineFitMse (old, new): " << _detector_params->aprilTagMaxLineFitMse << ", " << i->value.toNumber() << std::endl; + _detector_params->aprilTagMaxLineFitMse = i->value.toNumber(); + } + else if ("aprilTagMaxNmaxima" == std::string(i->key)) { + // std::cout << "aprilTagMaxNmaxima (old, new): " << _detector_params->aprilTagMaxNmaxima << ", " << i->value.toNumber() << std::endl; + _detector_params->aprilTagMaxNmaxima = i->value.toNumber(); + } + else if ("aprilTagMinClusterPixels" == std::string(i->key)) { + // std::cout << "aprilTagMinClusterPixels (old, new): " << _detector_params->aprilTagMinClusterPixels << ", " << i->value.toNumber() << std::endl; + _detector_params->aprilTagMinClusterPixels = i->value.toNumber(); + } + else if ("aprilTagMinWhiteBlackDiff" == std::string(i->key)) { + // std::cout << "aprilTagMinWhiteBlackDiff (old, new): " << _detector_params->aprilTagMinWhiteBlackDiff << ", " << i->value.toNumber() << std::endl; + _detector_params->aprilTagMinWhiteBlackDiff = i->value.toNumber(); + } + else if ("aprilTagQuadDecimate" == std::string(i->key)) { + // std::cout << "aprilTagQuadDecimate (old, new): " << _detector_params->aprilTagQuadDecimate << ", " << i->value.toNumber() << std::endl; + _detector_params->aprilTagQuadDecimate = i->value.toNumber(); + } + else if ("aprilTagQuadSigma" == std::string(i->key)) { + // std::cout << "aprilTagQuadSigma (old, new): " << _detector_params->aprilTagQuadSigma << ", " << i->value.toNumber() << std::endl; + _detector_params->aprilTagQuadSigma = i->value.toNumber(); + } + else if ("cornerRefinementMaxIterations" == std::string(i->key)) { + // std::cout << "cornerRefinementMaxIterations (old, new): " << _detector_params->cornerRefinementMaxIterations << ", " << i->value.toNumber() << std::endl; + _detector_params->cornerRefinementMaxIterations = i->value.toNumber(); + } + else if ("cornerRefinementMethod" == std::string(i->key)) { + // std::cout << "cornerRefinementMethod (old, new): " << _detector_params->cornerRefinementMethod << ", " << i->value.toNumber() << std::endl; + // _detector_params->cornerRefinementMethod = i->value.toNumber(); + int method = (int) i->value.toNumber(); + if (method == 1) + { + _detector_params->cornerRefinementMethod = cv::aruco::CornerRefineMethod::CORNER_REFINE_SUBPIX; + } + else if (method == 2) + { + _detector_params->cornerRefinementMethod = cv::aruco::CornerRefineMethod::CORNER_REFINE_CONTOUR; + } + else if (method == 3) + { + _detector_params->cornerRefinementMethod = cv::aruco::CornerRefineMethod::CORNER_REFINE_APRILTAG; + } + else + { + _detector_params->cornerRefinementMethod = cv::aruco::CornerRefineMethod::CORNER_REFINE_NONE; + } + } + else if ("cornerRefinementMinAccuracy" == std::string(i->key)) { + // std::cout << "cornerRefinementMinAccuracy (old, new): " << _detector_params->cornerRefinementMinAccuracy << ", " << i->value.toNumber() << std::endl; + _detector_params->cornerRefinementMinAccuracy = i->value.toNumber(); + } + else if ("cornerRefinementWinSize" == std::string(i->key)) { + // std::cout << "cornerRefinementWinSize (old, new): " << _detector_params->cornerRefinementWinSize << ", " << i->value.toNumber() << std::endl; + _detector_params->cornerRefinementWinSize = i->value.toNumber(); + } + else if ("detectInvertedMarker" == std::string(i->key)) { + bool json_tf = false; + if (i->value.getTag() == JSON_TRUE) json_tf = true; + // std::cout << "detectInvertedMarker (old, new): " << _detector_params->detectInvertedMarker << ", " << json_tf << std::endl; + _detector_params->detectInvertedMarker = json_tf; + } + else if ("errorCorrectionRate" == std::string(i->key)) { + // std::cout << "errorCorrectionRate (old, new): " << _detector_params->errorCorrectionRate << ", " << i->value.toNumber() << std::endl; + _detector_params->errorCorrectionRate = i->value.toNumber(); + } + else if ("markerBorderBits" == std::string(i->key)) { + // std::cout << "markerBorderBits (old, new): " << _detector_params->markerBorderBits << ", " << i->value.toNumber() << std::endl; + _detector_params->markerBorderBits = i->value.toNumber(); + } + else if ("maxErroneousBitsInBorderRate" == std::string(i->key)) { + // std::cout << "maxErroneousBitsInBorderRate (old, new): " << _detector_params->maxErroneousBitsInBorderRate << ", " << i->value.toNumber() << std::endl; + _detector_params->maxErroneousBitsInBorderRate = i->value.toNumber(); + } + else if ("maxMarkerPerimeterRate" == std::string(i->key)) { + // std::cout << "maxMarkerPerimeterRate (old, new): " << _detector_params->maxMarkerPerimeterRate << ", " << i->value.toNumber() << std::endl; + _detector_params->maxMarkerPerimeterRate = i->value.toNumber(); + } + else if ("minCornerDistanceRate" == std::string(i->key)) { + // std::cout << "minCornerDistanceRate (old, new): " << _detector_params->minCornerDistanceRate << ", " << i->value.toNumber() << std::endl; + _detector_params->minCornerDistanceRate = i->value.toNumber(); + } + else if ("minDistanceToBorder" == std::string(i->key)) { + // std::cout << "minDistanceToBorder (old, new): " << _detector_params->minDistanceToBorder << ", " << i->value.toNumber() << std::endl; + _detector_params->minDistanceToBorder = i->value.toNumber(); + } + else if ("minMarkerDistanceRate" == std::string(i->key)) { + // std::cout << "minMarkerDistanceRate (old, new): " << _detector_params->minMarkerDistanceRate << ", " << i->value.toNumber() << std::endl; + _detector_params->minMarkerDistanceRate = i->value.toNumber(); + } + else if ("minMarkerLengthRatioOriginalImg" == std::string(i->key)) { + // std::cout << "minMarkerLengthRatioOriginalImg (old, new): " << _detector_params->minMarkerLengthRatioOriginalImg << ", " << i->value.toNumber() << std::endl; + _detector_params->minMarkerLengthRatioOriginalImg = i->value.toNumber(); + } + else if ("minMarkerPerimeterRate" == std::string(i->key)) { + // std::cout << "minMarkerPerimeterRate (old, new): " << _detector_params->minMarkerPerimeterRate << ", " << i->value.toNumber() << std::endl; + _detector_params->minMarkerPerimeterRate = i->value.toNumber(); + } + else if ("minOtsuStdDev" == std::string(i->key)) { + // std::cout << "minOtsuStdDev (old, new): " << _detector_params->minOtsuStdDev << ", " << i->value.toNumber() << std::endl; + _detector_params->minOtsuStdDev = i->value.toNumber(); + } + else if ("minSideLengthCanonicalImg" == std::string(i->key)) { + // std::cout << "minSideLengthCanonicalImg (old, new): " << _detector_params->minSideLengthCanonicalImg << ", " << i->value.toNumber() << std::endl; + _detector_params->minSideLengthCanonicalImg = i->value.toNumber(); + } + else if ("perspectiveRemoveIgnoredMarginPerCell" == std::string(i->key)) { + // std::cout << "perspectiveRemoveIgnoredMarginPerCell (old, new): " << _detector_params->perspectiveRemoveIgnoredMarginPerCell << ", " << i->value.toNumber() << std::endl; + _detector_params->perspectiveRemoveIgnoredMarginPerCell = i->value.toNumber(); + } + else if ("perspectiveRemovePixelPerCell" == std::string(i->key)) { + // std::cout << "perspectiveRemovePixelPerCell (old, new): " << _detector_params->perspectiveRemovePixelPerCell << ", " << i->value.toNumber() << std::endl; + _detector_params->perspectiveRemovePixelPerCell = i->value.toNumber(); + } + else if ("polygonalApproxAccuracyRate" == std::string(i->key)) { + // std::cout << "polygonalApproxAccuracyRate (old, new): " << _detector_params->polygonalApproxAccuracyRate << ", " << i->value.toNumber() << std::endl; + _detector_params->polygonalApproxAccuracyRate = i->value.toNumber(); + } + else if ("useAruco3Detection" == std::string(i->key)) { + bool json_tf = false; + if (i->value.getTag() == JSON_TRUE) json_tf = true; + // std::cout << "useAruco3Detection (old, new): " << _detector_params->useAruco3Detection << ", " << json_tf << std::endl; + _detector_params->useAruco3Detection = json_tf; + } + else if ("markerIds" == std::string(i->key) && i->value.getTag() == JSON_ARRAY) { + int jcnt = 0; + for (auto j : i->value) { + if (jcnt == 0 && j->value.toNumber() == -1) { + _ids_need.push_back(-1); + break; + } + else { + _ids_need.push_back(j->value.toNumber()); + } + } + } + else if ("markerLengths" == std::string(i->key) && i->value.getTag() == JSON_ARRAY) { + for (auto j : i->value) { + if (_ids_need.size() > 0 && _ids_need[0] == -1) { + _lengths_need.push_back(j->value.toNumber()); + break; + } + else { + _lengths_need.push_back(j->value.toNumber()); + } + } + } + } + + if (_ids_need.size() == 0) _ids_need.push_back(-1); + if (_lengths_need.size() != _ids_need.size()) { + throw std::runtime_error("SpireCV (106) Parameter markerIds.length != markerLengths.length!"); + } + + // for (int id : _ids_need) + // std::cout << "_ids_need: " << id << std::endl; + // for (double l : _lengths_need) + // std::cout << "_lengths_need: " << l << std::endl; +} + + +void ArucoDetector::detect(cv::Mat img_, TargetsInFrame& tgts_) +{ + if (!_params_loaded) + { + this->_load(); + _params_loaded = true; + } + if (img_.cols != this->image_width || img_.rows != this->image_height) + { + char msg[256]; + sprintf(msg, "SpireCV (106) Calib camera SIZE(%d) != Input image SIZE(%d)!", this->image_width, img_.cols); + throw std::runtime_error(msg); + } + // std::cout << "_dictionary_id: " << _dictionary_id << std::endl; + // Ptr dictionary = aruco::getPredefinedDictionary(aruco::PredefinedDictionaryType(_dictionary_id)); + if (this->_dictionary == nullptr) + { + this->_dictionary = new aruco::Dictionary; + *(this->_dictionary) = aruco::getPredefinedDictionary(aruco::PredefinedDictionaryType(_dictionary_id)); + } + + std::vector ids, ids_final; + std::vector > corners, corners_final, rejected; + std::vector rvecs, tvecs; + + // detect markers and estimate pose + aruco::detectMarkers(img_, this->_dictionary, corners, ids, _detector_params, rejected); + + + if (ids.size() > 0) + { + if (_ids_need[0] == -1) + { + // std::cout << this->camera_matrix << std::endl; + aruco::estimatePoseSingleMarkers(corners, _lengths_need[0], this->camera_matrix, this->distortion, rvecs, tvecs); + ids_final = ids; + corners_final = corners; + } + else + { + for (int i=0; i<_ids_need.size(); i++) + { + int id_need = _ids_need[i]; + double length_need = _lengths_need[i]; + std::vector t_rvecs, t_tvecs; + std::vector > t_corners; + for (int j=0; j 0) + { + aruco::estimatePoseSingleMarkers(t_corners, length_need, this->camera_matrix, this->distortion, t_rvecs, t_tvecs); + for (auto t_rvec : t_rvecs) + rvecs.push_back(t_rvec); + for (auto t_tvec : t_tvecs) + tvecs.push_back(t_tvec); + } + } + } + } + + // aruco::drawDetectedMarkers(img_, corners_final, ids_final); + tgts_.setSize(img_.cols, img_.rows); + + // tgts_.fov_x = this->fov_x; + // tgts_.fov_y = this->fov_y; + tgts_.setFOV(this->fov_x, this->fov_y); + auto t1 = std::chrono::system_clock::now(); + tgts_.setFPS(1000.0 / std::chrono::duration_cast(t1 - this->_t0).count()); + this->_t0 = std::chrono::system_clock::now(); + tgts_.setTimeNow(); + + if (ids_final.size() > 0) + { + for (int i=0; icamera_matrix); + tgts_.targets.push_back(tgt); + + // Box b; + // tgt.getBox(b); + // cv::circle(img_, cv::Point(int(corners_final[i][0].x), int(corners_final[i][0].y)), 5, cv::Scalar(0,0,255), 2); + // cv::circle(img_, cv::Point(int(corners_final[i][1].x), int(corners_final[i][1].y)), 5, cv::Scalar(255,0,0), 2); + // cv::rectangle(img_, cv::Rect(b.x1, b.y1, b.x2-b.x1+1, b.y2-b.y1+1), cv::Scalar(0,0,255), 1, 1, 0); + } + } + + tgts_.type = MissionType::ARUCO_DET; + + // imshow("img", img_); + // waitKey(10); +} + + + +EllipseDetector::EllipseDetector() +{ + this->_ed = NULL; + this->_max_center_distance_ratio = 0.05f; + this->_params_loaded = false; +} +EllipseDetector::~EllipseDetector() +{ + if (_ed) { delete _ed; _ed = NULL; } +} + +void EllipseDetector::detect(cv::Mat img_, TargetsInFrame& tgts_) +{ + if (!_params_loaded) + { + this->_load(); + _params_loaded = true; + } + + float fMaxCenterDistance = sqrt(float(img_.cols*img_.cols + img_.rows*img_.rows)) * this->_max_center_distance_ratio; + _ed->SetMCD(fMaxCenterDistance); + std::vector ellsCned; + _ed->Detect(img_, ellsCned); + + tgts_.setSize(img_.cols, img_.rows); + tgts_.setFOV(this->fov_x, this->fov_y); + auto t1 = std::chrono::system_clock::now(); + tgts_.setFPS(1000.0 / std::chrono::duration_cast(t1 - this->_t0).count()); + this->_t0 = std::chrono::system_clock::now(); + tgts_.setTimeNow(); + + for (yaed::Ellipse ell : ellsCned) + { + Target tgt; + tgt.setEllipse(ell.xc_, ell.yc_, ell.a_, ell.b_, ell.rad_, ell.score_, tgts_.width, tgts_.height, this->camera_matrix, this->_radius_in_meter); + tgts_.targets.push_back(tgt); + } + + tgts_.type = MissionType::ELLIPSE_DET; +} + +LandingMarkerDetectorBase::LandingMarkerDetectorBase() +{ + // this->_params_loaded = false; + // std::string params_fn = _get_home() + SV_ROOT_DIR + "sv_algorithm_params.json"; + // if (_is_file_exist(params_fn)) + // this->loadAlgorithmParams(params_fn); + setupImpl(); +} + + +LandingMarkerDetectorBase::~LandingMarkerDetectorBase() +{ + +} + +bool LandingMarkerDetectorBase::isParamsLoaded() +{ + return this->_params_loaded; +} +int LandingMarkerDetectorBase::getMaxCandidates() +{ + return this->_max_candidates; +} +std::vector LandingMarkerDetectorBase::getLabelsNeed() +{ + return this->_labels_need; +} + + +void LandingMarkerDetectorBase::detect(cv::Mat img_, TargetsInFrame& tgts_) +{ + if (!_params_loaded) + { + this->_load(); + this->_loadLabels(); + _params_loaded = true; + } + + float fMaxCenterDistance = sqrt(float(img_.cols*img_.cols + img_.rows*img_.rows)) * this->_max_center_distance_ratio; + _ed->SetMCD(fMaxCenterDistance); + std::vector ellsCned; + _ed->Detect(img_, ellsCned); + + tgts_.setSize(img_.cols, img_.rows); + tgts_.setFOV(this->fov_x, this->fov_y); + auto t1 = std::chrono::system_clock::now(); + tgts_.setFPS(1000.0 / std::chrono::duration_cast(t1 - this->_t0).count()); + this->_t0 = std::chrono::system_clock::now(); + tgts_.setTimeNow(); + + static std::vector s_label2str = {"neg", "h", "x", "1", "2", "3", "4", "5", "6", "7", "8"}; + int cand_cnt = 0; + std::vector input_rois; + while (cand_cnt < this->_max_candidates && ellsCned.size() > cand_cnt) + { + yaed::Ellipse e = ellsCned[cand_cnt++]; + + cv::Rect rect; + e.GetRectangle(rect); + int x1 = rect.x; + int y1 = rect.y; + int x2 = rect.x + rect.width; + int y2 = rect.y + rect.height; + if (x1 < 0) x1 = 0; + if (y1 < 0) y1 = 0; + if (x2 > img_.cols - 1) x2 = img_.cols - 1; + if (y2 > img_.rows - 1) y2 = img_.rows - 1; + if (x2 - x1 < 5 || y2 - y1 < 5) continue; + rect.x = x1; + rect.y = y1; + rect.width = x2 - x1; + rect.height = y2 - y1; + + cv::Mat e_roi = img_(rect); + cv::resize(e_roi, e_roi, cv::Size(32, 32)); + + input_rois.push_back(e_roi); + } + + std::vector output_labels; + roiCNN(input_rois, output_labels); + if (input_rois.size() != output_labels.size()) + throw std::runtime_error("SpireCV (106) input_rois.size() != output_labels.size()"); + + for (int i=0; i_labels_need[j] == s_label2str[label]) + { + need = true; + } + } + if (!need) label = 0; + + yaed::Ellipse e = ellsCned[i]; + if (label > 0) + { + Target tgt; + tgt.setEllipse(e.xc_, e.yc_, e.a_, e.b_, e.rad_, e.score_, tgts_.width, tgts_.height, this->camera_matrix, this->_radius_in_meter); + tgt.setCategory(s_label2str[label], label); + tgts_.targets.push_back(tgt); + } + } + + tgts_.type = MissionType::LANDMARK_DET; +} +bool LandingMarkerDetectorBase::setupImpl() +{ + return false; +} +void LandingMarkerDetectorBase::roiCNN(std::vector& input_rois_, std::vector& output_labels_) +{ + +} + + +void LandingMarkerDetectorBase::_loadLabels() +{ + JsonValue all_value; + JsonAllocator allocator; + _load_all_json(this->alg_params_fn, all_value, allocator); + + JsonValue landing_params_value; + _parser_algorithm_params("LandingMarkerDetector", all_value, landing_params_value); + + for (auto i : landing_params_value) { + if ("labels" == std::string(i->key) && i->value.getTag() == JSON_ARRAY) { + for (auto j : i->value) { + this->_labels_need.push_back(j->value.toString()); + } + } + else if ("maxCandidates" == std::string(i->key)) { + this->_max_candidates = i->value.toNumber(); + // std::cout << "maxCandidates: " << this->_max_candidates << std::endl; + } + } + setupImpl(); +} + + +void EllipseDetector::detectAllInDirectory(std::string input_img_dir_, std::string output_json_dir_) +{ + if (!_params_loaded) + { + this->_load(); + _params_loaded = true; + } + + std::vector files; + yaed::_list_dir(input_img_dir_, files, "jpg"); + + for (size_t i=0; i 9 && files[i].substr(8, 1) == "_") + { + label_str = files[i].substr(9, 1); + std::cout << label_str << std::endl; + } + + cv::Mat resultImage = img.clone(); + + float fMaxCenterDistance = sqrt(float(img.cols*img.cols + img.rows*img.rows)) * this->_max_center_distance_ratio; + _ed->SetMCD(fMaxCenterDistance); + std::vector ellsCned; + _ed->Detect(img, ellsCned); + + std::ofstream ofs(output_json_dir_ + "/" + files[i] + ".json"); + std::string inst_str = ""; + int j = 0; + char buf[1024*32]; + for (yaed::Ellipse e : ellsCned) + { + cv::Rect rect; + e.GetRectangle(rect); + cv::rectangle(resultImage, rect, (0,0,255), 1); + + sprintf(buf, "{\"category_name\":\"%s\",\"bbox\":[%d,%d,%d,%d],\"area\":%d,\"score\":%.3f}", label_str.c_str(), rect.x, rect.y, rect.width, rect.height, rect.width*rect.height, e.score_); + inst_str += std::string(buf); + if (j < ellsCned.size() - 1) + inst_str += ","; + // ofs << e.xc_ << "," << e.yc_ << "," << e.a_ << "," << e.b_ << "," << e.rad_ << "," << e.score_ << std::endl; + j++; + } + + sprintf(buf, "{\"file_name\":\"%s\",\"height\":%d,\"width\":%d,\"annos\":[%s]}", files[i].c_str(), img.rows, img.cols, inst_str.c_str()); + ofs << buf << std::endl; + ofs.close(); + + cv::imshow("img", resultImage); + cv::waitKey(100); + } +} + +void EllipseDetector::_load() +{ + JsonValue all_value; + JsonAllocator allocator; + _load_all_json(this->alg_params_fn, all_value, allocator); + + JsonValue ell_params_value; + _parser_algorithm_params("EllipseDetector", all_value, ell_params_value); + + cv::Size szPreProcessingGaussKernel; + double dPreProcessingGaussSigma; + float fThPosition; + float fMaxCenterDistance; + int iMinEdgeLength; + float fMinOrientedRectSide; + float fDistanceToEllipseContour; + float fMinScore; + float fMinReliability; + int iNs; + double dPercentNe; + float fT_CNC; + float fT_TCN_L; + float fT_TCN_P; + float fThre_R; + + for (auto i : ell_params_value) { + if ("preProcessingGaussKernel" == std::string(i->key)) { + int sigma = i->value.toNumber(); + szPreProcessingGaussKernel = cv::Size(sigma, sigma); + // std::cout << "preProcessingGaussKernel: " << sigma << std::endl; + } + else if ("preProcessingGaussSigma" == std::string(i->key)) { + dPreProcessingGaussSigma = i->value.toNumber(); + // std::cout << "preProcessingGaussSigma: " << dPreProcessingGaussSigma << std::endl; + } + else if ("thPosition" == std::string(i->key)) { + fThPosition = i->value.toNumber(); + // std::cout << "thPosition: " << fThPosition << std::endl; + } + else if ("maxCenterDistance" == std::string(i->key)) { + this->_max_center_distance_ratio = i->value.toNumber(); + fMaxCenterDistance = sqrt(float(this->image_width*this->image_width + this->image_height*this->image_height)) * this->_max_center_distance_ratio; + // std::cout << "maxCenterDistance: " << this->_max_center_distance_ratio << std::endl; + } + else if ("minEdgeLength" == std::string(i->key)) { + iMinEdgeLength = i->value.toNumber(); + // std::cout << "minEdgeLength: " << iMinEdgeLength << std::endl; + } + else if ("minOrientedRectSide" == std::string(i->key)) { + fMinOrientedRectSide = i->value.toNumber(); + // std::cout << "minOrientedRectSide: " << fMinOrientedRectSide << std::endl; + } + else if ("distanceToEllipseContour" == std::string(i->key)) { + fDistanceToEllipseContour = i->value.toNumber(); + // std::cout << "distanceToEllipseContour: " << fDistanceToEllipseContour << std::endl; + } + else if ("minScore" == std::string(i->key)) { + fMinScore = i->value.toNumber(); + // std::cout << "minScore: " << fMinScore << std::endl; + } + else if ("minReliability" == std::string(i->key)) { + fMinReliability = i->value.toNumber(); + // std::cout << "minReliability: " << fMinReliability << std::endl; + } + else if ("ns" == std::string(i->key)) { + iNs = i->value.toNumber(); + // std::cout << "ns: " << iNs << std::endl; + } + else if ("percentNe" == std::string(i->key)) { + dPercentNe = i->value.toNumber(); + // std::cout << "percentNe: " << dPercentNe << std::endl; + } + else if ("T_CNC" == std::string(i->key)) { + fT_CNC = i->value.toNumber(); + // std::cout << "T_CNC: " << fT_CNC << std::endl; + } + else if ("T_TCN_L" == std::string(i->key)) { + fT_TCN_L = i->value.toNumber(); + // std::cout << "T_TCN_L: " << fT_TCN_L << std::endl; + } + else if ("T_TCN_P" == std::string(i->key)) { + fT_TCN_P = i->value.toNumber(); + // std::cout << "T_TCN_P: " << fT_TCN_P << std::endl; + } + else if ("thRadius" == std::string(i->key)) { + fThre_R = i->value.toNumber(); + // std::cout << "thRadius: " << fThre_R << std::endl; + } + else if ("radiusInMeter" == std::string(i->key)) { + this->_radius_in_meter = i->value.toNumber(); + // std::cout << "radiusInMeter: " << this->_radius_in_meter << std::endl; + } + } + + if (_ed) { delete _ed; _ed = NULL; } + _ed = new yaed::EllipseDetector; + _ed->SetParameters(szPreProcessingGaussKernel, dPreProcessingGaussSigma, fThPosition, fMaxCenterDistance, iMinEdgeLength, fMinOrientedRectSide, fDistanceToEllipseContour, fMinScore, fMinReliability, iNs, dPercentNe, fT_CNC, fT_TCN_L, fT_TCN_P, fThre_R); +} + + +SingleObjectTrackerBase::SingleObjectTrackerBase() +{ + this->_params_loaded = false; +} +SingleObjectTrackerBase::~SingleObjectTrackerBase() +{ + +} +bool SingleObjectTrackerBase::isParamsLoaded() +{ + return this->_params_loaded; +} +std::string SingleObjectTrackerBase::getAlgorithm() +{ + return this->_algorithm; +} +int SingleObjectTrackerBase::getBackend() +{ + return this->_backend; +} +int SingleObjectTrackerBase::getTarget() +{ + return this->_target; +} + +void SingleObjectTrackerBase::warmUp() +{ + cv::Mat testim = cv::Mat::zeros(640, 480, CV_8UC3); + this->init(testim, cv::Rect(10, 10, 100, 100)); + TargetsInFrame testtgts(0); + this->track(testim, testtgts); +} +void SingleObjectTrackerBase::init(cv::Mat img_, const cv::Rect& bounding_box_) +{ + if (!this->_params_loaded) + { + this->_load(); + this->_params_loaded = true; + } + + if (bounding_box_.width < 4 || bounding_box_.height < 4) + { + throw std::runtime_error("SpireCV (106) Tracking box size < (4, 4), too small!"); + } + if (bounding_box_.x < 0 || bounding_box_.y < 0 || bounding_box_.x + bounding_box_.width > img_.cols || bounding_box_.y + bounding_box_.height > img_.rows) + { + throw std::runtime_error("SpireCV (106) Tracking box not in the Input Image!"); + } + + initImpl(img_, bounding_box_); +} +void SingleObjectTrackerBase::track(cv::Mat img_, TargetsInFrame& tgts_) +{ + Rect rect; + bool ok = trackImpl(img_, rect); + + tgts_.setSize(img_.cols, img_.rows); + tgts_.setFOV(this->fov_x, this->fov_y); + auto t1 = std::chrono::system_clock::now(); + tgts_.setFPS(1000.0 / std::chrono::duration_cast(t1 - this->_t0).count()); + this->_t0 = std::chrono::system_clock::now(); + tgts_.setTimeNow(); + + if (ok) + { + Target tgt; + tgt.setBox(rect.x, rect.y, rect.x+rect.width, rect.y+rect.height, img_.cols, img_.rows); + tgt.setTrackID(1); + tgt.setLOS(tgt.cx, tgt.cy, this->camera_matrix, img_.cols, img_.rows); + tgts_.targets.push_back(tgt); + } + + tgts_.type = MissionType::TRACKING; +} +bool SingleObjectTrackerBase::setupImpl() +{ + return false; +} +void SingleObjectTrackerBase::initImpl(cv::Mat img_, const cv::Rect& bounding_box_) +{ + +} +bool SingleObjectTrackerBase::trackImpl(cv::Mat img_, cv::Rect& output_bbox_) +{ + return false; +} +void SingleObjectTrackerBase::_load() +{ + JsonValue all_value; + JsonAllocator allocator; + _load_all_json(this->alg_params_fn, all_value, allocator); + + JsonValue tracker_params_value; + _parser_algorithm_params("SingleObjectTracker", all_value, tracker_params_value); + + for (auto i : tracker_params_value) { + if ("algorithm" == std::string(i->key)) { + this->_algorithm = i->value.toString(); + std::cout << "algorithm: " << this->_algorithm << std::endl; + } + else if ("backend" == std::string(i->key)) { + this->_backend = i->value.toNumber(); + } + else if ("target" == std::string(i->key)) { + this->_target = i->value.toNumber(); + } + } + + setupImpl(); +} + + +CommonObjectDetectorBase::CommonObjectDetectorBase() // : CameraAlgorithm() +{ + this->_params_loaded = false; + // std::cout << "CommonObjectDetectorBase->_params_loaded: " << this->_params_loaded << std::endl; +} +CommonObjectDetectorBase::~CommonObjectDetectorBase() +{ + +} + +bool CommonObjectDetectorBase::isParamsLoaded() +{ + return this->_params_loaded; +} +std::string CommonObjectDetectorBase::getDataset() +{ + return this->_dataset; +} +std::vector CommonObjectDetectorBase::getClassNames() +{ + return this->_class_names; +} +std::vector CommonObjectDetectorBase::getClassWs() +{ + return this->_class_ws; +} +std::vector CommonObjectDetectorBase::getClassHs() +{ + return this->_class_hs; +} +int CommonObjectDetectorBase::getInputH() +{ + return this->_input_h; +} +int CommonObjectDetectorBase::getInputW() +{ + return this->_input_w; +} +int CommonObjectDetectorBase::getClassNum() +{ + return this->_n_classes; +} +int CommonObjectDetectorBase::getOutputSize() +{ + return this->_output_size; +} +double CommonObjectDetectorBase::getThrsNms() +{ + return this->_thrs_nms; +} +double CommonObjectDetectorBase::getThrsConf() +{ + return this->_thrs_conf; +} +int CommonObjectDetectorBase::useWidthOrHeight() +{ + return this->_use_width_or_height; +} +bool CommonObjectDetectorBase::withSegmentation() +{ + return this->_with_segmentation; +} +void CommonObjectDetectorBase::setInputH(int h_) +{ + this->_input_h = h_; +} +void CommonObjectDetectorBase::setInputW(int w_) +{ + this->_input_w = w_; +} + +void CommonObjectDetectorBase::warmUp() +{ + cv::Mat testim = cv::Mat::zeros(640, 480, CV_8UC3); + TargetsInFrame testtgts(0); + this->detect(testim, testtgts); +} + +void CommonObjectDetectorBase::detect(cv::Mat img_, TargetsInFrame& tgts_, Box* roi_, int img_w_, int img_h_) +{ + if (!this->_params_loaded) + { + this->_load(); + this->_params_loaded = true; + } + + if (nullptr != roi_ && img_w_ > 0 && img_h_ > 0) + tgts_.setSize(img_w_, img_h_); + else + tgts_.setSize(img_.cols, img_.rows); + + tgts_.setFOV(this->fov_x, this->fov_y); + auto t1 = std::chrono::system_clock::now(); + tgts_.setFPS(1000.0 / std::chrono::duration_cast(t1 - this->_t0).count()); + this->_t0 = std::chrono::system_clock::now(); + tgts_.setTimeNow(); + + std::vector boxes_x; + std::vector boxes_y; + std::vector boxes_w; + std::vector boxes_h; + std::vector boxes_label; + std::vector boxes_score; + std::vector boxes_seg; + detectImpl(img_, boxes_x, boxes_y, boxes_w, boxes_h, boxes_label, boxes_score, boxes_seg); + + size_t n_objs = boxes_x.size(); + if (n_objs != boxes_y.size() || n_objs != boxes_w.size() || n_objs != boxes_h.size() || n_objs != boxes_label.size() || n_objs != boxes_score.size()) + throw std::runtime_error("SpireCV (106) Error in detectImpl(), Vector Size Not Equal!"); + + if (this->_with_segmentation && n_objs != boxes_seg.size()) + throw std::runtime_error("SpireCV (106) Error in detectImpl(), Vector Size Not Equal!"); + + for (int j=0; j= img_.cols) ow = img_.cols - ox - 1; + if (oy + oh >= img_.rows) oh = img_.rows - oy - 1; + if (ow > 5 && oh > 5) + { + Target tgt; + if (nullptr != roi_ && img_w_ > 0 && img_h_ > 0) + tgt.setBox(roi_->x1 + ox, roi_->y1 + oy, roi_->x1 + ox + ow, roi_->y1 + oy + oh, img_w_, img_h_); + else + tgt.setBox(ox, oy, ox+ow, oy+oh, img_.cols, img_.rows); + + int cat_id = boxes_label[j]; + tgt.setCategory(this->_class_names[cat_id], cat_id); + if (nullptr != roi_ && img_w_ > 0 && img_h_ > 0) + tgt.setLOS(tgt.cx, tgt.cy, this->camera_matrix, img_w_, img_h_); + else + tgt.setLOS(tgt.cx, tgt.cy, this->camera_matrix, img_.cols, img_.rows); + tgt.score = boxes_score[j]; + if (this->_use_width_or_height == 0) + { + double z = this->camera_matrix.at(0, 0) * this->_class_ws[cat_id] / ow; + double x = tan(tgt.los_ax / SV_RAD2DEG) * z; + double y = tan(tgt.los_ay / SV_RAD2DEG) * z; + tgt.setPosition(x, y, z); + } + else if (this->_use_width_or_height == 1) + { + double z = this->camera_matrix.at(1, 1) * this->_class_hs[cat_id] / oh; + double x = tan(tgt.los_ax / SV_RAD2DEG) * z; + double y = tan(tgt.los_ay / SV_RAD2DEG) * z; + tgt.setPosition(x, y, z); + } + + if (this->_with_segmentation) + { + cv::Mat mask_j = boxes_seg[j].clone(); + int maskh = mask_j.rows, maskw = mask_j.cols; + assert(maskh == maskw); + + if (img_.cols > img_.rows) + { + int cut_h = (int)round((img_.rows * 1. / img_.cols) * maskh); + int gap_h = (int)round((maskh - cut_h) / 2.); + mask_j = mask_j.rowRange(gap_h, gap_h + cut_h); + } + else if (img_.cols < img_.rows) + { + int cut_w = (int)round((img_.cols * 1. / img_.rows) * maskh); + int gap_w = (int)round((maskh - cut_w) / 2.); + mask_j = mask_j.colRange(gap_w, gap_w + cut_w); + } + + if (nullptr != roi_ && img_w_ > 0 && img_h_ > 0) + { + cv::resize(mask_j, mask_j, cv::Size(img_.cols, img_.rows)); + + cv::Mat mask_out = cv::Mat::zeros(img_h_, img_w_, CV_32FC1); + mask_j.copyTo(mask_out(cv::Rect(roi_->x1, roi_->y1, mask_j.cols, mask_j.rows))); + + tgt.setMask(mask_out); + } + else + { + tgt.setMask(mask_j); + } + } + + tgts_.targets.push_back(tgt); + } + } + + tgts_.type = MissionType::COMMON_DET; +} + +bool CommonObjectDetectorBase::setupImpl() +{ + return false; +} + +void CommonObjectDetectorBase::detectImpl( + cv::Mat img_, + std::vector& boxes_x_, + std::vector& boxes_y_, + std::vector& boxes_w_, + std::vector& boxes_h_, + std::vector& boxes_label_, + std::vector& boxes_score_, + std::vector& boxes_seg_ +) +{ + +} + +void CommonObjectDetectorBase::_load() +{ + JsonValue all_value; + JsonAllocator allocator; + _load_all_json(this->alg_params_fn, all_value, allocator); + + JsonValue detector_params_value; + _parser_algorithm_params("CommonObjectDetector", all_value, detector_params_value); + + // std::cout << _get_home() + "/.spire/" << std::endl; + // stuff we know about the network and the input/output blobs + this->_input_h = 640; + this->_input_w = 640; + this->_n_classes = 1; + this->_thrs_nms = 0.6; + this->_thrs_conf = 0.4; + this->_use_width_or_height = 0; + + for (auto i : detector_params_value) { + + if ("dataset" == std::string(i->key)) { + this->_dataset = i->value.toString(); + std::cout << "dataset: " << this->_dataset << std::endl; + } + else if ("inputSize" == std::string(i->key)) { + // std::cout << "inputSize (old, new): " << this->_input_w << ", " << i->value.toNumber() << std::endl; + this->_input_w = i->value.toNumber(); + if (this->_input_w != 640 && this->_input_w != 1280) + { + throw std::runtime_error("SpireCV (106) inputSize should be 640 or 1280!"); + } + this->_input_h = this->_input_w; + } + else if ("nmsThrs" == std::string(i->key)) { + // std::cout << "nmsThrs (old, new): " << this->_thrs_nms << ", " << i->value.toNumber() << std::endl; + this->_thrs_nms = i->value.toNumber(); + } + else if ("scoreThrs" == std::string(i->key)) { + // std::cout << "scoreThrs (old, new): " << this->_thrs_conf << ", " << i->value.toNumber() << std::endl; + this->_thrs_conf = i->value.toNumber(); + } + else if ("useWidthOrHeight" == std::string(i->key)) { + // std::cout << "useWidthOrHeight (old, new): " << this->_use_width_or_height << ", " << i->value.toNumber() << std::endl; + this->_use_width_or_height = i->value.toNumber(); + } + else if ("withSegmentation" == std::string(i->key)) { + bool json_tf = false; + if (i->value.getTag() == JSON_TRUE) json_tf = true; + this->_with_segmentation = json_tf; + } + else if ("dataset" + this->_dataset == std::string(i->key)) { + if (i->value.getTag() == JSON_OBJECT) { + for (auto j : i->value) { + // std::cout << j->key << std::endl; + _class_names.push_back(std::string(j->key)); + if (j->value.getTag() == JSON_ARRAY) + { + int k_cnt = 0; + for (auto k : j->value) { + // std::cout << k->value.toNumber() << std::endl; + if (k_cnt == 0) _class_ws.push_back(k->value.toNumber()); + else if (k_cnt == 1) _class_hs.push_back(k->value.toNumber()); + k_cnt ++; + } + } + } + } + } + } + + setupImpl(); + +} + + + + + + + + + + + + + +} + diff --git a/algorithm/tracking/ocv470/tracking_ocv470_impl.cpp b/algorithm/tracking/ocv470/tracking_ocv470_impl.cpp new file mode 100644 index 0000000..6e46f23 --- /dev/null +++ b/algorithm/tracking/ocv470/tracking_ocv470_impl.cpp @@ -0,0 +1,135 @@ +#include "tracking_ocv470_impl.h" +#include +#include + +#define SV_MODEL_DIR "/SpireCV/models/" +#define SV_ROOT_DIR "/SpireCV/" + + +namespace sv { + +using namespace cv; + + +SingleObjectTrackerOCV470Impl::SingleObjectTrackerOCV470Impl() +{ +} + +SingleObjectTrackerOCV470Impl::~SingleObjectTrackerOCV470Impl() +{ +} + + +bool SingleObjectTrackerOCV470Impl::ocv470Setup(SingleObjectTrackerBase* base_) +{ + this->_algorithm = base_->getAlgorithm(); + this->_backend = base_->getBackend(); + this->_target = base_->getTarget(); + +#ifdef WITH_OCV470 + std::string net = get_home() + SV_MODEL_DIR + "dasiamrpn_model.onnx"; + std::string kernel_cls1 = get_home() + SV_MODEL_DIR + "dasiamrpn_kernel_cls1.onnx"; + std::string kernel_r1 = get_home() + SV_MODEL_DIR + "dasiamrpn_kernel_r1.onnx"; + + std::string backbone = get_home() + SV_MODEL_DIR + "nanotrack_backbone_sim.onnx"; + std::string neckhead = get_home() + SV_MODEL_DIR + "nanotrack_head_sim.onnx"; + + try + { + TrackerNano::Params nano_params; + nano_params.backbone = samples::findFile(backbone); + nano_params.neckhead = samples::findFile(neckhead); + nano_params.backend = this->_backend; + nano_params.target = this->_target; + + _nano = TrackerNano::create(nano_params); + } + catch (const cv::Exception& ee) + { + std::cerr << "Exception: " << ee.what() << std::endl; + std::cout << "Can't load the network by using the following files:" << std::endl; + std::cout << "nanoBackbone : " << backbone << std::endl; + std::cout << "nanoNeckhead : " << neckhead << std::endl; + } + + try + { + TrackerDaSiamRPN::Params params; + params.model = samples::findFile(net); + params.kernel_cls1 = samples::findFile(kernel_cls1); + params.kernel_r1 = samples::findFile(kernel_r1); + params.backend = this->_backend; + params.target = this->_target; + _siam_rpn = TrackerDaSiamRPN::create(params); + } + catch (const cv::Exception& ee) + { + std::cerr << "Exception: " << ee.what() << std::endl; + std::cout << "Can't load the network by using the following files:" << std::endl; + std::cout << "siamRPN : " << net << std::endl; + std::cout << "siamKernelCL1 : " << kernel_cls1 << std::endl; + std::cout << "siamKernelR1 : " << kernel_r1 << std::endl; + } + return true; +#endif + return false; +} + + +void SingleObjectTrackerOCV470Impl::ocv470Init(cv::Mat img_, const cv::Rect& bounding_box_) +{ +#ifdef WITH_OCV470 + if (this->_algorithm == "kcf") + { + TrackerKCF::Params params; + _kcf = TrackerKCF::create(params); + _kcf->init(img_, bounding_box_); + } + else if (this->_algorithm == "csrt") + { + TrackerCSRT::Params params; + _csrt = TrackerCSRT::create(params); + _csrt->init(img_, bounding_box_); + } + else if (this->_algorithm == "siamrpn") + { + _siam_rpn->init(img_, bounding_box_); + } + else if (this->_algorithm == "nano") + { + _nano->init(img_, bounding_box_); + } +#endif +} + + +bool SingleObjectTrackerOCV470Impl::ocv470Track(cv::Mat img_, cv::Rect& output_bbox_) +{ +#ifdef WITH_OCV470 + bool ok = false; + if (this->_algorithm == "kcf") + { + ok = _kcf->update(img_, output_bbox_); + } + else if (this->_algorithm == "csrt") + { + ok = _csrt->update(img_, output_bbox_); + } + else if (this->_algorithm == "siamrpn") + { + ok = _siam_rpn->update(img_, output_bbox_); + } + else if (this->_algorithm == "nano") + { + ok = _nano->update(img_, output_bbox_); + } + return ok; +#endif + return false; +} + + + + +} + diff --git a/algorithm/tracking/ocv470/tracking_ocv470_impl.h b/algorithm/tracking/ocv470/tracking_ocv470_impl.h new file mode 100644 index 0000000..7d7c759 --- /dev/null +++ b/algorithm/tracking/ocv470/tracking_ocv470_impl.h @@ -0,0 +1,40 @@ +#ifndef __SV_TRACKING_OCV470__ +#define __SV_TRACKING_OCV470__ + +#include "sv_core.h" +#include +#include +#include +#include +#include + + + +namespace sv { + + +class SingleObjectTrackerOCV470Impl +{ +public: + SingleObjectTrackerOCV470Impl(); + ~SingleObjectTrackerOCV470Impl(); + + bool ocv470Setup(SingleObjectTrackerBase* base_); + void ocv470Init(cv::Mat img_, const cv::Rect& bounding_box_); + bool ocv470Track(cv::Mat img_, cv::Rect& output_bbox_); + + std::string _algorithm; + int _backend; + int _target; + +#ifdef WITH_OCV470 + cv::Ptr _siam_rpn; + cv::Ptr _kcf; + cv::Ptr _csrt; + cv::Ptr _nano; +#endif +}; + + +} +#endif diff --git a/algorithm/tracking/sv_tracking.cpp b/algorithm/tracking/sv_tracking.cpp new file mode 100644 index 0000000..094a7a7 --- /dev/null +++ b/algorithm/tracking/sv_tracking.cpp @@ -0,0 +1,41 @@ +#include "sv_tracking.h" +#include +#include +#include "tracking_ocv470_impl.h" + + +namespace sv { + + +SingleObjectTracker::SingleObjectTracker() +{ + this->_ocv470_impl = new SingleObjectTrackerOCV470Impl; +} +SingleObjectTracker::~SingleObjectTracker() +{ +} + +bool SingleObjectTracker::setupImpl() +{ +#ifdef WITH_OCV470 + return this->_ocv470_impl->ocv470Setup(this); +#endif + return false; +} +void SingleObjectTracker::initImpl(cv::Mat img_, const cv::Rect& bounding_box_) +{ +#ifdef WITH_OCV470 + this->_ocv470_impl->ocv470Init(img_, bounding_box_); +#endif +} +bool SingleObjectTracker::trackImpl(cv::Mat img_, cv::Rect& output_bbox_) +{ +#ifdef WITH_OCV470 + return this->_ocv470_impl->ocv470Track(img_, output_bbox_); +#endif + return false; +} + + +} + diff --git a/build_on_jetson.sh b/build_on_jetson.sh new file mode 100755 index 0000000..65a5735 --- /dev/null +++ b/build_on_jetson.sh @@ -0,0 +1,9 @@ +#!/bin/bash -e + +rm -rf build +mkdir build +cd build +cmake .. -DPLATFORM=JETSON +make -j4 +sudo make install + diff --git a/build_on_x86_cuda.sh b/build_on_x86_cuda.sh new file mode 100755 index 0000000..2b9fd36 --- /dev/null +++ b/build_on_x86_cuda.sh @@ -0,0 +1,9 @@ +#!/bin/bash -e + +rm -rf build +mkdir build +cd build +cmake .. -DPLATFORM=X86_CUDA +make -j4 +sudo make install + diff --git a/gimbal_ctrl/IOs/serial/README.md b/gimbal_ctrl/IOs/serial/README.md new file mode 100755 index 0000000..c5d8d0b --- /dev/null +++ b/gimbal_ctrl/IOs/serial/README.md @@ -0,0 +1,63 @@ +# Serial Communication Library + +[![Build Status](https://travis-ci.org/wjwwood/serial.svg?branch=master)](https://travis-ci.org/wjwwood/serial)*(Linux and OS X)* [![Build Status](https://ci.appveyor.com/api/projects/status/github/wjwwood/serial)](https://ci.appveyor.com/project/wjwwood/serial)*(Windows)* + +This is a cross-platform library for interfacing with rs-232 serial like ports written in C++. It provides a modern C++ interface with a workflow designed to look and feel like PySerial, but with the speed and control provided by C++. + +This library is in use in several robotics related projects and can be built and installed to the OS like most unix libraries with make and then sudo make install, but because it is a catkin project it can also be built along side other catkin projects in a catkin workspace. + +Serial is a class that provides the basic interface common to serial libraries (open, close, read, write, etc..) and requires no extra dependencies. It also provides tight control over timeouts and control over handshaking lines. + +### Documentation + +Website: http://wjwwood.github.io/serial/ + +API Documentation: http://wjwwood.github.io/serial/doc/1.1.0/index.html + +### Dependencies + +Required: +* [catkin](http://www.ros.org/wiki/catkin) - cmake and Python based buildsystem +* [cmake](http://www.cmake.org) - buildsystem +* [Python](http://www.python.org) - scripting language + * [empy](http://www.alcyone.com/pyos/empy/) - Python templating library + * [catkin_pkg](http://pypi.python.org/pypi/catkin_pkg/) - Runtime Python library for catkin + +Optional (for documentation): +* [Doxygen](http://www.doxygen.org/) - Documentation generation tool +* [graphviz](http://www.graphviz.org/) - Graph visualization software + +### Install + +Get the code: + + git clone https://github.com/wjwwood/serial.git + +Build: + + make + +Build and run the tests: + + make test + +Build the documentation: + + make doc + +Install: + + make install + +### License + +[The MIT License](LICENSE) + +### Authors + +William Woodall +John Harrison + +### Contact + +William Woodall diff --git a/gimbal_ctrl/IOs/serial/include/serial/impl/unix.h b/gimbal_ctrl/IOs/serial/include/serial/impl/unix.h new file mode 100755 index 0000000..0fb38f2 --- /dev/null +++ b/gimbal_ctrl/IOs/serial/include/serial/impl/unix.h @@ -0,0 +1,221 @@ +/*! + * \file serial/impl/unix.h + * \author William Woodall + * \author John Harrison + * \version 0.1 + * + * \section LICENSE + * + * The MIT License + * + * Copyright (c) 2012 William Woodall, John Harrison + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + * \section DESCRIPTION + * + * This provides a unix based pimpl for the Serial class. This implementation is + * based off termios.h and uses select for multiplexing the IO ports. + * + */ + +#if !defined(_WIN32) + +#ifndef SERIAL_IMPL_UNIX_H +#define SERIAL_IMPL_UNIX_H + +#include "serial/serial.h" + +#include + +namespace serial { + +using std::size_t; +using std::string; +using std::invalid_argument; + +using serial::SerialException; +using serial::IOException; + +class MillisecondTimer { +public: + MillisecondTimer(const uint32_t millis); + int64_t remaining(); + +private: + static timespec timespec_now(); + timespec expiry; +}; + +class serial::Serial::SerialImpl { +public: + SerialImpl (const string &port, + unsigned long baudrate, + bytesize_t bytesize, + parity_t parity, + stopbits_t stopbits, + flowcontrol_t flowcontrol); + + virtual ~SerialImpl (); + + void + open (); + + void + close (); + + bool + isOpen () const; + + size_t + available (); + + bool + waitReadable (uint32_t timeout); + + void + waitByteTimes (size_t count); + + size_t + read (uint8_t *buf, size_t size = 1); + + size_t + write (const uint8_t *data, size_t length); + + void + flush (); + + void + flushInput (); + + void + flushOutput (); + + void + sendBreak (int duration); + + void + setBreak (bool level); + + void + setRTS (bool level); + + void + setDTR (bool level); + + bool + waitForChange (); + + bool + getCTS (); + + bool + getDSR (); + + bool + getRI (); + + bool + getCD (); + + void + setPort (const string &port); + + string + getPort () const; + + void + setTimeout (Timeout &timeout); + + Timeout + getTimeout () const; + + void + setBaudrate (unsigned long baudrate); + + unsigned long + getBaudrate () const; + + void + setBytesize (bytesize_t bytesize); + + bytesize_t + getBytesize () const; + + void + setParity (parity_t parity); + + parity_t + getParity () const; + + void + setStopbits (stopbits_t stopbits); + + stopbits_t + getStopbits () const; + + void + setFlowcontrol (flowcontrol_t flowcontrol); + + flowcontrol_t + getFlowcontrol () const; + + void + readLock (); + + void + readUnlock (); + + void + writeLock (); + + void + writeUnlock (); + +protected: + void reconfigurePort (); + +private: + string port_; // Path to the file descriptor + int fd_; // The current file descriptor + + bool is_open_; + bool xonxoff_; + bool rtscts_; + + Timeout timeout_; // Timeout for read operations + unsigned long baudrate_; // Baudrate + uint32_t byte_time_ns_; // Nanoseconds to transmit/receive a single byte + + parity_t parity_; // Parity + bytesize_t bytesize_; // Size of the bytes + stopbits_t stopbits_; // Stop Bits + flowcontrol_t flowcontrol_; // Flow Control + + // Mutex used to lock the read functions + pthread_mutex_t read_mutex; + // Mutex used to lock the write functions + pthread_mutex_t write_mutex; +}; + +} + +#endif // SERIAL_IMPL_UNIX_H + +#endif // !defined(_WIN32) diff --git a/gimbal_ctrl/IOs/serial/include/serial/impl/win.h b/gimbal_ctrl/IOs/serial/include/serial/impl/win.h new file mode 100755 index 0000000..2c0c6cd --- /dev/null +++ b/gimbal_ctrl/IOs/serial/include/serial/impl/win.h @@ -0,0 +1,207 @@ +/*! + * \file serial/impl/windows.h + * \author William Woodall + * \author John Harrison + * \version 0.1 + * + * \section LICENSE + * + * The MIT License + * + * Copyright (c) 2012 William Woodall, John Harrison + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + * \section DESCRIPTION + * + * This provides a windows implementation of the Serial class interface. + * + */ + +#if defined(_WIN32) + +#ifndef SERIAL_IMPL_WINDOWS_H +#define SERIAL_IMPL_WINDOWS_H + +#include "serial/serial.h" + +#include "windows.h" + +namespace serial { + +using std::string; +using std::wstring; +using std::invalid_argument; + +using serial::SerialException; +using serial::IOException; + +class serial::Serial::SerialImpl { +public: + SerialImpl (const string &port, + unsigned long baudrate, + bytesize_t bytesize, + parity_t parity, + stopbits_t stopbits, + flowcontrol_t flowcontrol); + + virtual ~SerialImpl (); + + void + open (); + + void + close (); + + bool + isOpen () const; + + size_t + available (); + + bool + waitReadable (uint32_t timeout); + + void + waitByteTimes (size_t count); + + size_t + read (uint8_t *buf, size_t size = 1); + + size_t + write (const uint8_t *data, size_t length); + + void + flush (); + + void + flushInput (); + + void + flushOutput (); + + void + sendBreak (int duration); + + void + setBreak (bool level); + + void + setRTS (bool level); + + void + setDTR (bool level); + + bool + waitForChange (); + + bool + getCTS (); + + bool + getDSR (); + + bool + getRI (); + + bool + getCD (); + + void + setPort (const string &port); + + string + getPort () const; + + void + setTimeout (Timeout &timeout); + + Timeout + getTimeout () const; + + void + setBaudrate (unsigned long baudrate); + + unsigned long + getBaudrate () const; + + void + setBytesize (bytesize_t bytesize); + + bytesize_t + getBytesize () const; + + void + setParity (parity_t parity); + + parity_t + getParity () const; + + void + setStopbits (stopbits_t stopbits); + + stopbits_t + getStopbits () const; + + void + setFlowcontrol (flowcontrol_t flowcontrol); + + flowcontrol_t + getFlowcontrol () const; + + void + readLock (); + + void + readUnlock (); + + void + writeLock (); + + void + writeUnlock (); + +protected: + void reconfigurePort (); + +private: + wstring port_; // Path to the file descriptor + HANDLE fd_; + + bool is_open_; + + Timeout timeout_; // Timeout for read operations + unsigned long baudrate_; // Baudrate + + parity_t parity_; // Parity + bytesize_t bytesize_; // Size of the bytes + stopbits_t stopbits_; // Stop Bits + flowcontrol_t flowcontrol_; // Flow Control + + // Mutex used to lock the read functions + HANDLE read_mutex; + // Mutex used to lock the write functions + HANDLE write_mutex; +}; + +} + +#endif // SERIAL_IMPL_WINDOWS_H + +#endif // if defined(_WIN32) diff --git a/gimbal_ctrl/IOs/serial/include/serial/serial.h b/gimbal_ctrl/IOs/serial/include/serial/serial.h new file mode 100755 index 0000000..4790629 --- /dev/null +++ b/gimbal_ctrl/IOs/serial/include/serial/serial.h @@ -0,0 +1,796 @@ +/*! + * \file serial/serial.h + * \author William Woodall + * \author John Harrison + * \version 0.1 + * + * \section LICENSE + * + * The MIT License + * + * Copyright (c) 2012 William Woodall + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + * \section DESCRIPTION + * + * This provides a cross platform interface for interacting with Serial Ports. + */ + +#ifndef SERIAL_H +#define SERIAL_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#define THROW(exceptionClass, message) throw exceptionClass(__FILE__, \ + __LINE__, (message)) + +namespace serial +{ + + /*! + * Enumeration defines the possible bytesizes for the serial port. + */ + typedef enum + { + fivebits = 5, + sixbits = 6, + sevenbits = 7, + eightbits = 8 + } bytesize_t; + + /*! + * Enumeration defines the possible parity types for the serial port. + */ + typedef enum + { + parity_none = 0, + parity_odd = 1, + parity_even = 2, + parity_mark = 3, + parity_space = 4 + } parity_t; + + /*! + * Enumeration defines the possible stopbit types for the serial port. + */ + typedef enum + { + stopbits_one = 1, + stopbits_two = 2, + stopbits_one_point_five + } stopbits_t; + + /*! + * Enumeration defines the possible flowcontrol types for the serial port. + */ + typedef enum + { + flowcontrol_none = 0, + flowcontrol_software, + flowcontrol_hardware + } flowcontrol_t; + + /*! + * Structure for setting the timeout of the serial port, times are + * in milliseconds. + * + * In order to disable the interbyte timeout, set it to Timeout::max(). + */ + struct Timeout + { +#ifdef max +#undef max +#endif + static uint32_t max() + { + return std::numeric_limits::max(); + } + /*! + * Convenience function to generate Timeout structs using a + * single absolute timeout. + * + * \param timeout A long that defines the time in milliseconds until a + * timeout occurs after a call to read or write is made. + * + * \return Timeout struct that represents this simple timeout provided. + */ + static Timeout simpleTimeout(uint32_t timeout) + { + return Timeout(max(), timeout, 0, timeout, 0); + } + + /*! Number of milliseconds between bytes received to timeout on. */ + uint32_t inter_byte_timeout; + /*! A constant number of milliseconds to wait after calling read. */ + uint32_t read_timeout_constant; + /*! A multiplier against the number of requested bytes to wait after + * calling read. + */ + uint32_t read_timeout_multiplier; + /*! A constant number of milliseconds to wait after calling write. */ + uint32_t write_timeout_constant; + /*! A multiplier against the number of requested bytes to wait after + * calling write. + */ + uint32_t write_timeout_multiplier; + + explicit Timeout(uint32_t inter_byte_timeout_ = 0, + uint32_t read_timeout_constant_ = 0, + uint32_t read_timeout_multiplier_ = 0, + uint32_t write_timeout_constant_ = 0, + uint32_t write_timeout_multiplier_ = 0) + : inter_byte_timeout(inter_byte_timeout_), + read_timeout_constant(read_timeout_constant_), + read_timeout_multiplier(read_timeout_multiplier_), + write_timeout_constant(write_timeout_constant_), + write_timeout_multiplier(write_timeout_multiplier_) + { + } + }; + + /*! + * Class that provides a portable serial port interface. + */ + class Serial + { + public: + /*! + * Creates a Serial object and opens the port if a port is specified, + * otherwise it remains closed until serial::Serial::open is called. + * + * \param port A std::string containing the address of the serial port, + * which would be something like 'COM1' on Windows and '/dev/ttyS0' + * on Linux. + * + * \param baudrate An unsigned 32-bit integer that represents the baudrate + * + * \param timeout A serial::Timeout struct that defines the timeout + * conditions for the serial port. \see serial::Timeout + * + * \param bytesize Size of each byte in the serial transmission of data, + * default is eightbits, possible values are: fivebits, sixbits, sevenbits, + * eightbits + * + * \param parity Method of parity, default is parity_none, possible values + * are: parity_none, parity_odd, parity_even + * + * \param stopbits Number of stop bits used, default is stopbits_one, + * possible values are: stopbits_one, stopbits_one_point_five, stopbits_two + * + * \param flowcontrol Type of flowcontrol used, default is + * flowcontrol_none, possible values are: flowcontrol_none, + * flowcontrol_software, flowcontrol_hardware + * + * \throw serial::PortNotOpenedException + * \throw serial::IOException + * \throw std::invalid_argument + */ + Serial(const std::string &port = "", + uint32_t baudrate = 9600, + Timeout timeout = Timeout(), + bytesize_t bytesize = eightbits, + parity_t parity = parity_none, + stopbits_t stopbits = stopbits_one, + flowcontrol_t flowcontrol = flowcontrol_none); + + /*! Destructor */ + virtual ~Serial(); + + /*! + * Opens the serial port as long as the port is set and the port isn't + * already open. + * + * If the port is provided to the constructor then an explicit call to open + * is not needed. + * + * \see Serial::Serial + * + * \throw std::invalid_argument + * \throw serial::SerialException + * \throw serial::IOException + */ + void + open(); + + /*! Gets the open status of the serial port. + * + * \return Returns true if the port is open, false otherwise. + */ + bool + isOpen() const; + + /*! Closes the serial port. */ + void + close(); + + /*! Return the number of characters in the buffer. */ + size_t + available(); + + /*! Block until there is serial data to read or read_timeout_constant + * number of milliseconds have elapsed. The return value is true when + * the function exits with the port in a readable state, false otherwise + * (due to timeout or select interruption). */ + bool + waitReadable(); + + /*! Block for a period of time corresponding to the transmission time of + * count characters at present serial settings. This may be used in con- + * junction with waitReadable to read larger blocks of data from the + * port. */ + void + waitByteTimes(size_t count); + + /*! Read a given amount of bytes from the serial port into a given buffer. + * + * The read function will return in one of three cases: + * * The number of requested bytes was read. + * * In this case the number of bytes requested will match the size_t + * returned by read. + * * A timeout occurred, in this case the number of bytes read will not + * match the amount requested, but no exception will be thrown. One of + * two possible timeouts occurred: + * * The inter byte timeout expired, this means that number of + * milliseconds elapsed between receiving bytes from the serial port + * exceeded the inter byte timeout. + * * The total timeout expired, which is calculated by multiplying the + * read timeout multiplier by the number of requested bytes and then + * added to the read timeout constant. If that total number of + * milliseconds elapses after the initial call to read a timeout will + * occur. + * * An exception occurred, in this case an actual exception will be thrown. + * + * \param buffer An uint8_t array of at least the requested size. + * \param size A size_t defining how many bytes to be read. + * + * \return A size_t representing the number of bytes read as a result of the + * call to read. + * + * \throw serial::PortNotOpenedException + * \throw serial::SerialException + */ + size_t + read(uint8_t *buffer, size_t size); + + /*! Read a given amount of bytes from the serial port into a give buffer. + * + * \param buffer A reference to a std::vector of uint8_t. + * \param size A size_t defining how many bytes to be read. + * + * \return A size_t representing the number of bytes read as a result of the + * call to read. + * + * \throw serial::PortNotOpenedException + * \throw serial::SerialException + */ + size_t + read(std::vector &buffer, size_t size = 1); + + /*! Read a given amount of bytes from the serial port into a give buffer. + * + * \param buffer A reference to a std::string. + * \param size A size_t defining how many bytes to be read. + * + * \return A size_t representing the number of bytes read as a result of the + * call to read. + * + * \throw serial::PortNotOpenedException + * \throw serial::SerialException + */ + size_t + read(std::string &buffer, size_t size = 1); + + /*! Read a given amount of bytes from the serial port and return a string + * containing the data. + * + * \param size A size_t defining how many bytes to be read. + * + * \return A std::string containing the data read from the port. + * + * \throw serial::PortNotOpenedException + * \throw serial::SerialException + */ + std::string + read(size_t size = 1); + + /*! Reads in a line or until a given delimiter has been processed. + * + * Reads from the serial port until a single line has been read. + * + * \param buffer A std::string reference used to store the data. + * \param size A maximum length of a line, defaults to 65536 (2^16) + * \param eol A string to match against for the EOL. + * + * \return A size_t representing the number of bytes read. + * + * \throw serial::PortNotOpenedException + * \throw serial::SerialException + */ + size_t + readline(std::string &buffer, size_t size = 65536, std::string eol = "\n"); + + /*! Reads in a line or until a given delimiter has been processed. + * + * Reads from the serial port until a single line has been read. + * + * \param size A maximum length of a line, defaults to 65536 (2^16) + * \param eol A string to match against for the EOL. + * + * \return A std::string containing the line. + * + * \throw serial::PortNotOpenedException + * \throw serial::SerialException + */ + std::string + readline(size_t size = 65536, std::string eol = "\n"); + + /*! Reads in multiple lines until the serial port times out. + * + * This requires a timeout > 0 before it can be run. It will read until a + * timeout occurs and return a list of strings. + * + * \param size A maximum length of combined lines, defaults to 65536 (2^16) + * + * \param eol A string to match against for the EOL. + * + * \return A vector containing the lines. + * + * \throw serial::PortNotOpenedException + * \throw serial::SerialException + */ + std::vector + readlines(size_t size = 65536, std::string eol = "\n"); + + /*! Write a string to the serial port. + * + * \param data A const reference containing the data to be written + * to the serial port. + * + * \param size A size_t that indicates how many bytes should be written from + * the given data buffer. + * + * \return A size_t representing the number of bytes actually written to + * the serial port. + * + * \throw serial::PortNotOpenedException + * \throw serial::SerialException + * \throw serial::IOException + */ + size_t + write(const uint8_t *data, size_t size); + + /*! Write a string to the serial port. + * + * \param data A const reference containing the data to be written + * to the serial port. + * + * \return A size_t representing the number of bytes actually written to + * the serial port. + * + * \throw serial::PortNotOpenedException + * \throw serial::SerialException + * \throw serial::IOException + */ + size_t + write(const std::vector &data); + + /*! Write a string to the serial port. + * + * \param data A const reference containing the data to be written + * to the serial port. + * + * \return A size_t representing the number of bytes actually written to + * the serial port. + * + * \throw serial::PortNotOpenedException + * \throw serial::SerialException + * \throw serial::IOException + */ + size_t + write(const std::string &data); + + /*! Sets the serial port identifier. + * + * \param port A const std::string reference containing the address of the + * serial port, which would be something like 'COM1' on Windows and + * '/dev/ttyS0' on Linux. + * + * \throw std::invalid_argument + */ + void + setPort(const std::string &port); + + /*! Gets the serial port identifier. + * + * \see Serial::setPort + * + * \throw std::invalid_argument + */ + std::string + getPort() const; + + /*! Sets the timeout for reads and writes using the Timeout struct. + * + * There are two timeout conditions described here: + * * The inter byte timeout: + * * The inter_byte_timeout component of serial::Timeout defines the + * maximum amount of time, in milliseconds, between receiving bytes on + * the serial port that can pass before a timeout occurs. Setting this + * to zero will prevent inter byte timeouts from occurring. + * * Total time timeout: + * * The constant and multiplier component of this timeout condition, + * for both read and write, are defined in serial::Timeout. This + * timeout occurs if the total time since the read or write call was + * made exceeds the specified time in milliseconds. + * * The limit is defined by multiplying the multiplier component by the + * number of requested bytes and adding that product to the constant + * component. In this way if you want a read call, for example, to + * timeout after exactly one second regardless of the number of bytes + * you asked for then set the read_timeout_constant component of + * serial::Timeout to 1000 and the read_timeout_multiplier to zero. + * This timeout condition can be used in conjunction with the inter + * byte timeout condition with out any problems, timeout will simply + * occur when one of the two timeout conditions is met. This allows + * users to have maximum control over the trade-off between + * responsiveness and efficiency. + * + * Read and write functions will return in one of three cases. When the + * reading or writing is complete, when a timeout occurs, or when an + * exception occurs. + * + * A timeout of 0 enables non-blocking mode. + * + * \param timeout A serial::Timeout struct containing the inter byte + * timeout, and the read and write timeout constants and multipliers. + * + * \see serial::Timeout + */ + void + setTimeout(Timeout &timeout); + + /*! Sets the timeout for reads and writes. */ + void + setTimeout(uint32_t inter_byte_timeout, uint32_t read_timeout_constant, + uint32_t read_timeout_multiplier, uint32_t write_timeout_constant, + uint32_t write_timeout_multiplier) + { + Timeout timeout(inter_byte_timeout, read_timeout_constant, + read_timeout_multiplier, write_timeout_constant, + write_timeout_multiplier); + return setTimeout(timeout); + } + + /*! Gets the timeout for reads in seconds. + * + * \return A Timeout struct containing the inter_byte_timeout, and read + * and write timeout constants and multipliers. + * + * \see Serial::setTimeout + */ + Timeout + getTimeout() const; + + /*! Sets the baudrate for the serial port. + * + * Possible baudrates depends on the system but some safe baudrates include: + * 110, 300, 600, 1200, 2400, 4800, 9600, 14400, 19200, 28800, 38400, 56000, + * 57600, 115200 + * Some other baudrates that are supported by some comports: + * 128000, 153600, 230400, 256000, 460800, 500000, 921600 + * + * \param baudrate An integer that sets the baud rate for the serial port. + * + * \throw std::invalid_argument + */ + void + setBaudrate(uint32_t baudrate); + + /*! Gets the baudrate for the serial port. + * + * \return An integer that sets the baud rate for the serial port. + * + * \see Serial::setBaudrate + * + * \throw std::invalid_argument + */ + uint32_t + getBaudrate() const; + + /*! Sets the bytesize for the serial port. + * + * \param bytesize Size of each byte in the serial transmission of data, + * default is eightbits, possible values are: fivebits, sixbits, sevenbits, + * eightbits + * + * \throw std::invalid_argument + */ + void + setBytesize(bytesize_t bytesize); + + /*! Gets the bytesize for the serial port. + * + * \see Serial::setBytesize + * + * \throw std::invalid_argument + */ + bytesize_t + getBytesize() const; + + /*! Sets the parity for the serial port. + * + * \param parity Method of parity, default is parity_none, possible values + * are: parity_none, parity_odd, parity_even + * + * \throw std::invalid_argument + */ + void + setParity(parity_t parity); + + /*! Gets the parity for the serial port. + * + * \see Serial::setParity + * + * \throw std::invalid_argument + */ + parity_t + getParity() const; + + /*! Sets the stopbits for the serial port. + * + * \param stopbits Number of stop bits used, default is stopbits_one, + * possible values are: stopbits_one, stopbits_one_point_five, stopbits_two + * + * \throw std::invalid_argument + */ + void + setStopbits(stopbits_t stopbits); + + /*! Gets the stopbits for the serial port. + * + * \see Serial::setStopbits + * + * \throw std::invalid_argument + */ + stopbits_t + getStopbits() const; + + /*! Sets the flow control for the serial port. + * + * \param flowcontrol Type of flowcontrol used, default is flowcontrol_none, + * possible values are: flowcontrol_none, flowcontrol_software, + * flowcontrol_hardware + * + * \throw std::invalid_argument + */ + void + setFlowcontrol(flowcontrol_t flowcontrol); + + /*! Gets the flow control for the serial port. + * + * \see Serial::setFlowcontrol + * + * \throw std::invalid_argument + */ + flowcontrol_t + getFlowcontrol() const; + + /*! Flush the input and output buffers */ + void + flush(); + + /*! Flush only the input buffer */ + void + flushInput(); + + /*! Flush only the output buffer */ + void + flushOutput(); + + /*! Sends the RS-232 break signal. See tcsendbreak(3). */ + void + sendBreak(int duration); + + /*! Set the break condition to a given level. Defaults to true. */ + void + setBreak(bool level = true); + + /*! Set the RTS handshaking line to the given level. Defaults to true. */ + void + setRTS(bool level = true); + + /*! Set the DTR handshaking line to the given level. Defaults to true. */ + void + setDTR(bool level = true); + + /*! + * Blocks until CTS, DSR, RI, CD changes or something interrupts it. + * + * Can throw an exception if an error occurs while waiting. + * You can check the status of CTS, DSR, RI, and CD once this returns. + * Uses TIOCMIWAIT via ioctl if available (mostly only on Linux) with a + * resolution of less than +-1ms and as good as +-0.2ms. Otherwise a + * polling method is used which can give +-2ms. + * + * \return Returns true if one of the lines changed, false if something else + * occurred. + * + * \throw SerialException + */ + bool + waitForChange(); + + /*! Returns the current status of the CTS line. */ + bool + getCTS(); + + /*! Returns the current status of the DSR line. */ + bool + getDSR(); + + /*! Returns the current status of the RI line. */ + bool + getRI(); + + /*! Returns the current status of the CD line. */ + bool + getCD(); + + private: + // Disable copy constructors + Serial(const Serial &); + Serial &operator=(const Serial &); + + // Pimpl idiom, d_pointer + class SerialImpl; + SerialImpl *pimpl_; + + // Scoped Lock Classes + class ScopedReadLock; + class ScopedWriteLock; + + // Read common function + size_t + read_(uint8_t *buffer, size_t size); + // Write common function + size_t + write_(const uint8_t *data, size_t length); + }; + + class SerialException : public std::exception + { + // Disable copy constructors + SerialException &operator=(const SerialException &); + std::string e_what_; + + public: + SerialException(const char *description) + { + std::stringstream ss; + ss << "SerialException " << description << " failed."; + e_what_ = ss.str(); + } + SerialException(const SerialException &other) : e_what_(other.e_what_) {} + virtual ~SerialException() throw() {} + virtual const char *what() const throw() + { + return e_what_.c_str(); + } + }; + + class IOException : public std::exception + { + // Disable copy constructors + IOException &operator=(const IOException &); + std::string file_; + int line_; + std::string e_what_; + int errno_; + + public: + explicit IOException(std::string file, int line, int errnum) + : file_(file), line_(line), errno_(errnum) + { + std::stringstream ss; +#if defined(_WIN32) && !defined(__MINGW32__) + char error_str[1024]; + strerror_s(error_str, 1024, errnum); +#else + char *error_str = strerror(errnum); +#endif + ss << "IO Exception (" << errno_ << "): " << error_str; + ss << ", file " << file_ << ", line " << line_ << "."; + e_what_ = ss.str(); + } + explicit IOException(std::string file, int line, const char *description) + : file_(file), line_(line), errno_(0) + { + std::stringstream ss; + ss << "IO Exception: " << description; + ss << ", file " << file_ << ", line " << line_ << "."; + e_what_ = ss.str(); + } + virtual ~IOException() throw() {} + IOException(const IOException &other) : line_(other.line_), e_what_(other.e_what_), errno_(other.errno_) {} + + int getErrorNumber() const { return errno_; } + + virtual const char *what() const throw() + { + return e_what_.c_str(); + } + }; + + class PortNotOpenedException : public std::exception + { + // Disable copy constructors + const PortNotOpenedException &operator=(PortNotOpenedException); + std::string e_what_; + + public: + PortNotOpenedException(const char *description) + { + std::stringstream ss; + ss << "PortNotOpenedException " << description << " failed."; + e_what_ = ss.str(); + } + PortNotOpenedException(const PortNotOpenedException &other) : e_what_(other.e_what_) {} + virtual ~PortNotOpenedException() throw() {} + virtual const char *what() const throw() + { + return e_what_.c_str(); + } + }; + + /*! + * Structure that describes a serial device. + */ + struct PortInfo + { + + /*! Address of the serial port (this can be passed to the constructor of Serial). */ + std::string port; + + /*! Human readable description of serial device if available. */ + std::string description; + + /*! Hardware ID (e.g. VID:PID of USB serial devices) or "n/a" if not available. */ + std::string hardware_id; + }; + + /* Lists the serial ports available on the system + * + * Returns a vector of available serial ports, each represented + * by a serial::PortInfo data structure: + * + * \return vector of serial::PortInfo. + */ + std::vector + list_ports(); + +} // namespace serial + +#endif diff --git a/gimbal_ctrl/IOs/serial/include/serial/v8stdint.h b/gimbal_ctrl/IOs/serial/include/serial/v8stdint.h new file mode 100755 index 0000000..f3be96b --- /dev/null +++ b/gimbal_ctrl/IOs/serial/include/serial/v8stdint.h @@ -0,0 +1,57 @@ +// This header is from the v8 google project: +// http://code.google.com/p/v8/source/browse/trunk/include/v8stdint.h + +// Copyright 2012 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Load definitions of standard types. + +#ifndef V8STDINT_H_ +#define V8STDINT_H_ + +#include +#include + +#if defined(_WIN32) && !defined(__MINGW32__) + +typedef signed char int8_t; +typedef unsigned char uint8_t; +typedef short int16_t; // NOLINT +typedef unsigned short uint16_t; // NOLINT +typedef int int32_t; +typedef unsigned int uint32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +// intptr_t and friends are defined in crtdefs.h through stdio.h. + +#else + +#include + +#endif + +#endif // V8STDINT_H_ diff --git a/gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_linux.cc b/gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_linux.cc new file mode 100755 index 0000000..db2afb2 --- /dev/null +++ b/gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_linux.cc @@ -0,0 +1,336 @@ +#if defined(__linux__) + +/* + * Copyright (c) 2014 Craig Lilley + * This software is made available under the terms of the MIT licence. + * A copy of the licence can be obtained from: + * http://opensource.org/licenses/MIT + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "serial/serial.h" + +using serial::PortInfo; +using std::istringstream; +using std::ifstream; +using std::getline; +using std::vector; +using std::string; +using std::cout; +using std::endl; + +static vector glob(const vector& patterns); +static string basename(const string& path); +static string dirname(const string& path); +static bool path_exists(const string& path); +static string realpath(const string& path); +static string usb_sysfs_friendly_name(const string& sys_usb_path); +static vector get_sysfs_info(const string& device_path); +static string read_line(const string& file); +static string usb_sysfs_hw_string(const string& sysfs_path); +static string format(const char* format, ...); + +vector +glob(const vector& patterns) +{ + vector paths_found; + + if(patterns.size() == 0) + return paths_found; + + glob_t glob_results; + + int glob_retval = glob(patterns[0].c_str(), 0, NULL, &glob_results); + + vector::const_iterator iter = patterns.begin(); + + while(++iter != patterns.end()) + { + glob_retval = glob(iter->c_str(), GLOB_APPEND, NULL, &glob_results); + } + + for(int path_index = 0; path_index < glob_results.gl_pathc; path_index++) + { + paths_found.push_back(glob_results.gl_pathv[path_index]); + } + + globfree(&glob_results); + + return paths_found; +} + +string +basename(const string& path) +{ + size_t pos = path.rfind("/"); + + if(pos == std::string::npos) + return path; + + return string(path, pos+1, string::npos); +} + +string +dirname(const string& path) +{ + size_t pos = path.rfind("/"); + + if(pos == std::string::npos) + return path; + else if(pos == 0) + return "/"; + + return string(path, 0, pos); +} + +bool +path_exists(const string& path) +{ + struct stat sb; + + if( stat(path.c_str(), &sb ) == 0 ) + return true; + + return false; +} + +string +realpath(const string& path) +{ + char* real_path = realpath(path.c_str(), NULL); + + string result; + + if(real_path != NULL) + { + result = real_path; + + free(real_path); + } + + return result; +} + +string +usb_sysfs_friendly_name(const string& sys_usb_path) +{ + unsigned int device_number = 0; + + istringstream( read_line(sys_usb_path + "/devnum") ) >> device_number; + + string manufacturer = read_line( sys_usb_path + "/manufacturer" ); + + string product = read_line( sys_usb_path + "/product" ); + + string serial = read_line( sys_usb_path + "/serial" ); + + if( manufacturer.empty() && product.empty() && serial.empty() ) + return ""; + + return format("%s %s %s", manufacturer.c_str(), product.c_str(), serial.c_str() ); +} + +vector +get_sysfs_info(const string& device_path) +{ + string device_name = basename( device_path ); + + string friendly_name; + + string hardware_id; + + string sys_device_path = format( "/sys/class/tty/%s/device", device_name.c_str() ); + + if( device_name.compare(0,6,"ttyUSB") == 0 ) + { + sys_device_path = dirname( dirname( realpath( sys_device_path ) ) ); + + if( path_exists( sys_device_path ) ) + { + friendly_name = usb_sysfs_friendly_name( sys_device_path ); + + hardware_id = usb_sysfs_hw_string( sys_device_path ); + } + } + else if( device_name.compare(0,6,"ttyACM") == 0 ) + { + sys_device_path = dirname( realpath( sys_device_path ) ); + + if( path_exists( sys_device_path ) ) + { + friendly_name = usb_sysfs_friendly_name( sys_device_path ); + + hardware_id = usb_sysfs_hw_string( sys_device_path ); + } + } + else + { + // Try to read ID string of PCI device + + string sys_id_path = sys_device_path + "/id"; + + if( path_exists( sys_id_path ) ) + hardware_id = read_line( sys_id_path ); + } + + if( friendly_name.empty() ) + friendly_name = device_name; + + if( hardware_id.empty() ) + hardware_id = "n/a"; + + vector result; + result.push_back(friendly_name); + result.push_back(hardware_id); + + return result; +} + +string +read_line(const string& file) +{ + ifstream ifs(file.c_str(), ifstream::in); + + string line; + + if(ifs) + { + getline(ifs, line); + } + + return line; +} + +string +format(const char* format, ...) +{ + va_list ap; + + size_t buffer_size_bytes = 256; + + string result; + + char* buffer = (char*)malloc(buffer_size_bytes); + + if( buffer == NULL ) + return result; + + bool done = false; + + unsigned int loop_count = 0; + + while(!done) + { + va_start(ap, format); + + int return_value = vsnprintf(buffer, buffer_size_bytes, format, ap); + + if( return_value < 0 ) + { + done = true; + } + else if( return_value >= buffer_size_bytes ) + { + // Realloc and try again. + + buffer_size_bytes = return_value + 1; + + char* new_buffer_ptr = (char*)realloc(buffer, buffer_size_bytes); + + if( new_buffer_ptr == NULL ) + { + done = true; + } + else + { + buffer = new_buffer_ptr; + } + } + else + { + result = buffer; + done = true; + } + + va_end(ap); + + if( ++loop_count > 5 ) + done = true; + } + + free(buffer); + + return result; +} + +string +usb_sysfs_hw_string(const string& sysfs_path) +{ + string serial_number = read_line( sysfs_path + "/serial" ); + + if( serial_number.length() > 0 ) + { + serial_number = format( "SNR=%s", serial_number.c_str() ); + } + + string vid = read_line( sysfs_path + "/idVendor" ); + + string pid = read_line( sysfs_path + "/idProduct" ); + + return format("USB VID:PID=%s:%s %s", vid.c_str(), pid.c_str(), serial_number.c_str() ); +} + +vector +serial::list_ports() +{ + vector results; + + vector search_globs; + search_globs.push_back("/dev/ttyACM*"); + search_globs.push_back("/dev/ttyS*"); + search_globs.push_back("/dev/ttyUSB*"); + search_globs.push_back("/dev/tty.*"); + search_globs.push_back("/dev/cu.*"); + search_globs.push_back("/dev/rfcomm*"); + + vector devices_found = glob( search_globs ); + + vector::iterator iter = devices_found.begin(); + + while( iter != devices_found.end() ) + { + string device = *iter++; + + vector sysfs_info = get_sysfs_info( device ); + + string friendly_name = sysfs_info[0]; + + string hardware_id = sysfs_info[1]; + + PortInfo device_entry; + device_entry.port = device; + device_entry.description = friendly_name; + device_entry.hardware_id = hardware_id; + + results.push_back( device_entry ); + + } + + return results; +} + +#endif // defined(__linux__) diff --git a/gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_osx.cc b/gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_osx.cc new file mode 100755 index 0000000..333c55c --- /dev/null +++ b/gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_osx.cc @@ -0,0 +1,286 @@ +#if defined(__APPLE__) + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include "serial/serial.h" + +using serial::PortInfo; +using std::string; +using std::vector; + +#define HARDWARE_ID_STRING_LENGTH 128 + +string cfstring_to_string( CFStringRef cfstring ); +string get_device_path( io_object_t& serial_port ); +string get_class_name( io_object_t& obj ); +io_registry_entry_t get_parent_iousb_device( io_object_t& serial_port ); +string get_string_property( io_object_t& device, const char* property ); +uint16_t get_int_property( io_object_t& device, const char* property ); +string rtrim(const string& str); + +string +cfstring_to_string( CFStringRef cfstring ) +{ + char cstring[MAXPATHLEN]; + string result; + + if( cfstring ) + { + Boolean success = CFStringGetCString( cfstring, + cstring, + sizeof(cstring), + kCFStringEncodingASCII ); + + if( success ) + result = cstring; + } + + return result; +} + +string +get_device_path( io_object_t& serial_port ) +{ + CFTypeRef callout_path; + string device_path; + + callout_path = IORegistryEntryCreateCFProperty( serial_port, + CFSTR(kIOCalloutDeviceKey), + kCFAllocatorDefault, + 0 ); + + if (callout_path) + { + if( CFGetTypeID(callout_path) == CFStringGetTypeID() ) + device_path = cfstring_to_string( static_cast(callout_path) ); + + CFRelease(callout_path); + } + + return device_path; +} + +string +get_class_name( io_object_t& obj ) +{ + string result; + io_name_t class_name; + kern_return_t kern_result; + + kern_result = IOObjectGetClass( obj, class_name ); + + if( kern_result == KERN_SUCCESS ) + result = class_name; + + return result; +} + +io_registry_entry_t +get_parent_iousb_device( io_object_t& serial_port ) +{ + io_object_t device = serial_port; + io_registry_entry_t parent = 0; + io_registry_entry_t result = 0; + kern_return_t kern_result = KERN_FAILURE; + string name = get_class_name(device); + + // Walk the IO Registry tree looking for this devices parent IOUSBDevice. + while( name != "IOUSBDevice" ) + { + kern_result = IORegistryEntryGetParentEntry( device, + kIOServicePlane, + &parent ); + + if(kern_result != KERN_SUCCESS) + { + result = 0; + break; + } + + device = parent; + + name = get_class_name(device); + } + + if(kern_result == KERN_SUCCESS) + result = device; + + return result; +} + +string +get_string_property( io_object_t& device, const char* property ) +{ + string property_name; + + if( device ) + { + CFStringRef property_as_cfstring = CFStringCreateWithCString ( + kCFAllocatorDefault, + property, + kCFStringEncodingASCII ); + + CFTypeRef name_as_cfstring = IORegistryEntryCreateCFProperty( + device, + property_as_cfstring, + kCFAllocatorDefault, + 0 ); + + if( name_as_cfstring ) + { + if( CFGetTypeID(name_as_cfstring) == CFStringGetTypeID() ) + property_name = cfstring_to_string( static_cast(name_as_cfstring) ); + + CFRelease(name_as_cfstring); + } + + if(property_as_cfstring) + CFRelease(property_as_cfstring); + } + + return property_name; +} + +uint16_t +get_int_property( io_object_t& device, const char* property ) +{ + uint16_t result = 0; + + if( device ) + { + CFStringRef property_as_cfstring = CFStringCreateWithCString ( + kCFAllocatorDefault, + property, + kCFStringEncodingASCII ); + + CFTypeRef number = IORegistryEntryCreateCFProperty( device, + property_as_cfstring, + kCFAllocatorDefault, + 0 ); + + if(property_as_cfstring) + CFRelease(property_as_cfstring); + + if( number ) + { + if( CFGetTypeID(number) == CFNumberGetTypeID() ) + { + bool success = CFNumberGetValue( static_cast(number), + kCFNumberSInt16Type, + &result ); + + if( !success ) + result = 0; + } + + CFRelease(number); + } + + } + + return result; +} + +string rtrim(const string& str) +{ + string result = str; + + string whitespace = " \t\f\v\n\r"; + + std::size_t found = result.find_last_not_of(whitespace); + + if (found != std::string::npos) + result.erase(found+1); + else + result.clear(); + + return result; +} + +vector +serial::list_ports(void) +{ + vector devices_found; + CFMutableDictionaryRef classes_to_match; + io_iterator_t serial_port_iterator; + io_object_t serial_port; + mach_port_t master_port; + kern_return_t kern_result; + + kern_result = IOMasterPort(MACH_PORT_NULL, &master_port); + + if(kern_result != KERN_SUCCESS) + return devices_found; + + classes_to_match = IOServiceMatching(kIOSerialBSDServiceValue); + + if (classes_to_match == NULL) + return devices_found; + + CFDictionarySetValue( classes_to_match, + CFSTR(kIOSerialBSDTypeKey), + CFSTR(kIOSerialBSDAllTypes) ); + + kern_result = IOServiceGetMatchingServices(master_port, classes_to_match, &serial_port_iterator); + + if (KERN_SUCCESS != kern_result) + return devices_found; + + while ( (serial_port = IOIteratorNext(serial_port_iterator)) ) + { + string device_path = get_device_path( serial_port ); + io_registry_entry_t parent = get_parent_iousb_device( serial_port ); + IOObjectRelease(serial_port); + + if( device_path.empty() ) + continue; + + PortInfo port_info; + port_info.port = device_path; + port_info.description = "n/a"; + port_info.hardware_id = "n/a"; + + string device_name = rtrim( get_string_property( parent, "USB Product Name" ) ); + string vendor_name = rtrim( get_string_property( parent, "USB Vendor Name") ); + string description = rtrim( vendor_name + " " + device_name ); + if( !description.empty() ) + port_info.description = description; + + string serial_number = rtrim(get_string_property( parent, "USB Serial Number" ) ); + uint16_t vendor_id = get_int_property( parent, "idVendor" ); + uint16_t product_id = get_int_property( parent, "idProduct" ); + + if( vendor_id && product_id ) + { + char cstring[HARDWARE_ID_STRING_LENGTH]; + + if(serial_number.empty()) + serial_number = "None"; + + int ret = snprintf( cstring, HARDWARE_ID_STRING_LENGTH, "USB VID:PID=%04x:%04x SNR=%s", + vendor_id, + product_id, + serial_number.c_str() ); + + if( (ret >= 0) && (ret < HARDWARE_ID_STRING_LENGTH) ) + port_info.hardware_id = cstring; + } + + devices_found.push_back(port_info); + } + + IOObjectRelease(serial_port_iterator); + return devices_found; +} + +#endif // defined(__APPLE__) diff --git a/gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_win.cc b/gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_win.cc new file mode 100755 index 0000000..7da40c4 --- /dev/null +++ b/gimbal_ctrl/IOs/serial/src/impl/list_ports/list_ports_win.cc @@ -0,0 +1,152 @@ +#if defined(_WIN32) + +/* + * Copyright (c) 2014 Craig Lilley + * This software is made available under the terms of the MIT licence. + * A copy of the licence can be obtained from: + * http://opensource.org/licenses/MIT + */ + +#include "serial/serial.h" +#include +#include +#include +#include +#include +#include + +using serial::PortInfo; +using std::vector; +using std::string; + +static const DWORD port_name_max_length = 256; +static const DWORD friendly_name_max_length = 256; +static const DWORD hardware_id_max_length = 256; + +// Convert a wide Unicode string to an UTF8 string +std::string utf8_encode(const std::wstring &wstr) +{ + int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); + std::string strTo( size_needed, 0 ); + WideCharToMultiByte (CP_UTF8, 0, &wstr[0], (int)wstr.size(), &strTo[0], size_needed, NULL, NULL); + return strTo; +} + +vector +serial::list_ports() +{ + vector devices_found; + + HDEVINFO device_info_set = SetupDiGetClassDevs( + (const GUID *) &GUID_DEVCLASS_PORTS, + NULL, + NULL, + DIGCF_PRESENT); + + unsigned int device_info_set_index = 0; + SP_DEVINFO_DATA device_info_data; + + device_info_data.cbSize = sizeof(SP_DEVINFO_DATA); + + while(SetupDiEnumDeviceInfo(device_info_set, device_info_set_index, &device_info_data)) + { + device_info_set_index++; + + // Get port name + + HKEY hkey = SetupDiOpenDevRegKey( + device_info_set, + &device_info_data, + DICS_FLAG_GLOBAL, + 0, + DIREG_DEV, + KEY_READ); + + TCHAR port_name[port_name_max_length]; + DWORD port_name_length = port_name_max_length; + + LONG return_code = RegQueryValueEx( + hkey, + _T("PortName"), + NULL, + NULL, + (LPBYTE)port_name, + &port_name_length); + + RegCloseKey(hkey); + + if(return_code != EXIT_SUCCESS) + continue; + + if(port_name_length > 0 && port_name_length <= port_name_max_length) + port_name[port_name_length-1] = '\0'; + else + port_name[0] = '\0'; + + // Ignore parallel ports + + if(_tcsstr(port_name, _T("LPT")) != NULL) + continue; + + // Get port friendly name + + TCHAR friendly_name[friendly_name_max_length]; + DWORD friendly_name_actual_length = 0; + + BOOL got_friendly_name = SetupDiGetDeviceRegistryProperty( + device_info_set, + &device_info_data, + SPDRP_FRIENDLYNAME, + NULL, + (PBYTE)friendly_name, + friendly_name_max_length, + &friendly_name_actual_length); + + if(got_friendly_name == TRUE && friendly_name_actual_length > 0) + friendly_name[friendly_name_actual_length-1] = '\0'; + else + friendly_name[0] = '\0'; + + // Get hardware ID + + TCHAR hardware_id[hardware_id_max_length]; + DWORD hardware_id_actual_length = 0; + + BOOL got_hardware_id = SetupDiGetDeviceRegistryProperty( + device_info_set, + &device_info_data, + SPDRP_HARDWAREID, + NULL, + (PBYTE)hardware_id, + hardware_id_max_length, + &hardware_id_actual_length); + + if(got_hardware_id == TRUE && hardware_id_actual_length > 0) + hardware_id[hardware_id_actual_length-1] = '\0'; + else + hardware_id[0] = '\0'; + + #ifdef UNICODE + std::string portName = utf8_encode(port_name); + std::string friendlyName = utf8_encode(friendly_name); + std::string hardwareId = utf8_encode(hardware_id); + #else + std::string portName = port_name; + std::string friendlyName = friendly_name; + std::string hardwareId = hardware_id; + #endif + + PortInfo port_entry; + port_entry.port = portName; + port_entry.description = friendlyName; + port_entry.hardware_id = hardwareId; + + devices_found.push_back(port_entry); + } + + SetupDiDestroyDeviceInfoList(device_info_set); + + return devices_found; +} + +#endif // #if defined(_WIN32) diff --git a/gimbal_ctrl/IOs/serial/src/impl/unix.cc b/gimbal_ctrl/IOs/serial/src/impl/unix.cc new file mode 100755 index 0000000..a40b0fa --- /dev/null +++ b/gimbal_ctrl/IOs/serial/src/impl/unix.cc @@ -0,0 +1,1084 @@ +/* Copyright 2012 William Woodall and John Harrison + * + * Additional Contributors: Christopher Baker @bakercp + */ + +#if !defined(_WIN32) + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__linux__) +# include +#endif + +#include +#include +#include +#ifdef __MACH__ +#include +#include +#include +#endif + +#include "serial/impl/unix.h" + +#ifndef TIOCINQ +#ifdef FIONREAD +#define TIOCINQ FIONREAD +#else +#define TIOCINQ 0x541B +#endif +#endif + +#if defined(MAC_OS_X_VERSION_10_3) && (MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_3) +#include +#endif + +using std::string; +using std::stringstream; +using std::invalid_argument; +using serial::MillisecondTimer; +using serial::Serial; +using serial::SerialException; +using serial::PortNotOpenedException; +using serial::IOException; + + +MillisecondTimer::MillisecondTimer (const uint32_t millis) + : expiry(timespec_now()) +{ + int64_t tv_nsec = expiry.tv_nsec + (millis * 1e6); + if (tv_nsec >= 1e9) { + int64_t sec_diff = tv_nsec / static_cast (1e9); + expiry.tv_nsec = tv_nsec % static_cast(1e9); + expiry.tv_sec += sec_diff; + } else { + expiry.tv_nsec = tv_nsec; + } +} + +int64_t +MillisecondTimer::remaining () +{ + timespec now(timespec_now()); + int64_t millis = (expiry.tv_sec - now.tv_sec) * 1e3; + millis += (expiry.tv_nsec - now.tv_nsec) / 1e6; + return millis; +} + +timespec +MillisecondTimer::timespec_now () +{ + timespec time; +# ifdef __MACH__ // OS X does not have clock_gettime, use clock_get_time + clock_serv_t cclock; + mach_timespec_t mts; + host_get_clock_service(mach_host_self(), SYSTEM_CLOCK, &cclock); + clock_get_time(cclock, &mts); + mach_port_deallocate(mach_task_self(), cclock); + time.tv_sec = mts.tv_sec; + time.tv_nsec = mts.tv_nsec; +# else + clock_gettime(CLOCK_MONOTONIC, &time); +# endif + return time; +} + +timespec +timespec_from_ms (const uint32_t millis) +{ + timespec time; + time.tv_sec = millis / 1e3; + time.tv_nsec = (millis - (time.tv_sec * 1e3)) * 1e6; + return time; +} + +Serial::SerialImpl::SerialImpl (const string &port, unsigned long baudrate, + bytesize_t bytesize, + parity_t parity, stopbits_t stopbits, + flowcontrol_t flowcontrol) + : port_ (port), fd_ (-1), is_open_ (false), xonxoff_ (false), rtscts_ (false), + baudrate_ (baudrate), parity_ (parity), + bytesize_ (bytesize), stopbits_ (stopbits), flowcontrol_ (flowcontrol) +{ + pthread_mutex_init(&this->read_mutex, NULL); + pthread_mutex_init(&this->write_mutex, NULL); + if (port_.empty () == false) + open (); +} + +Serial::SerialImpl::~SerialImpl () +{ + close(); + pthread_mutex_destroy(&this->read_mutex); + pthread_mutex_destroy(&this->write_mutex); +} + +void +Serial::SerialImpl::open () +{ + if (port_.empty ()) { + throw invalid_argument ("Empty port is invalid."); + } + if (is_open_ == true) { + throw SerialException ("Serial port already open."); + } + + fd_ = ::open (port_.c_str(), O_RDWR | O_NOCTTY | O_NONBLOCK); + + if (fd_ == -1) { + switch (errno) { + case EINTR: + // Recurse because this is a recoverable error. + open (); + return; + case ENFILE: + case EMFILE: + THROW (IOException, "Too many file handles open."); + default: + THROW (IOException, errno); + } + } + + reconfigurePort(); + is_open_ = true; +} + +void +Serial::SerialImpl::reconfigurePort () +{ + if (fd_ == -1) { + // Can only operate on a valid file descriptor + THROW (IOException, "Invalid file descriptor, is the serial port open?"); + } + + struct termios options; // The options for the file descriptor + + if (tcgetattr(fd_, &options) == -1) { + THROW (IOException, "::tcgetattr"); + } + + // set up raw mode / no echo / binary + options.c_cflag |= (tcflag_t) (CLOCAL | CREAD); + options.c_lflag &= (tcflag_t) ~(ICANON | ECHO | ECHOE | ECHOK | ECHONL | + ISIG | IEXTEN); //|ECHOPRT + + options.c_oflag &= (tcflag_t) ~(OPOST); + options.c_iflag &= (tcflag_t) ~(INLCR | IGNCR | ICRNL | IGNBRK); +#ifdef IUCLC + options.c_iflag &= (tcflag_t) ~IUCLC; +#endif +#ifdef PARMRK + options.c_iflag &= (tcflag_t) ~PARMRK; +#endif + + // setup baud rate + bool custom_baud = false; + speed_t baud; + switch (baudrate_) { +#ifdef B0 + case 0: baud = B0; break; +#endif +#ifdef B50 + case 50: baud = B50; break; +#endif +#ifdef B75 + case 75: baud = B75; break; +#endif +#ifdef B110 + case 110: baud = B110; break; +#endif +#ifdef B134 + case 134: baud = B134; break; +#endif +#ifdef B150 + case 150: baud = B150; break; +#endif +#ifdef B200 + case 200: baud = B200; break; +#endif +#ifdef B300 + case 300: baud = B300; break; +#endif +#ifdef B600 + case 600: baud = B600; break; +#endif +#ifdef B1200 + case 1200: baud = B1200; break; +#endif +#ifdef B1800 + case 1800: baud = B1800; break; +#endif +#ifdef B2400 + case 2400: baud = B2400; break; +#endif +#ifdef B4800 + case 4800: baud = B4800; break; +#endif +#ifdef B7200 + case 7200: baud = B7200; break; +#endif +#ifdef B9600 + case 9600: baud = B9600; break; +#endif +#ifdef B14400 + case 14400: baud = B14400; break; +#endif +#ifdef B19200 + case 19200: baud = B19200; break; +#endif +#ifdef B28800 + case 28800: baud = B28800; break; +#endif +#ifdef B57600 + case 57600: baud = B57600; break; +#endif +#ifdef B76800 + case 76800: baud = B76800; break; +#endif +#ifdef B38400 + case 38400: baud = B38400; break; +#endif +#ifdef B115200 + case 115200: baud = B115200; break; +#endif +#ifdef B128000 + case 128000: baud = B128000; break; +#endif +#ifdef B153600 + case 153600: baud = B153600; break; +#endif +#ifdef B230400 + case 230400: baud = B230400; break; +#endif +#ifdef B256000 + case 256000: baud = B256000; break; +#endif +#ifdef B460800 + case 460800: baud = B460800; break; +#endif +#ifdef B500000 + case 500000: baud = B500000; break; +#endif +#ifdef B576000 + case 576000: baud = B576000; break; +#endif +#ifdef B921600 + case 921600: baud = B921600; break; +#endif +#ifdef B1000000 + case 1000000: baud = B1000000; break; +#endif +#ifdef B1152000 + case 1152000: baud = B1152000; break; +#endif +#ifdef B1500000 + case 1500000: baud = B1500000; break; +#endif +#ifdef B2000000 + case 2000000: baud = B2000000; break; +#endif +#ifdef B2500000 + case 2500000: baud = B2500000; break; +#endif +#ifdef B3000000 + case 3000000: baud = B3000000; break; +#endif +#ifdef B3500000 + case 3500000: baud = B3500000; break; +#endif +#ifdef B4000000 + case 4000000: baud = B4000000; break; +#endif + default: + custom_baud = true; + } + if (custom_baud == false) { +#ifdef _BSD_SOURCE + ::cfsetspeed(&options, baud); +#else + ::cfsetispeed(&options, baud); + ::cfsetospeed(&options, baud); +#endif + } + + // setup char len + options.c_cflag &= (tcflag_t) ~CSIZE; + if (bytesize_ == eightbits) + options.c_cflag |= CS8; + else if (bytesize_ == sevenbits) + options.c_cflag |= CS7; + else if (bytesize_ == sixbits) + options.c_cflag |= CS6; + else if (bytesize_ == fivebits) + options.c_cflag |= CS5; + else + throw invalid_argument ("invalid char len"); + // setup stopbits + if (stopbits_ == stopbits_one) + options.c_cflag &= (tcflag_t) ~(CSTOPB); + else if (stopbits_ == stopbits_one_point_five) + // ONE POINT FIVE same as TWO.. there is no POSIX support for 1.5 + options.c_cflag |= (CSTOPB); + else if (stopbits_ == stopbits_two) + options.c_cflag |= (CSTOPB); + else + throw invalid_argument ("invalid stop bit"); + // setup parity + options.c_iflag &= (tcflag_t) ~(INPCK | ISTRIP); + if (parity_ == parity_none) { + options.c_cflag &= (tcflag_t) ~(PARENB | PARODD); + } else if (parity_ == parity_even) { + options.c_cflag &= (tcflag_t) ~(PARODD); + options.c_cflag |= (PARENB); + } else if (parity_ == parity_odd) { + options.c_cflag |= (PARENB | PARODD); + } +#ifdef CMSPAR + else if (parity_ == parity_mark) { + options.c_cflag |= (PARENB | CMSPAR | PARODD); + } + else if (parity_ == parity_space) { + options.c_cflag |= (PARENB | CMSPAR); + options.c_cflag &= (tcflag_t) ~(PARODD); + } +#else + // CMSPAR is not defined on OSX. So do not support mark or space parity. + else if (parity_ == parity_mark || parity_ == parity_space) { + throw invalid_argument ("OS does not support mark or space parity"); + } +#endif // ifdef CMSPAR + else { + throw invalid_argument ("invalid parity"); + } + // setup flow control + if (flowcontrol_ == flowcontrol_none) { + xonxoff_ = false; + rtscts_ = false; + } + if (flowcontrol_ == flowcontrol_software) { + xonxoff_ = true; + rtscts_ = false; + } + if (flowcontrol_ == flowcontrol_hardware) { + xonxoff_ = false; + rtscts_ = true; + } + // xonxoff +#ifdef IXANY + if (xonxoff_) + options.c_iflag |= (IXON | IXOFF); //|IXANY) + else + options.c_iflag &= (tcflag_t) ~(IXON | IXOFF | IXANY); +#else + if (xonxoff_) + options.c_iflag |= (IXON | IXOFF); + else + options.c_iflag &= (tcflag_t) ~(IXON | IXOFF); +#endif + // rtscts +#ifdef CRTSCTS + if (rtscts_) + options.c_cflag |= (CRTSCTS); + else + options.c_cflag &= (unsigned long) ~(CRTSCTS); +#elif defined CNEW_RTSCTS + if (rtscts_) + options.c_cflag |= (CNEW_RTSCTS); + else + options.c_cflag &= (unsigned long) ~(CNEW_RTSCTS); +#else +#error "OS Support seems wrong." +#endif + + // http://www.unixwiz.net/techtips/termios-vmin-vtime.html + // this basically sets the read call up to be a polling read, + // but we are using select to ensure there is data available + // to read before each call, so we should never needlessly poll + options.c_cc[VMIN] = 0; + options.c_cc[VTIME] = 0; + + // activate settings + ::tcsetattr (fd_, TCSANOW, &options); + + // apply custom baud rate, if any + if (custom_baud == true) { + // OS X support +#if defined(MAC_OS_X_VERSION_10_4) && (MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_4) + // Starting with Tiger, the IOSSIOSPEED ioctl can be used to set arbitrary baud rates + // other than those specified by POSIX. The driver for the underlying serial hardware + // ultimately determines which baud rates can be used. This ioctl sets both the input + // and output speed. + speed_t new_baud = static_cast (baudrate_); + // PySerial uses IOSSIOSPEED=0x80045402 + if (-1 == ioctl (fd_, IOSSIOSPEED, &new_baud, 1)) { + THROW (IOException, errno); + } + // Linux Support +#elif defined(__linux__) && defined (TIOCSSERIAL) + struct serial_struct ser; + + if (-1 == ioctl (fd_, TIOCGSERIAL, &ser)) { + THROW (IOException, errno); + } + + // set custom divisor + ser.custom_divisor = ser.baud_base / static_cast (baudrate_); + // update flags + ser.flags &= ~ASYNC_SPD_MASK; + ser.flags |= ASYNC_SPD_CUST; + + if (-1 == ioctl (fd_, TIOCSSERIAL, &ser)) { + THROW (IOException, errno); + } +#else + throw invalid_argument ("OS does not currently support custom bauds"); +#endif + } + + // Update byte_time_ based on the new settings. + uint32_t bit_time_ns = 1e9 / baudrate_; + byte_time_ns_ = bit_time_ns * (1 + bytesize_ + parity_ + stopbits_); + + // Compensate for the stopbits_one_point_five enum being equal to int 3, + // and not 1.5. + if (stopbits_ == stopbits_one_point_five) { + byte_time_ns_ += ((1.5 - stopbits_one_point_five) * bit_time_ns); + } +} + +void +Serial::SerialImpl::close () +{ + if (is_open_ == true) { + if (fd_ != -1) { + int ret; + ret = ::close (fd_); + if (ret == 0) { + fd_ = -1; + } else { + THROW (IOException, errno); + } + } + is_open_ = false; + } +} + +bool +Serial::SerialImpl::isOpen () const +{ + return is_open_; +} + +size_t +Serial::SerialImpl::available () +{ + if (!is_open_) { + return 0; + } + int count = 0; + if (-1 == ioctl (fd_, TIOCINQ, &count)) { + THROW (IOException, errno); + } else { + return static_cast (count); + } +} + +bool +Serial::SerialImpl::waitReadable (uint32_t timeout) +{ + // Setup a select call to block for serial data or a timeout + fd_set readfds; + FD_ZERO (&readfds); + FD_SET (fd_, &readfds); + timespec timeout_ts (timespec_from_ms (timeout)); + int r = pselect (fd_ + 1, &readfds, NULL, NULL, &timeout_ts, NULL); + + if (r < 0) { + // Select was interrupted + if (errno == EINTR) { + return false; + } + // Otherwise there was some error + THROW (IOException, errno); + } + // Timeout occurred + if (r == 0) { + return false; + } + // This shouldn't happen, if r > 0 our fd has to be in the list! + if (!FD_ISSET (fd_, &readfds)) { + THROW (IOException, "select reports ready to read, but our fd isn't" + " in the list, this shouldn't happen!"); + } + // Data available to read. + return true; +} + +void +Serial::SerialImpl::waitByteTimes (size_t count) +{ + timespec wait_time = { 0, static_cast(byte_time_ns_ * count)}; + pselect (0, NULL, NULL, NULL, &wait_time, NULL); +} + +size_t +Serial::SerialImpl::read (uint8_t *buf, size_t size) +{ + // If the port is not open, throw + if (!is_open_) { + throw PortNotOpenedException ("Serial::read"); + } + size_t bytes_read = 0; + + // Calculate total timeout in milliseconds t_c + (t_m * N) + long total_timeout_ms = timeout_.read_timeout_constant; + total_timeout_ms += timeout_.read_timeout_multiplier * static_cast (size); + MillisecondTimer total_timeout(total_timeout_ms); + + // Pre-fill buffer with available bytes + { + ssize_t bytes_read_now = ::read (fd_, buf, size); + if (bytes_read_now > 0) { + bytes_read = bytes_read_now; + } + } + + while (bytes_read < size) { + int64_t timeout_remaining_ms = total_timeout.remaining(); + if (timeout_remaining_ms <= 0) { + // Timed out + break; + } + // Timeout for the next select is whichever is less of the remaining + // total read timeout and the inter-byte timeout. + uint32_t timeout = std::min(static_cast (timeout_remaining_ms), + timeout_.inter_byte_timeout); + // Wait for the device to be readable, and then attempt to read. + if (waitReadable(timeout)) { + // If it's a fixed-length multi-byte read, insert a wait here so that + // we can attempt to grab the whole thing in a single IO call. Skip + // this wait if a non-max inter_byte_timeout is specified. + if (size > 1 && timeout_.inter_byte_timeout == Timeout::max()) { + size_t bytes_available = available(); + if (bytes_available + bytes_read < size) { + waitByteTimes(size - (bytes_available + bytes_read)); + } + } + // This should be non-blocking returning only what is available now + // Then returning so that select can block again. + ssize_t bytes_read_now = + ::read (fd_, buf + bytes_read, size - bytes_read); + // read should always return some data as select reported it was + // ready to read when we get to this point. + if (bytes_read_now < 1) { + // Disconnected devices, at least on Linux, show the + // behavior that they are always ready to read immediately + // but reading returns nothing. + throw SerialException ("device reports readiness to read but " + "returned no data (device disconnected?)"); + } + // Update bytes_read + bytes_read += static_cast (bytes_read_now); + // If bytes_read == size then we have read everything we need + if (bytes_read == size) { + break; + } + // If bytes_read < size then we have more to read + if (bytes_read < size) { + continue; + } + // If bytes_read > size then we have over read, which shouldn't happen + if (bytes_read > size) { + throw SerialException ("read over read, too many bytes where " + "read, this shouldn't happen, might be " + "a logical error!"); + } + } + } + return bytes_read; +} + +size_t +Serial::SerialImpl::write (const uint8_t *data, size_t length) +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::write"); + } + fd_set writefds; + size_t bytes_written = 0; + + // Calculate total timeout in milliseconds t_c + (t_m * N) + long total_timeout_ms = timeout_.write_timeout_constant; + total_timeout_ms += timeout_.write_timeout_multiplier * static_cast (length); + MillisecondTimer total_timeout(total_timeout_ms); + + bool first_iteration = true; + while (bytes_written < length) { + int64_t timeout_remaining_ms = total_timeout.remaining(); + // Only consider the timeout if it's not the first iteration of the loop + // otherwise a timeout of 0 won't be allowed through + if (!first_iteration && (timeout_remaining_ms <= 0)) { + // Timed out + break; + } + first_iteration = false; + + timespec timeout(timespec_from_ms(timeout_remaining_ms)); + + FD_ZERO (&writefds); + FD_SET (fd_, &writefds); + + // Do the select + int r = pselect (fd_ + 1, NULL, &writefds, NULL, &timeout, NULL); + + // Figure out what happened by looking at select's response 'r' + /** Error **/ + if (r < 0) { + // Select was interrupted, try again + if (errno == EINTR) { + continue; + } + // Otherwise there was some error + THROW (IOException, errno); + } + /** Timeout **/ + if (r == 0) { + break; + } + /** Port ready to write **/ + if (r > 0) { + // Make sure our file descriptor is in the ready to write list + if (FD_ISSET (fd_, &writefds)) { + // This will write some + ssize_t bytes_written_now = + ::write (fd_, data + bytes_written, length - bytes_written); + + // even though pselect returned readiness the call might still be + // interrupted. In that case simply retry. + if (bytes_written_now == -1 && errno == EINTR) { + continue; + } + + // write should always return some data as select reported it was + // ready to write when we get to this point. + if (bytes_written_now < 1) { + // Disconnected devices, at least on Linux, show the + // behavior that they are always ready to write immediately + // but writing returns nothing. + std::stringstream strs; + strs << "device reports readiness to write but " + "returned no data (device disconnected?)"; + strs << " errno=" << errno; + strs << " bytes_written_now= " << bytes_written_now; + strs << " bytes_written=" << bytes_written; + strs << " length=" << length; + throw SerialException(strs.str().c_str()); + } + // Update bytes_written + bytes_written += static_cast (bytes_written_now); + // If bytes_written == size then we have written everything we need to + if (bytes_written == length) { + break; + } + // If bytes_written < size then we have more to write + if (bytes_written < length) { + continue; + } + // If bytes_written > size then we have over written, which shouldn't happen + if (bytes_written > length) { + throw SerialException ("write over wrote, too many bytes where " + "written, this shouldn't happen, might be " + "a logical error!"); + } + } + // This shouldn't happen, if r > 0 our fd has to be in the list! + THROW (IOException, "select reports ready to write, but our fd isn't" + " in the list, this shouldn't happen!"); + } + } + return bytes_written; +} + +void +Serial::SerialImpl::setPort (const string &port) +{ + port_ = port; +} + +string +Serial::SerialImpl::getPort () const +{ + return port_; +} + +void +Serial::SerialImpl::setTimeout (serial::Timeout &timeout) +{ + timeout_ = timeout; +} + +serial::Timeout +Serial::SerialImpl::getTimeout () const +{ + return timeout_; +} + +void +Serial::SerialImpl::setBaudrate (unsigned long baudrate) +{ + baudrate_ = baudrate; + if (is_open_) + reconfigurePort (); +} + +unsigned long +Serial::SerialImpl::getBaudrate () const +{ + return baudrate_; +} + +void +Serial::SerialImpl::setBytesize (serial::bytesize_t bytesize) +{ + bytesize_ = bytesize; + if (is_open_) + reconfigurePort (); +} + +serial::bytesize_t +Serial::SerialImpl::getBytesize () const +{ + return bytesize_; +} + +void +Serial::SerialImpl::setParity (serial::parity_t parity) +{ + parity_ = parity; + if (is_open_) + reconfigurePort (); +} + +serial::parity_t +Serial::SerialImpl::getParity () const +{ + return parity_; +} + +void +Serial::SerialImpl::setStopbits (serial::stopbits_t stopbits) +{ + stopbits_ = stopbits; + if (is_open_) + reconfigurePort (); +} + +serial::stopbits_t +Serial::SerialImpl::getStopbits () const +{ + return stopbits_; +} + +void +Serial::SerialImpl::setFlowcontrol (serial::flowcontrol_t flowcontrol) +{ + flowcontrol_ = flowcontrol; + if (is_open_) + reconfigurePort (); +} + +serial::flowcontrol_t +Serial::SerialImpl::getFlowcontrol () const +{ + return flowcontrol_; +} + +void +Serial::SerialImpl::flush () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::flush"); + } + tcdrain (fd_); +} + +void +Serial::SerialImpl::flushInput () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::flushInput"); + } + tcflush (fd_, TCIFLUSH); +} + +void +Serial::SerialImpl::flushOutput () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::flushOutput"); + } + tcflush (fd_, TCOFLUSH); +} + +void +Serial::SerialImpl::sendBreak (int duration) +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::sendBreak"); + } + tcsendbreak (fd_, static_cast (duration / 4)); +} + +void +Serial::SerialImpl::setBreak (bool level) +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::setBreak"); + } + + if (level) { + if (-1 == ioctl (fd_, TIOCSBRK)) + { + stringstream ss; + ss << "setBreak failed on a call to ioctl(TIOCSBRK): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + } else { + if (-1 == ioctl (fd_, TIOCCBRK)) + { + stringstream ss; + ss << "setBreak failed on a call to ioctl(TIOCCBRK): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + } +} + +void +Serial::SerialImpl::setRTS (bool level) +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::setRTS"); + } + + int command = TIOCM_RTS; + + if (level) { + if (-1 == ioctl (fd_, TIOCMBIS, &command)) + { + stringstream ss; + ss << "setRTS failed on a call to ioctl(TIOCMBIS): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + } else { + if (-1 == ioctl (fd_, TIOCMBIC, &command)) + { + stringstream ss; + ss << "setRTS failed on a call to ioctl(TIOCMBIC): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + } +} + +void +Serial::SerialImpl::setDTR (bool level) +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::setDTR"); + } + + int command = TIOCM_DTR; + + if (level) { + if (-1 == ioctl (fd_, TIOCMBIS, &command)) + { + stringstream ss; + ss << "setDTR failed on a call to ioctl(TIOCMBIS): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + } else { + if (-1 == ioctl (fd_, TIOCMBIC, &command)) + { + stringstream ss; + ss << "setDTR failed on a call to ioctl(TIOCMBIC): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + } +} + +bool +Serial::SerialImpl::waitForChange () +{ +#ifndef TIOCMIWAIT + +while (is_open_ == true) { + + int status; + + if (-1 == ioctl (fd_, TIOCMGET, &status)) + { + stringstream ss; + ss << "waitForChange failed on a call to ioctl(TIOCMGET): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + else + { + if (0 != (status & TIOCM_CTS) + || 0 != (status & TIOCM_DSR) + || 0 != (status & TIOCM_RI) + || 0 != (status & TIOCM_CD)) + { + return true; + } + } + + usleep(1000); + } + + return false; +#else + int command = (TIOCM_CD|TIOCM_DSR|TIOCM_RI|TIOCM_CTS); + + if (-1 == ioctl (fd_, TIOCMIWAIT, &command)) { + stringstream ss; + ss << "waitForDSR failed on a call to ioctl(TIOCMIWAIT): " + << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + return true; +#endif +} + +bool +Serial::SerialImpl::getCTS () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::getCTS"); + } + + int status; + + if (-1 == ioctl (fd_, TIOCMGET, &status)) + { + stringstream ss; + ss << "getCTS failed on a call to ioctl(TIOCMGET): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + else + { + return 0 != (status & TIOCM_CTS); + } +} + +bool +Serial::SerialImpl::getDSR () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::getDSR"); + } + + int status; + + if (-1 == ioctl (fd_, TIOCMGET, &status)) + { + stringstream ss; + ss << "getDSR failed on a call to ioctl(TIOCMGET): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + else + { + return 0 != (status & TIOCM_DSR); + } +} + +bool +Serial::SerialImpl::getRI () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::getRI"); + } + + int status; + + if (-1 == ioctl (fd_, TIOCMGET, &status)) + { + stringstream ss; + ss << "getRI failed on a call to ioctl(TIOCMGET): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + else + { + return 0 != (status & TIOCM_RI); + } +} + +bool +Serial::SerialImpl::getCD () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::getCD"); + } + + int status; + + if (-1 == ioctl (fd_, TIOCMGET, &status)) + { + stringstream ss; + ss << "getCD failed on a call to ioctl(TIOCMGET): " << errno << " " << strerror(errno); + throw(SerialException(ss.str().c_str())); + } + else + { + return 0 != (status & TIOCM_CD); + } +} + +void +Serial::SerialImpl::readLock () +{ + int result = pthread_mutex_lock(&this->read_mutex); + if (result) { + THROW (IOException, result); + } +} + +void +Serial::SerialImpl::readUnlock () +{ + int result = pthread_mutex_unlock(&this->read_mutex); + if (result) { + THROW (IOException, result); + } +} + +void +Serial::SerialImpl::writeLock () +{ + int result = pthread_mutex_lock(&this->write_mutex); + if (result) { + THROW (IOException, result); + } +} + +void +Serial::SerialImpl::writeUnlock () +{ + int result = pthread_mutex_unlock(&this->write_mutex); + if (result) { + THROW (IOException, result); + } +} + +#endif // !defined(_WIN32) diff --git a/gimbal_ctrl/IOs/serial/src/impl/win.cc b/gimbal_ctrl/IOs/serial/src/impl/win.cc new file mode 100755 index 0000000..889e06f --- /dev/null +++ b/gimbal_ctrl/IOs/serial/src/impl/win.cc @@ -0,0 +1,646 @@ +#if defined(_WIN32) + +/* Copyright 2012 William Woodall and John Harrison */ + +#include + +#include "serial/impl/win.h" + +using std::string; +using std::wstring; +using std::stringstream; +using std::invalid_argument; +using serial::Serial; +using serial::Timeout; +using serial::bytesize_t; +using serial::parity_t; +using serial::stopbits_t; +using serial::flowcontrol_t; +using serial::SerialException; +using serial::PortNotOpenedException; +using serial::IOException; + +inline wstring +_prefix_port_if_needed(const wstring &input) +{ + static wstring windows_com_port_prefix = L"\\\\.\\"; + if (input.compare(0, windows_com_port_prefix.size(), windows_com_port_prefix) != 0) + { + return windows_com_port_prefix + input; + } + return input; +} + +Serial::SerialImpl::SerialImpl (const string &port, unsigned long baudrate, + bytesize_t bytesize, + parity_t parity, stopbits_t stopbits, + flowcontrol_t flowcontrol) + : port_ (port.begin(), port.end()), fd_ (INVALID_HANDLE_VALUE), is_open_ (false), + baudrate_ (baudrate), parity_ (parity), + bytesize_ (bytesize), stopbits_ (stopbits), flowcontrol_ (flowcontrol) +{ + if (port_.empty () == false) + open (); + read_mutex = CreateMutex(NULL, false, NULL); + write_mutex = CreateMutex(NULL, false, NULL); +} + +Serial::SerialImpl::~SerialImpl () +{ + this->close(); + CloseHandle(read_mutex); + CloseHandle(write_mutex); +} + +void +Serial::SerialImpl::open () +{ + if (port_.empty ()) { + throw invalid_argument ("Empty port is invalid."); + } + if (is_open_ == true) { + throw SerialException ("Serial port already open."); + } + + // See: https://github.com/wjwwood/serial/issues/84 + wstring port_with_prefix = _prefix_port_if_needed(port_); + LPCWSTR lp_port = port_with_prefix.c_str(); + fd_ = CreateFileW(lp_port, + GENERIC_READ | GENERIC_WRITE, + 0, + 0, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + 0); + + if (fd_ == INVALID_HANDLE_VALUE) { + DWORD create_file_err = GetLastError(); + stringstream ss; + switch (create_file_err) { + case ERROR_FILE_NOT_FOUND: + // Use this->getPort to convert to a std::string + ss << "Specified port, " << this->getPort() << ", does not exist."; + THROW (IOException, ss.str().c_str()); + default: + ss << "Unknown error opening the serial port: " << create_file_err; + THROW (IOException, ss.str().c_str()); + } + } + + reconfigurePort(); + is_open_ = true; +} + +void +Serial::SerialImpl::reconfigurePort () +{ + if (fd_ == INVALID_HANDLE_VALUE) { + // Can only operate on a valid file descriptor + THROW (IOException, "Invalid file descriptor, is the serial port open?"); + } + + DCB dcbSerialParams = {0}; + + dcbSerialParams.DCBlength=sizeof(dcbSerialParams); + + if (!GetCommState(fd_, &dcbSerialParams)) { + //error getting state + THROW (IOException, "Error getting the serial port state."); + } + + // setup baud rate + switch (baudrate_) { +#ifdef CBR_0 + case 0: dcbSerialParams.BaudRate = CBR_0; break; +#endif +#ifdef CBR_50 + case 50: dcbSerialParams.BaudRate = CBR_50; break; +#endif +#ifdef CBR_75 + case 75: dcbSerialParams.BaudRate = CBR_75; break; +#endif +#ifdef CBR_110 + case 110: dcbSerialParams.BaudRate = CBR_110; break; +#endif +#ifdef CBR_134 + case 134: dcbSerialParams.BaudRate = CBR_134; break; +#endif +#ifdef CBR_150 + case 150: dcbSerialParams.BaudRate = CBR_150; break; +#endif +#ifdef CBR_200 + case 200: dcbSerialParams.BaudRate = CBR_200; break; +#endif +#ifdef CBR_300 + case 300: dcbSerialParams.BaudRate = CBR_300; break; +#endif +#ifdef CBR_600 + case 600: dcbSerialParams.BaudRate = CBR_600; break; +#endif +#ifdef CBR_1200 + case 1200: dcbSerialParams.BaudRate = CBR_1200; break; +#endif +#ifdef CBR_1800 + case 1800: dcbSerialParams.BaudRate = CBR_1800; break; +#endif +#ifdef CBR_2400 + case 2400: dcbSerialParams.BaudRate = CBR_2400; break; +#endif +#ifdef CBR_4800 + case 4800: dcbSerialParams.BaudRate = CBR_4800; break; +#endif +#ifdef CBR_7200 + case 7200: dcbSerialParams.BaudRate = CBR_7200; break; +#endif +#ifdef CBR_9600 + case 9600: dcbSerialParams.BaudRate = CBR_9600; break; +#endif +#ifdef CBR_14400 + case 14400: dcbSerialParams.BaudRate = CBR_14400; break; +#endif +#ifdef CBR_19200 + case 19200: dcbSerialParams.BaudRate = CBR_19200; break; +#endif +#ifdef CBR_28800 + case 28800: dcbSerialParams.BaudRate = CBR_28800; break; +#endif +#ifdef CBR_57600 + case 57600: dcbSerialParams.BaudRate = CBR_57600; break; +#endif +#ifdef CBR_76800 + case 76800: dcbSerialParams.BaudRate = CBR_76800; break; +#endif +#ifdef CBR_38400 + case 38400: dcbSerialParams.BaudRate = CBR_38400; break; +#endif +#ifdef CBR_115200 + case 115200: dcbSerialParams.BaudRate = CBR_115200; break; +#endif +#ifdef CBR_128000 + case 128000: dcbSerialParams.BaudRate = CBR_128000; break; +#endif +#ifdef CBR_153600 + case 153600: dcbSerialParams.BaudRate = CBR_153600; break; +#endif +#ifdef CBR_230400 + case 230400: dcbSerialParams.BaudRate = CBR_230400; break; +#endif +#ifdef CBR_256000 + case 256000: dcbSerialParams.BaudRate = CBR_256000; break; +#endif +#ifdef CBR_460800 + case 460800: dcbSerialParams.BaudRate = CBR_460800; break; +#endif +#ifdef CBR_921600 + case 921600: dcbSerialParams.BaudRate = CBR_921600; break; +#endif + default: + // Try to blindly assign it + dcbSerialParams.BaudRate = baudrate_; + } + + // setup char len + if (bytesize_ == eightbits) + dcbSerialParams.ByteSize = 8; + else if (bytesize_ == sevenbits) + dcbSerialParams.ByteSize = 7; + else if (bytesize_ == sixbits) + dcbSerialParams.ByteSize = 6; + else if (bytesize_ == fivebits) + dcbSerialParams.ByteSize = 5; + else + throw invalid_argument ("invalid char len"); + + // setup stopbits + if (stopbits_ == stopbits_one) + dcbSerialParams.StopBits = ONESTOPBIT; + else if (stopbits_ == stopbits_one_point_five) + dcbSerialParams.StopBits = ONE5STOPBITS; + else if (stopbits_ == stopbits_two) + dcbSerialParams.StopBits = TWOSTOPBITS; + else + throw invalid_argument ("invalid stop bit"); + + // setup parity + if (parity_ == parity_none) { + dcbSerialParams.Parity = NOPARITY; + } else if (parity_ == parity_even) { + dcbSerialParams.Parity = EVENPARITY; + } else if (parity_ == parity_odd) { + dcbSerialParams.Parity = ODDPARITY; + } else if (parity_ == parity_mark) { + dcbSerialParams.Parity = MARKPARITY; + } else if (parity_ == parity_space) { + dcbSerialParams.Parity = SPACEPARITY; + } else { + throw invalid_argument ("invalid parity"); + } + + // setup flowcontrol + if (flowcontrol_ == flowcontrol_none) { + dcbSerialParams.fOutxCtsFlow = false; + dcbSerialParams.fRtsControl = RTS_CONTROL_DISABLE; + dcbSerialParams.fOutX = false; + dcbSerialParams.fInX = false; + } + if (flowcontrol_ == flowcontrol_software) { + dcbSerialParams.fOutxCtsFlow = false; + dcbSerialParams.fRtsControl = RTS_CONTROL_DISABLE; + dcbSerialParams.fOutX = true; + dcbSerialParams.fInX = true; + } + if (flowcontrol_ == flowcontrol_hardware) { + dcbSerialParams.fOutxCtsFlow = true; + dcbSerialParams.fRtsControl = RTS_CONTROL_HANDSHAKE; + dcbSerialParams.fOutX = false; + dcbSerialParams.fInX = false; + } + + // activate settings + if (!SetCommState(fd_, &dcbSerialParams)){ + CloseHandle(fd_); + THROW (IOException, "Error setting serial port settings."); + } + + // Setup timeouts + COMMTIMEOUTS timeouts = {0}; + timeouts.ReadIntervalTimeout = timeout_.inter_byte_timeout; + timeouts.ReadTotalTimeoutConstant = timeout_.read_timeout_constant; + timeouts.ReadTotalTimeoutMultiplier = timeout_.read_timeout_multiplier; + timeouts.WriteTotalTimeoutConstant = timeout_.write_timeout_constant; + timeouts.WriteTotalTimeoutMultiplier = timeout_.write_timeout_multiplier; + if (!SetCommTimeouts(fd_, &timeouts)) { + THROW (IOException, "Error setting timeouts."); + } +} + +void +Serial::SerialImpl::close () +{ + if (is_open_ == true) { + if (fd_ != INVALID_HANDLE_VALUE) { + int ret; + ret = CloseHandle(fd_); + if (ret == 0) { + stringstream ss; + ss << "Error while closing serial port: " << GetLastError(); + THROW (IOException, ss.str().c_str()); + } else { + fd_ = INVALID_HANDLE_VALUE; + } + } + is_open_ = false; + } +} + +bool +Serial::SerialImpl::isOpen () const +{ + return is_open_; +} + +size_t +Serial::SerialImpl::available () +{ + if (!is_open_) { + return 0; + } + COMSTAT cs; + if (!ClearCommError(fd_, NULL, &cs)) { + stringstream ss; + ss << "Error while checking status of the serial port: " << GetLastError(); + THROW (IOException, ss.str().c_str()); + } + return static_cast(cs.cbInQue); +} + +bool +Serial::SerialImpl::waitReadable (uint32_t /*timeout*/) +{ + THROW (IOException, "waitReadable is not implemented on Windows."); + return false; +} + +void +Serial::SerialImpl::waitByteTimes (size_t /*count*/) +{ + THROW (IOException, "waitByteTimes is not implemented on Windows."); +} + +size_t +Serial::SerialImpl::read (uint8_t *buf, size_t size) +{ + if (!is_open_) { + throw PortNotOpenedException ("Serial::read"); + } + DWORD bytes_read; + if (!ReadFile(fd_, buf, static_cast(size), &bytes_read, NULL)) { + stringstream ss; + ss << "Error while reading from the serial port: " << GetLastError(); + THROW (IOException, ss.str().c_str()); + } + return (size_t) (bytes_read); +} + +size_t +Serial::SerialImpl::write (const uint8_t *data, size_t length) +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::write"); + } + DWORD bytes_written; + if (!WriteFile(fd_, data, static_cast(length), &bytes_written, NULL)) { + stringstream ss; + ss << "Error while writing to the serial port: " << GetLastError(); + THROW (IOException, ss.str().c_str()); + } + return (size_t) (bytes_written); +} + +void +Serial::SerialImpl::setPort (const string &port) +{ + port_ = wstring(port.begin(), port.end()); +} + +string +Serial::SerialImpl::getPort () const +{ + return string(port_.begin(), port_.end()); +} + +void +Serial::SerialImpl::setTimeout (serial::Timeout &timeout) +{ + timeout_ = timeout; + if (is_open_) { + reconfigurePort (); + } +} + +serial::Timeout +Serial::SerialImpl::getTimeout () const +{ + return timeout_; +} + +void +Serial::SerialImpl::setBaudrate (unsigned long baudrate) +{ + baudrate_ = baudrate; + if (is_open_) { + reconfigurePort (); + } +} + +unsigned long +Serial::SerialImpl::getBaudrate () const +{ + return baudrate_; +} + +void +Serial::SerialImpl::setBytesize (serial::bytesize_t bytesize) +{ + bytesize_ = bytesize; + if (is_open_) { + reconfigurePort (); + } +} + +serial::bytesize_t +Serial::SerialImpl::getBytesize () const +{ + return bytesize_; +} + +void +Serial::SerialImpl::setParity (serial::parity_t parity) +{ + parity_ = parity; + if (is_open_) { + reconfigurePort (); + } +} + +serial::parity_t +Serial::SerialImpl::getParity () const +{ + return parity_; +} + +void +Serial::SerialImpl::setStopbits (serial::stopbits_t stopbits) +{ + stopbits_ = stopbits; + if (is_open_) { + reconfigurePort (); + } +} + +serial::stopbits_t +Serial::SerialImpl::getStopbits () const +{ + return stopbits_; +} + +void +Serial::SerialImpl::setFlowcontrol (serial::flowcontrol_t flowcontrol) +{ + flowcontrol_ = flowcontrol; + if (is_open_) { + reconfigurePort (); + } +} + +serial::flowcontrol_t +Serial::SerialImpl::getFlowcontrol () const +{ + return flowcontrol_; +} + +void +Serial::SerialImpl::flush () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::flush"); + } + FlushFileBuffers (fd_); +} + +void +Serial::SerialImpl::flushInput () +{ + if (is_open_ == false) { + throw PortNotOpenedException("Serial::flushInput"); + } + PurgeComm(fd_, PURGE_RXCLEAR); +} + +void +Serial::SerialImpl::flushOutput () +{ + if (is_open_ == false) { + throw PortNotOpenedException("Serial::flushOutput"); + } + PurgeComm(fd_, PURGE_TXCLEAR); +} + +void +Serial::SerialImpl::sendBreak (int /*duration*/) +{ + THROW (IOException, "sendBreak is not supported on Windows."); +} + +void +Serial::SerialImpl::setBreak (bool level) +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::setBreak"); + } + if (level) { + EscapeCommFunction (fd_, SETBREAK); + } else { + EscapeCommFunction (fd_, CLRBREAK); + } +} + +void +Serial::SerialImpl::setRTS (bool level) +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::setRTS"); + } + if (level) { + EscapeCommFunction (fd_, SETRTS); + } else { + EscapeCommFunction (fd_, CLRRTS); + } +} + +void +Serial::SerialImpl::setDTR (bool level) +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::setDTR"); + } + if (level) { + EscapeCommFunction (fd_, SETDTR); + } else { + EscapeCommFunction (fd_, CLRDTR); + } +} + +bool +Serial::SerialImpl::waitForChange () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::waitForChange"); + } + DWORD dwCommEvent; + + if (!SetCommMask(fd_, EV_CTS | EV_DSR | EV_RING | EV_RLSD)) { + // Error setting communications mask + return false; + } + + if (!WaitCommEvent(fd_, &dwCommEvent, NULL)) { + // An error occurred waiting for the event. + return false; + } else { + // Event has occurred. + return true; + } +} + +bool +Serial::SerialImpl::getCTS () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::getCTS"); + } + DWORD dwModemStatus; + if (!GetCommModemStatus(fd_, &dwModemStatus)) { + THROW (IOException, "Error getting the status of the CTS line."); + } + + return (MS_CTS_ON & dwModemStatus) != 0; +} + +bool +Serial::SerialImpl::getDSR () +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::getDSR"); + } + DWORD dwModemStatus; + if (!GetCommModemStatus(fd_, &dwModemStatus)) { + THROW (IOException, "Error getting the status of the DSR line."); + } + + return (MS_DSR_ON & dwModemStatus) != 0; +} + +bool +Serial::SerialImpl::getRI() +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::getRI"); + } + DWORD dwModemStatus; + if (!GetCommModemStatus(fd_, &dwModemStatus)) { + THROW (IOException, "Error getting the status of the RI line."); + } + + return (MS_RING_ON & dwModemStatus) != 0; +} + +bool +Serial::SerialImpl::getCD() +{ + if (is_open_ == false) { + throw PortNotOpenedException ("Serial::getCD"); + } + DWORD dwModemStatus; + if (!GetCommModemStatus(fd_, &dwModemStatus)) { + // Error in GetCommModemStatus; + THROW (IOException, "Error getting the status of the CD line."); + } + + return (MS_RLSD_ON & dwModemStatus) != 0; +} + +void +Serial::SerialImpl::readLock() +{ + if (WaitForSingleObject(read_mutex, INFINITE) != WAIT_OBJECT_0) { + THROW (IOException, "Error claiming read mutex."); + } +} + +void +Serial::SerialImpl::readUnlock() +{ + if (!ReleaseMutex(read_mutex)) { + THROW (IOException, "Error releasing read mutex."); + } +} + +void +Serial::SerialImpl::writeLock() +{ + if (WaitForSingleObject(write_mutex, INFINITE) != WAIT_OBJECT_0) { + THROW (IOException, "Error claiming write mutex."); + } +} + +void +Serial::SerialImpl::writeUnlock() +{ + if (!ReleaseMutex(write_mutex)) { + THROW (IOException, "Error releasing write mutex."); + } +} + +#endif // #if defined(_WIN32) + diff --git a/gimbal_ctrl/IOs/serial/src/serial.cc b/gimbal_ctrl/IOs/serial/src/serial.cc new file mode 100755 index 0000000..a9e6f84 --- /dev/null +++ b/gimbal_ctrl/IOs/serial/src/serial.cc @@ -0,0 +1,432 @@ +/* Copyright 2012 William Woodall and John Harrison */ +#include + +#if !defined(_WIN32) && !defined(__OpenBSD__) && !defined(__FreeBSD__) +# include +#endif + +#if defined (__MINGW32__) +# define alloca __builtin_alloca +#endif + +#include "serial/serial.h" + +#ifdef _WIN32 +#include "serial/impl/win.h" +#else +#include "serial/impl/unix.h" +#endif + +using std::invalid_argument; +using std::min; +using std::numeric_limits; +using std::vector; +using std::size_t; +using std::string; + +using serial::Serial; +using serial::SerialException; +using serial::IOException; +using serial::bytesize_t; +using serial::parity_t; +using serial::stopbits_t; +using serial::flowcontrol_t; + +class Serial::ScopedReadLock { +public: + ScopedReadLock(SerialImpl *pimpl) : pimpl_(pimpl) { + this->pimpl_->readLock(); + } + ~ScopedReadLock() { + this->pimpl_->readUnlock(); + } +private: + // Disable copy constructors + ScopedReadLock(const ScopedReadLock&); + const ScopedReadLock& operator=(ScopedReadLock); + + SerialImpl *pimpl_; +}; + +class Serial::ScopedWriteLock { +public: + ScopedWriteLock(SerialImpl *pimpl) : pimpl_(pimpl) { + this->pimpl_->writeLock(); + } + ~ScopedWriteLock() { + this->pimpl_->writeUnlock(); + } +private: + // Disable copy constructors + ScopedWriteLock(const ScopedWriteLock&); + const ScopedWriteLock& operator=(ScopedWriteLock); + SerialImpl *pimpl_; +}; + +Serial::Serial (const string &port, uint32_t baudrate, serial::Timeout timeout, + bytesize_t bytesize, parity_t parity, stopbits_t stopbits, + flowcontrol_t flowcontrol) + : pimpl_(new SerialImpl (port, baudrate, bytesize, parity, + stopbits, flowcontrol)) +{ + pimpl_->setTimeout(timeout); +} + +Serial::~Serial () +{ + delete pimpl_; +} + +void +Serial::open () +{ + pimpl_->open (); +} + +void +Serial::close () +{ + pimpl_->close (); +} + +bool +Serial::isOpen () const +{ + return pimpl_->isOpen (); +} + +size_t +Serial::available () +{ + return pimpl_->available (); +} + +bool +Serial::waitReadable () +{ + serial::Timeout timeout(pimpl_->getTimeout ()); + return pimpl_->waitReadable(timeout.read_timeout_constant); +} + +void +Serial::waitByteTimes (size_t count) +{ + pimpl_->waitByteTimes(count); +} + +size_t +Serial::read_ (uint8_t *buffer, size_t size) +{ + return this->pimpl_->read (buffer, size); +} + +size_t +Serial::read (uint8_t *buffer, size_t size) +{ + ScopedReadLock lock(this->pimpl_); + return this->pimpl_->read (buffer, size); +} + +size_t +Serial::read (std::vector &buffer, size_t size) +{ + ScopedReadLock lock(this->pimpl_); + uint8_t *buffer_ = new uint8_t[size]; + size_t bytes_read = 0; + + try { + bytes_read = this->pimpl_->read (buffer_, size); + } + catch (const std::exception &e) { + delete[] buffer_; + throw; + } + + buffer.insert (buffer.end (), buffer_, buffer_+bytes_read); + delete[] buffer_; + return bytes_read; +} + +size_t +Serial::read (std::string &buffer, size_t size) +{ + ScopedReadLock lock(this->pimpl_); + uint8_t *buffer_ = new uint8_t[size]; + size_t bytes_read = 0; + try { + bytes_read = this->pimpl_->read (buffer_, size); + } + catch (const std::exception &e) { + delete[] buffer_; + throw; + } + buffer.append (reinterpret_cast(buffer_), bytes_read); + delete[] buffer_; + return bytes_read; +} + +string +Serial::read (size_t size) +{ + std::string buffer; + this->read (buffer, size); + return buffer; +} + +size_t +Serial::readline (string &buffer, size_t size, string eol) +{ + ScopedReadLock lock(this->pimpl_); + size_t eol_len = eol.length (); + uint8_t *buffer_ = static_cast + (alloca (size * sizeof (uint8_t))); + size_t read_so_far = 0; + while (true) + { + size_t bytes_read = this->read_ (buffer_ + read_so_far, 1); + read_so_far += bytes_read; + if (bytes_read == 0) { + break; // Timeout occured on reading 1 byte + } + if(read_so_far < eol_len) continue; + if (string (reinterpret_cast + (buffer_ + read_so_far - eol_len), eol_len) == eol) { + break; // EOL found + } + if (read_so_far == size) { + break; // Reached the maximum read length + } + } + buffer.append(reinterpret_cast (buffer_), read_so_far); + return read_so_far; +} + +string +Serial::readline (size_t size, string eol) +{ + std::string buffer; + this->readline (buffer, size, eol); + return buffer; +} + +vector +Serial::readlines (size_t size, string eol) +{ + ScopedReadLock lock(this->pimpl_); + std::vector lines; + size_t eol_len = eol.length (); + uint8_t *buffer_ = static_cast + (alloca (size * sizeof (uint8_t))); + size_t read_so_far = 0; + size_t start_of_line = 0; + while (read_so_far < size) { + size_t bytes_read = this->read_ (buffer_+read_so_far, 1); + read_so_far += bytes_read; + if (bytes_read == 0) { + if (start_of_line != read_so_far) { + lines.push_back ( + string (reinterpret_cast (buffer_ + start_of_line), + read_so_far - start_of_line)); + } + break; // Timeout occured on reading 1 byte + } + if(read_so_far < eol_len) continue; + if (string (reinterpret_cast + (buffer_ + read_so_far - eol_len), eol_len) == eol) { + // EOL found + lines.push_back( + string(reinterpret_cast (buffer_ + start_of_line), + read_so_far - start_of_line)); + start_of_line = read_so_far; + } + if (read_so_far == size) { + if (start_of_line != read_so_far) { + lines.push_back( + string(reinterpret_cast (buffer_ + start_of_line), + read_so_far - start_of_line)); + } + break; // Reached the maximum read length + } + } + return lines; +} + +size_t +Serial::write (const string &data) +{ + ScopedWriteLock lock(this->pimpl_); + return this->write_ (reinterpret_cast(data.c_str()), + data.length()); +} + +size_t +Serial::write (const std::vector &data) +{ + ScopedWriteLock lock(this->pimpl_); + return this->write_ (&data[0], data.size()); +} + +size_t +Serial::write (const uint8_t *data, size_t size) +{ + ScopedWriteLock lock(this->pimpl_); + return this->write_(data, size); +} + +size_t +Serial::write_ (const uint8_t *data, size_t length) +{ + return pimpl_->write (data, length); +} + +void +Serial::setPort (const string &port) +{ + ScopedReadLock rlock(this->pimpl_); + ScopedWriteLock wlock(this->pimpl_); + bool was_open = pimpl_->isOpen (); + if (was_open) close(); + pimpl_->setPort (port); + if (was_open) open (); +} + +string +Serial::getPort () const +{ + return pimpl_->getPort (); +} + +void +Serial::setTimeout (serial::Timeout &timeout) +{ + pimpl_->setTimeout (timeout); +} + +serial::Timeout +Serial::getTimeout () const { + return pimpl_->getTimeout (); +} + +void +Serial::setBaudrate (uint32_t baudrate) +{ + pimpl_->setBaudrate (baudrate); +} + +uint32_t +Serial::getBaudrate () const +{ + return uint32_t(pimpl_->getBaudrate ()); +} + +void +Serial::setBytesize (bytesize_t bytesize) +{ + pimpl_->setBytesize (bytesize); +} + +bytesize_t +Serial::getBytesize () const +{ + return pimpl_->getBytesize (); +} + +void +Serial::setParity (parity_t parity) +{ + pimpl_->setParity (parity); +} + +parity_t +Serial::getParity () const +{ + return pimpl_->getParity (); +} + +void +Serial::setStopbits (stopbits_t stopbits) +{ + pimpl_->setStopbits (stopbits); +} + +stopbits_t +Serial::getStopbits () const +{ + return pimpl_->getStopbits (); +} + +void +Serial::setFlowcontrol (flowcontrol_t flowcontrol) +{ + pimpl_->setFlowcontrol (flowcontrol); +} + +flowcontrol_t +Serial::getFlowcontrol () const +{ + return pimpl_->getFlowcontrol (); +} + +void Serial::flush () +{ + ScopedReadLock rlock(this->pimpl_); + ScopedWriteLock wlock(this->pimpl_); + pimpl_->flush (); +} + +void Serial::flushInput () +{ + ScopedReadLock lock(this->pimpl_); + pimpl_->flushInput (); +} + +void Serial::flushOutput () +{ + ScopedWriteLock lock(this->pimpl_); + pimpl_->flushOutput (); +} + +void Serial::sendBreak (int duration) +{ + pimpl_->sendBreak (duration); +} + +void Serial::setBreak (bool level) +{ + pimpl_->setBreak (level); +} + +void Serial::setRTS (bool level) +{ + pimpl_->setRTS (level); +} + +void Serial::setDTR (bool level) +{ + pimpl_->setDTR (level); +} + +bool Serial::waitForChange() +{ + return pimpl_->waitForChange(); +} + +bool Serial::getCTS () +{ + return pimpl_->getCTS (); +} + +bool Serial::getDSR () +{ + return pimpl_->getDSR (); +} + +bool Serial::getRI () +{ + return pimpl_->getRI (); +} + +bool Serial::getCD () +{ + return pimpl_->getCD (); +} diff --git a/gimbal_ctrl/driver/src/FIFO/Ring_Fifo.cc b/gimbal_ctrl/driver/src/FIFO/Ring_Fifo.cc new file mode 100755 index 0000000..ba507a3 --- /dev/null +++ b/gimbal_ctrl/driver/src/FIFO/Ring_Fifo.cc @@ -0,0 +1,212 @@ +/* + * @Description : + * @Author : Aiyangsky + * @Date : 2022-08-26 21:42:10 + * @LastEditors : Aiyangsky + * @LastEditTime : 2022-08-27 03:43:49 + * @FilePath : \mavlink\src\route\Ring_Fifo.c + */ + +#include + +#include "Ring_Fifo.h" + +/** + * @description: + * @param {RING_FIFO_CB_T} *fifo fifo struct pointer + * @param {unsigned short} cell_size sizeof(cell) + * @param {unsigned char} *buffer fifo buffer address + * @param {unsigned int} buffer_lenght sizeof(buffer) + * @return {*} + * @note : + */ +void Ring_Fifo_init(RING_FIFO_CB_T *fifo, unsigned short cell_size, + unsigned char *buffer, unsigned int buffer_lenght) +{ + fifo->cell_size = cell_size; + + fifo->start = buffer; + // Remainder is taken to avoid splicing in the output so as to improve the efficiency + fifo->end = buffer + buffer_lenght - (buffer_lenght % cell_size); + fifo->in = buffer; + fifo->out = buffer; + fifo->curr_number = 0; + fifo->max_number = buffer_lenght / cell_size; +} + +/** + * @description: add a cell to fifo + * @param {RING_FIFO_CB_T} *fifo fifo struct pointer + * @param {void} *data cell data [in] + * @return {*} Success or fail + * @note : failed if without space + */ +bool Ring_Fifo_in_cell(RING_FIFO_CB_T *fifo, void *data) +{ + unsigned char *next; + unsigned char *ptemp = fifo->in; + bool ret = false; + + LOCK(); + + if (fifo->curr_number < fifo->max_number) + { + next = fifo->in + fifo->cell_size; + if (next >= fifo->end) + { + next = fifo->start; + } + fifo->in = next; + fifo->curr_number++; + memcpy(ptemp, data, fifo->cell_size); + ret = true; + } + + UNLOCK(); + + return ret; +} + +/** + * @description: add a series of cells to fifo + * @param {RING_FIFO_CB_T} *fifo + * @param {void} *data cells data [in] + * @param {unsigned short} number expect add number of cells + * @return {*} number of successful add + * @note : + */ +unsigned short Ring_Fifo_in_cells(RING_FIFO_CB_T *fifo, void *data, unsigned short number) +{ + // Number of remaining storable cells is described to simplify the calculation in the copying process. + unsigned short diff = fifo->max_number - fifo->curr_number; + unsigned short count_temp, count_temp_r; + unsigned char *next; + unsigned char *ptemp = fifo->in; + unsigned short ret; + + LOCK(); + + if (diff > number) + { + ret = number; + } + else if (diff > 0 && diff < number) + { + ret = diff; + } + else + { + ret = 0; + } + + count_temp = fifo->cell_size * ret; + next = fifo->in + count_temp; + // Moving the write pointer and the number of stored cells before + // copying data reduces the likelihood of multithreaded write conflicts. + fifo->curr_number += ret; + + if (next < fifo->end) + { + fifo->in = next; + memcpy(ptemp, data, count_temp); + } + else + { + count_temp_r = fifo->end - fifo->in; + next = fifo->start + count_temp - count_temp_r; + fifo->in = next; + memcpy(ptemp, data, count_temp_r); + memcpy(fifo->start, ((unsigned char *)data) + count_temp_r, count_temp - count_temp_r); + } + + UNLOCK(); + + return ret; +} + +/** + * @description: output a cell + * @param {RING_FIFO_CB_T} *fifo + * @param {void} *data cell data [out] + * @return {*} Success or fail + * @note : fail if without cell + */ +bool Ring_Fifo_out_cell(RING_FIFO_CB_T *fifo, void *data) +{ + unsigned char *next; + unsigned char *ptemp = fifo->out; + bool ret = false; + + LOCK(); + + if (fifo->curr_number > 0) + { + next = fifo->out + fifo->cell_size; + if (next >= fifo->end) + { + next = fifo->start; + } + fifo->out = next; + fifo->curr_number--; + memcpy(data, ptemp, fifo->cell_size); + ret = true; + } + + UNLOCK(); + + return ret; +} + +/** + * @description: output a series of cells in fifo + * @param {RING_FIFO_CB_T} *fifo + * @param {void} *data cells data [out] + * @param {unsigned short} number expect out number of cells + * @return {*} number of successful output + * @note : + */ +unsigned short Ring_Fifo_out_cells(RING_FIFO_CB_T *fifo, void *data, unsigned short number) +{ + unsigned char *next; + unsigned char *ptemp = fifo->out; + unsigned short count_temp, count_temp_r; + unsigned short ret; + + LOCK(); + + if (fifo->curr_number > number) + { + ret = number; + } + else if (fifo->curr_number < number && fifo->curr_number > 0) + { + ret = fifo->curr_number; + } + else + { + ret = 0; + } + + count_temp = fifo->cell_size * ret; + next = fifo->out + count_temp; + + fifo->curr_number -= ret; + + if (next < fifo->end) + { + fifo->out = next; + memcpy(data, ptemp, count_temp); + } + else + { + count_temp_r = fifo->end - fifo->in; + next = fifo->start + count_temp - count_temp_r; + fifo->out = next; + memcpy(data, ptemp, count_temp_r); + memcpy(((unsigned char *)data) + count_temp_r, fifo->start, count_temp - count_temp_r); + } + + UNLOCK(); + + return ret; +} diff --git a/gimbal_ctrl/driver/src/FIFO/Ring_Fifo.h b/gimbal_ctrl/driver/src/FIFO/Ring_Fifo.h new file mode 100755 index 0000000..f50087c --- /dev/null +++ b/gimbal_ctrl/driver/src/FIFO/Ring_Fifo.h @@ -0,0 +1,47 @@ +/* + * @Description : + * @Author : Aiyangsky + * @Date : 2022-08-26 21:42:02 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-03-03 16:12:37 + * @FilePath: \host\gimbal-sdk-multi-platform\src\FIFO\Ring_Fifo.h + */ + +#ifndef RING_FIFO_H +#define RING_FIFO_H + +#include "stdbool.h" + +#ifdef __cplusplus +extern "C" +{ +#endif + +#define LOCK() +#define UNLOCK() + + typedef struct + { + unsigned char *start; + unsigned char *in; + unsigned char *out; + unsigned char *end; + + unsigned short curr_number; + unsigned short max_number; + unsigned short cell_size; + } RING_FIFO_CB_T; + + void Ring_Fifo_init(RING_FIFO_CB_T *fifo, unsigned short cell_size, + unsigned char *buffer, unsigned int buffer_lenght); + bool Ring_Fifo_in_cell(RING_FIFO_CB_T *fifo, void *data); + unsigned short Ring_Fifo_in_cells(RING_FIFO_CB_T *fifo, void *data, unsigned short number); + bool Ring_Fifo_out_cell(RING_FIFO_CB_T *fifo, void *data); + unsigned short Ring_Fifo_out_cells(RING_FIFO_CB_T *fifo, void *data, unsigned short number); + + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/gimbal_ctrl/driver/src/G1/g1_gimbal_crc32.h b/gimbal_ctrl/driver/src/G1/g1_gimbal_crc32.h new file mode 100755 index 0000000..a974124 --- /dev/null +++ b/gimbal_ctrl/driver/src/G1/g1_gimbal_crc32.h @@ -0,0 +1,93 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2022-10-27 18:10:06 + * @LastEditors: L LC @amov + * @LastEditTime: 2022-10-28 14:10:02 + * @FilePath: \amov-gimbal-sdk\src\G1\g1_gimbal_crc32.h + */ +#ifndef G1_GIMBAL_CRC32_H +#define G1_GIMBAL_CRC32_H + +namespace G1 +{ + const unsigned int Crc32Table[256] = { + 0x00000000, 0x04C11DB7, 0x09823B6E, 0x0D4326D9, 0x130476DC, 0x17C56B6B, + 0x1A864DB2, 0x1E475005, 0x2608EDB8, 0x22C9F00F, 0x2F8AD6D6, 0x2B4BCB61, + 0x350C9B64, 0x31CD86D3, 0x3C8EA00A, 0x384FBDBD, 0x4C11DB70, 0x48D0C6C7, + 0x4593E01E, 0x4152FDA9, 0x5F15ADAC, 0x5BD4B01B, 0x569796C2, 0x52568B75, + 0x6A1936C8, 0x6ED82B7F, 0x639B0DA6, 0x675A1011, 0x791D4014, 0x7DDC5DA3, + 0x709F7B7A, 0x745E66CD, 0x9823B6E0, 0x9CE2AB57, 0x91A18D8E, 0x95609039, + 0x8B27C03C, 0x8FE6DD8B, 0x82A5FB52, 0x8664E6E5, 0xBE2B5B58, 0xBAEA46EF, + 0xB7A96036, 0xB3687D81, 0xAD2F2D84, 0xA9EE3033, 0xA4AD16EA, 0xA06C0B5D, + 0xD4326D90, 0xD0F37027, 0xDDB056FE, 0xD9714B49, 0xC7361B4C, 0xC3F706FB, + 0xCEB42022, 0xCA753D95, 0xF23A8028, 0xF6FB9D9F, 0xFBB8BB46, 0xFF79A6F1, + 0xE13EF6F4, 0xE5FFEB43, 0xE8BCCD9A, 0xEC7DD02D, 0x34867077, 0x30476DC0, + 0x3D044B19, 0x39C556AE, 0x278206AB, 0x23431B1C, 0x2E003DC5, 0x2AC12072, + 0x128E9DCF, 0x164F8078, 0x1B0CA6A1, 0x1FCDBB16, 0x018AEB13, 0x054BF6A4, + 0x0808D07D, 0x0CC9CDCA, 0x7897AB07, 0x7C56B6B0, 0x71159069, 0x75D48DDE, + 0x6B93DDDB, 0x6F52C06C, 0x6211E6B5, 0x66D0FB02, 0x5E9F46BF, 0x5A5E5B08, + 0x571D7DD1, 0x53DC6066, 0x4D9B3063, 0x495A2DD4, 0x44190B0D, 0x40D816BA, + 0xACA5C697, 0xA864DB20, 0xA527FDF9, 0xA1E6E04E, 0xBFA1B04B, 0xBB60ADFC, + 0xB6238B25, 0xB2E29692, 0x8AAD2B2F, 0x8E6C3698, 0x832F1041, 0x87EE0DF6, + 0x99A95DF3, 0x9D684044, 0x902B669D, 0x94EA7B2A, 0xE0B41DE7, 0xE4750050, + 0xE9362689, 0xEDF73B3E, 0xF3B06B3B, 0xF771768C, 0xFA325055, 0xFEF34DE2, + 0xC6BCF05F, 0xC27DEDE8, 0xCF3ECB31, 0xCBFFD686, 0xD5B88683, 0xD1799B34, + 0xDC3ABDED, 0xD8FBA05A, 0x690CE0EE, 0x6DCDFD59, 0x608EDB80, 0x644FC637, + 0x7A089632, 0x7EC98B85, 0x738AAD5C, 0x774BB0EB, 0x4F040D56, 0x4BC510E1, + 0x46863638, 0x42472B8F, 0x5C007B8A, 0x58C1663D, 0x558240E4, 0x51435D53, + 0x251D3B9E, 0x21DC2629, 0x2C9F00F0, 0x285E1D47, 0x36194D42, 0x32D850F5, + 0x3F9B762C, 0x3B5A6B9B, 0x0315D626, 0x07D4CB91, 0x0A97ED48, 0x0E56F0FF, + 0x1011A0FA, 0x14D0BD4D, 0x19939B94, 0x1D528623, 0xF12F560E, 0xF5EE4BB9, + 0xF8AD6D60, 0xFC6C70D7, 0xE22B20D2, 0xE6EA3D65, 0xEBA91BBC, 0xEF68060B, + 0xD727BBB6, 0xD3E6A601, 0xDEA580D8, 0xDA649D6F, 0xC423CD6A, 0xC0E2D0DD, + 0xCDA1F604, 0xC960EBB3, 0xBD3E8D7E, 0xB9FF90C9, 0xB4BCB610, 0xB07DABA7, + 0xAE3AFBA2, 0xAAFBE615, 0xA7B8C0CC, 0xA379DD7B, 0x9B3660C6, 0x9FF77D71, + 0x92B45BA8, 0x9675461F, 0x8832161A, 0x8CF30BAD, 0x81B02D74, 0x857130C3, + 0x5D8A9099, 0x594B8D2E, 0x5408ABF7, 0x50C9B640, 0x4E8EE645, 0x4A4FFBF2, + 0x470CDD2B, 0x43CDC09C, 0x7B827D21, 0x7F436096, 0x7200464F, 0x76C15BF8, + 0x68860BFD, 0x6C47164A, 0x61043093, 0x65C52D24, 0x119B4BE9, 0x155A565E, + 0x18197087, 0x1CD86D30, 0x029F3D35, 0x065E2082, 0x0B1D065B, 0x0FDC1BEC, + 0x3793A651, 0x3352BBE6, 0x3E119D3F, 0x3AD08088, 0x2497D08D, 0x2056CD3A, + 0x2D15EBE3, 0x29D4F654, 0xC5A92679, 0xC1683BCE, 0xCC2B1D17, 0xC8EA00A0, + 0xD6AD50A5, 0xD26C4D12, 0xDF2F6BCB, 0xDBEE767C, 0xE3A1CBC1, 0xE760D676, + 0xEA23F0AF, 0xEEE2ED18, 0xF0A5BD1D, 0xF464A0AA, 0xF9278673, 0xFDE69BC4, + 0x89B8FD09, 0x8D79E0BE, 0x803AC667, 0x84FBDBD0, 0x9ABC8BD5, 0x9E7D9662, + 0x933EB0BB, 0x97FFAD0C, 0xAFB010B1, 0xAB710D06, 0xA6322BDF, 0xA2F33668, + 0xBCB4666D, 0xB8757BDA, 0xB5365D03, 0xB1F740B4}; + + static inline unsigned int CRC32Software(const unsigned char *pData, unsigned short Length) + { + unsigned int nReg; + unsigned int nTemp = 0; + unsigned short i, n; + + nReg = 0xFFFFFFFF; + for (n = 0; n < Length; n++) + { + nReg ^= (unsigned int)pData[n]; + + for (i = 0; i < 4; i++) + { + nTemp = Crc32Table[(unsigned char)((nReg >> 24) & 0xff)]; + nReg <<= 8; + nReg ^= nTemp; + } + } + return nReg; + } + + static inline unsigned char CheckSum(unsigned char *pData, unsigned short Lenght) + { + unsigned short temp = 0; + unsigned short i = 0; + for (i = 0; i < Lenght; i++) + { + temp += pData[i]; + } + return temp & 0XFF; + } + +} // namespace name + +#endif \ No newline at end of file diff --git a/gimbal_ctrl/driver/src/G1/g1_gimbal_driver.cpp b/gimbal_ctrl/driver/src/G1/g1_gimbal_driver.cpp new file mode 100755 index 0000000..1ef00aa --- /dev/null +++ b/gimbal_ctrl/driver/src/G1/g1_gimbal_driver.cpp @@ -0,0 +1,245 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2022-10-27 18:10:06 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-18 10:12:46 + * @FilePath: /gimbal-sdk-multi-platform/src/G1/g1_gimbal_driver.cpp + */ +#include "g1_gimbal_driver.h" +#include "g1_gimbal_crc32.h" +#include "string.h" + +/** + * The function creates a new instance of the g1GimbalDriver class, which is a subclass of the + * IamovGimbalBase class + * + * @param _IO The IOStreamBase object that will be used to communicate with the gimbal. + */ +g1GimbalDriver::g1GimbalDriver(amovGimbal::IOStreamBase *_IO) : amovGimbal::IamovGimbalBase(_IO) +{ + memset(&rxQueue, 0, sizeof(RING_FIFO_CB_T)); + memset(&txQueue, 0, sizeof(RING_FIFO_CB_T)); + + rxBuffer = (uint8_t *)malloc(MAX_QUEUE_SIZE * sizeof(G1::GIMBAL_FRAME_T)); + if (rxBuffer == NULL) + { + std::cout << "Receive buffer creation failed! Size : " << MAX_QUEUE_SIZE << std::endl; + exit(1); + } + txBuffer = (uint8_t *)malloc(MAX_QUEUE_SIZE * sizeof(G1::GIMBAL_FRAME_T)); + if (txBuffer == NULL) + { + free(rxBuffer); + std::cout << "Send buffer creation failed! Size : " << MAX_QUEUE_SIZE << std::endl; + exit(1); + } + + Ring_Fifo_init(&rxQueue, sizeof(G1::GIMBAL_FRAME_T), rxBuffer, MAX_QUEUE_SIZE * sizeof(G1::GIMBAL_FRAME_T)); + Ring_Fifo_init(&txQueue, sizeof(G1::GIMBAL_FRAME_T), txBuffer, MAX_QUEUE_SIZE * sizeof(G1::GIMBAL_FRAME_T)); + + parserState = G1::GIMBAL_SERIAL_STATE_IDLE; +} + +/** + * The function takes a command, a pointer to a payload, and the size of the payload. It then copies + * the payload into the tx buffer, calculates the checksum, and then calculates the CRC32 of the + * payload. It then copies the CRC32 into the tx buffer, and then copies the tx buffer into the txQueue + * + * @param uint32_t 4 bytes + * @param pPayload pointer to the data to be sent + * @param payloadSize the size of the payload + * + * @return The size of the data to be sent. + */ +uint32_t g1GimbalDriver::pack(IN uint32_t cmd, uint8_t *pPayload, uint8_t payloadSize) +{ + uint32_t ret = 0; + G1::GIMBAL_FRAME_T txTemp; + + txTemp.head = G1_SERIAL_HEAD; + txTemp.version = G1_SERIAL_VERSION; + txTemp.lenght = payloadSize; + txTemp.cmd = cmd; + txTemp.checksum = G1::CheckSum((unsigned char *)&txTemp.version, 3); + memcpy(txTemp.payload, pPayload, payloadSize); + txTemp.crc.u32 = G1::CRC32Software(txTemp.payload, payloadSize); + memcpy(txTemp.payload + payloadSize, txTemp.crc.u8, sizeof(uint32_t)); + + txMutex.lock(); + if (Ring_Fifo_in_cell(&txQueue, &txTemp)) + { + ret = txTemp.lenght + G1_PAYLOAD_OFFSET + sizeof(uint32_t); + } + txMutex.unlock(); + + return ret; +} + +/** + * > This function is used to get a packet from the receive queue + * + * @param void This is the type of data that will be stored in the queue. + * + * @return A boolean value. + */ +bool g1GimbalDriver::getRxPack(OUT void *pack) +{ + bool state = false; + rxMutex.lock(); + state = Ring_Fifo_out_cell(&rxQueue, pack); + rxMutex.unlock(); + return state; +} + +void g1GimbalDriver::convert(void *buf) +{ + G1::GIMBAL_FRAME_T *temp; + temp = reinterpret_cast(buf); + switch (temp->cmd) + { + case G1::GIMBAL_CMD_RCV_POS: + G1::GIMBAL_RCV_POS_MSG_T *tempPos; + tempPos = reinterpret_cast(((uint8_t *)buf) + G1_PAYLOAD_OFFSET); + mState.lock(); + state.abs.yaw = tempPos->IMU_yaw * G1_SCALE_FACTOR; + state.abs.roll = tempPos->IMU_roll * G1_SCALE_FACTOR; + state.abs.pitch = tempPos->IMU_pitch * G1_SCALE_FACTOR; + state.rel.yaw = tempPos->HALL_yaw * G1_SCALE_FACTOR; + state.rel.roll = tempPos->HALL_roll * G1_SCALE_FACTOR; + state.rel.pitch = tempPos->HALL_pitch * G1_SCALE_FACTOR; + updateGimbalStateCallback(state.abs.roll, state.abs.pitch, state.abs.yaw, + state.rel.roll, state.rel.pitch, state.rel.yaw, + state.fov.x, state.fov.y); + mState.unlock(); + break; + + default: + std::cout << "Undefined frame from G1 : "; + for (uint16_t i = 0; i < temp->lenght + G1_PAYLOAD_OFFSET + sizeof(uint32_t); i++) + { + printf("%02X ", ((uint8_t *)buf)[i]); + } + std::cout << std::endl; + break; + } +} + +/** + * The function is called by the main thread to send a command to the gimbal. + * + * The function first checks to see if the serial port is busy and if it is open. If it is not busy and + * it is open, the function locks the txMutex and then checks to see if there is a command in the + * txQueue. If there is a command in the txQueue, the function copies the command to the tx buffer and + * then unlocks the txMutex. The function then sends the command to the gimbal. + * + * The txQueue is a ring buffer that holds commands that are waiting to be sent to the gimbal. The + * txQueue is a ring buffer because the gimbal can only process one command at a time. If the gimbal is + * busy processing a command, the command will be placed in the txQueue and sent to the gimbal when the + * gimbal is ready to receive the command. + */ +void g1GimbalDriver::send(void) +{ + if (!IO->isBusy() && IO->isOpen()) + { + bool state = false; + txMutex.lock(); + state = Ring_Fifo_out_cell(&txQueue, &tx); + txMutex.unlock(); + if (state) + { + IO->outPutBytes((uint8_t *)&tx, tx.lenght + G1_PAYLOAD_OFFSET + sizeof(uint32_t)); + } + } +} + +/** + * It's a state machine that parses a serial stream of bytes into a struct + * + * @param uint8_t unsigned char + * + * @return A boolean value. + */ +bool g1GimbalDriver::parser(IN uint8_t byte) +{ + bool state = false; + static uint8_t payloadLenghte = 0; + static uint8_t *pRx = NULL; + + switch (parserState) + { + case G1::GIMBAL_SERIAL_STATE_IDLE: + if (byte == G1_SERIAL_HEAD) + { + rx.head = byte; + parserState = G1::GIMBAL_SERIAL_STATE_VERSION; + } + break; + + case G1::GIMBAL_SERIAL_STATE_VERSION: + if (byte == G1_SERIAL_VERSION) + { + rx.version = byte; + payloadLenghte = 0; + parserState = G1::GIMBAL_SERIAL_STATE_LENGHT; + } + else + { + rx.head = 0; + parserState = G1::GIMBAL_SERIAL_STATE_IDLE; + } + break; + + case G1::GIMBAL_SERIAL_STATE_LENGHT: + payloadLenghte = byte + 4; + rx.lenght = byte; + parserState = G1::GIMBAL_SERIAL_STATE_CMD; + break; + + case G1::GIMBAL_SERIAL_STATE_CMD: + rx.cmd = byte; + parserState = G1::GIMBAL_SERIAL_STATE_CHECK; + break; + + case G1::GIMBAL_SERIAL_STATE_CHECK: + rx.checksum = byte; + if (G1::CheckSum((unsigned char *)&rx.version, 3) == byte) + { + + parserState = G1::GIMBAL_SERIAL_STATE_PAYLOAD; + pRx = rx.payload; + } + else + { + memset(&rx, 0, 5); + parserState = G1::GIMBAL_SERIAL_STATE_IDLE; + } + break; + + case G1::GIMBAL_SERIAL_STATE_PAYLOAD: + *pRx = byte; + payloadLenghte--; + pRx++; + if (payloadLenghte <= 0) + { + if (*((uint32_t *)(pRx - sizeof(uint32_t))) == G1::CRC32Software(rx.payload, rx.lenght)) + { + state = true; + rxMutex.lock(); + Ring_Fifo_in_cell(&rxQueue, &rx); + rxMutex.unlock(); + } + else + { + memset(&rx, 0, sizeof(G1::GIMBAL_FRAME_T)); + } + parserState = G1::GIMBAL_SERIAL_STATE_IDLE; + } + break; + + default: + parserState = G1::GIMBAL_SERIAL_STATE_IDLE; + break; + } + return state; +} diff --git a/gimbal_ctrl/driver/src/G1/g1_gimbal_driver.h b/gimbal_ctrl/driver/src/G1/g1_gimbal_driver.h new file mode 100755 index 0000000..19ed884 --- /dev/null +++ b/gimbal_ctrl/driver/src/G1/g1_gimbal_driver.h @@ -0,0 +1,68 @@ +/* + * @Description: G1吊舱的驱动文件 + * @Author: L LC @amov + * @Date: 2022-10-28 12:24:21 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-03-13 12:29:17 + * @FilePath: \gimbal-sdk-multi-platform\src\G1\g1_gimbal_driver.h + */ +#include "../amov_gimbal.h" +#include "g1_gimbal_struct.h" +#include +#include +#include + +#ifndef __G1_DRIVER_H +#define __G1_DRIVER_H + +extern "C" +{ +#include "Ring_Fifo.h" +} + +class g1GimbalDriver : protected amovGimbal::IamovGimbalBase +{ +private: + G1::GIMBAL_CMD_PARSER_STATE_T parserState; + G1::GIMBAL_FRAME_T rx; + G1::GIMBAL_FRAME_T tx; + + std::mutex rxMutex; + uint8_t *rxBuffer; + RING_FIFO_CB_T rxQueue; + std::mutex txMutex; + uint8_t *txBuffer; + RING_FIFO_CB_T txQueue; + + bool parser(IN uint8_t byte); + void send(void); + + void convert(void *buf); + uint32_t pack(IN uint32_t cmd, uint8_t *pPayload, uint8_t payloadSize); + bool getRxPack(OUT void *pack); + +public: + // funtions + uint32_t setGimabalPos(const amovGimbal::AMOV_GIMBAL_POS_T &pos); + uint32_t setGimabalSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &speed); + uint32_t setGimabalFollowSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &followSpeed); + uint32_t setGimabalHome(void); + + uint32_t takePic(void); + uint32_t setVideo(const amovGimbal::AMOV_GIMBAL_VIDEO_T newState); + + // builds + static amovGimbal::IamovGimbalBase *creat(amovGimbal::IOStreamBase *_IO) + { + return new g1GimbalDriver(_IO); + } + + g1GimbalDriver(amovGimbal::IOStreamBase *_IO); + ~g1GimbalDriver() + { + free(rxBuffer); + free(txBuffer); + } +}; + +#endif diff --git a/gimbal_ctrl/driver/src/G1/g1_gimbal_funtion.cpp b/gimbal_ctrl/driver/src/G1/g1_gimbal_funtion.cpp new file mode 100755 index 0000000..270fbb4 --- /dev/null +++ b/gimbal_ctrl/driver/src/G1/g1_gimbal_funtion.cpp @@ -0,0 +1,118 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2023-03-02 10:00:52 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-03-17 18:29:33 + * @FilePath: \gimbal-sdk-multi-platform\src\G1\g1_gimbal_funtion.cpp + */ +#include "g1_gimbal_driver.h" +#include "g1_gimbal_crc32.h" +#include "string.h" + +/** + * It sets the gimbal position. + * + * @param pos the position of the gimbal + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t g1GimbalDriver::setGimabalPos(const amovGimbal::AMOV_GIMBAL_POS_T &pos) +{ + G1::GIMBAL_SET_POS_MSG_T temp; + temp.mode = G1::GIMBAL_CMD_POS_MODE_ANGLE; + temp.angle_pitch = pos.pitch / G1_SCALE_FACTOR; + temp.angle_roll = pos.roll / G1_SCALE_FACTOR; + temp.angle_yaw = pos.yaw / G1_SCALE_FACTOR; + temp.speed_pitch = state.maxFollow.pitch; + temp.speed_roll = state.maxFollow.roll; + temp.speed_yaw = state.maxFollow.yaw; + return pack(G1::GIMBAL_CMD_SET_POS, reinterpret_cast(&temp), sizeof(G1::GIMBAL_SET_POS_MSG_T)); +} + +/** + * It takes a struct of type amovGimbal::AMOV_GIMBAL_POS_T and converts it to a struct of type + * G1::GIMBAL_SET_POS_MSG_T + * + * @param speed the speed of the gimbal + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t g1GimbalDriver::setGimabalSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &speed) +{ + G1::GIMBAL_SET_POS_MSG_T temp; + temp.mode = G1::GIMBAL_CMD_POS_MODE_SPEED; + temp.angle_pitch = 0; + temp.angle_roll = 0; + temp.angle_yaw = 0; + temp.speed_pitch = speed.pitch / G1_SCALE_FACTOR; + temp.speed_roll = speed.roll / G1_SCALE_FACTOR; + temp.speed_yaw = speed.yaw / G1_SCALE_FACTOR; + return pack(G1::GIMBAL_CMD_SET_POS, reinterpret_cast(&temp), sizeof(G1::GIMBAL_SET_POS_MSG_T)); +} + +/** + * This function sets the gimbal's follow speed + * + * @param followSpeed the speed of the gimbal + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t g1GimbalDriver::setGimabalFollowSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &followSpeed) +{ + state.maxFollow.pitch = followSpeed.pitch / G1_SCALE_FACTOR; + state.maxFollow.roll = followSpeed.roll / G1_SCALE_FACTOR; + state.maxFollow.yaw = followSpeed.yaw / G1_SCALE_FACTOR; + return 0; +} + +/** + * This function sets the gimbal to its home position + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t g1GimbalDriver::setGimabalHome(void) +{ + G1::GIMBAL_SET_POS_MSG_T temp; + temp.mode = G1::GIMBAL_CMD_POS_MODE_HOME; + temp.speed_pitch = state.maxFollow.pitch; + temp.speed_roll = state.maxFollow.roll; + temp.speed_yaw = state.maxFollow.yaw; + return pack(G1::GIMBAL_CMD_SET_POS, reinterpret_cast(&temp), sizeof(G1::GIMBAL_SET_POS_MSG_T)); +} + +/** + * It takes a picture. + * + * @return The return value is the number of bytes written to the serial port. + */ +uint32_t g1GimbalDriver::takePic(void) +{ + uint8_t temp = G1::GIMBAL_CMD_CAMERA_TACK; + return pack(G1::GIMBAL_CMD_CAMERA, &temp, sizeof(uint8_t)); +} + +/** + * The function sets the video state of the gimbal + * + * @param newState The new state of the video. + * + * @return The return value is the number of bytes written to the serial port. + */ +uint32_t g1GimbalDriver::setVideo(const amovGimbal::AMOV_GIMBAL_VIDEO_T newState) +{ + uint8_t temp = G1::GIMBAL_CMD_CAMERA_REC; + + mState.lock(); + if(state.video == amovGimbal::AMOV_GIMBAL_VIDEO_TAKE) + { + state.video = amovGimbal::AMOV_GIMBAL_VIDEO_OFF; + } + else + { + state.video = amovGimbal::AMOV_GIMBAL_VIDEO_TAKE; + } + mState.unlock(); + + return pack(G1::GIMBAL_CMD_CAMERA, &temp, sizeof(uint8_t)); +} diff --git a/gimbal_ctrl/driver/src/G1/g1_gimbal_struct.h b/gimbal_ctrl/driver/src/G1/g1_gimbal_struct.h new file mode 100755 index 0000000..c7e99c5 --- /dev/null +++ b/gimbal_ctrl/driver/src/G1/g1_gimbal_struct.h @@ -0,0 +1,91 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2022-10-27 18:10:07 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-03-17 18:12:57 + * @FilePath: \gimbal-sdk-multi-platform\src\G1\g1_gimbal_struct.h + */ +#ifndef G1_GIMBAL_STRUCT_H +#define G1_GIMBAL_STRUCT_H + +#include +namespace G1 +{ +#define G1_MAX_GIMBAL_PAYLOAD 256 +#define G1_PAYLOAD_OFFSET 5 +#define G1_SCALE_FACTOR 0.01f +#define G1_SERIAL_HEAD 0XAE +#define G1_SERIAL_VERSION 0X01 + + typedef enum + { + GIMBAL_CMD_SET_POS = 0X85, + GIMBAL_CMD_CAMERA = 0X86, + GIMBAL_CMD_RCV_POS = 0X87 + } GIMBAL_CMD_T; + + typedef enum + { + GIMBAL_CMD_POS_MODE_SPEED = 1, + GIMBAL_CMD_POS_MODE_ANGLE = 2, + GIMBAL_CMD_POS_MODE_HOME = 3 + } GIMBAL_CMD_POS_MODE_T; + + typedef enum + { + GIMBAL_CMD_CAMERA_REC = 1, + GIMBAL_CMD_CAMERA_TACK = 2 + } GIMBAL_CMD_CAMERA_T; + + typedef enum + { + GIMBAL_SERIAL_STATE_IDLE, + GIMBAL_SERIAL_STATE_VERSION, + GIMBAL_SERIAL_STATE_LENGHT, + GIMBAL_SERIAL_STATE_CMD, + GIMBAL_SERIAL_STATE_CHECK, + GIMBAL_SERIAL_STATE_PAYLOAD, + } GIMBAL_CMD_PARSER_STATE_T; + +#pragma pack(1) + typedef struct + { + uint8_t head; + uint8_t version; + uint8_t lenght; + uint8_t cmd; + uint8_t checksum; + uint8_t payload[G1_MAX_GIMBAL_PAYLOAD + sizeof(uint32_t)]; + union + { + uint8_t u8[4]; + uint32_t u32; + } crc; + } GIMBAL_FRAME_T; + + typedef struct + { + uint8_t mode; + int16_t angle_roll; + int16_t angle_pitch; + int16_t angle_yaw; + int16_t speed_roll; + int16_t speed_pitch; + int16_t speed_yaw; + } GIMBAL_SET_POS_MSG_T; + + typedef struct + { + int16_t IMU_roll; + int16_t IMU_pitch; + int16_t IMU_yaw; + int16_t HALL_roll; + int16_t HALL_pitch; + int16_t HALL_yaw; + } GIMBAL_RCV_POS_MSG_T; + +#pragma pack() + +} +#endif \ No newline at end of file diff --git a/gimbal_ctrl/driver/src/G2/g2_gimbal_crc.h b/gimbal_ctrl/driver/src/G2/g2_gimbal_crc.h new file mode 100755 index 0000000..fff2e6c --- /dev/null +++ b/gimbal_ctrl/driver/src/G2/g2_gimbal_crc.h @@ -0,0 +1,166 @@ + +#ifndef __G2_GIMBAL_CHECK_H +#define __G2_GIMBAL_CHECK_H + +namespace G2 +{ +#include "stdint.h" + + const uint16_t crc16_tab[256] = { + 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, + 0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, + 0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6, + 0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de, + 0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485, + 0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d, + 0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4, + 0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc, + 0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823, + 0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b, + 0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12, + 0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a, + 0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41, + 0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49, + 0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70, + 0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78, + 0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f, + 0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067, + 0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e, + 0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256, + 0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d, + 0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405, + 0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c, + 0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634, + 0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab, + 0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3, + 0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a, + 0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92, + 0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9, + 0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1, + 0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8, + 0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0}; + + /** + * "For each byte in the data, shift the CRC register left by 8 bits, XOR the CRC register with the CRC + * table value for the byte, and then shift the CRC register right by 8 bits." + * + * The CRC table is a 256-byte array of 16-bit values. The index into the table is the byte value. + * The value in the table is the CRC value for that byte. The CRC table is generated by the following + * function: + * + * @param data pointer to the data to be checked + * @param len the length of the data to be checked + * + * @return The CRC value. + * @note 16 bit CRC with polynomial x^16+x^12+x^5+1 + */ + static inline uint16_t checkCrc16(uint8_t *pData, uint32_t len) + { + uint16_t crc = 0XFFFF; + uint32_t idx = 0; + + for (idx = 0; idx < len; idx++) + { + crc = crc16_tab[((crc >> 8) ^ pData[idx]) & 0xFF] ^ (crc << 8); + } + return crc; + } + + const unsigned int Crc32Table[256] = { + 0x00000000, 0x04C11DB7, 0x09823B6E, 0x0D4326D9, 0x130476DC, 0x17C56B6B, + 0x1A864DB2, 0x1E475005, 0x2608EDB8, 0x22C9F00F, 0x2F8AD6D6, 0x2B4BCB61, + 0x350C9B64, 0x31CD86D3, 0x3C8EA00A, 0x384FBDBD, 0x4C11DB70, 0x48D0C6C7, + 0x4593E01E, 0x4152FDA9, 0x5F15ADAC, 0x5BD4B01B, 0x569796C2, 0x52568B75, + 0x6A1936C8, 0x6ED82B7F, 0x639B0DA6, 0x675A1011, 0x791D4014, 0x7DDC5DA3, + 0x709F7B7A, 0x745E66CD, 0x9823B6E0, 0x9CE2AB57, 0x91A18D8E, 0x95609039, + 0x8B27C03C, 0x8FE6DD8B, 0x82A5FB52, 0x8664E6E5, 0xBE2B5B58, 0xBAEA46EF, + 0xB7A96036, 0xB3687D81, 0xAD2F2D84, 0xA9EE3033, 0xA4AD16EA, 0xA06C0B5D, + 0xD4326D90, 0xD0F37027, 0xDDB056FE, 0xD9714B49, 0xC7361B4C, 0xC3F706FB, + 0xCEB42022, 0xCA753D95, 0xF23A8028, 0xF6FB9D9F, 0xFBB8BB46, 0xFF79A6F1, + 0xE13EF6F4, 0xE5FFEB43, 0xE8BCCD9A, 0xEC7DD02D, 0x34867077, 0x30476DC0, + 0x3D044B19, 0x39C556AE, 0x278206AB, 0x23431B1C, 0x2E003DC5, 0x2AC12072, + 0x128E9DCF, 0x164F8078, 0x1B0CA6A1, 0x1FCDBB16, 0x018AEB13, 0x054BF6A4, + 0x0808D07D, 0x0CC9CDCA, 0x7897AB07, 0x7C56B6B0, 0x71159069, 0x75D48DDE, + 0x6B93DDDB, 0x6F52C06C, 0x6211E6B5, 0x66D0FB02, 0x5E9F46BF, 0x5A5E5B08, + 0x571D7DD1, 0x53DC6066, 0x4D9B3063, 0x495A2DD4, 0x44190B0D, 0x40D816BA, + 0xACA5C697, 0xA864DB20, 0xA527FDF9, 0xA1E6E04E, 0xBFA1B04B, 0xBB60ADFC, + 0xB6238B25, 0xB2E29692, 0x8AAD2B2F, 0x8E6C3698, 0x832F1041, 0x87EE0DF6, + 0x99A95DF3, 0x9D684044, 0x902B669D, 0x94EA7B2A, 0xE0B41DE7, 0xE4750050, + 0xE9362689, 0xEDF73B3E, 0xF3B06B3B, 0xF771768C, 0xFA325055, 0xFEF34DE2, + 0xC6BCF05F, 0xC27DEDE8, 0xCF3ECB31, 0xCBFFD686, 0xD5B88683, 0xD1799B34, + 0xDC3ABDED, 0xD8FBA05A, 0x690CE0EE, 0x6DCDFD59, 0x608EDB80, 0x644FC637, + 0x7A089632, 0x7EC98B85, 0x738AAD5C, 0x774BB0EB, 0x4F040D56, 0x4BC510E1, + 0x46863638, 0x42472B8F, 0x5C007B8A, 0x58C1663D, 0x558240E4, 0x51435D53, + 0x251D3B9E, 0x21DC2629, 0x2C9F00F0, 0x285E1D47, 0x36194D42, 0x32D850F5, + 0x3F9B762C, 0x3B5A6B9B, 0x0315D626, 0x07D4CB91, 0x0A97ED48, 0x0E56F0FF, + 0x1011A0FA, 0x14D0BD4D, 0x19939B94, 0x1D528623, 0xF12F560E, 0xF5EE4BB9, + 0xF8AD6D60, 0xFC6C70D7, 0xE22B20D2, 0xE6EA3D65, 0xEBA91BBC, 0xEF68060B, + 0xD727BBB6, 0xD3E6A601, 0xDEA580D8, 0xDA649D6F, 0xC423CD6A, 0xC0E2D0DD, + 0xCDA1F604, 0xC960EBB3, 0xBD3E8D7E, 0xB9FF90C9, 0xB4BCB610, 0xB07DABA7, + 0xAE3AFBA2, 0xAAFBE615, 0xA7B8C0CC, 0xA379DD7B, 0x9B3660C6, 0x9FF77D71, + 0x92B45BA8, 0x9675461F, 0x8832161A, 0x8CF30BAD, 0x81B02D74, 0x857130C3, + 0x5D8A9099, 0x594B8D2E, 0x5408ABF7, 0x50C9B640, 0x4E8EE645, 0x4A4FFBF2, + 0x470CDD2B, 0x43CDC09C, 0x7B827D21, 0x7F436096, 0x7200464F, 0x76C15BF8, + 0x68860BFD, 0x6C47164A, 0x61043093, 0x65C52D24, 0x119B4BE9, 0x155A565E, + 0x18197087, 0x1CD86D30, 0x029F3D35, 0x065E2082, 0x0B1D065B, 0x0FDC1BEC, + 0x3793A651, 0x3352BBE6, 0x3E119D3F, 0x3AD08088, 0x2497D08D, 0x2056CD3A, + 0x2D15EBE3, 0x29D4F654, 0xC5A92679, 0xC1683BCE, 0xCC2B1D17, 0xC8EA00A0, + 0xD6AD50A5, 0xD26C4D12, 0xDF2F6BCB, 0xDBEE767C, 0xE3A1CBC1, 0xE760D676, + 0xEA23F0AF, 0xEEE2ED18, 0xF0A5BD1D, 0xF464A0AA, 0xF9278673, 0xFDE69BC4, + 0x89B8FD09, 0x8D79E0BE, 0x803AC667, 0x84FBDBD0, 0x9ABC8BD5, 0x9E7D9662, + 0x933EB0BB, 0x97FFAD0C, 0xAFB010B1, 0xAB710D06, 0xA6322BDF, 0xA2F33668, + 0xBCB4666D, 0xB8757BDA, 0xB5365D03, 0xB1F740B4}; + + /** + * For each byte in the input data, XOR the current CRC value with the byte, then shift the CRC value + * left 8 bits, and XOR the CRC value with the CRC table value for the byte + * + * @param pData pointer to the data to be CRC'd + * @param Length The length of the data to be CRC'd. + * + * @return The CRC32 value of the data. + */ + static inline uint32_t checkCRC32(uint8_t *pData, uint32_t Length) + { + unsigned int nReg; + unsigned int nTemp = 0; + unsigned short i, n; + + nReg = 0xFFFFFFFF; + for (n = 0; n < Length; n++) + { + nReg ^= (unsigned int)pData[n]; + + for (i = 0; i < 4; i++) + { + nTemp = Crc32Table[(unsigned char)((nReg >> 24) & 0xff)]; + nReg <<= 8; + nReg ^= nTemp; + } + } + return nReg; + } + + /** + * It takes a pointer to an array of bytes and the length of the array, and returns the sum of the + * bytes in the array + * + * @param pData The data to be calculated + * @param Lenght The length of the data to be sent. + * + * @return The sum of the bytes in the array. + */ + static inline unsigned char CheckSum(unsigned char *pData, unsigned short Lenght) + { + unsigned short temp = 0; + unsigned short i = 0; + for (i = 0; i < Lenght; i++) + { + temp += pData[i]; + } + return temp & 0XFF; + } + +} + +#endif \ No newline at end of file diff --git a/gimbal_ctrl/driver/src/G2/g2_gimbal_driver.cpp b/gimbal_ctrl/driver/src/G2/g2_gimbal_driver.cpp new file mode 100755 index 0000000..cd43ab4 --- /dev/null +++ b/gimbal_ctrl/driver/src/G2/g2_gimbal_driver.cpp @@ -0,0 +1,243 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2023-03-01 10:12:58 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-11 17:33:42 + * @FilePath: /gimbal-sdk-multi-platform/src/G2/g2_gimbal_driver.cpp + */ + +#include "g2_gimbal_driver.h" +#include "g2_gimbal_crc.h" +#include "string.h" + +/** + * The function creates a new instance of the g2GimbalDriver class, which is a subclass of the + * IamovGimbalBase class + * + * @param _IO The IOStreamBase class that is used to communicate with the gimbal. + */ +g2GimbalDriver::g2GimbalDriver(amovGimbal::IOStreamBase *_IO) : amovGimbal::IamovGimbalBase(_IO) +{ + memset(&rxQueue, 0, sizeof(RING_FIFO_CB_T)); + memset(&txQueue, 0, sizeof(RING_FIFO_CB_T)); + + rxBuffer = (uint8_t *)malloc(MAX_QUEUE_SIZE * sizeof(G2::GIMBAL_FRAME_T)); + if (rxBuffer == NULL) + { + std::cout << "Receive buffer creation failed! Size : " << MAX_QUEUE_SIZE << std::endl; + exit(1); + } + txBuffer = (uint8_t *)malloc(MAX_QUEUE_SIZE * sizeof(G2::GIMBAL_FRAME_T)); + if (txBuffer == NULL) + { + free(rxBuffer); + std::cout << "Send buffer creation failed! Size : " << MAX_QUEUE_SIZE << std::endl; + exit(1); + } + + Ring_Fifo_init(&rxQueue, sizeof(G2::GIMBAL_FRAME_T), rxBuffer, MAX_QUEUE_SIZE * sizeof(G2::GIMBAL_FRAME_T)); + Ring_Fifo_init(&txQueue, sizeof(G2::GIMBAL_FRAME_T), txBuffer, MAX_QUEUE_SIZE * sizeof(G2::GIMBAL_FRAME_T)); + + parserState = G2::GIMBAL_SERIAL_STATE_IDEL; +} + +/** + * It takes a command, a pointer to a payload, and the size of the payload, and then it puts the + * command, the payload, and the CRC into a ring buffer + * + * @param uint32_t 4 bytes + * @param pPayload pointer to the data to be sent + * @param payloadSize the size of the payload in bytes + * + * @return The number of bytes in the packet. + */ +uint32_t g2GimbalDriver::pack(IN uint32_t cmd, uint8_t *pPayload, uint8_t payloadSize) +{ + uint32_t ret = 0; + G2::GIMBAL_FRAME_T txTemp; + + txTemp.head = G2_SERIAL_HEAD; + txTemp.version = G2_SERIAL_VERSION; + txTemp.len = payloadSize; + txTemp.command = cmd; + txTemp.source = self; + txTemp.target = remote; + memcpy(txTemp.data, pPayload, payloadSize); + txTemp.crc.f16 = G2::checkCrc16((uint8_t *)&txTemp, txTemp.len + G2_PAYLOAD_OFFSET); + memcpy(txTemp.data + payloadSize, txTemp.crc.f8, sizeof(uint16_t)); + + txMutex.lock(); + if (Ring_Fifo_in_cell(&txQueue, &txTemp)) + { + ret = txTemp.len + G2_PAYLOAD_OFFSET + sizeof(uint16_t); + } + txMutex.unlock(); + + return ret; +} + +/** + * > This function is used to get a packet from the receive queue + * + * @param void This is the type of data that will be stored in the queue. + * + * @return A boolean value. + */ +bool g2GimbalDriver::getRxPack(OUT void *pack) +{ + bool state = false; + rxMutex.lock(); + state = Ring_Fifo_out_cell(&rxQueue, pack); + rxMutex.unlock(); + return state; +} + +/** + * The function takes a pointer to a buffer, casts it to a pointer to a G2::GIMBAL_FRAME_T, and then + * checks the command field of the frame. If the command is G2::IAP_COMMAND_BLOCK_END, it locks the + * mutex, and then unlocks it. Otherwise, it prints out the contents of the buffer + * + * @param buf pointer to the data received from the gimbal + */ +void g2GimbalDriver::convert(void *buf) +{ + G2::GIMBAL_FRAME_T *temp; + temp = reinterpret_cast(buf); + switch (temp->command) + { + case G2::IAP_COMMAND_BLOCK_END: + mState.lock(); + updateGimbalStateCallback(state.abs.roll, state.abs.pitch, state.abs.yaw, + state.rel.roll, state.rel.pitch, state.rel.yaw, + state.fov.x, state.fov.y); + mState.unlock(); + break; + + default: + std::cout << "Undefined frame from G2 : "; + for (uint16_t i = 0; i < temp->len + G2_PAYLOAD_OFFSET + sizeof(uint32_t); i++) + { + printf("%02X ", ((uint8_t *)buf)[i]); + } + std::cout << std::endl; + break; + } +} + +/** + * If the serial port is not busy and is open, then lock the txMutex, get the next byte from the + * txQueue, unlock the txMutex, and send the byte + */ +void g2GimbalDriver::send(void) +{ + if (!IO->isBusy() && IO->isOpen()) + { + bool state = false; + txMutex.lock(); + state = Ring_Fifo_out_cell(&txQueue, &tx); + txMutex.unlock(); + if (state) + { + IO->outPutBytes((uint8_t *)&tx, tx.len + G2_PAYLOAD_OFFSET + sizeof(uint16_t)); + } + } +} + +/** + * The function is called every time a byte is received from the serial port. It parses the byte and + * stores it in a buffer. When the buffer is full, it checks the CRC and if it's correct, it stores the + * buffer in a queue + * + * @param uint8_t unsigned char + * + * @return The parser function is returning a boolean value. + */ +bool g2GimbalDriver::parser(IN uint8_t byte) +{ + bool state = false; + static uint8_t payloadLenghte = 0; + static uint8_t *pRx = NULL; + + switch (parserState) + { + case G2::GIMBAL_SERIAL_STATE_IDEL: + if (byte == G2_SERIAL_HEAD) + { + rx.head = byte; + parserState = G2::GIMBAL_SERIAL_STATE_HEAD_RCV; + } + break; + + case G2::GIMBAL_SERIAL_STATE_HEAD_RCV: + if (byte == G2_SERIAL_VERSION) + { + rx.version = byte; + parserState = G2::GIMBAL_SERIAL_STATE_VERSION_RCV; + } + else + { + rx.head = 0; + parserState = G2::GIMBAL_SERIAL_STATE_IDEL; + } + break; + + case G2::GIMBAL_SERIAL_STATE_VERSION_RCV: + rx.target = byte; + parserState = G2::GIMBAL_SERIAL_STATE_TARGET_RCV; + break; + + case G2::GIMBAL_SERIAL_STATE_TARGET_RCV: + rx.source = byte; + parserState = G2::GIMBAL_SERIAL_STATE_SOURCE_RCV; + break; + + case G2::GIMBAL_SERIAL_STATE_SOURCE_RCV: + rx.len = byte; + parserState = G2::GIMBAL_SERIAL_STATE_LENGHT_RCV; + pRx = rx.data; + payloadLenghte = byte; + break; + + case G2::GIMBAL_SERIAL_STATE_LENGHT_RCV: + rx.command = byte; + parserState = G2::GIMBAL_SERIAL_STATE_DATA_RCV; + break; + + case G2::GIMBAL_SERIAL_STATE_DATA_RCV: + *pRx = byte; + payloadLenghte--; + if (payloadLenghte == 0) + { + parserState = G2::GIMBAL_SERIAL_STATE_CRC_RCV1; + } + break; + + case G2::GIMBAL_SERIAL_STATE_CRC_RCV1: + rx.crc.f8[1] = byte; + parserState = G2::GIMBAL_SERIAL_STATE_END; + break; + + case G2::GIMBAL_SERIAL_STATE_END: + rx.crc.f8[0] = byte; + + if (rx.crc.f16 == G2::checkCrc16((uint8_t *)&rx, G2_PAYLOAD_OFFSET + rx.len)) + { + state = true; + rxMutex.lock(); + Ring_Fifo_in_cell(&rxQueue, &rx); + rxMutex.unlock(); + } + else + { + memset(&rx, 0, sizeof(G2::GIMBAL_FRAME_T)); + } + parserState = G2::GIMBAL_SERIAL_STATE_IDEL; + break; + + default: + parserState = G2::GIMBAL_SERIAL_STATE_IDEL; + break; + } + return state; +} \ No newline at end of file diff --git a/gimbal_ctrl/driver/src/G2/g2_gimbal_driver.h b/gimbal_ctrl/driver/src/G2/g2_gimbal_driver.h new file mode 100755 index 0000000..98cd2c3 --- /dev/null +++ b/gimbal_ctrl/driver/src/G2/g2_gimbal_driver.h @@ -0,0 +1,90 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2023-03-01 10:02:24 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-03-13 12:29:33 + * @FilePath: \gimbal-sdk-multi-platform\src\G2\g2_gimbal_driver.h + */ +#include "../amov_gimbal.h" +#include "g2_gimbal_struct.h" +#include +#include +#include + +#ifndef __G2_DRIVER_H +#define __G2_DRIVER_H + +extern "C" +{ +#include "Ring_Fifo.h" +} + +class g2GimbalDriver : protected amovGimbal::IamovGimbalBase +{ +private: + G2::GIMBAL_CMD_PARSER_STATE_T parserState; + G2::GIMBAL_FRAME_T rx; + G2::GIMBAL_FRAME_T tx; + + std::mutex rxMutex; + uint8_t *rxBuffer; + RING_FIFO_CB_T rxQueue; + std::mutex txMutex; + uint8_t *txBuffer; + RING_FIFO_CB_T txQueue; + + uint8_t self; + uint8_t remote; + + bool parser(IN uint8_t byte); + void send(void); + + void convert(void *buf); + uint32_t pack(IN uint32_t cmd, uint8_t *pPayload, uint8_t payloadSize); + bool getRxPack(OUT void *pack); + +public: + void nodeSet(SET uint32_t _self, SET uint32_t _remote) + { + self = _self; + remote = _remote; + } + + // funtion + uint32_t setGimabalPos(const amovGimbal::AMOV_GIMBAL_POS_T &pos); + uint32_t setGimabalSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &speed); + uint32_t setGimabalFollowSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &followSpeed); + uint32_t setGimabalHome(void); + + uint32_t takePic(void); + uint32_t setVideo(const amovGimbal::AMOV_GIMBAL_VIDEO_T newState); + +#ifdef AMOV_HOST + // iap funtion (内部源码模式提供功能 lib模式下不可见) + bool iapGetSoftInfo(std::string &info); + bool iapGetHardInfo(std::string &info); + bool iapJump(G2::GIMBAL_IAP_STATE_T &state); + bool iapFlashErase(G2::GIMBAL_IAP_STATE_T &state); + bool iapSendBlockInfo(uint32_t &startAddr, uint32_t &crc32); + bool iapSendBlockData(uint8_t offset, uint8_t *data); + bool iapFlashWrite(uint32_t &crc32, G2::GIMBAL_IAP_STATE_T &state); + + // 判断是否需要跳转 + bool iapJumpCheck(std::string &info) { return true; } +#endif + + static amovGimbal::IamovGimbalBase *creat(amovGimbal::IOStreamBase *_IO) + { + return new g2GimbalDriver(_IO); + } + + g2GimbalDriver(amovGimbal::IOStreamBase *_IO); + ~g2GimbalDriver() + { + free(rxBuffer); + free(txBuffer); + } +}; + +#endif diff --git a/gimbal_ctrl/driver/src/G2/g2_gimbal_funtion.cpp b/gimbal_ctrl/driver/src/G2/g2_gimbal_funtion.cpp new file mode 100644 index 0000000..e7b0ca9 --- /dev/null +++ b/gimbal_ctrl/driver/src/G2/g2_gimbal_funtion.cpp @@ -0,0 +1,81 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2023-03-13 11:58:54 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-03-13 12:31:58 + * @FilePath: \gimbal-sdk-multi-platform\src\G2\g2_gimbal_funtion.cpp + */ +#include "g2_gimbal_driver.h" +#include "g2_gimbal_crc.h" +#include "string.h" + +/** + * It sets the gimbal position. + * + * @param pos the position of the gimbal + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t g2GimbalDriver::setGimabalPos(const amovGimbal::AMOV_GIMBAL_POS_T &pos) +{ + return 0; +} + +/** + * It takes a struct of type amovGimbal::AMOV_GIMBAL_POS_T and converts it to a struct of type + * G1::GIMBAL_SET_POS_MSG_T + * + * @param speed the speed of the gimbal + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t g2GimbalDriver::setGimabalSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &speed) +{ + return 0; +} + +/** + * This function sets the gimbal's follow speed + * + * @param followSpeed the speed of the gimbal + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t g2GimbalDriver::setGimabalFollowSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &followSpeed) +{ + return 0; +} + +/** + * This function sets the gimbal to its home position + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t g2GimbalDriver::setGimabalHome(void) +{ + return 0; +} + +/** + * It takes a picture. + * + * @return The return value is the number of bytes written to the serial port. + */ +uint32_t g2GimbalDriver::takePic(void) +{ + return 0; +} + +/** + * The function sets the video state of the gimbal + * + * @param newState The new state of the video. + * + * @return The return value is the number of bytes written to the serial port. + */ +uint32_t g2GimbalDriver::setVideo(const amovGimbal::AMOV_GIMBAL_VIDEO_T newState) +{ + + return 0; +} diff --git a/gimbal_ctrl/driver/src/G2/g2_gimbal_iap_funtion.cpp b/gimbal_ctrl/driver/src/G2/g2_gimbal_iap_funtion.cpp new file mode 100755 index 0000000..8024c5f --- /dev/null +++ b/gimbal_ctrl/driver/src/G2/g2_gimbal_iap_funtion.cpp @@ -0,0 +1,357 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2023-03-02 11:16:52 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-18 10:13:08 + * @FilePath: /gimbal-sdk-multi-platform/src/G2/g2_gimbal_iap_funtion.cpp + */ + +#ifdef AMOV_HOST + +#include "g2_gimbal_driver.h" +#include "g2_gimbal_crc.h" +#include "string.h" + +#include + +#define MAX_WAIT_TIME_MS 2000 + +/** + * It gets the software information from the gimbal. + * + * @param info the string to store the information + * + * @return a boolean value. + */ +bool g2GimbalDriver::iapGetSoftInfo(std::string &info) +{ + uint8_t temp = 0; + bool ret = false; + G2::GIMBAL_FRAME_T ack; + + pack(G2::IAP_COMMAND_SOFT_INFO, &temp, 1); + + std::chrono::milliseconds startMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + + while (1) + { + if (getRxPack(&ack)) + { + if (ack.command == G2::IAP_COMMAND_SOFT_INFO && + ack.target == self && + ack.source == remote) + { + info = (char *)ack.data; + std::cout << info << std::endl; + ret = true; + break; + } + } + + std::chrono::milliseconds nowMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + if ((nowMs - startMs) > std::chrono::milliseconds(MAX_WAIT_TIME_MS)) + { + break; + } + } + return ret; +} + +/** + * It gets the hardware information of the gimbal. + * + * @param info the string to store the hardware information + * + * @return a boolean value. + */ +bool g2GimbalDriver::iapGetHardInfo(std::string &info) +{ + uint8_t temp = 0; + bool ret = false; + G2::GIMBAL_FRAME_T ack; + + pack(G2::IAP_COMMAND_HARDWARE_INFO, &temp, 1); + + std::chrono::milliseconds startMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + + while (1) + { + if (getRxPack(&ack)) + { + if (ack.command == G2::IAP_COMMAND_HARDWARE_INFO && + ack.target == self && + ack.source == remote) + { + info = (char *)ack.data; + std::cout << info << std::endl; + ret = true; + break; + } + } + + std::chrono::milliseconds nowMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + if ((nowMs - startMs) > std::chrono::milliseconds(MAX_WAIT_TIME_MS)) + { + break; + } + } + return ret; +} + +/** + * It sends a command to the gimbal to jump to the bootloader, and then waits for a response from the + * bootloader + * + * @param state the state of the gimbal, 0: normal, 1: iap + * + * @return The return value is a boolean. + */ +bool g2GimbalDriver::iapJump(G2::GIMBAL_IAP_STATE_T &state) +{ + uint8_t temp = 0; + bool ret = true; + G2::GIMBAL_FRAME_T ack; + + pack(G2::IAP_COMMAND_JUMP, &temp, 1); + + std::chrono::milliseconds startMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + + // It fails if the specified message is received. + while (1) + { + if (getRxPack(&ack)) + { + if (ack.command == G2::IAP_COMMAND_JUMP && + ack.target == self && + ack.source == remote) + { + state = (G2::GIMBAL_IAP_STATE_T)ack.data[1]; + ret = false; + break; + } + } + + std::chrono::milliseconds nowMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + if ((nowMs - startMs) > std::chrono::milliseconds(MAX_WAIT_TIME_MS)) + { + break; + } + } + return ret; +} + +/** + * The function sends a command to the gimbal to erase the flash memory + * + * @param state The state of the IAP process. + * + * @return The return value is a boolean. + */ +bool g2GimbalDriver::iapFlashErase(G2::GIMBAL_IAP_STATE_T &state) +{ + uint8_t temp = 0; + bool ret = false; + G2::GIMBAL_FRAME_T ack; + + pack(G2::IAP_COMMAND_FLASH_ERASE, &temp, 1); + + std::chrono::milliseconds startMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + + while (1) + { + if (getRxPack(&ack)) + { + if (ack.command == G2::IAP_COMMAND_FLASH_ERASE && + ack.target == self && + ack.source == remote) + { + state = (G2::GIMBAL_IAP_STATE_T)ack.data[1]; + ret = true; + break; + } + } + + std::chrono::milliseconds nowMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + if ((nowMs - startMs) > std::chrono::milliseconds(MAX_WAIT_TIME_MS)) + { + break; + } + } + return ret; +} + +/** + * It sends a block of data to the gimbal, and waits for an acknowledgement + * + * @param startAddr the start address of the block to be sent + * @param crc32 The CRC32 of the data to be sent. + * + * @return a boolean value. + */ +bool g2GimbalDriver::iapSendBlockInfo(uint32_t &startAddr, uint32_t &crc32) +{ + union + { + uint32_t f32; + uint8_t f8[4]; + } temp; + uint8_t buf[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + bool ret = false; + G2::GIMBAL_FRAME_T ack; + + temp.f32 = startAddr; + memcpy(buf, temp.f8, sizeof(uint32_t)); + temp.f32 = crc32; + memcpy(buf, temp.f8, sizeof(uint32_t)); + + pack(G2::IAP_COMMAND_BOLCK_INFO, buf, 8); + + std::chrono::milliseconds startMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + + while (1) + { + if (getRxPack(&ack)) + { + if (ack.command == G2::IAP_COMMAND_BOLCK_INFO && + ack.target == self && + ack.source == remote) + { + ret = true; + for (uint8_t i = 0; i < 8; i++) + { + if (buf[i] != ack.data[i]) + { + ret = false; + } + } + break; + } + } + + std::chrono::milliseconds nowMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + if ((nowMs - startMs) > std::chrono::milliseconds(MAX_WAIT_TIME_MS)) + { + break; + } + } + return ret; +} + +/** + * It sends a block of data to the gimbal, and waits for an acknowledgement + * + * @param offset the offset of the data block in the file + * @param data pointer to the data to be sent + * + * @return The return value is a boolean. + */ +bool g2GimbalDriver::iapSendBlockData(uint8_t offset, uint8_t *data) +{ + bool ret = false; + G2::GIMBAL_FRAME_T ack; + + pack(G2::IAP_COMMAND_BLOCK_START + offset, data, 64); + + std::chrono::milliseconds startMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + + while (1) + { + if (getRxPack(&ack)) + { + if (ack.command == G2::IAP_COMMAND_BLOCK_START + offset && + ack.target == self && + ack.source == remote) + { + ret = true; + for (uint8_t i = 0; i < 64; i++) + { + if (data[i] != ack.data[i]) + { + ret = false; + } + } + break; + } + } + + std::chrono::milliseconds nowMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + if ((nowMs - startMs) > std::chrono::milliseconds(MAX_WAIT_TIME_MS)) + { + break; + } + } + return ret; +} + +/** + * The function sends a block of data to the gimbal, and waits for an acknowledgement + * + * @param crc32 The CRC32 of the data block + * @param state The state of the IAP process. + * + * @return The return value is a boolean. + */ +bool g2GimbalDriver::iapFlashWrite(uint32_t &crc32, G2::GIMBAL_IAP_STATE_T &state) +{ + bool ret = false; + G2::GIMBAL_FRAME_T ack; + + union + { + uint32_t f32; + uint8_t f8[4]; + } temp; + + temp.f32 = crc32; + + pack(G2::IAP_COMMAND_BLOCK_WRITE, temp.f8, sizeof(uint32_t)); + + std::chrono::milliseconds startMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + + while (1) + { + if (getRxPack(&ack)) + { + if (ack.command == G2::IAP_COMMAND_BLOCK_WRITE && + ack.target == self && + ack.source == remote) + { + state = (G2::GIMBAL_IAP_STATE_T)ack.data[4]; + ret = true; + for (uint8_t i = 0; i < 4; i++) + { + if (temp.f8[i] != ack.data[i]) + { + ret = false; + } + } + break; + } + } + + std::chrono::milliseconds nowMs = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + if ((nowMs - startMs) > std::chrono::milliseconds(MAX_WAIT_TIME_MS)) + { + break; + } + } + return ret; +} + +#endif + diff --git a/gimbal_ctrl/driver/src/G2/g2_gimbal_struct.h b/gimbal_ctrl/driver/src/G2/g2_gimbal_struct.h new file mode 100755 index 0000000..221d0dc --- /dev/null +++ b/gimbal_ctrl/driver/src/G2/g2_gimbal_struct.h @@ -0,0 +1,81 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2023-03-01 09:21:57 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-18 10:13:23 + * @FilePath: /gimbal-sdk-multi-platform/src/G2/g2_gimbal_struct.h + */ + +#ifndef G2_GIMBAL_STRUCT_H +#define G2_GIMBAL_STRUCT_H + +#include + +namespace G2 +{ + +#define G2_MAX_GIMBAL_PAYLOAD 64 +#define G2_PAYLOAD_OFFSET 6 +#define G2_SCALE_FACTOR 0.01f +#define G2_SERIAL_HEAD 0XAF +#define G2_SERIAL_VERSION 0X02 + + typedef enum + { + IAP_COMMAND_JUMP = 80, + IAP_COMMAND_FLASH_ERASE, + IAP_COMMAND_BOLCK_INFO, + IAP_COMMAND_BLOCK_WRITE, + IAP_COMMAND_SOFT_INFO, + IAP_COMMAND_HARDWARE_INFO, + IAP_COMMAND_BLOCK_START, + IAP_COMMAND_BLOCK_END = 117, + } GIMBAL_CMD_T; + + typedef enum + { + IAP_STATE_FAILD = 0, + IAP_STATE_SUCCEED, + IAP_STATE_READY, + IAP_STATE_FIRMWARE_BROKE, + IAP_STATE_JUMP_FAILD, + IAP_STATE_ADDR_ERR, + IAP_STATE_CRC_ERR, + IAP_STATE_WRITE_ERR, + IAP_STATE_WRITE_TIMEOUT, + } GIMBAL_IAP_STATE_T; + + typedef enum + { + GIMBAL_SERIAL_STATE_IDEL = 0, + GIMBAL_SERIAL_STATE_HEAD_RCV, + GIMBAL_SERIAL_STATE_VERSION_RCV, + GIMBAL_SERIAL_STATE_TARGET_RCV, + GIMBAL_SERIAL_STATE_SOURCE_RCV, + GIMBAL_SERIAL_STATE_LENGHT_RCV, + GIMBAL_SERIAL_STATE_DATA_RCV, + GIMBAL_SERIAL_STATE_CRC_RCV1, + GIMBAL_SERIAL_STATE_END, + } GIMBAL_CMD_PARSER_STATE_T; + +#pragma pack(1) + typedef struct + { + uint8_t head; + uint8_t version; + uint8_t target; + uint8_t source; + uint8_t len; + uint8_t command; + uint8_t data[G2_MAX_GIMBAL_PAYLOAD]; + union + { + uint8_t f8[2]; + uint16_t f16; + } crc; + } GIMBAL_FRAME_T; +#pragma pack(0) + +} +#endif \ No newline at end of file diff --git a/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_crc32.h b/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_crc32.h new file mode 100755 index 0000000..709b2d8 --- /dev/null +++ b/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_crc32.h @@ -0,0 +1,27 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2022-10-27 18:10:06 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-03-23 17:24:23 + * @FilePath: /gimbal-sdk-multi-platform/src/Q10f/Q10f_gimbal_crc32.h + */ +#ifndef Q10F_GIMBAL_CRC32_H +#define Q10F_GIMBAL_CRC32_H + +namespace Q10f +{ + static inline unsigned char CheckSum(unsigned char *pData, unsigned short Lenght) + { + unsigned short temp = 0; + unsigned short i = 0; + for (i = 0; i < Lenght; i++) + { + temp += pData[i]; + } + return temp & 0XFF; + } + +} // namespace name + +#endif \ No newline at end of file diff --git a/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_driver.cpp b/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_driver.cpp new file mode 100755 index 0000000..2b098d1 --- /dev/null +++ b/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_driver.cpp @@ -0,0 +1,258 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2022-10-27 18:10:06 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-11 17:29:58 + * @FilePath: /gimbal-sdk-multi-platform/src/Q10f/Q10f_gimbal_driver.cpp + */ +#include "Q10f_gimbal_driver.h" +#include "Q10f_gimbal_crc32.h" +#include "string.h" + +/** + * The function creates a new instance of the g1GimbalDriver class, which is a subclass of the + * IamovGimbalBase class + * + * @param _IO The IOStreamBase object that will be used to communicate with the gimbal. + */ +Q10fGimbalDriver::Q10fGimbalDriver(amovGimbal::IOStreamBase *_IO) : amovGimbal::IamovGimbalBase(_IO) +{ + memset(&rxQueue, 0, sizeof(RING_FIFO_CB_T)); + memset(&txQueue, 0, sizeof(RING_FIFO_CB_T)); + + rxBuffer = (uint8_t *)malloc(MAX_QUEUE_SIZE * sizeof(Q10f::GIMBAL_FRAME_T)); + if (rxBuffer == NULL) + { + std::cout << "Receive buffer creation failed! Size : " << MAX_QUEUE_SIZE << std::endl; + exit(1); + } + txBuffer = (uint8_t *)malloc(MAX_QUEUE_SIZE * sizeof(Q10f::GIMBAL_FRAME_T)); + if (txBuffer == NULL) + { + free(rxBuffer); + std::cout << "Send buffer creation failed! Size : " << MAX_QUEUE_SIZE << std::endl; + exit(1); + } + + Ring_Fifo_init(&rxQueue, sizeof(Q10f::GIMBAL_FRAME_T), rxBuffer, MAX_QUEUE_SIZE * sizeof(Q10f::GIMBAL_FRAME_T)); + Ring_Fifo_init(&txQueue, sizeof(Q10f::GIMBAL_FRAME_T), txBuffer, MAX_QUEUE_SIZE * sizeof(Q10f::GIMBAL_FRAME_T)); + + parserState = Q10f::GIMBAL_SERIAL_STATE_IDLE; + + // Initialize and enable attitude data return (50Hz) + uint8_t cmd = 0XFF; + pack(Q10f::GIMBAL_CMD_SET_FEEDBACK_H, &cmd, 1); + pack(Q10f::GIMBAL_CMD_SET_FEEDBACK_L, &cmd, 1); + cmd = 0X00; + pack(Q10f::GIMBAL_CMD_OPEN_FEEDBACK, &cmd, 1); +} + +/** + * The function takes a command, a pointer to a payload, and the size of the payload. It then copies + * the payload into the tx buffer, calculates the checksum, and then calculates the CRC32 of the + * payload. It then copies the CRC32 into the tx buffer, and then copies the tx buffer into the txQueue + * + * @param uint32_t 4 bytes + * @param pPayload pointer to the data to be sent + * @param payloadSize the size of the payload + * + * @return The size of the data to be sent. + */ +uint32_t Q10fGimbalDriver::pack(IN uint32_t cmd, uint8_t *pPayload, uint8_t payloadSize) +{ + uint32_t ret = 0; + Q10f::GIMBAL_FRAME_T txTemp; + + txTemp.head = cmd; + memcpy(txTemp.data, pPayload, payloadSize); + + if (cmd != Q10f::GIMBAL_CMD_SET_POS) + { + payloadSize--; + } + else + { + txTemp.data[payloadSize] = Q10f::CheckSum(pPayload, payloadSize); + } + txTemp.len = payloadSize; + + txMutex.lock(); + if (Ring_Fifo_in_cell(&txQueue, &txTemp)) + { + ret = payloadSize + sizeof(uint32_t) + sizeof(uint8_t); + } + txMutex.unlock(); + + return ret; +} + +/** + * > This function is used to get a packet from the receive queue + * + * @param void This is the type of data that will be stored in the queue. + * + * @return A boolean value. + */ +bool Q10fGimbalDriver::getRxPack(OUT void *pack) +{ + bool state = false; + rxMutex.lock(); + state = Ring_Fifo_out_cell(&rxQueue, pack); + rxMutex.unlock(); + return state; +} + +void Q10fGimbalDriver::convert(void *buf) +{ + Q10f::GIMBAL_FRAME_T *temp; + temp = reinterpret_cast(buf); + switch (temp->head) + { + case Q10f::GIMBAL_CMD_RCV_STATE: + Q10f::GIMBAL_RCV_POS_MSG_T *tempPos; + tempPos = reinterpret_cast(((uint8_t *)buf) + Q10F_PAYLOAD_OFFSET); + mState.lock(); + state.abs.yaw = tempPos->yawIMUAngle * Q10F_SCALE_FACTOR_ANGLE; + state.abs.roll = tempPos->rollIMUAngle * Q10F_SCALE_FACTOR_ANGLE; + state.abs.pitch = tempPos->pitchIMUAngle * Q10F_SCALE_FACTOR_ANGLE; + state.rel.yaw = tempPos->rollStatorRotorAngle * Q10F_SCALE_FACTOR_SPEED; + state.rel.roll = tempPos->rollStatorRotorAngle * Q10F_SCALE_FACTOR_SPEED; + state.rel.pitch = tempPos->pitchStatorRotorAngle * Q10F_SCALE_FACTOR_SPEED; + updateGimbalStateCallback(state.abs.roll, state.abs.pitch, state.abs.yaw, + state.rel.roll, state.rel.pitch, state.rel.yaw, + state.fov.x, state.fov.y); + mState.unlock(); + + break; + default: + std::cout << "Undefined frame from Q10f : "; + for (uint16_t i = 0; i < temp->len + Q10F_PAYLOAD_OFFSET; i++) + { + printf("%02X ", ((uint8_t *)buf)[i]); + } + std::cout << std::endl; + break; + } +} + +/** + * The function is called by the main thread to send a command to the gimbal. + * + * The function first checks to see if the serial port is busy and if it is open. If it is not busy and + * it is open, the function locks the txMutex and then checks to see if there is a command in the + * txQueue. If there is a command in the txQueue, the function copies the command to the tx buffer and + * then unlocks the txMutex. The function then sends the command to the gimbal. + * + * The txQueue is a ring buffer that holds commands that are waiting to be sent to the gimbal. The + * txQueue is a ring buffer because the gimbal can only process one command at a time. If the gimbal is + * busy processing a command, the command will be placed in the txQueue and sent to the gimbal when the + * gimbal is ready to receive the command. + */ +void Q10fGimbalDriver::send(void) +{ + if (!IO->isBusy() && IO->isOpen()) + { + bool state = false; + txMutex.lock(); + state = Ring_Fifo_out_cell(&txQueue, &tx); + txMutex.unlock(); + if (state) + { + IO->outPutBytes((uint8_t *)&tx, tx.len + Q10F_PAYLOAD_OFFSET + sizeof(uint8_t)); + } + } +} + +/** + * It's a state machine that parses a serial stream of bytes into a struct + * + * @param uint8_t unsigned char + * + * @return A boolean value. + */ +bool Q10fGimbalDriver::parser(IN uint8_t byte) +{ + bool state = false; + static uint8_t payloadLenghte = 0; + static uint8_t *pRx = NULL; + uint8_t suncheck; + + switch (parserState) + { + case Q10f::GIMBAL_SERIAL_STATE_IDLE: + if (byte == ((Q10f::GIMBAL_CMD_RCV_STATE & 0X000000FF) >> 0)) + { + parserState = Q10f::GIMBAL_SERIAL_STATE_HEAD1; + } + break; + + case Q10f::GIMBAL_SERIAL_STATE_HEAD1: + if (byte == ((Q10f::GIMBAL_CMD_RCV_STATE & 0X0000FF00) >> 8)) + { + parserState = Q10f::GIMBAL_SERIAL_STATE_HEAD2; + } + else + { + parserState = Q10f::GIMBAL_SERIAL_STATE_IDLE; + } + break; + + case Q10f::GIMBAL_SERIAL_STATE_HEAD2: + if (byte == ((Q10f::GIMBAL_CMD_RCV_STATE & 0X00FF0000) >> 16)) + { + parserState = Q10f::GIMBAL_SERIAL_STATE_HEAD3; + } + else + { + parserState = Q10f::GIMBAL_SERIAL_STATE_IDLE; + } + break; + + case Q10f::GIMBAL_SERIAL_STATE_HEAD3: + if (byte == ((Q10f::GIMBAL_CMD_RCV_STATE & 0XFF000000) >> 24)) + { + parserState = Q10f::GIMBAL_SERIAL_STATE_DATE; + payloadLenghte = sizeof(Q10f::GIMBAL_RCV_POS_MSG_T); + pRx = rx.data; + rx.head = Q10f::GIMBAL_CMD_RCV_STATE; + } + else + { + parserState = Q10f::GIMBAL_SERIAL_STATE_IDLE; + } + break; + + case Q10f::GIMBAL_SERIAL_STATE_DATE: + *pRx = byte; + payloadLenghte--; + pRx++; + if (payloadLenghte == 0) + { + parserState = Q10f::GIMBAL_SERIAL_STATE_CHECK; + } + break; + + case Q10f::GIMBAL_SERIAL_STATE_CHECK: + suncheck = Q10f::CheckSum(rx.data, sizeof(Q10f::GIMBAL_RCV_POS_MSG_T)); + if (byte == suncheck) + { + state = true; + rxMutex.lock(); + Ring_Fifo_in_cell(&rxQueue, &rx); + rxMutex.unlock(); + } + else + { + memset(&rx, 0, sizeof(Q10f::GIMBAL_FRAME_T)); + } + parserState = Q10f::GIMBAL_SERIAL_STATE_IDLE; + break; + + default: + parserState = Q10f::GIMBAL_SERIAL_STATE_IDLE; + break; + } + + return state; +} diff --git a/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_driver.h b/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_driver.h new file mode 100755 index 0000000..0def5a1 --- /dev/null +++ b/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_driver.h @@ -0,0 +1,71 @@ +/* + * @Description: Q10f吊舱的驱动文件 + * @Author: L LC @amov + * @Date: 2022-10-28 12:24:21 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-03-28 17:01:00 + * @FilePath: /gimbal-sdk-multi-platform/src/Q10f/Q10f_gimbal_driver.h + */ +#include "../amov_gimbal.h" +#include "Q10f_gimbal_struct.h" +#include +#include +#include + +#ifndef __Q10F_DRIVER_H +#define __Q10F_DRIVER_H + +extern "C" +{ +#include "Ring_Fifo.h" +} + +class Q10fGimbalDriver : protected amovGimbal::IamovGimbalBase +{ +private: + Q10f::GIMBAL_SERIAL_STATE_T parserState; + Q10f::GIMBAL_FRAME_T rx; + Q10f::GIMBAL_FRAME_T tx; + + std::mutex rxMutex; + uint8_t *rxBuffer; + RING_FIFO_CB_T rxQueue; + std::mutex txMutex; + uint8_t *txBuffer; + RING_FIFO_CB_T txQueue; + + bool parser(IN uint8_t byte); + void send(void); + + void convert(void *buf); + uint32_t pack(IN uint32_t cmd, uint8_t *pPayload, uint8_t payloadSize); + bool getRxPack(OUT void *pack); + +public: + // funtions + uint32_t setGimabalPos(const amovGimbal::AMOV_GIMBAL_POS_T &pos); + uint32_t setGimabalSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &speed); + uint32_t setGimabalFollowSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &followSpeed); + uint32_t setGimabalHome(void); + + uint32_t setGimbalZoom(amovGimbal::AMOV_GIMBAL_ZOOM_T zoom, float targetRate = 0); + uint32_t setGimbalFocus(amovGimbal::AMOV_GIMBAL_ZOOM_T zoom, float targetRate = 0); + + uint32_t takePic(void); + uint32_t setVideo(const amovGimbal::AMOV_GIMBAL_VIDEO_T newState); + + // builds + static amovGimbal::IamovGimbalBase *creat(amovGimbal::IOStreamBase *_IO) + { + return new Q10fGimbalDriver(_IO); + } + + Q10fGimbalDriver(amovGimbal::IOStreamBase *_IO); + ~Q10fGimbalDriver() + { + free(rxBuffer); + free(txBuffer); + } +}; + +#endif diff --git a/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_funtion.cpp b/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_funtion.cpp new file mode 100755 index 0000000..dff1e8d --- /dev/null +++ b/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_funtion.cpp @@ -0,0 +1,180 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2023-03-02 10:00:52 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-03-29 11:47:18 + * @FilePath: /gimbal-sdk-multi-platform/src/Q10f/Q10f_gimbal_funtion.cpp + */ +#include "Q10f_gimbal_driver.h" +#include "Q10f_gimbal_crc32.h" +#include "string.h" + +/** + * It sets the gimbal position. + * + * @param pos the position of the gimbal + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t Q10fGimbalDriver::setGimabalPos(const amovGimbal::AMOV_GIMBAL_POS_T &pos) +{ + Q10f::GIMBAL_SET_POS_MSG_T temp; + temp.modeR = Q10f::GIMBAL_CMD_POS_MODE_ANGLE_SPEED; + temp.modeP = Q10f::GIMBAL_CMD_POS_MODE_ANGLE_SPEED; + temp.modeY = Q10f::GIMBAL_CMD_POS_MODE_ANGLE_SPEED; + + temp.angleP = pos.pitch / Q10F_SCALE_FACTOR_ANGLE; + temp.angleR = pos.roll / Q10F_SCALE_FACTOR_ANGLE; + temp.angleY = pos.yaw / Q10F_SCALE_FACTOR_ANGLE; + temp.speedP = state.maxFollow.pitch; + temp.speedR = state.maxFollow.roll; + temp.speedY = state.maxFollow.yaw; + return pack(Q10f::GIMBAL_CMD_SET_POS, reinterpret_cast(&temp), sizeof(Q10f::GIMBAL_SET_POS_MSG_T)); +} + +/** + * It takes a struct of type amovGimbal::AMOV_GIMBAL_POS_T and converts it to a struct of type + * G1::GIMBAL_SET_POS_MSG_T + * + * @param speed the speed of the gimbal + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t Q10fGimbalDriver::setGimabalSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &speed) +{ + Q10f::GIMBAL_SET_POS_MSG_T temp; + temp.modeR = Q10f::GIMBAL_CMD_POS_MODE_SPEED; + temp.modeP = Q10f::GIMBAL_CMD_POS_MODE_SPEED; + temp.modeY = Q10f::GIMBAL_CMD_POS_MODE_SPEED; + + temp.angleP = 0; + temp.angleR = 0; + temp.angleY = 0; + temp.speedP = speed.pitch / 0.1220740379f; + temp.speedR = speed.roll / 0.1220740379f; + temp.speedY = speed.yaw / 0.1220740379f; + return pack(Q10f::GIMBAL_CMD_SET_POS, reinterpret_cast(&temp), sizeof(Q10f::GIMBAL_SET_POS_MSG_T)); +} + +/** + * This function sets the gimbal's follow speed + * + * @param followSpeed the speed of the gimbal + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t Q10fGimbalDriver::setGimabalFollowSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &followSpeed) +{ + state.maxFollow.pitch = followSpeed.pitch / 0.1220740379f; + state.maxFollow.roll = followSpeed.roll / 0.1220740379f; + state.maxFollow.yaw = followSpeed.yaw / 0.1220740379f; + return 0; +} + +/** + * This function sets the gimbal to its home position + * + * @return The return value is the number of bytes written to the buffer. + */ +uint32_t Q10fGimbalDriver::setGimabalHome(void) +{ + // amovGimbal::AMOV_GIMBAL_POS_T home; + // home.pitch = 0; + // home.roll = 0; + // home.yaw = 0; + // return setGimabalPos(home); + + const static uint8_t cmd[5] = {0X00, 0X00, 0X03, 0X03, 0XFF}; + return pack(Q10f::GIMBAL_CMD_HOME, (uint8_t *)cmd, sizeof(cmd)); +} + +/** + * It takes a picture. + * + * @return The return value is the number of bytes written to the serial port. + */ +uint32_t Q10fGimbalDriver::takePic(void) +{ + const static uint8_t cmd[2] = {0X01, 0XFF}; + + return pack(Q10f::GIMBAL_CMD_CAMERA, (uint8_t *)cmd, sizeof(cmd)); +} + +/** + * The function sets the video state of the gimbal + * + * @param newState The new state of the video. + * + * @return The return value is the number of bytes written to the serial port. + */ +uint32_t Q10fGimbalDriver::setVideo(const amovGimbal::AMOV_GIMBAL_VIDEO_T newState) +{ + uint8_t cmd[2] = {0X01, 0XFF}; + + if (newState == amovGimbal::AMOV_GIMBAL_VIDEO_TAKE) + { + cmd[0] = 0X02; + state.video = amovGimbal::AMOV_GIMBAL_VIDEO_TAKE; + } + else + { + cmd[0] = 0X03; + state.video = amovGimbal::AMOV_GIMBAL_VIDEO_OFF; + } + + return pack(Q10f::GIMBAL_CMD_CAMERA, (uint8_t *)cmd, sizeof(cmd)); +} + +uint32_t Q10fGimbalDriver::setGimbalZoom(amovGimbal::AMOV_GIMBAL_ZOOM_T zoom, float targetRate) +{ + uint8_t cmd[5] = {0X00, 0X00, 0X00, 0X00, 0XFF}; + if (targetRate == 0.0f) + { + cmd[1] = 0XFF; + switch (zoom) + { + case amovGimbal::AMOV_GIMBAL_ZOOM_IN: + cmd[0] = Q10f::GIMBAL_CMD_ZOOM_IN; + break; + case amovGimbal::AMOV_GIMBAL_ZOOM_OUT: + cmd[0] = Q10f::GIMBAL_CMD_ZOOM_OUT; + break; + case amovGimbal::AMOV_GIMBAL_ZOOM_STOP: + cmd[0] = Q10f::GIMBAL_CMD_ZOOM_STOP; + break; + default: + break; + } + return pack(Q10f::GIMBAL_CMD_ZOOM, (uint8_t *)cmd, 2); + } + else + { + uint16_t count = (targetRate / Q10F_MAX_ZOOM) * Q10F_MAX_ZOOM_COUNT; + cmd[0] = count & 0XF000 >> 12; + cmd[1] = count & 0X0F00 >> 8; + cmd[2] = count & 0X00F0 >> 4; + cmd[3] = count & 0X000F >> 0; + return pack(Q10f::GIMBAL_CMD_ZOOM_DIRECT, (uint8_t *)cmd, 5); + } +} + +uint32_t Q10fGimbalDriver::setGimbalFocus(amovGimbal::AMOV_GIMBAL_ZOOM_T zoom, float targetRate) +{ + uint8_t cmd[2] = {0X00, 0XFF}; + switch (zoom) + { + case amovGimbal::AMOV_GIMBAL_ZOOM_IN: + cmd[0] = Q10f::GIMBAL_CMD_ZOOM_IN; + break; + case amovGimbal::AMOV_GIMBAL_ZOOM_OUT: + cmd[0] = Q10f::GIMBAL_CMD_ZOOM_OUT; + break; + case amovGimbal::AMOV_GIMBAL_ZOOM_STOP: + cmd[0] = Q10f::GIMBAL_CMD_ZOOM_STOP; + break; + default: + break; + } + return pack(Q10f::GIMBAL_CMD_FOCUS, (uint8_t *)cmd, 2); +} diff --git a/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_struct.h b/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_struct.h new file mode 100755 index 0000000..c5061d8 --- /dev/null +++ b/gimbal_ctrl/driver/src/Q10f/Q10f_gimbal_struct.h @@ -0,0 +1,105 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2022-10-27 18:10:07 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-03-28 18:15:47 + * @FilePath: /gimbal-sdk-multi-platform/src/Q10f/Q10f_gimbal_struct.h + */ +#ifndef Q10F_GIMBAL_STRUCT_H +#define Q10F_GIMBAL_STRUCT_H + +#include +namespace Q10f +{ +#define Q10F_MAX_GIMBAL_PAYLOAD 64 +#define Q10F_PAYLOAD_OFFSET 4 +#define Q10F_SCALE_FACTOR_ANGLE 0.02197f +#define Q10F_SCALE_FACTOR_SPEED 0.06103f +#define Q10F_MAX_ZOOM 10.0f +#define Q10F_MAX_ZOOM_COUNT 0X4000 + + typedef enum + { + GIMBAL_CMD_SET_POS = 0X100F01FF, + GIMBAL_CMD_GET = 0X3D003D3E, + GIMBAL_CMD_FOCUS = 0X08040181, + GIMBAL_CMD_ZOOM = 0X07040181, + GIMBAL_CMD_ZOOM_DIRECT = 0X47040181, + GIMBAL_CMD_HOME = 0X010A0181, + GIMBAL_CMD_CAMERA = 0X68040181, + GIMBAL_CMD_RCV_STATE = 0X721A583E, + GIMBAL_CMD_SET_FEEDBACK_L = 0X143055AA, + GIMBAL_CMD_SET_FEEDBACK_H = 0X003155AA, + GIMBAL_CMD_OPEN_FEEDBACK =0X3E003E3E, + GIMBAL_CMD_CLOSE_FEEDBACK =0X3D003D3E, + } GIMBAL_CMD_T; + + typedef enum + { + GIMBAL_CMD_POS_MODE_NO = 0X00, + GIMBAL_CMD_POS_MODE_SPEED = 0X01, + GIMBAL_CMD_POS_MODE_ANGLE = 0X02, + GIMBAL_CMD_POS_MODE_ANGLE_SPEED = 0X03, + } GIMBAL_CMD_POS_MODE_T; + + typedef enum + { + GIMBAL_CMD_ZOOM_IN = 0X27, + GIMBAL_CMD_ZOOM_OUT = 0X37, + GIMBAL_CMD_ZOOM_STOP = 0X00, + } GIMBAL_CMD_ZOOM_T; + + typedef enum + { + GIMBAL_SERIAL_STATE_IDLE, + GIMBAL_SERIAL_STATE_HEAD1, + GIMBAL_SERIAL_STATE_HEAD2, + GIMBAL_SERIAL_STATE_HEAD3, + GIMBAL_SERIAL_STATE_DATE, + GIMBAL_SERIAL_STATE_CHECK + } GIMBAL_SERIAL_STATE_T; + +#pragma pack(1) + typedef struct + { + uint32_t head; + uint8_t data[Q10F_MAX_GIMBAL_PAYLOAD]; + uint8_t checkSum; + uint8_t len; + } GIMBAL_FRAME_T; + + typedef struct + { + uint8_t modeR; + uint8_t modeP; + uint8_t modeY; + int16_t speedR; + int16_t angleR; + int16_t speedP; + int16_t angleP; + int16_t speedY; + int16_t angleY; + } GIMBAL_SET_POS_MSG_T; + + typedef struct + { + uint16_t timeStamp; + int16_t rollIMUAngle; + int16_t pitchIMUAngle; + int16_t yawIMUAngle; + int16_t rollTAGAngle; + int16_t pitchTAGAngle; + int16_t yawTAGAngle; + int16_t rollTAGSpeed; + int16_t pitchTAGSpeed; + int16_t yawTAGSpeed; + int16_t rollStatorRotorAngle; + int16_t pitchStatorRotorAngle; + int16_t yawStatorRotorAngle; + } GIMBAL_RCV_POS_MSG_T; + +#pragma pack() + +} +#endif \ No newline at end of file diff --git a/gimbal_ctrl/driver/src/amov_gimabl.cpp b/gimbal_ctrl/driver/src/amov_gimabl.cpp new file mode 100755 index 0000000..cc0d7f3 --- /dev/null +++ b/gimbal_ctrl/driver/src/amov_gimabl.cpp @@ -0,0 +1,239 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2022-10-28 11:54:11 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-11 18:13:25 + * @FilePath: /gimbal-sdk-multi-platform/src/amov_gimabl.cpp + */ + +#include "amov_gimbal.h" +#include "g1_gimbal_driver.h" +#include "g2_gimbal_driver.h" +#include "Q10f_gimbal_driver.h" + +#include +#include +#include +#include + +#define MAX_PACK_SIZE 280 +typedef enum +{ + AMOV_GIMBAL_TYPE_NULL, + AMOV_GIMBAL_TYPE_G1 = 1, + AMOV_GIMBAL_TYPE_G2, + AMOV_GIMBAL_TYPE_Q10, +} AMOV_GIMBAL_TYPE_T; + +namespace amovGimbal +{ + typedef amovGimbal::IamovGimbalBase *(*createCallback)(amovGimbal::IOStreamBase *_IO); + typedef std::map callbackMap; + std::map amovGimbalTypeList = + { + {"G1", AMOV_GIMBAL_TYPE_G1}, + {"G2", AMOV_GIMBAL_TYPE_G2}, + {"Q10f", AMOV_GIMBAL_TYPE_Q10}}; + + callbackMap amovGimbals = + { + {"G1", g1GimbalDriver::creat}, + {"G2", g2GimbalDriver::creat}, + {"Q10f", Q10fGimbalDriver::creat}}; +} + +/* The amovGimbalCreator class is a factory class that creates an instance of the amovGimbal class */ +// Factory used to create the gimbal instance +class amovGimbalCreator +{ +public: + static amovGimbal::IamovGimbalBase *createAmovGimbal(const std::string &type, amovGimbal::IOStreamBase *_IO) + { + amovGimbal::callbackMap::iterator temp = amovGimbal::amovGimbals.find(type); + + if (temp != amovGimbal::amovGimbals.end()) + { + return (temp->second)(_IO); + } + std::cout << type << " is Unsupported device type!" << std::endl; + return NULL; + } + +private: + amovGimbalCreator() + { + } + static amovGimbalCreator *pInstance; + static amovGimbalCreator *getInstance() + { + if (pInstance == NULL) + { + pInstance = new amovGimbalCreator(); + } + return pInstance; + } + + ~amovGimbalCreator(); +}; + +/** + * "If the input byte is available, then parse it." + * + * The function is a loop that runs forever. It calls the IO->inPutByte() function to get a byte from + * the serial port. If the byte is available, then it calls the parser() function to parse the byte + */ +void amovGimbal::IamovGimbalBase::parserLoop(void) +{ + uint8_t temp; + + while (1) + { + if (IO->inPutByte(&temp)) + { + parser(temp); + } + } +} + +void amovGimbal::IamovGimbalBase::sendLoop(void) +{ + while (1) + { + send(); + } +} + +void amovGimbal::IamovGimbalBase::mainLoop(void) +{ + uint8_t tempBuffer[MAX_PACK_SIZE]; + + while (1) + { + if (getRxPack(tempBuffer)) + { + convert(tempBuffer); + } + } +} + +/** + * It starts two threads, one for reading data from the serial port and one for sending data to the + * serial port + */ +void amovGimbal::IamovGimbalBase::startStack(void) +{ + if (!IO->isOpen()) + { + IO->open(); + } + + std::thread mainLoop(&IamovGimbalBase::parserLoop, this); + std::thread sendLoop(&IamovGimbalBase::sendLoop, this); + mainLoop.detach(); + sendLoop.detach(); +} + +/** + * The function creates a thread that runs the mainLoop function + */ +void amovGimbal::IamovGimbalBase::parserAuto(pStateInvoke callback) +{ + this->updateGimbalStateCallback = callback; + std::thread mainLoop(&IamovGimbalBase::mainLoop, this); + mainLoop.detach(); +} + +amovGimbal::AMOV_GIMBAL_STATE_T amovGimbal::IamovGimbalBase::getGimabalState(void) +{ + mState.lock(); + AMOV_GIMBAL_STATE_T temp = state; + mState.unlock(); + return temp; +} + +amovGimbal::IamovGimbalBase::~IamovGimbalBase() +{ + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + IO->close(); +} + +/** + * Default implementation of interface functions, not pure virtual functions for ease of extension. + */ +void amovGimbal::IamovGimbalBase::nodeSet(SET uint32_t _self, SET uint32_t _remote) +{ + return; +} + +uint32_t amovGimbal::IamovGimbalBase::setGimabalPos(const amovGimbal::AMOV_GIMBAL_POS_T &pos) +{ + return 0; +} + +uint32_t amovGimbal::IamovGimbalBase::setGimabalSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &speed) +{ + return 0; +} + +uint32_t amovGimbal::IamovGimbalBase::setGimabalFollowSpeed(const amovGimbal::AMOV_GIMBAL_POS_T &followSpeed) +{ + return 0; +} + +uint32_t amovGimbal::IamovGimbalBase::setGimabalHome(void) +{ + return 0; +} + +uint32_t amovGimbal::IamovGimbalBase::setGimbalZoom(amovGimbal::AMOV_GIMBAL_ZOOM_T zoom, float targetRate) +{ + return 0; +} + +uint32_t amovGimbal::IamovGimbalBase::setGimbalFocus(amovGimbal::AMOV_GIMBAL_ZOOM_T zoom, float targetRate) +{ + return 0; +} + +uint32_t amovGimbal::IamovGimbalBase::setGimbalROI(const amovGimbal::AMOV_GIMBAL_ROI_T area) +{ + return 0; +} + +uint32_t amovGimbal::IamovGimbalBase::takePic(void) +{ + return 0; +} + +uint32_t amovGimbal::IamovGimbalBase::setVideo(const amovGimbal::AMOV_GIMBAL_VIDEO_T newState) +{ + return 0; +} + +/** + * The function creates a new gimbal object, which is a pointer to a new amovGimbal object, which is a + * pointer to a new Gimbal object, which is a pointer to a new IOStreamBase object + * + * @param type the type of the device, which is the same as the name of the class + * @param _IO The IOStreamBase object that is used to communicate with the device. + * @param _self the node ID of the device + * @param _remote the node ID of the remote device + */ +amovGimbal::gimbal::gimbal(const std::string &type, IOStreamBase *_IO, + uint32_t _self, uint32_t _remote) +{ + typeName = type; + IO = _IO; + + dev = amovGimbalCreator::createAmovGimbal(typeName, IO); + + dev->nodeSet(_self, _remote); +} + +amovGimbal::gimbal::~gimbal() +{ + // 先干掉请求线程 + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + delete dev; +} diff --git a/gimbal_ctrl/driver/src/amov_gimbal.h b/gimbal_ctrl/driver/src/amov_gimbal.h new file mode 100755 index 0000000..52a42f1 --- /dev/null +++ b/gimbal_ctrl/driver/src/amov_gimbal.h @@ -0,0 +1,118 @@ +/* + * @Description: External interface of amov gimbals + * @Author: L LC @amov + * @Date: 2022-10-27 18:34:26 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-18 11:42:05 + * @FilePath: /spirecv-gimbal-sdk/gimbal_ctrl/driver/src/amov_gimbal.h + */ + +#ifndef AMOV_GIMBAL_H +#define AMOV_GIMBAL_H + +#include +#include +#include + +#include +#include +#include + +#include "amov_gimbal_struct.h" + +#define MAX_QUEUE_SIZE 50 + +namespace amovGimbal +{ +#define IN +#define OUT +#define SET + + static inline void idleCallback(double &frameAngleRoll, double &frameAnglePitch, double &frameAngleYaw, + double &imuAngleRoll, double &imuAnglePitch, double &imuAngleYaw, + double &fovX, double &fovY) + { + } + + // Control data input and output + class IOStreamBase + { + public: + IOStreamBase() {} + virtual ~IOStreamBase() {} + + virtual bool open() = 0; + virtual bool close() = 0; + virtual bool isOpen() = 0; + virtual bool isBusy() = 0; + // These two functions need to be thread-safe + virtual bool inPutByte(IN uint8_t *byte) = 0; + virtual uint32_t outPutBytes(IN uint8_t *byte, uint32_t lenght) = 0; + }; + + class IamovGimbalBase + { + protected: + AMOV_GIMBAL_STATE_T state; + std::mutex mState; + IOStreamBase *IO; + pStateInvoke updateGimbalStateCallback; + + virtual bool parser(IN uint8_t byte) = 0; + virtual void send(void) = 0; + virtual void convert(void *buf) = 0; + virtual uint32_t pack(IN uint32_t cmd, uint8_t *pPayload, uint8_t payloadSize) = 0; + virtual bool getRxPack(OUT void *pack) = 0; + + void parserLoop(void); + void sendLoop(void); + void mainLoop(void); + + public: + IamovGimbalBase(SET IOStreamBase *_IO) + { + IO = _IO; + } + virtual ~IamovGimbalBase(); + + void setParserCallback(pStateInvoke callback) + { + this->updateGimbalStateCallback = callback; + } + + // Protocol stack function items + virtual void startStack(void); + virtual void parserAuto(pStateInvoke callback = idleCallback); + virtual void nodeSet(SET uint32_t _self, SET uint32_t _remote); + + // functions + virtual AMOV_GIMBAL_STATE_T getGimabalState(void); + virtual uint32_t setGimabalPos(const AMOV_GIMBAL_POS_T &pos); + virtual uint32_t setGimabalSpeed(const AMOV_GIMBAL_POS_T &speed); + virtual uint32_t setGimabalFollowSpeed(const AMOV_GIMBAL_POS_T &followSpeed); + virtual uint32_t setGimabalHome(void); + virtual uint32_t setGimbalZoom(AMOV_GIMBAL_ZOOM_T zoom, float targetRate = 0); + virtual uint32_t setGimbalFocus(AMOV_GIMBAL_ZOOM_T zoom, float targetRate = 0); + virtual uint32_t setGimbalROI(const AMOV_GIMBAL_ROI_T area); + virtual uint32_t takePic(void); + virtual uint32_t setVideo(const AMOV_GIMBAL_VIDEO_T newState); + }; + + class gimbal + { + private: + std::string typeName; + IOStreamBase *IO; + + public: + IamovGimbalBase *dev; + std::string name() + { + return typeName; + } + gimbal(const std::string &type, IOStreamBase *_IO, + uint32_t _self = 0x02, uint32_t _remote = 0X80); + ~gimbal(); + }; +} +#endif \ No newline at end of file diff --git a/gimbal_ctrl/driver/src/amov_gimbal_struct.h b/gimbal_ctrl/driver/src/amov_gimbal_struct.h new file mode 100755 index 0000000..f7bebae --- /dev/null +++ b/gimbal_ctrl/driver/src/amov_gimbal_struct.h @@ -0,0 +1,74 @@ +/* + * @Description: Common Data Structures of gimbal + * @Author: L LC @amov + * @Date: 2022-10-31 11:56:43 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-18 10:12:33 + * @FilePath: /gimbal-sdk-multi-platform/src/amov_gimbal_struct.h + */ + +#include + +#ifndef __AMOV_GIMABL_STRUCT_H +#define __AMOV_GIMABL_STRUCT_H + +namespace amovGimbal +{ + typedef void (*pStateInvoke)(double &frameAngleRoll, double &frameAnglePitch, double &frameAngleYaw, + double &imuAngleRoll, double &imuAnglePitch, double &imuAngleYaw, + double &fovX, double &fovY); + + typedef enum + { + AMOV_GIMBAL_MODE_LOCK, + AMOV_GIMBAL_MODE_NULOCK, + } AMOV_GIMBAL_MODE_T; + + typedef enum + { + AMOV_GIMBAL_VIDEO_TAKE, + AMOV_GIMBAL_VIDEO_OFF + } AMOV_GIMBAL_VIDEO_T; + + typedef enum + { + AMOV_GIMBAL_ZOOM_IN, + AMOV_GIMBAL_ZOOM_OUT, + AMOV_GIMBAL_ZOOM_STOP + } AMOV_GIMBAL_ZOOM_T; + + typedef struct + { + double yaw; + double roll; + double pitch; + } AMOV_GIMBAL_POS_T; + + typedef struct + { + double x; + double y; + }AMOV_GIMBAL_FOV_T; + + + typedef struct + { + AMOV_GIMBAL_MODE_T workMode; + AMOV_GIMBAL_VIDEO_T video; + AMOV_GIMBAL_POS_T abs; + AMOV_GIMBAL_POS_T rel; + AMOV_GIMBAL_POS_T maxFollow; + AMOV_GIMBAL_FOV_T fov; + } AMOV_GIMBAL_STATE_T; + + typedef struct + { + uint32_t centreX; + uint32_t centreY; + uint32_t hight; + uint32_t width; + } AMOV_GIMBAL_ROI_T; + +} // namespace amovGimbal + +#endif diff --git a/gimbal_ctrl/sv_gimbal.cpp b/gimbal_ctrl/sv_gimbal.cpp new file mode 100644 index 0000000..233ef41 --- /dev/null +++ b/gimbal_ctrl/sv_gimbal.cpp @@ -0,0 +1,411 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2023-04-12 09:12:52 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-18 11:37:42 + * @FilePath: /spirecv-gimbal-sdk/gimbal_ctrl/sv_gimbal.cpp + */ +#include "amov_gimbal.h" +#include "amov_gimbal_struct.h" + +#include "sv_gimbal.h" +#include "sv_gimbal_io.hpp" + +#include +#include +#include + +/** + * This function sets the serial port for a Gimbal object. + * + * @param port The parameter "port" is a constant reference to a string object. It is used to set the + * serial port for a Gimbal object. + */ +void sv::Gimbal::setSerialPort(const std::string &port) +{ + this->m_serial_port = port; +} + +/** + * This function sets the baud rate for the serial port of a Gimbal object. + * + * @param baud_rate baud_rate is an integer parameter that represents the baud rate (bits per second) + * for the serial port. It is used to set the communication speed between the Gimbal and the device it + * is connected to via the serial port. + */ +void sv::Gimbal::setSerialPort(int baud_rate) +{ + this->m_serial_baud_rate = baud_rate; +} + +/** + * This function sets the serial port parameters for a Gimbal object. + * + * @param byte_size The number of bits in each byte of serial data. It can be 5, 6, 7, 8, or 9 bits. + * @param parity Parity refers to the method of error detection in serial communication. It is used to + * ensure that the data transmitted between two devices is accurate and error-free. There are three + * types of parity: even, odd, and none. Even parity means that the number of 1s in the data byte plus + * @param stop_bits Stop bits refer to the number of bits used to indicate the end of a character. It + * is a parameter used in serial communication to ensure that the receiver knows when a character has + * ended. Common values for stop bits are 1 or 2. + * @param flowcontrol GimablSerialFlowControl is an enumeration type that represents the flow control + * settings for the serial port. It can have one of the following values: + * @param time_out The time-out parameter is an integer value that specifies the maximum amount of time + * to wait for a response from the serial port before timing out. If no response is received within + * this time, the function will return an error. + */ +void sv::Gimbal::setSerialPort(GimablSerialByteSize byte_size, GimablSerialParity parity, + GimablSerialStopBits stop_bits, GimablSerialFlowControl flowcontrol, + int time_out) +{ + this->m_serial_byte_size = (int)byte_size; + this->m_serial_parity = (int)parity; + this->m_serial_stopbits = (int)stop_bits; + this->m_serial_flowcontrol = (int)flowcontrol; + this->m_serial_timeout = (int)time_out; +} + +/** + * This function sets the network IP address for a Gimbal object in C++. + * + * @param ip The parameter "ip" is a constant reference to a string. It is used to set the value of the + * member variable "m_net_ip" in the class "Gimbal". + */ +void sv::Gimbal::setNetIp(const std::string &ip) +{ + this->m_net_ip = ip; +} + +/** + * This function sets the network port for a Gimbal object in C++. + * + * @param port The "port" parameter is an integer value that represents the network port number that + * the Gimbal object will use for communication. This function sets the value of the "m_net_port" + * member variable of the Gimbal object to the value passed in as the "port" parameter. + */ +void sv::Gimbal::setNetPort(const int &port) +{ + this->m_net_port = port; +} + +/** + * The function sets a parser callback for a gimbal device. + * + * @param callback callback is a function pointer of type sv::PStateInvoke. It is a callback function + * that will be invoked when the state of the Gimbal device changes. The function takes a single + * parameter of type sv::PState, which represents the new state of the Gimbal device. + */ +void sv::Gimbal::setStateCallback(sv::PStateInvoke callback) +{ + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + pdevTemp->dev->setParserCallback(callback); +} + +/** + * The function opens a communication interface with a gimbal device and sets up a parser to handle + * incoming data. + * + * @param callback callback is a function pointer to a PStateInvoke function, which is a callback + * function that will be invoked when the gimbal receives a new packet of data. The function takes in a + * PState object as its argument, which contains the current state of the gimbal. The purpose of the + * callback function + * + * @return A boolean value is being returned. + */ +bool sv::Gimbal::open(PStateInvoke callback) +{ + if (this->m_gimbal_link == GimbalLink::SERIAL) + { + this->IO = new UART(this->m_serial_port, + (uint32_t)this->m_serial_baud_rate, + serial::Timeout::simpleTimeout(this->m_serial_timeout), + (serial::bytesize_t)this->m_serial_byte_size, + (serial::parity_t)this->m_serial_parity, + (serial::stopbits_t)this->m_serial_stopbits, + (serial::flowcontrol_t)this->m_serial_flowcontrol); + } + // Subsequent additions + else if (this->m_gimbal_link == sv::GimbalLink::ETHERNET_TCP) + { + return false; + } + else if (this->m_gimbal_link == sv::GimbalLink::ETHERNET_UDP) + { + return false; + } + else + { + throw "Error: Unsupported communication interface class!!!"; + return false; + } + std::string driverName; + switch (this->m_gimbal_type) + { + case sv::GimbalType::G1: + driverName = "G1"; + break; + case sv::GimbalType::Q10f: + driverName = "Q10f"; + break; + + default: + throw "Error: Unsupported driver!!!"; + return false; + break; + } + this->dev = new amovGimbal::gimbal(driverName, (amovGimbal::IOStreamBase *)this->IO); + + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + + pdevTemp->dev->startStack(); + pdevTemp->dev->parserAuto(callback); + + return true; +} + +/** + * This function sets the home position of a gimbal device and returns a boolean value indicating + * success or failure. + * + * @return A boolean value is being returned. If the function call `setGimabalHome()` returns a value + * greater than 0, then `true` is returned. Otherwise, `false` is returned. + */ +bool sv::Gimbal::setHome() +{ + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + if (pdevTemp->dev->setGimabalHome() > 0) + { + return true; + } + else + { + return false; + } +} + +/** + * This function sets the zoom level of a gimbal device and returns a boolean indicating success or + * failure. + * + * @param x The zoom level to set for the gimbal. It should be a positive double value. + * + * @return This function returns a boolean value. It returns true if the gimbal zoom is successfully + * set to the specified value, and false if the specified value is less than or equal to zero or if the + * setGimbalZoom function call returns a value less than or equal to zero. + */ +bool sv::Gimbal::setZoom(double x) +{ + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + + if (x <= 0.0) + { + return false; + } + + if (pdevTemp->dev->setGimbalZoom(amovGimbal::AMOV_GIMBAL_ZOOM_STOP, x) > 0) + { + return true; + } + else + { + return false; + } +} + +/** + * This function sets the auto zoom state of a gimbal device. + * + * @param state The state parameter is an integer that represents the desired state of the auto zoom + * feature. It is used to enable or disable the auto zoom feature of the gimbal. A value of 1 enables + * the auto zoom feature, while a value of 0 disables it. + * + * @return This function returns a boolean value. It returns true if the setGimbalZoom function call is + * successful and false if it fails. + */ +bool sv::Gimbal::setAutoZoom(int state) +{ + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + + if (pdevTemp->dev->setGimbalZoom((amovGimbal::AMOV_GIMBAL_ZOOM_T)state, 0.0f) > 0) + { + return true; + } + else + { + return false; + } +} + +/** + * This function sets the autofocus state of a gimbal device and returns a boolean indicating success + * or failure. + * + * @param state The state parameter is an integer that represents the desired autofocus state. It is + * likely that a value of 1 represents autofocus enabled and a value of 0 represents autofocus + * disabled, but without more context it is impossible to say for certain. + * + * @return This function returns a boolean value. It returns true if the setGimbalFocus function call + * is successful and returns a value greater than 0, and false if the function call fails and returns a + * value less than or equal to 0. + */ +bool sv::Gimbal::setAutoFocus(int state) +{ + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + + if (pdevTemp->dev->setGimbalFocus((amovGimbal::AMOV_GIMBAL_ZOOM_T)state, 0.0f) > 0) + { + return true; + } + else + { + return false; + } +} + +/** + * The function takes a photo using a gimbal device and returns true if successful, false otherwise. + * + * @return A boolean value is being returned. It will be true if the function call to takePic() returns + * a value greater than 0, and false otherwise. + */ +bool sv::Gimbal::takePhoto() +{ + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + + if (pdevTemp->dev->takePic() > 0) + { + return true; + } + else + { + return false; + } +} + +/** + * The function takes a state parameter and sets the video state of a gimbal device accordingly. + * + * @param state The state parameter is an integer that determines the desired state of the video + * recording function of the Gimbal device. It can have two possible values: 0 for turning off the + * video recording and 1 for starting the video recording. + * + * @return a boolean value. It returns true if the video state was successfully set to the desired + * state (either off or take), and false if there was an error in setting the state. + */ +bool sv::Gimbal::takeVideo(int state) +{ + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + + amovGimbal::AMOV_GIMBAL_VIDEO_T newState; + switch (state) + { + case 0: + newState = amovGimbal::AMOV_GIMBAL_VIDEO_OFF; + break; + case 1: + newState = amovGimbal::AMOV_GIMBAL_VIDEO_TAKE; + break; + default: + newState = amovGimbal::AMOV_GIMBAL_VIDEO_OFF; + break; + } + + if (pdevTemp->dev->setVideo(newState) > 0) + { + return true; + } + else + { + return false; + } +} + +/** + * This function returns the current state of the video on a gimbal device. + * + * @return an integer value that represents the state of the video on the gimbal. If the video is being + * taken, it returns 1. If the video is off, it returns 0. If the state is unknown, it throws an + * exception with the message "Unknown state information!!!". + */ +int sv::Gimbal::getVideoState() +{ + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + int ret; + amovGimbal::AMOV_GIMBAL_STATE_T state; + state = pdevTemp->dev->getGimabalState(); + if (state.video == amovGimbal::AMOV_GIMBAL_VIDEO_TAKE) + { + ret = 1; + } + else if (state.video == amovGimbal::AMOV_GIMBAL_VIDEO_OFF) + { + ret = 0; + } + else + { + throw "Unknown state information!!!"; + } + return ret; +} + +/** + * The function sets the angle and rate of a gimbal using Euler angles. + * + * @param roll The desired roll angle of the gimbal in degrees. + * @param pitch The desired pitch angle of the gimbal in degrees. + * @param yaw The desired yaw angle in degrees. Yaw is the rotation around the vertical axis. + * @param roll_rate The rate at which the gimbal should rotate around the roll axis, in degrees per + * second. + * @param pitch_rate The desired pitch rotation rate in degrees per second. If it is set to 0, it will + * default to 360 degrees per second. + * @param yaw_rate The rate at which the yaw angle of the gimbal should change, in degrees per second. + */ +void sv::Gimbal::setAngleEuler(double roll, double pitch, double yaw, + double roll_rate, double pitch_rate, double yaw_rate) +{ + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + + amovGimbal::AMOV_GIMBAL_POS_T temp; + + if (pitch_rate == 0.0f) + pitch_rate = 360; + if (roll_rate == 0.0f) + roll_rate = 360; + if (yaw_rate == 0.0f) + yaw_rate = 360; + + temp.pitch = pitch_rate; + temp.roll = roll_rate; + temp.yaw = yaw_rate; + pdevTemp->dev->setGimabalFollowSpeed(temp); + temp.pitch = pitch; + temp.roll = roll; + temp.yaw = yaw; + pdevTemp->dev->setGimabalPos(temp); +} + +/** + * This function sets the angle rate of a gimbal using Euler angles. + * + * @param roll_rate The rate of change of the roll angle of the gimbal, measured in degrees per second. + * @param pitch_rate The rate of change of pitch angle in degrees per second. + * @param yaw_rate The rate of change of the yaw angle of the gimbal. Yaw is the rotation of the gimbal + * around the vertical axis. + */ +void sv::Gimbal::setAngleRateEuler(double roll_rate, double pitch_rate, double yaw_rate) +{ + amovGimbal::gimbal *pdevTemp = (amovGimbal::gimbal *)this->dev; + + amovGimbal::AMOV_GIMBAL_POS_T temp; + temp.pitch = pitch_rate; + temp.roll = roll_rate; + temp.yaw = yaw_rate; + pdevTemp->dev->setGimabalSpeed(temp); +} + +sv::Gimbal::~Gimbal() +{ + delete (amovGimbal::gimbal *)this->dev; + delete (amovGimbal::IOStreamBase *)this->IO; +} diff --git a/gimbal_ctrl/sv_gimbal_io.hpp b/gimbal_ctrl/sv_gimbal_io.hpp new file mode 100644 index 0000000..0b0e1c7 --- /dev/null +++ b/gimbal_ctrl/sv_gimbal_io.hpp @@ -0,0 +1,68 @@ +/* + * @Description: + * @Author: L LC @amov + * @Date: 2023-04-12 12:22:09 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-13 10:17:21 + * @FilePath: /spirecv-gimbal-sdk/gimbal_ctrl/sv_gimbal_io.hpp + */ +#ifndef __SV_GIMABL_IO_H +#define __SV_GIMABL_IO_H + +#include "amov_gimbal.h" +#include "serial/serial.h" + +class UART : public amovGimbal::IOStreamBase +{ +private: + serial::Serial *ser; + +public: + virtual bool open() + { + ser->open(); + return true; + } + virtual bool close() + { + ser->close(); + return true; + } + virtual bool isOpen() + { + return ser->isOpen(); + } + virtual bool isBusy() + { + return false; + } + + virtual bool inPutByte(IN uint8_t *byte) + { + if (ser->available() > 0) + { + ser->read(byte, 1); + return true; + } + return false; + } + + virtual uint32_t outPutBytes(IN uint8_t *byte, uint32_t lenght) + { + return ser->write(byte, lenght); + } + + UART(const std::string &port, uint32_t baudrate, serial::Timeout timeout, + serial::bytesize_t bytesize, serial::parity_t parity, serial::stopbits_t stopbits, + serial::flowcontrol_t flowcontrol) + { + ser = new serial::Serial(port, baudrate, timeout, bytesize, parity, stopbits, flowcontrol); + } + ~UART() + { + ser->close(); + delete ser; + } +}; + +#endif \ No newline at end of file diff --git a/include/sv_algorithm_base.h b/include/sv_algorithm_base.h new file mode 100644 index 0000000..1e17993 --- /dev/null +++ b/include/sv_algorithm_base.h @@ -0,0 +1,172 @@ +#ifndef __SV_ALGORITHM__ +#define __SV_ALGORITHM__ + +#include "sv_video_base.h" +#include +#include +#include +#include +#include + + +namespace yaed { +class EllipseDetector; +} + +namespace sv { + +// union JsonValue; +// class JsonAllocator; + +class CameraAlgorithm +{ +public: + CameraAlgorithm(); + ~CameraAlgorithm(); + void loadCameraParams(std::string yaml_fn_); + void loadAlgorithmParams(std::string json_fn_); + + cv::Mat camera_matrix; + cv::Mat distortion; + int image_width; + int image_height; + double fov_x; + double fov_y; + + std::string alg_params_fn; +protected: + // JsonValue* _value; + // JsonAllocator* _allocator; + std::chrono::system_clock::time_point _t0; +}; + + +class ArucoDetector : public CameraAlgorithm +{ +public: + ArucoDetector(); + void detect(cv::Mat img_, TargetsInFrame& tgts_); +private: + void _load(); + bool _params_loaded; + cv::Ptr _detector_params; + cv::Ptr _dictionary; + int _dictionary_id; + std::vector _ids_need; + std::vector _lengths_need; +}; + + +class EllipseDetector : public CameraAlgorithm +{ +public: + EllipseDetector(); + ~EllipseDetector(); + void detectAllInDirectory(std::string input_img_dir_, std::string output_json_dir_); + void detect(cv::Mat img_, TargetsInFrame& tgts_); +protected: + void _load(); + bool _params_loaded; + yaed::EllipseDetector* _ed; + float _max_center_distance_ratio; + double _radius_in_meter; +}; + +class LandingMarkerDetectorBase : public EllipseDetector +{ +public: + LandingMarkerDetectorBase(); + ~LandingMarkerDetectorBase(); + void detect(cv::Mat img_, TargetsInFrame& tgts_); + + bool isParamsLoaded(); + int getMaxCandidates(); + std::vector getLabelsNeed(); +protected: + virtual bool setupImpl(); + virtual void roiCNN(std::vector& input_rois_, std::vector& output_labels_); + void _loadLabels(); + int _max_candidates; + std::vector _labels_need; +}; + + +class SingleObjectTrackerBase : public CameraAlgorithm +{ +public: + SingleObjectTrackerBase(); + ~SingleObjectTrackerBase(); + void warmUp(); + void init(cv::Mat img_, const cv::Rect& bounding_box_); + void track(cv::Mat img_, TargetsInFrame& tgts_); + + bool isParamsLoaded(); + std::string getAlgorithm(); + int getBackend(); + int getTarget(); +protected: + virtual bool setupImpl(); + virtual void initImpl(cv::Mat img_, const cv::Rect& bounding_box_); + virtual bool trackImpl(cv::Mat img_, cv::Rect& output_bbox_); + void _load(); + bool _params_loaded; + std::string _algorithm; + int _backend; + int _target; +}; + + +class CommonObjectDetectorBase : public CameraAlgorithm +{ +public: + CommonObjectDetectorBase(); + ~CommonObjectDetectorBase(); + void warmUp(); + void detect(cv::Mat img_, TargetsInFrame& tgts_, Box* roi_=nullptr, int img_w_=0, int img_h_=0); + + bool isParamsLoaded(); + std::string getDataset(); + std::vector getClassNames(); + std::vector getClassWs(); + std::vector getClassHs(); + int getInputH(); + void setInputH(int h_); + int getInputW(); + void setInputW(int w_); + int getClassNum(); + int getOutputSize(); + double getThrsNms(); + double getThrsConf(); + int useWidthOrHeight(); + bool withSegmentation(); +protected: + virtual bool setupImpl(); + virtual void detectImpl( + cv::Mat img_, + std::vector& boxes_x_, + std::vector& boxes_y_, + std::vector& boxes_w_, + std::vector& boxes_h_, + std::vector& boxes_label_, + std::vector& boxes_score_, + std::vector& boxes_seg_ + ); + void _load(); + bool _params_loaded; + std::string _dataset; + std::vector _class_names; + std::vector _class_ws; + std::vector _class_hs; + int _input_h; + int _input_w; + int _n_classes; + int _output_size; + double _thrs_nms; + double _thrs_conf; + int _use_width_or_height; + bool _with_segmentation; +}; + + +} +#endif diff --git a/include/sv_common_det.h b/include/sv_common_det.h new file mode 100644 index 0000000..1bb14fe --- /dev/null +++ b/include/sv_common_det.h @@ -0,0 +1,39 @@ +#ifndef __SV_COMMON_DET__ +#define __SV_COMMON_DET__ + +#include "sv_core.h" +#include +#include +#include +#include +#include + + +namespace sv { + +class CommonObjectDetectorCUDAImpl; + +class CommonObjectDetector : public CommonObjectDetectorBase +{ +public: + CommonObjectDetector(); + ~CommonObjectDetector(); +protected: + bool setupImpl(); + void detectImpl( + cv::Mat img_, + std::vector& boxes_x_, + std::vector& boxes_y_, + std::vector& boxes_w_, + std::vector& boxes_h_, + std::vector& boxes_label_, + std::vector& boxes_score_, + std::vector& boxes_seg_ + ); + + CommonObjectDetectorCUDAImpl* _cuda_impl; +}; + + +} +#endif diff --git a/include/sv_core.h b/include/sv_core.h new file mode 100644 index 0000000..f895ecb --- /dev/null +++ b/include/sv_core.h @@ -0,0 +1,8 @@ +#ifndef __SV_CORE__ +#define __SV_CORE__ + +#include "sv_video_base.h" +#include "sv_gimbal.h" +#include "sv_algorithm_base.h" + +#endif diff --git a/include/sv_gimbal.h b/include/sv_gimbal.h new file mode 100644 index 0000000..e63bc18 --- /dev/null +++ b/include/sv_gimbal.h @@ -0,0 +1,162 @@ +/* + * @Description: + * @Author: jario-jin @amov + * @Date: 2023-04-12 09:12:52 + * @LastEditors: L LC @amov + * @LastEditTime: 2023-04-18 11:49:27 + * @FilePath: /spirecv-gimbal-sdk/include/sv_gimbal.h + */ +#ifndef __SV_GIMBAL__ +#define __SV_GIMBAL__ + +#include + +namespace sv +{ + + typedef void (*PStateInvoke)(double &frame_ang_r, double &frame_ang_p, double &frame_ang_y, + double &imu_ang_r, double &imu_ang_p, double &imu_ang_y, + double &fov_x, double &fov_y); + + enum class GimbalType + { + G1, + Q10f + }; + enum class GimbalLink + { + SERIAL, + ETHERNET_TCP, + ETHERNET_UDP + }; + + enum class GimablSerialByteSize + { + FIVE_BYTES = 5, + SIX_BYTES = 6, + SEVEN_BYTES = 7, + EIGHT_BYTES = 8, + }; + + enum class GimablSerialParity + { + PARITY_NONE = 0, + PARITY_ODD = 1, + PARITY_EVEN = 2, + PARITY_MARK = 3, + PARITY_SPACE = 4, + }; + + enum class GimablSerialStopBits + { + STOPBITS_ONE = 1, + STOPBITS_TWO = 2, + STOPBITS_ONE_POINT_FIVE = 3, + }; + + enum class GimablSerialFlowControl + { + FLOWCONTROL_NONE = 0, + FLOWCONTROL_SOFTWARE = 1, + FLOWCONTROL_HARDWARE = 2, + }; + + static inline void emptyCallback(double &frameAngleRoll, double &frameAnglePitch, double &frameAngleYaw, + double &imuAngleRoll, double &imuAnglePitch, double &imuAngleYaw, + double &fovX, double &fovY) + { + } + + //! A gimbal control and state reading class. + /*! + A common gimbal control class for vary type of gimbals. + e.g. AMOV G1 + */ + class Gimbal + { + private: + // Device pointers + void *dev; + void *IO; + + // Generic serial interface parameters list & default parameters + std::string m_serial_port = "/dev/ttyUSB0"; + int m_serial_baud_rate = 115200; + int m_serial_byte_size = (int)GimablSerialByteSize::EIGHT_BYTES; + int m_serial_parity = (int)GimablSerialParity::PARITY_NONE; + int m_serial_stopbits = (int)GimablSerialStopBits::STOPBITS_ONE; + int m_serial_flowcontrol = (int)GimablSerialFlowControl::FLOWCONTROL_NONE; + int m_serial_timeout = 500; + + // Ethernet interface parameters list & default parameters + std::string m_net_ip = "192.168.2.64"; + int m_net_port = 9090; + + GimbalType m_gimbal_type; + GimbalLink m_gimbal_link; + + public: + //! Constructor + /*! + \param serial_port: string like '/dev/ttyUSB0' in linux sys. + \param baud_rate: serial baud rate, e.g. 115200 + */ + Gimbal(GimbalType gtype = GimbalType::G1, GimbalLink ltype = GimbalLink::SERIAL) + { + m_gimbal_type = gtype; + m_gimbal_link = ltype; + } + ~Gimbal(); + // set Generic serial interface parameters + void setSerialPort(const std::string &port); + void setSerialPort(const int baud_rate); + void setSerialPort(GimablSerialByteSize byte_size, GimablSerialParity parity, + GimablSerialStopBits stop_bits, GimablSerialFlowControl flowcontrol, + int time_out = 500); + + // set Ethernet interface parameters + void setNetIp(const std::string &ip); + void setNetPort(const int &port); + + // Create a device instance + void setStateCallback(PStateInvoke callback); + bool open(PStateInvoke callback = emptyCallback); + + // Funtions + bool setHome(); + bool setZoom(double x); + bool setAutoZoom(int state); + bool setAutoFocus(int state); + bool takePhoto(); + bool takeVideo(int state); + int getVideoState(); + + //! Set gimbal angles + /*! + \param roll: eular roll angle (-60, 60) degree + \param pitch: eular pitch angle (-135, 135) degree + \param yaw: eular yaw angle (-150, 150) degree + \param roll_rate: roll angle rate, degree/s + \param pitch_rate: pitch angle rate, degree/s + \param yaw_rate: yaw angle rate, degree/s + */ + void setAngleEuler( + double roll, + double pitch, + double yaw, + double roll_rate = 0, + double pitch_rate = 0, + double yaw_rate = 0); + //! Set gimbal angle rates + /*! + \param roll_rate: roll angle rate, degree/s + \param pitch_rate: pitch angle rate, degree/s + \param yaw_rate: yaw angle rate, degree/s + */ + void setAngleRateEuler( + double roll_rate, + double pitch_rate, + double yaw_rate); + }; +} +#endif diff --git a/include/sv_landing_det.h b/include/sv_landing_det.h new file mode 100644 index 0000000..493e021 --- /dev/null +++ b/include/sv_landing_det.h @@ -0,0 +1,33 @@ +#ifndef __SV_LANDING_DET__ +#define __SV_LANDING_DET__ + +#include "sv_core.h" +#include +#include +#include +#include +#include + + +namespace sv { + +class LandingMarkerDetectorCUDAImpl; + +class LandingMarkerDetector : public LandingMarkerDetectorBase +{ +public: + LandingMarkerDetector(); + ~LandingMarkerDetector(); +protected: + bool setupImpl(); + void roiCNN( + std::vector& input_rois_, + std::vector& output_labels_ + ); + + LandingMarkerDetectorCUDAImpl* _cuda_impl; +}; + + +} +#endif diff --git a/include/sv_tracking.h b/include/sv_tracking.h new file mode 100644 index 0000000..7895362 --- /dev/null +++ b/include/sv_tracking.h @@ -0,0 +1,32 @@ +#ifndef __SV_TRACKING__ +#define __SV_TRACKING__ + +#include "sv_core.h" +#include +#include +#include +#include +#include + + + +namespace sv { + +class SingleObjectTrackerOCV470Impl; + +class SingleObjectTracker : public SingleObjectTrackerBase +{ +public: + SingleObjectTracker(); + ~SingleObjectTracker(); +protected: + bool setupImpl(); + void initImpl(cv::Mat img_, const cv::Rect& bounding_box_); + bool trackImpl(cv::Mat img_, cv::Rect& output_bbox_); + + SingleObjectTrackerOCV470Impl* _ocv470_impl; +}; + + +} +#endif diff --git a/include/sv_video_base.h b/include/sv_video_base.h new file mode 100644 index 0000000..47ee94d --- /dev/null +++ b/include/sv_video_base.h @@ -0,0 +1,399 @@ +#ifndef __SV_VIDEOIO__ +#define __SV_VIDEOIO__ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include // for sockaddr_in + +#define SV_RAD2DEG 57.2957795 +// #define X86_PLATFORM +// #define JETSON_PLATFORM + + +namespace sv { + +//! The rectangle bounding-box of an object. +class Box +{ +public: + Box(); + + int x1; + int y1; + int x2; + int y2; + + //! Set the parameters of the bounding-box by XYXY-format. + /*! + \param x1_: The x-axis pixel coordinates of the top-left point. + \param y1_: The y-axis pixel coordinates of the top-left point. + \param x2_: The x-axis pixel coordinates of the bottom-right point. + \param y2_: The y-axis pixel coordinates of the bottom-right point. + */ + void setXYXY(int x1_, int y1_, int x2_, int y2_); + //! Set the parameters of the bounding-box by XYWH-format. + /*! + \param x1_: The x-axis pixel coordinates of the top-left point. + \param y1_: The y-axis pixel coordinates of the top-left point. + \param w_: The width of the bounding rectangle. + \param h_: The height of the bounding rectangle. + */ + void setXYWH(int x_, int y_, int w_, int h_); +}; + + +//! Description class for a single target detection result. +/*! + Support multiple description methods, + such as bounding box, segmentation, ellipse, three-dimensional position, etc. +*/ +class Target +{ +public: + Target(); + + //! X coordinate of object center point, [0, 1], (Required) + double cx; + //! Y coordinate of object center point, [0, 1], (Required) + double cy; + //! Object-width / image-width, (0, 1] + double w; + //! Object-height / image-heigth, (0, 1] + double h; + + //! Objectness, Confidence, (0, 1] + double score; + //! Category of target. + std::string category; + //! Category ID of target. + int category_id; + //! The same target in different frames shares a unique ID. + int tracked_id; + + //! X coordinate of object position in Camera-Frame (unit: meter). + double px; + //! Y coordinate of object position in Camera-Frame (unit: meter). + double py; + //! Z coordinate of object position in Camera-Frame (unit: meter). + double pz; + + //! Line of sight (LOS) angle on X-axis (unit: degree). + double los_ax; + //! Line of sight (LOS) angle on Y-axis (unit: degree). + double los_ay; + //! The angle of the target in the image coordinate system, (unit: degree) [-180, 180]. + double yaw_a; + + //! Whether the height&width of the target can be obtained. + bool has_hw; + //! Whether the category of the target can be obtained. + bool has_category; + //! Whether the tracking-ID of the target can be obtained. + bool has_tid; + //! Whether the 3D-position of the target can be obtained. + bool has_position; + //! Whether the LOS-angle of the target can be obtained. + bool has_los; + //! Whether the segmentation of the target can be obtained. + bool has_seg; + //! Whether the bounding-box of the target can be obtained. + bool has_box; + //! Whether the ellipse-parameters of the target can be obtained. + bool has_ell; + //! Whether the aruco-parameters of the target can be obtained. + bool has_aruco; + //! Whether the direction of the target can be obtained. + bool has_yaw; + + void setCategory(std::string cate_, int cate_id_); + void setLOS(double cx_, double cy_, cv::Mat camera_matrix_, int img_w_, int img_h_); + void setTrackID(int id_); + void setPosition(double x_, double y_, double z_); + void setBox(int x1_, int y1_, int x2_, int y2_, int img_w_, int img_h_); + void setAruco(int id_, std::vector corners_, cv::Vec3d rvecs_, cv::Vec3d tvecs_, int img_w_, int img_h_, cv::Mat camera_matrix_); + void setEllipse(double xc_, double yc_, double a_, double b_, double rad_, double score_, int img_w_, int img_h_, cv::Mat camera_matrix_, double radius_in_meter_); + void setYaw(double vec_x_, double vec_y); + void setMask(cv::Mat mask_); + cv::Mat getMask(); + + bool getBox(Box& b); + bool getAruco(int& id, std::vector &corners); + bool getEllipse(double& xc_, double& yc_, double& a_, double& b_, double& rad_); + std::string getJsonStr(); + +private: + //! segmentation [[x1,y1, x2,y2, x3,y3,...],...] + /*! + SEG variables: (_s_) segmentation, segmentation_size_h, segmentation_size_w, segmentation_counts, area + */ + std::vector > _s_segmentation; + int _s_segmentation_size_h; + int _s_segmentation_size_w; + std::string _s_segmentation_counts; + cv::Mat _mask; + double _s_area; + //! bounding box [x, y, w, h] + /*! + BOX variables: (_b_) box + */ + Box _b_box; // x,y,w,h + //! ellipse x-axis center + /*! + ELL variables: (_e_) xc, yc, a, b, rad + */ + double _e_xc; + double _e_yc; + double _e_a; + double _e_b; + double _e_rad; + //! Aruco Marker ID + /*! + ARUCO variables: (_a_) id, corners, rvecs, tvecs + */ + int _a_id; + std::vector _a_corners; + cv::Vec3d _a_rvecs; + cv::Vec3d _a_tvecs; +}; + + +enum class MissionType {NONE, COMMON_DET, TRACKING, ARUCO_DET, LANDMARK_DET, ELLIPSE_DET}; + +//! This class describes all objects in a single frame image. +/*! + 1. Contains multiple Target instances. + 2. Describes the ID of the current frame, image width and height, current field of view, etc. + 3. Describes the processed image sub-regions and supports local region detection. +*/ +class TargetsInFrame +{ +public: + TargetsInFrame(int frame_id_); + + //! Frame number. + int frame_id; + //! Frame/image height. + int height; + //! Frame/image width. + int width; + + //! Detection frame per second (FPS). + double fps; + //! The x-axis field of view (FOV) of the current camera. + double fov_x; + //! The y-axis field of view (FOV) of the current camera. + double fov_y; + + //! 吊舱俯仰角 + double pod_patch; + //! 吊舱滚转角 + double pod_roll; + //! 吊舱航向角,东向为0,东北天为正,范围[-180,180] + double pod_yaw; + + //! 当前经度 + double longitude; + //! 当前纬度 + double latitude; + //! 当前飞行高度 + double altitude; + + //! 飞行速度,x轴,东北天坐标系 + double uav_vx; + //! 飞行速度,y轴,东北天坐标系 + double uav_vy; + //! 飞行速度,z轴,东北天坐标系 + double uav_vz; + //! 当前光照强度,Lux + double illumination; + + //! Whether the detection FPS can be obtained. + bool has_fps; + //! Whether the FOV can be obtained. + bool has_fov; + //! Whether the processed image sub-region can be obtained. + bool has_roi; + + bool has_pod_info; + bool has_uav_pos; + bool has_uav_vel; + bool has_ill; + + MissionType type; + + //! The processed image sub-region, if size>0, it means no full image detection. + std::vector rois; + //! Detected Target Instances. + std::vector targets; + std::string date_captured; + + void setTimeNow(); + void setFPS(double fps_); + void setFOV(double fov_x_, double fov_y_); + void setSize(int width_, int height_); + std::string getJsonStr(); +}; + + +class UDPServer { +public: + UDPServer(std::string dest_ip="127.0.0.1", int port=20166); + ~UDPServer(); + + void send(const TargetsInFrame& tgts_); +private: + struct sockaddr_in _servaddr; + int _sockfd; +}; + + +class VideoWriterBase { +public: + VideoWriterBase(); + ~VideoWriterBase(); + + void setup(std::string file_path, cv::Size size, double fps=25.0, bool with_targets=false); + void write(cv::Mat image, TargetsInFrame tgts=TargetsInFrame(0)); + void release(); + + cv::Size getSize(); + double getFps(); + std::string getFilePath(); + bool isRunning(); +protected: + virtual bool setupImpl(std::string file_name_); + virtual bool isOpenedImpl(); + virtual void writeImpl(cv::Mat img_); + virtual void releaseImpl(); + void _init(); + void _run(); + + bool _is_running; + cv::Size _image_size; + double _fps; + bool _with_targets; + int _fid; + int _fcnt; + + std::thread _tt; + // cv::VideoWriter _writer; + std::ofstream _targets_ofs; + std::string _file_path; + + std::queue _image_to_write; + std::queue _tgts_to_write; +}; + + +class VideoStreamerBase { +public: + VideoStreamerBase(); + ~VideoStreamerBase(); + + void setup(cv::Size size, int port=8554, int bitrate=2, std::string url="/live"); // 2M + void stream(cv::Mat image); + void release(); + + cv::Size getSize(); + int getPort(); + std::string getUrl(); + int getBitrate(); + bool isRunning(); +protected: + virtual bool setupImpl(); + virtual bool isOpenedImpl(); + virtual void writeImpl(cv::Mat image); + virtual void releaseImpl(); + void _run(); + + bool _is_running; + cv::Size _stream_size; + int _port; + std::string _url; + int _bitrate; + std::thread _tt; + std::stack _image_to_stream; +}; + + +enum class CameraType {NONE, WEBCAM, G1, Q10}; + +class CameraBase { +public: + CameraBase(CameraType type=CameraType::NONE, int id=0); + ~CameraBase(); + void open(CameraType type=CameraType::WEBCAM, int id=0); + bool read(cv::Mat& image); + void release(); + + int getW(); + int getH(); + int getFps(); + std::string getIp(); + int getPort(); + double getBrightness(); + double getContrast(); + double getSaturation(); + double getHue(); + double getExposure(); + bool isRunning(); + void setWH(int width, int height); + void setFps(int fps); + void setIp(std::string ip); + void setPort(int port); + void setBrightness(double brightness); + void setContrast(double contrast); + void setSaturation(double saturation); + void setHue(double hue); + void setExposure(double exposure); +protected: + virtual void openImpl(); + void _run(); + + bool _is_running; + bool _is_updated; + std::thread _tt; + cv::VideoCapture _cap; + cv::Mat _frame; + CameraType _type; + int _camera_id; + + int _width; + int _height; + int _fps; + std::string _ip; + int _port; + double _brightness; + double _contrast; + double _saturation; + double _hue; + double _exposure; +}; + + +void drawTargetsInFrame( + cv::Mat& img_, + const TargetsInFrame& tgts_, + bool with_all=true, + bool with_category=false, + bool with_tid=false, + bool with_seg=false, + bool with_box=false, + bool with_ell=false, + bool with_aruco=false, + bool with_yaw=false +); +std::string get_home(); +bool is_file_exist(std::string& fn); +void list_dir(std::string dir, std::vector& files, std::string suffixs="", bool r=false); + + +} +#endif diff --git a/include/sv_video_input.h b/include/sv_video_input.h new file mode 100644 index 0000000..600d878 --- /dev/null +++ b/include/sv_video_input.h @@ -0,0 +1,27 @@ +#ifndef __SV_VIDEO_INPUT__ +#define __SV_VIDEO_INPUT__ + +#include "sv_core.h" +#include +#include +#include +#include +#include + + +namespace sv { + + +class Camera : public CameraBase +{ +public: + Camera(); + ~Camera(); +protected: + void openImpl(); + +}; + + +} +#endif diff --git a/include/sv_video_output.h b/include/sv_video_output.h new file mode 100644 index 0000000..807879a --- /dev/null +++ b/include/sv_video_output.h @@ -0,0 +1,53 @@ +#ifndef __SV_VIDEO_OUTPUT__ +#define __SV_VIDEO_OUTPUT__ + +#include "sv_core.h" +#include +#include +#include +#include +#include + + +class BsVideoSaver; +class BsPushStreamer; + +namespace sv { + +class VideoWriterGstreamerImpl; +class VideoStreamerGstreamerImpl; + +class VideoWriter : public VideoWriterBase +{ +public: + VideoWriter(); + ~VideoWriter(); +protected: + bool setupImpl(std::string file_name_); + bool isOpenedImpl(); + void writeImpl(cv::Mat img_); + void releaseImpl(); + + VideoWriterGstreamerImpl* _gstreamer_impl; + BsVideoSaver* _ffmpeg_impl; +}; + + +class VideoStreamer : public VideoStreamerBase +{ +public: + VideoStreamer(); + ~VideoStreamer(); +protected: + bool setupImpl(); + bool isOpenedImpl(); + void writeImpl(cv::Mat img_); + void releaseImpl(); + + VideoStreamerGstreamerImpl* _gstreamer_impl; + BsPushStreamer* _ffmpeg_impl; +}; + + +} +#endif diff --git a/include/sv_world.h b/include/sv_world.h new file mode 100644 index 0000000..5746122 --- /dev/null +++ b/include/sv_world.h @@ -0,0 +1,12 @@ +#ifndef __SV__WORLD__ +#define __SV__WORLD__ + +#include "sv_core.h" +#include "sv_common_det.h" +#include "sv_landing_det.h" +#include "sv_tracking.h" +#include "sv_video_input.h" +#include "sv_video_output.h" + + +#endif diff --git a/samples/SpireCVDet.cpp b/samples/SpireCVDet.cpp new file mode 100644 index 0000000..0a7353e --- /dev/null +++ b/samples/SpireCVDet.cpp @@ -0,0 +1,120 @@ +#include "yolov7/cuda_utils.h" +#include "yolov7/logging.h" +#include "yolov7/utils.h" +#include "yolov7/preprocess.h" +#include "yolov7/postprocess.h" +#include "yolov7/model.h" + +#include +#include +#include + +using namespace nvinfer1; + +static Logger gLogger; +const static int kInputH = 640; +const static int kInputW = 640; +const static int kInputH_HD = 1280; +const static int kInputW_HD = 1280; +const static int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; + +bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, bool& is_p6, float& gd, float& gw, std::string& img_dir, int& n_classes) { + if (argc < 4) return false; + if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 8)) { + wts = std::string(argv[2]); + engine = std::string(argv[3]); + n_classes = atoi(argv[4]); + if (n_classes < 1) + return false; + auto net = std::string(argv[5]); + if (net[0] == 'n') { + gd = 0.33; + gw = 0.25; + } else if (net[0] == 's') { + gd = 0.33; + gw = 0.50; + } else if (net[0] == 'm') { + gd = 0.67; + gw = 0.75; + } else if (net[0] == 'l') { + gd = 1.0; + gw = 1.0; + } else if (net[0] == 'x') { + gd = 1.33; + gw = 1.25; + } else if (net[0] == 'c' && argc == 8) { + gd = atof(argv[6]); + gw = atof(argv[7]); + } else { + return false; + } + if (net.size() == 2 && net[1] == '6') { + is_p6 = true; + } + } else { + return false; + } + return true; +} + +void serialize_engine(unsigned int max_batchsize, bool& is_p6, float& gd, float& gw, std::string& wts_name, std::string& engine_name, int n_classes) { + // Create builder + IBuilder* builder = createInferBuilder(gLogger); + IBuilderConfig* config = builder->createBuilderConfig(); + + // Create model to populate the network, then set the outputs and create an engine + ICudaEngine *engine = nullptr; + if (is_p6) { + engine = build_det_p6_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name, kInputH_HD, kInputW_HD, n_classes); + } else { + engine = build_det_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name, kInputH, kInputW, n_classes); + } + assert(engine != nullptr); + + // Serialize the engine + IHostMemory* serialized_engine = engine->serialize(); + assert(serialized_engine != nullptr); + + // Save engine to file + std::ofstream p(engine_name, std::ios::binary); + if (!p) { + std::cerr << "Could not open plan output file" << std::endl; + assert(false); + } + p.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + + // Close everything down + engine->destroy(); + builder->destroy(); + config->destroy(); + serialized_engine->destroy(); +} + +int main(int argc, char** argv) { + cudaSetDevice(kGpuId); + + std::string wts_name = ""; + std::string engine_name = ""; + bool is_p6 = false; + float gd = 0.0f, gw = 0.0f; + std::string img_dir; + int n_classes; + + if (!parse_args(argc, argv, wts_name, engine_name, is_p6, gd, gw, img_dir, n_classes)) { + std::cerr << "arguments not right!" << std::endl; + std::cerr << "./SpireCVDet -s [.wts] [.engine] n_classes [n/s/m/l/x/n6/s6/m6/l6/x6 or c/c6 gd gw] // serialize model to plan file" << std::endl; + // std::cerr << "./SpireCVDet -d [.engine] ../images // deserialize plan file and run inference" << std::endl; + return -1; + } + std::cout << "n_classes: " << n_classes << std::endl; + + // Create a model using the API directly and serialize it to a file + if (!wts_name.empty()) { + serialize_engine(kBatchSize, is_p6, gd, gw, wts_name, engine_name, n_classes); + return 0; + } + + + return 0; +} + diff --git a/samples/SpireCVSeg.cpp b/samples/SpireCVSeg.cpp new file mode 100644 index 0000000..ebc1ebc --- /dev/null +++ b/samples/SpireCVSeg.cpp @@ -0,0 +1,120 @@ +#include "yolov7/config.h" +#include "yolov7/cuda_utils.h" +#include "yolov7/logging.h" +#include "yolov7/utils.h" +#include "yolov7/preprocess.h" +#include "yolov7/postprocess.h" +#include "yolov7/model.h" + +#include +#include +#include + +using namespace nvinfer1; + +static Logger gLogger; +const static int kInputH = 640; +const static int kInputW = 640; +const static int kOutputSize1 = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; +const static int kOutputSize2 = 32 * (kInputH / 4) * (kInputW / 4); + +bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, std::string& img_dir, std::string& labels_filename, int& n_classes) { + if (argc < 4) return false; + if (std::string(argv[1]) == "-s" && (argc == 6 || argc == 8)) { + wts = std::string(argv[2]); + engine = std::string(argv[3]); + n_classes = atoi(argv[4]); + if (n_classes < 1) + return false; + auto net = std::string(argv[5]); + if (net[0] == 'n') { + gd = 0.33; + gw = 0.25; + } else if (net[0] == 's') { + gd = 0.33; + gw = 0.50; + } else if (net[0] == 'm') { + gd = 0.67; + gw = 0.75; + } else if (net[0] == 'l') { + gd = 1.0; + gw = 1.0; + } else if (net[0] == 'x') { + gd = 1.33; + gw = 1.25; + } else if (net[0] == 'c' && argc == 8) { + gd = atof(argv[6]); + gw = atof(argv[7]); + } else { + return false; + } + } else if (std::string(argv[1]) == "-d" && argc == 5) { + engine = std::string(argv[2]); + img_dir = std::string(argv[3]); + labels_filename = std::string(argv[4]); + } else { + return false; + } + return true; +} + + +void serialize_engine(unsigned int max_batchsize, float& gd, float& gw, std::string& wts_name, std::string& engine_name, int n_classes) { + // Create builder + IBuilder* builder = createInferBuilder(gLogger); + IBuilderConfig* config = builder->createBuilderConfig(); + + // Create model to populate the network, then set the outputs and create an engine + ICudaEngine *engine = nullptr; + + engine = build_seg_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name, kInputH, kInputW, n_classes); + + assert(engine != nullptr); + + // Serialize the engine + IHostMemory* serialized_engine = engine->serialize(); + assert(serialized_engine != nullptr); + + // Save engine to file + std::ofstream p(engine_name, std::ios::binary); + if (!p) { + std::cerr << "Could not open plan output file" << std::endl; + assert(false); + } + p.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + + // Close everything down + engine->destroy(); + builder->destroy(); + config->destroy(); + serialized_engine->destroy(); +} + +int main(int argc, char** argv) { + cudaSetDevice(kGpuId); + + std::string wts_name = ""; + std::string engine_name = ""; + std::string labels_filename = ""; + float gd = 0.0f, gw = 0.0f; + int n_classes; + + std::string img_dir; + if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir, labels_filename, n_classes)) { + std::cerr << "arguments not right!" << std::endl; + std::cerr << "./SpireCVSeg -s [.wts] [.engine] n_classes [n/s/m/l/x or c gd gw] // serialize model to plan file" << std::endl; + // std::cerr << "./SpireCVSeg -d [.engine] ../images coco.txt // deserialize plan file, read the labels file and run inference" << std::endl; + return -1; + } + std::cout << "n_classes: " << n_classes << std::endl; + + // Create a model using the API directly and serialize it to a file + if (!wts_name.empty()) { + serialize_engine(kBatchSize, gd, gw, wts_name, engine_name, n_classes); + return 0; + } + + + return 0; +} + diff --git a/samples/calib/aruco_samples_utility.hpp b/samples/calib/aruco_samples_utility.hpp new file mode 100644 index 0000000..c1cfe62 --- /dev/null +++ b/samples/calib/aruco_samples_utility.hpp @@ -0,0 +1,48 @@ +#include +#include +#include +#include + +namespace { +inline static bool readCameraParameters(std::string filename, cv::Mat &camMatrix, cv::Mat &distCoeffs) { + cv::FileStorage fs(filename, cv::FileStorage::READ); + if (!fs.isOpened()) + return false; + fs["camera_matrix"] >> camMatrix; + fs["distortion_coefficients"] >> distCoeffs; + return true; +} + +inline static bool saveCameraParams(const std::string &filename, cv::Size imageSize, float aspectRatio, int flags, + const cv::Mat &cameraMatrix, const cv::Mat &distCoeffs, double totalAvgErr) { + cv::FileStorage fs(filename, cv::FileStorage::WRITE); + if (!fs.isOpened()) + return false; + + time_t tt; + time(&tt); + struct tm *t2 = localtime(&tt); + char buf[1024]; + strftime(buf, sizeof(buf) - 1, "%c", t2); + + fs << "calibration_time" << buf; + fs << "image_width" << imageSize.width; + fs << "image_height" << imageSize.height; + + if (flags & cv::CALIB_FIX_ASPECT_RATIO) fs << "aspectRatio" << aspectRatio; + + if (flags != 0) { + sprintf(buf, "flags: %s%s%s%s", + flags & cv::CALIB_USE_INTRINSIC_GUESS ? "+use_intrinsic_guess" : "", + flags & cv::CALIB_FIX_ASPECT_RATIO ? "+fix_aspectRatio" : "", + flags & cv::CALIB_FIX_PRINCIPAL_POINT ? "+fix_principal_point" : "", + flags & cv::CALIB_ZERO_TANGENT_DIST ? "+zero_tangent_dist" : ""); + } + fs << "flags" << flags; + fs << "camera_matrix" << cameraMatrix; + fs << "distortion_coefficients" << distCoeffs; + fs << "avg_reprojection_error" << totalAvgErr; + return true; +} + +} diff --git a/samples/calib/calibrate_camera_charuco.cpp b/samples/calib/calibrate_camera_charuco.cpp new file mode 100644 index 0000000..7337dd2 --- /dev/null +++ b/samples/calib/calibrate_camera_charuco.cpp @@ -0,0 +1,293 @@ +/* +By downloading, copying, installing or using the software you agree to this +license. If you do not agree to this license, do not download, install, +copy or use the software. + License Agreement + For Open Source Computer Vision Library + (3-clause BSD License) +Copyright (C) 2013, OpenCV Foundation, all rights reserved. +Third party copyrights are property of their respective owners. +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the names of the copyright holders nor the names of the contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. +This software is provided by the copyright holders and contributors "as is" and +any express or implied warranties, including, but not limited to, the implied +warranties of merchantability and fitness for a particular purpose are +disclaimed. In no event shall copyright holders or contributors be liable for +any direct, indirect, incidental, special, exemplary, or consequential damages +(including, but not limited to, procurement of substitute goods or services; +loss of use, data, or profits; or business interruption) however caused +and on any theory of liability, whether in contract, strict liability, +or tort (including negligence or otherwise) arising in any way out of +the use of this software, even if advised of the possibility of such damage. +*/ + + +#include +#include +#include +#include +#include +#include +#include "aruco_samples_utility.hpp" + +using namespace std; +using namespace cv; + +namespace { +const char* about = + "Calibration using a ChArUco board\n" + " To capture a frame for calibration, press 'c',\n" + " If input comes from video, press any key for next frame\n" + " To finish capturing, press 'ESC' key and calibration starts.\n"; +const char* keys = + "{w | | Number of squares in X direction }" + "{h | | Number of squares in Y direction }" + "{sl | | Square side length (in meters) }" + "{ml | | Marker side length (in meters) }" + "{d | | dictionary: DICT_4X4_50=0, DICT_4X4_100=1, DICT_4X4_250=2," + "DICT_4X4_1000=3, DICT_5X5_50=4, DICT_5X5_100=5, DICT_5X5_250=6, DICT_5X5_1000=7, " + "DICT_6X6_50=8, DICT_6X6_100=9, DICT_6X6_250=10, DICT_6X6_1000=11, DICT_7X7_50=12," + "DICT_7X7_100=13, DICT_7X7_250=14, DICT_7X7_1000=15, DICT_ARUCO_ORIGINAL = 16}" + "{cd | | Input file with custom dictionary }" + "{@outfile | | Output file with calibrated camera parameters }" + "{v | | Input from video file, if ommited, input comes from camera }" + "{ci | 0 | Camera id if input doesnt come from video (-v) }" + "{dp | | File of marker detector parameters }" + "{rs | false | Apply refind strategy }" + "{zt | false | Assume zero tangential distortion }" + "{a | | Fix aspect ratio (fx/fy) to this value }" + "{pc | false | Fix the principal point at the center }" + "{sc | false | Show detected chessboard corners after calibration }"; +} + + +int main(int argc, char *argv[]) { + CommandLineParser parser(argc, argv, keys); + parser.about(about); + + if(argc < 7) { + parser.printMessage(); + return 0; + } + + int squaresX = parser.get("w"); + int squaresY = parser.get("h"); + float squareLength = parser.get("sl"); + float markerLength = parser.get("ml"); + string outputFile = parser.get(0); + + bool showChessboardCorners = parser.get("sc"); + + int calibrationFlags = 0; + float aspectRatio = 1; + if(parser.has("a")) { + calibrationFlags |= CALIB_FIX_ASPECT_RATIO; + aspectRatio = parser.get("a"); + } + if(parser.get("zt")) calibrationFlags |= CALIB_ZERO_TANGENT_DIST; + if(parser.get("pc")) calibrationFlags |= CALIB_FIX_PRINCIPAL_POINT; + + Ptr detectorParams = makePtr(); + if(parser.has("dp")) { + FileStorage fs(parser.get("dp"), FileStorage::READ); + bool readOk = detectorParams->readDetectorParameters(fs.root()); + if(!readOk) { + cerr << "Invalid detector parameters file" << endl; + return 0; + } + } + + bool refindStrategy = parser.get("rs"); + int camId = parser.get("ci"); + String video; + + if(parser.has("v")) { + video = parser.get("v"); + } + + if(!parser.check()) { + parser.printErrors(); + return 0; + } + + VideoCapture inputVideo; + int waitTime; + if(!video.empty()) { + inputVideo.open(video); + waitTime = 0; + } else { + inputVideo.open(camId); + waitTime = 10; + } + + aruco::Dictionary dictionary = aruco::getPredefinedDictionary(0); + if (parser.has("d")) { + int dictionaryId = parser.get("d"); + dictionary = aruco::getPredefinedDictionary(aruco::PredefinedDictionaryType(dictionaryId)); + } + else if (parser.has("cd")) { + FileStorage fs(parser.get("cd"), FileStorage::READ); + bool readOk = dictionary.aruco::Dictionary::readDictionary(fs.root()); + if(!readOk) { + cerr << "Invalid dictionary file" << endl; + return 0; + } + } + else { + cerr << "Dictionary not specified" << endl; + return 0; + } + + // create charuco board object + Ptr charucoboard = new aruco::CharucoBoard(Size(squaresX, squaresY), squareLength, markerLength, dictionary); + Ptr board = charucoboard.staticCast(); + + // collect data from each frame + vector< vector< vector< Point2f > > > allCorners; + vector< vector< int > > allIds; + vector< Mat > allImgs; + Size imgSize; + + while(inputVideo.grab()) { + Mat image, imageCopy; + inputVideo.retrieve(image); + + vector< int > ids; + vector< vector< Point2f > > corners, rejected; + + // detect markers + aruco::detectMarkers(image, makePtr(dictionary), corners, ids, detectorParams, rejected); + + // refind strategy to detect more markers + if(refindStrategy) aruco::refineDetectedMarkers(image, board, corners, ids, rejected); + + // interpolate charuco corners + Mat currentCharucoCorners, currentCharucoIds; + if(ids.size() > 0) + aruco::interpolateCornersCharuco(corners, ids, image, charucoboard, currentCharucoCorners, + currentCharucoIds); + + // draw results + image.copyTo(imageCopy); + if(ids.size() > 0) aruco::drawDetectedMarkers(imageCopy, corners); + + if(currentCharucoCorners.total() > 0) + aruco::drawDetectedCornersCharuco(imageCopy, currentCharucoCorners, currentCharucoIds); + + putText(imageCopy, "Press 'c' to add current frame. 'ESC' to finish and calibrate", + Point(10, 20), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(255, 0, 0), 2); + + imshow("out", imageCopy); + char key = (char)waitKey(waitTime); + if(key == 27) break; + if(key == 'c' && ids.size() > 0) { + cout << "Frame captured" << endl; + allCorners.push_back(corners); + allIds.push_back(ids); + allImgs.push_back(image); + imgSize = image.size(); + } + } + + if(allIds.size() < 1) { + cerr << "Not enough captures for calibration" << endl; + return 0; + } + + Mat cameraMatrix, distCoeffs; + vector< Mat > rvecs, tvecs; + double repError; + + if(calibrationFlags & CALIB_FIX_ASPECT_RATIO) { + cameraMatrix = Mat::eye(3, 3, CV_64F); + cameraMatrix.at< double >(0, 0) = aspectRatio; + } + + // prepare data for calibration + vector< vector< Point2f > > allCornersConcatenated; + vector< int > allIdsConcatenated; + vector< int > markerCounterPerFrame; + markerCounterPerFrame.reserve(allCorners.size()); + for(unsigned int i = 0; i < allCorners.size(); i++) { + markerCounterPerFrame.push_back((int)allCorners[i].size()); + for(unsigned int j = 0; j < allCorners[i].size(); j++) { + allCornersConcatenated.push_back(allCorners[i][j]); + allIdsConcatenated.push_back(allIds[i][j]); + } + } + + // calibrate camera using aruco markers + double arucoRepErr; + arucoRepErr = aruco::calibrateCameraAruco(allCornersConcatenated, allIdsConcatenated, + markerCounterPerFrame, board, imgSize, cameraMatrix, + distCoeffs, noArray(), noArray(), calibrationFlags); + + // prepare data for charuco calibration + int nFrames = (int)allCorners.size(); + vector< Mat > allCharucoCorners; + vector< Mat > allCharucoIds; + vector< Mat > filteredImages; + allCharucoCorners.reserve(nFrames); + allCharucoIds.reserve(nFrames); + + for(int i = 0; i < nFrames; i++) { + // interpolate using camera parameters + Mat currentCharucoCorners, currentCharucoIds; + aruco::interpolateCornersCharuco(allCorners[i], allIds[i], allImgs[i], charucoboard, + currentCharucoCorners, currentCharucoIds, cameraMatrix, + distCoeffs); + + allCharucoCorners.push_back(currentCharucoCorners); + allCharucoIds.push_back(currentCharucoIds); + filteredImages.push_back(allImgs[i]); + } + + if(allCharucoCorners.size() < 4) { + cerr << "Not enough corners for calibration" << endl; + return 0; + } + + // calibrate camera using charuco + repError = + aruco::calibrateCameraCharuco(allCharucoCorners, allCharucoIds, charucoboard, imgSize, + cameraMatrix, distCoeffs, rvecs, tvecs, calibrationFlags); + + bool saveOk = saveCameraParams(outputFile, imgSize, aspectRatio, calibrationFlags, + cameraMatrix, distCoeffs, repError); + if(!saveOk) { + cerr << "Cannot save output file" << endl; + return 0; + } + + cout << "Rep Error: " << repError << endl; + cout << "Rep Error Aruco: " << arucoRepErr << endl; + cout << "Calibration saved to " << outputFile << endl; + + // show interpolated charuco corners for debugging + if(showChessboardCorners) { + for(unsigned int frame = 0; frame < filteredImages.size(); frame++) { + Mat imageCopy = filteredImages[frame].clone(); + if(allIds[frame].size() > 0) { + + if(allCharucoCorners[frame].total() > 0) { + aruco::drawDetectedCornersCharuco( imageCopy, allCharucoCorners[frame], + allCharucoIds[frame]); + } + } + + imshow("out", imageCopy); + char key = (char)waitKey(0); + if(key == 27) break; + } + } + + return 0; +} diff --git a/samples/demo/aruco_detection.cpp b/samples/demo/aruco_detection.cpp new file mode 100644 index 0000000..0229400 --- /dev/null +++ b/samples/demo/aruco_detection.cpp @@ -0,0 +1,74 @@ +#include +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; + +int main(int argc, char *argv[]) { + // 实例化Aruco检测器类 + sv::ArucoDetector ad; + // 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件 + ad.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_640x480.yaml"); + + // 打开摄像头 + sv::Camera cap; + // cap.setWH(640, 480); + // cap.setFps(30); + cap.open(sv::CameraType::WEBCAM, 0); // CameraID 0 + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + int frame_id = 0; + while (1) + { + // 实例化SpireCV的 单帧检测结果 接口类 TargetsInFrame + sv::TargetsInFrame tgts(frame_id++); + // 读取一帧图像到img + cap.read(img); + cv::resize(img, img, cv::Size(ad.image_width, ad.image_height)); + + // 执行Aruco二维码检测 + ad.detect(img, tgts); + // 可视化检测结果,叠加到img上 + sv::drawTargetsInFrame(img, tgts); + + // 控制台打印Aruco检测结果 + printf("Frame-[%d]\n", frame_id); + // 打印当前检测的FPS + printf(" FPS = %.2f\n", tgts.fps); + // 打印当前相机的视场角(degree) + printf(" FOV (fx, fy) = (%.2f, %.2f)\n", tgts.fov_x, tgts.fov_y); + // 打印当前输入图像的像素宽度和高度 + printf(" Frame Size (width, height) = (%d, %d)\n", tgts.width, tgts.height); + for (int i=0; i +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; + +int main(int argc, char *argv[]) { + // 打开摄像头 + sv::Camera cap; + // cap.setWH(640, 480); + // cap.setFps(30); + cap.open(sv::CameraType::WEBCAM, 0); // CameraID 0 + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + while (1) + { + // 读取一帧图像到img + cap.read(img); + + // 显示img + cv::imshow("img", img); + cv::waitKey(10); + } + + return 0; +} diff --git a/samples/demo/common_object_detection.cpp b/samples/demo/common_object_detection.cpp new file mode 100644 index 0000000..f051756 --- /dev/null +++ b/samples/demo/common_object_detection.cpp @@ -0,0 +1,72 @@ +#include +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; + +int main(int argc, char *argv[]) { + // 实例化 通用目标 检测器类 + sv::CommonObjectDetector cod; + // 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件 + cod.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_640x480.yaml"); + + // 打开摄像头 + sv::Camera cap; + // cap.setWH(640, 480); + // cap.setFps(30); + cap.open(sv::CameraType::WEBCAM, 0); // CameraID 0 + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + int frame_id = 0; + while (1) + { + // 实例化SpireCV的 单帧检测结果 接口类 TargetsInFrame + sv::TargetsInFrame tgts(frame_id++); + // 读取一帧图像到img + cap.read(img); + cv::resize(img, img, cv::Size(cod.image_width, cod.image_height)); + + // 执行通用目标检测 + cod.detect(img, tgts); + // 可视化检测结果,叠加到img上 + sv::drawTargetsInFrame(img, tgts); + + // 控制台打印通用目标检测结果 + printf("Frame-[%d]\n", frame_id); + // 打印当前检测的FPS + printf(" FPS = %.2f\n", tgts.fps); + // 打印当前相机的视场角(degree) + printf(" FOV (fx, fy) = (%.2f, %.2f)\n", tgts.fov_x, tgts.fov_y); + // 打印当前输入图像的像素宽度和高度 + printf(" Frame Size (width, height) = (%d, %d)\n", tgts.width, tgts.height); + for (int i=0; i +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; + +// 定义窗口名称 +static const std::string RGB_WINDOW = "Image window"; +// 框选到的矩形 +cv::Rect rect_sel; +// 框选起始点 +cv::Point pt_origin; + +// 是否得到一个新的框选区域 +bool b_renew_ROI = false; +// 是否开始跟踪 +bool b_begin_TRACK = false; +// 实现框选逻辑的回调函数 +void onMouse(int event, int x, int y, int, void*); + + +struct node { + double x,y; +}; +node p1,p2,p3,p4; +node p; +double getCross(node p1, node p2, node p) { + return (p2.x-p1.x)*(p.y-p1.y)-(p.x-p1.x)*(p2.y-p1.y); +} +bool b_clicked =false; +bool detect_tracking =true; + + +int main(int argc, char *argv[]) { + // 定义一个新的窗口,可在上面进行框选操作 + cv::namedWindow(RGB_WINDOW); + // 设置窗口操作回调函数,该函数实现整个框选逻辑 + cv::setMouseCallback(RGB_WINDOW, onMouse, 0); + // 实例化 框选目标跟踪类 + sv::SingleObjectTracker sot; + // 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件 + sot.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_640x480.yaml"); + + + sv::CommonObjectDetector cod; + cod.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_640x480.yaml"); + + + // 打开摄像头 + sv::Camera cap; + // cap.setWH(640, 480); + // cap.setFps(30); + cap.open(sv::CameraType::WEBCAM, 0); // CameraID 0 + // cv::VideoCapture cap("/home/amov/SpireCV/test/tracking_1280x720.mp4"); + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + int frame_id = 0; + while (1) + { + if (detect_tracking == true) { + // 实例化SpireCV的 单帧检测结果 接口类 TargetsInFrame + sv::TargetsInFrame tgts(frame_id++); + // 读取一帧图像到img + cap.read(img); + cv::resize(img, img, cv::Size(cod.image_width, cod.image_height)); + + // 执行通用目标检测 + cod.detect(img, tgts); + + // 可视化检测结果,叠加到img上 + sv::drawTargetsInFrame(img, tgts); + + // 控制台打印通用目标检测结果 + printf("Frame-[%d]\n", frame_id); + // 打印当前检测的FPS + printf(" FPS = %.2f\n", tgts.fps); + // 打印当前相机的视场角(degree) + printf(" FOV (fx, fy) = (%.2f, %.2f)\n", tgts.fov_x, tgts.fov_y); + for (int i=0; i= 0 && getCross(p2, p3, p) * getCross(p4, p1, p) >= 0) { + b_begin_TRACK = false; + detect_tracking = false; + // pt_origin = cv::Point(nor_x, nor_p_y); + // std::cout << "pt_origin " < 0) + { + printf("Frame-[%d]\n", frame_id); + // 打印 跟踪目标 的中心位置,cx,cy的值域为[0, 1] + printf(" Tracking Center (cx, cy) = (%.3f, %.3f)\n", tgts.targets[0].cx, tgts.targets[0].cy); + // 打印 跟踪目标 的外接矩形框的宽度、高度,w,h的值域为(0, 1] + printf(" Tracking Size (w, h) = (%.3f, %.3f)\n", tgts.targets[0].w, tgts.targets[0].h); + // 打印 跟踪目标 的视线角,跟相机视场相关 + printf(" Tracking Line-of-sight (ax, ay) = (%.3f, %.3f)\n", tgts.targets[0].los_ax, tgts.targets[0].los_ay); + } + } + }//end of tracking + // 显示检测结果img + cv::imshow(RGB_WINDOW, img); + cv::waitKey(10); + } + return 0; +} + +void onMouse(int event, int x, int y, int, void*) +{ + if (b_clicked) + { + // 更新框选区域坐标 + pt_origin.x = 0; + pt_origin.y = 0; + } + // 左键按下 + if (event == cv::EVENT_LBUTTONDOWN) + { + detect_tracking = true; + pt_origin = cv::Point(x, y); + } + + else if (event == cv::EVENT_RBUTTONDOWN) + { + detect_tracking = true; + b_renew_ROI = false; + b_begin_TRACK = false; + b_clicked = true; + } +} diff --git a/samples/demo/ellipse_detection.cpp b/samples/demo/ellipse_detection.cpp new file mode 100644 index 0000000..596bf2f --- /dev/null +++ b/samples/demo/ellipse_detection.cpp @@ -0,0 +1,70 @@ +#include +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; + +int main(int argc, char *argv[]) { + // 实例化 椭圆 检测器类 + sv::EllipseDetector ed; + // 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件 + ed.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_640x480.yaml"); + + // 打开摄像头 + sv::Camera cap; + // cap.setWH(640, 480); + // cap.setFps(30); + cap.open(sv::CameraType::WEBCAM, 0); // CameraID 0 + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + int frame_id = 0; + while (1) + { + // 实例化SpireCV的 单帧检测结果 接口类 TargetsInFrame + sv::TargetsInFrame tgts(frame_id++); + // 读取一帧图像到img + cap.read(img); + cv::resize(img, img, cv::Size(ed.image_width, ed.image_height)); + + // 执行 椭圆 检测 + ed.detect(img, tgts); + // 可视化检测结果,叠加到img上 + sv::drawTargetsInFrame(img, tgts); + + // 控制台打印 椭圆 检测结果 + printf("Frame-[%d]\n", frame_id); + // 打印当前检测的FPS + printf(" FPS = %.2f\n", tgts.fps); + // 打印当前相机的视场角(degree) + printf(" FOV (fx, fy) = (%.2f, %.2f)\n", tgts.fov_x, tgts.fov_y); + // 打印当前输入图像的像素宽度和高度 + printf(" Frame Size (width, height) = (%d, %d)\n", tgts.width, tgts.height); + for (int i=0; i +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; + +int main(int argc, char *argv[]) { + // 实例化 圆形降落标志 检测器类 + sv::LandingMarkerDetector lmd; + // 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件 + lmd.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_640x480.yaml"); + + // 打开摄像头 + sv::Camera cap; + // cap.setWH(640, 480); + // cap.setFps(30); + cap.open(sv::CameraType::WEBCAM, 0); // CameraID 0 + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + int frame_id = 0; + while (1) + { + // 实例化SpireCV的 单帧检测结果 接口类 TargetsInFrame + sv::TargetsInFrame tgts(frame_id++); + // 读取一帧图像到img + cap.read(img); + cv::resize(img, img, cv::Size(lmd.image_width, lmd.image_height)); + + // 执行 降落标志 检测 + lmd.detect(img, tgts); + // 可视化检测结果,叠加到img上 + sv::drawTargetsInFrame(img, tgts); + + // 控制台打印 降落标志 检测结果 + printf("Frame-[%d]\n", frame_id); + // 打印当前检测的FPS + printf(" FPS = %.2f\n", tgts.fps); + // 打印当前相机的视场角(degree) + printf(" FOV (fx, fy) = (%.2f, %.2f)\n", tgts.fov_x, tgts.fov_y); + // 打印当前输入图像的像素宽度和高度 + printf(" Frame Size (width, height) = (%d, %d)\n", tgts.width, tgts.height); + for (int i=0; i +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; + +// 定义窗口名称 +static const std::string RGB_WINDOW = "Image window"; +// 框选到的矩形 +cv::Rect rect_sel; +// 框选起始点 +cv::Point pt_origin; +// 是否按下左键 +bool b_clicked = false; +// 是否得到一个新的框选区域 +bool b_renew_ROI = false; +// 是否开始跟踪 +bool b_begin_TRACK = false; +// 实现框选逻辑的回调函数 +void onMouse(int event, int x, int y, int, void*); + +int main(int argc, char *argv[]) { + // 定义一个新的窗口,可在上面进行框选操作 + cv::namedWindow(RGB_WINDOW); + // 设置窗口操作回调函数,该函数实现整个框选逻辑 + cv::setMouseCallback(RGB_WINDOW, onMouse, 0); + // 实例化 框选目标跟踪类 + sv::SingleObjectTracker sot; + // 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件 + sot.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_640x480.yaml"); + // sot.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_1280x720.yaml"); + // sot.loadCameraParams(sv::get_home() + "/SpireCV/calib_webcam_1920x1080.yaml"); + + // 打开摄像头 + sv::Camera cap; + // cap.setWH(640, 480); + // cap.setFps(30); + cap.open(sv::CameraType::WEBCAM, 0); // CameraID 0 + // cv::VideoCapture cap("/home/amov/SpireCV/test/tracking_1280x720.mp4"); + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + int frame_id = 0; + while (1) + { + // 实例化SpireCV的 单帧检测结果 接口类 TargetsInFrame + sv::TargetsInFrame tgts(frame_id++); + // 读取一帧图像到img + cap.read(img); + cv::resize(img, img, cv::Size(sot.image_width, sot.image_height)); + + // 开始 单目标跟踪 逻辑 + // 是否有新的目标被手动框选 + if (b_renew_ROI) + { + // 拿新的框选区域 来 初始化跟踪器 + sot.init(img, rect_sel); + // std::cout << rect_sel << std::endl; + // 重置框选标志 + b_renew_ROI = false; + // 开始跟踪 + b_begin_TRACK = true; + } + else if (b_begin_TRACK) + { + // 以前一帧的结果继续跟踪 + sot.track(img, tgts); + + // 可视化检测结果,叠加到img上 + sv::drawTargetsInFrame(img, tgts); + + // 控制台打印 单目标跟踪 结果 + printf("Frame-[%d]\n", frame_id); + // 打印当前检测的FPS + printf(" FPS = %.2f\n", tgts.fps); + // 打印当前相机的视场角(degree) + printf(" FOV (fx, fy) = (%.2f, %.2f)\n", tgts.fov_x, tgts.fov_y); + // 打印当前输入图像的像素宽度和高度 + printf(" Frame Size (width, height) = (%d, %d)\n", tgts.width, tgts.height); + if (tgts.targets.size() > 0) + { + printf("Frame-[%d]\n", frame_id); + // 打印 跟踪目标 的中心位置,cx,cy的值域为[0, 1],以及cx,cy的像素值 + printf(" Tracking Center (cx, cy) = (%.3f, %.3f), in Pixels = ((%d, %d))\n", + tgts.targets[0].cx, tgts.targets[0].cy, + int(tgts.targets[0].cx * tgts.width), + int(tgts.targets[0].cy * tgts.height)); + // 打印 跟踪目标 的外接矩形框的宽度、高度,w,h的值域为(0, 1],以及w,h的像素值 + printf(" Tracking Size (w, h) = (%.3f, %.3f), in Pixels = ((%d, %d))\n", + tgts.targets[0].w, tgts.targets[0].h, + int(tgts.targets[0].w * tgts.width), + int(tgts.targets[0].h * tgts.height)); + // 打印 跟踪目标 的视线角,跟相机视场相关 + printf(" Tracking Line-of-sight (ax, ay) = (%.3f, %.3f)\n", tgts.targets[0].los_ax, tgts.targets[0].los_ay); + } + + } + + // 显示检测结果img + cv::imshow(RGB_WINDOW, img); + cv::waitKey(10); + } + + return 0; +} + +void onMouse(int event, int x, int y, int, void*) +{ + if (b_clicked) + { + // 更新框选区域坐标 + rect_sel.x = MIN(pt_origin.x, x); + rect_sel.y = MIN(pt_origin.y, y); + rect_sel.width = abs(x - pt_origin.x); + rect_sel.height = abs(y - pt_origin.y); + } + // 左键按下 + if (event == cv::EVENT_LBUTTONDOWN) + { + b_begin_TRACK = false; + b_clicked = true; + pt_origin = cv::Point(x, y); + rect_sel = cv::Rect(x, y, 0, 0); + } + // 左键松开 + else if (event == cv::EVENT_LBUTTONUP) + { + // 框选区域需要大于8x8像素 + if (rect_sel.width * rect_sel.height < 64) + { + ; + } + else + { + b_clicked = false; + b_renew_ROI = true; + } + } +} diff --git a/samples/demo/udp_detection_info_receiver.cpp b/samples/demo/udp_detection_info_receiver.cpp new file mode 100644 index 0000000..c2cef11 --- /dev/null +++ b/samples/demo/udp_detection_info_receiver.cpp @@ -0,0 +1,248 @@ +#include +#include +#include +#include +#include +#define SERV_PORT 20166 + +typedef unsigned char byte; +using namespace std; + + +int main(int argc, char *argv[]) { + sockaddr_in servaddr; + int sockfd; + + + sockfd = socket(AF_INET, SOCK_DGRAM, 0); + bzero(&servaddr, sizeof(servaddr)); + servaddr.sin_family = AF_INET; + servaddr.sin_port = htons(SERV_PORT); + servaddr.sin_addr.s_addr = htonl(INADDR_ANY); + bind(sockfd, (struct sockaddr *)&servaddr, sizeof(servaddr)); + + + int upd_msg_len = 1024 * 6; // max_objects = 100 + byte upd_msg[upd_msg_len]; + int msg_queue_len = 1024 * 1024; // 1M + byte msg_queue[msg_queue_len]; + + int addr_len = sizeof(struct sockaddr_in); + int start_index = 0, end_index = 0; + + while (1) + { + int n = recvfrom(sockfd, upd_msg, upd_msg_len, 0, (struct sockaddr *)&servaddr, reinterpret_cast(&addr_len)); + + if (end_index + n > msg_queue_len) + { + int m = end_index - start_index; + memcpy((void*) &msg_queue[0], (const void*) &msg_queue[start_index], (size_t) m); + start_index = 0; + end_index = m; + } + + memcpy((void*) &msg_queue[end_index], (const void*) upd_msg, (size_t) n); + end_index += n; + +cout << n << ", " << start_index << ", " << end_index << endl; + + // processing + while (start_index < end_index) + { + int i = start_index; + if (i > 0 && msg_queue[i-1] == 0xFA && msg_queue[i] == 0xFC) // frame start + { + cout << "FOUND 0xFAFC" << endl; + i++; + if (end_index - i >= 2) // read length + { + unsigned short* len = reinterpret_cast(&msg_queue[i]); + int ilen = (int) (*len); + cout << "LEN: " << ilen << endl; + if (end_index - i >= ilen + 2 && msg_queue[i+ilen] == 0xFB && msg_queue[i+ilen+1] == 0xFD) + { + cout << "FOUND 0xFAFC & 0xFBFD" << endl; + byte* msg_type = reinterpret_cast(&msg_queue[i+2]); + cout << "Type: " << (int) *msg_type << endl; + unsigned short* year = reinterpret_cast(&msg_queue[i+7]); + byte* month = reinterpret_cast(&msg_queue[i+9]); + byte* day = reinterpret_cast(&msg_queue[i+10]); + byte* hour = reinterpret_cast(&msg_queue[i+11]); + byte* minute = reinterpret_cast(&msg_queue[i+12]); + byte* second = reinterpret_cast(&msg_queue[i+13]); + unsigned short* millisecond = reinterpret_cast(&msg_queue[i+14]); + cout << "Time: " << *year << "-" << (int) *month << "-" << (int) *day << " " << (int) *hour << ":" << (int) *minute << ":" << (int) *second << " " << *millisecond << endl; + + byte* index_d1 = reinterpret_cast(&msg_queue[i+16]); + byte* index_d2 = reinterpret_cast(&msg_queue[i+17]); + byte* index_d3 = reinterpret_cast(&msg_queue[i+18]); + byte* index_d4 = reinterpret_cast(&msg_queue[i+19]); + int mp = i+20; + if ((*index_d4) & 0x01 == 0x01) + { + int* frame_id = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << "FrameID: " << *frame_id << endl; + } + if ((*index_d4) & 0x02 == 0x02 && (*index_d4) & 0x04 == 0x04) + { + int* width = reinterpret_cast(&msg_queue[mp]); + mp += 4; + int* height = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << "FrameSize: (" << *width << ", " << *height << ")" << endl; + } + int n_objects = 0; + if ((*index_d4) & 0x08 == 0x08) + { + n_objects = *reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << "N_Objects: " << n_objects << endl; + } + if ((*index_d4) & 0x10 == 0x10) + { + float* fps = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << "FPS: " << *fps << endl; + } + if ((*index_d4) & 0x20 == 0x20 && (*index_d4) & 0x40 == 0x40) + { + float* fov_x = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* fov_y = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << "FOV: (" << *fov_x << ", " << *fov_y << ")" << endl; + } + if ((*index_d4) & 0x80 == 0x80 && (*index_d3) & 0x01 == 0x01 && (*index_d3) & 0x02 == 0x02) + { + float* pod_patch = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* pod_roll = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* pod_yaw = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << "POD-Angles: (" << *pod_patch << ", " << *pod_roll << ", " << *pod_yaw << ")" << endl; + } + if ((*index_d3) & 0x04 == 0x04 && (*index_d3) & 0x08 == 0x08 && (*index_d3) & 0x10 == 0x10) + { + float* longitude = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* latitude = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* altitude = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << "UAV-Position: (" << *longitude << ", " << *latitude << ", " << *altitude << ")" << endl; + } + if ((*index_d3) & 0x20 == 0x20 && (*index_d3) & 0x40 == 0x40 && (*index_d3) & 0x80 == 0x80) + { + float* uav_vx = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* uav_vy = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* uav_vz = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << "UAV-Speed: (" << *uav_vx << ", " << *uav_vy << ", " << *uav_vz << ")" << endl; + } + if ((*index_d2) & 0x01 == 0x01) + { + float* illumination = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << "Illumination: " << *illumination << endl; + } + for (int j=0; j(&msg_queue[mp]); + mp++; + byte* index_f2 = reinterpret_cast(&msg_queue[mp]); + mp++; + byte* index_f3 = reinterpret_cast(&msg_queue[mp]); + mp++; + byte* index_f4 = reinterpret_cast(&msg_queue[mp]); + mp++; + if ((*index_f4) & 0x01 == 0x01 && (*index_f4) & 0x02 == 0x02) + { + float* cx = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* cy = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << " Object-[" << j+1 << "]-CXCY: (" << *cx << ", " << *cy << ")" << endl; + } + if ((*index_f4) & 0x04 == 0x04 && (*index_f4) & 0x08 == 0x08) + { + float* w = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* h = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << " Object-[" << j+1 << "]-WH: (" << *w << ", " << *h << ")" << endl; + } + if ((*index_f4) & 0x10 == 0x10) + { + float* score = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << " Object-[" << j+1 << "]-Score: " << *score << endl; + } + if ((*index_f4) & 0x20 == 0x20) + { + int* category_id = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << " Object-[" << j+1 << "]-CateID: " << *category_id << endl; + } + if ((*index_f4) & 0x40 == 0x40) + { + int* tracked_id = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << " Object-[" << j+1 << "]-TrackID: " << *tracked_id << endl; + } + if ((*index_f4) & 0x80 == 0x80 && (*index_f3) & 0x01 == 0x01 && (*index_f3) & 0x02 == 0x02) + { + float* px = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* py = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* pz = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << " Object-[" << j+1 << "]-Position: (" << *px << ", " << *py << ", " << *pz << ")" << endl; + } + if ((*index_f3) & 0x04 == 0x04 && (*index_f3) & 0x08 == 0x08) + { + float* los_ax = reinterpret_cast(&msg_queue[mp]); + mp += 4; + float* los_ay = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << " Object-[" << j+1 << "]-LOS: (" << *los_ax << ", " << *los_ay << ")" << endl; + } + if ((*index_f3) & 0x10 == 0x10) + { + float* yaw_a = reinterpret_cast(&msg_queue[mp]); + mp += 4; + cout << " Object-[" << j+1 << "]-YAW: " << *yaw_a << endl; + } + } + + start_index += ilen + 4; + } + else if (end_index - i < ilen + 2) + { + break; + } + else + { + start_index++; + } + } + else + { + break; + } + } + else + { + start_index++; + } + } + + } + + return 0; +} diff --git a/samples/demo/udp_detection_info_sender.cpp b/samples/demo/udp_detection_info_sender.cpp new file mode 100644 index 0000000..cd25812 --- /dev/null +++ b/samples/demo/udp_detection_info_sender.cpp @@ -0,0 +1,60 @@ +#include +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; + +int main(int argc, char *argv[]) { + // 实例化Aruco检测器类 + sv::ArucoDetector ad; + // 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件 + ad.loadCameraParams("/home/amov/SpireCV/calib_webcam_640x480.yaml"); + + // 打开摄像头 + sv::Camera cap; + cap.setWH(640, 480); + cap.setFps(30); + cap.open(sv::CameraType::WEBCAM, 0); // CameraID 0 + + sv::UDPServer udp; + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + int frame_id = 0; + while (1) + { + // 实例化SpireCV的 单帧检测结果 接口类 TargetsInFrame + sv::TargetsInFrame tgts(frame_id++); + // 读取一帧图像到img + cap.read(img); + + // 执行Aruco二维码检测 + ad.detect(img, tgts); + + tgts.has_pod_info = true; + tgts.pod_patch = 1; + tgts.pod_roll = 2; + tgts.pod_yaw = 3; + tgts.has_uav_pos = true; + tgts.longitude = 1.1234567; + tgts.latitude = 2.2345678; + tgts.altitude = 3.3456789; + tgts.has_uav_vel = true; + tgts.uav_vx = 4; + tgts.uav_vy = 5; + tgts.uav_vz = 6; + tgts.has_ill = true; + tgts.illumination = 7; + + // www.write(img, tgts); + udp.send(tgts); + // 可视化检测结果,叠加到img上 + sv::drawTargetsInFrame(img, tgts); + + // 显示检测结果img + cv::imshow("img", img); + cv::waitKey(10); + } + + return 0; +} diff --git a/samples/demo/video_saving.cpp b/samples/demo/video_saving.cpp new file mode 100644 index 0000000..d4f7414 --- /dev/null +++ b/samples/demo/video_saving.cpp @@ -0,0 +1,51 @@ +#include +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; + +int main(int argc, char *argv[]) { + // 实例化 通用目标 检测器类 + sv::CommonObjectDetector cod; + // 手动导入相机参数,如果使用Amov的G1等吊舱或相机,则可以忽略该步骤,将自动下载相机参数文件 + cod.loadCameraParams("/home/amov/SpireCV/calib_webcam_640x480.yaml"); + + // 打开摄像头 + sv::Camera cap; + // cap.setWH(640, 480); + // cap.setFps(30); + cap.open(sv::CameraType::WEBCAM, 0); // CameraID 0 + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + int frame_id = 0; + + // 实例化视频保存类 + sv::VideoWriter vw; + // 设置保存路径"/home/amov/Videos",保存图像尺寸(640,480),帧频25Hz,同步保存检测结果(.svj) + vw.setup("/home/amov/Videos", cv::Size(640, 480), 25, true); + + while (1) + { + // 实例化SpireCV的 单帧检测结果 接口类 TargetsInFrame + sv::TargetsInFrame tgts(frame_id++); + // 读取一帧图像到img + cap.read(img); + cv::resize(img, img, cv::Size(640, 480)); + + // 执行通用目标检测 + cod.detect(img, tgts); + + // 同步保存视频流 和 检测结果信息 + vw.write(img, tgts); + + // 可视化检测结果,叠加到img上 + sv::drawTargetsInFrame(img, tgts); + + // 显示检测结果img + cv::imshow("img", img); + cv::waitKey(10); + } + + return 0; +} diff --git a/samples/demo/video_streaming.cpp b/samples/demo/video_streaming.cpp new file mode 100644 index 0000000..406b5b2 --- /dev/null +++ b/samples/demo/video_streaming.cpp @@ -0,0 +1,35 @@ +#include +#include +// 包含SpireCV SDK头文件 +#include + +using namespace std; + +int main(int argc, char *argv[]) { + // 打开摄像头 + sv::Camera cap; + // cap.setWH(1280, 720); + // cap.setFps(30); + cap.open(sv::CameraType::WEBCAM, 0); // CameraID 0 + + // 实例化视频推流类sv::VideoStreamer + sv::VideoStreamer streamer; + // 初始化 推流分辨率(640, 480),端口号8554,比特率1Mb + streamer.setup(cv::Size(1280, 720), 8554, 1); + // 实例化OpenCV的Mat类,用于内存单帧图像 + cv::Mat img; + while (1) + { + // 读取一帧图像到img + cap.read(img); + cv::resize(img, img, cv::Size(1280, 720)); + // 将img推流到 地址:rtsp://ip:8554/live + streamer.stream(img); + + // 显示检测结果img + cv::imshow("img", img); + cv::waitKey(10); + } + + return 0; +} diff --git a/scripts/common/download_test_videos.sh b/scripts/common/download_test_videos.sh new file mode 100644 index 0000000..f7498f0 --- /dev/null +++ b/scripts/common/download_test_videos.sh @@ -0,0 +1,28 @@ +#!/bin/sh + +wget https://download.amovlab.com/model/install/benchmark/aruco_1280x720.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/aruco_640x480.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/car_1280x720.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/car_640x480.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/drone_1280x720.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/drone_640x480.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/ellipse_1280x720.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/ellipse_640x480.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/landing_1280x720.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/landing_640x480.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/tracking_1280x720.mp4 -P ${HOME}/SpireCV/test +wget https://download.amovlab.com/model/install/benchmark/tracking_640x480.mp4 -P ${HOME}/SpireCV/test + + +wget https://download.amovlab.com/model/install/c-params/calib_webcam_1280x720.yaml -P ${HOME}/SpireCV +wget https://download.amovlab.com/model/install/c-params/calib_webcam_1920x1080.yaml -P ${HOME}/SpireCV + + +wget https://download.amovlab.com/model/install/a-params/sv_algorithm_params_1280_wo_mask.json -P ${HOME}/SpireCV +wget https://download.amovlab.com/model/install/a-params/sv_algorithm_params_640_w_mask.json -P ${HOME}/SpireCV +wget https://download.amovlab.com/model/install/a-params/sv_algorithm_params_640_wo_mask.json -P ${HOME}/SpireCV +wget https://download.amovlab.com/model/install/a-params/sv_algorithm_params_csrt.json -P ${HOME}/SpireCV +wget https://download.amovlab.com/model/install/a-params/sv_algorithm_params_kcf.json -P ${HOME}/SpireCV +wget https://download.amovlab.com/model/install/a-params/sv_algorithm_params_siamrpn.json -P ${HOME}/SpireCV +wget https://download.amovlab.com/model/install/a-params/sv_algorithm_params_nano.json -P ${HOME}/SpireCV + diff --git a/scripts/common/ffmpeg425-install.sh b/scripts/common/ffmpeg425-install.sh new file mode 100644 index 0000000..acc16db --- /dev/null +++ b/scripts/common/ffmpeg425-install.sh @@ -0,0 +1,63 @@ +#!/bin/sh + + +sudo apt install -y \ +build-essential yasm cmake libtool libc6 libc6-dev unzip wget libfmt-dev \ +libnuma1 libnuma-dev libx264-dev libx265-dev libfaac-dev libssl-dev + +root_dir=${HOME}"/SpireCV" +if [ ! -d ${root_dir} ]; then + echo -e "\033[32m[INFO]: ${root_dir} not exist, creating it ... \033[0m" + mkdir -p ${root_dir} +fi +cd ${root_dir} + + +git clone https://gitee.com/jario-jin/nv-codec-headers.git +cd nv-codec-headers +git checkout n11.1.5.0 +sudo make install +cd .. + +wget https://ffmpeg.org/releases/ffmpeg-4.2.5.tar.bz2 +tar -xjf ffmpeg-4.2.5.tar.bz2 +cd ffmpeg-4.2.5 +export PATH=$PATH:/usr/local/cuda/bin +sed -i 's#_30#_75#' configure; sed -i 's#_30#_75#' configure +./configure \ +--enable-nonfree \ +--enable-gpl \ +--enable-shared \ +--enable-ffmpeg \ +--enable-ffplay \ +--enable-ffprobe \ +--enable-libx264 \ +--enable-libx265 \ +--enable-cuda-nvcc \ +--enable-nvenc \ +--enable-cuda \ +--enable-cuvid \ +--enable-libnpp \ +--extra-libs="-lpthread -lm" \ +--extra-cflags=-I/usr/local/cuda/include \ +--extra-ldflags=-L/usr/local/cuda/lib64 +make -j8 +sudo make install +cd .. + +git clone https://gitee.com/jario-jin/ZLMediaKit.git +cd ZLMediaKit +git submodule update --init +mkdir build +cd build +cmake .. +make -j4 +cd .. +cd .. + +mkdir ZLM +cd ZLM +cp ../ZLMediaKit/release/linux/Debug/MediaServer . +cp ../ZLMediaKit/release/linux/Debug/config.ini . +cd .. + diff --git a/scripts/common/gst-install-orin.sh b/scripts/common/gst-install-orin.sh new file mode 100644 index 0000000..90510f5 --- /dev/null +++ b/scripts/common/gst-install-orin.sh @@ -0,0 +1,23 @@ +#!/bin/sh + +sudo apt install -y libgstreamer1.0-dev libgstreamer-plugins-base1.0-dev +sudo apt install -y libgstreamer-plugins-bad1.0-dev gstreamer1.0-plugins-base +sudo apt install -y gstreamer1.0-plugins-good gstreamer1.0-plugins-bad +sudo apt install -y gstreamer1.0-plugins-ugly gstreamer1.0-libav gstreamer1.0-doc +sudo apt install -y gstreamer1.0-tools gstreamer1.0-x gstreamer1.0-alsa +sudo apt install -y gstreamer1.0-gl gstreamer1.0-gtk3 gstreamer1.0-qt5 +sudo apt install -y gstreamer1.0-pulseaudio +sudo apt install -y gtk-doc-tools + +sudo apt -y install autotools-dev automake m4 perl +sudo apt -y install libtool +autoreconf -ivf + +git clone https://gitee.com/jario-jin/gst-rtsp-server-b18.git +cd gst-rtsp-server-b18 +./autogen.sh +make +sudo make install +cd .. +sudo rm -r gst-rtsp-server-b18 + diff --git a/scripts/common/gst-install.sh b/scripts/common/gst-install.sh new file mode 100644 index 0000000..f629213 --- /dev/null +++ b/scripts/common/gst-install.sh @@ -0,0 +1,19 @@ +#!/bin/sh + +sudo apt install -y libgstreamer1.0-dev libgstreamer-plugins-base1.0-dev +sudo apt install -y libgstreamer-plugins-bad1.0-dev gstreamer1.0-plugins-base +sudo apt install -y gstreamer1.0-plugins-good gstreamer1.0-plugins-bad +sudo apt install -y gstreamer1.0-plugins-ugly gstreamer1.0-libav gstreamer1.0-doc +sudo apt install -y gstreamer1.0-tools gstreamer1.0-x gstreamer1.0-alsa +sudo apt install -y gstreamer1.0-gl gstreamer1.0-gtk3 gstreamer1.0-qt5 +sudo apt install -y gstreamer1.0-pulseaudio +sudo apt install -y gtk-doc-tools + +git clone https://gitee.com/jario-jin/gst-rtsp-server-b18.git +cd gst-rtsp-server-b18 +./autogen.sh +make +sudo make install +cd .. +sudo rm -r gst-rtsp-server-b18 + diff --git a/scripts/common/models-converting.sh b/scripts/common/models-converting.sh new file mode 100644 index 0000000..6e08d2b --- /dev/null +++ b/scripts/common/models-converting.sh @@ -0,0 +1,41 @@ +#!/bin/bash -e +export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + +root_dir=${HOME}"/SpireCV/models" + + +coco_model1="COCO-yolov5s.wts" +coco_model2="COCO-yolov5s6.wts" +coco_model3="COCO-yolov5s-seg.wts" +coco_model1_fn=${root_dir}/${coco_model1} +coco_model2_fn=${root_dir}/${coco_model2} +coco_model3_fn=${root_dir}/${coco_model3} + +drone_model1="Drone-yolov5s-ap935-v230302.wts" +drone_model2="Drone-yolov5s6-ap949-v230302.wts" +drone_model1_fn=${root_dir}/${drone_model1} +drone_model2_fn=${root_dir}/${drone_model2} + +personcar_model1="PersonCar-yolov5s-ap606-v230306.wts" +personcar_model2="PersonCar-yolov5s6-ap702-v230306.wts" +personcar_model1_fn=${root_dir}/${personcar_model1} +personcar_model2_fn=${root_dir}/${personcar_model2} + +landing_model1="LandingMarker-resnet34-v230228.onnx" +landing_model1_fn=${root_dir}/${landing_model1} + +SpireCVDet -s ${coco_model1_fn} ${root_dir}/COCO.engine 80 s +SpireCVDet -s ${coco_model2_fn} ${root_dir}/COCO_HD.engine 80 s6 +SpireCVSeg -s ${coco_model3_fn} ${root_dir}/COCO_SEG.engine 80 s + +SpireCVDet -s ${drone_model1_fn} ${root_dir}/Drone.engine 1 s +SpireCVDet -s ${drone_model2_fn} ${root_dir}/Drone_HD.engine 1 s6 + +SpireCVDet -s ${personcar_model1_fn} ${root_dir}/PersonCar.engine 8 s +SpireCVDet -s ${personcar_model2_fn} ${root_dir}/PersonCar_HD.engine 8 s6 + +cd /usr/src/tensorrt/bin/ +./trtexec --explicitBatch --onnx=${landing_model1_fn} --saveEngine=${root_dir}/LandingMarker.engine --fp16 + +echo "export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH" >> ~/.bashrc + diff --git a/scripts/common/models-downloading.sh b/scripts/common/models-downloading.sh new file mode 100644 index 0000000..f884534 --- /dev/null +++ b/scripts/common/models-downloading.sh @@ -0,0 +1,107 @@ +#!/bin/bash -e + +root_dir=${HOME}"/SpireCV/models" +root_server="https://download.amovlab.com/model" + +sv_params1=${HOME}"/SpireCV/sv_algorithm_params.json" +sv_params2=${HOME}"/SpireCV/sv_algorithm_params_coco_640.json" +sv_params3=${HOME}"/SpireCV/sv_algorithm_params_coco_1280.json" +camera_params1=${HOME}"/SpireCV/calib_webcam_640x480.yaml" +camera_params2=${HOME}"/SpireCV/calib_webcam_1280x720.yaml" + +coco_model1="COCO-yolov5s.wts" +coco_model2="COCO-yolov5s6.wts" +coco_model3="COCO-yolov5s-seg.wts" +coco_model1_fn=${root_dir}/${coco_model1} +coco_model2_fn=${root_dir}/${coco_model2} +coco_model3_fn=${root_dir}/${coco_model3} + +drone_model1="Drone-yolov5s-ap935-v230302.wts" +drone_model2="Drone-yolov5s6-ap949-v230302.wts" +drone_model1_fn=${root_dir}/${drone_model1} +drone_model2_fn=${root_dir}/${drone_model2} + +personcar_model1="PersonCar-yolov5s-ap606-v230306.wts" +personcar_model2="PersonCar-yolov5s6-ap702-v230306.wts" +personcar_model1_fn=${root_dir}/${personcar_model1} +personcar_model2_fn=${root_dir}/${personcar_model2} + +track_model1="dasiamrpn_model.onnx" +track_model2="dasiamrpn_kernel_cls1.onnx" +track_model3="dasiamrpn_kernel_r1.onnx" +track_model4="nanotrack_backbone_sim.onnx" +track_model5="nanotrack_head_sim.onnx" +track_model1_fn=${root_dir}/${track_model1} +track_model2_fn=${root_dir}/${track_model2} +track_model3_fn=${root_dir}/${track_model3} +track_model4_fn=${root_dir}/${track_model4} +track_model5_fn=${root_dir}/${track_model5} + +landing_model1="LandingMarker-resnet34-v230228.onnx" +landing_model1_fn=${root_dir}/${landing_model1} + + +if [ ! -d ${root_dir} ]; then + echo -e "\033[32m[INFO]: ${root_dir} not exist, creating it ... \033[0m" + mkdir -p ${root_dir} +fi + +if [ ! -f ${sv_params1} ]; then + echo -e "\033[32m[INFO]: ${sv_params1} not exist, downloading ... \033[0m" + wget -O ${sv_params1} ${root_server}/install/a-params/sv_algorithm_params.json +fi +if [ ! -f ${sv_params2} ]; then + echo -e "\033[32m[INFO]: ${sv_params2} not exist, downloading ... \033[0m" + wget -O ${sv_params2} ${root_server}/install/a-params/sv_algorithm_params_coco_640.json +fi +if [ ! -f ${sv_params3} ]; then + echo -e "\033[32m[INFO]: ${sv_params3} not exist, downloading ... \033[0m" + wget -O ${sv_params3} ${root_server}/install/a-params/sv_algorithm_params_coco_1280.json +fi + +if [ ! -f ${camera_params1} ]; then + echo -e "\033[32m[INFO]: ${camera_params1} not exist, downloading ... \033[0m" + wget -O ${camera_params1} ${root_server}/install/c-params/calib_webcam_640x480.yaml +fi +if [ ! -f ${camera_params2} ]; then + echo -e "\033[32m[INFO]: ${camera_params2} not exist, downloading ... \033[0m" + wget -O ${camera_params2} ${root_server}/install/c-params/calib_webcam_1280x720.yaml +fi + +if [ ! -f ${coco_model1_fn} ]; then + echo -e "\033[32m[INFO]: ${coco_model1_fn} not exist, downloading ... \033[0m" + wget -O ${coco_model1_fn} ${root_server}/install/${coco_model1} + wget -O ${coco_model2_fn} ${root_server}/install/${coco_model2} + wget -O ${coco_model3_fn} ${root_server}/install/${coco_model3} +fi + +if [ ! -f ${drone_model1_fn} ]; then + echo -e "\033[32m[INFO]: ${drone_model1_fn} not exist, downloading ... \033[0m" + wget -O ${drone_model1_fn} ${root_server}/install/${drone_model1} + wget -O ${drone_model2_fn} ${root_server}/install/${drone_model2} +fi + +if [ ! -f ${personcar_model1_fn} ]; then + echo -e "\033[32m[INFO]: ${personcar_model1_fn} not exist, downloading ... \033[0m" + wget -O ${personcar_model1_fn} ${root_server}/install/${personcar_model1} + wget -O ${personcar_model2_fn} ${root_server}/install/${personcar_model2} +fi + +if [ ! -f ${track_model1_fn} ]; then + echo -e "\033[32m[INFO]: ${track_model1_fn} not exist, downloading ... \033[0m" + wget -O ${track_model1_fn} ${root_server}/${track_model1} + wget -O ${track_model2_fn} ${root_server}/${track_model2} + wget -O ${track_model3_fn} ${root_server}/${track_model3} +fi + +if [ ! -f ${track_model4_fn} ]; then + echo -e "\033[32m[INFO]: ${track_model4_fn} not exist, downloading ... \033[0m" + wget -O ${track_model4_fn} ${root_server}/${track_model4} + wget -O ${track_model5_fn} ${root_server}/${track_model5} +fi + +if [ ! -f ${landing_model1_fn} ]; then + echo -e "\033[32m[INFO]: ${landing_model1_fn} not exist, downloading ... \033[0m" + wget -O ${landing_model1_fn} ${root_server}/install/${landing_model1} +fi + diff --git a/scripts/common/opencv470-install.sh b/scripts/common/opencv470-install.sh new file mode 100644 index 0000000..69c88e6 --- /dev/null +++ b/scripts/common/opencv470-install.sh @@ -0,0 +1,56 @@ +#!/bin/sh + + +wget https://download.amovlab.com/model/deps/opencv-4.7.0.zip +wget https://download.amovlab.com/model/deps/opencv_contrib-4.7.0.zip +wget https://download.amovlab.com/model/deps/opencv_cache-4.7.0.zip + + +package_dir="." +mkdir ~/opencv_build + + +if [ ! -d ""$package_dir"" ];then + echo "\033[31m[ERROR]: $package_dir not exist!: \033[0m" + exit 1 +fi + +sudo add-apt-repository "deb http://security.ubuntu.com/ubuntu xenial-security main" +sudo add-apt-repository "deb http://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/ xenial main multiverse restricted universe" +sudo apt update +sudo apt install -y build-essential +sudo apt install -y cmake git libgtk2.0-dev pkg-config libavcodec-dev libavformat-dev libswscale-dev +sudo apt install -y libjasper1 libjasper-dev +sudo apt install -y python3-dev python3-numpy libtbb2 libtbb-dev libjpeg-dev libpng-dev libtiff-dev +sudo apt install -y libdc1394-22-dev + + +echo "\033[32m[INFO]:\033[0m unzip opencv-4.7.0.zip ..." +unzip -q -o $package_dir/opencv-4.7.0.zip -d ~/opencv_build + +echo "\033[32m[INFO]:\033[0m unzip opencv_contrib-4.7.0.zip ..." +unzip -q -o $package_dir/opencv_contrib-4.7.0.zip -d ~/opencv_build + +echo "\033[32m[INFO]:\033[0m unzip opencv_cache-4.7.0.zip ..." +unzip -q -o $package_dir/opencv_cache-4.7.0.zip -d ~/opencv_build + + +sudo rm opencv-4.7.0.zip +sudo rm opencv_contrib-4.7.0.zip +sudo rm opencv_cache-4.7.0.zip + +cd ~/opencv_build/opencv-4.7.0 +mkdir .cache + +cp -r ~/opencv_build/opencv_cache-4.7.0/* ~/opencv_build/opencv-4.7.0/.cache/ + +mkdir build +cd build + +cmake -D CMAKE_BUILD_TYPE=Release -D WITH_CUDA=OFF -D OPENCV_ENABLE_NONFREE=ON -D CMAKE_INSTALL_PREFIX=/usr/local -D OPENCV_EXTRA_MODULES_PATH=../../opencv_contrib-4.7.0/modules .. + +make -j2 +sudo make install + +cd +sudo rm -r ~/opencv_build diff --git a/scripts/jetson/opencv470-jetpack511-install.sh b/scripts/jetson/opencv470-jetpack511-install.sh new file mode 100644 index 0000000..50fb1fb --- /dev/null +++ b/scripts/jetson/opencv470-jetpack511-install.sh @@ -0,0 +1,54 @@ +#!/bin/sh + + +wget https://download.amovlab.com/model/deps/opencv-4.7.0.zip +wget https://download.amovlab.com/model/deps/opencv_contrib-4.7.0.zip +wget https://download.amovlab.com/model/deps/opencv_cache-4.7.0.zip + + +package_dir="." +mkdir ~/opencv_build + + +if [ ! -d ""$package_dir"" ];then + echo "\033[31m[ERROR]: $package_dir not exist!: \033[0m" + exit 1 +fi + +sudo apt update +sudo apt install -y build-essential +sudo apt install -y cmake git libgtk2.0-dev pkg-config libavcodec-dev libavformat-dev libswscale-dev +sudo apt install -y libjasper1 libjasper-dev +sudo apt install -y python3-dev python3-numpy libtbb2 libtbb-dev libjpeg-dev libpng-dev libtiff-dev +sudo apt install -y libdc1394-22-dev + + +echo "\033[32m[INFO]:\033[0m unzip opencv-4.7.0.zip ..." +unzip -q -o $package_dir/opencv-4.7.0.zip -d ~/opencv_build + +echo "\033[32m[INFO]:\033[0m unzip opencv_contrib-4.7.0.zip ..." +unzip -q -o $package_dir/opencv_contrib-4.7.0.zip -d ~/opencv_build + +echo "\033[32m[INFO]:\033[0m unzip opencv_cache-4.7.0.zip ..." +unzip -q -o $package_dir/opencv_cache-4.7.0.zip -d ~/opencv_build + + +sudo rm opencv-4.7.0.zip +sudo rm opencv_contrib-4.7.0.zip +sudo rm opencv_cache-4.7.0.zip + +cd ~/opencv_build/opencv-4.7.0 +mkdir .cache + +cp -r ~/opencv_build/opencv_cache-4.7.0/* ~/opencv_build/opencv-4.7.0/.cache/ + +mkdir build +cd build + +cmake -D CMAKE_BUILD_TYPE=Release -D WITH_CUDA=OFF -D OPENCV_ENABLE_NONFREE=ON -D CMAKE_INSTALL_PREFIX=/usr/local -D OPENCV_EXTRA_MODULES_PATH=../../opencv_contrib-4.7.0/modules .. + +make -j2 +sudo make install + +cd +sudo rm -r ~/opencv_build diff --git a/scripts/x86-cuda/x86-gst-install.sh b/scripts/x86-cuda/x86-gst-install.sh new file mode 100644 index 0000000..f629213 --- /dev/null +++ b/scripts/x86-cuda/x86-gst-install.sh @@ -0,0 +1,19 @@ +#!/bin/sh + +sudo apt install -y libgstreamer1.0-dev libgstreamer-plugins-base1.0-dev +sudo apt install -y libgstreamer-plugins-bad1.0-dev gstreamer1.0-plugins-base +sudo apt install -y gstreamer1.0-plugins-good gstreamer1.0-plugins-bad +sudo apt install -y gstreamer1.0-plugins-ugly gstreamer1.0-libav gstreamer1.0-doc +sudo apt install -y gstreamer1.0-tools gstreamer1.0-x gstreamer1.0-alsa +sudo apt install -y gstreamer1.0-gl gstreamer1.0-gtk3 gstreamer1.0-qt5 +sudo apt install -y gstreamer1.0-pulseaudio +sudo apt install -y gtk-doc-tools + +git clone https://gitee.com/jario-jin/gst-rtsp-server-b18.git +cd gst-rtsp-server-b18 +./autogen.sh +make +sudo make install +cd .. +sudo rm -r gst-rtsp-server-b18 + diff --git a/scripts/x86-cuda/x86-opencv470-install.sh b/scripts/x86-cuda/x86-opencv470-install.sh new file mode 100644 index 0000000..89200a9 --- /dev/null +++ b/scripts/x86-cuda/x86-opencv470-install.sh @@ -0,0 +1,56 @@ +#!/bin/sh + + +wget https://download.amovlab.com/model/deps/opencv-4.7.0.zip +wget https://download.amovlab.com/model/deps/opencv_contrib-4.7.0.zip +wget https://download.amovlab.com/model/deps/opencv_cache_x86-4.7.0.zip + + +package_dir="." +mkdir ~/opencv_build + + +if [ ! -d ""$package_dir"" ];then + echo "\033[31m[ERROR]: $package_dir not exist!: \033[0m" + exit 1 +fi + +# sudo add-apt-repository "deb http://security.ubuntu.com/ubuntu xenial-security main" +# sudo add-apt-repository "deb http://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/ xenial main multiverse restricted universe" +sudo apt update +sudo apt install -y build-essential +sudo apt install -y cmake git libgtk2.0-dev pkg-config libavcodec-dev libavformat-dev libswscale-dev +sudo apt install -y libjasper1 libjasper-dev +sudo apt install -y python3-dev python3-numpy libtbb2 libtbb-dev libjpeg-dev libpng-dev libtiff-dev +sudo apt install -y libdc1394-22-dev + + +echo "\033[32m[INFO]:\033[0m unzip opencv-4.7.0.zip ..." +unzip -q -o $package_dir/opencv-4.7.0.zip -d ~/opencv_build + +echo "\033[32m[INFO]:\033[0m unzip opencv_contrib-4.7.0.zip ..." +unzip -q -o $package_dir/opencv_contrib-4.7.0.zip -d ~/opencv_build + +echo "\033[32m[INFO]:\033[0m unzip opencv_cache_x86-4.7.0.zip ..." +unzip -q -o $package_dir/opencv_cache_x86-4.7.0.zip -d ~/opencv_build + + +sudo rm opencv-4.7.0.zip +sudo rm opencv_contrib-4.7.0.zip +sudo rm opencv_cache_x86-4.7.0.zip + +cd ~/opencv_build/opencv-4.7.0 +mkdir .cache + +cp -r ~/opencv_build/opencv_cache_x86-4.7.0/* ~/opencv_build/opencv-4.7.0/.cache/ + +mkdir build +cd build + +cmake -D CMAKE_BUILD_TYPE=Release -D WITH_CUDA=OFF -D OPENCV_ENABLE_NONFREE=ON -D CMAKE_INSTALL_PREFIX=/usr/local -D OPENCV_EXTRA_MODULES_PATH=../../opencv_contrib-4.7.0/modules .. + +make -j2 +sudo make install + +cd +sudo rm -r ~/opencv_build diff --git a/scripts/x86-cuda/x86-ubuntu2004-cuda-cudnn-11-6.sh b/scripts/x86-cuda/x86-ubuntu2004-cuda-cudnn-11-6.sh new file mode 100644 index 0000000..947fc67 --- /dev/null +++ b/scripts/x86-cuda/x86-ubuntu2004-cuda-cudnn-11-6.sh @@ -0,0 +1,98 @@ +#!/bin/sh + +echo -e "\033[32m[INFO]:\033[0m Please enter the folder path of the installation package: " +# package_dir="/home/jario/Downloads/nv" + +wget https://download.amovlab.com/model/install/x86-nvidia/cuda-repo-ubuntu2004-11-6-local_11.6.2-510.47.03-1_amd64.deb +wget https://download.amovlab.com/model/install/x86-nvidia/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz +wget https://download.amovlab.com/model/install/x86-nvidia/nv-tensorrt-repo-ubuntu2004-cuda11.6-trt8.4.0.6-ea-20220212_1-1_amd64.deb +wget https://download.amovlab.com/model/install/x86-nvidia/cuda-ubuntu2004.pin + +package_dir="." + +cuda_fn=$package_dir"/cuda-repo-ubuntu2004-11-6-local_11.6.2-510.47.03-1_amd64.deb" +cudnn_fn=$package_dir"/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz" +tensorrt_fn=$package_dir"/nv-tensorrt-repo-ubuntu2004-cuda11.6-trt8.4.0.6-ea-20220212_1-1_amd64.deb" +tmp_dir="/tmp" + +# https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.0/local_repos/nv-tensorrt-repo-ubuntu2004-cuda11.6-trt8.4.0.6-ea-20220212_1-1_amd64.deb + +echo -e "\033[32m[INFO]: CUDA_PKG: \033[0m"$cuda_fn +echo -e "\033[32m[INFO]: CUDNN_PKG: \033[0m"$cudnn_fn +echo -e "\033[32m[INFO]: TENSORRT_PKG: \033[0m"$tensorrt_fn + +# 所有文件都存在时,才会继续执行脚本 +if [ ! -f "$cuda_fn" ]; then + echo -e "\033[31m[ERROR]: CUDA_PKG not exist!: \033[0m" + exit 1 +fi + +if [ ! -f "$cudnn_fn" ]; then + echo -e "\033[31m[ERROR]: CUDNN_PKG not exist!: \033[0m" + exit 1 +fi + +if [ ! -f "$tensorrt_fn" ]; then + echo -e "\033[31m[ERROR]: TENSORRT_PKG not exist!: \033[0m" + exit 1 +fi + +# 删除显卡驱动 +# sudo apt-get remove nvidia-* + +# 安装显卡驱动 +# echo -e "\033[32m[INFO]: Nvidia Driver installing ...\033[0m" +# sudo add-apt-repository ppa:graphics-drivers/ppa +# ubuntu-drivers devices +# sudo ubuntu-drivers autoinstall +# sudo apt-get install nvidia-driver-465 + +# 删除已安装CUDA +# --purge选项会将配置文件、数据库等删除 +# sudo apt-get autoremove --purge cuda +# sudo apt-get purge nvidia-* +# 查看安装了哪些cuda相关的库,可以用以下指令 +# sudo dpkg -l |grep cuda +# sudo dpkg -P cuda-repo-ubuntu1804-10-2-local-10.2.89-440.33.01 +# sudo dpkg -P cuda-repo-ubuntu1804-11-1-local +# sudo dpkg -P nv-tensorrt-repo-ubuntu1804-cuda10.2-trt8.0.1.6-ga-20210626 +# 这个key值是官网文档查到的,当然通过apt-key list也能查看 +# sudo apt-key list +# sudo apt-key del 7fa2af80 + +# 安装CUDA +echo -e "\033[32m[INFO]: CUDA installing ...\033[0m" +# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin +sudo cp $package_dir/cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600 +sudo dpkg -i $cuda_fn +sudo apt-key add /var/cuda-repo-ubuntu2004-11-6-local/7fa2af80.pub +sudo apt-get update +sudo apt-get -y install cuda + +# 安装CUDNN +echo -e "\033[32m[INFO]: CUDNN installing ...\033[0m" +tar -xvf $cudnn_fn -C $tmp_dir +sudo cp $tmp_dir/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive/include/cudnn* /usr/local/cuda/include/ +sudo cp $tmp_dir/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive/lib/libcudnn* /usr/local/cuda/lib64/ +sudo chmod a+r /usr/local/cuda/include/cudnn* /usr/local/cuda/lib64/libcudnn* + +sudo ln -sf /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1 /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8 +sudo ln -sf /usr/local/cuda/targets/x86_64-linux/lib/libcudnn.so.8.4.1 /usr/local/cuda/targets/x86_64-linux/lib/libcudnn.so.8 +sudo ln -sf /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1 /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_ops_train.so.8 +sudo ln -sf /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1 /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_adv_train.so.8 +sudo ln -sf /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1 /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8 +sudo ln -sf /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1 /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8 +sudo ln -sf /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1 /usr/local/cuda/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8 + +# 安装TensorRT +echo -e "\033[32m[INFO]: TensorRT installing ...\033[0m" +sudo dpkg -i $tensorrt_fn +sudo apt-key add /var/nv-tensorrt-repo-ubuntu2004-cuda11.6-trt8.4.0.6-ea-20220212/7fa2af80.pub +sudo apt-get update +sudo apt-get install tensorrt -y +sudo apt-get install python3-libnvinfer-dev -y + +sudo rm $cuda_fn +sudo rm $cudnn_fn +sudo rm $tensorrt_fn + diff --git a/utils/gason.cpp b/utils/gason.cpp new file mode 100644 index 0000000..3443243 --- /dev/null +++ b/utils/gason.cpp @@ -0,0 +1,396 @@ +// https://github.com/vivkin/gason - pulled January 10, 2016 +#include "gason.h" +#include +#include +#include +#include +#include "sv_util.h" + + +namespace sv { + +#define JSON_ZONE_SIZE 4096 +#define JSON_STACK_SIZE 32 + +const char *jsonStrError(int err) { + switch (err) { +#define XX(no, str) \ + case JSON_##no: \ + return str; + JSON_ERRNO_MAP(XX) +#undef XX + default: + return "unknown"; + } +} + +void *JsonAllocator::allocate(size_t size) { + size = (size + 7) & ~7; + + if (head && head->used + size <= JSON_ZONE_SIZE) { + char *p = (char *)head + head->used; + head->used += size; + return p; + } + + size_t allocSize = sizeof(Zone) + size; + Zone *zone = (Zone *)malloc(allocSize <= JSON_ZONE_SIZE ? JSON_ZONE_SIZE : allocSize); + if (zone == nullptr) + return nullptr; + zone->used = allocSize; + if (allocSize <= JSON_ZONE_SIZE || head == nullptr) { + zone->next = head; + head = zone; + } else { + zone->next = head->next; + head->next = zone; + } + return (char *)zone + sizeof(Zone); +} + +void JsonAllocator::deallocate() { + while (head) { + Zone *next = head->next; + free(head); + head = next; + } +} + +static inline bool isspace(char c) { + return c == ' ' || (c >= '\t' && c <= '\r'); +} + +static inline bool isdelim(char c) { + return c == ',' || c == ':' || c == ']' || c == '}' || isspace(c) || !c; +} + +static inline bool isdigit(char c) { + return c >= '0' && c <= '9'; +} + +static inline bool isxdigit(char c) { + return (c >= '0' && c <= '9') || ((c & ~' ') >= 'A' && (c & ~' ') <= 'F'); +} + +static inline int char2int(char c) { + if (c <= '9') + return c - '0'; + return (c & ~' ') - 'A' + 10; +} + +static double string2double(char *s, char **endptr) { + char ch = *s; + if (ch == '-') + ++s; + + double result = 0; + while (isdigit(*s)) + result = (result * 10) + (*s++ - '0'); + + if (*s == '.') { + ++s; + + double fraction = 1; + while (isdigit(*s)) { + fraction *= 0.1; + result += (*s++ - '0') * fraction; + } + } + + if (*s == 'e' || *s == 'E') { + ++s; + + double base = 10; + if (*s == '+') + ++s; + else if (*s == '-') { + ++s; + base = 0.1; + } + + unsigned int exponent = 0; + while (isdigit(*s)) + exponent = (exponent * 10) + (*s++ - '0'); + + double power = 1; + for (; exponent; exponent >>= 1, base *= base) + if (exponent & 1) + power *= base; + + result *= power; + } + + *endptr = s; + return ch == '-' ? -result : result; +} + +static inline JsonNode *insertAfter(JsonNode *tail, JsonNode *node) { + if (!tail) + return node->next = node; + node->next = tail->next; + tail->next = node; + return node; +} + +static inline JsonValue listToValue(JsonTag tag, JsonNode *tail) { + if (tail) { + auto head = tail->next; + tail->next = nullptr; + return JsonValue(tag, head); + } + return JsonValue(tag, nullptr); +} + +int jsonParse(char *s, char **endptr, JsonValue *value, JsonAllocator &allocator) { + JsonNode *tails[JSON_STACK_SIZE]; + JsonTag tags[JSON_STACK_SIZE]; + char *keys[JSON_STACK_SIZE]; + JsonValue o; + int pos = -1; + bool separator = true; + JsonNode *node; + *endptr = s; + + while (*s) { + while (isspace(*s)) { + ++s; + if (!*s) break; + } + *endptr = s++; + switch (**endptr) { + case '-': + if (!isdigit(*s) && *s != '.') { + *endptr = s; + return JSON_BAD_NUMBER; + } + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + o = JsonValue(string2double(*endptr, &s)); + if (!isdelim(*s)) { + *endptr = s; + return JSON_BAD_NUMBER; + } + break; + case '"': + o = JsonValue(JSON_STRING, s); + for (char *it = s; *s; ++it, ++s) { + int c = *it = *s; + if (c == '\\') { + c = *++s; + switch (c) { + case '\\': + case '"': + case '/': + *it = c; + break; + case 'b': + *it = '\b'; + break; + case 'f': + *it = '\f'; + break; + case 'n': + *it = '\n'; + break; + case 'r': + *it = '\r'; + break; + case 't': + *it = '\t'; + break; + case 'u': + c = 0; + for (int i = 0; i < 4; ++i) { + if (isxdigit(*++s)) { + c = c * 16 + char2int(*s); + } else { + *endptr = s; + return JSON_BAD_STRING; + } + } + if (c < 0x80) { + *it = c; + } else if (c < 0x800) { + *it++ = 0xC0 | (c >> 6); + *it = 0x80 | (c & 0x3F); + } else { + *it++ = 0xE0 | (c >> 12); + *it++ = 0x80 | ((c >> 6) & 0x3F); + *it = 0x80 | (c & 0x3F); + } + break; + default: + *endptr = s; + return JSON_BAD_STRING; + } + } else if ((unsigned int)c < ' ' || c == '\x7F') { + *endptr = s; + return JSON_BAD_STRING; + } else if (c == '"') { + *it = 0; + ++s; + break; + } + } + if (!isdelim(*s)) { + *endptr = s; + return JSON_BAD_STRING; + } + break; + case 't': + if (!(s[0] == 'r' && s[1] == 'u' && s[2] == 'e' && isdelim(s[3]))) + return JSON_BAD_IDENTIFIER; + o = JsonValue(JSON_TRUE); + s += 3; + break; + case 'f': + if (!(s[0] == 'a' && s[1] == 'l' && s[2] == 's' && s[3] == 'e' && isdelim(s[4]))) + return JSON_BAD_IDENTIFIER; + o = JsonValue(JSON_FALSE); + s += 4; + break; + case 'n': + if (!(s[0] == 'u' && s[1] == 'l' && s[2] == 'l' && isdelim(s[3]))) + return JSON_BAD_IDENTIFIER; + o = JsonValue(JSON_NULL); + s += 3; + break; + case ']': + if (pos == -1) + return JSON_STACK_UNDERFLOW; + if (tags[pos] != JSON_ARRAY) + return JSON_MISMATCH_BRACKET; + o = listToValue(JSON_ARRAY, tails[pos--]); + break; + case '}': + if (pos == -1) + return JSON_STACK_UNDERFLOW; + if (tags[pos] != JSON_OBJECT) + return JSON_MISMATCH_BRACKET; + if (keys[pos] != nullptr) + return JSON_UNEXPECTED_CHARACTER; + o = listToValue(JSON_OBJECT, tails[pos--]); + break; + case '[': + if (++pos == JSON_STACK_SIZE) + return JSON_STACK_OVERFLOW; + tails[pos] = nullptr; + tags[pos] = JSON_ARRAY; + keys[pos] = nullptr; + separator = true; + continue; + case '{': + if (++pos == JSON_STACK_SIZE) + return JSON_STACK_OVERFLOW; + tails[pos] = nullptr; + tags[pos] = JSON_OBJECT; + keys[pos] = nullptr; + separator = true; + continue; + case ':': + if (separator || keys[pos] == nullptr) + return JSON_UNEXPECTED_CHARACTER; + separator = true; + continue; + case ',': + if (separator || keys[pos] != nullptr) + return JSON_UNEXPECTED_CHARACTER; + separator = true; + continue; + case '\0': + continue; + default: + return JSON_UNEXPECTED_CHARACTER; + } + + separator = false; + + if (pos == -1) { + *endptr = s; + *value = o; + return JSON_OK; + } + + if (tags[pos] == JSON_OBJECT) { + if (!keys[pos]) { + if (o.getTag() != JSON_STRING) + return JSON_UNQUOTED_KEY; + keys[pos] = o.toString(); + continue; + } + if ((node = (JsonNode *) allocator.allocate(sizeof(JsonNode))) == nullptr) + return JSON_ALLOCATION_FAILURE; + tails[pos] = insertAfter(tails[pos], node); + tails[pos]->key = keys[pos]; + keys[pos] = nullptr; + } else { + if ((node = (JsonNode *) allocator.allocate(sizeof(JsonNode) - sizeof(char *))) == nullptr) + return JSON_ALLOCATION_FAILURE; + tails[pos] = insertAfter(tails[pos], node); + } + tails[pos]->value = o; + } + return JSON_BREAKING_BAD; +} + + +void _parser_algorithm_params(std::string name_, JsonValue& f_value_, JsonValue& params_value_) +{ + bool has_params = false; + if (f_value_.getTag() == JSON_OBJECT) { + for (auto i : f_value_) { + if (name_ == std::string(i->key)) { + // std::cout << i->key << std::endl; + params_value_ = i->value; + has_params = true; + } + } + } + + if (!has_params && params_value_.getTag() == JSON_OBJECT) + { + char msg[256]; + sprintf(msg, "SpireCV (106) %s parameters reading ERROR!", name_.c_str()); + throw std::runtime_error(msg); + } +} + +void _load_all_json(std::string json_fn, JsonValue& value, JsonAllocator& allocator) +{ + std::ifstream fin; + fin.open(json_fn, std::ios::in); + if (!fin.is_open()) + { + throw std::runtime_error("SpireCV (104) Algorithm parameters file NOT exist!"); + } + std::string json_str, buff; + json_str = ""; + while (getline(fin, buff)) + { + json_str = json_str + _trim(buff); + } + fin.close(); + // std::cout << json_str << std::endl; + char source[1024 * 1024]; // 1M + char *endptr; + + strcpy(source, json_str.c_str()); + int status = jsonParse(source, &endptr, &value, allocator); + if (status != JSON_OK) { + char msg[256]; + sprintf(msg, "SpireCV (106) %s at %zd\n", jsonStrError(status), endptr - source); + throw std::runtime_error(msg); + } +} + + +} + diff --git a/utils/gason.h b/utils/gason.h new file mode 100644 index 0000000..58054dd --- /dev/null +++ b/utils/gason.h @@ -0,0 +1,145 @@ +// https://github.com/vivkin/gason - pulled January 10, 2016 +#pragma once + +#include +#include +#include +#include + + +namespace sv { + +enum JsonTag { + JSON_NUMBER = 0, + JSON_STRING, + JSON_ARRAY, + JSON_OBJECT, + JSON_TRUE, + JSON_FALSE, + JSON_NULL = 0xF +}; + +struct JsonNode; + +#define JSON_VALUE_PAYLOAD_MASK 0x0000FFFFFFFFFFFFULL +#define JSON_VALUE_NAN_MASK 0x7FF0000000000000ULL +#define JSON_VALUE_TAG_MASK 0xF +#define JSON_VALUE_TAG_SHIFT 48 + +union JsonValue { + uint64_t ival; + double fval; + + JsonValue(double x) + : fval(x) { + } + JsonValue(JsonTag tag = JSON_NULL, void *payload = nullptr) { + assert((uintptr_t)payload <= JSON_VALUE_PAYLOAD_MASK); + ival = JSON_VALUE_NAN_MASK | ((uint64_t)tag << JSON_VALUE_TAG_SHIFT) | (uintptr_t)payload; + } + bool isDouble() const { + return (int64_t)ival <= (int64_t)JSON_VALUE_NAN_MASK; + } + JsonTag getTag() const { + return isDouble() ? JSON_NUMBER : JsonTag((ival >> JSON_VALUE_TAG_SHIFT) & JSON_VALUE_TAG_MASK); + } + uint64_t getPayload() const { + assert(!isDouble()); + return ival & JSON_VALUE_PAYLOAD_MASK; + } + double toNumber() const { + assert(getTag() == JSON_NUMBER); + return fval; + } + char *toString() const { + assert(getTag() == JSON_STRING); + return (char *)getPayload(); + } + JsonNode *toNode() const { + assert(getTag() == JSON_ARRAY || getTag() == JSON_OBJECT); + return (JsonNode *)getPayload(); + } +}; + +struct JsonNode { + JsonValue value; + JsonNode *next; + char *key; +}; + +struct JsonIterator { + JsonNode *p; + + void operator++() { + p = p->next; + } + bool operator!=(const JsonIterator &x) const { + return p != x.p; + } + JsonNode *operator*() const { + return p; + } + JsonNode *operator->() const { + return p; + } +}; + +inline JsonIterator begin(JsonValue o) { + return JsonIterator{o.toNode()}; +} +inline JsonIterator end(JsonValue) { + return JsonIterator{nullptr}; +} + +#define JSON_ERRNO_MAP(XX) \ + XX(OK, "ok") \ + XX(BAD_NUMBER, "bad number") \ + XX(BAD_STRING, "bad string") \ + XX(BAD_IDENTIFIER, "bad identifier") \ + XX(STACK_OVERFLOW, "stack overflow") \ + XX(STACK_UNDERFLOW, "stack underflow") \ + XX(MISMATCH_BRACKET, "mismatch bracket") \ + XX(UNEXPECTED_CHARACTER, "unexpected character") \ + XX(UNQUOTED_KEY, "unquoted key") \ + XX(BREAKING_BAD, "breaking bad") \ + XX(ALLOCATION_FAILURE, "allocation failure") + +enum JsonErrno { +#define XX(no, str) JSON_##no, + JSON_ERRNO_MAP(XX) +#undef XX +}; + +const char *jsonStrError(int err); + +class JsonAllocator { + struct Zone { + Zone *next; + size_t used; + } *head = nullptr; + +public: + JsonAllocator() = default; + JsonAllocator(const JsonAllocator &) = delete; + JsonAllocator &operator=(const JsonAllocator &) = delete; + JsonAllocator(JsonAllocator &&x) : head(x.head) { + x.head = nullptr; + } + JsonAllocator &operator=(JsonAllocator &&x) { + head = x.head; + x.head = nullptr; + return *this; + } + ~JsonAllocator() { + deallocate(); + } + void *allocate(size_t size); + void deallocate(); +}; + +int jsonParse(char *str, char **endptr, JsonValue *value, JsonAllocator &allocator); +void _load_all_json(std::string json_fn, JsonValue& value, JsonAllocator& allocator); +void _parser_algorithm_params(std::string name_, JsonValue& f_value_, JsonValue& params_value_); + +} + diff --git a/utils/sv_crclib.cpp b/utils/sv_crclib.cpp new file mode 100644 index 0000000..577a9be --- /dev/null +++ b/utils/sv_crclib.cpp @@ -0,0 +1,575 @@ +#include "sv_crclib.h" + + +namespace sv { + +/****************************************************************************** + * Name: CRC-4/ITU x4+x+1 + * Poly: 0x03 + * Init: 0x00 + * Refin: True + * Refout: True + * Xorout: 0x00 + * Note: + *****************************************************************************/ +uint8_t crc4_itu(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint8_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0x0C;// 0x0C = (reverse 0x03)>>(8-4) + else + crc = (crc >> 1); + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-5/EPC x5+x3+1 + * Poly: 0x09 + * Init: 0x09 + * Refin: False + * Refout: False + * Xorout: 0x00 + * Note: + *****************************************************************************/ +uint8_t crc5_epc(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint8_t crc = 0x48; // Initial value: 0x48 = 0x09<<(8-5) + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for ( i = 0; i < 8; i++ ) + { + if ( crc & 0x80 ) + crc = (crc << 1) ^ 0x48; // 0x48 = 0x09<<(8-5) + else + crc <<= 1; + } + } + return crc >> 3; +} + +/****************************************************************************** + * Name: CRC-5/ITU x5+x4+x2+1 + * Poly: 0x15 + * Init: 0x00 + * Refin: True + * Refout: True + * Xorout: 0x00 + * Note: + *****************************************************************************/ +uint8_t crc5_itu(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint8_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0x15;// 0x15 = (reverse 0x15)>>(8-5) + else + crc = (crc >> 1); + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-5/USB x5+x2+1 + * Poly: 0x05 + * Init: 0x1F + * Refin: True + * Refout: True + * Xorout: 0x1F + * Note: + *****************************************************************************/ +uint8_t crc5_usb(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint8_t crc = 0x1F; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0x14;// 0x14 = (reverse 0x05)>>(8-5) + else + crc = (crc >> 1); + } + } + return crc ^ 0x1F; +} + +/****************************************************************************** + * Name: CRC-6/ITU x6+x+1 + * Poly: 0x03 + * Init: 0x00 + * Refin: True + * Refout: True + * Xorout: 0x00 + * Note: + *****************************************************************************/ +uint8_t crc6_itu(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint8_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0x30;// 0x30 = (reverse 0x03)>>(8-6) + else + crc = (crc >> 1); + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-7/MMC x7+x3+1 + * Poly: 0x09 + * Init: 0x00 + * Refin: False + * Refout: False + * Xorout: 0x00 + * Use: MultiMediaCard,SD,ect. + *****************************************************************************/ +uint8_t crc7_mmc(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint8_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for ( i = 0; i < 8; i++ ) + { + if ( crc & 0x80 ) + crc = (crc << 1) ^ 0x12; // 0x12 = 0x09<<(8-7) + else + crc <<= 1; + } + } + return crc >> 1; +} + +/****************************************************************************** + * Name: CRC-8 x8+x2+x+1 + * Poly: 0x07 + * Init: 0x00 + * Refin: False + * Refout: False + * Xorout: 0x00 + * Note: + *****************************************************************************/ +uint8_t crc8(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint8_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for ( i = 0; i < 8; i++ ) + { + if ( crc & 0x80 ) + crc = (crc << 1) ^ 0x07; + else + crc <<= 1; + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-8/ITU x8+x2+x+1 + * Poly: 0x07 + * Init: 0x00 + * Refin: False + * Refout: False + * Xorout: 0x55 + * Alias: CRC-8/ATM + *****************************************************************************/ +uint8_t crc8_itu(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint8_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for ( i = 0; i < 8; i++ ) + { + if ( crc & 0x80 ) + crc = (crc << 1) ^ 0x07; + else + crc <<= 1; + } + } + return crc ^ 0x55; +} + +/****************************************************************************** + * Name: CRC-8/ROHC x8+x2+x+1 + * Poly: 0x07 + * Init: 0xFF + * Refin: True + * Refout: True + * Xorout: 0x00 + * Note: + *****************************************************************************/ +uint8_t crc8_rohc(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint8_t crc = 0xFF; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0xE0; // 0xE0 = reverse 0x07 + else + crc = (crc >> 1); + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-8/MAXIM x8+x5+x4+1 + * Poly: 0x31 + * Init: 0x00 + * Refin: True + * Refout: True + * Xorout: 0x00 + * Alias: DOW-CRC,CRC-8/IBUTTON + * Use: Maxim(Dallas)'s some devices,e.g. DS18B20 + *****************************************************************************/ +uint8_t crc8_maxim(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint8_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; i++) + { + if (crc & 1) + crc = (crc >> 1) ^ 0x8C; // 0x8C = reverse 0x31 + else + crc >>= 1; + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-16/IBM x16+x15+x2+1 + * Poly: 0x8005 + * Init: 0x0000 + * Refin: True + * Refout: True + * Xorout: 0x0000 + * Alias: CRC-16,CRC-16/ARC,CRC-16/LHA + *****************************************************************************/ +uint16_t crc16_ibm(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint16_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0xA001; // 0xA001 = reverse 0x8005 + else + crc = (crc >> 1); + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-16/MAXIM x16+x15+x2+1 + * Poly: 0x8005 + * Init: 0x0000 + * Refin: True + * Refout: True + * Xorout: 0xFFFF + * Note: + *****************************************************************************/ +uint16_t crc16_maxim(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint16_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0xA001; // 0xA001 = reverse 0x8005 + else + crc = (crc >> 1); + } + } + return ~crc; // crc^0xffff +} + +/****************************************************************************** + * Name: CRC-16/USB x16+x15+x2+1 + * Poly: 0x8005 + * Init: 0xFFFF + * Refin: True + * Refout: True + * Xorout: 0xFFFF + * Note: + *****************************************************************************/ +uint16_t crc16_usb(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint16_t crc = 0xffff; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0xA001; // 0xA001 = reverse 0x8005 + else + crc = (crc >> 1); + } + } + return ~crc; // crc^0xffff +} + +/****************************************************************************** + * Name: CRC-16/MODBUS x16+x15+x2+1 + * Poly: 0x8005 + * Init: 0xFFFF + * Refin: True + * Refout: True + * Xorout: 0x0000 + * Note: + *****************************************************************************/ +uint16_t crc16_modbus(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint16_t crc = 0xffff; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0xA001; // 0xA001 = reverse 0x8005 + else + crc = (crc >> 1); + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-16/CCITT x16+x12+x5+1 + * Poly: 0x1021 + * Init: 0x0000 + * Refin: True + * Refout: True + * Xorout: 0x0000 + * Alias: CRC-CCITT,CRC-16/CCITT-TRUE,CRC-16/KERMIT + *****************************************************************************/ +uint16_t crc16_ccitt(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint16_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0x8408; // 0x8408 = reverse 0x1021 + else + crc = (crc >> 1); + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-16/CCITT-FALSE x16+x12+x5+1 + * Poly: 0x1021 + * Init: 0xFFFF + * Refin: False + * Refout: False + * Xorout: 0x0000 + * Note: + *****************************************************************************/ +uint16_t crc16_ccitt_false(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint16_t crc = 0xffff; //Initial value + while(length--) + { + crc ^= (uint16_t)(*data++) << 8; // crc ^= (uint6_t)(*data)<<8; data++; + for (i = 0; i < 8; ++i) + { + if ( crc & 0x8000 ) + crc = (crc << 1) ^ 0x1021; + else + crc <<= 1; + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-16/X25 x16+x12+x5+1 + * Poly: 0x1021 + * Init: 0xFFFF + * Refin: True + * Refout: True + * Xorout: 0XFFFF + * Note: + *****************************************************************************/ +uint16_t crc16_x25(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint16_t crc = 0xffff; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0x8408; // 0x8408 = reverse 0x1021 + else + crc = (crc >> 1); + } + } + return ~crc; // crc^Xorout +} + +/****************************************************************************** + * Name: CRC-16/XMODEM x16+x12+x5+1 + * Poly: 0x1021 + * Init: 0x0000 + * Refin: False + * Refout: False + * Xorout: 0x0000 + * Alias: CRC-16/ZMODEM,CRC-16/ACORN + *****************************************************************************/ +uint16_t crc16_xmodem(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint16_t crc = 0; // Initial value + while(length--) + { + crc ^= (uint16_t)(*data++) << 8; // crc ^= (uint16_t)(*data)<<8; data++; + for (i = 0; i < 8; ++i) + { + if ( crc & 0x8000 ) + crc = (crc << 1) ^ 0x1021; + else + crc <<= 1; + } + } + return crc; +} + +/****************************************************************************** + * Name: CRC-16/DNP x16+x13+x12+x11+x10+x8+x6+x5+x2+1 + * Poly: 0x3D65 + * Init: 0x0000 + * Refin: True + * Refout: True + * Xorout: 0xFFFF + * Use: M-Bus,ect. + *****************************************************************************/ +uint16_t crc16_dnp(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint16_t crc = 0; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0xA6BC; // 0xA6BC = reverse 0x3D65 + else + crc = (crc >> 1); + } + } + return ~crc; // crc^Xorout +} + +/****************************************************************************** + * Name: CRC-32 x32+x26+x23+x22+x16+x12+x11+x10+x8+x7+x5+x4+x2+x+1 + * Poly: 0x4C11DB7 + * Init: 0xFFFFFFF + * Refin: True + * Refout: True + * Xorout: 0xFFFFFFF + * Alias: CRC_32/ADCCP + * Use: WinRAR,ect. + *****************************************************************************/ +uint32_t crc32(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint32_t crc = 0xffffffff; // Initial value + while(length--) + { + crc ^= *data++; // crc ^= *data; data++; + for (i = 0; i < 8; ++i) + { + if (crc & 1) + crc = (crc >> 1) ^ 0xEDB88320;// 0xEDB88320= reverse 0x04C11DB7 + else + crc = (crc >> 1); + } + } + return ~crc; +} + +/****************************************************************************** + * Name: CRC-32/MPEG-2 x32+x26+x23+x22+x16+x12+x11+x10+x8+x7+x5+x4+x2+x+1 + * Poly: 0x4C11DB7 + * Init: 0xFFFFFFF + * Refin: False + * Refout: False + * Xorout: 0x0000000 + * Note: + *****************************************************************************/ +uint32_t crc32_mpeg_2(uint8_t *data, uint16_t length) +{ + uint8_t i; + uint32_t crc = 0xffffffff; // Initial value + while(length--) + { + crc ^= (uint32_t)(*data++) << 24;// crc ^=(uint32_t)(*data)<<24; data++; + for (i = 0; i < 8; ++i) + { + if ( crc & 0x80000000 ) + crc = (crc << 1) ^ 0x04C11DB7; + else + crc <<= 1; + } + } + return crc; +} + +} diff --git a/utils/sv_crclib.h b/utils/sv_crclib.h new file mode 100644 index 0000000..6a7c97d --- /dev/null +++ b/utils/sv_crclib.h @@ -0,0 +1,32 @@ +#ifndef __CRCLIB_H__ +#define __CRCLIB_H__ + +#include + +namespace sv { + +uint8_t crc4_itu(uint8_t *data, uint16_t length); +uint8_t crc5_epc(uint8_t *data, uint16_t length); +uint8_t crc5_itu(uint8_t *data, uint16_t length); +uint8_t crc5_usb(uint8_t *data, uint16_t length); +uint8_t crc6_itu(uint8_t *data, uint16_t length); +uint8_t crc7_mmc(uint8_t *data, uint16_t length); +uint8_t crc8(uint8_t *data, uint16_t length); +uint8_t crc8_itu(uint8_t *data, uint16_t length); +uint8_t crc8_rohc(uint8_t *data, uint16_t length); +uint8_t crc8_maxim(uint8_t *data, uint16_t length);//DS18B20 +uint16_t crc16_ibm(uint8_t *data, uint16_t length); +uint16_t crc16_maxim(uint8_t *data, uint16_t length); +uint16_t crc16_usb(uint8_t *data, uint16_t length); +uint16_t crc16_modbus(uint8_t *data, uint16_t length); +uint16_t crc16_ccitt(uint8_t *data, uint16_t length); +uint16_t crc16_ccitt_false(uint8_t *data, uint16_t length); +uint16_t crc16_x25(uint8_t *data, uint16_t length); +uint16_t crc16_xmodem(uint8_t *data, uint16_t length); +uint16_t crc16_dnp(uint8_t *data, uint16_t length); +uint32_t crc32(uint8_t *data, uint16_t length); +uint32_t crc32_mpeg_2(uint8_t *data, uint16_t length); + +} + +#endif // __CRCLIB_H__ diff --git a/utils/sv_util.cpp b/utils/sv_util.cpp new file mode 100644 index 0000000..b37ff19 --- /dev/null +++ b/utils/sv_util.cpp @@ -0,0 +1,139 @@ +#include "sv_util.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +namespace sv { + + +std::string _get_home() +{ + struct passwd *pw = getpwuid(getuid()); + const char *homedir = pw->pw_dir; + return std::string(homedir); +} + +bool _is_file_exist(std::string& fn) +{ + std::ifstream f(fn); + return f.good(); +} + + +void _get_sys_time(TimeInfo& t_info) +{ + time_t tt = time(NULL); + tm* t = localtime(&tt); + t_info.year = t->tm_year + 1900; + t_info.mon = t->tm_mon + 1; + t_info.day = t->tm_mday; + t_info.hour = t->tm_hour; + t_info.min = t->tm_min; + t_info.sec = t->tm_sec; +} + +std::string _get_time_str() +{ + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + std::chrono::system_clock::duration tp = now.time_since_epoch(); + tp -= std::chrono::duration_cast(tp); + + std::time_t tt = std::chrono::system_clock::to_time_t(now); + tm t = *std::localtime(&tt); + + char buf[128]; + sprintf(buf, "%4d-%02d-%02d_%02d-%02d-%02d_%03u", t.tm_year + 1900, t.tm_mon + 1, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec, static_cast(tp / std::chrono::milliseconds(1))); + + return std::string(buf); +} + +vector _split(const string& srcstr, const string& delimeter) +{ + vector ret(0); //use ret save the spilted reault + if (srcstr.empty()) //judge the arguments + { + return ret; + } + string::size_type pos_begin = srcstr.find_first_not_of(delimeter); //find first element of srcstr + + string::size_type dlm_pos; //the delimeter postion + string temp; //use third-party temp to save splited element + while (pos_begin != string::npos) //if not a next of end, continue spliting + { + dlm_pos = srcstr.find(delimeter, pos_begin); //find the delimeter symbol + if (dlm_pos != string::npos) + { + temp = srcstr.substr(pos_begin, dlm_pos - pos_begin); + pos_begin = dlm_pos + delimeter.length(); + } + else + { + temp = srcstr.substr(pos_begin); + pos_begin = dlm_pos; + } + if (!temp.empty()) + ret.push_back(temp); + } + return ret; +} + +bool _startswith(const std::string& str, const std::string& start) +{ + size_t srclen = str.size(); + size_t startlen = start.size(); + if (srclen >= startlen) + { + string temp = str.substr(0, startlen); + if (temp == start) + return true; + } + + return false; +} + +bool _endswith(const std::string& str, const std::string& end) +{ + size_t srclen = str.size(); + size_t endlen = end.size(); + if (srclen >= endlen) + { + string temp = str.substr(srclen - endlen, endlen); + if (temp == end) + return true; + } + + return false; +} + +string _trim(const std::string& str) +{ + string ret; + // find the first position of not start with space or '\t' + string::size_type pos_begin = str.find_first_not_of(" \t"); + if (pos_begin == string::npos) + return str; + + // find the last position of end with space or '\t' + string::size_type pos_end = str.find_last_not_of(" \t"); + if (pos_end == string::npos) + return str; + + ret = str.substr(pos_begin, pos_end - pos_begin); + + return ret; +} + +int _comp_str_idx(const std::string& in_str, const std::string* str_list, int len) { + for (int i = 0; i < len; ++i) { + if (in_str.compare(str_list[i]) == 0) return i; + } + return -1; +} + +} diff --git a/utils/sv_util.h b/utils/sv_util.h new file mode 100644 index 0000000..d46b9db --- /dev/null +++ b/utils/sv_util.h @@ -0,0 +1,34 @@ +#ifndef __SV_UTIL__ +#define __SV_UTIL__ + +#include +#include +#include + + +namespace sv { + + +struct TimeInfo +{ + int year, mon, day, hour, min, sec; +}; + +/************* time-related functions *************/ +void _get_sys_time(TimeInfo& t_info); +std::string _get_time_str(); + +/************* std::string-related functions *************/ +std::vector _split(const std::string& srcstr, const std::string& delimeter); +bool _startswith(const std::string& str, const std::string& start); +bool _endswith(const std::string& str, const std::string& end); +std::string _trim(const std::string& str); +int _comp_str_idx(const std::string& in_str, const std::string* str_list, int len); + +/************* file-related functions ***************/ +std::string _get_home(); +bool _is_file_exist(std::string& fn); + +} + +#endif // __SV_UTIL__ diff --git a/video_io/ffmpeg/bs_common.h b/video_io/ffmpeg/bs_common.h new file mode 100644 index 0000000..e52666a --- /dev/null +++ b/video_io/ffmpeg/bs_common.h @@ -0,0 +1,49 @@ +#pragma once +#include +#include +#include + +// 获取当前系统启动以来的毫秒数 +static int64_t getCurTime() +{ + // tv_sec (s) tv_nsec (ns-纳秒) + struct timespec now; + clock_gettime(CLOCK_MONOTONIC, &now); + return (now.tv_sec * 1000 + now.tv_nsec / 1000000); +} + + + +struct VideoFrame +{ +public: + enum VideoFrameType + { + BGR = 0, + YUV420P, + + }; + // VideoFrame(VideoFrameType type, int width, int height,int size) + VideoFrame(VideoFrameType type, int width, int height) + { + this->type = type; + this->width = width; + this->height = height; + this->size = width*height*3; + this->data = new uint8_t[this->size]; + } + ~VideoFrame() + { + delete[] this->data; + this->data = nullptr; + } + + VideoFrameType type; + int size; + int width; + int height; + uint8_t *data; +}; + + + diff --git a/video_io/ffmpeg/bs_push_streamer.cpp b/video_io/ffmpeg/bs_push_streamer.cpp new file mode 100644 index 0000000..8e24090 --- /dev/null +++ b/video_io/ffmpeg/bs_push_streamer.cpp @@ -0,0 +1,377 @@ +#include "bs_push_streamer.h" + +/* +amov_rtsp +2914e3c44737811096c5e1797fe5373d12fcdd39 +*/ + + +// char av_error[AV_ERROR_MAX_STRING_SIZE] = { 0 }; +// #define av_err2str(errnum) av_make_error_string(av_error, AV_ERROR_MAX_STRING_SIZE, errnum) + +BsPushStreamer::BsPushStreamer() +{ + +} + +BsPushStreamer::~BsPushStreamer() +{ + mThread->join(); + mThread = nullptr; +} + + +bool BsPushStreamer::setup(std::string name, int width, int height, int fps, std::string encoder, int bitrate = 4) +{ + if (!connect(name, width, height, fps, encoder, bitrate)) + { + std::cout << "BsPushStreamer::setup error\n"; + return false; + } + + mVideoFrame = new VideoFrame(VideoFrame::BGR, width, height); + // std::cout << "BsStreamer::setup Success!\n"; + start(); + return true; +} + +void BsPushStreamer::start() +{ + mThread = new std::thread(BsPushStreamer::encodeVideoAndWriteStreamThread, this); + mThread->native_handle(); + push_running = true; +} + + +bool BsPushStreamer::connect(std::string name, int width, int height, int fps, std::string encoder, int bitrate) +{ + // 初始化上下文 + if (avformat_alloc_output_context2(&mFmtCtx, NULL, "rtsp", name.c_str()) < 0) + { + std::cout << "avformat_alloc_output_context2 error\n"; + return false; + } + + // 初始化视频编码器 + // AVCodec *videoCodec = avcodec_find_encoder(AV_CODEC_ID_H264); + // AVCodec *videoCodec = avcodec_find_encoder_by_name("h264_nvenc"); + + AVCodec *videoCodec = avcodec_find_encoder_by_name(encoder.c_str()); + if (!videoCodec) + { + std::cout << fmt::format("Using encoder:[{}] error!\n", encoder); + videoCodec = avcodec_find_encoder(AV_CODEC_ID_H264); + + if (!videoCodec) + { + std::cout << "avcodec_alloc_context3 error"; + return false; + } + std::cout << "Using default H264 encoder!\n"; + } + + mVideoCodecCtx = avcodec_alloc_context3(videoCodec); + if (!mVideoCodecCtx) + { + std::cout << "avcodec_alloc_context3 error"; + return false; + } + + // 压缩视频bit位大小 300kB + int bit_rate = bitrate * 1024 * 1024 * 8; + + // CBR:Constant BitRate - 固定比特率 + mVideoCodecCtx->flags |= AV_CODEC_FLAG_QSCALE; + mVideoCodecCtx->bit_rate = bit_rate; + mVideoCodecCtx->rc_min_rate = bit_rate; + mVideoCodecCtx->rc_max_rate = bit_rate; + mVideoCodecCtx->bit_rate_tolerance = bit_rate; + + mVideoCodecCtx->codec_id = videoCodec->id; + // 不支持AV_PIX_FMT_BGR24直接进行编码 + mVideoCodecCtx->pix_fmt = AV_PIX_FMT_YUV420P; + mVideoCodecCtx->codec_type = AVMEDIA_TYPE_VIDEO; + mVideoCodecCtx->width = width; + mVideoCodecCtx->height = height; + mVideoCodecCtx->time_base = {1, fps}; + mVideoCodecCtx->framerate = {fps, 1}; + mVideoCodecCtx->gop_size = 12; + mVideoCodecCtx->max_b_frames = 0; + mVideoCodecCtx->thread_count = 1; + + // 手动设置PPS + // unsigned char sps_pps[] = { + // 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x2a, 0x96, 0x35, 0x40, 0xf0, 0x04, + // 0x4f, 0xcb, 0x37, 0x01, 0x01, 0x01, 0x40, 0x00, 0x01, 0xc2, 0x00, 0x00, 0x57, + // 0xe4, 0x01, 0x00, 0x00, 0x00, 0x01, 0x68, 0xce, 0x3c, 0x80, 0x00 + // }; + + AVDictionary *video_codec_options = NULL; + av_dict_set(&video_codec_options, "profile", "main", 0); + // av_dict_set(&video_codec_options, "preset", "superfast", 0); + av_dict_set(&video_codec_options, "tune", "fastdecode", 0); + + if (avcodec_open2(mVideoCodecCtx, videoCodec, &video_codec_options) < 0) + { + std::cout << "avcodec_open2 error\n"; + return false; + } + + mVideoStream = avformat_new_stream(mFmtCtx, videoCodec); + if (!mVideoStream) + { + std::cout << "avformat_new_stream error\n"; + return false; + } + mVideoStream->id = mFmtCtx->nb_streams - 1; + // stream的time_base参数非常重要,它表示将现实中的一秒钟分为多少个时间基, 在下面调用avformat_write_header时自动完成 + avcodec_parameters_from_context(mVideoStream->codecpar, mVideoCodecCtx); + mVideoIndex = mVideoStream->id; + + + // open output url + av_dump_format(mFmtCtx, 0, name.c_str(), 1); + if (!(mFmtCtx->oformat->flags & AVFMT_NOFILE)) + { + int ret = avio_open(&mFmtCtx->pb, name.c_str(), AVIO_FLAG_WRITE); + if ( ret < 0) + { + std::cout << fmt::format("avio_open error url: {}\n", name.c_str()); + // std::cout << fmt::format("ret = {} : {}\n", ret, av_err2str(ret)); + std::cout << fmt::format("ret = {}\n", ret); + return false; + } + } + + AVDictionary *fmt_options = NULL; + av_dict_set(&fmt_options, "bufsize", "1024", 0); + av_dict_set(&fmt_options, "rw_timeout", "30000000", 0); // 设置rtmp/http-flv连接超时(单位 us) + av_dict_set(&fmt_options, "stimeout", "30000000", 0); // 设置rtsp连接超时(单位 us) + av_dict_set(&fmt_options, "rtsp_transport", "tcp", 0); + + mFmtCtx->video_codec_id = mFmtCtx->oformat->video_codec; + mFmtCtx->audio_codec_id = mFmtCtx->oformat->audio_codec; + + // 调用该函数会将所有stream的time_base,自动设置一个值,通常是1/90000或1/1000,这表示一秒钟表示的时间基长度 + if (avformat_write_header(mFmtCtx, &fmt_options) < 0) + { + std::cout << "avformat_write_header error\n"; + return false; + } + + return true; +} + +void BsPushStreamer::encodeVideoAndWriteStreamThread(void* arg) +{ + // PushExecutor *executor = (PushExecutor *)arg; + BsPushStreamer *mBsPushStreamer = (BsPushStreamer *)arg; + int width = mBsPushStreamer->mVideoFrame->width; + int height = mBsPushStreamer->mVideoFrame->height; + + // 未编码的视频帧(bgr格式) + // VideoFrame *videoFrame = NULL; + // 未编码视频帧队列当前长度 + // int videoFrameQSize = 0; + + AVFrame *frame_yuv420p = av_frame_alloc(); + frame_yuv420p->format = mBsPushStreamer->mVideoCodecCtx->pix_fmt; + frame_yuv420p->width = width; + frame_yuv420p->height = height; + + int frame_yuv420p_buff_size = av_image_get_buffer_size(AV_PIX_FMT_YUV420P, width, height, 1); + uint8_t *frame_yuv420p_buff = (uint8_t *)av_malloc(frame_yuv420p_buff_size); + av_image_fill_arrays( + frame_yuv420p->data, frame_yuv420p->linesize, + frame_yuv420p_buff, + AV_PIX_FMT_YUV420P, + width, height, 1); + + // 编码后的视频帧 + AVPacket *pkt = av_packet_alloc(); + int64_t encodeSuccessCount = 0; + int64_t frameCount = 0; + + int64_t t1 = 0; + int64_t t2 = 0; + int ret = -1; + + auto cnt_time = std::chrono::system_clock::now().time_since_epoch(); + auto last_update_time = std::chrono::system_clock::now().time_since_epoch(); + + while (mBsPushStreamer->push_running) + { + // if (mBsPushStreamer->getVideoFrame(videoFrame, videoFrameQSize)) + if (mBsPushStreamer->nd_push_frame) + { + mBsPushStreamer->nd_push_frame = false; + // frame_bgr 转 frame_yuv420p + mBsPushStreamer->bgr24ToYuv420p(mBsPushStreamer->mVideoFrame->data, width, height, frame_yuv420p_buff); + + frame_yuv420p->pts = frame_yuv420p->pkt_dts = av_rescale_q_rnd( + frameCount, + mBsPushStreamer->mVideoCodecCtx->time_base, + mBsPushStreamer->mVideoStream->time_base, + (AVRounding)(AV_ROUND_NEAR_INF | AV_ROUND_PASS_MINMAX)); + + frame_yuv420p->pkt_duration = av_rescale_q_rnd( + 1, + mBsPushStreamer->mVideoCodecCtx->time_base, + mBsPushStreamer->mVideoStream->time_base, + (AVRounding)(AV_ROUND_NEAR_INF | AV_ROUND_PASS_MINMAX)); + + frame_yuv420p->pkt_pos = -1; + + t1 = getCurTime(); + ret = avcodec_send_frame(mBsPushStreamer->mVideoCodecCtx, frame_yuv420p); + if (ret >= 0) + { + ret = avcodec_receive_packet(mBsPushStreamer->mVideoCodecCtx, pkt); + if (ret >= 0) + { + t2 = getCurTime(); + encodeSuccessCount++; + + // 如果实际推流的是flv文件,不会执行里面的fix_packet_pts + if (pkt->pts == AV_NOPTS_VALUE) + { + std::cout << "pkt->pts == AV_NOPTS_VALUE\n"; + } + pkt->stream_index = mBsPushStreamer->mVideoIndex; + + pkt->pos = -1; + pkt->duration = frame_yuv420p->pkt_duration; + + ret = mBsPushStreamer->writePkt(pkt); + if (ret < 0) + { + std::cout << fmt::format("writePkt : ret = {}\n", ret); + } + } + else + { + // std::cout << fmt::format("avcodec_receive_packet error : ret = {}\n", ret); + } + } + else + { + std::cout << fmt::format("avcodec_send_frame error : ret = {}\n", ret); + } + + frameCount++; + } + else + { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + } + // std::cout << fmt::format("push_running is false!\n"); + // std::cout << fmt::format("end stream!\n"); + + // av_write_trailer(mFmtCtx); //写文件尾 + + av_packet_unref(pkt); + pkt = NULL; + + av_free(frame_yuv420p_buff); + frame_yuv420p_buff = NULL; + + av_frame_free(&frame_yuv420p); + // av_frame_unref(frame_yuv420p); + frame_yuv420p = NULL; +} + +int BsPushStreamer::writePkt(AVPacket* pkt) { + mWritePkt_mtx.lock(); + int ret = av_write_frame(mFmtCtx, pkt); + mWritePkt_mtx.unlock(); + + return ret; + +} + +bool BsPushStreamer::getVideoFrame(VideoFrame *&frame, int &frameQSize) +{ + mRGB_VideoFrameQ_mtx.lock(); + + if (!mRGB_VideoFrameQ.empty()) + { + frame = mRGB_VideoFrameQ.front(); + mRGB_VideoFrameQ.pop(); + frameQSize = mRGB_VideoFrameQ.size(); + mRGB_VideoFrameQ_mtx.unlock(); + return true; + } + else + { + frameQSize = 0; + mRGB_VideoFrameQ_mtx.unlock(); + return false; + } +} + +// void BsPushStreamer::pushVideoFrame(unsigned char *data, int width,int height) +void BsPushStreamer::stream(cv::Mat& image) +{ + + int size = image.cols * image.rows * image.channels(); + // VideoFrame* frame = new VideoFrame(VideoFrame::BGR, image.cols, image.rows, size); + memcpy(mVideoFrame->data, image.data, size); + + mRGB_VideoFrameQ_mtx.lock(); + nd_push_frame = true; + // mRGB_VideoFrameQ.push(frame); + mRGB_VideoFrameQ_mtx.unlock(); +} + +bool BsPushStreamer::videoFrameQisEmpty() +{ + return mRGB_VideoFrameQ.empty(); +} + +unsigned char BsPushStreamer::clipValue(unsigned char x, unsigned char min_val, unsigned char max_val) +{ + if (x > max_val) { return max_val; } + else if (x < min_val) { return min_val; } + else { return x; } +} + +bool BsPushStreamer::bgr24ToYuv420p(unsigned char *bgrBuf, int w, int h, unsigned char *yuvBuf) +{ + + unsigned char *ptrY, *ptrU, *ptrV, *ptrRGB; + memset(yuvBuf, 0, w * h * 3 / 2); + ptrY = yuvBuf; + ptrU = yuvBuf + w * h; + ptrV = ptrU + (w * h * 1 / 4); + unsigned char y, u, v, r, g, b; + + for (int j = 0; j < h; ++j) + { + + ptrRGB = bgrBuf + w * j * 3; + for (int i = 0; i < w; i++) + { + b = *(ptrRGB++); + g = *(ptrRGB++); + r = *(ptrRGB++); + + y = (unsigned char)((66 * r + 129 * g + 25 * b + 128) >> 8) + 16; + u = (unsigned char)((-38 * r - 74 * g + 112 * b + 128) >> 8) + 128; + v = (unsigned char)((112 * r - 94 * g - 18 * b + 128) >> 8) + 128; + *(ptrY++) = clipValue(y, 0, 255); + if (j % 2 == 0 && i % 2 == 0) + { + *(ptrU++) = clipValue(u, 0, 255); + } + else + { + if (i % 2 == 0) + { + *(ptrV++) = clipValue(v, 0, 255); + } + } + } + } + return true; +} diff --git a/video_io/ffmpeg/bs_push_streamer.h b/video_io/ffmpeg/bs_push_streamer.h new file mode 100644 index 0000000..e92268c --- /dev/null +++ b/video_io/ffmpeg/bs_push_streamer.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include + + +#include +// #include + +#include +#include +extern "C" +{ +#include +#include +#include +#include +// #include +#include +} + +#include + +#include "bs_common.h" + + +class BsPushStreamer +{ +public: + BsPushStreamer(); + ~BsPushStreamer(); + + // 用于初始化视频推流,仅调用一次 + bool setup(std::string name, int width, int height, int fps, std::string encoder, int bitrate); + // 推流一帧图像,在循环中被调用 + void stream(cv::Mat& image); + + + + + // 连接流媒体服务器 + bool connect(std::string name, int width, int height, int fps, std::string encoder, int bitrate); + void start(); + void stop(){push_running = false;}; + + // 编码视频帧并推流 + static void encodeVideoAndWriteStreamThread(void* arg); + + bool videoFrameQisEmpty(); + + int writePkt(AVPacket *pkt); + + + // 上下文 + AVFormatContext *mFmtCtx = nullptr; + // 视频帧 + AVCodecContext *mVideoCodecCtx = NULL; + AVStream *mVideoStream = NULL; + + VideoFrame* mVideoFrame = NULL; + + + int mVideoIndex = -1; + + // YAML::Node yaml_cfg; + +private: + + + // 从mRGB_VideoFrameQ里面获取RGBframe + bool getVideoFrame(VideoFrame *&frame, int &frameQSize); + + + // bgr24转yuv420p + unsigned char clipValue(unsigned char x, unsigned char min_val, unsigned char max_val); + bool bgr24ToYuv420p(unsigned char *bgrBuf, int w, int h, unsigned char *yuvBuf); + + + bool push_running = false; + bool nd_push_frame = false; + + // 视频帧 + std::queue mRGB_VideoFrameQ; + std::mutex mRGB_VideoFrameQ_mtx; + + + // 推流锁 + std::mutex mWritePkt_mtx; + std::thread* mThread; + + +}; \ No newline at end of file diff --git a/video_io/ffmpeg/bs_video_saver.cpp b/video_io/ffmpeg/bs_video_saver.cpp new file mode 100644 index 0000000..37f74fb --- /dev/null +++ b/video_io/ffmpeg/bs_video_saver.cpp @@ -0,0 +1,392 @@ +#include "bs_video_saver.h" + +/* +amov_rtsp +53248e16cc899903cf296df468977c60d7d73aa7 +*/ + +// char av_error[AV_ERROR_MAX_STRING_SIZE] = { 0 }; +// #define av_err2str(errnum) av_make_error_string(av_error, AV_ERROR_MAX_STRING_SIZE, errnum) + +BsVideoSaver::BsVideoSaver() +{ + +} + +BsVideoSaver::~BsVideoSaver() +{ + +} + + +bool BsVideoSaver::setup(std::string name, int width, int height, int fps, std::string encoder, int bitrate = 4) +{ + // 重置状态然后初始化 + this->width = width; + this->height = height; + + // 线程停止 + if(mThread != nullptr) + { + this->stop(); + } + + // 编码器重置 + if (mVideoCodecCtx != NULL) + { + avcodec_free_context(&mVideoCodecCtx); + } + + + if (!this->init(name, width, height, fps, encoder, bitrate)) + { + std::cout << "BsVideoSaver::setup error\n"; + return false; + } + + + std::cout << "BsStreamer::setup Success!\n"; + start(); + return true; +} + +void BsVideoSaver::start() +{ + mThread = new std::thread(BsVideoSaver::encodeVideoAndSaveThread, this); + mThread->native_handle(); + push_running = true; +} + +void BsVideoSaver::stop() +{ + if (mThread != nullptr) + { + push_running = false; + mThread->join(); + mThread = nullptr; + } +} + +bool BsVideoSaver::init(std::string name, int width, int height, int fps, std::string encoder, int bitrate) +{ + // 初始化上下文 + if (avformat_alloc_output_context2(&mFmtCtx, NULL, NULL, name.c_str()) < 0) + { + std::cout << "avformat_alloc_output_context2 error\n"; + return false; + } + + // 初始化视频编码器 + // AVCodec *videoCodec = avcodec_find_encoder(AV_CODEC_ID_H264); + // AVCodec *videoCodec = avcodec_find_encoder_by_name("h264_nvenc"); + + AVCodec *videoCodec = avcodec_find_encoder_by_name(encoder.c_str()); + if (!videoCodec) + { + std::cout << fmt::format("Using encoder:[{}] error!\n", encoder); + videoCodec = avcodec_find_encoder(AV_CODEC_ID_H264); + + if (!videoCodec) + { + std::cout << "avcodec_alloc_context3 error"; + return false; + } + + std::cout << "Using default H264 encoder!\n"; + + } + + mVideoCodecCtx = avcodec_alloc_context3(videoCodec); + if (!mVideoCodecCtx) + { + std::cout << "avcodec_alloc_context3 error"; + return false; + } + + // 压缩视频bit位大小 300kB + int bit_rate = bitrate * 1024 * 1024 * 8; + + // CBR:Constant BitRate - 固定比特率 + mVideoCodecCtx->flags |= AV_CODEC_FLAG_QSCALE; + mVideoCodecCtx->bit_rate = bit_rate; + mVideoCodecCtx->rc_min_rate = bit_rate; + mVideoCodecCtx->rc_max_rate = bit_rate; + mVideoCodecCtx->bit_rate_tolerance = bit_rate; + + mVideoCodecCtx->codec_id = videoCodec->id; + // 不支持AV_PIX_FMT_BGR24直接进行编码 + mVideoCodecCtx->pix_fmt = AV_PIX_FMT_YUV420P; + mVideoCodecCtx->codec_type = AVMEDIA_TYPE_VIDEO; + mVideoCodecCtx->width = width; + mVideoCodecCtx->height = height; + mVideoCodecCtx->time_base = {1, fps}; + mVideoCodecCtx->framerate = {fps, 1}; + mVideoCodecCtx->gop_size = 12; + mVideoCodecCtx->max_b_frames = 0; + mVideoCodecCtx->thread_count = 1; + + + AVDictionary *video_codec_options = NULL; + av_dict_set(&video_codec_options, "profile", "main", 0); + // av_dict_set(&video_codec_options, "preset", "superfast", 0); + av_dict_set(&video_codec_options, "tune", "fastdecode", 0); + + if (avcodec_open2(mVideoCodecCtx, videoCodec, &video_codec_options) < 0) + { + std::cout << "avcodec_open2 error\n"; + return false; + } + + mVideoStream = avformat_new_stream(mFmtCtx, videoCodec); + if (!mVideoStream) + { + std::cout << "avformat_new_stream error\n"; + return false; + } + mVideoStream->id = mFmtCtx->nb_streams - 1; + // stream的time_base参数非常重要,它表示将现实中的一秒钟分为多少个时间基, 在下面调用avformat_write_header时自动完成 + avcodec_parameters_from_context(mVideoStream->codecpar, mVideoCodecCtx); + mVideoIndex = mVideoStream->id; + + + // open output url + av_dump_format(mFmtCtx, 0, name.c_str(), 1); + if (!(mFmtCtx->oformat->flags & AVFMT_NOFILE)) + { + int ret = avio_open(&mFmtCtx->pb, name.c_str(), AVIO_FLAG_WRITE); + if ( ret < 0) + { + std::cout << fmt::format("avio_open error url: {}\n", name.c_str()); + // std::cout << fmt::format("ret = {} : {}\n", ret, av_err2str(ret)); + std::cout << fmt::format("ret = {}\n", ret); + return false; + } + } + + AVDictionary *fmt_options = NULL; + av_dict_set(&fmt_options, "bufsize", "1024", 0); + + + mFmtCtx->video_codec_id = mFmtCtx->oformat->video_codec; + mFmtCtx->audio_codec_id = mFmtCtx->oformat->audio_codec; + + // 调用该函数会将所有stream的time_base,自动设置一个值,通常是1/90000或1/1000,这表示一秒钟表示的时间基长度 + if (avformat_write_header(mFmtCtx, &fmt_options) < 0) + { + std::cout << "avformat_write_header error\n"; + return false; + } + + return true; +} + +void BsVideoSaver::encodeVideoAndSaveThread(void* arg) +{ + // PushExecutor *executor = (PushExecutor *)arg; + BsVideoSaver *mBsVideoSaver = (BsVideoSaver *)arg; + int width = mBsVideoSaver->width; + int height = mBsVideoSaver->height; + + // 未编码的视频帧(bgr格式) + VideoFrame *videoFrame = NULL; + // 未编码视频帧队列当前长度 + int videoFrameQSize = 0; + + AVFrame *frame_yuv420p = av_frame_alloc(); + frame_yuv420p->format = mBsVideoSaver->mVideoCodecCtx->pix_fmt; + frame_yuv420p->width = width; + frame_yuv420p->height = height; + + int frame_yuv420p_buff_size = av_image_get_buffer_size(AV_PIX_FMT_YUV420P, width, height, 1); + uint8_t *frame_yuv420p_buff = (uint8_t *)av_malloc(frame_yuv420p_buff_size); + av_image_fill_arrays( + frame_yuv420p->data, frame_yuv420p->linesize, + frame_yuv420p_buff, + AV_PIX_FMT_YUV420P, + width, height, 1); + + // 编码后的视频帧 + AVPacket *pkt = av_packet_alloc(); + int64_t encodeSuccessCount = 0; + int64_t frameCount = 0; + + int64_t t1 = 0; + int64_t t2 = 0; + int ret = -1; + + while (mBsVideoSaver->push_running) + { + if (mBsVideoSaver->getVideoFrame(videoFrame, videoFrameQSize)) + { + + // frame_bgr 转 frame_yuv420p + mBsVideoSaver->bgr24ToYuv420p(videoFrame->data, width, height, frame_yuv420p_buff); + + frame_yuv420p->pts = frame_yuv420p->pkt_dts = av_rescale_q_rnd( + frameCount, + mBsVideoSaver->mVideoCodecCtx->time_base, + mBsVideoSaver->mVideoStream->time_base, + (AVRounding)(AV_ROUND_NEAR_INF | AV_ROUND_PASS_MINMAX)); + + frame_yuv420p->pkt_duration = av_rescale_q_rnd( + 1, + mBsVideoSaver->mVideoCodecCtx->time_base, + mBsVideoSaver->mVideoStream->time_base, + (AVRounding)(AV_ROUND_NEAR_INF | AV_ROUND_PASS_MINMAX)); + + frame_yuv420p->pkt_pos = frameCount; + + t1 = getCurTime(); + ret = avcodec_send_frame(mBsVideoSaver->mVideoCodecCtx, frame_yuv420p); + if (ret >= 0) + { + ret = avcodec_receive_packet(mBsVideoSaver->mVideoCodecCtx, pkt); + if (ret >= 0) + { + t2 = getCurTime(); + encodeSuccessCount++; + + pkt->stream_index = mBsVideoSaver->mVideoIndex; + + pkt->pos = frameCount; + pkt->duration = frame_yuv420p->pkt_duration; + + ret = mBsVideoSaver->writePkt(pkt); + + if (ret < 0) + { + std::cout << fmt::format("writePkt : ret = {}\n", ret); + } + } + else + { + // std::cout << fmt::format("avcodec_receive_packet error : ret = {}\n", ret); + } + } + else + { + std::cout << fmt::format("avcodec_send_frame error : ret = {}\n", ret); + } + + frameCount++; + + // 释放资源 + delete videoFrame; + videoFrame = NULL; + } + else + { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } + // std::cout << fmt::format("push_running is false!\n"); + // std::cout << fmt::format("end stream!\n"); + + //写文件尾 + av_write_trailer(mBsVideoSaver->mFmtCtx); + + av_packet_unref(pkt); + pkt = NULL; + + av_free(frame_yuv420p_buff); + frame_yuv420p_buff = NULL; + + av_frame_free(&frame_yuv420p); + // av_frame_unref(frame_yuv420p); + frame_yuv420p = NULL; + +} + +int BsVideoSaver::writePkt(AVPacket* pkt) { + mWritePkt_mtx.lock(); + int ret = av_write_frame(mFmtCtx, pkt); + mWritePkt_mtx.unlock(); + + return ret; + +} + +bool BsVideoSaver::getVideoFrame(VideoFrame *&frame, int &frameQSize) +{ + + mRGB_VideoFrameQ_mtx.lock(); + + if (!mRGB_VideoFrameQ.empty()) + { + frame = mRGB_VideoFrameQ.front(); + mRGB_VideoFrameQ.pop(); + frameQSize = mRGB_VideoFrameQ.size(); + mRGB_VideoFrameQ_mtx.unlock(); + return true; + } + else + { + frameQSize = 0; + mRGB_VideoFrameQ_mtx.unlock(); + return false; + } +} + +void BsVideoSaver::write(cv::Mat& image) +{ + + int size = image.cols * image.rows * image.channels(); + VideoFrame* frame = new VideoFrame(VideoFrame::BGR, image.cols, image.rows); + memcpy(frame->data, image.data, size); + + mRGB_VideoFrameQ_mtx.lock(); + mRGB_VideoFrameQ.push(frame); + mRGB_VideoFrameQ_mtx.unlock(); +} + +bool BsVideoSaver::videoFrameQisEmpty() +{ + return mRGB_VideoFrameQ.empty(); +} + +unsigned char BsVideoSaver::clipValue(unsigned char x, unsigned char min_val, unsigned char max_val) +{ + if (x > max_val) { return max_val; } + else if (x < min_val) { return min_val; } + else { return x; } +} + +bool BsVideoSaver::bgr24ToYuv420p(unsigned char *bgrBuf, int w, int h, unsigned char *yuvBuf) +{ + + unsigned char *ptrY, *ptrU, *ptrV, *ptrRGB; + memset(yuvBuf, 0, w * h * 3 / 2); + ptrY = yuvBuf; + ptrU = yuvBuf + w * h; + ptrV = ptrU + (w * h * 1 / 4); + unsigned char y, u, v, r, g, b; + + for (int j = 0; j < h; ++j) + { + + ptrRGB = bgrBuf + w * j * 3; + for (int i = 0; i < w; i++) + { + b = *(ptrRGB++); + g = *(ptrRGB++); + r = *(ptrRGB++); + + y = (unsigned char)((66 * r + 129 * g + 25 * b + 128) >> 8) + 16; + u = (unsigned char)((-38 * r - 74 * g + 112 * b + 128) >> 8) + 128; + v = (unsigned char)((112 * r - 94 * g - 18 * b + 128) >> 8) + 128; + *(ptrY++) = clipValue(y, 0, 255); + if (j % 2 == 0 && i % 2 == 0) + { + *(ptrU++) = clipValue(u, 0, 255); + } + else + { + if (i % 2 == 0) + { + *(ptrV++) = clipValue(v, 0, 255); + } + } + } + } + return true; +} diff --git a/video_io/ffmpeg/bs_video_saver.h b/video_io/ffmpeg/bs_video_saver.h new file mode 100644 index 0000000..223c22a --- /dev/null +++ b/video_io/ffmpeg/bs_video_saver.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include +#include + + +#include +// #include + +#include +#include +extern "C" +{ +#include +#include +#include +#include +// #include +#include +} + +#include + +#include "bs_common.h" + + +class BsVideoSaver +{ +public: + BsVideoSaver(); + ~BsVideoSaver(); + + // 用于初始化视频推流,仅调用一次 + bool setup(std::string name, int width, int height, int fps, std::string encoder, int bitrate); + // 推流一帧图像,在循环中被调用 + void write(cv::Mat& image); + + + // 连接流媒体服务器 + bool init(std::string name, int width, int height, int fps, std::string encoder, int bitrate); + void start(); + void stop(); + + // 编码视频帧并推流 + static void encodeVideoAndSaveThread(void* arg); + + bool videoFrameQisEmpty(); + + int writePkt(AVPacket *pkt); + + + // 上下文 + AVFormatContext *mFmtCtx = nullptr; + // 视频帧 + AVCodecContext *mVideoCodecCtx = NULL; + AVStream *mVideoStream = NULL; + + + int mVideoIndex = -1; + + +private: + + // 从mRGB_VideoFrameQ里面获取RGBframe + bool getVideoFrame(VideoFrame *&frame, int &frameQSize); + + + // bgr24转yuv420p + unsigned char clipValue(unsigned char x, unsigned char min_val, unsigned char max_val); + bool bgr24ToYuv420p(unsigned char *bgrBuf, int w, int h, unsigned char *yuvBuf); + + int width = -1; + int height = -1; + + + bool push_running = false; + bool nd_push_frame = false; + + // 视频帧 + std::queue mRGB_VideoFrameQ; + std::mutex mRGB_VideoFrameQ_mtx; + + + // 推流锁 + std::mutex mWritePkt_mtx; + std::thread* mThread = nullptr; + + +}; \ No newline at end of file diff --git a/video_io/gstreamer/streamer_gstreamer_impl.cpp b/video_io/gstreamer/streamer_gstreamer_impl.cpp new file mode 100644 index 0000000..3336338 --- /dev/null +++ b/video_io/gstreamer/streamer_gstreamer_impl.cpp @@ -0,0 +1,102 @@ +#include "streamer_gstreamer_impl.h" +#include +#include + + + +namespace sv { + + +VideoStreamerGstreamerImpl::VideoStreamerGstreamerImpl() +{ +} +VideoStreamerGstreamerImpl::~VideoStreamerGstreamerImpl() +{ +} + +bool VideoStreamerGstreamerImpl::gstreamerSetup(VideoStreamerBase* base_) +{ + this->_rtsp_port = base_->getPort(); + this->_url = base_->getUrl(); + this->_bitrate = base_->getBitrate(); + this->_stream_size = base_->getSize(); + +#ifdef WITH_GSTREAMER + int media_port = 5400; + char port_str[8]; + sprintf(port_str, "%d", this->_rtsp_port); + + /* create a server instance */ + this->_server = gst_rtsp_server_new(); + g_object_set(_server, "service", port_str, NULL); + this->_mounts = gst_rtsp_server_get_mount_points(this->_server); + this->_factory = gst_rtsp_media_factory_new(); + + char media_str[512]; + +#ifdef PLATFORM_JETSON + sprintf(media_str, "(udpsrc name=pay0 port=%d caps=\"application/x-rtp, media=(string)video, clock-rate=(int)90000, encoding-name=(string)H264, payload=96 \")", media_port); + gst_rtsp_media_factory_set_launch(this->_factory, media_str); + gst_rtsp_media_factory_set_shared(this->_factory, TRUE); +#else + sprintf(media_str, "(udpsrc name=pay0 port=%d buffer-size=524288 caps=\"application/x-rtp, media=(string)video, clock-rate=(int)90000, encoding-name=(string)H264, payload=96 \")", media_port); + gst_rtsp_media_factory_set_launch(this->_factory, media_str); + gst_rtsp_media_factory_set_shared(this->_factory, TRUE); +#endif + + /* attach the test factory to the /test url */ + gst_rtsp_mount_points_add_factory(this->_mounts, this->_url.c_str(), this->_factory); + /* don't need the ref to the mapper anymore */ + g_object_unref(this->_mounts); + /* attach the server to the default maincontext */ + gst_rtsp_server_attach(this->_server, NULL); + + /* start serving */ + std::cout << "stream ready at rtsp://127.0.0.1:" << this->_rtsp_port << this->_url << std::endl; + + int bitrate = this->_bitrate; + if (bitrate < 1) bitrate = 1; + if (bitrate > 20) bitrate = 20; + + char str_buf[512]; + +#ifdef PLATFORM_JETSON + sprintf(str_buf, "appsrc is-live=true ! videoconvert ! nvvidconv ! video/x-raw(memory:NVMM) ! nvv4l2h264enc insert-sps-pps=true bitrate=%d ! h264parse ! rtph264pay name=pay0 pt=96 ! udpsink host=127.0.0.1 port=%d async=false", bitrate * 1000000, media_port); // omxh264enc +#else + sprintf(str_buf, "appsrc is-live=true ! videoconvert ! x264enc bitrate=%d ! video/x-h264, stream-format=byte-stream ! rtph264pay name=pay0 pt=96 ! udpsink host=127.0.0.1 port=%d async=false", bitrate * 1000000, media_port); +#endif + + std::string str(str_buf); + this->_stream_writer = cv::VideoWriter(str, cv::CAP_GSTREAMER, 0, 30, this->_stream_size, true); + + return true; +#endif + return false; +} + +bool VideoStreamerGstreamerImpl::gstreamerIsOpened() +{ +#ifdef WITH_GSTREAMER + return this->_stream_writer.isOpened(); +#endif + return false; +} + +void VideoStreamerGstreamerImpl::gstreamerWrite(cv::Mat img_) +{ +#ifdef WITH_GSTREAMER + this->_stream_writer.write(img_); +#endif +} + +void VideoStreamerGstreamerImpl::gstreamerRelease() +{ +#ifdef WITH_GSTREAMER + if (this->_stream_writer.isOpened()) + this->_stream_writer.release(); +#endif +} + + +} + diff --git a/video_io/gstreamer/streamer_gstreamer_impl.h b/video_io/gstreamer/streamer_gstreamer_impl.h new file mode 100644 index 0000000..1348780 --- /dev/null +++ b/video_io/gstreamer/streamer_gstreamer_impl.h @@ -0,0 +1,48 @@ +#ifndef __SV_STREAM_GSTREAMER_IMPL__ +#define __SV_STREAM_GSTREAMER_IMPL__ + +#include "sv_core.h" +#include +#include +#include +#include +#include + +#ifdef WITH_GSTREAMER +#include +#include +#include +#include // for sockaddr_in +#endif + + +namespace sv { + + +class VideoStreamerGstreamerImpl +{ +public: + VideoStreamerGstreamerImpl(); + ~VideoStreamerGstreamerImpl(); + + bool gstreamerSetup(VideoStreamerBase* base_); + bool gstreamerIsOpened(); + void gstreamerWrite(cv::Mat img_); + void gstreamerRelease(); + + int _rtsp_port; + std::string _url; + int _bitrate; + cv::Size _stream_size; + +#ifdef WITH_GSTREAMER + cv::VideoWriter _stream_writer; + GstRTSPServer *_server; + GstRTSPMountPoints *_mounts; + GstRTSPMediaFactory *_factory; +#endif +}; + + +} +#endif diff --git a/video_io/gstreamer/writer_gstreamer_impl.cpp b/video_io/gstreamer/writer_gstreamer_impl.cpp new file mode 100644 index 0000000..b15cca1 --- /dev/null +++ b/video_io/gstreamer/writer_gstreamer_impl.cpp @@ -0,0 +1,61 @@ +#include "writer_gstreamer_impl.h" +#include +#include + + + +namespace sv { + + +VideoWriterGstreamerImpl::VideoWriterGstreamerImpl() +{ +} +VideoWriterGstreamerImpl::~VideoWriterGstreamerImpl() +{ +} + +bool VideoWriterGstreamerImpl::gstreamerSetup(VideoWriterBase* base_, std::string file_name_) +{ + this->_file_path = base_->getFilePath(); + this->_fps = base_->getFps(); + this->_image_size = base_->getSize(); + +#ifdef WITH_GSTREAMER + bool opend = false; +#ifdef PLATFORM_JETSON + std::string pipeline = "appsrc ! videoconvert ! nvvidconv ! video/x-raw(memory:NVMM) ! nvv4l2h264enc ! h264parse ! matroskamux ! filesink location=" + this->_file_path + file_name_ + ".avi"; + opend = this->_writer.open(pipeline, cv::VideoWriter::fourcc('m','p','4','v'), this->_fps, this->_image_size); +#else + opend = this->_writer.open(this->_file_path + file_name_ + ".avi", cv::VideoWriter::fourcc('x','v','i','d'), this->_fps, this->_image_size); +#endif + return opend; +#endif + return false; +} + +bool VideoWriterGstreamerImpl::gstreamerIsOpened() +{ +#ifdef WITH_GSTREAMER + return this->_writer.isOpened(); +#endif + return false; +} + +void VideoWriterGstreamerImpl::gstreamerWrite(cv::Mat img_) +{ +#ifdef WITH_GSTREAMER + this->_writer << img_; +#endif +} + +void VideoWriterGstreamerImpl::gstreamerRelease() +{ +#ifdef WITH_GSTREAMER + if (this->_writer.isOpened()) + this->_writer.release(); +#endif +} + + +} + diff --git a/video_io/gstreamer/writer_gstreamer_impl.h b/video_io/gstreamer/writer_gstreamer_impl.h new file mode 100644 index 0000000..88f18a5 --- /dev/null +++ b/video_io/gstreamer/writer_gstreamer_impl.h @@ -0,0 +1,37 @@ +#ifndef __SV_WRITER_GSTREAMER_IMPL__ +#define __SV_WRITER_GSTREAMER_IMPL__ + +#include "sv_core.h" +#include +#include +#include +#include +#include + + +namespace sv { + + +class VideoWriterGstreamerImpl +{ +public: + VideoWriterGstreamerImpl(); + ~VideoWriterGstreamerImpl(); + + bool gstreamerSetup(VideoWriterBase* base_, std::string file_name_); + bool gstreamerIsOpened(); + void gstreamerWrite(cv::Mat img_); + void gstreamerRelease(); + + std::string _file_path; + double _fps; + cv::Size _image_size; + +#ifdef WITH_GSTREAMER + cv::VideoWriter _writer; +#endif +}; + + +} +#endif diff --git a/video_io/sv_video_base.cpp b/video_io/sv_video_base.cpp new file mode 100644 index 0000000..573f744 --- /dev/null +++ b/video_io/sv_video_base.cpp @@ -0,0 +1,1310 @@ +#include "sv_video_base.h" +#include +#include "ellipse_detector.h" +#include "sv_util.h" +#include "sv_crclib.h" + +#define SV_MAX_FRAMES 52000 +typedef unsigned char byte; + + +namespace sv { + + +cv::Ptr _g_dict = nullptr; + + +std::string get_home() +{ + return _get_home(); +} +bool is_file_exist(std::string& fn) +{ + return _is_file_exist(fn); +} +void list_dir(std::string dir, std::vector& files, std::string suffixs, bool r) +{ + yaed::_list_dir(dir, files, suffixs, r); +} + +cv::Mat& _attach_aruco(int id, cv::Mat& img) +{ + cv::Mat marker_img; + std::vector ch(3); + cv::aruco::Dictionary dict = cv::aruco::getPredefinedDictionary(cv::aruco::DICT_5X5_1000); + ch[0] = cv::Mat::zeros(22, 22, CV_8UC1); + ch[1] = cv::Mat::zeros(22, 22, CV_8UC1); + ch[2] = cv::Mat::zeros(22, 22, CV_8UC1); + + ch[0].setTo(cv::Scalar(255)); + ch[1].setTo(cv::Scalar(255)); + ch[2].setTo(cv::Scalar(255)); + cv::Rect inner_roi = cv::Rect(4, 4, 14, 14); + cv::Rect full_roi = cv::Rect(img.cols - 22, img.rows - 22, 22, 22); + + int id_k = id % 1000; + + // dict.drawMarker(id_k, 14, marker_img, 1); + cv::aruco::generateImageMarker(dict, id_k, 14, marker_img, 1); + marker_img.copyTo(ch[0](inner_roi)); + // dict.drawMarker(id_k, 14, marker_img, 1); + cv::aruco::generateImageMarker(dict, id_k, 14, marker_img, 1); + marker_img.copyTo(ch[1](inner_roi)); + // dict.drawMarker(id_k, 14, marker_img, 1); + cv::aruco::generateImageMarker(dict, id_k, 14, marker_img, 1); + marker_img.copyTo(ch[2](inner_roi)); + + cv::merge(ch, marker_img); + marker_img.copyTo(img(full_roi)); + return img; +} + +int _parse_aruco(cv::Mat& img) +{ + int id; + cv::Mat marker_img; + std::vector ch(3); + if (_g_dict == nullptr) + { + _g_dict = new cv::aruco::Dictionary; + *_g_dict = cv::aruco::getPredefinedDictionary(cv::aruco::DICT_5X5_1000); + } + cv::Rect full_roi = cv::Rect(img.cols - 22, img.rows - 22, 22, 22); + img(full_roi).copyTo(marker_img); + cv::split(marker_img, ch); + + std::vector id_i; + std::vector id_k; + std::vector id_m; + std::vector > marker_corners; + cv::aruco::detectMarkers(ch[0], _g_dict, marker_corners, id_i); + cv::aruco::detectMarkers(ch[1], _g_dict, marker_corners, id_k); + cv::aruco::detectMarkers(ch[2], _g_dict, marker_corners, id_m); + if (id_i.size() > 0 || id_k.size() > 0 || id_m.size() > 0) + { + if (id_i.size() > 0) + id = id_i[0]; + else if (id_k.size() > 0) + id = id_k[0]; + else if (id_m.size() > 0) + id = id_m[0]; + } + else + { + // std::cout << "error ch0 & ch1" << std::endl; + id = -1; + } + + return id; +} + + +Target::Target() +{ + this->has_hw = false; + this->has_tid = false; + this->has_position = false; + this->has_los = false; + this->has_seg = false; + this->has_box = false; + this->has_ell = false; + this->has_aruco = false; + this->has_yaw = false; +} + + +UDPServer::UDPServer(std::string dest_ip, int port) +{ + this->_sockfd = socket(AF_INET, SOCK_DGRAM, 0); + bzero(&this->_servaddr, sizeof(this->_servaddr)); + this->_servaddr.sin_family = AF_INET; + inet_pton(AF_INET, dest_ip.c_str(), &this->_servaddr.sin_addr); + this->_servaddr.sin_port = htons(port); +} + +UDPServer::~UDPServer() +{ + +} + +void _floatTobytes(float data, byte bytes[]) +{ + int i; + size_t length = sizeof(float); + byte *pdata = (byte*)&data; + for (i = 0; i < length; i++) + { + bytes[i] = *pdata++; + } +} +void _intTobytes(int data, byte bytes[]) +{ + int i; + size_t length = sizeof(int); + byte *pdata = (byte*)&data; + for (i = 0; i < length; i++) + { + bytes[i] = *pdata++; + } +} +void _uint32Tobytes(uint32_t data, byte bytes[]) +{ + int i; + size_t length = sizeof(uint32_t); + byte *pdata = (byte*)&data; + for (i = 0; i < length; i++) + { + bytes[i] = *pdata++; + } +} +void _shortTobytes(unsigned short data, byte bytes[]) +{ + int i; + size_t length = sizeof(unsigned short); + byte *pdata = (byte*)&data; + for (i = 0; i < length; i++) + { + bytes[i] = *pdata++; + } +} + +void UDPServer::send(const TargetsInFrame& tgts_) +{ + byte upd_msg[1024*6]; // max to 100 objects + + upd_msg[0] = 0xFA; + upd_msg[1] = 0xFC; + + upd_msg[4] = (byte) tgts_.type; + upd_msg[5] = 0xFF; + upd_msg[6] = 0xFF; + upd_msg[7] = 0xFF; + upd_msg[8] = 0xFF; + + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + std::chrono::system_clock::duration tp = now.time_since_epoch(); + tp -= std::chrono::duration_cast(tp); + unsigned short milliseconds = static_cast(tp / std::chrono::milliseconds(1)); + std::time_t tt = std::chrono::system_clock::to_time_t(now); + tm t = *std::localtime(&tt); + + _shortTobytes((unsigned short) (t.tm_year + 1900), &upd_msg[9]); // year + upd_msg[11] = (byte) (t.tm_mon + 1); // month + upd_msg[12] = (byte) t.tm_mday; // day + upd_msg[13] = (byte) t.tm_hour; // hour + upd_msg[14] = (byte) t.tm_min; // min + upd_msg[15] = (byte) t.tm_sec; // sec + _shortTobytes(milliseconds, &upd_msg[16]); + int index_d1 = 18; + upd_msg[index_d1] = 0x00; + int index_d2 = 19; + upd_msg[index_d2] = 0x00; + int index_d3 = 20; + upd_msg[index_d3] = 0x00; + int index_d4 = 21; + upd_msg[index_d4] = 0x00; + + int max_objs = 100; + if (tgts_.targets.size() < 100) max_objs = (int) tgts_.targets.size(); + + int mp = 22; + _intTobytes(tgts_.frame_id, &upd_msg[mp]); + upd_msg[index_d4] = upd_msg[index_d4] | 0x01; + mp += 4; + _intTobytes(tgts_.width, &upd_msg[mp]); + upd_msg[index_d4] = upd_msg[index_d4] | 0x02; + mp += 4; + _intTobytes(tgts_.height, &upd_msg[mp]); + upd_msg[index_d4] = upd_msg[index_d4] | 0x04; + mp += 4; + _intTobytes(max_objs, &upd_msg[mp]); + upd_msg[index_d4] = upd_msg[index_d4] | 0x08; + if (tgts_.has_fps) + { + mp += 4; + _floatTobytes((float) tgts_.fps, &upd_msg[mp]); + upd_msg[index_d4] = upd_msg[index_d4] | 0x10; + } + if (tgts_.has_fov) + { + mp += 4; + _floatTobytes((float) tgts_.fov_x, &upd_msg[mp]); + upd_msg[index_d4] = upd_msg[index_d4] | 0x20; + mp += 4; + _floatTobytes((float) tgts_.fov_y, &upd_msg[mp]); + upd_msg[index_d4] = upd_msg[index_d4] | 0x40; + } + if (tgts_.has_pod_info) + { + mp += 4; + _floatTobytes((float) tgts_.pod_patch, &upd_msg[mp]); + upd_msg[index_d4] = upd_msg[index_d4] | 0x80; + mp += 4; + _floatTobytes((float) tgts_.pod_roll, &upd_msg[mp]); + upd_msg[index_d3] = upd_msg[index_d3] | 0x01; + mp += 4; + _floatTobytes((float) tgts_.pod_yaw, &upd_msg[mp]); + upd_msg[index_d3] = upd_msg[index_d3] | 0x02; + } + if (tgts_.has_uav_pos) + { + mp += 4; + _floatTobytes((float) tgts_.longitude, &upd_msg[mp]); + upd_msg[index_d3] = upd_msg[index_d3] | 0x04; + mp += 4; + _floatTobytes((float) tgts_.latitude, &upd_msg[mp]); + upd_msg[index_d3] = upd_msg[index_d3] | 0x08; + mp += 4; + _floatTobytes((float) tgts_.altitude, &upd_msg[mp]); + upd_msg[index_d3] = upd_msg[index_d3] | 0x10; + } + if (tgts_.has_uav_vel) + { + mp += 4; + _floatTobytes((float) tgts_.uav_vx, &upd_msg[mp]); + upd_msg[index_d3] = upd_msg[index_d3] | 0x20; + mp += 4; + _floatTobytes((float) tgts_.uav_vy, &upd_msg[mp]); + upd_msg[index_d3] = upd_msg[index_d3] | 0x40; + mp += 4; + _floatTobytes((float) tgts_.uav_vz, &upd_msg[mp]); + upd_msg[index_d3] = upd_msg[index_d3] | 0x80; + } + if (tgts_.has_ill) + { + mp += 4; + _floatTobytes((float) tgts_.illumination, &upd_msg[mp]); + upd_msg[index_d2] = upd_msg[index_d2] | 0x01; + } + mp += 4; + + for (int n=0; n_sockfd, upd_msg, mp, 0, (struct sockaddr *)&this->_servaddr, sizeof(this->_servaddr)); +} + + +void drawTargetsInFrame( + cv::Mat& img_, + const TargetsInFrame& tgts_, + bool with_all, + bool with_category, + bool with_tid, + bool with_seg, + bool with_box, + bool with_ell, + bool with_aruco, + bool with_yaw +) +{ + if (tgts_.rois.size() > 0) + { + cv::Mat image_ret; + cv::addWeighted(img_, 0.5, cv::Mat::zeros(cv::Size(img_.cols, img_.rows), CV_8UC3), 0, 0, image_ret); + cv::Rect roi = cv::Rect(tgts_.rois[0].x1, tgts_.rois[0].y1, tgts_.rois[0].x2 - tgts_.rois[0].x1, tgts_.rois[0].y2 - tgts_.rois[0].y1); + img_(roi).copyTo(image_ret(roi)); + image_ret.copyTo(img_); + } + std::vector > aruco_corners; + std::vector aruco_ids; + std::vector ellipses; + for (Target tgt : tgts_.targets) + { + cv::circle(img_, cv::Point(int(tgt.cx * tgts_.width), int(tgt.cy * tgts_.height)), 4, cv::Scalar(0,255,0), 2); + if ((with_all || with_aruco) && tgt.has_aruco) + { + std::vector a_corners; + int a_id; + if (tgt.getAruco(a_id, a_corners)) { aruco_ids.push_back(a_id); aruco_corners.push_back(a_corners); } + } + if ((with_all || with_box) && tgt.has_box) + { + Box b; + tgt.getBox(b); + cv::rectangle(img_, cv::Rect(b.x1, b.y1, b.x2-b.x1+1, b.y2-b.y1+1), cv::Scalar(0,0,255), 1, 1, 0); + if ((with_all || with_category) && tgt.has_category) + { + cv::putText(img_, tgt.category, cv::Point(b.x1, b.y1-4), cv::FONT_HERSHEY_DUPLEX, 0.4, cv::Scalar(255,0,0)); + } + if ((with_all || with_tid) && tgt.has_tid) + { + char tmp[32]; + sprintf(tmp, "TID: %d", tgt.tracked_id); + cv::putText(img_, tmp, cv::Point(b.x1, b.y1-14), cv::FONT_HERSHEY_DUPLEX, 0.4, cv::Scalar(0,0,255)); + } + } + if ((with_all || with_ell) && tgt.has_ell) + { + double xc, yc, a, b, rad; + if (tgt.getEllipse(xc, yc, a, b, rad)) + { + ellipses.push_back(yaed::Ellipse(xc, yc, a, b, rad, tgt.score)); + } + } + if ((with_all || with_seg) && tgt.has_seg) + { + cv::Mat mask = tgt.getMask() * 255; + cv::threshold(mask, mask, 127, 255, cv::THRESH_BINARY); + mask.convertTo(mask, CV_8UC1); + + cv::resize(mask, mask, cv::Size(img_.cols, img_.rows)); + std::vector > contours; + std::vector hierarchy; + + cv::findContours(mask, contours, hierarchy, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); + cv::Mat mask_disp = img_.clone(); + cv::fillPoly(mask_disp, contours, cv::Scalar(255,255,255), cv::LINE_AA); + cv::polylines(img_, contours, true, cv::Scalar(255,255,255), 2, cv::LINE_AA); + + double alpha = 0.6; + cv::addWeighted(img_, alpha, mask_disp, 1.0-alpha, 0, img_); + } + } + if ((with_all || with_aruco) && aruco_ids.size() > 0) + { + cv::aruco::drawDetectedMarkers(img_, aruco_corners, aruco_ids); + } + if ((with_all || with_ell) && ellipses.size() > 0) + { + yaed::EllipseDetector ed; ed.DrawDetectedEllipses(img_, ellipses); + } +} + +std::string Target::getJsonStr() +{ + std::string json_str = "{"; + char buf[1024]; + if (this->has_box) + { + sprintf(buf, "\"box\":[%d,%d,%d,%d],", _b_box.x1, _b_box.y1, _b_box.x2 - _b_box.x1, _b_box.y2 - _b_box.y1); // xywh + json_str += std::string(buf); + } + if (this->has_ell) + { + sprintf(buf, "\"ell\":[%.3f,%.3f,%.3f,%.3f,%.3f],", this->_e_xc, this->_e_yc, this->_e_a, this->_e_b, this->_e_rad); // xyabr + json_str += std::string(buf); + } + if (this->has_yaw) + { + sprintf(buf, "\"yaw\":%.3f,", this->yaw_a); + json_str += std::string(buf); + } + if (this->has_los) + { + sprintf(buf, "\"los\":[%.3f,%.3f],", this->los_ax, this->los_ay); + json_str += std::string(buf); + } + if (this->has_position) + { + sprintf(buf, "\"pos\":[%.3f,%.3f,%.3f],", this->px, this->py, this->pz); + json_str += std::string(buf); + } + if (this->has_tid) + { + sprintf(buf, "\"tid\":%d,", this->tracked_id); + json_str += std::string(buf); + } + if (this->has_category) + { + sprintf(buf, "\"cat\":\"%s\",", this->category.c_str()); + json_str += std::string(buf); + } + sprintf(buf, "\"sc\":%.3f,\"cet\":[%.3f,%.3f]}", this->score, this->cx, this->cy); + json_str += std::string(buf); + return json_str; +} + +std::string TargetsInFrame::getJsonStr() +{ + std::string json_str = "{"; + char buf[1024]; + + if (this->has_fps) + { + sprintf(buf, "\"fps\":%.3f,", this->fps); + json_str += std::string(buf); + } + if (this->has_fov) + { + sprintf(buf, "\"fov\":[%.3f,%.3f],", this->fov_x, this->fov_y); + json_str += std::string(buf); + } + if (this->has_pod_info) + { + sprintf(buf, "\"pod\":[%.3f,%.3f,%.3f],", this->pod_patch, this->pod_roll, this->pod_yaw); + json_str += std::string(buf); + } + if (this->has_uav_pos) + { + sprintf(buf, "\"uav_pos\":[%.7f,%.7f,%.3f],", this->longitude, this->latitude, this->altitude); + json_str += std::string(buf); + } + if (this->has_uav_vel) + { + sprintf(buf, "\"uav_vel\":[%.3f,%.3f,%.3f],", this->uav_vx, this->uav_vy, this->uav_vz); + json_str += std::string(buf); + } + if (this->has_ill) + { + sprintf(buf, "\"ill\":%.3f,", this->illumination); + json_str += std::string(buf); + } + if (this->date_captured.size() > 0) + { + sprintf(buf, "\"time\":\"%s\",", this->date_captured.c_str()); + json_str += std::string(buf); + } + + json_str += "\"rois\":["; + for (int i=0; (int)irois.size(); i++) + { + if (i == (int)this->rois.size() - 1) + { + sprintf(buf, "[%d,%d,%d,%d]", this->rois[i].x1, this->rois[i].y1, this->rois[i].x2-this->rois[i].x1, this->rois[i].y2-this->rois[i].y1); + } + else + { + sprintf(buf, "[%d,%d,%d,%d],", this->rois[i].x1, this->rois[i].y1, this->rois[i].x2-this->rois[i].x1, this->rois[i].y2-this->rois[i].y1); + } + json_str += std::string(buf); + } + json_str += "],"; + json_str += "\"tgts\":["; + for (int i=0; (int)itargets.size(); i++) + { + if (i == (int)this->targets.size() - 1) + { + json_str += this->targets[i].getJsonStr(); + } + else + { + json_str += this->targets[i].getJsonStr() + ","; + } + } + json_str += "],"; + sprintf(buf, "\"h\":%d,\"w\":%d,\"fid\":%d}", this->height, this->width, this->frame_id); + json_str += std::string(buf); + return json_str; +} + +bool Target::getEllipse(double& xc_, double& yc_, double& a_, double& b_, double& rad_) +{ + xc_ = this->_e_xc; + yc_ = this->_e_yc; + a_ = this->_e_a; + b_ = this->_e_b; + rad_ = this->_e_rad; + return this->has_ell; +} + +bool Target::getAruco(int& id, std::vector &corners) +{ + id = this->_a_id; + corners = this->_a_corners; + return this->has_aruco; +} + +bool Target::getBox(Box& b) +{ + b = this->_b_box; + return this->has_box; +} + +void Target::setAruco(int id_, std::vector corners_, cv::Vec3d rvecs_, cv::Vec3d tvecs_, int img_w_, int img_h_, cv::Mat camera_matrix_) +{ + this->_a_id = id_; + this->_a_corners = corners_; + this->_a_rvecs = rvecs_; + this->_a_tvecs = tvecs_; + + double x_mid = (corners_[0].x + corners_[1].x) / 2.; + double y_mid = (corners_[0].y + corners_[1].y) / 2.; + + double left = std::min(std::min(corners_[0].x, corners_[1].x), std::min(corners_[2].x, corners_[3].x)); + double right = std::max(std::max(corners_[0].x, corners_[1].x), std::max(corners_[2].x, corners_[3].x)); + double top = std::min(std::min(corners_[0].y, corners_[1].y), std::min(corners_[2].y, corners_[3].y)); + double bottom = std::max(std::max(corners_[0].y, corners_[1].y), std::max(corners_[2].y, corners_[3].y)); + + double x_vec = x_mid - (left + right) / 2.; + double y_vec = y_mid - (top + bottom) / 2.; + + this->setYaw(x_vec, y_vec); + this->setBox(left, top, right, bottom, img_w_, img_h_); + + this->score = 1.; + char cate[256]; + sprintf(cate, "aruco-%d", id_); + this->setCategory(cate, id_); + this->setTrackID(id_); + this->setLOS(this->cx, this->cy, camera_matrix_, img_w_, img_h_); + this->setPosition(tvecs_[0], tvecs_[1], tvecs_[2]); + + this->has_aruco = true; +} + +void Target::setYaw(double vec_x_, double vec_y_) +{ + if (vec_x_ == 0. && vec_y_ > 0.) + { + this->yaw_a = 180; + } + else if (vec_x_ == 0. && vec_y_ < 0.) + { + this->yaw_a = 0; + } + else if (vec_x_ > 0. && vec_y_ == 0.) + { + this->yaw_a = 90; + } + else if (vec_x_ > 0. && vec_y_ > 0.) + { + this->yaw_a = 180 - atan(vec_x_ / vec_y_) * SV_RAD2DEG; + } + else if (vec_x_ > 0. && vec_y_ < 0.) + { + this->yaw_a = atan(vec_x_ / -vec_y_) * SV_RAD2DEG; + } + else if (vec_x_ < 0. && vec_y_ == 0.) + { + this->yaw_a = -90; + } + else if (vec_x_ < 0. && vec_y_ > 0.) + { + this->yaw_a = atan(-vec_x_ / vec_y_) * SV_RAD2DEG - 180; + } + else if (vec_x_ < 0. && vec_y_ < 0.) + { + this->yaw_a = -atan(-vec_x_ / -vec_y_) * SV_RAD2DEG; + } + this->has_yaw = true; +} + +void Target::setEllipse(double xc_, double yc_, double a_, double b_, double rad_, double score_, int img_w_, int img_h_, cv::Mat camera_matrix_, double radius_in_meter_) +{ + this->_e_xc = xc_; + this->_e_yc = yc_; + this->_e_a = a_; + this->_e_b = b_; + this->_e_rad = rad_; + this->has_ell = true; + + this->score = score_; + cv::Rect rect; + yaed::Ellipse ell(xc_, yc_, a_, b_, rad_); + ell.GetRectangle(rect); + this->setBox(rect.x, rect.y, rect.x + rect.width, rect.y + rect.height, img_w_, img_h_); + this->setCategory("ellipse", 0); + this->setLOS(this->cx, this->cy, camera_matrix_, img_w_, img_h_); + + if (radius_in_meter_ > 0) + { + double z = camera_matrix_.at(0, 0) * radius_in_meter_ / b_; + double x = tan(this->los_ax / SV_RAD2DEG) * z; + double y = tan(this->los_ay / SV_RAD2DEG) * z; + this->setPosition(x, y, z); + } +} + +void Target::setLOS(double cx_, double cy_, cv::Mat camera_matrix_, int img_w_, int img_h_) +{ + this->los_ax = atan((cx_ * img_w_ - img_w_ / 2.) / camera_matrix_.at(0, 0)) * SV_RAD2DEG; + this->los_ay = atan((cy_ * img_h_ - img_h_ / 2.) / camera_matrix_.at(1, 1)) * SV_RAD2DEG; + this->has_los = true; +} + +void Target::setCategory(std::string cate_, int cate_id_) +{ + this->category = cate_; + this->category_id = cate_id_; + this->has_category = true; +} + +void Target::setTrackID(int id_) +{ + this->tracked_id = id_; + this->has_tid = true; +} + +void Target::setPosition(double x_, double y_, double z_) +{ + this->px = x_; + this->py = y_; + this->pz = z_; + this->has_position = true; +} + +void Target::setBox(int x1_, int y1_, int x2_, int y2_, int img_w_, int img_h_) +{ + this->_b_box.setXYXY(x1_, y1_, x2_, y2_); + + this->cx = (double)(x2_ + x1_) / 2 / img_w_; + this->cy = (double)(y2_ + y1_) / 2 / img_h_; + this->w = (double)(x2_ - x1_) / img_w_; + this->h = (double)(y2_ - y1_) / img_h_; + + // std::cout << this->cx << ", " << this->cy << ", " << this->w << "," << this->h << std::endl; + + this->has_box = true; + this->has_hw = true; +} + +void Target::setMask(cv::Mat mask_) +{ + this->_mask = mask_; + this->has_seg = true; +} +cv::Mat Target::getMask() +{ + return this->_mask; +} + +TargetsInFrame::TargetsInFrame(int frame_id_) +{ + this->frame_id = frame_id_; + this->height = -1; + this->width = -1; + this->has_fps = false; + this->has_fov = false; + this->has_roi = false; + this->has_pod_info = false; + this->has_uav_pos = false; + this->has_uav_vel = false; + this->has_ill = false; + this->type = MissionType::NONE; +} +void TargetsInFrame::setTimeNow() +{ + this->date_captured = _get_time_str(); +} +void TargetsInFrame::setSize(int width_, int height_) +{ + this->width = width_; + this->height = height_; +} +void TargetsInFrame::setFPS(double fps_) +{ + this->fps = fps_; + this->has_fps = true; +} +void TargetsInFrame::setFOV(double fov_x_, double fov_y_) +{ + this->fov_x = fov_x_; + this->fov_y = fov_y_; + this->has_fov = true; +} + +Box::Box() +{ + +} + + +void Box::setXYXY(int x1_, int y1_, int x2_, int y2_) +{ + x1 = x1_; + y1 = y1_; + x2 = x2_; + y2 = y2_; +} + +void Box::setXYWH(int x_, int y_, int w_, int h_) +{ + x1 = x_; + y1 = y_; + x2 = x_ + w_ - 1; + y2 = y_ + h_ - 1; +} + + +VideoWriterBase::VideoWriterBase() +{ + this->_is_running = false; + this->_fid = 0; + this->_fcnt = 0; +} +VideoWriterBase::~VideoWriterBase() +{ + this->release(); + this->_tt.join(); +} +cv::Size VideoWriterBase::getSize() +{ + return this->_image_size; +} +double VideoWriterBase::getFps() +{ + return this->_fps; +} +std::string VideoWriterBase::getFilePath() +{ + return this->_file_path; +} +bool VideoWriterBase::isRunning() +{ + return this->_is_running; +} +void VideoWriterBase::setup(std::string file_path, cv::Size size, double fps, bool with_targets) +{ + this->_file_path = file_path; + this->_fps = fps; + this->_image_size = size; + this->_with_targets = with_targets; + + this->_init(); + + this->_tt = std::thread(&VideoWriterBase::_run, this); + this->_tt.detach(); +} +void VideoWriterBase::write(cv::Mat image, TargetsInFrame tgts) +{ + if (this->_is_running) + { + cv::Mat image_put; + if (this->_image_size.height == image.rows && this->_image_size.width == image.cols) + { + image.copyTo(image_put); + } + else + { + char msg[256]; + sprintf(msg, "SpireCV (106) Input image SIZE (%d, %d) != Saving SIZE (%d, %d)!", image.cols, image.rows, this->_image_size.width, this->_image_size.height); + throw std::runtime_error(msg); + // cv::resize(image, image_put, this->_image_size); + } + + if (this->_targets_ofs) + { + this->_fid ++; + image_put = _attach_aruco(this->_fid, image_put); + tgts.frame_id = this->_fid; + this->_tgts_to_write.push(tgts); + if (this->_fid >= SV_MAX_FRAMES) + this->_fid = 0; + } + this->_image_to_write.push(image_put); + } +} +void VideoWriterBase::_run() +{ + while (this->_is_running && isOpenedImpl()) + { + while (!this->_image_to_write.empty()) + { + this->_fcnt ++; + + cv::Mat img = _image_to_write.front(); + if (this->_targets_ofs) + { + if (!this->_tgts_to_write.empty()) + { + TargetsInFrame tgts = this->_tgts_to_write.front(); + + std::string json_str = tgts.getJsonStr(); + _targets_ofs << json_str << std::endl; + + this->_tgts_to_write.pop(); + } + } + // this->_writer << img; + writeImpl(img); + this->_image_to_write.pop(); + + if (this->_fcnt >= SV_MAX_FRAMES) + { + _init(); + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(int(1000 / this->_fps))); + } +} + +void VideoWriterBase::_init() +{ + this->release(); + + // get now time + time_t t = time(NULL);; + tm* local = localtime(&t); + + char s_buf[128]; + strftime(s_buf, 64, "/FlyVideo_%Y-%m-%d_%H-%M-%S", local); + std::string name = std::string(s_buf); + + bool opend = false; + opend = setupImpl(name); + + if (!opend) + { + std::cout << "Failed to write video: " << _file_path + name << std::endl; + } + else + { + this->_is_running = true; + if (this->_with_targets) + { + this->_targets_ofs.open(this->_file_path + name + ".svj"); + if (!this->_targets_ofs) + { + std::cout << "Failed to write info file: " << this->_file_path << std::endl; + this->_is_running = false; + } + } + } +} +void VideoWriterBase::release() +{ + this->_is_running = false; + this->_fid = 0; + this->_fcnt = 0; + + if (this->_targets_ofs.is_open()) + this->_targets_ofs.close(); + + while (!this->_image_to_write.empty()) + this->_image_to_write.pop(); + while (!this->_tgts_to_write.empty()) + this->_tgts_to_write.pop(); + + releaseImpl(); +} +bool VideoWriterBase::setupImpl(std::string file_name_) +{ + return false; +} +bool VideoWriterBase::isOpenedImpl() +{ + return false; +} +void VideoWriterBase::writeImpl(cv::Mat img_) +{ + +} +void VideoWriterBase::releaseImpl() +{ + +} + + +CameraBase::CameraBase(CameraType type, int id) +{ + this->_is_running = false; + this->_is_updated = false; + this->_type = type; + + this->_width = -1; + this->_height = -1; + this->_fps = -1; + this->_ip = "192.168.2.64"; + this->_port = -1; + this->_brightness = -1; + this->_contrast = -1; + this->_saturation = -1; + this->_hue = -1; + this->_exposure = -1; + + this->open(type, id); +} +CameraBase::~CameraBase() +{ + this->_is_running = false; + this->_tt.join(); +} +void CameraBase::setWH(int width, int height) +{ + this->_width = width; + this->_height = height; +} +void CameraBase::setFps(int fps) +{ + this->_fps = fps; +} +void CameraBase::setIp(std::string ip) +{ + this->_ip = ip; +} +void CameraBase::setPort(int port) +{ + this->_port = port; +} +void CameraBase::setBrightness(double brightness) +{ + this->_brightness = brightness; +} +void CameraBase::setContrast(double contrast) +{ + this->_contrast = contrast; +} +void CameraBase::setSaturation(double saturation) +{ + this->_saturation = saturation; +} +void CameraBase::setHue(double hue) +{ + this->_hue = hue; +} +void CameraBase::setExposure(double exposure) +{ + this->_exposure = exposure; +} + +int CameraBase::getW() +{ + return this->_width; +} +int CameraBase::getH() +{ + return this->_height; +} +int CameraBase::getFps() +{ + return this->_fps; +} +std::string CameraBase::getIp() +{ + return this->_ip; +} +int CameraBase::getPort() +{ + return this->_port; +} +double CameraBase::getBrightness() +{ + return this->_brightness; +} +double CameraBase::getContrast() +{ + return this->_contrast; +} +double CameraBase::getSaturation() +{ + return this->_saturation; +} +double CameraBase::getHue() +{ + return this->_hue; +} +double CameraBase::getExposure() +{ + return this->_exposure; +} +bool CameraBase::isRunning() +{ + return this->_is_running; +} + +void CameraBase::openImpl() +{ + if (this->_type == CameraType::WEBCAM) + { + this->_cap.open(this->_camera_id); + if (this->_width > 0 && this->_height > 0) + { + this->_cap.set(cv::CAP_PROP_FRAME_WIDTH, this->_width); + this->_cap.set(cv::CAP_PROP_FRAME_HEIGHT, this->_height); + } + if (this->_fps > 0) + { + this->_cap.set(cv::CAP_PROP_FPS, this->_fps); + } + if (this->_brightness > 0) + { + this->_cap.set(cv::CAP_PROP_BRIGHTNESS, this->_brightness); + } + if (this->_contrast > 0) + { + this->_cap.set(cv::CAP_PROP_CONTRAST, this->_contrast); + } + if (this->_saturation > 0) + { + this->_cap.set(cv::CAP_PROP_SATURATION, this->_saturation); + } + if (this->_hue > 0) + { + this->_cap.set(cv::CAP_PROP_HUE, this->_hue); + } + if (this->_exposure > 0) + { + this->_cap.set(cv::CAP_PROP_EXPOSURE, this->_exposure); + } + } + else if (this->_type == CameraType::G1) + { + char pipe[512]; + if (this->_width <= 0 || this->_height <= 0) + { + this->_width = 1280; + this->_height = 720; + } + if (this->_port <= 0) + { + this->_port = 554; + } + if (this->_fps <= 0) + { + this->_fps = 30; + } + // sprintf(pipe, "rtsp://%s:%d/H264?W=%d&H=%d&BR=10000000&FPS=%d", this->_ip.c_str(), this->_port, this->_width, this->_height, this->_fps); + sprintf(pipe, "rtspsrc location=rtsp://%s:%d/H264?W=%d&H=%d&FPS=%d&BR=4000000 latency=100 ! application/x-rtp,media=video ! rtph264depay ! parsebin ! nvv4l2decoder enable-max-performancegst=1 ! nvvidconv ! video/x-raw,format=(string)BGRx ! videoconvert ! appsink sync=false", this->_ip.c_str(), this->_port, this->_width, this->_height, this->_fps); + // std::cout << pipe << std::endl; + // this->_cap.open(pipe); // cv::CAP_GSTREAMER + this->_cap.open(pipe, cv::CAP_GSTREAMER); + } +} +void CameraBase::open(CameraType type, int id) +{ + this->_type = type; + this->_camera_id = id; + + openImpl(); + + if (this->_cap.isOpened()) + { + std::cout << "Camera opened!" << std::endl; + this->_is_running = true; + this->_tt = std::thread(&CameraBase::_run, this); + this->_tt.detach(); + } + else if (type != CameraType::NONE) + { + std::cout << "Camera can NOT open!" << std::endl; + } +} +void CameraBase::_run() +{ + while (this->_is_running && this->_cap.isOpened()) + { + this->_cap >> this->_frame; + this->_is_updated = true; + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + } +} +bool CameraBase::read(cv::Mat& image) +{ + if (this->_type == CameraType::WEBCAM || this->_type == CameraType::G1) + { + int n_try = 0; + while (n_try < 5000) + { + if (this->_is_updated) + { + this->_is_updated = false; + this->_frame.copyTo(image); + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + n_try ++; + } + } + if (image.cols == 0 || image.rows == 0) + { + throw std::runtime_error("SpireCV (101) Camera cannot OPEN, check CAMERA_ID!"); + } + return image.cols > 0 && image.rows > 0; +} +void CameraBase::release() +{ + _cap.release(); +} + + + +VideoStreamerBase::VideoStreamerBase() +{ + this->_is_running = false; +} +VideoStreamerBase::~VideoStreamerBase() +{ + this->release(); +} +cv::Size VideoStreamerBase::getSize() +{ + return this->_stream_size; +} +int VideoStreamerBase::getPort() +{ + return this->_port; +} +std::string VideoStreamerBase::getUrl() +{ + return this->_url; +} +int VideoStreamerBase::getBitrate() +{ + return this->_bitrate; +} +bool VideoStreamerBase::isRunning() +{ + return this->_is_running; +} + +void VideoStreamerBase::_run() +{ + while (this->_is_running) + { + if (isOpenedImpl()) + { + if (!this->_image_to_stream.empty()) + { + cv::Mat img = this->_image_to_stream.top(); + writeImpl(img); + + while (!this->_image_to_stream.empty()) + { + this->_image_to_stream.pop(); + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + else + { + std::cout << "VideoStreamer.isOpened(): FALSE" << std::endl; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + setupImpl(); + } + } + std::cout << "streaming == FALSE" << std::endl; +} + +void VideoStreamerBase::release() +{ + this->_is_running = false; + releaseImpl(); +} + +void VideoStreamerBase::setup(cv::Size size, int port, int bitrate, std::string url) +{ + this->_bitrate = bitrate; + this->_stream_size = size; + + this->_port = port; + this->_url = url; + + if (setupImpl()) + { + // std::cout << "stream ready at rtsp://127.0.0.1:" << this->_port << this->_url << std::endl; + + this->_is_running = true; + this->_tt = std::thread(&VideoStreamerBase::_run, this); + this->_tt.detach(); + } +} + +void VideoStreamerBase::stream(cv::Mat image) +{ + if (this->_is_running) + { + cv::Mat image_stream; + if (this->_stream_size.height == image.rows && this->_stream_size.width == image.cols) + image.copyTo(image_stream); + else + cv::resize(image, image_stream, this->_stream_size); + + this->_image_to_stream.push(image_stream); + } +} + +bool VideoStreamerBase::setupImpl() +{ + return false; +} +bool VideoStreamerBase::isOpenedImpl() +{ + return false; +} +void VideoStreamerBase::writeImpl(cv::Mat image) +{ + +} + +void VideoStreamerBase::releaseImpl() +{ + +} + + +} + diff --git a/video_io/sv_video_input.cpp b/video_io/sv_video_input.cpp new file mode 100644 index 0000000..352850b --- /dev/null +++ b/video_io/sv_video_input.cpp @@ -0,0 +1,77 @@ +#include "sv_video_input.h" +#include +#include + + + +namespace sv { + + +Camera::Camera() +{ +} +Camera::~Camera() +{ +} + + +void Camera::openImpl() +{ + if (this->_type == CameraType::WEBCAM) + { + this->_cap.open(this->_camera_id); + if (this->_width > 0 && this->_height > 0) + { + this->_cap.set(cv::CAP_PROP_FRAME_WIDTH, this->_width); + this->_cap.set(cv::CAP_PROP_FRAME_HEIGHT, this->_height); + } + if (this->_fps > 0) + { + this->_cap.set(cv::CAP_PROP_FPS, this->_fps); + } + if (this->_brightness > 0) + { + this->_cap.set(cv::CAP_PROP_BRIGHTNESS, this->_brightness); + } + if (this->_contrast > 0) + { + this->_cap.set(cv::CAP_PROP_CONTRAST, this->_contrast); + } + if (this->_saturation > 0) + { + this->_cap.set(cv::CAP_PROP_SATURATION, this->_saturation); + } + if (this->_hue > 0) + { + this->_cap.set(cv::CAP_PROP_HUE, this->_hue); + } + if (this->_exposure > 0) + { + this->_cap.set(cv::CAP_PROP_EXPOSURE, this->_exposure); + } + } + else if (this->_type == CameraType::G1) + { + char pipe[512]; + if (this->_width <= 0 || this->_height <= 0) + { + this->_width = 1280; + this->_height = 720; + } + if (this->_port <= 0) + { + this->_port = 554; + } + if (this->_fps <= 0) + { + this->_fps = 30; + } + + sprintf(pipe, "rtspsrc location=rtsp://%s:%d/H264?W=%d&H=%d&FPS=%d&BR=4000000 latency=100 ! application/x-rtp,media=video ! rtph264depay ! parsebin ! nvv4l2decoder enable-max-performancegst=1 ! nvvidconv ! video/x-raw,format=(string)BGRx ! videoconvert ! appsink sync=false", this->_ip.c_str(), this->_port, this->_width, this->_height, this->_fps); + this->_cap.open(pipe, cv::CAP_GSTREAMER); + } +} + + +} + diff --git a/video_io/sv_video_output.cpp b/video_io/sv_video_output.cpp new file mode 100644 index 0000000..4a0e2ad --- /dev/null +++ b/video_io/sv_video_output.cpp @@ -0,0 +1,151 @@ +#include "sv_video_output.h" +#include +#include +#ifdef WITH_GSTREAMER +#include "streamer_gstreamer_impl.h" +#include "writer_gstreamer_impl.h" +#endif +#ifdef WITH_FFMPEG +#include "bs_push_streamer.h" +#include "bs_video_saver.h" +#endif + + +namespace sv { + + +VideoWriter::VideoWriter() +{ +#ifdef WITH_GSTREAMER + this->_gstreamer_impl = new VideoWriterGstreamerImpl; +#endif +#ifdef WITH_FFMPEG + this->_ffmpeg_impl = new BsVideoSaver; +#endif +} +VideoWriter::~VideoWriter() +{ +} + +bool VideoWriter::setupImpl(std::string file_name_) +{ + cv::Size img_sz = this->getSize(); + double fps = this->getFps(); + std::string file_path = this->getFilePath(); + +#ifdef WITH_GSTREAMER + return this->_gstreamer_impl->gstreamerSetup(this, file_name_); +#endif +#ifdef WITH_FFMPEG +#ifdef PLATFORM_X86_CUDA + std::string enc = "h264_nvenc"; +#else + std::string enc = ""; +#endif + return this->_ffmpeg_impl->setup(file_path + file_name_ + ".avi", img_sz.width, img_sz.height, (int)fps, enc, 4); +#endif + return false; +} + +bool VideoWriter::isOpenedImpl() +{ +#ifdef WITH_GSTREAMER + return this->_gstreamer_impl->gstreamerIsOpened(); +#endif +#ifdef WITH_FFMPEG + return this->isRunning(); +#endif + return false; +} + +void VideoWriter::writeImpl(cv::Mat img_) +{ +#ifdef WITH_GSTREAMER + this->_gstreamer_impl->gstreamerWrite(img_); +#endif +#ifdef WITH_FFMPEG + this->_ffmpeg_impl->write(img_); +#endif +} + +void VideoWriter::releaseImpl() +{ +#ifdef WITH_GSTREAMER + this->_gstreamer_impl->gstreamerRelease(); +#endif +#ifdef WITH_FFMPEG + this->_ffmpeg_impl->stop(); +#endif +} + + + +VideoStreamer::VideoStreamer() +{ +#ifdef WITH_GSTREAMER + this->_gstreamer_impl = new VideoStreamerGstreamerImpl; +#endif +#ifdef WITH_FFMPEG + this->_ffmpeg_impl = new BsPushStreamer; +#endif +} +VideoStreamer::~VideoStreamer() +{ +} + +bool VideoStreamer::setupImpl() +{ + cv::Size img_sz = this->getSize(); + int port = this->getPort(); + std::string url = this->getUrl(); + int bitrate = this->getBitrate(); + +#ifdef WITH_GSTREAMER + return this->_gstreamer_impl->gstreamerSetup(this); +#endif +#ifdef WITH_FFMPEG + std::string rtsp_url = "rtsp://127.0.0.1/live" + url; +#ifdef PLATFORM_X86_CUDA + std::string enc = "h264_nvenc"; +#else + std::string enc = ""; +#endif + return this->_ffmpeg_impl->setup(rtsp_url, img_sz.width, img_sz.height, 24, enc, bitrate); +#endif + return false; +} + +bool VideoStreamer::isOpenedImpl() +{ +#ifdef WITH_GSTREAMER + return this->_gstreamer_impl->gstreamerIsOpened(); +#endif +#ifdef WITH_FFMPEG + return this->isRunning(); +#endif + return false; +} + +void VideoStreamer::writeImpl(cv::Mat img_) +{ +#ifdef WITH_GSTREAMER + this->_gstreamer_impl->gstreamerWrite(img_); +#endif +#ifdef WITH_FFMPEG + this->_ffmpeg_impl->stream(img_); +#endif +} + +void VideoStreamer::releaseImpl() +{ +#ifdef WITH_GSTREAMER + this->_gstreamer_impl->gstreamerRelease(); +#endif +#ifdef WITH_FFMPEG + this->_ffmpeg_impl->stop(); +#endif +} + + +} +