decimal operator test

This commit is contained in:
wangjiaming0909 2025-03-01 12:26:38 +08:00
parent 79c7ff8bf2
commit 759c9af591
2 changed files with 100 additions and 56 deletions

View File

@ -338,7 +338,7 @@ static bool decimal64Lt(const DecimalType* pLeft, const DecimalType* pRight,
static bool decimal64Gt(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
static bool decimal64Eq(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
static int32_t decimal64ToStr(const DecimalType* pInt, uint8_t scale, char* pBuf, int32_t bufLen);
static void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown);
static void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown, bool round);
static void decimal64ScaleUp(Decimal64* pDec, uint8_t scaleUp);
static void decimal64ScaleTo(Decimal64* pDec, uint8_t oldScale, uint8_t newScale);
@ -358,7 +358,7 @@ static bool decimal128Gt(const DecimalType* pLeft, const DecimalType* pRight,
static bool decimal128Eq(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
static int32_t decimal128ToStr(const DecimalType* pInt, uint8_t scale, char* pBuf, int32_t bufLen);
static void decimal128ScaleTo(Decimal128* pDec, uint8_t oldScale, uint8_t newScale);
static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown);
static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown, bool round);
static void decimal128ScaleUp(Decimal128* pDec, uint8_t scaleUp);
static int32_t decimal128CountLeadingBinaryZeros(const Decimal128* pDec);
static int32_t decimal128FromInt64(DecimalType* pDec, uint8_t prec, uint8_t scale, int64_t val);
@ -793,7 +793,7 @@ static void decimalAddLargePositive(Decimal* pX, const SDataType* pXT, const Dec
decimal128Add(&right, &fracY, WORD_NUM(Decimal));
}
decimal128ScaleDown(&right, maxScale - pOT->scale);
decimal128ScaleDown(&right, maxScale - pOT->scale, true);
decimal128Add(&wholeX, &wholeY, WORD_NUM(Decimal));
decimal128Add(&wholeX, &carry, WORD_NUM(Decimal));
decimal128Multiply(&wholeX, &SCALE_MULTIPLIER_128[pOT->scale], WORD_NUM(Decimal));
@ -822,7 +822,7 @@ static void decimalAddLargeNegative(Decimal* pX, const SDataType* pXT, const Dec
decimal128Add(&fracX, &SCALE_MULTIPLIER_128[maxScale], WORD_NUM(Decimal));
}
decimal128ScaleDown(&fracX, maxScale - pOT->scale);
decimal128ScaleDown(&fracX, maxScale - pOT->scale, true);
decimal128Multiply(&wholeX, &SCALE_MULTIPLIER_128[pOT->scale], WORD_NUM(Decimal));
decimal128Add(&wholeX, &fracX, WORD_NUM(Decimal));
*pX = wholeX;
@ -951,7 +951,7 @@ static int32_t decimalMultiply(Decimal* pX, const SDataType* pXT, const Decimal*
// no need to trim scale
if (deltaScale <= 38) {
decimal128Multiply(pX, pY, WORD_NUM(Decimal));
decimal128ScaleDown(pX, deltaScale);
decimal128ScaleDown(pX, deltaScale, true);
} else {
makeDecimal128(pX, 0, 0);
}
@ -1470,16 +1470,18 @@ int32_t convertToDecimal(const void* pData, const SDataType* pInputType, void* p
return code;
}
void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown) {
void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown, bool round) {
if (scaleDown > 0) {
Decimal64 divisor = SCALE_MULTIPLIER_64[scaleDown], remainder = {0};
decimal64divide(pDec, &divisor, WORD_NUM(Decimal64), &remainder);
if (round) {
decimal64Abs(&remainder);
Decimal64 half = SCALE_MULTIPLIER_64[scaleDown];
decimal64divide(&half, &decimal64Two, WORD_NUM(Decimal64), NULL);
if (!decimal64Lt(&remainder, &half, WORD_NUM(Decimal64))) {
Decimal64 delta = {DECIMAL64_SIGN(pDec)};
//decimal64Add(pDec, &delta, WORD_NUM(Decimal64));
decimal64Add(pDec, &delta, WORD_NUM(Decimal64));
}
}
}
}
@ -1495,7 +1497,7 @@ static void decimal64ScaleTo(Decimal64* pDec, uint8_t oldScale, uint8_t newScale
if (newScale > oldScale)
decimal64ScaleUp(pDec, newScale - oldScale);
else if (newScale < oldScale)
decimal64ScaleDown(pDec, oldScale - newScale);
decimal64ScaleDown(pDec, oldScale - newScale, true);
}
static void decimal64ScaleAndCheckOverflow(Decimal64* pDec, int8_t scale, uint8_t toPrec, uint8_t toScale,
@ -1513,7 +1515,7 @@ static void decimal64ScaleAndCheckOverflow(Decimal64* pDec, int8_t scale, uint8_
}
} else if (deltaScale < 0) {
Decimal64 res = *pDec, max = {0};
decimal64ScaleDown(&res, -deltaScale);
decimal64ScaleDown(&res, -deltaScale, false);
DECIMAL64_GET_MAX(toPrec, &max);
if (decimal64Gt(&res, &max, WORD_NUM(Decimal64))) {
if (overflow) *overflow = true;
@ -1609,10 +1611,19 @@ int32_t decimal64FromStr(const char* str, int32_t len, uint8_t expectPrecision,
}
// TODO wjm add round param
static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown) {
static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown, bool round) {
if (scaleDown > 0) {
Decimal128 divisor = SCALE_MULTIPLIER_128[scaleDown];
decimal128Divide(pDec, &divisor, 2, NULL);
Decimal128 divisor = SCALE_MULTIPLIER_128[scaleDown], remainder = {0};
decimal128Divide(pDec, &divisor, 2, &remainder);
if (round) {
decimal128Abs(&remainder);
Decimal128 half = SCALE_MULTIPLIER_128[scaleDown];
decimal128Divide(&half, &decimal128Two, 2, NULL);
if (!decimal128Lt(&remainder, &half, WORD_NUM(Decimal128))) {
Decimal64 delta = {DECIMAL128_SIGN(pDec)};
decimal128Add(pDec, &delta, WORD_NUM(Decimal64));
}
}
}
}
@ -1627,7 +1638,7 @@ static void decimal128ScaleTo(Decimal128* pDec, uint8_t oldScale, uint8_t newSca
if (newScale > oldScale)
decimal128ScaleUp(pDec, newScale - oldScale);
else if (newScale < oldScale)
decimal128ScaleDown(pDec, oldScale - newScale);
decimal128ScaleDown(pDec, oldScale - newScale, true);
}
int32_t decimal128FromStr(const char* str, int32_t len, uint8_t expectPrecision, uint8_t expectScale,
@ -1781,7 +1792,7 @@ static void decimal128ModifyScaleAndPrecision(Decimal128* pDec, uint8_t scale, u
}
} else {
Decimal128 res = *pDec, max = {0};
decimal128ScaleDown(&res, -deltaScale);
decimal128ScaleDown(&res, -deltaScale, false);
DECIMAL128_GET_MAX(toPrec, &max);
if (decimal128Gt(&res, &max, WORD_NUM(Decimal128))) {
if (overflow) *overflow = true;

View File

@ -192,7 +192,7 @@ class DecimalColumnExpr:
return f"({self.format_})".format(*self.params_)
def execute(self, params):
return self.executor_(params)
return self.executor_(self, params)
def get_val(self, tbname: str, idx: int):
params: Tuple = ()
@ -200,6 +200,9 @@ class DecimalColumnExpr:
params = params + (p.get_val(tbname, idx),)
return self.execute(params)
def get_input_types(self) -> List:
pass
def check(self, query_col_res: List, tbname: str):
for i in range(len(query_col_res)):
v_from_query = query_col_res[i]
@ -226,10 +229,11 @@ class DecimalColumnExpr:
failed = dec_from_query != dec_from_insert
if failed:
tdLog.exit(
f"check decimal column failed for expr: {self} params: {params}, query: {v_from_query}, expect {dec_from_insert}, but get {dec_from_query}")
f"check decimal column failed for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, query: {v_from_query}, expect {dec_from_insert}, but get {dec_from_query}"
)
else:
tdLog.debug(
f"check decimal succ for expr: {self}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {dec_from_insert}"
f"check decimal succ for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {dec_from_insert}"
)
## format_params are already been set
@ -298,7 +302,6 @@ class TypeEnum:
return type_str[type]
class DataType:
def __init__(self, type: int, length: int = 0, type_mod: int = 0):
self.type: int = type
@ -336,7 +339,7 @@ class DataType:
## TODO generate NULL, None
def generate_value(self) -> str:
if self.type == TypeEnum.BOOL:
return str(secrets.randbelow(2))
return ['true', 'false'][secrets.randbelow(2)]
if self.type == TypeEnum.TINYINT:
return str(secrets.randbelow(256) - 128)
if self.type == TypeEnum.SMALLINT:
@ -375,6 +378,11 @@ class DataType:
def get_typed_val(self, val):
if self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE:
return float(val)
if self.type == TypeEnum.BOOL:
if val == "true":
return 1
else:
return 0
return val
class DecimalType(DataType):
@ -536,7 +544,19 @@ class Column:
@staticmethod
def get_decimal_oper_const_cols() -> list:
return Column.get_all_type_columns(Column.get_decimal_unsupported_types() + Column.get_decimal_types())
types_unable_to_be_const = [
TypeEnum.TINYINT,
TypeEnum.SMALLINT,
TypeEnum.INT,
TypeEnum.UINT,
TypeEnum.USMALLINT,
TypeEnum.UTINYINT,
]
return Column.get_all_type_columns(
Column.get_decimal_unsupported_types()
+ Column.get_decimal_types()
+ types_unable_to_be_const
)
@staticmethod
def get_decimal_types() -> List:
@ -700,6 +720,8 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def __init__(self, format, executor, op: str):
super().__init__(format, executor)
self.op_ = op
self.left_type_: DataType = None
self.right_type_: DataType = None
def __str__(self):
return super().__str__()
@ -727,13 +749,13 @@ class DecimalBinaryOperator(DecimalColumnExpr):
if not left.is_decimal_type():
left_prec = TypeEnum.get_type_prec(left.type)
else:
left_prec = DecimalType(left).prec()
left_scale = DecimalType(left).scale()
left_prec = left.prec()
left_scale = left.scale()
if not right.is_decimal_type():
right_prec = TypeEnum.get_type_prec(right.type)
else:
right_prec = DecimalType(right).prec()
right_scale = DecimalType(right).scale()
right_prec = right.prec()
right_scale = right.scale()
out_prec = 0
out_scale = 0
@ -762,13 +784,23 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def generate_res_type(self):
if DecimalBinaryOperator.is_compare_op(self.op_):
self.res_type_ = DataType(TypeEnum.BOOL)
return
left_type = self.params_[0].type_
right_type = self.params_[1].type_
self.left_type_ = left_type
self.right_type_ = right_type
ret_double_types = [TypeEnum.VARCHAR, TypeEnum.BINARY, TypeEnum.DOUBLE, TypeEnum.FLOAT]
if left_type.type in ret_double_types or right_type.type in ret_double_types:
self.res_type_ = DataType(TypeEnum.DOUBLE)
else:
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(left_type, right_type, self.op_)
def get_input_types(self)-> list:
return [self.left_type_, self.right_type_]
def convert_to_res_type(self, val: Decimal) -> Decimal:
return val.quantize(Decimal("1." + "0" * self.res_type_.scale()), ROUND_HALF_UP)
@staticmethod
def get_ret_type(params) -> Tuple:
ret_float = False
@ -784,49 +816,50 @@ class DecimalBinaryOperator(DecimalColumnExpr):
ret_float = True
return (left, right), ret_float
@staticmethod
def execute_plus(params):
def execute_plus(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float:
return float(left) + float(right)
else:
return Decimal(left) + Decimal(right)
return self.convert_to_res_type(Decimal(left) + Decimal(right))
@staticmethod
def execute_minus(params):
def execute_minus(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float:
return float(left) - float(right)
else:
return Decimal(left) - Decimal(right)
return self.convert_to_res_type(Decimal(left) - Decimal(right))
@staticmethod
def execute_mul(params):
def execute_mul(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float:
return float(left) * float(right)
else:
return Decimal(left) * Decimal(right)
return self.convert_to_res_type(Decimal(left) * Decimal(right))
@staticmethod
def execute_div(params):
def execute_div(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float:
return float(left) / float(right)
else:
return Decimal(left) / Decimal(right)
return self.convert_to_res_type(Decimal(left) / Decimal(right))
@staticmethod
def execute_mod(params):
return params[0] % params[1]
def execute_mod(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float:
return float(left) % float(right)
else:
return self.convert_to_res_type(Decimal(left) % Decimal(right))
@staticmethod
def execute_eq(params):