decimal operator test
This commit is contained in:
parent
79c7ff8bf2
commit
759c9af591
|
@ -338,7 +338,7 @@ static bool decimal64Lt(const DecimalType* pLeft, const DecimalType* pRight,
|
||||||
static bool decimal64Gt(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
|
static bool decimal64Gt(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
|
||||||
static bool decimal64Eq(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
|
static bool decimal64Eq(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
|
||||||
static int32_t decimal64ToStr(const DecimalType* pInt, uint8_t scale, char* pBuf, int32_t bufLen);
|
static int32_t decimal64ToStr(const DecimalType* pInt, uint8_t scale, char* pBuf, int32_t bufLen);
|
||||||
static void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown);
|
static void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown, bool round);
|
||||||
static void decimal64ScaleUp(Decimal64* pDec, uint8_t scaleUp);
|
static void decimal64ScaleUp(Decimal64* pDec, uint8_t scaleUp);
|
||||||
static void decimal64ScaleTo(Decimal64* pDec, uint8_t oldScale, uint8_t newScale);
|
static void decimal64ScaleTo(Decimal64* pDec, uint8_t oldScale, uint8_t newScale);
|
||||||
|
|
||||||
|
@ -358,7 +358,7 @@ static bool decimal128Gt(const DecimalType* pLeft, const DecimalType* pRight,
|
||||||
static bool decimal128Eq(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
|
static bool decimal128Eq(const DecimalType* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
|
||||||
static int32_t decimal128ToStr(const DecimalType* pInt, uint8_t scale, char* pBuf, int32_t bufLen);
|
static int32_t decimal128ToStr(const DecimalType* pInt, uint8_t scale, char* pBuf, int32_t bufLen);
|
||||||
static void decimal128ScaleTo(Decimal128* pDec, uint8_t oldScale, uint8_t newScale);
|
static void decimal128ScaleTo(Decimal128* pDec, uint8_t oldScale, uint8_t newScale);
|
||||||
static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown);
|
static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown, bool round);
|
||||||
static void decimal128ScaleUp(Decimal128* pDec, uint8_t scaleUp);
|
static void decimal128ScaleUp(Decimal128* pDec, uint8_t scaleUp);
|
||||||
static int32_t decimal128CountLeadingBinaryZeros(const Decimal128* pDec);
|
static int32_t decimal128CountLeadingBinaryZeros(const Decimal128* pDec);
|
||||||
static int32_t decimal128FromInt64(DecimalType* pDec, uint8_t prec, uint8_t scale, int64_t val);
|
static int32_t decimal128FromInt64(DecimalType* pDec, uint8_t prec, uint8_t scale, int64_t val);
|
||||||
|
@ -793,7 +793,7 @@ static void decimalAddLargePositive(Decimal* pX, const SDataType* pXT, const Dec
|
||||||
decimal128Add(&right, &fracY, WORD_NUM(Decimal));
|
decimal128Add(&right, &fracY, WORD_NUM(Decimal));
|
||||||
}
|
}
|
||||||
|
|
||||||
decimal128ScaleDown(&right, maxScale - pOT->scale);
|
decimal128ScaleDown(&right, maxScale - pOT->scale, true);
|
||||||
decimal128Add(&wholeX, &wholeY, WORD_NUM(Decimal));
|
decimal128Add(&wholeX, &wholeY, WORD_NUM(Decimal));
|
||||||
decimal128Add(&wholeX, &carry, WORD_NUM(Decimal));
|
decimal128Add(&wholeX, &carry, WORD_NUM(Decimal));
|
||||||
decimal128Multiply(&wholeX, &SCALE_MULTIPLIER_128[pOT->scale], WORD_NUM(Decimal));
|
decimal128Multiply(&wholeX, &SCALE_MULTIPLIER_128[pOT->scale], WORD_NUM(Decimal));
|
||||||
|
@ -822,7 +822,7 @@ static void decimalAddLargeNegative(Decimal* pX, const SDataType* pXT, const Dec
|
||||||
decimal128Add(&fracX, &SCALE_MULTIPLIER_128[maxScale], WORD_NUM(Decimal));
|
decimal128Add(&fracX, &SCALE_MULTIPLIER_128[maxScale], WORD_NUM(Decimal));
|
||||||
}
|
}
|
||||||
|
|
||||||
decimal128ScaleDown(&fracX, maxScale - pOT->scale);
|
decimal128ScaleDown(&fracX, maxScale - pOT->scale, true);
|
||||||
decimal128Multiply(&wholeX, &SCALE_MULTIPLIER_128[pOT->scale], WORD_NUM(Decimal));
|
decimal128Multiply(&wholeX, &SCALE_MULTIPLIER_128[pOT->scale], WORD_NUM(Decimal));
|
||||||
decimal128Add(&wholeX, &fracX, WORD_NUM(Decimal));
|
decimal128Add(&wholeX, &fracX, WORD_NUM(Decimal));
|
||||||
*pX = wholeX;
|
*pX = wholeX;
|
||||||
|
@ -951,7 +951,7 @@ static int32_t decimalMultiply(Decimal* pX, const SDataType* pXT, const Decimal*
|
||||||
// no need to trim scale
|
// no need to trim scale
|
||||||
if (deltaScale <= 38) {
|
if (deltaScale <= 38) {
|
||||||
decimal128Multiply(pX, pY, WORD_NUM(Decimal));
|
decimal128Multiply(pX, pY, WORD_NUM(Decimal));
|
||||||
decimal128ScaleDown(pX, deltaScale);
|
decimal128ScaleDown(pX, deltaScale, true);
|
||||||
} else {
|
} else {
|
||||||
makeDecimal128(pX, 0, 0);
|
makeDecimal128(pX, 0, 0);
|
||||||
}
|
}
|
||||||
|
@ -1470,16 +1470,18 @@ int32_t convertToDecimal(const void* pData, const SDataType* pInputType, void* p
|
||||||
return code;
|
return code;
|
||||||
}
|
}
|
||||||
|
|
||||||
void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown) {
|
void decimal64ScaleDown(Decimal64* pDec, uint8_t scaleDown, bool round) {
|
||||||
if (scaleDown > 0) {
|
if (scaleDown > 0) {
|
||||||
Decimal64 divisor = SCALE_MULTIPLIER_64[scaleDown], remainder = {0};
|
Decimal64 divisor = SCALE_MULTIPLIER_64[scaleDown], remainder = {0};
|
||||||
decimal64divide(pDec, &divisor, WORD_NUM(Decimal64), &remainder);
|
decimal64divide(pDec, &divisor, WORD_NUM(Decimal64), &remainder);
|
||||||
decimal64Abs(&remainder);
|
if (round) {
|
||||||
Decimal64 half = SCALE_MULTIPLIER_64[scaleDown];
|
decimal64Abs(&remainder);
|
||||||
decimal64divide(&half, &decimal64Two, WORD_NUM(Decimal64), NULL);
|
Decimal64 half = SCALE_MULTIPLIER_64[scaleDown];
|
||||||
if (!decimal64Lt(&remainder, &half, WORD_NUM(Decimal64))) {
|
decimal64divide(&half, &decimal64Two, WORD_NUM(Decimal64), NULL);
|
||||||
Decimal64 delta = {DECIMAL64_SIGN(pDec)};
|
if (!decimal64Lt(&remainder, &half, WORD_NUM(Decimal64))) {
|
||||||
//decimal64Add(pDec, &delta, WORD_NUM(Decimal64));
|
Decimal64 delta = {DECIMAL64_SIGN(pDec)};
|
||||||
|
decimal64Add(pDec, &delta, WORD_NUM(Decimal64));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1495,7 +1497,7 @@ static void decimal64ScaleTo(Decimal64* pDec, uint8_t oldScale, uint8_t newScale
|
||||||
if (newScale > oldScale)
|
if (newScale > oldScale)
|
||||||
decimal64ScaleUp(pDec, newScale - oldScale);
|
decimal64ScaleUp(pDec, newScale - oldScale);
|
||||||
else if (newScale < oldScale)
|
else if (newScale < oldScale)
|
||||||
decimal64ScaleDown(pDec, oldScale - newScale);
|
decimal64ScaleDown(pDec, oldScale - newScale, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void decimal64ScaleAndCheckOverflow(Decimal64* pDec, int8_t scale, uint8_t toPrec, uint8_t toScale,
|
static void decimal64ScaleAndCheckOverflow(Decimal64* pDec, int8_t scale, uint8_t toPrec, uint8_t toScale,
|
||||||
|
@ -1513,7 +1515,7 @@ static void decimal64ScaleAndCheckOverflow(Decimal64* pDec, int8_t scale, uint8_
|
||||||
}
|
}
|
||||||
} else if (deltaScale < 0) {
|
} else if (deltaScale < 0) {
|
||||||
Decimal64 res = *pDec, max = {0};
|
Decimal64 res = *pDec, max = {0};
|
||||||
decimal64ScaleDown(&res, -deltaScale);
|
decimal64ScaleDown(&res, -deltaScale, false);
|
||||||
DECIMAL64_GET_MAX(toPrec, &max);
|
DECIMAL64_GET_MAX(toPrec, &max);
|
||||||
if (decimal64Gt(&res, &max, WORD_NUM(Decimal64))) {
|
if (decimal64Gt(&res, &max, WORD_NUM(Decimal64))) {
|
||||||
if (overflow) *overflow = true;
|
if (overflow) *overflow = true;
|
||||||
|
@ -1609,10 +1611,19 @@ int32_t decimal64FromStr(const char* str, int32_t len, uint8_t expectPrecision,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO wjm add round param
|
// TODO wjm add round param
|
||||||
static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown) {
|
static void decimal128ScaleDown(Decimal128* pDec, uint8_t scaleDown, bool round) {
|
||||||
if (scaleDown > 0) {
|
if (scaleDown > 0) {
|
||||||
Decimal128 divisor = SCALE_MULTIPLIER_128[scaleDown];
|
Decimal128 divisor = SCALE_MULTIPLIER_128[scaleDown], remainder = {0};
|
||||||
decimal128Divide(pDec, &divisor, 2, NULL);
|
decimal128Divide(pDec, &divisor, 2, &remainder);
|
||||||
|
if (round) {
|
||||||
|
decimal128Abs(&remainder);
|
||||||
|
Decimal128 half = SCALE_MULTIPLIER_128[scaleDown];
|
||||||
|
decimal128Divide(&half, &decimal128Two, 2, NULL);
|
||||||
|
if (!decimal128Lt(&remainder, &half, WORD_NUM(Decimal128))) {
|
||||||
|
Decimal64 delta = {DECIMAL128_SIGN(pDec)};
|
||||||
|
decimal128Add(pDec, &delta, WORD_NUM(Decimal64));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1627,7 +1638,7 @@ static void decimal128ScaleTo(Decimal128* pDec, uint8_t oldScale, uint8_t newSca
|
||||||
if (newScale > oldScale)
|
if (newScale > oldScale)
|
||||||
decimal128ScaleUp(pDec, newScale - oldScale);
|
decimal128ScaleUp(pDec, newScale - oldScale);
|
||||||
else if (newScale < oldScale)
|
else if (newScale < oldScale)
|
||||||
decimal128ScaleDown(pDec, oldScale - newScale);
|
decimal128ScaleDown(pDec, oldScale - newScale, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t decimal128FromStr(const char* str, int32_t len, uint8_t expectPrecision, uint8_t expectScale,
|
int32_t decimal128FromStr(const char* str, int32_t len, uint8_t expectPrecision, uint8_t expectScale,
|
||||||
|
@ -1781,7 +1792,7 @@ static void decimal128ModifyScaleAndPrecision(Decimal128* pDec, uint8_t scale, u
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Decimal128 res = *pDec, max = {0};
|
Decimal128 res = *pDec, max = {0};
|
||||||
decimal128ScaleDown(&res, -deltaScale);
|
decimal128ScaleDown(&res, -deltaScale, false);
|
||||||
DECIMAL128_GET_MAX(toPrec, &max);
|
DECIMAL128_GET_MAX(toPrec, &max);
|
||||||
if (decimal128Gt(&res, &max, WORD_NUM(Decimal128))) {
|
if (decimal128Gt(&res, &max, WORD_NUM(Decimal128))) {
|
||||||
if (overflow) *overflow = true;
|
if (overflow) *overflow = true;
|
||||||
|
|
|
@ -192,14 +192,17 @@ class DecimalColumnExpr:
|
||||||
return f"({self.format_})".format(*self.params_)
|
return f"({self.format_})".format(*self.params_)
|
||||||
|
|
||||||
def execute(self, params):
|
def execute(self, params):
|
||||||
return self.executor_(params)
|
return self.executor_(self, params)
|
||||||
|
|
||||||
def get_val(self, tbname: str, idx: int):
|
def get_val(self, tbname: str, idx: int):
|
||||||
params: Tuple = ()
|
params: Tuple = ()
|
||||||
for p in self.params_:
|
for p in self.params_:
|
||||||
params = params + (p.get_val(tbname, idx),)
|
params = params + (p.get_val(tbname, idx),)
|
||||||
return self.execute(params)
|
return self.execute(params)
|
||||||
|
|
||||||
|
def get_input_types(self) -> List:
|
||||||
|
pass
|
||||||
|
|
||||||
def check(self, query_col_res: List, tbname: str):
|
def check(self, query_col_res: List, tbname: str):
|
||||||
for i in range(len(query_col_res)):
|
for i in range(len(query_col_res)):
|
||||||
v_from_query = query_col_res[i]
|
v_from_query = query_col_res[i]
|
||||||
|
@ -226,12 +229,13 @@ class DecimalColumnExpr:
|
||||||
failed = dec_from_query != dec_from_insert
|
failed = dec_from_query != dec_from_insert
|
||||||
if failed:
|
if failed:
|
||||||
tdLog.exit(
|
tdLog.exit(
|
||||||
f"check decimal column failed for expr: {self} params: {params}, query: {v_from_query}, expect {dec_from_insert}, but get {dec_from_query}")
|
f"check decimal column failed for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, query: {v_from_query}, expect {dec_from_insert}, but get {dec_from_query}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
tdLog.debug(
|
tdLog.debug(
|
||||||
f"check decimal succ for expr: {self}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {dec_from_insert}"
|
f"check decimal succ for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {dec_from_insert}"
|
||||||
)
|
)
|
||||||
|
|
||||||
## format_params are already been set
|
## format_params are already been set
|
||||||
def generate_res_type(self):
|
def generate_res_type(self):
|
||||||
pass
|
pass
|
||||||
|
@ -298,7 +302,6 @@ class TypeEnum:
|
||||||
return type_str[type]
|
return type_str[type]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DataType:
|
class DataType:
|
||||||
def __init__(self, type: int, length: int = 0, type_mod: int = 0):
|
def __init__(self, type: int, length: int = 0, type_mod: int = 0):
|
||||||
self.type: int = type
|
self.type: int = type
|
||||||
|
@ -336,7 +339,7 @@ class DataType:
|
||||||
## TODO generate NULL, None
|
## TODO generate NULL, None
|
||||||
def generate_value(self) -> str:
|
def generate_value(self) -> str:
|
||||||
if self.type == TypeEnum.BOOL:
|
if self.type == TypeEnum.BOOL:
|
||||||
return str(secrets.randbelow(2))
|
return ['true', 'false'][secrets.randbelow(2)]
|
||||||
if self.type == TypeEnum.TINYINT:
|
if self.type == TypeEnum.TINYINT:
|
||||||
return str(secrets.randbelow(256) - 128)
|
return str(secrets.randbelow(256) - 128)
|
||||||
if self.type == TypeEnum.SMALLINT:
|
if self.type == TypeEnum.SMALLINT:
|
||||||
|
@ -375,6 +378,11 @@ class DataType:
|
||||||
def get_typed_val(self, val):
|
def get_typed_val(self, val):
|
||||||
if self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE:
|
if self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE:
|
||||||
return float(val)
|
return float(val)
|
||||||
|
if self.type == TypeEnum.BOOL:
|
||||||
|
if val == "true":
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
return val
|
return val
|
||||||
|
|
||||||
class DecimalType(DataType):
|
class DecimalType(DataType):
|
||||||
|
@ -479,18 +487,18 @@ class Column:
|
||||||
|
|
||||||
def is_constant_col(self):
|
def is_constant_col(self):
|
||||||
return '' in self.saved_vals.keys()
|
return '' in self.saved_vals.keys()
|
||||||
|
|
||||||
def get_typed_val(self, val):
|
def get_typed_val(self, val):
|
||||||
return self.type_.get_typed_val(val)
|
return self.type_.get_typed_val(val)
|
||||||
|
|
||||||
def get_constant_val(self):
|
def get_constant_val(self):
|
||||||
return self.get_typed_val(self.saved_vals[''][0])
|
return self.get_typed_val(self.saved_vals[''][0])
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.is_constant_col():
|
if self.is_constant_col():
|
||||||
return str(self.get_constant_val())
|
return str(self.get_constant_val())
|
||||||
return self.name_
|
return self.name_
|
||||||
|
|
||||||
def get_val(self, tbname: str, idx: int):
|
def get_val(self, tbname: str, idx: int):
|
||||||
if self.is_constant_col():
|
if self.is_constant_col():
|
||||||
return self.get_constant_val()
|
return self.get_constant_val()
|
||||||
|
@ -513,7 +521,7 @@ class Column:
|
||||||
|
|
||||||
def check(self, values, offset: int):
|
def check(self, values, offset: int):
|
||||||
return self.type_.check(values, offset)
|
return self.type_.check(values, offset)
|
||||||
|
|
||||||
def construct_type_value(self, val: str):
|
def construct_type_value(self, val: str):
|
||||||
if (
|
if (
|
||||||
self.type_.type == TypeEnum.BINARY
|
self.type_.type == TypeEnum.BINARY
|
||||||
|
@ -525,7 +533,7 @@ class Column:
|
||||||
return f"'{val}'"
|
return f"'{val}'"
|
||||||
else:
|
else:
|
||||||
return val
|
return val
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_decimal_unsupported_types() -> list:
|
def get_decimal_unsupported_types() -> list:
|
||||||
return [
|
return [
|
||||||
|
@ -533,11 +541,23 @@ class Column:
|
||||||
TypeEnum.GEOMETRY,
|
TypeEnum.GEOMETRY,
|
||||||
TypeEnum.VARBINARY,
|
TypeEnum.VARBINARY,
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_decimal_oper_const_cols() -> list:
|
def get_decimal_oper_const_cols() -> list:
|
||||||
return Column.get_all_type_columns(Column.get_decimal_unsupported_types() + Column.get_decimal_types())
|
types_unable_to_be_const = [
|
||||||
|
TypeEnum.TINYINT,
|
||||||
|
TypeEnum.SMALLINT,
|
||||||
|
TypeEnum.INT,
|
||||||
|
TypeEnum.UINT,
|
||||||
|
TypeEnum.USMALLINT,
|
||||||
|
TypeEnum.UTINYINT,
|
||||||
|
]
|
||||||
|
return Column.get_all_type_columns(
|
||||||
|
Column.get_decimal_unsupported_types()
|
||||||
|
+ Column.get_decimal_types()
|
||||||
|
+ types_unable_to_be_const
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_decimal_types() -> List:
|
def get_decimal_types() -> List:
|
||||||
return [TypeEnum.DECIMAL, TypeEnum.DECIMAL64]
|
return [TypeEnum.DECIMAL, TypeEnum.DECIMAL64]
|
||||||
|
@ -700,6 +720,8 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
||||||
def __init__(self, format, executor, op: str):
|
def __init__(self, format, executor, op: str):
|
||||||
super().__init__(format, executor)
|
super().__init__(format, executor)
|
||||||
self.op_ = op
|
self.op_ = op
|
||||||
|
self.left_type_: DataType = None
|
||||||
|
self.right_type_: DataType = None
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return super().__str__()
|
return super().__str__()
|
||||||
|
@ -727,13 +749,13 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
||||||
if not left.is_decimal_type():
|
if not left.is_decimal_type():
|
||||||
left_prec = TypeEnum.get_type_prec(left.type)
|
left_prec = TypeEnum.get_type_prec(left.type)
|
||||||
else:
|
else:
|
||||||
left_prec = DecimalType(left).prec()
|
left_prec = left.prec()
|
||||||
left_scale = DecimalType(left).scale()
|
left_scale = left.scale()
|
||||||
if not right.is_decimal_type():
|
if not right.is_decimal_type():
|
||||||
right_prec = TypeEnum.get_type_prec(right.type)
|
right_prec = TypeEnum.get_type_prec(right.type)
|
||||||
else:
|
else:
|
||||||
right_prec = DecimalType(right).prec()
|
right_prec = right.prec()
|
||||||
right_scale = DecimalType(right).scale()
|
right_scale = right.scale()
|
||||||
|
|
||||||
out_prec = 0
|
out_prec = 0
|
||||||
out_scale = 0
|
out_scale = 0
|
||||||
|
@ -762,12 +784,22 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
||||||
def generate_res_type(self):
|
def generate_res_type(self):
|
||||||
if DecimalBinaryOperator.is_compare_op(self.op_):
|
if DecimalBinaryOperator.is_compare_op(self.op_):
|
||||||
self.res_type_ = DataType(TypeEnum.BOOL)
|
self.res_type_ = DataType(TypeEnum.BOOL)
|
||||||
|
return
|
||||||
left_type = self.params_[0].type_
|
left_type = self.params_[0].type_
|
||||||
right_type = self.params_[1].type_
|
right_type = self.params_[1].type_
|
||||||
|
self.left_type_ = left_type
|
||||||
|
self.right_type_ = right_type
|
||||||
ret_double_types = [TypeEnum.VARCHAR, TypeEnum.BINARY, TypeEnum.DOUBLE, TypeEnum.FLOAT]
|
ret_double_types = [TypeEnum.VARCHAR, TypeEnum.BINARY, TypeEnum.DOUBLE, TypeEnum.FLOAT]
|
||||||
if left_type.type in ret_double_types or right_type.type in ret_double_types:
|
if left_type.type in ret_double_types or right_type.type in ret_double_types:
|
||||||
self.res_type_ = DataType(TypeEnum.DOUBLE)
|
self.res_type_ = DataType(TypeEnum.DOUBLE)
|
||||||
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(left_type, right_type, self.op_)
|
else:
|
||||||
|
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(left_type, right_type, self.op_)
|
||||||
|
|
||||||
|
def get_input_types(self)-> list:
|
||||||
|
return [self.left_type_, self.right_type_]
|
||||||
|
|
||||||
|
def convert_to_res_type(self, val: Decimal) -> Decimal:
|
||||||
|
return val.quantize(Decimal("1." + "0" * self.res_type_.scale()), ROUND_HALF_UP)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_ret_type(params) -> Tuple:
|
def get_ret_type(params) -> Tuple:
|
||||||
|
@ -784,49 +816,50 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
||||||
ret_float = True
|
ret_float = True
|
||||||
return (left, right), ret_float
|
return (left, right), ret_float
|
||||||
|
|
||||||
@staticmethod
|
def execute_plus(self, params):
|
||||||
def execute_plus(params):
|
|
||||||
if DecimalBinaryOperator.check_null(params):
|
if DecimalBinaryOperator.check_null(params):
|
||||||
return 'NULL'
|
return 'NULL'
|
||||||
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
|
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
|
||||||
if ret_float:
|
if ret_float:
|
||||||
return float(left) + float(right)
|
return float(left) + float(right)
|
||||||
else:
|
else:
|
||||||
return Decimal(left) + Decimal(right)
|
return self.convert_to_res_type(Decimal(left) + Decimal(right))
|
||||||
|
|
||||||
@staticmethod
|
def execute_minus(self, params):
|
||||||
def execute_minus(params):
|
|
||||||
if DecimalBinaryOperator.check_null(params):
|
if DecimalBinaryOperator.check_null(params):
|
||||||
return 'NULL'
|
return 'NULL'
|
||||||
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
|
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
|
||||||
if ret_float:
|
if ret_float:
|
||||||
return float(left) - float(right)
|
return float(left) - float(right)
|
||||||
else:
|
else:
|
||||||
return Decimal(left) - Decimal(right)
|
return self.convert_to_res_type(Decimal(left) - Decimal(right))
|
||||||
|
|
||||||
@staticmethod
|
def execute_mul(self, params):
|
||||||
def execute_mul(params):
|
|
||||||
if DecimalBinaryOperator.check_null(params):
|
if DecimalBinaryOperator.check_null(params):
|
||||||
return 'NULL'
|
return 'NULL'
|
||||||
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
|
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
|
||||||
if ret_float:
|
if ret_float:
|
||||||
return float(left) * float(right)
|
return float(left) * float(right)
|
||||||
else:
|
else:
|
||||||
return Decimal(left) * Decimal(right)
|
return self.convert_to_res_type(Decimal(left) * Decimal(right))
|
||||||
|
|
||||||
@staticmethod
|
def execute_div(self, params):
|
||||||
def execute_div(params):
|
|
||||||
if DecimalBinaryOperator.check_null(params):
|
if DecimalBinaryOperator.check_null(params):
|
||||||
return 'NULL'
|
return 'NULL'
|
||||||
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
|
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
|
||||||
if ret_float:
|
if ret_float:
|
||||||
return float(left) / float(right)
|
return float(left) / float(right)
|
||||||
else:
|
else:
|
||||||
return Decimal(left) / Decimal(right)
|
return self.convert_to_res_type(Decimal(left) / Decimal(right))
|
||||||
|
|
||||||
@staticmethod
|
def execute_mod(self, params):
|
||||||
def execute_mod(params):
|
if DecimalBinaryOperator.check_null(params):
|
||||||
return params[0] % params[1]
|
return 'NULL'
|
||||||
|
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
|
||||||
|
if ret_float:
|
||||||
|
return float(left) % float(right)
|
||||||
|
else:
|
||||||
|
return self.convert_to_res_type(Decimal(left) % Decimal(right))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def execute_eq(params):
|
def execute_eq(params):
|
||||||
|
|
Loading…
Reference in New Issue