From d48006106400044b32442e741f8bf343fa95fd91 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 18 Feb 2026 08:26:37 -0500 Subject: [PATCH] temp --- .../python/context/code_context_extractor.py | 55 ++++++-- tests/test_code_context_extractor.py | 125 ++++++++++++++++++ 2 files changed, 172 insertions(+), 8 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 6b101e350..b49fa6d1a 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -752,17 +752,30 @@ def extract_init_stub_from_class(class_name: str, module_source: str, module_tre if class_node is None: return None - init_node = None + lines = module_source.splitlines() + relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = [] for item in class_node.body: - if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__": - init_node = item - break - if init_node is None: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + if item.name in ("__init__", "__post_init__"): + relevant_nodes.append(item) + elif any( + isinstance(d, ast.Name) and d.id == "property" + or isinstance(d, ast.Attribute) and d.attr == "property" + for d in item.decorator_list + ): + relevant_nodes.append(item) + + if not relevant_nodes: return None - lines = module_source.splitlines() - init_source = "\n".join(lines[init_node.lineno - 1 : init_node.end_lineno]) - return f"class {class_name}:\n{init_source}" + snippets: list[str] = [] + for node in relevant_nodes: + start = node.lineno + if node.decorator_list: + start = min(d.lineno for d in node.decorator_list) + snippets.append("\n".join(lines[start - 1 : node.end_lineno])) + + return f"class {class_name}:\n" + "\n".join(snippets) def extract_parameter_type_constructors( @@ -844,6 +857,27 @@ def extract_parameter_type_constructors( return CodeStringsMarkdown(code_strings=code_strings) +def resolve_instance_class_name(name: str, module_tree: ast.Module) -> str | None: + for node in module_tree.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == name: + value = node.value + if isinstance(value, ast.Call): + func = value.func + if isinstance(func, ast.Name): + return func.id + if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): + return func.value.id + elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == name: + ann = node.annotation + if isinstance(ann, ast.Name): + return ann.id + if isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name): + return ann.value.id + return None + + def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: import jedi @@ -938,6 +972,11 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: extract_class_and_bases(name, module_path, module_source, module_tree) + if (module_path, name) not in extracted_classes: + resolved_class = resolve_instance_class_name(name, module_tree) + if resolved_class and resolved_class not in existing_classes: + extract_class_and_bases(resolved_class, module_path, module_source, module_tree) + except Exception: logger.debug(f"Error extracting class definition for {name} from {module_name}") continue diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 0e2838554..83a5c3145 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -18,6 +18,7 @@ from codeflash.languages.python.context.code_context_extractor import ( extract_init_stub_from_class, extract_parameter_type_constructors, get_code_optimization_context, + resolve_instance_class_name, ) from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent from codeflash.optimization.optimizer import Optimizer @@ -4339,3 +4340,127 @@ def process(c: Config) -> str: ) result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) assert len(result.code_strings) == 0 + + +# --- Tests for resolve_instance_class_name --- + + +def test_resolve_instance_class_name_direct_call() -> None: + source = "config = MyConfig(debug=True)" + tree = ast.parse(source) + assert resolve_instance_class_name("config", tree) == "MyConfig" + + +def test_resolve_instance_class_name_annotated() -> None: + source = "config: MyConfig = load()" + tree = ast.parse(source) + assert resolve_instance_class_name("config", tree) == "MyConfig" + + +def test_resolve_instance_class_name_factory_method() -> None: + source = "config = MyConfig.from_env()" + tree = ast.parse(source) + assert resolve_instance_class_name("config", tree) == "MyConfig" + + +def test_resolve_instance_class_name_no_match() -> None: + source = "x = 42" + tree = ast.parse(source) + assert resolve_instance_class_name("x", tree) is None + + +def test_resolve_instance_class_name_missing_variable() -> None: + source = "config = MyConfig()" + tree = ast.parse(source) + assert resolve_instance_class_name("other", tree) is None + + +# --- Tests for enhanced extract_init_stub_from_class --- + + +def test_extract_init_stub_includes_post_init() -> None: + source = """\ +class MyDataclass: + def __init__(self, x: int): + self.x = x + def __post_init__(self): + self.y = self.x * 2 +""" + tree = ast.parse(source) + stub = extract_init_stub_from_class("MyDataclass", source, tree) + assert stub is not None + assert "class MyDataclass:" in stub + assert "def __init__" in stub + assert "def __post_init__" in stub + assert "self.y = self.x * 2" in stub + + +def test_extract_init_stub_includes_properties() -> None: + source = """\ +class MyClass: + def __init__(self, name: str): + self._name = name + @property + def name(self) -> str: + return self._name +""" + tree = ast.parse(source) + stub = extract_init_stub_from_class("MyClass", source, tree) + assert stub is not None + assert "def __init__" in stub + assert "@property" in stub + assert "def name" in stub + + +def test_extract_init_stub_property_only_class() -> None: + source = """\ +class ReadOnly: + @property + def value(self) -> int: + return 42 +""" + tree = ast.parse(source) + stub = extract_init_stub_from_class("ReadOnly", source, tree) + assert stub is not None + assert "class ReadOnly:" in stub + assert "@property" in stub + assert "def value" in stub + + +# --- Tests for enrich_testgen_context resolving instances --- + + +def test_enrich_testgen_context_resolves_instance_to_class(tmp_path: Path) -> None: + package_dir = tmp_path / "mypkg" + package_dir.mkdir() + (package_dir / "__init__.py").write_text("", encoding="utf-8") + + config_module = """\ +class AppConfig: + def __init__(self, debug: bool = False): + self.debug = debug + + @property + def log_level(self) -> str: + return "DEBUG" if self.debug else "INFO" + +app_config = AppConfig(debug=True) +""" + (package_dir / "config.py").write_text(config_module, encoding="utf-8") + + consumer_code = """\ +from mypkg.config import app_config + +def get_log_level() -> str: + return app_config.log_level +""" + consumer_path = package_dir / "consumer.py" + consumer_path.write_text(consumer_code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=consumer_code, file_path=consumer_path)]) + result = enrich_testgen_context(context, tmp_path) + + assert len(result.code_strings) >= 1 + combined = "\n".join(cs.code for cs in result.code_strings) + assert "class AppConfig:" in combined + assert "@property" in combined