diff --git a/source/libs/parser/src/parTranslater.c b/source/libs/parser/src/parTranslater.c index 5abc142584..6535c70c82 100755 --- a/source/libs/parser/src/parTranslater.c +++ b/source/libs/parser/src/parTranslater.c @@ -7438,6 +7438,18 @@ static bool invalidColsAlias(SFunctionNode* pFunc) { return false; } +static int32_t getSelectFuncIndex(SNodeList* FuncNodeList, SNode* pSelectFunc) { + SNode* pNode = NULL; + int32_t selectFuncIndex = 0; + FOREACH(pNode, FuncNodeList) { + ++selectFuncIndex; + if (nodesEqualNode(pNode, pSelectFunc)) { + return selectFuncIndex; + } + } + return 0; +} + static int32_t rewriteColsFunction(STranslateContext* pCxt, SNodeList** nodeList) { int32_t code = hasInvalidColsFunction(pCxt, *nodeList); if (TSDB_CODE_SUCCESS != code) { @@ -7469,26 +7481,30 @@ static int32_t rewriteColsFunction(STranslateContext* pCxt, SNodeList** nodeList } SNode* pNewNode = NULL; int32_t nums = 0; - int32_t selectFuncNum = 0; + int32_t selectFuncCount = 0; FOREACH(pTmpNode, *nodeList) { if (QUERY_NODE_FUNCTION == nodeType(pTmpNode)) { SFunctionNode* pFunc = (SFunctionNode*)pTmpNode; if (strcasecmp(pFunc->functionName, "cols") == 0) { - ++selectFuncNum; SNode* pSelectFunc = nodesListGetNode(pFunc->pParameterList, 0); - if(nodeType(pSelectFunc) != QUERY_NODE_FUNCTION) { + if (nodeType(pSelectFunc) != QUERY_NODE_FUNCTION) { code = TSDB_CODE_PAR_INVALID_COLS_FUNCTION; parserError("Invalid cols function, the first parameter must be a select function"); goto _end; } - nodesListMakeStrictAppend(&tmpFuncNodeList, pSelectFunc); + int32_t selectFuncIndex = getSelectFuncIndex(tmpFuncNodeList, pSelectFunc); + if (selectFuncIndex == 0) { + ++selectFuncCount; + selectFuncIndex = selectFuncCount; + nodesListMakeStrictAppend(&tmpFuncNodeList, pSelectFunc); + } // start from index 1, because the first parameter is select function which needn't to output. for (int i = 1; i < pFunc->pParameterList->length; ++i) { SNode* pExpr = nodesListGetNode(pFunc->pParameterList, i); code = nodesCloneNode(pExpr, &pNewNode); if (nodesIsExprNode(pNewNode)) { - SBindTupleFuncCxt pCxt = {selectFuncNum}; + SBindTupleFuncCxt pCxt = {selectFuncIndex}; nodesRewriteExpr(&pNewNode, pushDownBindSelectFunc, &pCxt); } else { code = TSDB_CODE_PAR_INVALID_COLS_FUNCTION; diff --git a/tests/system-test/2-query/cols_function.py b/tests/system-test/2-query/cols_function.py index 26e2e4a0cb..dba7e346b1 100644 --- a/tests/system-test/2-query/cols_function.py +++ b/tests/system-test/2-query/cols_function.py @@ -85,9 +85,9 @@ class TDTestCase: tdSql.checkRows(1) #tdSql.checkCols(4) tdSql.checkData(0, 0, 1734574930000) - tdSql.checkData(0, 1, 2.2) + tdSql.checkData(0, 1, 'bbbbbbbbb') tdSql.checkData(0, 2, 1734574929000) - tdSql.checkData(0, 3, 1.1) + tdSql.checkData(0, 3, 'a') tdSql.query(f'select cols(last(c0), ts, c1, c2, c3), cols(first(c0), ts, c1, c2, c3) from db.d1') tdSql.checkRows(1) #tdSql.checkCols(6) @@ -103,19 +103,37 @@ class TDTestCase: tdSql.query(f'select cols(last(ts), c1), cols(first(ts), c1) from db.d1') tdSql.checkRows(1) #tdSql.checkCols(6) - tdSql.checkData(0, 0, 2) - tdSql.checkData(0, 1, 1) + tdSql.checkData(0, 0, 2.2) + tdSql.checkData(0, 1, 1.1) - tdSql.query(f'select cols(first(ts), c1), cols(first(ts), c1) from db.d1') + tdSql.query(f'select cols(first(ts), c0, c1), cols(first(ts), c0, c1) from db.d1') tdSql.checkRows(1) #tdSql.checkCols(6) tdSql.checkData(0, 0, 1) - tdSql.checkData(0, 1, 1) + tdSql.checkData(0, 1, 1.1) + tdSql.checkData(0, 2, 1) + tdSql.checkData(0, 3, 1.1) + + tdSql.query(f'select cols(first(ts), c0, c1), cols(first(ts+1), c0, c1) from db.d1') + tdSql.checkRows(1) + #tdSql.checkCols(6) + tdSql.checkData(0, 0, 1) + tdSql.checkData(0, 1, 1.1) + tdSql.checkData(0, 2, 1) + tdSql.checkData(0, 3, 1.1) + + tdSql.query(f'select cols(first(ts), c0, c1), cols(first(ts), c0+1, c1+2) from db.d1') + tdSql.checkRows(1) + #tdSql.checkCols(6) + tdSql.checkData(0, 0, 1) + tdSql.checkData(0, 1, 1.1) + tdSql.checkData(0, 2, 2) + tdSql.checkData(0, 3, 3.1) tdSql.query(f'select cols(first(c0), ts, length(c2)), cols(last(c0), ts, length(c2)) from db.d1') tdSql.checkRows(1) #tdSql.checkCols(6) - tdSql.checkData(0, 0, 1) + tdSql.checkData(0, 0, 1734574929000) tdSql.checkData(0, 1, 1) tdSql.query(f'select cols(first(c0), ts, length(c2)), cols(last(c0), ts, length(c2)) from db.d1') tdSql.checkRows(1)