From 4a8bbc36bd9d383741e8df30208ba09a294f38eb Mon Sep 17 00:00:00 2001 From: Jing Sima Date: Tue, 3 Sep 2024 17:05:40 +0800 Subject: [PATCH] feat:[TS-4893] refactor of rand function. --- include/libs/function/functionMgt.h | 1 - include/libs/nodes/querynodes.h | 1 + source/libs/function/inc/functionMgtInt.h | 1 - source/libs/function/src/builtins.c | 24 +++++++++++++++++++++- source/libs/function/src/functionMgt.c | 2 -- source/libs/parser/inc/parTranslater.h | 1 + source/libs/parser/src/parTranslater.c | 3 +++ source/libs/planner/src/planOptimizer.c | 7 +++++-- source/libs/scalar/src/scalar.c | 16 ++++----------- source/libs/scalar/src/sclfunc.c | 7 ++++--- tests/army/query/function/test_function.py | 22 ++++++++++++++++++++ 11 files changed, 63 insertions(+), 22 deletions(-) diff --git a/include/libs/function/functionMgt.h b/include/libs/function/functionMgt.h index 5c77b8141d..233e201ac3 100644 --- a/include/libs/function/functionMgt.h +++ b/include/libs/function/functionMgt.h @@ -280,7 +280,6 @@ bool fmIsSkipScanCheckFunc(int32_t funcId); bool fmIsPrimaryKeyFunc(int32_t funcId); bool fmIsProcessByRowFunc(int32_t funcId); bool fmisSelectGroupConstValueFunc(int32_t funcId); -bool fmIsCalcEachRowFunc(int32_t funcId); void getLastCacheDataType(SDataType* pType, int32_t pkBytes); int32_t createFunction(const char* pName, SNodeList* pParameterList, SFunctionNode** pFunc); diff --git a/include/libs/nodes/querynodes.h b/include/libs/nodes/querynodes.h index e5c29dc16a..1b02203c07 100644 --- a/include/libs/nodes/querynodes.h +++ b/include/libs/nodes/querynodes.h @@ -192,6 +192,7 @@ typedef struct SFunctionNode { int32_t originalFuncId; ETrimType trimType; bool hasSMA; + bool dual; // whether select stmt without from stmt, true for without. } SFunctionNode; typedef struct STableNode { diff --git a/source/libs/function/inc/functionMgtInt.h b/source/libs/function/inc/functionMgtInt.h index 76f99ae517..a50562d78d 100644 --- a/source/libs/function/inc/functionMgtInt.h +++ b/source/libs/function/inc/functionMgtInt.h @@ -58,7 +58,6 @@ extern "C" { #define FUNC_MGT_TSMA_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(29) #define FUNC_MGT_COUNT_LIKE_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(30) // funcs that should also return 0 when no rows found #define FUNC_MGT_PROCESS_BY_ROW FUNC_MGT_FUNC_CLASSIFICATION_MASK(31) -#define FUNC_MGT_CALC_EACH_ROW_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(32) #define FUNC_MGT_TEST_MASK(val, mask) (((val) & (mask)) != 0) diff --git a/source/libs/function/src/builtins.c b/source/libs/function/src/builtins.c index 87e8f9473e..3ec63a6d34 100644 --- a/source/libs/function/src/builtins.c +++ b/source/libs/function/src/builtins.c @@ -271,6 +271,21 @@ static int32_t addUint8Param(SNodeList** pList, uint8_t param) { return TSDB_CODE_SUCCESS; } +static int32_t addPseudoParam(SNodeList** pList) { + SNode *pseudoNode = NULL; + int32_t code = nodesMakeNode(QUERY_NODE_LEFT_VALUE, &pseudoNode); + if (pseudoNode == NULL) { + return code; + } + + code = nodesListMakeAppend(pList, pseudoNode); + if (TSDB_CODE_SUCCESS != code) { + nodesDestroyNode(pseudoNode); + return code; + } + return TSDB_CODE_SUCCESS; +} + static SDataType* getSDataTypeFromNode(SNode* pNode) { if (pNode == NULL) return NULL; if (nodesIsExprNode(pNode)) { @@ -610,6 +625,13 @@ static int32_t translateRand(SFunctionNode* pFunc, char* pErrBuf, int32_t len) { } } + if (!pFunc->dual) { + int32_t code = addPseudoParam(&pFunc->pParameterList); + if (code != TSDB_CODE_SUCCESS) { + return code; + } + } + pFunc->node.resType = (SDataType){.bytes = tDataTypes[TSDB_DATA_TYPE_DOUBLE].bytes, .type = TSDB_DATA_TYPE_DOUBLE}; return TSDB_CODE_SUCCESS; @@ -4768,7 +4790,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = { { .name = "rand", .type = FUNCTION_TYPE_RAND, - .classification = FUNC_MGT_SCALAR_FUNC | FUNC_MGT_CALC_EACH_ROW_FUNC, + .classification = FUNC_MGT_SCALAR_FUNC, .translateFunc = translateRand, .getEnvFunc = NULL, .initFunc = NULL, diff --git a/source/libs/function/src/functionMgt.c b/source/libs/function/src/functionMgt.c index e75ca9cea8..1a927a1576 100644 --- a/source/libs/function/src/functionMgt.c +++ b/source/libs/function/src/functionMgt.c @@ -287,8 +287,6 @@ bool fmIsProcessByRowFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId bool fmIsIgnoreNullFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_IGNORE_NULL_FUNC); } -bool fmIsCalcEachRowFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_CALC_EACH_ROW_FUNC); } - void fmFuncMgtDestroy() { void* m = gFunMgtService.pFuncNameHashTable; if (m != NULL && atomic_val_compare_exchange_ptr((void**)&gFunMgtService.pFuncNameHashTable, m, 0) == m) { diff --git a/source/libs/parser/inc/parTranslater.h b/source/libs/parser/inc/parTranslater.h index 99e303e51c..93d6645e12 100644 --- a/source/libs/parser/inc/parTranslater.h +++ b/source/libs/parser/inc/parTranslater.h @@ -45,6 +45,7 @@ typedef struct STranslateContext { bool showRewrite; SNode* pPrevRoot; SNode* pPostRoot; + bool dual; // whether select stmt without from stmt, true for without. } STranslateContext; int32_t biRewriteToTbnameFunc(STranslateContext* pCxt, SNode** ppNode, bool* pRet); diff --git a/source/libs/parser/src/parTranslater.c b/source/libs/parser/src/parTranslater.c index 62943cb6d5..a0b8c13e58 100755 --- a/source/libs/parser/src/parTranslater.c +++ b/source/libs/parser/src/parTranslater.c @@ -2307,6 +2307,7 @@ static bool hasInvalidFuncNesting(SNodeList* pParameterList) { static int32_t getFuncInfo(STranslateContext* pCxt, SFunctionNode* pFunc) { // the time precision of the function execution environment + pFunc->dual = pCxt->dual; pFunc->node.resType.precision = getPrecisionFromCurrStmt(pCxt->pCurrStmt, TSDB_TIME_PRECISION_MILLI); int32_t code = fmGetFuncInfo(pFunc, pCxt->msgBuf.buf, pCxt->msgBuf.len); if (TSDB_CODE_FUNC_NOT_BUILTIN_FUNTION == code) { @@ -6696,11 +6697,13 @@ static int32_t replaceOrderByAliasForSelect(STranslateContext* pCxt, SSelectStmt static int32_t translateSelectWithoutFrom(STranslateContext* pCxt, SSelectStmt* pSelect) { pCxt->pCurrStmt = (SNode*)pSelect; pCxt->currClause = SQL_CLAUSE_SELECT; + pCxt->dual = true; return translateExprList(pCxt, pSelect->pProjectionList); } static int32_t translateSelectFrom(STranslateContext* pCxt, SSelectStmt* pSelect) { pCxt->pCurrStmt = (SNode*)pSelect; + pCxt->dual = false; int32_t code = translateFrom(pCxt, &pSelect->pFromTable); if (TSDB_CODE_SUCCESS == code) { pSelect->precision = ((STableNode*)pSelect->pFromTable)->precision; diff --git a/source/libs/planner/src/planOptimizer.c b/source/libs/planner/src/planOptimizer.c index 820d2f0d26..988440227b 100644 --- a/source/libs/planner/src/planOptimizer.c +++ b/source/libs/planner/src/planOptimizer.c @@ -294,6 +294,9 @@ static bool scanPathOptIsSpecifiedFuncType(const SFunctionNode* pFunc, bool (*ty return true; } +static bool isMinMaxFunction(int32_t funcType) { + return funcType == FUNCTION_TYPE_MIN || funcType == FUNCTION_TYPE_MAX; +} static int32_t scanPathOptGetRelatedFuncs(SScanLogicNode* pScan, SNodeList** pSdrFuncs, SNodeList** pDsoFuncs) { SNodeList* pAllFuncs = scanPathOptGetAllFuncs(pScan->node.pParent); SNodeList* pTmpSdrFuncs = NULL; @@ -303,8 +306,8 @@ static int32_t scanPathOptGetRelatedFuncs(SScanLogicNode* pScan, SNodeList** pSd FOREACH(pNode, pAllFuncs) { SFunctionNode* pFunc = (SFunctionNode*)pNode; int32_t code = TSDB_CODE_SUCCESS; - if (scanPathOptIsSpecifiedFuncType(pFunc, fmIsSpecialDataRequiredFunc) && - ((pFunc->funcType == FUNCTION_TYPE_MIN || pFunc->funcType == FUNCTION_TYPE_MAX) && pFunc->hasSMA)) { + if ((!isMinMaxFunction(pFunc->funcType) && scanPathOptIsSpecifiedFuncType(pFunc, fmIsSpecialDataRequiredFunc)) || + (isMinMaxFunction(pFunc->funcType) && pFunc->hasSMA)) { SNode* pNew = NULL; code = nodesCloneNode(pNode, &pNew); if (TSDB_CODE_SUCCESS == code) { diff --git a/source/libs/scalar/src/scalar.c b/source/libs/scalar/src/scalar.c index 4b173b85e9..b9a95a3216 100644 --- a/source/libs/scalar/src/scalar.c +++ b/source/libs/scalar/src/scalar.c @@ -464,7 +464,7 @@ int32_t sclInitParam(SNode *node, SScalarParam *param, SScalarCtx *ctx, int32_t } int32_t sclInitParamList(SScalarParam **pParams, SNodeList *pParamList, SScalarCtx *ctx, int32_t *paramNum, - int32_t *rowNum, bool needCalcForEachRow) { + int32_t *rowNum) { int32_t code = 0; if (NULL == pParamList) { if (ctx->pBlockList) { @@ -505,13 +505,6 @@ int32_t sclInitParamList(SScalarParam **pParams, SNodeList *pParamList, SScalarC } else { FOREACH(tnode, pParamList) { SCL_ERR_JRET(sclInitParam(tnode, ¶mList[i], ctx, rowNum)); - if (needCalcForEachRow) { - SSDataBlock *pBlock = taosArrayGetP(ctx->pBlockList, 0); - if (NULL == pBlock) { - SCL_ERR_RET(TSDB_CODE_OUT_OF_RANGE); - } - paramList[i].numOfRows = pBlock->info.rows; - } ++i; } } @@ -759,7 +752,7 @@ int32_t sclExecFunction(SFunctionNode *node, SScalarCtx *ctx, SScalarParam *outp int32_t rowNum = 0; int32_t paramNum = 0; int32_t code = 0; - SCL_ERR_RET(sclInitParamList(¶ms, node->pParameterList, ctx, ¶mNum, &rowNum, fmIsCalcEachRowFunc(node->funcId))); + SCL_ERR_RET(sclInitParamList(¶ms, node->pParameterList, ctx, ¶mNum, &rowNum)); if (fmIsUserDefinedFunc(node->funcId)) { code = callUdfScalarFunc(node->functionName, params, paramNum, output); @@ -818,7 +811,7 @@ int32_t sclExecLogic(SLogicConditionNode *node, SScalarCtx *ctx, SScalarParam *o int32_t rowNum = 0; int32_t paramNum = 0; int32_t code = 0; - SCL_ERR_RET(sclInitParamList(¶ms, node->pParameterList, ctx, ¶mNum, &rowNum, false)); + SCL_ERR_RET(sclInitParamList(¶ms, node->pParameterList, ctx, ¶mNum, &rowNum)); if (NULL == params) { output->numOfRows = 0; return TSDB_CODE_SUCCESS; @@ -1187,8 +1180,7 @@ EDealRes sclRewriteFunction(SNode **pNode, SScalarCtx *ctx) { SFunctionNode *node = (SFunctionNode *)*pNode; SNode *tnode = NULL; if ((!fmIsScalarFunc(node->funcId) && (!ctx->dual)) || - fmIsUserDefinedFunc(node->funcId) || - (fmIsCalcEachRowFunc(node->funcId) && (!ctx->dual))) { + fmIsUserDefinedFunc(node->funcId)) { return DEAL_RES_CONTINUE; } diff --git a/source/libs/scalar/src/sclfunc.c b/source/libs/scalar/src/sclfunc.c index 11587a1961..5fc7c06d57 100644 --- a/source/libs/scalar/src/sclfunc.c +++ b/source/libs/scalar/src/sclfunc.c @@ -2888,16 +2888,17 @@ int32_t floorFunction(SScalarParam *pInput, int32_t inputNum, SScalarParam *pOut } int32_t randFunction(SScalarParam *pInput, int32_t inputNum, SScalarParam *pOutput) { - if (inputNum == 1 && !IS_NULL_TYPE(GET_PARAM_TYPE(&pInput[0]))) { + if (!IS_NULL_TYPE(GET_PARAM_TYPE(&pInput[0]))) { int32_t seed; GET_TYPED_DATA(seed, int32_t, GET_PARAM_TYPE(&pInput[0]), pInput[0].columnData->pData); taosSeedRand(seed); } - for (int32_t i = 0; i < pInput->numOfRows; ++i) { + int32_t numOfRows = inputNum == 1 ? pInput[0].numOfRows : TMAX(pInput[0].numOfRows, pInput[1].numOfRows); + for (int32_t i = 0; i < numOfRows; ++i) { double random_value = (double)(taosRand() % RAND_MAX) / RAND_MAX; colDataSetDouble(pOutput->columnData, i, &random_value); } - pOutput->numOfRows = pInput->numOfRows; + pOutput->numOfRows = numOfRows; return TSDB_CODE_SUCCESS; } diff --git a/tests/army/query/function/test_function.py b/tests/army/query/function/test_function.py index c1a77343c2..4981e93563 100644 --- a/tests/army/query/function/test_function.py +++ b/tests/army/query/function/test_function.py @@ -511,8 +511,29 @@ class TDTestCase(TBase): expectErrInfo="Not supported timzone format") # TS-5340 def test_min(self): self.test_normal_query("min") + + tdSql.query("select min(var1), min(id) from ts_4893.d0;") + tdSql.checkRows(1) + tdSql.checkData(0, 0, 'abc一二三abc一二三abc') + tdSql.checkData(0, 1, 0) def test_max(self): self.test_normal_query("max") + tdSql.query("select max(var1), max(id) from ts_4893.d0;") + tdSql.checkRows(1) + tdSql.checkData(0, 0, '一二三四五六七八九十') + tdSql.checkData(0, 1, 9999) + def test_rand(self): + tdSql.query("select rand();") + tdSql.checkRows(1) + + tdSql.query("select rand(1);") + tdSql.checkRows(1) + + tdSql.query("select rand(1) from ts_4893.meters limit 10;") + tdSql.checkRows(10) + + tdSql.query("select rand(id) from ts_4893.d0 limit 10;") + tdSql.checkRows(10) # run def run(self): tdLog.debug(f"start to excute {__file__}") @@ -530,6 +551,7 @@ class TDTestCase(TBase): self.test_sign() self.test_degrees() self.test_radians() + self.test_rand() # char function self.test_char_length()