fix decimal filtering

This commit is contained in:
wangjiaming0909 2025-03-04 17:22:26 +08:00
parent c2b20e6722
commit 7f1eabf905
2 changed files with 28 additions and 12 deletions

View File

@ -4176,6 +4176,7 @@ int32_t fltSclGetOrCreateColumnRange(SColumnNode *colNode, SArray *colRangeList,
static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumnNode* pColNode, SValueNode* valNode) {
datum->type = pColNode->node.resType;
SDataType valDt = valNode->node.resType;
if (valNode->isNull) {
datum->kind = FLT_SCL_DATUM_KIND_NULL;
} else {
@ -4203,6 +4204,7 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn
case TSDB_DATA_TYPE_FLOAT:
case TSDB_DATA_TYPE_DOUBLE:
pInput = &valNode->datum.d;
valDt.type = TSDB_DATA_TYPE_DOUBLE;
break;
case TSDB_DATA_TYPE_VARCHAR:
pInput = valNode->literal;
@ -4224,7 +4226,7 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn
datum->kind = FLT_SCL_DATUM_KIND_DECIMAL;
}
if (datum->kind == FLT_SCL_DATUM_KIND_DECIMAL64 || datum->kind == FLT_SCL_DATUM_KIND_DECIMAL) {
int32_t code = convertToDecimal(pInput, &valNode->node.resType, pData, &datum->type);
int32_t code = convertToDecimal(pInput, &valDt, pData, &datum->type);
if (TSDB_CODE_SUCCESS != code) return code; // TODO wjm handle overflow error
valNode->node.resType = datum->type;
}

View File

@ -43,9 +43,9 @@ scalar_convert_err = -2147470768
decimal_insert_validator_test = False
operator_test_round = 2
operator_test_round = 1
tb_insert_rows = 1000
binary_op_with_const_test = True
binary_op_with_const_test = False
binary_op_with_col_test = False
unary_op_test = False
binary_op_in_where_test = True
@ -465,6 +465,8 @@ class DataType:
val = float(str(numpy.float32(val)))
else:
val = float(numpy.float32(val))
elif self.type == TypeEnum.DECIMAL or self.type == TypeEnum.DECIMAL64:
return get_decimal(val, self.scale())
elif isinstance(val, str):
val = val.strip("'")
return val
@ -1658,6 +1660,7 @@ class TDTestCase:
if not binary_op_in_where_test:
return
for expr in exprs:
tdLog.info(f"start to test decimal where filtering with const cols for expr: {expr.format_}")
for col in tb_cols:
if col.name_ == '':
continue
@ -1668,14 +1671,17 @@ class TDTestCase:
continue
const_col.generate_value()
select_expr = expr.generate((const_col, col))
expr.query_col = col
sql = f"select {col} from {dbname}.{tbname} where {select_expr}"
if const_col.type_.is_decimal_type():
expr.query_col = col
else:
expr.query_col = col
sql = f"select {expr.query_col} from {dbname}.{tbname} where {select_expr}"
res = TaosShell().query(sql)
##TODO wjm no need to check len(res) for filtering test, cause we need to check for every row in the table to check if the filtering is working
if len(res) > 0:
expr.check_for_filtering(res[0], tbname)
select_expr = expr.generate((col, const_col))
sql = f"select {col} from {dbname}.{tbname} where {select_expr}"
sql = f"select {expr.query_col} from {dbname}.{tbname} where {select_expr}"
res = TaosShell().query(sql)
if len(res) > 0:
expr.check_for_filtering(res[0], tbname)
@ -1688,19 +1694,28 @@ class TDTestCase:
if not binary_op_in_where_test:
return
for expr in exprs:
tdLog.info(f"start to test decimal where filtering with cols for expr{expr.format_}")
for col in tb_cols:
if col.name_ == '':
continue
for col2 in tb_cols:
if col2.name_ == '':
continue
if expr.should_skip_for_decimal([col, col2]):
continue
select_expr = expr.generate((col, col2))
sql = f"select {col} from {dbname}.{tbname} where {select_expr}"
if col.type_.is_decimal_type():
expr.query_col = col
else:
expr.query_col = col2
sql = f"select {expr.query_col} from {dbname}.{tbname} where {select_expr}"
res = TaosShell().query(sql)
if len(res) > 0:
expr.check_for_filtering(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
select_expr = expr.generate((col2, col))
sql = f"select {col} from {dbname}.{tbname} where {select_expr}"
sql = f"select {expr.query_col} from {dbname}.{tbname} where {select_expr}"
res = TaosShell().query(sql)
if len(res) > 0:
expr.check_for_filtering(res[0], tbname)
@ -1708,6 +1723,7 @@ class TDTestCase:
tdLog.info(f"sql: {sql} got no output")
def test_query_decimal_where_clause(self):
tdLog.info("start to test decimal where filtering")
binary_compare_ops = DecimalBinaryOperator.get_all_filtering_binary_compare_ops()
const_cols = Column.get_decimal_oper_const_cols()
for i in range(operator_test_round):
@ -1725,14 +1741,12 @@ class TDTestCase:
self.norm_table_name,
self.norm_tb_columns,
binary_compare_ops)
## test filtering with decimal exprs
## 1. dec op const col
## 2. dec op dec
## TODO wjm
## 3. (dec op const col) op const col
## 4. (dec op dec) op const col
## 5. (dec op const col) op dec
## 6. (dec op dec) op dec
pass
def test_query_decimal_order_clause(self):
pass