remove duplicate group by cols

This commit is contained in:
wangjiaming0909 2024-11-15 17:52:06 +08:00
parent 59c755b7b1
commit 09fe252a31
6 changed files with 61 additions and 5 deletions

View File

@ -174,6 +174,7 @@ char* nodesGetNameFromColumnNode(SNode* pNode);
int32_t nodesGetOutputNumFromSlotList(SNodeList* pSlots);
void nodesSortList(SNodeList** pList, int32_t (*)(SNode* pNode1, SNode* pNode2));
void destroyFuncParam(void* pFuncStruct);
int32_t nodesListDeduplicate(SNodeList** pList);
#ifdef __cplusplus
}

View File

@ -153,6 +153,12 @@ static bool caseWhenNodeEqual(const SCaseWhenNode* a, const SCaseWhenNode* b) {
return true;
}
static bool groupingSetNodeEqual(const SGroupingSetNode* a, const SGroupingSetNode* b) {
COMPARE_SCALAR_FIELD(groupingSetType);
COMPARE_NODE_LIST_FIELD(pParameterList);
return true;
}
bool nodesEqualNode(const SNode* a, const SNode* b) {
if (a == b) {
return true;
@ -181,10 +187,11 @@ bool nodesEqualNode(const SNode* a, const SNode* b) {
return whenThenNodeEqual((const SWhenThenNode*)a, (const SWhenThenNode*)b);
case QUERY_NODE_CASE_WHEN:
return caseWhenNodeEqual((const SCaseWhenNode*)a, (const SCaseWhenNode*)b);
case QUERY_NODE_GROUPING_SET:
return groupingSetNodeEqual((const SGroupingSetNode*)a, (const SGroupingSetNode*)b);
case QUERY_NODE_REAL_TABLE:
case QUERY_NODE_TEMP_TABLE:
case QUERY_NODE_JOIN_TABLE:
case QUERY_NODE_GROUPING_SET:
case QUERY_NODE_ORDER_BY_EXPR:
case QUERY_NODE_LIMIT:
return false;

View File

@ -2948,3 +2948,37 @@ void nodesSortList(SNodeList** pList, int32_t (*comp)(SNode* pNode1, SNode* pNod
inSize *= 2;
}
}
static SNode* nodesListFindNode(SNodeList* pList, SNode* pNode) {
SNode* pFound = NULL;
FOREACH(pFound, pList) {
if (nodesEqualNode(pFound, pNode)) {
break;
}
}
return pFound;
}
int32_t nodesListDeduplicate(SNodeList** ppList) {
if (!ppList || LIST_LENGTH(*ppList) == 0) return TSDB_CODE_SUCCESS;
SNodeList* pTmp = NULL;
int32_t code = nodesMakeList(&pTmp);
if (TSDB_CODE_SUCCESS == code) {
SNode* pNode = NULL;
FOREACH(pNode, *ppList) {
SNode* pFound = nodesListFindNode(pTmp, pNode);
if (NULL == pFound) {
code = nodesCloneNode(pNode, &pFound);
if (TSDB_CODE_SUCCESS == code) code = nodesListStrictAppend(pTmp, pFound);
if (TSDB_CODE_SUCCESS != code) break;
}
}
}
if (TSDB_CODE_SUCCESS == code) {
nodesDestroyList(*ppList);
*ppList = pTmp;
} else {
nodesDestroyList(pTmp);
}
return code;
}

View File

@ -5531,7 +5531,6 @@ static int32_t translateGroupByList(STranslateContext* pCxt, SSelectStmt* pSelec
SReplaceGroupByAliasCxt cxt = {
.pTranslateCxt = pCxt, .pProjectionList = pSelect->pProjectionList};
nodesRewriteExprsPostOrder(pSelect->pGroupByList, translateGroupPartitionByImpl, &cxt);
return pCxt->errCode;
}
@ -5543,7 +5542,6 @@ static int32_t translatePartitionByList(STranslateContext* pCxt, SSelectStmt* pS
SReplaceGroupByAliasCxt cxt = {
.pTranslateCxt = pCxt, .pProjectionList = pSelect->pProjectionList};
nodesRewriteExprsPostOrder(pSelect->pPartitionByList, translateGroupPartitionByImpl, &cxt);
return pCxt->errCode;
}

View File

@ -838,8 +838,11 @@ static int32_t createAggLogicNode(SLogicPlanContext* pCxt, SSelectStmt* pSelect,
}
if (NULL != pSelect->pGroupByList) {
pAgg->pGroupKeys = NULL;
code = nodesCloneList(pSelect->pGroupByList, &pAgg->pGroupKeys);
code = nodesListDeduplicate(&pSelect->pGroupByList);
if (TSDB_CODE_SUCCESS == code) {
pAgg->pGroupKeys = NULL;
code = nodesCloneList(pSelect->pGroupByList, &pAgg->pGroupKeys);
}
}
// rewrite the expression in subsequent clauses

View File

@ -437,6 +437,18 @@ class TDTestCase:
tdSql.checkRows(10)
tdSql.query(f"select const_col from (select 1 as const_col, count(c1) from {self.dbname}.{self.stable} t group by c1) partition by 1")
tdSql.checkRows(10)
def test_TD_32883(self):
sql = "select avg(c1), t9 from stb group by t9,t9, tbname"
tdSql.query(sql, queryTimes=1)
tdSql.checkRows(5)
sql = "select avg(c1), t10 from stb group by t10,t10, tbname"
tdSql.query(sql, queryTimes=1)
tdSql.checkRows(5)
sql = "select avg(c1), t10 from stb partition by t10,t10, tbname"
tdSql.query(sql, queryTimes=1)
tdSql.checkRows(5)
def run(self):
tdSql.prepare()
self.prepare_db()
@ -470,6 +482,7 @@ class TDTestCase:
self.test_event_window(nonempty_tb_num)
self.test_TS5567()
self.test_TD_32883()
## test old version before changed
# self.test_groupby('group', 0, 0)