From a7cbf6f8b2b7589f2f654bdfb54bbff88f3a76ec Mon Sep 17 00:00:00 2001 From: chenhaoran Date: Wed, 26 Feb 2025 17:53:39 +0800 Subject: [PATCH] refactor: replace inspect-based class retrieval with AST parsing for improved class order handling --- tests/army/test.py | 22 +++++++++++++--------- tests/develop-test/test.py | 18 ++++++++++-------- tests/pytest/test.py | 17 ++++++++++------- tests/pytest/util/cases.py | 31 ++++++++++++++++++------------- tests/system-test/test.py | 23 ++++++++++++++--------- tests/test_new/test.py | 23 ++++++++++++++--------- 6 files changed, 79 insertions(+), 55 deletions(-) diff --git a/tests/army/test.py b/tests/army/test.py index 56b868c7ae..6ac0948b7b 100644 --- a/tests/army/test.py +++ b/tests/army/test.py @@ -24,7 +24,7 @@ import platform import socket import threading import importlib -import inspect +import ast import toml from frame.log import * @@ -56,11 +56,11 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") -def get_local_classes(module): - classes = [] - for name, obj in inspect.getmembers(module, inspect.isclass): - if inspect.getmodule(obj) == module: - classes.append(name) +def get_local_classes_in_order(file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] return classes def dynamicLoadModule(fileName): @@ -77,8 +77,10 @@ def runOnPreviousCluster(host, config, fileName): sep = "/" if platform.system().lower() == 'windows': sep = os.sep + uModule = dynamicLoadModule(fileName) - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) case = case_class() # create conn @@ -371,7 +373,8 @@ if __name__ == "__main__": if is_test_framework: uModule = dynamicLoadModule(fileName) try: - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) ucase = case_class() if ((json.dumps(updateCfgDict) == '{}') and hasattr(ucase, 'updatecfgDict')): updateCfgDict = ucase.updatecfgDict @@ -543,7 +546,8 @@ if __name__ == "__main__": if is_test_framework: uModule = dynamicLoadModule(fileName) try: - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) ucase = case_class() if (json.dumps(updateCfgDict) == '{}'): updateCfgDict = ucase.updatecfgDict diff --git a/tests/develop-test/test.py b/tests/develop-test/test.py index 76d5732e80..3525fd6332 100644 --- a/tests/develop-test/test.py +++ b/tests/develop-test/test.py @@ -22,7 +22,7 @@ import json import platform import socket import threading -import inspect +import ast import importlib import os @@ -59,11 +59,11 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") -def get_local_classes(module): - classes = [] - for name, obj in inspect.getmembers(module, inspect.isclass): - if inspect.getmodule(obj) == module: - classes.append(name) +def get_local_classes_in_order(file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] return classes def dynamicLoadModule(fileName): @@ -311,7 +311,8 @@ if __name__ == "__main__": if is_test_framework: uModule = dynamicLoadModule(fileName) try: - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) ucase = case_class() if (json.dumps(updateCfgDict) == "{}") and hasattr( ucase, "updatecfgDict" @@ -450,7 +451,8 @@ if __name__ == "__main__": if is_test_framework: uModule = dynamicLoadModule(fileName) try: - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) ucase = case_class() if json.dumps(updateCfgDict) == "{}": updateCfgDict = ucase.updatecfgDict diff --git a/tests/pytest/test.py b/tests/pytest/test.py index 5d1a15961a..fa91d20c00 100644 --- a/tests/pytest/test.py +++ b/tests/pytest/test.py @@ -19,6 +19,7 @@ import subprocess import time from distutils.log import warn as printf import platform +import ast from util.log import * from util.dnodes import * @@ -26,11 +27,11 @@ from util.cases import * import taos -def get_local_classes(module): - classes = [] - for name, obj in inspect.getmembers(module, inspect.isclass): - if inspect.getmodule(obj) == module: - classes.append(name) +def get_local_classes_in_order(file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] return classes def dynamicLoadModule(fileName): @@ -149,7 +150,8 @@ if __name__ == "__main__": if is_test_framework: uModule = dynamicLoadModule(fileName) try: - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) ucase = case_class() tdDnodes.deploy(1,ucase.updatecfgDict) except : @@ -183,7 +185,8 @@ if __name__ == "__main__": if is_test_framework: uModule = dynamicLoadModule(fileName) try: - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) ucase = case_class() tdDnodes.deploy(1,ucase.updatecfgDict) except : diff --git a/tests/pytest/util/cases.py b/tests/pytest/util/cases.py index 48f73c50f1..ac9b97a874 100644 --- a/tests/pytest/util/cases.py +++ b/tests/pytest/util/cases.py @@ -20,7 +20,7 @@ import importlib import traceback from util.log import * import platform - +import ast class TDCase: def __init__(self, name, case): @@ -51,11 +51,11 @@ class TDCases: def addCluster(self, name, case): self.clusterCases.append(TDCase(name, case)) - def get_local_classes(self, module): - classes = [] - for name, obj in inspect.getmembers(module, inspect.isclass): - if inspect.getmodule(obj) == module: - classes.append(name) + def get_local_classes_in_order(self, file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] return classes def runAllLinux(self, conn): @@ -64,7 +64,8 @@ class TDCases: for tmp in self.linuxCases: if tmp.name.find(fileName) != -1: # get the last class name as the test case class name - case_class = getattr(testModule, self.get_local_classes(testModule)[0]) + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) case = case_class() case.init(conn) case.run() @@ -81,8 +82,8 @@ class TDCases: for tmp in self.linuxCases: if tmp.name.find(fileName) != -1: # get the last class name as the test case class name - case_class = getattr(testModule, self.get_local_classes(testModule)[-1]) - print(case_class) + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) case = case_class() case.init(conn, self._logSql, replicaVar) try: @@ -101,7 +102,8 @@ class TDCases: for tmp in self.windowsCases: if tmp.name.find(fileName) != -1: # get the last class name as the test case class name - case_class = getattr(testModule, self.get_local_classes(testModule)[-1]) + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) case = case_class() case.init(conn) case.run() @@ -118,7 +120,8 @@ class TDCases: for tmp in self.windowsCases: if tmp.name.find(fileName) != -1: # get the last class name as the test case class name - case_class = getattr(testModule, self.get_local_classes(testModule)[-1]) + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) case = case_class() case.init(conn, self._logSql,replicaVar) try: @@ -140,7 +143,8 @@ class TDCases: if tmp.name.find(fileName) != -1: tdLog.notice("run cases like %s" % (fileName)) # get the last class name as the test case class name - case_class = getattr(testModule, self.get_local_classes(testModule)[-1]) + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) case = case_class() case.init() case.run() @@ -158,7 +162,8 @@ class TDCases: if tmp.name.find(fileName) != -1: tdLog.notice("run cases like %s" % (fileName)) # get the last class name as the test case class name - case_class = getattr(testModule, self.get_local_classes(testModule)[-1]) + class_names = self.get_local_classes_in_order(fileName) + case_class = getattr(testModule, class_names[-1]) case = case_class() case.init() case.run() diff --git a/tests/system-test/test.py b/tests/system-test/test.py index bf66b6a765..cd0e60160c 100644 --- a/tests/system-test/test.py +++ b/tests/system-test/test.py @@ -24,7 +24,7 @@ import platform import socket import threading import importlib -import inspect +import ast print(f"Python version: {sys.version}") print(f"Version info: {sys.version_info}") @@ -59,13 +59,14 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") -def get_local_classes(module): - classes = [] - for name, obj in inspect.getmembers(module, inspect.isclass): - if inspect.getmodule(obj) == module: - classes.append(name) +def get_local_classes_in_order(file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] return classes + def dynamicLoadModule(fileName): moduleName = fileName.replace(".py", "").replace(os.sep, ".") return importlib.import_module(moduleName, package='..') @@ -80,8 +81,10 @@ def runOnPreviousCluster(host, config, fileName): sep = "/" if platform.system().lower() == 'windows': sep = os.sep + uModule = dynamicLoadModule(fileName) - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) case = case_class() # create conn @@ -364,7 +367,8 @@ if __name__ == "__main__": if is_test_framework: uModule = dynamicLoadModule(fileName) try: - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) ucase = case_class() if ((json.dumps(updateCfgDict) == '{}') and hasattr(ucase, 'updatecfgDict')): updateCfgDict = ucase.updatecfgDict @@ -536,7 +540,8 @@ if __name__ == "__main__": if is_test_framework: uModule = dynamicLoadModule(fileName) try: - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) ucase = case_class() if (json.dumps(updateCfgDict) == '{}'): updateCfgDict = ucase.updatecfgDict diff --git a/tests/test_new/test.py b/tests/test_new/test.py index bf66b6a765..cd0e60160c 100644 --- a/tests/test_new/test.py +++ b/tests/test_new/test.py @@ -24,7 +24,7 @@ import platform import socket import threading import importlib -import inspect +import ast print(f"Python version: {sys.version}") print(f"Version info: {sys.version_info}") @@ -59,13 +59,14 @@ def checkRunTimeError(): if hwnd: os.system("TASKKILL /F /IM taosd.exe") -def get_local_classes(module): - classes = [] - for name, obj in inspect.getmembers(module, inspect.isclass): - if inspect.getmodule(obj) == module: - classes.append(name) +def get_local_classes_in_order(file_path): + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + + classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] return classes + def dynamicLoadModule(fileName): moduleName = fileName.replace(".py", "").replace(os.sep, ".") return importlib.import_module(moduleName, package='..') @@ -80,8 +81,10 @@ def runOnPreviousCluster(host, config, fileName): sep = "/" if platform.system().lower() == 'windows': sep = os.sep + uModule = dynamicLoadModule(fileName) - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) case = case_class() # create conn @@ -364,7 +367,8 @@ if __name__ == "__main__": if is_test_framework: uModule = dynamicLoadModule(fileName) try: - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) ucase = case_class() if ((json.dumps(updateCfgDict) == '{}') and hasattr(ucase, 'updatecfgDict')): updateCfgDict = ucase.updatecfgDict @@ -536,7 +540,8 @@ if __name__ == "__main__": if is_test_framework: uModule = dynamicLoadModule(fileName) try: - case_class = getattr(uModule, get_local_classes(uModule)[-1]) + class_names = get_local_classes_in_order(fileName) + case_class = getattr(uModule, class_names[-1]) ucase = case_class() if (json.dumps(updateCfgDict) == '{}'): updateCfgDict = ucase.updatecfgDict