add tests
This commit is contained in:
parent
69740f0340
commit
65ff392d20
1 changed files with 575 additions and 0 deletions
|
|
@ -4075,3 +4075,578 @@ def reify_channel_message(data: dict) -> MessageIn:
|
|||
```
|
||||
"""
|
||||
assert code_ctx.read_writable_code.markdown.strip() == expected_read_writable.strip()
|
||||
|
||||
|
||||
def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None:
|
||||
"""Test that external base class __init__ methods are included in testgen context.
|
||||
|
||||
This covers line 65 in code_context_extractor.py where external_base_inits.code_strings
|
||||
are appended to the testgen context when a class inherits from an external library.
|
||||
"""
|
||||
code = """from collections import UserDict
|
||||
|
||||
class MyCustomDict(UserDict):
|
||||
def target_method(self):
|
||||
return self.data
|
||||
"""
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyCustomDict", type="ClassDef")],
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
)
|
||||
|
||||
# The testgen context should include the UserDict __init__ method
|
||||
testgen_context = code_ctx.testgen_context.markdown
|
||||
assert "class UserDict:" in testgen_context, "UserDict class should be in testgen context"
|
||||
assert "def __init__" in testgen_context, "UserDict __init__ should be in testgen context"
|
||||
assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included"
|
||||
|
||||
|
||||
def test_read_only_code_removed_when_exceeds_limit(tmp_path: Path) -> None:
|
||||
"""Test read-only code is completely removed when it exceeds token limit even without docstrings.
|
||||
|
||||
This covers lines 152-153 in code_context_extractor.py where read_only_context_code is set
|
||||
to empty string when it still exceeds the token limit after docstring removal.
|
||||
"""
|
||||
# Create a second-degree helper with large implementation that has no docstrings
|
||||
# Second-degree helpers go into read-only context
|
||||
long_lines = [" x = 0"]
|
||||
for i in range(150):
|
||||
long_lines.append(f" x = x + {i}")
|
||||
long_lines.append(" return x")
|
||||
long_body = "\n".join(long_lines)
|
||||
|
||||
code = f"""
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_method(self):
|
||||
return first_helper()
|
||||
|
||||
|
||||
def first_helper():
|
||||
# First degree helper - calls second degree
|
||||
return second_helper()
|
||||
|
||||
|
||||
def second_helper():
|
||||
# Second degree helper - goes into read-only context
|
||||
{long_body}
|
||||
"""
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
)
|
||||
|
||||
# Use a small optim_token_limit that allows read-writable but not read-only
|
||||
# Read-writable is ~48 tokens, read-only is ~600 tokens
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
optim_token_limit=100, # Small limit to trigger read-only removal
|
||||
)
|
||||
|
||||
# The read-only context should be empty because it exceeded the limit
|
||||
assert code_ctx.read_only_context_code == "", "Read-only code should be removed when exceeding token limit"
|
||||
|
||||
|
||||
def test_testgen_removes_imported_classes_on_overflow(tmp_path: Path) -> None:
|
||||
"""Test testgen context removes imported class definitions when exceeding token limit.
|
||||
|
||||
This covers lines 176-186 in code_context_extractor.py where:
|
||||
- Testgen context exceeds limit (line 175)
|
||||
- Removing docstrings still exceeds (line 175 again)
|
||||
- Removing imported classes succeeds (line 177-183)
|
||||
"""
|
||||
# Create a package structure with a large type class used only in type annotations
|
||||
# This ensures get_imported_class_definitions extracts the full class
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a large class with methods that will be extracted via get_imported_class_definitions
|
||||
# Use methods WITHOUT docstrings so removing docstrings won't help much
|
||||
many_methods = "\n".join([f" def method_{i}(self):\n return {i}" for i in range(100)])
|
||||
type_class_code = f'''
|
||||
class TypeClass:
|
||||
"""A type class for annotations."""
|
||||
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
{many_methods}
|
||||
'''
|
||||
type_class_path = package_dir / "types.py"
|
||||
type_class_path.write_text(type_class_code, encoding="utf-8")
|
||||
|
||||
# Main module uses TypeClass only in annotation (not instantiated)
|
||||
# This triggers get_imported_class_definitions to extract the full class
|
||||
main_code = """
|
||||
from mypackage.types import TypeClass
|
||||
|
||||
def target_function(obj: TypeClass) -> int:
|
||||
return obj.value
|
||||
"""
|
||||
main_path = package_dir / "main.py"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_function",
|
||||
file_path=main_path,
|
||||
parents=[],
|
||||
)
|
||||
|
||||
# Use a testgen_token_limit that:
|
||||
# - Is exceeded by full context with imported class (~1500 tokens)
|
||||
# - Is exceeded even after removing docstrings
|
||||
# - But fits when imported class is removed (~40 tokens)
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
testgen_token_limit=200, # Small limit to trigger imported class removal
|
||||
)
|
||||
|
||||
# The testgen context should exist (didn't raise ValueError)
|
||||
testgen_context = code_ctx.testgen_context.markdown
|
||||
assert testgen_context, "Testgen context should not be empty"
|
||||
|
||||
# The target function should still be there
|
||||
assert "def target_function" in testgen_context, "Target function should be in testgen context"
|
||||
|
||||
# The large imported class should NOT be included (removed due to token limit)
|
||||
assert "class TypeClass" not in testgen_context, (
|
||||
"TypeClass should be removed from testgen context when exceeding token limit"
|
||||
)
|
||||
|
||||
|
||||
def test_testgen_raises_when_all_fallbacks_fail(tmp_path: Path) -> None:
|
||||
"""Test that ValueError is raised when testgen context exceeds limit even after all fallbacks.
|
||||
|
||||
This covers line 186 in code_context_extractor.py.
|
||||
"""
|
||||
# Create a function with a very long body that exceeds limits even without imports/docstrings
|
||||
long_lines = [" x = 0"]
|
||||
for i in range(200):
|
||||
long_lines.append(f" x = x + {i}")
|
||||
long_lines.append(" return x")
|
||||
long_body = "\n".join(long_lines)
|
||||
|
||||
code = f"""
|
||||
def target_function():
|
||||
{long_body}
|
||||
"""
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_function",
|
||||
file_path=file_path,
|
||||
parents=[],
|
||||
)
|
||||
|
||||
# Use a very small testgen_token_limit that cannot fit even the base function
|
||||
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit"):
|
||||
get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
testgen_token_limit=50, # Very small limit
|
||||
)
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_attribute_base(tmp_path: Path) -> None:
|
||||
"""Test handling of base class accessed as module.ClassName (ast.Attribute).
|
||||
|
||||
This covers line 616 in code_context_extractor.py.
|
||||
"""
|
||||
# Use the standard import style which the code actually handles
|
||||
code = """from collections import UserDict
|
||||
|
||||
class MyDict(UserDict):
|
||||
def custom_method(self):
|
||||
return self.data
|
||||
"""
|
||||
code_path = tmp_path / "mydict.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
|
||||
# Should extract UserDict __init__
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class UserDict:" in result.code_strings[0].code
|
||||
assert "def __init__" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_no_init_method(tmp_path: Path) -> None:
|
||||
"""Test handling when base class has no __init__ method.
|
||||
|
||||
This covers line 641 in code_context_extractor.py.
|
||||
"""
|
||||
# Create a class inheriting from a class that doesn't have inspectable __init__
|
||||
code = """from typing import Protocol
|
||||
|
||||
class MyProtocol(Protocol):
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myproto.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
|
||||
# Protocol's __init__ can't be easily inspected, should handle gracefully
|
||||
# Result may be empty or contain Protocol based on implementation
|
||||
assert isinstance(result.code_strings, list)
|
||||
|
||||
|
||||
def test_collect_names_from_annotation_attribute(tmp_path: Path) -> None:
|
||||
"""Test collect_names_from_annotation handles ast.Attribute annotations.
|
||||
|
||||
This covers line 756 in code_context_extractor.py.
|
||||
"""
|
||||
# Use __import__ to avoid polluting the test file's detected imports
|
||||
ast_mod = __import__("ast")
|
||||
|
||||
# Parse code with type annotation using attribute access
|
||||
code = "x: typing.List[int] = []"
|
||||
tree = ast_mod.parse(code)
|
||||
names: set[str] = set()
|
||||
|
||||
# Find the annotation node
|
||||
for node in ast_mod.walk(tree):
|
||||
if isinstance(node, ast_mod.AnnAssign) and node.annotation:
|
||||
collect_names_from_annotation(node.annotation, names)
|
||||
break
|
||||
|
||||
assert "typing" in names
|
||||
|
||||
|
||||
def test_extract_imports_for_class_decorator_call_attribute(tmp_path: Path) -> None:
|
||||
"""Test extract_imports_for_class handles decorator calls with attribute access.
|
||||
|
||||
This covers lines 707-708 in code_context_extractor.py.
|
||||
"""
|
||||
ast_mod = __import__("ast")
|
||||
|
||||
code = """
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
class CachedClass:
|
||||
pass
|
||||
"""
|
||||
tree = ast_mod.parse(code)
|
||||
|
||||
# Find the class node
|
||||
class_node = None
|
||||
for node in ast_mod.walk(tree):
|
||||
if isinstance(node, ast_mod.ClassDef):
|
||||
class_node = node
|
||||
break
|
||||
|
||||
assert class_node is not None
|
||||
result = extract_imports_for_class(tree, class_node, code)
|
||||
|
||||
# Should include the functools import
|
||||
assert "functools" in result
|
||||
|
||||
|
||||
def test_annotated_assignment_in_read_writable(tmp_path: Path) -> None:
|
||||
"""Test that annotated assignments used by target function are in read-writable context.
|
||||
|
||||
This covers lines 965-969 in code_context_extractor.py.
|
||||
"""
|
||||
code = """
|
||||
CONFIG_VALUE: int = 42
|
||||
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = CONFIG_VALUE
|
||||
|
||||
def target_method(self):
|
||||
return self.x
|
||||
"""
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
)
|
||||
|
||||
# CONFIG_VALUE should be in read-writable context since it's used by __init__
|
||||
read_writable = code_ctx.read_writable_code.markdown
|
||||
assert "CONFIG_VALUE" in read_writable
|
||||
|
||||
|
||||
def test_imported_class_definitions_module_path_none(tmp_path: Path) -> None:
|
||||
"""Test handling when module_path is None in get_imported_class_definitions.
|
||||
|
||||
This covers line 560 in code_context_extractor.py.
|
||||
"""
|
||||
# Create code that imports from a non-existent or unresolvable module
|
||||
code = """
|
||||
from nonexistent_module_xyz import SomeClass
|
||||
|
||||
class MyClass:
|
||||
def method(self, obj: SomeClass):
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "test.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
# Should handle gracefully and return empty or partial results
|
||||
assert isinstance(result.code_strings, list)
|
||||
|
||||
|
||||
def test_get_imported_names_import_star(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles import * correctly.
|
||||
|
||||
This covers lines 808-809 and 824-825 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
# Test regular import *
|
||||
# Note: "import *" is not valid Python, but "from x import *" is
|
||||
from_import_star = cst.parse_statement("from os import *")
|
||||
assert isinstance(from_import_star, cst.SimpleStatementLine)
|
||||
import_node = from_import_star.body[0]
|
||||
assert isinstance(import_node, cst.ImportFrom)
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert result == {"*"}
|
||||
|
||||
|
||||
def test_get_imported_names_aliased_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles aliased imports correctly.
|
||||
|
||||
This covers lines 812-813 and 828-829 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test import with alias
|
||||
import_stmt = cst.parse_statement("import numpy as np")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "np" in result
|
||||
|
||||
# Test from import with alias
|
||||
from_import_stmt = cst.parse_statement("from os import path as ospath")
|
||||
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
|
||||
from_import_node = from_import_stmt.body[0]
|
||||
assert isinstance(from_import_node, cst.ImportFrom)
|
||||
|
||||
result2 = get_imported_names(from_import_node)
|
||||
assert "ospath" in result2
|
||||
|
||||
|
||||
def test_get_imported_names_dotted_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles dotted imports correctly.
|
||||
|
||||
This covers lines 816-822 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test dotted import like "import os.path"
|
||||
import_stmt = cst.parse_statement("import os.path")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "os" in result
|
||||
|
||||
|
||||
def test_used_name_collector_comprehensive(tmp_path: Path) -> None:
|
||||
"""Test UsedNameCollector handles various node types.
|
||||
|
||||
This covers lines 767-801 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import UsedNameCollector
|
||||
|
||||
code = """
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
x: int = 1
|
||||
y = os.path.join("a", "b")
|
||||
|
||||
class MyClass:
|
||||
z = 10
|
||||
|
||||
def my_func():
|
||||
pass
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
collector = UsedNameCollector()
|
||||
# In libcst, the walker traverses the module
|
||||
cst.MetadataWrapper(module).visit(collector)
|
||||
|
||||
# Check used names
|
||||
assert "os" in collector.used_names
|
||||
assert "int" in collector.used_names
|
||||
assert "List" in collector.used_names
|
||||
|
||||
# Check defined names
|
||||
assert "x" in collector.defined_names
|
||||
assert "y" in collector.defined_names
|
||||
assert "MyClass" in collector.defined_names
|
||||
assert "my_func" in collector.defined_names
|
||||
|
||||
# Check external names (used but not defined)
|
||||
external = collector.get_external_names()
|
||||
assert "os" in external
|
||||
assert "x" not in external # x is defined
|
||||
|
||||
|
||||
def test_imported_class_with_base_in_same_module(tmp_path: Path) -> None:
|
||||
"""Test that imported classes with bases in the same module are extracted correctly.
|
||||
|
||||
This covers line 528 in code_context_extractor.py - early return for already extracted.
|
||||
"""
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a module with inheritance chain
|
||||
module_code = """
|
||||
class BaseClass:
|
||||
def __init__(self):
|
||||
self.base = True
|
||||
|
||||
class MiddleClass(BaseClass):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.middle = True
|
||||
|
||||
class DerivedClass(MiddleClass):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.derived = True
|
||||
"""
|
||||
module_path = package_dir / "classes.py"
|
||||
module_path.write_text(module_code, encoding="utf-8")
|
||||
|
||||
# Main module imports and uses the derived class
|
||||
main_code = """
|
||||
from mypackage.classes import DerivedClass
|
||||
|
||||
def target_function(obj: DerivedClass) -> bool:
|
||||
return obj.derived
|
||||
"""
|
||||
main_path = package_dir / "main.py"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=main_code, file_path=main_path)])
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
# Should extract the inheritance chain
|
||||
all_code = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class BaseClass" in all_code or "class DerivedClass" in all_code
|
||||
|
||||
|
||||
def test_get_imported_names_from_import_without_alias(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles from imports without aliases.
|
||||
|
||||
This covers lines 830-831 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test from import without alias
|
||||
from_import_stmt = cst.parse_statement("from os import path, getcwd")
|
||||
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
|
||||
from_import_node = from_import_stmt.body[0]
|
||||
assert isinstance(from_import_node, cst.ImportFrom)
|
||||
|
||||
result = get_imported_names(from_import_node)
|
||||
assert "path" in result
|
||||
assert "getcwd" in result
|
||||
|
||||
|
||||
def test_get_imported_names_regular_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles regular imports.
|
||||
|
||||
This covers lines 814-815 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test regular import without alias
|
||||
import_stmt = cst.parse_statement("import json")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "json" in result
|
||||
|
||||
|
||||
def test_augmented_assignment_not_in_context(tmp_path: Path) -> None:
|
||||
"""Test that augmented assignments are handled but not included unless used.
|
||||
|
||||
This covers line 962-969 in code_context_extractor.py.
|
||||
"""
|
||||
code = """
|
||||
counter = 0
|
||||
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
global counter
|
||||
counter += 1
|
||||
|
||||
def target_method(self):
|
||||
return 42
|
||||
"""
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
)
|
||||
|
||||
# counter should be in context since __init__ uses it
|
||||
read_writable = code_ctx.read_writable_code.markdown
|
||||
assert "counter" in read_writable
|
||||
|
|
|
|||
Loading…
Reference in a new issue