diff --git a/tests/army/test.py b/tests/army/test.py index 5827657106..56b868c7ae 100644 --- a/tests/army/test.py +++ b/tests/army/test.py @@ -369,8 +369,7 @@ if __name__ == "__main__": updateCfgDictStr = '' # adapter_cfg_dict_str = '' if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: case_class = getattr(uModule, get_local_classes(uModule)[-1]) ucase = case_class() @@ -542,8 +541,7 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: case_class = getattr(uModule, get_local_classes(uModule)[-1]) ucase = case_class() diff --git a/tests/develop-test/test.py b/tests/develop-test/test.py index c0a090aaa5..76d5732e80 100644 --- a/tests/develop-test/test.py +++ b/tests/develop-test/test.py @@ -22,6 +22,9 @@ import json import platform import socket import threading +import inspect +import importlib +import os import toml @@ -56,6 +59,17 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") +def get_local_classes(module): + classes = [] + for name, obj in inspect.getmembers(module, inspect.isclass): + if inspect.getmodule(obj) == module: + classes.append(name) + return classes + +def dynamicLoadModule(fileName): + moduleName = fileName.replace(".py", "").replace(os.sep, ".") + return importlib.import_module(moduleName, package='..') + if __name__ == "__main__": @@ -295,8 +309,7 @@ if __name__ == "__main__": updateCfgDictStr = "" # adapter_cfg_dict_str = '' if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: case_class = getattr(uModule, get_local_classes(uModule)[-1]) ucase = case_class() @@ -435,8 +448,7 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: case_class = getattr(uModule, get_local_classes(uModule)[-1]) ucase = case_class() diff --git a/tests/pytest/test.py b/tests/pytest/test.py index eac9d9ea77..9d0e8651b4 100644 --- a/tests/pytest/test.py +++ b/tests/pytest/test.py @@ -147,8 +147,7 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: ucase = uModule.TDTestCase() tdDnodes.deploy(1,ucase.updatecfgDict) @@ -181,8 +180,7 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: ucase = uModule.TDTestCase() tdDnodes.deploy(1,ucase.updatecfgDict) diff --git a/tests/pytest/util/cases.py b/tests/pytest/util/cases.py index 9a76e14790..48f73c50f1 100644 --- a/tests/pytest/util/cases.py +++ b/tests/pytest/util/cases.py @@ -63,7 +63,9 @@ class TDCases: runNum = 0 for tmp in self.linuxCases: if tmp.name.find(fileName) != -1: - case = testModule.TDTestCase() + # get the last class name as the test case class name + case_class = getattr(testModule, self.get_local_classes(testModule)[0]) + case = case_class() case.init(conn) case.run() case.stop() @@ -98,7 +100,9 @@ class TDCases: runNum = 0 for tmp in self.windowsCases: if tmp.name.find(fileName) != -1: - case = testModule.TDTestCase() + # get the last class name as the test case class name + case_class = getattr(testModule, self.get_local_classes(testModule)[-1]) + case = case_class() case.init(conn) case.run() case.stop() diff --git a/tests/system-test/test.py b/tests/system-test/test.py index 87febd1bf2..bf66b6a765 100644 --- a/tests/system-test/test.py +++ b/tests/system-test/test.py @@ -24,6 +24,7 @@ import platform import socket import threading import importlib +import inspect print(f"Python version: {sys.version}") print(f"Version info: {sys.version_info}") @@ -58,6 +59,17 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") +def get_local_classes(module): + classes = [] + for name, obj in inspect.getmembers(module, inspect.isclass): + if inspect.getmodule(obj) == module: + classes.append(name) + return classes + +def dynamicLoadModule(fileName): + moduleName = fileName.replace(".py", "").replace(os.sep, ".") + return importlib.import_module(moduleName, package='..') + # # run case on previous cluster # @@ -69,7 +81,6 @@ def runOnPreviousCluster(host, config, fileName): if platform.system().lower() == 'windows': sep = os.sep uModule = dynamicLoadModule(fileName) - case_class = getattr(uModule, get_local_classes(uModule)[-1]) case = case_class() @@ -351,8 +362,7 @@ if __name__ == "__main__": updateCfgDictStr = '' # adapter_cfg_dict_str = '' if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: case_class = getattr(uModule, get_local_classes(uModule)[-1]) ucase = case_class() @@ -524,8 +534,7 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: case_class = getattr(uModule, get_local_classes(uModule)[-1]) ucase = case_class() diff --git a/tests/test_new/test.py b/tests/test_new/test.py index 87febd1bf2..bf66b6a765 100644 --- a/tests/test_new/test.py +++ b/tests/test_new/test.py @@ -24,6 +24,7 @@ import platform import socket import threading import importlib +import inspect print(f"Python version: {sys.version}") print(f"Version info: {sys.version_info}") @@ -58,6 +59,17 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") +def get_local_classes(module): + classes = [] + for name, obj in inspect.getmembers(module, inspect.isclass): + if inspect.getmodule(obj) == module: + classes.append(name) + return classes + +def dynamicLoadModule(fileName): + moduleName = fileName.replace(".py", "").replace(os.sep, ".") + return importlib.import_module(moduleName, package='..') + # # run case on previous cluster # @@ -69,7 +81,6 @@ def runOnPreviousCluster(host, config, fileName): if platform.system().lower() == 'windows': sep = os.sep uModule = dynamicLoadModule(fileName) - case_class = getattr(uModule, get_local_classes(uModule)[-1]) case = case_class() @@ -351,8 +362,7 @@ if __name__ == "__main__": updateCfgDictStr = '' # adapter_cfg_dict_str = '' if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: case_class = getattr(uModule, get_local_classes(uModule)[-1]) ucase = case_class() @@ -524,8 +534,7 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: case_class = getattr(uModule, get_local_classes(uModule)[-1]) ucase = case_class()