From 9b4414c12eb02006a6fac4af6524494ede744eae Mon Sep 17 00:00:00 2001 From: wangjiaming0909 Date: Fri, 7 Mar 2025 18:33:16 +0800 Subject: [PATCH] test decimal functions --- include/libs/decimal/decimal.h | 2 +- source/libs/command/src/command.c | 2 + source/libs/decimal/src/decimal.c | 29 ++- source/libs/function/src/builtins.c | 11 +- source/libs/nodes/src/nodesUtilFuncs.c | 2 + source/libs/parser/src/parTranslater.c | 3 +- source/libs/scalar/inc/sclInt.h | 1 + source/libs/scalar/src/filter.c | 121 ++++++---- source/libs/scalar/src/scalar.c | 2 +- source/libs/scalar/src/sclfunc.c | 22 +- tests/system-test/2-query/decimal.py | 299 ++++++++++++++++++++++--- 11 files changed, 398 insertions(+), 96 deletions(-) diff --git a/include/libs/decimal/decimal.h b/include/libs/decimal/decimal.h index 5fec585cdb..d13bc56a5f 100644 --- a/include/libs/decimal/decimal.h +++ b/include/libs/decimal/decimal.h @@ -67,7 +67,7 @@ static const Decimal128 decimal128Min = DEFINE_DECIMAL128(17759344522308878337UL #define DECIMAL128_CLONE(pDst, pFrom) makeDecimal128(pDst, DECIMAL128_HIGH_WORD(pFrom), DECIMAL128_LOW_WORD(pFrom)) typedef struct SDecimalCompareCtx { - void* pData; + const void* pData; int8_t type; STypeMod typeMod; } SDecimalCompareCtx; diff --git a/source/libs/command/src/command.c b/source/libs/command/src/command.c index e63457c150..e3002766da 100644 --- a/source/libs/command/src/command.c +++ b/source/libs/command/src/command.c @@ -1040,6 +1040,8 @@ static int32_t createSelectResultDataBlock(SNodeList* pProjects, SSDataBlock** p } else { infoData.info.type = pExpr->resType.type; infoData.info.bytes = pExpr->resType.bytes; + infoData.info.precision = pExpr->resType.precision; + infoData.info.scale = pExpr->resType.scale; } QRY_ERR_RET(blockDataAppendColInfo(pBlock, &infoData)); } diff --git a/source/libs/decimal/src/decimal.c b/source/libs/decimal/src/decimal.c index 887f93446d..136b8e4855 100644 --- a/source/libs/decimal/src/decimal.c +++ b/source/libs/decimal/src/decimal.c @@ -1190,12 +1190,30 @@ bool decimal64Compare(EOperatorType op, const SDecimalCompareCtx* pLeft, const S return ret; } -// There is no need to do type conversions, we assume that pLeftT and pRightT are all decimal128 types. bool decimalCompare(EOperatorType op, const SDecimalCompareCtx* pLeft, const SDecimalCompareCtx* pRight) { + if (pLeft->type == TSDB_DATA_TYPE_DECIMAL64 && pRight->type == TSDB_DATA_TYPE_DECIMAL64) { + return decimal64Compare(op, pLeft, pRight); + } bool ret = false; uint8_t leftPrec = 0, leftScale = 0, rightPrec = 0, rightScale = 0; decimalFromTypeMod(pLeft->typeMod, &leftPrec, &leftScale); decimalFromTypeMod(pRight->typeMod, &rightPrec, &rightScale); + + if (pLeft->type == TSDB_DATA_TYPE_DECIMAL64) { + Decimal128 dec128 = {0}; + makeDecimal128FromDecimal64(&dec128, *(Decimal64*)pLeft->pData); + SDecimalCompareCtx leftCtx = {.pData = &dec128, + .type = TSDB_DATA_TYPE_DECIMAL, + .typeMod = decimalCalcTypeMod(TSDB_DECIMAL128_MAX_PRECISION, leftScale)}; + return decimalCompare(op, &leftCtx, pRight); + } else if (pRight->type == TSDB_DATA_TYPE_DECIMAL64) { + Decimal128 dec128 = {0}; + makeDecimal128FromDecimal64(&dec128, *(Decimal64*)pRight->pData); + SDecimalCompareCtx rightCtx = {.pData = &dec128, + .type = TSDB_DATA_TYPE_DECIMAL, + .typeMod = decimalCalcTypeMod(TSDB_DECIMAL128_MAX_PRECISION, rightScale)}; + return decimalCompare(op, pLeft, &rightCtx); + } int32_t deltaScale = leftScale - rightScale; Decimal pLeftDec = *(Decimal*)pLeft->pData, pRightDec = *(Decimal*)pRight->pData; @@ -1297,6 +1315,9 @@ static int32_t decimal64FromDouble(DecimalType* pDec, uint8_t prec, uint8_t scal uint64_t result = (uint64_t)abs; makeDecimal64(pDec, result); + Decimal64 max = {0}; + DECIMAL64_GET_MAX(prec, &max); + if (decimal64Gt(pDec, &max, WORD_NUM(Decimal64))) goto _OVERFLOW; if (negative) decimal64Negate(pDec); return 0; @@ -1353,7 +1374,7 @@ static int64_t int64FromDecimal128(const DecimalType* pDec, uint8_t prec, uint8_ decimal128FromInt64(&min, TSDB_DECIMAL128_MAX_PRECISION, 0, INT64_MIN); if (decimal128Gt(&rounded, &max, WORD_NUM(Decimal128)) || decimal128Lt(&rounded, &min, WORD_NUM(Decimal128))) { overflow = true; - return 0; + return (int64_t)DECIMAL128_LOW_WORD(&rounded); } return (int64_t)DECIMAL128_LOW_WORD(&rounded); @@ -1370,7 +1391,7 @@ static uint64_t uint64FromDecimal128(const DecimalType* pDec, uint8_t prec, uint if (decimal128Gt(&rounded, &max, WORD_NUM(Decimal128)) || decimal128Lt(&rounded, &decimal128Zero, WORD_NUM(Decimal128))) { overflow = true; - return 0; + return DECIMAL128_LOW_WORD(&rounded); } return DECIMAL128_LOW_WORD(&rounded); } @@ -1437,7 +1458,7 @@ static int32_t decimal128FromDecimal64(DecimalType* pDec, uint8_t prec, uint8_t makeDecimal128(pDec, 0, DECIMAL64_GET_VALUE(&dec64)); Decimal128 max = {0}; DECIMAL128_GET_MAX(prec - scale, &max); - decimal64ScaleTo(&dec64, valScale, 0); + decimal64ScaleDown(&dec64, valScale, false); if (decimal128Lt(&max, &dec64, WORD_NUM(Decimal64))) { return TSDB_CODE_DECIMAL_OVERFLOW; } diff --git a/source/libs/function/src/builtins.c b/source/libs/function/src/builtins.c index 18ca0d2f3a..aa919623a1 100644 --- a/source/libs/function/src/builtins.c +++ b/source/libs/function/src/builtins.c @@ -412,7 +412,6 @@ static bool paramSupportGeometry(uint64_t typeFlag) { static bool paramSupportDecimal(uint64_t typeFlag) { return FUNC_MGT_TEST_MASK(typeFlag, FUNC_PARAM_SUPPORT_DECIMAL_TYPE) || - FUNC_MGT_TEST_MASK(typeFlag, FUNC_PARAM_SUPPORT_NUMERIC_TYPE) || FUNC_MGT_TEST_MASK(typeFlag, FUNC_PARAM_SUPPORT_ALL_TYPE); } @@ -1892,7 +1891,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = { .inputParaInfo[0][0] = {.isLastParam = true, .startParam = 1, .endParam = 1, - .validDataType = FUNC_PARAM_SUPPORT_NUMERIC_TYPE | FUNC_PARAM_SUPPORT_NULL_TYPE, + .validDataType = FUNC_PARAM_SUPPORT_NUMERIC_TYPE | FUNC_PARAM_SUPPORT_NULL_TYPE | FUNC_PARAM_SUPPORT_DECIMAL_TYPE, .validNodeType = FUNC_PARAM_SUPPORT_EXPR_NODE, .paramAttribute = FUNC_PARAM_NO_SPECIFIC_ATTRIBUTE, .valueRangeFlag = FUNC_PARAM_NO_SPECIFIC_VALUE,}, @@ -1922,7 +1921,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = { .inputParaInfo[0][0] = {.isLastParam = true, .startParam = 1, .endParam = 1, - .validDataType = FUNC_PARAM_SUPPORT_NUMERIC_TYPE | FUNC_PARAM_SUPPORT_STRING_TYPE | FUNC_PARAM_SUPPORT_NULL_TYPE, + .validDataType = FUNC_PARAM_SUPPORT_NUMERIC_TYPE | FUNC_PARAM_SUPPORT_STRING_TYPE | FUNC_PARAM_SUPPORT_NULL_TYPE | FUNC_PARAM_SUPPORT_DECIMAL_TYPE, .validNodeType = FUNC_PARAM_SUPPORT_EXPR_NODE, .paramAttribute = FUNC_PARAM_NO_SPECIFIC_ATTRIBUTE, .valueRangeFlag = FUNC_PARAM_NO_SPECIFIC_VALUE,}, @@ -1949,7 +1948,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = { .inputParaInfo[0][0] = {.isLastParam = true, .startParam = 1, .endParam = 1, - .validDataType = FUNC_PARAM_SUPPORT_NUMERIC_TYPE | FUNC_PARAM_SUPPORT_STRING_TYPE | FUNC_PARAM_SUPPORT_NULL_TYPE, + .validDataType = FUNC_PARAM_SUPPORT_NUMERIC_TYPE | FUNC_PARAM_SUPPORT_STRING_TYPE | FUNC_PARAM_SUPPORT_NULL_TYPE | FUNC_PARAM_SUPPORT_DECIMAL_TYPE, .validNodeType = FUNC_PARAM_SUPPORT_EXPR_NODE, .paramAttribute = FUNC_PARAM_NO_SPECIFIC_ATTRIBUTE, .valueRangeFlag = FUNC_PARAM_NO_SPECIFIC_VALUE,}, @@ -2090,7 +2089,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = { .inputParaInfo[0][0] = {.isLastParam = true, .startParam = 1, .endParam = 1, - .validDataType = FUNC_PARAM_SUPPORT_NUMERIC_TYPE | FUNC_PARAM_SUPPORT_NULL_TYPE, + .validDataType = FUNC_PARAM_SUPPORT_NUMERIC_TYPE | FUNC_PARAM_SUPPORT_NULL_TYPE | FUNC_PARAM_SUPPORT_DECIMAL_TYPE, .validNodeType = FUNC_PARAM_SUPPORT_EXPR_NODE, .paramAttribute = FUNC_PARAM_NO_SPECIFIC_ATTRIBUTE, .valueRangeFlag = FUNC_PARAM_NO_SPECIFIC_VALUE,}, @@ -2121,7 +2120,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = { .inputParaInfo[0][0] = {.isLastParam = true, .startParam = 1, .endParam = 1, - .validDataType = FUNC_PARAM_SUPPORT_NUMERIC_TYPE | FUNC_PARAM_SUPPORT_NULL_TYPE, + .validDataType = FUNC_PARAM_SUPPORT_NUMERIC_TYPE | FUNC_PARAM_SUPPORT_NULL_TYPE | FUNC_PARAM_SUPPORT_DECIMAL_TYPE, .validNodeType = FUNC_PARAM_SUPPORT_EXPR_NODE, .paramAttribute = FUNC_PARAM_NO_SPECIFIC_ATTRIBUTE, .valueRangeFlag = FUNC_PARAM_NO_SPECIFIC_VALUE,}, diff --git a/source/libs/nodes/src/nodesUtilFuncs.c b/source/libs/nodes/src/nodesUtilFuncs.c index 46529f00ae..d45d3edd5e 100644 --- a/source/libs/nodes/src/nodesUtilFuncs.c +++ b/source/libs/nodes/src/nodesUtilFuncs.c @@ -2308,12 +2308,14 @@ void* nodesGetValueFromNode(SValueNode* pNode) { case TSDB_DATA_TYPE_UBIGINT: case TSDB_DATA_TYPE_FLOAT: case TSDB_DATA_TYPE_DOUBLE: + case TSDB_DATA_TYPE_DECIMAL64: return (void*)&pNode->typeData; case TSDB_DATA_TYPE_NCHAR: case TSDB_DATA_TYPE_VARCHAR: case TSDB_DATA_TYPE_VARBINARY: case TSDB_DATA_TYPE_JSON: case TSDB_DATA_TYPE_GEOMETRY: + case TSDB_DATA_TYPE_DECIMAL: return (void*)pNode->datum.p; default: break; diff --git a/source/libs/parser/src/parTranslater.c b/source/libs/parser/src/parTranslater.c index 57fa0953b1..95d1398e7e 100755 --- a/source/libs/parser/src/parTranslater.c +++ b/source/libs/parser/src/parTranslater.c @@ -2517,7 +2517,8 @@ static bool hasInvalidFuncNesting(SFunctionNode* pFunc) { static int32_t getFuncInfo(STranslateContext* pCxt, SFunctionNode* pFunc) { // the time precision of the function execution environment pFunc->dual = pCxt->dual; - pFunc->node.resType.precision = getPrecisionFromCurrStmt(pCxt->pCurrStmt, TSDB_TIME_PRECISION_MILLI); + if (!IS_DECIMAL_TYPE(pFunc->node.resType.type)) + pFunc->node.resType.precision = getPrecisionFromCurrStmt(pCxt->pCurrStmt, TSDB_TIME_PRECISION_MILLI); int32_t code = fmGetFuncInfo(pFunc, pCxt->msgBuf.buf, pCxt->msgBuf.len); if (TSDB_CODE_FUNC_NOT_BUILTIN_FUNTION == code) { code = getUdfInfo(pCxt, pFunc); diff --git a/source/libs/scalar/inc/sclInt.h b/source/libs/scalar/inc/sclInt.h index 6e5e4fb6a3..457880c172 100644 --- a/source/libs/scalar/inc/sclInt.h +++ b/source/libs/scalar/inc/sclInt.h @@ -142,6 +142,7 @@ int32_t sclConvertToTsValueNode(int8_t precision, SValueNode* valueNode); #define GET_PARAM_TYPE(_c) ((_c)->columnData ? (_c)->columnData->info.type : (_c)->filterValueType) #define GET_PARAM_BYTES(_c) ((_c)->columnData->info.bytes) #define GET_PARAM_PRECISON(_c) ((_c)->columnData->info.precision) +#define GET_PARAM_SCALE(_c) ((_c)->columnData->info.scale) void sclFreeParam(SScalarParam* param); int32_t doVectorCompare(SScalarParam* pLeft, SScalarParam *pLeftVar, SScalarParam* pRight, SScalarParam *pOut, int32_t startIndex, int32_t numOfRows, diff --git a/source/libs/scalar/src/filter.c b/source/libs/scalar/src/filter.c index 8825522021..34c6bde332 100644 --- a/source/libs/scalar/src/filter.c +++ b/source/libs/scalar/src/filter.c @@ -3955,6 +3955,14 @@ int32_t fltSclCompareWithFloat64(SFltSclDatum *val1, SFltSclDatum *val2) { case FLT_SCL_DATUM_KIND_FLOAT64: { return compareDoubleVal(&val1->d, &val2->d); } + case FLT_SCL_DATUM_KIND_DECIMAL64: { + double d = doubleFromDecimal64(&val1->i, val1->type.precision, val1->type.scale); + return compareDoubleVal(&d, &val2->d); + } + case FLT_SCL_DATUM_KIND_DECIMAL: { + double d = doubleFromDecimal128(val1->pData, val1->type.precision, val1->type.scale); + return compareDoubleVal(&d, &val2->d); + } // TODO: varchar, nchar default: qError("not supported comparsion. kind1 %d, kind2 %d", val1->kind, val2->kind); @@ -3996,12 +4004,25 @@ int32_t fltSclCompareWithUInt64(SFltSclDatum *val1, SFltSclDatum *val2) { } } +int32_t fltSclCompareWithDecimal(void* pData1, const SDataType* pType1, void* pData2, const SDataType* pType2) { + SDecimalCompareCtx ctx1 = {.pData = pData1, .type =pType1->type, .typeMod = decimalCalcTypeMod(pType1->precision, pType1->scale)}, + ctx2 = {.pData = pData2, .type = pType2->type, .typeMod = decimalCalcTypeMod(pType2->precision, pType2->scale)}; + if (decimalCompare(OP_TYPE_GREATER_THAN, &ctx1, &ctx2)) return 1; + if (decimalCompare(OP_TYPE_EQUAL, &ctx1, &ctx2)) return 0; + return -1; +} + int32_t fltSclCompareDatum(SFltSclDatum *val1, SFltSclDatum *val2) { if (val2->kind == FLT_SCL_DATUM_KIND_NULL || val2->kind == FLT_SCL_DATUM_KIND_MIN || val2->kind == FLT_SCL_DATUM_KIND_MAX) { return (val1->kind < val2->kind) ? -1 : ((val1->kind > val2->kind) ? 1 : 0); } + if (val1->kind == FLT_SCL_DATUM_KIND_NULL || val1->kind == FLT_SCL_DATUM_KIND_MIN || + val1->kind == FLT_SCL_DATUM_KIND_MAX) { + return (val1->kind < val2->kind) ? -1 : ((val1->kind > val2->kind) ? 1 : 0); + } + switch (val2->kind) { case FLT_SCL_DATUM_KIND_UINT64: { return fltSclCompareWithUInt64(val1, val2); @@ -4013,20 +4034,13 @@ int32_t fltSclCompareDatum(SFltSclDatum *val1, SFltSclDatum *val2) { return fltSclCompareWithFloat64(val1, val2); } case FLT_SCL_DATUM_KIND_DECIMAL64: { - if (val1->kind == FLT_SCL_DATUM_KIND_NULL || val1->kind == FLT_SCL_DATUM_KIND_MIN || - val1->kind == FLT_SCL_DATUM_KIND_MAX) { - return (val1->kind < val2->kind) ? -1 : ((val1->kind > val2->kind) ? 1 : 0); - } - return compareDecimal64SameScale(&val1->i, &val2->i); + void* pData1 = val1->kind == FLT_SCL_DATUM_KIND_DECIMAL64 ? (void*)&val1->i : (void*)val1->pData; + return fltSclCompareWithDecimal(pData1, &val1->type, &val2->i, &val2->type); } case FLT_SCL_DATUM_KIND_DECIMAL: { - if (val1->kind == FLT_SCL_DATUM_KIND_NULL || val1->kind == FLT_SCL_DATUM_KIND_MIN || - val1->kind == FLT_SCL_DATUM_KIND_MAX) { - return (val1->kind < val2->kind) ? -1 : ((val1->kind > val2->kind) ? 1 : 0); - } - return compareDecimal128SameScale(val1->pData, val2->pData); + void* pData1 = val1->kind == FLT_SCL_DATUM_KIND_DECIMAL64 ? (void*)&val1->i : (void*)val1->pData; + return fltSclCompareWithDecimal(pData1, &val1->type, val2->pData, &val2->type); } - // TODO: varchar/nchar default: qError("not supported kind when compare datum. kind2 : %d", val2->kind); return 0; @@ -4203,13 +4217,26 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn break; case TSDB_DATA_TYPE_FLOAT: case TSDB_DATA_TYPE_DOUBLE: - pInput = &valNode->datum.d; - valDt.type = TSDB_DATA_TYPE_DOUBLE; - break; + datum->kind = FLT_SCL_DATUM_KIND_FLOAT64; + datum->type = valDt; + datum->d = valNode->datum.d; + FLT_RET(0); case TSDB_DATA_TYPE_VARCHAR: - pInput = valNode->literal; - break; - // TODO wjm test cast to decimal + datum->kind = FLT_SCL_DATUM_KIND_FLOAT64; + datum->type.type = TSDB_DATA_TYPE_DOUBLE; + datum->type.bytes = DOUBLE_BYTES; + datum->d = taosStr2Double(valNode->literal, NULL); + FLT_RET(0); + case TSDB_DATA_TYPE_DECIMAL64: + datum->kind = FLT_SCL_DATUM_KIND_DECIMAL64; + datum->type = valDt; + datum->i = valNode->datum.i; + FLT_RET(0); + case TSDB_DATA_TYPE_DECIMAL: + datum->kind = FLT_SCL_DATUM_KIND_DECIMAL; + datum->type = valDt; + datum->pData = (void*)valNode->datum.p; + FLT_RET(0); default: qError("not supported type %d when build decimal datum from value node", valNode->node.resType.type); return TSDB_CODE_INVALID_PARA; @@ -4217,7 +4244,7 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn void *pData = NULL; if (datum->type.type == TSDB_DATA_TYPE_DECIMAL64) { - pData = &datum->i; // TODO wjm set kind + pData = &datum->i; datum->kind = FLT_SCL_DATUM_KIND_DECIMAL64; } else if (datum->type.type == TSDB_DATA_TYPE_DECIMAL) { pData = taosMemoryCalloc(1, pColNode->node.resType.bytes); @@ -4313,7 +4340,7 @@ int32_t fltSclBuildDatumFromBlockSmaValue(SFltSclDatum *datum, uint8_t type, voi break; } case TSDB_DATA_TYPE_DECIMAL64: - datum->kind = FLT_SCL_DATUM_KIND_DECIMAL; + datum->kind = FLT_SCL_DATUM_KIND_DECIMAL64; datum->u = *(uint64_t *)val; break; case TSDB_DATA_TYPE_DECIMAL: @@ -5020,35 +5047,39 @@ static int32_t fltSclBuildRangePointsForInOper(SFltSclOperator* oper, SArray* po SValueNode *valueNode = (SValueNode *)nodeItem; SFltSclDatum valDatum; FLT_ERR_RET(fltSclBuildDatumFromValueNode(&valDatum, oper->colNode, valueNode)); + if (valDatum.kind == FLT_SCL_DATUM_KIND_NULL) { + continue; + } + if (IS_DECIMAL_TYPE(oper->colNode->node.resType.type)) { + if (IS_DECIMAL_TYPE(valDatum.type.type)) { + double v = valDatum.type.type == TSDB_DATA_TYPE_DECIMAL64 + ? doubleFromDecimal64(&valDatum.i, valDatum.type.precision, valDatum.type.scale) + : doubleFromDecimal128(valDatum.pData, valDatum.type.precision, valDatum.type.scale); + if (minDatum.kind == FLT_SCL_DATUM_KIND_FLOAT64) { + minDatum.d = TMIN(v, minDatum.d); + maxDatum.d = TMAX(v, maxDatum.d); + } else if (minDatum.kind == FLT_SCL_DATUM_KIND_INT64) { + minDatum.d = v; + maxDatum.d = v; + minDatum.kind = FLT_SCL_DATUM_KIND_FLOAT64; + maxDatum.kind = FLT_SCL_DATUM_KIND_FLOAT64; + } + } else if (valDatum.kind == FLT_SCL_DATUM_KIND_FLOAT64) { + if (minDatum.kind == FLT_SCL_DATUM_KIND_INT64) { + minDatum.kind = FLT_SCL_DATUM_KIND_FLOAT64; + maxDatum.kind = FLT_SCL_DATUM_KIND_FLOAT64; + minDatum.d = TMIN(valDatum.d, minDatum.d); + maxDatum.d = TMAX(valDatum.d, maxDatum.d); + } else { + minDatum.d = TMIN(valDatum.d, minDatum.d); + maxDatum.d = TMAX(valDatum.d, maxDatum.d); + } + } + continue; + } if(valueNode->node.resType.type == TSDB_DATA_TYPE_FLOAT || valueNode->node.resType.type == TSDB_DATA_TYPE_DOUBLE) { minDatum.i = TMIN(minDatum.i, valDatum.d); maxDatum.i = TMAX(maxDatum.i, valDatum.d); - } else if (IS_DECIMAL_TYPE(valueNode->node.resType.type)) { - // TODO wjm test it, looks like we cannot assign double or decimal values to int64, what if in (0, 1.9), and there is a block with all col range in 1.1-1.8. - SDecimalOps* pOps = getDecimalOps(valueNode->node.resType.type); - if (valueNode->node.resType.type == TSDB_DATA_TYPE_DECIMAL64) { - // TODO wjm do i need to convert precision and scale??? - if (pOps->gt(&minDatum.i, &valDatum.i, WORD_NUM(Decimal64))) minDatum.i = valDatum.i; - if (pOps->lt(&maxDatum.i, &valDatum.i, WORD_NUM(Decimal64))) maxDatum.i = valDatum.i; - maxDatum.kind = minDatum.kind = FLT_SCL_DATUM_KIND_DECIMAL64; - } else if (valueNode->node.resType.type == TSDB_DATA_TYPE_DECIMAL) { - if (listNode->pNodeList->pHead->pNode == nodeItem) { - // first node in list, set min/max datum - minDatum.pData = taosMemoryCalloc(1, sizeof(Decimal)); - if (!minDatum.pData) return terrno; - maxDatum.pData = taosMemoryCalloc(1, sizeof(Decimal)); - if (!maxDatum.pData) { - taosMemoryFreeClear(minDatum.pData); - return terrno; - } - DECIMAL128_CLONE((Decimal*)minDatum.pData, &decimal128Max); - DECIMAL128_CLONE((Decimal*)maxDatum.pData, &decimal128Min); - } - if (pOps->gt(minDatum.pData, valDatum.pData, WORD_NUM(Decimal))) DECIMAL128_CLONE((Decimal*)minDatum.pData, (Decimal*)valDatum.pData); - - if (pOps->lt(maxDatum.pData, valDatum.pData, WORD_NUM(Decimal))) DECIMAL128_CLONE((Decimal*)maxDatum.pData, (Decimal*)valDatum.pData); - maxDatum.kind = minDatum.kind = FLT_SCL_DATUM_KIND_DECIMAL; - } } else { minDatum.i = TMIN(minDatum.i, valDatum.i); maxDatum.i = TMAX(maxDatum.i, valDatum.i); diff --git a/source/libs/scalar/src/scalar.c b/source/libs/scalar/src/scalar.c index 42e5aa2639..aaf1cd1a34 100644 --- a/source/libs/scalar/src/scalar.c +++ b/source/libs/scalar/src/scalar.c @@ -168,7 +168,7 @@ int32_t scalarGenerateSetFromList(void **data, void *pNode, uint32_t type, SType if (overflow) { continue; } - // TODO For decimal types, after conversion, check if we lose some scale to ignore values with larger scale + // TODO wjm For decimal types, after conversion, check if we lose some scale to ignore values with larger scale // e.g. convert decimal(18, 4) to decimal(18, 2) with value: // 1.2345 -> 1.23. 1.23 != 1.2345, ignore this value, can't be the same as any decimal(18, 2) // 1.2300 -> 1.23. 1.2300 == 1.23, take this value. diff --git a/source/libs/scalar/src/sclfunc.c b/source/libs/scalar/src/sclfunc.c index 38ab814d87..c7cb2bba27 100644 --- a/source/libs/scalar/src/sclfunc.c +++ b/source/libs/scalar/src/sclfunc.c @@ -2099,7 +2099,14 @@ int32_t castFunction(SScalarParam *pInput, int32_t inputNum, SScalarParam *pOutp varDataSetLen(output, len); } else { int32_t outputSize = (outputLen - VARSTR_HEADER_SIZE) < bufSize ? (outputLen - VARSTR_HEADER_SIZE + 1): bufSize; - NUM_TO_STRING(inputType, input, outputSize, buf); + if (IS_DECIMAL_TYPE(inputType)) { + if (outputType == TSDB_DATA_TYPE_GEOMETRY) return TSDB_CODE_FUNC_FUNTION_PARA_TYPE; + uint8_t inputPrec = GET_PARAM_PRECISON(&pInput[0]), inputScale = GET_PARAM_SCALE(&pInput[0]); + code = decimalToStr(input, inputType, inputPrec, inputScale, buf, outputSize); + if (code != 0) goto _end; + } else { + NUM_TO_STRING(inputType, input, outputSize, buf); + } int32_t len = (int32_t)strlen(buf); len = (outputLen - VARSTR_HEADER_SIZE) > len ? len : (outputLen - VARSTR_HEADER_SIZE); (void)memcpy(varDataVal(output), buf, len); @@ -2145,7 +2152,13 @@ int32_t castFunction(SScalarParam *pInput, int32_t inputNum, SScalarParam *pOutp (void)memcpy(output, input, len + VARSTR_HEADER_SIZE); varDataSetLen(output, len); } else { - NUM_TO_STRING(inputType, input, bufSize, buf); + if (IS_DECIMAL_TYPE(inputType)) { + uint8_t inputPrec = GET_PARAM_PRECISON(&pInput[0]), inputScale = GET_PARAM_SCALE(&pInput[0]); + code = decimalToStr(input, inputType, inputPrec, inputScale, buf, bufSize); + if (code != 0) goto _end; + } else { + NUM_TO_STRING(inputType, input, bufSize, buf); + } len = (int32_t)strlen(buf); len = outputCharLen > len ? len : outputCharLen; bool ret = taosMbsToUcs4(buf, len, (TdUcs4 *)varDataVal(output), outputLen - VARSTR_HEADER_SIZE, &len, pInput->charsetCxt); @@ -2175,7 +2188,10 @@ int32_t castFunction(SScalarParam *pInput, int32_t inputNum, SScalarParam *pOutp convBuf[len] = 0; code = convertToDecimal(convBuf, &iT, output, &oT); } else { - code = convertToDecimal(input, &iT, output, &oT); + if (IS_VAR_DATA_TYPE(iT.type)) + code = convertToDecimal(varDataVal(input), &iT, output, &oT); + else + code = convertToDecimal(input, &iT, output, &oT); } if (code != TSDB_CODE_SUCCESS) { terrno = code; diff --git a/tests/system-test/2-query/decimal.py b/tests/system-test/2-query/decimal.py index fe0282f770..d31b9c1b50 100644 --- a/tests/system-test/2-query/decimal.py +++ b/tests/system-test/2-query/decimal.py @@ -7,10 +7,7 @@ import time import threading import secrets import numpy -from paramiko import Agent -import query -from tag_lite import datatype from util.log import * from util.sql import * from util.cases import * @@ -33,7 +30,7 @@ class AtomicCounter: getcontext().prec = 40 def get_decimal(val, scale: int) -> Decimal: - getcontext().prec = 40 + getcontext().prec = 100 return Decimal(val).quantize(Decimal("1." + "0" * scale), ROUND_HALF_UP) syntax_error = -2147473920 @@ -52,6 +49,7 @@ binary_op_with_col_test = True unary_op_test = True binary_op_in_where_test = True test_decimal_funcs = True +cast_func_test_round = 100 class DecimalTypeGeneratorConfig: def __init__(self): @@ -194,7 +192,7 @@ class TaosShell: if len(self.queryResult) == 0: self.queryResult = [[] for i in range(len(vals))] for val in vals: - self.queryResult[col].append(val.strip()) + self.queryResult[col].append(val.strip().strip('"')) col += 1 def query(self, sql: str): @@ -245,7 +243,7 @@ class DecimalColumnExpr: def get_input_types(self) -> List: pass - def should_skip_for_decimal(self, cols: list): + def should_skip_for_decimal(self, cols: list)->bool: return False def check_query_results(self, query_col_res: List, tbname: str): @@ -394,7 +392,7 @@ class DataType: return TypeEnum.get_type_str(self.type) def __eq__(self, other): - return self.type == other.type and self.length == other.length + return self.type == other.type and self.length == other.length and self.type_mod == other.type_mod def __ne__(self, other): return not self.__eq__(other) @@ -407,10 +405,10 @@ class DataType: def is_decimal_type(self): return self.type == TypeEnum.DECIMAL or self.type == TypeEnum.DECIMAL64 - + def is_varchar_type(self): return self.type == TypeEnum.VARCHAR or self.type == TypeEnum.NCHAR or self.type == TypeEnum.VARBINARY or self.type == TypeEnum.JSON or self.type == TypeEnum.BINARY - + def is_real_type(self): return self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE @@ -419,7 +417,7 @@ class DataType: def scale(self): return 0 - + ## TODO generate NULL, None def generate_value(self) -> str: if self.type == TypeEnum.BOOL: @@ -458,7 +456,7 @@ class DataType: def check(self, values, offset: int): return True - + def get_typed_val_for_execute(self, val, const_col = False): if self.type == TypeEnum.DOUBLE: return float(val) @@ -477,7 +475,7 @@ class DataType: elif isinstance(val, str): val = val.strip("'") return val - + def get_typed_val(self, val): if self.type == TypeEnum.FLOAT: return float(str(numpy.float32(val))) @@ -485,8 +483,33 @@ class DataType: return float(val) return val + @staticmethod + def get_decimal_types() -> list[int]: + return [TypeEnum.DECIMAL64, TypeEnum.DECIMAL] + + @staticmethod + def get_decimal_op_types()-> list[int]: + return [ + TypeEnum.BOOL, + TypeEnum.TINYINT, + TypeEnum.SMALLINT, + TypeEnum.INT, + TypeEnum.BIGINT, + TypeEnum.FLOAT, + TypeEnum.DOUBLE, + TypeEnum.VARCHAR, + TypeEnum.NCHAR, + TypeEnum.UTINYINT, + TypeEnum.USMALLINT, + TypeEnum.UINT, + TypeEnum.UBIGINT, + TypeEnum.DECIMAL, + TypeEnum.DECIMAL64, + ] + class DecimalType(DataType): - MAX_PRECISION = 38 + DECIMAL_MAX_PRECISION = 38 + DECIMAL64_MAX_PRECISION = 18 def __init__(self, type, precision: int, scale: int): self.precision_ = precision self.scale_ = scale @@ -686,7 +709,7 @@ class Column: + Column.get_decimal_types() + types_unable_to_be_const ) - + @staticmethod def get_decimal_types() -> List: return [TypeEnum.DECIMAL, TypeEnum.DECIMAL64] @@ -823,6 +846,36 @@ class TableInserter: self.conn.execute(f"flush database {self.dbName}", queryTimes=1) self.conn.execute(sql, queryTimes=1) +class DecimalCastTypeGenerator: + def __init__(self, input_type: DataType): + self.input_type_: DataType = input_type + + def get_possible_output_types(self) -> List[int]: + if not self.input_type_.is_decimal_type(): + return DataType.get_decimal_types() + else: + return DataType.get_decimal_op_types() + + def do_generate_type(self, dt: int) ->DataType: + if dt == TypeEnum.DECIMAL: + prec = random.randint(1, DecimalType.DECIMAL_MAX_PRECISION) + return DecimalType(dt, prec, random.randint(0, prec)) + elif dt == TypeEnum.DECIMAL64: + prec = random.randint(1, DecimalType.DECIMAL64_MAX_PRECISION) + return DecimalType(dt, prec, random.randint(0, prec)) + elif dt == TypeEnum.BINARY or dt == TypeEnum.VARCHAR: + return DataType(dt, random.randint(16, 255), 0) + else: + return DataType(dt, 0, 0) + + def generate(self, num: int) -> List[DataType]: + res: list[DataType] = [] + for _ in range(num): + dt = random.choice(self.get_possible_output_types()) + dt = self.do_generate_type(dt) + res.append(dt) + res = list(set(res)) + return res class TableDataValidator: def __init__(self, columns: List[Column], tbName: str, dbName: str, tbIdx: int = 0): @@ -850,39 +903,145 @@ class DecimalFunction(DecimalColumnExpr): def __init__(self, format, executor, name: str): super().__init__(format, executor) self.func_name_ = name - - def is_agg_func(self, op: str): + + def is_agg_func(self, op: str) ->bool: return False def get_func_res(self): return None - + def get_input_types(self): return [self.query_col] - + @staticmethod def get_decimal_agg_funcs() -> List: - return [DecimalMinFunction(), DecimalMaxFunction(), DecimalSumFunction(), DecimalAvgFunction()] + return [ + DecimalMinFunction(), + DecimalMaxFunction(), + DecimalSumFunction(), + DecimalAvgFunction(), + DecimalCountFunction(), + ] def check_results(self, query_col_res: List) -> bool: return False - def check_for_agg_func(self, query_col_res: List, tbname: str): + def check_for_agg_func(self, query_col_res: List, tbname: str, func): col_expr = self.query_col for i in range(col_expr.get_cardinality(tbname)): col_val = col_expr.get_val_for_execute(tbname, i) self.execute((col_val,)) if not self.check_results(query_col_res): tdLog.exit(f"check failed for {self}, query got: {query_col_res}, expect {self.get_func_res()}") + else: + tdLog.info(f"check expr: {func} with val: {col_val} got result: {query_col_res}, expect: {self.get_func_res()}") + +class DecimalCastFunction(DecimalFunction): + def __init__(self): + super().__init__("cast({0} as {1})", DecimalCastFunction.execute_cast, "cast") + + def should_skip_for_decimal(self, cols: list)->bool: + return False + + def check_results(self, query_col_res: List) -> bool: + return False + + def generate_res_type(self)->DataType: + self.query_col = self.params_[0] + self.res_type_ = self.params_[1] + return self.res_type_ + + def check(self, res: list, tbname: str): + calc_res = [] + params = [] + for i in range(self.query_col.get_cardinality(tbname)): + val = self.query_col.get_val_for_execute(tbname, i) + params.append(val) + try: + calc_val = self.execute(val) + except OverflowError as e: + tdLog.info(f"execute {self} overflow for param: {val}") + calc_res = [] + break + calc_res.append(calc_val) + if len(calc_res) != len(res): + tdLog.exit(f"check result for {self} failed len got: {len(res)}, expect: {len(calc_res)}") + if len(calc_res) == 0: + return True + for v, calc_v, p in zip(res, calc_res, params): + query_v = self.execute_cast(v) + if isinstance(calc_v, float): + eq = math.isclose(query_v, calc_v, rel_tol=1e-7) + elif isinstance(calc_v, numpy.float32): + eq = math.isclose(query_v, calc_v, abs_tol=1e-6, rel_tol=1e-6) + elif isinstance(p, float) or isinstance(p, str): + eq = math.isclose(query_v, calc_v, rel_tol=1e-7) + else: + eq = query_v == calc_v + if not eq: + tdLog.exit(f"check result for {self} failed with param: {p} query got: {v}, expect: {calc_v}") + return True + + def execute_cast(self, val): + if val is None or val == 'NULL': + return None + if self.res_type_.type == TypeEnum.BOOL: + return Decimal(val) != 0 + elif self.res_type_.type == TypeEnum.TINYINT: + dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP) + return int(dec) & 0xFF + elif self.res_type_.type == TypeEnum.SMALLINT: + dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP) + return int(dec) & 0xFFFF + elif self.res_type_.type == TypeEnum.INT: + dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP) + return int(dec) & 0xFFFFFFFF + elif self.res_type_.type == TypeEnum.BIGINT or self.res_type_.type == TypeEnum.TIMESTAMP: + dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP) + return int(dec) & 0xFFFFFFFFFFFFFFFF + elif self.res_type_.type == TypeEnum.FLOAT: + return numpy.float32(val) + elif self.res_type_.type == TypeEnum.DOUBLE: + return float(val) + elif self.res_type_.type == TypeEnum.VARCHAR or self.res_type_.type == TypeEnum.NCHAR: + if Decimal(val) == 0: + return "0" + return str(val)[0:self.res_type_.length] + elif self.res_type_.type == TypeEnum.UTINYINT: + dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP) + return int(dec) & 0xFF + elif self.res_type_.type == TypeEnum.USMALLINT: + dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP) + return int(dec) & 0xFFFF + elif self.res_type_.type == TypeEnum.UINT: + dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP) + return int(dec) & 0xFFFFFFFF + elif self.res_type_.type == TypeEnum.UBIGINT: + dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP) + return int(dec) & 0xFFFFFFFFFFFFFFFF + elif self.res_type_.is_decimal_type(): + max: Decimal = Decimal( + "9" * (self.res_type_.prec() - self.res_type_.scale()) + + "." + + "9" * self.res_type_.scale() + ) + if max < get_decimal(val, self.res_type_.scale()): + raise OverflowError() + try: + return get_decimal(val, self.res_type_.scale()) + except Exception as e: + tdLog.exit(f"failed to cast {val} to {self.res_type_}, {e}") + else: + raise Exception(f"cast unsupported type {self.res_type_.type}") class DecimalAggFunction(DecimalFunction): def __init__(self, format, executor, name: str): super().__init__(format, executor, name) - - def is_agg_func(self, op): + + def is_agg_func(self, op: str)-> bool: return True - - def should_skip_for_decimal(self, cols: list): + + def should_skip_for_decimal(self, cols: list)-> bool: col: Column = cols[0] if col.type_.is_decimal_type(): return False @@ -895,6 +1054,62 @@ class DecimalAggFunction(DecimalFunction): else: return self.get_func_res() == Decimal(query_col_res[0]) +class DecimalLastRowFunction(DecimalAggFunction): + def __init__(self): + super().__init__("last_row({0})", DecimalLastRowFunction.execute_last_row, "last_row") + def get_func_res(self): + return 1 + def generate_res_type(self): + self.res_type_ = self.query_col.type_ + def execute_last_row(self, params): + return 1 + +class DecimalCacheLastRowFunction(DecimalAggFunction): + def __init__(self): + super().__init__("_cache_last_row({0})", DecimalCacheLastRowFunction.execute_cache_last_row, "_cache_last_row") + def get_func_res(self): + return 1 + def generate_res_type(self): + self.res_type_ = self.query_col.type_ + def execute_cache_last_row(self, params): + return 1 + +class DecimalCacheLastFunction(DecimalAggFunction): + pass + +class DecimalFirstFunction(DecimalAggFunction): + pass + +class DecimalLastFunction(DecimalAggFunction): + pass + +class DecimalHyperloglogFunction(DecimalAggFunction): + pass + +class DecimalSampleFunction(DecimalAggFunction): + pass + +class DecimalTailFunction(DecimalAggFunction): + pass + +class DecimalUniqueFunction(DecimalAggFunction): + pass + +class DecimalModeFunction(DecimalAggFunction): + pass + +class DecimalCountFunction(DecimalAggFunction): + def __init__(self): + super().__init__("count({0})", DecimalCountFunction.execute_count, "count") + def get_func_res(self): + decimal_type: DecimalType = self.query_col.type_ + return decimal_type.aggregator.count - decimal_type.aggregator.null_num + def generate_res_type(self): + self.res_type_ = DataType(TypeEnum.BIGINT, 8, 0) + + def execute_count(self, params): + return 1 + class DecimalMinFunction(DecimalAggFunction): def __init__(self): super().__init__("min({0})", DecimalMinFunction.execute_min, "min") @@ -947,7 +1162,7 @@ class DecimalSumFunction(DecimalAggFunction): return self.sum_ def generate_res_type(self) -> DataType: self.res_type_ = self.query_col.type_ - self.res_type_.set_prec(DecimalType.MAX_PRECISION) + self.res_type_.set_prec(DecimalType.DECIMAL_MAX_PRECISION) def execute_sum(self, params): if params[0] is None: return @@ -971,7 +1186,7 @@ class DecimalAvgFunction(DecimalAggFunction): ) def generate_res_type(self) -> DataType: sum_type = self.query_col.type_ - sum_type.set_prec(DecimalType.MAX_PRECISION) + sum_type.set_prec(DecimalType.DECIMAL_MAX_PRECISION) count_type = DataType(TypeEnum.BIGINT, 8, 0) self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(sum_type, count_type, "/") def execute_avg(self, params): @@ -993,10 +1208,10 @@ class DecimalBinaryOperator(DecimalColumnExpr): def __str__(self): return super().__str__() - + def generate(self, format_params) -> str: return super().generate(format_params) - + def should_skip_for_decimal(self, cols: list): left_col = cols[0] right_col = cols[1] @@ -1247,10 +1462,10 @@ class DecimalUnaryOperator(DecimalColumnExpr): def get_input_types(self)-> list: return [self.col_type_] - + def generate_res_type(self): self.res_type_ = self.col_type_ = self.params_[0].type_ - + def execute_minus(self, params) -> Decimal: if params[0] is None: return 'NULL' @@ -1595,8 +1810,8 @@ class TDTestCase: self.no_decimal_table_test() self.test_insert_decimal_values() self.test_query_decimal() - ##self.test_decimal_and_stream() - ##self.test_decimal_and_tsma() + self.test_decimal_and_stream() + self.test_decimal_and_tsma() def stop(self): tdSql.close() @@ -1930,7 +2145,22 @@ class TDTestCase: res = TaosShell().query(sql) if len(res) > 0: res = res[0] - func.check_for_agg_func(res, tbname) + func.check_for_agg_func(res, tbname, func) + + def test_decimal_cast_func(self, dbname, tbname, tb_cols: List[Column]): + for col in tb_cols: + if col.name_ == '': + continue + to_types: list[DataType] = DecimalCastTypeGenerator(col.type_).generate(cast_func_test_round) + for t in to_types: + cast_func = DecimalCastFunction() + expr = cast_func.generate([col, t]) + sql = f"select {expr} from {dbname}.{tbname}" + res = TaosShell().query(sql) + if len(res) > 0: + res = res[0] + cast_func.check(res, tbname) + def test_decimal_functions(self): if not test_decimal_funcs: @@ -1941,8 +2171,7 @@ class TDTestCase: self.norm_tb_columns, DecimalFunction.get_decimal_agg_funcs, ) - - self.test_decimal_last_first_func() + self.test_decimal_cast_func(self.db_name, self.norm_table_name, self.norm_tb_columns) def test_query_decimal(self): self.test_decimal_operators()