From 40a7262fcdf44090507b8db7144d1096e0350c65 Mon Sep 17 00:00:00 2001 From: wangjiaming0909 <604227650@qq.com> Date: Tue, 12 Mar 2024 02:09:12 +0000 Subject: [PATCH] support spread --- source/libs/function/src/builtins.c | 29 +++++++++++++++++++++++-- source/libs/function/src/builtinsimpl.c | 1 + source/libs/parser/src/parTranslater.c | 8 +++++++ tests/system-test/2-query/tsma.py | 2 +- 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/source/libs/function/src/builtins.c b/source/libs/function/src/builtins.c index 4dd88eef2c..e1b2fa0b05 100644 --- a/source/libs/function/src/builtins.c +++ b/source/libs/function/src/builtins.c @@ -853,6 +853,31 @@ static int32_t translateSpreadMerge(SFunctionNode* pFunc, char* pErrBuf, int32_t return translateSpreadImpl(pFunc, pErrBuf, len, false); } +static int32_t translateSpreadState(SFunctionNode* pFunc, char* pErrBuf, int32_t len) { + if (1 != LIST_LENGTH(pFunc->pParameterList)) { + return invaildFuncParaNumErrMsg(pErrBuf, len, pFunc->functionName); + } + + uint8_t paraType = getSDataTypeFromNode(nodesListGetNode(pFunc->pParameterList, 0))->type; + if (!IS_NUMERIC_TYPE(paraType) && !IS_TIMESTAMP_TYPE(paraType)) { + return invaildFuncParaTypeErrMsg(pErrBuf, len, pFunc->functionName); + } + pFunc->node.resType = (SDataType){.bytes = getSpreadInfoSize() + VARSTR_HEADER_SIZE, .type = TSDB_DATA_TYPE_BINARY}; + return TSDB_CODE_SUCCESS; +} + +static int32_t translateSpreadStateMerge(SFunctionNode* pFunc, char* pErrBuf, int32_t len) { + if (1 != LIST_LENGTH(pFunc->pParameterList)) { + return invaildFuncParaNumErrMsg(pErrBuf, len, pFunc->functionName); + } + uint8_t paraType = getSDataTypeFromNode(nodesListGetNode(pFunc->pParameterList, 0))->type; + if (paraType != TSDB_DATA_TYPE_BINARY) { + return invaildFuncParaTypeErrMsg(pErrBuf, len, pFunc->functionName); + } + pFunc->node.resType = (SDataType){.bytes = getSpreadInfoSize() + VARSTR_HEADER_SIZE, .type = TSDB_DATA_TYPE_BINARY}; + return TSDB_CODE_SUCCESS; +} + static int32_t translateElapsed(SFunctionNode* pFunc, char* pErrBuf, int32_t len) { int32_t numOfParams = LIST_LENGTH(pFunc->pParameterList); if (1 != numOfParams && 2 != numOfParams) { @@ -3964,12 +3989,12 @@ const SBuiltinFuncDefinition funcMgtBuiltins[] = { .processFunc = spreadFunction, .finalizeFunc = spreadPartialFinalize, .pPartialFunc = "_spread_partial", - .pMergeFunc = "_last_state_merge" + .pMergeFunc = "_spread_state_merge" }, { .name = "_spread_state_merge", .type = FUNCTION_TYPE_SPREAD_STATE_MERGE, - .classification = FUNC_MGT_AGG_FUNC | FUNC_MGT_SPECIAL_DATA_REQUIRED | FUNC_MGT_TSMA_FUNC, + .classification = FUNC_MGT_AGG_FUNC | FUNC_MGT_TSMA_FUNC, .translateFunc = translateSpreadStateMerge, .getEnvFunc = getSpreadFuncEnv, .initFunc = spreadFunctionSetup, diff --git a/source/libs/function/src/builtinsimpl.c b/source/libs/function/src/builtinsimpl.c index d7ef445f9d..d2db33573b 100644 --- a/source/libs/function/src/builtinsimpl.c +++ b/source/libs/function/src/builtinsimpl.c @@ -3877,6 +3877,7 @@ int32_t spreadFunctionMerge(SqlFunctionCtx* pCtx) { int32_t start = pInput->startRowIndex; for (int32_t i = start; i < start + pInput->numOfRows; ++i) { + if(colDataIsNull_s(pCol, i)) continue; char* data = colDataGetData(pCol, i); SSpreadInfo* pInputInfo = (SSpreadInfo*)varDataVal(data); if (pInputInfo->hasResult) { diff --git a/source/libs/parser/src/parTranslater.c b/source/libs/parser/src/parTranslater.c index 4f1ff56879..6ceb2aa7fc 100644 --- a/source/libs/parser/src/parTranslater.c +++ b/source/libs/parser/src/parTranslater.c @@ -10704,6 +10704,14 @@ static int32_t rewriteTSMAFuncs(STranslateContext* pCxt, SCreateTSMAStmt* pStmt, code = TSDB_CODE_TSMA_INVALID_FUNC_PARAM; break; } + SColumnNode* pCol = (SColumnNode*)pFunc->pParameterList->pHead->pNode; + for (int32_t i = 0; i < columnNum; ++i) { + if (strcmp(pCols[i].name, pCol->colName) == 0) { + pCol->colId = pCols[i].colId; + pCol->node.resType.type = pCols[i].type; + pCol->node.resType.bytes = pCols[i].bytes; + } + } code = fmGetFuncInfo(pFunc, NULL, 0); if (TSDB_CODE_SUCCESS != code) break; if (!fmIsTSMASupportedFunc(pFunc->funcId)) { diff --git a/tests/system-test/2-query/tsma.py b/tests/system-test/2-query/tsma.py index 4fcefb5cb0..9fa4d819d1 100644 --- a/tests/system-test/2-query/tsma.py +++ b/tests/system-test/2-query/tsma.py @@ -771,7 +771,7 @@ class TDTestCase: def test_recursive_tsma(self): tdSql.execute('drop tsma tsma2') - tsma_func_list = ['avg(c2)', 'avg(c3)', 'min(c4)', 'max(c3)', 'sum(c2)', 'count(ts)', 'count(c2)', 'first(c5)', 'last(c5)'] + tsma_func_list = ['avg(c2)', 'avg(c3)', 'min(c4)', 'max(c3)', 'sum(c2)', 'count(ts)', 'count(c2)', 'first(c5)', 'last(c5)', 'spread(c2)'] select_func_list: List[str] = tsma_func_list.copy() select_func_list.append('count(*)') self.create_tsma('tsma3', 'test', 'meters', tsma_func_list, '5m')