diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index c5009b898..a85590b28 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -15,6 +15,7 @@ from codeflash.context.code_context_extractor import ( extract_imports_for_class, get_code_optimization_context, get_external_base_class_inits, + get_external_class_inits, get_imported_class_definitions, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -4620,3 +4621,134 @@ class MyClass: # counter should be in context since __init__ uses it read_writable = code_ctx.read_writable_code.markdown assert "counter" in read_writable + + +def test_get_external_class_inits_extracts_click_option(tmp_path: Path) -> None: + """Extracts __init__ from click.Option when directly imported.""" + code = """from click import Option + +def my_func(opt: Option) -> None: + pass +""" + code_path = tmp_path / "myfunc.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_class_inits(context, tmp_path) + + assert len(result.code_strings) == 1 + code_string = result.code_strings[0] + assert "class Option:" in code_string.code + assert "def __init__" in code_string.code + assert "click" in code_string.file_path.as_posix() + + +def test_get_external_class_inits_skips_project_classes(tmp_path: Path) -> None: + """Returns empty when imported class is from the project, not external.""" + # Create a project module with a class + (tmp_path / "mymodule.py").write_text("class ProjectClass:\n pass\n", encoding="utf-8") + + code = """from mymodule import ProjectClass + +def my_func(obj: ProjectClass) -> None: + pass +""" + code_path = tmp_path / "myfunc.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_class_inits(context, tmp_path) + + assert result.code_strings == [] + + +def test_get_external_class_inits_skips_non_classes(tmp_path: Path) -> None: + """Returns empty when imported name is a function, not a class.""" + code = """from collections import OrderedDict +from os.path import join + +def my_func() -> None: + pass +""" + code_path = tmp_path / "myfunc.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_class_inits(context, tmp_path) + + # join is a function, not a class — should be skipped + # OrderedDict is a class and should be included + class_names = [cs.code.split("\n")[0] for cs in result.code_strings] + assert not any("join" in name for name in class_names) + + +def test_get_external_class_inits_skips_already_defined_classes(tmp_path: Path) -> None: + """Skips classes already defined in the context (e.g., added by get_imported_class_definitions).""" + code = """from collections import UserDict + +class UserDict: + def __init__(self): + pass + +def my_func(d: UserDict) -> None: + pass +""" + code_path = tmp_path / "myfunc.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_class_inits(context, tmp_path) + + # UserDict is already defined in the context, so it should be skipped + assert result.code_strings == [] + + +def test_get_external_class_inits_skips_builtins(tmp_path: Path) -> None: + """Returns empty for builtin classes like list/dict that have no inspectable source.""" + code = """x: list = [] +y: dict = {} + +def my_func() -> None: + pass +""" + code_path = tmp_path / "myfunc.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_class_inits(context, tmp_path) + + assert result.code_strings == [] + + +def test_get_external_class_inits_skips_object_init(tmp_path: Path) -> None: + """Skips classes whose __init__ is just object.__init__ (trivial).""" + # enum.Enum has a metaclass-based __init__, but individual enum members + # effectively use object.__init__. Use a class we know has object.__init__. + code = """from xml.etree.ElementTree import QName + +def my_func(q: QName) -> None: + pass +""" + code_path = tmp_path / "myfunc.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_class_inits(context, tmp_path) + + # QName has its own __init__, so it should be included if it's in site-packages. + # But since it's stdlib (not site-packages), it should be skipped. + assert result.code_strings == [] + + +def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None: + """Returns empty when there are no from-imports.""" + code = """def my_func() -> None: + pass +""" + code_path = tmp_path / "myfunc.py" + code_path.write_text(code, encoding="utf-8") + + context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) + result = get_external_class_inits(context, tmp_path) + + assert result.code_strings == []