From 51fe4bd0c31a870ee13bf51dbf9779cd399e8c91 Mon Sep 17 00:00:00 2001 From: wangjiaming0909 Date: Wed, 12 Feb 2025 14:53:03 +0800 Subject: [PATCH] decimal use int256/int128 --- source/libs/decimal/CMakeLists.txt | 3 +- source/libs/decimal/inc/wideInteger.h | 119 +- source/libs/decimal/src/decimal.c | 159 ++- source/libs/decimal/src/detail/CMakeLists.txt | 15 + .../libs/decimal/src/detail/intx/int128.hpp | 885 ++++++++++++ source/libs/decimal/src/detail/intx/intx.hpp | 1221 +++++++++++++++++ .../libs/decimal/src/detail/wideInteger.cpp | 244 ++++ source/libs/decimal/src/wideInteger.c | 26 - source/libs/decimal/test/decimalTest.cpp | 350 ++++- source/libs/executor/CMakeLists.txt | 2 +- source/libs/executor/src/operator.c | 3 + source/libs/scalar/src/filter.c | 1 + source/libs/scalar/src/sclvector.c | 84 +- 13 files changed, 2963 insertions(+), 149 deletions(-) create mode 100644 source/libs/decimal/src/detail/CMakeLists.txt create mode 100644 source/libs/decimal/src/detail/intx/int128.hpp create mode 100644 source/libs/decimal/src/detail/intx/intx.hpp create mode 100644 source/libs/decimal/src/detail/wideInteger.cpp delete mode 100644 source/libs/decimal/src/wideInteger.c diff --git a/source/libs/decimal/CMakeLists.txt b/source/libs/decimal/CMakeLists.txt index 75c4874461..440953216b 100644 --- a/source/libs/decimal/CMakeLists.txt +++ b/source/libs/decimal/CMakeLists.txt @@ -9,9 +9,10 @@ target_include_directories( ) target_link_libraries( decimal - PRIVATE os common + PRIVATE os common wideInteger ) if(${BUILD_TEST}) ADD_SUBDIRECTORY(test) endif(${BUILD_TEST}) +ADD_SUBDIRECTORY(src/detail) diff --git a/source/libs/decimal/inc/wideInteger.h b/source/libs/decimal/inc/wideInteger.h index 8f1e341f91..1f2970dc09 100644 --- a/source/libs/decimal/inc/wideInteger.h +++ b/source/libs/decimal/inc/wideInteger.h @@ -16,54 +16,118 @@ #ifndef _TD_WIDE_INTEGER_H_ #define _TD_WIDE_INTEGER_H_ +#include #include -#include "tdef.h" #ifdef __cplusplus extern "C" { #endif -typedef struct _UInt128 { +typedef struct uint128 { uint64_t low; uint64_t high; -} _UInt128; +} uint128; -typedef struct Int256 { - char data[32]; -} Int256; +struct int128 { + uint64_t low; + int64_t high; +}; -// TODO wjm use cmake to check if the compiler supports __int128_t -#if defined(__GNUC__) || defined(__clang__) -// #if 0 -typedef __uint128_t UInt128; -#else -typedef _UInt128 UInt128; -#define Int128 UInt128 -#endif +typedef struct uint256 { + uint128 low; + uint128 high; +} uint256; // TODO wjm remove typedef + +struct int256 { + uint128 low; + struct int128 high; +}; + +#define UInt128 uint128 +#define UInt256 uint256 +#define Int128 struct int128 +#define Int256 struct int256 #define SAFE_SIGNED_OP(a, b, SIGNED_TYPE, UNSIGNED_TYPE, OP) (SIGNED_TYPE)((UNSIGNED_TYPE)(a)OP(UNSIGNED_TYPE)(b)) #define SAFE_INT64_ADD(a, b) SAFE_SIGNED_OP(a, b, int64_t, uint64_t, +) #define SAFE_INT64_SUBTRACT(a, b) SAFE_SIGNED_OP(a, b, int64_t, uint64_t, -) -void makeUInt128(UInt128* pInt, DecimalWord hi, DecimalWord lo); +void makeUInt128(UInt128* pInt, uint64_t hi, uint64_t lo); uint64_t uInt128Hi(const UInt128* pInt); uint64_t uInt128Lo(const UInt128* pInt); - -void uInt128Abs(UInt128* pInt); -void uInt128Add(UInt128* pLeft, const UInt128* pRight); -void uInt128Subtract(UInt128* pLeft, const UInt128* pRight); -void uInt128Multiply(UInt128* pLeft, const UInt128* pRight); -void uInt128Divide(UInt128* pLeft, const UInt128* pRight); -void uInt128Mod(UInt128* pLeft, const UInt128* pRight); -bool uInt128Lt(const UInt128* pLeft, const UInt128* pRight); -bool uInt128Gt(const UInt128* pLeft, const UInt128* pRight); -bool uInt128Eq(const UInt128* pLeft, const UInt128* pRight); +void uInt128Add(UInt128* pLeft, const UInt128* pRight); +void uInt128Subtract(UInt128* pLeft, const UInt128* pRight); +void uInt128Multiply(UInt128* pLeft, const UInt128* pRight); +void uInt128Divide(UInt128* pLeft, const UInt128* pRight); +void uInt128Mod(UInt128* pLeft, const UInt128* pRight); +bool uInt128Lt(const UInt128* pLeft, const UInt128* pRight); +bool uInt128Gt(const UInt128* pLeft, const UInt128* pRight); +bool uInt128Eq(const UInt128* pLeft, const UInt128* pRight); extern const UInt128 uInt128_1e18; extern const UInt128 uInt128Zero; extern const uint64_t k1e18; +extern const UInt128 uInt128One; +extern const UInt128 uInt128Two; + +Int128 makeInt128(int64_t high, uint64_t low); +int64_t int128Hi(const Int128* pUint128); +uint64_t int128Lo(const Int128* pUint128); +Int128 int128Abs(const Int128* pInt128); +Int128 int128Negate(const Int128* pInt128); +Int128 int128Add(const Int128* pLeft, const Int128* pRight); +Int128 int128Subtract(const Int128* pLeft, const Int128* pRight); +Int128 int128Multiply(const Int128* pLeft, const Int128* pRight); +Int128 int128Divide(const Int128* pLeft, const Int128* pRight); +Int128 int128Mod(const Int128* pLeft, const Int128* pRight); +bool int128Lt(const Int128* pLeft, const Int128* pRight); +bool int128Gt(const Int128* pLeft, const Int128* pRight); +bool int128Eq(const Int128* pLeft, const Int128* pRight); +Int128 int128RightShift(const Int128* pLeft, int32_t shift); + +extern const Int128 int128Zero; +extern const Int128 int128One; + +UInt256 makeUint256(UInt128 high, UInt128 low); +UInt128 uInt256Hi(const UInt256* pUint256); +UInt128 uInt256Lo(const UInt256* pUint256); +UInt256 uInt256Add(const UInt256* pLeft, const UInt256* pRight); +UInt256 uInt256Subtract(const UInt256* pLeft, const UInt256* pRight); +UInt256 uInt256Multiply(const UInt256* pLeft, const UInt256* pRight); +UInt256 uInt256Divide(const UInt256* pLeft, const UInt256* pRight); +UInt256 uInt256Mod(const UInt256* pLeft, const UInt256* pRight); +bool uInt256Lt(const UInt256* pLeft, const UInt256* pRight); +bool uInt256Gt(const UInt256* pLeft, const UInt256* pRight); +bool uInt256Eq(const UInt256* pLeft, const UInt256* pRight); +UInt256 uInt256RightShift(const UInt256* pLeft, int32_t shift); + +extern const UInt256 uInt256Zero; +extern const UInt256 uInt256One; + +Int256 makeInt256(Int128 high, UInt128 low);// TODO wjm all params should be high then low +Int128 int256Hi(const Int256* pUint256); +UInt128 int256Lo(const Int256* pUint256); +Int256 int256Abs(const Int256* pInt256); +Int256 int256Negate(const Int256* pInt256); +Int256 int256Add(const Int256* pLeft, const Int256* pRight); +Int256 int256Subtract(const Int256* pLeft, const Int256* pRight); +Int256 int256Multiply(const Int256* pLeft, const Int256* pRight); +Int256 int256Divide(const Int256* pLeft, const Int256* pRight); +Int256 int256Mod(const Int256* pLeft, const Int256* pRight); +bool int256Lt(const Int256* pLeft, const Int256* pRight); +bool int256Gt(const Int256* pLeft, const Int256* pRight); +bool int256Eq(const Int256* pLeft, const Int256* pRight); +Int256 int256RightShift(const Int256* pLeft, int32_t shift); + +extern const Int256 int256Zero; +extern const Int256 int256One; +extern const Int256 int256Two; + +#ifdef __cplusplus +} +#endif static inline int32_t countLeadingZeros(uint64_t v) { -#if defined(__clang__) || defined(__GUNC__) +#if defined(__clang__) || defined(__GNUC__) if (v == 0) return 64; return __builtin_clzll(v); #else @@ -76,8 +140,5 @@ static inline int32_t countLeadingZeros(uint64_t v) { #endif } -#ifdef __cplusplus -} -#endif #endif /* _TD_WIDE_INTEGER_H_ */ diff --git a/source/libs/decimal/src/decimal.c b/source/libs/decimal/src/decimal.c index 603fd34f0c..28fb9e82c4 100644 --- a/source/libs/decimal/src/decimal.c +++ b/source/libs/decimal/src/decimal.c @@ -63,6 +63,13 @@ static const uint8_t typeConvertDecimalPrec[] = { int32_t decimalGetRetType(const SDataType* pLeftT, const SDataType* pRightT, EOperatorType opType, SDataType* pOutType) { + if (pLeftT->type == TSDB_DATA_TYPE_JSON || pRightT->type == TSDB_DATA_TYPE_JSON || + pLeftT->type == TSDB_DATA_TYPE_VARBINARY || pRightT->type == TSDB_DATA_TYPE_VARBINARY) + return TSDB_CODE_TSC_INVALID_OPERATION; + if ((pLeftT->type >= TSDB_DATA_TYPE_BLOB && pLeftT->type <= TSDB_DATA_TYPE_GEOMETRY) || + (pRightT->type >= TSDB_DATA_TYPE_BLOB && pRightT->type <= TSDB_DATA_TYPE_GEOMETRY)) { + return TSDB_CODE_TSC_INVALID_OPERATION; + } if (IS_FLOAT_TYPE(pLeftT->type) || IS_FLOAT_TYPE(pRightT->type) || IS_VAR_DATA_TYPE(pLeftT->type) || IS_VAR_DATA_TYPE(pRightT->type)) { pOutType->type = TSDB_DATA_TYPE_DOUBLE; @@ -75,8 +82,6 @@ int32_t decimalGetRetType(const SDataType* pLeftT, const SDataType* pRightT, EOp pOutType->bytes = tDataTypes[TSDB_DATA_TYPE_NULL].bytes; return 0; } - - // TODO wjm check not supported types uint8_t p1 = pLeftT->precision, s1 = pLeftT->scale, p2 = pRightT->precision, s2 = pRightT->scale; if (!IS_DECIMAL_TYPE(pLeftT->type)) { @@ -466,7 +471,7 @@ static const Decimal128 SCALE_MULTIPLIER_128[TSDB_DECIMAL128_MAX_PRECISION + 1] DEFINE_DECIMAL128(4003012203950112768ULL, 542101086242752LL), DEFINE_DECIMAL128(3136633892082024448ULL, 5421010862427522LL), DEFINE_DECIMAL128(12919594847110692864ULL, 54210108624275221LL), - DEFINE_DECIMAL128(68739955140067328ULL, 542101086242752217LL), + DEFINE_DECIMAL128(68739955140067328ULL, 542101086242752217LL), // TODO wjm TEST it DEFINE_DECIMAL128(687399551400673280ULL, 5421010862427522170LL), }; @@ -518,7 +523,7 @@ static void decimal128Negate(DecimalType* pWord) { Decimal128* pDec = (Decimal128*)pWord; uint64_t lo = ~DECIMAL128_LOW_WORD(pDec) + 1; int64_t hi = ~DECIMAL128_HIGH_WORD(pDec); - if (lo == 0) hi = SAFE_INT64_ADD(hi, 1); + if (lo == 0) hi = SAFE_INT64_ADD(hi, 1); // TODO wjm test if overflow? makeDecimal128(pDec, hi, lo); } @@ -567,7 +572,7 @@ static void decimal128Multiply(DecimalType* pLeft, const DecimalType* pRight, ui bool negate = DECIMAL128_SIGN(pLeftDec) != DECIMAL128_SIGN(pRightDec); Decimal128 x = *pLeftDec, y = *pRightDec; decimal128Abs(&x); - decimal128Abs(&y); + decimal128Abs(&y); // TODO wjm use too much abs, optimize it. UInt128 res = {0}, tmp = {0}; makeUInt128(&res, DECIMAL128_HIGH_WORD(&x), DECIMAL128_LOW_WORD(&x)); @@ -615,7 +620,8 @@ static void decimal128Mod(DecimalType* pLeft, const DecimalType* pRight, uint8_t Decimal128 pLeftDec = *(Decimal128*)pLeft, *pRightDec = (Decimal128*)pRight, right = {0}; DECIMAL128_CHECK_RIGHT_WORD_NUM(rightWordNum, pRightDec, right, pRight); - decimal128Divide(&pLeftDec, pRightDec, WORD_NUM(Decimal128), pLeft); + decimal128Divide(&pLeftDec, pRightDec, WORD_NUM(Decimal128), + pLeft); // TODO wjm test it pLeft and pRemainder use the same pointer } static bool decimal128Gt(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum) { @@ -642,8 +648,8 @@ static void extractDecimal128Digits(const Decimal128* pDec, uint64_t* digits, in *digitNum = 0; makeUInt128(&a, DECIMAL128_HIGH_WORD(pDec), DECIMAL128_LOW_WORD(pDec)); while (!uInt128Eq(&a, &uInt128Zero)) { - uint64_t hi = a >> 64; // TODO wjm use function, UInt128 may be a struct. - uint64_t lo = a; + uint64_t hi = uInt128Hi(&a); + uint64_t lo = uInt128Lo(&a); uint64_t hiQuotient = hi / k1e18; uint64_t hiRemainder = hi % k1e18; @@ -793,6 +799,66 @@ static void decimalAdd(Decimal* pX, const SDataType* pXT, const Decimal* pY, con } } +static void makeInt256FromDecimal128(Int256* pTarget, const Decimal128* pDec) { + bool negative = DECIMAL128_SIGN(pDec) == -1; + Decimal128 abs = *pDec; + decimal128Abs(&abs); + UInt128 tmp = {DECIMAL128_LOW_WORD(&abs), DECIMAL128_HIGH_WORD(&abs)}; + *pTarget = makeInt256(int128Zero, tmp); + if (negative) { + int256Negate(pTarget); + } +} + +static Int256 int256ScaleBy(const Int256* pX, int32_t scale) { + Int256 result = *pX; + if (scale > 0) { + Int256 multiplier = {0}; + makeInt256FromDecimal128(&multiplier, &SCALE_MULTIPLIER_128[scale]); + result = int256Multiply(pX, &multiplier); + } else if (scale < 0) { + Int256 divisor = {0}; + makeInt256FromDecimal128(&divisor, &SCALE_MULTIPLIER_128[-scale]); + result = int256Divide(pX, &divisor); + Int256 remainder = int256Mod(pX, &divisor); + Int256 afterShift = int256RightShift(&divisor, 1); + remainder = int256Abs(&remainder); + if (int256Gt(&remainder, &afterShift)) { + if (int256Gt(&result, &int256Zero)) { + result = int256Add(&result, &int256One); + } else { + result = int256Subtract(&result, &int256One); + } + } + } + return result; +} + +static bool convertInt256ToDecimal128(const Int256* pX, Decimal128* pDec) { + bool overflow = false; + Int256 abs = int256Abs(pX); + bool isNegative = int256Lt(pX, &int256Zero); + UInt128 low = int256Lo(&abs); + uint64_t lowLow= uInt128Lo(&low); + uint64_t lowHigh = uInt128Hi(&low); + Int256 afterShift = int256RightShift(&abs, 128); + + if (int256Gt(&afterShift, &int256Zero)) { + overflow = true; + } else if (lowHigh > INT64_MAX) { + overflow = true; + } else { + makeDecimal128(pDec, lowHigh, lowLow); + if (decimal128Gt(pDec, &decimal128Max, WORD_NUM(Decimal128))) { + overflow = true; + } + } + if (isNegative) { + decimal128Negate(pDec); + } + return overflow; +} + static int32_t decimalMultiply(Decimal* pX, const SDataType* pXT, const Decimal* pY, const SDataType* pYT, const SDataType* pOT) { if (pOT->precision < TSDB_DECIMAL_MAX_PRECISION) { @@ -819,7 +885,15 @@ static int32_t decimalMultiply(Decimal* pX, const SDataType* pXT, const Decimal* int32_t leadingZeros = decimal128CountLeadingBinaryZeros(&xAbs) + decimal128CountLeadingBinaryZeros(&yAbs); if (leadingZeros <= 128) { // need to trim scale - return TSDB_CODE_DECIMAL_OVERFLOW; + Int256 x256 = {0}, y256 = {0}; + makeInt256FromDecimal128(&x256, pX); + makeInt256FromDecimal128(&y256, pY); + Int256 res = int256Multiply(&x256, &y256); + if (deltaScale != 0) { + res = int256ScaleBy(&res, -deltaScale); + } + bool overflow = convertInt256ToDecimal128(&res, pX); + if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW; } else { // no need to trim scale if (deltaScale <= 38) { @@ -834,7 +908,7 @@ static int32_t decimalMultiply(Decimal* pX, const SDataType* pXT, const Decimal* return 0; } -int32_t decimalDivide(Decimal* pX, const SDataType* pXT, const Decimal* pY, const SDataType* pYT, +static int32_t decimalDivide(Decimal* pX, const SDataType* pXT, const Decimal* pY, const SDataType* pYT, const SDataType* pOT) { if (decimal128Eq(pY, &DECIMAL128_ZERO, WORD_NUM(Decimal))) { return TSDB_CODE_DECIMAL_OVERFLOW; // TODO wjm divide zero error @@ -860,10 +934,68 @@ int32_t decimalDivide(Decimal* pX, const SDataType* pXT, const Decimal* pY, cons Decimal64 extra = {(DECIMAL128_SIGN(pX) ^ DECIMAL128_SIGN(pY)) + 1}; decimal128Add(&xTmp, &extra, WORD_NUM(Decimal64)); } + *pX = xTmp; } else { - return TSDB_CODE_DECIMAL_OVERFLOW; + Int256 x256 = {0}, y256 = {0}; + makeInt256FromDecimal128(&x256, pX); + Int256 xScaledUp = int256ScaleBy(&x256, deltaScale); + makeInt256FromDecimal128(&y256, pY); + Int256 res = int256Divide(&xScaledUp, &y256); + Int256 remainder = int256Mod(&xScaledUp, &y256); + + remainder = int256Multiply(&remainder, &int256Two); + remainder = int256Abs(&remainder); + y256 = int256Abs(&y256); + if (!int256Lt(&remainder, &y256)) { + if ((DECIMAL128_SIGN(pX) ^ DECIMAL128_SIGN(pY)) == 0) { + res = int256Add(&res, &int256One); + } else { + res = int256Subtract(&res, &int256One); + } + } + bool overflow = convertInt256ToDecimal128(&res, pX); + if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW; + } + return 0; +} + +static int32_t decimalMod(Decimal* pX, const SDataType* pXT, const Decimal* pY, const SDataType* pYT, + const SDataType* pOT) { + if (decimal128Eq(pY, &DECIMAL128_ZERO, WORD_NUM(Decimal))) { + return TSDB_CODE_DECIMAL_OVERFLOW; // TODO wjm mod zero error + } + Decimal xAbs = *pX, yAbs = *pY; + decimal128Abs(&xAbs); + decimal128Abs(&yAbs); + int32_t xlz = decimal128CountLeadingBinaryZeros(&xAbs), ylz = decimal128CountLeadingBinaryZeros(&yAbs); + if (pXT->scale < pYT->scale) { + // x scale up + xlz = xlz - bitsForNumDigits[pYT->scale - pXT->scale]; + } else if (pXT->scale > pYT->scale) { + // y scale up + ylz = ylz - bitsForNumDigits[pXT->scale - pYT->scale]; + } + int32_t lz = TMIN(xlz, ylz); + if (lz >= 2) { + // it's safe to scale up + yAbs = *pY; + decimal128ScaleTo(pX, pXT->scale, TMAX(pXT->scale, pYT->scale)); + decimal128ScaleTo(&yAbs, pYT->scale, TMAX(pXT->scale, pYT->scale)); + decimal128Mod(pX, &yAbs, WORD_NUM(Decimal)); + } else { + Int256 x256 = {0}, y256 = {0}; + makeInt256FromDecimal128(&x256, pX); + makeInt256FromDecimal128(&y256, pY); + if (pXT->scale < pYT->scale) { + x256 = int256ScaleBy(&x256, pYT->scale - pXT->scale); + } else if (pXT->scale > pYT->scale) { + y256 = int256ScaleBy(&y256, pXT->scale - pYT->scale); + } + Int256 res = int256Mod(&x256, &y256); + if (convertInt256ToDecimal128(&res, pX)) { + return TSDB_CODE_DECIMAL_OVERFLOW; + } } - *pX = xTmp; return 0; } @@ -910,6 +1042,9 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR case OP_TYPE_DIV: code = decimalDivide(&left, <, &right, &rt, pOutT); break; + case OP_TYPE_REM: + code = decimalMod(&left, <, &right, &rt, pOutT); + break; default: code = TSDB_CODE_TSC_INVALID_OPERATION; break; diff --git a/source/libs/decimal/src/detail/CMakeLists.txt b/source/libs/decimal/src/detail/CMakeLists.txt new file mode 100644 index 0000000000..fb2b866db1 --- /dev/null +++ b/source/libs/decimal/src/detail/CMakeLists.txt @@ -0,0 +1,15 @@ +MESSAGE(STATUS "Building decimal/src/detail") +aux_source_directory(. WIDE_INTEGER_SRC) + +SET(CMAKE_CXX_STANDARD 14) +add_library(wideInteger STATIC ${WIDE_INTEGER_SRC}) + +target_include_directories( + wideInteger + PUBLIC "${TD_SOURCE_DIR}/source/libs/decimal/inc/" + PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/intx/" +) +target_link_libraries( + wideInteger + PUBLIC stdc++ +) \ No newline at end of file diff --git a/source/libs/decimal/src/detail/intx/int128.hpp b/source/libs/decimal/src/detail/intx/int128.hpp new file mode 100644 index 0000000000..b351c1e4cf --- /dev/null +++ b/source/libs/decimal/src/detail/intx/int128.hpp @@ -0,0 +1,885 @@ +// intx: extended precision integer library. +// Copyright 2019-2020 Pawel Bylica. +// Licensed under the Apache License, Version 2.0. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER + #include +#endif + +#ifdef _MSC_VER + #define INTX_UNREACHABLE __assume(0) +#else + #define INTX_UNREACHABLE __builtin_unreachable() +#endif + +#ifdef _MSC_VER + #define INTX_UNLIKELY(EXPR) (bool{EXPR}) +#else + #define INTX_UNLIKELY(EXPR) __builtin_expect(bool{EXPR}, false) +#endif + +#ifdef NDEBUG + #define INTX_REQUIRE(X) (X) ? (void)0 : INTX_UNREACHABLE +#else + #include + #define INTX_REQUIRE assert +#endif + +namespace intx +{ +template +struct uint; + +/// The 128-bit unsigned integer. +/// +/// This type is defined as a specialization of uint<> to easier integration with full intx package, +/// however, uint128 may be used independently. +template <> +struct uint<128> +{ + static constexpr unsigned num_bits = 128; + + uint64_t lo = 0; + uint64_t hi = 0; + + constexpr uint() noexcept = default; + + constexpr uint(uint64_t high, uint64_t low) noexcept : lo{low}, hi{high} {} + + template ::value>> + constexpr uint(T x) noexcept : lo(static_cast(x)) // NOLINT + {} + +#ifdef __SIZEOF_INT128__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" + constexpr uint(unsigned __int128 x) noexcept // NOLINT + : lo{uint64_t(x)}, hi{uint64_t(x >> 64)} + {} + + constexpr explicit operator unsigned __int128() const noexcept + { + return (static_cast(hi) << 64) | lo; + } + #pragma GCC diagnostic pop +#endif + + constexpr explicit operator bool() const noexcept { return hi | lo; } + + /// Explicit converting operator for all builtin integral types. + template ::value>::type> + constexpr explicit operator Int() const noexcept + { + return static_cast(lo); + } +}; + +using uint128 = uint<128>; + + +/// Contains result of add/sub/etc with a carry flag. +template +struct result_with_carry +{ + T value; + bool carry; + + /// Conversion to tuple of references, to allow usage with std::tie(). + constexpr operator std::tuple() noexcept { return {value, carry}; } +}; + + +/// Linear arithmetic operators. +/// @{ + +constexpr inline result_with_carry add_with_carry( + uint64_t x, uint64_t y, bool carry = false) noexcept +{ + const auto s = x + y; + const auto carry1 = s < x; + const auto t = s + carry; + const auto carry2 = t < s; + return {t, carry1 || carry2}; +} + +template +constexpr result_with_carry> add_with_carry( + const uint& a, const uint& b, bool carry = false) noexcept +{ + const auto lo = add_with_carry(a.lo, b.lo, carry); + const auto hi = add_with_carry(a.hi, b.hi, lo.carry); + return {{hi.value, lo.value}, hi.carry}; +} + +constexpr inline uint128 operator+(uint128 x, uint128 y) noexcept +{ + return add_with_carry(x, y).value; +} + +constexpr inline uint128 operator+(uint128 x) noexcept +{ + return x; +} + +constexpr inline result_with_carry sub_with_carry( + uint64_t x, uint64_t y, bool carry = false) noexcept +{ + const auto d = x - y; + const auto carry1 = d > x; + const auto e = d - carry; + const auto carry2 = e > d; + return {e, carry1 || carry2}; +} + +/// Performs subtraction of two unsigned numbers and returns the difference +/// and the carry bit (aka borrow, overflow). +template +constexpr inline result_with_carry> sub_with_carry( + const uint& a, const uint& b, bool carry = false) noexcept +{ + const auto lo = sub_with_carry(a.lo, b.lo, carry); + const auto hi = sub_with_carry(a.hi, b.hi, lo.carry); + return {{hi.value, lo.value}, hi.carry}; +} + +constexpr inline uint128 operator-(uint128 x, uint128 y) noexcept +{ + return sub_with_carry(x, y).value; +} + +constexpr inline uint128 operator-(uint128 x) noexcept +{ + // Implementing as subtraction is better than ~x + 1. + // Clang9: Perfect. + // GCC8: Does something weird. + return 0 - x; +} + +inline uint128& operator++(uint128& x) noexcept +{ + return x = x + 1; +} + +inline uint128& operator--(uint128& x) noexcept +{ + return x = x - 1; +} + +inline uint128 operator++(uint128& x, int) noexcept +{ + auto ret = x; + ++x; + return ret; +} + +inline uint128 operator--(uint128& x, int) noexcept +{ + auto ret = x; + --x; + return ret; +} + +/// Optimized addition. +/// +/// This keeps the multiprecision addition until CodeGen so the pattern is not +/// broken during other optimizations. +constexpr uint128 fast_add(uint128 x, uint128 y) noexcept +{ +#ifdef __SIZEOF_INT128__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" + using uint128_native = unsigned __int128; + return uint128_native{x} + uint128_native{y}; + #pragma GCC diagnostic pop +#else + // Fallback to regular addition. + return x + y; +#endif +} + +/// @} + + +/// Comparison operators. +/// +/// In all implementations bitwise operators are used instead of logical ones +/// to avoid branching. +/// +/// @{ + +constexpr bool operator==(uint128 x, uint128 y) noexcept +{ + // Clang7: generates perfect xor based code, + // much better than __int128 where it uses vector instructions. + // GCC8: generates a bit worse cmp based code + // although it generates the xor based one for __int128. + return (x.lo == y.lo) & (x.hi == y.hi); +} + +constexpr bool operator!=(uint128 x, uint128 y) noexcept +{ + // Analogous to ==, but == not used directly, because that confuses GCC 8-9. + return (x.lo != y.lo) | (x.hi != y.hi); +} + +constexpr bool operator<(uint128 x, uint128 y) noexcept +{ + // OPT: This should be implemented by checking the borrow of x - y, + // but compilers (GCC8, Clang7) + // have problem with properly optimizing subtraction. + return (x.hi < y.hi) | ((x.hi == y.hi) & (x.lo < y.lo)); +} + +constexpr bool operator<=(uint128 x, uint128 y) noexcept +{ + return !(y < x); +} + +constexpr bool operator>(uint128 x, uint128 y) noexcept +{ + return y < x; +} + +constexpr bool operator>=(uint128 x, uint128 y) noexcept +{ + return !(x < y); +} + +/// @} + + +/// Bitwise operators. +/// @{ + +constexpr uint128 operator~(uint128 x) noexcept +{ + return {~x.hi, ~x.lo}; +} + +constexpr uint128 operator|(uint128 x, uint128 y) noexcept +{ + // Clang7: perfect. + // GCC8: stupidly uses a vector instruction in all bitwise operators. + return {x.hi | y.hi, x.lo | y.lo}; +} + +constexpr uint128 operator&(uint128 x, uint128 y) noexcept +{ + return {x.hi & y.hi, x.lo & y.lo}; +} + +constexpr uint128 operator^(uint128 x, uint128 y) noexcept +{ + return {x.hi ^ y.hi, x.lo ^ y.lo}; +} + +constexpr uint128 operator<<(uint128 x, unsigned shift) noexcept +{ + return (shift < 64) ? + // Find the part moved from lo to hi. + // For shift == 0 right shift by (64 - shift) is invalid so + // split it into 2 shifts by 1 and (63 - shift). + uint128{(x.hi << shift) | ((x.lo >> 1) >> (63 - shift)), x.lo << shift} : + + // Guarantee "defined" behavior for shifts larger than 128. + (shift < 128) ? uint128{x.lo << (shift - 64), 0} : 0; +} + +constexpr uint128 operator<<(uint128 x, uint128 shift) noexcept +{ + if (shift < 128) + return x << unsigned(shift); + return 0; +} + +constexpr uint128 operator>>(uint128 x, unsigned shift) noexcept +{ + return (shift < 64) ? + // Find the part moved from lo to hi. + // For shift == 0 left shift by (64 - shift) is invalid so + // split it into 2 shifts by 1 and (63 - shift). + uint128{x.hi >> shift, (x.lo >> shift) | ((x.hi << 1) << (63 - shift))} : + + // Guarantee "defined" behavior for shifts larger than 128. + (shift < 128) ? uint128{0, x.hi >> (shift - 64)} : 0; +} + +constexpr uint128 operator>>(uint128 x, uint128 shift) noexcept +{ + if (shift < 128) + return x >> unsigned(shift); + return 0; +} + + +/// @} + + +/// Multiplication +/// @{ + +/// Portable full unsigned multiplication 64 x 64 -> 128. +constexpr uint128 constexpr_umul(uint64_t x, uint64_t y) noexcept +{ + uint64_t xl = x & 0xffffffff; + uint64_t xh = x >> 32; + uint64_t yl = y & 0xffffffff; + uint64_t yh = y >> 32; + + uint64_t t0 = xl * yl; + uint64_t t1 = xh * yl; + uint64_t t2 = xl * yh; + uint64_t t3 = xh * yh; + + uint64_t u1 = t1 + (t0 >> 32); + uint64_t u2 = t2 + (u1 & 0xffffffff); + + uint64_t lo = (u2 << 32) | (t0 & 0xffffffff); + uint64_t hi = t3 + (u2 >> 32) + (u1 >> 32); + return {hi, lo}; +} + +/// Full unsigned multiplication 64 x 64 -> 128. +inline uint128 umul(uint64_t x, uint64_t y) noexcept +{ +#if defined(__SIZEOF_INT128__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" + const auto p = static_cast(x) * y; + return {uint64_t(p >> 64), uint64_t(p)}; + #pragma GCC diagnostic pop +#elif defined(_MSC_VER) + unsigned __int64 hi; + const auto lo = _umul128(x, y, &hi); + return {hi, lo}; +#else + return constexpr_umul(x, y); +#endif +} + +inline uint128 operator*(uint128 x, uint128 y) noexcept +{ + auto p = umul(x.lo, y.lo); + p.hi += (x.lo * y.hi) + (x.hi * y.lo); + return {p.hi, p.lo}; +} + +constexpr uint128 constexpr_mul(uint128 x, uint128 y) noexcept +{ + auto p = constexpr_umul(x.lo, y.lo); + p.hi += (x.lo * y.hi) + (x.hi * y.lo); + return {p.hi, p.lo}; +} + +/// @} + + +/// Assignment operators. +/// @{ + +constexpr uint128& operator+=(uint128& x, uint128 y) noexcept +{ + return x = x + y; +} + +constexpr uint128& operator-=(uint128& x, uint128 y) noexcept +{ + return x = x - y; +} + +inline uint128& operator*=(uint128& x, uint128 y) noexcept +{ + return x = x * y; +} + +constexpr uint128& operator|=(uint128& x, uint128 y) noexcept +{ + return x = x | y; +} + +constexpr uint128& operator&=(uint128& x, uint128 y) noexcept +{ + return x = x & y; +} + +constexpr uint128& operator^=(uint128& x, uint128 y) noexcept +{ + return x = x ^ y; +} + +constexpr uint128& operator<<=(uint128& x, unsigned shift) noexcept +{ + return x = x << shift; +} + +constexpr uint128& operator>>=(uint128& x, unsigned shift) noexcept +{ + return x = x >> shift; +} + +/// @} + + +constexpr unsigned clz_generic(uint32_t x) noexcept +{ + unsigned n = 32; + for (int i = 4; i >= 0; --i) + { + const auto s = unsigned{1} << i; + const auto hi = x >> s; + if (hi != 0) + { + n -= s; + x = hi; + } + } + return n - x; +} + +constexpr unsigned clz_generic(uint64_t x) noexcept +{ + unsigned n = 64; + for (int i = 5; i >= 0; --i) + { + const auto s = unsigned{1} << i; + const auto hi = x >> s; + if (hi != 0) + { + n -= s; + x = hi; + } + } + return n - static_cast(x); +} + +constexpr inline unsigned clz(uint32_t x) noexcept +{ +#ifdef _MSC_VER + return clz_generic(x); +#else + return x != 0 ? unsigned(__builtin_clz(x)) : 32; +#endif +} + +constexpr inline unsigned clz(uint64_t x) noexcept +{ +#ifdef _MSC_VER + return clz_generic(x); +#else + return x != 0 ? unsigned(__builtin_clzll(x)) : 64; +#endif +} + +constexpr inline unsigned clz(uint128 x) noexcept +{ + // In this order `h == 0` we get less instructions than in case of `h != 0`. + return x.hi == 0 ? clz(x.lo) + 64 : clz(x.hi); +} + + +inline uint64_t bswap(uint64_t x) noexcept +{ +#ifdef _MSC_VER + return _byteswap_uint64(x); +#else + return __builtin_bswap64(x); +#endif +} + +inline uint128 bswap(uint128 x) noexcept +{ + return {bswap(x.lo), bswap(x.hi)}; +} + + +/// Division. +/// @{ + +template +struct div_result +{ + QuotT quot; + RemT rem; + + /// Conversion to tuple of references, to allow usage with std::tie(). + constexpr operator std::tuple() noexcept { return {quot, rem}; } +}; + +namespace internal +{ +constexpr uint16_t reciprocal_table_item(uint8_t d9) noexcept +{ + return uint16_t(0x7fd00 / (0x100 | d9)); +} + +#define REPEAT4(x) \ + reciprocal_table_item((x) + 0), reciprocal_table_item((x) + 1), \ + reciprocal_table_item((x) + 2), reciprocal_table_item((x) + 3) + +#define REPEAT32(x) \ + REPEAT4((x) + 4 * 0), REPEAT4((x) + 4 * 1), REPEAT4((x) + 4 * 2), REPEAT4((x) + 4 * 3), \ + REPEAT4((x) + 4 * 4), REPEAT4((x) + 4 * 5), REPEAT4((x) + 4 * 6), REPEAT4((x) + 4 * 7) + +#define REPEAT256() \ + REPEAT32(32 * 0), REPEAT32(32 * 1), REPEAT32(32 * 2), REPEAT32(32 * 3), REPEAT32(32 * 4), \ + REPEAT32(32 * 5), REPEAT32(32 * 6), REPEAT32(32 * 7) + +/// Reciprocal lookup table. +constexpr uint16_t reciprocal_table[] = {REPEAT256()}; + +#undef REPEAT4 +#undef REPEAT32 +#undef REPEAT256 +} // namespace internal + +/// Computes the reciprocal (2^128 - 1) / d - 2^64 for normalized d. +/// +/// Based on Algorithm 2 from "Improved division by invariant integers". +inline uint64_t reciprocal_2by1(uint64_t d) noexcept +{ + INTX_REQUIRE(d & 0x8000000000000000); // Must be normalized. + + const uint64_t d9 = d >> 55; + const uint32_t v0 = internal::reciprocal_table[d9 - 256]; + + const uint64_t d40 = (d >> 24) + 1; + const uint64_t v1 = (v0 << 11) - uint32_t(v0 * v0 * d40 >> 40) - 1; + + const uint64_t v2 = (v1 << 13) + (v1 * (0x1000000000000000 - v1 * d40) >> 47); + + const uint64_t d0 = d & 1; + const uint64_t d63 = (d >> 1) + d0; // ceil(d/2) + const uint64_t e = ((v2 >> 1) & (0 - d0)) - v2 * d63; + const uint64_t v3 = (umul(v2, e).hi >> 1) + (v2 << 31); + + const uint64_t v4 = v3 - (umul(v3, d) + d).hi - d; + return v4; +} + +inline uint64_t reciprocal_3by2(uint128 d) noexcept +{ + auto v = reciprocal_2by1(d.hi); + auto p = d.hi * v; + p += d.lo; + if (p < d.lo) + { + --v; + if (p >= d.hi) + { + --v; + p -= d.hi; + } + p -= d.hi; + } + + const auto t = umul(v, d.lo); + + p += t.hi; + if (p < t.hi) + { + --v; + if (p >= d.hi) + { + if (p > d.hi || t.lo >= d.lo) + --v; + } + } + return v; +} + +inline div_result udivrem_2by1(uint128 u, uint64_t d, uint64_t v) noexcept +{ + auto q = umul(v, u.hi); + q = fast_add(q, u); + + ++q.hi; + + auto r = u.lo - q.hi * d; + + if (r > q.lo) + { + --q.hi; + r += d; + } + + if (r >= d) + { + ++q.hi; + r -= d; + } + + return {q.hi, r}; +} + +inline div_result udivrem_3by2( + uint64_t u2, uint64_t u1, uint64_t u0, uint128 d, uint64_t v) noexcept +{ + auto q = umul(v, u2); + q = fast_add(q, {u2, u1}); + + auto r1 = u1 - q.hi * d.hi; + + auto t = umul(d.lo, q.hi); + + auto r = uint128{r1, u0} - t - d; + r1 = r.hi; + + ++q.hi; + + if (r1 >= q.lo) + { + --q.hi; + r += d; + } + + if (r >= d) + { + ++q.hi; + r -= d; + } + + return {q.hi, r}; +} + +inline div_result udivrem(uint128 x, uint128 y) noexcept +{ + if (y.hi == 0) + { + INTX_REQUIRE(y.lo != 0); // Division by 0. + + const auto lsh = clz(y.lo); + const auto rsh = (64 - lsh) % 64; + const auto rsh_mask = uint64_t{lsh == 0} - 1; + + const auto yn = y.lo << lsh; + const auto xn_lo = x.lo << lsh; + const auto xn_hi = (x.hi << lsh) | ((x.lo >> rsh) & rsh_mask); + const auto xn_ex = (x.hi >> rsh) & rsh_mask; + + const auto v = reciprocal_2by1(yn); + const auto res1 = udivrem_2by1({xn_ex, xn_hi}, yn, v); + const auto res2 = udivrem_2by1({res1.rem, xn_lo}, yn, v); + return {{res1.quot, res2.quot}, res2.rem >> lsh}; + } + + if (y.hi > x.hi) + return {0, x}; + + const auto lsh = clz(y.hi); + if (lsh == 0) + { + const auto q = unsigned{y.hi < x.hi} | unsigned{y.lo <= x.lo}; + return {q, x - (q ? y : 0)}; + } + + const auto rsh = 64 - lsh; + + const auto yn_lo = y.lo << lsh; + const auto yn_hi = (y.hi << lsh) | (y.lo >> rsh); + const auto xn_lo = x.lo << lsh; + const auto xn_hi = (x.hi << lsh) | (x.lo >> rsh); + const auto xn_ex = x.hi >> rsh; + + const auto v = reciprocal_3by2({yn_hi, yn_lo}); + const auto res = udivrem_3by2(xn_ex, xn_hi, xn_lo, {yn_hi, yn_lo}, v); + + return {res.quot, res.rem >> lsh}; +} + +inline div_result sdivrem(uint128 x, uint128 y) noexcept +{ + constexpr auto sign_mask = uint128{1} << 127; + const auto x_is_neg = (x & sign_mask) != 0; + const auto y_is_neg = (y & sign_mask) != 0; + + const auto x_abs = x_is_neg ? -x : x; + const auto y_abs = y_is_neg ? -y : y; + + const auto q_is_neg = x_is_neg ^ y_is_neg; + + const auto res = udivrem(x_abs, y_abs); + + return {q_is_neg ? -res.quot : res.quot, x_is_neg ? -res.rem : res.rem}; +} + +inline uint128 operator/(uint128 x, uint128 y) noexcept +{ + return udivrem(x, y).quot; +} + +inline uint128 operator%(uint128 x, uint128 y) noexcept +{ + return udivrem(x, y).rem; +} + +inline uint128& operator/=(uint128& x, uint128 y) noexcept +{ + return x = x / y; +} + +inline uint128& operator%=(uint128& x, uint128 y) noexcept +{ + return x = x % y; +} + +/// @} + +} // namespace intx + + +namespace std +{ +template +struct numeric_limits> +{ + using type = intx::uint; + + static constexpr bool is_specialized = true; + static constexpr bool is_integer = true; + static constexpr bool is_signed = false; + static constexpr bool is_exact = true; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = false; + static constexpr bool has_signaling_NaN = false; + static constexpr float_denorm_style has_denorm = denorm_absent; + static constexpr bool has_denorm_loss = false; + static constexpr float_round_style round_style = round_toward_zero; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = true; + static constexpr int digits = CHAR_BIT * sizeof(type); + static constexpr int digits10 = int(0.3010299956639812 * digits); + static constexpr int max_digits10 = 0; + static constexpr int radix = 2; + static constexpr int min_exponent = 0; + static constexpr int min_exponent10 = 0; + static constexpr int max_exponent = 0; + static constexpr int max_exponent10 = 0; + static constexpr bool traps = std::numeric_limits::traps; + static constexpr bool tinyness_before = false; + + static constexpr type min() noexcept { return 0; } + static constexpr type lowest() noexcept { return min(); } + static constexpr type max() noexcept { return ~type{0}; } + static constexpr type epsilon() noexcept { return 0; } + static constexpr type round_error() noexcept { return 0; } + static constexpr type infinity() noexcept { return 0; } + static constexpr type quiet_NaN() noexcept { return 0; } + static constexpr type signaling_NaN() noexcept { return 0; } + static constexpr type denorm_min() noexcept { return 0; } +}; +} // namespace std + +namespace intx +{ +template +[[noreturn]] inline void throw_(const char* what) +{ +#if __cpp_exceptions + throw T{what}; +#else + std::fputs(what, stderr); + std::abort(); +#endif +} + +constexpr inline int from_dec_digit(char c) +{ + if (c < '0' || c > '9') + throw_("invalid digit"); + return c - '0'; +} + +constexpr inline int from_hex_digit(char c) +{ + if (c >= 'a' && c <= 'f') + return c - ('a' - 10); + if (c >= 'A' && c <= 'F') + return c - ('A' - 10); + return from_dec_digit(c); +} + +template +constexpr Int from_string(const char* str) +{ + auto s = str; + auto x = Int{}; + int num_digits = 0; + + if (s[0] == '0' && s[1] == 'x') + { + s += 2; + while (const auto c = *s++) + { + if (++num_digits > int{sizeof(x) * 2}) + throw_(str); + x = (x << 4) | from_hex_digit(c); + } + return x; + } + + while (const auto c = *s++) + { + if (num_digits++ > std::numeric_limits::digits10) + throw_(str); + + const auto d = from_dec_digit(c); + x = constexpr_mul(x, Int{10}) + d; + if (x < d) + throw_(str); + } + return x; +} + +template +constexpr Int from_string(const std::string& s) +{ + return from_string(s.c_str()); +} + +constexpr uint128 operator""_u128(const char* s) +{ + return from_string(s); +} + +template +inline std::string to_string(uint x, int base = 10) +{ + if (base < 2 || base > 36) + base = 10; + + if (x == 0) + return "0"; + + auto s = std::string{}; + while (x != 0) + { + // TODO: Use constexpr udivrem_1? + const auto res = udivrem(x, uint{base}); + const auto d = int(res.rem); + const auto c = d < 10 ? '0' + d : 'a' + d - 10; + s.push_back(char(c)); + x = res.quot; + } + std::reverse(s.begin(), s.end()); + return s; +} + +template +inline std::string hex(uint x) +{ + return to_string(x, 16); +} +} // namespace intx diff --git a/source/libs/decimal/src/detail/intx/intx.hpp b/source/libs/decimal/src/detail/intx/intx.hpp new file mode 100644 index 0000000000..e91809d647 --- /dev/null +++ b/source/libs/decimal/src/detail/intx/intx.hpp @@ -0,0 +1,1221 @@ +// intx: extended precision integer library. +// Copyright 2019-2020 Pawel Bylica. +// Licensed under the Apache License, Version 2.0. + +#pragma once + +#include "int128.hpp" +#include +#include +#include +#include +#include +#include + +namespace intx +{ +template +struct uint +{ + static_assert((N & (N - 1)) == 0, "Number of bits must be power of 2"); + static_assert(N >= 256, "Number of bits must be at lest 256"); + + using word_type = uint64_t; + + /// The 2x smaller type. + using half_type = uint; + + static constexpr auto num_bits = N; + static constexpr auto num_words = N / 8 / sizeof(word_type); + + half_type lo = 0; + half_type hi = 0; + + constexpr uint() noexcept = default; + + constexpr uint(half_type high, half_type low) noexcept : lo(low), hi(high) {} + + /// Implicit converting constructor for the half type. + constexpr uint(half_type x) noexcept : lo(x) {} // NOLINT + + /// Implicit converting constructor for types convertible to the half type. + template ::value>::type> + constexpr uint(T x) noexcept : lo(x) // NOLINT + {} + + constexpr explicit operator bool() const noexcept + { + return static_cast(lo) | static_cast(hi); + } + + /// Explicit converting operator for all builtin integral types. + template ::value>::type> + explicit operator Int() const noexcept + { + return static_cast(lo); + } +}; + +using uint256 = uint<256>; +using uint512 = uint<512>; + +constexpr uint8_t lo_half(uint16_t x) +{ + return static_cast(x); +} + +constexpr uint16_t lo_half(uint32_t x) +{ + return static_cast(x); +} + +constexpr uint32_t lo_half(uint64_t x) +{ + return static_cast(x); +} + +constexpr uint8_t hi_half(uint16_t x) +{ + return static_cast(x >> 8); +} + +constexpr uint16_t hi_half(uint32_t x) +{ + return static_cast(x >> 16); +} + +constexpr uint32_t hi_half(uint64_t x) +{ + return static_cast(x >> 32); +} + +template +inline constexpr auto lo_half(const uint& x) noexcept +{ + return x.lo; +} + +template +inline constexpr auto hi_half(const uint& x) noexcept +{ + return x.hi; +} + +template +constexpr unsigned num_bits(const T&) noexcept +{ + return sizeof(T) * 8; +} + +template +constexpr bool operator==(const uint& a, const uint& b) noexcept +{ + return (a.lo == b.lo) & (a.hi == b.hi); +} + +template >::value>::type> +constexpr bool operator==(const uint& x, const T& y) noexcept +{ + return x == uint(y); +} + +template >::value>::type> +constexpr bool operator==(const T& x, const uint& y) noexcept +{ + return uint(y) == x; +} + + +template +constexpr bool operator!=(const uint& a, const uint& b) noexcept +{ + return !(a == b); +} + +template >::value>::type> +constexpr bool operator!=(const uint& x, const T& y) noexcept +{ + return x != uint(y); +} + +template >::value>::type> +constexpr bool operator!=(const T& x, const uint& y) noexcept +{ + return uint(x) != y; +} + + +template +constexpr bool operator<(const uint& a, const uint& b) noexcept +{ + // Bitwise operators are used to implement logic here to avoid branching. + // It also should make the function smaller, but no proper benchmark has + // been done. + return (a.hi < b.hi) | ((a.hi == b.hi) & (a.lo < b.lo)); +} + +template >::value>::type> +constexpr bool operator<(const uint& x, const T& y) noexcept +{ + return x < uint(y); +} + +template >::value>::type> +constexpr bool operator<(const T& x, const uint& y) noexcept +{ + return uint(x) < y; +} + + +template +constexpr bool operator>(const uint& a, const uint& b) noexcept +{ + return b < a; +} + +template >::value>::type> +constexpr bool operator>(const uint& x, const T& y) noexcept +{ + return x > uint(y); +} + +template >::value>::type> +constexpr bool operator>(const T& x, const uint& y) noexcept +{ + return uint(x) > y; +} + + +template +constexpr bool operator>=(const uint& a, const uint& b) noexcept +{ + return !(a < b); +} + +template >::value>::type> +constexpr bool operator>=(const uint& x, const T& y) noexcept +{ + return x >= uint(y); +} + +template >::value>::type> +constexpr bool operator>=(const T& x, const uint& y) noexcept +{ + return uint(x) >= y; +} + + +template +constexpr bool operator<=(const uint& a, const uint& b) noexcept +{ + return !(b < a); +} + +template >::value>::type> +constexpr bool operator<=(const uint& x, const T& y) noexcept +{ + return x <= uint(y); +} + +template >::value>::type> +constexpr bool operator<=(const T& x, const uint& y) noexcept +{ + return uint(x) <= y; +} + +template +constexpr uint operator|(const uint& x, const uint& y) noexcept +{ + return {x.hi | y.hi, x.lo | y.lo}; +} + +template +constexpr uint operator&(const uint& x, const uint& y) noexcept +{ + return {x.hi & y.hi, x.lo & y.lo}; +} + +template +constexpr uint operator^(const uint& x, const uint& y) noexcept +{ + return {x.hi ^ y.hi, x.lo ^ y.lo}; +} + +template +constexpr uint operator~(const uint& x) noexcept +{ + return {~x.hi, ~x.lo}; +} + +template +constexpr uint operator<<(const uint& x, unsigned shift) noexcept +{ + constexpr auto num_bits = N; + constexpr auto half_bits = num_bits / 2; + + if (shift < half_bits) + { + const auto lo = x.lo << shift; + + // Find the part moved from lo to hi. + // The shift right here can be invalid: + // for shift == 0 => lshift == half_bits. + // Split it into 2 valid shifts by (rshift - 1) and 1. + const auto rshift = half_bits - shift; + const auto lo_overflow = (x.lo >> (rshift - 1)) >> 1; + const auto hi = (x.hi << shift) | lo_overflow; + return {hi, lo}; + } + + // This check is only needed if we want "defined" behavior for shifts + // larger than size of the Int. + if (shift < num_bits) + return {x.lo << (shift - half_bits), 0}; + + return 0; +} + +template +inline Target narrow_cast(uint64_t x) noexcept +{ + return static_cast(x); +} + +template +inline Target narrow_cast(const Int& x) noexcept +{ + return narrow_cast(x.lo); +} + +template +constexpr uint operator>>(const uint& x, unsigned shift) noexcept +{ + constexpr auto half_bits = N / 2; + + if (shift < half_bits) + { + auto hi = x.hi >> shift; + + // Find the part moved from hi to lo. + // To avoid invalid shift left, + // split them into 2 valid shifts by (lshift - 1) and 1. + unsigned lshift = half_bits - shift; + auto hi_overflow = (x.hi << (lshift - 1)) << 1; + auto lo_part = x.lo >> shift; + auto lo = lo_part | hi_overflow; + return {hi, lo}; + } + + if (shift < num_bits(x)) + return {0, x.hi >> (shift - half_bits)}; + + return 0; +} + + +template >::value>::type> +constexpr uint operator<<(const uint& x, const T& shift) noexcept +{ + if (shift < T{sizeof(x) * 8}) + return x << static_cast(shift); + return 0; +} + +template >::value>::type> +constexpr uint operator>>(const uint& x, const T& shift) noexcept +{ + if (shift < T{sizeof(x) * 8}) + return x >> static_cast(shift); + return 0; +} + +template +inline uint& operator>>=(uint& x, unsigned shift) noexcept +{ + return x = x >> shift; +} + + +constexpr uint64_t* as_words(uint128& x) noexcept +{ + return &x.lo; +} + +constexpr const uint64_t* as_words(const uint128& x) noexcept +{ + return &x.lo; +} + +template +constexpr uint64_t* as_words(uint& x) noexcept +{ + return as_words(x.lo); +} + +template +constexpr const uint64_t* as_words(const uint& x) noexcept +{ + return as_words(x.lo); +} + +template +inline uint8_t* as_bytes(uint& x) noexcept +{ + return reinterpret_cast(as_words(x)); +} + +template +inline const uint8_t* as_bytes(const uint& x) noexcept +{ + return reinterpret_cast(as_words(x)); +} + +/// Implementation of shift left as a loop. +/// This one is slower than the one using "split" strategy. +template +inline uint shl_loop(const uint& x, unsigned shift) +{ + auto r = uint{}; + constexpr unsigned word_bits = sizeof(uint64_t) * 8; + constexpr size_t num_words = sizeof(uint) / sizeof(uint64_t); + auto rw = as_words(r); + auto words = as_words(x); + unsigned s = shift % word_bits; + unsigned skip = shift / word_bits; + + uint64_t carry = 0; + for (size_t i = 0; i < (num_words - skip); ++i) + { + auto w = words[i]; + auto v = (w << s) | carry; + carry = (w >> (word_bits - s - 1)) >> 1; + rw[i + skip] = v; + } + return r; +} + +template +inline uint add_loop(const uint& a, const uint& b) noexcept +{ + static constexpr auto num_words = sizeof(a) / sizeof(uint64_t); + + auto x = as_words(a); + auto y = as_words(b); + + uint s; + auto z = as_words(s); + + bool k = false; + for (size_t i = 0; i < num_words; ++i) + { + z[i] = x[i] + y[i]; + auto k1 = z[i] < x[i]; + z[i] += k; + k = (z[i] < k) || k1; + } + + return s; +} + +template +constexpr uint operator+(const uint& x, const uint& y) noexcept +{ + return add_with_carry(x, y).value; +} + +template +constexpr uint operator-(const uint& x) noexcept +{ + return ~x + uint{1}; +} + +template +constexpr uint operator-(const uint& x, const uint& y) noexcept +{ + return sub_with_carry(x, y).value; +} + +template >::value>::type> +constexpr uint& operator+=(uint& x, const T& y) noexcept +{ + return x = x + y; +} + +template >::value>::type> +constexpr uint& operator-=(uint& x, const T& y) noexcept +{ + return x = x - y; +} + + +template +inline uint<2 * N> umul(const uint& x, const uint& y) noexcept +{ + const auto t0 = umul(x.lo, y.lo); + const auto t1 = umul(x.hi, y.lo); + const auto t2 = umul(x.lo, y.hi); + const auto t3 = umul(x.hi, y.hi); + + const auto u1 = t1 + t0.hi; + const auto u2 = t2 + u1.lo; + + const auto lo = (u2 << (num_bits(x) / 2)) | t0.lo; + const auto hi = t3 + u2.hi + u1.hi; + + return {hi, lo}; +} + +template +constexpr uint<2 * N> constexpr_umul(const uint& x, const uint& y) noexcept +{ + auto t0 = constexpr_umul(x.lo, y.lo); + auto t1 = constexpr_umul(x.hi, y.lo); + auto t2 = constexpr_umul(x.lo, y.hi); + auto t3 = constexpr_umul(x.hi, y.hi); + + auto u1 = t1 + t0.hi; + auto u2 = t2 + u1.lo; + + auto lo = (u2 << (num_bits(x) / 2)) | t0.lo; + auto hi = t3 + u2.hi + u1.hi; + + return {hi, lo}; +} + +template +inline uint mul(const uint& a, const uint& b) noexcept +{ + // Requires 1 full mul, 2 muls and 2 adds. + // Clang & GCC implements 128-bit multiplication this way. + + const auto t = umul(a.lo, b.lo); + const auto hi = (a.lo * b.hi) + (a.hi * b.lo) + t.hi; + + return {hi, t.lo}; +} + +template +inline uint sqr(const uint& a) noexcept +{ + // Based on mul() implementation. + + const auto t = umul(a.lo, a.lo); + const auto hi = 2 * (a.lo * a.hi) + t.hi; + + return {hi, t.lo}; +} + + +template +constexpr uint constexpr_mul(const uint& a, const uint& b) noexcept +{ + auto t = constexpr_umul(a.lo, b.lo); + auto hi = constexpr_mul(a.lo, b.hi) + constexpr_mul(a.hi, b.lo) + t.hi; + return {hi, t.lo}; +} + + +template +inline uint<2 * N> umul_loop(const uint& x, const uint& y) noexcept +{ + constexpr int num_words = sizeof(uint) / sizeof(uint64_t); + + uint<2 * N> p; + auto pw = as_words(p); + auto uw = as_words(x); + auto vw = as_words(y); + + for (int j = 0; j < num_words; ++j) + { + uint64_t k = 0; + for (int i = 0; i < num_words; ++i) + { + auto t = umul(uw[i], vw[j]) + pw[i + j] + k; + pw[i + j] = t.lo; + k = t.hi; + } + pw[j + num_words] = k; + } + return p; +} + +template +inline uint mul_loop_opt(const uint& u, const uint& v) noexcept +{ + constexpr int num_words = sizeof(uint) / sizeof(uint64_t); + + uint p; + auto pw = as_words(p); + auto uw = as_words(u); + auto vw = as_words(v); + + for (int j = 0; j < num_words; j++) + { + uint64_t k = 0; + for (int i = 0; i < (num_words - j - 1); i++) + { + auto t = umul(uw[i], vw[j]) + pw[i + j] + k; + pw[i + j] = t.lo; + k = t.hi; + } + pw[num_words - 1] += uw[num_words - j - 1] * vw[j] + k; + } + return p; +} + +inline uint256 operator*(const uint256& x, const uint256& y) noexcept +{ + return mul(x, y); +} + +template +inline uint operator*(const uint& x, const uint& y) noexcept +{ + return mul_loop_opt(x, y); +} + + +template >::value>::type> +constexpr uint& operator*=(uint& x, const T& y) noexcept +{ + return x = x * y; +} + +template +constexpr uint exp(uint base, uint exponent) noexcept +{ + auto result = uint{1}; + if (base == 2) + return result << exponent; + + while (exponent != 0) + { + if ((exponent & 1) != 0) + result *= base; + base = sqr(base); + exponent >>= 1; + } + return result; +} + +template +constexpr unsigned clz(const uint& x) noexcept +{ + const auto half_bits = num_bits(x) / 2; + + // TODO: Try: + // bool take_hi = h != 0; + // bool take_lo = !take_hi; + // unsigned clz_hi = take_hi * clz(h); + // unsigned clz_lo = take_lo * (clz(l) | half_bits); + // return clz_hi | clz_lo; + + // In this order `h == 0` we get less instructions than in case of `h != 0`. + return x.hi == 0 ? clz(x.lo) + half_bits : clz(x.hi); +} + +template +std::array to_words(Int x) noexcept +{ + std::array words; + std::memcpy(&words, &x, sizeof(x)); + return words; +} + +template +unsigned count_significant_words_loop(uint256 x) noexcept +{ + auto words = to_words(x); + for (size_t i = words.size(); i > 0; --i) + { + if (words[i - 1] != 0) + return static_cast(i); + } + return 0; +} + +template +inline typename std::enable_if::type count_significant_words( + const Int& x) noexcept +{ + return x != 0 ? 1 : 0; +} + +template +inline typename std::enable_if::type count_significant_words( + const Int& x) noexcept +{ + constexpr auto num_words = static_cast(sizeof(x) / sizeof(Word)); + auto h = count_significant_words(hi_half(x)); + auto l = count_significant_words(lo_half(x)); + return h != 0 ? h + (num_words / 2) : l; +} + + +namespace internal +{ +template +struct normalized_div_args +{ + uint divisor; + uint numerator; + typename uint::word_type numerator_ex; + int num_divisor_words; + int num_numerator_words; + unsigned shift; +}; + +template +[[gnu::always_inline]] inline normalized_div_args normalize( + const IntT& numerator, const IntT& denominator) noexcept +{ + // FIXME: Make the implementation type independent + static constexpr auto num_words = IntT::num_words; + + auto* u = as_words(numerator); + auto* v = as_words(denominator); + + normalized_div_args na; + auto* un = as_words(na.numerator); + auto* vn = as_words(na.divisor); + + auto& m = na.num_numerator_words; + for (m = num_words; m > 0 && u[m - 1] == 0; --m) + ; + + auto& n = na.num_divisor_words; + for (n = num_words; n > 0 && v[n - 1] == 0; --n) + ; + + na.shift = clz(v[n - 1]); + if (na.shift) + { + for (int i = num_words - 1; i > 0; --i) + vn[i] = (v[i] << na.shift) | (v[i - 1] >> (64 - na.shift)); + vn[0] = v[0] << na.shift; + + un[num_words] = u[num_words - 1] >> (64 - na.shift); + for (int i = num_words - 1; i > 0; --i) + un[i] = (u[i] << na.shift) | (u[i - 1] >> (64 - na.shift)); + un[0] = u[0] << na.shift; + } + else + { + na.numerator_ex = 0; + na.numerator = numerator; + na.divisor = denominator; + } + + // Skip the highest word of numerator if not significant. + if (un[m] != 0 || un[m - 1] >= vn[n - 1]) + ++m; + + return na; +} + +/// Divides arbitrary long unsigned integer by 64-bit unsigned integer (1 word). +/// @param u The array of a normalized numerator words. It will contain +/// the quotient after execution. +/// @param len The number of numerator words. +/// @param d The normalized divisor. +/// @return The remainder. +inline uint64_t udivrem_by1(uint64_t u[], int len, uint64_t d) noexcept +{ + INTX_REQUIRE(len >= 2); + + const auto reciprocal = reciprocal_2by1(d); + + auto rem = u[len - 1]; // Set the top word as remainder. + u[len - 1] = 0; // Reset the word being a part of the result quotient. + + auto it = &u[len - 2]; + do + { + std::tie(*it, rem) = udivrem_2by1({rem, *it}, d, reciprocal); + } while (it-- != &u[0]); + + return rem; +} + +/// Divides arbitrary long unsigned integer by 128-bit unsigned integer (2 words). +/// @param u The array of a normalized numerator words. It will contain the +/// quotient after execution. +/// @param len The number of numerator words. +/// @param d The normalized divisor. +/// @return The remainder. +inline uint128 udivrem_by2(uint64_t u[], int len, uint128 d) noexcept +{ + INTX_REQUIRE(len >= 3); + + const auto reciprocal = reciprocal_3by2(d); + + auto rem = uint128{u[len - 1], u[len - 2]}; // Set the 2 top words as remainder. + u[len - 1] = u[len - 2] = 0; // Reset these words being a part of the result quotient. + + auto it = &u[len - 3]; + do + { + std::tie(*it, rem) = udivrem_3by2(rem.hi, rem.lo, *it, d, reciprocal); + } while (it-- != &u[0]); + + return rem; +} + +/// s = x + y. +inline bool add(uint64_t s[], const uint64_t x[], const uint64_t y[], int len) noexcept +{ + // OPT: Add MinLen template parameter and unroll first loop iterations. + INTX_REQUIRE(len >= 2); + + bool carry = false; + for (int i = 0; i < len; ++i) + std::tie(s[i], carry) = add_with_carry(x[i], y[i], carry); + return carry; +} + +/// r = x - multiplier * y. +inline uint64_t submul( + uint64_t r[], const uint64_t x[], const uint64_t y[], int len, uint64_t multiplier) noexcept +{ + // OPT: Add MinLen template parameter and unroll first loop iterations. + INTX_REQUIRE(len >= 1); + + uint64_t borrow = 0; + for (int i = 0; i < len; ++i) + { + const auto s = sub_with_carry(x[i], borrow); + const auto p = umul(y[i], multiplier); + const auto t = sub_with_carry(s.value, p.lo); + r[i] = t.value; + borrow = p.hi + s.carry + t.carry; + } + return borrow; +} + +inline void udivrem_knuth( + uint64_t q[], uint64_t u[], int ulen, const uint64_t d[], int dlen) noexcept +{ + INTX_REQUIRE(dlen >= 3); + INTX_REQUIRE(ulen >= dlen); + + const auto divisor = uint128{d[dlen - 1], d[dlen - 2]}; + const auto reciprocal = reciprocal_3by2(divisor); + for (int j = ulen - dlen - 1; j >= 0; --j) + { + const auto u2 = u[j + dlen]; + const auto u1 = u[j + dlen - 1]; + const auto u0 = u[j + dlen - 2]; + + uint64_t qhat; + if (INTX_UNLIKELY(uint128(u2, u1) == divisor)) // Division overflows. + { + qhat = ~uint64_t{0}; + + u[j + dlen] = u2 - submul(&u[j], &u[j], d, dlen, qhat); + } + else + { + uint128 rhat; + std::tie(qhat, rhat) = udivrem_3by2(u2, u1, u0, divisor, reciprocal); + + bool carry; + const auto overflow = submul(&u[j], &u[j], d, dlen - 2, qhat); + std::tie(u[j + dlen - 2], carry) = sub_with_carry(rhat.lo, overflow); + std::tie(u[j + dlen - 1], carry) = sub_with_carry(rhat.hi, carry); + + if (INTX_UNLIKELY(carry)) + { + --qhat; + u[j + dlen - 1] += divisor.hi + add(&u[j], &u[j], d, dlen - 1); + } + } + + q[j] = qhat; // Store quotient digit. + } +} + +} // namespace internal + +template +div_result> udivrem(const uint& u, const uint& v) noexcept +{ + auto na = internal::normalize(u, v); + + if (na.num_numerator_words <= na.num_divisor_words) + return {0, u}; + + if (na.num_divisor_words == 1) + { + const auto r = internal::udivrem_by1( + as_words(na.numerator), na.num_numerator_words, as_words(na.divisor)[0]); + return {na.numerator, r >> na.shift}; + } + + if (na.num_divisor_words == 2) + { + const auto d = as_words(na.divisor); + const auto r = + internal::udivrem_by2(as_words(na.numerator), na.num_numerator_words, {d[1], d[0]}); + return {na.numerator, r >> na.shift}; + } + + auto un = as_words(na.numerator); // Will be modified. + + uint q; + internal::udivrem_knuth( + as_words(q), &un[0], na.num_numerator_words, as_words(na.divisor), na.num_divisor_words); + + uint r; + auto rw = as_words(r); + for (int i = 0; i < na.num_divisor_words - 1; ++i) + rw[i] = na.shift ? (un[i] >> na.shift) | (un[i + 1] << (64 - na.shift)) : un[i]; + rw[na.num_divisor_words - 1] = un[na.num_divisor_words - 1] >> na.shift; + + return {q, r}; +} + +template +constexpr div_result> sdivrem(const uint& u, const uint& v) noexcept +{ + const auto sign_mask = uint{1} << (sizeof(u) * 8 - 1); + auto u_is_neg = (u & sign_mask) != 0; + auto v_is_neg = (v & sign_mask) != 0; + + auto u_abs = u_is_neg ? -u : u; + auto v_abs = v_is_neg ? -v : v; + + auto q_is_neg = u_is_neg ^ v_is_neg; + + auto res = udivrem(u_abs, v_abs); + + return {q_is_neg ? -res.quot : res.quot, u_is_neg ? -res.rem : res.rem}; +} + +template +constexpr uint operator/(const uint& x, const uint& y) noexcept +{ + return udivrem(x, y).quot; +} + +template +constexpr uint operator%(const uint& x, const uint& y) noexcept +{ + return udivrem(x, y).rem; +} + +template >::value>::type> +constexpr uint& operator/=(uint& x, const T& y) noexcept +{ + return x = x / y; +} + +template >::value>::type> +constexpr uint& operator%=(uint& x, const T& y) noexcept +{ + return x = x % y; +} + +template +inline uint bswap(const uint& x) noexcept +{ + return {bswap(x.lo), bswap(x.hi)}; +} + + +// Support for type conversions for binary operators. + +template >::value>::type> +constexpr uint operator+(const uint& x, const T& y) noexcept +{ + return x + uint(y); +} + +template >::value>::type> +constexpr uint operator+(const T& x, const uint& y) noexcept +{ + return uint(x) + y; +} + +template >::value>::type> +constexpr uint operator-(const uint& x, const T& y) noexcept +{ + return x - uint(y); +} + +template >::value>::type> +constexpr uint operator-(const T& x, const uint& y) noexcept +{ + return uint(x) - y; +} + +template >::value>::type> +constexpr uint operator*(const uint& x, const T& y) noexcept +{ + return x * uint(y); +} + +template >::value>::type> +constexpr uint operator*(const T& x, const uint& y) noexcept +{ + return uint(x) * y; +} + +template >::value>::type> +constexpr uint operator/(const uint& x, const T& y) noexcept +{ + return x / uint(y); +} + +template >::value>::type> +constexpr uint operator/(const T& x, const uint& y) noexcept +{ + return uint(x) / y; +} + +template >::value>::type> +constexpr uint operator%(const uint& x, const T& y) noexcept +{ + return x % uint(y); +} + +template >::value>::type> +constexpr uint operator%(const T& x, const uint& y) noexcept +{ + return uint(x) % y; +} + +template >::value>::type> +constexpr uint operator|(const uint& x, const T& y) noexcept +{ + return x | uint(y); +} + +template >::value>::type> +constexpr uint operator|(const T& x, const uint& y) noexcept +{ + return uint(x) | y; +} + +template >::value>::type> +constexpr uint operator&(const uint& x, const T& y) noexcept +{ + return x & uint(y); +} + +template >::value>::type> +constexpr uint operator&(const T& x, const uint& y) noexcept +{ + return uint(x) & y; +} + +template >::value>::type> +constexpr uint operator^(const uint& x, const T& y) noexcept +{ + return x ^ uint(y); +} + +template >::value>::type> +constexpr uint operator^(const T& x, const uint& y) noexcept +{ + return uint(x) ^ y; +} + +template >::value>::type> +constexpr uint& operator|=(uint& x, const T& y) noexcept +{ + return x = x | y; +} + +template >::value>::type> +constexpr uint& operator&=(uint& x, const T& y) noexcept +{ + return x = x & y; +} + +template >::value>::type> +constexpr uint& operator^=(uint& x, const T& y) noexcept +{ + return x = x ^ y; +} + +template >::value>::type> +constexpr uint& operator<<=(uint& x, const T& y) noexcept +{ + return x = x << y; +} + +template >::value>::type> +constexpr uint& operator>>=(uint& x, const T& y) noexcept +{ + return x = x >> y; +} + + +inline uint256 addmod(const uint256& x, const uint256& y, const uint256& mod) noexcept +{ + const auto s = add_with_carry(x, y); + return (uint512{s.carry, s.value} % mod).lo; +} + +inline uint256 mulmod(const uint256& x, const uint256& y, const uint256& mod) noexcept +{ + return (umul(x, y) % mod).lo; +} + + +constexpr uint256 operator"" _u256(const char* s) noexcept +{ + return from_string(s); +} + +constexpr uint512 operator"" _u512(const char* s) noexcept +{ + return from_string(s); +} + +namespace le // Conversions to/from LE bytes. +{ +template +inline IntT load(const uint8_t (&bytes)[M]) noexcept +{ + static_assert(M == IntT::num_bits / 8, + "the size of source bytes must match the size of the destination uint"); + auto x = IntT{}; + std::memcpy(&x, bytes, sizeof(x)); + return x; +} + +template +inline void store(uint8_t (&dst)[N / 8], const intx::uint& x) noexcept +{ + std::memcpy(dst, &x, sizeof(x)); +} + +} // namespace le + + +namespace be // Conversions to/from BE bytes. +{ +/// Loads an uint value from bytes of big-endian order. +/// If the size of bytes is smaller than the result uint, the value is zero-extended. +template +inline IntT load(const uint8_t (&bytes)[M]) noexcept +{ + static_assert(M <= IntT::num_bits / 8, + "the size of source bytes must not exceed the size of the destination uint"); + auto x = IntT{}; + std::memcpy(&as_bytes(x)[IntT::num_bits / 8 - M], bytes, M); + return bswap(x); +} + +template +inline IntT load(const T& t) noexcept +{ + return load(t.bytes); +} + +/// Stores an uint value in a bytes array in big-endian order. +template +inline void store(uint8_t (&dst)[N / 8], const intx::uint& x) noexcept +{ + const auto d = bswap(x); + std::memcpy(dst, &d, sizeof(d)); +} + +/// Stores an uint value in .bytes field of type T. The .bytes must be an array of uint8_t +/// of the size matching the size of uint. +template +inline T store(const intx::uint& x) noexcept +{ + T r{}; + store(r.bytes, x); + return r; +} + +/// Stores the truncated value of an uint in a bytes array. +/// Only the least significant bytes from big-endian representation of the uint +/// are stored in the result bytes array up to array's size. +template +inline void trunc(uint8_t (&dst)[M], const intx::uint& x) noexcept +{ + static_assert(M < N / 8, "destination must be smaller than the source value"); + const auto d = bswap(x); + const auto b = as_bytes(d); + std::memcpy(dst, &b[sizeof(d) - M], M); +} + +/// Stores the truncated value of an uint in the .bytes field of an object of type T. +template +inline T trunc(const intx::uint& x) noexcept +{ + T r{}; + trunc(r.bytes, x); + return r; +} + +namespace unsafe +{ +/// Loads an uint value from a buffer. The user must make sure +/// that the provided buffer is big enough. Therefore marked "unsafe". +template +inline IntT load(const uint8_t* bytes) noexcept +{ + auto x = IntT{}; + std::memcpy(&x, bytes, sizeof(x)); + return bswap(x); +} + +/// Stores an uint value at the provided pointer in big-endian order. The user must make sure +/// that the provided buffer is big enough to fit the value. Therefore marked "unsafe". +template +inline void store(uint8_t* dst, const intx::uint& x) noexcept +{ + const auto d = bswap(x); + std::memcpy(dst, &d, sizeof(d)); +} +} // namespace unsafe + +} // namespace be + +} // namespace intx diff --git a/source/libs/decimal/src/detail/wideInteger.cpp b/source/libs/decimal/src/detail/wideInteger.cpp new file mode 100644 index 0000000000..91f8d56691 --- /dev/null +++ b/source/libs/decimal/src/detail/wideInteger.cpp @@ -0,0 +1,244 @@ +#include "wideInteger.h" +#include "intx/int128.hpp" +#include "intx/intx.hpp" + + +const UInt128 uInt128Zero = {0, 0}; +const uint64_t k1e18 = 1000000000000000000LL; +const UInt128 uInt128_1e18 = {k1e18, 0}; +const UInt128 uInt128One = {1, 0}; +const UInt128 uInt128Two = {2, 0}; + +void makeUInt128(uint128* pUint128, uint64_t high, uint64_t low) { + intx::uint128* pIntxUint = (intx::uint128*)pUint128; + pIntxUint->hi = high; + pIntxUint->lo = low; +} + +uint64_t uInt128Hi(const UInt128* pInt) { + intx::uint128 *pIntUint = (intx::uint128*)pInt; + return pIntUint->hi; +} + +uint64_t uInt128Lo(const UInt128* pInt) { + intx::uint128 *pIntUint = (intx::uint128*)pInt; + return pIntUint->lo; +} + +void uInt128Add(UInt128* pLeft, const UInt128* pRight) { + intx::uint128 *pX = (intx::uint128*)pLeft; + const intx::uint128 *pY = (const intx::uint128*)pRight; + *pX += *pY; +} +void uInt128Subtract(UInt128* pLeft, const UInt128* pRight) { + intx::uint128 *pX = (intx::uint128*)pLeft; + const intx::uint128 *pY = (const intx::uint128*)pRight; + *pX -= *pY; +} +void uInt128Multiply(UInt128* pLeft, const UInt128* pRight) { + /* + intx::uint128 *pX = (intx::uint128*)pLeft; + const intx::uint128 *pY = (const intx::uint128*)pRight; + *pX *= *pY; */ + __uint128_t *px = (__uint128_t*)pLeft; + const __uint128_t *py = (__uint128_t*)pRight; + *px = *px * *py; +} +void uInt128Divide(UInt128* pLeft, const UInt128* pRight) { + /* + intx::uint128 *pX = (intx::uint128*)pLeft; + const intx::uint128 *pY = (const intx::uint128*)pRight; + *pX /= *pY;*/ + __uint128_t *px = (__uint128_t*)pLeft; + const __uint128_t *py = (__uint128_t*)pRight; + *px = *px / *py; +} +void uInt128Mod(UInt128* pLeft, const UInt128* pRight) { + /* + intx::uint128 *pX = (intx::uint128*)pLeft; + const intx::uint128 *pY = (const intx::uint128*)pRight; + *pX %= *pY;*/ + __uint128_t *px = (__uint128_t*)pLeft; + const __uint128_t *py = (__uint128_t*)pRight; + *px = *px % *py; +} +bool uInt128Lt(const UInt128* pLeft, const UInt128* pRight) { + const intx::uint128 *pX = (const intx::uint128*)pLeft; + const intx::uint128 *pY = (const intx::uint128*)pRight; + return *pX < *pY; +} +bool uInt128Gt(const UInt128* pLeft, const UInt128* pRight) { + const intx::uint128 *pX = (const intx::uint128*)pLeft; + const intx::uint128 *pY = (const intx::uint128*)pRight; + return *pX > *pY; +} +bool uInt128Eq(const UInt128* pLeft, const UInt128* pRight) { + const intx::uint128 *pX = (const intx::uint128*)pLeft; + const intx::uint128 *pY = (const intx::uint128*)pRight; + return *pX == *pY; +} + +Int128 makeInt128(int64_t high, uint64_t low) { + Int128 int128 = {low, high}; + return int128; +} +int64_t int128Hi(const Int128* pUint128) { + return pUint128->high; +} +uint64_t int128Lo(const Int128* pUint128) { + return pUint128->low; +} +Int128 int128Abs(const Int128* pInt128) { + if (int128Lt(pInt128, &int128Zero)) { + return int128Negate(pInt128); + } + return *pInt128; +} +Int128 int128Negate(const Int128* pInt128) { + uint64_t low = ~pInt128->low + 1; + int64_t high = ~pInt128->high; + if (low == 0) high += 1; + return makeInt128(high, low); +} +Int128 int128Add(const Int128* pLeft, const Int128* pRight) { + intx::uint128 result = *(intx::uint128*)pLeft + *(intx::uint128*)pRight; + return *(Int128*)&result; +} +Int128 int128Subtract(const Int128* pLeft, const Int128* pRight) { + intx::uint128 result = *(intx::uint128*)pLeft - *(intx::uint128*)pRight; + return *(Int128*)&result; +} +Int128 int128Multiply(const Int128* pLeft, const Int128* pRight) { + intx::uint128 result = *(intx::uint128*)pLeft * *(intx::uint128*)pRight; + return *(Int128*)&result; +} +Int128 int128Divide(const Int128* pLeft, const Int128* pRight) { + intx::uint128 result = *(intx::uint128*)pLeft / *(intx::uint128*)pRight; + return *(Int128*)&result; +} +Int128 int128Mod(const Int128* pLeft, const Int128* pRight) { + intx::uint128 result = *(intx::uint128*)pLeft % *(intx::uint128*)pRight; + return *(Int128*)&result; +} +bool int128Lt(const Int128* pLeft, const Int128* pRight) { + return pLeft->high < pRight->high || (pLeft->high == pRight->high && pLeft->low < pRight->low); +} +bool int128Gt(const Int128* pLeft, const Int128* pRight) { + return int128Lt(pRight, pLeft); +} +bool int128Eq(const Int128* pLeft, const Int128* pRight) { + return pLeft->high == pRight->high && pLeft->low == pRight->low; +} +Int128 int128RightShift(const Int128* pLeft, int32_t shift) { + intx::uint128 result = *(intx::uint128*)pLeft >> shift; + return *(Int128*)&result; +} + +const Int128 int128Zero = {0, 0}; +const Int128 int128One = {1, 0}; + +UInt256 makeUint256(UInt128 high, UInt128 low) { + UInt256 uint256 = {high, low}; + return uint256; +} +uint128 uInt256Hi(const UInt256* pUint256) { + return pUint256->high; +} +uint128 uInt256Lo(const UInt256* pUint256) { + return pUint256->low; +} +UInt256 uInt256Add(const UInt256* pLeft, const UInt256* pRight) { + intx::uint256 result = *(intx::uint256*)pLeft + *(intx::uint256*)pRight; + return *(UInt256*)&result; +} +UInt256 uInt256Subtract(const UInt256* pLeft, const UInt256* pRight) { + intx::uint256 result = *(intx::uint256*)pLeft - *(intx::uint256*)pRight; + return *(UInt256*)&result; +} +UInt256 uInt256Multiply(const UInt256* pLeft, const UInt256* pRight) { + intx::uint256 result = *(intx::uint256*)pLeft * *(intx::uint256*)pRight; + return *(UInt256*)&result; +} +UInt256 uInt256Divide(const UInt256* pLeft, const UInt256* pRight) { + intx::uint256 result = *(intx::uint256*)pLeft / *(intx::uint256*)pRight; + return *(UInt256*)&result; +} +UInt256 uInt256Mod(const UInt256* pLeft, const UInt256* pRight) { + intx::uint256 result = *(intx::uint256*)pLeft % *(intx::uint256*)pRight; + return *(UInt256*)&result; +} +bool uInt256Lt(const UInt256* pLeft, const UInt256* pRight) { + return *(intx::uint256*)pLeft < *(intx::uint256*)pRight; +} +bool uInt256Gt(const UInt256* pLeft, const UInt256* pRight) { + return *(intx::uint256*)pLeft > *(intx::uint256*)pRight; +} +bool uInt256Eq(const UInt256* pLeft, const UInt256* pRight) { + return *(intx::uint256*)pLeft == *(intx::uint256*)pRight; +} +UInt256 uInt256RightShift(const UInt256* pLeft, int32_t shift) { + intx::uint256 result = *(intx::uint256*)pLeft >> shift; + return *(UInt256*)&result; +} + +Int256 makeInt256(Int128 high, UInt128 low) { + Int256 int256 = {low, high}; + return int256; +} +Int128 int256Hi(const Int256* pUint256) { + return pUint256->high; +} +UInt128 int256Lo(const Int256* pUint256) { + return pUint256->low; +} +Int256 int256Abs(const Int256* pInt256) { + if (int256Lt(pInt256, &int256Zero)) { + return int256Negate(pInt256); + } + return *pInt256; +} + +Int256 int256Negate(const Int256* pInt256) { + return int256Subtract(&int256Zero, pInt256); +} +Int256 int256Add(const Int256* pLeft, const Int256* pRight) { + intx::uint256 result = *(intx::uint256*)pLeft + *(intx::uint256*)pRight; + return *(Int256*)&result; +} +Int256 int256Subtract(const Int256* pLeft, const Int256* pRight) { + intx::uint256 result = *(intx::uint256*)pLeft - *(intx::uint256*)pRight; + return *(Int256*)&result; +} +Int256 int256Multiply(const Int256* pLeft, const Int256* pRight) { + intx::uint256 result = *(intx::uint256*)pLeft * *(intx::uint256*)pRight; + return *(Int256*)&result; +} +Int256 int256Divide(const Int256* pLeft, const Int256* pRight) { + intx::uint256 result = *(intx::uint256*)pLeft / *(intx::uint256*)pRight; + return *(Int256*)&result; +} +Int256 int256Mod(const Int256* pLeft, const Int256* pRight) { + intx::uint256 result = *(intx::uint256*)pLeft % *(intx::uint256*)pRight; + return *(Int256*)&result; +} +bool int256Lt(const Int256* pLeft, const Int256* pRight) { + Int128 hiLeft = int256Hi(pLeft), hiRight = int256Hi(pRight); + UInt128 lowLeft = int256Lo(pLeft), lowRight = int256Lo(pRight); + return int128Lt(&hiLeft, &hiRight) || (int128Eq(&hiLeft, &hiRight) && uInt128Lt(&lowLeft, &lowRight)); +} +bool int256Gt(const Int256* pLeft, const Int256* pRight) { + return int256Lt(pRight, pLeft); +} +bool int256Eq(const Int256* pLeft, const Int256* pRight) { + Int128 hiLeft = int256Hi(pLeft), hiRight = int256Hi(pRight); + UInt128 lowLeft = int256Lo(pLeft), lowRight = int256Lo(pRight); + return int128Eq(&hiLeft, &hiRight) && uInt128Eq(&lowLeft, &lowRight); +} +Int256 int256RightShift(const Int256* pLeft, int32_t shift) { + intx::uint256 result = *(intx::uint256*)pLeft >> shift; + return *(Int256*)&result; +} + +const Int256 int256One = {.low = uInt128One, .high = int128Zero}; +const Int256 int256Zero = {uInt128Zero, int128Zero}; +const Int256 int256Two = {.low = uInt128Two, .high = int128Zero}; diff --git a/source/libs/decimal/src/wideInteger.c b/source/libs/decimal/src/wideInteger.c deleted file mode 100644 index f3f2f7eff8..0000000000 --- a/source/libs/decimal/src/wideInteger.c +++ /dev/null @@ -1,26 +0,0 @@ -#include "wideInteger.h" - -#if defined(__GNUC__) || defined(__clang__) -// #if 0 -void makeUInt128(UInt128* pInt, DecimalWord hi, DecimalWord lo) { *pInt = ((UInt128)hi) << 64 | lo; } -uint64_t uInt128Hi(const UInt128* pInt) { return *pInt >> 64; } -uint64_t uInt128Lo(const UInt128* pInt) { return *pInt & 0xFFFFFFFFFFFFFFFF; } - -void uInt128Abs(UInt128* pInt); -void uInt128Add(UInt128* pLeft, const UInt128* pRight) { *pLeft += *pRight; } -void uInt128Subtract(UInt128* pLeft, const UInt128* pRight); -void uInt128Multiply(UInt128* pLeft, const UInt128* pRight) { *pLeft *= *pRight; } -void uInt128Divide(UInt128* pLeft, const UInt128* pRight) { *pLeft /= *pRight; } -void uInt128Mod(UInt128* pLeft, const UInt128* pRight) { *pLeft %= *pRight; } -bool uInt128Lt(const UInt128* pLeft, const UInt128* pRight); -bool uInt128Gt(const UInt128* pLeft, const UInt128* pRight); -bool uInt128Eq(const UInt128* pLeft, const UInt128* pRight) { return *pLeft == *pRight; } - -const UInt128 uInt128Zero = 0; -const uint64_t k1e18 = 1000000000000000000LL; -const UInt128 uInt128_1e18 = k1e18; -#else - -void uInt128Multiply(UInt128* pLeft, const UInt128* pRight) {} - -#endif diff --git a/source/libs/decimal/test/decimalTest.cpp b/source/libs/decimal/test/decimalTest.cpp index 60a3086f6d..61fef52630 100644 --- a/source/libs/decimal/test/decimalTest.cpp +++ b/source/libs/decimal/test/decimalTest.cpp @@ -1,8 +1,9 @@ #include -#include #include +#include +#include #include -#include +#define ALLOW_FORBID_FUNC #include "decimal.h" #include "tdatablock.h" @@ -184,11 +185,15 @@ class Numeric { static SDataType getRetType(EOperatorType op, const SDataType& lt, const SDataType& rt) { SDataType ot = {0}; - decimalGetRetType(<, &rt, op, &ot); + int32_t code = decimalGetRetType(<, &rt, op, &ot); + if (code != 0) throw std::runtime_error(tstrerror(code)); return ot; } SDataType type() const { - return {.type = NumericType::dataType, .precision = prec(), .scale = scale(), .bytes = NumericType::bytes}; + return {.type = NumericType::dataType, + .precision = prec(), + .scale = scale(), + .bytes = NumericType::bytes}; } uint8_t prec() const { return prec_; } uint8_t scale() const { return scale_; } @@ -202,9 +207,15 @@ class Numeric { template Numeric binaryOp(const Numeric& r, EOperatorType op) { - SDataType lt{.type = NumericType::dataType, .precision = prec_, .scale = scale_, .bytes = NumericType::bytes}; - SDataType rt{.type = NumericType::dataType, .precision = r.prec(), .scale = r.scale(), .bytes = NumericType::bytes}; - SDataType ot = getRetType(op, lt, rt); + SDataType lt{.type = NumericType::dataType, + .precision = prec_, + .scale = scale_, + .bytes = NumericType::bytes}; + SDataType rt{.type = NumericType::dataType, + .precision = r.prec(), + .scale = r.scale(), + .bytes = NumericType::bytes}; + SDataType ot = getRetType(op, lt, rt); Numeric out{ot.precision, ot.scale, "0"}; int32_t code = decimalOp(op, <, &rt, &ot, &dec_, &r.dec(), &out); if (code != 0) throw std::overflow_error(tstrerror(code)); @@ -214,7 +225,10 @@ class Numeric { template Numeric binaryOp(const T& r, EOperatorType op) { using TypeInfo = TrivialTypeInfo; - SDataType lt{.type = NumericType::dataType, .precision = prec_, .scale = scale_, .bytes = NumericType::bytes}; + SDataType lt{.type = NumericType::dataType, + .precision = prec_, + .scale = scale_, + .bytes = NumericType::bytes}; SDataType rt{.type = TypeInfo::dataType, .precision = 0, .scale = 0, .bytes = TypeInfo::bytes}; SDataType ot = getRetType(op, lt, rt); Numeric out{ot.precision, ot.scale, "0"}; @@ -236,16 +250,16 @@ class Numeric { DEFINE_OPERATOR(-, OP_TYPE_SUB); DEFINE_OPERATOR(*, OP_TYPE_MULTI); DEFINE_OPERATOR(/, OP_TYPE_DIV); + DEFINE_OPERATOR(%, OP_TYPE_REM); #define DEFINE_TYPE_OP(op, op_type) \ template \ Numeric operator op(const T & r) { \ cout << *this << " " #op " " << r << "(" << typeid(T).name() << ")" << " = "; \ - Numeric res = {}; \ + Numeric res = {}; \ try { \ res = binaryOp(r, op_type); \ } catch (...) { \ - cout << "Exception caught during binaryOp" << endl; \ throw; \ } \ cout << res << endl; \ @@ -255,6 +269,7 @@ class Numeric { DEFINE_TYPE_OP(-, OP_TYPE_SUB); DEFINE_TYPE_OP(*, OP_TYPE_MULTI); DEFINE_TYPE_OP(/, OP_TYPE_DIV); + DEFINE_TYPE_OP(%, OP_TYPE_REM); #define DEFINE_REAL_OP(op) \ double operator op(double v) { \ @@ -344,7 +359,7 @@ class Numeric { } #define DEFINE_OPERATOR_EQ_T(type) \ - Numeric& operator=(type v) { \ + Numeric& operator=(type v) { \ int32_t code = 0; \ if (BitNum == 64) { \ DEFINE_OPERATOR_FROM_FOR_BITNUM(type, 64); \ @@ -368,17 +383,25 @@ class Numeric { DEFINE_OPERATOR_EQ_T(float); Numeric& operator=(const Decimal128& d) { - SDataType inputDt = {.type = TSDB_DATA_TYPE_DECIMAL, .precision = prec(), .scale = scale(), .bytes = DECIMAL128_BYTES}; - SDataType outputDt = {.type = NumericType::dataType, .precision = prec(), .scale = scale(), .bytes = NumericType::bytes}; + SDataType inputDt = { + .type = TSDB_DATA_TYPE_DECIMAL, .precision = prec(), .scale = scale(), .bytes = DECIMAL128_BYTES}; + SDataType outputDt = {.type = NumericType::dataType, + .precision = prec(), + .scale = scale(), + .bytes = NumericType::bytes}; int32_t code = convertToDecimal(&d, &inputDt, &dec_, &outputDt); if (code == TSDB_CODE_DECIMAL_OVERFLOW) throw std::overflow_error(tstrerror(code)); if (code != 0) throw std::runtime_error(tstrerror(code)); return *this; } Numeric& operator=(const Decimal64& d) { - SDataType inputDt = {.type = TSDB_DATA_TYPE_DECIMAL64, .precision = prec(), .scale = scale(), .bytes = DECIMAL64_BYTES}; - SDataType outputDt = {.type = NumericType::dataType, .precision = prec_, .scale = scale_, .bytes = NumericType::bytes}; - int32_t code = convertToDecimal(&d, &inputDt, &dec_, &outputDt); + SDataType inputDt = { + .type = TSDB_DATA_TYPE_DECIMAL64, .precision = prec(), .scale = scale(), .bytes = DECIMAL64_BYTES}; + SDataType outputDt = {.type = NumericType::dataType, + .precision = prec_, + .scale = scale_, + .bytes = NumericType::bytes}; + int32_t code = convertToDecimal(&d, &inputDt, &dec_, &outputDt); if (code == TSDB_CODE_DECIMAL_OVERFLOW) throw std::overflow_error(tstrerror(code)); if (code != 0) throw std::runtime_error(tstrerror(code)); return *this; @@ -583,6 +606,7 @@ TEST(decimal128, divide) { cout << " = "; ops->divide(d.words, d2.words, 2, remainder.words); printDecimal(&d, TSDB_DATA_TYPE_DECIMAL, out_precision, out_scale); + ASSERT_TRUE(1); } TEST(decimal, api_taos_fetch_rows) { @@ -854,7 +878,7 @@ class DecimalStringRandomGenerator { std::mt19937 gen_; std::uniform_int_distribution dis_; static const std::array cornerCases; - static const unsigned int ratio_base = 1000000; + static const unsigned int ratio_base = 1000000; public: DecimalStringRandomGenerator() : gen_(rd_()), dis_(0, ratio_base) {} @@ -890,12 +914,10 @@ class DecimalStringRandomGenerator { } private: - int randomInt(int modulus) { return dis_(gen_) % modulus; } - char generateSign(float positive_ratio) { return possible(positive_ratio) ? '+' : '-'; } - char generateDigit() { return randomInt(10) + '0'; } - bool currentShouldGenerateCornerCase(float corner_case_ratio) { - return possible(corner_case_ratio); - } + int randomInt(int modulus) { return dis_(gen_) % modulus; } + char generateSign(float positive_ratio) { return possible(positive_ratio) ? '+' : '-'; } + char generateDigit() { return randomInt(10) + '0'; } + bool currentShouldGenerateCornerCase(float corner_case_ratio) { return possible(corner_case_ratio); } string generateCornerCase(const DecimalStringRandomGeneratorConfig& config) { string res{}; if (possible(0.8)) { @@ -912,14 +934,13 @@ class DecimalStringRandomGenerator { return res; } - bool possible(float ratio) { - return randomInt(ratio_base) <= ratio * ratio_base; - } + bool possible(float ratio) { return randomInt(ratio_base) <= ratio * ratio_base; } }; const std::array DecimalStringRandomGenerator::cornerCases = {"0", "NULL", "0.", ".0", "00000.000000"}; TEST(decimal, randomGenerator) { + GTEST_SKIP(); DecimalStringRandomGeneratorConfig config; DecimalStringRandomGenerator generator; for (int i = 0; i < 1000; ++i) { @@ -932,18 +953,23 @@ TEST(deicmal, decimalFromStr_all) { // TODO test e/E } -#define ASSERT_OVERFLOW(op) \ - try { \ - auto res = op; \ - } catch (std::overflow_error & e) { \ - } catch (std::exception & e) { \ - FAIL(); \ - } +#define ASSERT_OVERFLOW(op) \ + do { \ + try { \ + auto res = op; \ + } catch (std::overflow_error & e) { \ + cout << " overflow" << endl; \ + break; \ + } catch (std::exception & e) { \ + FAIL(); \ + } \ + FAIL(); \ + } while (0) TEST(decimal, op_overflow) { // divide 0 error Numeric<128> dec{38, 2, string(36, '9') + ".99"}; - ASSERT_OVERFLOW(dec / 0); // TODO wjm add divide by 0 error code + ASSERT_OVERFLOW(dec / 0); // TODO wjm add divide by 0 error code // test decimal128Max Numeric<128> max{38, 10, "0"}; @@ -951,12 +977,13 @@ TEST(decimal, op_overflow) { ASSERT_EQ(max.toString(), "9999999999999999999999999999.9999999999"); { - // multiply overflow - ASSERT_OVERFLOW(max * 10); - } - { + // multiply no overflow, trim scale + auto res = max * 10; // scale will be trimed to 6, and round up + ASSERT_EQ(res.scale(), 6); + ASSERT_EQ(res.toString(), "100000000000000000000000000000.000000"); + // multiply not overflow, no trim scale - Numeric<64> dec64{18, 10, "99999999.9999999999"}; + Numeric<64> dec64{18, 10, "99999999.9999999999"}; Numeric<128> dec128{19, 10, "999999999.9999999999"}; auto rett = Numeric<64>::getRetType(OP_TYPE_MULTI, dec64.type(), dec128.type()); @@ -964,7 +991,7 @@ TEST(decimal, op_overflow) { ASSERT_EQ(rett.type, TSDB_DATA_TYPE_DECIMAL); ASSERT_EQ(rett.scale, dec64.scale() + dec128.scale()); - auto res = dec64 * dec128; + res = dec64 * dec128; ASSERT_EQ(res.toString(), "99999999999999999.89000000000000000001"); // multiply not overflow, trim scale from 20 - 19 @@ -979,16 +1006,247 @@ TEST(decimal, op_overflow) { rett = Numeric<128>::getRetType(OP_TYPE_MULTI, dec64.type(), dec128_2.type()); ASSERT_EQ(rett.scale, 18); res = dec64 * dec128_2; - ASSERT_EQ(res.toString(), "9999999999999999990.000000000000000000"); + ASSERT_EQ(res.toString(), "9999999999999999989.990000000000000000"); + + // trim scale from 20 -> 17 + dec128_2 = {22, 10, "999999999999.9999999999"}; + rett = Numeric<128>::getRetType(OP_TYPE_MULTI, dec64.type(), dec128_2.type()); + ASSERT_EQ(rett.scale, 17); + res = dec64 * dec128_2; + ASSERT_EQ(res.toString(), "99999999999999999899.99000000000000000"); + + // trim scale from 20 -> 6 + dec128_2 = {33, 10, "99999999999999999999999.9999999999"}; + rett = Numeric<128>::getRetType(OP_TYPE_MULTI, dec64.type(), dec128_2.type()); + ASSERT_EQ(rett.scale, 6); + res = dec64 * dec128_2; + ASSERT_EQ(res.toString(), "9999999999999999989999999999999.990000"); + + dec128_2 = {34, 10, "999999999999999999999999.9999999999"}; + rett = Numeric<128>::getRetType(OP_TYPE_MULTI, dec64.type(), dec128_2.type()); + ASSERT_EQ(rett.scale, 6); + res = dec64 * dec128_2; + ASSERT_EQ(res.toString(), "99999999999999999899999999999999.990000"); + + dec128_2 = {35, 10, "9999999999999999999999999.9999999999"}; + rett = Numeric<128>::getRetType(OP_TYPE_MULTI, dec64.type(), dec128_2.type()); + ASSERT_EQ(rett.scale, 6); + ASSERT_OVERFLOW(dec64 * dec128_2); } - - { - // multiply middle res overflow, but final res not overflow - // same scale multiply - // different scale multiply + // divide not overflow but trim scale + Numeric<128> dec128{19, 10, "999999999.9999999999"}; + Numeric<64> dec64{10, 10, "0.10000000"}; + auto res = dec128 / dec64; + ASSERT_EQ(res.scale(), 19); + ASSERT_EQ(res.toString(), "9999999999.9999999990000000000"); + + dec64 = {10, 10, "0.1111111111"}; + res = dec128 / dec64; + ASSERT_EQ(res.scale(), 19); + ASSERT_EQ(res.toString(), "9000000000.8999999991899999999"); + + dec64 = {10, 2, "0.01"}; + res = dec128 / dec64; + ASSERT_EQ(res.scale(), 21); + ASSERT_EQ(res.prec(), 32); + ASSERT_EQ(res.toString(), "99999999999.999999990000000000000"); + + dec64 = {10, 2, "7.77"}; + int32_t a = 2; + res = dec64 % a; + ASSERT_EQ(res.toString(), "1.77"); + + dec128 = {38, 10, "999999999999999999999999999.9999999999"}; + res = dec128 % dec64; + ASSERT_EQ(res.toString(), "5.4399999999"); + + dec64 = {18, 10, "99999999.9999999999"}; + res = dec128 % dec64; + ASSERT_EQ(res.toString(), "0.0000000009"); + + Numeric<128> dec128_2 = {38, 10, "9988888888888888888888888.1111111111"}; + res = dec128 % dec128_2; + ASSERT_EQ(res.toString(), "1111111111111111111111188.8888888899"); + + dec128 = {38, 10, "9999999999999999999999999999.9999999999"}; + dec128_2 = {38, 2, "999999999999999999999999999988123123.88"}; + res = dec128 % dec128_2; + ASSERT_EQ(res.toString(), "9999999999999999999999999999.9999999999"); + } +} + +EOperatorType get_op_type(char op) { + switch (op) { + case '+': + return OP_TYPE_ADD; + case '-': + return OP_TYPE_SUB; + case '*': + return OP_TYPE_MULTI; + case '/': + return OP_TYPE_DIV; + case '%': + return OP_TYPE_REM; + default: + return OP_TYPE_IS_UNKNOWN; + } +} + +std::string get_op_str(EOperatorType op) { + switch (op) { + case OP_TYPE_ADD: + return "+"; + case OP_TYPE_SUB: + return "-"; + case OP_TYPE_MULTI: + return "*"; + case OP_TYPE_DIV: + return "/"; + case OP_TYPE_REM: + return "%"; + default: + return "unknown"; + } +} + +struct DecimalRetTypeCheckConfig { + bool check_res_type = true; + bool check_bytes = true; + bool log = true; +}; +struct DecimalRetTypeCheckContent { + DecimalRetTypeCheckContent(const SDataType& a, const SDataType& b, const SDataType& out, EOperatorType op) + : type_a(a), type_b(b), type_out(out), op_type(op) {} + SDataType type_a; + SDataType type_b; + SDataType type_out; + EOperatorType op_type; + // (1, 0) / (1, 1) = (8, 6) + // (1, 0) / (1, 1) = (8, 6) + DecimalRetTypeCheckContent(const std::string& s) { + char op = '\0'; + sscanf(s.c_str(), "(%hhu, %hhu) %c (%hhu, %hhu) = (%hhu, %hhu)", &type_a.precision, &type_a.scale, &op, + &type_b.precision, &type_b.scale, &type_out.precision, &type_out.scale); + type_a = getDecimalType(type_a.precision, type_a.scale); + type_b = getDecimalType(type_b.precision, type_b.scale); + type_out = getDecimalType(type_out.precision, type_out.scale); + op_type = get_op_type(op); } + void check(const DecimalRetTypeCheckConfig &config = DecimalRetTypeCheckConfig()) { + SDataType ret = {0}; + try { + if (config.log) + cout << "check ret type for type: (" << (int)type_a.type << " " << (int)type_a.precision << " " + << (int)type_a.scale << ") " << get_op_str(op_type) << " (" << (int)type_b.type << " " + << (int)type_b.precision << " " << (int)type_b.scale << ") = \n"; + ret = Numeric<64>::getRetType(op_type, type_a, type_b); + } catch (std::runtime_error& e) { + ASSERT_EQ(type_out.type, TSDB_DATA_TYPE_MAX); + if (config.log) cout << "not support!" << endl; + return; + } + if (config.log) + cout << "(" << (int)ret.type << " " << (int)ret.precision << " " << (int)ret.scale << ") expect:" << endl + << "(" << (int)type_out.type << " " << (int)type_out.precision << " " << (int)type_out.scale << ")" << endl; + if (config.check_res_type) ASSERT_EQ(ret.type, type_out.type); + ASSERT_EQ(ret.precision, type_out.precision); + ASSERT_EQ(ret.scale, type_out.scale); + if (config.check_bytes) ASSERT_EQ(ret.bytes, type_out.bytes); + } +}; + +TEST(decimal_all, ret_type_load_from_file) { + GTEST_SKIP(); + std::string fname = "/tmp/ret_type.txt"; + std::ifstream ifs(fname, std::ios_base::in); + if (!ifs.is_open()) { + std::cerr << "open file " << fname << " failed" << std::endl; + FAIL(); + } + char buf[64]; + int32_t total_lines = 0; + while (ifs.getline(buf, 64, '\n')) { + DecimalRetTypeCheckContent dcc(buf); + DecimalRetTypeCheckConfig config; + config.check_res_type = false; + config.check_bytes = false; + config.log = false; + dcc.check(config); + ++total_lines; + } + ASSERT_EQ(total_lines, 3034205); +} + +TEST(decimal_all, ret_type_for_non_decimal_types) { + std::vector non_decimal_types; + SDataType decimal_type = {TSDB_DATA_TYPE_DECIMAL64, 10, 2, 8}; + EOperatorType op = OP_TYPE_DIV; + std::vector out_types; + auto count_digits = [](uint64_t v) { + return std::floor(std::log10(v) + 1); + }; + std::vector equivalent_decimal_types; + // #define TSDB_DATA_TYPE_NULL 0 // 1 bytes + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_NULL, 0, 0, tDataTypes[TSDB_DATA_TYPE_NULL].bytes}); + // #define TSDB_DATA_TYPE_BOOL 1 // 1 bytes + equivalent_decimal_types.push_back(getDecimalType(1, 0)); + // #define TSDB_DATA_TYPE_TINYINT 2 // 1 byte + equivalent_decimal_types.push_back(getDecimalType(count_digits(INT8_MAX), 0)); + // #define TSDB_DATA_TYPE_SMALLINT 3 // 2 bytes + equivalent_decimal_types.push_back(getDecimalType(count_digits(INT16_MAX), 0)); + // #define TSDB_DATA_TYPE_INT 4 // 4 bytes + equivalent_decimal_types.push_back(getDecimalType(count_digits(INT32_MAX), 0)); + // #define TSDB_DATA_TYPE_BIGINT 5 // 8 bytes + equivalent_decimal_types.push_back(getDecimalType(count_digits(INT64_MAX), 0)); + // #define TSDB_DATA_TYPE_FLOAT 6 // 4 bytes + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_DOUBLE, 0, 0, tDataTypes[TSDB_DATA_TYPE_DOUBLE].bytes}); + // #define TSDB_DATA_TYPE_DOUBLE 7 // 8 bytes + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_DOUBLE, 0, 0, tDataTypes[TSDB_DATA_TYPE_DOUBLE].bytes}); + // #define TSDB_DATA_TYPE_VARCHAR 8 // string, alias for varchar + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_DOUBLE, 0, 0, tDataTypes[TSDB_DATA_TYPE_DOUBLE].bytes}); + // #define TSDB_DATA_TYPE_TIMESTAMP 9 // 8 bytes + equivalent_decimal_types.push_back(getDecimalType(count_digits(INT64_MAX), 0)); + // #define TSDB_DATA_TYPE_NCHAR 10 // unicode string + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_DOUBLE, 0, 0, tDataTypes[TSDB_DATA_TYPE_DOUBLE].bytes}); + // #define TSDB_DATA_TYPE_UTINYINT 11 // 1 byte + equivalent_decimal_types.push_back(getDecimalType(count_digits(UINT8_MAX), 0)); + // #define TSDB_DATA_TYPE_USMALLINT 12 // 2 bytes + equivalent_decimal_types.push_back(getDecimalType(count_digits(UINT16_MAX), 0)); + // #define TSDB_DATA_TYPE_UINT 13 // 4 bytes + equivalent_decimal_types.push_back(getDecimalType(count_digits(UINT32_MAX), 0)); + // #define TSDB_DATA_TYPE_UBIGINT 14 // 8 bytes + equivalent_decimal_types.push_back(getDecimalType(count_digits(UINT64_MAX), 0)); + // #define TSDB_DATA_TYPE_JSON 15 // json string + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_MAX, 0, 0, 0}); + // #define TSDB_DATA_TYPE_VARBINARY 16 // binary + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_MAX, 0, 0, 0}); + // #define TSDB_DATA_TYPE_DECIMAL 17 // decimal + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_MAX, 0, 0, 0}); + // #define TSDB_DATA_TYPE_BLOB 18 // binary + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_MAX, 0, 0, 0}); + // #define TSDB_DATA_TYPE_MEDIUMBLOB 19 + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_MAX, 0, 0, 0}); + // #define TSDB_DATA_TYPE_GEOMETRY 20 // geometry + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_MAX, 0, 0, 0}); + // #define TSDB_DATA_TYPE_DECIMAL64 21 // decimal64 + equivalent_decimal_types.push_back({TSDB_DATA_TYPE_MAX, 0, 0, 0}); + + for (uint8_t i = 0; i < TSDB_DATA_TYPE_MAX; ++i) { + if (IS_DECIMAL_TYPE(i)) continue; + SDataType equivalent_out_type = equivalent_decimal_types[i]; + if (equivalent_out_type.type != TSDB_DATA_TYPE_MAX) + equivalent_out_type = Numeric<128>::getRetType(op, decimal_type, equivalent_decimal_types[i]); + DecimalRetTypeCheckContent dcc{decimal_type, {i, 0, 0, tDataTypes[i].bytes}, equivalent_out_type, op}; + dcc.check(); + + if (equivalent_out_type.type != TSDB_DATA_TYPE_MAX) { + equivalent_out_type = Numeric<128>::getRetType(op, equivalent_decimal_types[i], decimal_type); + } + DecimalRetTypeCheckContent dcc2{{i, 0, 0, tDataTypes[i].bytes}, decimal_type, equivalent_out_type, op}; + dcc2.check(); + } } int main(int argc, char** argv) { diff --git a/source/libs/executor/CMakeLists.txt b/source/libs/executor/CMakeLists.txt index 9a49076b6b..9a6b52bfb5 100644 --- a/source/libs/executor/CMakeLists.txt +++ b/source/libs/executor/CMakeLists.txt @@ -11,7 +11,7 @@ if(${BUILD_WITH_ANALYSIS}) endif() target_link_libraries(executor - PRIVATE os util common function parser planner qcom scalar nodes index wal tdb geometry + PRIVATE os util common function parser planner qcom scalar nodes index wal tdb geometry profiler ) target_include_directories( diff --git a/source/libs/executor/src/operator.c b/source/libs/executor/src/operator.c index 057deed038..96595a6a7f 100644 --- a/source/libs/executor/src/operator.c +++ b/source/libs/executor/src/operator.c @@ -28,6 +28,7 @@ #include "storageapi.h" #include "tdatablock.h" +#include "gperftools/profiler.h" SOperatorFpSet createOperatorFpSet(__optr_open_fn_t openFn, __optr_fn_t nextFn, __optr_fn_t cleanup, __optr_close_fn_t closeFn, __optr_reqBuf_fn_t reqBufFn, __optr_explain_fn_t explain, @@ -282,6 +283,7 @@ int32_t stopTableScanOperator(SOperatorInfo* pOperator, const char* pIdStr, SSto int32_t createOperator(SPhysiNode* pPhyNode, SExecTaskInfo* pTaskInfo, SReadHandle* pHandle, SNode* pTagCond, SNode* pTagIndexCond, const char* pUser, const char* dbname, SOperatorInfo** pOptrInfo) { QRY_PARAM_CHECK(pOptrInfo); + ProfilerStart("/tmp/createOperator.prof"); int32_t code = 0; int32_t type = nodeType(pPhyNode); @@ -654,6 +656,7 @@ int32_t createOperator(SPhysiNode* pPhyNode, SExecTaskInfo* pTaskInfo, SReadHand } void destroyOperator(SOperatorInfo* pOperator) { + ProfilerFlush(); if (pOperator == NULL) { return; } diff --git a/source/libs/scalar/src/filter.c b/source/libs/scalar/src/filter.c index 4c070ef447..879ae187d1 100644 --- a/source/libs/scalar/src/filter.c +++ b/source/libs/scalar/src/filter.c @@ -4204,6 +4204,7 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn void *pData = NULL; if (datum->type.type == TSDB_DATA_TYPE_DECIMAL64) { pData = &datum->i; // TODO wjm set kind + datum->kind = FLT_SCL_DATUM_KIND_DECIMAL64; } else if (datum->type.type == TSDB_DATA_TYPE_DECIMAL) { pData = taosMemoryCalloc(1, pColNode->node.resType.bytes); if (!pData) FLT_ERR_RET(terrno); diff --git a/source/libs/scalar/src/sclvector.c b/source/libs/scalar/src/sclvector.c index c0246356a6..d54d480c22 100644 --- a/source/libs/scalar/src/sclvector.c +++ b/source/libs/scalar/src/sclvector.c @@ -1655,37 +1655,40 @@ int32_t vectorMathRemainder(SScalarParam *pLeft, SScalarParam *pRight, SScalarPa int32_t leftConvert = 0, rightConvert = 0; SColumnInfoData *pLeftCol = NULL; SColumnInfoData *pRightCol = NULL; - SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol)); - SCL_ERR_JRET(vectorConvertVarToDouble(pRight, &rightConvert, &pRightCol)); + if (pOutputCol->info.type == TSDB_DATA_TYPE_DECIMAL) { + SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_REM)); + } else { + SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol)); + SCL_ERR_JRET(vectorConvertVarToDouble(pRight, &rightConvert, &pRightCol)); - _getDoubleValue_fn_t getVectorDoubleValueFnLeft; - _getDoubleValue_fn_t getVectorDoubleValueFnRight; - SCL_ERR_JRET(getVectorDoubleValueFn(pLeftCol->info.type, &getVectorDoubleValueFnLeft)); - SCL_ERR_JRET(getVectorDoubleValueFn(pRightCol->info.type, &getVectorDoubleValueFnRight)); + _getDoubleValue_fn_t getVectorDoubleValueFnLeft; + _getDoubleValue_fn_t getVectorDoubleValueFnRight; + SCL_ERR_JRET(getVectorDoubleValueFn(pLeftCol->info.type, &getVectorDoubleValueFnLeft)); + SCL_ERR_JRET(getVectorDoubleValueFn(pRightCol->info.type, &getVectorDoubleValueFnRight)); - double *output = (double *)pOutputCol->pData; + double *output = (double *)pOutputCol->pData; - int32_t numOfRows = TMAX(pLeft->numOfRows, pRight->numOfRows); - for (; i < numOfRows && i >= 0; i += step, output += 1) { - int32_t leftidx = pLeft->numOfRows == 1 ? 0 : i; - int32_t rightidx = pRight->numOfRows == 1 ? 0 : i; - if (IS_HELPER_NULL(pLeftCol, leftidx) || IS_HELPER_NULL(pRightCol, rightidx)) { - colDataSetNULL(pOutputCol, i); - continue; + int32_t numOfRows = TMAX(pLeft->numOfRows, pRight->numOfRows); + for (; i < numOfRows && i >= 0; i += step, output += 1) { + int32_t leftidx = pLeft->numOfRows == 1 ? 0 : i; + int32_t rightidx = pRight->numOfRows == 1 ? 0 : i; + if (IS_HELPER_NULL(pLeftCol, leftidx) || IS_HELPER_NULL(pRightCol, rightidx)) { + colDataSetNULL(pOutputCol, i); + continue; + } + + double lx = 0; + double rx = 0; + SCL_ERR_JRET(getVectorDoubleValueFnLeft(LEFT_COL, leftidx, &lx)); + SCL_ERR_JRET(getVectorDoubleValueFnRight(RIGHT_COL, rightidx, &rx)); + if (isnan(lx) || isinf(lx) || isnan(rx) || isinf(rx) || FLT_EQUAL(rx, 0)) { + colDataSetNULL(pOutputCol, i); + continue; + } + + *output = lx - ((int64_t)(lx / rx)) * rx; } - - double lx = 0; - double rx = 0; - SCL_ERR_JRET(getVectorDoubleValueFnLeft(LEFT_COL, leftidx, &lx)); - SCL_ERR_JRET(getVectorDoubleValueFnRight(RIGHT_COL, rightidx, &rx)); - if (isnan(lx) || isinf(lx) || isnan(rx) || isinf(rx) || FLT_EQUAL(rx, 0)) { - colDataSetNULL(pOutputCol, i); - continue; - } - - *output = lx - ((int64_t)(lx / rx)) * rx; } - _return: doReleaseVec(pLeftCol, leftConvert); doReleaseVec(pRightCol, rightConvert); @@ -2276,15 +2279,28 @@ static int32_t vectorMathOpOneRowForDecimal(SScalarParam *pLeft, SScalarParam *p outType = GET_COL_DATA_TYPE(pOut->columnData->info); if (IS_HELPER_NULL(pOneRowParam->columnData, 0)) { colDataSetNNULL(pOut->columnData, 0, pNotOneRowParam->numOfRows); + } + Decimal oneRowData = {0}; + SDataType oneRowType = outType; + if (pLeft == pOneRowParam) { + oneRowType.scale = leftType.scale; + code = convertToDecimal(colDataGetData(pLeft->columnData, 0), &leftType, &oneRowData, &oneRowType); } else { - for (; i < pNotOneRowParam->numOfRows && i >= 0 && TSDB_CODE_SUCCESS == code; i += step, output += 1) { - if (IS_HELPER_NULL(pNotOneRowParam->columnData, i)) { - colDataSetNULL(pOut->columnData, i); - continue; - } - code = decimalOp(op, &leftType, &rightType, &outType, - colDataGetData(pLeft->columnData, pLeft == pOneRowParam ? 0 : i), - colDataGetData(pRight->columnData, pRight == pOneRowParam ? 0 : i), output); + oneRowType.scale = rightType.scale; + code = convertToDecimal(colDataGetData(pRight->columnData, 0), &rightType, &oneRowData, &oneRowType); + } + if (code != 0) return code; + + for (; i < pNotOneRowParam->numOfRows && i >= 0 && TSDB_CODE_SUCCESS == code; i += step, output += 1) { + if (IS_HELPER_NULL(pNotOneRowParam->columnData, i)) { + colDataSetNULL(pOut->columnData, i); + continue; + } + if (pOneRowParam == pLeft) { + code = + decimalOp(op, &oneRowType, &rightType, &outType, &oneRowData, colDataGetData(pRight->columnData, i), output); + } else { + code = decimalOp(op, &leftType, &oneRowType, &outType, colDataGetData(pLeft->columnData, i), &oneRowData, output); } } return code;