test decimal operators

This commit is contained in:
wangjiaming0909 2025-02-21 08:51:37 +08:00
parent efe432cff7
commit 1fad80e631
1 changed files with 226 additions and 360 deletions

View File

@ -4,6 +4,7 @@ from re import A
import time import time
import threading import threading
import secrets import secrets
from tkinter.tix import COLUMN
from regex import D from regex import D
from sympy import true from sympy import true
@ -267,23 +268,15 @@ class DataType:
def __repr__(self): def __repr__(self):
return f"DataType({self.type}, {self.length}, {self.type_mod})" return f"DataType({self.type}, {self.length}, {self.type_mod})"
def is_decimal_type(self):
return self.type == TypeEnum.DECIMAL or self.type == TypeEnum.DECIMAL64
def prec(self): def prec(self):
return 0 return 0
def scale(self): def scale(self):
return 0 return 0
def construct_type_value(self, val: str):
if (
self.type == TypeEnum.BINARY
or self.type == TypeEnum.VARCHAR
or self.type == TypeEnum.NCHAR
or self.type == TypeEnum.VARBINARY
):
return f"'{val}'"
else:
return val
## TODO generate NULL, None ## TODO generate NULL, None
def generate_value(self) -> str: def generate_value(self) -> str:
if self.type == TypeEnum.BOOL: if self.type == TypeEnum.BOOL:
@ -321,32 +314,6 @@ class DataType:
def check(self, values, offset: int): def check(self, values, offset: int):
return True return True
@staticmethod
def get_all_type_columns() -> List:
other_types = [
DataType(TypeEnum.BOOL),
DataType(TypeEnum.TINYINT),
DataType(TypeEnum.SMALLINT),
DataType(TypeEnum.INT),
DataType(TypeEnum.BIGINT),
DataType(TypeEnum.FLOAT),
DataType(TypeEnum.DOUBLE),
DataType(TypeEnum.VARCHAR, 255),
DataType(TypeEnum.TIMESTAMP),
DataType(TypeEnum.NCHAR),
DataType(TypeEnum.UTINYINT),
DataType(TypeEnum.USMALLINT),
DataType(TypeEnum.UINT),
DataType(TypeEnum.UBIGINT),
DataType(TypeEnum.JSON),
DataType(TypeEnum.VARBINARY),
DataType(TypeEnum.DECIMAL),
DataType(TypeEnum.BINARY),
DataType(TypeEnum.GEOMETRY),
DataType(TypeEnum.DECIMAL64),
]
return other_types
class DecimalType(DataType): class DecimalType(DataType):
def __init__(self, type, precision: int, scale: int): def __init__(self, type, precision: int, scale: int):
@ -376,8 +343,8 @@ class DecimalType(DataType):
def __str__(self): def __str__(self):
return f"DECIMAL({self.precision_}, {self.scale()})" return f"DECIMAL({self.precision_}, {self.scale()})"
def __eq__(self, other): def __eq__(self, other: DataType):
return self.precision_ == other.precision and self.scale() == other.scale return self.precision_ == other.prec() and self.scale() == other.scale()
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
@ -432,34 +399,97 @@ class DecimalType(DataType):
) )
class Column:
def __init__(self, type: DataType):
self.type_: DataType = type
self.name_: str = ""
def generate_value(self):
return self.type_.generate_value()
def get_type_str(self) -> str:
return str(self.type_)
def set_name(self, name: str):
self.name_ = name
def check(self, values, offset: int):
return self.type_.check(values, offset)
def construct_type_value(self, val: str):
if (
self.type_.type == TypeEnum.BINARY
or self.type_.type == TypeEnum.VARCHAR
or self.type_.type == TypeEnum.NCHAR
or self.type_.type == TypeEnum.VARBINARY
):
return f"'{val}'"
else:
return val
@staticmethod
def get_all_type_columns() -> List:
other_types = [
Column(DataType(TypeEnum.BOOL)),
Column(DataType(TypeEnum.TINYINT)),
Column(DataType(TypeEnum.SMALLINT)),
Column(DataType(TypeEnum.INT)),
Column(DataType(TypeEnum.BIGINT)),
Column(DataType(TypeEnum.FLOAT)),
Column(DataType(TypeEnum.DOUBLE)),
Column(DataType(TypeEnum.VARCHAR, 255)),
Column(DataType(TypeEnum.TIMESTAMP)),
Column(DataType(TypeEnum.NCHAR)),
Column(DataType(TypeEnum.UTINYINT)),
Column(DataType(TypeEnum.USMALLINT)),
Column(DataType(TypeEnum.UINT)),
Column(DataType(TypeEnum.UBIGINT)),
Column(DataType(TypeEnum.JSON)),
Column(DataType(TypeEnum.VARBINARY)),
Column(DataType(TypeEnum.DECIMAL)),
Column(DataType(TypeEnum.BINARY)),
Column(DataType(TypeEnum.GEOMETRY)),
Column(DataType(TypeEnum.DECIMAL64)),
]
return other_types
class DecimalColumnTableCreater: class DecimalColumnTableCreater:
def __init__( def __init__(
self, self,
conn, conn,
dbName: str, dbName: str,
tbName: str, tbName: str,
columns_types: List[DataType], columns: List[Column],
tags_types: List[DataType] = [], tags_cols: List[Column] = [],
col_prefix: str = "c",
tag_prefix: str = "t",
): ):
self.conn = conn self.conn = conn
self.dbName = dbName self.dbName = dbName
self.tbName = tbName self.tbName = tbName
self.tags_types = tags_types self.tags_cols = tags_cols
self.columns_types = columns_types self.columns: List[Column] = columns
self.col_prefix = col_prefix
self.tag_prefix = tag_prefix
def create(self): def create(self):
if len(self.tags_types) > 0: if len(self.tags_cols) > 0:
table = "stable" table = "stable"
else: else:
table = "table" table = "table"
sql = f"create {table} {self.dbName}.{self.tbName} (ts timestamp" sql = f"create {table} {self.dbName}.{self.tbName} (ts timestamp"
for i, column in enumerate(self.columns_types): for i, column in enumerate(self.columns):
sql += f", c{i+1} {column}" tbname = f"{self.col_prefix}{i+1}"
if self.tags_types: sql += f", {tbname} {column.get_type_str()}"
column.set_name(tbname)
if self.tags_cols:
sql += ") tags(" sql += ") tags("
for i, tag in enumerate(self.tags_types): for i, tag in enumerate(self.tags_cols):
sql += f"t{i+1} {tag}" tagname = f"{self.tag_prefix}{i+1}"
if i != len(self.tags_types) - 1: sql += f"{tagname} {tag.get_type_str()}"
tag.set_name(tagname)
if i != len(self.tags_cols) - 1:
sql += ", " sql += ", "
sql += ")" sql += ")"
self.conn.execute(sql, queryTimes=1) self.conn.execute(sql, queryTimes=1)
@ -468,14 +498,15 @@ class DecimalColumnTableCreater:
self, self,
ctbPrefix: str, ctbPrefix: str,
ctbNum: int, ctbNum: int,
tag_types: List[DataType], tag_cols: List[Column],
tag_values: List[str], tag_values: List[str],
): ):
for i in range(ctbNum): for i in range(ctbNum):
sql = f"create table {self.dbName}.{ctbPrefix}{i} using {self.dbName}.{self.tbName} tags(" tbname = f"{ctbPrefix}{i}"
for j, tag in enumerate(tag_types): sql = f"create table {self.dbName}.{tbname} using {self.dbName}.{self.tbName} tags("
for j, tag in enumerate(tag_cols):
sql += f"{tag.construct_type_value(tag_values[j])}" sql += f"{tag.construct_type_value(tag_values[j])}"
if j != len(tag_types) - 1: if j != len(tag_cols) - 1:
sql += ", " sql += ", "
sql += ")" sql += ")"
self.conn.execute(sql, queryTimes=1) self.conn.execute(sql, queryTimes=1)
@ -487,21 +518,21 @@ class TableInserter:
conn, conn,
dbName: str, dbName: str,
tbName: str, tbName: str,
columns_types: List[DataType], columns: List[Column],
tags_types: List[DataType] = [], tags_cols: List[Column] = [],
): ):
self.conn = conn self.conn = conn
self.dbName = dbName self.dbName = dbName
self.tbName = tbName self.tbName = tbName
self.tags_types = tags_types self.tag_cols = tags_cols
self.columns_types = columns_types self.columns = columns
def insert(self, rows: int, start_ts: int, step: int, flush_database: bool = False): def insert(self, rows: int, start_ts: int, step: int, flush_database: bool = False):
pre_insert = f"insert into {self.dbName}.{self.tbName} values" pre_insert = f"insert into {self.dbName}.{self.tbName} values"
sql = pre_insert sql = pre_insert
for i in range(rows): for i in range(rows):
sql += f"({start_ts + i * step}" sql += f"({start_ts + i * step}"
for column in self.columns_types: for column in self.columns:
sql += f", {column.generate_value()}" sql += f", {column.generate_value()}"
sql += ")" sql += ")"
if i != rows - 1: if i != rows - 1:
@ -521,9 +552,7 @@ class TableInserter:
class TableDataValidator: class TableDataValidator:
def __init__( def __init__(self, columns: List[Column], tbName: str, dbName: str, tbIdx: int = 0):
self, columns: List[DataType], tbName: str, dbName: str, tbIdx: int = 0
):
self.columns = columns self.columns = columns
self.tbName = tbName self.tbName = tbName
self.dbName = dbName self.dbName = dbName
@ -535,25 +564,29 @@ class TableDataValidator:
row_num = len(res[1]) row_num = len(res[1])
colIdx = 1 colIdx = 1
for col in self.columns: for col in self.columns:
if col.type == TypeEnum.DECIMAL or col.type == TypeEnum.DECIMAL64: if (
col.type_.type == TypeEnum.DECIMAL
or col.type_.type == TypeEnum.DECIMAL64
):
col.check(res[colIdx], row_num * self.tbIdx) col.check(res[colIdx], row_num * self.tbIdx)
colIdx += 1 colIdx += 1
class DecimalColumnExpr: class DecimalColumnExpr:
def __init__(self, column: DecimalType): def __init__(self, format: str, executor):
self.column_ = column self.format_: str = ""
self.executor_ = None
def execute(self): def execute(self, params: List):
pass return self.executor_(params)
def generate(self): def generate(self, format_params) -> str:
pass return f"({self.format_ % format_params})"
class DecimalBinaryOperator(DecimalColumnExpr): class DecimalBinaryOperator(DecimalColumnExpr):
def __init__(self, op: str, column: DecimalType): def __init__(self, op: str):
super().__init__(column) super().__init__()
self.op_ = op self.op_ = op
def __str__(self): def __str__(self):
@ -562,6 +595,10 @@ class DecimalBinaryOperator(DecimalColumnExpr):
def generate(self): def generate(self):
pass pass
@staticmethod
def execute_plus(params):
return params[0] + params[1]
def execute(self, left, right): def execute(self, left, right):
if self.op_ == "+": if self.op_ == "+":
return left + right return left + right
@ -617,7 +654,9 @@ class TDTestCase:
self.tags = [] self.tags = []
self.stable_name = "meters" self.stable_name = "meters"
self.norm_table_name = "nt" self.norm_table_name = "nt"
self.col_prefix = "c"
self.c_table_prefix = "t" self.c_table_prefix = "t"
self.tag_name_prefix = "t"
self.db_name = "test" self.db_name = "test"
self.c_table_num = 10 self.c_table_num = 10
self.no_decimal_col_tb_name = "tt" self.no_decimal_col_tb_name = "tt"
@ -632,241 +671,19 @@ class TDTestCase:
tdLog.debug(f"start to excute {__file__}") tdLog.debug(f"start to excute {__file__}")
tdSql.init(conn.cursor(), False) tdSql.init(conn.cursor(), False)
def create_database(
self, tsql, dbName, dropFlag=1, vgroups=2, replica=1, duration: str = "1d"
):
if dropFlag == 1:
tsql.execute("drop database if exists %s" % (dbName))
tsql.execute(
"create database if not exists %s vgroups %d replica %d duration %s"
% (dbName, vgroups, replica, duration)
)
tdLog.debug("complete to create database %s" % (dbName))
def create_stable(self, tsql, paraDict):
colString = tdCom.gen_column_type_str(
colname_prefix=paraDict["colPrefix"], column_elm_list=paraDict["colSchema"]
)
tagString = tdCom.gen_tag_type_str(
tagname_prefix=paraDict["tagPrefix"], tag_elm_list=paraDict["tagSchema"]
)
sqlString = f"create table if not exists %s.%s (%s) tags (%s)" % (
paraDict["dbName"],
paraDict["stbName"],
colString,
tagString,
)
tdLog.debug("%s" % (sqlString))
tsql.execute(sqlString)
def create_ctable(
self,
tsql=None,
dbName="dbx",
stbName="stb",
ctbPrefix="ctb",
ctbNum=1,
ctbStartIdx=0,
):
for i in range(ctbNum):
sqlString = (
"create table %s.%s%d using %s.%s tags(%d, 'tb%d', 'tb%d', %d, %d, %d)"
% (
dbName,
ctbPrefix,
i + ctbStartIdx,
dbName,
stbName,
(i + ctbStartIdx) % 5,
i + ctbStartIdx + random.randint(1, 100),
i + ctbStartIdx + random.randint(1, 100),
i + ctbStartIdx + random.randint(1, 100),
i + ctbStartIdx + random.randint(1, 100),
i + ctbStartIdx + random.randint(1, 100),
)
)
tsql.execute(sqlString)
tdLog.debug(
"complete to create %d child tables by %s.%s" % (ctbNum, dbName, stbName)
)
def init_normal_tb(
self, tsql, db_name: str, tb_name: str, rows: int, start_ts: int, ts_step: int
):
sql = (
"CREATE TABLE %s.%s (ts timestamp, c1 INT, c2 INT, c3 INT, c4 double, c5 VARCHAR(255))"
% (db_name, tb_name)
)
tsql.execute(sql)
sql = "INSERT INTO %s.%s values" % (db_name, tb_name)
for j in range(rows):
sql += f'(%d, %d,%d,%d,{random.random()},"varchar_%d"),' % (
start_ts + j * ts_step + randrange(500),
j % 10 + randrange(200),
j % 10,
j % 10,
j % 10 + randrange(100),
)
tsql.execute(sql)
def insert_data(
self, tsql, dbName, ctbPrefix, ctbNum, rowsPerTbl, batchNum, startTs, tsStep
):
tdLog.debug("start to insert data ............")
tsql.execute("use %s" % dbName)
pre_insert = "insert into "
sql = pre_insert
for i in range(ctbNum):
rowsBatched = 0
sql += " %s.%s%d values " % (dbName, ctbPrefix, i)
for j in range(rowsPerTbl):
if i < ctbNum / 2:
sql += "(%d, %d, %d, %d,%d,%d,%d,true,'binary%d', 'nchar%d') " % (
startTs + j * tsStep + randrange(500),
j % 10 + randrange(100),
j % 10 + randrange(200),
j % 10,
j % 10,
j % 10,
j % 10,
j % 10,
j % 10,
)
else:
sql += (
"(%d, %d, NULL, %d,NULL,%d,%d,true,'binary%d', 'nchar%d') "
% (
startTs + j * tsStep + randrange(500),
j % 10,
j % 10,
j % 10,
j % 10,
j % 10,
j % 10,
)
)
rowsBatched += 1
if (rowsBatched == batchNum) or (j == rowsPerTbl - 1):
tsql.execute(sql)
rowsBatched = 0
if j < rowsPerTbl - 1:
sql = "insert into %s.%s%d values " % (dbName, ctbPrefix, i)
else:
sql = "insert into "
if sql != pre_insert:
tsql.execute(sql)
tdLog.debug("insert data ............ [OK]")
return
def init_data(
self,
db: str = "test",
ctb_num: int = 10,
rows_per_ctb: int = 10000,
start_ts: int = 1537146000000,
ts_step: int = 500,
):
tdLog.printNoPrefix(
"======== prepare test env include database, stable, ctables, and insert data: "
)
paraDict = {
"dbName": db,
"dropFlag": 1,
"vgroups": 4,
"stbName": "meters",
"colPrefix": "c",
"tagPrefix": "t",
"colSchema": [
{"type": "INT", "count": 1},
{"type": "BIGINT", "count": 1},
{"type": "FLOAT", "count": 1},
{"type": "DOUBLE", "count": 1},
{"type": "smallint", "count": 1},
{"type": "tinyint", "count": 1},
{"type": "bool", "count": 1},
{"type": "binary", "len": 10, "count": 1},
{"type": "nchar", "len": 10, "count": 1},
],
"tagSchema": [
{"type": "INT", "count": 1},
{"type": "nchar", "len": 20, "count": 1},
{"type": "binary", "len": 20, "count": 1},
{"type": "BIGINT", "count": 1},
{"type": "smallint", "count": 1},
{"type": "DOUBLE", "count": 1},
],
"ctbPrefix": "t",
"ctbStartIdx": 0,
"ctbNum": ctb_num,
"rowsPerTbl": rows_per_ctb,
"batchNum": 3000,
"startTs": start_ts,
"tsStep": ts_step,
}
paraDict["vgroups"] = self.vgroups
paraDict["ctbNum"] = ctb_num
paraDict["rowsPerTbl"] = rows_per_ctb
tdLog.info("create database")
self.create_database(
tsql=tdSql,
dbName=paraDict["dbName"],
dropFlag=paraDict["dropFlag"],
vgroups=paraDict["vgroups"],
replica=self.replicaVar,
duration=self.duraion,
)
tdLog.info("create stb")
self.create_stable(tsql=tdSql, paraDict=paraDict)
tdLog.info("create child tables")
self.create_ctable(
tsql=tdSql,
dbName=paraDict["dbName"],
stbName=paraDict["stbName"],
ctbPrefix=paraDict["ctbPrefix"],
ctbNum=paraDict["ctbNum"],
ctbStartIdx=paraDict["ctbStartIdx"],
)
self.insert_data(
tsql=tdSql,
dbName=paraDict["dbName"],
ctbPrefix=paraDict["ctbPrefix"],
ctbNum=paraDict["ctbNum"],
rowsPerTbl=paraDict["rowsPerTbl"],
batchNum=paraDict["batchNum"],
startTs=paraDict["startTs"],
tsStep=paraDict["tsStep"],
)
self.init_normal_tb(
tdSql,
paraDict["dbName"],
"norm_tb",
paraDict["rowsPerTbl"],
paraDict["startTs"],
paraDict["tsStep"],
)
def check_desc_for_one_ctb( def check_desc_for_one_ctb(
self, ctbPrefix: str, columns: List[DataType], tags: List[DataType] = [] self, ctbPrefix: str, columns: List[Column], tags: List[Column] = []
): ):
ctb_idx = randrange(self.c_table_num) ctb_idx = randrange(self.c_table_num)
return self.check_desc(f"{ctbPrefix}{ctb_idx}", columns, tags) return self.check_desc(f"{ctbPrefix}{ctb_idx}", columns, tags)
def check_desc( def check_desc(self, tbname: str, columns: List[Column], tags: List[Column] = []):
self, tbname: str, column_types: List[DataType], tag_types: List[DataType] = []
):
sql = f"desc {self.db_name}.{tbname}" sql = f"desc {self.db_name}.{tbname}"
tdSql.query(sql, queryTimes=1) tdSql.query(sql, queryTimes=1)
results = tdSql.queryResult results = tdSql.queryResult
for i, column_type in enumerate(column_types): for i, col in enumerate(columns):
if column_type.type == TypeEnum.DECIMAL: if col.type_.type == TypeEnum.DECIMAL:
if results[i + 1][1] != column_type.__str__(): if results[i + 1][1] != col.type_.__str__():
tdLog.info(str(results)) tdLog.info(str(results))
tdLog.exit( tdLog.exit(
f"check desc failed for table: {tbname} column {results[i+1][0]} type is {results[i+1][1]}, expect DECIMAL" f"check desc failed for table: {tbname} column {results[i+1][0]} type is {results[i+1][1]}, expect DECIMAL"
@ -880,45 +697,45 @@ class TDTestCase:
f"check desc failed for table: {tbname} column {results[i+1][0]} compression is {results[i+1][4]}, expect {DecimalType.default_compression()}" f"check desc failed for table: {tbname} column {results[i+1][0]} compression is {results[i+1][4]}, expect {DecimalType.default_compression()}"
) )
if tbname == self.stable_name: if tbname == self.stable_name:
self.check_desc_for_one_ctb(self.c_table_prefix, column_types, tag_types) self.check_desc_for_one_ctb(self.c_table_prefix, columns, tags)
def check_show_create_table( def check_show_create_table(
self, tbname: str, column_types: List[DataType], tag_types: List[DataType] = [] self, tbname: str, cols: List[Column], tags: List[Column] = []
): ):
sql = f"show create table {self.db_name}.{tbname}" sql = f"show create table {self.db_name}.{tbname}"
tdSql.query(sql, queryTimes=1) tdSql.query(sql, queryTimes=1)
create_table_sql = tdSql.queryResult[0][1] create_table_sql = tdSql.queryResult[0][1]
decimal_idx = 0 decimal_idx = 0
results = re.findall(r"DECIMAL\((\d+),(\d+)\)", create_table_sql) results = re.findall(r"DECIMAL\((\d+),(\d+)\)", create_table_sql)
for i, column_type in enumerate(column_types): for i, col in enumerate(cols):
if ( if (
column_type.type == TypeEnum.DECIMAL col.type_.type == TypeEnum.DECIMAL
or column_type.type == TypeEnum.DECIMAL64 or col.type_.type == TypeEnum.DECIMAL64
): ):
result_type = DecimalType( result_type = DecimalType(
column_type.type, col.type_.type,
int(results[decimal_idx][0]), int(results[decimal_idx][0]),
int(results[decimal_idx][1]), int(results[decimal_idx][1]),
) )
if result_type != column_type: if result_type != col.type_:
tdLog.exit( tdLog.exit(
f"check show create table failed for: {tbname} column {i} type is {result_type}, expect {column_type.get_decimal_type()}" f"check show create table failed for: {tbname} column {i} type is {result_type}, expect {col.type}"
) )
decimal_idx += 1 decimal_idx += 1
def test_add_drop_columns_with_decimal(self, tbname: str, columns: List[DataType]): def test_add_drop_columns_with_decimal(self, tbname: str, columns: List[Column]):
is_stb = tbname == self.stable_name is_stb = tbname == self.stable_name
## alter table add column ## alter table add column
create_c99_sql = ( create_c99_sql = (
f"alter table {self.db_name}.{tbname} add column c99 decimal(37, 19)" f"alter table {self.db_name}.{tbname} add column c99 decimal(37, 19)"
) )
columns.append(DecimalType(TypeEnum.DECIMAL, 37, 19)) columns.append(Column(DecimalType(TypeEnum.DECIMAL, 37, 19)))
tdSql.execute(create_c99_sql, queryTimes=1, show=True) tdSql.execute(create_c99_sql, queryTimes=1, show=True)
self.check_desc(tbname, columns) self.check_desc(tbname, columns)
## alter table add column with compression ## alter table add column with compression
create_c100_sql = f'ALTER TABLE {self.db_name}.{tbname} ADD COLUMN c100 decimal(36, 18) COMPRESS "zstd"' create_c100_sql = f'ALTER TABLE {self.db_name}.{tbname} ADD COLUMN c100 decimal(36, 18) COMPRESS "zstd"'
tdSql.execute(create_c100_sql, queryTimes=1, show=True) tdSql.execute(create_c100_sql, queryTimes=1, show=True)
columns.append(DecimalType(TypeEnum.DECIMAL, 36, 18)) columns.append(Column(DecimalType(TypeEnum.DECIMAL, 36, 18)))
self.check_desc(tbname, columns) self.check_desc(tbname, columns)
## drop non decimal column ## drop non decimal column
@ -947,33 +764,42 @@ class TDTestCase:
def test_decimal_column_ddl(self): def test_decimal_column_ddl(self):
## create decimal type table, normal/super table, decimal64/decimal128 ## create decimal type table, normal/super table, decimal64/decimal128
tdLog.printNoPrefix("-------- test create decimal column") tdLog.printNoPrefix("-------- test create decimal column")
self.norm_tb_columns = [ self.norm_tb_columns: List[Column] = [
DecimalType(TypeEnum.DECIMAL, 10, 2), Column(DecimalType(TypeEnum.DECIMAL, 10, 2)),
DecimalType(TypeEnum.DECIMAL, 20, 4), Column(DecimalType(TypeEnum.DECIMAL, 20, 4)),
DecimalType(TypeEnum.DECIMAL, 30, 8), Column(DecimalType(TypeEnum.DECIMAL, 30, 8)),
DecimalType(TypeEnum.DECIMAL, 38, 10), Column(DecimalType(TypeEnum.DECIMAL, 38, 10)),
DataType(TypeEnum.TINYINT), Column(DataType(TypeEnum.TINYINT)),
DataType(TypeEnum.INT), Column(DataType(TypeEnum.INT)),
DataType(TypeEnum.BIGINT), Column(DataType(TypeEnum.BIGINT)),
DataType(TypeEnum.DOUBLE), Column(DataType(TypeEnum.DOUBLE)),
DataType(TypeEnum.FLOAT), Column(DataType(TypeEnum.FLOAT)),
DataType(TypeEnum.VARCHAR, 255), Column(DataType(TypeEnum.VARCHAR, 255)),
] ]
self.tags = [DataType(TypeEnum.INT), DataType(TypeEnum.VARCHAR, 255)] self.tags: List[Column] = [
self.stb_columns = [ Column(DataType(TypeEnum.INT)),
DecimalType(TypeEnum.DECIMAL, 10, 2), Column(DataType(TypeEnum.VARCHAR, 255)),
DecimalType(TypeEnum.DECIMAL, 20, 4), ]
DecimalType(TypeEnum.DECIMAL, 30, 8), self.stb_columns: List[Column] = [
DecimalType(TypeEnum.DECIMAL, 38, 10), Column(DecimalType(TypeEnum.DECIMAL, 10, 2)),
DataType(TypeEnum.TINYINT), Column(DecimalType(TypeEnum.DECIMAL, 20, 4)),
DataType(TypeEnum.INT), Column(DecimalType(TypeEnum.DECIMAL, 30, 8)),
DataType(TypeEnum.BIGINT), Column(DecimalType(TypeEnum.DECIMAL, 38, 10)),
DataType(TypeEnum.DOUBLE), Column(DataType(TypeEnum.TINYINT)),
DataType(TypeEnum.FLOAT), Column(DataType(TypeEnum.INT)),
DataType(TypeEnum.VARCHAR, 255), Column(DataType(TypeEnum.BIGINT)),
Column(DataType(TypeEnum.DOUBLE)),
Column(DataType(TypeEnum.FLOAT)),
Column(DataType(TypeEnum.VARCHAR, 255)),
] ]
DecimalColumnTableCreater( DecimalColumnTableCreater(
tdSql, self.db_name, self.stable_name, self.stb_columns, self.tags tdSql,
self.db_name,
self.stable_name,
self.stb_columns,
self.tags,
col_prefix=self.col_prefix,
tag_prefix=self.tag_name_prefix,
).create() ).create()
self.check_show_create_table("meters", self.stb_columns, self.tags) self.check_show_create_table("meters", self.stb_columns, self.tags)
@ -1075,12 +901,12 @@ class TDTestCase:
def no_decimal_table_test(self): def no_decimal_table_test(self):
columns = [ columns = [
DataType(TypeEnum.TINYINT), Column(DataType(TypeEnum.TINYINT)),
DataType(TypeEnum.INT), Column(DataType(TypeEnum.INT)),
DataType(TypeEnum.BIGINT), Column(DataType(TypeEnum.BIGINT)),
DataType(TypeEnum.DOUBLE), Column(DataType(TypeEnum.DOUBLE)),
DataType(TypeEnum.FLOAT), Column(DataType(TypeEnum.FLOAT)),
DataType(TypeEnum.VARCHAR, 255), Column(DataType(TypeEnum.VARCHAR, 255)),
] ]
DecimalColumnTableCreater( DecimalColumnTableCreater(
tdSql, self.db_name, self.no_decimal_col_tb_name, columns, [] tdSql, self.db_name, self.no_decimal_col_tb_name, columns, []
@ -1095,7 +921,7 @@ class TDTestCase:
## Create table with no decimal type, the metaentries should not have extschma, and add decimal column, the metaentries should have extschema for all columns. ## Create table with no decimal type, the metaentries should not have extschma, and add decimal column, the metaentries should have extschema for all columns.
sql = f"ALTER TABLE {self.db_name}.{self.no_decimal_col_tb_name} ADD COLUMN c200 decimal(37, 19)" sql = f"ALTER TABLE {self.db_name}.{self.no_decimal_col_tb_name} ADD COLUMN c200 decimal(37, 19)"
tdSql.execute(sql, queryTimes=1) ## now meta entry has ext schemas tdSql.execute(sql, queryTimes=1) ## now meta entry has ext schemas
columns.append(DecimalType(TypeEnum.DECIMAL, 37, 19)) columns.append(Column(DecimalType(TypeEnum.DECIMAL, 37, 19)))
self.check_desc(self.no_decimal_col_tb_name, columns) self.check_desc(self.no_decimal_col_tb_name, columns)
## After drop this only decimal column, the metaentries should not have extschema for all columns. ## After drop this only decimal column, the metaentries should not have extschema for all columns.
@ -1107,7 +933,7 @@ class TDTestCase:
self.check_desc(self.no_decimal_col_tb_name, columns) self.check_desc(self.no_decimal_col_tb_name, columns)
sql = f"ALTER TABLE {self.db_name}.{self.no_decimal_col_tb_name} ADD COLUMN c200 int" sql = f"ALTER TABLE {self.db_name}.{self.no_decimal_col_tb_name} ADD COLUMN c200 int"
tdSql.execute(sql, queryTimes=1) ## meta entry has no ext schemas tdSql.execute(sql, queryTimes=1) ## meta entry has no ext schemas
columns.append(DataType(TypeEnum.INT)) columns.append(Column(DataType(TypeEnum.INT)))
self.check_desc(self.no_decimal_col_tb_name, columns) self.check_desc(self.no_decimal_col_tb_name, columns)
self.test_add_drop_columns_with_decimal(self.no_decimal_col_tb_name, columns) self.test_add_drop_columns_with_decimal(self.no_decimal_col_tb_name, columns)
@ -1137,9 +963,9 @@ class TDTestCase:
self.test_decimal_ddl() self.test_decimal_ddl()
self.no_decimal_table_test() self.no_decimal_table_test()
self.test_insert_decimal_values() self.test_insert_decimal_values()
self.test_query_query() self.test_query_decimal()
self.test_decimal_and_stream() ##self.test_decimal_and_stream()
self.test_decimal_and_tsma() ##self.test_decimal_and_tsma()
def stop(self): def stop(self):
tdSql.close() tdSql.close()
@ -1158,36 +984,76 @@ class TDTestCase:
f"wait query result timeout for {sql} failed after {times} time, expect {expect_result}, but got {results}" f"wait query result timeout for {sql} failed after {times} time, expect {expect_result}, but got {results}"
) )
def a(self, dbname, tbname, columns, operator: DecimalColumnExpr): def check_decimal_binary_expr_results(
self,
dbname,
tbname,
tb_cols: List[Column],
constant_cols: List[Column],
exprs: List[DecimalColumnExpr],
):
for expr in exprs:
for col in tb_cols:
left_is_decimal = col.type_.is_decimal_type()
for const_val in constant_cols:
right_is_decimal = const_val.type_.is_decimal_type()
if not left_is_decimal and not right_is_decimal:
continue
select_expr = expr.generate((col.name_, const_val.generate_value()))
sql = f"select {select_expr} from {dbname}.{tbname}"
res = TaosShell().query(sql)
## query ## query
## build expr, operator.generate() ## build expr, expr.generate(column) to generate sql expr
## pass this expr into DataValidator.
# When validating between query results and local values, pass the column data into the Expr, and invoke expr.execute
## get result ## get result
## check result ## check result
##
pass ## test others unsupported types operator with decimal
def test_decimal_unsupported_types(self):
unsupported_type_cols = [
Column(DataType(TypeEnum.JSON)),
Column(DataType(TypeEnum.GEOMETRY)),
Column(DataType(TypeEnum.VARBINARY)),
]
all_type_columns = Column.get_all_type_columns()
tbname = "test_decimal_unsupported_types"
DecimalColumnTableCreater(
tdSql, self.db_name, tbname, all_type_columns
).create()
tdSql.error(
f"select c17 + c15 from {self.db_name}{tbname}", queryTimes=1, show=True
)
def test_decimal_operators(self): def test_decimal_operators(self):
tdLog.debug("start to test decimal operators") tdLog.debug("start to test decimal operators")
## tables: meters, nt ## tables: meters, nt
## columns: c1, c2, c3, c4, c5, c7, c8, c9, c10, c99, c100 ## columns: c1, c2, c3, c4, c5, c7, c8, c9, c10, c99, c100
binary_operators = [ binary_operators = [
DecimalBinaryOperator("+"), DecimalColumnExpr("%s + %s", DecimalBinaryOperator.execute_plus),
DecimalBinaryOperator("-"), DecimalColumnExpr("-"),
DecimalBinaryOperator("*"), DecimalColumnExpr("*"),
DecimalBinaryOperator("/"), DecimalColumnExpr("/"),
DecimalBinaryOperator("%"), DecimalColumnExpr("%"),
DecimalBinaryOperator(">"), DecimalColumnExpr(">"),
DecimalBinaryOperator("<"), DecimalColumnExpr("<"),
DecimalBinaryOperator(">="), DecimalColumnExpr(">="),
DecimalBinaryOperator("<="), DecimalColumnExpr("<="),
DecimalBinaryOperator("=="), DecimalColumnExpr("=="),
DecimalBinaryOperator("!="), DecimalColumnExpr("!="),
DecimalBinaryOperatorIn("in"), DecimalBinaryOperatorIn("in"),
DecimalBinaryOperatorIn("not in"), DecimalBinaryOperatorIn("not in"),
] ]
all_type_columns = DataType.get_all_type_columns() all_type_columns = Column.get_all_type_columns()
## decimal operator with constants of all other types ## decimal operator with constants of all other types
self.check_decimal_binary_expr_results(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
all_type_columns,
binary_operators,
)
## decimal operator with columns of all other types ## decimal operator with columns of all other types