From caccf4add932b447e3cc837d3e9b6964334c2594 Mon Sep 17 00:00:00 2001 From: wangjiaming0909 Date: Wed, 5 Mar 2025 21:51:37 +0800 Subject: [PATCH] test decimal agg funcs --- include/common/tcommon.h | 9 +- include/libs/decimal/decimal.h | 1 + source/common/src/tdataformat.c | 52 +++-- source/dnode/vnode/src/tsdb/tsdbUtil.c | 26 ++- source/libs/decimal/src/decimal.c | 12 + source/libs/decimal/test/decimalTest.cpp | 11 + source/libs/function/inc/functionResInfoInt.h | 20 +- source/libs/function/src/builtinsimpl.c | 37 ++-- .../libs/function/src/detail/tavgfunction.c | 37 ++-- source/libs/function/src/detail/tminmax.c | 2 +- source/libs/scalar/src/filter.c | 2 +- tests/system-test/2-query/decimal.py | 205 ++++++++++++++++-- 12 files changed, 327 insertions(+), 87 deletions(-) diff --git a/include/common/tcommon.h b/include/common/tcommon.h index d3e58d8afe..f17104f818 100644 --- a/include/common/tcommon.h +++ b/include/common/tcommon.h @@ -177,20 +177,21 @@ typedef struct SColumnDataAgg { struct { uint64_t decimal128Sum[2]; uint64_t decimal128Max[2]; - uint64_t decimal128Min[2]; // TODO wjm 1. use deicmal128Sum for decimal64, 2. add overflow flag + uint64_t decimal128Min[2]; // TODO wjm 1. use deicmal128Sum for decimal64, 2. add overflow flag + uint8_t overflow; }; }; } SColumnDataAgg; #pragma pack(pop) #define COL_AGG_GET_SUM_PTR(pAggs, dataType) \ - (dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->sum : (void*)pAggs->decimal128Sum) + (!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->sum : (void*)pAggs->decimal128Sum) #define COL_AGG_GET_MAX_PTR(pAggs, dataType) \ - (dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->max : (void*)pAggs->decimal128Max) + (!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->max : (void*)pAggs->decimal128Max) #define COL_AGG_GET_MIN_PTR(pAggs, dataType) \ - (dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->min : (void*)pAggs->decimal128Min) + (!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->min : (void*)pAggs->decimal128Min) typedef struct SBlockID { // The uid of table, from which current data block comes. And it is always 0, if current block is the diff --git a/include/libs/decimal/decimal.h b/include/libs/decimal/decimal.h index b4df324ee7..5fec585cdb 100644 --- a/include/libs/decimal/decimal.h +++ b/include/libs/decimal/decimal.h @@ -95,6 +95,7 @@ bool decimalCompare(EOperatorType op, const SDecimalCompareCtx* pLeft, const int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pRightT, const SDataType* pOutT, const void* pLeftData, const void* pRightData, void* pOutputData); int32_t convertToDecimal(const void* pData, const SDataType* pInputType, void* pOut, const SDataType* pOutType); +bool decimal128AddCheckOverflow(const Decimal128* pLeft, const DecimalType* pRight, uint8_t rightWordNum); DEFINE_TYPE_FROM_DECIMAL_FUNCS(, Decimal64); DEFINE_TYPE_FROM_DECIMAL_FUNCS(, Decimal128); diff --git a/source/common/src/tdataformat.c b/source/common/src/tdataformat.c index b5be8123d3..528019c248 100644 --- a/source/common/src/tdataformat.c +++ b/source/common/src/tdataformat.c @@ -4222,29 +4222,36 @@ static FORCE_INLINE void tColDataCalcSMAVarType(SColData *pColData, SColumnDataA } } -#define CALC_DECIMAL_SUM_MAX_MIN(TYPE, pOps, pColData, pSum, pMax, pMin) \ - for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) { \ - pVal = ((TYPE *)pColData->pData) + iVal; \ - pOps->add(pSum, pVal, WORD_NUM(TYPE)); \ - if (pOps->gt(pVal, pMax, WORD_NUM(TYPE))) { \ - *(pMax) = *pVal; \ - } \ - if (pOps->lt(pVal, pMin, WORD_NUM(TYPE))) { \ - *(pMin) = *pVal; \ - } \ - } +#define CALC_DECIMAL_SUM_MAX_MIN(TYPE, pSumOp, pCompOp, pColData, pSum, pMax, pMin) \ + do { \ + if (decimal128AddCheckOverflow((Decimal *)pSum, pVal, WORD_NUM(TYPE))) *pOverflow = true; \ + pSumOp->add(pSum, pVal, WORD_NUM(TYPE)); \ + if (pCompOp->gt(pVal, pMax, WORD_NUM(TYPE))) { \ + *(pMax) = *pVal; \ + } \ + if (pCompOp->lt(pVal, pMin, WORD_NUM(TYPE))) { \ + *(pMin) = *pVal; \ + } \ + } while (0) static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColumnDataAgg* pAggs) { - Decimal64* pSum = (Decimal64*)&pAggs->sum, *pMax = (Decimal64*)&pAggs->max, *pMin = (Decimal64*)&pAggs->min; - *pSum = DECIMAL64_ZERO; + Decimal128 *pSum = (Decimal128 *)pAggs->decimal128Sum; + Decimal64 *pMax = (Decimal64 *)pAggs->decimal128Max, *pMin = (Decimal64 *)pAggs->decimal128Min; + uint8_t *pOverflow = &pAggs->overflow; + *pSum = DECIMAL128_ZERO; *pMax = DECIMAL64_MIN; *pMin = DECIMAL64_MAX; pAggs->numOfNull = 0; + pAggs->colId |= 0x80000000; // TODO wjm define it Decimal64 *pVal = NULL; - SDecimalOps *pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL64); + const SDecimalOps *pSumOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); + const SDecimalOps *pCompOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL64); if (HAS_VALUE == pColData->flag) { - CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pOps, pColData, pSum, pMax, pMin); + for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) { + pVal = ((Decimal64*)pColData->pData) + iVal; + CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pSumOps, pCompOps, pColData, pSum, pMax, pMin); + } } else { for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) { switch (tColDataGetBitValue(pColData, iVal)) { @@ -4253,7 +4260,8 @@ static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColum pAggs->numOfNull++; break; case 2: - CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pOps, pColData, pSum, pMax, pMin);// TODO wjm what if overflow + pVal = ((Decimal64 *)pColData->pData) + iVal; + CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pSumOps, pCompOps, pColData, pSum, pMax, pMin); break; default: break; @@ -4263,7 +4271,9 @@ static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColum } static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColumnDataAgg* pAggs) { - Decimal128* pSum = (Decimal128*)pAggs->decimal128Sum, *pMax = (Decimal128*)pAggs->decimal128Max, *pMin = (Decimal128*)pAggs->decimal128Min; + Decimal128 *pSum = (Decimal128 *)pAggs->decimal128Sum, *pMax = (Decimal128 *)pAggs->decimal128Max, + *pMin = (Decimal128 *)pAggs->decimal128Min; + uint8_t *pOverflow = &pAggs->overflow; *pSum = DECIMAL128_ZERO; *pMax = DECIMAL128_MIN; *pMin = DECIMAL128_MAX; @@ -4273,7 +4283,10 @@ static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColu Decimal128 *pVal = NULL; SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); if (HAS_VALUE == pColData->flag) { - CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pColData, pSum, pMax, pMin); + for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) { + pVal = ((Decimal128*)pColData->pData) + iVal; + CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pOps, pColData, pSum, pMax, pMin); + } } else { for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) { switch (tColDataGetBitValue(pColData, iVal)) { @@ -4282,7 +4295,8 @@ static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColu pAggs->numOfNull++; break; case 2: - CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pColData, pSum, pMax, pMin); + pVal = ((Decimal128*)pColData->pData) + iVal; + CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pOps, pColData, pSum, pMax, pMin); break; default: break; diff --git a/source/dnode/vnode/src/tsdb/tsdbUtil.c b/source/dnode/vnode/src/tsdb/tsdbUtil.c index 51c7a5049c..454b0d0e85 100644 --- a/source/dnode/vnode/src/tsdb/tsdbUtil.c +++ b/source/dnode/vnode/src/tsdb/tsdbUtil.c @@ -1577,12 +1577,13 @@ int32_t tPutColumnDataAgg(SBuffer *buffer, SColumnDataAgg *pColAgg) { if (pColAgg->colId & 0x80000000) { if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code; if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code; - if ((code = tBufferPutI64(buffer, pColAgg->decimal128Sum[0]))) return code; - if ((code = tBufferPutI64(buffer, pColAgg->decimal128Sum[1]))) return code; - if ((code = tBufferPutI64(buffer, pColAgg->decimal128Max[0]))) return code; - if ((code = tBufferPutI64(buffer, pColAgg->decimal128Max[1]))) return code; - if ((code = tBufferPutI64(buffer, pColAgg->decimal128Min[0]))) return code; - if ((code = tBufferPutI64(buffer, pColAgg->decimal128Min[1]))) return code; + if ((code = tBufferPutU64(buffer, pColAgg->decimal128Sum[0]))) return code; + if ((code = tBufferPutU64(buffer, pColAgg->decimal128Sum[1]))) return code; + if ((code = tBufferPutU64(buffer, pColAgg->decimal128Max[0]))) return code; + if ((code = tBufferPutU64(buffer, pColAgg->decimal128Max[1]))) return code; + if ((code = tBufferPutU64(buffer, pColAgg->decimal128Min[0]))) return code; + if ((code = tBufferPutU64(buffer, pColAgg->decimal128Min[1]))) return code; + if ((code = tBufferPutU8(buffer, pColAgg->overflow))) return code; } else { if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code; if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code; @@ -1601,12 +1602,13 @@ int32_t tGetColumnDataAgg(SBufferReader *br, SColumnDataAgg *pColAgg) { if ((code = tBufferGetI16v(br, &pColAgg->numOfNull))) return code; if (pColAgg->colId & 0x80000000) { pColAgg->colId &= 0xFFFF; - if ((code = tBufferGetI64(br, &pColAgg->decimal128Sum[0]))) return code; - if ((code = tBufferGetI64(br, &pColAgg->decimal128Sum[1]))) return code; - if ((code = tBufferGetI64(br, &pColAgg->decimal128Max[0]))) return code; - if ((code = tBufferGetI64(br, &pColAgg->decimal128Max[1]))) return code; - if ((code = tBufferGetI64(br, &pColAgg->decimal128Min[0]))) return code; - if ((code = tBufferGetI64(br, &pColAgg->decimal128Min[1]))) return code; + if ((code = tBufferGetU64(br, &pColAgg->decimal128Sum[0]))) return code; + if ((code = tBufferGetU64(br, &pColAgg->decimal128Sum[1]))) return code; + if ((code = tBufferGetU64(br, &pColAgg->decimal128Max[0]))) return code; + if ((code = tBufferGetU64(br, &pColAgg->decimal128Max[1]))) return code; + if ((code = tBufferGetU64(br, &pColAgg->decimal128Min[0]))) return code; + if ((code = tBufferGetU64(br, &pColAgg->decimal128Min[1]))) return code; + if ((code = tBufferGetU8(br, &pColAgg->overflow))) return code; } else { if ((code = tBufferGetI64(br, &pColAgg->sum))) return code; if ((code = tBufferGetI64(br, &pColAgg->max))) return code; diff --git a/source/libs/decimal/src/decimal.c b/source/libs/decimal/src/decimal.c index 6856cd5a85..887f93446d 100644 --- a/source/libs/decimal/src/decimal.c +++ b/source/libs/decimal/src/decimal.c @@ -1920,6 +1920,18 @@ static int32_t decimal128CountRoundingDelta(const Decimal128* pDec, int8_t scale return res; } +bool decimal128AddCheckOverflow(const Decimal128* pLeft, const DecimalType* pRight, uint8_t rightWordNum) { + if (DECIMAL128_SIGN(pLeft) == 0) { + Decimal128 max = decimal128Max; + decimal128Subtract(&max, pLeft, WORD_NUM(Decimal128)); + return decimal128Lt(&max, pRight, rightWordNum); + } else { + Decimal128 min = decimal128Min; + decimal128Subtract(&min, pLeft, WORD_NUM(Decimal128)); + return decimal128Gt(&min, pRight, rightWordNum); + } +} + int32_t TEST_decimal64From_int64_t(Decimal64* pDec, uint8_t prec, uint8_t scale, int64_t v) { return decimal64FromInt64(pDec, prec, scale, v); } diff --git a/source/libs/decimal/test/decimalTest.cpp b/source/libs/decimal/test/decimalTest.cpp index 0e3788c4c3..17f63b1e75 100644 --- a/source/libs/decimal/test/decimalTest.cpp +++ b/source/libs/decimal/test/decimalTest.cpp @@ -1509,6 +1509,17 @@ TEST_F(DecimalTest, decimalFromStr) { Numeric<128> numeric128 = {38, 10, "0"}; } +TEST(decimal, test_add_check_overflow) { + Numeric<128> dec128 = {38, 10, "9999999999999999999999999999.9999999999"}; + Numeric<64> dec64 = {18, 2, "123.12"}; + bool overflow = decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64)); + ASSERT_TRUE(overflow); + dec128 = {38, 10, "-9999999999999999999999999999.9999999999"}; + ASSERT_FALSE(decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64))); + dec64 = {18, 2, "-123.1"}; + ASSERT_TRUE(decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64))); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/source/libs/function/inc/functionResInfoInt.h b/source/libs/function/inc/functionResInfoInt.h index 4380c73fa6..c3e03326cf 100644 --- a/source/libs/function/inc/functionResInfoInt.h +++ b/source/libs/function/inc/functionResInfoInt.h @@ -87,13 +87,19 @@ typedef struct SDecimalSumRes { #define SUM_RES_GET_DECIMAL_SUM(pSumRes) ((SDecimalSumRes*)(pSumRes))->sum // TODO wjm check for overflow -#define SUM_RES_INC_DECIMAL_SUM(pSumRes, pVal, type) \ - do { \ - const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \ - if (type == TSDB_DATA_TYPE_DECIMAL64) \ - pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, WORD_NUM(Decimal64)); \ - else \ - pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, WORD_NUM(Decimal)); \ +#define SUM_RES_INC_DECIMAL_SUM(pSumRes, pVal, type) \ + do { \ + const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \ + int32_t wordNum = 0; \ + if (type == TSDB_DATA_TYPE_DECIMAL64) { \ + wordNum = WORD_NUM(Decimal64); \ + overflow = decimal128AddCheckOverflow(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \ + } else { \ + wordNum = WORD_NUM(Decimal); \ + overflow = decimal128AddCheckOverflow(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \ + } \ + if (overflow) break; \ + pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \ } while (0) typedef struct SMinmaxResInfo { diff --git a/source/libs/function/src/builtinsimpl.c b/source/libs/function/src/builtinsimpl.c index 008ed542d0..2fc2735962 100644 --- a/source/libs/function/src/builtinsimpl.c +++ b/source/libs/function/src/builtinsimpl.c @@ -106,17 +106,19 @@ typedef enum { } \ } while (0) -#define LIST_ADD_DECIMAL_N(_res, _col, _start, _rows, _t, numOfElem) \ - do { \ - _t* d = (_t*)(_col->pData); \ - const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \ - for (int32_t i = (_start); i < (_rows) + (_start); ++i) { \ - if (((_col)->hasNull) && colDataIsNull_f((_col)->nullbitmap, i)) { \ - continue; \ - }; \ - pOps->add(_res, d + i, WORD_NUM(_t)); \ - (numOfElem)++; \ - } \ +#define LIST_ADD_DECIMAL_N(_res, _col, _start, _rows, _t, numOfElem) \ + do { \ + _t* d = (_t*)(_col->pData); \ + const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \ + for (int32_t i = (_start); i < (_rows) + (_start); ++i) { \ + if (((_col)->hasNull) && colDataIsNull_f((_col)->nullbitmap, i)) { \ + continue; \ + }; \ + overflow = overflow || decimal128AddCheckOverflow((Decimal*)_res, d + i, WORD_NUM(_t)); \ + if (overflow) break; \ + pOps->add(_res, d + i, WORD_NUM(_t)); \ + (numOfElem)++; \ + } \ } while (0) #define LIST_SUB_N(_res, _col, _start, _rows, _t, numOfElem) \ @@ -652,12 +654,12 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) { SUM_RES_INC_DSUM(pSumRes, GET_DOUBLE_VAL((const char*)&(pAgg->sum))); } else if (IS_DECIMAL_TYPE(type)) { SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL); - const SDecimalOps* pOps = getDecimalOps(type); - if (TSDB_DATA_TYPE_DECIMAL64 == type) { - pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->sum, WORD_NUM(Decimal64)); - } else if (TSDB_DATA_TYPE_DECIMAL == type) { - pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->decimal128Sum, WORD_NUM(Decimal)); + const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); + if (pAgg->overflow || decimal128AddCheckOverflow((Decimal*)&SUM_RES_GET_DECIMAL_SUM(pSumRes), + &pAgg->decimal128Sum, WORD_NUM(Decimal))) { + return TSDB_CODE_DECIMAL_OVERFLOW; } + pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->decimal128Sum, WORD_NUM(Decimal)); } } else { // computing based on the true data block SColumnInfoData* pCol = pInput->pData[0]; @@ -691,12 +693,13 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) { LIST_ADD_N(SUM_RES_GET_DSUM(pSumRes), pCol, start, numOfRows, float, numOfElem); } else if (IS_DECIMAL_TYPE(type)) { SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL); + int32_t overflow = false; if (TSDB_DATA_TYPE_DECIMAL64 == type) { LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal64, numOfElem); } else if (TSDB_DATA_TYPE_DECIMAL == type) { LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal128, numOfElem); } - // TODO wjm check overflow + if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW; } } diff --git a/source/libs/function/src/detail/tavgfunction.c b/source/libs/function/src/detail/tavgfunction.c index 28ef2931fb..4e38fb7ce6 100644 --- a/source/libs/function/src/detail/tavgfunction.c +++ b/source/libs/function/src/detail/tavgfunction.c @@ -126,7 +126,7 @@ int32_t avgFunctionSetup(SqlFunctionCtx* pCtx, SResultRowEntryInfo* pResultInfo) return TSDB_CODE_SUCCESS; } -static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type, const SColumnDataAgg* pAgg) { +static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type, const SColumnDataAgg* pAgg, int32_t* pNumOfElem) { int32_t numOfElem = numOfRows - pAgg->numOfNull; AVG_RES_INC_COUNT(pRes, type, numOfElem); @@ -137,16 +137,17 @@ static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type } else if (IS_FLOAT_TYPE(type)) { SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pRes), GET_DOUBLE_VAL((const char*)&(pAgg->sum))); } else if (IS_DECIMAL_TYPE(type)) { - if (type == TSDB_DATA_TYPE_DECIMAL64) - SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->sum, TSDB_DATA_TYPE_DECIMAL64); - else - SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->decimal128Sum, TSDB_DATA_TYPE_DECIMAL); + bool overflow = pAgg->overflow; + if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW; + SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->decimal128Sum, TSDB_DATA_TYPE_DECIMAL); + if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW; } - return numOfElem; + *pNumOfElem = numOfElem; + return 0; } -static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputColumnInfoData *pInput, void* pRes) { +static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputColumnInfoData *pInput, void* pRes, int32_t* pNumOfElem) { int32_t start = pInput->startRowIndex; int32_t numOfRows = pInput->numOfRows; int32_t numOfElems = 0; @@ -306,14 +307,17 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol numOfElems += 1; AVG_RES_INC_COUNT(pRes, type, 1); + bool overflow = false; SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), (const void*)(pDec + i * tDataTypes[type].bytes), type); + if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW; } } break; default: break; } - return numOfElems; + *pNumOfElem = numOfElems; + return 0; } int32_t avgFunction(SqlFunctionCtx* pCtx) { @@ -341,7 +345,8 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { if (IS_DECIMAL_TYPE(type)) AVG_RES_SET_INPUT_SCALE(pAvgRes, pInput->pData[0]->info.scale); if (pInput->colDataSMAIsSet) { // try to use SMA if available - numOfElem = calculateAvgBySMAInfo(pAvgRes, numOfRows, type, pAgg); + int32_t code = calculateAvgBySMAInfo(pAvgRes, numOfRows, type, pAgg, &numOfElem); + if (code != 0) return code; } else if (!pCol->hasNull) { // try to employ the simd instructions to speed up the loop numOfElem = pInput->numOfRows; AVG_RES_INC_COUNT(pAvgRes, pCtx->inputType, pInput->numOfRows); @@ -424,6 +429,7 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { const char* pDec = pCol->pData; // TODO wjm check for overflow for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) { + bool overflow = false; if (type == TSDB_DATA_TYPE_DECIMAL64) { SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes), TSDB_DATA_TYPE_DECIMAL64); @@ -431,13 +437,14 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes), TSDB_DATA_TYPE_DECIMAL); } + if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW; } } break; default: return TSDB_CODE_FUNC_FUNTION_PARA_TYPE; } } else { - numOfElem = doAddNumericVector(pCol, type, pInput, pAvgRes); + int32_t code = doAddNumericVector(pCol, type, pInput, pAvgRes, &numOfElem); } _over: @@ -446,12 +453,12 @@ _over: return TSDB_CODE_SUCCESS; } -static void avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) { +static int32_t avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) { int32_t inputDT = pCtx->pExpr->pExpr->_function.pFunctNode->srcFuncInputType.type; int32_t type = AVG_RES_GET_TYPE(pInput, inputDT); pCtx->inputType = type; if (IS_NULL_TYPE(type)) { - return; + return 0; } @@ -464,12 +471,15 @@ static void avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) { CHECK_OVERFLOW_SUM_UNSIGNED_BIG(pOutput, (overflow ? SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput)) : SUM_RES_GET_USUM(&AVG_RES_GET_SUM(pInput))), overflow); } else if (IS_DECIMAL_TYPE(type)) { AVG_RES_SET_INPUT_SCALE(pOutput, AVG_RES_GET_INPUT_SCALE(pInput)); + bool overflow = false; SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pOutput), &AVG_RES_GET_DECIMAL_SUM(pInput), TSDB_DATA_TYPE_DECIMAL); + if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW; } else { SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pOutput), SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput))); } AVG_RES_INC_COUNT(pOutput, type, AVG_RES_GET_COUNT(pInput, true, type)); + return 0; } int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) { @@ -493,7 +503,8 @@ int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) { if(colDataIsNull_s(pCol, i)) continue; char* data = colDataGetData(pCol, i); void* pInputInfo = varDataVal(data); - avgTransferInfo(pCtx, pInputInfo, pInfo); + int32_t code = avgTransferInfo(pCtx, pInputInfo, pInfo); + if (code != 0) return code; } SET_VAL(GET_RES_INFO(pCtx), 1, 1); diff --git a/source/libs/function/src/detail/tminmax.c b/source/libs/function/src/detail/tminmax.c index c79c6b7553..74244c231d 100644 --- a/source/libs/function/src/detail/tminmax.c +++ b/source/libs/function/src/detail/tminmax.c @@ -619,7 +619,7 @@ int32_t doMinMaxHelper(SqlFunctionCtx* pCtx, int32_t isMinFunc, int32_t* nElems) int16_t index = 0; void* tval = NULL; - if (type == TSDB_DATA_TYPE_DECIMAL) { + if (IS_DECIMAL_TYPE(type)) { tval = isMinFunc ? pInput->pColumnDataAgg[0]->decimal128Min : pInput->pColumnDataAgg[0]->decimal128Max; } else { tval = (isMinFunc) ? &pInput->pColumnDataAgg[0]->min : &pInput->pColumnDataAgg[0]->max; diff --git a/source/libs/scalar/src/filter.c b/source/libs/scalar/src/filter.c index 317ee9131d..8825522021 100644 --- a/source/libs/scalar/src/filter.c +++ b/source/libs/scalar/src/filter.c @@ -4228,7 +4228,7 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn if (datum->kind == FLT_SCL_DATUM_KIND_DECIMAL64 || datum->kind == FLT_SCL_DATUM_KIND_DECIMAL) { int32_t code = convertToDecimal(pInput, &valDt, pData, &datum->type); if (TSDB_CODE_SUCCESS != code) return code; // TODO wjm handle overflow error - valNode->node.resType = datum->type; + //valNode->node.resType = datum->type; } } FLT_RET(0); diff --git a/tests/system-test/2-query/decimal.py b/tests/system-test/2-query/decimal.py index 272140d002..fe0282f770 100644 --- a/tests/system-test/2-query/decimal.py +++ b/tests/system-test/2-query/decimal.py @@ -7,8 +7,10 @@ import time import threading import secrets import numpy +from paramiko import Agent import query +from tag_lite import datatype from util.log import * from util.sql import * from util.cases import * @@ -43,12 +45,13 @@ scalar_convert_err = -2147470768 decimal_insert_validator_test = False -operator_test_round = 10 +operator_test_round = 1 tb_insert_rows = 1000 binary_op_with_const_test = True binary_op_with_col_test = True unary_op_test = True binary_op_in_where_test = True +test_decimal_funcs = True class DecimalTypeGeneratorConfig: def __init__(self): @@ -156,14 +159,14 @@ class DecimalColumnAggregator: self.null_num: int = 0 self.none_num: int = 0 - def add_value(self, value: str): + def add_value(self, value: str, scale: int): self.count += 1 if value == "NULL": self.null_num += 1 elif value == "None": self.none_num += 1 else: - v: Decimal = Decimal(value) + v: Decimal = get_decimal(value, scale) self.sum += v if v > self.max: self.max = v @@ -245,6 +248,10 @@ class DecimalColumnExpr: def should_skip_for_decimal(self, cols: list): return False + def check_query_results(self, query_col_res: List, tbname: str): + query_len = len(query_col_res) + pass + def check_for_filtering(self, query_col_res: List, tbname: str): j: int = -1 for i in range(len(query_col_res)): @@ -479,6 +486,7 @@ class DataType: return val class DecimalType(DataType): + MAX_PRECISION = 38 def __init__(self, type, precision: int, scale: int): self.precision_ = precision self.scale_ = scale @@ -496,6 +504,14 @@ class DecimalType(DataType): def get_decimal_type_mod(self) -> int: return self.precision_ * 100 + self.scale() + + def set_prec(self, prec: int): + self.precision_ = prec + self.type_mod = self.get_decimal_type_mod() + + def set_scale(self, scale: int): + self.scale_ = scale + self.type_mod = self.get_decimal_type_mod() def prec(self): return self.precision_ @@ -520,7 +536,7 @@ class DecimalType(DataType): def generate_value(self) -> str: val = self.decimal_generator.generate(self.generator_config) - self.aggregator.add_value(val) ## convert to Decimal first + self.aggregator.add_value(val, self.scale()) ## convert to Decimal first # self.values.append(val) ## save it into files maybe return val @@ -605,6 +621,12 @@ class Column: if self.is_constant_col(): return self.get_constant_val_for_execute() return self.get_typed_val_for_execute(self.saved_vals[tbname][idx]) + + def get_cardinality(self, tbname): + if self.is_constant_col(): + return 1 + else: + return len(self.saved_vals[tbname]) ## tbName: for normal table, pass the tbname, for child table, pass the child table name def generate_value(self, tbName: str = '', save: bool = True): @@ -824,6 +846,143 @@ class TableDataValidator: col.check(res[colIdx], row_num * self.tbIdx) colIdx += 1 +class DecimalFunction(DecimalColumnExpr): + def __init__(self, format, executor, name: str): + super().__init__(format, executor) + self.func_name_ = name + + def is_agg_func(self, op: str): + return False + + def get_func_res(self): + return None + + def get_input_types(self): + return [self.query_col] + + @staticmethod + def get_decimal_agg_funcs() -> List: + return [DecimalMinFunction(), DecimalMaxFunction(), DecimalSumFunction(), DecimalAvgFunction()] + + def check_results(self, query_col_res: List) -> bool: + return False + + def check_for_agg_func(self, query_col_res: List, tbname: str): + col_expr = self.query_col + for i in range(col_expr.get_cardinality(tbname)): + col_val = col_expr.get_val_for_execute(tbname, i) + self.execute((col_val,)) + if not self.check_results(query_col_res): + tdLog.exit(f"check failed for {self}, query got: {query_col_res}, expect {self.get_func_res()}") + +class DecimalAggFunction(DecimalFunction): + def __init__(self, format, executor, name: str): + super().__init__(format, executor, name) + + def is_agg_func(self, op): + return True + + def should_skip_for_decimal(self, cols: list): + col: Column = cols[0] + if col.type_.is_decimal_type(): + return False + return True + + def check_results(self, query_col_res): + if len(query_col_res) == 0: + tdLog.info(f"query got no output: {self}, py calc: {self.get_func_res()}") + return True + else: + return self.get_func_res() == Decimal(query_col_res[0]) + +class DecimalMinFunction(DecimalAggFunction): + def __init__(self): + super().__init__("min({0})", DecimalMinFunction.execute_min, "min") + self.min_: Decimal = None + + def get_func_res(self) -> Decimal: + decimal_type: DecimalType = self.query_col.type_ + return decimal_type.aggregator.min + return self.min_ + + def generate_res_type(self) -> DataType: + self.res_type_ = self.query_col.type_ + + def execute_min(self, params): + if params[0] is None: + return + if self.min_ is None: + self.min_ = Decimal(params[0]) + else: + self.min_ = min(self.min_, Decimal(params[0])) + return self.min_ + +class DecimalMaxFunction(DecimalAggFunction): + def __init__(self): + super().__init__("max({0})", DecimalMaxFunction.execute_max, "max") + self.max_: Decimal = None + + def get_func_res(self) -> Decimal: + return self.max_ + + def generate_res_type(self) -> DataType: + self.res_type_ = self.query_col.type_ + + def execute_max(self, params): + if params[0] is None: + return + if self.max_ is None: + self.max_ = Decimal(params[0]) + else: + self.max_ = max(self.max_, Decimal(params[0])) + return self.max_ + +class DecimalSumFunction(DecimalAggFunction): + def __init__(self): + super().__init__("sum({0})", DecimalSumFunction.execute_sum, "sum") + self.sum_:Decimal = None + def get_func_res(self) -> Decimal: + decimal_type: DecimalType = self.query_col.type_ + return decimal_type.aggregator.sum + return self.sum_ + def generate_res_type(self) -> DataType: + self.res_type_ = self.query_col.type_ + self.res_type_.set_prec(DecimalType.MAX_PRECISION) + def execute_sum(self, params): + if params[0] is None: + return + if self.sum_ is None: + self.sum_ = Decimal(params[0]) + else: + self.sum_ += Decimal(params[0]) + return self.sum_ + +class DecimalAvgFunction(DecimalAggFunction): + def __init__(self): + super().__init__("avg({0})", DecimalAvgFunction.execute_avg, "avg") + self.count_: Decimal = 0 + self.sum_: Decimal = None + def get_func_res(self) -> Decimal: + decimal_type: DecimalType = self.query_col.type_ + return get_decimal( + decimal_type.aggregator.sum + / (decimal_type.aggregator.count - decimal_type.aggregator.null_num), + self.res_type_.scale(), + ) + def generate_res_type(self) -> DataType: + sum_type = self.query_col.type_ + sum_type.set_prec(DecimalType.MAX_PRECISION) + count_type = DataType(TypeEnum.BIGINT, 8, 0) + self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(sum_type, count_type, "/") + def execute_avg(self, params): + if params[0] is None: + return + if self.sum_ is None: + self.sum_ = Decimal(params[0]) + else: + self.sum_ += Decimal(params[0]) + self.count_ += 1 + return self.get_func_res() class DecimalBinaryOperator(DecimalColumnExpr): def __init__(self, format, executor, op: str): @@ -845,7 +1004,7 @@ class DecimalBinaryOperator(DecimalColumnExpr): return True if self.op_ != "%": return False - ## why skip decimal % float/double? it's wrong now. + ## TODO wjm why skip decimal % float/double? it's wrong now. left_is_real = left_col.type_.is_real_type() or left_col.type_.is_varchar_type() right_is_real = right_col.type_.is_real_type() or right_col.type_.is_varchar_type() if left_is_real or right_is_real: @@ -1554,7 +1713,7 @@ class TDTestCase: threads.append(t) for t in threads: t.join() - + def run_in_thread2(self, func, params) -> threading.Thread: t = threading.Thread(target=func, args=params) t.start() @@ -1638,10 +1797,6 @@ class TDTestCase: self.test_query_decimal_where_clause() - def test_decimal_functions(self): - self.test_decimal_last_first_func() - funcs = ["max", "min", "sum", "avg", "count", "first", "last", "cast"] - def test_decimal_last_first_func(self): pass @@ -1663,9 +1818,7 @@ class TDTestCase: for col in tb_cols: if col.name_ == '': continue - left_is_decimal = col.type_.is_decimal_type() for const_col in constant_cols: - right_is_decimal = const_col.type_.is_decimal_type() if expr.should_skip_for_decimal([col, const_col]): continue const_col.generate_value() @@ -1740,7 +1893,7 @@ class TDTestCase: self.norm_table_name, self.norm_tb_columns, binary_compare_ops) - + ## TODO wjm ## 3. (dec op const col) op const col ## 4. (dec op dec) op const col @@ -1765,6 +1918,32 @@ class TDTestCase: def test_query_decimal_case_when(self): pass + def test_decimal_agg_funcs(self, dbname, tbname, tb_cols: List[Column], get_agg_funcs_func): + agg_funcs: List[DecimalFunction] = get_agg_funcs_func() + for func in agg_funcs: + for col in tb_cols: + if col.name_ == '' or func.should_skip_for_decimal([col]): + continue + func.query_col = col + select_expr = func.generate([col]) + sql = f"select {select_expr} from {dbname}.{tbname}" + res = TaosShell().query(sql) + if len(res) > 0: + res = res[0] + func.check_for_agg_func(res, tbname) + + def test_decimal_functions(self): + if not test_decimal_funcs: + return + self.test_decimal_agg_funcs( + self.db_name, + self.norm_table_name, + self.norm_tb_columns, + DecimalFunction.get_decimal_agg_funcs, + ) + + self.test_decimal_last_first_func() + def test_query_decimal(self): self.test_decimal_operators() self.test_decimal_functions()