diff --git a/include/libs/nodes/querynodes.h b/include/libs/nodes/querynodes.h index 678c694d9b..dc312a762e 100644 --- a/include/libs/nodes/querynodes.h +++ b/include/libs/nodes/querynodes.h @@ -52,6 +52,7 @@ typedef struct SExprNode { SArray* pAssociation; bool orderAlias; bool asAlias; + bool asParam; } SExprNode; typedef enum EColumnType { diff --git a/source/libs/nodes/src/nodesTraverseFuncs.c b/source/libs/nodes/src/nodesTraverseFuncs.c index ce575ede8a..b3623a4b0a 100644 --- a/source/libs/nodes/src/nodesTraverseFuncs.c +++ b/source/libs/nodes/src/nodesTraverseFuncs.c @@ -214,6 +214,18 @@ void nodesWalkExprsPostOrder(SNodeList* pList, FNodeWalker walker, void* pContex (void)walkExprs(pList, TRAVERSAL_POSTORDER, walker, pContext); } +static void checkParamIsFunc(SFunctionNode *pFunc) { + int32_t numOfParams = LIST_LENGTH(pFunc->pParameterList); + if (numOfParams > 1) { + for (int32_t i = 0; i < numOfParams; ++i) { + SNode* pPara = nodesListGetNode(pFunc->pParameterList, i); + if (nodeType(pPara) == QUERY_NODE_FUNCTION) { + ((SFunctionNode *)pPara)->node.asParam = true; + } + } + } +} + static EDealRes rewriteExprs(SNodeList* pNodeList, ETraversalOrder order, FNodeRewriter rewriter, void* pContext); static EDealRes rewriteExpr(SNode** pRawNode, ETraversalOrder order, FNodeRewriter rewriter, void* pContext) { @@ -248,9 +260,12 @@ static EDealRes rewriteExpr(SNode** pRawNode, ETraversalOrder order, FNodeRewrit case QUERY_NODE_LOGIC_CONDITION: res = rewriteExprs(((SLogicConditionNode*)pNode)->pParameterList, order, rewriter, pContext); break; - case QUERY_NODE_FUNCTION: - res = rewriteExprs(((SFunctionNode*)pNode)->pParameterList, order, rewriter, pContext); + case QUERY_NODE_FUNCTION: { + SFunctionNode* pFunc = (SFunctionNode*)pNode; + checkParamIsFunc(pFunc); + res = rewriteExprs(pFunc->pParameterList, order, rewriter, pContext); break; + } case QUERY_NODE_REAL_TABLE: case QUERY_NODE_TEMP_TABLE: break; // todo diff --git a/source/libs/parser/src/parTranslater.c b/source/libs/parser/src/parTranslater.c index b3a043fe12..90fe4f3e4c 100644 --- a/source/libs/parser/src/parTranslater.c +++ b/source/libs/parser/src/parTranslater.c @@ -1760,7 +1760,7 @@ static int32_t translateMultiResFunc(STranslateContext* pCxt, SFunctionNode* pFu "%s(*) is only supported in SELECTed list", pFunc->functionName); } } - if (tsKeepColumnName && 1 == LIST_LENGTH(pFunc->pParameterList) && !pFunc->node.asAlias) { + if (tsKeepColumnName && 1 == LIST_LENGTH(pFunc->pParameterList) && !pFunc->node.asAlias && !pFunc->node.asParam) { strcpy(pFunc->node.userAlias, ((SExprNode*)nodesListGetNode(pFunc->pParameterList, 0))->userAlias); strcpy(pFunc->node.aliasName, pFunc->node.userAlias); } diff --git a/tests/system-test/2-query/Timediff.py b/tests/system-test/2-query/Timediff.py index 4e72c07b30..a7366a4007 100644 --- a/tests/system-test/2-query/Timediff.py +++ b/tests/system-test/2-query/Timediff.py @@ -4,6 +4,8 @@ from util.cases import * from util.gettime import * class TDTestCase: + updatecfgDict = {'keepColumnName': 1} + def init(self, conn, logSql, replicaVar=1): self.replicaVar = int(replicaVar) tdLog.debug(f"start to excute {__file__}") @@ -27,14 +29,14 @@ class TDTestCase: self.ctbname = f'{self.dbname}.ctb' self.subtractor = 1 # unit:s def check_tbtype(self,tb_type): - if tb_type.lower() == 'ntb': + if tb_type.lower() == 'ntb': tdSql.query(f'select timediff(ts,{self.subtractor}) from {self.ntbname}') elif tb_type.lower() == 'ctb': tdSql.query(f'select timediff(ts,{self.subtractor}) from {self.ctbname}') elif tb_type.lower() == 'stb': tdSql.query(f'select timediff(ts,{self.subtractor}) from {self.stbname}') def check_tb_type(self,unit,tb_type): - if tb_type.lower() == 'ntb': + if tb_type.lower() == 'ntb': tdSql.query(f'select timediff(ts,{self.subtractor},{unit}) from {self.ntbname}') elif tb_type.lower() == 'ctb': tdSql.query(f'select timediff(ts,{self.subtractor},{unit}) from {self.ctbname}') @@ -43,7 +45,7 @@ class TDTestCase: def data_check(self,date_time,precision,tb_type): for unit in self.time_unit: if (unit.lower() == '1u' and precision.lower() == 'ms') or (unit.lower() == '1b' and precision.lower() == 'us') or (unit.lower() == '1b' and precision.lower() == 'ms'): - if tb_type.lower() == 'ntb': + if tb_type.lower() == 'ntb': tdSql.error(f'select timediff(ts,{self.subtractor},{unit}) from {self.ntbname}') elif tb_type.lower() == 'ctb': tdSql.error(f'select timediff(ts,{self.subtractor},{unit}) from {self.ctbname}') @@ -66,7 +68,7 @@ class TDTestCase: tdSql.checkEqual(tdSql.queryResult[i][0],int(((date_time[i]/1000)-self.subtractor)/60/60)) elif unit.lower() == '1d': for i in range(len(self.ts_str)): - tdSql.checkEqual(tdSql.queryResult[i][0],int(((date_time[i]/1000)-self.subtractor)/60/60/24)) + tdSql.checkEqual(tdSql.queryResult[i][0],int(((date_time[i]/1000)-self.subtractor)/60/60/24)) elif unit.lower() == '1w': for i in range(len(self.ts_str)): tdSql.checkEqual(tdSql.queryResult[i][0],int(((date_time[i]/1000)-self.subtractor)/60/60/24/7)) @@ -97,7 +99,7 @@ class TDTestCase: tdSql.checkEqual(tdSql.queryResult[i][0],int(((date_time[i]/1000)-self.subtractor*1000))) elif unit.lower() == '1u': for i in range(len(self.ts_str)): - tdSql.checkEqual(tdSql.queryResult[i][0],int(((date_time[i])-self.subtractor*1000000))) + tdSql.checkEqual(tdSql.queryResult[i][0],int(((date_time[i])-self.subtractor*1000000))) self.check_tbtype(tb_type) tdSql.checkRows(len(self.ts_str)) for i in range(len(self.ts_str)): @@ -185,8 +187,16 @@ class TDTestCase: elif precision.lower() == 'ns': for i in range(len(self.ts_str)): tdSql.checkEqual(tdSql.queryResult[i][0],int(((date_time[i])-self.subtractor*1000000000))) - + def function_multi_res_param(self): + tdSql.execute(f'drop database if exists {self.dbname}') + tdSql.execute(f'create database {self.dbname}') + tdSql.execute(f'use {self.dbname}') + tdSql.execute(f'create table {self.ntbname} (ts timestamp,c0 int)') + tdSql.execute(f'insert into {self.ntbname} values("2023-01-01 00:00:00",1)') + tdSql.execute(f'insert into {self.ntbname} values("2023-01-01 00:01:00",2)') + tdSql.query(f'select timediff(last(ts), first(ts)) from {self.ntbname}') + tdSql.checkData(0, 0, 60000) @@ -194,7 +204,8 @@ class TDTestCase: self.function_check_ntb() self.function_check_stb() self.function_without_param() - + self.function_multi_res_param() + def stop(self): tdSql.close() tdLog.success(f"{__file__} successfully executed")