test decimal compare operators

This commit is contained in:
wangjiaming0909 2025-03-02 17:24:30 +08:00
parent 66e1fddce5
commit cd9942e5c0
6 changed files with 180 additions and 134 deletions

View File

@ -641,8 +641,11 @@ static bool decimal128Lt(const DecimalType* pLeft, const DecimalType* pRight, ui
// TODO wjm pRightDec use const
Decimal128 *pLeftDec = (Decimal128*)pLeft, *pRightDec = (Decimal128*)pRight;
Decimal128 right = {0};
char left_buf[64] = {0}, right_buf[64] = {0}; // TODO wjm remove it
DECIMAL128_CHECK_RIGHT_WORD_NUM(rightWordNum, pRightDec, right, pRight);
decimal128ToStr(pLeftDec, 0, left_buf, 64);
decimal128ToStr(pRightDec, 0, right_buf, 64);
return DECIMAL128_HIGH_WORD(pLeftDec) < DECIMAL128_HIGH_WORD(pRightDec) ||
(DECIMAL128_HIGH_WORD(pLeftDec) == DECIMAL128_HIGH_WORD(pRightDec) &&
DECIMAL128_LOW_WORD(pLeftDec) < DECIMAL128_LOW_WORD(pRightDec));
@ -1377,7 +1380,7 @@ static int32_t decimal128FromDecimal128(DecimalType* pDec, uint8_t prec, uint8_t
Decimal128 max = {0};
DECIMAL128_GET_MAX(prec - scale, &max);
decimal128ScaleTo(&tmpDec, valScale, 0);
decimal128ScaleDown(&tmpDec, valScale, false);
if (decimal128Lt(&max, &tmpDec, WORD_NUM(Decimal128))) {
return TSDB_CODE_DECIMAL_OVERFLOW;
}
@ -1386,78 +1389,79 @@ static int32_t decimal128FromDecimal128(DecimalType* pDec, uint8_t prec, uint8_t
return 0;
}
#define CONVERT_TO_DECIMAL(pData, pInputType, pOut, pOutType, decimal) \
({ \
int32_t code = 0; \
int64_t val = 0; \
uint64_t uval = 0; \
double dval = 0; \
switch (pInputType->type) { \
case TSDB_DATA_TYPE_NULL: \
break; \
case TSDB_DATA_TYPE_BOOL: \
uval = *(const bool*)pData; \
code = decimal##FromUint64(pOut, pOutType->precision, pOutType->scale, uval); \
break; \
case TSDB_DATA_TYPE_TINYINT: \
val = *(const int8_t*)pData; \
code = decimal##FromInt64(pOut, pOutType->precision, pOutType->scale, val); \
break; \
case TSDB_DATA_TYPE_SMALLINT: \
val = *(const int16_t*)pData; \
code = decimal##FromInt64(pOut, pOutType->precision, pOutType->scale, val); \
break; \
case TSDB_DATA_TYPE_INT: \
val = *(const int32_t*)pData; \
code = decimal##FromInt64(pOut, pOutType->precision, pOutType->scale, val); \
break; \
case TSDB_DATA_TYPE_TIMESTAMP: \
case TSDB_DATA_TYPE_BIGINT: \
val = *(const int64_t*)pData; \
code = decimal##FromInt64(pOut, pOutType->precision, pOutType->scale, val); \
break; \
case TSDB_DATA_TYPE_UTINYINT: \
uval = *(const uint8_t*)pData; \
code = decimal##FromUint64(pOut, pOutType->precision, pOutType->scale, uval); \
break; \
case TSDB_DATA_TYPE_USMALLINT: \
uval = *(const uint16_t*)pData; \
code = decimal##FromUint64(pOut, pOutType->precision, pOutType->scale, uval); \
break; \
case TSDB_DATA_TYPE_UINT: \
uval = *(const uint32_t*)pData; \
code = decimal##FromUint64(pOut, pOutType->precision, pOutType->scale, uval); \
break; \
case TSDB_DATA_TYPE_UBIGINT: \
uval = *(const uint64_t*)pData; \
code = decimal##FromUint64(pOut, pOutType->precision, pOutType->scale, uval); \
break; \
case TSDB_DATA_TYPE_FLOAT: { \
dval = *(const float*)pData; \
code = decimal##FromDouble(pOut, pOutType->precision, pOutType->scale, dval); \
} break; \
case TSDB_DATA_TYPE_DOUBLE: { \
dval = *(const double*)pData; \
code = decimal##FromDouble(pOut, pOutType->precision, pOutType->scale, dval); \
} break; \
case TSDB_DATA_TYPE_DECIMAL64: { \
code = decimal##FromDecimal64(pOut, pOutType->precision, pOutType->scale, pData, pInputType->precision, \
pInputType->scale); \
} break; \
case TSDB_DATA_TYPE_DECIMAL: { \
code = decimal##FromDecimal128(pOut, pOutType->precision, pOutType->scale, pData, pInputType->precision, \
pInputType->scale); \
} break; \
case TSDB_DATA_TYPE_VARCHAR: \
case TSDB_DATA_TYPE_VARBINARY: \
case TSDB_DATA_TYPE_NCHAR: { \
code = decimal##FromStr(pData, pInputType->bytes, pOutType->precision, pOutType->scale, pOut); \
} break; \
default: \
code = TSDB_CODE_OPS_NOT_SUPPORT; \
break; \
} \
code; \
#define CONVERT_TO_DECIMAL(pData, pInputType, pOut, pOutType, decimal) \
({ \
int32_t code = 0; \
int64_t val = 0; \
uint64_t uval = 0; \
double dval = 0; \
switch (pInputType->type) { \
case TSDB_DATA_TYPE_NULL: \
break; \
case TSDB_DATA_TYPE_BOOL: \
uval = *(const bool*)pData; \
code = decimal##FromUint64(pOut, pOutType->precision, pOutType->scale, uval); \
break; \
case TSDB_DATA_TYPE_TINYINT: \
val = *(const int8_t*)pData; \
code = decimal##FromInt64(pOut, pOutType->precision, pOutType->scale, val); \
break; \
case TSDB_DATA_TYPE_SMALLINT: \
val = *(const int16_t*)pData; \
code = decimal##FromInt64(pOut, pOutType->precision, pOutType->scale, val); \
break; \
case TSDB_DATA_TYPE_INT: \
val = *(const int32_t*)pData; \
code = decimal##FromInt64(pOut, pOutType->precision, pOutType->scale, val); \
break; \
case TSDB_DATA_TYPE_TIMESTAMP: \
case TSDB_DATA_TYPE_BIGINT: \
val = *(const int64_t*)pData; \
code = decimal##FromInt64(pOut, pOutType->precision, pOutType->scale, val); \
break; \
case TSDB_DATA_TYPE_UTINYINT: \
uval = *(const uint8_t*)pData; \
code = decimal##FromUint64(pOut, pOutType->precision, pOutType->scale, uval); \
break; \
case TSDB_DATA_TYPE_USMALLINT: \
uval = *(const uint16_t*)pData; \
code = decimal##FromUint64(pOut, pOutType->precision, pOutType->scale, uval); \
break; \
case TSDB_DATA_TYPE_UINT: \
uval = *(const uint32_t*)pData; \
code = decimal##FromUint64(pOut, pOutType->precision, pOutType->scale, uval); \
break; \
case TSDB_DATA_TYPE_UBIGINT: \
uval = *(const uint64_t*)pData; \
code = decimal##FromUint64(pOut, pOutType->precision, pOutType->scale, uval); \
break; \
case TSDB_DATA_TYPE_FLOAT: { \
dval = *(const float*)pData; \
code = decimal##FromDouble(pOut, pOutType->precision, pOutType->scale, dval); \
} break; \
case TSDB_DATA_TYPE_DOUBLE: { \
dval = *(const double*)pData; \
code = decimal##FromDouble(pOut, pOutType->precision, pOutType->scale, dval); \
} break; \
case TSDB_DATA_TYPE_DECIMAL64: { \
code = decimal##FromDecimal64(pOut, pOutType->precision, pOutType->scale, pData, pInputType->precision, \
pInputType->scale); \
} break; \
case TSDB_DATA_TYPE_DECIMAL: { \
code = decimal##FromDecimal128(pOut, pOutType->precision, pOutType->scale, pData, pInputType->precision, \
pInputType->scale); \
} break; \
case TSDB_DATA_TYPE_VARCHAR: \
case TSDB_DATA_TYPE_VARBINARY: \
case TSDB_DATA_TYPE_NCHAR: { \
code = decimal##FromStr(pData, pInputType->bytes - VARSTR_HEADER_SIZE, pOutType->precision, pOutType->scale, \
pOut); \
} break; \
default: \
code = TSDB_CODE_OPS_NOT_SUPPORT; \
break; \
} \
code; \
})
int32_t convertToDecimal(const void* pData, const SDataType* pInputType, void* pOut, const SDataType* pOutType) {

View File

@ -230,12 +230,14 @@ Int256 int256Divide(const Int256* pLeft, const Int256* pRight) {
}
Int256 int256Mod(const Int256* pLeft, const Int256* pRight) {
Int256 left = *pLeft;
Int256 left = *pLeft, right = *pRight;
bool leftNegative = int256Lt(pLeft, &int256Zero);
if (leftNegative) {
left = int256Abs(&left);
}
intx::uint256 result = *(intx::uint256*)&left % *(intx::uint256*)pRight;
bool rightNegate = int256Lt(pRight, &int256Zero);
if (rightNegate) right = int256Abs(pRight);
intx::uint256 result = *(intx::uint256*)&left % *(intx::uint256*)&right;
Int256 res = *(Int256*)&result;
if (leftNegative) res = int256Negate(&res);
return res;

View File

@ -436,8 +436,8 @@ TEST(decimal, numeric) {
ASSERT_EQ(os.toString(), "15241399.13627308800000");
os = dec * dec128;
ASSERT_EQ(os.toStringTrimTailingZeros(), "15241481344302520.993184");
ASSERT_EQ(os.toString(), "15241481344302520.993184");
ASSERT_EQ(os.toStringTrimTailingZeros(), "15241481344302520.993185");
ASSERT_EQ(os.toString(), "15241481344302520.993185");
os2 = os / dec128;
ASSERT_EQ(os2.toStringTrimTailingZeros(), "123.456");
@ -1027,7 +1027,7 @@ TEST(decimal, decimalFromStr_all) {
TEST(decimal, op_overflow) {
// divide 0 error
Numeric<128> dec{38, 2, string(36, '9') + ".99"};
ASSERT_OVERFLOW(dec / 0); // TODO wjm add divide by 0 error code
ASSERT_OVERFLOW(dec / 0);
// test decimal128Max
Numeric<128> max{38, 10, "0"};

View File

@ -4205,8 +4205,9 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn
pInput = &valNode->datum.d;
break;
case TSDB_DATA_TYPE_VARCHAR:
pInput = &valNode->datum.p;
pInput = valNode->literal;
break;
// TODO wjm test cast to decimal
default:
qError("not supported type %d when build decimal datum from value node", valNode->node.resType.type);
return TSDB_CODE_INVALID_PARA;

View File

@ -1143,6 +1143,7 @@ int32_t vectorConvertCols(SScalarParam *pLeft, SScalarParam *pRight, SScalarPara
if (leftType == rightType) {
if (IS_DECIMAL_TYPE(leftType)) {
//TODO wjm force do conversion for decimal type, do not convert any more, do conversion inside decimal.c
//TODO wjm where c1 = "999999999999999.99999"; this str will be converted to double and do comapre, add doc in TS
}
return TSDB_CODE_SUCCESS;
}

View File

@ -1,4 +1,3 @@
from ast import Tuple
import math
from random import randrange
import random
@ -23,6 +22,10 @@ invalid_encode_param = -2147483087
invalid_operation = -2147483136
scalar_convert_err = -2147470768
operator_test_round = 2
tb_insert_rows = 1000
class DecimalTypeGeneratorConfig:
def __init__(self):
self.enable_weight_overflow: bool = False
@ -181,7 +184,7 @@ class DecimalColumnExpr:
def __init__(self, format: str, executor):
self.format_: str = format
self.executor_ = executor
self.params_: Tuple = ()
self.params_ = ()
self.res_type_: DataType = None
def __str__(self):
@ -191,7 +194,7 @@ class DecimalColumnExpr:
return self.executor_(self, params)
def get_val(self, tbname: str, idx: int):
params: Tuple = ()
params = ()
for p in self.params_:
params = params + (p.get_val(tbname, idx),)
return self.execute(params)
@ -202,7 +205,7 @@ class DecimalColumnExpr:
def check(self, query_col_res: List, tbname: str):
for i in range(len(query_col_res)):
v_from_query = query_col_res[i]
params: Tuple = ()
params = ()
for p in self.params_:
if isinstance(p, Column) or isinstance(p, DecimalColumnExpr):
p = p.get_val_for_execute(tbname, i)
@ -215,21 +218,25 @@ class DecimalColumnExpr:
tdLog.debug(f"query with expr: {self} calc got same result: NULL")
continue
failed = False
if isinstance(v_from_calc_in_py, float):
dec_from_query = float(v_from_query)
dec_from_insert = float(v_from_calc_in_py)
failed = not math.isclose(dec_from_query, dec_from_insert, abs_tol=1e-7)
if self.res_type_.type == TypeEnum.BOOL:
query_res = bool(int(v_from_query))
calc_res = bool(int(v_from_calc_in_py))
failed = query_res != calc_res
elif isinstance(v_from_calc_in_py, float):
query_res = float(v_from_query)
calc_res = float(v_from_calc_in_py)
failed = not math.isclose(query_res, calc_res, abs_tol=1e-7)
else:
dec_from_query = Decimal(v_from_query)
dec_from_insert = Decimal(v_from_calc_in_py)
failed = dec_from_query != dec_from_insert
query_res = Decimal(v_from_query)
calc_res = Decimal(v_from_calc_in_py)
failed = query_res != calc_res
if failed:
tdLog.exit(
f"check decimal column failed for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, query: {v_from_query}, expect {dec_from_insert}, but get {dec_from_query}"
f"check decimal column failed for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, query: {v_from_query}, expect {calc_res}, but get {query_res}"
)
else:
tdLog.debug(
f"check decimal succ for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {dec_from_insert}"
f"check decimal succ for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {calc_res}"
)
## format_params are already been set
@ -743,7 +750,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def __str__(self):
return super().__str__()
def generate(self, format_params: Tuple) -> str:
def generate(self, format_params) -> str:
return super().generate(format_params)
def should_skip_for_decimal(self, left_col: Column, right_col: Column):
@ -811,18 +818,18 @@ class DecimalBinaryOperator(DecimalColumnExpr):
return DecimalType(TypeEnum.DECIMAL, out_prec, out_scale)
def generate_res_type(self):
if DecimalBinaryOperator.is_compare_op(self.op_):
self.res_type_ = DataType(TypeEnum.BOOL)
return
left_type = self.params_[0].type_
right_type = self.params_[1].type_
self.left_type_ = left_type
self.right_type_ = right_type
ret_double_types = [TypeEnum.VARCHAR, TypeEnum.BINARY, TypeEnum.DOUBLE, TypeEnum.FLOAT, TypeEnum.NCHAR]
if left_type.type in ret_double_types or right_type.type in ret_double_types:
self.res_type_ = DataType(TypeEnum.DOUBLE)
if DecimalBinaryOperator.is_compare_op(self.op_):
self.res_type_ = DataType(TypeEnum.BOOL)
else:
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(left_type, right_type, self.op_)
ret_double_types = [TypeEnum.VARCHAR, TypeEnum.BINARY, TypeEnum.DOUBLE, TypeEnum.FLOAT, TypeEnum.NCHAR]
if left_type.type in ret_double_types or right_type.type in ret_double_types:
self.res_type_ = DataType(TypeEnum.DOUBLE)
else:
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(left_type, right_type, self.op_)
def get_input_types(self)-> list:
return [self.left_type_, self.right_type_]
@ -834,7 +841,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
return float(val)
@staticmethod
def get_ret_type(params) -> Tuple:
def get_convert_type(params):
ret_float = False
if isinstance(params[0], float) or isinstance(params[1], float):
ret_float = True
@ -851,7 +858,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def execute_plus(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
(left, right), ret_float = DecimalBinaryOperator.get_convert_type(params)
if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) + float(right)
else:
@ -860,7 +867,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def execute_minus(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
(left, right), ret_float = DecimalBinaryOperator.get_convert_type(params)
if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) - float(right)
else:
@ -869,7 +876,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def execute_mul(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
(left, right), ret_float = DecimalBinaryOperator.get_convert_type(params)
if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) * float(right)
else:
@ -878,7 +885,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def execute_div(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
(left, right), _ = DecimalBinaryOperator.get_convert_type(params)
if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) / float(right)
else:
@ -887,35 +894,65 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def execute_mod(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
(left, right), _ = DecimalBinaryOperator.get_convert_type(params)
if self.res_type_.type == TypeEnum.DOUBLE:
return self.convert_to_res_type(Decimal(left) % Decimal(right))
else:
return self.convert_to_res_type(Decimal(left) % Decimal(right))
def execute_eq(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), ret_float = DecimalBinaryOperator.get_convert_type(params)
if ret_float:
return float(left) == float(right)
else:
return Decimal(left) == Decimal(right)
@staticmethod
def execute_eq(params):
return params[0] == params[1]
def execute_ne(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = DecimalBinaryOperator.get_convert_type(params)
if convert_float:
return float(left) != float(right)
else:
return Decimal(left) != Decimal(right)
@staticmethod
def execute_ne(params):
return params[0] != params[1]
def execute_gt(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = DecimalBinaryOperator.get_convert_type(params)
if convert_float:
return float(left) > float(right)
else:
return Decimal(left) > Decimal(right)
@staticmethod
def execute_gt(params):
return params[0] > params[1]
def execute_lt(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = DecimalBinaryOperator.get_convert_type(params)
if convert_float:
return float(left) < float(right)
else:
return Decimal(left) < Decimal(right)
@staticmethod
def execute_lt(params):
return params[0] < params[1]
def execute_ge(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = DecimalBinaryOperator.get_convert_type(params)
if convert_float:
return float(left) >= float(right)
else:
return Decimal(left) >= Decimal(right)
@staticmethod
def execute_ge(params):
return params[0] >= params[1]
@staticmethod
def execute_le(params):
return params[0] <= params[1]
def execute_le(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = DecimalBinaryOperator.get_convert_type(params)
if convert_float:
return float(left) <= float(right)
else:
return Decimal(left) <= Decimal(right)
@staticmethod
def get_all_binary_ops() -> List[DecimalColumnExpr]:
@ -1184,7 +1221,7 @@ class TDTestCase:
f"{self.c_table_prefix}{i}",
self.stb_columns,
self.tags,
).insert(1000, 1537146000000, 500)
).insert(tb_insert_rows, 1537146000000, 500)
for i in range(self.c_table_num):
TableDataValidator(
@ -1193,7 +1230,7 @@ class TDTestCase:
TableInserter(
tdSql, self.db_name, self.norm_table_name, self.norm_tb_columns
).insert(1000, 1537146000000, 500, flush_database=True)
).insert(tb_insert_rows, 1537146000000, 500, flush_database=True)
TableDataValidator(
self.norm_tb_columns, self.norm_table_name, self.db_name
).validate()
@ -1381,13 +1418,14 @@ class TDTestCase:
all_type_columns = Column.get_decimal_oper_const_cols()
## decimal operator with constants of all other types
self.check_decimal_binary_expr_results(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
all_type_columns,
binary_operators,
)
for i in range(operator_test_round):
self.check_decimal_binary_expr_results(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
all_type_columns,
binary_operators,
)
## decimal operator with columns of all other types