mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
parent
2966e15775
commit
c1703a2d71
4 changed files with 31 additions and 628 deletions
1
.codex/skills/.gitignore
vendored
1
.codex/skills/.gitignore
vendored
|
|
@ -1,3 +1,2 @@
|
|||
# Managed by Tessl
|
||||
tessl:*
|
||||
tessl__*
|
||||
|
|
|
|||
1
.gemini/skills/.gitignore
vendored
1
.gemini/skills/.gitignore
vendored
|
|
@ -1,3 +1,2 @@
|
|||
# Managed by Tessl
|
||||
tessl:*
|
||||
tessl__*
|
||||
|
|
|
|||
|
|
@ -68,26 +68,15 @@ def build_testgen_context(
|
|||
if enrichment.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(code_strings=testgen_context.code_strings + enrichment.code_strings)
|
||||
|
||||
type_context_strings: list[CodeString] = []
|
||||
if function_to_optimize is not None:
|
||||
result = _parse_and_collect_imports(testgen_context)
|
||||
existing_classes = collect_existing_class_names(result[0]) if result else set()
|
||||
type_context = extract_type_context_for_testgen(function_to_optimize, project_root_path, existing_classes)
|
||||
if type_context.code_strings:
|
||||
constructor_stubs = extract_parameter_type_constructors(
|
||||
function_to_optimize, project_root_path, existing_classes
|
||||
)
|
||||
if constructor_stubs.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + type_context.code_strings
|
||||
)
|
||||
type_context_strings = type_context.code_strings
|
||||
|
||||
# Enrich field types from all newly extracted classes (enrichment + type context)
|
||||
new_classes = CodeStringsMarkdown(code_strings=enrichment.code_strings + type_context_strings)
|
||||
if new_classes.code_strings:
|
||||
updated_result = _parse_and_collect_imports(testgen_context)
|
||||
updated_existing = collect_existing_class_names(updated_result[0]) if updated_result else set()
|
||||
field_type_enrichment = enrich_type_context_classes(new_classes, updated_existing, project_root_path)
|
||||
if field_type_enrichment.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + field_type_enrichment.code_strings
|
||||
code_strings=testgen_context.code_strings + constructor_stubs.code_strings
|
||||
)
|
||||
|
||||
return testgen_context
|
||||
|
|
@ -755,48 +744,6 @@ def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]:
|
|||
return set()
|
||||
|
||||
|
||||
def collect_names_from_function_body(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
|
||||
names: set[str] = set()
|
||||
for node in ast.walk(func_node):
|
||||
if isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
names.add(node.func.id)
|
||||
elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
|
||||
names.add(node.func.value.id)
|
||||
if isinstance(node.func, ast.Name) and node.func.id in ("isinstance", "issubclass") and node.args:
|
||||
second_arg = node.args[1] if len(node.args) > 1 else None
|
||||
if isinstance(second_arg, ast.Name):
|
||||
names.add(second_arg.id)
|
||||
elif isinstance(second_arg, ast.Tuple):
|
||||
for elt in second_arg.elts:
|
||||
if isinstance(elt, ast.Name):
|
||||
names.add(elt.id)
|
||||
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
|
||||
names.add(node.value.id)
|
||||
return names
|
||||
|
||||
|
||||
def extract_full_class_from_module(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
|
||||
class_node = None
|
||||
# Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order)
|
||||
q: deque[ast.AST] = deque([module_tree])
|
||||
while q:
|
||||
candidate = q.popleft()
|
||||
if isinstance(candidate, ast.ClassDef) and candidate.name == class_name:
|
||||
class_node = candidate
|
||||
break
|
||||
q.extend(ast.iter_child_nodes(candidate))
|
||||
|
||||
if class_node is None:
|
||||
return None
|
||||
|
||||
lines = module_source.split("\n")
|
||||
start_line = class_node.lineno
|
||||
if class_node.decorator_list:
|
||||
start_line = min(d.lineno for d in class_node.decorator_list)
|
||||
return "\n".join(lines[start_line - 1 : class_node.end_lineno])
|
||||
|
||||
|
||||
def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
|
||||
class_node = None
|
||||
# Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order)
|
||||
|
|
@ -846,7 +793,7 @@ def extract_init_stub_from_class(class_name: str, module_source: str, module_tre
|
|||
return f"class {class_name}:\n" + "\n".join(snippets)
|
||||
|
||||
|
||||
def extract_type_context_for_testgen(
|
||||
def extract_parameter_type_constructors(
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path, existing_class_names: set[str]
|
||||
) -> CodeStringsMarkdown:
|
||||
import jedi
|
||||
|
|
@ -878,8 +825,6 @@ def extract_type_context_for_testgen(
|
|||
if func_node.args.kwarg:
|
||||
type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation)
|
||||
|
||||
type_names |= collect_names_from_function_body(func_node)
|
||||
|
||||
type_names -= BUILTIN_AND_TYPING_NAMES
|
||||
type_names -= existing_class_names
|
||||
if not type_names:
|
||||
|
|
@ -910,13 +855,6 @@ def extract_type_context_for_testgen(
|
|||
if not module_path:
|
||||
continue
|
||||
|
||||
resolved_module = module_path.resolve()
|
||||
module_str = str(resolved_module)
|
||||
is_project = module_str.startswith(str(project_root_path.resolve()))
|
||||
is_third_party = "site-packages" in module_str
|
||||
if not is_project and not is_third_party:
|
||||
continue
|
||||
|
||||
if module_path in module_cache:
|
||||
mod_source, mod_tree = module_cache[module_path]
|
||||
else:
|
||||
|
|
@ -924,15 +862,11 @@ def extract_type_context_for_testgen(
|
|||
mod_tree = ast.parse(mod_source)
|
||||
module_cache[module_path] = (mod_source, mod_tree)
|
||||
|
||||
class_source = extract_full_class_from_module(type_name, mod_source, mod_tree)
|
||||
if class_source is None:
|
||||
resolved_class = resolve_instance_class_name(type_name, mod_tree)
|
||||
if resolved_class and resolved_class not in existing_class_names:
|
||||
class_source = extract_full_class_from_module(resolved_class, mod_source, mod_tree)
|
||||
if class_source:
|
||||
code_strings.append(CodeString(code=class_source, file_path=module_path))
|
||||
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
|
||||
if stub:
|
||||
code_strings.append(CodeString(code=stub, file_path=module_path))
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting type context for {type_name} from {module_name}")
|
||||
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
|
||||
continue
|
||||
|
||||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
|
@ -978,7 +912,6 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
# --- Step 1: Project class definitions (jedi resolution + recursive base extraction) ---
|
||||
extracted_classes: set[tuple[Path, str]] = set()
|
||||
module_cache: dict[Path, tuple[str, ast.Module]] = {}
|
||||
module_import_maps: dict[Path, dict[str, str]] = {}
|
||||
|
||||
def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None:
|
||||
if module_path in module_cache:
|
||||
|
|
@ -992,22 +925,10 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
module_cache[module_path] = (module_source, module_tree)
|
||||
return module_source, module_tree
|
||||
|
||||
def get_module_import_map(module_path: Path, module_tree: ast.Module) -> dict[str, str]:
|
||||
if module_path in module_import_maps:
|
||||
return module_import_maps[module_path]
|
||||
import_map: dict[str, str] = {}
|
||||
for node in ast.walk(module_tree):
|
||||
if isinstance(node, ast.ImportFrom) and node.module:
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
import_map[name] = node.module
|
||||
module_import_maps[module_path] = import_map
|
||||
return import_map
|
||||
|
||||
def extract_class_and_bases(
|
||||
class_name: str, module_path: Path, module_source: str, module_tree: ast.Module, depth: int = 0
|
||||
class_name: str, module_path: Path, module_source: str, module_tree: ast.Module
|
||||
) -> None:
|
||||
if depth >= 3 or (module_path, class_name) in extracted_classes:
|
||||
if (module_path, class_name) in extracted_classes:
|
||||
return
|
||||
|
||||
class_node = None
|
||||
|
|
@ -1026,40 +947,8 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
elif isinstance(base, ast.Attribute):
|
||||
continue
|
||||
|
||||
if not base_name or base_name in existing_classes:
|
||||
continue
|
||||
|
||||
# Check if the base class is defined locally in this module
|
||||
local_base = any(isinstance(n, ast.ClassDef) and n.name == base_name for n in ast.walk(module_tree))
|
||||
if local_base:
|
||||
extract_class_and_bases(base_name, module_path, module_source, module_tree, depth + 1)
|
||||
else:
|
||||
# Resolve cross-module base class via imports
|
||||
import_map = get_module_import_map(module_path, module_tree)
|
||||
base_module_name = import_map.get(base_name)
|
||||
if not base_module_name:
|
||||
continue
|
||||
try:
|
||||
script_code = f"from {base_module_name} import {base_name}"
|
||||
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
|
||||
definitions = script.goto(
|
||||
1, len(f"from {base_module_name} import ") + len(base_name), follow_imports=True
|
||||
)
|
||||
if not definitions or not definitions[0].module_path:
|
||||
continue
|
||||
base_module_path = definitions[0].module_path
|
||||
resolved_str = str(base_module_path.resolve())
|
||||
is_project = resolved_str.startswith(str(project_root_path.resolve()) + os.sep)
|
||||
is_third_party = "site-packages" in resolved_str
|
||||
if not is_project and not is_third_party:
|
||||
continue
|
||||
base_mod = get_module_source_and_tree(base_module_path)
|
||||
if base_mod is None:
|
||||
continue
|
||||
extract_class_and_bases(base_name, base_module_path, base_mod[0], base_mod[1], depth + 1)
|
||||
except Exception:
|
||||
logger.debug(f"Error resolving cross-module base class {base_name} from {base_module_name}")
|
||||
continue
|
||||
if base_name and base_name not in existing_classes:
|
||||
extract_class_and_bases(base_name, module_path, module_source, module_tree)
|
||||
|
||||
if (module_path, class_name) in extracted_classes:
|
||||
return
|
||||
|
|
@ -1117,96 +1006,6 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
def enrich_type_context_classes(
|
||||
type_context: CodeStringsMarkdown, existing_class_names: set[str], project_root_path: Path
|
||||
) -> CodeStringsMarkdown:
|
||||
import jedi
|
||||
|
||||
code_strings: list[CodeString] = []
|
||||
emitted: set[str] = set()
|
||||
module_cache: dict[Path, tuple[str, ast.Module]] = {}
|
||||
|
||||
for cs in type_context.code_strings:
|
||||
try:
|
||||
snippet_tree = ast.parse(cs.code)
|
||||
except SyntaxError:
|
||||
continue
|
||||
|
||||
# Collect type names from field annotations of extracted classes
|
||||
type_names: set[str] = set()
|
||||
for node in ast.walk(snippet_tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
for item in node.body:
|
||||
if isinstance(item, ast.AnnAssign) and item.annotation:
|
||||
type_names |= collect_type_names_from_annotation(item.annotation)
|
||||
|
||||
type_names -= BUILTIN_AND_TYPING_NAMES
|
||||
type_names -= existing_class_names
|
||||
type_names -= emitted
|
||||
if not type_names:
|
||||
continue
|
||||
|
||||
# Build import map from the source file, not the snippet
|
||||
source_path = cs.file_path
|
||||
if not source_path:
|
||||
continue
|
||||
import_map: dict[str, str] = {}
|
||||
if source_path in module_cache:
|
||||
source_code, source_tree = module_cache[source_path]
|
||||
else:
|
||||
try:
|
||||
source_code = source_path.read_text(encoding="utf-8")
|
||||
source_tree = ast.parse(source_code)
|
||||
module_cache[source_path] = (source_code, source_tree)
|
||||
except Exception:
|
||||
continue
|
||||
for snode in ast.walk(source_tree):
|
||||
if isinstance(snode, ast.ImportFrom) and snode.module:
|
||||
for alias in snode.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
import_map[name] = snode.module
|
||||
|
||||
for type_name in sorted(type_names):
|
||||
module_name = import_map.get(type_name)
|
||||
if not module_name:
|
||||
continue
|
||||
try:
|
||||
script_code = f"from {module_name} import {type_name}"
|
||||
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
|
||||
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
|
||||
if not definitions or not definitions[0].module_path:
|
||||
continue
|
||||
|
||||
mod_path = definitions[0].module_path
|
||||
resolved_str = str(mod_path.resolve())
|
||||
is_project = resolved_str.startswith(str(project_root_path.resolve()))
|
||||
is_third_party = "site-packages" in resolved_str
|
||||
if not is_project and not is_third_party:
|
||||
continue
|
||||
|
||||
if mod_path in module_cache:
|
||||
mod_source, mod_tree = module_cache[mod_path]
|
||||
else:
|
||||
mod_source = mod_path.read_text(encoding="utf-8")
|
||||
mod_tree = ast.parse(mod_source)
|
||||
module_cache[mod_path] = (mod_source, mod_tree)
|
||||
|
||||
class_source = extract_full_class_from_module(type_name, mod_source, mod_tree)
|
||||
if class_source is None:
|
||||
resolved_class = resolve_instance_class_name(type_name, mod_tree)
|
||||
if resolved_class and resolved_class not in existing_class_names and resolved_class not in emitted:
|
||||
class_source = extract_full_class_from_module(resolved_class, mod_source, mod_tree)
|
||||
type_name = resolved_class
|
||||
if class_source:
|
||||
code_strings.append(CodeString(code=class_source, file_path=mod_path))
|
||||
emitted.add(type_name)
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting type context class for {type_name} from {module_name}")
|
||||
continue
|
||||
|
||||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
|
||||
"""Removes the docstring from an indented block if it exists."""
|
||||
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
|
||||
|
|
|
|||
|
|
@ -13,13 +13,10 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_g
|
|||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import (
|
||||
collect_names_from_function_body,
|
||||
collect_type_names_from_annotation,
|
||||
enrich_testgen_context,
|
||||
enrich_type_context_classes,
|
||||
extract_full_class_from_module,
|
||||
extract_init_stub_from_class,
|
||||
extract_type_context_for_testgen,
|
||||
extract_parameter_type_constructors,
|
||||
get_code_optimization_context,
|
||||
resolve_instance_class_name,
|
||||
)
|
||||
|
|
@ -3556,7 +3553,8 @@ def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
|
|||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(package_dir / "base.py").write_text(
|
||||
"class Base:\n def __init__(self, x: int):\n self.x = x\n", encoding="utf-8"
|
||||
"class Base:\n def __init__(self, x: int):\n self.x = x\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
code = "from mypkg.base import Base\n\nclass A(Base):\n pass\n\nclass B(Base):\n pass\n"
|
||||
|
|
@ -3699,7 +3697,8 @@ def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None:
|
|||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(package_dir / "base.py").write_text(
|
||||
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n", encoding="utf-8"
|
||||
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
code = "from mypkg.base import BaseDict\n\nclass MyCustomDict(BaseDict):\n def target_method(self):\n return self.data\n"
|
||||
|
|
@ -3753,7 +3752,8 @@ def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
|
|||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(package_dir / "base.py").write_text(
|
||||
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n", encoding="utf-8"
|
||||
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
code = "from mypkg.base import CustomDict\n\nclass MyDict(CustomDict):\n def custom_method(self):\n return self.data\n"
|
||||
|
|
@ -4208,10 +4208,10 @@ class Other:
|
|||
assert stub is None
|
||||
|
||||
|
||||
# --- Tests for extract_type_context_for_testgen ---
|
||||
# --- Tests for extract_parameter_type_constructors ---
|
||||
|
||||
|
||||
def test_extract_type_context_for_testgen_project_type(tmp_path: Path) -> None:
|
||||
def test_extract_parameter_type_constructors_project_type(tmp_path: Path) -> None:
|
||||
# Create a module with a class
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
|
|
@ -4239,7 +4239,7 @@ def process(w: Widget) -> str:
|
|||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
code = result.code_strings[0].code
|
||||
assert "class Widget:" in code
|
||||
|
|
@ -4247,7 +4247,7 @@ def process(w: Widget) -> str:
|
|||
assert "size" in code
|
||||
|
||||
|
||||
def test_extract_type_context_for_testgen_excludes_builtins(tmp_path: Path) -> None:
|
||||
def test_extract_parameter_type_constructors_excludes_builtins(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
|
@ -4262,11 +4262,11 @@ def my_func(x: int, y: str, z: list) -> None:
|
|||
fto = FunctionToOptimize(
|
||||
function_name="my_func", file_path=(pkg / "func.py").resolve(), starting_line=2, ending_line=3
|
||||
)
|
||||
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 0
|
||||
|
||||
|
||||
def test_extract_type_context_for_testgen_skips_existing_classes(tmp_path: Path) -> None:
|
||||
def test_extract_parameter_type_constructors_skips_existing_classes(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
|
@ -4291,11 +4291,11 @@ def process(w: Widget) -> str:
|
|||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
# Widget is already in the context — should not be duplicated
|
||||
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), {"Widget"})
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), {"Widget"})
|
||||
assert len(result.code_strings) == 0
|
||||
|
||||
|
||||
def test_extract_type_context_for_testgen_no_init(tmp_path: Path) -> None:
|
||||
def test_extract_parameter_type_constructors_no_init(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
|
@ -4318,11 +4318,8 @@ def process(c: Config) -> str:
|
|||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
code = result.code_strings[0].code
|
||||
assert "class Config:" in code
|
||||
assert "x = 10" in code
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 0
|
||||
|
||||
|
||||
# --- Tests for resolve_instance_class_name ---
|
||||
|
|
@ -4447,394 +4444,3 @@ def get_log_level() -> str:
|
|||
combined = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class AppConfig:" in combined
|
||||
assert "@property" in combined
|
||||
|
||||
|
||||
# --- Tests for collect_names_from_function_body ---
|
||||
|
||||
|
||||
def test_collect_names_from_function_body_constructor_calls() -> None:
|
||||
source = """\
|
||||
def f():
|
||||
x = MyClass(1, 2)
|
||||
y = OtherClass()
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
func_node = tree.body[0]
|
||||
names = collect_names_from_function_body(func_node)
|
||||
assert "MyClass" in names
|
||||
assert "OtherClass" in names
|
||||
|
||||
|
||||
def test_collect_names_from_function_body_attribute_access() -> None:
|
||||
source = """\
|
||||
def f():
|
||||
val = config.PROP
|
||||
x = MyEnum.VALUE
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
func_node = tree.body[0]
|
||||
names = collect_names_from_function_body(func_node)
|
||||
assert "config" in names
|
||||
assert "MyEnum" in names
|
||||
|
||||
|
||||
def test_collect_names_from_function_body_isinstance() -> None:
|
||||
source = """\
|
||||
def f(x):
|
||||
if isinstance(x, SomeClass):
|
||||
pass
|
||||
if issubclass(type(x), (A, B)):
|
||||
pass
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
func_node = tree.body[0]
|
||||
names = collect_names_from_function_body(func_node)
|
||||
assert "SomeClass" in names
|
||||
assert "A" in names
|
||||
assert "B" in names
|
||||
|
||||
|
||||
def test_collect_names_from_function_body_class_method_call() -> None:
|
||||
source = """\
|
||||
def f():
|
||||
obj = Builder.create(42)
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
func_node = tree.body[0]
|
||||
names = collect_names_from_function_body(func_node)
|
||||
assert "Builder" in names
|
||||
|
||||
|
||||
# --- Tests for extract_full_class_from_module ---
|
||||
|
||||
|
||||
def test_extract_full_class_from_module_basic() -> None:
|
||||
source = """\
|
||||
class Foo:
|
||||
x = 10
|
||||
def bar(self):
|
||||
return self.x
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
result = extract_full_class_from_module("Foo", source, tree)
|
||||
assert result is not None
|
||||
assert "class Foo:" in result
|
||||
assert "x = 10" in result
|
||||
assert "def bar(self):" in result
|
||||
|
||||
|
||||
def test_extract_full_class_from_module_missing_class() -> None:
|
||||
source = """\
|
||||
class Foo:
|
||||
pass
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
result = extract_full_class_from_module("Bar", source, tree)
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- Integration tests for extract_type_context_for_testgen with body references ---
|
||||
|
||||
|
||||
def test_extract_type_context_for_testgen_body_enum_reference(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "enums.py").write_text(
|
||||
"""\
|
||||
import enum
|
||||
|
||||
class Color(enum.Enum):
|
||||
RED = "red"
|
||||
BLUE = "blue"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"""\
|
||||
from mypkg.enums import Color
|
||||
|
||||
def paint(name: str) -> str:
|
||||
return name + Color.RED.value
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="paint", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
code = result.code_strings[0].code
|
||||
assert "class Color" in code
|
||||
assert "RED" in code
|
||||
|
||||
|
||||
def test_extract_type_context_for_testgen_body_config_class(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "config.py").write_text(
|
||||
"""\
|
||||
class AppConfig:
|
||||
DEBUG = True
|
||||
LOG_LEVEL = "INFO"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "service.py").write_text(
|
||||
"""\
|
||||
from mypkg.config import AppConfig
|
||||
|
||||
def run() -> str:
|
||||
return AppConfig.LOG_LEVEL
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="run", file_path=(pkg / "service.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
code = result.code_strings[0].code
|
||||
assert "class AppConfig:" in code
|
||||
assert "LOG_LEVEL" in code
|
||||
|
||||
|
||||
def test_extract_type_context_for_testgen_instance_resolution(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "config.py").write_text(
|
||||
"""\
|
||||
class AppConfig:
|
||||
DEBUG = True
|
||||
LOG_LEVEL = "INFO"
|
||||
|
||||
app_config = AppConfig()
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "service.py").write_text(
|
||||
"""\
|
||||
from mypkg.config import app_config
|
||||
|
||||
def run() -> str:
|
||||
return app_config.LOG_LEVEL
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="run", file_path=(pkg / "service.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
code = result.code_strings[0].code
|
||||
assert "class AppConfig:" in code
|
||||
|
||||
|
||||
def test_extract_type_context_for_testgen_non_imported_names_filtered(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "processor.py").write_text(
|
||||
"""\
|
||||
def process(x: int) -> int:
|
||||
result = SomeLocal(x)
|
||||
return result
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=1, ending_line=3
|
||||
)
|
||||
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 0
|
||||
|
||||
|
||||
def test_enrich_testgen_context_cross_module_base_class(tmp_path: Path) -> None:
|
||||
"""Base classes from other modules are resolved and extracted."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
(pkg / "base.py").write_text(
|
||||
"""\
|
||||
class Base:
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
|
||||
def do_stuff(self) -> int:
|
||||
return self.x * 2
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
(pkg / "child.py").write_text(
|
||||
"""\
|
||||
from mypkg.base import Base
|
||||
|
||||
class Child(Base):
|
||||
def __init__(self, x: int, y: int):
|
||||
super().__init__(x)
|
||||
self.y = y
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
context = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code="""\
|
||||
from mypkg.child import Child
|
||||
|
||||
def process(c: Child) -> int:
|
||||
return c.do_stuff()
|
||||
""",
|
||||
file_path=pkg / "user.py",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
all_code = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class Child(Base)" in all_code
|
||||
assert "class Base:" in all_code
|
||||
|
||||
|
||||
def test_enrich_testgen_context_cross_module_base_depth_limit(tmp_path: Path) -> None:
|
||||
"""Deep inheritance chains across modules stop at the depth limit without crashing."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a chain of 6 modules: Level0 -> Level1 -> ... -> Level5
|
||||
for i in range(6):
|
||||
if i == 0:
|
||||
code = f"class Level{i}:\n pass\n"
|
||||
else:
|
||||
code = f"from mypkg.level{i - 1} import Level{i - 1}\n\nclass Level{i}(Level{i - 1}):\n pass\n"
|
||||
(pkg / f"level{i}.py").write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code="from mypkg.level5 import Level5\n\ndef process(obj: Level5) -> None:\n pass\n",
|
||||
file_path=pkg / "user.py",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Should not crash and should extract Level5 plus some ancestors (limited by depth=3)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
all_code = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class Level5" in all_code
|
||||
# Depth limit prevents extracting the full chain — Level0 through Level2 may not appear
|
||||
# but Level3+ should be present
|
||||
assert "class Level3" in all_code or "class Level4" in all_code
|
||||
|
||||
|
||||
def test_enrich_type_context_classes_extracts_field_types(tmp_path: Path) -> None:
|
||||
"""Classes referenced as field types in type-context classes are extracted."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
(pkg / "source.py").write_text(
|
||||
"""\
|
||||
class Source:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
(pkg / "regions.py").write_text(
|
||||
"""\
|
||||
from mypkg.source import Source
|
||||
|
||||
class TextRegions:
|
||||
sources: list[Source]
|
||||
label: str
|
||||
|
||||
def __init__(self, sources: list[Source], label: str):
|
||||
self.sources = sources
|
||||
self.label = label
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
type_context = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code="""\
|
||||
class TextRegions:
|
||||
sources: list[Source]
|
||||
label: str
|
||||
|
||||
def __init__(self, sources: list[Source], label: str):
|
||||
self.sources = sources
|
||||
self.label = label
|
||||
""",
|
||||
file_path=pkg / "regions.py",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = enrich_type_context_classes(type_context, {"TextRegions"}, tmp_path)
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class Source:" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_enrich_type_context_classes_skips_existing_and_builtins(tmp_path: Path) -> None:
|
||||
"""Builtins and already-existing class names are not re-extracted."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
(pkg / "models.py").write_text(
|
||||
"""\
|
||||
from mypkg.other import Other
|
||||
|
||||
class MyClass:
|
||||
x: int
|
||||
y: str
|
||||
z: Other
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
(pkg / "other.py").write_text(
|
||||
"""\
|
||||
class Other:
|
||||
def __init__(self):
|
||||
pass
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
type_context = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code="""\
|
||||
class MyClass:
|
||||
x: int
|
||||
y: str
|
||||
z: Other
|
||||
""",
|
||||
file_path=pkg / "models.py",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Other is already existing — should not be extracted
|
||||
result = enrich_type_context_classes(type_context, {"MyClass", "Other"}, tmp_path)
|
||||
assert len(result.code_strings) == 0
|
||||
|
||||
# Without Other in existing_class_names, it should be extracted
|
||||
result = enrich_type_context_classes(type_context, {"MyClass"}, tmp_path)
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class Other:" in result.code_strings[0].code
|
||||
|
|
|
|||
Loading…
Reference in a new issue