Merge pull request #2437 from codeflash-ai/misc-changes

fix: improve ranker scoring consistency and local-caching bias
This commit is contained in:
Kevin Turcios 2026-02-23 08:55:18 +00:00 committed by GitHub
commit 05aecd6fbd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 54 additions and 1308 deletions

View file

@ -168,6 +168,8 @@ jobs:
env:
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.ANTHROPIC_FOUNDRY_API_KEY }}
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.ANTHROPIC_FOUNDRY_BASE_URL }}
DATABASE_URL: ${{ secrets.DATABASE_URL }}
DJANGO_SETTINGS_MODULE: aiservice.settings
# @claude mentions (can edit and push)
claude-mention:

View file

@ -13,8 +13,10 @@ from aiservice.common.llm_output_utils import truncate_pathological_output
# Matches both ```python and ```python:filepath blocks, captures content only
MARKDOWN_CODE_BLOCK_PATTERN = re.compile(r"```python(?::[^\n]*)?\n(.*?)```", re.DOTALL)
# Matches first ```python block (no filepath), captures content
FIRST_CODE_BLOCK_PATTERN = re.compile(r"^```python\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
# Matches first ```python block (no filepath), captures content.
# Uses greedy (.*) to handle LLM outputs with nested code fences (e.g. ```python:filepath
# blocks inside the main block). Requires closing ``` to be alone on its line.
FIRST_CODE_BLOCK_PATTERN = re.compile(r"^```python\s*\n(.*)\n```[ \t]*$", re.MULTILINE | re.DOTALL)
# Fallback for incomplete code blocks (missing closing ```)
FIRST_CODE_BLOCK_FALLBACK_PATTERN = re.compile(r"^```python\s*\n(.*)", re.MULTILINE | re.DOTALL)

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import re
import uuid
import isort
@ -29,10 +30,13 @@ def parse_python_version(version: str | None) -> tuple[int, int, int]:
split_version = version.split(".")
if len(split_version) != 3:
raise ValueError("Invalid version format")
major, minor, patch = int(split_version[0]), int(split_version[1]), int(split_version[2])
patch_str = re.match(r"\d+", split_version[2])
if not patch_str:
raise ValueError("Invalid patch version")
major, minor, patch = int(split_version[0]), int(split_version[1]), int(patch_str.group())
assert major == 3, "Only Python 3 is supported"
assert minor >= 9, "Only Python 3.9 and above is supported"
assert minor <= 14, "Unsupported Python version"
assert minor <= 15, "Unsupported Python version"
assert patch >= 0, "Only Python 3.9 and above is supported"
assert patch < 100, "Invalid version format"

View file

@ -1,165 +0,0 @@
"""Preprocessing step to extract constructor signatures from regular (non-dataclass) classes.
When the test context contains class definitions with explicit __init__ methods,
this module extracts the constructor signatures to help the LLM understand how
to properly instantiate them.
"""
from __future__ import annotations
from typing import NamedTuple
import libcst as cst
from libcst.metadata import MetadataWrapper
from aiservice.common.markdown_utils import extract_all_code_from_markdown
from core.languages.python.cst_utils import get_node_source_text, has_decorator, parse_module_to_cst
class ParamInfo(NamedTuple):
name: str
annotation: str | None
has_default: bool
default_repr: str | None
def _find_init_method(class_node: cst.ClassDef) -> cst.FunctionDef | None:
for stmt in class_node.body.body:
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == "__init__":
return stmt
return None
def extract_init_params(
init_node: cst.FunctionDef, source_lines: list[str], wrapper: MetadataWrapper
) -> list[ParamInfo]:
params: list[ParamInfo] = []
parameters = init_node.params
for param in parameters.params:
name = param.name.value
if name == "self":
continue
annotation = None
if param.annotation is not None:
annotation = get_node_source_text(param.annotation, source_lines, wrapper)
default_repr = None
if param.default is not None:
default_repr = get_node_source_text(param.default, source_lines, wrapper)
if len(default_repr) > 50:
default_repr = default_repr[:47] + "..."
params.append(ParamInfo(name, annotation, has_default=param.default is not None, default_repr=default_repr))
if isinstance(parameters.star_arg, cst.Param):
name = f"*{parameters.star_arg.name.value}"
annotation = None
if parameters.star_arg.annotation is not None:
annotation = get_node_source_text(parameters.star_arg.annotation, source_lines, wrapper)
params.append(ParamInfo(name, annotation, has_default=False, default_repr=None))
for param in parameters.kwonly_params:
name = param.name.value
annotation = None
if param.annotation is not None:
annotation = get_node_source_text(param.annotation, source_lines, wrapper)
default_repr = None
if param.default is not None:
default_repr = get_node_source_text(param.default, source_lines, wrapper)
if len(default_repr) > 50:
default_repr = default_repr[:47] + "..."
params.append(ParamInfo(name, annotation, has_default=param.default is not None, default_repr=default_repr))
if parameters.star_kwarg is not None:
name = f"**{parameters.star_kwarg.name.value}"
annotation = None
if parameters.star_kwarg.annotation is not None:
annotation = get_node_source_text(parameters.star_kwarg.annotation, source_lines, wrapper)
params.append(ParamInfo(name, annotation, has_default=False, default_repr=None))
return params
class ClassInitCollector(cst.CSTVisitor):
def __init__(self) -> None:
self.classes: dict[str, cst.ClassDef] = {}
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
if not has_decorator(node, "dataclass") and _find_init_method(node) is not None:
self.classes[node.name.value] = node
return True
def _find_all_classes_with_init(source_code: str) -> tuple[dict[str, cst.ClassDef], list[str], MetadataWrapper | None]:
try:
tree = parse_module_to_cst(source_code)
except cst.ParserSyntaxError:
return {}, [], None
wrapper = MetadataWrapper(tree)
collector = ClassInitCollector()
wrapper.visit(collector)
source_lines = source_code.split("\n")
return collector.classes, source_lines, wrapper
def format_init_signature(class_name: str, params: list[ParamInfo]) -> str:
if not params:
return f"{class_name}() - no parameters"
required = [p for p in params if not p.has_default and not p.name.startswith("*")]
optional = [p for p in params if p.has_default]
variadic = [p for p in params if p.name.startswith("*")]
lines = [f"Constructor signature for {class_name}:"]
if required:
lines.append(" Required (positional) arguments:")
for p in required:
if p.annotation:
lines.append(f" - {p.name}: {p.annotation}")
else:
lines.append(f" - {p.name}")
if optional:
lines.append(" Optional (keyword) arguments:")
for p in optional:
default_note = f" = {p.default_repr}" if p.default_repr else " = ..."
if p.annotation:
lines.append(f" - {p.name}: {p.annotation}{default_note}")
else:
lines.append(f" - {p.name}{default_note}")
if variadic:
lines.append(" Variadic arguments:")
for p in variadic:
if p.annotation:
lines.append(f" - {p.name}: {p.annotation}")
else:
lines.append(f" - {p.name}")
return "\n".join(lines)
def get_class_constructor_notes(test_context: str) -> list[str]:
code_to_analyze = extract_all_code_from_markdown(test_context) if "```python" in test_context else test_context
classes, source_lines, wrapper = _find_all_classes_with_init(code_to_analyze)
if not classes or wrapper is None:
return []
notes = []
for class_name, class_node in classes.items():
init_node = _find_init_method(class_node)
if init_node is None:
continue
params = extract_init_params(init_node, source_lines, wrapper)
signature = format_init_signature(class_name, params)
notes.append(signature)
return notes

View file

@ -1,176 +0,0 @@
"""Preprocessing step to extract dataclass constructor signatures.
When the test context contains dataclass definitions, this module extracts
the constructor signatures to help the LLM understand how to properly
instantiate them, including inherited fields from parent dataclasses.
"""
from __future__ import annotations
from typing import NamedTuple
import libcst as cst
from libcst.metadata import MetadataWrapper
from core.languages.python.cst_utils import (
get_base_class_name,
get_node_source_text,
has_decorator,
parse_module_to_cst,
)
from aiservice.common.markdown_utils import extract_all_code_from_markdown
class FieldInfo(NamedTuple):
"""Information about a dataclass field."""
name: str
annotation: str
has_default: bool
default_repr: str | None
source: str # "inherited" or "own"
def extract_dataclass_fields(
class_node: cst.ClassDef, source_lines: list[str], wrapper: MetadataWrapper
) -> list[FieldInfo]:
"""Extract fields from a dataclass definition.
Returns list of FieldInfo with (field_name, type_annotation, has_default, default_value).
"""
fields: list[FieldInfo] = []
for stmt in class_node.body.body:
if not isinstance(stmt, cst.SimpleStatementLine):
continue
for node in stmt.body:
if not (isinstance(node, cst.AnnAssign) and isinstance(node.target, cst.Name)):
continue
field_name = node.target.value
annotation = get_node_source_text(node.annotation, source_lines, wrapper)
default_repr = None
if node.value is not None:
default_repr = get_node_source_text(node.value, source_lines, wrapper)
# Truncate long defaults
if len(default_repr) > 50:
default_repr = default_repr[:47] + "..."
fields.append(FieldInfo(field_name, annotation, node.value is not None, default_repr, "own"))
return fields
class DataclassCollector(cst.CSTVisitor):
"""Visitor to collect all dataclass definitions."""
def __init__(self) -> None:
self.dataclasses: dict[str, cst.ClassDef] = {}
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
if has_decorator(node, "dataclass"):
self.dataclasses[node.name.value] = node
return True # Continue visiting nested classes
def find_all_dataclasses(source_code: str) -> tuple[dict[str, cst.ClassDef], list[str], MetadataWrapper | None]:
"""Find all dataclass definitions in source code.
Returns tuple of (dict mapping class name to class_node, source_lines, metadata_wrapper).
"""
try:
tree = parse_module_to_cst(source_code)
except cst.ParserSyntaxError:
return {}, [], None
wrapper = MetadataWrapper(tree)
collector = DataclassCollector()
wrapper.visit(collector)
source_lines = source_code.split("\n")
return collector.dataclasses, source_lines, wrapper
def get_all_fields_with_inheritance(
class_name: str, dataclasses: dict[str, cst.ClassDef], source_lines: list[str], wrapper: MetadataWrapper
) -> list[FieldInfo]:
"""Get all fields for a dataclass, including inherited ones.
Fields are returned in order: inherited first, then own fields.
"""
if class_name not in dataclasses:
return []
class_node = dataclasses[class_name]
# First, collect inherited fields from parent dataclasses
inherited_fields: list[FieldInfo] = []
if class_node.bases:
for base in class_node.bases:
base_name = get_base_class_name(base)
if base_name and base_name in dataclasses:
parent_fields = get_all_fields_with_inheritance(base_name, dataclasses, source_lines, wrapper)
# Mark these as inherited
inherited_fields.extend(
FieldInfo(field.name, field.annotation, field.has_default, field.default_repr, "inherited")
for field in parent_fields
)
# Then collect own fields
own_fields = extract_dataclass_fields(class_node, source_lines, wrapper)
return inherited_fields + own_fields
def format_constructor_signature(class_name: str, fields: list[FieldInfo]) -> str:
"""Format a constructor signature for documentation."""
if not fields:
return f"{class_name}() - no fields"
# Separate required and optional fields
required = [f for f in fields if not f.has_default]
optional = [f for f in fields if f.has_default]
lines = [f"Constructor signature for {class_name}:"]
if required:
lines.append(" Required (positional) arguments:")
for f in required:
source_note = " (from parent class)" if f.source == "inherited" else ""
lines.append(f" - {f.name}: {f.annotation}{source_note}")
if optional:
lines.append(" Optional (keyword) arguments:")
for f in optional:
default_note = f" = {f.default_repr}" if f.default_repr else " = ..."
source_note = " (from parent class)" if f.source == "inherited" else ""
lines.append(f" - {f.name}: {f.annotation}{default_note}{source_note}")
return "\n".join(lines)
def get_dataclass_constructor_notes(test_context: str) -> list[str]:
"""Generate notes about dataclass constructor signatures.
Analyzes the test context for dataclass definitions and generates
explicit notes about their constructor signatures, including inherited fields.
Args:
test_context: The source code context to analyze for dataclass definitions.
"""
# Extract actual code from markdown if present
code_to_analyze = extract_all_code_from_markdown(test_context) if "```python" in test_context else test_context
# Find all dataclasses in the context
dataclasses, source_lines, wrapper = find_all_dataclasses(code_to_analyze)
if not dataclasses or wrapper is None:
return []
notes = []
for class_name in dataclasses:
fields = get_all_fields_with_inheritance(class_name, dataclasses, source_lines, wrapper)
if fields:
signature = format_constructor_signature(class_name, fields)
notes.append(signature)
return notes

View file

@ -1,12 +1,10 @@
from itertools import chain
from core.languages.python.testgen.preprocessing.class_constructor_notes import get_class_constructor_notes
from core.languages.python.testgen.preprocessing.dataclass_constructor_notes import get_dataclass_constructor_notes
from core.languages.python.testgen.preprocessing.torch_tensor_limit import get_tensor_size_note
# Preprocessing functions that analyze code context and return notes
# Each function takes test_context (str) and returns a list of notes (list[str])
_PREPROCESSING_FUNCTIONS = [get_tensor_size_note, get_dataclass_constructor_notes, get_class_constructor_notes]
_PREPROCESSING_FUNCTIONS = [get_tensor_size_note]
def preprocessing_testgen_pipeline(test_context: str) -> list[str]:

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
import sentry_sdk
from asgiref.sync import sync_to_async
@ -17,6 +17,15 @@ if TYPE_CHECKING:
features_api = NinjaAPI(urls_namespace="log_features")
def _positional_list(items: list[str] | None, index: int | None) -> list[str | None] | None:
"""Place a single item at the correct index, padding with None."""
if not items or index is None or len(items) != 1:
return cast("list[str | None] | None", items)
result: list[str | None] = [None] * (index + 1)
result[index] = items[0]
return result
@sync_to_async
@transaction.atomic
def log_features(
@ -89,9 +98,9 @@ def log_features(
"optimized_runtime": optimized_runtime,
"optimized_line_profiler_results": optimized_line_profiler_results,
"is_correct": is_correct,
"generated_test": generated_tests,
"instrumented_generated_test": instrumented_generated_tests,
"instrumented_perf_test": instrumented_perf_tests,
"generated_test": _positional_list(generated_tests, test_index),
"instrumented_generated_test": _positional_list(instrumented_generated_tests, test_index),
"instrumented_perf_test": _positional_list(instrumented_perf_tests, test_index),
"test_framework": test_framework,
"created_at": datetime,
"aiservice_commit_id": aiservice_commit,

View file

@ -85,6 +85,7 @@ You are also provided with the following information.
- Introduction of the `global` and `nonlocal` keywords in optimizations is **HIGHLY DISCOURAGED** as it reduces code clarity and maintainability, introduces hidden dependencies, can cause subtle bugs and breaks modularity. **DO NOT** prefer such optimizations.
- Replacement of `isinstance()` checks with `type()` checks is **HIGHLY DISCOURAGED** as `isinstance()` correctly handles inheritance and subclasses, while `type()` checks are incorrect for subclass instances and represent a micro-optimization that should be avoided. Do not prefer such optimizations.
- If the only optimizations are micro-optimizations like inlining a function call, or localizing variables or methods (attribute lookup optimizations), do not prefer the optimizations. The performance improvements are minimal and come at a substantial cost to readability.
- Local variable caching of globals (e.g., `local_var = GLOBAL_VAR` before a loop) is a micro-optimization only relevant on Python 3.10. On Python 3.11+ `LOAD_GLOBAL` uses adaptive specialization and is nearly as fast as `LOAD_FAST`, so the benefit is negligible prefer the simpler code WITHOUT the local cache.
## Response Format
@ -383,9 +384,10 @@ async def rank_optimizations( # noqa: D417
json_response = _parse_json_response(output.content, num_candidates)
if json_response is not None:
logging.info("Successfully parsed JSON response")
return RankResponseSchema(
ranking=json_response.ranking, explanation=json_response.explanation, scores=json_response.scores
ranking = (
_scores_to_ranking(json_response.scores) if json_response.scores is not None else json_response.ranking
)
return RankResponseSchema(ranking=ranking, explanation=json_response.explanation, scores=json_response.scores)
# Fall back to regex parsing (legacy XML-tag format)
logging.info("JSON parsing failed, falling back to regex")

View file

@ -181,6 +181,23 @@ x ="""
assert result == expected
def test_extract_code_block_nested_code_fence_in_triple_quote() -> None:
# LLM embeds function definition in a triple-quoted string containing ```
text = '```python\nimport pytest\n_source = """```python:file.py\ndef foo(): pass\n```"""\ndef test_foo():\n assert True\n```'
result = extract_code_block(text)
assert (
result
== 'import pytest\n_source = """```python:file.py\ndef foo(): pass\n```"""\ndef test_foo():\n assert True'
)
def test_extract_code_block_nested_code_fence_block() -> None:
# LLM nests a ```python:filepath block inside the main block
text = "```python\nimport pytest\n```python:src/mod.py\ndef foo(): pass\n```\ndef test_foo():\n assert True\n```"
result = extract_code_block(text)
assert result == "import pytest\n```python:src/mod.py\ndef foo(): pass\n```\ndef test_foo():\n assert True"
def test_extract_all_code_single_block() -> None:
text = "```python\ncode1\n```"
result = extract_all_code_from_markdown(text)

View file

@ -1,418 +0,0 @@
"""Tests for regular class constructor signature extraction."""
import libcst as cst
from core.languages.python.testgen.preprocessing.class_constructor_notes import (
ParamInfo,
_find_all_classes_with_init,
extract_init_params,
format_init_signature,
get_class_constructor_notes,
)
class TestClassInitCollector:
def test_collects_class_with_init(self) -> None:
code = """
class Foo:
def __init__(self, x: int):
self.x = x
"""
classes, _, wrapper = _find_all_classes_with_init(code)
assert "Foo" in classes
assert wrapper is not None
def test_skips_class_without_init(self) -> None:
code = """
class Foo:
def bar(self):
pass
"""
classes, _, _ = _find_all_classes_with_init(code)
assert "Foo" not in classes
def test_skips_dataclass(self) -> None:
code = """
@dataclass
class Foo:
x: int
def __init__(self, x: int):
self.x = x
"""
classes, _, _ = _find_all_classes_with_init(code)
assert "Foo" not in classes
def test_skips_dataclass_with_call(self) -> None:
code = """
@dataclass(frozen=True)
class Foo:
x: int
def __init__(self, x: int):
self.x = x
"""
classes, _, _ = _find_all_classes_with_init(code)
assert "Foo" not in classes
def test_collects_multiple_classes(self) -> None:
code = """
class Foo:
def __init__(self, x: int):
self.x = x
class Bar:
def __init__(self, y: str):
self.y = y
"""
classes, _, _ = _find_all_classes_with_init(code)
assert "Foo" in classes
assert "Bar" in classes
def test_handles_syntax_error(self) -> None:
code = "this is not valid python code {{{{"
classes, _, wrapper = _find_all_classes_with_init(code)
assert classes == {}
assert wrapper is None
class TestExtractInitParams:
def test_typed_params(self) -> None:
code = """
class Foo:
def __init__(self, x: int, y: str):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
assert isinstance(init_node, cst.FunctionDef)
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 2
assert params[0].name == "x"
assert params[0].annotation == "int"
assert not params[0].has_default
assert params[1].name == "y"
assert params[1].annotation == "str"
assert not params[1].has_default
def test_params_with_defaults(self) -> None:
code = """
class Foo:
def __init__(self, x: int = 0, y: str = "hello"):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
assert isinstance(init_node, cst.FunctionDef)
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 2
assert params[0].name == "x"
assert params[0].has_default
assert params[0].default_repr == "0"
assert params[1].name == "y"
assert params[1].has_default
assert params[1].default_repr == '"hello"'
def test_args_and_kwargs(self) -> None:
code = """
class Foo:
def __init__(self, x: int, *args, **kwargs):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
assert isinstance(init_node, cst.FunctionDef)
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 3
assert params[0].name == "x"
assert params[1].name == "*args"
assert params[2].name == "**kwargs"
def test_skips_self(self) -> None:
code = """
class Foo:
def __init__(self):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
assert isinstance(init_node, cst.FunctionDef)
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 0
def test_untyped_params(self) -> None:
code = """
class Foo:
def __init__(self, x, y=10):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
assert isinstance(init_node, cst.FunctionDef)
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 2
assert params[0].name == "x"
assert params[0].annotation is None
assert params[1].name == "y"
assert params[1].annotation is None
assert params[1].has_default
assert params[1].default_repr == "10"
def test_kwonly_params(self) -> None:
code = """
class Foo:
def __init__(self, x: int, *, key: str = "default"):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
assert isinstance(init_node, cst.FunctionDef)
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 2
assert params[0].name == "x"
assert params[1].name == "key"
assert params[1].annotation == "str"
assert params[1].has_default
assert params[1].default_repr == '"default"'
def test_long_default_truncated(self) -> None:
code = """
class Foo:
def __init__(self, x: list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]):
pass
"""
classes, source_lines, wrapper = _find_all_classes_with_init(code)
assert wrapper is not None
init_node = next(iter(classes["Foo"].body.body))
assert isinstance(init_node, cst.FunctionDef)
params = extract_init_params(init_node, source_lines, wrapper)
assert len(params) == 1
assert params[0].has_default
assert params[0].default_repr is not None
assert len(params[0].default_repr) <= 50
assert params[0].default_repr.endswith("...")
class TestFormatInitSignature:
def test_required_only(self) -> None:
params = [
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
ParamInfo(name="y", annotation="str", has_default=False, default_repr=None),
]
result = format_init_signature("Foo", params)
assert "Constructor signature for Foo:" in result
assert "Required (positional) arguments:" in result
assert "- x: int" in result
assert "- y: str" in result
assert "Optional" not in result
def test_optional_only(self) -> None:
params = [
ParamInfo(name="x", annotation="int", has_default=True, default_repr="0"),
ParamInfo(name="y", annotation="str", has_default=True, default_repr='"hi"'),
]
result = format_init_signature("Foo", params)
assert "Constructor signature for Foo:" in result
assert "Optional (keyword) arguments:" in result
assert "- x: int = 0" in result
assert '- y: str = "hi"' in result
assert "Required" not in result
def test_mixed_params(self) -> None:
params = [
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
ParamInfo(name="y", annotation="str", has_default=True, default_repr='"default"'),
]
result = format_init_signature("Foo", params)
assert "Required (positional) arguments:" in result
assert "Optional (keyword) arguments:" in result
def test_no_params(self) -> None:
result = format_init_signature("Empty", [])
assert "Empty() - no parameters" in result
def test_variadic_params(self) -> None:
params = [
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
ParamInfo(name="*args", annotation=None, has_default=False, default_repr=None),
ParamInfo(name="**kwargs", annotation=None, has_default=False, default_repr=None),
]
result = format_init_signature("Foo", params)
assert "Required (positional) arguments:" in result
assert "Variadic arguments:" in result
assert "- *args" in result
assert "- **kwargs" in result
def test_untyped_params(self) -> None:
params = [
ParamInfo(name="x", annotation=None, has_default=False, default_repr=None),
ParamInfo(name="y", annotation=None, has_default=True, default_repr="10"),
]
result = format_init_signature("Foo", params)
assert " - x\n" in result
assert " - y = 10" in result
class TestGetClassConstructorNotes:
def test_class_with_typed_init(self) -> None:
context = """
class LayoutElements:
def __init__(self, width: int, height: int, elements: list[str]):
self.width = width
self.height = height
self.elements = elements
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "Constructor signature for LayoutElements:" in notes[0]
assert "- width: int" in notes[0]
assert "- height: int" in notes[0]
assert "- elements: list[str]" in notes[0]
def test_class_with_defaults(self) -> None:
context = """
class TextRegions:
def __init__(self, text: str = "", max_len: int = 100):
self.text = text
self.max_len = max_len
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "Optional (keyword) arguments:" in notes[0]
assert '- text: str = ""' in notes[0]
assert "- max_len: int = 100" in notes[0]
def test_class_with_args_kwargs(self) -> None:
context = """
class Flexible:
def __init__(self, name: str, *args, **kwargs):
self.name = name
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "- name: str" in notes[0]
assert "- *args" in notes[0]
assert "- **kwargs" in notes[0]
def test_class_without_init_skipped(self) -> None:
context = """
class NoInit:
x = 10
def method(self):
pass
"""
notes = get_class_constructor_notes(context)
assert notes == []
def test_dataclass_skipped(self) -> None:
context = """
@dataclass
class Config:
name: str
value: int
"""
notes = get_class_constructor_notes(context)
assert notes == []
def test_multiple_classes(self) -> None:
context = """
class Foo:
def __init__(self, x: int):
self.x = x
class Bar:
def __init__(self, y: str):
self.y = y
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 2
class_names = " ".join(notes)
assert "Foo" in class_names
assert "Bar" in class_names
def test_markdown_wrapped_code(self) -> None:
context = """
Some description text.
```python:models.py
class LayoutElements:
def __init__(self, width: int, height: int):
self.width = width
self.height = height
```
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "LayoutElements" in notes[0]
assert "- width: int" in notes[0]
def test_syntax_error_returns_empty(self) -> None:
context = "this is not valid python {{{{"
notes = get_class_constructor_notes(context)
assert notes == []
def test_init_with_only_self(self) -> None:
context = """
class Empty:
def __init__(self):
pass
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "Empty() - no parameters" in notes[0]
def test_mixed_dataclass_and_regular(self) -> None:
context = """
@dataclass
class Config:
name: str
class Service:
def __init__(self, config: Config, debug: bool = False):
self.config = config
self.debug = debug
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "Service" in notes[0]
assert "Config" not in notes[0].split("\n")[0]
def test_plain_code_without_markdown(self) -> None:
context = """
class MyClass:
def __init__(self, value: int):
self.value = value
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
assert "MyClass" in notes[0]
def test_full_output_format(self) -> None:
context = """
class Server:
def __init__(self, host: str, port: int, debug: bool = False, workers: int = 4):
self.host = host
self.port = port
self.debug = debug
self.workers = workers
"""
notes = get_class_constructor_notes(context)
assert len(notes) == 1
expected = """Constructor signature for Server:
Required (positional) arguments:
- host: str
- port: int
Optional (keyword) arguments:
- debug: bool = False
- workers: int = 4"""
assert notes[0] == expected

View file

@ -1,529 +0,0 @@
"""Tests for dataclass constructor signature extraction."""
import libcst as cst
from core.languages.python.cst_utils import get_base_class_name, has_decorator
from aiservice.common.markdown_utils import extract_all_code_from_markdown
from core.languages.python.testgen.preprocessing.dataclass_constructor_notes import (
FieldInfo,
extract_dataclass_fields,
find_all_dataclasses,
format_constructor_signature,
get_all_fields_with_inheritance,
get_dataclass_constructor_notes,
)
class TestHasDecorator:
def test_simple_dataclass_decorator(self) -> None:
code = """
@dataclass
class Foo:
x: int
"""
tree = cst.parse_module(code)
class_node = tree.body[0]
assert isinstance(class_node, cst.ClassDef)
assert has_decorator(class_node, "dataclass")
def test_dataclass_with_call(self) -> None:
code = """
@dataclass(frozen=True)
class Foo:
x: int
"""
tree = cst.parse_module(code)
class_node = tree.body[0]
assert isinstance(class_node, cst.ClassDef)
assert has_decorator(class_node, "dataclass")
def test_not_a_dataclass(self) -> None:
code = """
class Foo:
x: int
"""
tree = cst.parse_module(code)
class_node = tree.body[0]
assert isinstance(class_node, cst.ClassDef)
assert not has_decorator(class_node, "dataclass")
def test_other_decorator(self) -> None:
code = """
@other_decorator
class Foo:
x: int
"""
tree = cst.parse_module(code)
class_node = tree.body[0]
assert isinstance(class_node, cst.ClassDef)
assert not has_decorator(class_node, "dataclass")
class TestGetBaseClassName:
def test_simple_name(self) -> None:
code = "class Foo(Bar): pass"
tree = cst.parse_module(code)
class_node = tree.body[0]
assert isinstance(class_node, cst.ClassDef)
assert get_base_class_name(class_node.bases[0]) == "Bar"
def test_attribute(self) -> None:
code = "class Foo(module.Bar): pass"
tree = cst.parse_module(code)
class_node = tree.body[0]
assert isinstance(class_node, cst.ClassDef)
assert get_base_class_name(class_node.bases[0]) == "Bar"
class TestExtractDataclassFields:
def test_simple_fields(self) -> None:
code = """
@dataclass
class Foo:
x: int
y: str
"""
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
assert wrapper is not None
class_node = dataclasses["Foo"]
fields = extract_dataclass_fields(class_node, source_lines, wrapper)
assert len(fields) == 2
assert fields[0].name == "x"
assert fields[0].annotation == "int"
assert not fields[0].has_default
assert fields[1].name == "y"
assert fields[1].annotation == "str"
assert not fields[1].has_default
def test_fields_with_defaults(self) -> None:
code = """
@dataclass
class Foo:
x: int
y: str = "default"
"""
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
assert wrapper is not None
class_node = dataclasses["Foo"]
fields = extract_dataclass_fields(class_node, source_lines, wrapper)
assert len(fields) == 2
assert fields[0].name == "x"
assert not fields[0].has_default
assert fields[1].name == "y"
assert fields[1].has_default
assert fields[1].default_repr == '"default"'
def test_complex_annotation(self) -> None:
code = """
@dataclass
class Foo:
items: list[str]
mapping: dict[str, int] | None = None
"""
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
assert wrapper is not None
class_node = dataclasses["Foo"]
fields = extract_dataclass_fields(class_node, source_lines, wrapper)
assert len(fields) == 2
assert fields[0].name == "items"
assert fields[0].annotation == "list[str]"
assert fields[1].name == "mapping"
assert "dict[str, int]" in fields[1].annotation
class TestFindAllDataclasses:
def test_finds_single_dataclass(self) -> None:
code = """
from dataclasses import dataclass
@dataclass
class Foo:
x: int
"""
result, source_lines, wrapper = find_all_dataclasses(code)
assert "Foo" in result
assert wrapper is not None
assert len(source_lines) > 0
def test_finds_multiple_dataclasses(self) -> None:
code = """
from dataclasses import dataclass
@dataclass
class Foo:
x: int
@dataclass
class Bar:
y: str
"""
result, _source_lines, wrapper = find_all_dataclasses(code)
assert "Foo" in result
assert "Bar" in result
assert wrapper is not None
def test_ignores_non_dataclasses(self) -> None:
code = """
class NotADataclass:
x: int
@dataclass
class IsADataclass:
y: str
"""
result, _source_lines, wrapper = find_all_dataclasses(code)
assert "NotADataclass" not in result
assert "IsADataclass" in result
assert wrapper is not None
def test_handles_syntax_error(self) -> None:
code = "this is not valid python code {{{{"
result, _source_lines, wrapper = find_all_dataclasses(code)
assert result == {}
assert wrapper is None
class TestGetAllFieldsWithInheritance:
def test_no_inheritance(self) -> None:
code = """
@dataclass
class Foo:
x: int
y: str
"""
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
assert wrapper is not None
fields = get_all_fields_with_inheritance("Foo", dataclasses, source_lines, wrapper)
assert len(fields) == 2
assert all(f.source == "own" for f in fields)
def test_simple_inheritance(self) -> None:
code = """
@dataclass
class Base:
a: int
b: str
@dataclass
class Child(Base):
c: float
"""
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
assert wrapper is not None
fields = get_all_fields_with_inheritance("Child", dataclasses, source_lines, wrapper)
assert len(fields) == 3
assert fields[0].name == "a"
assert fields[0].source == "inherited"
assert fields[1].name == "b"
assert fields[1].source == "inherited"
assert fields[2].name == "c"
assert fields[2].source == "own"
def test_multi_level_inheritance(self) -> None:
code = """
@dataclass
class GrandParent:
a: int
@dataclass
class Parent(GrandParent):
b: str
@dataclass
class Child(Parent):
c: float
"""
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
assert wrapper is not None
fields = get_all_fields_with_inheritance("Child", dataclasses, source_lines, wrapper)
assert len(fields) == 3
field_names = [f.name for f in fields]
assert "a" in field_names
assert "b" in field_names
assert "c" in field_names
def test_unknown_class(self) -> None:
code = """
@dataclass
class Foo:
x: int
"""
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
assert wrapper is not None
fields = get_all_fields_with_inheritance("Unknown", dataclasses, source_lines, wrapper)
assert fields == []
class TestFormatConstructorSignature:
def test_required_only(self) -> None:
fields = [
FieldInfo(name="x", annotation="int", has_default=False, default_repr=None, source="own"),
FieldInfo(name="y", annotation="str", has_default=False, default_repr=None, source="own"),
]
result = format_constructor_signature("Foo", fields)
assert "Constructor signature for Foo:" in result
assert "Required (positional) arguments:" in result
assert "- x: int" in result
assert "- y: str" in result
assert "Optional" not in result
def test_optional_only(self) -> None:
fields = [
FieldInfo(name="x", annotation="int", has_default=True, default_repr="0", source="own"),
FieldInfo(name="y", annotation="str", has_default=True, default_repr='"default"', source="own"),
]
result = format_constructor_signature("Foo", fields)
assert "Constructor signature for Foo:" in result
assert "Optional (keyword) arguments:" in result
assert "- x: int = 0" in result
assert '- y: str = "default"' in result
assert "Required" not in result
def test_mixed_fields(self) -> None:
fields = [
FieldInfo(name="a", annotation="int", has_default=False, default_repr=None, source="inherited"),
FieldInfo(name="b", annotation="str", has_default=False, default_repr=None, source="own"),
FieldInfo(name="c", annotation="float", has_default=True, default_repr="0.0", source="own"),
]
result = format_constructor_signature("Foo", fields)
assert "Required (positional) arguments:" in result
assert "Optional (keyword) arguments:" in result
assert "(from parent class)" in result
def test_no_fields(self) -> None:
result = format_constructor_signature("Empty", [])
assert "no fields" in result
class TestExtractCodeFromMarkdown:
def test_simple_code_block(self) -> None:
markdown = """
Some text
```python
def foo():
pass
```
More text
"""
result = extract_all_code_from_markdown(markdown)
assert "def foo():" in result
assert "pass" in result
assert "Some text" not in result
def test_code_block_with_filepath(self) -> None:
markdown = """
```python:path/to/file.py
class Foo:
x: int
```
"""
result = extract_all_code_from_markdown(markdown)
assert "class Foo:" in result
assert "x: int" in result
def test_multiple_code_blocks(self) -> None:
markdown = """
```python:file1.py
@dataclass
class A:
x: int
```
```python:file2.py
@dataclass
class B:
y: str
```
"""
result = extract_all_code_from_markdown(markdown)
assert "class A:" in result
assert "class B:" in result
class TestGetDataclassConstructorNotes:
def test_simple_dataclass(self) -> None:
context = """
```python:models.py
from dataclasses import dataclass
@dataclass
class Config:
name: str
value: int
```
"""
notes = get_dataclass_constructor_notes(context)
assert len(notes) == 1
assert "Constructor signature for Config:" in notes[0]
assert "- name: str" in notes[0]
assert "- value: int" in notes[0]
def test_dataclass_with_inheritance(self) -> None:
context = """
```python:models.py
from dataclasses import dataclass
@dataclass
class BaseConfig:
model_name: str
required_env_vars: list[str]
@dataclass
class ExtendedConfig(BaseConfig):
extra_param: int = 0
```
"""
notes = get_dataclass_constructor_notes(context)
assert len(notes) == 2
# Find the note for ExtendedConfig
extended_note = next(n for n in notes if "ExtendedConfig" in n)
assert "- model_name: str (from parent class)" in extended_note
assert "- required_env_vars: list[str] (from parent class)" in extended_note
assert "- extra_param: int = 0" in extended_note
def test_no_dataclasses(self) -> None:
context = """
def regular_function():
pass
"""
notes = get_dataclass_constructor_notes(context)
assert notes == []
def test_plain_code_without_markdown(self) -> None:
context = """
from dataclasses import dataclass
@dataclass
class SimpleConfig:
name: str
"""
notes = get_dataclass_constructor_notes(context)
assert len(notes) == 1
assert "SimpleConfig" in notes[0]
def test_llm_config_like_structure(self) -> None:
"""Test with a structure similar to the skyvern LLMConfig."""
context = """
```python:skyvern/forge/sdk/api/llm/models.py
from dataclasses import dataclass, field
from typing import Optional
@dataclass(frozen=True)
class LLMConfigBase:
model_name: str
required_env_vars: list[str]
supports_vision: bool
add_assistant_prefix: bool
@dataclass(frozen=True)
class LLMConfig(LLMConfigBase):
litellm_params: Optional[dict] = field(default=None)
max_tokens: int | None = 4096
```
"""
notes = get_dataclass_constructor_notes(context)
# Should have notes for both classes
assert len(notes) == 2
# Find the LLMConfig note
llm_config_note = next(n for n in notes if "LLMConfig:" in n)
# Should show inherited fields from LLMConfigBase
assert "- model_name: str (from parent class)" in llm_config_note
assert "- required_env_vars: list[str] (from parent class)" in llm_config_note
assert "- supports_vision: bool (from parent class)" in llm_config_note
assert "- add_assistant_prefix: bool (from parent class)" in llm_config_note
# Should also show own fields
assert "- litellm_params:" in llm_config_note
assert "- max_tokens:" in llm_config_note
def test_full_output_for_complex_inheritance(self) -> None:
"""Test full generated notes output for complex dataclass inheritance."""
context = """
```python:skyvern/forge/sdk/api/llm/models.py
from dataclasses import dataclass, field
from typing import Optional
@dataclass(frozen=True)
class LLMConfigBase:
model_name: str
required_env_vars: list[str]
supports_vision: bool
add_assistant_prefix: bool
@dataclass(frozen=True)
class LLMConfig(LLMConfigBase):
litellm_params: Optional[dict] = field(default=None)
max_tokens: int | None = 4096
max_completion_tokens: int | None = None
temperature: float | None = 0.7
reasoning_effort: str | None = None
@dataclass(frozen=True)
class LLMRouterConfig(LLMConfigBase):
model_list: list
main_model_group: str
redis_host: str | None = None
redis_port: int | None = None
fallback_model_group: str | None = None
routing_strategy: str = "usage-based-routing"
num_retries: int = 1
```
"""
notes = get_dataclass_constructor_notes(context)
# Should have notes for all 3 classes
assert len(notes) == 3
# Verify LLMConfigBase note (no inheritance)
base_note = next(n for n in notes if "LLMConfigBase:" in n)
expected_base = """Constructor signature for LLMConfigBase:
Required (positional) arguments:
- model_name: str
- required_env_vars: list[str]
- supports_vision: bool
- add_assistant_prefix: bool"""
assert base_note == expected_base
# Verify LLMConfig note (with inheritance)
config_note = next(n for n in notes if "LLMConfig:" in n)
expected_config = """Constructor signature for LLMConfig:
Required (positional) arguments:
- model_name: str (from parent class)
- required_env_vars: list[str] (from parent class)
- supports_vision: bool (from parent class)
- add_assistant_prefix: bool (from parent class)
Optional (keyword) arguments:
- litellm_params: Optional[dict] = field(default=None)
- max_tokens: int | None = 4096
- max_completion_tokens: int | None = None
- temperature: float | None = 0.7
- reasoning_effort: str | None = None"""
assert config_note == expected_config
# Verify LLMRouterConfig note (with inheritance + own required fields)
router_note = next(n for n in notes if "LLMRouterConfig:" in n)
expected_router = """Constructor signature for LLMRouterConfig:
Required (positional) arguments:
- model_name: str (from parent class)
- required_env_vars: list[str] (from parent class)
- supports_vision: bool (from parent class)
- add_assistant_prefix: bool (from parent class)
- model_list: list
- main_model_group: str
Optional (keyword) arguments:
- redis_host: str | None = None
- redis_port: int | None = None
- fallback_model_group: str | None = None
- routing_strategy: str = "usage-based-routing"
- num_retries: int = 1"""
assert router_note == expected_router

View file

@ -8,5 +8,5 @@ NPM_TOKEN
SCM_DO_BUILD_DURING_DEPLOYMENT
WEBSITE_HEALTHCHECK_MAXPINGFAILURES
WEBSITE_HTTPLOGGING_RETENTION_DAYS
CODEFLASH_INTERNAL_REPO_PATH=
CODEFLASH_CLI_REPO_PATH=
AISERVICE_DIR=
CODEFLASH_DIR=

View file

@ -168,8 +168,8 @@ export function indexTraceData(traceData: TraceData): IndexedTraceData {
// --- Codebase browsing helpers ---
function getRepoRoot(repo: string): string | null {
if (repo === "codeflash-internal") return process.env.CODEFLASH_INTERNAL_REPO_PATH || null
if (repo === "codeflash") return process.env.CODEFLASH_CLI_REPO_PATH || null
if (repo === "codeflash-internal") return process.env.AISERVICE_DIR || null
if (repo === "codeflash") return process.env.CODEFLASH_DIR || null
return null
}
@ -630,7 +630,7 @@ export async function resolveToolCall(
case "read_file": {
const repoRoot = getRepoRoot(args.repo as string)
if (!repoRoot) return `Repository path not configured. Set CODEFLASH_INTERNAL_REPO_PATH or CODEFLASH_CLI_REPO_PATH env var.`
if (!repoRoot) return `Repository path not configured. Set AISERVICE_DIR or CODEFLASH_DIR env var.`
const pathResult = resolveAndValidatePath(repoRoot, args.path as string)
if ("error" in pathResult) return pathResult.error
@ -657,7 +657,7 @@ export async function resolveToolCall(
case "search_code": {
const repoRoot = getRepoRoot(args.repo as string)
if (!repoRoot) return `Repository path not configured. Set CODEFLASH_INTERNAL_REPO_PATH or CODEFLASH_CLI_REPO_PATH env var.`
if (!repoRoot) return `Repository path not configured. Set AISERVICE_DIR or CODEFLASH_DIR env var.`
const maxResults = Math.min(Math.max(1, (args.max_results as number) || 30), 100)
const rgArgs = [
@ -699,7 +699,7 @@ export async function resolveToolCall(
case "list_directory": {
const repoRoot = getRepoRoot(args.repo as string)
if (!repoRoot) return `Repository path not configured. Set CODEFLASH_INTERNAL_REPO_PATH or CODEFLASH_CLI_REPO_PATH env var.`
if (!repoRoot) return `Repository path not configured. Set AISERVICE_DIR or CODEFLASH_DIR env var.`
const relativePath = (args.path as string) || "."
const pathResult = resolveAndValidatePath(repoRoot, relativePath)