diff --git a/include/libs/function/functionMgt.h b/include/libs/function/functionMgt.h index 8888f6ca8e..0b33fe588a 100644 --- a/include/libs/function/functionMgt.h +++ b/include/libs/function/functionMgt.h @@ -171,6 +171,10 @@ bool fmIsRepeatScanFunc(int32_t funcId); bool fmIsUserDefinedFunc(int32_t funcId); bool fmIsDistExecFunc(int32_t funcId); bool fmIsForbidFillFunc(int32_t funcId); +bool fmIsForbidStreamFunc(int32_t funcId); + +bool fmNeedRewrite(int32_t funcId); +int32_t fmRewriteFunc(SNode** pFunc); int32_t fmGetDistMethod(const SFunctionNode* pFunc, SFunctionNode** pPartialFunc, SFunctionNode** pMergeFunc); diff --git a/source/libs/function/inc/builtins.h b/source/libs/function/inc/builtins.h index bc91875006..792b875529 100644 --- a/source/libs/function/inc/builtins.h +++ b/source/libs/function/inc/builtins.h @@ -23,6 +23,7 @@ extern "C" { #include "functionMgtInt.h" typedef int32_t (*FTranslateFunc)(SFunctionNode* pFunc, char* pErrBuf, int32_t len); +typedef int32_t (*FRewriteFunc)(SNode** pFunc); typedef EFuncDataRequired (*FFuncDataRequired)(SFunctionNode* pFunc, STimeWindow* pTimeWindow); typedef struct SBuiltinFuncDefinition { @@ -30,6 +31,7 @@ typedef struct SBuiltinFuncDefinition { EFunctionType type; uint64_t classification; FTranslateFunc translateFunc; + FRewriteFunc rewriteFunc; FFuncDataRequired dataRequiredFunc; FExecGetEnv getEnvFunc; FExecInit initFunc; diff --git a/source/libs/function/inc/functionMgtInt.h b/source/libs/function/inc/functionMgtInt.h index d1af6b6051..6fefcceb87 100644 --- a/source/libs/function/inc/functionMgtInt.h +++ b/source/libs/function/inc/functionMgtInt.h @@ -42,6 +42,7 @@ extern "C" { #define FUNC_MGT_SELECT_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(13) #define FUNC_MGT_REPEAT_SCAN_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(14) #define FUNC_MGT_FORBID_FILL_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(15) +#define FUNC_MGT_FORBID_STREAM_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(16) #define FUNC_MGT_TEST_MASK(val, mask) (((val) & (mask)) != 0) diff --git a/source/libs/function/src/builtins.c b/source/libs/function/src/builtins.c index 06faa53da4..6cad9c855c 100644 --- a/source/libs/function/src/builtins.c +++ b/source/libs/function/src/builtins.c @@ -1333,6 +1333,49 @@ static bool getBlockDistFuncEnv(SFunctionNode* UNUSED_PARAM(pFunc), SFuncExecEnv return true; } +static int32_t rewriteAvg(SNode** pFunc) { + SOperatorNode* pOper = (SOperatorNode*)nodesMakeNode(QUERY_NODE_OPERATOR); + if (NULL == pOper) { + return TSDB_CODE_OUT_OF_MEMORY; + } + + SFunctionNode* pAvg = (SFunctionNode*)*pFunc; + pOper->node.resType = pAvg->node.resType; + strcpy(pOper->node.aliasName, pAvg->node.aliasName); + pOper->opType = OP_TYPE_DIV; + pOper->pLeft = nodesMakeNode(QUERY_NODE_FUNCTION); + pOper->pRight = nodesMakeNode(QUERY_NODE_FUNCTION); + if (NULL == pOper->pLeft || NULL == pOper->pRight) { + nodesDestroyNode((SNode*)pOper); + return TSDB_CODE_OUT_OF_MEMORY; + } + + SFunctionNode* pSum = (SFunctionNode*)pOper->pLeft; + strcpy(pSum->functionName, "sum"); + pSum->pParameterList = nodesCloneList(pAvg->pParameterList); + if (NULL == pSum->pParameterList) { + nodesDestroyNode((SNode*)pOper); + return TSDB_CODE_OUT_OF_MEMORY; + } + char msgBuf[64] = {0}; + int32_t code = fmGetFuncInfo(pSum, msgBuf, sizeof(msgBuf)); + if (TSDB_CODE_SUCCESS == code) { + SFunctionNode* pCount = (SFunctionNode*)pOper->pRight; + strcpy(pCount->functionName, "count"); + TSWAP(pCount->pParameterList, pAvg->pParameterList); + code = fmGetFuncInfo(pCount, msgBuf, sizeof(msgBuf)); + } + + if (TSDB_CODE_SUCCESS == code) { + nodesDestroyNode((SNode*)pAvg); + *pFunc = (SNode*)pOper; + } else { + nodesDestroyNode((SNode*)pOper); + } + + return code; +} + // clang-format off const SBuiltinFuncDefinition funcMgtBuiltins[] = { { @@ -1422,6 +1465,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = { .type = FUNCTION_TYPE_AVG, .classification = FUNC_MGT_AGG_FUNC, .translateFunc = translateInNumOutDou, + .rewriteFunc = rewriteAvg, .getEnvFunc = getAvgFuncEnv, .initFunc = avgFunctionSetup, .processFunc = avgFunction, diff --git a/source/libs/function/src/functionMgt.c b/source/libs/function/src/functionMgt.c index 81cd29fc54..75f4ce52b9 100644 --- a/source/libs/function/src/functionMgt.c +++ b/source/libs/function/src/functionMgt.c @@ -161,6 +161,8 @@ bool fmIsUserDefinedFunc(int32_t funcId) { return funcId > FUNC_UDF_ID_START; } bool fmIsForbidFillFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_FORBID_FILL_FUNC); } +bool fmIsForbidStreamFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_FORBID_STREAM_FUNC); } + void fmFuncMgtDestroy() { void* m = gFunMgtService.pFuncNameHashTable; if (m != NULL && atomic_val_compare_exchange_ptr((void**)&gFunMgtService.pFuncNameHashTable, m, 0) == m) { @@ -297,3 +299,12 @@ int32_t fmGetDistMethod(const SFunctionNode* pFunc, SFunctionNode** pPartialFunc return code; } + +bool fmNeedRewrite(int32_t funcId) { + if (fmIsUserDefinedFunc(funcId)) { + return false; + } + return NULL != funcMgtBuiltins[funcId].rewriteFunc; +} + +int32_t fmRewriteFunc(SNode** pFunc) { return funcMgtBuiltins[((SFunctionNode*)*pFunc)->funcId].rewriteFunc(pFunc); } diff --git a/source/libs/parser/src/parTranslater.c b/source/libs/parser/src/parTranslater.c index fcf2f57dec..65e62c0d38 100644 --- a/source/libs/parser/src/parTranslater.c +++ b/source/libs/parser/src/parTranslater.c @@ -1076,29 +1076,32 @@ static void setFuncClassification(SSelectStmt* pSelect, SFunctionNode* pFunc) { } } -static EDealRes translateFunction(STranslateContext* pCxt, SFunctionNode* pFunc) { +static EDealRes translateFunction(STranslateContext* pCxt, SFunctionNode** pFunc) { SNode* pParam = NULL; - FOREACH(pParam, pFunc->pParameterList) { + FOREACH(pParam, (*pFunc)->pParameterList) { if (isMultiResFunc(pParam)) { return generateDealNodeErrMsg(pCxt, TSDB_CODE_PAR_WRONG_VALUE_TYPE, ((SExprNode*)pParam)->aliasName); } } - pCxt->errCode = getFuncInfo(pCxt, pFunc); + pCxt->errCode = getFuncInfo(pCxt, *pFunc); if (TSDB_CODE_SUCCESS == pCxt->errCode) { - pCxt->errCode = translateAggFunc(pCxt, pFunc); + pCxt->errCode = translateAggFunc(pCxt, *pFunc); } if (TSDB_CODE_SUCCESS == pCxt->errCode) { - pCxt->errCode = translateScanPseudoColumnFunc(pCxt, pFunc); + pCxt->errCode = translateScanPseudoColumnFunc(pCxt, *pFunc); } if (TSDB_CODE_SUCCESS == pCxt->errCode) { - pCxt->errCode = translateIndefiniteRowsFunc(pCxt, pFunc); + pCxt->errCode = translateIndefiniteRowsFunc(pCxt, *pFunc); } if (TSDB_CODE_SUCCESS == pCxt->errCode) { - pCxt->errCode = translateForbidFillFunc(pCxt, pFunc); + pCxt->errCode = translateForbidFillFunc(pCxt, *pFunc); } if (TSDB_CODE_SUCCESS == pCxt->errCode) { - setFuncClassification(pCxt->pCurrSelectStmt, pFunc); + setFuncClassification(pCxt->pCurrSelectStmt, *pFunc); + } + if (TSDB_CODE_SUCCESS == pCxt->errCode && fmNeedRewrite((*pFunc)->funcId)) { + pCxt->errCode = fmRewriteFunc((SNode**)pFunc); } return TSDB_CODE_SUCCESS == pCxt->errCode ? DEAL_RES_CONTINUE : DEAL_RES_ERROR; } @@ -1123,7 +1126,7 @@ static EDealRes doTranslateExpr(SNode** pNode, void* pContext) { case QUERY_NODE_OPERATOR: return translateOperator(pCxt, (SOperatorNode**)pNode); case QUERY_NODE_FUNCTION: - return translateFunction(pCxt, (SFunctionNode*)*pNode); + return translateFunction(pCxt, (SFunctionNode**)pNode); case QUERY_NODE_LOGIC_CONDITION: return translateLogicCond(pCxt, (SLogicConditionNode*)*pNode); case QUERY_NODE_TEMP_TABLE: diff --git a/source/libs/parser/test/parInsertTest.cpp b/source/libs/parser/test/parInsertTest.cpp index 3cea494b92..29edca0d40 100644 --- a/source/libs/parser/test/parInsertTest.cpp +++ b/source/libs/parser/test/parInsertTest.cpp @@ -35,6 +35,7 @@ string toString(int32_t code) { return tstrerror(code); } // [...]; class InsertTest : public Test { protected: + InsertTest() : res_(nullptr) {} ~InsertTest() { reset(); } void setDatabase(const string& acctId, const string& db) { diff --git a/source/libs/planner/test/planGroupByTest.cpp b/source/libs/planner/test/planGroupByTest.cpp index 201df2efde..78d0f7b21f 100644 --- a/source/libs/planner/test/planGroupByTest.cpp +++ b/source/libs/planner/test/planGroupByTest.cpp @@ -53,6 +53,14 @@ TEST_F(PlanGroupByTest, aggFunc) { run("SELECT SUM(10), COUNT(c1) FROM t1 GROUP BY c2"); } +TEST_F(PlanGroupByTest, rewriteFunc) { + useDb("root", "test"); + + run("SELECT AVG(c1) FROM t1"); + + run("SELECT AVG(c1) FROM t1 GROUP BY c2"); +} + TEST_F(PlanGroupByTest, selectFunc) { useDb("root", "test");