test decimal col filtering
This commit is contained in:
parent
8dc25a53af
commit
9ed049fd0e
|
@ -1128,17 +1128,40 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR
|
|||
// There is no need to do type conversions, we assume that pLeftT and pRightT are all decimal128 types.
|
||||
bool decimalCompare(EOperatorType op, const SDecimalCompareCtx* pLeft, const SDecimalCompareCtx* pRight) {
|
||||
bool ret = false;
|
||||
uint8_t pLeftPrec = 0, pLeftScale = 0, pRightPrec = 0, pRightScale = 0;
|
||||
decimalFromTypeMod(pLeft->typeMod, &pLeftPrec, &pLeftScale);
|
||||
decimalFromTypeMod(pRight->typeMod, &pRightPrec, &pRightScale);
|
||||
int32_t deltaScale = pLeftScale - pRightScale;
|
||||
uint8_t leftPrec = 0, leftScale = 0, rightPrec = 0, rightScale = 0;
|
||||
decimalFromTypeMod(pLeft->typeMod, &leftPrec, &leftScale);
|
||||
decimalFromTypeMod(pRight->typeMod, &rightPrec, &rightScale);
|
||||
int32_t deltaScale = leftScale - rightScale;
|
||||
Decimal pLeftDec = *(Decimal*)pLeft->pData, pRightDec = *(Decimal*)pRight->pData;
|
||||
|
||||
if (deltaScale != 0) {
|
||||
bool needInt256 = (deltaScale < 0 && pLeftPrec - deltaScale > TSDB_DECIMAL_MAX_PRECISION) ||
|
||||
(pRightPrec + deltaScale > TSDB_DECIMAL_MAX_PRECISION);
|
||||
bool needInt256 = (deltaScale < 0 && leftPrec - deltaScale > TSDB_DECIMAL_MAX_PRECISION) ||
|
||||
(rightPrec + deltaScale > TSDB_DECIMAL_MAX_PRECISION);
|
||||
if (needInt256) {
|
||||
// TODO wjm impl it
|
||||
Int256 x = {0}, y = {0};
|
||||
makeInt256FromDecimal128(&x, &pLeftDec);
|
||||
makeInt256FromDecimal128(&y, &pRightDec);
|
||||
if (leftScale < rightScale) {
|
||||
x = int256ScaleBy(&x, rightScale - leftScale);
|
||||
} else {
|
||||
y = int256ScaleBy(&y, leftScale - rightScale);
|
||||
}
|
||||
switch (op) {
|
||||
case OP_TYPE_GREATER_THAN:
|
||||
return int256Gt(&x, &y);
|
||||
case OP_TYPE_GREATER_EQUAL:
|
||||
return !int256Lt(&x, &y);
|
||||
case OP_TYPE_LOWER_THAN:
|
||||
return int256Lt(&x, &y);
|
||||
case OP_TYPE_LOWER_EQUAL:
|
||||
return !int256Gt(&x, &y);
|
||||
case OP_TYPE_EQUAL:
|
||||
return int256Eq(&x, &y);
|
||||
case OP_TYPE_NOT_EQUAL:
|
||||
return !int256Eq(&x, &y);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
if (deltaScale < 0) {
|
||||
|
|
|
@ -198,6 +198,7 @@ class Numeric {
|
|||
uint8_t prec() const { return prec_; }
|
||||
uint8_t scale() const { return scale_; }
|
||||
const Type& dec() const { return dec_; }
|
||||
STypeMod get_type_mod() const { return decimalCalcTypeMod(prec(), scale()); }
|
||||
|
||||
template <int BitNum2>
|
||||
Numeric& binaryOp(const Numeric<BitNum2>& r, EOperatorType op) {
|
||||
|
@ -289,10 +290,6 @@ class Numeric {
|
|||
return binaryOp(r, OP_TYPE_ADD);
|
||||
}
|
||||
|
||||
template <int BitNum2>
|
||||
bool operator==(const Numeric<BitNum2>& r) {
|
||||
return binaryOp(r, OP_TYPE_EQUAL);
|
||||
}
|
||||
std::string toString() const {
|
||||
char buf[64] = {0};
|
||||
int32_t code = decimalToStr(&dec_, NumericType<BitNum>::dataType, prec(), scale(), buf, 64);
|
||||
|
@ -406,6 +403,45 @@ class Numeric {
|
|||
if (code != 0) throw std::runtime_error(tstrerror(code));
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <int BitNum2>
|
||||
Numeric(const Numeric<BitNum2>& num2) {
|
||||
Numeric();
|
||||
*this = num2;
|
||||
}
|
||||
|
||||
template <int BitNum2>
|
||||
Numeric& operator=(const Numeric<BitNum2>& num2) {
|
||||
static_assert(BitNum2 == 64 || BitNum2 == 128, "Only support decimal128/decimal64");
|
||||
SDataType inputDt = {
|
||||
.type = num2.type().type, .precision = num2.prec(), .scale = num2.scale(), .bytes = num2.type().bytes};
|
||||
SDataType outputDt = {.type = NumericType<BitNum>::dataType,
|
||||
.precision = NumericType<BitNum>::maxPrec,
|
||||
.scale = num2.scale(),
|
||||
.bytes = NumericType<BitNum>::bytes};
|
||||
int32_t code = convertToDecimal(&num2.dec(), &inputDt, &dec_, &outputDt);
|
||||
if (code == TSDB_CODE_DECIMAL_OVERFLOW) throw std::overflow_error(tstrerror(code));
|
||||
if (code != 0) throw std::runtime_error(tstrerror(code));
|
||||
prec_ = outputDt.precision;
|
||||
scale_ = outputDt.scale;
|
||||
return *this;
|
||||
}
|
||||
#define DEFINE_COMPARE_OP(op, op_type) \
|
||||
template <typename T> \
|
||||
bool operator op(const T& t) { \
|
||||
Numeric<128> lDec = *this, rDec = *this; \
|
||||
rDec = t; \
|
||||
SDecimalCompareCtx l = {(void*)&lDec.dec(), TSDB_DATA_TYPE_DECIMAL, lDec.get_type_mod()}, \
|
||||
r = {(void*)&rDec.dec(), TSDB_DATA_TYPE_DECIMAL, rDec.get_type_mod()}; \
|
||||
return decimalCompare(op_type, &l, &r); \
|
||||
}
|
||||
|
||||
DEFINE_COMPARE_OP(>, OP_TYPE_GREATER_THAN);
|
||||
DEFINE_COMPARE_OP(>=, OP_TYPE_GREATER_EQUAL);
|
||||
DEFINE_COMPARE_OP(<, OP_TYPE_LOWER_THAN);
|
||||
DEFINE_COMPARE_OP(<=, OP_TYPE_LOWER_EQUAL);
|
||||
DEFINE_COMPARE_OP(==, OP_TYPE_EQUAL);
|
||||
DEFINE_COMPARE_OP(!=, OP_TYPE_NOT_EQUAL);
|
||||
};
|
||||
|
||||
template <int BitNum>
|
||||
|
@ -1237,6 +1273,18 @@ TEST(decimal_all, ret_type_load_from_file) {
|
|||
ASSERT_EQ(total_lines, 3034205);
|
||||
}
|
||||
|
||||
TEST(decimal_all, test_decimal_compare) {
|
||||
Numeric<64> dec64 = {10, 2, "123.23"};
|
||||
Numeric<64> dec64_2 = {11, 10, "1.23"};
|
||||
ASSERT_FALSE(dec64_2 > dec64);
|
||||
dec64 = "10123456.23";
|
||||
ASSERT_FALSE(dec64_2 > dec64);
|
||||
ASSERT_TRUE(dec64 > dec64_2);
|
||||
ASSERT_TRUE(dec64_2 < 100);
|
||||
Numeric<128> dec128 = {38, 10, "1.23"};
|
||||
ASSERT_TRUE(dec128 == dec64_2);
|
||||
}
|
||||
|
||||
TEST(decimal_all, ret_type_for_non_decimal_types) {
|
||||
std::vector<DecimalRetTypeCheckContent> non_decimal_types;
|
||||
SDataType decimal_type = {TSDB_DATA_TYPE_DECIMAL64, 10, 2, 8};
|
||||
|
|
|
@ -1104,6 +1104,8 @@ STypeMod getConvertTypeMod(int32_t type, const SColumnInfo* pCol1, const SColumn
|
|||
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), pCol1->scale);
|
||||
} else if (pCol2 && IS_DECIMAL_TYPE(pCol2->type) && !IS_DECIMAL_TYPE(pCol1->type)) {
|
||||
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), pCol2->scale);
|
||||
} else if (IS_DECIMAL_TYPE(pCol1->type) && pCol2 && IS_DECIMAL_TYPE(pCol2->type)) {
|
||||
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), MAX(pCol1->scale, pCol2->scale));
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
from random import randrange
|
||||
import random
|
||||
from re import I
|
||||
import time
|
||||
import threading
|
||||
import secrets
|
||||
|
@ -25,11 +26,12 @@ scalar_convert_err = -2147470768
|
|||
|
||||
|
||||
decimal_insert_validator_test = False
|
||||
operator_test_round = 1
|
||||
operator_test_round = 10
|
||||
tb_insert_rows = 1000
|
||||
binary_op_with_const_test = True
|
||||
binary_op_with_col_test = True
|
||||
unary_op_test = True
|
||||
binary_op_with_const_test = False
|
||||
binary_op_with_col_test = False
|
||||
unary_op_test = False
|
||||
binary_op_in_where_test = True
|
||||
|
||||
class DecimalTypeGeneratorConfig:
|
||||
def __init__(self):
|
||||
|
@ -191,12 +193,16 @@ class DecimalColumnExpr:
|
|||
self.executor_ = executor
|
||||
self.params_ = ()
|
||||
self.res_type_: DataType = None
|
||||
self.query_col: Column = None
|
||||
|
||||
def __str__(self):
|
||||
return f"({self.format_})".format(*self.params_)
|
||||
|
||||
def execute(self, params):
|
||||
return self.executor_(self, params)
|
||||
|
||||
def get_query_col_val(self, tbname, i):
|
||||
return self.query_col.get_val_for_execute(tbname, i)
|
||||
|
||||
def get_val(self, tbname: str, idx: int):
|
||||
params = ()
|
||||
|
@ -215,6 +221,31 @@ class DecimalColumnExpr:
|
|||
|
||||
def should_skip_for_decimal(self, cols: list):
|
||||
return False
|
||||
|
||||
def check_for_filtering(self, query_col_res: List, tbname: str):
|
||||
j: int = -1
|
||||
for i in range(len(query_col_res)):
|
||||
j += 1
|
||||
v_from_query = query_col_res[i]
|
||||
while True:
|
||||
params = ()
|
||||
for p in self.params_:
|
||||
if isinstance(p, Column) or isinstance(p, DecimalColumnExpr):
|
||||
p = p.get_val_for_execute(tbname, j)
|
||||
params = params + (p,)
|
||||
v_from_calc_in_py = self.execute(params)
|
||||
|
||||
if not v_from_calc_in_py:
|
||||
j += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
dec_from_query = Decimal(v_from_query)
|
||||
dec_from_calc = self.get_query_col_val(tbname, j)
|
||||
if dec_from_query != dec_from_calc:
|
||||
tdLog.exit(f"filter with {self} failed, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}")
|
||||
else:
|
||||
tdLog.info(f"filter with {self} succ, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}")
|
||||
|
||||
def check(self, query_col_res: List, tbname: str):
|
||||
for i in range(len(query_col_res)):
|
||||
|
@ -776,10 +807,6 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
|||
def should_skip_for_decimal(self, cols: list):
|
||||
left_col = cols[0]
|
||||
right_col = cols[1]
|
||||
if not left_col.is_constant_col() and left_col.name_ == '':
|
||||
return True
|
||||
if not right_col.is_constant_col() and right_col.name_ == '':
|
||||
return True
|
||||
if not left_col.type_.is_decimal_type() and not right_col.type_.is_decimal_type():
|
||||
return True
|
||||
if self.op_ != "%":
|
||||
|
@ -933,6 +960,10 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
|||
return float(left) == float(right)
|
||||
else:
|
||||
return Decimal(left) == Decimal(right)
|
||||
|
||||
def execute_eq_filtering(self, params):
|
||||
if self.execute_eq(params):
|
||||
return True
|
||||
|
||||
def execute_ne(self, params):
|
||||
if DecimalBinaryOperator.check_null(params):
|
||||
|
@ -994,6 +1025,17 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
|||
DecimalBinaryOperator(" {0} >= {1} ", DecimalBinaryOperator.execute_ge, ">="),
|
||||
DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_all_filtering_binary_compare_ops() -> List[DecimalColumnExpr]:
|
||||
return [
|
||||
DecimalBinaryOperator(" {0} == {1} ", DecimalBinaryOperator.execute_eq, "=="),
|
||||
DecimalBinaryOperator(" {0} != {1} ", DecimalBinaryOperator.execute_ne, "!="),
|
||||
DecimalBinaryOperator(" {0} > {1} ", DecimalBinaryOperator.execute_gt, ">"),
|
||||
DecimalBinaryOperator(" {0} < {1} ", DecimalBinaryOperator.execute_lt, "<"),
|
||||
DecimalBinaryOperator(" {0} >= {1} ", DecimalBinaryOperator.execute_ge, ">="),
|
||||
DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="),
|
||||
]
|
||||
|
||||
def execute(self, params):
|
||||
return super().execute(params)
|
||||
|
@ -1406,6 +1448,41 @@ class TDTestCase:
|
|||
else:
|
||||
tdLog.info(f"sql: {sql} got no output")
|
||||
|
||||
def check_decimal_where_with_binary_expr_with_const_col_results(
|
||||
self,
|
||||
dbname,
|
||||
tbname,
|
||||
tb_cols: List[Column],
|
||||
constant_cols: List[Column],
|
||||
exprs: List[DecimalColumnExpr],
|
||||
):
|
||||
if not binary_op_in_where_test:
|
||||
return
|
||||
for expr in exprs:
|
||||
for col in tb_cols:
|
||||
if col.name_ == '':
|
||||
continue
|
||||
left_is_decimal = col.type_.is_decimal_type()
|
||||
for const_col in constant_cols:
|
||||
right_is_decimal = const_col.type_.is_decimal_type()
|
||||
if expr.should_skip_for_decimal([col, const_col]):
|
||||
continue
|
||||
const_col.generate_value()
|
||||
select_expr = expr.generate((const_col, col))
|
||||
expr.query_col = col
|
||||
sql = f"select {col} from {dbname}.{tbname} where {select_expr}"
|
||||
res = TaosShell().query(sql)
|
||||
##TODO wjm no need to check len(res) for filtering test, cause we need to check for every row in the table to check if the filtering is working
|
||||
if len(res) > 0:
|
||||
expr.check_for_filtering(res[0], tbname)
|
||||
select_expr = expr.generate((col, const_col))
|
||||
sql = f"select {col} from {dbname}.{tbname} where {select_expr}"
|
||||
res = TaosShell().query(sql)
|
||||
if len(res) > 0:
|
||||
expr.check_for_filtering(res[0], tbname)
|
||||
else:
|
||||
tdLog.info(f"sql: {sql} got no output")
|
||||
|
||||
def check_decimal_binary_expr_with_const_col_results(
|
||||
self,
|
||||
dbname,
|
||||
|
@ -1530,6 +1607,8 @@ class TDTestCase:
|
|||
self.norm_tb_columns,
|
||||
unary_operators,)
|
||||
|
||||
self.test_query_decimal_where_clause()
|
||||
|
||||
def test_decimal_functions(self):
|
||||
self.test_decimal_last_first_func()
|
||||
funcs = ["max", "min", "sum", "avg", "count", "first", "last", "cast"]
|
||||
|
@ -1541,6 +1620,23 @@ class TDTestCase:
|
|||
pass
|
||||
|
||||
def test_query_decimal_where_clause(self):
|
||||
binary_compare_ops = DecimalBinaryOperator.get_all_filtering_binary_compare_ops()
|
||||
const_cols = Column.get_decimal_oper_const_cols()
|
||||
for i in range(operator_test_round):
|
||||
self.check_decimal_where_with_binary_expr_with_const_col_results(
|
||||
self.db_name,
|
||||
self.norm_table_name,
|
||||
self.norm_tb_columns,
|
||||
const_cols,
|
||||
binary_compare_ops,
|
||||
)
|
||||
## test filtering with decimal exprs
|
||||
## 1. dec op const col
|
||||
## 2. dec op dec
|
||||
## 3. (dec op const col) op const col
|
||||
## 4. (dec op dec) op const col
|
||||
## 5. (dec op const col) op dec
|
||||
## 6. (dec op dec) op dec
|
||||
pass
|
||||
|
||||
def test_query_decimal_order_clause(self):
|
||||
|
|
Loading…
Reference in New Issue