From b36af6ca5907cfe7073179cce7243b6e427833bc Mon Sep 17 00:00:00 2001 From: wangjiaming0909 Date: Mon, 3 Mar 2025 08:53:27 +0800 Subject: [PATCH] decimal unary operator test --- source/libs/decimal/src/decimal.c | 18 +++-- source/libs/scalar/src/scalar.c | 12 ++- source/libs/scalar/src/sclvector.c | 62 +++++++++----- tests/system-test/2-query/decimal.py | 116 +++++++++++++++++++++------ 4 files changed, 158 insertions(+), 50 deletions(-) diff --git a/source/libs/decimal/src/decimal.c b/source/libs/decimal/src/decimal.c index e8f9995df0..e3657ccbea 100644 --- a/source/libs/decimal/src/decimal.c +++ b/source/libs/decimal/src/decimal.c @@ -1069,18 +1069,19 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR SDataType rt = {.type = TSDB_DATA_TYPE_DECIMAL, .precision = TSDB_DECIMAL_MAX_PRECISION, .bytes = tDataTypes[TSDB_DATA_TYPE_DECIMAL].bytes, - .scale = pRightT->scale}; + .scale = 0}; + if (pRightT) rt.scale = pRightT->scale; if (TSDB_DATA_TYPE_DECIMAL != pLeftT->type) { code = convertToDecimal(pLeftData, pLeftT, &left, <); if (TSDB_CODE_SUCCESS != code) return code; // TODO add some logs here } else { left = *(Decimal*)pLeftData; } - if (TSDB_DATA_TYPE_DECIMAL != pRightT->type) { + if (pRightT && TSDB_DATA_TYPE_DECIMAL != pRightT->type) { code = convertToDecimal(pRightData, pRightT, &right, &rt); if (TSDB_CODE_SUCCESS != code) return code; pRightData = &right; - } else { + } else if (pRightData){ right = *(Decimal*)pRightData; } #ifdef DEBUG @@ -1107,6 +1108,9 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR case OP_TYPE_REM: code = decimalMod(&left, <, &right, &rt, pOutT); break; + case OP_TYPE_MINUS: + decimal128Negate(&left); + break; default: code = TSDB_CODE_TSC_INVALID_OPERATION; break; @@ -1239,7 +1243,7 @@ static int32_t decimal64FromDecimal128(DecimalType* pDec, uint8_t prec, uint8_t Decimal64 max = {0}; DECIMAL64_GET_MAX(prec - scale, &max); - decimal128ScaleTo(&dec128, valScale, 0); + decimal128ScaleDown(&dec128, valScale, false); if (decimal128Gt(&dec128, &max, WORD_NUM(Decimal64))) { return TSDB_CODE_DECIMAL_OVERFLOW; } @@ -1257,7 +1261,7 @@ static int32_t decimal64FromDecimal64(DecimalType* pDec, uint8_t prec, uint8_t s *(Decimal64*)pDec = dec64; DECIMAL64_GET_MAX(prec - scale, &max); - decimal64ScaleTo(&dec64, valScale, 0); + decimal64ScaleDown(&dec64, valScale, false); if (decimal64Lt(&max, &dec64, WORD_NUM(Decimal64))) { return TSDB_CODE_DECIMAL_OVERFLOW; } @@ -1529,7 +1533,9 @@ static void decimal64ScaleAndCheckOverflow(Decimal64* pDec, int8_t scale, uint8_ Decimal64 res = *pDec, max = {0}; decimal64ScaleDown(&res, -deltaScale, false); DECIMAL64_GET_MAX(toPrec, &max); - if (decimal64Gt(&res, &max, WORD_NUM(Decimal64))) { + Decimal64 abs = res; + decimal64Abs(&abs); + if (decimal64Gt(&abs, &max, WORD_NUM(Decimal64))) { if (overflow) *overflow = true; } else { *pDec = res; diff --git a/source/libs/scalar/src/scalar.c b/source/libs/scalar/src/scalar.c index f46376cd5d..08effe9724 100644 --- a/source/libs/scalar/src/scalar.c +++ b/source/libs/scalar/src/scalar.c @@ -1726,11 +1726,17 @@ _return: } static int32_t sclGetMinusOperatorResType(SOperatorNode *pOp) { - if (!IS_MATHABLE_TYPE(((SExprNode *)(pOp->pLeft))->resType.type)) { + const SDataType* pDt = &((SExprNode*)(pOp->pLeft))->resType; + if (!IS_MATHABLE_TYPE(pDt->type)) { return TSDB_CODE_TSC_INVALID_OPERATION; } - pOp->node.resType.type = TSDB_DATA_TYPE_DOUBLE; - pOp->node.resType.bytes = tDataTypes[TSDB_DATA_TYPE_DOUBLE].bytes; + + if (IS_DECIMAL_TYPE(pDt->type)) { + pOp->node.resType = *pDt; + } else { + pOp->node.resType.type = TSDB_DATA_TYPE_DOUBLE; + pOp->node.resType.bytes = tDataTypes[TSDB_DATA_TYPE_DOUBLE].bytes; + } return TSDB_CODE_SUCCESS; } diff --git a/source/libs/scalar/src/sclvector.c b/source/libs/scalar/src/sclvector.c index f0773370e0..0744895424 100644 --- a/source/libs/scalar/src/sclvector.c +++ b/source/libs/scalar/src/sclvector.c @@ -49,7 +49,11 @@ bool compareForType(__compar_fn_t fp, int32_t optr, SColumnInfoData* pColL, int3 bool compareForTypeWithColAndHash(__compar_fn_t fp, int32_t optr, SColumnInfoData *pColL, int32_t idxL, const void *hashData, int32_t hashType, STypeMod hashTypeMod); -static int32_t vectorMathOpForDecimal(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *pOut, int32_t step, int32_t i, EOperatorType op); +static int32_t vectorMathBinaryOpForDecimal(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *pOut, int32_t step, + int32_t i, EOperatorType op); + +static int32_t vectorMathUnaryOpForDecimal(SScalarParam *pCol, SScalarParam *pOut, int32_t step, int32_t i, + EOperatorType op); int32_t convertNumberToNumber(const void *inData, void *outData, int8_t inType, int8_t outType) { switch (outType) { @@ -1340,7 +1344,7 @@ int32_t vectorMathAdd(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *p SCL_ERR_JRET(vectorMathAddHelper(pLeftCol, pRightCol, pOutputCol, pLeft->numOfRows, step, i)); } } else if (IS_DECIMAL_TYPE(pOutputCol->info.type)) { - SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_ADD)); + SCL_ERR_JRET(vectorMathBinaryOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_ADD)); } _return: @@ -1473,7 +1477,7 @@ int32_t vectorMathSub(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *p SCL_ERR_JRET(vectorMathSubHelper(pLeftCol, pRightCol, pOutputCol, pLeft->numOfRows, step, 1, i)); } } else if (pOutputCol->info.type == TSDB_DATA_TYPE_DECIMAL) { - SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_SUB)); + SCL_ERR_JRET(vectorMathBinaryOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_SUB)); } _return: @@ -1522,7 +1526,7 @@ int32_t vectorMathMultiply(SScalarParam *pLeft, SScalarParam *pRight, SScalarPar SColumnInfoData *pLeftCol = NULL; SColumnInfoData *pRightCol = NULL; if (pOutputCol->info.type == TSDB_DATA_TYPE_DECIMAL) { - SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_MULTI)); + SCL_ERR_JRET(vectorMathBinaryOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_MULTI)); } else { SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol)); SCL_ERR_JRET(vectorConvertVarToDouble(pRight, &rightConvert, &pRightCol)); @@ -1570,7 +1574,7 @@ int32_t vectorMathDivide(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam SColumnInfoData *pLeftCol = NULL; SColumnInfoData *pRightCol = NULL; if (pOutputCol->info.type == TSDB_DATA_TYPE_DECIMAL) { - SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_DIV)); + SCL_ERR_JRET(vectorMathBinaryOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_DIV)); } else { SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol)); SCL_ERR_JRET(vectorConvertVarToDouble(pRight, &rightConvert, &pRightCol)); @@ -1658,7 +1662,7 @@ int32_t vectorMathRemainder(SScalarParam *pLeft, SScalarParam *pRight, SScalarPa SColumnInfoData *pLeftCol = NULL; SColumnInfoData *pRightCol = NULL; if (pOutputCol->info.type == TSDB_DATA_TYPE_DECIMAL) { - SCL_ERR_JRET(vectorMathOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_REM)); + SCL_ERR_JRET(vectorMathBinaryOpForDecimal(pLeft, pRight, pOut, step, i, OP_TYPE_REM)); } else { SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol)); SCL_ERR_JRET(vectorConvertVarToDouble(pRight, &rightConvert, &pRightCol)); @@ -1708,20 +1712,24 @@ int32_t vectorMathMinus(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam int32_t leftConvert = 0; SColumnInfoData *pLeftCol = NULL; - SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol)); + if (IS_DECIMAL_TYPE(pOutputCol->info.type)) { + SCL_ERR_JRET(vectorMathUnaryOpForDecimal(pLeft, pOut, step, i, OP_TYPE_MINUS)); + } else { + SCL_ERR_JRET(vectorConvertVarToDouble(pLeft, &leftConvert, &pLeftCol)); - _getDoubleValue_fn_t getVectorDoubleValueFnLeft; - SCL_ERR_JRET(getVectorDoubleValueFn(pLeftCol->info.type, &getVectorDoubleValueFnLeft)); + _getDoubleValue_fn_t getVectorDoubleValueFnLeft; + SCL_ERR_JRET(getVectorDoubleValueFn(pLeftCol->info.type, &getVectorDoubleValueFnLeft)); - double *output = (double *)pOutputCol->pData; - for (; i < pLeft->numOfRows && i >= 0; i += step, output += 1) { - if (IS_HELPER_NULL(pLeftCol, i)) { - colDataSetNULL(pOutputCol, i); - continue; + double *output = (double *)pOutputCol->pData; + for (; i < pLeft->numOfRows && i >= 0; i += step, output += 1) { + if (IS_HELPER_NULL(pLeftCol, i)) { + colDataSetNULL(pOutputCol, i); + continue; + } + double result = 0; + SCL_ERR_JRET(getVectorDoubleValueFnLeft(LEFT_COL, i, &result)); + *output = (result == 0) ? 0 : -result; } - double result = 0; - SCL_ERR_JRET(getVectorDoubleValueFnLeft(LEFT_COL, i, &result)); - *output = (result == 0) ? 0 : -result; } _return: @@ -2308,7 +2316,25 @@ static int32_t vectorMathOpOneRowForDecimal(SScalarParam *pLeft, SScalarParam *p return code; } -static int32_t vectorMathOpForDecimal(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *pOut, int32_t step, int32_t i, EOperatorType op) { +static int32_t vectorMathUnaryOpForDecimal(SScalarParam *pCol, SScalarParam *pOut, int32_t step, int32_t i, + EOperatorType op) { + int32_t code = 0; + SColumnInfoData *pOutputCol = pOut->columnData; + void *pDec = pOutputCol->pData; + for (; i < pCol->numOfRows && i >= 0; i += step, pDec += tDataTypes[pOutputCol->info.type].bytes) { + if (IS_HELPER_NULL(pCol->columnData, i)) { + colDataSetNULL(pOutputCol, i); + continue; + } + SDataType colDt = GET_COL_DATA_TYPE(pCol->columnData->info), outDt = GET_COL_DATA_TYPE(pOutputCol->info); + + code = decimalOp(op, &colDt, NULL, &outDt, colDataGetData(pCol->columnData, i), NULL, pDec); + } + return code; +} + +static int32_t vectorMathBinaryOpForDecimal(SScalarParam *pLeft, SScalarParam *pRight, SScalarParam *pOut, int32_t step, + int32_t i, EOperatorType op) { Decimal *output = (Decimal *)pOut->columnData->pData; int32_t code = 0; SDataType leftType = GET_COL_DATA_TYPE(pLeft->columnData->info), diff --git a/tests/system-test/2-query/decimal.py b/tests/system-test/2-query/decimal.py index 73056e8028..7dfdb07ff6 100644 --- a/tests/system-test/2-query/decimal.py +++ b/tests/system-test/2-query/decimal.py @@ -23,8 +23,10 @@ invalid_operation = -2147483136 scalar_convert_err = -2147470768 -operator_test_round = 2 +operator_test_round = 1 tb_insert_rows = 1000 +binary_op_test = True +unary_op_test = True class DecimalTypeGeneratorConfig: def __init__(self): @@ -199,9 +201,18 @@ class DecimalColumnExpr: params = params + (p.get_val(tbname, idx),) return self.execute(params) + def convert_to_res_type(self, val: Decimal) -> Decimal: + if self.res_type_.is_decimal_type(): + return val.quantize(Decimal("0." + "0" * self.res_type_.scale()), ROUND_HALF_UP) + elif self.res_type_.type == TypeEnum.DOUBLE: + return float(val) + def get_input_types(self) -> List: pass + def should_skip_for_decimal(self, cols: list): + return False + def check(self, query_col_res: List, tbname: str): for i in range(len(query_col_res)): v_from_query = query_col_res[i] @@ -235,8 +246,8 @@ class DecimalColumnExpr: f"check decimal column failed for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, query: {v_from_query}, expect {calc_res}, but get {query_res}" ) else: - tdLog.debug( - f"check decimal succ for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py dec: {calc_res}" + tdLog.info( + f"op succ: {self}, in: {[t.__str__() for t in self.get_input_types()]}, res: {self.res_type_}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py calc: {calc_res}" ) ## format_params are already been set @@ -533,7 +544,11 @@ class Column: if save: if tbName not in self.saved_vals: self.saved_vals[tbName] = [] - self.saved_vals[tbName].append(val) + ## for constant columns, always replace the last val + if self.is_constant_col(): + self.saved_vals[tbName] = [val] + else: + self.saved_vals[tbName].append(val) return val def get_type_str(self) -> str: @@ -753,12 +768,14 @@ class DecimalBinaryOperator(DecimalColumnExpr): def generate(self, format_params) -> str: return super().generate(format_params) - def should_skip_for_decimal(self, left_col: Column, right_col: Column): + def should_skip_for_decimal(self, cols: list): + left_col = cols[0] + right_col = cols[1] if not left_col.type_.is_decimal_type() and not right_col.type_.is_decimal_type(): return True if self.op_ != "%": return False - ## why skip decimal % float/double?? it's wrong now. + ## why skip decimal % float/double? it's wrong now. left_is_real = left_col.type_.is_real_type() or left_col.type_.is_varchar_type() right_is_real = right_col.type_.is_real_type() or right_col.type_.is_varchar_type() if left_is_real or right_is_real: @@ -833,12 +850,6 @@ class DecimalBinaryOperator(DecimalColumnExpr): def get_input_types(self)-> list: return [self.left_type_, self.right_type_] - - def convert_to_res_type(self, val: Decimal) -> Decimal: - if self.res_type_.is_decimal_type(): - return val.quantize(Decimal("0." + "0" * self.res_type_.scale()), ROUND_HALF_UP) - elif self.res_type_.type == TypeEnum.DOUBLE: - return float(val) @staticmethod def get_convert_type(params): @@ -972,8 +983,36 @@ class DecimalBinaryOperator(DecimalColumnExpr): def execute(self, params): return super().execute(params) + +class DecimalUnaryOperator(DecimalColumnExpr): + def __init__(self, format, executor, op: str): + super().__init__(format, executor) + self.op_ = op + self.col_type_: DataType = None + def should_skip_for_decimal(self, cols: list): + col:Column = cols[0] + if not col.type_.is_decimal_type(): + return True + return False + @staticmethod + def get_all_unary_ops() -> List[DecimalColumnExpr]: + return [ + DecimalUnaryOperator(" -{0} ", DecimalUnaryOperator.execute_minus, "-"), + ] + + def get_input_types(self)-> list: + return [self.col_type_] + + def generate_res_type(self): + self.res_type_ = self.col_type_ = self.params_[0].type_ + + def execute_minus(self, params) -> Decimal: + if params[0] is None: + return 'NULL' + return -Decimal(params[0]) + class DecimalBinaryOperatorIn(DecimalBinaryOperator): def __init__(self, op: str): super().__init__(op) @@ -983,7 +1022,7 @@ class DecimalBinaryOperatorIn(DecimalBinaryOperator): return left in right if self.op_.lower() == "not in": return left not in right - + class TDTestCase: updatecfgDict = { @@ -1341,6 +1380,8 @@ class TDTestCase: constant_cols: List[Column], exprs: List[DecimalColumnExpr], ): + if not binary_op_test: + return for expr in exprs: for col in tb_cols: if col.name_ == '': @@ -1348,22 +1389,46 @@ class TDTestCase: left_is_decimal = col.type_.is_decimal_type() for const_col in constant_cols: right_is_decimal = const_col.type_.is_decimal_type() - if expr.should_skip_for_decimal(col, const_col): + if expr.should_skip_for_decimal([col, const_col]): continue const_col.generate_value() + select_expr2 = expr.generate((const_col, col)) + sql = f"select {select_expr2} from {dbname}.{tbname}" + res2 = TaosShell().query(sql) select_expr = expr.generate((col, const_col)) sql = f"select {select_expr} from {dbname}.{tbname}" res = TaosShell().query(sql) if len(res) > 0: + if len(res) != len(res2): + tdLog.exit( + f"sql: {sql} got different row number for {select_expr} and {select_expr2}" + ) + for c, c2 in zip(res, res2): + for t, t2 in zip(c, c2): + if t != t2: + tdLog.exit( + f"sql: {sql} got different result for {select_expr} and {select_expr2}, expect {t2}, but got {t}" + ) expr.check(res[0], tbname) else: tdLog.info(f"sql: {sql} got no output") - ## query - ## build expr, expr.generate(column) to generate sql expr - ## pass this expr into DataValidator. - # When validating between query results and local values, pass the column data into the Expr, and invoke expr.execute - ## get result - ## check result + + def check_decimal_unary_expr_results(self, dbname, tbname, tb_cols: List[Column], exprs: List[DecimalColumnExpr]): + if not unary_op_test: + return + for expr in exprs: + for col in tb_cols: + if col.name_ == '': + continue + if expr.should_skip_for_decimal([col]): + continue + select_expr = expr.generate([col]) + sql = f"select {select_expr} from {dbname}.{tbname}" + res = TaosShell().query(sql) + if len(res) > 0: + expr.check(res[0], tbname) + else: + tdLog.info(f"sql: {sql} got no output") ## test others unsupported types operator with decimal def test_decimal_unsupported_types(self): @@ -1415,6 +1480,7 @@ class TDTestCase: ## tables: meters, nt ## columns: c1, c2, c3, c4, c5, c7, c8, c9, c10, c99, c100 binary_operators = DecimalBinaryOperator.get_all_binary_ops() + binary_operators = binary_operators[0:1] all_type_columns = Column.get_decimal_oper_const_cols() ## decimal operator with constants of all other types @@ -1427,9 +1493,13 @@ class TDTestCase: binary_operators, ) - ## decimal operator with columns of all other types - - unary_operators = ["-"] + unary_operators = DecimalUnaryOperator.get_all_unary_ops() + self.check_decimal_unary_expr_results( + self.db_name, + self.norm_table_name, + self.norm_tb_columns, + unary_operators, + ) def test_decimal_functions(self): self.test_decimal_last_first_func()