From 9ed049fd0e0c7438b3eb058d2f704d6963a4e914 Mon Sep 17 00:00:00 2001 From: wangjiaming0909 Date: Mon, 3 Mar 2025 18:06:02 +0800 Subject: [PATCH] test decimal col filtering --- source/libs/decimal/src/decimal.c | 37 ++++++-- source/libs/decimal/test/decimalTest.cpp | 56 +++++++++++- source/libs/scalar/src/sclvector.c | 2 + tests/system-test/2-query/decimal.py | 112 +++++++++++++++++++++-- 4 files changed, 188 insertions(+), 19 deletions(-) diff --git a/source/libs/decimal/src/decimal.c b/source/libs/decimal/src/decimal.c index e3657ccbea..acea6e5bf0 100644 --- a/source/libs/decimal/src/decimal.c +++ b/source/libs/decimal/src/decimal.c @@ -1128,17 +1128,40 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR // There is no need to do type conversions, we assume that pLeftT and pRightT are all decimal128 types. bool decimalCompare(EOperatorType op, const SDecimalCompareCtx* pLeft, const SDecimalCompareCtx* pRight) { bool ret = false; - uint8_t pLeftPrec = 0, pLeftScale = 0, pRightPrec = 0, pRightScale = 0; - decimalFromTypeMod(pLeft->typeMod, &pLeftPrec, &pLeftScale); - decimalFromTypeMod(pRight->typeMod, &pRightPrec, &pRightScale); - int32_t deltaScale = pLeftScale - pRightScale; + uint8_t leftPrec = 0, leftScale = 0, rightPrec = 0, rightScale = 0; + decimalFromTypeMod(pLeft->typeMod, &leftPrec, &leftScale); + decimalFromTypeMod(pRight->typeMod, &rightPrec, &rightScale); + int32_t deltaScale = leftScale - rightScale; Decimal pLeftDec = *(Decimal*)pLeft->pData, pRightDec = *(Decimal*)pRight->pData; if (deltaScale != 0) { - bool needInt256 = (deltaScale < 0 && pLeftPrec - deltaScale > TSDB_DECIMAL_MAX_PRECISION) || - (pRightPrec + deltaScale > TSDB_DECIMAL_MAX_PRECISION); + bool needInt256 = (deltaScale < 0 && leftPrec - deltaScale > TSDB_DECIMAL_MAX_PRECISION) || + (rightPrec + deltaScale > TSDB_DECIMAL_MAX_PRECISION); if (needInt256) { - // TODO wjm impl it + Int256 x = {0}, y = {0}; + makeInt256FromDecimal128(&x, &pLeftDec); + makeInt256FromDecimal128(&y, &pRightDec); + if (leftScale < rightScale) { + x = int256ScaleBy(&x, rightScale - leftScale); + } else { + y = int256ScaleBy(&y, leftScale - rightScale); + } + switch (op) { + case OP_TYPE_GREATER_THAN: + return int256Gt(&x, &y); + case OP_TYPE_GREATER_EQUAL: + return !int256Lt(&x, &y); + case OP_TYPE_LOWER_THAN: + return int256Lt(&x, &y); + case OP_TYPE_LOWER_EQUAL: + return !int256Gt(&x, &y); + case OP_TYPE_EQUAL: + return int256Eq(&x, &y); + case OP_TYPE_NOT_EQUAL: + return !int256Eq(&x, &y); + default: + break; + } return false; } else { if (deltaScale < 0) { diff --git a/source/libs/decimal/test/decimalTest.cpp b/source/libs/decimal/test/decimalTest.cpp index 6b51290978..0e3788c4c3 100644 --- a/source/libs/decimal/test/decimalTest.cpp +++ b/source/libs/decimal/test/decimalTest.cpp @@ -198,6 +198,7 @@ class Numeric { uint8_t prec() const { return prec_; } uint8_t scale() const { return scale_; } const Type& dec() const { return dec_; } + STypeMod get_type_mod() const { return decimalCalcTypeMod(prec(), scale()); } template Numeric& binaryOp(const Numeric& r, EOperatorType op) { @@ -289,10 +290,6 @@ class Numeric { return binaryOp(r, OP_TYPE_ADD); } - template - bool operator==(const Numeric& r) { - return binaryOp(r, OP_TYPE_EQUAL); - } std::string toString() const { char buf[64] = {0}; int32_t code = decimalToStr(&dec_, NumericType::dataType, prec(), scale(), buf, 64); @@ -406,6 +403,45 @@ class Numeric { if (code != 0) throw std::runtime_error(tstrerror(code)); return *this; } + + template + Numeric(const Numeric& num2) { + Numeric(); + *this = num2; + } + + template + Numeric& operator=(const Numeric& num2) { + static_assert(BitNum2 == 64 || BitNum2 == 128, "Only support decimal128/decimal64"); + SDataType inputDt = { + .type = num2.type().type, .precision = num2.prec(), .scale = num2.scale(), .bytes = num2.type().bytes}; + SDataType outputDt = {.type = NumericType::dataType, + .precision = NumericType::maxPrec, + .scale = num2.scale(), + .bytes = NumericType::bytes}; + int32_t code = convertToDecimal(&num2.dec(), &inputDt, &dec_, &outputDt); + if (code == TSDB_CODE_DECIMAL_OVERFLOW) throw std::overflow_error(tstrerror(code)); + if (code != 0) throw std::runtime_error(tstrerror(code)); + prec_ = outputDt.precision; + scale_ = outputDt.scale; + return *this; + } +#define DEFINE_COMPARE_OP(op, op_type) \ + template \ + bool operator op(const T& t) { \ + Numeric<128> lDec = *this, rDec = *this; \ + rDec = t; \ + SDecimalCompareCtx l = {(void*)&lDec.dec(), TSDB_DATA_TYPE_DECIMAL, lDec.get_type_mod()}, \ + r = {(void*)&rDec.dec(), TSDB_DATA_TYPE_DECIMAL, rDec.get_type_mod()}; \ + return decimalCompare(op_type, &l, &r); \ + } + + DEFINE_COMPARE_OP(>, OP_TYPE_GREATER_THAN); + DEFINE_COMPARE_OP(>=, OP_TYPE_GREATER_EQUAL); + DEFINE_COMPARE_OP(<, OP_TYPE_LOWER_THAN); + DEFINE_COMPARE_OP(<=, OP_TYPE_LOWER_EQUAL); + DEFINE_COMPARE_OP(==, OP_TYPE_EQUAL); + DEFINE_COMPARE_OP(!=, OP_TYPE_NOT_EQUAL); }; template @@ -1237,6 +1273,18 @@ TEST(decimal_all, ret_type_load_from_file) { ASSERT_EQ(total_lines, 3034205); } +TEST(decimal_all, test_decimal_compare) { + Numeric<64> dec64 = {10, 2, "123.23"}; + Numeric<64> dec64_2 = {11, 10, "1.23"}; + ASSERT_FALSE(dec64_2 > dec64); + dec64 = "10123456.23"; + ASSERT_FALSE(dec64_2 > dec64); + ASSERT_TRUE(dec64 > dec64_2); + ASSERT_TRUE(dec64_2 < 100); + Numeric<128> dec128 = {38, 10, "1.23"}; + ASSERT_TRUE(dec128 == dec64_2); +} + TEST(decimal_all, ret_type_for_non_decimal_types) { std::vector non_decimal_types; SDataType decimal_type = {TSDB_DATA_TYPE_DECIMAL64, 10, 2, 8}; diff --git a/source/libs/scalar/src/sclvector.c b/source/libs/scalar/src/sclvector.c index 0744895424..e2f46d37d2 100644 --- a/source/libs/scalar/src/sclvector.c +++ b/source/libs/scalar/src/sclvector.c @@ -1104,6 +1104,8 @@ STypeMod getConvertTypeMod(int32_t type, const SColumnInfo* pCol1, const SColumn return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), pCol1->scale); } else if (pCol2 && IS_DECIMAL_TYPE(pCol2->type) && !IS_DECIMAL_TYPE(pCol1->type)) { return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), pCol2->scale); + } else if (IS_DECIMAL_TYPE(pCol1->type) && pCol2 && IS_DECIMAL_TYPE(pCol2->type)) { + return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), MAX(pCol1->scale, pCol2->scale)); } else { return 0; } diff --git a/tests/system-test/2-query/decimal.py b/tests/system-test/2-query/decimal.py index 65f22270f4..8637f5d858 100644 --- a/tests/system-test/2-query/decimal.py +++ b/tests/system-test/2-query/decimal.py @@ -1,6 +1,7 @@ import math from random import randrange import random +from re import I import time import threading import secrets @@ -25,11 +26,12 @@ scalar_convert_err = -2147470768 decimal_insert_validator_test = False -operator_test_round = 1 +operator_test_round = 10 tb_insert_rows = 1000 -binary_op_with_const_test = True -binary_op_with_col_test = True -unary_op_test = True +binary_op_with_const_test = False +binary_op_with_col_test = False +unary_op_test = False +binary_op_in_where_test = True class DecimalTypeGeneratorConfig: def __init__(self): @@ -191,12 +193,16 @@ class DecimalColumnExpr: self.executor_ = executor self.params_ = () self.res_type_: DataType = None + self.query_col: Column = None def __str__(self): return f"({self.format_})".format(*self.params_) def execute(self, params): return self.executor_(self, params) + + def get_query_col_val(self, tbname, i): + return self.query_col.get_val_for_execute(tbname, i) def get_val(self, tbname: str, idx: int): params = () @@ -215,6 +221,31 @@ class DecimalColumnExpr: def should_skip_for_decimal(self, cols: list): return False + + def check_for_filtering(self, query_col_res: List, tbname: str): + j: int = -1 + for i in range(len(query_col_res)): + j += 1 + v_from_query = query_col_res[i] + while True: + params = () + for p in self.params_: + if isinstance(p, Column) or isinstance(p, DecimalColumnExpr): + p = p.get_val_for_execute(tbname, j) + params = params + (p,) + v_from_calc_in_py = self.execute(params) + + if not v_from_calc_in_py: + j += 1 + continue + else: + break + dec_from_query = Decimal(v_from_query) + dec_from_calc = self.get_query_col_val(tbname, j) + if dec_from_query != dec_from_calc: + tdLog.exit(f"filter with {self} failed, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}") + else: + tdLog.info(f"filter with {self} succ, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}") def check(self, query_col_res: List, tbname: str): for i in range(len(query_col_res)): @@ -776,10 +807,6 @@ class DecimalBinaryOperator(DecimalColumnExpr): def should_skip_for_decimal(self, cols: list): left_col = cols[0] right_col = cols[1] - if not left_col.is_constant_col() and left_col.name_ == '': - return True - if not right_col.is_constant_col() and right_col.name_ == '': - return True if not left_col.type_.is_decimal_type() and not right_col.type_.is_decimal_type(): return True if self.op_ != "%": @@ -933,6 +960,10 @@ class DecimalBinaryOperator(DecimalColumnExpr): return float(left) == float(right) else: return Decimal(left) == Decimal(right) + + def execute_eq_filtering(self, params): + if self.execute_eq(params): + return True def execute_ne(self, params): if DecimalBinaryOperator.check_null(params): @@ -994,6 +1025,17 @@ class DecimalBinaryOperator(DecimalColumnExpr): DecimalBinaryOperator(" {0} >= {1} ", DecimalBinaryOperator.execute_ge, ">="), DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="), ] + + @staticmethod + def get_all_filtering_binary_compare_ops() -> List[DecimalColumnExpr]: + return [ + DecimalBinaryOperator(" {0} == {1} ", DecimalBinaryOperator.execute_eq, "=="), + DecimalBinaryOperator(" {0} != {1} ", DecimalBinaryOperator.execute_ne, "!="), + DecimalBinaryOperator(" {0} > {1} ", DecimalBinaryOperator.execute_gt, ">"), + DecimalBinaryOperator(" {0} < {1} ", DecimalBinaryOperator.execute_lt, "<"), + DecimalBinaryOperator(" {0} >= {1} ", DecimalBinaryOperator.execute_ge, ">="), + DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="), + ] def execute(self, params): return super().execute(params) @@ -1406,6 +1448,41 @@ class TDTestCase: else: tdLog.info(f"sql: {sql} got no output") + def check_decimal_where_with_binary_expr_with_const_col_results( + self, + dbname, + tbname, + tb_cols: List[Column], + constant_cols: List[Column], + exprs: List[DecimalColumnExpr], + ): + if not binary_op_in_where_test: + return + for expr in exprs: + for col in tb_cols: + if col.name_ == '': + continue + left_is_decimal = col.type_.is_decimal_type() + for const_col in constant_cols: + right_is_decimal = const_col.type_.is_decimal_type() + if expr.should_skip_for_decimal([col, const_col]): + continue + const_col.generate_value() + select_expr = expr.generate((const_col, col)) + expr.query_col = col + sql = f"select {col} from {dbname}.{tbname} where {select_expr}" + res = TaosShell().query(sql) + ##TODO wjm no need to check len(res) for filtering test, cause we need to check for every row in the table to check if the filtering is working + if len(res) > 0: + expr.check_for_filtering(res[0], tbname) + select_expr = expr.generate((col, const_col)) + sql = f"select {col} from {dbname}.{tbname} where {select_expr}" + res = TaosShell().query(sql) + if len(res) > 0: + expr.check_for_filtering(res[0], tbname) + else: + tdLog.info(f"sql: {sql} got no output") + def check_decimal_binary_expr_with_const_col_results( self, dbname, @@ -1530,6 +1607,8 @@ class TDTestCase: self.norm_tb_columns, unary_operators,) + self.test_query_decimal_where_clause() + def test_decimal_functions(self): self.test_decimal_last_first_func() funcs = ["max", "min", "sum", "avg", "count", "first", "last", "cast"] @@ -1541,6 +1620,23 @@ class TDTestCase: pass def test_query_decimal_where_clause(self): + binary_compare_ops = DecimalBinaryOperator.get_all_filtering_binary_compare_ops() + const_cols = Column.get_decimal_oper_const_cols() + for i in range(operator_test_round): + self.check_decimal_where_with_binary_expr_with_const_col_results( + self.db_name, + self.norm_table_name, + self.norm_tb_columns, + const_cols, + binary_compare_ops, + ) + ## test filtering with decimal exprs + ## 1. dec op const col + ## 2. dec op dec + ## 3. (dec op const col) op const col + ## 4. (dec op dec) op const col + ## 5. (dec op const col) op dec + ## 6. (dec op dec) op dec pass def test_query_decimal_order_clause(self):