test: add unit tests for get_external_class_inits

Tests cover: extracting __init__ from site-packages classes (click.Option),
skipping project classes, non-classes, already-defined classes, builtins,
classes with trivial object.__init__, and empty import scenarios.
This commit is contained in:
Kevin Turcios 2026-02-13 09:03:09 -05:00
parent 5449a32ade
commit f4c0208f49

View file

@ -15,6 +15,7 @@ from codeflash.context.code_context_extractor import (
extract_imports_for_class, extract_imports_for_class,
get_code_optimization_context, get_code_optimization_context,
get_external_base_class_inits, get_external_base_class_inits,
get_external_class_inits,
get_imported_class_definitions, get_imported_class_definitions,
) )
from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -4620,3 +4621,134 @@ class MyClass:
# counter should be in context since __init__ uses it # counter should be in context since __init__ uses it
read_writable = code_ctx.read_writable_code.markdown read_writable = code_ctx.read_writable_code.markdown
assert "counter" in read_writable assert "counter" in read_writable
def test_get_external_class_inits_extracts_click_option(tmp_path: Path) -> None:
"""Extracts __init__ from click.Option when directly imported."""
code = """from click import Option
def my_func(opt: Option) -> None:
pass
"""
code_path = tmp_path / "myfunc.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
assert len(result.code_strings) == 1
code_string = result.code_strings[0]
assert "class Option:" in code_string.code
assert "def __init__" in code_string.code
assert "click" in code_string.file_path.as_posix()
def test_get_external_class_inits_skips_project_classes(tmp_path: Path) -> None:
"""Returns empty when imported class is from the project, not external."""
# Create a project module with a class
(tmp_path / "mymodule.py").write_text("class ProjectClass:\n pass\n", encoding="utf-8")
code = """from mymodule import ProjectClass
def my_func(obj: ProjectClass) -> None:
pass
"""
code_path = tmp_path / "myfunc.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
assert result.code_strings == []
def test_get_external_class_inits_skips_non_classes(tmp_path: Path) -> None:
"""Returns empty when imported name is a function, not a class."""
code = """from collections import OrderedDict
from os.path import join
def my_func() -> None:
pass
"""
code_path = tmp_path / "myfunc.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
# join is a function, not a class — should be skipped
# OrderedDict is a class and should be included
class_names = [cs.code.split("\n")[0] for cs in result.code_strings]
assert not any("join" in name for name in class_names)
def test_get_external_class_inits_skips_already_defined_classes(tmp_path: Path) -> None:
"""Skips classes already defined in the context (e.g., added by get_imported_class_definitions)."""
code = """from collections import UserDict
class UserDict:
def __init__(self):
pass
def my_func(d: UserDict) -> None:
pass
"""
code_path = tmp_path / "myfunc.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
# UserDict is already defined in the context, so it should be skipped
assert result.code_strings == []
def test_get_external_class_inits_skips_builtins(tmp_path: Path) -> None:
"""Returns empty for builtin classes like list/dict that have no inspectable source."""
code = """x: list = []
y: dict = {}
def my_func() -> None:
pass
"""
code_path = tmp_path / "myfunc.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
assert result.code_strings == []
def test_get_external_class_inits_skips_object_init(tmp_path: Path) -> None:
"""Skips classes whose __init__ is just object.__init__ (trivial)."""
# enum.Enum has a metaclass-based __init__, but individual enum members
# effectively use object.__init__. Use a class we know has object.__init__.
code = """from xml.etree.ElementTree import QName
def my_func(q: QName) -> None:
pass
"""
code_path = tmp_path / "myfunc.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
# QName has its own __init__, so it should be included if it's in site-packages.
# But since it's stdlib (not site-packages), it should be skipped.
assert result.code_strings == []
def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
"""Returns empty when there are no from-imports."""
code = """def my_func() -> None:
pass
"""
code_path = tmp_path / "myfunc.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
assert result.code_strings == []