feat: add more test cases for decimal

* add decimal tests

* decimal: add more decimal tests

* test(decimal): add more tests for decimal

* fix(decimal): fix decimal.py

* test(decimal): fix decimal test

* fix(decimal): fix decimal3.py test case
This commit is contained in:
wangjiaming 2025-03-19 17:13:56 +08:00 committed by GitHub
parent 0cb0e109a7
commit 77f9707f89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 2782 additions and 148 deletions

View File

@ -36,6 +36,12 @@ extern "C" {
#define FLT_GREATEREQUAL(_x, _y) (FLT_EQUAL((_x), (_y)) || ((_x) > (_y)))
#define FLT_LESSEQUAL(_x, _y) (FLT_EQUAL((_x), (_y)) || ((_x) < (_y)))
#define DBL_EQUAL(_x, _y) fabs((_x) - (_y)) <= (FLT_COMPAR_TOL_FACTOR * DBL_EPSILON)
#define DBL_GREATER(_x, _y) (!DBL_EQUAL((_x), (_y)) && ((_x) > (_y)))
#define DBL_LESS(_x, _y) (!DBL_EQUAL((_x), (_y)) && ((_x) < (_y)))
#define DBL_GREATEREQUAL(_x, _y) (DBL_EQUAL((_x), (_y)) || ((_x) > (_y)))
#define DBL_LESSEQUAL(_x, _y) (DBL_EQUAL((_x), (_y)) || ((_x) < (_y)))
#define PATTERN_COMPARE_INFO_INITIALIZER { '%', '_', L'%', L'_' }
typedef struct SPatternCompareInfo {

View File

@ -313,7 +313,7 @@ typedef struct SSyncQueryParam {
void* doAsyncFetchRows(SRequestObj* pRequest, bool setupOneRowPtr, bool convertUcs4);
void* doFetchRows(SRequestObj* pRequest, bool setupOneRowPtr, bool convertUcs4);
void doSetOneRowPtr(SReqResultInfo* pResultInfo, bool isStmt);
void doSetOneRowPtr(SReqResultInfo* pResultInfo);
void setResPrecision(SReqResultInfo* pResInfo, int32_t precision);
int32_t setQueryResultFromRsp(SReqResultInfo* pResultInfo, const SRetrieveTableRsp* pRsp, bool convertUcs4, bool isStmt);
int32_t setResultDataPtr(SReqResultInfo* pResultInfo, bool convertUcs4, bool isStmt);

View File

@ -1941,12 +1941,12 @@ TAOS* taos_connect_auth(const char* ip, const char* user, const char* auth, cons
// return taos_connect(ipStr, userStr, passStr, dbStr, port);
// }
void doSetOneRowPtr(SReqResultInfo* pResultInfo, bool isStmt) {
void doSetOneRowPtr(SReqResultInfo* pResultInfo) {
for (int32_t i = 0; i < pResultInfo->numOfCols; ++i) {
SResultColumn* pCol = &pResultInfo->pCol[i];
int32_t type = pResultInfo->fields[i].type;
int32_t schemaBytes = calcSchemaBytesFromTypeBytes(type, pResultInfo->fields[i].bytes, isStmt);
int32_t schemaBytes = calcSchemaBytesFromTypeBytes(type, pResultInfo->userFields[i].bytes, false);
if (IS_VAR_DATA_TYPE(type)) {
if (!IS_VAR_NULL_TYPE(type, schemaBytes) && pCol->offset[pResultInfo->current] != -1) {
@ -2012,7 +2012,7 @@ void* doFetchRows(SRequestObj* pRequest, bool setupOneRowPtr, bool convertUcs4)
}
if (setupOneRowPtr) {
doSetOneRowPtr(pResultInfo, pRequest->isStmtBind);
doSetOneRowPtr(pResultInfo);
pResultInfo->current += 1;
}
@ -2059,7 +2059,7 @@ void* doAsyncFetchRows(SRequestObj* pRequest, bool setupOneRowPtr, bool convertU
return NULL;
} else {
if (setupOneRowPtr) {
doSetOneRowPtr(pResultInfo, pRequest->isStmtBind);
doSetOneRowPtr(pResultInfo);
pResultInfo->current += 1;
}
@ -2135,8 +2135,9 @@ static int32_t doConvertUCS4(SReqResultInfo* pResultInfo, int32_t* colLength, bo
static int32_t convertDecimalType(SReqResultInfo* pResultInfo) {
for (int32_t i = 0; i < pResultInfo->numOfCols; ++i) {
TAOS_FIELD_E* pField = pResultInfo->fields + i;
int32_t type = pField->type;
TAOS_FIELD_E* pFieldE = pResultInfo->fields + i;
TAOS_FIELD* pField = pResultInfo->userFields + i;
int32_t type = pFieldE->type;
int32_t bufLen = 0;
char* p = NULL;
if (!IS_DECIMAL_TYPE(type) || !pResultInfo->pCol[i].pData) {
@ -2144,6 +2145,7 @@ static int32_t convertDecimalType(SReqResultInfo* pResultInfo) {
} else {
bufLen = 64;
p = taosMemoryRealloc(pResultInfo->convertBuf[i], bufLen * pResultInfo->numOfRows);
pFieldE->bytes = bufLen;
pField->bytes = bufLen;
}
if (!p) return terrno;
@ -2151,7 +2153,7 @@ static int32_t convertDecimalType(SReqResultInfo* pResultInfo) {
for (int32_t j = 0; j < pResultInfo->numOfRows; ++j) {
int32_t code = decimalToStr((DecimalWord*)(pResultInfo->pCol[i].pData + j * tDataTypes[type].bytes), type,
pField->precision, pField->scale, p, bufLen);
pFieldE->precision, pFieldE->scale, p, bufLen);
p += bufLen;
if (TSDB_CODE_SUCCESS != code) {
return code;
@ -2395,6 +2397,7 @@ static int32_t doConvertJson(SReqResultInfo* pResultInfo) {
}
int32_t setResultDataPtr(SReqResultInfo* pResultInfo, bool convertUcs4, bool isStmt) {
bool convertForDecimal = convertUcs4;
if (pResultInfo == NULL || pResultInfo->numOfCols <= 0 || pResultInfo->fields == NULL) {
tscError("setResultDataPtr paras error");
return TSDB_CODE_TSC_INTERNAL_ERROR;
@ -2507,7 +2510,7 @@ int32_t setResultDataPtr(SReqResultInfo* pResultInfo, bool convertUcs4, bool isS
code = doConvertUCS4(pResultInfo, colLength, isStmt);
}
#endif
if (TSDB_CODE_SUCCESS == code && convertUcs4) {
if (TSDB_CODE_SUCCESS == code && convertForDecimal) {
code = convertDecimalType(pResultInfo);
}
return code;

View File

@ -647,7 +647,7 @@ TAOS_ROW taos_fetch_row(TAOS_RES *res) {
}
if (pResultInfo->current < pResultInfo->numOfRows) {
doSetOneRowPtr(pResultInfo, false);
doSetOneRowPtr(pResultInfo);
pResultInfo->current += 1;
return pResultInfo->row;
} else {
@ -655,7 +655,7 @@ TAOS_ROW taos_fetch_row(TAOS_RES *res) {
return NULL;
}
doSetOneRowPtr(pResultInfo, false);
doSetOneRowPtr(pResultInfo);
pResultInfo->current += 1;
return pResultInfo->row;
}

View File

@ -335,7 +335,7 @@ void sclDowngradeValueType(SValueNode *valueNode) {
}
case TSDB_DATA_TYPE_DOUBLE: {
float f = valueNode->datum.d;
if (FLT_EQUAL(f, valueNode->datum.d)) {
if (DBL_EQUAL(f, valueNode->datum.d)) {
valueNode->node.resType.type = TSDB_DATA_TYPE_FLOAT;
*(float *)&valueNode->typeData = f;
break;

View File

@ -65,7 +65,8 @@ int32_t setChkInDecimalHash(const void* pLeft, const void* pRight) {
}
int32_t setChkNotInDecimalHash(const void* pLeft, const void* pRight) {
return NULL == taosHashGet((SHashObj *)pRight, pLeft, 16) ? 1 : 0;
const SDecimalCompareCtx *pCtxL = pLeft, *pCtxR = pRight;
return NULL == taosHashGet((SHashObj *)(pCtxR->pData), pCtxL->pData, tDataTypes[pCtxL->type].bytes) ? 1 : 0;
}
int32_t compareChkInString(const void *pLeft, const void *pRight) {
@ -187,7 +188,7 @@ int32_t compareDoubleVal(const void *pLeft, const void *pRight) {
return 1;
}
if (FLT_EQUAL(p1, p2)) {
if (DBL_EQUAL(p1, p2)) {
return 0;
}
return FLT_GREATER(p1, p2) ? 1 : -1;

View File

@ -170,12 +170,13 @@ uint32_t taosDoubleHash(const char *key, uint32_t UNUSED_PARAM(len)) {
return 0x7fc00000;
}
if (FLT_EQUAL(f, 0.0)) {
if (DBL_EQUAL(f, 0.0)) {
return 0;
}
if (fabs(f) < DBL_MAX / BASE - DLT) {
int32_t t = (int32_t)(round(BASE * (f + DLT)));
return (uint32_t)t;
uint64_t bits;
memcpy(&bits, &f, sizeof(double));
return (uint32_t)(bits ^ (bits >> 32));
} else {
return 0x7fc00000;
}

View File

@ -343,6 +343,11 @@
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/decimal2.py -Q 3
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/decimal2.py -Q 2
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/decimal2.py -Q 1
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/decimal3.py
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/decimal3.py -Q 4
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/decimal3.py -Q 3
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/decimal3.py -Q 2
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/decimal3.py -Q 1
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/tbnameIn.py
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/tbnameIn.py -R
,,y,system-test,./pytest.sh python3 ./test.py -f 2-query/tbnameIn.py -Q 2

View File

@ -28,11 +28,14 @@ class AtomicCounter:
getcontext().prec = 40
def get_decimal(val, scale: int) -> Decimal:
def get_decimal(val, scale: int):
if val == 'NULL':
return None
getcontext().prec = 100
return Decimal(val).quantize(Decimal("1." + "0" * scale), ROUND_HALF_UP)
try:
return Decimal(val).quantize(Decimal("1." + "0" * scale), ROUND_HALF_UP)
except:
tdLog.exit(f"faield to convert to decimal for v: {val} scale: {scale}")
syntax_error = -2147473920
invalid_column = -2147473918
@ -52,6 +55,7 @@ unary_op_test = True
binary_op_in_where_test = True
test_decimal_funcs = False
cast_func_test_round = 10
in_op_test_round = 10
class DecimalTypeGeneratorConfig:
def __init__(self):
@ -308,8 +312,8 @@ class DecimalColumnExpr:
calc_res = float(v_from_calc_in_py)
failed = not math.isclose(query_res, calc_res, abs_tol=1e-7)
else:
query_res = Decimal(v_from_query)
calc_res = Decimal(v_from_calc_in_py)
query_res = get_decimal(v_from_query, self.res_type_.scale())
calc_res = get_decimal(v_from_calc_in_py, self.res_type_.scale())
failed = query_res != calc_res
if failed:
tdLog.exit(
@ -438,13 +442,13 @@ class DataType:
if self.type == TypeEnum.BIGINT:
return str(secrets.randbelow(9223372036854775808) - 4611686018427387904)
if self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE:
return str(random.random())
return str(random.uniform(-1e10, 1e10))
if (
self.type == TypeEnum.VARCHAR
or self.type == TypeEnum.NCHAR
or self.type == TypeEnum.VARBINARY
):
return f"'{str(random.random())[0:self.length]}'"
return f"'{str(random.uniform(-1e20, 1e20))[0:self.length]}'"
if self.type == TypeEnum.TIMESTAMP:
return str(secrets.randbelow(9223372036854775808))
if self.type == TypeEnum.UTINYINT:
@ -461,6 +465,29 @@ class DataType:
return "'POINT(1.0 1.0)'"
raise Exception(f"unsupport type {self.type}")
def generate_sized_val(self, prec: int, scale: int) -> str:
weight = prec - scale
if self.type == TypeEnum.BOOL:
return ['true', 'false'][secrets.randbelow(2)]
if self.type == TypeEnum.TINYINT or self.type == TypeEnum.SMALLINT or self.type == TypeEnum.INT or self.type == TypeEnum.BIGINT or self.type == TypeEnum.TIMESTAMP:
return str(secrets.randbelow(10 * weight * 2) - 10 * weight)
if self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE:
return str(random.uniform(-10 * weight, 10 * weight))
if (
self.type == TypeEnum.VARCHAR
or self.type == TypeEnum.NCHAR
or self.type == TypeEnum.VARBINARY
):
return f"'{str(random.uniform(-(10 * weight + 1), 10 * weight - 1))}'"
if self.type == TypeEnum.UTINYINT or self.type == TypeEnum.USMALLINT or self.type == TypeEnum.UINT or self.type == TypeEnum.UBIGINT:
return str(secrets.randbelow(10 * weight * 2))
if self.type == TypeEnum.JSON:
return f'{{"key": "{secrets.token_urlsafe(10)}"}}'
if self.type == TypeEnum.GEOMETRY:
return "'POINT(1.0 1.0)'"
raise Exception(f"unsupport type {self.type}")
def check(self, values, offset: int):
return True
@ -481,6 +508,8 @@ class DataType:
return get_decimal(val, self.scale())
elif isinstance(val, str):
val = val.strip("'")
if len(val) == 0:
return 0
return val
def get_typed_val(self, val):
@ -513,6 +542,19 @@ class DataType:
TypeEnum.DECIMAL,
TypeEnum.DECIMAL64,
]
@staticmethod
def generate_random_type_for(dt: int):
if dt == TypeEnum.DECIMAL:
prec = random.randint(1, DecimalType.DECIMAL_MAX_PRECISION)
return DecimalType(dt, prec, random.randint(0, prec))
elif dt == TypeEnum.DECIMAL64:
prec = random.randint(1, DecimalType.DECIMAL64_MAX_PRECISION)
return DecimalType(dt, prec, random.randint(0, prec))
elif dt == TypeEnum.BINARY or dt == TypeEnum.VARCHAR:
return DataType(dt, random.randint(16, 255), 0)
else:
return DataType(dt, 0, 0)
class DecimalType(DataType):
DECIMAL_MAX_PRECISION = 38
@ -668,6 +710,25 @@ class Column:
else:
return len(self.saved_vals[tbname])
def seq_scan_col(self, tbname: str, idx: int):
if self.is_constant_col():
return self.get_constant_val_for_execute(), False
elif len(self.saved_vals) > 1:
keys = list(self.saved_vals.keys())
for i, key in enumerate(keys):
l = len(self.saved_vals[key])
if idx < l:
return self.get_typed_val_for_execute(self.saved_vals[key][idx]), True
else:
idx -= l
return 1, False
else:
if idx > len(self.saved_vals[tbname]) - 1:
return 1, False
v = self.get_typed_val_for_execute(self.saved_vals[tbname][idx])
return v, True
@staticmethod
def comp_key(key1, key2):
if key1 is None:
@ -921,23 +982,11 @@ class DecimalCastTypeGenerator:
else:
return DataType.get_decimal_op_types()
def do_generate_type(self, dt: int) ->DataType:
if dt == TypeEnum.DECIMAL:
prec = random.randint(1, DecimalType.DECIMAL_MAX_PRECISION)
return DecimalType(dt, prec, random.randint(0, prec))
elif dt == TypeEnum.DECIMAL64:
prec = random.randint(1, DecimalType.DECIMAL64_MAX_PRECISION)
return DecimalType(dt, prec, random.randint(0, prec))
elif dt == TypeEnum.BINARY or dt == TypeEnum.VARCHAR:
return DataType(dt, random.randint(16, 255), 0)
else:
return DataType(dt, 0, 0)
def generate(self, num: int) -> List[DataType]:
res: list[DataType] = []
for _ in range(num):
dt = random.choice(self.get_possible_output_types())
dt = self.do_generate_type(dt)
dt = DataType.generate_random_type_for(dt)
res.append(dt)
res = list(set(res))
return res
@ -1303,10 +1352,12 @@ class DecimalBinaryOperator(DecimalColumnExpr):
return super().generate(format_params)
def should_skip_for_decimal(self, cols: list):
left_col = cols[0]
right_col = cols[1]
left_col: Column = cols[0]
right_col: Column = cols[1]
if not left_col.type_.is_decimal_type() and not right_col.type_.is_decimal_type():
return True
if not right_col.is_constant_col() and (self.op_ == '%' or self.op_ == '/'):
return True
if self.op_ != "%":
return False
## TODO wjm why skip decimal % float/double? it's wrong now.
@ -1556,21 +1607,104 @@ class DecimalUnaryOperator(DecimalColumnExpr):
def generate_res_type(self):
self.res_type_ = self.col_type_ = self.params_[0].type_
def execute_minus(self, params) -> Decimal:
def execute_minus(self, params):
if params[0] is None:
return 'NULL'
return -Decimal(params[0])
class DecimalBinaryOperatorIn(DecimalBinaryOperator):
def __init__(self, op: str):
super().__init__(op)
self.op_ = op
super().__init__("{0} " + self.op_ + " ({1})", self.execute_op, op)
def execute(self, left, right):
if self.op_.lower()() == "in":
return left in right
if self.op_.lower() == "not in":
return left not in right
def generate_res_type(self):
self.query_col = self.params_[0]
self.res_type_ = DataType(TypeEnum.BOOL)
def execute_op(self, params):
list_exprs: DecimalListExpr = self.params_[1]
v, vs = list_exprs.get_converted_vs(params)
if v is None:
return False
b = False
if self.op_.lower() == 'in':
b = v in vs
#if b:
#tdLog.debug(f"eval {v} in {list_exprs} got: {b}")
else:
b = v not in vs
#if not b:
#tdLog.debug(f"eval {v} not in {list_exprs} got: {b}")
return b
def check(self, res, tbname: str):
idx = 0
v, has_next = self.query_col.seq_scan_col(tbname, idx)
calc_res = []
while has_next:
keep: bool = self.execute_op(v)
if keep:
calc_res.append(v)
idx += 1
v, has_next = self.query_col.seq_scan_col(tbname, idx)
calc_res = sorted(calc_res)
res = [get_decimal(e, self.query_col.type_.scale()) for e in res]
res = sorted(res)
for v, calc_v in zip(res, calc_res):
if v != calc_v:
tdLog.exit(f"check failed for {self}, query got: {v}, expect: {calc_v}, len_query: {len(res)}, len_calc: {len(calc_res)}")
return True
class DecimalListExpr:
def __init__(self, num: int, col: Column):
self.elements_ = []
self.num_ = num
self.types_ = []
self.converted_vs_ = None
self.output_float_ = False
self.col_in_: Column = col
@staticmethod
def get_all_possible_types():
types = DataType.get_decimal_op_types()
types.remove(TypeEnum.DECIMAL)
types.remove(TypeEnum.DECIMAL64)
return types
def has_output_double_type(self) -> bool:
for t in self.types_:
if t.is_real_type() or t.is_varchar_type():
return True
return False
def get_converted_vs(self, v):
if self.converted_vs_ is None:
vs = []
for t, saved_v in zip(self.types_, self.elements_):
vs.append(t.get_typed_val_for_execute(saved_v, True))
if self.has_output_double_type():
self.converted_vs_ = [float(e) for e in vs]
self.output_float_= True
else:
self.converted_vs_ = [Decimal(e) for e in vs]
if v is None:
return None, self.converted_vs_
if self.output_float_:
return float(v), self.converted_vs_
else:
return Decimal(v), self.converted_vs_
def generate(self):
types = self.get_all_possible_types()
for _ in range(self.num_):
type = random.choice(types)
dt = DataType.generate_random_type_for(type)
self.types_.append(dt)
v = dt.generate_sized_val(self.col_in_.type_.prec(), self.col_in_.type_.scale())
self.elements_.append(v)
def __str__(self):
return f"{','.join([e for e in self.elements_])}"
class TDTestCase:
updatecfgDict = {
@ -2152,44 +2286,64 @@ class TDTestCase:
if tdSql.errno != invalid_operation and tdSql.errno != scalar_convert_err:
tdLog.exit(f"expected err not occured for sql: {sql}, expect: {invalid_operation} or {scalar_convert_err}, but got {tdSql.errno}")
def check_decimal_in_op(self, tbname: str, tb_cols: list):
for i in range(in_op_test_round):
inOp: DecimalBinaryOperatorIn = DecimalBinaryOperatorIn('in')
notInOp: DecimalBinaryOperatorIn = DecimalBinaryOperatorIn('not in')
for col in tb_cols:
if not col.type_.is_decimal_type():
continue
list_expr: DecimalListExpr = DecimalListExpr(random.randint(1, 10), col)
list_expr.generate()
expr = inOp.generate((col, list_expr))
sql = f'select {col} from {self.db_name}.{tbname} where {expr}'
res = TaosShell().query(sql)
if len(res) > 0:
res = res[0]
inOp.check(res, tbname)
expr = notInOp.generate((col, list_expr))
sql = f'select {col} from {self.db_name}.{tbname} where {expr}'
res = TaosShell().query(sql)
if len(res) > 0:
res = res[0]
notInOp.check(res, tbname)
def test_decimal_operators(self):
tdLog.debug("start to test decimal operators")
self.test_decimal_unsupported_types()
## tables: meters, nt
## columns: c1, c2, c3, c4, c5, c7, c8, c9, c10, c99, c100
binary_operators = DecimalBinaryOperator.get_all_binary_ops()
if True:
self.test_decimal_unsupported_types()
## tables: meters, nt
## columns: c1, c2, c3, c4, c5, c7, c8, c9, c10, c99, c100
binary_operators = DecimalBinaryOperator.get_all_binary_ops()
## decimal operator with constants of all other types
self.run_in_thread(
operator_test_round,
self.check_decimal_binary_expr_with_const_col_results,
(
## decimal operator with constants of all other types
self.run_in_thread(
operator_test_round,
self.check_decimal_binary_expr_with_const_col_results,
(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
Column.get_decimal_oper_const_cols,
DecimalBinaryOperator.get_all_binary_ops,
),
)
## test decimal column op decimal column
for _ in range(operator_test_round):
self.check_decimal_binary_expr_with_col_results(
self.db_name, self.norm_table_name, self.norm_tb_columns, binary_operators)
unary_operators = DecimalUnaryOperator.get_all_unary_ops()
self.check_decimal_unary_expr_results(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
Column.get_decimal_oper_const_cols,
DecimalBinaryOperator.get_all_binary_ops,
),
)
unary_operators,)
## test decimal column op decimal column
for i in range(operator_test_round):
self.check_decimal_binary_expr_with_col_results(
self.db_name, self.norm_table_name, self.norm_tb_columns, binary_operators)
unary_operators = DecimalUnaryOperator.get_all_unary_ops()
self.check_decimal_unary_expr_results(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
unary_operators,)
def test_decimal_last_first_func(self):
pass
def test_query_decimal_with_sma(self):
pass
self.check_decimal_in_op(self.norm_table_name, self.norm_tb_columns)
self.check_decimal_in_op(self.stable_name, self.stb_columns)
def check_decimal_where_with_binary_expr_with_const_col_results(
self,

View File

@ -32,7 +32,10 @@ def get_decimal(val, scale: int) -> Decimal:
if val == 'NULL':
return None
getcontext().prec = 100
return Decimal(val).quantize(Decimal("1." + "0" * scale), ROUND_HALF_UP)
try:
return Decimal(val).quantize(Decimal("1." + "0" * scale), ROUND_HALF_UP)
except Exception as e:
tdLog.exit(f"failed to convert {val} to decimal, {e}")
syntax_error = -2147473920
invalid_column = -2147473918
@ -46,6 +49,7 @@ decimal_test_query = True
decimal_insert_validator_test = False
operator_test_round = 1
tb_insert_rows = 1000
ctb_num = 10
binary_op_with_const_test = False
binary_op_with_col_test = False
unary_op_test = False
@ -160,6 +164,23 @@ class DecimalColumnAggregator:
self.none_num: int = 0
self.first = None
self.last = None
self.firsts = []
self.lasts = []
for i in range(ctb_num):
self.firsts.append(None)
self.lasts.append(None)
def is_stb(self):
return self.firsts[1] is not None
def get_last(self):
if self.is_stb():
return self.lasts
return self.last
def get_first(self):
if self.is_stb():
return self.firsts
return self.first
def add_value(self, value: str, scale: int):
self.count += 1
@ -171,7 +192,10 @@ class DecimalColumnAggregator:
v: Decimal = get_decimal(value, scale)
if self.first is None:
self.first = v
if self.firsts[int((self.count - 1) / tb_insert_rows)] is None:
self.firsts[int((self.count - 1) / tb_insert_rows)] = v
self.last = v
self.lasts[int((self.count - 1) / tb_insert_rows)] = v
self.sum += v
if v > self.max:
self.max = v
@ -275,7 +299,7 @@ class DecimalColumnExpr:
continue
else:
break
dec_from_query = Decimal(v_from_query)
dec_from_query = get_decimal(v_from_query, self.query_col.type_.scale())
dec_from_calc = self.get_query_col_val(tbname, j)
if dec_from_query != dec_from_calc:
tdLog.exit(f"filter with {self} failed, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}")
@ -994,7 +1018,7 @@ class DecimalFunction(DecimalColumnExpr):
def check_results(self, query_col_res: List) -> bool:
return False
def check_for_agg_func(self, query_col_res: List, tbname: str, func):
def check_for_agg_func(self, query_col_res: List, tbname: str, func, is_stb: bool = False):
col_expr = self.query_col
for i in range(col_expr.get_cardinality(tbname)):
col_val = col_expr.get_val_for_execute(tbname, i)
@ -1122,18 +1146,30 @@ class DecimalAggFunction(DecimalFunction):
else:
return self.get_func_res() == Decimal(query_col_res[0])
class DecimalLastRowFunction(DecimalAggFunction):
def __init__(self):
super().__init__("last_row({0})", DecimalLastRowFunction.execute_last_row, "last_row")
self.res_ = None
def get_func_res(self):
decimal_type:DecimalType = self.query_col.type_
return decimal_type.aggregator.last
class DecimalFirstLastFunction(DecimalAggFunction):
def __init__(self, format: str, func, name):
super().__init__(format, func, name)
def generate_res_type(self):
self.res_type_ = self.query_col.type_
def check_results(self, query_col_res):
if len(query_col_res) == 0:
tdLog.exit(f"query got no output: {self}, py calc: {self.get_func_res()}")
else:
v = Decimal(query_col_res[0])
decimal_type: DecimalType = self.query_col.type_
if decimal_type.aggregator.is_stb():
return v in self.get_func_res()
else:
return self.get_func_res() == v
class DecimalLastRowFunction(DecimalFirstLastFunction):
def __init__(self):
super().__init__("last_row({0})", DecimalLastRowFunction.execute_last_row, "last_row")
def get_func_res(self):
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.get_last()
def execute_last_row(self, params):
if params[0] is not None:
self.res_ = Decimal(params[0])
pass
class DecimalCacheLastRowFunction(DecimalAggFunction):
def __init__(self):
@ -1148,27 +1184,22 @@ class DecimalCacheLastRowFunction(DecimalAggFunction):
class DecimalCacheLastFunction(DecimalAggFunction):
pass
class DecimalFirstFunction(DecimalAggFunction):
class DecimalFirstFunction(DecimalFirstLastFunction):
def __init__(self):
super().__init__("first({0})", DecimalFirstFunction.execute_first, "first")
self.res_ = None
def get_func_res(self):
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.first
def generate_res_type(self):
self.res_type_ = self.query_col.type_
return decimal_type.aggregator.get_first()
def execute_first(self, params):
pass
class DecimalLastFunction(DecimalAggFunction):
class DecimalLastFunction(DecimalFirstLastFunction):
def __init__(self):
super().__init__("last({0})", DecimalLastFunction.execute_last, "last")
self.res_ = None
def get_func_res(self):
decimal_type:DecimalType = self.query_col.type_
return decimal_type.aggregator.last
def generate_res_type(self):
self.res_type_ = self.query_col.type_
return decimal_type.aggregator.get_last()
def execute_last(self, params):
pass
@ -1207,19 +1238,12 @@ class DecimalMinFunction(DecimalAggFunction):
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.min
return self.min_
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
def execute_min(self, params):
if params[0] is None:
return
if self.min_ is None:
self.min_ = Decimal(params[0])
else:
self.min_ = min(self.min_, Decimal(params[0]))
return self.min_
pass
class DecimalMaxFunction(DecimalAggFunction):
def __init__(self):
@ -1234,13 +1258,7 @@ class DecimalMaxFunction(DecimalAggFunction):
self.res_type_ = self.query_col.type_
def execute_max(self, params):
if params[0] is None:
return
if self.max_ is None:
self.max_ = Decimal(params[0])
else:
self.max_ = max(self.max_, Decimal(params[0]))
return self.max_
pass
class DecimalSumFunction(DecimalAggFunction):
def __init__(self):
@ -1249,24 +1267,15 @@ class DecimalSumFunction(DecimalAggFunction):
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.sum
return self.sum_
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
self.res_type_.set_prec(DecimalType.DECIMAL_MAX_PRECISION)
def execute_sum(self, params):
if params[0] is None:
return
if self.sum_ is None:
self.sum_ = Decimal(params[0])
else:
self.sum_ += Decimal(params[0])
return self.sum_
pass
class DecimalAvgFunction(DecimalAggFunction):
def __init__(self):
super().__init__("avg({0})", DecimalAvgFunction.execute_avg, "avg")
self.count_: Decimal = 0
self.sum_: Decimal = None
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return get_decimal(
@ -1280,14 +1289,7 @@ class DecimalAvgFunction(DecimalAggFunction):
count_type = DataType(TypeEnum.BIGINT, 8, 0)
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(sum_type, count_type, "/")
def execute_avg(self, params):
if params[0] is None:
return
if self.sum_ is None:
self.sum_ = Decimal(params[0])
else:
self.sum_ += Decimal(params[0])
self.count_ += 1
return self.get_func_res()
pass
class DecimalBinaryOperator(DecimalColumnExpr):
def __init__(self, format, executor, op: str):
@ -1594,7 +1596,7 @@ class TDTestCase:
self.c_table_prefix = "t"
self.tag_name_prefix = "t"
self.db_name = "test"
self.c_table_num = 10
self.c_table_num = ctb_num
self.no_decimal_col_tb_name = "tt"
self.stb_columns = []
self.stream_name = "stream1"
@ -1975,9 +1977,9 @@ class TDTestCase:
#self.no_decimal_table_test()
self.test_insert_decimal_values()
self.test_query_decimal()
self.test_decimal_and_tsma()
self.test_decimal_and_view()
self.test_decimal_and_stream()
#self.test_decimal_and_tsma()
#self.test_decimal_and_view()
#self.test_decimal_and_stream()
def stop(self):
tdSql.close()
@ -2396,7 +2398,7 @@ class TDTestCase:
res = TaosShell().query(sql)
if len(res) > 0:
res = res[0]
func.check_for_agg_func(res, tbname, func)
func.check_for_agg_func(res, tbname, func, tbname == self.stable_name)
def test_decimal_cast_func(self, dbname, tbname, tb_cols: List[Column]):
for col in tb_cols:
@ -2416,13 +2418,8 @@ class TDTestCase:
self.log_test("start to test decimal functions")
if not test_decimal_funcs:
return
self.test_decimal_agg_funcs(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
DecimalFunction.get_decimal_agg_funcs,
)
##self.test_decimal_agg_funcs( self.db_name, self.stable_name, self.stb_columns, DecimalFunction.get_decimal_agg_funcs)
self.test_decimal_agg_funcs( self.db_name, self.norm_table_name, self.norm_tb_columns, DecimalFunction.get_decimal_agg_funcs)
self.test_decimal_agg_funcs( self.db_name, self.stable_name, self.stb_columns, DecimalFunction.get_decimal_agg_funcs)
self.test_decimal_cast_func(self.db_name, self.norm_table_name, self.norm_tb_columns)
def test_query_decimal(self):
@ -2431,11 +2428,11 @@ class TDTestCase:
return
#self.test_decimal_operators()
self.test_query_decimal_where_clause()
self.test_decimal_functions()
self.test_query_decimal_order_clause()
self.test_query_decimal_case_when()
self.test_query_decimal_group_by_clause()
self.test_query_decimal_having_clause()
#self.test_decimal_functions()
#self.test_query_decimal_order_clause()
#self.test_query_decimal_case_when()
#self.test_query_decimal_group_by_clause()
#self.test_query_decimal_having_clause()
event = threading.Event()

File diff suppressed because it is too large Load Diff

View File

@ -70,7 +70,6 @@ def get_local_classes_in_order(file_path):
def dynamicLoadModule(fileName):
moduleName = fileName.replace(".py", "").replace(os.sep, ".")
return importlib.import_module(moduleName, package='..')
#
# run case on previous cluster
#