mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
feat: extract parameter type constructor signatures into testgen context
Add enrichment step that parses FTO parameter type annotations, resolves types via jedi (following re-exports), and extracts full __init__ source to give the LLM constructor context for typed parameters.
This commit is contained in:
parent
68c148c876
commit
2367b4c02c
2 changed files with 432 additions and 3 deletions
|
|
@ -52,6 +52,7 @@ def build_testgen_context(
|
|||
*,
|
||||
remove_docstrings: bool = False,
|
||||
include_enrichment: bool = True,
|
||||
function_to_optimize: FunctionToOptimize | None = None,
|
||||
) -> CodeStringsMarkdown:
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
|
|
@ -66,6 +67,17 @@ def build_testgen_context(
|
|||
if enrichment.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(code_strings=testgen_context.code_strings + enrichment.code_strings)
|
||||
|
||||
if function_to_optimize is not None:
|
||||
result = _parse_and_collect_imports(testgen_context)
|
||||
existing_classes = collect_existing_class_names(result[0]) if result else set()
|
||||
constructor_stubs = extract_parameter_type_constructors(
|
||||
function_to_optimize, project_root_path, existing_classes
|
||||
)
|
||||
if constructor_stubs.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + constructor_stubs.code_strings
|
||||
)
|
||||
|
||||
return testgen_context
|
||||
|
||||
|
||||
|
|
@ -156,12 +168,18 @@ def get_code_optimization_context(
|
|||
read_only_context_code = ""
|
||||
|
||||
# Progressive fallback for testgen context token limits
|
||||
testgen_context = build_testgen_context(helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path)
|
||||
testgen_context = build_testgen_context(
|
||||
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, function_to_optimize=function_to_optimize
|
||||
)
|
||||
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
logger.debug("Testgen context exceeded token limit, removing docstrings")
|
||||
testgen_context = build_testgen_context(
|
||||
helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
function_to_optimize=function_to_optimize,
|
||||
)
|
||||
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
|
|
@ -627,6 +645,205 @@ def collect_existing_class_names(tree: ast.Module) -> set[str]:
|
|||
return class_names
|
||||
|
||||
|
||||
BUILTIN_AND_TYPING_NAMES = frozenset(
|
||||
{
|
||||
"int",
|
||||
"str",
|
||||
"float",
|
||||
"bool",
|
||||
"bytes",
|
||||
"bytearray",
|
||||
"complex",
|
||||
"list",
|
||||
"dict",
|
||||
"set",
|
||||
"frozenset",
|
||||
"tuple",
|
||||
"type",
|
||||
"object",
|
||||
"None",
|
||||
"NoneType",
|
||||
"Ellipsis",
|
||||
"NotImplemented",
|
||||
"memoryview",
|
||||
"range",
|
||||
"slice",
|
||||
"property",
|
||||
"classmethod",
|
||||
"staticmethod",
|
||||
"super",
|
||||
"Optional",
|
||||
"Union",
|
||||
"Any",
|
||||
"List",
|
||||
"Dict",
|
||||
"Set",
|
||||
"FrozenSet",
|
||||
"Tuple",
|
||||
"Type",
|
||||
"Callable",
|
||||
"Iterator",
|
||||
"Generator",
|
||||
"Coroutine",
|
||||
"AsyncGenerator",
|
||||
"AsyncIterator",
|
||||
"Iterable",
|
||||
"AsyncIterable",
|
||||
"Sequence",
|
||||
"MutableSequence",
|
||||
"Mapping",
|
||||
"MutableMapping",
|
||||
"Collection",
|
||||
"Awaitable",
|
||||
"Literal",
|
||||
"Final",
|
||||
"ClassVar",
|
||||
"TypeVar",
|
||||
"TypeAlias",
|
||||
"ParamSpec",
|
||||
"Concatenate",
|
||||
"Annotated",
|
||||
"TypeGuard",
|
||||
"Self",
|
||||
"Unpack",
|
||||
"TypeVarTuple",
|
||||
"Never",
|
||||
"NoReturn",
|
||||
"SupportsInt",
|
||||
"SupportsFloat",
|
||||
"SupportsComplex",
|
||||
"SupportsBytes",
|
||||
"SupportsAbs",
|
||||
"SupportsRound",
|
||||
"IO",
|
||||
"TextIO",
|
||||
"BinaryIO",
|
||||
"Pattern",
|
||||
"Match",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]:
|
||||
if node is None:
|
||||
return set()
|
||||
if isinstance(node, ast.Name):
|
||||
return {node.id}
|
||||
if isinstance(node, ast.Subscript):
|
||||
names = collect_type_names_from_annotation(node.value)
|
||||
names |= collect_type_names_from_annotation(node.slice)
|
||||
return names
|
||||
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
return collect_type_names_from_annotation(node.left) | collect_type_names_from_annotation(node.right)
|
||||
if isinstance(node, ast.Tuple):
|
||||
names: set[str] = set()
|
||||
for elt in node.elts:
|
||||
names |= collect_type_names_from_annotation(elt)
|
||||
return names
|
||||
return set()
|
||||
|
||||
|
||||
def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
|
||||
class_node = None
|
||||
for node in ast.walk(module_tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
||||
class_node = node
|
||||
break
|
||||
if class_node is None:
|
||||
return None
|
||||
|
||||
init_node = None
|
||||
for item in class_node.body:
|
||||
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__":
|
||||
init_node = item
|
||||
break
|
||||
if init_node is None:
|
||||
return None
|
||||
|
||||
lines = module_source.splitlines()
|
||||
init_source = "\n".join(lines[init_node.lineno - 1 : init_node.end_lineno])
|
||||
return f"class {class_name}:\n{init_source}"
|
||||
|
||||
|
||||
def extract_parameter_type_constructors(
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path, existing_class_names: set[str]
|
||||
) -> CodeStringsMarkdown:
|
||||
import jedi
|
||||
|
||||
try:
|
||||
source = function_to_optimize.file_path.read_text(encoding="utf-8")
|
||||
tree = ast.parse(source)
|
||||
except Exception:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
func_node = None
|
||||
for node in ast.walk(tree):
|
||||
if (
|
||||
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
and node.name == function_to_optimize.function_name
|
||||
):
|
||||
if function_to_optimize.starting_line is not None and node.lineno != function_to_optimize.starting_line:
|
||||
continue
|
||||
func_node = node
|
||||
break
|
||||
if func_node is None:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
type_names: set[str] = set()
|
||||
for arg in func_node.args.args + func_node.args.posonlyargs + func_node.args.kwonlyargs:
|
||||
type_names |= collect_type_names_from_annotation(arg.annotation)
|
||||
if func_node.args.vararg:
|
||||
type_names |= collect_type_names_from_annotation(func_node.args.vararg.annotation)
|
||||
if func_node.args.kwarg:
|
||||
type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation)
|
||||
|
||||
type_names -= BUILTIN_AND_TYPING_NAMES
|
||||
type_names -= existing_class_names
|
||||
if not type_names:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
import_map: dict[str, str] = {}
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom) and node.module:
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
import_map[name] = node.module
|
||||
|
||||
code_strings: list[CodeString] = []
|
||||
module_cache: dict[Path, tuple[str, ast.Module]] = {}
|
||||
|
||||
for type_name in sorted(type_names):
|
||||
module_name = import_map.get(type_name)
|
||||
if not module_name:
|
||||
continue
|
||||
try:
|
||||
script_code = f"from {module_name} import {type_name}"
|
||||
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
|
||||
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
|
||||
if not definitions:
|
||||
continue
|
||||
|
||||
module_path = definitions[0].module_path
|
||||
if not module_path:
|
||||
continue
|
||||
|
||||
if module_path in module_cache:
|
||||
mod_source, mod_tree = module_cache[module_path]
|
||||
else:
|
||||
mod_source = module_path.read_text(encoding="utf-8")
|
||||
mod_tree = ast.parse(mod_source)
|
||||
module_cache[module_path] = (mod_source, mod_tree)
|
||||
|
||||
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
|
||||
if stub:
|
||||
code_strings.append(CodeString(code=stub, file_path=module_path))
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
|
||||
continue
|
||||
|
||||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
|
||||
import jedi
|
||||
|
||||
|
|
@ -852,7 +1069,12 @@ def prune_cst(
|
|||
return node, False
|
||||
|
||||
# Handle dunder methods for READ_ONLY/TESTGEN modes
|
||||
if include_dunder_methods and len(node.name.value) > 4 and node.name.value.startswith("__") and node.name.value.endswith("__"):
|
||||
if (
|
||||
include_dunder_methods
|
||||
and len(node.name.value) > 4
|
||||
and node.name.value.startswith("__")
|
||||
and node.name.value.endswith("__")
|
||||
):
|
||||
if not include_init_dunder and node.name.value == "__init__":
|
||||
return None, False
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
|
|
@ -12,7 +13,10 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_g
|
|||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import (
|
||||
collect_type_names_from_annotation,
|
||||
enrich_testgen_context,
|
||||
extract_init_stub_from_class,
|
||||
extract_parameter_type_constructors,
|
||||
get_code_optimization_context,
|
||||
)
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
|
||||
|
|
@ -4132,3 +4136,206 @@ def my_func(ctx: Context) -> None:
|
|||
|
||||
class_names = [cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings]
|
||||
assert len(class_names) == len(set(class_names)), f"Duplicate class stubs found: {class_names}"
|
||||
|
||||
|
||||
# --- Tests for collect_type_names_from_annotation ---
|
||||
|
||||
|
||||
def test_collect_type_names_simple() -> None:
|
||||
tree = ast.parse("def f(x: Foo): pass")
|
||||
func = tree.body[0]
|
||||
ann = func.args.args[0].annotation
|
||||
assert collect_type_names_from_annotation(ann) == {"Foo"}
|
||||
|
||||
|
||||
def test_collect_type_names_generic() -> None:
|
||||
tree = ast.parse("def f(x: list[Foo]): pass")
|
||||
func = tree.body[0]
|
||||
ann = func.args.args[0].annotation
|
||||
names = collect_type_names_from_annotation(ann)
|
||||
assert "Foo" in names
|
||||
assert "list" in names
|
||||
|
||||
|
||||
def test_collect_type_names_optional() -> None:
|
||||
tree = ast.parse("def f(x: Optional[Foo]): pass")
|
||||
func = tree.body[0]
|
||||
ann = func.args.args[0].annotation
|
||||
names = collect_type_names_from_annotation(ann)
|
||||
assert "Optional" in names
|
||||
assert "Foo" in names
|
||||
|
||||
|
||||
def test_collect_type_names_union_pipe() -> None:
|
||||
tree = ast.parse("def f(x: Foo | Bar): pass")
|
||||
func = tree.body[0]
|
||||
ann = func.args.args[0].annotation
|
||||
names = collect_type_names_from_annotation(ann)
|
||||
assert names == {"Foo", "Bar"}
|
||||
|
||||
|
||||
def test_collect_type_names_none_annotation() -> None:
|
||||
assert collect_type_names_from_annotation(None) == set()
|
||||
|
||||
|
||||
def test_collect_type_names_attribute_skipped() -> None:
|
||||
tree = ast.parse("def f(x: module.Foo): pass")
|
||||
func = tree.body[0]
|
||||
ann = func.args.args[0].annotation
|
||||
assert collect_type_names_from_annotation(ann) == set()
|
||||
|
||||
|
||||
# --- Tests for extract_init_stub_from_class ---
|
||||
|
||||
|
||||
def test_extract_init_stub_basic() -> None:
|
||||
source = """
|
||||
class MyClass:
|
||||
def __init__(self, name: str, value: int = 0):
|
||||
self.name = name
|
||||
self.value = value
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("MyClass", source, tree)
|
||||
assert stub is not None
|
||||
assert "class MyClass:" in stub
|
||||
assert "def __init__(self, name: str, value: int = 0):" in stub
|
||||
assert "self.name = name" in stub
|
||||
assert "self.value = value" in stub
|
||||
|
||||
|
||||
def test_extract_init_stub_no_init() -> None:
|
||||
source = """
|
||||
class NoInit:
|
||||
x = 10
|
||||
def other(self):
|
||||
pass
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("NoInit", source, tree)
|
||||
assert stub is None
|
||||
|
||||
|
||||
def test_extract_init_stub_class_not_found() -> None:
|
||||
source = """
|
||||
class Other:
|
||||
def __init__(self):
|
||||
pass
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("Missing", source, tree)
|
||||
assert stub is None
|
||||
|
||||
|
||||
# --- Tests for extract_parameter_type_constructors ---
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_project_type(tmp_path: Path) -> None:
|
||||
# Create a module with a class
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "models.py").write_text(
|
||||
"""
|
||||
class Widget:
|
||||
def __init__(self, size: int, color: str = "red"):
|
||||
self.size = size
|
||||
self.color = color
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Create the FTO file that uses Widget
|
||||
(pkg / "processor.py").write_text(
|
||||
"""from mypkg.models import Widget
|
||||
|
||||
def process(w: Widget) -> str:
|
||||
return str(w)
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
code = result.code_strings[0].code
|
||||
assert "class Widget:" in code
|
||||
assert "def __init__" in code
|
||||
assert "size" in code
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_excludes_builtins(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "func.py").write_text(
|
||||
"""
|
||||
def my_func(x: int, y: str, z: list) -> None:
|
||||
pass
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="my_func", file_path=(pkg / "func.py").resolve(), starting_line=2, ending_line=3
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 0
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_skips_existing_classes(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "models.py").write_text(
|
||||
"""
|
||||
class Widget:
|
||||
def __init__(self, size: int):
|
||||
self.size = size
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"""from mypkg.models import Widget
|
||||
|
||||
def process(w: Widget) -> str:
|
||||
return str(w)
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
# Widget is already in the context — should not be duplicated
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), {"Widget"})
|
||||
assert len(result.code_strings) == 0
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_no_init(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "models.py").write_text(
|
||||
"""
|
||||
class Config:
|
||||
x = 10
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"""from mypkg.models import Config
|
||||
|
||||
def process(c: Config) -> str:
|
||||
return str(c)
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue