test: add unit tests for get_external_class_inits
Tests cover: extracting __init__ from site-packages classes (click.Option), skipping project classes, non-classes, already-defined classes, builtins, classes with trivial object.__init__, and empty import scenarios.
This commit is contained in:
parent
5449a32ade
commit
f4c0208f49
1 changed files with 132 additions and 0 deletions
|
|
@ -15,6 +15,7 @@ from codeflash.context.code_context_extractor import (
|
|||
extract_imports_for_class,
|
||||
get_code_optimization_context,
|
||||
get_external_base_class_inits,
|
||||
get_external_class_inits,
|
||||
get_imported_class_definitions,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -4620,3 +4621,134 @@ class MyClass:
|
|||
# counter should be in context since __init__ uses it
|
||||
read_writable = code_ctx.read_writable_code.markdown
|
||||
assert "counter" in read_writable
|
||||
|
||||
|
||||
def test_get_external_class_inits_extracts_click_option(tmp_path: Path) -> None:
|
||||
"""Extracts __init__ from click.Option when directly imported."""
|
||||
code = """from click import Option
|
||||
|
||||
def my_func(opt: Option) -> None:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myfunc.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
code_string = result.code_strings[0]
|
||||
assert "class Option:" in code_string.code
|
||||
assert "def __init__" in code_string.code
|
||||
assert "click" in code_string.file_path.as_posix()
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_project_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when imported class is from the project, not external."""
|
||||
# Create a project module with a class
|
||||
(tmp_path / "mymodule.py").write_text("class ProjectClass:\n pass\n", encoding="utf-8")
|
||||
|
||||
code = """from mymodule import ProjectClass
|
||||
|
||||
def my_func(obj: ProjectClass) -> None:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myfunc.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_non_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when imported name is a function, not a class."""
|
||||
code = """from collections import OrderedDict
|
||||
from os.path import join
|
||||
|
||||
def my_func() -> None:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myfunc.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
# join is a function, not a class — should be skipped
|
||||
# OrderedDict is a class and should be included
|
||||
class_names = [cs.code.split("\n")[0] for cs in result.code_strings]
|
||||
assert not any("join" in name for name in class_names)
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_already_defined_classes(tmp_path: Path) -> None:
|
||||
"""Skips classes already defined in the context (e.g., added by get_imported_class_definitions)."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
class UserDict:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def my_func(d: UserDict) -> None:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myfunc.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
# UserDict is already defined in the context, so it should be skipped
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_builtins(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin classes like list/dict that have no inspectable source."""
|
||||
code = """x: list = []
|
||||
y: dict = {}
|
||||
|
||||
def my_func() -> None:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myfunc.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_object_init(tmp_path: Path) -> None:
|
||||
"""Skips classes whose __init__ is just object.__init__ (trivial)."""
|
||||
# enum.Enum has a metaclass-based __init__, but individual enum members
|
||||
# effectively use object.__init__. Use a class we know has object.__init__.
|
||||
code = """from xml.etree.ElementTree import QName
|
||||
|
||||
def my_func(q: QName) -> None:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myfunc.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
# QName has its own __init__, so it should be included if it's in site-packages.
|
||||
# But since it's stdlib (not site-packages), it should be skipped.
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
|
||||
"""Returns empty when there are no from-imports."""
|
||||
code = """def my_func() -> None:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myfunc.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
|
|
|||
Loading…
Reference in a new issue