feat: extract parameter type constructor signatures into testgen context

Add enrichment step that parses FTO parameter type annotations, resolves
types via jedi (following re-exports), and extracts full __init__ source
to give the LLM constructor context for typed parameters.
This commit is contained in:
Kevin Turcios 2026-02-18 07:51:53 -05:00
parent 68c148c876
commit 2367b4c02c
2 changed files with 432 additions and 3 deletions

View file

@ -52,6 +52,7 @@ def build_testgen_context(
*,
remove_docstrings: bool = False,
include_enrichment: bool = True,
function_to_optimize: FunctionToOptimize | None = None,
) -> CodeStringsMarkdown:
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
@ -66,6 +67,17 @@ def build_testgen_context(
if enrichment.code_strings:
testgen_context = CodeStringsMarkdown(code_strings=testgen_context.code_strings + enrichment.code_strings)
if function_to_optimize is not None:
result = _parse_and_collect_imports(testgen_context)
existing_classes = collect_existing_class_names(result[0]) if result else set()
constructor_stubs = extract_parameter_type_constructors(
function_to_optimize, project_root_path, existing_classes
)
if constructor_stubs.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + constructor_stubs.code_strings
)
return testgen_context
@ -156,12 +168,18 @@ def get_code_optimization_context(
read_only_context_code = ""
# Progressive fallback for testgen context token limits
testgen_context = build_testgen_context(helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path)
testgen_context = build_testgen_context(
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, function_to_optimize=function_to_optimize
)
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
logger.debug("Testgen context exceeded token limit, removing docstrings")
testgen_context = build_testgen_context(
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
function_to_optimize=function_to_optimize,
)
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
@ -627,6 +645,205 @@ def collect_existing_class_names(tree: ast.Module) -> set[str]:
return class_names
BUILTIN_AND_TYPING_NAMES = frozenset(
{
"int",
"str",
"float",
"bool",
"bytes",
"bytearray",
"complex",
"list",
"dict",
"set",
"frozenset",
"tuple",
"type",
"object",
"None",
"NoneType",
"Ellipsis",
"NotImplemented",
"memoryview",
"range",
"slice",
"property",
"classmethod",
"staticmethod",
"super",
"Optional",
"Union",
"Any",
"List",
"Dict",
"Set",
"FrozenSet",
"Tuple",
"Type",
"Callable",
"Iterator",
"Generator",
"Coroutine",
"AsyncGenerator",
"AsyncIterator",
"Iterable",
"AsyncIterable",
"Sequence",
"MutableSequence",
"Mapping",
"MutableMapping",
"Collection",
"Awaitable",
"Literal",
"Final",
"ClassVar",
"TypeVar",
"TypeAlias",
"ParamSpec",
"Concatenate",
"Annotated",
"TypeGuard",
"Self",
"Unpack",
"TypeVarTuple",
"Never",
"NoReturn",
"SupportsInt",
"SupportsFloat",
"SupportsComplex",
"SupportsBytes",
"SupportsAbs",
"SupportsRound",
"IO",
"TextIO",
"BinaryIO",
"Pattern",
"Match",
}
)
def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]:
if node is None:
return set()
if isinstance(node, ast.Name):
return {node.id}
if isinstance(node, ast.Subscript):
names = collect_type_names_from_annotation(node.value)
names |= collect_type_names_from_annotation(node.slice)
return names
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
return collect_type_names_from_annotation(node.left) | collect_type_names_from_annotation(node.right)
if isinstance(node, ast.Tuple):
names: set[str] = set()
for elt in node.elts:
names |= collect_type_names_from_annotation(elt)
return names
return set()
def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
class_node = None
for node in ast.walk(module_tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
class_node = node
break
if class_node is None:
return None
init_node = None
for item in class_node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__":
init_node = item
break
if init_node is None:
return None
lines = module_source.splitlines()
init_source = "\n".join(lines[init_node.lineno - 1 : init_node.end_lineno])
return f"class {class_name}:\n{init_source}"
def extract_parameter_type_constructors(
function_to_optimize: FunctionToOptimize, project_root_path: Path, existing_class_names: set[str]
) -> CodeStringsMarkdown:
import jedi
try:
source = function_to_optimize.file_path.read_text(encoding="utf-8")
tree = ast.parse(source)
except Exception:
return CodeStringsMarkdown(code_strings=[])
func_node = None
for node in ast.walk(tree):
if (
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
and node.name == function_to_optimize.function_name
):
if function_to_optimize.starting_line is not None and node.lineno != function_to_optimize.starting_line:
continue
func_node = node
break
if func_node is None:
return CodeStringsMarkdown(code_strings=[])
type_names: set[str] = set()
for arg in func_node.args.args + func_node.args.posonlyargs + func_node.args.kwonlyargs:
type_names |= collect_type_names_from_annotation(arg.annotation)
if func_node.args.vararg:
type_names |= collect_type_names_from_annotation(func_node.args.vararg.annotation)
if func_node.args.kwarg:
type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation)
type_names -= BUILTIN_AND_TYPING_NAMES
type_names -= existing_class_names
if not type_names:
return CodeStringsMarkdown(code_strings=[])
import_map: dict[str, str] = {}
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
name = alias.asname if alias.asname else alias.name
import_map[name] = node.module
code_strings: list[CodeString] = []
module_cache: dict[Path, tuple[str, ast.Module]] = {}
for type_name in sorted(type_names):
module_name = import_map.get(type_name)
if not module_name:
continue
try:
script_code = f"from {module_name} import {type_name}"
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
if not definitions:
continue
module_path = definitions[0].module_path
if not module_path:
continue
if module_path in module_cache:
mod_source, mod_tree = module_cache[module_path]
else:
mod_source = module_path.read_text(encoding="utf-8")
mod_tree = ast.parse(mod_source)
module_cache[module_path] = (mod_source, mod_tree)
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
if stub:
code_strings.append(CodeString(code=stub, file_path=module_path))
except Exception:
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
continue
return CodeStringsMarkdown(code_strings=code_strings)
def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
import jedi
@ -852,7 +1069,12 @@ def prune_cst(
return node, False
# Handle dunder methods for READ_ONLY/TESTGEN modes
if include_dunder_methods and len(node.name.value) > 4 and node.name.value.startswith("__") and node.name.value.endswith("__"):
if (
include_dunder_methods
and len(node.name.value) > 4
and node.name.value.startswith("__")
and node.name.value.endswith("__")
):
if not include_init_dunder and node.name.value == "__init__":
return None, False
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import ast
import sys
import tempfile
from argparse import Namespace
@ -12,7 +13,10 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_g
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.code_context_extractor import (
collect_type_names_from_annotation,
enrich_testgen_context,
extract_init_stub_from_class,
extract_parameter_type_constructors,
get_code_optimization_context,
)
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
@ -4132,3 +4136,206 @@ def my_func(ctx: Context) -> None:
class_names = [cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings]
assert len(class_names) == len(set(class_names)), f"Duplicate class stubs found: {class_names}"
# --- Tests for collect_type_names_from_annotation ---
def test_collect_type_names_simple() -> None:
tree = ast.parse("def f(x: Foo): pass")
func = tree.body[0]
ann = func.args.args[0].annotation
assert collect_type_names_from_annotation(ann) == {"Foo"}
def test_collect_type_names_generic() -> None:
tree = ast.parse("def f(x: list[Foo]): pass")
func = tree.body[0]
ann = func.args.args[0].annotation
names = collect_type_names_from_annotation(ann)
assert "Foo" in names
assert "list" in names
def test_collect_type_names_optional() -> None:
tree = ast.parse("def f(x: Optional[Foo]): pass")
func = tree.body[0]
ann = func.args.args[0].annotation
names = collect_type_names_from_annotation(ann)
assert "Optional" in names
assert "Foo" in names
def test_collect_type_names_union_pipe() -> None:
tree = ast.parse("def f(x: Foo | Bar): pass")
func = tree.body[0]
ann = func.args.args[0].annotation
names = collect_type_names_from_annotation(ann)
assert names == {"Foo", "Bar"}
def test_collect_type_names_none_annotation() -> None:
assert collect_type_names_from_annotation(None) == set()
def test_collect_type_names_attribute_skipped() -> None:
tree = ast.parse("def f(x: module.Foo): pass")
func = tree.body[0]
ann = func.args.args[0].annotation
assert collect_type_names_from_annotation(ann) == set()
# --- Tests for extract_init_stub_from_class ---
def test_extract_init_stub_basic() -> None:
source = """
class MyClass:
def __init__(self, name: str, value: int = 0):
self.name = name
self.value = value
"""
tree = ast.parse(source)
stub = extract_init_stub_from_class("MyClass", source, tree)
assert stub is not None
assert "class MyClass:" in stub
assert "def __init__(self, name: str, value: int = 0):" in stub
assert "self.name = name" in stub
assert "self.value = value" in stub
def test_extract_init_stub_no_init() -> None:
source = """
class NoInit:
x = 10
def other(self):
pass
"""
tree = ast.parse(source)
stub = extract_init_stub_from_class("NoInit", source, tree)
assert stub is None
def test_extract_init_stub_class_not_found() -> None:
source = """
class Other:
def __init__(self):
pass
"""
tree = ast.parse(source)
stub = extract_init_stub_from_class("Missing", source, tree)
assert stub is None
# --- Tests for extract_parameter_type_constructors ---
def test_extract_parameter_type_constructors_project_type(tmp_path: Path) -> None:
# Create a module with a class
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"""
class Widget:
def __init__(self, size: int, color: str = "red"):
self.size = size
self.color = color
""",
encoding="utf-8",
)
# Create the FTO file that uses Widget
(pkg / "processor.py").write_text(
"""from mypkg.models import Widget
def process(w: Widget) -> str:
return str(w)
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
code = result.code_strings[0].code
assert "class Widget:" in code
assert "def __init__" in code
assert "size" in code
def test_extract_parameter_type_constructors_excludes_builtins(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "func.py").write_text(
"""
def my_func(x: int, y: str, z: list) -> None:
pass
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="my_func", file_path=(pkg / "func.py").resolve(), starting_line=2, ending_line=3
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 0
def test_extract_parameter_type_constructors_skips_existing_classes(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"""
class Widget:
def __init__(self, size: int):
self.size = size
""",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"""from mypkg.models import Widget
def process(w: Widget) -> str:
return str(w)
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
# Widget is already in the context — should not be duplicated
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), {"Widget"})
assert len(result.code_strings) == 0
def test_extract_parameter_type_constructors_no_init(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"""
class Config:
x = 10
""",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"""from mypkg.models import Config
def process(c: Config) -> str:
return str(c)
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 0