feat: avg function rewrite

This commit is contained in:
Xiaoyu Wang 2022-06-13 18:36:23 +08:00
parent 3eda47e159
commit 3954edb366
8 changed files with 83 additions and 9 deletions

View File

@ -171,6 +171,10 @@ bool fmIsRepeatScanFunc(int32_t funcId);
bool fmIsUserDefinedFunc(int32_t funcId); bool fmIsUserDefinedFunc(int32_t funcId);
bool fmIsDistExecFunc(int32_t funcId); bool fmIsDistExecFunc(int32_t funcId);
bool fmIsForbidFillFunc(int32_t funcId); bool fmIsForbidFillFunc(int32_t funcId);
bool fmIsForbidStreamFunc(int32_t funcId);
bool fmNeedRewrite(int32_t funcId);
int32_t fmRewriteFunc(SNode** pFunc);
int32_t fmGetDistMethod(const SFunctionNode* pFunc, SFunctionNode** pPartialFunc, SFunctionNode** pMergeFunc); int32_t fmGetDistMethod(const SFunctionNode* pFunc, SFunctionNode** pPartialFunc, SFunctionNode** pMergeFunc);

View File

@ -23,6 +23,7 @@ extern "C" {
#include "functionMgtInt.h" #include "functionMgtInt.h"
typedef int32_t (*FTranslateFunc)(SFunctionNode* pFunc, char* pErrBuf, int32_t len); typedef int32_t (*FTranslateFunc)(SFunctionNode* pFunc, char* pErrBuf, int32_t len);
typedef int32_t (*FRewriteFunc)(SNode** pFunc);
typedef EFuncDataRequired (*FFuncDataRequired)(SFunctionNode* pFunc, STimeWindow* pTimeWindow); typedef EFuncDataRequired (*FFuncDataRequired)(SFunctionNode* pFunc, STimeWindow* pTimeWindow);
typedef struct SBuiltinFuncDefinition { typedef struct SBuiltinFuncDefinition {
@ -30,6 +31,7 @@ typedef struct SBuiltinFuncDefinition {
EFunctionType type; EFunctionType type;
uint64_t classification; uint64_t classification;
FTranslateFunc translateFunc; FTranslateFunc translateFunc;
FRewriteFunc rewriteFunc;
FFuncDataRequired dataRequiredFunc; FFuncDataRequired dataRequiredFunc;
FExecGetEnv getEnvFunc; FExecGetEnv getEnvFunc;
FExecInit initFunc; FExecInit initFunc;

View File

@ -42,6 +42,7 @@ extern "C" {
#define FUNC_MGT_SELECT_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(13) #define FUNC_MGT_SELECT_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(13)
#define FUNC_MGT_REPEAT_SCAN_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(14) #define FUNC_MGT_REPEAT_SCAN_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(14)
#define FUNC_MGT_FORBID_FILL_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(15) #define FUNC_MGT_FORBID_FILL_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(15)
#define FUNC_MGT_FORBID_STREAM_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(16)
#define FUNC_MGT_TEST_MASK(val, mask) (((val) & (mask)) != 0) #define FUNC_MGT_TEST_MASK(val, mask) (((val) & (mask)) != 0)

View File

@ -1333,6 +1333,49 @@ static bool getBlockDistFuncEnv(SFunctionNode* UNUSED_PARAM(pFunc), SFuncExecEnv
return true; return true;
} }
static int32_t rewriteAvg(SNode** pFunc) {
SOperatorNode* pOper = (SOperatorNode*)nodesMakeNode(QUERY_NODE_OPERATOR);
if (NULL == pOper) {
return TSDB_CODE_OUT_OF_MEMORY;
}
SFunctionNode* pAvg = (SFunctionNode*)*pFunc;
pOper->node.resType = pAvg->node.resType;
strcpy(pOper->node.aliasName, pAvg->node.aliasName);
pOper->opType = OP_TYPE_DIV;
pOper->pLeft = nodesMakeNode(QUERY_NODE_FUNCTION);
pOper->pRight = nodesMakeNode(QUERY_NODE_FUNCTION);
if (NULL == pOper->pLeft || NULL == pOper->pRight) {
nodesDestroyNode((SNode*)pOper);
return TSDB_CODE_OUT_OF_MEMORY;
}
SFunctionNode* pSum = (SFunctionNode*)pOper->pLeft;
strcpy(pSum->functionName, "sum");
pSum->pParameterList = nodesCloneList(pAvg->pParameterList);
if (NULL == pSum->pParameterList) {
nodesDestroyNode((SNode*)pOper);
return TSDB_CODE_OUT_OF_MEMORY;
}
char msgBuf[64] = {0};
int32_t code = fmGetFuncInfo(pSum, msgBuf, sizeof(msgBuf));
if (TSDB_CODE_SUCCESS == code) {
SFunctionNode* pCount = (SFunctionNode*)pOper->pRight;
strcpy(pCount->functionName, "count");
TSWAP(pCount->pParameterList, pAvg->pParameterList);
code = fmGetFuncInfo(pCount, msgBuf, sizeof(msgBuf));
}
if (TSDB_CODE_SUCCESS == code) {
nodesDestroyNode((SNode*)pAvg);
*pFunc = (SNode*)pOper;
} else {
nodesDestroyNode((SNode*)pOper);
}
return code;
}
// clang-format off // clang-format off
const SBuiltinFuncDefinition funcMgtBuiltins[] = { const SBuiltinFuncDefinition funcMgtBuiltins[] = {
{ {
@ -1422,6 +1465,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = {
.type = FUNCTION_TYPE_AVG, .type = FUNCTION_TYPE_AVG,
.classification = FUNC_MGT_AGG_FUNC, .classification = FUNC_MGT_AGG_FUNC,
.translateFunc = translateInNumOutDou, .translateFunc = translateInNumOutDou,
.rewriteFunc = rewriteAvg,
.getEnvFunc = getAvgFuncEnv, .getEnvFunc = getAvgFuncEnv,
.initFunc = avgFunctionSetup, .initFunc = avgFunctionSetup,
.processFunc = avgFunction, .processFunc = avgFunction,

View File

@ -161,6 +161,8 @@ bool fmIsUserDefinedFunc(int32_t funcId) { return funcId > FUNC_UDF_ID_START; }
bool fmIsForbidFillFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_FORBID_FILL_FUNC); } bool fmIsForbidFillFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_FORBID_FILL_FUNC); }
bool fmIsForbidStreamFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_FORBID_STREAM_FUNC); }
void fmFuncMgtDestroy() { void fmFuncMgtDestroy() {
void* m = gFunMgtService.pFuncNameHashTable; void* m = gFunMgtService.pFuncNameHashTable;
if (m != NULL && atomic_val_compare_exchange_ptr((void**)&gFunMgtService.pFuncNameHashTable, m, 0) == m) { if (m != NULL && atomic_val_compare_exchange_ptr((void**)&gFunMgtService.pFuncNameHashTable, m, 0) == m) {
@ -297,3 +299,12 @@ int32_t fmGetDistMethod(const SFunctionNode* pFunc, SFunctionNode** pPartialFunc
return code; return code;
} }
bool fmNeedRewrite(int32_t funcId) {
if (fmIsUserDefinedFunc(funcId)) {
return false;
}
return NULL != funcMgtBuiltins[funcId].rewriteFunc;
}
int32_t fmRewriteFunc(SNode** pFunc) { return funcMgtBuiltins[((SFunctionNode*)*pFunc)->funcId].rewriteFunc(pFunc); }

View File

@ -1076,29 +1076,32 @@ static void setFuncClassification(SSelectStmt* pSelect, SFunctionNode* pFunc) {
} }
} }
static EDealRes translateFunction(STranslateContext* pCxt, SFunctionNode* pFunc) { static EDealRes translateFunction(STranslateContext* pCxt, SFunctionNode** pFunc) {
SNode* pParam = NULL; SNode* pParam = NULL;
FOREACH(pParam, pFunc->pParameterList) { FOREACH(pParam, (*pFunc)->pParameterList) {
if (isMultiResFunc(pParam)) { if (isMultiResFunc(pParam)) {
return generateDealNodeErrMsg(pCxt, TSDB_CODE_PAR_WRONG_VALUE_TYPE, ((SExprNode*)pParam)->aliasName); return generateDealNodeErrMsg(pCxt, TSDB_CODE_PAR_WRONG_VALUE_TYPE, ((SExprNode*)pParam)->aliasName);
} }
} }
pCxt->errCode = getFuncInfo(pCxt, pFunc); pCxt->errCode = getFuncInfo(pCxt, *pFunc);
if (TSDB_CODE_SUCCESS == pCxt->errCode) { if (TSDB_CODE_SUCCESS == pCxt->errCode) {
pCxt->errCode = translateAggFunc(pCxt, pFunc); pCxt->errCode = translateAggFunc(pCxt, *pFunc);
} }
if (TSDB_CODE_SUCCESS == pCxt->errCode) { if (TSDB_CODE_SUCCESS == pCxt->errCode) {
pCxt->errCode = translateScanPseudoColumnFunc(pCxt, pFunc); pCxt->errCode = translateScanPseudoColumnFunc(pCxt, *pFunc);
} }
if (TSDB_CODE_SUCCESS == pCxt->errCode) { if (TSDB_CODE_SUCCESS == pCxt->errCode) {
pCxt->errCode = translateIndefiniteRowsFunc(pCxt, pFunc); pCxt->errCode = translateIndefiniteRowsFunc(pCxt, *pFunc);
} }
if (TSDB_CODE_SUCCESS == pCxt->errCode) { if (TSDB_CODE_SUCCESS == pCxt->errCode) {
pCxt->errCode = translateForbidFillFunc(pCxt, pFunc); pCxt->errCode = translateForbidFillFunc(pCxt, *pFunc);
} }
if (TSDB_CODE_SUCCESS == pCxt->errCode) { if (TSDB_CODE_SUCCESS == pCxt->errCode) {
setFuncClassification(pCxt->pCurrSelectStmt, pFunc); setFuncClassification(pCxt->pCurrSelectStmt, *pFunc);
}
if (TSDB_CODE_SUCCESS == pCxt->errCode && fmNeedRewrite((*pFunc)->funcId)) {
pCxt->errCode = fmRewriteFunc((SNode**)pFunc);
} }
return TSDB_CODE_SUCCESS == pCxt->errCode ? DEAL_RES_CONTINUE : DEAL_RES_ERROR; return TSDB_CODE_SUCCESS == pCxt->errCode ? DEAL_RES_CONTINUE : DEAL_RES_ERROR;
} }
@ -1123,7 +1126,7 @@ static EDealRes doTranslateExpr(SNode** pNode, void* pContext) {
case QUERY_NODE_OPERATOR: case QUERY_NODE_OPERATOR:
return translateOperator(pCxt, (SOperatorNode**)pNode); return translateOperator(pCxt, (SOperatorNode**)pNode);
case QUERY_NODE_FUNCTION: case QUERY_NODE_FUNCTION:
return translateFunction(pCxt, (SFunctionNode*)*pNode); return translateFunction(pCxt, (SFunctionNode**)pNode);
case QUERY_NODE_LOGIC_CONDITION: case QUERY_NODE_LOGIC_CONDITION:
return translateLogicCond(pCxt, (SLogicConditionNode*)*pNode); return translateLogicCond(pCxt, (SLogicConditionNode*)*pNode);
case QUERY_NODE_TEMP_TABLE: case QUERY_NODE_TEMP_TABLE:

View File

@ -35,6 +35,7 @@ string toString(int32_t code) { return tstrerror(code); }
// [...]; // [...];
class InsertTest : public Test { class InsertTest : public Test {
protected: protected:
InsertTest() : res_(nullptr) {}
~InsertTest() { reset(); } ~InsertTest() { reset(); }
void setDatabase(const string& acctId, const string& db) { void setDatabase(const string& acctId, const string& db) {

View File

@ -53,6 +53,14 @@ TEST_F(PlanGroupByTest, aggFunc) {
run("SELECT SUM(10), COUNT(c1) FROM t1 GROUP BY c2"); run("SELECT SUM(10), COUNT(c1) FROM t1 GROUP BY c2");
} }
TEST_F(PlanGroupByTest, rewriteFunc) {
useDb("root", "test");
run("SELECT AVG(c1) FROM t1");
run("SELECT AVG(c1) FROM t1 GROUP BY c2");
}
TEST_F(PlanGroupByTest, selectFunc) { TEST_F(PlanGroupByTest, selectFunc) {
useDb("root", "test"); useDb("root", "test");