diff --git a/tests/army/test.py b/tests/army/test.py index d37d08b406..bf14f19e3f 100644 --- a/tests/army/test.py +++ b/tests/army/test.py @@ -24,6 +24,7 @@ import platform import socket import threading import importlib +import inspect import toml @@ -56,6 +57,17 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") +def get_local_classes(module): + classes = [] + for name, obj in inspect.getmembers(module, inspect.isclass): + if inspect.getmodule(obj) == module: + classes.append(name) + return classes + +def dynamicLoadModule(fileName): + moduleName = fileName.replace(".py", "").replace(os.sep, ".") + return importlib.import_module(moduleName, package='..') + # # run case on previous cluster # @@ -66,9 +78,9 @@ def runOnPreviousCluster(host, config, fileName): sep = "/" if platform.system().lower() == 'windows': sep = os.sep - moduleName = fileName.replace(".py", "").replace(sep, ".") - uModule = importlib.import_module(moduleName) - case = uModule.TDTestCase() + uModule = dynamicLoadModule(fileName) + case_class = getattr(uModule, get_local_classes(uModule)[0]) + case = case_class() # create conn conn = taos.connect(host, config) @@ -358,10 +370,10 @@ if __name__ == "__main__": updateCfgDictStr = '' # adapter_cfg_dict_str = '' if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + case_class = getattr(uModule, get_local_classes(uModule)[0]) + ucase = case_class() if ((json.dumps(updateCfgDict) == '{}') and hasattr(ucase, 'updatecfgDict')): updateCfgDict = ucase.updatecfgDict updateCfgDictStr = "-d %s"%base64.b64encode(json.dumps(updateCfgDict).encode()).decode() @@ -530,10 +542,10 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + case_class = getattr(uModule, get_local_classes(uModule)[0]) + ucase = case_class() if (json.dumps(updateCfgDict) == '{}'): updateCfgDict = ucase.updatecfgDict if (json.dumps(adapter_cfg_dict) == '{}'): diff --git a/tests/develop-test/test.py b/tests/develop-test/test.py index 6b1c63a1c0..9106e38d9a 100644 --- a/tests/develop-test/test.py +++ b/tests/develop-test/test.py @@ -22,6 +22,9 @@ import json import platform import socket import threading +import inspect +import importlib +import os import toml @@ -57,6 +60,17 @@ def checkRunTimeError(): os.system("TASKKILL /F /IM taosd.exe") +def get_local_classes(module): + classes = [] + for name, obj in inspect.getmembers(module, inspect.isclass): + if inspect.getmodule(obj) == module: + classes.append(name) + return classes + +def dynamicLoadModule(fileName): + moduleName = fileName.replace(".py", "").replace(os.sep, ".") + return importlib.import_module(moduleName, package='..') + if __name__ == "__main__": fileName = "all" @@ -295,10 +309,10 @@ if __name__ == "__main__": updateCfgDictStr = "" # adapter_cfg_dict_str = '' if is_test_framework: - moduleName = fileName.replace(".py", "").replace(os.sep, ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + case_class = getattr(uModule, get_local_classes(uModule)[0]) + ucase = case_class() if (json.dumps(updateCfgDict) == "{}") and hasattr( ucase, "updatecfgDict" ): @@ -434,10 +448,10 @@ if __name__ == "__main__": except: pass if is_test_framework: - moduleName = fileName.replace(".py", "").replace("/", ".") - uModule = importlib.import_module(moduleName) + uModule = dynamicLoadModule(fileName) try: - ucase = uModule.TDTestCase() + case_class = getattr(uModule, get_local_classes(uModule)[0]) + ucase = case_class() if json.dumps(updateCfgDict) == "{}": updateCfgDict = ucase.updatecfgDict if json.dumps(adapter_cfg_dict) == "{}":