diff --git a/tests/army/test.py b/tests/army/test.py index d37d08b406..6ac0948b7b 100644 --- a/tests/army/test.py +++ b/tests/army/test.py @@ -24,7 +24,7 @@ import platform import socket import threading import importlib - +import ast import toml from frame.log import * @@ -56,6 +56,17 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") +def get_local_classes_in_order(file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] + return classes + +def dynamicLoadModule(fileName): + moduleName = fileName.replace(".py", "").replace(os.sep, ".") + return importlib.import_module(moduleName, package='..') + # # run case on previous cluster # @@ -66,9 +77,11 @@ def runOnPreviousCluster(host, config, fileName): sep = "/" if platform.system().lower() == 'windows': sep = os.sep - moduleName = fileName.replace(".py", "").replace(sep, ".") - uModule = importlib.import_module(moduleName) - case = uModule.TDTestCase() + + uModule = dynamicLoadModule(fileName) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + case = case_class() # create conn conn = taos.connect(host, config) @@ -358,10 +371,11 @@ if __name__ == "__main__": updateCfgDictStr = '' # adapter_cfg_dict_str = '' if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + ucase = case_class() if ((json.dumps(updateCfgDict) == '{}') and hasattr(ucase, 'updatecfgDict')): updateCfgDict = ucase.updatecfgDict updateCfgDictStr = "-d %s"%base64.b64encode(json.dumps(updateCfgDict).encode()).decode() @@ -530,10 +544,11 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + ucase = case_class() if (json.dumps(updateCfgDict) == '{}'): updateCfgDict = ucase.updatecfgDict if (json.dumps(adapter_cfg_dict) == '{}'): diff --git a/tests/develop-test/test.py b/tests/develop-test/test.py index 6b1c63a1c0..3525fd6332 100644 --- a/tests/develop-test/test.py +++ b/tests/develop-test/test.py @@ -22,6 +22,9 @@ import json import platform import socket import threading +import ast +import importlib +import os import toml @@ -56,6 +59,17 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") +def get_local_classes_in_order(file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] + return classes + +def dynamicLoadModule(fileName): + moduleName = fileName.replace(".py", "").replace(os.sep, ".") + return importlib.import_module(moduleName, package='..') + if __name__ == "__main__": @@ -295,10 +309,11 @@ if __name__ == "__main__": updateCfgDictStr = "" # adapter_cfg_dict_str = '' if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + ucase = case_class() if (json.dumps(updateCfgDict) == "{}") and hasattr( ucase, "updatecfgDict" ): @@ -434,10 +449,11 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + ucase = case_class() if json.dumps(updateCfgDict) == "{}": updateCfgDict = ucase.updatecfgDict if json.dumps(adapter_cfg_dict) == "{}": diff --git a/tests/pytest/test.py b/tests/pytest/test.py index 1b185ef189..fa91d20c00 100644 --- a/tests/pytest/test.py +++ b/tests/pytest/test.py @@ -19,6 +19,7 @@ import subprocess import time from distutils.log import warn as printf import platform +import ast from util.log import * from util.dnodes import * @@ -26,6 +27,17 @@ from util.cases import * import taos +def get_local_classes_in_order(file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] + return classes + +def dynamicLoadModule(fileName): + moduleName = fileName.replace(".py", "").replace(os.sep, ".") + return importlib.import_module(moduleName, package='..') + if __name__ == "__main__": fileName = "all" @@ -136,10 +148,11 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + ucase = case_class() tdDnodes.deploy(1,ucase.updatecfgDict) except : tdDnodes.deploy(1,{}) @@ -170,10 +183,11 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + ucase = case_class() tdDnodes.deploy(1,ucase.updatecfgDict) except : tdDnodes.deploy(1,{}) diff --git a/tests/pytest/util/cases.py b/tests/pytest/util/cases.py index eee8809ad0..ac9b97a874 100644 --- a/tests/pytest/util/cases.py +++ b/tests/pytest/util/cases.py @@ -20,7 +20,7 @@ import importlib import traceback from util.log import * import platform - +import ast class TDCase: def __init__(self, name, case): @@ -51,12 +51,22 @@ class TDCases: def addCluster(self, name, case): self.clusterCases.append(TDCase(name, case)) + def get_local_classes_in_order(self, file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] + return classes + def runAllLinux(self, conn): # TODO: load all Linux cases here runNum = 0 for tmp in self.linuxCases: if tmp.name.find(fileName) != -1: - case = testModule.TDTestCase() + # get the last class name as the test case class name + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) + case = case_class() case.init(conn) case.run() case.stop() @@ -71,7 +81,10 @@ class TDCases: runNum = 0 for tmp in self.linuxCases: if tmp.name.find(fileName) != -1: - case = testModule.TDTestCase() + # get the last class name as the test case class name + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) + case = case_class() case.init(conn, self._logSql, replicaVar) try: case.run() @@ -88,7 +101,10 @@ class TDCases: runNum = 0 for tmp in self.windowsCases: if tmp.name.find(fileName) != -1: - case = testModule.TDTestCase() + # get the last class name as the test case class name + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) + case = case_class() case.init(conn) case.run() case.stop() @@ -103,7 +119,10 @@ class TDCases: runNum = 0 for tmp in self.windowsCases: if tmp.name.find(fileName) != -1: - case = testModule.TDTestCase() + # get the last class name as the test case class name + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) + case = case_class() case.init(conn, self._logSql,replicaVar) try: case.run() @@ -117,12 +136,16 @@ class TDCases: def runAllCluster(self): # TODO: load all cluster case module here + testModule = self.__dynamicLoadModule(fileName) runNum = 0 for tmp in self.clusterCases: if tmp.name.find(fileName) != -1: tdLog.notice("run cases like %s" % (fileName)) - case = testModule.TDTestCase() + # get the last class name as the test case class name + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) + case = case_class() case.init() case.run() case.stop() @@ -138,7 +161,10 @@ class TDCases: for tmp in self.clusterCases: if tmp.name.find(fileName) != -1: tdLog.notice("run cases like %s" % (fileName)) - case = testModule.TDTestCase() + # get the last class name as the test case class name + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) + case = case_class() case.init() case.run() case.stop() diff --git a/tests/system-test/test.py b/tests/system-test/test.py index ab1bdc21d3..cd0e60160c 100644 --- a/tests/system-test/test.py +++ b/tests/system-test/test.py @@ -24,6 +24,7 @@ import platform import socket import threading import importlib +import ast print(f"Python version: {sys.version}") print(f"Version info: {sys.version_info}") @@ -58,6 +59,18 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") +def get_local_classes_in_order(file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] + return classes + + +def dynamicLoadModule(fileName): + moduleName = fileName.replace(".py", "").replace(os.sep, ".") + return importlib.import_module(moduleName, package='..') + # # run case on previous cluster # @@ -68,9 +81,11 @@ def runOnPreviousCluster(host, config, fileName): sep = "/" if platform.system().lower() == 'windows': sep = os.sep - moduleName = fileName.replace(".py", "").replace(sep, ".") - uModule = importlib.import_module(moduleName) - case = uModule.TDTestCase() + + uModule = dynamicLoadModule(fileName) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + case = case_class() # create conn conn = taos.connect(host, config) @@ -350,10 +365,11 @@ if __name__ == "__main__": updateCfgDictStr = '' # adapter_cfg_dict_str = '' if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + ucase = case_class() if ((json.dumps(updateCfgDict) == '{}') and hasattr(ucase, 'updatecfgDict')): updateCfgDict = ucase.updatecfgDict updateCfgDictStr = "-d %s"%base64.b64encode(json.dumps(updateCfgDict).encode()).decode() @@ -522,10 +538,11 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + ucase = case_class() if (json.dumps(updateCfgDict) == '{}'): updateCfgDict = ucase.updatecfgDict if (json.dumps(adapter_cfg_dict) == '{}'): diff --git a/tests/test_new/test.py b/tests/test_new/test.py index ab1bdc21d3..cd0e60160c 100644 --- a/tests/test_new/test.py +++ b/tests/test_new/test.py @@ -24,6 +24,7 @@ import platform import socket import threading import importlib +import ast print(f"Python version: {sys.version}") print(f"Version info: {sys.version_info}") @@ -58,6 +59,18 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") +def get_local_classes_in_order(file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] + return classes + + +def dynamicLoadModule(fileName): + moduleName = fileName.replace(".py", "").replace(os.sep, ".") + return importlib.import_module(moduleName, package='..') + # # run case on previous cluster # @@ -68,9 +81,11 @@ def runOnPreviousCluster(host, config, fileName): sep = "/" if platform.system().lower() == 'windows': sep = os.sep - moduleName = fileName.replace(".py", "").replace(sep, ".") - uModule = importlib.import_module(moduleName) - case = uModule.TDTestCase() + + uModule = dynamicLoadModule(fileName) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + case = case_class() # create conn conn = taos.connect(host, config) @@ -350,10 +365,11 @@ if __name__ == "__main__": updateCfgDictStr = '' # adapter_cfg_dict_str = '' if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + ucase = case_class() if ((json.dumps(updateCfgDict) == '{}') and hasattr(ucase, 'updatecfgDict')): updateCfgDict = ucase.updatecfgDict updateCfgDictStr = "-d %s"%base64.b64encode(json.dumps(updateCfgDict).encode()).decode() @@ -522,10 +538,11 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) + ucase = case_class() if (json.dumps(updateCfgDict) == '{}'): updateCfgDict = ucase.updatecfgDict if (json.dumps(adapter_cfg_dict) == '{}'):