diff --git a/source/libs/decimal/src/decimal.c b/source/libs/decimal/src/decimal.c index 71604b66b2..f71168effa 100644 --- a/source/libs/decimal/src/decimal.c +++ b/source/libs/decimal/src/decimal.c @@ -338,7 +338,7 @@ static bool decimal64Lt(const DecimalType* pLeft, const DecimalType* pRight, static bool decimal64Gt(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum); static bool decimal64Eq(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum); static int32_t decimal64ToStr(const DecimalType* pInt, uint8_t scale, char* pBuf, int32_t bufLen); -static void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown); +static void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown, bool round); static void decimal64ScaleUp(Decimal64* pDec, uint8_t scaleUp); static void decimal64ScaleTo(Decimal64* pDec, uint8_t oldScale, uint8_t newScale); @@ -358,7 +358,7 @@ static bool decimal128Gt(const DecimalType* pLeft, const DecimalType* pRight, static bool decimal128Eq(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum); static int32_t decimal128ToStr(const DecimalType* pInt, uint8_t scale, char* pBuf, int32_t bufLen); static void decimal128ScaleTo(Decimal128* pDec, uint8_t oldScale, uint8_t newScale); -static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown); +static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown, bool round); static void decimal128ScaleUp(Decimal128* pDec, uint8_t scaleUp); static int32_t decimal128CountLeadingBinaryZeros(const Decimal128* pDec); static int32_t decimal128FromInt64(DecimalType* pDec, uint8_t prec, uint8_t scale, int64_t val); @@ -793,7 +793,7 @@ static void decimalAddLargePositive(Decimal* pX, const SDataType* pXT, const Dec decimal128Add(&right, &fracY, WORD_NUM(Decimal)); } - decimal128ScaleDown(&right, maxScale - pOT->scale); + decimal128ScaleDown(&right, maxScale - pOT->scale, true); decimal128Add(&wholeX, &wholeY, WORD_NUM(Decimal)); decimal128Add(&wholeX, &carry, WORD_NUM(Decimal)); decimal128Multiply(&wholeX, &SCALE_MULTIPLIER_128[pOT->scale], WORD_NUM(Decimal)); @@ -822,7 +822,7 @@ static void decimalAddLargeNegative(Decimal* pX, const SDataType* pXT, const Dec decimal128Add(&fracX, &SCALE_MULTIPLIER_128[maxScale], WORD_NUM(Decimal)); } - decimal128ScaleDown(&fracX, maxScale - pOT->scale); + decimal128ScaleDown(&fracX, maxScale - pOT->scale, true); decimal128Multiply(&wholeX, &SCALE_MULTIPLIER_128[pOT->scale], WORD_NUM(Decimal)); decimal128Add(&wholeX, &fracX, WORD_NUM(Decimal)); *pX = wholeX; @@ -951,7 +951,7 @@ static int32_t decimalMultiply(Decimal* pX, const SDataType* pXT, const Decimal* // no need to trim scale if (deltaScale <= 38) { decimal128Multiply(pX, pY, WORD_NUM(Decimal)); - decimal128ScaleDown(pX, deltaScale); + decimal128ScaleDown(pX, deltaScale, true); } else { makeDecimal128(pX, 0, 0); } @@ -1470,16 +1470,18 @@ int32_t convertToDecimal(const void* pData, const SDataType* pInputType, void* p return code; } -void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown) { +void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown, bool round) { if (scaleDown > 0) { Decimal64 divisor = SCALE_MULTIPLIER_64[scaleDown], remainder = {0}; decimal64divide(pDec, &divisor, WORD_NUM(Decimal64), &remainder); - decimal64Abs(&remainder); - Decimal64 half = SCALE_MULTIPLIER_64[scaleDown]; - decimal64divide(&half, &decimal64Two, WORD_NUM(Decimal64), NULL); - if (!decimal64Lt(&remainder, &half, WORD_NUM(Decimal64))) { - Decimal64 delta = {DECIMAL64_SIGN(pDec)}; - //decimal64Add(pDec, &delta, WORD_NUM(Decimal64)); + if (round) { + decimal64Abs(&remainder); + Decimal64 half = SCALE_MULTIPLIER_64[scaleDown]; + decimal64divide(&half, &decimal64Two, WORD_NUM(Decimal64), NULL); + if (!decimal64Lt(&remainder, &half, WORD_NUM(Decimal64))) { + Decimal64 delta = {DECIMAL64_SIGN(pDec)}; + decimal64Add(pDec, &delta, WORD_NUM(Decimal64)); + } } } } @@ -1495,7 +1497,7 @@ static void decimal64ScaleTo(Decimal64* pDec, uint8_t oldScale, uint8_t newScale if (newScale > oldScale) decimal64ScaleUp(pDec, newScale - oldScale); else if (newScale < oldScale) - decimal64ScaleDown(pDec, oldScale - newScale); + decimal64ScaleDown(pDec, oldScale - newScale, true); } static void decimal64ScaleAndCheckOverflow(Decimal64* pDec, int8_t scale, uint8_t toPrec, uint8_t toScale, @@ -1513,7 +1515,7 @@ static void decimal64ScaleAndCheckOverflow(Decimal64* pDec, int8_t scale, uint8_ } } else if (deltaScale < 0) { Decimal64 res = *pDec, max = {0}; - decimal64ScaleDown(&res, -deltaScale); + decimal64ScaleDown(&res, -deltaScale, false); DECIMAL64_GET_MAX(toPrec, &max); if (decimal64Gt(&res, &max, WORD_NUM(Decimal64))) { if (overflow) *overflow = true; @@ -1609,10 +1611,19 @@ int32_t decimal64FromStr(const char* str, int32_t len, uint8_t expectPrecision, } // TODO wjm add round param -static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown) { +static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown, bool round) { if (scaleDown > 0) { - Decimal128 divisor = SCALE_MULTIPLIER_128[scaleDown]; - decimal128Divide(pDec, &divisor, 2, NULL); + Decimal128 divisor = SCALE_MULTIPLIER_128[scaleDown], remainder = {0}; + decimal128Divide(pDec, &divisor, 2, &remainder); + if (round) { + decimal128Abs(&remainder); + Decimal128 half = SCALE_MULTIPLIER_128[scaleDown]; + decimal128Divide(&half, &decimal128Two, 2, NULL); + if (!decimal128Lt(&remainder, &half, WORD_NUM(Decimal128))) { + Decimal64 delta = {DECIMAL128_SIGN(pDec)}; + decimal128Add(pDec, &delta, WORD_NUM(Decimal64)); + } + } } } @@ -1627,7 +1638,7 @@ static void decimal128ScaleTo(Decimal128* pDec, uint8_t oldScale, uint8_t newSca if (newScale > oldScale) decimal128ScaleUp(pDec, newScale - oldScale); else if (newScale < oldScale) - decimal128ScaleDown(pDec, oldScale - newScale); + decimal128ScaleDown(pDec, oldScale - newScale, true); } int32_t decimal128FromStr(const char* str, int32_t len, uint8_t expectPrecision, uint8_t expectScale, @@ -1781,7 +1792,7 @@ static void decimal128ModifyScaleAndPrecision(Decimal128* pDec, uint8_t scale, u } } else { Decimal128 res = *pDec, max = {0}; - decimal128ScaleDown(&res, -deltaScale); + decimal128ScaleDown(&res, -deltaScale, false); DECIMAL128_GET_MAX(toPrec, &max); if (decimal128Gt(&res, &max, WORD_NUM(Decimal128))) { if (overflow) *overflow = true; diff --git a/tests/system-test/2-query/decimal.py b/tests/system-test/2-query/decimal.py index 1e02c4bb2a..6616ee8b94 100644 --- a/tests/system-test/2-query/decimal.py +++ b/tests/system-test/2-query/decimal.py @@ -192,14 +192,17 @@ class DecimalColumnExpr: return f"({self.format_})".format(*self.params_) def execute(self, params): - return self.executor_(params) - + return self.executor_(self, params) + def get_val(self, tbname: str, idx: int): params: Tuple = () for p in self.params_: params = params + (p.get_val(tbname, idx),) return self.execute(params) - + + def get_input_types(self) -> List: + pass + def check(self, query_col_res: List, tbname: str): for i in range(len(query_col_res)): v_from_query = query_col_res[i] @@ -226,12 +229,13 @@ class DecimalColumnExpr: failed = dec_from_query != dec_from_insert if failed: tdLog.exit( - f"check decimal column failed for expr: {self} params: {params}, query: {v_from_query}, expect {dec_from_insert}, but get {dec_from_query}") + f"check decimal column failed for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, query: {v_from_query}, expect {dec_from_insert}, but get {dec_from_query}" + ) else: tdLog.debug( - f"check decimal succ for expr: {self}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {dec_from_insert}" + f"check decimal succ for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {dec_from_insert}" ) - + ## format_params are already been set def generate_res_type(self): pass @@ -298,7 +302,6 @@ class TypeEnum: return type_str[type] - class DataType: def __init__(self, type: int, length: int = 0, type_mod: int = 0): self.type: int = type @@ -336,7 +339,7 @@ class DataType: ## TODO generate NULL, None def generate_value(self) -> str: if self.type == TypeEnum.BOOL: - return str(secrets.randbelow(2)) + return ['true', 'false'][secrets.randbelow(2)] if self.type == TypeEnum.TINYINT: return str(secrets.randbelow(256) - 128) if self.type == TypeEnum.SMALLINT: @@ -375,6 +378,11 @@ class DataType: def get_typed_val(self, val): if self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE: return float(val) + if self.type == TypeEnum.BOOL: + if val == "true": + return 1 + else: + return 0 return val class DecimalType(DataType): @@ -479,18 +487,18 @@ class Column: def is_constant_col(self): return '' in self.saved_vals.keys() - + def get_typed_val(self, val): return self.type_.get_typed_val(val) - + def get_constant_val(self): return self.get_typed_val(self.saved_vals[''][0]) - + def __str__(self): if self.is_constant_col(): return str(self.get_constant_val()) return self.name_ - + def get_val(self, tbname: str, idx: int): if self.is_constant_col(): return self.get_constant_val() @@ -513,7 +521,7 @@ class Column: def check(self, values, offset: int): return self.type_.check(values, offset) - + def construct_type_value(self, val: str): if ( self.type_.type == TypeEnum.BINARY @@ -525,7 +533,7 @@ class Column: return f"'{val}'" else: return val - + @staticmethod def get_decimal_unsupported_types() -> list: return [ @@ -533,11 +541,23 @@ class Column: TypeEnum.GEOMETRY, TypeEnum.VARBINARY, ] - + @staticmethod def get_decimal_oper_const_cols() -> list: - return Column.get_all_type_columns(Column.get_decimal_unsupported_types() + Column.get_decimal_types()) - + types_unable_to_be_const = [ + TypeEnum.TINYINT, + TypeEnum.SMALLINT, + TypeEnum.INT, + TypeEnum.UINT, + TypeEnum.USMALLINT, + TypeEnum.UTINYINT, + ] + return Column.get_all_type_columns( + Column.get_decimal_unsupported_types() + + Column.get_decimal_types() + + types_unable_to_be_const + ) + @staticmethod def get_decimal_types() -> List: return [TypeEnum.DECIMAL, TypeEnum.DECIMAL64] @@ -700,6 +720,8 @@ class DecimalBinaryOperator(DecimalColumnExpr): def __init__(self, format, executor, op: str): super().__init__(format, executor) self.op_ = op + self.left_type_: DataType = None + self.right_type_: DataType = None def __str__(self): return super().__str__() @@ -727,13 +749,13 @@ class DecimalBinaryOperator(DecimalColumnExpr): if not left.is_decimal_type(): left_prec = TypeEnum.get_type_prec(left.type) else: - left_prec = DecimalType(left).prec() - left_scale = DecimalType(left).scale() + left_prec = left.prec() + left_scale = left.scale() if not right.is_decimal_type(): right_prec = TypeEnum.get_type_prec(right.type) else: - right_prec = DecimalType(right).prec() - right_scale = DecimalType(right).scale() + right_prec = right.prec() + right_scale = right.scale() out_prec = 0 out_scale = 0 @@ -762,12 +784,22 @@ class DecimalBinaryOperator(DecimalColumnExpr): def generate_res_type(self): if DecimalBinaryOperator.is_compare_op(self.op_): self.res_type_ = DataType(TypeEnum.BOOL) + return left_type = self.params_[0].type_ right_type = self.params_[1].type_ + self.left_type_ = left_type + self.right_type_ = right_type ret_double_types = [TypeEnum.VARCHAR, TypeEnum.BINARY, TypeEnum.DOUBLE, TypeEnum.FLOAT] if left_type.type in ret_double_types or right_type.type in ret_double_types: self.res_type_ = DataType(TypeEnum.DOUBLE) - self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(left_type, right_type, self.op_) + else: + self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(left_type, right_type, self.op_) + + def get_input_types(self)-> list: + return [self.left_type_, self.right_type_] + + def convert_to_res_type(self, val: Decimal) -> Decimal: + return val.quantize(Decimal("1." + "0" * self.res_type_.scale()), ROUND_HALF_UP) @staticmethod def get_ret_type(params) -> Tuple: @@ -784,49 +816,50 @@ class DecimalBinaryOperator(DecimalColumnExpr): ret_float = True return (left, right), ret_float - @staticmethod - def execute_plus(params): + def execute_plus(self, params): if DecimalBinaryOperator.check_null(params): return 'NULL' (left, right), ret_float = DecimalBinaryOperator.get_ret_type(params) if ret_float: return float(left) + float(right) else: - return Decimal(left) + Decimal(right) + return self.convert_to_res_type(Decimal(left) + Decimal(right)) - @staticmethod - def execute_minus(params): + def execute_minus(self, params): if DecimalBinaryOperator.check_null(params): return 'NULL' (left, right), ret_float = DecimalBinaryOperator.get_ret_type(params) if ret_float: return float(left) - float(right) else: - return Decimal(left) - Decimal(right) + return self.convert_to_res_type(Decimal(left) - Decimal(right)) - @staticmethod - def execute_mul(params): + def execute_mul(self, params): if DecimalBinaryOperator.check_null(params): return 'NULL' (left, right), ret_float = DecimalBinaryOperator.get_ret_type(params) if ret_float: return float(left) * float(right) else: - return Decimal(left) * Decimal(right) + return self.convert_to_res_type(Decimal(left) * Decimal(right)) - @staticmethod - def execute_div(params): + def execute_div(self, params): if DecimalBinaryOperator.check_null(params): return 'NULL' (left, right), ret_float = DecimalBinaryOperator.get_ret_type(params) if ret_float: return float(left) / float(right) else: - return Decimal(left) / Decimal(right) + return self.convert_to_res_type(Decimal(left) / Decimal(right)) - @staticmethod - def execute_mod(params): - return params[0] % params[1] + def execute_mod(self, params): + if DecimalBinaryOperator.check_null(params): + return 'NULL' + (left, right), ret_float = DecimalBinaryOperator.get_ret_type(params) + if ret_float: + return float(left) % float(right) + else: + return self.convert_to_res_type(Decimal(left) % Decimal(right)) @staticmethod def execute_eq(params):