decimal test operators

This commit is contained in:
wangjiaming0909 2025-02-28 13:12:17 +08:00
parent e3e79f014b
commit 79c7ff8bf2
3 changed files with 171 additions and 109 deletions

View File

@ -653,7 +653,7 @@ static void decimal128Divide(DecimalType* pLeft, const DecimalType* pRight, uint
Decimal128 right = {0}; Decimal128 right = {0};
DECIMAL128_CHECK_RIGHT_WORD_NUM(rightWordNum, pRightDec, right, pRight); DECIMAL128_CHECK_RIGHT_WORD_NUM(rightWordNum, pRightDec, right, pRight);
bool negate = DECIMAL128_SIGN(pLeftDec) != DECIMAL128_SIGN(pRightDec); bool leftNegate = DECIMAL128_SIGN(pLeftDec) == -1, rightNegate = DECIMAL128_SIGN(pRightDec) == -1;
UInt128 a = {0}, b = {0}, c = {0}, d = {0}; UInt128 a = {0}, b = {0}, c = {0}, d = {0};
Decimal128 x = *pLeftDec, y = *pRightDec; Decimal128 x = *pLeftDec, y = *pRightDec;
decimal128Abs(&x); decimal128Abs(&x);
@ -666,8 +666,8 @@ static void decimal128Divide(DecimalType* pLeft, const DecimalType* pRight, uint
uInt128Mod(&d, &b); uInt128Mod(&d, &b);
makeDecimal128(pLeftDec, uInt128Hi(&a), uInt128Lo(&a)); makeDecimal128(pLeftDec, uInt128Hi(&a), uInt128Lo(&a));
if (pRemainder) makeDecimal128(pRemainderDec, uInt128Hi(&d), uInt128Lo(&d)); if (pRemainder) makeDecimal128(pRemainderDec, uInt128Hi(&d), uInt128Lo(&d));
if (negate) decimal128Negate(pLeftDec); if (leftNegate != rightNegate) decimal128Negate(pLeftDec);
if (DECIMAL128_SIGN(pLeftDec) == -1 && pRemainder) decimal128Negate(pRemainderDec); if (leftNegate && pRemainder) decimal128Negate(pRemainderDec);
} }
static void decimal128Mod(DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum) { static void decimal128Mod(DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum) {

View File

@ -4013,9 +4013,17 @@ int32_t fltSclCompareDatum(SFltSclDatum *val1, SFltSclDatum *val2) {
return fltSclCompareWithFloat64(val1, val2); return fltSclCompareWithFloat64(val1, val2);
} }
case FLT_SCL_DATUM_KIND_DECIMAL64: { case FLT_SCL_DATUM_KIND_DECIMAL64: {
if (val1->kind == FLT_SCL_DATUM_KIND_NULL || val1->kind == FLT_SCL_DATUM_KIND_MIN ||
val1->kind == FLT_SCL_DATUM_KIND_MAX) {
return (val1->kind < val2->kind) ? -1 : ((val1->kind > val2->kind) ? 1 : 0);
}
return compareDecimal64SameScale(&val1->i, &val2->i); return compareDecimal64SameScale(&val1->i, &val2->i);
} }
case FLT_SCL_DATUM_KIND_DECIMAL: { case FLT_SCL_DATUM_KIND_DECIMAL: {
if (val1->kind == FLT_SCL_DATUM_KIND_NULL || val1->kind == FLT_SCL_DATUM_KIND_MIN ||
val1->kind == FLT_SCL_DATUM_KIND_MAX) {
return (val1->kind < val2->kind) ? -1 : ((val1->kind > val2->kind) ? 1 : 0);
}
return compareDecimal128SameScale(val1->pData, val2->pData); return compareDecimal128SameScale(val1->pData, val2->pData);
} }
// TODO: varchar/nchar // TODO: varchar/nchar

View File

@ -178,6 +178,7 @@ class TaosShell:
except Exception as e: except Exception as e:
tdLog.exit(f"Command '{sql}' failed with error: {e.stderr.decode('utf-8')}") tdLog.exit(f"Command '{sql}' failed with error: {e.stderr.decode('utf-8')}")
self.queryResult = [] self.queryResult = []
raise
return self.queryResult return self.queryResult
class DecimalColumnExpr: class DecimalColumnExpr:
@ -185,6 +186,7 @@ class DecimalColumnExpr:
self.format_: str = format self.format_: str = format
self.executor_ = executor self.executor_ = executor
self.params_: Tuple = () self.params_: Tuple = ()
self.res_type_: DataType = None
def __str__(self): def __str__(self):
return f"({self.format_})".format(*self.params_) return f"({self.format_})".format(*self.params_)
@ -230,9 +232,13 @@ class DecimalColumnExpr:
f"check decimal succ for expr: {self}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {dec_from_insert}" f"check decimal succ for expr: {self}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {dec_from_insert}"
) )
## format_params are already been set
def generate_res_type(self):
pass
def generate(self, format_params) -> str: def generate(self, format_params) -> str:
self.params_ = format_params self.params_ = format_params
self.generate_res_type()
return self.__str__() return self.__str__()
@ -258,50 +264,39 @@ class TypeEnum:
GEOMETRY = 20 GEOMETRY = 20
DECIMAL64 = 21 DECIMAL64 = 21
@staticmethod
def get_type_prec(type: int):
type_prec = [0, 1, 3, 5, 10, 19, 38, 38, 0, 19, 10, 3, 5, 10, 20, 0, 0, 0, 0, 0, 0, 0]
return type_prec[type]
@staticmethod @staticmethod
def get_type_str(type: int): def get_type_str(type: int):
if type == TypeEnum.BOOL: type_str = [
return "BOOL" "",
elif type == TypeEnum.TINYINT: "BOOL",
return "TINYINT" "TINYINT",
elif type == TypeEnum.SMALLINT: "SMALLINT",
return "SMALLINT" "INT",
elif type == TypeEnum.INT: "BIGINT",
return "INT" "FLOAT",
elif type == TypeEnum.BIGINT: "DOUBLE",
return "BIGINT" "VARCHAR",
elif type == TypeEnum.FLOAT: "TIMESTAMP",
return "FLOAT" "NCHAR",
elif type == TypeEnum.DOUBLE: "TINYINT UNSIGNED",
return "DOUBLE" "SMALLINT UNSIGNED",
elif type == TypeEnum.VARCHAR: "INT UNSIGNED",
return "VARCHAR" "BIGINT UNSIGNED",
elif type == TypeEnum.TIMESTAMP: "JSON",
return "TIMESTAMP" "VARBINARY",
elif type == TypeEnum.NCHAR: "DECIMAL",
return "NCHAR" "",
elif type == TypeEnum.UTINYINT: "",
return "TINYINT UNSIGNED" "GEOMETRY",
elif type == TypeEnum.USMALLINT: "DECIMAL",
return "SMALLINT UNSIGNED" ]
elif type == TypeEnum.UINT: return type_str[type]
return "INT UNSIGNED"
elif type == TypeEnum.UBIGINT:
return "BIGINT UNSIGNED"
elif type == TypeEnum.JSON:
return "JSON"
elif type == TypeEnum.VARBINARY:
return "VARBINARY"
elif type == TypeEnum.DECIMAL:
return "DECIMAL"
elif type == TypeEnum.BINARY:
return "BINARY"
elif type == TypeEnum.GEOMETRY:
return "GEOMETRY"
elif type == TypeEnum.DECIMAL64:
return "DECIMAL"
else:
raise Exception("unknow type")
class DataType: class DataType:
@ -471,6 +466,11 @@ class DecimalType(DataType):
) )
@staticmethod
def decimal_type_from_other_type(other: DataType):
prec = 0
return DecimalType(other.type, other.length, other.type_mod)
class Column: class Column:
def __init__(self, type: DataType): def __init__(self, type: DataType):
self.type_: DataType = type self.type_: DataType = type
@ -697,21 +697,81 @@ class TableDataValidator:
class DecimalBinaryOperator(DecimalColumnExpr): class DecimalBinaryOperator(DecimalColumnExpr):
def __init__(self, op: str): def __init__(self, format, executor, op: str):
super().__init__() super().__init__(format, executor)
self.op_ = op self.op_ = op
def __str__(self): def __str__(self):
return self.op_ return super().__str__()
def generate(self): def generate(self, format_params: Tuple) -> str:
pass return super().generate(format_params)
@staticmethod @staticmethod
def execute_plus(params): def check_null(params):
ret_float = False
if params[0] is None or params[1] is None: if params[0] is None or params[1] is None:
return 'NULL' return True
else:
return False
@staticmethod
def is_compare_op(op: str)-> bool:
return op in ["==", "!=", ">", "<", ">=", "<="]
@staticmethod
def calc_decimal_prec_scale(left: DataType, right: DataType, op: str) -> DecimalType:
left_prec = 0
left_scale = 0
right_prec = 0
right_scale = 0
if not left.is_decimal_type():
left_prec = TypeEnum.get_type_prec(left.type)
else:
left_prec = DecimalType(left).prec()
left_scale = DecimalType(left).scale()
if not right.is_decimal_type():
right_prec = TypeEnum.get_type_prec(right.type)
else:
right_prec = DecimalType(right).prec()
right_scale = DecimalType(right).scale()
out_prec = 0
out_scale = 0
if op in ['+', '-']:
out_scale = max(left_scale, right_scale)
out_prec = max(left_prec - left_scale, right_prec - right_scale) + out_scale + 1
elif op == '*':
out_scale = left_scale + right_scale
out_prec = left_prec + right_prec + 1
elif op == '/':
out_scale = max(left_scale + right_prec + 1, 6)
out_prec = left_prec - left_scale + right_scale + out_scale
elif op == '%':
out_scale = max(left_scale, right_scale)
out_prec = min(left_prec - left_scale, right_prec - right_scale) + out_scale
else:
raise Exception(f"unknown op for binary operators: {op}")
if out_prec > 38:
min_scale = min(6, out_scale)
delta = out_prec - 38
out_prec = 38
out_scale = max(min_scale, out_scale - delta)
return DecimalType(TypeEnum.DECIMAL, out_prec, out_scale)
def generate_res_type(self):
if DecimalBinaryOperator.is_compare_op(self.op_):
self.res_type_ = DataType(TypeEnum.BOOL)
left_type = self.params_[0].type_
right_type = self.params_[1].type_
ret_double_types = [TypeEnum.VARCHAR, TypeEnum.BINARY, TypeEnum.DOUBLE, TypeEnum.FLOAT]
if left_type.type in ret_double_types or right_type.type in ret_double_types:
self.res_type_ = DataType(TypeEnum.DOUBLE)
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(left_type, right_type, self.op_)
@staticmethod
def get_ret_type(params) -> Tuple:
ret_float = False
if isinstance(params[0], float) or isinstance(params[1], float): if isinstance(params[0], float) or isinstance(params[1], float):
ret_float = True ret_float = True
left = params[0] left = params[0]
@ -722,6 +782,13 @@ class DecimalBinaryOperator(DecimalColumnExpr):
if isinstance(params[1], str): if isinstance(params[1], str):
right = right.strip("'") right = right.strip("'")
ret_float = True ret_float = True
return (left, right), ret_float
@staticmethod
def execute_plus(params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float: if ret_float:
return float(left) + float(right) return float(left) + float(right)
else: else:
@ -729,15 +796,33 @@ class DecimalBinaryOperator(DecimalColumnExpr):
@staticmethod @staticmethod
def execute_minus(params): def execute_minus(params):
return params[0] - params[1] if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float:
return float(left) - float(right)
else:
return Decimal(left) - Decimal(right)
@staticmethod @staticmethod
def execute_mul(params): def execute_mul(params):
return params[0] * params[1] if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float:
return float(left) * float(right)
else:
return Decimal(left) * Decimal(right)
@staticmethod @staticmethod
def execute_div(params): def execute_div(params):
return params[0] / params[1] if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float:
return float(left) / float(right)
else:
return Decimal(left) / Decimal(right)
@staticmethod @staticmethod
def execute_mod(params): def execute_mod(params):
@ -770,43 +855,21 @@ class DecimalBinaryOperator(DecimalColumnExpr):
@staticmethod @staticmethod
def get_all_binary_ops() -> List[DecimalColumnExpr]: def get_all_binary_ops() -> List[DecimalColumnExpr]:
return [ return [
DecimalColumnExpr(" {0} + {1} ", DecimalBinaryOperator.execute_plus), DecimalBinaryOperator(" {0} + {1} ", DecimalBinaryOperator.execute_plus, "+"),
DecimalColumnExpr(" {0} - {1} ", DecimalBinaryOperator.execute_minus), DecimalBinaryOperator(" {0} - {1} ", DecimalBinaryOperator.execute_minus, "-"),
DecimalColumnExpr(" {0} * {1} ", DecimalBinaryOperator.execute_mul), DecimalBinaryOperator(" {0} * {1} ", DecimalBinaryOperator.execute_mul, "*"),
DecimalColumnExpr(" {0} / {1} ", DecimalBinaryOperator.execute_div), DecimalBinaryOperator(" {0} / {1} ", DecimalBinaryOperator.execute_div, "/"),
DecimalColumnExpr(" {0} % {1} ", DecimalBinaryOperator.execute_mod), DecimalBinaryOperator(" {0} % {1} ", DecimalBinaryOperator.execute_mod, "%"),
DecimalColumnExpr(" {0} == {1} ", DecimalBinaryOperator.execute_eq), DecimalBinaryOperator(" {0} == {1} ", DecimalBinaryOperator.execute_eq, "=="),
DecimalColumnExpr(" {0} != {1} ", DecimalBinaryOperator.execute_ne), DecimalBinaryOperator(" {0} != {1} ", DecimalBinaryOperator.execute_ne, "!="),
DecimalColumnExpr(" {0} > {1} ", DecimalBinaryOperator.execute_gt), DecimalBinaryOperator(" {0} > {1} ", DecimalBinaryOperator.execute_gt, ">"),
DecimalColumnExpr(" {0} < {1} ", DecimalBinaryOperator.execute_lt), DecimalBinaryOperator(" {0} < {1} ", DecimalBinaryOperator.execute_lt, "<"),
DecimalColumnExpr(" {0} >= {1} ", DecimalBinaryOperator.execute_ge), DecimalBinaryOperator(" {0} >= {1} ", DecimalBinaryOperator.execute_ge, ">="),
DecimalColumnExpr(" {0} <= {1} ", DecimalBinaryOperator.execute_le), DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="),
] ]
def execute(self, left, right): def execute(self, params):
if self.op_ == "+": return super().execute(params)
return left + right
if self.op_ == "-":
return left - right
if self.op_ == "*":
return left * right
if self.op_ == "/":
return left / right
if self.op_ == "%":
return left % right
if self.op_ == "==":
return left == right
if self.op_ == "!=":
return left != right
if self.op_ == ">":
return left > right
if self.op_ == "<":
return left < right
if self.op_ == ">=":
return left >= right
if self.op_ == "<=":
return left <= right
raise Exception(f"unsupport operator {self.op_}")
class DecimalBinaryOperatorIn(DecimalBinaryOperator): class DecimalBinaryOperatorIn(DecimalBinaryOperator):
@ -1178,6 +1241,8 @@ class TDTestCase:
): ):
for expr in exprs: for expr in exprs:
for col in tb_cols: for col in tb_cols:
if col.name_ == '':
continue
left_is_decimal = col.type_.is_decimal_type() left_is_decimal = col.type_.is_decimal_type()
for const_col in constant_cols: for const_col in constant_cols:
right_is_decimal = const_col.type_.is_decimal_type() right_is_decimal = const_col.type_.is_decimal_type()
@ -1187,7 +1252,10 @@ class TDTestCase:
select_expr = expr.generate((col, const_col)) select_expr = expr.generate((col, const_col))
sql = f"select {select_expr} from {dbname}.{tbname}" sql = f"select {select_expr} from {dbname}.{tbname}"
res = TaosShell().query(sql) res = TaosShell().query(sql)
if len(res) > 0:
expr.check(res[0], tbname) expr.check(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
## query ## query
## build expr, expr.generate(column) to generate sql expr ## build expr, expr.generate(column) to generate sql expr
## pass this expr into DataValidator. ## pass this expr into DataValidator.
@ -1244,21 +1312,7 @@ class TDTestCase:
self.test_decimal_unsupported_types() self.test_decimal_unsupported_types()
## tables: meters, nt ## tables: meters, nt
## columns: c1, c2, c3, c4, c5, c7, c8, c9, c10, c99, c100 ## columns: c1, c2, c3, c4, c5, c7, c8, c9, c10, c99, c100
binary_operators = [ binary_operators = DecimalBinaryOperator.get_all_binary_ops()
DecimalColumnExpr("{0} + {1}", DecimalBinaryOperator.execute_plus),
# DecimalColumnExpr("-"),
# DecimalColumnExpr("*"),
# DecimalColumnExpr("/"),
# DecimalColumnExpr("%"),
# DecimalColumnExpr(">"),
# DecimalColumnExpr("<"),
# DecimalColumnExpr(">="),
# DecimalColumnExpr("<="),
# DecimalColumnExpr("=="),
# DecimalColumnExpr("!="),
# DecimalBinaryOperatorIn("in"),
# DecimalBinaryOperatorIn("not in"),
]
all_type_columns = Column.get_decimal_oper_const_cols() all_type_columns = Column.get_decimal_oper_const_cols()
## decimal operator with constants of all other types ## decimal operator with constants of all other types