diff --git a/include/libs/function/function.h b/include/libs/function/function.h index 0f0b89e4ba..5f4d34836b 100644 --- a/include/libs/function/function.h +++ b/include/libs/function/function.h @@ -265,6 +265,7 @@ typedef struct SqlFunctionCtx { bool bInputFinished; bool hasWindowOrGroup; // denote that the function is used with time window or group bool needCleanup; // denote that the function need to be cleaned up + int32_t inputType; // TODO wjm rename it } SqlFunctionCtx; typedef struct tExprNode { diff --git a/include/libs/nodes/querynodes.h b/include/libs/nodes/querynodes.h index 451d1794b1..82ad654232 100644 --- a/include/libs/nodes/querynodes.h +++ b/include/libs/nodes/querynodes.h @@ -194,6 +194,8 @@ typedef struct SFunctionNode { bool dual; // whether select stmt without from stmt, true for without. timezone_t tz; void *charsetCxt; + const struct SFunctionNode* pSrcFuncRef; + SDataType srcFuncInputType; } SFunctionNode; typedef struct STableNode { diff --git a/source/common/src/tdatablock.c b/source/common/src/tdatablock.c index e1f1fa2833..bdc6ad39ed 100644 --- a/source/common/src/tdatablock.c +++ b/source/common/src/tdatablock.c @@ -3339,7 +3339,7 @@ int32_t blockDecode(SSDataBlock* pBlock, const char* pData, const char** pEndPos if (IS_DECIMAL_TYPE(pColInfoData->info.type)) { pColInfoData->info.scale = *(char*)pStart; pColInfoData->info.precision = *((char*)pStart + 2); - pColInfoData->info.bytes >>= *((char*)pStart + 3); + pColInfoData->info.bytes &= 0xFF; } pStart += sizeof(int32_t); @@ -3733,7 +3733,7 @@ int32_t blockDataCheck(const SSDataBlock* pDataBlock) { } else if (TSDB_DATA_TYPE_DOUBLE == pCol->info.type) { double v = 0; GET_TYPED_DATA(v, double, pCol->info.type, colDataGetNumData(pCol, r), typeGetTypeModFromColInfo(&pCol->info)); - } else { + } else {// TODO wjm add decimal type GET_TYPED_DATA(typeValue, int64_t, pCol->info.type, colDataGetNumData(pCol, r), typeGetTypeModFromColInfo(&pCol->info)); } } diff --git a/source/libs/decimal/src/decimal.c b/source/libs/decimal/src/decimal.c index 3a3b632dc2..f96300cf4f 100644 --- a/source/libs/decimal/src/decimal.c +++ b/source/libs/decimal/src/decimal.c @@ -887,7 +887,7 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR .scale = pRightT->scale}; if (TSDB_DATA_TYPE_DECIMAL != pLeftT->type) { code = convertToDecimal(pLeftData, pLeftT, &left, <); - if (TSDB_CODE_SUCCESS != code) return code; + if (TSDB_CODE_SUCCESS != code) return code; // TODO add some logs here } else { left = *(Decimal*)pLeftData; } diff --git a/source/libs/function/inc/builtinsimpl.h b/source/libs/function/inc/builtinsimpl.h index d548ae956d..9d11d78c33 100644 --- a/source/libs/function/inc/builtinsimpl.h +++ b/source/libs/function/inc/builtinsimpl.h @@ -78,7 +78,7 @@ int32_t avgInvertFunction(SqlFunctionCtx* pCtx); #endif int32_t avgCombine(SqlFunctionCtx* pDestCtx, SqlFunctionCtx* pSourceCtx); -int32_t getAvgInfoSize(); +int32_t getAvgInfoSize(SFunctionNode* pFunc); bool getStdFuncEnv(struct SFunctionNode* pFunc, SFuncExecEnv* pEnv); int32_t stdFunctionSetup(SqlFunctionCtx* pCtx, SResultRowEntryInfo* pResultInfo); diff --git a/source/libs/function/inc/functionResInfoInt.h b/source/libs/function/inc/functionResInfoInt.h index a8cef86046..4380c73fa6 100644 --- a/source/libs/function/inc/functionResInfoInt.h +++ b/source/libs/function/inc/functionResInfoInt.h @@ -40,7 +40,6 @@ typedef struct SSumRes { int64_t isum; uint64_t usum; double dsum; - void* pData; // for decimal128 }; int16_t type; int64_t prevTs; @@ -49,15 +48,54 @@ typedef struct SSumRes { } SSumRes; typedef struct SDecimalSumRes { - int64_t flag; // currently not used + Decimal128 sum; // TODO wjm use same struct for the following four fields as SSumRes int16_t type; int64_t prevTs; bool isPrevTsSet; bool overflow; // if overflow is true, dsum to be used for any type; - Decimal128 sum; + uint32_t flag; // currently not used } SDecimalSumRes; +#define SUM_RES_GET_RES(pSumRes) ((SSumRes*)pSumRes) +#define SUM_RES_GET_DECIMAL_RES(pSumRes) ((SDecimalSumRes*)pSumRes) + +#define SUM_RES_GET_SIZE(type) IS_DECIMAL_TYPE(type) ? sizeof(SDecimalSumRes) : sizeof(SSumRes) + +#define SUM_RES_SET_TYPE(pSumRes, inputType, _type) \ + do { \ + if (IS_DECIMAL_TYPE(inputType)) \ + SUM_RES_GET_DECIMAL_RES(pSumRes)->type = _type; \ + else \ + SUM_RES_GET_RES(pSumRes)->type = _type; \ + } while (0) + +#define SUM_RES_GET_TYPE(pSumRes, inputType) \ + (IS_DECIMAL_TYPE(inputType) ? SUM_RES_GET_DECIMAL_RES(pSumRes)->type : SUM_RES_GET_RES(pSumRes)->type) +#define SUM_RES_GET_PREV_TS(pSumRes, inputType) \ + (IS_DECIMAL_TYPE(inputType) ? SUM_RES_GET_DECIMAL_RES(pSumRes)->prevTs : SUM_RES_GET_RES(pSumRes)->prevTs) +#define SUM_RES_GET_OVERFLOW(pSumRes, checkInputType, inputType) \ + (checkInputType && IS_DECIMAL_TYPE(inputType) ? SUM_RES_GET_DECIMAL_RES(pSumRes)->overflow \ + : SUM_RES_GET_RES(pSumRes)->overflow) + +#define SUM_RES_GET_ISUM(pSumRes) (((SSumRes*)(pSumRes))->isum) +#define SUM_RES_GET_USUM(pSumRes) (((SSumRes*)(pSumRes))->usum) +#define SUM_RES_GET_DSUM(pSumRes) (((SSumRes*)(pSumRes))->dsum) +#define SUM_RES_INC_ISUM(pSumRes, val) ((SSumRes*)(pSumRes))->isum += val +#define SUM_RES_INC_USUM(pSumRes, val) ((SSumRes*)(pSumRes))->usum += val +#define SUM_RES_INC_DSUM(pSumRes, val) ((SSumRes*)(pSumRes))->dsum += val + +#define SUM_RES_GET_DECIMAL_SUM(pSumRes) ((SDecimalSumRes*)(pSumRes))->sum +// TODO wjm check for overflow +#define SUM_RES_INC_DECIMAL_SUM(pSumRes, pVal, type) \ + do { \ + const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \ + if (type == TSDB_DATA_TYPE_DECIMAL64) \ + pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, WORD_NUM(Decimal64)); \ + else \ + pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, WORD_NUM(Decimal)); \ + } while (0) + typedef struct SMinmaxResInfo { bool assign; // assign the first value or not int64_t v; @@ -145,6 +183,55 @@ typedef struct SAvgRes { int16_t type; // store the original input type, used in merge function } SAvgRes; +typedef struct SDecimalAvgRes { + Decimal128 avg; + SDecimalSumRes sum; + int64_t count; + int16_t type; // store the original input type and scale, used in merge function + uint8_t scale; +} SDecimalAvgRes; + +#define AVG_RES_GET_RES(pAvgRes) ((SAvgRes*)pAvgRes) +#define AVG_RES_GET_DECIMAL_RES(pAvgRes) ((SDecimalAvgRes*)pAvgRes) +#define AVG_RES_SET_TYPE(pAvgRes, inputType, _type) \ + do { \ + if (IS_DECIMAL_TYPE(inputType)) \ + AVG_RES_GET_DECIMAL_RES(pAvgRes)->type = _type; \ + else \ + AVG_RES_GET_RES(pAvgRes)->type = _type; \ + } while (0) + +#define AVG_RES_SET_INPUT_SCALE(pAvgRes, _scale) \ + do { \ + AVG_RES_GET_DECIMAL_RES(pAvgRes)->scale = _scale; \ + } while (0) + +#define AVG_RES_GET_INPUT_SCALE(pAvgRes) (AVG_RES_GET_DECIMAL_RES(pAvgRes)->scale) + +#define AVG_RES_GET_TYPE(pAvgRes, inputType) \ + (IS_DECIMAL_TYPE(inputType) ? AVG_RES_GET_DECIMAL_RES(pAvgRes)->type : AVG_RES_GET_RES(pAvgRes)->type) + +#define AVG_RES_GET_SIZE(inputType) (IS_DECIMAL_TYPE(inputType) ? sizeof(SDecimalAvgRes) : sizeof(SAvgRes)) +#define AVG_RES_GET_AVG(pAvgRes) (AVG_RES_GET_RES(pAvgRes)->result) +#define AVG_RES_GET_SUM(pAvgRes) (AVG_RES_GET_RES(pAvgRes)->sum) +#define AVG_RES_GET_COUNT(pAvgRes, checkInputType, inputType) \ + (checkInputType && IS_DECIMAL_TYPE(inputType) ? AVG_RES_GET_DECIMAL_RES(pAvgRes)->count \ + : AVG_RES_GET_RES(pAvgRes)->count) +#define AVG_RES_INC_COUNT(pAvgRes, inputType, val) \ + do { \ + if (IS_DECIMAL_TYPE(inputType)) \ + AVG_RES_GET_DECIMAL_RES(pAvgRes)->count += val; \ + else \ + AVG_RES_GET_RES(pAvgRes)->count += val; \ + } while (0) + +#define AVG_RES_GET_DECIMAL_AVG(pAvgRes) (((SDecimalAvgRes*)(pAvgRes))->avg) +#define AVG_RES_GET_DECIMAL_SUM(pAvgRes) (((SDecimalAvgRes*)(pAvgRes))->sum) + +#define AVG_RES_GET_SUM_OVERFLOW(pAvgRes, checkInputType, inputType) \ + checkInputType&& IS_DECIMAL_TYPE(inputType) \ + ? SUM_RES_GET_OVERFLOW(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), true, inputType) \ + : SUM_RES_GET_OVERFLOW(&AVG_RES_GET_SUM(pAvgRes), false, inputType) // structs above are used in stream diff --git a/source/libs/function/src/builtins.c b/source/libs/function/src/builtins.c index 3ba55c850b..18ca0d2f3a 100644 --- a/source/libs/function/src/builtins.c +++ b/source/libs/function/src/builtins.c @@ -979,6 +979,33 @@ static int32_t translateMinMax(SFunctionNode* pFunc, char* pErrBuf, int32_t len) return TSDB_CODE_SUCCESS; } +static int32_t translateAvg(SFunctionNode* pFunc, char* pErrBuf, int32_t len) { + FUNC_ERR_RET(validateParam(pFunc, pErrBuf, len)); + + uint8_t dt = TSDB_DATA_TYPE_DOUBLE, prec = 0, scale = 0; + + bool isMergeFunc = pFunc->funcType == FUNCTION_TYPE_AVG_MERGE || pFunc->funcType == FUNCTION_TYPE_AVG_STATE_MERGE; + SDataType* pInputDt = getSDataTypeFromNode( + nodesListGetNode(isMergeFunc ? pFunc->pSrcFuncRef->pParameterList : pFunc->pParameterList, 0)); + if (IS_DECIMAL_TYPE(pInputDt->type)) { + pFunc->srcFuncInputType = *pInputDt; + SDataType sumDt = {.type = TSDB_DATA_TYPE_DECIMAL, + .bytes = tDataTypes[TSDB_DATA_TYPE_DECIMAL].bytes, + .precision = TSDB_DECIMAL_MAX_PRECISION, + .scale = pInputDt->scale}; + SDataType countDt = { + .type = TSDB_DATA_TYPE_BIGINT, .bytes = tDataTypes[TSDB_DATA_TYPE_BIGINT].bytes, .precision = 0, .scale = 0}; + SDataType avgDt = {0}; + int32_t code = decimalGetRetType(&sumDt, &countDt, OP_TYPE_DIV, &avgDt); + if (code != 0) return code; + dt = TSDB_DATA_TYPE_DECIMAL; + prec = TSDB_DECIMAL_MAX_PRECISION; + scale = avgDt.scale; + } + pFunc->node.resType = (SDataType){.bytes = tDataTypes[dt].bytes, .type = dt, .precision = prec, .scale = scale}; + return TSDB_CODE_SUCCESS; +} + // The return type is DOUBLE type static int32_t translateOutDouble(SFunctionNode* pFunc, char* pErrBuf, int32_t len) { FUNC_ERR_RET(validateParam(pFunc, pErrBuf, len)); @@ -1665,8 +1692,12 @@ static int32_t translateOutVarchar(SFunctionNode* pFunc, char* pErrBuf, int32_t break; case FUNCTION_TYPE_AVG_PARTIAL: case FUNCTION_TYPE_AVG_STATE: + pFunc->srcFuncInputType = ((SExprNode*)pFunc->pParameterList->pHead->pNode)->resType; + bytes = getAvgInfoSize(pFunc) + VARSTR_HEADER_SIZE; + break; case FUNCTION_TYPE_AVG_STATE_MERGE: - bytes = getAvgInfoSize() + VARSTR_HEADER_SIZE; + pFunc->srcFuncInputType = pFunc->pSrcFuncRef->srcFuncInputType; + bytes = getAvgInfoSize(pFunc) + VARSTR_HEADER_SIZE; break; case FUNCTION_TYPE_HISTOGRAM_PARTIAL: bytes = getHistogramInfoSize() + VARSTR_HEADER_SIZE; @@ -2064,7 +2095,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = { .paramAttribute = FUNC_PARAM_NO_SPECIFIC_ATTRIBUTE, .valueRangeFlag = FUNC_PARAM_NO_SPECIFIC_VALUE,}, .outputParaInfo = {.validDataType = FUNC_PARAM_SUPPORT_DOUBLE_TYPE | FUNC_PARAM_SUPPORT_DECIMAL_TYPE}}, - .translateFunc = translateOutDouble, + .translateFunc = translateAvg, .dataRequiredFunc = statisDataRequired, .getEnvFunc = getAvgFuncEnv, .initFunc = avgFunctionSetup, @@ -2121,7 +2152,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = { .paramAttribute = FUNC_PARAM_NO_SPECIFIC_ATTRIBUTE, .valueRangeFlag = FUNC_PARAM_NO_SPECIFIC_VALUE,}, .outputParaInfo = {.validDataType = FUNC_PARAM_SUPPORT_DOUBLE_TYPE}}, - .translateFunc = translateOutDouble, + .translateFunc = translateAvg, .getEnvFunc = getAvgFuncEnv, .initFunc = avgFunctionSetup, .processFunc = avgFunctionMerge, diff --git a/source/libs/function/src/builtinsimpl.c b/source/libs/function/src/builtinsimpl.c index bc375e6a51..008ed542d0 100644 --- a/source/libs/function/src/builtinsimpl.c +++ b/source/libs/function/src/builtinsimpl.c @@ -631,9 +631,10 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) { SInputColumnInfoData* pInput = &pCtx->input; SColumnDataAgg* pAgg = pInput->pColumnDataAgg[0]; int32_t type = pInput->pData[0]->info.type; + pCtx->inputType = type; - SSumRes* pSumRes = GET_ROWCELL_INTERBUF(GET_RES_INFO(pCtx)); - pSumRes->type = type; + void* pSumRes = GET_ROWCELL_INTERBUF(GET_RES_INFO(pCtx)); + SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, type); if (IS_NULL_TYPE(type)) { numOfElem = 0; @@ -644,19 +645,18 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) { numOfElem = pInput->numOfRows - pAgg->numOfNull; if (IS_SIGNED_NUMERIC_TYPE(type)) { - pSumRes->isum += pAgg->sum; + SUM_RES_INC_ISUM(pSumRes, pAgg->sum); } else if (IS_UNSIGNED_NUMERIC_TYPE(type)) { - pSumRes->usum += pAgg->sum; + SUM_RES_INC_USUM(pSumRes, pAgg->sum); } else if (IS_FLOAT_TYPE(type)) { - pSumRes->dsum += GET_DOUBLE_VAL((const char*)&(pAgg->sum)); + SUM_RES_INC_DSUM(pSumRes, GET_DOUBLE_VAL((const char*)&(pAgg->sum))); } else if (IS_DECIMAL_TYPE(type)) { - SDecimalSumRes* pDecimalSum = (SDecimalSumRes*)pSumRes; - pDecimalSum->type = TSDB_DATA_TYPE_DECIMAL; + SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL); const SDecimalOps* pOps = getDecimalOps(type); if (TSDB_DATA_TYPE_DECIMAL64 == type) { - pOps->add(&pDecimalSum->sum, &pAgg->sum, WORD_NUM(Decimal64)); + pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->sum, WORD_NUM(Decimal64)); } else if (TSDB_DATA_TYPE_DECIMAL == type) { - pOps->add(&pDecimalSum->sum, pAgg->decimal128Sum, WORD_NUM(Decimal)); + pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->decimal128Sum, WORD_NUM(Decimal)); } } } else { // computing based on the true data block @@ -667,42 +667,41 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) { if (IS_SIGNED_NUMERIC_TYPE(type) || type == TSDB_DATA_TYPE_BOOL) { if (type == TSDB_DATA_TYPE_TINYINT || type == TSDB_DATA_TYPE_BOOL) { - LIST_ADD_N(pSumRes->isum, pCol, start, numOfRows, int8_t, numOfElem); + LIST_ADD_N(SUM_RES_GET_ISUM(pSumRes), pCol, start, numOfRows, int8_t, numOfElem); } else if (type == TSDB_DATA_TYPE_SMALLINT) { - LIST_ADD_N(pSumRes->isum, pCol, start, numOfRows, int16_t, numOfElem); + LIST_ADD_N(SUM_RES_GET_ISUM(pSumRes), pCol, start, numOfRows, int16_t, numOfElem); } else if (type == TSDB_DATA_TYPE_INT) { - LIST_ADD_N(pSumRes->isum, pCol, start, numOfRows, int32_t, numOfElem); + LIST_ADD_N(SUM_RES_GET_ISUM(pSumRes), pCol, start, numOfRows, int32_t, numOfElem); } else if (type == TSDB_DATA_TYPE_BIGINT) { - LIST_ADD_N(pSumRes->isum, pCol, start, numOfRows, int64_t, numOfElem); + LIST_ADD_N(SUM_RES_GET_ISUM(pSumRes), pCol, start, numOfRows, int64_t, numOfElem); } } else if (IS_UNSIGNED_NUMERIC_TYPE(type)) { if (type == TSDB_DATA_TYPE_UTINYINT) { - LIST_ADD_N(pSumRes->usum, pCol, start, numOfRows, uint8_t, numOfElem); + LIST_ADD_N(SUM_RES_GET_USUM(pSumRes), pCol, start, numOfRows, uint8_t, numOfElem); } else if (type == TSDB_DATA_TYPE_USMALLINT) { - LIST_ADD_N(pSumRes->usum, pCol, start, numOfRows, uint16_t, numOfElem); + LIST_ADD_N(SUM_RES_GET_USUM(pSumRes), pCol, start, numOfRows, uint16_t, numOfElem); } else if (type == TSDB_DATA_TYPE_UINT) { - LIST_ADD_N(pSumRes->usum, pCol, start, numOfRows, uint32_t, numOfElem); + LIST_ADD_N(SUM_RES_GET_USUM(pSumRes), pCol, start, numOfRows, uint32_t, numOfElem); } else if (type == TSDB_DATA_TYPE_UBIGINT) { - LIST_ADD_N(pSumRes->usum, pCol, start, numOfRows, uint64_t, numOfElem); + LIST_ADD_N(SUM_RES_GET_USUM(pSumRes), pCol, start, numOfRows, uint64_t, numOfElem); } } else if (type == TSDB_DATA_TYPE_DOUBLE) { - LIST_ADD_N(pSumRes->dsum, pCol, start, numOfRows, double, numOfElem); + LIST_ADD_N(SUM_RES_GET_DSUM(pSumRes), pCol, start, numOfRows, double, numOfElem); } else if (type == TSDB_DATA_TYPE_FLOAT) { - LIST_ADD_N(pSumRes->dsum, pCol, start, numOfRows, float, numOfElem); + LIST_ADD_N(SUM_RES_GET_DSUM(pSumRes), pCol, start, numOfRows, float, numOfElem); } else if (IS_DECIMAL_TYPE(type)) { - SDecimalSumRes* pDecimalSum = (SDecimalSumRes*)pSumRes; - pSumRes->type = TSDB_DATA_TYPE_DECIMAL; + SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL); if (TSDB_DATA_TYPE_DECIMAL64 == type) { - LIST_ADD_DECIMAL_N(&pDecimalSum->sum, pCol, start, numOfRows, Decimal64, numOfElem); + LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal64, numOfElem); } else if (TSDB_DATA_TYPE_DECIMAL == type) { - LIST_ADD_DECIMAL_N(&pDecimalSum->sum, pCol, start, numOfRows, Decimal128, numOfElem); + LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal128, numOfElem); } // TODO wjm check overflow } } // check for overflow - if (IS_FLOAT_TYPE(type) && (isinf(pSumRes->dsum) || isnan(pSumRes->dsum))) { + if (IS_FLOAT_TYPE(type) && (isinf(SUM_RES_GET_DSUM(pSumRes)) || isnan(SUM_RES_GET_DSUM(pSumRes)))) { numOfElem = 0; } @@ -778,6 +777,7 @@ int32_t sumInvertFunction(SqlFunctionCtx* pCtx) { } #endif +// TODO wjm impl for decimal int32_t sumCombine(SqlFunctionCtx* pDestCtx, SqlFunctionCtx* pSourceCtx) { SResultRowEntryInfo* pDResInfo = GET_RES_INFO(pDestCtx); SSumRes* pDBuf = GET_ROWCELL_INTERBUF(pDResInfo); @@ -799,10 +799,7 @@ int32_t sumCombine(SqlFunctionCtx* pDestCtx, SqlFunctionCtx* pSourceCtx) { } bool getSumFuncEnv(SFunctionNode* pFunc, SFuncExecEnv* pEnv) { - if (pFunc->node.resType.type == TSDB_DATA_TYPE_DECIMAL) - pEnv->calcMemSize = sizeof(SDecimalSumRes); - else - pEnv->calcMemSize = sizeof(SSumRes); + pEnv->calcMemSize = SUM_RES_GET_SIZE(pFunc->node.resType.type); return true; } diff --git a/source/libs/function/src/detail/tavgfunction.c b/source/libs/function/src/detail/tavgfunction.c index 7313fc82f7..28ef2931fb 100644 --- a/source/libs/function/src/detail/tavgfunction.c +++ b/source/libs/function/src/detail/tavgfunction.c @@ -42,60 +42,74 @@ } while (0) // define signed number sum with check overflow -#define CHECK_OVERFLOW_SUM_SIGNED(out, val) \ - if (out->sum.overflow) { \ - out->sum.dsum += val; \ - } else if (out->sum.isum > 0 && val > 0 && INT64_MAX - out->sum.isum <= val || \ - out->sum.isum < 0 && val < 0 && INT64_MIN - out->sum.isum >= val) { \ - double dsum = (double)out->sum.isum; \ - out->sum.overflow = true; \ - out->sum.dsum = dsum + val; \ - } else { \ - out->sum.isum += val; \ - } +#define CHECK_OVERFLOW_SUM_SIGNED(pAvgRes, val) \ + do { \ + SAvgRes* out = pAvgRes; \ + if (out->sum.overflow) { \ + out->sum.dsum += val; \ + } else if (out->sum.isum > 0 && val > 0 && INT64_MAX - out->sum.isum <= val || \ + out->sum.isum < 0 && val < 0 && INT64_MIN - out->sum.isum >= val) { \ + double dsum = (double)out->sum.isum; \ + out->sum.overflow = true; \ + out->sum.dsum = dsum + val; \ + } else { \ + out->sum.isum += val; \ + } \ + } while (0) // val is big than INT64_MAX, val come from merge -#define CHECK_OVERFLOW_SUM_SIGNED_BIG(out, val, big) \ - if (out->sum.overflow) { \ - out->sum.dsum += val; \ - } else if (out->sum.isum > 0 && val > 0 && INT64_MAX - out->sum.isum <= val || \ - out->sum.isum < 0 && val < 0 && INT64_MIN - out->sum.isum >= val || \ - big) { \ - double dsum = (double)out->sum.isum; \ - out->sum.overflow = true; \ - out->sum.dsum = dsum + val; \ - } else { \ - out->sum.isum += val; \ - } +#define CHECK_OVERFLOW_SUM_SIGNED_BIG(pAvgRes, val, big) \ + do { \ + SAvgRes* out = pAvgRes; \ + if (out->sum.overflow) { \ + out->sum.dsum += val; \ + } else if (out->sum.isum > 0 && val > 0 && INT64_MAX - out->sum.isum <= val || \ + out->sum.isum < 0 && val < 0 && INT64_MIN - out->sum.isum >= val || big) { \ + double dsum = (double)out->sum.isum; \ + out->sum.overflow = true; \ + out->sum.dsum = dsum + val; \ + } else { \ + SUM_RES_INC_ISUM(&AVG_RES_GET_SUM(out), val); \ + } \ + } while (0) // define unsigned number sum with check overflow -#define CHECK_OVERFLOW_SUM_UNSIGNED(out, val) \ - if (out->sum.overflow) { \ - out->sum.dsum += val; \ - } else if (UINT64_MAX - out->sum.usum <= val) { \ - double dsum = (double)out->sum.usum; \ - out->sum.overflow = true; \ - out->sum.dsum = dsum + val; \ - } else { \ - out->sum.usum += val; \ - } +#define CHECK_OVERFLOW_SUM_UNSIGNED(pAvgRes, val) \ + do { \ + SAvgRes* out = pAvgRes; \ + if (out->sum.overflow) { \ + out->sum.dsum += val; \ + } else if (UINT64_MAX - out->sum.usum <= val) { \ + double dsum = (double)out->sum.usum; \ + out->sum.overflow = true; \ + out->sum.dsum = dsum + val; \ + } else { \ + out->sum.usum += val; \ + } \ + } while (0) // val is big than UINT64_MAX, val come from merge -#define CHECK_OVERFLOW_SUM_UNSIGNED_BIG(out, val, big) \ - if (out->sum.overflow) { \ - out->sum.dsum += val; \ - } else if (UINT64_MAX - out->sum.usum <= val || big) { \ - double dsum = (double)out->sum.usum; \ - out->sum.overflow = true; \ - out->sum.dsum = dsum + val; \ - } else { \ - out->sum.usum += val; \ - } +#define CHECK_OVERFLOW_SUM_UNSIGNED_BIG(pAvgRes, val, big) \ + do { \ + SAvgRes* out = pAvgRes; \ + if (out->sum.overflow) { \ + out->sum.dsum += val; \ + } else if (UINT64_MAX - out->sum.usum <= val || big) { \ + double dsum = (double)out->sum.usum; \ + out->sum.overflow = true; \ + out->sum.dsum = dsum + val; \ + } else { \ + out->sum.usum += val; \ + } \ + } while (0) -int32_t getAvgInfoSize() { return (int32_t)sizeof(SAvgRes); } +int32_t getAvgInfoSize(SFunctionNode* pFunc) { + if (pFunc->pSrcFuncRef) return AVG_RES_GET_SIZE(pFunc->pSrcFuncRef->srcFuncInputType.type); + return AVG_RES_GET_SIZE(pFunc->srcFuncInputType.type); +} -bool getAvgFuncEnv(SFunctionNode* UNUSED_PARAM(pFunc), SFuncExecEnv* pEnv) { - pEnv->calcMemSize = sizeof(SAvgRes); +bool getAvgFuncEnv(SFunctionNode* pFunc, SFuncExecEnv* pEnv) { + pEnv->calcMemSize =getAvgInfoSize(pFunc); return true; } @@ -107,27 +121,32 @@ int32_t avgFunctionSetup(SqlFunctionCtx* pCtx, SResultRowEntryInfo* pResultInfo) return TSDB_CODE_FUNC_SETUP_ERROR; } - SAvgRes* pRes = GET_ROWCELL_INTERBUF(pResultInfo); - (void)memset(pRes, 0, sizeof(SAvgRes)); + void* pRes = GET_ROWCELL_INTERBUF(pResultInfo); + (void)memset(pRes, 0, pCtx->resDataInfo.interBufSize); return TSDB_CODE_SUCCESS; } -static int32_t calculateAvgBySMAInfo(SAvgRes* pRes, int32_t numOfRows, int32_t type, const SColumnDataAgg* pAgg) { +static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type, const SColumnDataAgg* pAgg) { int32_t numOfElem = numOfRows - pAgg->numOfNull; - pRes->count += numOfElem; + AVG_RES_INC_COUNT(pRes, type, numOfElem); if (IS_SIGNED_NUMERIC_TYPE(type)) { CHECK_OVERFLOW_SUM_SIGNED(pRes, pAgg->sum); } else if (IS_UNSIGNED_NUMERIC_TYPE(type)) { CHECK_OVERFLOW_SUM_UNSIGNED(pRes, pAgg->sum); } else if (IS_FLOAT_TYPE(type)) { - pRes->sum.dsum += GET_DOUBLE_VAL((const char*)&(pAgg->sum)); + SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pRes), GET_DOUBLE_VAL((const char*)&(pAgg->sum))); + } else if (IS_DECIMAL_TYPE(type)) { + if (type == TSDB_DATA_TYPE_DECIMAL64) + SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->sum, TSDB_DATA_TYPE_DECIMAL64); + else + SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->decimal128Sum, TSDB_DATA_TYPE_DECIMAL); } return numOfElem; } -static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputColumnInfoData *pInput, SAvgRes* pRes) { +static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputColumnInfoData *pInput, void* pRes) { int32_t start = pInput->startRowIndex; int32_t numOfRows = pInput->numOfRows; int32_t numOfElems = 0; @@ -141,8 +160,8 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol } numOfElems += 1; - pRes->count += 1; - CHECK_OVERFLOW_SUM_SIGNED(pRes, plist[i]) + AVG_RES_INC_COUNT(pRes, TSDB_DATA_TYPE_TINYINT, 1); + CHECK_OVERFLOW_SUM_SIGNED(pRes, plist[i]); } break; @@ -156,8 +175,8 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol } numOfElems += 1; - pRes->count += 1; - CHECK_OVERFLOW_SUM_SIGNED(pRes, plist[i]) + AVG_RES_INC_COUNT(pRes, TSDB_DATA_TYPE_SMALLINT, 1); + CHECK_OVERFLOW_SUM_SIGNED(pRes, plist[i]); } break; } @@ -170,8 +189,8 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol } numOfElems += 1; - pRes->count += 1; - CHECK_OVERFLOW_SUM_SIGNED(pRes, plist[i]) + AVG_RES_INC_COUNT(pRes, TSDB_DATA_TYPE_INT, 1); + CHECK_OVERFLOW_SUM_SIGNED(pRes, plist[i]); } break; @@ -185,8 +204,8 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol } numOfElems += 1; - pRes->count += 1; - CHECK_OVERFLOW_SUM_SIGNED(pRes, plist[i]) + AVG_RES_INC_COUNT(pRes, TSDB_DATA_TYPE_BIGINT, 1); + CHECK_OVERFLOW_SUM_SIGNED(pRes, plist[i]); } break; } @@ -199,8 +218,8 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol } numOfElems += 1; - pRes->count += 1; - CHECK_OVERFLOW_SUM_UNSIGNED(pRes, plist[i]) + AVG_RES_INC_COUNT(pRes, TSDB_DATA_TYPE_UTINYINT, 1); + CHECK_OVERFLOW_SUM_UNSIGNED(pRes, plist[i]); } break; @@ -214,8 +233,8 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol } numOfElems += 1; - pRes->count += 1; - CHECK_OVERFLOW_SUM_UNSIGNED(pRes, plist[i]) + AVG_RES_INC_COUNT(pRes, TSDB_DATA_TYPE_USMALLINT, 1); + CHECK_OVERFLOW_SUM_UNSIGNED(pRes, plist[i]); } break; } @@ -228,8 +247,8 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol } numOfElems += 1; - pRes->count += 1; - CHECK_OVERFLOW_SUM_UNSIGNED(pRes, plist[i]) + AVG_RES_INC_COUNT(pRes, TSDB_DATA_TYPE_UINT, 1); + CHECK_OVERFLOW_SUM_UNSIGNED(pRes, plist[i]); } break; @@ -243,8 +262,8 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol } numOfElems += 1; - pRes->count += 1; - CHECK_OVERFLOW_SUM_UNSIGNED(pRes, plist[i]) + AVG_RES_INC_COUNT(pRes, TSDB_DATA_TYPE_UBIGINT, 1); + CHECK_OVERFLOW_SUM_UNSIGNED(pRes, plist[i]); } break; @@ -258,8 +277,8 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol } numOfElems += 1; - pRes->count += 1; - pRes->sum.dsum += plist[i]; + AVG_RES_INC_COUNT(pRes, TSDB_DATA_TYPE_FLOAT, 1); + SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pRes), plist[i]); } break; } @@ -272,12 +291,24 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol } numOfElems += 1; - pRes->count += 1; - pRes->sum.dsum += plist[i]; + AVG_RES_INC_COUNT(pRes, TSDB_DATA_TYPE_DOUBLE, 1); + SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pRes), plist[i]); } break; } + case TSDB_DATA_TYPE_DECIMAL64: + case TSDB_DATA_TYPE_DECIMAL: { + const char* pDec = pCol->pData; + for (int32_t i = start; i < numOfRows + start; ++i) { + if (colDataIsNull_f(pCol->nullbitmap, i)) { + continue; + } + numOfElems += 1; + AVG_RES_INC_COUNT(pRes, type, 1); + SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), (const void*)(pDec + i * tDataTypes[type].bytes), type); + } + } break; default: break; } @@ -292,8 +323,9 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { SInputColumnInfoData* pInput = &pCtx->input; SColumnDataAgg* pAgg = pInput->pColumnDataAgg[0]; int32_t type = pInput->pData[0]->info.type; + pCtx->inputType = type; - SAvgRes* pAvgRes = GET_ROWCELL_INTERBUF(GET_RES_INFO(pCtx)); + void* pAvgRes = GET_ROWCELL_INTERBUF(GET_RES_INFO(pCtx)); // computing based on the true data block SColumnInfoData* pCol = pInput->pData[0]; @@ -305,13 +337,14 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { goto _over; } - pAvgRes->type = type; + AVG_RES_SET_TYPE(pAvgRes, pCtx->inputType, type); + if (IS_DECIMAL_TYPE(type)) AVG_RES_SET_INPUT_SCALE(pAvgRes, pInput->pData[0]->info.scale); if (pInput->colDataSMAIsSet) { // try to use SMA if available numOfElem = calculateAvgBySMAInfo(pAvgRes, numOfRows, type, pAgg); } else if (!pCol->hasNull) { // try to employ the simd instructions to speed up the loop numOfElem = pInput->numOfRows; - pAvgRes->count += pInput->numOfRows; + AVG_RES_INC_COUNT(pAvgRes, pCtx->inputType, pInput->numOfRows); switch(type) { case TSDB_DATA_TYPE_UTINYINT: @@ -320,9 +353,9 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) { if (type == TSDB_DATA_TYPE_TINYINT) { - CHECK_OVERFLOW_SUM_SIGNED(pAvgRes, plist[i]) + CHECK_OVERFLOW_SUM_SIGNED(pAvgRes, plist[i]); } else { - CHECK_OVERFLOW_SUM_UNSIGNED(pAvgRes, (uint8_t)plist[i]) + CHECK_OVERFLOW_SUM_UNSIGNED(pAvgRes, (uint8_t)plist[i]); } } break; @@ -334,9 +367,9 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) { if (type == TSDB_DATA_TYPE_SMALLINT) { - CHECK_OVERFLOW_SUM_SIGNED(pAvgRes, plist[i]) + CHECK_OVERFLOW_SUM_SIGNED(pAvgRes, plist[i]); } else { - CHECK_OVERFLOW_SUM_UNSIGNED(pAvgRes, (uint16_t)plist[i]) + CHECK_OVERFLOW_SUM_UNSIGNED(pAvgRes, (uint16_t)plist[i]); } } break; @@ -348,9 +381,9 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) { if (type == TSDB_DATA_TYPE_INT) { - CHECK_OVERFLOW_SUM_SIGNED(pAvgRes, plist[i]) + CHECK_OVERFLOW_SUM_SIGNED(pAvgRes, plist[i]); } else { - CHECK_OVERFLOW_SUM_UNSIGNED(pAvgRes, (uint32_t)plist[i]) + CHECK_OVERFLOW_SUM_UNSIGNED(pAvgRes, (uint32_t)plist[i]); } } break; @@ -362,9 +395,9 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) { if (type == TSDB_DATA_TYPE_BIGINT) { - CHECK_OVERFLOW_SUM_SIGNED(pAvgRes, plist[i]) + CHECK_OVERFLOW_SUM_SIGNED(pAvgRes, plist[i]); } else { - CHECK_OVERFLOW_SUM_UNSIGNED(pAvgRes, (uint64_t)plist[i]) + CHECK_OVERFLOW_SUM_UNSIGNED(pAvgRes, (uint64_t)plist[i]); } } break; @@ -374,7 +407,7 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { const float* plist = (const float*) pCol->pData; for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) { - pAvgRes->sum.dsum += plist[i]; + SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pAvgRes), plist[i]); } break; } @@ -382,10 +415,24 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) { const double* plist = (const double*)pCol->pData; for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) { - pAvgRes->sum.dsum += plist[i]; + SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pAvgRes), plist[i]); } break; } + case TSDB_DATA_TYPE_DECIMAL: + case TSDB_DATA_TYPE_DECIMAL64: { + const char* pDec = pCol->pData; + // TODO wjm check for overflow + for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) { + if (type == TSDB_DATA_TYPE_DECIMAL64) { + SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes), + TSDB_DATA_TYPE_DECIMAL64); + } else { + SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes), + TSDB_DATA_TYPE_DECIMAL); + } + } + } break; default: return TSDB_CODE_FUNC_FUNTION_PARA_TYPE; } @@ -399,23 +446,30 @@ _over: return TSDB_CODE_SUCCESS; } -static void avgTransferInfo(SAvgRes* pInput, SAvgRes* pOutput) { - if (IS_NULL_TYPE(pInput->type)) { +static void avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) { + int32_t inputDT = pCtx->pExpr->pExpr->_function.pFunctNode->srcFuncInputType.type; + int32_t type = AVG_RES_GET_TYPE(pInput, inputDT); + pCtx->inputType = type; + if (IS_NULL_TYPE(type)) { return; } - pOutput->type = pInput->type; - if (IS_SIGNED_NUMERIC_TYPE(pOutput->type)) { - bool overflow = pInput->sum.overflow; - CHECK_OVERFLOW_SUM_SIGNED_BIG(pOutput, (overflow ? pInput->sum.dsum : pInput->sum.isum), overflow); - } else if (IS_UNSIGNED_NUMERIC_TYPE(pOutput->type)) { - bool overflow = pInput->sum.overflow; - CHECK_OVERFLOW_SUM_UNSIGNED_BIG(pOutput, (overflow ? pInput->sum.dsum : pInput->sum.usum), overflow); + + AVG_RES_SET_TYPE(pOutput, inputDT, type); + if (IS_SIGNED_NUMERIC_TYPE(type)) { + bool overflow = AVG_RES_GET_SUM_OVERFLOW(pInput, false, 0); + CHECK_OVERFLOW_SUM_SIGNED_BIG(pOutput, (overflow ? SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput)) : SUM_RES_GET_ISUM(&AVG_RES_GET_SUM(pInput))), overflow); + } else if (IS_UNSIGNED_NUMERIC_TYPE(type)) { + bool overflow = AVG_RES_GET_SUM_OVERFLOW(pInput, false, 0); + CHECK_OVERFLOW_SUM_UNSIGNED_BIG(pOutput, (overflow ? SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput)) : SUM_RES_GET_USUM(&AVG_RES_GET_SUM(pInput))), overflow); + } else if (IS_DECIMAL_TYPE(type)) { + AVG_RES_SET_INPUT_SCALE(pOutput, AVG_RES_GET_INPUT_SCALE(pInput)); + SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pOutput), &AVG_RES_GET_DECIMAL_SUM(pInput), TSDB_DATA_TYPE_DECIMAL); } else { - pOutput->sum.dsum += pInput->sum.dsum; + SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pOutput), SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput))); } - pOutput->count += pInput->count; + AVG_RES_INC_COUNT(pOutput, type, AVG_RES_GET_COUNT(pInput, true, type)); } int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) { @@ -431,15 +485,15 @@ int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) { return TSDB_CODE_FUNC_FUNTION_PARA_TYPE; } - SAvgRes* pInfo = GET_ROWCELL_INTERBUF(GET_RES_INFO(pCtx)); + void* pInfo = GET_ROWCELL_INTERBUF(GET_RES_INFO(pCtx)); int32_t start = pInput->startRowIndex; for (int32_t i = start; i < start + pInput->numOfRows; ++i) { if(colDataIsNull_s(pCol, i)) continue; char* data = colDataGetData(pCol, i); - SAvgRes* pInputInfo = (SAvgRes*)varDataVal(data); - avgTransferInfo(pInputInfo, pInfo); + void* pInputInfo = varDataVal(data); + avgTransferInfo(pCtx, pInputInfo, pInfo); } SET_VAL(GET_RES_INFO(pCtx), 1, 1); @@ -521,9 +575,9 @@ int32_t avgCombine(SqlFunctionCtx* pDestCtx, SqlFunctionCtx* pSourceCtx) { int16_t type = pDBuf->type == TSDB_DATA_TYPE_NULL ? pSBuf->type : pDBuf->type; if (IS_SIGNED_NUMERIC_TYPE(type)) { - CHECK_OVERFLOW_SUM_SIGNED(pDBuf, pSBuf->sum.isum) + CHECK_OVERFLOW_SUM_SIGNED(pDBuf, pSBuf->sum.isum); } else if (IS_UNSIGNED_NUMERIC_TYPE(type)) { - CHECK_OVERFLOW_SUM_UNSIGNED(pDBuf, pSBuf->sum.usum) + CHECK_OVERFLOW_SUM_UNSIGNED(pDBuf, pSBuf->sum.usum); } else { pDBuf->sum.dsum += pSBuf->sum.dsum; } @@ -535,23 +589,47 @@ int32_t avgCombine(SqlFunctionCtx* pDestCtx, SqlFunctionCtx* pSourceCtx) { int32_t avgFinalize(SqlFunctionCtx* pCtx, SSDataBlock* pBlock) { SResultRowEntryInfo* pEntryInfo = GET_RES_INFO(pCtx); - SAvgRes* pRes = GET_ROWCELL_INTERBUF(pEntryInfo); - int32_t type = pRes->type; + void* pRes = GET_ROWCELL_INTERBUF(pEntryInfo); + int32_t type = AVG_RES_GET_TYPE(pRes, pCtx->inputType); + int64_t count = AVG_RES_GET_COUNT(pRes, true, type); - if (pRes->count > 0) { - if(pRes->sum.overflow) { - // overflow flag set , use dsum - pRes->result = pRes->sum.dsum / ((double)pRes->count); + if (AVG_RES_GET_COUNT(pRes, true, pCtx->inputType) > 0) { + + if(AVG_RES_GET_SUM_OVERFLOW(pRes, true, pCtx->inputType)) { + // overflow flag set , use dsum TODO wjm check deicmal overflow and return error + AVG_RES_GET_AVG(pRes) = SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pRes)) / ((double)AVG_RES_GET_COUNT(pRes, false, 0)); }else if (IS_SIGNED_NUMERIC_TYPE(type)) { - pRes->result = pRes->sum.isum / ((double)pRes->count); + AVG_RES_GET_AVG(pRes) = SUM_RES_GET_ISUM(&AVG_RES_GET_SUM(pRes)) / ((double)AVG_RES_GET_COUNT(pRes, false, 0)); } else if (IS_UNSIGNED_NUMERIC_TYPE(type)) { - pRes->result = pRes->sum.usum / ((double)pRes->count); + AVG_RES_GET_AVG(pRes) = SUM_RES_GET_USUM(&AVG_RES_GET_SUM(pRes)) / ((double)AVG_RES_GET_COUNT(pRes, false, 0)); + } else if (IS_DECIMAL_TYPE(type)) { + int32_t slotId = pCtx->pExpr->base.resSchema.slotId; + SColumnInfoData* pCol = taosArrayGet(pBlock->pDataBlock, slotId); + SDataType sumDt = {.type = TSDB_DATA_TYPE_DECIMAL, + .bytes = tDataTypes[TSDB_DATA_TYPE_DECIMAL].bytes, + .precision = pCol->info.precision, + .scale = AVG_RES_GET_INPUT_SCALE(pRes)}; + SDataType countDt = { + .type = TSDB_DATA_TYPE_BIGINT, .bytes = tDataTypes[TSDB_DATA_TYPE_BIGINT].bytes, .precision = 0, .scale = 0}; + SDataType avgDt = {.type = TSDB_DATA_TYPE_DECIMAL, + .bytes = tDataTypes[TSDB_DATA_TYPE_DECIMAL].bytes, + .precision = pCol->info.precision, + .scale = pCol->info.scale}; + int64_t count = AVG_RES_GET_COUNT(pRes, true, type); + int32_t code = + decimalOp(OP_TYPE_DIV, &sumDt, &countDt, &avgDt, &SUM_RES_GET_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes)), + &count, &AVG_RES_GET_DECIMAL_AVG(pRes)); + if (code != TSDB_CODE_SUCCESS) { + return code; + } } else { - pRes->result = pRes->sum.dsum / ((double)pRes->count); + AVG_RES_GET_AVG(pRes) = SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pRes)) / ((double)AVG_RES_GET_COUNT(pRes, false, 0)); } } - - if (pRes->count == 0 || isinf(pRes->result) || isnan(pRes->result)) { + if (!IS_DECIMAL_TYPE(pCtx->inputType)) { + if (isinf(AVG_RES_GET_AVG(pRes)) || isnan(AVG_RES_GET_AVG(pRes))) pEntryInfo->numOfRes = 0; + } + if (AVG_RES_GET_COUNT(pRes, true, pCtx->inputType) == 0) { pEntryInfo->numOfRes = 0; } else { pEntryInfo->numOfRes = 1; @@ -562,8 +640,8 @@ int32_t avgFinalize(SqlFunctionCtx* pCtx, SSDataBlock* pBlock) { int32_t avgPartialFinalize(SqlFunctionCtx* pCtx, SSDataBlock* pBlock) { SResultRowEntryInfo* pResInfo = GET_RES_INFO(pCtx); - SAvgRes* pInfo = GET_ROWCELL_INTERBUF(GET_RES_INFO(pCtx)); - int32_t resultBytes = getAvgInfoSize(); + void* pInfo = GET_ROWCELL_INTERBUF(GET_RES_INFO(pCtx)); + int32_t resultBytes = AVG_RES_GET_SIZE(pCtx->inputType); char* res = taosMemoryCalloc(resultBytes + VARSTR_HEADER_SIZE, sizeof(char)); int32_t code = TSDB_CODE_SUCCESS; if (NULL == res) { diff --git a/source/libs/function/src/functionMgt.c b/source/libs/function/src/functionMgt.c index 03935bf1a0..22bedf7869 100644 --- a/source/libs/function/src/functionMgt.c +++ b/source/libs/function/src/functionMgt.c @@ -432,6 +432,7 @@ int32_t createFunctionWithSrcFunc(const char* pName, const SFunctionNode* pSrcFu (*ppFunc)->hasPk = pSrcFunc->hasPk; (*ppFunc)->pkBytes = pSrcFunc->pkBytes; + (*ppFunc)->pSrcFuncRef = pSrcFunc; (void)snprintf((*ppFunc)->functionName, sizeof((*ppFunc)->functionName), "%s", pName); (*ppFunc)->pParameterList = pParameterList; diff --git a/source/libs/nodes/src/nodesCloneFuncs.c b/source/libs/nodes/src/nodesCloneFuncs.c index 64dc073e42..e94a3b7f15 100644 --- a/source/libs/nodes/src/nodesCloneFuncs.c +++ b/source/libs/nodes/src/nodesCloneFuncs.c @@ -239,6 +239,7 @@ static int32_t functionNodeCopy(const SFunctionNode* pSrc, SFunctionNode* pDst) COPY_SCALAR_FIELD(pkBytes); COPY_SCALAR_FIELD(hasOriginalFunc); COPY_SCALAR_FIELD(originalFuncId); + COPY_OBJECT_FIELD(srcFuncInputType, sizeof(SDataType)); return TSDB_CODE_SUCCESS; } diff --git a/source/libs/nodes/src/nodesCodeFuncs.c b/source/libs/nodes/src/nodesCodeFuncs.c index 012ef41699..cfa9b58f90 100644 --- a/source/libs/nodes/src/nodesCodeFuncs.c +++ b/source/libs/nodes/src/nodesCodeFuncs.c @@ -4573,6 +4573,7 @@ static const char* jkFunctionPkBytes = "PkBytes"; static const char* jkFunctionIsMergeFunc = "IsMergeFunc"; static const char* jkFunctionMergeFuncOf = "MergeFuncOf"; static const char* jkFunctionTrimType = "TrimType"; +static const char* jkFunctionSrcFuncInputDT = "SrcFuncInputDataType"; static int32_t functionNodeToJson(const void* pObj, SJson* pJson) { const SFunctionNode* pNode = (const SFunctionNode*)pObj; @@ -4608,6 +4609,9 @@ static int32_t functionNodeToJson(const void* pObj, SJson* pJson) { if (TSDB_CODE_SUCCESS == code) { code = tjsonAddIntegerToObject(pJson, jkFunctionTrimType, pNode->trimType); } + if (TSDB_CODE_SUCCESS == code) { + code = dataTypeToJson(&pNode->srcFuncInputType, pJson); + } return code; } @@ -4645,6 +4649,9 @@ static int32_t jsonToFunctionNode(const SJson* pJson, void* pObj) { if (TSDB_CODE_SUCCESS == code) { tjsonGetNumberValue(pJson, jkFunctionTrimType, pNode->trimType, code); } + if (TSDB_CODE_SUCCESS == code) { + code = jsonToDataType(pJson, &pNode->srcFuncInputType); + } return code; } diff --git a/source/libs/nodes/src/nodesMsgFuncs.c b/source/libs/nodes/src/nodesMsgFuncs.c index b9aaa916ee..0ad56a8d0f 100644 --- a/source/libs/nodes/src/nodesMsgFuncs.c +++ b/source/libs/nodes/src/nodesMsgFuncs.c @@ -1154,6 +1154,7 @@ enum { FUNCTION_CODE_IS_MERGE_FUNC, FUNCTION_CODE_MERGE_FUNC_OF, FUNCTION_CODE_TRIM_TYPE, + FUNCTION_SRC_FUNC_INPUT_TYPE, }; static int32_t functionNodeToMsg(const void* pObj, STlvEncoder* pEncoder) { @@ -1190,6 +1191,9 @@ static int32_t functionNodeToMsg(const void* pObj, STlvEncoder* pEncoder) { if (TSDB_CODE_SUCCESS == code) { code = tlvEncodeEnum(pEncoder, FUNCTION_CODE_TRIM_TYPE, pNode->trimType); } + if (TSDB_CODE_SUCCESS == code) { + code = tlvEncodeObj(pEncoder, FUNCTION_SRC_FUNC_INPUT_TYPE, dataTypeInlineToMsg, &pNode->srcFuncInputType); + } return code; } @@ -1234,6 +1238,8 @@ static int32_t msgToFunctionNode(STlvDecoder* pDecoder, void* pObj) { case FUNCTION_CODE_TRIM_TYPE: code = tlvDecodeEnum(pTlv, &pNode->trimType, sizeof(pNode->trimType)); break; + case FUNCTION_SRC_FUNC_INPUT_TYPE: + code = tlvDecodeObjFromTlv(pTlv, msgToDataTypeInline, &pNode->srcFuncInputType); default: break; }