diff --git a/utils/super/error.go b/utils/super/error.go index 226b84c..9d1b2b8 100644 --- a/utils/super/error.go +++ b/utils/super/error.go @@ -2,36 +2,67 @@ package super import ( "errors" + "sync" ) var errorMapper = make(map[error]int) var errorMapperRef = make(map[error]error) +var mutex sync.Mutex // RegError 通过错误码注册错误,返回错误的引用 func RegError(code int, message string) error { if code == 0 { - panic("error code can not be 0") + return errors.New("error code can not be 0") } - err := errors.New(message) + mutex.Lock() + defer mutex.Unlock() + err := &ser{code: code, message: message} errorMapper[err] = code return err } // RegErrorRef 通过错误码注册错误,返回错误的引用 +// - 引用将会被重定向到注册的错误信息 func RegErrorRef(code int, message string, ref error) error { if code == 0 { - panic("error code can not be 0") + return errors.New("error code can not be 0") } - err := errors.New(message) + mutex.Lock() + defer mutex.Unlock() + err := &ser{code: code, message: message} errorMapper[err] = code errorMapperRef[ref] = err return ref } -// GetErrorCode 通过错误引用获取错误码,如果错误不存在则返回 0 -func GetErrorCode(err error) (int, error) { - if ref, exist := errorMapperRef[err]; exist { +// GetError 通过错误引用获取错误码和真实错误信息,如果错误不存在则返回 0,如果错误引用不存在则返回原本的错误 +func GetError(err error) (int, error) { + unw := errors.Unwrap(err) + if unw == nil { + unw = err + } + mutex.Lock() + defer mutex.Unlock() + if ref, exist := errorMapperRef[unw]; exist { + //err = fmt.Errorf("%w : %s", ref, err.Error()) err = ref } - return errorMapper[err], err + unw = errors.Unwrap(err) + if unw == nil { + unw = err + } + code, exist := errorMapper[unw] + if !exist { + return 0, errors.New("error not found") + } + return code, err +} + +type ser struct { + code int + message string +} + +func (slf *ser) Error() string { + return slf.message } diff --git a/utils/super/error_test.go b/utils/super/error_test.go new file mode 100644 index 0000000..32b417f --- /dev/null +++ b/utils/super/error_test.go @@ -0,0 +1,13 @@ +package super_test + +import ( + "errors" + "github.com/kercylan98/minotaur/utils/super" + "testing" +) + +func TestGetError(t *testing.T) { + var ErrNotFound = errors.New("not found") + var _ = super.RegErrorRef(100, "test error", ErrNotFound) + t.Log(super.GetError(ErrNotFound)) +}