homework-jianmu/tests/system-test/2-query/decimal3.py

2469 lines
96 KiB
Python

import math
from random import randrange
import random
import time
import threading
import secrets
import numpy
from util.log import *
from util.sql import *
from util.cases import *
from util.dnodes import *
from util.common import *
from decimal import *
from multiprocessing import Value, Lock
from functools import cmp_to_key
class AtomicCounter:
def __init__(self, initial_value=0):
self._value = Value('i', initial_value)
self._lock = Lock()
def fetch_add(self, delta = 1):
with self._lock:
old_value = self._value.value
self._value.value += delta
return old_value
getcontext().prec = 40
def get_decimal(val, scale: int) -> Decimal:
if val == 'NULL':
return None
getcontext().prec = 100
return Decimal(val).quantize(Decimal("1." + "0" * scale), ROUND_HALF_UP)
syntax_error = -2147473920
invalid_column = -2147473918
invalid_compress_level = -2147483084
invalid_encode_param = -2147483087
invalid_operation = -2147483136
scalar_convert_err = -2147470768
decimal_test_query = True
decimal_insert_validator_test = False
operator_test_round = 1
tb_insert_rows = 1000
ctb_num = 10
binary_op_with_const_test = False
binary_op_with_col_test = False
unary_op_test = False
binary_op_in_where_test = True
test_decimal_funcs = True
cast_func_test_round = 10
class DecimalTypeGeneratorConfig:
def __init__(self):
self.enable_weight_overflow: bool = False
self.weightOverflowRatio: float = 0.001
self.enable_scale_overflow: bool = True
self.scale_overflow_ratio = 0.1
self.enable_positive_sign = False
self.with_corner_case = True
self.corner_case_ratio = 0.1
self.positive_ratio = 0.7
self.prec = 38
self.scale = 10
class DecimalStringRandomGenerator:
def __init__(self):
self.corner_cases = ["0", "NULL", "0.", ".0", "000.000000"]
self.ratio_base: int = 1000000
def possible(self, possibility: float) -> bool:
return random.randint(0, self.ratio_base) < possibility * self.ratio_base
def generate_sign(self, positive_ratio: float) -> str:
if self.possible(positive_ratio):
return "+"
return "-"
def generate_digit(self) -> str:
return str(random.randint(0, 9))
def current_should_generate_corner_case(self, corner_case_ratio: float) -> bool:
return self.possible(corner_case_ratio)
def generate_corner_case(self, config: DecimalTypeGeneratorConfig) -> str:
if self.possible(0.8):
return random.choice(self.corner_cases)
else:
res = self.generate_digit() * (config.prec - config.scale)
if self.possible(0.8):
res += "."
if self.possible(0.8):
res += self.generate_digit() * config.scale
return res
## 写入大整数的例子, 如10000000000, scale解析时可能为负数
## Generate decimal with E/e
def generate(self, config: DecimalTypeGeneratorConfig) -> str:
ret: str = ""
sign = self.generate_sign(config.positive_ratio)
if config.with_corner_case and self.current_should_generate_corner_case(
config.corner_case_ratio
):
ret += self.generate_corner_case(config)
else:
if config.enable_positive_sign or sign != "+":
ret += sign
weight = random.randint(1, config.prec - config.scale)
scale = random.randint(1, config.scale)
for i in range(weight):
ret += self.generate_digit()
if config.enable_weight_overflow and self.possible(
config.weightOverflowRatio
):
extra_weight = (
config.prec
- weight
+ 1
+ random.randint(1, self.get_max_prec(config.prec))
)
while extra_weight > 0:
ret += self.generate_digit()
extra_weight -= 1
ret += "."
for i in range(scale):
ret += self.generate_digit()
if config.enable_scale_overflow and self.possible(
config.scale_overflow_ratio
):
extra_scale = (
config.scale
- scale
+ 1
+ random.randint(1, self.get_max_prec(config.prec))
)
while extra_scale > 0:
ret += self.generate_digit()
extra_scale -= 1
return ret
def get_max_prec(self, prec):
if prec <= 18:
return 18
else:
return 38
class DecimalColumnAggregator:
def __init__(self):
self.max: Decimal = Decimal("0")
self.min: Decimal = Decimal("0")
self.count: int = 0
self.sum: Decimal = Decimal("0")
self.null_num: int = 0
self.none_num: int = 0
self.first = None
self.last = None
self.firsts = []
self.lasts = []
self.last_row = None
self.last_rows = []
for i in range(ctb_num):
self.firsts.append(None)
self.lasts.append(None)
self.last_rows.append(None)
def is_stb(self):
return self.firsts[1] is not None
def get_last(self):
if self.is_stb():
return self.lasts
return self.last
def get_last_row(self):
if self.is_stb():
return self.last_rows
return self.last_row
def get_first(self):
if self.is_stb():
return self.firsts
return self.first
def add_value(self, value: str, scale: int):
self.count += 1
if value == "NULL":
self.null_num += 1
self.last_row = None
self.last_rows[int((self.count - 1) / tb_insert_rows)] = None
elif value == "None":
self.none_num += 1
self.last_row = None
self.last_rows[int((self.count - 1) / tb_insert_rows)] = None
else:
v: Decimal = get_decimal(value, scale)
self.last_row = v
if self.first is None:
self.first = v
if self.firsts[int((self.count - 1) / tb_insert_rows)] is None:
self.firsts[int((self.count - 1) / tb_insert_rows)] = v
self.last = v
self.lasts[int((self.count - 1) / tb_insert_rows)] = v
self.last_rows[int((self.count - 1) / tb_insert_rows)] = v
self.sum += v
if v > self.max:
self.max = v
if v < self.min:
self.min = v
atomic_counter = AtomicCounter(0)
class TaosShell:
def __init__(self):
self.counter_ = atomic_counter.fetch_add()
self.queryResult = []
self.tmp_file_path = "/tmp/taos_shell_result"
def get_file_path(self):
return f"{self.tmp_file_path}_{self.counter_}"
def read_result(self):
with open(self.get_file_path(), "r") as f:
lines = f.readlines()
lines = lines[1:]
for line in lines:
col = 0
vals: list[str] = line.split(",")
if len(self.queryResult) == 0:
self.queryResult = [[] for i in range(len(vals))]
for val in vals:
self.queryResult[col].append(val.strip().strip('"'))
col += 1
def query(self, sql: str):
with open(self.get_file_path(), "a+") as f:
f.truncate(0)
self.queryResult = []
try:
command = f'taos -s "{sql} >> {self.get_file_path()}"'
result = subprocess.run(
command, shell=True, check=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE
)
self.read_result()
except Exception as e:
tdLog.exit(f"Command '{sql}' failed with error: {e.stderr.decode('utf-8')}")
return self.queryResult
class DecimalColumnExpr:
def __init__(self, format: str, executor):
self.format_: str = format
self.executor_ = executor
self.params_ = ()
self.res_type_: DataType = None
self.query_col: Column = None
def __str__(self):
return f"({self.format_})".format(*self.params_)
def execute(self, params):
return self.executor_(self, params)
def get_query_col_val(self, tbname, i):
return self.query_col.get_val_for_execute(tbname, i)
def get_val(self, tbname: str, idx: int):
params = ()
for p in self.params_:
params = params + (p.get_val(tbname, idx),)
return self.execute(params)
def convert_to_res_type(self, val: Decimal) -> Decimal:
if self.res_type_.is_decimal_type():
return get_decimal(val, self.res_type_.scale())
elif self.res_type_.type == TypeEnum.DOUBLE:
return float(val)
def get_input_types(self) -> List:
pass
def should_skip_for_decimal(self, cols: list)->bool:
return False
def check_query_results(self, query_col_res: List, tbname: str):
query_len = len(query_col_res)
pass
def check_for_filtering(self, query_col_res: List, tbname: str):
j: int = -1
for i in range(len(query_col_res)):
j += 1
v_from_query = query_col_res[i]
while True:
params = ()
for p in self.params_:
if isinstance(p, Column) or isinstance(p, DecimalColumnExpr):
p = p.get_val_for_execute(tbname, j)
params = params + (p,)
v_from_calc_in_py = self.execute(params)
if not v_from_calc_in_py:
j += 1
continue
else:
break
dec_from_query = Decimal(v_from_query)
dec_from_calc = self.get_query_col_val(tbname, j)
if dec_from_query != dec_from_calc:
tdLog.exit(f"filter with {self} failed, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}")
else:
pass
#tdLog.info(f"filter with {self} succ, query got: {dec_from_query}, expect {dec_from_calc}, param: {params}")
def check(self, query_col_res: List, tbname: str):
for i in range(len(query_col_res)):
v_from_query = query_col_res[i]
params = ()
for p in self.params_:
if isinstance(p, Column) or isinstance(p, DecimalColumnExpr):
p = p.get_val_for_execute(tbname, i)
params = params + (p,)
v_from_calc_in_py = self.execute(params)
if v_from_calc_in_py == 'NULL' or v_from_query == 'NULL':
if v_from_calc_in_py != v_from_query:
tdLog.exit(f"query with expr: {self} calc in py got: {v_from_calc_in_py}, query got: {v_from_query}")
#tdLog.debug(f"query with expr: {self} calc got same result: NULL")
continue
failed = False
if self.res_type_.type == TypeEnum.BOOL:
query_res = bool(int(v_from_query))
calc_res = bool(int(v_from_calc_in_py))
failed = query_res != calc_res
elif isinstance(v_from_calc_in_py, float):
query_res = float(v_from_query)
calc_res = float(v_from_calc_in_py)
failed = not math.isclose(query_res, calc_res, abs_tol=1e-7)
else:
query_res = Decimal(v_from_query)
calc_res = Decimal(v_from_calc_in_py)
failed = query_res != calc_res
if failed:
tdLog.exit(
f"check decimal column failed for expr: {self}, input: {[t.__str__() for t in self.get_input_types()]}, res_type: {self.res_type_}, params: {params}, query: {v_from_query}, expect {calc_res}, but get {query_res}"
)
else:
pass
#tdLog.info( f"op succ: {self}, in: {[t.__str__() for t in self.get_input_types()]}, res: {self.res_type_}, params: {params}, insert:{v_from_calc_in_py} query:{v_from_query}, py calc: {calc_res}")
## format_params are already been set
def generate_res_type(self):
pass
def generate(self, format_params) -> str:
self.params_ = format_params
self.generate_res_type()
return self.__str__()
class TypeEnum:
BOOL = 1
TINYINT = 2
SMALLINT = 3
INT = 4
BIGINT = 5
FLOAT = 6
DOUBLE = 7
VARCHAR = 8
TIMESTAMP = 9
NCHAR = 10
UTINYINT = 11
USMALLINT = 12
UINT = 13
UBIGINT = 14
JSON = 15
VARBINARY = 16
DECIMAL = 17
BINARY = 8
GEOMETRY = 20
DECIMAL64 = 21
@staticmethod
def get_type_prec(type: int):
type_prec = [0, 1, 3, 5, 10, 19, 38, 38, 0, 19, 10, 3, 5, 10, 20, 0, 0, 0, 0, 0, 0, 0]
return type_prec[type]
@staticmethod
def get_type_str(type: int):
type_str = [
"",
"BOOL",
"TINYINT",
"SMALLINT",
"INT",
"BIGINT",
"FLOAT",
"DOUBLE",
"VARCHAR",
"TIMESTAMP",
"NCHAR",
"TINYINT UNSIGNED",
"SMALLINT UNSIGNED",
"INT UNSIGNED",
"BIGINT UNSIGNED",
"JSON",
"VARBINARY",
"DECIMAL",
"",
"",
"GEOMETRY",
"DECIMAL",
]
return type_str[type]
class DataType:
def __init__(self, type: int, length: int = 0, type_mod: int = 0):
self.type: int = type
self.length = length
self.type_mod = type_mod
def __str__(self):
if self.type_mod != 0:
return f"{TypeEnum.get_type_str(self.type)}({self.prec()}, {self.scale()})"
if self.length:
return f"{TypeEnum.get_type_str(self.type)}({self.length})"
return TypeEnum.get_type_str(self.type)
def __eq__(self, other):
return self.type == other.type and self.length == other.length and self.type_mod == other.type_mod
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash((self.type, self.length))
def __repr__(self):
return f"DataType({self.type}, {self.length}, {self.type_mod})"
def is_decimal_type(self):
return self.type == TypeEnum.DECIMAL or self.type == TypeEnum.DECIMAL64
def is_varchar_type(self):
return self.type == TypeEnum.VARCHAR or self.type == TypeEnum.NCHAR or self.type == TypeEnum.VARBINARY or self.type == TypeEnum.JSON or self.type == TypeEnum.BINARY
def is_real_type(self):
return self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE
def prec(self):
return 0
def scale(self):
return 0
## TODO generate NULL, None
def generate_value(self, has_null = False) -> str:
if self.type == TypeEnum.BOOL:
return ['true', 'false'][secrets.randbelow(2)]
if self.type == TypeEnum.TINYINT:
return str(secrets.randbelow(256) - 128)
if self.type == TypeEnum.SMALLINT:
return str(secrets.randbelow(65536) - 32768)
if self.type == TypeEnum.INT:
return str(secrets.randbelow(4294967296) - 2147483648)
if self.type == TypeEnum.BIGINT:
return str(secrets.randbelow(9223372036854775808) - 4611686018427387904)
if self.type == TypeEnum.FLOAT or self.type == TypeEnum.DOUBLE:
return str(random.random())
if (
self.type == TypeEnum.VARCHAR
or self.type == TypeEnum.NCHAR
or self.type == TypeEnum.VARBINARY
):
return f"'{str(random.random())[0:self.length]}'"
if self.type == TypeEnum.TIMESTAMP:
return str(secrets.randbelow(9223372036854775808))
if self.type == TypeEnum.UTINYINT:
return str(secrets.randbelow(256))
if self.type == TypeEnum.USMALLINT:
return str(secrets.randbelow(65536))
if self.type == TypeEnum.UINT:
return str(secrets.randbelow(4294967296))
if self.type == TypeEnum.UBIGINT:
return str(secrets.randbelow(9223372036854775808))
if self.type == TypeEnum.JSON:
return f'{{"key": "{secrets.token_urlsafe(10)}"}}'
if self.type == TypeEnum.GEOMETRY:
return "'POINT(1.0 1.0)'"
raise Exception(f"unsupport type {self.type}")
def check(self, values, offset: int):
return True
def get_typed_val_for_execute(self, val, const_col = False):
if self.type == TypeEnum.DOUBLE:
return float(val)
elif self.type == TypeEnum.BOOL:
if val == "true":
return 1
else:
return 0
elif self.type == TypeEnum.FLOAT:
if const_col:
val = float(str(numpy.float32(val)))
else:
val = float(numpy.float32(val))
elif self.type == TypeEnum.DECIMAL or self.type == TypeEnum.DECIMAL64:
return get_decimal(val, self.scale())
elif isinstance(val, str):
val = val.strip("'")
return val
def get_typed_val(self, val):
if self.type == TypeEnum.FLOAT:
return float(str(numpy.float32(val)))
elif self.type == TypeEnum.DOUBLE:
return float(val)
return val
@staticmethod
def get_decimal_types() -> list:
return [TypeEnum.DECIMAL64, TypeEnum.DECIMAL]
@staticmethod
def get_decimal_op_types()-> list:
return [
TypeEnum.BOOL,
TypeEnum.TINYINT,
TypeEnum.SMALLINT,
TypeEnum.INT,
TypeEnum.BIGINT,
TypeEnum.FLOAT,
TypeEnum.DOUBLE,
TypeEnum.VARCHAR,
TypeEnum.NCHAR,
TypeEnum.UTINYINT,
TypeEnum.USMALLINT,
TypeEnum.UINT,
TypeEnum.UBIGINT,
TypeEnum.DECIMAL,
TypeEnum.DECIMAL64,
]
class DecimalType(DataType):
DECIMAL_MAX_PRECISION = 38
DECIMAL64_MAX_PRECISION = 18
def __init__(self, type, precision: int, scale: int):
self.precision_ = precision
self.scale_ = scale
if type == TypeEnum.DECIMAL64:
bytes = 8
else:
bytes = 16
super().__init__(type, bytes, self.get_decimal_type_mod())
self.decimal_generator: DecimalStringRandomGenerator = DecimalStringRandomGenerator()
self.generator_config: DecimalTypeGeneratorConfig = DecimalTypeGeneratorConfig()
#self.generator_config.with_corner_case = False
self.generator_config.prec = precision
self.generator_config.scale = scale
self.aggregator: DecimalColumnAggregator = DecimalColumnAggregator()
self.values: List[str] = []
def get_decimal_type_mod(self) -> int:
return self.precision_ * 100 + self.scale()
def set_prec(self, prec: int):
self.precision_ = prec
self.type_mod = self.get_decimal_type_mod()
def set_scale(self, scale: int):
self.scale_ = scale
self.type_mod = self.get_decimal_type_mod()
def prec(self):
return self.precision_
def scale(self):
return self.scale_
def __str__(self):
return f"DECIMAL({self.precision_}, {self.scale()})"
def __eq__(self, other: DataType):
return self.precision_ == other.prec() and self.scale() == other.scale()
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash((self.precision_, self.scale()))
def __repr__(self):
return f"DecimalType({self.precision_}, {self.scale()})"
def generate_value(self, has_null = True) -> str:
tmp = self.generator_config.with_corner_case
self.generator_config.with_corner_case = has_null
val = self.decimal_generator.generate(self.generator_config)
self.aggregator.add_value(val, self.scale()) ## convert to Decimal first
self.generator_config.with_corner_case = tmp
# self.values.append(val) ## save it into files maybe
return val
def get_typed_val(self, val):
if val == "NULL":
return None
return get_decimal(val, self.scale())
def get_typed_val_for_execute(self, val, const_col = False):
return self.get_typed_val(val)
@staticmethod
def default_compression() -> str:
return "zstd"
@staticmethod
def default_encode() -> str:
return "disabled"
def check(self, values, offset: int):
val_from_query = values
val_insert = self.values[offset:]
for v_from_query, v_from_insert in zip(val_from_query, val_insert):
if v_from_insert == "NULL":
if v_from_query.strip() != "NULL":
tdLog.debug(
f"val_insert: {val_insert} val_from_query: {val_from_query}"
)
tdLog.exit(f"insert NULL, query not NULL: {v_from_query}")
else:
continue
try:
dec_query: Decimal = Decimal(v_from_query)
dec_insert: Decimal = Decimal(v_from_insert)
dec_insert = get_decimal(dec_insert, self.scale())
except Exception as e:
tdLog.exit(f"failed to convert {v_from_query} or {v_from_insert} to decimal, {e}")
return False
if dec_query != dec_insert:
tdLog.exit(
f"check decimal column failed for insert: {v_from_insert}, query: {v_from_query}, expect {dec_insert}, but get {dec_query}"
)
return False
else:
tdLog.debug(
f"check decimal succ, insert:{v_from_insert} query:{v_from_query}, py dec: {dec_insert}"
)
@staticmethod
def decimal_type_from_other_type(other: DataType):
prec = 0
return DecimalType(other.type, other.length, other.type_mod)
class Column:
def __init__(self, type: DataType):
self.type_: DataType = type
self.name_: str = ""
self.saved_vals:dict[str:[]] = {}
def is_constant_col(self):
return '' in self.saved_vals.keys()
def get_typed_val(self, val):
return self.type_.get_typed_val(val)
def get_typed_val_for_execute(self, val, const_col = False):
return self.type_.get_typed_val_for_execute(val, const_col)
def get_constant_val(self):
return self.get_typed_val(self.saved_vals[''][0])
def get_constant_val_for_execute(self):
return self.get_typed_val_for_execute(self.saved_vals[''][0], const_col=True)
def __str__(self):
if self.is_constant_col():
return str(self.get_constant_val())
return self.name_
def get_val_for_execute(self, tbname: str, idx: int):
if self.is_constant_col():
return self.get_constant_val_for_execute()
if len(self.saved_vals) > 1:
for key in self.saved_vals.keys():
l = len(self.saved_vals[key])
if idx < l:
return self.get_typed_val_for_execute(self.saved_vals[key][idx])
else:
idx -= l
return self.get_typed_val_for_execute(self.saved_vals[tbname][idx])
def get_cardinality(self, tbname):
if self.is_constant_col():
return 1
elif len(self.saved_vals) > 1:
return len(self.saved_vals['t0'])
else:
return len(self.saved_vals[tbname])
@staticmethod
def comp_key(key1, key2):
if key1 is None:
return -1
if key2 is None:
return 1
return key1 - key2
def get_ordered_result(self, tbname: str, asc: bool) -> list:
if tbname in self.saved_vals:
return sorted(
[
get_decimal(val, self.type_.scale())
for val in self.saved_vals[tbname]
],
reverse=not asc,
key=cmp_to_key(Column.comp_key)
)
else:
res = []
for val in self.saved_vals.values():
res.extend(val)
return sorted(
[get_decimal(val, self.type_.scale()) for val in res], reverse=not asc,
key=cmp_to_key(Column.comp_key)
)
def get_group_num(self, tbname, ignore_null=False) -> int:
if tbname in self.saved_vals:
s = set(get_decimal(val, self.type_.scale()) for val in self.saved_vals[tbname])
if ignore_null:
s.remove(None)
return len(s)
else:
res = set()
for vals in self.saved_vals.values():
for v in vals:
res.add(get_decimal(v, self.type_.scale()))
if ignore_null:
res.remove(None)
return len(res)
## tbName: for normal table, pass the tbname, for child table, pass the child table name
def generate_value(self, tbName: str = '', save: bool = True, has_null = True):
val = self.type_.generate_value(has_null=has_null)
if save:
if tbName not in self.saved_vals:
self.saved_vals[tbName] = []
## for constant columns, always replace the last val
if self.is_constant_col():
self.saved_vals[tbName] = [val]
else:
self.saved_vals[tbName].append(val)
return val
def get_type_str(self) -> str:
return str(self.type_)
def set_name(self, name: str):
self.name_ = name
def check(self, values, offset: int):
return self.type_.check(values, offset)
def construct_type_value(self, val: str):
if (
self.type_.type == TypeEnum.BINARY
or self.type_.type == TypeEnum.VARCHAR
or self.type_.type == TypeEnum.NCHAR
or self.type_.type == TypeEnum.VARBINARY
or self.type_.type == TypeEnum.JSON
):
return f"'{val}'"
else:
return val
@staticmethod
def get_decimal_unsupported_types() -> list:
return [
TypeEnum.JSON,
TypeEnum.GEOMETRY,
TypeEnum.VARBINARY,
]
@staticmethod
def get_decimal_oper_const_cols() -> list:
types_unable_to_be_const = [
TypeEnum.TINYINT,
TypeEnum.SMALLINT,
TypeEnum.INT,
TypeEnum.UINT,
TypeEnum.USMALLINT,
TypeEnum.UTINYINT,
TypeEnum.UBIGINT,
]
return Column.get_all_type_columns(
Column.get_decimal_unsupported_types()
+ Column.get_decimal_types()
+ types_unable_to_be_const
)
@staticmethod
def get_decimal_types() -> List:
return [TypeEnum.DECIMAL, TypeEnum.DECIMAL64]
@staticmethod
def get_all_type_columns(types_to_exclude: List[TypeEnum] = []) -> List:
all_types = [
Column(DataType(TypeEnum.BOOL)),
Column(DataType(TypeEnum.TINYINT)),
Column(DataType(TypeEnum.SMALLINT)),
Column(DataType(TypeEnum.INT)),
Column(DataType(TypeEnum.BIGINT)),
Column(DataType(TypeEnum.FLOAT)),
Column(DataType(TypeEnum.DOUBLE)),
Column(DataType(TypeEnum.VARCHAR, 255)),
Column(DataType(TypeEnum.TIMESTAMP)),
Column(DataType(TypeEnum.NCHAR, 255)),
Column(DataType(TypeEnum.UTINYINT)),
Column(DataType(TypeEnum.USMALLINT)),
Column(DataType(TypeEnum.UINT)),
Column(DataType(TypeEnum.UBIGINT)),
Column(DataType(TypeEnum.JSON)),
Column(DataType(TypeEnum.VARBINARY, 255)),
Column(DecimalType(TypeEnum.DECIMAL, 38, 10)),
Column(DataType(TypeEnum.BINARY, 255)),
Column(DataType(TypeEnum.GEOMETRY, 10240)),
Column(DecimalType(TypeEnum.DECIMAL64, 18, 4)),
]
ret = []
for c in all_types:
found = False
for type in types_to_exclude:
if c.type_.type == type:
found = True
break
if not found:
ret.append(c)
return ret
class DecimalColumnTableCreater:
def __init__(
self,
conn,
dbName: str,
tbName: str,
columns: List[Column],
tags_cols: List[Column] = [],
col_prefix: str = "c",
tag_prefix: str = "t",
):
self.conn = conn
self.dbName = dbName
self.tbName = tbName
self.tags_cols = tags_cols
self.columns: List[Column] = columns
self.col_prefix = col_prefix
self.tag_prefix = tag_prefix
def create(self):
if len(self.tags_cols) > 0:
table = "stable"
else:
table = "table"
sql = f"create {table} {self.dbName}.{self.tbName} (ts timestamp"
for i, column in enumerate(self.columns):
tbname = f"{self.col_prefix}{i+1}"
sql += f", {tbname} {column.get_type_str()}"
column.set_name(tbname)
if self.tags_cols:
sql += ") tags("
for i, tag in enumerate(self.tags_cols):
tagname = f"{self.tag_prefix}{i+1}"
sql += f"{tagname} {tag.get_type_str()}"
tag.set_name(tagname)
if i != len(self.tags_cols) - 1:
sql += ", "
sql += ")"
self.conn.execute(sql, queryTimes=1)
def create_child_table(
self,
ctbPrefix: str,
ctbNum: int,
tag_cols: List[Column],
tag_values: List[str],
):
for i in range(ctbNum):
tbname = f"{ctbPrefix}{i}"
sql = f"create table {self.dbName}.{tbname} using {self.dbName}.{self.tbName} tags("
for j, tag in enumerate(tag_cols):
sql += f"{tag.construct_type_value(tag_values[j])}"
if j != len(tag_cols) - 1:
sql += ", "
sql += ")"
self.conn.execute(sql, queryTimes=1)
class TableInserter:
def __init__(
self,
conn,
dbName: str,
tbName: str,
columns: List[Column],
tags_cols: List[Column] = [],
):
self.conn: TDSql = conn
self.dbName = dbName
self.tbName = tbName
self.tag_cols = tags_cols
self.columns = columns
def insert(self, rows: int, start_ts: int, step: int, flush_database: bool = False):
pre_insert = f"insert into {self.dbName}.{self.tbName} values"
sql = pre_insert
t = datetime.now()
for i in range(rows):
sql += f"({start_ts + i * step}"
for column in self.columns:
sql += f", {column.generate_value(self.tbName)}"
sql += ")"
if i != rows - 1:
sql += ", "
local_flush_database = i % 5000 == 0
if len(sql) > 1000:
# tdLog.debug(f"insert into with sql{sql}")
if flush_database and local_flush_database:
self.conn.execute(f"flush database {self.dbName}", queryTimes=1)
self.conn.execute(sql, queryTimes=1)
t1 = datetime.now()
if (t1 - t).seconds > 1:
TaosShell().query(f"select last(c1), last(c2) from {self.dbName}.{self.tbName}")
t = t1
sql = pre_insert
if len(sql) > len(pre_insert):
# tdLog.debug(f"insert into with sql{sql}")
if flush_database:
self.conn.execute(f"flush database {self.dbName}", queryTimes=1)
self.conn.execute(sql, queryTimes=1)
TaosShell().query(f"select last(c1), last(c2) from {self.dbName}.{self.tbName}")
class DecimalCastTypeGenerator:
def __init__(self, input_type: DataType):
self.input_type_: DataType = input_type
def get_possible_output_types(self) -> List[int]:
if not self.input_type_.is_decimal_type():
return DataType.get_decimal_types()
else:
return DataType.get_decimal_op_types()
def do_generate_type(self, dt: int) ->DataType:
if dt == TypeEnum.DECIMAL:
prec = random.randint(1, DecimalType.DECIMAL_MAX_PRECISION)
return DecimalType(dt, prec, random.randint(0, prec))
elif dt == TypeEnum.DECIMAL64:
prec = random.randint(1, DecimalType.DECIMAL64_MAX_PRECISION)
return DecimalType(dt, prec, random.randint(0, prec))
elif dt == TypeEnum.BINARY or dt == TypeEnum.VARCHAR:
return DataType(dt, random.randint(16, 255), 0)
else:
return DataType(dt, 0, 0)
def generate(self, num: int) -> List[DataType]:
res: list[DataType] = []
for _ in range(num):
dt = random.choice(self.get_possible_output_types())
dt = self.do_generate_type(dt)
res.append(dt)
res = list(set(res))
return res
class TableDataValidator:
def __init__(self, columns: List[Column], tbName: str, dbName: str, tbIdx: int = 0):
self.columns = columns
self.tbName = tbName
self.dbName = dbName
self.tbIdx = tbIdx
def validate(self):
if not decimal_insert_validator_test:
return
sql = f"select * from {self.dbName}.{self.tbName}"
res = TaosShell().query(sql)
row_num = len(res[1])
colIdx = 1
for col in self.columns:
if (
col.type_.type == TypeEnum.DECIMAL
or col.type_.type == TypeEnum.DECIMAL64
):
col.check(res[colIdx], row_num * self.tbIdx)
colIdx += 1
class DecimalFunction(DecimalColumnExpr):
def __init__(self, format, executor, name: str):
super().__init__(format, executor)
self.func_name_ = name
def is_agg_func(self, op: str) ->bool:
return False
def get_func_res(self):
return None
def get_input_types(self):
return [self.query_col]
@staticmethod
def get_decimal_agg_funcs() -> List:
return [
DecimalMinFunction(),
DecimalMaxFunction(),
DecimalSumFunction(),
DecimalAvgFunction(),
DecimalCountFunction(),
DecimalLastRowFunction(),
DecimalLastFunction(),
DecimalFirstFunction(),
]
def check_results(self, query_col_res: List) -> bool:
return False
def check_for_agg_func(self, query_col_res: List, tbname: str, func, is_stb: bool = False):
col_expr = self.query_col
for i in range(col_expr.get_cardinality(tbname)):
col_val = col_expr.get_val_for_execute(tbname, i)
self.execute((col_val,))
if not self.check_results(query_col_res):
tdLog.exit(f"check failed for {self}, query got: {query_col_res}, expect {self.get_func_res()}")
else:
tdLog.info(f"check expr: {func} with val: {col_val} got result: {query_col_res}, expect: {self.get_func_res()}")
class DecimalCastFunction(DecimalFunction):
def __init__(self):
super().__init__("cast({0} as {1})", DecimalCastFunction.execute_cast, "cast")
def should_skip_for_decimal(self, cols: list)->bool:
return False
def check_results(self, query_col_res: List) -> bool:
return False
def generate_res_type(self)->DataType:
self.query_col = self.params_[0]
self.res_type_ = self.params_[1]
return self.res_type_
def check(self, res: list, tbname: str):
calc_res = []
params = []
for i in range(self.query_col.get_cardinality(tbname)):
val = self.query_col.get_val_for_execute(tbname, i)
params.append(val)
try:
calc_val = self.execute(val)
except OverflowError as e:
tdLog.info(f"execute {self} overflow for param: {val}")
calc_res = []
break
calc_res.append(calc_val)
if len(calc_res) != len(res):
tdLog.exit(f"check result for {self} failed len got: {len(res)}, expect: {len(calc_res)}")
if len(calc_res) == 0:
return True
for v, calc_v, p in zip(res, calc_res, params):
query_v = self.execute_cast(v)
if isinstance(calc_v, float):
eq = math.isclose(query_v, calc_v, rel_tol=1e-7)
elif isinstance(calc_v, numpy.float32):
eq = math.isclose(query_v, calc_v, abs_tol=1e-6, rel_tol=1e-6)
elif isinstance(p, float) or isinstance(p, str):
eq = math.isclose(query_v, calc_v, rel_tol=1e-7)
else:
eq = query_v == calc_v
if not eq:
tdLog.exit(f"check result for {self} failed with param: {p} query got: {v}, expect: {calc_v}")
return True
def execute_cast(self, val):
if val is None or val == 'NULL':
return None
if self.res_type_.type == TypeEnum.BOOL:
return Decimal(val) != 0
elif self.res_type_.type == TypeEnum.TINYINT:
dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP)
return int(dec) & 0xFF
elif self.res_type_.type == TypeEnum.SMALLINT:
dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP)
return int(dec) & 0xFFFF
elif self.res_type_.type == TypeEnum.INT:
dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP)
return int(dec) & 0xFFFFFFFF
elif self.res_type_.type == TypeEnum.BIGINT or self.res_type_.type == TypeEnum.TIMESTAMP:
dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP)
return int(dec) & 0xFFFFFFFFFFFFFFFF
elif self.res_type_.type == TypeEnum.FLOAT:
return numpy.float32(val)
elif self.res_type_.type == TypeEnum.DOUBLE:
return float(val)
elif self.res_type_.type == TypeEnum.VARCHAR or self.res_type_.type == TypeEnum.NCHAR:
if Decimal(val) == 0:
return "0"
return str(val)[0:self.res_type_.length]
elif self.res_type_.type == TypeEnum.UTINYINT:
dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP)
return int(dec) & 0xFF
elif self.res_type_.type == TypeEnum.USMALLINT:
dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP)
return int(dec) & 0xFFFF
elif self.res_type_.type == TypeEnum.UINT:
dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP)
return int(dec) & 0xFFFFFFFF
elif self.res_type_.type == TypeEnum.UBIGINT:
dec = Decimal(val).quantize(Decimal("1"), ROUND_HALF_UP)
return int(dec) & 0xFFFFFFFFFFFFFFFF
elif self.res_type_.is_decimal_type():
max: Decimal = Decimal(
"9" * (self.res_type_.prec() - self.res_type_.scale())
+ "."
+ "9" * self.res_type_.scale()
)
if max < get_decimal(val, self.res_type_.scale()):
raise OverflowError()
try:
return get_decimal(val, self.res_type_.scale())
except Exception as e:
tdLog.exit(f"failed to cast {val} to {self.res_type_}, {e}")
else:
raise Exception(f"cast unsupported type {self.res_type_.type}")
class DecimalAggFunction(DecimalFunction):
def __init__(self, format, executor, name: str):
super().__init__(format, executor, name)
def is_agg_func(self, op: str)-> bool:
return True
def should_skip_for_decimal(self, cols: list)-> bool:
col: Column = cols[0]
if col.type_.is_decimal_type():
return False
return True
def check_results(self, query_col_res):
if len(query_col_res) == 0:
tdLog.info(f"query got no output: {self}, py calc: {self.get_func_res()}")
return True
else:
return self.get_func_res() == Decimal(query_col_res[0])
class DecimalFirstLastFunction(DecimalAggFunction):
def __init__(self, format: str, func, name):
super().__init__(format, func, name)
def generate_res_type(self):
self.res_type_ = self.query_col.type_
def check_results(self, query_col_res):
if len(query_col_res) == 0:
tdLog.exit(f"query got no output: {self}, py calc: {self.get_func_res()}")
else:
v = get_decimal(query_col_res[0], self.query_col.type_.scale())
decimal_type: DecimalType = self.query_col.type_
if decimal_type.aggregator.is_stb():
return v in self.get_func_res()
else:
return self.get_func_res() == v
class DecimalLastRowFunction(DecimalFirstLastFunction):
def __init__(self):
super().__init__("last_row({0})", DecimalLastRowFunction.execute_last_row, "last_row")
def get_func_res(self):
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.get_last_row()
def execute_last_row(self, params):
pass
class DecimalCacheLastRowFunction(DecimalAggFunction):
def __init__(self):
super().__init__("_cache_last_row({0})", DecimalCacheLastRowFunction.execute_cache_last_row, "_cache_last_row")
def get_func_res(self):
return 1
def generate_res_type(self):
self.res_type_ = self.query_col.type_
def execute_cache_last_row(self, params):
return 1
class DecimalCacheLastFunction(DecimalAggFunction):
pass
class DecimalFirstFunction(DecimalFirstLastFunction):
def __init__(self):
super().__init__("first({0})", DecimalFirstFunction.execute_first, "first")
def get_func_res(self):
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.get_first()
def execute_first(self, params):
pass
class DecimalLastFunction(DecimalFirstLastFunction):
def __init__(self):
super().__init__("last({0})", DecimalLastFunction.execute_last, "last")
self.res_ = None
def get_func_res(self):
decimal_type:DecimalType = self.query_col.type_
return decimal_type.aggregator.get_last()
def execute_last(self, params):
pass
class DecimalHyperloglogFunction(DecimalAggFunction):
pass
class DecimalSampleFunction(DecimalAggFunction):
pass
class DecimalTailFunction(DecimalAggFunction):
pass
class DecimalUniqueFunction(DecimalAggFunction):
pass
class DecimalModeFunction(DecimalAggFunction):
pass
class DecimalCountFunction(DecimalAggFunction):
def __init__(self):
super().__init__("count({0})", DecimalCountFunction.execute_count, "count")
def get_func_res(self):
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.count - decimal_type.aggregator.null_num
def generate_res_type(self):
self.res_type_ = DataType(TypeEnum.BIGINT, 8, 0)
def execute_count(self, params):
return 1
class DecimalMinFunction(DecimalAggFunction):
def __init__(self):
super().__init__("min({0})", DecimalMinFunction.execute_min, "min")
self.min_: Decimal = None
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.min
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
def execute_min(self, params):
pass
class DecimalMaxFunction(DecimalAggFunction):
def __init__(self):
super().__init__("max({0})", DecimalMaxFunction.execute_max, "max")
self.max_: Decimal = None
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.max
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
def execute_max(self, params):
pass
class DecimalSumFunction(DecimalAggFunction):
def __init__(self):
super().__init__("sum({0})", DecimalSumFunction.execute_sum, "sum")
self.sum_:Decimal = None
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.sum
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
self.res_type_.set_prec(DecimalType.DECIMAL_MAX_PRECISION)
def execute_sum(self, params):
pass
class DecimalAvgFunction(DecimalAggFunction):
def __init__(self):
super().__init__("avg({0})", DecimalAvgFunction.execute_avg, "avg")
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return get_decimal(
decimal_type.aggregator.sum
/ (decimal_type.aggregator.count - decimal_type.aggregator.null_num),
self.res_type_.scale(),
)
def generate_res_type(self) -> DataType:
sum_type = self.query_col.type_
sum_type.set_prec(DecimalType.DECIMAL_MAX_PRECISION)
count_type = DataType(TypeEnum.BIGINT, 8, 0)
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(sum_type, count_type, "/")
def execute_avg(self, params):
pass
class DecimalBinaryOperator(DecimalColumnExpr):
def __init__(self, format, executor, op: str):
super().__init__(format, executor)
self.op_ = op
self.left_type_: DataType = None
self.right_type_: DataType = None
def __str__(self):
return super().__str__()
def generate(self, format_params) -> str:
return super().generate(format_params)
def should_skip_for_decimal(self, cols: list):
left_col = cols[0]
right_col = cols[1]
if not left_col.type_.is_decimal_type() and not right_col.type_.is_decimal_type():
return True
if self.op_ != "%":
return False
## TODO wjm why skip decimal % float/double? it's wrong now.
left_is_real = left_col.type_.is_real_type() or left_col.type_.is_varchar_type()
right_is_real = right_col.type_.is_real_type() or right_col.type_.is_varchar_type()
if left_is_real or right_is_real:
return True
return False
@staticmethod
def check_null(params):
if params[0] is None or params[1] is None:
return True
else:
return False
@staticmethod
def is_compare_op(op: str)-> bool:
return op in ["==", "!=", ">", "<", ">=", "<="]
@staticmethod
def calc_decimal_prec_scale(left: DataType, right: DataType, op: str) -> DecimalType:
left_prec = 0
left_scale = 0
right_prec = 0
right_scale = 0
if not left.is_decimal_type():
left_prec = TypeEnum.get_type_prec(left.type)
else:
left_prec = left.prec()
left_scale = left.scale()
if not right.is_decimal_type():
right_prec = TypeEnum.get_type_prec(right.type)
else:
right_prec = right.prec()
right_scale = right.scale()
out_prec = 0
out_scale = 0
if op in ['+', '-']:
out_scale = max(left_scale, right_scale)
out_prec = max(left_prec - left_scale, right_prec - right_scale) + out_scale + 1
elif op == '*':
out_scale = left_scale + right_scale
out_prec = left_prec + right_prec + 1
elif op == '/':
out_scale = max(left_scale + right_prec + 1, 6)
out_prec = left_prec - left_scale + right_scale + out_scale
elif op == '%':
out_scale = max(left_scale, right_scale)
out_prec = min(left_prec - left_scale, right_prec - right_scale) + out_scale
else:
raise Exception(f"unknown op for binary operators: {op}")
if out_prec > 38:
min_scale = min(6, out_scale)
delta = out_prec - 38
out_prec = 38
out_scale = max(min_scale, out_scale - delta)
return DecimalType(TypeEnum.DECIMAL, out_prec, out_scale)
def generate_res_type(self):
left_type = self.params_[0].type_
right_type = self.params_[1].type_
self.left_type_ = left_type
self.right_type_ = right_type
if DecimalBinaryOperator.is_compare_op(self.op_):
self.res_type_ = DataType(TypeEnum.BOOL)
else:
ret_double_types = [TypeEnum.VARCHAR, TypeEnum.BINARY, TypeEnum.DOUBLE, TypeEnum.FLOAT, TypeEnum.NCHAR]
if left_type.type in ret_double_types or right_type.type in ret_double_types:
self.res_type_ = DataType(TypeEnum.DOUBLE)
else:
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(left_type, right_type, self.op_)
def get_input_types(self)-> list:
return [self.left_type_, self.right_type_]
def get_convert_type(self, params):
ret_float = False
if isinstance(params[0], float) or isinstance(params[1], float):
ret_float = True
left = params[0]
right = params[1]
if isinstance(params[0], str):
ret_float = True
if isinstance(params[1], str):
ret_float = True
return (left, right), ret_float
def execute_plus(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = self.get_convert_type(params)
if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) + float(right)
else:
return self.convert_to_res_type(Decimal(left) + Decimal(right))
def execute_minus(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = self.get_convert_type(params)
if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) - float(right)
else:
return self.convert_to_res_type(Decimal(left) - Decimal(right))
def execute_mul(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), ret_float = self.get_convert_type(params)
if self.res_type_.type == TypeEnum.DOUBLE:
return float(left) * float(right)
else:
return self.convert_to_res_type(Decimal(left) * Decimal(right))
def execute_div(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), _ = self.get_convert_type(params)
if self.res_type_.type == TypeEnum.DOUBLE:
if right == 0:
return 'NULL'
return float(left) / float(right)
else:
return self.convert_to_res_type(Decimal(left) / Decimal(right))
def execute_mod(self, params):
if DecimalBinaryOperator.check_null(params):
return 'NULL'
(left, right), _ = self.get_convert_type(params)
if self.res_type_.type == TypeEnum.DOUBLE:
return self.convert_to_res_type(Decimal(left) % Decimal(right))
else:
return self.convert_to_res_type(Decimal(left) % Decimal(right))
def execute_eq(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = self.get_convert_type(params)
if convert_float:
return float(left) == float(right)
else:
return Decimal(left) == Decimal(right)
def execute_eq_filtering(self, params):
if self.execute_eq(params):
return True
def execute_ne(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = self.get_convert_type(params)
if convert_float:
return float(left) != float(right)
else:
return Decimal(left) != Decimal(right)
def execute_gt(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = self.get_convert_type(params)
if convert_float:
return float(left) > float(right)
else:
return Decimal(left) > Decimal(right)
def execute_lt(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = self.get_convert_type(params)
if convert_float:
return float(left) < float(right)
else:
return Decimal(left) < Decimal(right)
def execute_ge(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = self.get_convert_type(params)
if convert_float:
return float(left) >= float(right)
else:
return Decimal(left) >= Decimal(right)
def execute_le(self, params):
if DecimalBinaryOperator.check_null(params):
return False
(left, right), convert_float = self.get_convert_type(params)
if convert_float:
return float(left) <= float(right)
else:
return Decimal(left) <= Decimal(right)
@staticmethod
def get_all_binary_ops() -> List[DecimalColumnExpr]:
return [
DecimalBinaryOperator(" {0} + {1} ", DecimalBinaryOperator.execute_plus, "+"),
DecimalBinaryOperator(" {0} - {1} ", DecimalBinaryOperator.execute_minus, "-"),
DecimalBinaryOperator(" {0} * {1} ", DecimalBinaryOperator.execute_mul, "*"),
DecimalBinaryOperator(" {0} / {1} ", DecimalBinaryOperator.execute_div, "/"),
DecimalBinaryOperator(" {0} % {1} ", DecimalBinaryOperator.execute_mod, "%"),
DecimalBinaryOperator(" {0} == {1} ", DecimalBinaryOperator.execute_eq, "=="),
DecimalBinaryOperator(" {0} != {1} ", DecimalBinaryOperator.execute_ne, "!="),
DecimalBinaryOperator(" {0} > {1} ", DecimalBinaryOperator.execute_gt, ">"),
DecimalBinaryOperator(" {0} < {1} ", DecimalBinaryOperator.execute_lt, "<"),
DecimalBinaryOperator(" {0} >= {1} ", DecimalBinaryOperator.execute_ge, ">="),
DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="),
]
@staticmethod
def get_all_filtering_binary_compare_ops() -> List[DecimalColumnExpr]:
return [
DecimalBinaryOperator(" {0} == {1} ", DecimalBinaryOperator.execute_eq, "=="),
DecimalBinaryOperator(" {0} != {1} ", DecimalBinaryOperator.execute_ne, "!="),
DecimalBinaryOperator(" {0} > {1} ", DecimalBinaryOperator.execute_gt, ">"),
DecimalBinaryOperator(" {0} < {1} ", DecimalBinaryOperator.execute_lt, "<"),
DecimalBinaryOperator(" {0} >= {1} ", DecimalBinaryOperator.execute_ge, ">="),
DecimalBinaryOperator(" {0} <= {1} ", DecimalBinaryOperator.execute_le, "<="),
]
def execute(self, params):
return super().execute(params)
class DecimalUnaryOperator(DecimalColumnExpr):
def __init__(self, format, executor, op: str):
super().__init__(format, executor)
self.op_ = op
self.col_type_: DataType = None
def should_skip_for_decimal(self, cols: list):
col:Column = cols[0]
if not col.type_.is_decimal_type():
return True
return False
@staticmethod
def get_all_unary_ops() -> List[DecimalColumnExpr]:
return [
DecimalUnaryOperator(" -{0} ", DecimalUnaryOperator.execute_minus, "-"),
]
def get_input_types(self)-> list:
return [self.col_type_]
def generate_res_type(self):
self.res_type_ = self.col_type_ = self.params_[0].type_
def execute_minus(self, params) -> Decimal:
if params[0] is None:
return 'NULL'
return -Decimal(params[0])
class DecimalBinaryOperatorIn(DecimalBinaryOperator):
def __init__(self, op: str):
super().__init__(op)
def execute(self, left, right):
if self.op_.lower()() == "in":
return left in right
if self.op_.lower() == "not in":
return left not in right
class TDTestCase:
updatecfgDict = {
"asynclog": 0,
"ttlUnit": 1,
"ttlPushInterval": 5,
"ratioOfVnodeStreamThrea": 4,
"debugFlag": 143,
}
def __init__(self):
self.vgroups = 4
self.ctbNum = 10
self.rowsPerTbl = 10000
self.duraion = "1h"
self.norm_tb_columns = []
self.tags = []
self.stable_name = "meters"
self.norm_table_name = "nt"
self.col_prefix = "c"
self.c_table_prefix = "t"
self.tag_name_prefix = "t"
self.db_name = "test"
self.c_table_num = ctb_num
self.no_decimal_col_tb_name = "tt"
self.stb_columns = []
self.stream_name = "stream1"
self.stream_out_stb = "stream_out_stb"
self.tsma_name = "tsma1"
self.query_test_round = 10000
def init(self, conn, logSql, replicaVar=1):
self.replicaVar = int(replicaVar)
tdLog.debug(f"start to excute {__file__}")
tdSql.init(conn.cursor(), False)
def check_desc_for_one_ctb(
self, ctbPrefix: str, columns: List[Column], tags: List[Column] = []
):
ctb_idx = randrange(self.c_table_num)
return self.check_desc(f"{ctbPrefix}{ctb_idx}", columns, tags)
def check_desc(self, tbname: str, columns: List[Column], tags: List[Column] = []):
sql = f"desc {self.db_name}.{tbname}"
tdSql.query(sql, queryTimes=1)
results = tdSql.queryResult
for i, col in enumerate(columns):
if col.type_.type == TypeEnum.DECIMAL:
if results[i + 1][1] != col.type_.__str__():
tdLog.info(str(results))
tdLog.exit(
f"check desc failed for table: {tbname} column {results[i+1][0]} type is {results[i+1][1]}, expect {col.type_}"
)
if results[i + 1][4] != DecimalType.default_encode():
tdLog.exit(
f"check desc failed for table: {tbname} column {results[i+1][0]} encode is {results[i+1][5]}, expect {DecimalType.default_encode()}"
)
if results[i + 1][5] != DecimalType.default_compression():
tdLog.exit(
f"check desc failed for table: {tbname} column {results[i+1][0]} compression is {results[i+1][4]}, expect {DecimalType.default_compression()}"
)
if tbname == self.stable_name:
self.check_desc_for_one_ctb(self.c_table_prefix, columns, tags)
def check_show_create_table(
self, tbname: str, cols: List[Column], tags: List[Column] = []
):
sql = f"show create table {self.db_name}.{tbname}"
tdSql.query(sql, queryTimes=1)
create_table_sql = tdSql.queryResult[0][1]
decimal_idx = 0
results = re.findall(r"DECIMAL\((\d+),(\d+)\)", create_table_sql)
for i, col in enumerate(cols):
if (
col.type_.type == TypeEnum.DECIMAL
or col.type_.type == TypeEnum.DECIMAL64
):
result_type = DecimalType(
col.type_.type,
int(results[decimal_idx][0]),
int(results[decimal_idx][1]),
)
if result_type != col.type_:
tdLog.exit(
f"check show create table failed for: {tbname} column {i} type is {result_type}, expect {col.type}"
)
decimal_idx += 1
def test_add_drop_columns_with_decimal(self, tbname: str, columns: List[Column]):
is_stb = tbname == self.stable_name
## alter table add column
create_c99_sql = (
f"alter table {self.db_name}.{tbname} add column c99 decimal(37, 19)"
)
columns.append(Column(DecimalType(TypeEnum.DECIMAL, 37, 19)))
tdSql.execute(create_c99_sql, queryTimes=1, show=True)
self.check_desc(tbname, columns)
## alter table add column with compression
create_c100_sql = f'ALTER TABLE {self.db_name}.{tbname} ADD COLUMN c100 decimal(36, 18) COMPRESS "zstd"'
tdSql.execute(create_c100_sql, queryTimes=1, show=True)
columns.append(Column(DecimalType(TypeEnum.DECIMAL, 36, 18)))
self.check_desc(tbname, columns)
## drop non decimal column
drop_c6_sql = f"alter table {self.db_name}.{tbname} drop column c6"
tdSql.execute(drop_c6_sql, queryTimes=1, show=True)
c6 = columns.pop(5)
self.check_desc(tbname, columns)
## drop decimal column and not last column
drop_c99_sql = f"alter table {self.db_name}.{tbname} drop column c99"
tdSql.execute(drop_c99_sql, queryTimes=1, show=True)
c99 = columns.pop(len(columns) - 2)
self.check_desc(tbname, columns)
## drop decimal column and last column
drop_c100_sql = f"alter table {self.db_name}.{tbname} drop column c100"
tdSql.execute(drop_c100_sql, queryTimes=1, show=True)
c100 = columns.pop(len(columns) - 1)
self.check_desc(tbname, columns)
## create decimal back
tdSql.execute(create_c99_sql, queryTimes=1, show=True)
tdSql.execute(create_c100_sql, queryTimes=1, show=True)
columns.append(c99)
columns.append(c100)
self.check_desc(tbname, columns)
def test_decimal_column_ddl(self):
## create decimal type table, normal/super table, decimal64/decimal128
tdLog.printNoPrefix("-------- test create decimal column")
self.norm_tb_columns: List[Column] = [
Column(DecimalType(TypeEnum.DECIMAL, 10, 2)),
Column(DecimalType(TypeEnum.DECIMAL, 20, 4)),
Column(DecimalType(TypeEnum.DECIMAL, 30, 8)),
Column(DecimalType(TypeEnum.DECIMAL, 38, 10)),
Column(DataType(TypeEnum.TINYINT)),
Column(DataType(TypeEnum.INT)),
Column(DataType(TypeEnum.BIGINT)),
Column(DataType(TypeEnum.DOUBLE)),
Column(DataType(TypeEnum.FLOAT)),
Column(DataType(TypeEnum.VARCHAR, 255)),
]
self.tags: List[Column] = [
Column(DataType(TypeEnum.INT)),
Column(DataType(TypeEnum.VARCHAR, 255)),
]
self.stb_columns: List[Column] = [
Column(DecimalType(TypeEnum.DECIMAL, 10, 2)),
Column(DecimalType(TypeEnum.DECIMAL, 20, 4)),
Column(DecimalType(TypeEnum.DECIMAL, 30, 8)),
Column(DecimalType(TypeEnum.DECIMAL, 38, 10)),
Column(DataType(TypeEnum.TINYINT)),
Column(DataType(TypeEnum.INT)),
Column(DataType(TypeEnum.BIGINT)),
Column(DataType(TypeEnum.DOUBLE)),
Column(DataType(TypeEnum.FLOAT)),
Column(DataType(TypeEnum.VARCHAR, 255)),
]
DecimalColumnTableCreater(
tdSql,
self.db_name,
self.stable_name,
self.stb_columns,
self.tags,
col_prefix=self.col_prefix,
tag_prefix=self.tag_name_prefix,
).create()
self.check_show_create_table("meters", self.stb_columns, self.tags)
DecimalColumnTableCreater(
tdSql, self.db_name, self.norm_table_name, self.norm_tb_columns
).create()
self.check_desc(self.norm_table_name, self.norm_tb_columns)
self.check_show_create_table(self.norm_table_name, self.norm_tb_columns)
## TODO add more values for all rows
tag_values = ["1", "t1"]
DecimalColumnTableCreater(
tdSql, self.db_name, self.stable_name, self.stb_columns
).create_child_table(
self.c_table_prefix, self.c_table_num, self.tags, tag_values
)
self.check_desc("meters", self.stb_columns, self.tags)
self.check_desc("t1", self.stb_columns, self.tags)
## invalid precision/scale
invalid_precision_scale = [
("decimal(-1, 2)", syntax_error),
("decimal(39, 2)", invalid_column),
("decimal(10, -1)", syntax_error),
("decimal(10, 39)", invalid_column),
("decimal(10, 2.5)", syntax_error),
("decimal(10.5, 2)", syntax_error),
("decimal(10.5, 2.5)", syntax_error),
("decimal(0, 2)", invalid_column),
("decimal(0)", invalid_column),
("decimal", syntax_error),
("decimal()", syntax_error),
]
for i in invalid_precision_scale:
sql = f"create table {self.db_name}.invalid_decimal_precision_scale (ts timestamp, c1 {i[0]})"
tdSql.error(sql, i[1])
## can't create decimal tag
sql = (
"create stable %s.invalid_decimal_tag (ts timestamp) tags (t1 decimal(10, 2))"
% (self.db_name)
)
tdSql.error(sql, invalid_column)
## alter table add/drop column
self.test_add_drop_columns_with_decimal(
self.norm_table_name, self.norm_tb_columns
)
self.test_add_drop_columns_with_decimal(self.stable_name, self.stb_columns)
## drop index from stb
### These ops will override the previous stbobjs and meta entries, so test it
## TODO test encode and compress for decimal type
sql = f'ALTER TABLE {self.db_name}.{self.norm_table_name} ADD COLUMN c101 decimal(37, 19) ENCODE "simple8b" COMPRESS "zstd"'
tdSql.error(sql, invalid_encode_param)
sql = f'ALTER TABLE {self.db_name}.{self.norm_table_name} ADD COLUMN c101 decimal(37, 19) ENCODE "delta-i" COMPRESS "zstd"'
tdSql.error(sql, invalid_encode_param)
sql = f'ALTER TABLE {self.db_name}.{self.norm_table_name} ADD COLUMN c101 decimal(37, 19) ENCODE "delta-d" COMPRESS "zstd"'
tdSql.error(sql, invalid_encode_param)
sql = f'ALTER TABLE {self.db_name}.{self.norm_table_name} ADD COLUMN c101 decimal(37, 19) ENCODE "bit-packing" COMPRESS "zstd"'
tdSql.error(sql, invalid_encode_param)
def test_insert_decimal_values(self):
self.log_test("insert decimal values")
for i in range(self.c_table_num):
TableInserter(
tdSql,
self.db_name,
f"{self.c_table_prefix}{i}",
self.stb_columns,
self.tags,
).insert(tb_insert_rows, 1537146000000, 500)
for i in range(self.c_table_num):
TableDataValidator(
self.stb_columns, self.c_table_prefix + str(i), self.db_name, i
).validate()
TableInserter(
tdSql, self.db_name, self.norm_table_name, self.norm_tb_columns
).insert(tb_insert_rows, 1537146000000, 500, flush_database=True)
TableDataValidator(
self.norm_tb_columns, self.norm_table_name, self.db_name
).validate()
tdSql.execute("flush database %s" % (self.db_name), queryTimes=1)
for i in range(self.c_table_num):
TableDataValidator(
self.stb_columns, self.c_table_prefix + str(i), self.db_name, i
).validate()
TableDataValidator(
self.norm_tb_columns, self.norm_table_name, self.db_name
).validate()
## insert null/None for decimal type
## insert with column format
def no_decimal_table_test(self):
columns = [
Column(DataType(TypeEnum.TINYINT)),
Column(DataType(TypeEnum.INT)),
Column(DataType(TypeEnum.BIGINT)),
Column(DataType(TypeEnum.DOUBLE)),
Column(DataType(TypeEnum.FLOAT)),
Column(DataType(TypeEnum.VARCHAR, 255)),
]
DecimalColumnTableCreater(
tdSql, self.db_name, self.no_decimal_col_tb_name, columns, []
).create()
TableInserter(tdSql, self.db_name, self.no_decimal_col_tb_name, columns).insert(
10000, 1537146000000, 500, flush_database=True
)
## TODO wjm test non support decimal version upgrade to decimal support version, and add decimal column
## Test metaentry compatibility problem for decimal type
## How to test it?
## Create table with no decimal type, the metaentries should not have extschma, and add decimal column, the metaentries should have extschema for all columns.
sql = f"ALTER TABLE {self.db_name}.{self.no_decimal_col_tb_name} ADD COLUMN c200 decimal(37, 19)"
tdSql.execute(sql, queryTimes=1) ## now meta entry has ext schemas
columns.append(Column(DecimalType(TypeEnum.DECIMAL, 37, 19)))
self.check_desc(self.no_decimal_col_tb_name, columns)
## After drop this only decimal column, the metaentries should not have extschema for all columns.
sql = (
f"ALTER TABLE {self.db_name}.{self.no_decimal_col_tb_name} DROP COLUMN c200"
)
tdSql.execute(sql, queryTimes=1) ## now meta entry has no ext schemas
columns.pop(len(columns) - 1)
self.check_desc(self.no_decimal_col_tb_name, columns)
sql = f"ALTER TABLE {self.db_name}.{self.no_decimal_col_tb_name} ADD COLUMN c200 int"
tdSql.execute(sql, queryTimes=1) ## meta entry has no ext schemas
columns.append(Column(DataType(TypeEnum.INT)))
self.check_desc(self.no_decimal_col_tb_name, columns)
self.test_add_drop_columns_with_decimal(self.no_decimal_col_tb_name, columns)
def test_decimal_ddl(self):
self.log_test("test_decimal_ddl")
tdSql.execute("create database test cachemodel 'both'", queryTimes=1)
self.test_decimal_column_ddl()
def test_decimal_and_stream(self):
self.log_test("test_decimal_and_stream")
create_stream = f"CREATE STREAM {self.stream_name} FILL_HISTORY 1 INTO {self.db_name}.{self.stream_out_stb} AS SELECT _wstart, count(c1), avg(c2), sum(c3) FROM {self.db_name}.{self.stable_name} INTERVAL(10s)"
tdSql.execute(create_stream, queryTimes=1, show=True)
self.wait_query_result(
f"select count(*) from {self.db_name}.{self.stream_out_stb}", [(50,)], 30
)
## test combine functions
create_stream = "CREATE STREAM stream2 trigger at_once watermark 100s INTO test.stream_out_stb_2 AS SELECT _wstart, count(*), min(c2), max(c2), last(c1), last(c3), avg(c2), sum(c3), min(c1), max(c1), avg(c1) FROM test.nt session(ts, 5s)"
tdSql.execute(create_stream, queryTimes=1, show=True)
cols_vals = []
ts = datetime.now()
rows = []
row = []
for col in self.norm_tb_columns:
v = col.generate_value(self.norm_table_name, has_null=False)
cols_vals.append(v)
row.append(v)
rows.append(row)
row = []
sql = f"insert into test.nt values('{ts}', {' ,'.join(cols_vals)})"
tdSql.execute(sql, queryTimes=1, show=True)
ts = ts + timedelta(seconds=6)
cols_vals = []
for col in self.norm_tb_columns:
v = col.generate_value(self.norm_table_name, has_null=False)
cols_vals.append(v)
row.append(v)
rows.append(row)
row = []
sql = f"insert into test.nt values('{ts}', {' ,'.join(cols_vals)})"
tdSql.execute(sql, queryTimes=1, show=True)
ts = ts - timedelta(seconds=4)
## waiting for the last two windows been calculated
time.sleep(10)
cols_vals = []
for col in self.norm_tb_columns:
v = col.generate_value(self.norm_table_name, has_null=False)
cols_vals.append(v)
row.append(v)
rows.append(row)
sql = f"insert into test.nt values('{ts}', {' ,'.join(cols_vals)})"
tdSql.execute(sql, queryTimes=1, show=True)
self.wait_query_result("select `count(*)` from test.stream_out_stb_2", [(3,)], 10)
res = TaosShell().query("select * from test.stream_out_stb_2")
if len(res) != 12: ## groupid
tdLog.exit(f"expect 12 columns but got: {len(res)}")
c1 = self.norm_tb_columns[0]
c2 = self.norm_tb_columns[1]
c3 = self.norm_tb_columns[2]
c1_vals = [get_decimal(v, c1.type_.scale()) for v in [row[0] for row in rows]]
c2_vals = [get_decimal(v, c2.type_.scale()) for v in [row[1] for row in rows]]
c3_vals = [get_decimal(v, c3.type_.scale()) for v in [row[2] for row in rows]]
min_c2 = Decimal(res[2][0])
if min_c2 != min(c2_vals):
tdLog.exit(f"expect min(c2) = {min(c2_vals)} got: {min_c2}")
max_c2 = Decimal(res[3][0])
if max_c2 != max(c2_vals):
tdLog.exit(f"expect max(c2) = {max(c2_vals)} got: {max_c2}")
def test_decimal_and_tsma(self):
self.log_test("test_decimal_and_tsma")
create_tsma = f"CREATE TSMA {self.tsma_name} ON {self.db_name}.{self.stable_name} FUNCTION(count(c1), min(c2), max(c3), avg(C3)) INTERVAL(1m)"
tdSql.execute(create_tsma, queryTimes=1, show=True)
self.wait_query_result(
f"select count(*) from {self.db_name}.{self.tsma_name}_tsma_res_stb_",
[(9 * self.c_table_num,)],
30,
)
def test_decimal_and_view(self):
self.log_test("test_decimal_and_view")
c1 = self.norm_tb_columns[0]
create_view_sql = f'create view {self.db_name}.view1 as select {c1} as c1, cast({c1} as decimal(38, 10)) as c2 from {self.db_name}.{self.norm_table_name}'
tdSql.execute(create_view_sql)
res = TaosShell().query(f'select c1 from {self.db_name}.view1')
if len(res[0]) != c1.get_cardinality(self.norm_table_name):
tdLog.exit(f"query from view1 got rows: {len(res)} expect: {c1.get_cardinality(self.norm_table_name)}")
for i in range(len(res[0])):
v_query = res[0][i]
v_insert = c1.get_val_for_execute(self.norm_table_name, i)
if (v_insert is None and v_query == 'NULL') or Decimal(v_query) == v_insert:
continue
else:
tdLog.exit(f"query from view got different results: {v_query}, expect: {v_insert}")
#self.check_desc("view1", [c1, Column(DecimalType(TypeEnum.DECIMAL, 38, 10))])
def log_test(self, name: str):
tdLog.info(f"{datetime.now()} start to test {name}")
def run(self):
self.test_decimal_ddl()
#self.no_decimal_table_test()
self.test_insert_decimal_values()
self.test_query_decimal()
self.test_decimal_and_tsma()
self.test_decimal_and_view()
self.test_decimal_and_stream()
def stop(self):
tdSql.close()
tdLog.success(f"{__file__} successfully executed")
def wait_query_result(self, sql: str, expect_result, times):
for i in range(times):
tdLog.info(f"wait query result for {sql}, times: {i}")
tdSql.query(sql, queryTimes=1)
results = tdSql.queryResult
if results != expect_result:
time.sleep(1)
continue
return True
tdLog.exit(
f"wait query result timeout for {sql} failed after {times} time, expect {expect_result}, but got {results}"
)
def wait_query_at_least_rows(self, sql: str, rows, wait_times):
for i in range(wait_times):
tdLog.info(f"wait query rows at least for {sql}, times: {i}")
tdSql.query(sql, queryTimes=1, show=True)
results = tdSql.queryResult
if len(results) < rows:
time.sleep(1)
continue
return True
tdLog.exit(
f"wait query rows at least for {sql} failed after {wait_times} times, expect at least {rows} rows, but got {len(results)} rows"
)
def check_decimal_binary_expr_with_col_results(
self, dbname, tbname, tb_cols: List[Column], exprs: List[DecimalColumnExpr]
):
if not binary_op_with_col_test:
return
for expr in exprs:
for col in tb_cols:
if col.name_ == '':
continue
for col2 in tb_cols:
if col2.name_ =='':
continue
if expr.should_skip_for_decimal([col, col2]):
continue
select_expr = expr.generate((col, col2))
sql = f"select {select_expr} from {dbname}.{tbname}"
res = TaosShell().query(sql)
if len(res) > 0:
expr.check(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
def check_decimal_binary_expr_with_const_col_results_for_one_expr(
self,
dbname,
tbname,
tb_cols: List[Column],
expr: DecimalColumnExpr,
get_constant_cols_func,
):
constant_cols = get_constant_cols_func()
for col in tb_cols:
if col.name_ == '':
continue
left_is_decimal = col.type_.is_decimal_type()
for const_col in constant_cols:
right_is_decimal = const_col.type_.is_decimal_type()
if expr.should_skip_for_decimal([col, const_col]):
continue
const_col.generate_value()
select_expr = expr.generate((const_col, col))
sql = f"select {select_expr} from {dbname}.{tbname}"
shell = TaosShell()
res = shell.query(sql)
if len(res) > 0:
expr.check(res[0], tbname)
select_expr = expr.generate((col, const_col))
sql = f"select {select_expr} from {dbname}.{tbname}"
res = shell.query(sql)
if len(res) > 0:
expr.check(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
def check_decimal_binary_expr_with_const_col_results(
self,
dbname,
tbname,
tb_cols: List[Column],
get_constant_cols_func,
get_exprs_func,
):
exprs: List[DecimalColumnExpr] = get_exprs_func()
if not binary_op_with_const_test:
return
ts: list[threading.Thread] = []
for expr in exprs:
t = self.run_in_thread2(
self.check_decimal_binary_expr_with_const_col_results_for_one_expr,
(dbname, tbname, tb_cols, expr, get_constant_cols_func),
)
ts.append(t)
for t in ts:
t.join()
def check_decimal_unary_expr_results(self, dbname, tbname, tb_cols: List[Column], exprs: List[DecimalColumnExpr]):
if not unary_op_test:
return
for expr in exprs:
for col in tb_cols:
if col.name_ == '':
continue
if expr.should_skip_for_decimal([col]):
continue
select_expr = expr.generate([col])
sql = f"select {select_expr} from {dbname}.{tbname}"
res = TaosShell().query(sql)
if len(res) > 0:
expr.check(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
def run_in_thread(self, times, func, params) -> threading.Thread:
threads: List[threading.Thread] = []
for i in range(times):
t = threading.Thread(target=func, args=params)
t.start()
threads.append(t)
for t in threads:
t.join()
def run_in_thread2(self, func, params) -> threading.Thread:
t = threading.Thread(target=func, args=params)
t.start()
return t
## test others unsupported types operator with decimal
def test_decimal_unsupported_types(self):
unsupported_type = [
TypeEnum.JSON,
TypeEnum.GEOMETRY,
TypeEnum.VARBINARY,
]
all_type_columns: List[Column] = Column.get_all_type_columns([TypeEnum.JSON])
tbname = "test_decimal_unsupported_types"
tag_cols = [Column(DataType(TypeEnum.JSON))]
tb_creater = DecimalColumnTableCreater( tdSql, self.db_name, tbname, all_type_columns, tag_cols)
tb_creater.create()
tb_creater.create_child_table(tbname, 10, tag_cols, ['{"k1": "v1"}'])
for i in range(10):
TableInserter(tdSql, self.db_name, f'{tbname}{i}', all_type_columns, tag_cols).insert(100, 1537146000000, 500, flush_database=True)
for col in all_type_columns:
## only test decimal cols
if not col.type_.is_decimal_type():
continue
for unsupported_col in all_type_columns:
## only test unsupported cols
if not unsupported_col.type_.type in unsupported_type:
continue
for binary_op in DecimalBinaryOperator.get_all_binary_ops():
select_expr = binary_op.generate((col, unsupported_col))
sql = f"select {select_expr} from {self.db_name}.{tbname}"
select_expr_reverse = binary_op.generate((unsupported_col, col))
sql_reverse = (
f"select {select_expr_reverse} from {self.db_name}.{tbname}"
)
tdLog.info(
f"select expr: {select_expr} with type: {col.type_} and {unsupported_col.type_} should err"
)
err = tdSql.error(sql)
if tdSql.errno != invalid_operation and tdSql.errno != scalar_convert_err:
tdLog.exit(f"expected err not occured for sql: {sql}, expect: {invalid_operation} or {scalar_convert_err}, but got {tdSql.errno}")
tdLog.info(
f"select expr: {select_expr} with type: {unsupported_col.type_} and {col.type_} should err"
)
err = tdSql.error(sql_reverse)
if tdSql.errno != invalid_operation and tdSql.errno != scalar_convert_err:
tdLog.exit(f"expected err not occured for sql: {sql}, expect: {invalid_operation} or {scalar_convert_err}, but got {tdSql.errno}")
def test_decimal_operators(self):
self.log_test("decimal operators")
self.test_decimal_unsupported_types()
## tables: meters, nt
## columns: c1, c2, c3, c4, c5, c7, c8, c9, c10, c99, c100
binary_operators = DecimalBinaryOperator.get_all_binary_ops()
## decimal operator with constants of all other types
self.run_in_thread(
operator_test_round,
self.check_decimal_binary_expr_with_const_col_results,
(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
Column.get_decimal_oper_const_cols,
DecimalBinaryOperator.get_all_binary_ops,
),
)
## test decimal column op decimal column
for i in range(operator_test_round):
self.check_decimal_binary_expr_with_col_results(
self.db_name, self.norm_table_name, self.norm_tb_columns, binary_operators)
unary_operators = DecimalUnaryOperator.get_all_unary_ops()
self.check_decimal_unary_expr_results(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
unary_operators,)
def check_decimal_where_with_binary_expr_with_const_col_results(
self,
dbname,
tbname,
tb_cols: List[Column],
constant_cols: List[Column],
exprs: List[DecimalColumnExpr],
):
if not binary_op_in_where_test:
return
for expr in exprs:
tdLog.info(f"start to test decimal where filtering with const cols for expr: {expr.format_}")
for col in tb_cols:
if col.name_ == '':
continue
for const_col in constant_cols:
if expr.should_skip_for_decimal([col, const_col]):
continue
const_col.generate_value()
select_expr = expr.generate((const_col, col))
if const_col.type_.is_decimal_type():
expr.query_col = col
else:
expr.query_col = col
sql = f"select {expr.query_col} from {dbname}.{tbname} where {select_expr}"
res = TaosShell().query(sql)
##TODO wjm no need to check len(res) for filtering test, cause we need to check for every row in the table to check if the filtering is working
if len(res) > 0:
expr.check_for_filtering(res[0], tbname)
select_expr = expr.generate((col, const_col))
sql = f"select {expr.query_col} from {dbname}.{tbname} where {select_expr}"
res = TaosShell().query(sql)
if len(res) > 0:
expr.check_for_filtering(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
def check_decimal_where_with_binary_expr_with_col_results(
self, dbname, tbname, tb_cols: List[Column], exprs: List[DecimalColumnExpr]
):
if not binary_op_in_where_test:
return
for expr in exprs:
tdLog.info(f"start to test decimal where filtering with cols for expr{expr.format_}")
for col in tb_cols:
if col.name_ == '':
continue
for col2 in tb_cols:
if col2.name_ == '':
continue
if expr.should_skip_for_decimal([col, col2]):
continue
select_expr = expr.generate((col, col2))
if col.type_.is_decimal_type():
expr.query_col = col
else:
expr.query_col = col2
sql = f"select {expr.query_col} from {dbname}.{tbname} where {select_expr}"
res = TaosShell().query(sql)
if len(res) > 0:
expr.check_for_filtering(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
select_expr = expr.generate((col2, col))
sql = f"select {expr.query_col} from {dbname}.{tbname} where {select_expr}"
res = TaosShell().query(sql)
if len(res) > 0:
expr.check_for_filtering(res[0], tbname)
else:
tdLog.info(f"sql: {sql} got no output")
def test_query_decimal_where_clause(self):
tdLog.info("start to test decimal where filtering")
binary_compare_ops = DecimalBinaryOperator.get_all_filtering_binary_compare_ops()
const_cols = Column.get_decimal_oper_const_cols()
for i in range(operator_test_round):
self.check_decimal_where_with_binary_expr_with_const_col_results(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
const_cols,
binary_compare_ops,
)
for i in range(operator_test_round):
self.check_decimal_where_with_binary_expr_with_col_results(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
binary_compare_ops)
## TODO wjm
## 3. (dec op const col) op const col
## 4. (dec op dec) op const col
## 5. (dec op const col) op dec
## 6. (dec op dec) op dec
def test_query_with_order_by_for_tb(self, tbname: str, cols: list):
for col in cols:
if col.type_.is_decimal_type() and col.name_ != '':
self.test_query_with_order_by(col, tbname)
def test_query_with_order_by(self, order_col: Column, tbname):
sql = f"select {order_col} from {self.db_name}.{tbname} order by {order_col} asc"
query_res = TaosShell().query(sql)[0]
calculated_ordered_res = order_col.get_ordered_result(tbname, True)
for v_from_query, v_from_calc in zip(query_res, calculated_ordered_res):
if v_from_calc is None:
if v_from_query != 'NULL':
tdLog.exit(f"query result: {v_from_query} not equal to calculated result: NULL")
elif Decimal(v_from_query) != v_from_calc:
tdLog.exit(f"query result: {v_from_query} not equal to calculated result: {v_from_calc}")
def test_query_decimal_order_clause(self):
self.log_test("decimal order by")
self.test_query_with_order_by_for_tb(self.norm_table_name, self.norm_tb_columns)
self.test_query_with_order_by_for_tb(self.stable_name, self.stb_columns)
def test_query_decimal_group_by_decimal(self, tbname: str, cols: list):
for col in cols:
if col.type_.is_decimal_type() and col.name_ != '':
sql = f"select count(*) from {self.db_name}.{tbname} group by {col}"
query_res = TaosShell().query(sql)[0]
calculated_grouped_res = col.get_group_num(tbname)
if len(query_res) != calculated_grouped_res:
tdLog.exit(f"query result: {len(query_res)} not equal to calculated result: {calculated_grouped_res}")
def test_query_decimal_group_by_clause(self):
self.log_test("decimal group by")
self.test_query_decimal_group_by_decimal(self.norm_table_name, self.norm_tb_columns)
self.test_query_decimal_group_by_decimal(self.stable_name, self.stb_columns)
def test_query_decimal_group_by_with_having(self, tbname, cols: list):
for col in cols:
if col.type_.is_decimal_type() and col.name_ != '':
sql = f"select count(*) from {self.db_name}.{tbname} group by {col} having {col} is not null"
query_res = TaosShell().query(sql)[0]
calculated_grouped_res = col.get_group_num(tbname, ignore_null=True)
if len(query_res) != calculated_grouped_res:
tdLog.exit(f"query result: {len(query_res)} not equal to calculated result: {calculated_grouped_res}")
def test_query_decimal_having_clause(self):
self.log_test("decimal having")
self.test_query_decimal_group_by_with_having(self.norm_table_name, self.norm_tb_columns)
self.test_query_decimal_group_by_with_having(self.stable_name, self.stb_columns)
def test_query_decimal_interval_fill(self):
self.test_query_decimal_interval_fill_for_tb(self.norm_table_name)
self.test_query_decimal_interval_fill_for_tb(self.stable_name)
def test_query_decimal_interval_fill_for_tb(self, tbname):
sql = f"select _wstart, count(*), avg(c1),tbname from {self.db_name}.{tbname} where ts >= 1537146000000 and ts < 1537146005000 partition by tbname interval(200a) fill(value,0,1.234)"
tdLog.debug(sql)
res = TaosShell().query(sql)
wstart_col = res[0]
count_col = res[1]
avg_col = res[2]
for wstart, count, avg in zip(wstart_col, count_col, avg_col):
if int(count) == 0:
if Decimal(avg) != Decimal('1.234'):
tdLog.exit(f"sql: {sql} wstart: {wstart} count is 0, but avg is not 1.234")
fill_prev_res = TaosShell().query(f"select _wstart, count(*), avg(c1),tbname from {self.db_name}.{tbname} where ts >= 1537146000000 and ts < 1537146005000 partition by tbname interval(200a) fill(prev)")
fill_prev_avg_col = fill_prev_res[2]
for i in range(len(wstart_col)):
if int(count_col[i]) == 0:
if i > 0 and fill_prev_avg_col[i] != fill_prev_avg_col[i-1]:
tdLog.exit(f"sql: {sql} count is 0, but avg is not same as previous row")
def test_query_decimal_case_when(self):
self.log_test("decimal case when")
sql = "select case when cast(1 as decimal(10, 4)) >= 1 then cast(88888888.88 as decimal(10,2)) else cast(3.333 as decimal(10,3)) end"
res = TaosShell().query(sql)[0]
if res[0] != "88888888.88":
tdLog.exit(f"query result for sql: {sql}: {res[0]} not equal to expected result: 88888888.88")
sql = "select case when cast(1 as decimal(10, 4)) > 1 then cast(88888888.88 as decimal(10,2)) else cast(3.333 as decimal(10,3)) end"
res = TaosShell().query(sql)[0]
if res[0] != "3.33":
tdLog.exit(f"query result for sql: {sql}: {res[0]} not equal to expected result: 3.33")
sql = "select case when cast(1 as decimal(10, 4)) > 1 then cast(88888888.88 as decimal(10,2)) else 1.23 end"
res = TaosShell().query(sql)[0]
if float(res[0]) != float(1.23):
tdLog.exit(f"query result for sql: {sql}: {res[0]} not equal to expected result: 1.23")
sql = "select case when cast(1 as decimal(10, 4)) >= 1 then cast(88888888.88 as decimal(10,2)) else 1.23 end"
res = TaosShell().query(sql)[0]
if float(res[0]) != float(88888888.88):
tdLog.exit(f"query result for sql: {sql}: {res[0]} not equal to expected result: 88888888.88")
sql = "select case when cast(1 as decimal(10, 4)) >= 1 then cast(88888888.88 as decimal(10,2)) else '1.23' end"
res = TaosShell().query(sql)[0]
if float(res[0]) != 88888888.88:
tdLog.exit(f"query result for sql: {sql}: {res[0]} not equal to expected result: 88888888.88")
sql = "select case when cast(1 as decimal(10, 4)) > 1 then cast(88888888.88 as decimal(10,2)) else '1.23' end"
res = TaosShell().query(sql)[0]
if float(res[0]) != 1.23:
tdLog.exit(f"query result for sql: {sql}: {res[0]} not equal to expected result: 88888888.88")
sql = "select case when cast(1 as decimal(10, 4)) > 1 then cast(88888888.88 as decimal(10,2)) else 'abcd' end"
res = TaosShell().query(sql)[0]
if float(res[0]) != 0:
tdLog.exit(f"query result for sql: {sql}: {res[0]} not equal to expected result: 0")
def test_decimal_agg_funcs(self, dbname, tbname, tb_cols: List[Column], get_agg_funcs_func):
agg_funcs: List[DecimalFunction] = get_agg_funcs_func()
for func in agg_funcs:
for col in tb_cols:
if col.name_ == '' or func.should_skip_for_decimal([col]):
continue
func.query_col = col
select_expr = func.generate([col])
sql = f"select {select_expr} from {dbname}.{tbname}"
res = TaosShell().query(sql)
if len(res) > 0:
res = res[0]
func.check_for_agg_func(res, tbname, func, tbname == self.stable_name)
def test_decimal_cast_func(self, dbname, tbname, tb_cols: List[Column]):
for col in tb_cols:
if col.name_ == '':
continue
to_types: list[DataType] = DecimalCastTypeGenerator(col.type_).generate(cast_func_test_round)
for t in to_types:
cast_func = DecimalCastFunction()
expr = cast_func.generate([col, t])
sql = f"select {expr} from {dbname}.{tbname}"
res = TaosShell().query(sql)
if len(res) > 0:
res = res[0]
cast_func.check(res, tbname)
def test_decimal_functions(self):
self.log_test("decimal functions")
if not test_decimal_funcs:
return
self.test_decimal_agg_funcs( self.db_name, self.norm_table_name, self.norm_tb_columns, DecimalFunction.get_decimal_agg_funcs)
self.test_decimal_agg_funcs( self.db_name, self.stable_name, self.stb_columns, DecimalFunction.get_decimal_agg_funcs)
self.test_decimal_cast_func(self.db_name, self.norm_table_name, self.norm_tb_columns)
def test_query_decimal(self):
self.log_test("decimal query")
if not decimal_test_query:
return
#self.test_decimal_operators() ## tested in decimal.py
#self.test_query_decimal_where_clause() ## tested in decimal2.py
self.test_decimal_functions()
self.test_query_decimal_order_clause()
self.test_query_decimal_case_when()
self.test_query_decimal_group_by_clause()
self.test_query_decimal_having_clause()
self.test_query_decimal_interval_fill()
event = threading.Event()
tdCases.addLinux(__file__, TDTestCase())
tdCases.addWindows(__file__, TDTestCase())