test decimal col filtering
This commit is contained in:
parent
8dc25a53af
commit
9ed049fd0e
|
@ -1128,17 +1128,40 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR
|
||||||
// There is no need to do type conversions, we assume that pLeftT and pRightT are all decimal128 types.
|
// There is no need to do type conversions, we assume that pLeftT and pRightT are all decimal128 types.
|
||||||
bool decimalCompare(EOperatorType op, const SDecimalCompareCtx* pLeft, const SDecimalCompareCtx* pRight) {
|
bool decimalCompare(EOperatorType op, const SDecimalCompareCtx* pLeft, const SDecimalCompareCtx* pRight) {
|
||||||
bool ret = false;
|
bool ret = false;
|
||||||
uint8_t pLeftPrec = 0, pLeftScale = 0, pRightPrec = 0, pRightScale = 0;
|
uint8_t leftPrec = 0, leftScale = 0, rightPrec = 0, rightScale = 0;
|
||||||
decimalFromTypeMod(pLeft->typeMod, &pLeftPrec, &pLeftScale);
|
decimalFromTypeMod(pLeft->typeMod, &leftPrec, &leftScale);
|
||||||
decimalFromTypeMod(pRight->typeMod, &pRightPrec, &pRightScale);
|
decimalFromTypeMod(pRight->typeMod, &rightPrec, &rightScale);
|
||||||
int32_t deltaScale = pLeftScale - pRightScale;
|
int32_t deltaScale = leftScale - rightScale;
|
||||||
Decimal pLeftDec = *(Decimal*)pLeft->pData, pRightDec = *(Decimal*)pRight->pData;
|
Decimal pLeftDec = *(Decimal*)pLeft->pData, pRightDec = *(Decimal*)pRight->pData;
|
||||||
|
|
||||||
if (deltaScale != 0) {
|
if (deltaScale != 0) {
|
||||||
bool needInt256 = (deltaScale < 0 && pLeftPrec - deltaScale > TSDB_DECIMAL_MAX_PRECISION) ||
|
bool needInt256 = (deltaScale < 0 && leftPrec - deltaScale > TSDB_DECIMAL_MAX_PRECISION) ||
|
||||||
(pRightPrec + deltaScale > TSDB_DECIMAL_MAX_PRECISION);
|
(rightPrec + deltaScale > TSDB_DECIMAL_MAX_PRECISION);
|
||||||
if (needInt256) {
|
if (needInt256) {
|
||||||
// TODO wjm impl it
|
Int256 x = {0}, y = {0};
|
||||||
|
makeInt256FromDecimal128(&x, &pLeftDec);
|
||||||
|
makeInt256FromDecimal128(&y, &pRightDec);
|
||||||
|
if (leftScale < rightScale) {
|
||||||
|
x = int256ScaleBy(&x, rightScale - leftScale);
|
||||||
|
} else {
|
||||||
|
y = int256ScaleBy(&y, leftScale - rightScale);
|
||||||
|
}
|
||||||
|
switch (op) {
|
||||||
|
case OP_TYPE_GREATER_THAN:
|
||||||
|
return int256Gt(&x, &y);
|
||||||
|
case OP_TYPE_GREATER_EQUAL:
|
||||||
|
return !int256Lt(&x, &y);
|
||||||
|
case OP_TYPE_LOWER_THAN:
|
||||||
|
return int256Lt(&x, &y);
|
||||||
|
case OP_TYPE_LOWER_EQUAL:
|
||||||
|
return !int256Gt(&x, &y);
|
||||||
|
case OP_TYPE_EQUAL:
|
||||||
|
return int256Eq(&x, &y);
|
||||||
|
case OP_TYPE_NOT_EQUAL:
|
||||||
|
return !int256Eq(&x, &y);
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
if (deltaScale < 0) {
|
if (deltaScale < 0) {
|
||||||
|
|
|
@ -198,6 +198,7 @@ class Numeric {
|
||||||
uint8_t prec() const { return prec_; }
|
uint8_t prec() const { return prec_; }
|
||||||
uint8_t scale() const { return scale_; }
|
uint8_t scale() const { return scale_; }
|
||||||
const Type& dec() const { return dec_; }
|
const Type& dec() const { return dec_; }
|
||||||
|
STypeMod get_type_mod() const { return decimalCalcTypeMod(prec(), scale()); }
|
||||||
|
|
||||||
template <int BitNum2>
|
template <int BitNum2>
|
||||||
Numeric& binaryOp(const Numeric<BitNum2>& r, EOperatorType op) {
|
Numeric& binaryOp(const Numeric<BitNum2>& r, EOperatorType op) {
|
||||||
|
@ -289,10 +290,6 @@ class Numeric {
|
||||||
return binaryOp(r, OP_TYPE_ADD);
|
return binaryOp(r, OP_TYPE_ADD);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int BitNum2>
|
|
||||||
bool operator==(const Numeric<BitNum2>& r) {
|
|
||||||
return binaryOp(r, OP_TYPE_EQUAL);
|
|
||||||
}
|
|
||||||
std::string toString() const {
|
std::string toString() const {
|
||||||
char buf[64] = {0};
|
char buf[64] = {0};
|
||||||
int32_t code = decimalToStr(&dec_, NumericType<BitNum>::dataType, prec(), scale(), buf, 64);
|
int32_t code = decimalToStr(&dec_, NumericType<BitNum>::dataType, prec(), scale(), buf, 64);
|
||||||
|
@ -406,6 +403,45 @@ class Numeric {
|
||||||
if (code != 0) throw std::runtime_error(tstrerror(code));
|
if (code != 0) throw std::runtime_error(tstrerror(code));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int BitNum2>
|
||||||
|
Numeric(const Numeric<BitNum2>& num2) {
|
||||||
|
Numeric();
|
||||||
|
*this = num2;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int BitNum2>
|
||||||
|
Numeric& operator=(const Numeric<BitNum2>& num2) {
|
||||||
|
static_assert(BitNum2 == 64 || BitNum2 == 128, "Only support decimal128/decimal64");
|
||||||
|
SDataType inputDt = {
|
||||||
|
.type = num2.type().type, .precision = num2.prec(), .scale = num2.scale(), .bytes = num2.type().bytes};
|
||||||
|
SDataType outputDt = {.type = NumericType<BitNum>::dataType,
|
||||||
|
.precision = NumericType<BitNum>::maxPrec,
|
||||||
|
.scale = num2.scale(),
|
||||||
|
.bytes = NumericType<BitNum>::bytes};
|
||||||
|
int32_t code = convertToDecimal(&num2.dec(), &inputDt, &dec_, &outputDt);
|
||||||
|
if (code == TSDB_CODE_DECIMAL_OVERFLOW) throw std::overflow_error(tstrerror(code));
|
||||||
|
if (code != 0) throw std::runtime_error(tstrerror(code));
|
||||||
|
prec_ = outputDt.precision;
|
||||||
|
scale_ = outputDt.scale;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
#define DEFINE_COMPARE_OP(op, op_type) \
|
||||||
|
template <typename T> \
|
||||||
|
bool operator op(const T& t) { \
|
||||||
|
Numeric<128> lDec = *this, rDec = *this; \
|
||||||
|
rDec = t; \
|
||||||
|
SDecimalCompareCtx l = {(void*)&lDec.dec(), TSDB_DATA_TYPE_DECIMAL, lDec.get_type_mod()}, \
|
||||||
|
r = {(void*)&rDec.dec(), TSDB_DATA_TYPE_DECIMAL, rDec.get_type_mod()}; \
|
||||||
|
return decimalCompare(op_type, &l, &r); \
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFINE_COMPARE_OP(>, OP_TYPE_GREATER_THAN);
|
||||||
|
DEFINE_COMPARE_OP(>=, OP_TYPE_GREATER_EQUAL);
|
||||||
|
DEFINE_COMPARE_OP(<, OP_TYPE_LOWER_THAN);
|
||||||
|
DEFINE_COMPARE_OP(<=, OP_TYPE_LOWER_EQUAL);
|
||||||
|
DEFINE_COMPARE_OP(==, OP_TYPE_EQUAL);
|
||||||
|
DEFINE_COMPARE_OP(!=, OP_TYPE_NOT_EQUAL);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int BitNum>
|
template <int BitNum>
|
||||||
|
@ -1237,6 +1273,18 @@ TEST(decimal_all, ret_type_load_from_file) {
|
||||||
ASSERT_EQ(total_lines, 3034205);
|
ASSERT_EQ(total_lines, 3034205);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(decimal_all, test_decimal_compare) {
|
||||||
|
Numeric<64> dec64 = {10, 2, "123.23"};
|
||||||
|
Numeric<64> dec64_2 = {11, 10, "1.23"};
|
||||||
|
ASSERT_FALSE(dec64_2 > dec64);
|
||||||
|
dec64 = "10123456.23";
|
||||||
|
ASSERT_FALSE(dec64_2 > dec64);
|
||||||
|
ASSERT_TRUE(dec64 > dec64_2);
|
||||||
|
ASSERT_TRUE(dec64_2 < 100);
|
||||||
|
Numeric<128> dec128 = {38, 10, "1.23"};
|
||||||
|
ASSERT_TRUE(dec128 == dec64_2);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(decimal_all, ret_type_for_non_decimal_types) {
|
TEST(decimal_all, ret_type_for_non_decimal_types) {
|
||||||
std::vector<DecimalRetTypeCheckContent> non_decimal_types;
|
std::vector<DecimalRetTypeCheckContent> non_decimal_types;
|
||||||
SDataType decimal_type = {TSDB_DATA_TYPE_DECIMAL64, 10, 2, 8};
|
SDataType decimal_type = {TSDB_DATA_TYPE_DECIMAL64, 10, 2, 8};
|
||||||
|
|
|
@ -1104,6 +1104,8 @@ STypeMod getConvertTypeMod(int32_t type, const SColumnInfo* pCol1, const SColumn
|
||||||
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), pCol1->scale);
|
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), pCol1->scale);
|
||||||
} else if (pCol2 && IS_DECIMAL_TYPE(pCol2->type) && !IS_DECIMAL_TYPE(pCol1->type)) {
|
} else if (pCol2 && IS_DECIMAL_TYPE(pCol2->type) && !IS_DECIMAL_TYPE(pCol1->type)) {
|
||||||
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), pCol2->scale);
|
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), pCol2->scale);
|
||||||
|
} else if (IS_DECIMAL_TYPE(pCol1->type) && pCol2 && IS_DECIMAL_TYPE(pCol2->type)) {
|
||||||
|
return decimalCalcTypeMod(GET_DEICMAL_MAX_PRECISION(type), MAX(pCol1->scale, pCol2->scale));
|
||||||
} else {
|
} else {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import math
|
import math
|
||||||
from random import randrange
|
from random import randrange
|
||||||
import random
|
import random
|
||||||
|
from re import I
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
import secrets
|
import secrets
|
||||||
|
@ -25,11 +26,12 @@ scalar_convert_err = -2147470768
|
||||||
|
|
||||||
|
|
||||||
decimal_insert_validator_test = False
|
decimal_insert_validator_test = False
|
||||||
operator_test_round = 1
|
operator_test_round = 10
|
||||||
tb_insert_rows = 1000
|
tb_insert_rows = 1000
|
||||||
binary_op_with_const_test = True
|
binary_op_with_const_test = False
|
||||||
binary_op_with_col_test = True
|
binary_op_with_col_test = False
|
||||||
unary_op_test = True
|
unary_op_test = False
|
||||||
|
binary_op_in_where_test = True
|
||||||
|
|
||||||
class DecimalTypeGeneratorConfig:
|
class DecimalTypeGeneratorConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -191,6 +193,7 @@ class DecimalColumnExpr:
|
||||||
self.executor_ = executor
|
self.executor_ = executor
|
||||||
self.params_ = ()
|
self.params_ = ()
|
||||||
self.res_type_: DataType = None
|
self.res_type_: DataType = None
|
||||||
|
self.query_col: Column = None
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"({self.format_})".format(*self.params_)
|
return f"({self.format_})".format(*self.params_)
|
||||||
|
@ -198,6 +201,9 @@ class DecimalColumnExpr:
|
||||||
def execute(self, params):
|
def execute(self, params):
|
||||||
return self.executor_(self, params)
|
return self.executor_(self, params)
|
||||||
|
|
||||||
|
def get_query_col_val(self, tbname, i):
|
||||||
|
return self.query_col.get_val_for_execute(tbname, i)
|
||||||
|
|
||||||
def get_val(self, tbname: str, idx: int):
|
def get_val(self, tbname: str, idx: int):
|
||||||
params = ()
|
params = ()
|
||||||
for p in self.params_:
|
for p in self.params_:
|
||||||
|
@ -216,6 +222,31 @@ class DecimalColumnExpr:
|
||||||
def should_skip_for_decimal(self, cols: list):
|
def should_skip_for_decimal(self, cols: list):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def check_for_filtering(self, query_col_res: List, tbname: str):
|
||||||
|
j: int = -1
|
||||||
|
for i in range(len(query_col_res)):
|
||||||
|
j += 1
|
||||||
|
v_from_query = query_col_res[i]
|
||||||
|
while True:
|
||||||
|
params = ()
|
||||||
|
for p in self.params_:
|
||||||
|
if isinstance(p, Column) or isinstance(p, DecimalColumnExpr):
|
||||||
|
p = p.get_val_for_execute(tbname, j)
|
||||||
|
params = params + (p,)
|
||||||
|
v_from_calc_in_py = self.execute(params)
|
||||||
|
|
||||||
|
if not v_from_calc_in_py:
|
||||||
|
j += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
dec_from_query = Decimal(v_from_query)
|
||||||
|
dec_from_calc = self.get_query_col_val(tbname, j)
|
||||||
|
if dec_from_query != dec_from_calc:
|
||||||
|
tdLog.exit(f"filter with {self} failed, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}")
|
||||||
|
else:
|
||||||
|
tdLog.info(f"filter with {self} succ, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}")
|
||||||
|
|
||||||
def check(self, query_col_res: List, tbname: str):
|
def check(self, query_col_res: List, tbname: str):
|
||||||
for i in range(len(query_col_res)):
|
for i in range(len(query_col_res)):
|
||||||
v_from_query = query_col_res[i]
|
v_from_query = query_col_res[i]
|
||||||
|
@ -776,10 +807,6 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
||||||
def should_skip_for_decimal(self, cols: list):
|
def should_skip_for_decimal(self, cols: list):
|
||||||
left_col = cols[0]
|
left_col = cols[0]
|
||||||
right_col = cols[1]
|
right_col = cols[1]
|
||||||
if not left_col.is_constant_col() and left_col.name_ == '':
|
|
||||||
return True
|
|
||||||
if not right_col.is_constant_col() and right_col.name_ == '':
|
|
||||||
return True
|
|
||||||
if not left_col.type_.is_decimal_type() and not right_col.type_.is_decimal_type():
|
if not left_col.type_.is_decimal_type() and not right_col.type_.is_decimal_type():
|
||||||
return True
|
return True
|
||||||
if self.op_ != "%":
|
if self.op_ != "%":
|
||||||
|
@ -934,6 +961,10 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
||||||
else:
|
else:
|
||||||
return Decimal(left) == Decimal(right)
|
return Decimal(left) == Decimal(right)
|
||||||
|
|
||||||
|
def execute_eq_filtering(self, params):
|
||||||
|
if self.execute_eq(params):
|
||||||
|
return True
|
||||||
|
|
||||||
def execute_ne(self, params):
|
def execute_ne(self, params):
|
||||||
if DecimalBinaryOperator.check_null(params):
|
if DecimalBinaryOperator.check_null(params):
|
||||||
return False
|
return False
|
||||||
|
@ -995,6 +1026,17 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
||||||
DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="),
|
DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_filtering_binary_compare_ops() -> List[DecimalColumnExpr]:
|
||||||
|
return [
|
||||||
|
DecimalBinaryOperator(" {0} == {1} ", DecimalBinaryOperator.execute_eq, "=="),
|
||||||
|
DecimalBinaryOperator(" {0} != {1} ", DecimalBinaryOperator.execute_ne, "!="),
|
||||||
|
DecimalBinaryOperator(" {0} > {1} ", DecimalBinaryOperator.execute_gt, ">"),
|
||||||
|
DecimalBinaryOperator(" {0} < {1} ", DecimalBinaryOperator.execute_lt, "<"),
|
||||||
|
DecimalBinaryOperator(" {0} >= {1} ", DecimalBinaryOperator.execute_ge, ">="),
|
||||||
|
DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="),
|
||||||
|
]
|
||||||
|
|
||||||
def execute(self, params):
|
def execute(self, params):
|
||||||
return super().execute(params)
|
return super().execute(params)
|
||||||
|
|
||||||
|
@ -1406,6 +1448,41 @@ class TDTestCase:
|
||||||
else:
|
else:
|
||||||
tdLog.info(f"sql: {sql} got no output")
|
tdLog.info(f"sql: {sql} got no output")
|
||||||
|
|
||||||
|
def check_decimal_where_with_binary_expr_with_const_col_results(
|
||||||
|
self,
|
||||||
|
dbname,
|
||||||
|
tbname,
|
||||||
|
tb_cols: List[Column],
|
||||||
|
constant_cols: List[Column],
|
||||||
|
exprs: List[DecimalColumnExpr],
|
||||||
|
):
|
||||||
|
if not binary_op_in_where_test:
|
||||||
|
return
|
||||||
|
for expr in exprs:
|
||||||
|
for col in tb_cols:
|
||||||
|
if col.name_ == '':
|
||||||
|
continue
|
||||||
|
left_is_decimal = col.type_.is_decimal_type()
|
||||||
|
for const_col in constant_cols:
|
||||||
|
right_is_decimal = const_col.type_.is_decimal_type()
|
||||||
|
if expr.should_skip_for_decimal([col, const_col]):
|
||||||
|
continue
|
||||||
|
const_col.generate_value()
|
||||||
|
select_expr = expr.generate((const_col, col))
|
||||||
|
expr.query_col = col
|
||||||
|
sql = f"select {col} from {dbname}.{tbname} where {select_expr}"
|
||||||
|
res = TaosShell().query(sql)
|
||||||
|
##TODO wjm no need to check len(res) for filtering test, cause we need to check for every row in the table to check if the filtering is working
|
||||||
|
if len(res) > 0:
|
||||||
|
expr.check_for_filtering(res[0], tbname)
|
||||||
|
select_expr = expr.generate((col, const_col))
|
||||||
|
sql = f"select {col} from {dbname}.{tbname} where {select_expr}"
|
||||||
|
res = TaosShell().query(sql)
|
||||||
|
if len(res) > 0:
|
||||||
|
expr.check_for_filtering(res[0], tbname)
|
||||||
|
else:
|
||||||
|
tdLog.info(f"sql: {sql} got no output")
|
||||||
|
|
||||||
def check_decimal_binary_expr_with_const_col_results(
|
def check_decimal_binary_expr_with_const_col_results(
|
||||||
self,
|
self,
|
||||||
dbname,
|
dbname,
|
||||||
|
@ -1530,6 +1607,8 @@ class TDTestCase:
|
||||||
self.norm_tb_columns,
|
self.norm_tb_columns,
|
||||||
unary_operators,)
|
unary_operators,)
|
||||||
|
|
||||||
|
self.test_query_decimal_where_clause()
|
||||||
|
|
||||||
def test_decimal_functions(self):
|
def test_decimal_functions(self):
|
||||||
self.test_decimal_last_first_func()
|
self.test_decimal_last_first_func()
|
||||||
funcs = ["max", "min", "sum", "avg", "count", "first", "last", "cast"]
|
funcs = ["max", "min", "sum", "avg", "count", "first", "last", "cast"]
|
||||||
|
@ -1541,6 +1620,23 @@ class TDTestCase:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_query_decimal_where_clause(self):
|
def test_query_decimal_where_clause(self):
|
||||||
|
binary_compare_ops = DecimalBinaryOperator.get_all_filtering_binary_compare_ops()
|
||||||
|
const_cols = Column.get_decimal_oper_const_cols()
|
||||||
|
for i in range(operator_test_round):
|
||||||
|
self.check_decimal_where_with_binary_expr_with_const_col_results(
|
||||||
|
self.db_name,
|
||||||
|
self.norm_table_name,
|
||||||
|
self.norm_tb_columns,
|
||||||
|
const_cols,
|
||||||
|
binary_compare_ops,
|
||||||
|
)
|
||||||
|
## test filtering with decimal exprs
|
||||||
|
## 1. dec op const col
|
||||||
|
## 2. dec op dec
|
||||||
|
## 3. (dec op const col) op const col
|
||||||
|
## 4. (dec op dec) op const col
|
||||||
|
## 5. (dec op const col) op dec
|
||||||
|
## 6. (dec op dec) op dec
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_query_decimal_order_clause(self):
|
def test_query_decimal_order_clause(self):
|
||||||
|
|
Loading…
Reference in New Issue