test decimal operators

This commit is contained in:
wangjiaming0909 2025-03-01 23:53:38 +08:00
parent 759c9af591
commit 66e1fddce5
5 changed files with 99 additions and 30 deletions

View File

@ -198,6 +198,7 @@ static int32_t decimalVarFromStr(const char* str, int32_t len, DecimalVar* resul
if (isdigit(str[pos2] || str[pos] == '.')) continue; if (isdigit(str[pos2] || str[pos] == '.')) continue;
if (str[pos2] == 'e' || str[pos2] == 'E') { if (str[pos2] == 'e' || str[pos2] == 'E') {
result->exponent = atoi(str + pos2 + 1); result->exponent = atoi(str + pos2 + 1);
break;
} }
pos2++; pos2++;
} }
@ -859,7 +860,7 @@ static void makeInt256FromDecimal128(Int256* pTarget, const Decimal128* pDec) {
UInt128 tmp = {DECIMAL128_LOW_WORD(&abs), DECIMAL128_HIGH_WORD(&abs)}; UInt128 tmp = {DECIMAL128_LOW_WORD(&abs), DECIMAL128_HIGH_WORD(&abs)};
*pTarget = makeInt256(int128Zero, tmp); *pTarget = makeInt256(int128Zero, tmp);
if (negative) { if (negative) {
int256Negate(pTarget); *pTarget = int256Negate(pTarget);
} }
} }
@ -876,8 +877,8 @@ static Int256 int256ScaleBy(const Int256* pX, int32_t scale) {
Int256 remainder = int256Mod(pX, &divisor); Int256 remainder = int256Mod(pX, &divisor);
Int256 afterShift = int256RightShift(&divisor, 1); Int256 afterShift = int256RightShift(&divisor, 1);
remainder = int256Abs(&remainder); remainder = int256Abs(&remainder);
if (int256Gt(&remainder, &afterShift)) { if (!int256Gt(&afterShift, &remainder)) {
if (int256Gt(&result, &int256Zero)) { if (int256Gt(pX, &int256Zero)) {
result = int256Add(&result, &int256One); result = int256Add(&result, &int256One);
} else { } else {
result = int256Subtract(&result, &int256One); result = int256Subtract(&result, &int256One);
@ -1079,6 +1080,11 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR
} else { } else {
right = *(Decimal*)pRightData; right = *(Decimal*)pRightData;
} }
#ifdef DEBUG
char left_var[64] = {0}, right_var[64] = {0};
decimal128ToStr(&left, lt.scale, left_var, 64);
decimal128ToStr(&right, rt.scale, right_var, 64);
#endif
switch (op) { switch (op) {
case OP_TYPE_ADD: case OP_TYPE_ADD:
@ -1444,7 +1450,9 @@ static int32_t decimal128FromDecimal128(DecimalType* pDec, uint8_t prec, uint8_t
} break; \ } break; \
case TSDB_DATA_TYPE_VARCHAR: \ case TSDB_DATA_TYPE_VARCHAR: \
case TSDB_DATA_TYPE_VARBINARY: \ case TSDB_DATA_TYPE_VARBINARY: \
case TSDB_DATA_TYPE_NCHAR: \ case TSDB_DATA_TYPE_NCHAR: { \
code = decimal##FromStr(pData, pInputType->bytes, pOutType->precision, pOutType->scale, pOut); \
} break; \
default: \ default: \
code = TSDB_CODE_OPS_NOT_SUPPORT; \ code = TSDB_CODE_OPS_NOT_SUPPORT; \
break; \ break; \

View File

@ -214,12 +214,31 @@ Int256 int256Multiply(const Int256* pLeft, const Int256* pRight) {
return *(Int256*)&result; return *(Int256*)&result;
} }
Int256 int256Divide(const Int256* pLeft, const Int256* pRight) { Int256 int256Divide(const Int256* pLeft, const Int256* pRight) {
intx::uint256 result = *(intx::uint256*)pLeft / *(intx::uint256*)pRight; Int256 l = *pLeft, r = *pRight;
return *(Int256*)&result; bool leftNegative = int256Lt(pLeft, &int256Zero), rightNegative = int256Lt(pRight, &int256Zero);
if (leftNegative) {
l = int256Abs(pLeft);
}
if (rightNegative) {
r = int256Abs(pRight);
}
intx::uint256 result = *(intx::uint256*)&l / *(intx::uint256*)&r;
Int256 res = *(Int256*)&result;
if (leftNegative != rightNegative)
res = int256Negate(&res);
return res;
} }
Int256 int256Mod(const Int256* pLeft, const Int256* pRight) { Int256 int256Mod(const Int256* pLeft, const Int256* pRight) {
intx::uint256 result = *(intx::uint256*)pLeft % *(intx::uint256*)pRight; Int256 left = *pLeft;
return *(Int256*)&result; bool leftNegative = int256Lt(pLeft, &int256Zero);
if (leftNegative) {
left = int256Abs(&left);
}
intx::uint256 result = *(intx::uint256*)&left % *(intx::uint256*)pRight;
Int256 res = *(Int256*)&result;
if (leftNegative) res = int256Negate(&res);
return res;
} }
bool int256Lt(const Int256* pLeft, const Int256* pRight) { bool int256Lt(const Int256* pLeft, const Int256* pRight) {
Int128 hiLeft = int256Hi(pLeft), hiRight = int256Hi(pRight); Int128 hiLeft = int256Hi(pLeft), hiRight = int256Hi(pRight);

View File

@ -731,6 +731,11 @@ TEST(decimal, decimalOpRetType) {
tExpect.scale = 13; tExpect.scale = 13;
tExpect.bytes = sizeof(Decimal); tExpect.bytes = sizeof(Decimal);
code = decimalGetRetType(&ta, &tb, op, &tc); code = decimalGetRetType(&ta, &tb, op, &tc);
Numeric<64> aNum = {10, 2, "123.99"};
int64_t bInt64 = 317759474393305778;
auto res = aNum / bInt64;
ASSERT_EQ(res.scale(), 22);
} }
TEST(decimal, op) { TEST(decimal, op) {

View File

@ -4204,6 +4204,9 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn
case TSDB_DATA_TYPE_DOUBLE: case TSDB_DATA_TYPE_DOUBLE:
pInput = &valNode->datum.d; pInput = &valNode->datum.d;
break; break;
case TSDB_DATA_TYPE_VARCHAR:
pInput = &valNode->datum.p;
break;
default: default:
qError("not supported type %d when build decimal datum from value node", valNode->node.resType.type); qError("not supported type %d when build decimal datum from value node", valNode->node.resType.type);
return TSDB_CODE_INVALID_PARA; return TSDB_CODE_INVALID_PARA;
@ -4219,9 +4222,11 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn
datum->pData = pData; datum->pData = pData;
datum->kind = FLT_SCL_DATUM_KIND_DECIMAL; datum->kind = FLT_SCL_DATUM_KIND_DECIMAL;
} }
int32_t code = convertToDecimal(pInput, &valNode->node.resType, pData, &datum->type); if (datum->kind == FLT_SCL_DATUM_KIND_DECIMAL64 || datum->kind == FLT_SCL_DATUM_KIND_DECIMAL) {
if (TSDB_CODE_SUCCESS != code) return code; // TODO wjm handle overflow error int32_t code = convertToDecimal(pInput, &valNode->node.resType, pData, &datum->type);
valNode->node.resType = datum->type; if (TSDB_CODE_SUCCESS != code) return code; // TODO wjm handle overflow error
valNode->node.resType = datum->type;
}
} }
FLT_RET(0); FLT_RET(0);
} }

View File

@ -1,16 +1,12 @@
from ast import Tuple from ast import Tuple
import math import math
from pydoc import doc
from random import randrange from random import randrange
import random import random
from re import A
import time import time
import threading import threading
import secrets import secrets
from regex import D, F
import query import query
from tag_lite import column
from util.log import * from util.log import *
from util.sql import * from util.sql import *
from util.cases import * from util.cases import *
@ -209,7 +205,7 @@ class DecimalColumnExpr:
params: Tuple = () params: Tuple = ()
for p in self.params_: for p in self.params_:
if isinstance(p, Column) or isinstance(p, DecimalColumnExpr): if isinstance(p, Column) or isinstance(p, DecimalColumnExpr):
p = p.get_val(tbname, i) p = p.get_val_for_execute(tbname, i)
params = params + (p,) params = params + (p,)
v_from_calc_in_py = self.execute(params) v_from_calc_in_py = self.execute(params)
@ -222,7 +218,7 @@ class DecimalColumnExpr:
if isinstance(v_from_calc_in_py, float): if isinstance(v_from_calc_in_py, float):
dec_from_query = float(v_from_query) dec_from_query = float(v_from_query)
dec_from_insert = float(v_from_calc_in_py) dec_from_insert = float(v_from_calc_in_py)
failed = not math.isclose(dec_from_query, dec_from_insert, rel_tol=1e-9, abs_tol=1e-9) failed = not math.isclose(dec_from_query, dec_from_insert, abs_tol=1e-7)
else: else:
dec_from_query = Decimal(v_from_query) dec_from_query = Decimal(v_from_query)
dec_from_insert = Decimal(v_from_calc_in_py) dec_from_insert = Decimal(v_from_calc_in_py)
@ -329,6 +325,12 @@ class DataType:
def is_decimal_type(self): def is_decimal_type(self):
return self.type == TypeEnum.DECIMAL or self.type == TypeEnum.DECIMAL64 return self.type == TypeEnum.DECIMAL or self.type == TypeEnum.DECIMAL64
def is_varchar_type(self):
return self.type == TypeEnum.VARCHAR or self.type == TypeEnum.NCHAR or self.type == TypeEnum.VARBINARY or self.type == TypeEnum.JSON or self.type == TypeEnum.BINARY
def is_real_type(self):
return self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE
def prec(self): def prec(self):
return 0 return 0
@ -375,7 +377,7 @@ class DataType:
def check(self, values, offset: int): def check(self, values, offset: int):
return True return True
def get_typed_val(self, val): def get_typed_val_for_execute(self, val):
if self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE: if self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE:
return float(val) return float(val)
if self.type == TypeEnum.BOOL: if self.type == TypeEnum.BOOL:
@ -384,6 +386,11 @@ class DataType:
else: else:
return 0 return 0
return val return val
def get_typed_val(self, val):
if self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE:
return float(val)
return val
class DecimalType(DataType): class DecimalType(DataType):
def __init__(self, type, precision: int, scale: int): def __init__(self, type, precision: int, scale: int):
@ -435,6 +442,9 @@ class DecimalType(DataType):
if val == "NULL": if val == "NULL":
return None return None
return Decimal(val).quantize(Decimal("1." + "0" * self.scale()), ROUND_HALF_UP) return Decimal(val).quantize(Decimal("1." + "0" * self.scale()), ROUND_HALF_UP)
def get_typed_val_for_execute(self, val):
return self.get_typed_val(val)
@staticmethod @staticmethod
def default_compression() -> str: def default_compression() -> str:
@ -490,19 +500,25 @@ class Column:
def get_typed_val(self, val): def get_typed_val(self, val):
return self.type_.get_typed_val(val) return self.type_.get_typed_val(val)
def get_typed_val_for_execute(self, val):
return self.type_.get_typed_val_for_execute(val)
def get_constant_val(self): def get_constant_val(self):
return self.get_typed_val(self.saved_vals[''][0]) return self.get_typed_val(self.saved_vals[''][0])
def get_constant_val_for_execute(self):
return self.get_typed_val_for_execute(self.saved_vals[''][0])
def __str__(self): def __str__(self):
if self.is_constant_col(): if self.is_constant_col():
return str(self.get_constant_val()) return str(self.get_constant_val())
return self.name_ return self.name_
def get_val(self, tbname: str, idx: int): def get_val_for_execute(self, tbname: str, idx: int):
if self.is_constant_col(): if self.is_constant_col():
return self.get_constant_val() return self.get_constant_val_for_execute()
return self.get_typed_val(self.saved_vals[tbname][idx]) return self.get_typed_val_for_execute(self.saved_vals[tbname][idx])
## tbName: for normal table, pass the tbname, for child table, pass the child table name ## tbName: for normal table, pass the tbname, for child table, pass the child table name
def generate_value(self, tbName: str = '', save: bool = True): def generate_value(self, tbName: str = '', save: bool = True):
@ -551,6 +567,7 @@ class Column:
TypeEnum.UINT, TypeEnum.UINT,
TypeEnum.USMALLINT, TypeEnum.USMALLINT,
TypeEnum.UTINYINT, TypeEnum.UTINYINT,
TypeEnum.UBIGINT,
] ]
return Column.get_all_type_columns( return Column.get_all_type_columns(
Column.get_decimal_unsupported_types() Column.get_decimal_unsupported_types()
@ -728,6 +745,18 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def generate(self, format_params: Tuple) -> str: def generate(self, format_params: Tuple) -> str:
return super().generate(format_params) return super().generate(format_params)
def should_skip_for_decimal(self, left_col: Column, right_col: Column):
if not left_col.type_.is_decimal_type() and not right_col.type_.is_decimal_type():
return True
if self.op_ != "%":
return False
## why skip decimal % float/double?? it's wrong now.
left_is_real = left_col.type_.is_real_type() or left_col.type_.is_varchar_type()
right_is_real = right_col.type_.is_real_type() or right_col.type_.is_varchar_type()
if left_is_real or right_is_real:
return True
return False
@staticmethod @staticmethod
def check_null(params): def check_null(params):
@ -789,7 +818,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
right_type = self.params_[1].type_ right_type = self.params_[1].type_
self.left_type_ = left_type self.left_type_ = left_type
self.right_type_ = right_type self.right_type_ = right_type
ret_double_types = [TypeEnum.VARCHAR, TypeEnum.BINARY, TypeEnum.DOUBLE, TypeEnum.FLOAT] ret_double_types = [TypeEnum.VARCHAR, TypeEnum.BINARY, TypeEnum.DOUBLE, TypeEnum.FLOAT, TypeEnum.NCHAR]
if left_type.type in ret_double_types or right_type.type in ret_double_types: if left_type.type in ret_double_types or right_type.type in ret_double_types:
self.res_type_ = DataType(TypeEnum.DOUBLE) self.res_type_ = DataType(TypeEnum.DOUBLE)
else: else:
@ -799,7 +828,10 @@ class DecimalBinaryOperator(DecimalColumnExpr):
return [self.left_type_, self.right_type_] return [self.left_type_, self.right_type_]
def convert_to_res_type(self, val: Decimal) -> Decimal: def convert_to_res_type(self, val: Decimal) -> Decimal:
return val.quantize(Decimal("1." + "0" * self.res_type_.scale()), ROUND_HALF_UP) if self.res_type_.is_decimal_type():
return val.quantize(Decimal("0." + "0" * self.res_type_.scale()), ROUND_HALF_UP)
elif self.res_type_.type == TypeEnum.DOUBLE:
return float(val)
@staticmethod @staticmethod
def get_ret_type(params) -> Tuple: def get_ret_type(params) -> Tuple:
@ -820,7 +852,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
if DecimalBinaryOperator.check_null(params): if DecimalBinaryOperator.check_null(params):
return 'NULL' return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params) (left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float: if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) + float(right) return float(left) + float(right)
else: else:
return self.convert_to_res_type(Decimal(left) + Decimal(right)) return self.convert_to_res_type(Decimal(left) + Decimal(right))
@ -829,7 +861,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
if DecimalBinaryOperator.check_null(params): if DecimalBinaryOperator.check_null(params):
return 'NULL' return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params) (left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float: if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) - float(right) return float(left) - float(right)
else: else:
return self.convert_to_res_type(Decimal(left) - Decimal(right)) return self.convert_to_res_type(Decimal(left) - Decimal(right))
@ -838,7 +870,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
if DecimalBinaryOperator.check_null(params): if DecimalBinaryOperator.check_null(params):
return 'NULL' return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params) (left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float: if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) * float(right) return float(left) * float(right)
else: else:
return self.convert_to_res_type(Decimal(left) * Decimal(right)) return self.convert_to_res_type(Decimal(left) * Decimal(right))
@ -847,7 +879,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
if DecimalBinaryOperator.check_null(params): if DecimalBinaryOperator.check_null(params):
return 'NULL' return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params) (left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float: if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) / float(right) return float(left) / float(right)
else: else:
return self.convert_to_res_type(Decimal(left) / Decimal(right)) return self.convert_to_res_type(Decimal(left) / Decimal(right))
@ -856,8 +888,8 @@ class DecimalBinaryOperator(DecimalColumnExpr):
if DecimalBinaryOperator.check_null(params): if DecimalBinaryOperator.check_null(params):
return 'NULL' return 'NULL'
(left, right), ret_float = DecimalBinaryOperator.get_ret_type(params) (left, right), ret_float = DecimalBinaryOperator.get_ret_type(params)
if ret_float: if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) % float(right) return self.convert_to_res_type(Decimal(left) % Decimal(right))
else: else:
return self.convert_to_res_type(Decimal(left) % Decimal(right)) return self.convert_to_res_type(Decimal(left) % Decimal(right))
@ -1279,7 +1311,7 @@ class TDTestCase:
left_is_decimal = col.type_.is_decimal_type() left_is_decimal = col.type_.is_decimal_type()
for const_col in constant_cols: for const_col in constant_cols:
right_is_decimal = const_col.type_.is_decimal_type() right_is_decimal = const_col.type_.is_decimal_type()
if not left_is_decimal and not right_is_decimal: if expr.should_skip_for_decimal(col, const_col):
continue continue
const_col.generate_value() const_col.generate_value()
select_expr = expr.generate((col, const_col)) select_expr = expr.generate((col, const_col))