decimal unary operator test

This commit is contained in:
wangjiaming0909 2025-03-03 08:53:27 +08:00
parent cd9942e5c0
commit b36af6ca59
4 changed files with 158 additions and 50 deletions

View File

@ -1069,18 +1069,19 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR
SDataType rt = {.type = TSDB_DATA_TYPE_DECIMAL,
.precision = TSDB_DECIMAL_MAX_PRECISION,
.bytes = tDataTypes[TSDB_DATA_TYPE_DECIMAL].bytes,
.scale = pRightT->scale};
.scale = 0};
if (pRightT) rt.scale = pRightT->scale;
if (TSDB_DATA_TYPE_DECIMAL != pLeftT->type) {
code = convertToDecimal(pLeftData, pLeftT, &left, &lt);
if (TSDB_CODE_SUCCESS != code) return code; // TODO add some logs here
} else {
left = *(Decimal*)pLeftData;
}
if (TSDB_DATA_TYPE_DECIMAL != pRightT->type) {
if (pRightT && TSDB_DATA_TYPE_DECIMAL != pRightT->type) {
code = convertToDecimal(pRightData, pRightT, &right, &rt);
if (TSDB_CODE_SUCCESS != code) return code;
pRightData = &right;
} else {
} else if (pRightData){
right = *(Decimal*)pRightData;
}
#ifdef DEBUG
@ -1107,6 +1108,9 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR
case OP_TYPE_REM:
code = decimalMod(&left, &lt, &right, &rt, pOutT);
break;
case OP_TYPE_MINUS:
decimal128Negate(&left);
break;
default:
code = TSDB_CODE_TSC_INVALID_OPERATION;
break;
@ -1239,7 +1243,7 @@ static int32_t decimal64FromDecimal128(DecimalType* pDec, uint8_t prec, uint8_t
Decimal64 max = {0};
DECIMAL64_GET_MAX(prec - scale, &max);
decimal128ScaleTo(&dec128, valScale, 0);
decimal128ScaleDown(&dec128, valScale, false);
if (decimal128Gt(&dec128, &max, WORD_NUM(Decimal64))) {
return TSDB_CODE_DECIMAL_OVERFLOW;
}
@ -1257,7 +1261,7 @@ static int32_t decimal64FromDecimal64(DecimalType* pDec, uint8_t prec, uint8_t s
*(Decimal64*)pDec = dec64;
DECIMAL64_GET_MAX(prec - scale, &max);
decimal64ScaleTo(&dec64, valScale, 0);
decimal64ScaleDown(&dec64, valScale, false);
if (decimal64Lt(&max, &dec64, WORD_NUM(Decimal64))) {
return TSDB_CODE_DECIMAL_OVERFLOW;
}
@ -1529,7 +1533,9 @@ static void decimal64ScaleAndCheckOverflow(Decimal64* pDec, int8_t scale, uint8_
Decimal64 res = *pDec, max = {0};
decimal64ScaleDown(&res, -deltaScale, false);
DECIMAL64_GET_MAX(toPrec, &max);
if (decimal64Gt(&res, &max, WORD_NUM(Decimal64))) {
Decimal64 abs = res;
decimal64Abs(&abs);
if (decimal64Gt(&abs, &max, WORD_NUM(Decimal64))) {
if (overflow) *overflow = true;
} else {
*pDec = res;

View File

@ -1726,11 +1726,17 @@ _return:
}
static int32_t sclGetMinusOperatorResType(SOperatorNode *pOp) {
if (!IS_MATHABLE_TYPE(((SExprNode *)(pOp->pLeft))->resType.type)) {
const SDataType* pDt = &((SExprNode*)(pOp->pLeft))->resType;
if (!IS_MATHABLE_TYPE(pDt->type)) {
return TSDB_CODE_TSC_INVALID_OPERATION;
}
pOp->node.resType.type = TSDB_DATA_TYPE_DOUBLE;
pOp->node.resType.bytes = tDataTypes[TSDB_DATA_TYPE_DOUBLE].bytes;
if (IS_DECIMAL_TYPE(pDt->type)) {
pOp->node.resType = *pDt;
} else {
pOp->node.resType.type = TSDB_DATA_TYPE_DOUBLE;
pOp->node.resType.bytes = tDataTypes[TSDB_DATA_TYPE_DOUBLE].bytes;
}
return TSDB_CODE_SUCCESS;
}

View File

@ -49,7 +49,11 @@ bool compareForType(__compar_fn_t fp, int32_t optr, SColumnInfoData* pColL, int3
bool compareForTypeWithColAndHash(__compar_fn_t fp, int32_t optr, SColumnInfoData *pColL, int32_t idxL,
const void *hashData, int32_t hashType, STypeMod hashTypeMod);
static int32_t vectorMathOpForDecimal(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *pOut, int32_t step, int32_t i, EOperatorType op);
static int32_t vectorMathBinaryOpForDecimal(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *pOut, int32_t step,
int32_t i, EOperatorType op);
static int32_t vectorMathUnaryOpForDecimal(SScalarParam *pCol, SScalarParam *pOut, int32_t step, int32_t i,
EOperatorType op);
int32_t convertNumberToNumber(const void *inData, void *outData, int8_t inType, int8_t outType) {
switch (outType) {
@ -1340,7 +1344,7 @@ int32_t vectorMathAdd(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *p
SCL_ERR_JRET(vectorMathAddHelper(pLeftCol, pRightCol, pOutputCol, pLeft->numOfRows, step, i));
}
} else if (IS_DECIMAL_TYPE(pOutputCol->info.type)) {
SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_ADD));
SCL_ERR_JRET(vectorMathBinaryOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_ADD));
}
_return:
@ -1473,7 +1477,7 @@ int32_t vectorMathSub(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *p
SCL_ERR_JRET(vectorMathSubHelper(pLeftCol, pRightCol, pOutputCol, pLeft->numOfRows, step, 1, i));
}
} else if (pOutputCol->info.type == TSDB_DATA_TYPE_DECIMAL) {
SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_SUB));
SCL_ERR_JRET(vectorMathBinaryOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_SUB));
}
_return:
@ -1522,7 +1526,7 @@ int32_t vectorMathMultiply(SScalarParam *pLeft, SScalarParam *pRight, SScalarPar
SColumnInfoData *pLeftCol = NULL;
SColumnInfoData *pRightCol = NULL;
if (pOutputCol->info.type == TSDB_DATA_TYPE_DECIMAL) {
SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_MULTI));
SCL_ERR_JRET(vectorMathBinaryOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_MULTI));
} else {
SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol));
SCL_ERR_JRET(vectorConvertVarToDouble(pRight, &rightConvert, &pRightCol));
@ -1570,7 +1574,7 @@ int32_t vectorMathDivide(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam
SColumnInfoData *pLeftCol = NULL;
SColumnInfoData *pRightCol = NULL;
if (pOutputCol->info.type == TSDB_DATA_TYPE_DECIMAL) {
SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_DIV));
SCL_ERR_JRET(vectorMathBinaryOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_DIV));
} else {
SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol));
SCL_ERR_JRET(vectorConvertVarToDouble(pRight, &rightConvert, &pRightCol));
@ -1658,7 +1662,7 @@ int32_t vectorMathRemainder(SScalarParam *pLeft, SScalarParam *pRight, SScalarPa
SColumnInfoData *pLeftCol = NULL;
SColumnInfoData *pRightCol = NULL;
if (pOutputCol->info.type == TSDB_DATA_TYPE_DECIMAL) {
SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_REM));
SCL_ERR_JRET(vectorMathBinaryOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_REM));
} else {
SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol));
SCL_ERR_JRET(vectorConvertVarToDouble(pRight, &rightConvert, &pRightCol));
@ -1708,20 +1712,24 @@ int32_t vectorMathMinus(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam
int32_t leftConvert = 0;
SColumnInfoData *pLeftCol = NULL;
SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol));
if (IS_DECIMAL_TYPE(pOutputCol->info.type)) {
SCL_ERR_JRET(vectorMathUnaryOpForDecimal(pLeft, pOut, step, i, OP_TYPE_MINUS));
} else {
SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol));
_getDoubleValue_fn_t getVectorDoubleValueFnLeft;
SCL_ERR_JRET(getVectorDoubleValueFn(pLeftCol->info.type, &getVectorDoubleValueFnLeft));
_getDoubleValue_fn_t getVectorDoubleValueFnLeft;
SCL_ERR_JRET(getVectorDoubleValueFn(pLeftCol->info.type, &getVectorDoubleValueFnLeft));
double *output = (double *)pOutputCol->pData;
for (; i < pLeft->numOfRows && i >= 0; i += step, output += 1) {
if (IS_HELPER_NULL(pLeftCol, i)) {
colDataSetNULL(pOutputCol, i);
continue;
double *output = (double *)pOutputCol->pData;
for (; i < pLeft->numOfRows && i >= 0; i += step, output += 1) {
if (IS_HELPER_NULL(pLeftCol, i)) {
colDataSetNULL(pOutputCol, i);
continue;
}
double result = 0;
SCL_ERR_JRET(getVectorDoubleValueFnLeft(LEFT_COL, i, &result));
*output = (result == 0) ? 0 : -result;
}
double result = 0;
SCL_ERR_JRET(getVectorDoubleValueFnLeft(LEFT_COL, i, &result));
*output = (result == 0) ? 0 : -result;
}
_return:
@ -2308,7 +2316,25 @@ static int32_t vectorMathOpOneRowForDecimal(SScalarParam *pLeft, SScalarParam *p
return code;
}
static int32_t vectorMathOpForDecimal(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *pOut, int32_t step, int32_t i, EOperatorType op) {
static int32_t vectorMathUnaryOpForDecimal(SScalarParam *pCol, SScalarParam *pOut, int32_t step, int32_t i,
EOperatorType op) {
int32_t code = 0;
SColumnInfoData *pOutputCol = pOut->columnData;
void *pDec = pOutputCol->pData;
for (; i < pCol->numOfRows && i >= 0; i += step, pDec += tDataTypes[pOutputCol->info.type].bytes) {
if (IS_HELPER_NULL(pCol->columnData, i)) {
colDataSetNULL(pOutputCol, i);
continue;
}
SDataType colDt = GET_COL_DATA_TYPE(pCol->columnData->info), outDt = GET_COL_DATA_TYPE(pOutputCol->info);
code = decimalOp(op, &colDt, NULL, &outDt, colDataGetData(pCol->columnData, i), NULL, pDec);
}
return code;
}
static int32_t vectorMathBinaryOpForDecimal(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *pOut, int32_t step,
int32_t i, EOperatorType op) {
Decimal *output = (Decimal *)pOut->columnData->pData;
int32_t code = 0;
SDataType leftType = GET_COL_DATA_TYPE(pLeft->columnData->info),

View File

@ -23,8 +23,10 @@ invalid_operation = -2147483136
scalar_convert_err = -2147470768
operator_test_round = 2
operator_test_round = 1
tb_insert_rows = 1000
binary_op_test = True
unary_op_test = True
class DecimalTypeGeneratorConfig:
def __init__(self):
@ -199,9 +201,18 @@ class DecimalColumnExpr:
params = params + (p.get_val(tbname, idx),)
return self.execute(params)
def convert_to_res_type(self, val: Decimal) -> Decimal:
if self.res_type_.is_decimal_type():
return val.quantize(Decimal("0." + "0" * self.res_type_.scale()), ROUND_HALF_UP)
elif self.res_type_.type == TypeEnum.DOUBLE:
return float(val)
def get_input_types(self) -> List:
pass
def should_skip_for_decimal(self, cols: list):
return False
def check(self, query_col_res: List, tbname: str):
for i in range(len(query_col_res)):
v_from_query = query_col_res[i]
@ -235,8 +246,8 @@ class DecimalColumnExpr:
f"check decimal column failed for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, query: {v_from_query}, expect {calc_res}, but get {query_res}"
)
else:
tdLog.debug(
f"check decimal succ for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {calc_res}"
tdLog.info(
f"op succ: {self}, in: {[t.__str__() for t in self.get_input_types()]}, res: {self.res_type_}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py calc: {calc_res}"
)
## format_params are already been set
@ -533,7 +544,11 @@ class Column:
if save:
if tbName not in self.saved_vals:
self.saved_vals[tbName] = []
self.saved_vals[tbName].append(val)
## for constant columns, always replace the last val
if self.is_constant_col():
self.saved_vals[tbName] = [val]
else:
self.saved_vals[tbName].append(val)
return val
def get_type_str(self) -> str:
@ -753,12 +768,14 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def generate(self, format_params) -> str:
return super().generate(format_params)
def should_skip_for_decimal(self, left_col: Column, right_col: Column):
def should_skip_for_decimal(self, cols: list):
left_col = cols[0]
right_col = cols[1]
if not left_col.type_.is_decimal_type() and not right_col.type_.is_decimal_type():
return True
if self.op_ != "%":
return False
## why skip decimal % float/double?? it's wrong now.
## why skip decimal % float/double? it's wrong now.
left_is_real = left_col.type_.is_real_type() or left_col.type_.is_varchar_type()
right_is_real = right_col.type_.is_real_type() or right_col.type_.is_varchar_type()
if left_is_real or right_is_real:
@ -833,12 +850,6 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def get_input_types(self)-> list:
return [self.left_type_, self.right_type_]
def convert_to_res_type(self, val: Decimal) -> Decimal:
if self.res_type_.is_decimal_type():
return val.quantize(Decimal("0." + "0" * self.res_type_.scale()), ROUND_HALF_UP)
elif self.res_type_.type == TypeEnum.DOUBLE:
return float(val)
@staticmethod
def get_convert_type(params):
@ -972,8 +983,36 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def execute(self, params):
return super().execute(params)
class DecimalUnaryOperator(DecimalColumnExpr):
def __init__(self, format, executor, op: str):
super().__init__(format, executor)
self.op_ = op
self.col_type_: DataType = None
def should_skip_for_decimal(self, cols: list):
col:Column = cols[0]
if not col.type_.is_decimal_type():
return True
return False
@staticmethod
def get_all_unary_ops() -> List[DecimalColumnExpr]:
return [
DecimalUnaryOperator(" -{0} ", DecimalUnaryOperator.execute_minus, "-"),
]
def get_input_types(self)-> list:
return [self.col_type_]
def generate_res_type(self):
self.res_type_ = self.col_type_ = self.params_[0].type_
def execute_minus(self, params) -> Decimal:
if params[0] is None:
return 'NULL'
return -Decimal(params[0])
class DecimalBinaryOperatorIn(DecimalBinaryOperator):
def __init__(self, op: str):
super().__init__(op)
@ -983,7 +1022,7 @@ class DecimalBinaryOperatorIn(DecimalBinaryOperator):
return left in right
if self.op_.lower() == "not in":
return left not in right
class TDTestCase:
updatecfgDict = {
@ -1341,6 +1380,8 @@ class TDTestCase:
constant_cols: List[Column],
exprs: List[DecimalColumnExpr],
):
if not binary_op_test:
return
for expr in exprs:
for col in tb_cols:
if col.name_ == '':
@ -1348,22 +1389,46 @@ class TDTestCase:
left_is_decimal = col.type_.is_decimal_type()
for const_col in constant_cols:
right_is_decimal = const_col.type_.is_decimal_type()
if expr.should_skip_for_decimal(col, const_col):
if expr.should_skip_for_decimal([col, const_col]):
continue
const_col.generate_value()
select_expr2 = expr.generate((const_col, col))
sql = f"select {select_expr2} from {dbname}.{tbname}"
res2 = TaosShell().query(sql)
select_expr = expr.generate((col, const_col))
sql = f"select {select_expr} from {dbname}.{tbname}"
res = TaosShell().query(sql)
if len(res) > 0:
if len(res) != len(res2):
tdLog.exit(
f"sql: {sql} got different row number for {select_expr} and {select_expr2}"
)
for c, c2 in zip(res, res2):
for t, t2 in zip(c, c2):
if t != t2:
tdLog.exit(
f"sql: {sql} got different result for {select_expr} and {select_expr2}, expect {t2}, but got {t}"
)
expr.check(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
## query
## build expr, expr.generate(column) to generate sql expr
## pass this expr into DataValidator.
# When validating between query results and local values, pass the column data into the Expr, and invoke expr.execute
## get result
## check result
def check_decimal_unary_expr_results(self, dbname, tbname, tb_cols: List[Column], exprs: List[DecimalColumnExpr]):
if not unary_op_test:
return
for expr in exprs:
for col in tb_cols:
if col.name_ == '':
continue
if expr.should_skip_for_decimal([col]):
continue
select_expr = expr.generate([col])
sql = f"select {select_expr} from {dbname}.{tbname}"
res = TaosShell().query(sql)
if len(res) > 0:
expr.check(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
## test others unsupported types operator with decimal
def test_decimal_unsupported_types(self):
@ -1415,6 +1480,7 @@ class TDTestCase:
## tables: meters, nt
## columns: c1, c2, c3, c4, c5, c7, c8, c9, c10, c99, c100
binary_operators = DecimalBinaryOperator.get_all_binary_ops()
binary_operators = binary_operators[0:1]
all_type_columns = Column.get_decimal_oper_const_cols()
## decimal operator with constants of all other types
@ -1427,9 +1493,13 @@ class TDTestCase:
binary_operators,
)
## decimal operator with columns of all other types
unary_operators = ["-"]
unary_operators = DecimalUnaryOperator.get_all_unary_ops()
self.check_decimal_unary_expr_results(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
unary_operators,
)
def test_decimal_functions(self):
self.test_decimal_last_first_func()