add tests

This commit is contained in:
Kevin Turcios 2026-01-24 10:14:54 -05:00
parent 69740f0340
commit 65ff392d20

View file

@ -4075,3 +4075,578 @@ def reify_channel_message(data: dict) -> MessageIn:
```
"""
assert code_ctx.read_writable_code.markdown.strip() == expected_read_writable.strip()
def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None:
"""Test that external base class __init__ methods are included in testgen context.
This covers line 65 in code_context_extractor.py where external_base_inits.code_strings
are appended to the testgen context when a class inherits from an external library.
"""
code = """from collections import UserDict
class MyCustomDict(UserDict):
def target_method(self):
return self.data
"""
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyCustomDict", type="ClassDef")],
)
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
)
# The testgen context should include the UserDict __init__ method
testgen_context = code_ctx.testgen_context.markdown
assert "class UserDict:" in testgen_context, "UserDict class should be in testgen context"
assert "def __init__" in testgen_context, "UserDict __init__ should be in testgen context"
assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included"
def test_read_only_code_removed_when_exceeds_limit(tmp_path: Path) -> None:
"""Test read-only code is completely removed when it exceeds token limit even without docstrings.
This covers lines 152-153 in code_context_extractor.py where read_only_context_code is set
to empty string when it still exceeds the token limit after docstring removal.
"""
# Create a second-degree helper with large implementation that has no docstrings
# Second-degree helpers go into read-only context
long_lines = [" x = 0"]
for i in range(150):
long_lines.append(f" x = x + {i}")
long_lines.append(" return x")
long_body = "\n".join(long_lines)
code = f"""
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
return first_helper()
def first_helper():
# First degree helper - calls second degree
return second_helper()
def second_helper():
# Second degree helper - goes into read-only context
{long_body}
"""
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
)
# Use a small optim_token_limit that allows read-writable but not read-only
# Read-writable is ~48 tokens, read-only is ~600 tokens
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
optim_token_limit=100, # Small limit to trigger read-only removal
)
# The read-only context should be empty because it exceeded the limit
assert code_ctx.read_only_context_code == "", "Read-only code should be removed when exceeding token limit"
def test_testgen_removes_imported_classes_on_overflow(tmp_path: Path) -> None:
"""Test testgen context removes imported class definitions when exceeding token limit.
This covers lines 176-186 in code_context_extractor.py where:
- Testgen context exceeds limit (line 175)
- Removing docstrings still exceeds (line 175 again)
- Removing imported classes succeeds (line 177-183)
"""
# Create a package structure with a large type class used only in type annotations
# This ensures get_imported_class_definitions extracts the full class
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a large class with methods that will be extracted via get_imported_class_definitions
# Use methods WITHOUT docstrings so removing docstrings won't help much
many_methods = "\n".join([f" def method_{i}(self):\n return {i}" for i in range(100)])
type_class_code = f'''
class TypeClass:
"""A type class for annotations."""
def __init__(self, value: int):
self.value = value
{many_methods}
'''
type_class_path = package_dir / "types.py"
type_class_path.write_text(type_class_code, encoding="utf-8")
# Main module uses TypeClass only in annotation (not instantiated)
# This triggers get_imported_class_definitions to extract the full class
main_code = """
from mypackage.types import TypeClass
def target_function(obj: TypeClass) -> int:
return obj.value
"""
main_path = package_dir / "main.py"
main_path.write_text(main_code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_function",
file_path=main_path,
parents=[],
)
# Use a testgen_token_limit that:
# - Is exceeded by full context with imported class (~1500 tokens)
# - Is exceeded even after removing docstrings
# - But fits when imported class is removed (~40 tokens)
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
testgen_token_limit=200, # Small limit to trigger imported class removal
)
# The testgen context should exist (didn't raise ValueError)
testgen_context = code_ctx.testgen_context.markdown
assert testgen_context, "Testgen context should not be empty"
# The target function should still be there
assert "def target_function" in testgen_context, "Target function should be in testgen context"
# The large imported class should NOT be included (removed due to token limit)
assert "class TypeClass" not in testgen_context, (
"TypeClass should be removed from testgen context when exceeding token limit"
)
def test_testgen_raises_when_all_fallbacks_fail(tmp_path: Path) -> None:
"""Test that ValueError is raised when testgen context exceeds limit even after all fallbacks.
This covers line 186 in code_context_extractor.py.
"""
# Create a function with a very long body that exceeds limits even without imports/docstrings
long_lines = [" x = 0"]
for i in range(200):
long_lines.append(f" x = x + {i}")
long_lines.append(" return x")
long_body = "\n".join(long_lines)
code = f"""
def target_function():
{long_body}
"""
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_function",
file_path=file_path,
parents=[],
)
# Use a very small testgen_token_limit that cannot fit even the base function
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit"):
get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
testgen_token_limit=50, # Very small limit
)
def test_get_external_base_class_inits_attribute_base(tmp_path: Path) -> None:
"""Test handling of base class accessed as module.ClassName (ast.Attribute).
This covers line 616 in code_context_extractor.py.
"""
# Use the standard import style which the code actually handles
code = """from collections import UserDict
class MyDict(UserDict):
def custom_method(self):
return self.data
"""
code_path = tmp_path / "mydict.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
# Should extract UserDict __init__
assert len(result.code_strings) == 1
assert "class UserDict:" in result.code_strings[0].code
assert "def __init__" in result.code_strings[0].code
def test_get_external_base_class_inits_no_init_method(tmp_path: Path) -> None:
"""Test handling when base class has no __init__ method.
This covers line 641 in code_context_extractor.py.
"""
# Create a class inheriting from a class that doesn't have inspectable __init__
code = """from typing import Protocol
class MyProtocol(Protocol):
pass
"""
code_path = tmp_path / "myproto.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
# Protocol's __init__ can't be easily inspected, should handle gracefully
# Result may be empty or contain Protocol based on implementation
assert isinstance(result.code_strings, list)
def test_collect_names_from_annotation_attribute(tmp_path: Path) -> None:
"""Test collect_names_from_annotation handles ast.Attribute annotations.
This covers line 756 in code_context_extractor.py.
"""
# Use __import__ to avoid polluting the test file's detected imports
ast_mod = __import__("ast")
# Parse code with type annotation using attribute access
code = "x: typing.List[int] = []"
tree = ast_mod.parse(code)
names: set[str] = set()
# Find the annotation node
for node in ast_mod.walk(tree):
if isinstance(node, ast_mod.AnnAssign) and node.annotation:
collect_names_from_annotation(node.annotation, names)
break
assert "typing" in names
def test_extract_imports_for_class_decorator_call_attribute(tmp_path: Path) -> None:
"""Test extract_imports_for_class handles decorator calls with attribute access.
This covers lines 707-708 in code_context_extractor.py.
"""
ast_mod = __import__("ast")
code = """
import functools
@functools.lru_cache(maxsize=128)
class CachedClass:
pass
"""
tree = ast_mod.parse(code)
# Find the class node
class_node = None
for node in ast_mod.walk(tree):
if isinstance(node, ast_mod.ClassDef):
class_node = node
break
assert class_node is not None
result = extract_imports_for_class(tree, class_node, code)
# Should include the functools import
assert "functools" in result
def test_annotated_assignment_in_read_writable(tmp_path: Path) -> None:
"""Test that annotated assignments used by target function are in read-writable context.
This covers lines 965-969 in code_context_extractor.py.
"""
code = """
CONFIG_VALUE: int = 42
class MyClass:
def __init__(self):
self.x = CONFIG_VALUE
def target_method(self):
return self.x
"""
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
)
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
)
# CONFIG_VALUE should be in read-writable context since it's used by __init__
read_writable = code_ctx.read_writable_code.markdown
assert "CONFIG_VALUE" in read_writable
def test_imported_class_definitions_module_path_none(tmp_path: Path) -> None:
"""Test handling when module_path is None in get_imported_class_definitions.
This covers line 560 in code_context_extractor.py.
"""
# Create code that imports from a non-existent or unresolvable module
code = """
from nonexistent_module_xyz import SomeClass
class MyClass:
def method(self, obj: SomeClass):
pass
"""
code_path = tmp_path / "test.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_imported_class_definitions(context, tmp_path)
# Should handle gracefully and return empty or partial results
assert isinstance(result.code_strings, list)
def test_get_imported_names_import_star(tmp_path: Path) -> None:
"""Test get_imported_names handles import * correctly.
This covers lines 808-809 and 824-825 in code_context_extractor.py.
"""
import libcst as cst
# Test regular import *
# Note: "import *" is not valid Python, but "from x import *" is
from_import_star = cst.parse_statement("from os import *")
assert isinstance(from_import_star, cst.SimpleStatementLine)
import_node = from_import_star.body[0]
assert isinstance(import_node, cst.ImportFrom)
from codeflash.context.code_context_extractor import get_imported_names
result = get_imported_names(import_node)
assert result == {"*"}
def test_get_imported_names_aliased_import(tmp_path: Path) -> None:
"""Test get_imported_names handles aliased imports correctly.
This covers lines 812-813 and 828-829 in code_context_extractor.py.
"""
import libcst as cst
from codeflash.context.code_context_extractor import get_imported_names
# Test import with alias
import_stmt = cst.parse_statement("import numpy as np")
assert isinstance(import_stmt, cst.SimpleStatementLine)
import_node = import_stmt.body[0]
assert isinstance(import_node, cst.Import)
result = get_imported_names(import_node)
assert "np" in result
# Test from import with alias
from_import_stmt = cst.parse_statement("from os import path as ospath")
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
from_import_node = from_import_stmt.body[0]
assert isinstance(from_import_node, cst.ImportFrom)
result2 = get_imported_names(from_import_node)
assert "ospath" in result2
def test_get_imported_names_dotted_import(tmp_path: Path) -> None:
"""Test get_imported_names handles dotted imports correctly.
This covers lines 816-822 in code_context_extractor.py.
"""
import libcst as cst
from codeflash.context.code_context_extractor import get_imported_names
# Test dotted import like "import os.path"
import_stmt = cst.parse_statement("import os.path")
assert isinstance(import_stmt, cst.SimpleStatementLine)
import_node = import_stmt.body[0]
assert isinstance(import_node, cst.Import)
result = get_imported_names(import_node)
assert "os" in result
def test_used_name_collector_comprehensive(tmp_path: Path) -> None:
"""Test UsedNameCollector handles various node types.
This covers lines 767-801 in code_context_extractor.py.
"""
import libcst as cst
from codeflash.context.code_context_extractor import UsedNameCollector
code = """
import os
from typing import List
x: int = 1
y = os.path.join("a", "b")
class MyClass:
z = 10
def my_func():
pass
"""
module = cst.parse_module(code)
collector = UsedNameCollector()
# In libcst, the walker traverses the module
cst.MetadataWrapper(module).visit(collector)
# Check used names
assert "os" in collector.used_names
assert "int" in collector.used_names
assert "List" in collector.used_names
# Check defined names
assert "x" in collector.defined_names
assert "y" in collector.defined_names
assert "MyClass" in collector.defined_names
assert "my_func" in collector.defined_names
# Check external names (used but not defined)
external = collector.get_external_names()
assert "os" in external
assert "x" not in external # x is defined
def test_imported_class_with_base_in_same_module(tmp_path: Path) -> None:
"""Test that imported classes with bases in the same module are extracted correctly.
This covers line 528 in code_context_extractor.py - early return for already extracted.
"""
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with inheritance chain
module_code = """
class BaseClass:
def __init__(self):
self.base = True
class MiddleClass(BaseClass):
def __init__(self):
super().__init__()
self.middle = True
class DerivedClass(MiddleClass):
def __init__(self):
super().__init__()
self.derived = True
"""
module_path = package_dir / "classes.py"
module_path.write_text(module_code, encoding="utf-8")
# Main module imports and uses the derived class
main_code = """
from mypackage.classes import DerivedClass
def target_function(obj: DerivedClass) -> bool:
return obj.derived
"""
main_path = package_dir / "main.py"
main_path.write_text(main_code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=main_code, file_path=main_path)])
result = get_imported_class_definitions(context, tmp_path)
# Should extract the inheritance chain
all_code = "\n".join(cs.code for cs in result.code_strings)
assert "class BaseClass" in all_code or "class DerivedClass" in all_code
def test_get_imported_names_from_import_without_alias(tmp_path: Path) -> None:
"""Test get_imported_names handles from imports without aliases.
This covers lines 830-831 in code_context_extractor.py.
"""
import libcst as cst
from codeflash.context.code_context_extractor import get_imported_names
# Test from import without alias
from_import_stmt = cst.parse_statement("from os import path, getcwd")
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
from_import_node = from_import_stmt.body[0]
assert isinstance(from_import_node, cst.ImportFrom)
result = get_imported_names(from_import_node)
assert "path" in result
assert "getcwd" in result
def test_get_imported_names_regular_import(tmp_path: Path) -> None:
"""Test get_imported_names handles regular imports.
This covers lines 814-815 in code_context_extractor.py.
"""
import libcst as cst
from codeflash.context.code_context_extractor import get_imported_names
# Test regular import without alias
import_stmt = cst.parse_statement("import json")
assert isinstance(import_stmt, cst.SimpleStatementLine)
import_node = import_stmt.body[0]
assert isinstance(import_node, cst.Import)
result = get_imported_names(import_node)
assert "json" in result
def test_augmented_assignment_not_in_context(tmp_path: Path) -> None:
"""Test that augmented assignments are handled but not included unless used.
This covers line 962-969 in code_context_extractor.py.
"""
code = """
counter = 0
class MyClass:
def __init__(self):
global counter
counter += 1
def target_method(self):
return 42
"""
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
)
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
)
# counter should be in context since __init__ uses it
read_writable = code_ctx.read_writable_code.markdown
assert "counter" in read_writable