test decimal col filtering

This commit is contained in:
wangjiaming0909 2025-03-03 18:06:02 +08:00
parent 8dc25a53af
commit 9ed049fd0e
4 changed files with 188 additions and 19 deletions

View File

@ -1128,17 +1128,40 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR
// There is no need to do type conversions, we assume that pLeftT and pRightT are all decimal128 types.
bool decimalCompare(EOperatorType op, const SDecimalCompareCtx* pLeft, const SDecimalCompareCtx* pRight) {
bool ret = false;
uint8_t pLeftPrec = 0, pLeftScale = 0, pRightPrec = 0, pRightScale = 0;
decimalFromTypeMod(pLeft->typeMod, &pLeftPrec, &pLeftScale);
decimalFromTypeMod(pRight->typeMod, &pRightPrec, &pRightScale);
int32_t deltaScale = pLeftScale - pRightScale;
uint8_t leftPrec = 0, leftScale = 0, rightPrec = 0, rightScale = 0;
decimalFromTypeMod(pLeft->typeMod, &leftPrec, &leftScale);
decimalFromTypeMod(pRight->typeMod, &rightPrec, &rightScale);
int32_t deltaScale = leftScale - rightScale;
Decimal pLeftDec = *(Decimal*)pLeft->pData, pRightDec = *(Decimal*)pRight->pData;
if (deltaScale != 0) {
bool needInt256 = (deltaScale < 0 && pLeftPrec - deltaScale > TSDB_DECIMAL_MAX_PRECISION) ||
(pRightPrec + deltaScale > TSDB_DECIMAL_MAX_PRECISION);
bool needInt256 = (deltaScale < 0 && leftPrec - deltaScale > TSDB_DECIMAL_MAX_PRECISION) ||
(rightPrec + deltaScale > TSDB_DECIMAL_MAX_PRECISION);
if (needInt256) {
// TODO wjm impl it
Int256 x = {0}, y = {0};
makeInt256FromDecimal128(&x, &pLeftDec);
makeInt256FromDecimal128(&y, &pRightDec);
if (leftScale < rightScale) {
x = int256ScaleBy(&x, rightScale - leftScale);
} else {
y = int256ScaleBy(&y, leftScale - rightScale);
}
switch (op) {
case OP_TYPE_GREATER_THAN:
return int256Gt(&x, &y);
case OP_TYPE_GREATER_EQUAL:
return !int256Lt(&x, &y);
case OP_TYPE_LOWER_THAN:
return int256Lt(&x, &y);
case OP_TYPE_LOWER_EQUAL:
return !int256Gt(&x, &y);
case OP_TYPE_EQUAL:
return int256Eq(&x, &y);
case OP_TYPE_NOT_EQUAL:
return !int256Eq(&x, &y);
default:
break;
}
return false;
} else {
if (deltaScale < 0) {

View File

@ -198,6 +198,7 @@ class Numeric {
uint8_t prec() const { return prec_; }
uint8_t scale() const { return scale_; }
const Type& dec() const { return dec_; }
STypeMod get_type_mod() const { return decimalCalcTypeMod(prec(), scale()); }
template <int BitNum2>
Numeric& binaryOp(const Numeric<BitNum2>& r, EOperatorType op) {
@ -289,10 +290,6 @@ class Numeric {
return binaryOp(r, OP_TYPE_ADD);
}
template <int BitNum2>
bool operator==(const Numeric<BitNum2>& r) {
return binaryOp(r, OP_TYPE_EQUAL);
}
std::string toString() const {
char buf[64] = {0};
int32_t code = decimalToStr(&dec_, NumericType<BitNum>::dataType, prec(), scale(), buf, 64);
@ -406,6 +403,45 @@ class Numeric {
if (code != 0) throw std::runtime_error(tstrerror(code));
return *this;
}
template <int BitNum2>
Numeric(const Numeric<BitNum2>& num2) {
Numeric();
*this = num2;
}
template <int BitNum2>
Numeric& operator=(const Numeric<BitNum2>& num2) {
static_assert(BitNum2 == 64 || BitNum2 == 128, "Only support decimal128/decimal64");
SDataType inputDt = {
.type = num2.type().type, .precision = num2.prec(), .scale = num2.scale(), .bytes = num2.type().bytes};
SDataType outputDt = {.type = NumericType<BitNum>::dataType,
.precision = NumericType<BitNum>::maxPrec,
.scale = num2.scale(),
.bytes = NumericType<BitNum>::bytes};
int32_t code = convertToDecimal(&num2.dec(), &inputDt, &dec_, &outputDt);
if (code == TSDB_CODE_DECIMAL_OVERFLOW) throw std::overflow_error(tstrerror(code));
if (code != 0) throw std::runtime_error(tstrerror(code));
prec_ = outputDt.precision;
scale_ = outputDt.scale;
return *this;
}
#define DEFINE_COMPARE_OP(op, op_type) \
template <typename T> \
bool operator op(const T& t) { \
Numeric<128> lDec = *this, rDec = *this; \
rDec = t; \
SDecimalCompareCtx l = {(void*)&lDec.dec(), TSDB_DATA_TYPE_DECIMAL, lDec.get_type_mod()}, \
r = {(void*)&rDec.dec(), TSDB_DATA_TYPE_DECIMAL, rDec.get_type_mod()}; \
return decimalCompare(op_type, &l, &r); \
}
DEFINE_COMPARE_OP(>, OP_TYPE_GREATER_THAN);
DEFINE_COMPARE_OP(>=, OP_TYPE_GREATER_EQUAL);
DEFINE_COMPARE_OP(<, OP_TYPE_LOWER_THAN);
DEFINE_COMPARE_OP(<=, OP_TYPE_LOWER_EQUAL);
DEFINE_COMPARE_OP(==, OP_TYPE_EQUAL);
DEFINE_COMPARE_OP(!=, OP_TYPE_NOT_EQUAL);
};
template <int BitNum>
@ -1237,6 +1273,18 @@ TEST(decimal_all, ret_type_load_from_file) {
ASSERT_EQ(total_lines, 3034205);
}
TEST(decimal_all, test_decimal_compare) {
Numeric<64> dec64 = {10, 2, "123.23"};
Numeric<64> dec64_2 = {11, 10, "1.23"};
ASSERT_FALSE(dec64_2 > dec64);
dec64 = "10123456.23";
ASSERT_FALSE(dec64_2 > dec64);
ASSERT_TRUE(dec64 > dec64_2);
ASSERT_TRUE(dec64_2 < 100);
Numeric<128> dec128 = {38, 10, "1.23"};
ASSERT_TRUE(dec128 == dec64_2);
}
TEST(decimal_all, ret_type_for_non_decimal_types) {
std::vector<DecimalRetTypeCheckContent> non_decimal_types;
SDataType decimal_type = {TSDB_DATA_TYPE_DECIMAL64, 10, 2, 8};

View File

@ -1104,6 +1104,8 @@ STypeMod getConvertTypeMod(int32_t type, const SColumnInfo* pCol1, const SColumn
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), pCol1->scale);
} else if (pCol2 && IS_DECIMAL_TYPE(pCol2->type) && !IS_DECIMAL_TYPE(pCol1->type)) {
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), pCol2->scale);
} else if (IS_DECIMAL_TYPE(pCol1->type) && pCol2 && IS_DECIMAL_TYPE(pCol2->type)) {
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), MAX(pCol1->scale, pCol2->scale));
} else {
return 0;
}

View File

@ -1,6 +1,7 @@
import math
from random import randrange
import random
from re import I
import time
import threading
import secrets
@ -25,11 +26,12 @@ scalar_convert_err = -2147470768
decimal_insert_validator_test = False
operator_test_round = 1
operator_test_round = 10
tb_insert_rows = 1000
binary_op_with_const_test = True
binary_op_with_col_test = True
unary_op_test = True
binary_op_with_const_test = False
binary_op_with_col_test = False
unary_op_test = False
binary_op_in_where_test = True
class DecimalTypeGeneratorConfig:
def __init__(self):
@ -191,12 +193,16 @@ class DecimalColumnExpr:
self.executor_ = executor
self.params_ = ()
self.res_type_: DataType = None
self.query_col: Column = None
def __str__(self):
return f"({self.format_})".format(*self.params_)
def execute(self, params):
return self.executor_(self, params)
def get_query_col_val(self, tbname, i):
return self.query_col.get_val_for_execute(tbname, i)
def get_val(self, tbname: str, idx: int):
params = ()
@ -215,6 +221,31 @@ class DecimalColumnExpr:
def should_skip_for_decimal(self, cols: list):
return False
def check_for_filtering(self, query_col_res: List, tbname: str):
j: int = -1
for i in range(len(query_col_res)):
j += 1
v_from_query = query_col_res[i]
while True:
params = ()
for p in self.params_:
if isinstance(p, Column) or isinstance(p, DecimalColumnExpr):
p = p.get_val_for_execute(tbname, j)
params = params + (p,)
v_from_calc_in_py = self.execute(params)
if not v_from_calc_in_py:
j += 1
continue
else:
break
dec_from_query = Decimal(v_from_query)
dec_from_calc = self.get_query_col_val(tbname, j)
if dec_from_query != dec_from_calc:
tdLog.exit(f"filter with {self} failed, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}")
else:
tdLog.info(f"filter with {self} succ, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}")
def check(self, query_col_res: List, tbname: str):
for i in range(len(query_col_res)):
@ -776,10 +807,6 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def should_skip_for_decimal(self, cols: list):
left_col = cols[0]
right_col = cols[1]
if not left_col.is_constant_col() and left_col.name_ == '':
return True
if not right_col.is_constant_col() and right_col.name_ == '':
return True
if not left_col.type_.is_decimal_type() and not right_col.type_.is_decimal_type():
return True
if self.op_ != "%":
@ -933,6 +960,10 @@ class DecimalBinaryOperator(DecimalColumnExpr):
return float(left) == float(right)
else:
return Decimal(left) == Decimal(right)
def execute_eq_filtering(self, params):
if self.execute_eq(params):
return True
def execute_ne(self, params):
if DecimalBinaryOperator.check_null(params):
@ -994,6 +1025,17 @@ class DecimalBinaryOperator(DecimalColumnExpr):
DecimalBinaryOperator(" {0} >= {1} ", DecimalBinaryOperator.execute_ge, ">="),
DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="),
]
@staticmethod
def get_all_filtering_binary_compare_ops() -> List[DecimalColumnExpr]:
return [
DecimalBinaryOperator(" {0} == {1} ", DecimalBinaryOperator.execute_eq, "=="),
DecimalBinaryOperator(" {0} != {1} ", DecimalBinaryOperator.execute_ne, "!="),
DecimalBinaryOperator(" {0} > {1} ", DecimalBinaryOperator.execute_gt, ">"),
DecimalBinaryOperator(" {0} < {1} ", DecimalBinaryOperator.execute_lt, "<"),
DecimalBinaryOperator(" {0} >= {1} ", DecimalBinaryOperator.execute_ge, ">="),
DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="),
]
def execute(self, params):
return super().execute(params)
@ -1406,6 +1448,41 @@ class TDTestCase:
else:
tdLog.info(f"sql: {sql} got no output")
def check_decimal_where_with_binary_expr_with_const_col_results(
self,
dbname,
tbname,
tb_cols: List[Column],
constant_cols: List[Column],
exprs: List[DecimalColumnExpr],
):
if not binary_op_in_where_test:
return
for expr in exprs:
for col in tb_cols:
if col.name_ == '':
continue
left_is_decimal = col.type_.is_decimal_type()
for const_col in constant_cols:
right_is_decimal = const_col.type_.is_decimal_type()
if expr.should_skip_for_decimal([col, const_col]):
continue
const_col.generate_value()
select_expr = expr.generate((const_col, col))
expr.query_col = col
sql = f"select {col} from {dbname}.{tbname} where {select_expr}"
res = TaosShell().query(sql)
##TODO wjm no need to check len(res) for filtering test, cause we need to check for every row in the table to check if the filtering is working
if len(res) > 0:
expr.check_for_filtering(res[0], tbname)
select_expr = expr.generate((col, const_col))
sql = f"select {col} from {dbname}.{tbname} where {select_expr}"
res = TaosShell().query(sql)
if len(res) > 0:
expr.check_for_filtering(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
def check_decimal_binary_expr_with_const_col_results(
self,
dbname,
@ -1530,6 +1607,8 @@ class TDTestCase:
self.norm_tb_columns,
unary_operators,)
self.test_query_decimal_where_clause()
def test_decimal_functions(self):
self.test_decimal_last_first_func()
funcs = ["max", "min", "sum", "avg", "count", "first", "last", "cast"]
@ -1541,6 +1620,23 @@ class TDTestCase:
pass
def test_query_decimal_where_clause(self):
binary_compare_ops = DecimalBinaryOperator.get_all_filtering_binary_compare_ops()
const_cols = Column.get_decimal_oper_const_cols()
for i in range(operator_test_round):
self.check_decimal_where_with_binary_expr_with_const_col_results(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
const_cols,
binary_compare_ops,
)
## test filtering with decimal exprs
## 1. dec op const col
## 2. dec op dec
## 3. (dec op const col) op const col
## 4. (dec op dec) op const col
## 5. (dec op const col) op dec
## 6. (dec op dec) op dec
pass
def test_query_decimal_order_clause(self):