feat: avg function rewrite
This commit is contained in:
parent
3eda47e159
commit
3954edb366
|
@ -171,6 +171,10 @@ bool fmIsRepeatScanFunc(int32_t funcId);
|
|||
bool fmIsUserDefinedFunc(int32_t funcId);
|
||||
bool fmIsDistExecFunc(int32_t funcId);
|
||||
bool fmIsForbidFillFunc(int32_t funcId);
|
||||
bool fmIsForbidStreamFunc(int32_t funcId);
|
||||
|
||||
bool fmNeedRewrite(int32_t funcId);
|
||||
int32_t fmRewriteFunc(SNode** pFunc);
|
||||
|
||||
int32_t fmGetDistMethod(const SFunctionNode* pFunc, SFunctionNode** pPartialFunc, SFunctionNode** pMergeFunc);
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ extern "C" {
|
|||
#include "functionMgtInt.h"
|
||||
|
||||
typedef int32_t (*FTranslateFunc)(SFunctionNode* pFunc, char* pErrBuf, int32_t len);
|
||||
typedef int32_t (*FRewriteFunc)(SNode** pFunc);
|
||||
typedef EFuncDataRequired (*FFuncDataRequired)(SFunctionNode* pFunc, STimeWindow* pTimeWindow);
|
||||
|
||||
typedef struct SBuiltinFuncDefinition {
|
||||
|
@ -30,6 +31,7 @@ typedef struct SBuiltinFuncDefinition {
|
|||
EFunctionType type;
|
||||
uint64_t classification;
|
||||
FTranslateFunc translateFunc;
|
||||
FRewriteFunc rewriteFunc;
|
||||
FFuncDataRequired dataRequiredFunc;
|
||||
FExecGetEnv getEnvFunc;
|
||||
FExecInit initFunc;
|
||||
|
|
|
@ -42,6 +42,7 @@ extern "C" {
|
|||
#define FUNC_MGT_SELECT_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(13)
|
||||
#define FUNC_MGT_REPEAT_SCAN_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(14)
|
||||
#define FUNC_MGT_FORBID_FILL_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(15)
|
||||
#define FUNC_MGT_FORBID_STREAM_FUNC FUNC_MGT_FUNC_CLASSIFICATION_MASK(16)
|
||||
|
||||
#define FUNC_MGT_TEST_MASK(val, mask) (((val) & (mask)) != 0)
|
||||
|
||||
|
|
|
@ -1333,6 +1333,49 @@ static bool getBlockDistFuncEnv(SFunctionNode* UNUSED_PARAM(pFunc), SFuncExecEnv
|
|||
return true;
|
||||
}
|
||||
|
||||
static int32_t rewriteAvg(SNode** pFunc) {
|
||||
SOperatorNode* pOper = (SOperatorNode*)nodesMakeNode(QUERY_NODE_OPERATOR);
|
||||
if (NULL == pOper) {
|
||||
return TSDB_CODE_OUT_OF_MEMORY;
|
||||
}
|
||||
|
||||
SFunctionNode* pAvg = (SFunctionNode*)*pFunc;
|
||||
pOper->node.resType = pAvg->node.resType;
|
||||
strcpy(pOper->node.aliasName, pAvg->node.aliasName);
|
||||
pOper->opType = OP_TYPE_DIV;
|
||||
pOper->pLeft = nodesMakeNode(QUERY_NODE_FUNCTION);
|
||||
pOper->pRight = nodesMakeNode(QUERY_NODE_FUNCTION);
|
||||
if (NULL == pOper->pLeft || NULL == pOper->pRight) {
|
||||
nodesDestroyNode((SNode*)pOper);
|
||||
return TSDB_CODE_OUT_OF_MEMORY;
|
||||
}
|
||||
|
||||
SFunctionNode* pSum = (SFunctionNode*)pOper->pLeft;
|
||||
strcpy(pSum->functionName, "sum");
|
||||
pSum->pParameterList = nodesCloneList(pAvg->pParameterList);
|
||||
if (NULL == pSum->pParameterList) {
|
||||
nodesDestroyNode((SNode*)pOper);
|
||||
return TSDB_CODE_OUT_OF_MEMORY;
|
||||
}
|
||||
char msgBuf[64] = {0};
|
||||
int32_t code = fmGetFuncInfo(pSum, msgBuf, sizeof(msgBuf));
|
||||
if (TSDB_CODE_SUCCESS == code) {
|
||||
SFunctionNode* pCount = (SFunctionNode*)pOper->pRight;
|
||||
strcpy(pCount->functionName, "count");
|
||||
TSWAP(pCount->pParameterList, pAvg->pParameterList);
|
||||
code = fmGetFuncInfo(pCount, msgBuf, sizeof(msgBuf));
|
||||
}
|
||||
|
||||
if (TSDB_CODE_SUCCESS == code) {
|
||||
nodesDestroyNode((SNode*)pAvg);
|
||||
*pFunc = (SNode*)pOper;
|
||||
} else {
|
||||
nodesDestroyNode((SNode*)pOper);
|
||||
}
|
||||
|
||||
return code;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
const SBuiltinFuncDefinition funcMgtBuiltins[] = {
|
||||
{
|
||||
|
@ -1422,6 +1465,7 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = {
|
|||
.type = FUNCTION_TYPE_AVG,
|
||||
.classification = FUNC_MGT_AGG_FUNC,
|
||||
.translateFunc = translateInNumOutDou,
|
||||
.rewriteFunc = rewriteAvg,
|
||||
.getEnvFunc = getAvgFuncEnv,
|
||||
.initFunc = avgFunctionSetup,
|
||||
.processFunc = avgFunction,
|
||||
|
|
|
@ -161,6 +161,8 @@ bool fmIsUserDefinedFunc(int32_t funcId) { return funcId > FUNC_UDF_ID_START; }
|
|||
|
||||
bool fmIsForbidFillFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_FORBID_FILL_FUNC); }
|
||||
|
||||
bool fmIsForbidStreamFunc(int32_t funcId) { return isSpecificClassifyFunc(funcId, FUNC_MGT_FORBID_STREAM_FUNC); }
|
||||
|
||||
void fmFuncMgtDestroy() {
|
||||
void* m = gFunMgtService.pFuncNameHashTable;
|
||||
if (m != NULL && atomic_val_compare_exchange_ptr((void**)&gFunMgtService.pFuncNameHashTable, m, 0) == m) {
|
||||
|
@ -297,3 +299,12 @@ int32_t fmGetDistMethod(const SFunctionNode* pFunc, SFunctionNode** pPartialFunc
|
|||
|
||||
return code;
|
||||
}
|
||||
|
||||
bool fmNeedRewrite(int32_t funcId) {
|
||||
if (fmIsUserDefinedFunc(funcId)) {
|
||||
return false;
|
||||
}
|
||||
return NULL != funcMgtBuiltins[funcId].rewriteFunc;
|
||||
}
|
||||
|
||||
int32_t fmRewriteFunc(SNode** pFunc) { return funcMgtBuiltins[((SFunctionNode*)*pFunc)->funcId].rewriteFunc(pFunc); }
|
||||
|
|
|
@ -1076,29 +1076,32 @@ static void setFuncClassification(SSelectStmt* pSelect, SFunctionNode* pFunc) {
|
|||
}
|
||||
}
|
||||
|
||||
static EDealRes translateFunction(STranslateContext* pCxt, SFunctionNode* pFunc) {
|
||||
static EDealRes translateFunction(STranslateContext* pCxt, SFunctionNode** pFunc) {
|
||||
SNode* pParam = NULL;
|
||||
FOREACH(pParam, pFunc->pParameterList) {
|
||||
FOREACH(pParam, (*pFunc)->pParameterList) {
|
||||
if (isMultiResFunc(pParam)) {
|
||||
return generateDealNodeErrMsg(pCxt, TSDB_CODE_PAR_WRONG_VALUE_TYPE, ((SExprNode*)pParam)->aliasName);
|
||||
}
|
||||
}
|
||||
|
||||
pCxt->errCode = getFuncInfo(pCxt, pFunc);
|
||||
pCxt->errCode = getFuncInfo(pCxt, *pFunc);
|
||||
if (TSDB_CODE_SUCCESS == pCxt->errCode) {
|
||||
pCxt->errCode = translateAggFunc(pCxt, pFunc);
|
||||
pCxt->errCode = translateAggFunc(pCxt, *pFunc);
|
||||
}
|
||||
if (TSDB_CODE_SUCCESS == pCxt->errCode) {
|
||||
pCxt->errCode = translateScanPseudoColumnFunc(pCxt, pFunc);
|
||||
pCxt->errCode = translateScanPseudoColumnFunc(pCxt, *pFunc);
|
||||
}
|
||||
if (TSDB_CODE_SUCCESS == pCxt->errCode) {
|
||||
pCxt->errCode = translateIndefiniteRowsFunc(pCxt, pFunc);
|
||||
pCxt->errCode = translateIndefiniteRowsFunc(pCxt, *pFunc);
|
||||
}
|
||||
if (TSDB_CODE_SUCCESS == pCxt->errCode) {
|
||||
pCxt->errCode = translateForbidFillFunc(pCxt, pFunc);
|
||||
pCxt->errCode = translateForbidFillFunc(pCxt, *pFunc);
|
||||
}
|
||||
if (TSDB_CODE_SUCCESS == pCxt->errCode) {
|
||||
setFuncClassification(pCxt->pCurrSelectStmt, pFunc);
|
||||
setFuncClassification(pCxt->pCurrSelectStmt, *pFunc);
|
||||
}
|
||||
if (TSDB_CODE_SUCCESS == pCxt->errCode && fmNeedRewrite((*pFunc)->funcId)) {
|
||||
pCxt->errCode = fmRewriteFunc((SNode**)pFunc);
|
||||
}
|
||||
return TSDB_CODE_SUCCESS == pCxt->errCode ? DEAL_RES_CONTINUE : DEAL_RES_ERROR;
|
||||
}
|
||||
|
@ -1123,7 +1126,7 @@ static EDealRes doTranslateExpr(SNode** pNode, void* pContext) {
|
|||
case QUERY_NODE_OPERATOR:
|
||||
return translateOperator(pCxt, (SOperatorNode**)pNode);
|
||||
case QUERY_NODE_FUNCTION:
|
||||
return translateFunction(pCxt, (SFunctionNode*)*pNode);
|
||||
return translateFunction(pCxt, (SFunctionNode**)pNode);
|
||||
case QUERY_NODE_LOGIC_CONDITION:
|
||||
return translateLogicCond(pCxt, (SLogicConditionNode*)*pNode);
|
||||
case QUERY_NODE_TEMP_TABLE:
|
||||
|
|
|
@ -35,6 +35,7 @@ string toString(int32_t code) { return tstrerror(code); }
|
|||
// [...];
|
||||
class InsertTest : public Test {
|
||||
protected:
|
||||
InsertTest() : res_(nullptr) {}
|
||||
~InsertTest() { reset(); }
|
||||
|
||||
void setDatabase(const string& acctId, const string& db) {
|
||||
|
|
|
@ -53,6 +53,14 @@ TEST_F(PlanGroupByTest, aggFunc) {
|
|||
run("SELECT SUM(10), COUNT(c1) FROM t1 GROUP BY c2");
|
||||
}
|
||||
|
||||
TEST_F(PlanGroupByTest, rewriteFunc) {
|
||||
useDb("root", "test");
|
||||
|
||||
run("SELECT AVG(c1) FROM t1");
|
||||
|
||||
run("SELECT AVG(c1) FROM t1 GROUP BY c2");
|
||||
}
|
||||
|
||||
TEST_F(PlanGroupByTest, selectFunc) {
|
||||
useDb("root", "test");
|
||||
|
||||
|
|
Loading…
Reference in New Issue