From 80611e553c05e40a4dbe9ce66696a5a9a26bf531 Mon Sep 17 00:00:00 2001 From: zhipz Date: Mon, 15 Jul 2024 14:33:51 +0800 Subject: [PATCH] Fixing the incorrect version of common.py --- tests/pytest/util/common.py | 150 ++++++++++++++++++++++++++++++++---- 1 file changed, 133 insertions(+), 17 deletions(-) diff --git a/tests/pytest/util/common.py b/tests/pytest/util/common.py index 412ce22545..38d8502572 100644 --- a/tests/pytest/util/common.py +++ b/tests/pytest/util/common.py @@ -139,7 +139,7 @@ class TDCom: self.stream_suffix = "_stream" self.range_count = 5 self.default_interval = 5 - self.stream_timeout = 12 + self.stream_timeout = 60 self.create_stream_sleep = 0.5 self.record_history_ts = str() self.precision = "ms" @@ -201,6 +201,9 @@ class TDCom: self.cast_tag_stb_filter_des_select_elm = "ts, t1, t2, t3, t4, cast(t1 as TINYINT UNSIGNED), t6, t7, t8, t9, t10, cast(t2 as varchar(256)), t12, cast(t3 as bool)" self.tag_count = len(self.tag_filter_des_select_elm.split(",")) self.state_window_range = list() + + self.custom_col_val = 0 + self.part_val_list = [1, 2] # def init(self, conn, logSql): # # tdSql.init(conn.cursor(), logSql) @@ -506,7 +509,22 @@ class TDCom: if ("packaging" not in rootRealPath): buildPath = root[:len(root) - len("/build/bin")] break + if platform.system().lower() == 'windows': + win_sep = "\\" + buildPath = buildPath.replace(win_sep,'/') + return buildPath + + def getTaosdPath(self, dnodeID="dnode1"): + buildPath = self.getBuildPath() + if (buildPath == ""): + tdLog.exit("taosd not found!") + else: + tdLog.info("taosd found in %s" % buildPath) + taosdPath = buildPath + "/../sim/" + dnodeID + tdLog.info("taosdPath: %s" % taosdPath) + return taosdPath + def getClientCfgPath(self): buildPath = self.getBuildPath() @@ -519,21 +537,21 @@ class TDCom: tdLog.info("cfgPath: %s" % cfgPath) return cfgPath - def newcon(self,host='localhost',port=6030,user='root',password='taosdata'): - con=taos.connect(host=host, user=user, password=password, port=port) + def newcon(self,host='localhost',port=6030,user='root',password='taosdata', database='None'): + con=taos.connect(host=host, user=user, password=password, port=port, database=database) # print(con) return con - def newcur(self,host='localhost',port=6030,user='root',password='taosdata', database=None): + def newcur(self,host='localhost',port=6030,user='root',password='taosdata',databse='None'): cfgPath = self.getClientCfgPath() - con=taos.connect(host=host, user=user, password=password, config=cfgPath, port=port, database=None) + con=taos.connect(host=host, user=user, password=password, config=cfgPath, port=port,database='None') cur=con.cursor() # print(cur) return cur - def newTdSql(self, host='localhost',port=6030,user='root',password='taosdata', database=None): + def newTdSql(self, host='localhost',port=6030,user='root',password='taosdata'): newTdSql = TDSql() - cur = self.newcur(host=host,port=port,user=user,password=password, database=None) + cur = self.newcur(host=host,port=port,user=user,password=password) newTdSql.init(cur, False) return newTdSql @@ -1244,7 +1262,7 @@ class TDCom: default_ctbname_index_start_num += 1 tdSql.execute(create_stable_sql) - def sgen_column_value_list(self, column_elm_list, need_null, ts_value=None): + def sgen_column_value_list(self, column_elm_list, need_null, ts_value=None, additional_ts=None, custom_col_index=None, col_value_type=None, force_pk_val=None): """_summary_ Args: @@ -1254,6 +1272,8 @@ class TDCom: """ self.column_value_list = list() self.ts_value = self.genTs()[0] + if additional_ts is not None: + self.additional_ts = self.genTs(additional_ts=additional_ts)[2] if ts_value is not None: self.ts_value = ts_value @@ -1277,7 +1297,22 @@ class TDCom: for i in range(int(len(self.column_value_list)/2)): index_num = random.randint(0, len(self.column_value_list)-1) self.column_value_list[index_num] = None - self.column_value_list = [self.ts_value] + self.column_value_list + + if custom_col_index is not None: + if col_value_type == "Random": + pass + elif col_value_type == "Incremental": + self.column_value_list[custom_col_index] = self.custom_col_val + self.custom_col_val += 1 + elif col_value_type == "Part_equal": + self.column_value_list[custom_col_index] = random.choice(self.part_val_list) + + self.column_value_list = [self.ts_value] + [self.additional_ts] + self.column_value_list if additional_ts is not None else [self.ts_value] + self.column_value_list + if col_value_type == "Incremental" and custom_col_index==1: + self.column_value_list[custom_col_index] = self.custom_col_val if force_pk_val is None else force_pk_val + if col_value_type == "Part_equal" and custom_col_index==1: + self.column_value_list[custom_col_index] = random.randint(0, self.custom_col_val) if force_pk_val is None else force_pk_val + def screate_table(self, dbname=None, tbname="tb", use_name="table", column_elm_list=None, count=1, default_tbname_prefix="tb", default_tbname_index_start_num=1, @@ -1318,7 +1353,7 @@ class TDCom: default_tbname_index_start_num += 1 tdSql.execute(create_table_sql) - def sinsert_rows(self, dbname=None, tbname=None, column_ele_list=None, ts_value=None, count=1, need_null=False): + def sinsert_rows(self, dbname=None, tbname=None, column_ele_list=None, ts_value=None, count=1, need_null=False, custom_col_index=None, col_value_type="random"): """insert rows Args: @@ -1338,7 +1373,7 @@ class TDCom: if tbname is not None: self.tbname = tbname - self.sgen_column_value_list(column_ele_list, need_null, ts_value) + self.sgen_column_value_list(column_ele_list, need_null, ts_value, custom_col_index=custom_col_index, col_value_type=col_value_type) # column_value_str = ", ".join(str(v) for v in self.column_value_list) column_value_str = "" for column_value in self.column_value_list: @@ -1355,7 +1390,7 @@ class TDCom: else: for num in range(count): ts_value = self.genTs()[0] - self.sgen_column_value_list(column_ele_list, need_null, f'{ts_value}+{num}s') + self.sgen_column_value_list(column_ele_list, need_null, f'{ts_value}+{num}s', custom_col_index=custom_col_index, col_value_type=col_value_type) column_value_str = "" for column_value in self.column_value_list: if column_value is None: @@ -1673,8 +1708,8 @@ class TDCom: res1 = self.round_handle(res1) res2 = self.round_handle(res2) if latency < self.stream_timeout: - latency += 0.2 - time.sleep(0.2) + latency += 0.5 + time.sleep(0.5) else: if latency == 0: return False @@ -1762,7 +1797,7 @@ class TDCom: self.sdelete_rows(tbname=self.ctb_name, start_ts=self.time_cast(self.record_history_ts, "-")) self.sdelete_rows(tbname=self.tb_name, start_ts=self.time_cast(self.record_history_ts, "-")) - def prepare_data(self, interval=None, watermark=None, session=None, state_window=None, state_window_max=127, interation=3, range_count=None, precision="ms", fill_history_value=0, ext_stb=None): + def prepare_data(self, interval=None, watermark=None, session=None, state_window=None, state_window_max=127, interation=3, range_count=None, precision="ms", fill_history_value=0, ext_stb=None, custom_col_index=None, col_value_type="random"): """prepare stream data Args: @@ -1825,11 +1860,89 @@ class TDCom: if fill_history_value == 1: for i in range(self.range_count): ts_value = str(self.date_time)+f'-{self.default_interval*(i+1)}s' - self.sinsert_rows(tbname=self.ctb_name, ts_value=ts_value) - self.sinsert_rows(tbname=self.tb_name, ts_value=ts_value) + self.sinsert_rows(tbname=self.ctb_name, ts_value=ts_value, custom_col_index=custom_col_index, col_value_type=col_value_type) + self.sinsert_rows(tbname=self.tb_name, ts_value=ts_value, custom_col_index=custom_col_index, col_value_type=col_value_type) if i == 1: self.record_history_ts = ts_value + def get_subtable(self, tbname_pre): + tdSql.query(f'show tables') + tbname_list = list(map(lambda x:x[0], tdSql.queryResult)) + for tbname in tbname_list: + if tbname_pre in tbname: + return tbname + + def get_subtable_wait(self, tbname_pre): + tbname = self.get_subtable(tbname_pre) + latency = 0 + while tbname is None: + tbname = self.get_subtable(tbname_pre) + if latency < self.stream_timeout: + latency += 1 + time.sleep(1) + return tbname + + def get_group_id_from_stb(self, stbname): + tdSql.query(f'select distinct group_id from {stbname}') + cnt = 0 + while len(tdSql.queryResult) == 0: + tdSql.query(f'select distinct group_id from {stbname}') + if cnt < self.default_interval: + cnt += 1 + time.sleep(1) + else: + return False + return tdSql.queryResult[0][0] + + def update_json_file_replica(self, json_file_path, new_replica_value, output_file_path=None): + """ + Read a JSON file, update the 'replica' value, and write the result back to a file. + + Parameters: + json_file_path (str): The path to the original JSON file. + new_replica_value (int): The new 'replica' value to be set. + output_file_path (str, optional): The path to the output file where the updated JSON will be saved. + If not provided, the original file will be overwritten. + + Returns: + None + """ + try: + # Read the JSON file and load its content into a Python dictionary + with open(json_file_path, 'r', encoding='utf-8') as file: + data = json.load(file) + + # Iterate over each item in the 'databases' list to find 'dbinfo' and update 'replica' + for db in data['databases']: + if 'dbinfo' in db: + db['dbinfo']['replica'] = new_replica_value + + # Convert the updated dictionary back into a JSON string with indentation for readability + updated_json_str = json.dumps(data, indent=4, ensure_ascii=False) + + # Write the updated JSON string to a file + if output_file_path: + # If an output file path is provided, write to the new file + with open(output_file_path, 'w', encoding='utf-8') as output_file: + output_file.write(updated_json_str) + else: + # Otherwise, overwrite the original file with the updated content + with open(json_file_path, 'w', encoding='utf-8') as file: + file.write(updated_json_str) + + except json.JSONDecodeError as e: + # Handle JSON decoding error (e.g., if the file is not valid JSON) + print(f"JSON decode error: {e}") + except FileNotFoundError: + # Handle the case where the JSON file is not found at the given path + print(f"File not found: {json_file_path}") + except KeyError as e: + # Handle missing key error (e.g., if 'databases' or 'dbinfo' is not present) + print(f"Key error: {e}") + except Exception as e: + # Handle any other exceptions that may occur + print(f"An error occurred: {e}") + def is_json(msg): if isinstance(msg, str): try: @@ -1864,4 +1977,7 @@ def dict2toml(in_dict: dict, file:str): with open(file, 'w') as f: toml.dump(in_dict, f) + + + tdCom = TDCom()