mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Merge branch 'main' into react_omni
This commit is contained in:
commit
ef0fd4b2ab
16 changed files with 167 additions and 1324 deletions
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
|
|
@ -168,6 +168,8 @@ jobs:
|
|||
env:
|
||||
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.ANTHROPIC_FOUNDRY_API_KEY }}
|
||||
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.ANTHROPIC_FOUNDRY_BASE_URL }}
|
||||
DATABASE_URL: ${{ secrets.DATABASE_URL }}
|
||||
DJANGO_SETTINGS_MODULE: aiservice.settings
|
||||
|
||||
# @claude mentions (can edit and push)
|
||||
claude-mention:
|
||||
|
|
|
|||
|
|
@ -13,8 +13,10 @@ from aiservice.common.llm_output_utils import truncate_pathological_output
|
|||
# Matches both ```python and ```python:filepath blocks, captures content only
|
||||
MARKDOWN_CODE_BLOCK_PATTERN = re.compile(r"```python(?::[^\n]*)?\n(.*?)```", re.DOTALL)
|
||||
|
||||
# Matches first ```python block (no filepath), captures content
|
||||
FIRST_CODE_BLOCK_PATTERN = re.compile(r"^```python\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
|
||||
# Matches first ```python block (no filepath), captures content.
|
||||
# Uses greedy (.*) to handle LLM outputs with nested code fences (e.g. ```python:filepath
|
||||
# blocks inside the main block). Requires closing ``` to be alone on its line.
|
||||
FIRST_CODE_BLOCK_PATTERN = re.compile(r"^```python\s*\n(.*)\n```[ \t]*$", re.MULTILINE | re.DOTALL)
|
||||
|
||||
# Fallback for incomplete code blocks (missing closing ```)
|
||||
FIRST_CODE_BLOCK_FALLBACK_PATTERN = re.compile(r"^```python\s*\n(.*)", re.MULTILINE | re.DOTALL)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
|
||||
import isort
|
||||
|
|
@ -29,10 +30,13 @@ def parse_python_version(version: str | None) -> tuple[int, int, int]:
|
|||
split_version = version.split(".")
|
||||
if len(split_version) != 3:
|
||||
raise ValueError("Invalid version format")
|
||||
major, minor, patch = int(split_version[0]), int(split_version[1]), int(split_version[2])
|
||||
patch_str = re.match(r"\d+", split_version[2])
|
||||
if not patch_str:
|
||||
raise ValueError("Invalid patch version")
|
||||
major, minor, patch = int(split_version[0]), int(split_version[1]), int(patch_str.group())
|
||||
assert major == 3, "Only Python 3 is supported"
|
||||
assert minor >= 9, "Only Python 3.9 and above is supported"
|
||||
assert minor <= 14, "Unsupported Python version"
|
||||
assert minor <= 15, "Unsupported Python version"
|
||||
assert patch >= 0, "Only Python 3.9 and above is supported"
|
||||
assert patch < 100, "Invalid version format"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,165 +0,0 @@
|
|||
"""Preprocessing step to extract constructor signatures from regular (non-dataclass) classes.
|
||||
|
||||
When the test context contains class definitions with explicit __init__ methods,
|
||||
this module extracts the constructor signatures to help the LLM understand how
|
||||
to properly instantiate them.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import libcst as cst
|
||||
from libcst.metadata import MetadataWrapper
|
||||
|
||||
from aiservice.common.markdown_utils import extract_all_code_from_markdown
|
||||
from core.languages.python.cst_utils import get_node_source_text, has_decorator, parse_module_to_cst
|
||||
|
||||
|
||||
class ParamInfo(NamedTuple):
|
||||
name: str
|
||||
annotation: str | None
|
||||
has_default: bool
|
||||
default_repr: str | None
|
||||
|
||||
|
||||
def _find_init_method(class_node: cst.ClassDef) -> cst.FunctionDef | None:
|
||||
for stmt in class_node.body.body:
|
||||
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == "__init__":
|
||||
return stmt
|
||||
return None
|
||||
|
||||
|
||||
def extract_init_params(
|
||||
init_node: cst.FunctionDef, source_lines: list[str], wrapper: MetadataWrapper
|
||||
) -> list[ParamInfo]:
|
||||
params: list[ParamInfo] = []
|
||||
parameters = init_node.params
|
||||
|
||||
for param in parameters.params:
|
||||
name = param.name.value
|
||||
if name == "self":
|
||||
continue
|
||||
|
||||
annotation = None
|
||||
if param.annotation is not None:
|
||||
annotation = get_node_source_text(param.annotation, source_lines, wrapper)
|
||||
|
||||
default_repr = None
|
||||
if param.default is not None:
|
||||
default_repr = get_node_source_text(param.default, source_lines, wrapper)
|
||||
if len(default_repr) > 50:
|
||||
default_repr = default_repr[:47] + "..."
|
||||
|
||||
params.append(ParamInfo(name, annotation, has_default=param.default is not None, default_repr=default_repr))
|
||||
|
||||
if isinstance(parameters.star_arg, cst.Param):
|
||||
name = f"*{parameters.star_arg.name.value}"
|
||||
annotation = None
|
||||
if parameters.star_arg.annotation is not None:
|
||||
annotation = get_node_source_text(parameters.star_arg.annotation, source_lines, wrapper)
|
||||
params.append(ParamInfo(name, annotation, has_default=False, default_repr=None))
|
||||
|
||||
for param in parameters.kwonly_params:
|
||||
name = param.name.value
|
||||
annotation = None
|
||||
if param.annotation is not None:
|
||||
annotation = get_node_source_text(param.annotation, source_lines, wrapper)
|
||||
|
||||
default_repr = None
|
||||
if param.default is not None:
|
||||
default_repr = get_node_source_text(param.default, source_lines, wrapper)
|
||||
if len(default_repr) > 50:
|
||||
default_repr = default_repr[:47] + "..."
|
||||
|
||||
params.append(ParamInfo(name, annotation, has_default=param.default is not None, default_repr=default_repr))
|
||||
|
||||
if parameters.star_kwarg is not None:
|
||||
name = f"**{parameters.star_kwarg.name.value}"
|
||||
annotation = None
|
||||
if parameters.star_kwarg.annotation is not None:
|
||||
annotation = get_node_source_text(parameters.star_kwarg.annotation, source_lines, wrapper)
|
||||
params.append(ParamInfo(name, annotation, has_default=False, default_repr=None))
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class ClassInitCollector(cst.CSTVisitor):
|
||||
def __init__(self) -> None:
|
||||
self.classes: dict[str, cst.ClassDef] = {}
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
if not has_decorator(node, "dataclass") and _find_init_method(node) is not None:
|
||||
self.classes[node.name.value] = node
|
||||
return True
|
||||
|
||||
|
||||
def _find_all_classes_with_init(source_code: str) -> tuple[dict[str, cst.ClassDef], list[str], MetadataWrapper | None]:
|
||||
try:
|
||||
tree = parse_module_to_cst(source_code)
|
||||
except cst.ParserSyntaxError:
|
||||
return {}, [], None
|
||||
|
||||
wrapper = MetadataWrapper(tree)
|
||||
collector = ClassInitCollector()
|
||||
wrapper.visit(collector)
|
||||
source_lines = source_code.split("\n")
|
||||
return collector.classes, source_lines, wrapper
|
||||
|
||||
|
||||
def format_init_signature(class_name: str, params: list[ParamInfo]) -> str:
|
||||
if not params:
|
||||
return f"{class_name}() - no parameters"
|
||||
|
||||
required = [p for p in params if not p.has_default and not p.name.startswith("*")]
|
||||
optional = [p for p in params if p.has_default]
|
||||
variadic = [p for p in params if p.name.startswith("*")]
|
||||
|
||||
lines = [f"Constructor signature for {class_name}:"]
|
||||
|
||||
if required:
|
||||
lines.append(" Required (positional) arguments:")
|
||||
for p in required:
|
||||
if p.annotation:
|
||||
lines.append(f" - {p.name}: {p.annotation}")
|
||||
else:
|
||||
lines.append(f" - {p.name}")
|
||||
|
||||
if optional:
|
||||
lines.append(" Optional (keyword) arguments:")
|
||||
for p in optional:
|
||||
default_note = f" = {p.default_repr}" if p.default_repr else " = ..."
|
||||
if p.annotation:
|
||||
lines.append(f" - {p.name}: {p.annotation}{default_note}")
|
||||
else:
|
||||
lines.append(f" - {p.name}{default_note}")
|
||||
|
||||
if variadic:
|
||||
lines.append(" Variadic arguments:")
|
||||
for p in variadic:
|
||||
if p.annotation:
|
||||
lines.append(f" - {p.name}: {p.annotation}")
|
||||
else:
|
||||
lines.append(f" - {p.name}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_class_constructor_notes(test_context: str) -> list[str]:
|
||||
code_to_analyze = extract_all_code_from_markdown(test_context) if "```python" in test_context else test_context
|
||||
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code_to_analyze)
|
||||
|
||||
if not classes or wrapper is None:
|
||||
return []
|
||||
|
||||
notes = []
|
||||
for class_name, class_node in classes.items():
|
||||
init_node = _find_init_method(class_node)
|
||||
if init_node is None:
|
||||
continue
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
signature = format_init_signature(class_name, params)
|
||||
notes.append(signature)
|
||||
|
||||
return notes
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
"""Preprocessing step to extract dataclass constructor signatures.
|
||||
|
||||
When the test context contains dataclass definitions, this module extracts
|
||||
the constructor signatures to help the LLM understand how to properly
|
||||
instantiate them, including inherited fields from parent dataclasses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import libcst as cst
|
||||
from libcst.metadata import MetadataWrapper
|
||||
|
||||
from core.languages.python.cst_utils import (
|
||||
get_base_class_name,
|
||||
get_node_source_text,
|
||||
has_decorator,
|
||||
parse_module_to_cst,
|
||||
)
|
||||
from aiservice.common.markdown_utils import extract_all_code_from_markdown
|
||||
|
||||
|
||||
class FieldInfo(NamedTuple):
|
||||
"""Information about a dataclass field."""
|
||||
|
||||
name: str
|
||||
annotation: str
|
||||
has_default: bool
|
||||
default_repr: str | None
|
||||
source: str # "inherited" or "own"
|
||||
|
||||
|
||||
def extract_dataclass_fields(
|
||||
class_node: cst.ClassDef, source_lines: list[str], wrapper: MetadataWrapper
|
||||
) -> list[FieldInfo]:
|
||||
"""Extract fields from a dataclass definition.
|
||||
|
||||
Returns list of FieldInfo with (field_name, type_annotation, has_default, default_value).
|
||||
"""
|
||||
fields: list[FieldInfo] = []
|
||||
for stmt in class_node.body.body:
|
||||
if not isinstance(stmt, cst.SimpleStatementLine):
|
||||
continue
|
||||
for node in stmt.body:
|
||||
if not (isinstance(node, cst.AnnAssign) and isinstance(node.target, cst.Name)):
|
||||
continue
|
||||
|
||||
field_name = node.target.value
|
||||
annotation = get_node_source_text(node.annotation, source_lines, wrapper)
|
||||
|
||||
default_repr = None
|
||||
if node.value is not None:
|
||||
default_repr = get_node_source_text(node.value, source_lines, wrapper)
|
||||
# Truncate long defaults
|
||||
if len(default_repr) > 50:
|
||||
default_repr = default_repr[:47] + "..."
|
||||
|
||||
fields.append(FieldInfo(field_name, annotation, node.value is not None, default_repr, "own"))
|
||||
return fields
|
||||
|
||||
|
||||
class DataclassCollector(cst.CSTVisitor):
|
||||
"""Visitor to collect all dataclass definitions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.dataclasses: dict[str, cst.ClassDef] = {}
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
if has_decorator(node, "dataclass"):
|
||||
self.dataclasses[node.name.value] = node
|
||||
return True # Continue visiting nested classes
|
||||
|
||||
|
||||
def find_all_dataclasses(source_code: str) -> tuple[dict[str, cst.ClassDef], list[str], MetadataWrapper | None]:
|
||||
"""Find all dataclass definitions in source code.
|
||||
|
||||
Returns tuple of (dict mapping class name to class_node, source_lines, metadata_wrapper).
|
||||
"""
|
||||
try:
|
||||
tree = parse_module_to_cst(source_code)
|
||||
except cst.ParserSyntaxError:
|
||||
return {}, [], None
|
||||
|
||||
wrapper = MetadataWrapper(tree)
|
||||
collector = DataclassCollector()
|
||||
wrapper.visit(collector)
|
||||
source_lines = source_code.split("\n")
|
||||
return collector.dataclasses, source_lines, wrapper
|
||||
|
||||
|
||||
def get_all_fields_with_inheritance(
|
||||
class_name: str, dataclasses: dict[str, cst.ClassDef], source_lines: list[str], wrapper: MetadataWrapper
|
||||
) -> list[FieldInfo]:
|
||||
"""Get all fields for a dataclass, including inherited ones.
|
||||
|
||||
Fields are returned in order: inherited first, then own fields.
|
||||
"""
|
||||
if class_name not in dataclasses:
|
||||
return []
|
||||
|
||||
class_node = dataclasses[class_name]
|
||||
|
||||
# First, collect inherited fields from parent dataclasses
|
||||
inherited_fields: list[FieldInfo] = []
|
||||
if class_node.bases:
|
||||
for base in class_node.bases:
|
||||
base_name = get_base_class_name(base)
|
||||
if base_name and base_name in dataclasses:
|
||||
parent_fields = get_all_fields_with_inheritance(base_name, dataclasses, source_lines, wrapper)
|
||||
# Mark these as inherited
|
||||
inherited_fields.extend(
|
||||
FieldInfo(field.name, field.annotation, field.has_default, field.default_repr, "inherited")
|
||||
for field in parent_fields
|
||||
)
|
||||
|
||||
# Then collect own fields
|
||||
own_fields = extract_dataclass_fields(class_node, source_lines, wrapper)
|
||||
|
||||
return inherited_fields + own_fields
|
||||
|
||||
|
||||
def format_constructor_signature(class_name: str, fields: list[FieldInfo]) -> str:
|
||||
"""Format a constructor signature for documentation."""
|
||||
if not fields:
|
||||
return f"{class_name}() - no fields"
|
||||
|
||||
# Separate required and optional fields
|
||||
required = [f for f in fields if not f.has_default]
|
||||
optional = [f for f in fields if f.has_default]
|
||||
|
||||
lines = [f"Constructor signature for {class_name}:"]
|
||||
|
||||
if required:
|
||||
lines.append(" Required (positional) arguments:")
|
||||
for f in required:
|
||||
source_note = " (from parent class)" if f.source == "inherited" else ""
|
||||
lines.append(f" - {f.name}: {f.annotation}{source_note}")
|
||||
|
||||
if optional:
|
||||
lines.append(" Optional (keyword) arguments:")
|
||||
for f in optional:
|
||||
default_note = f" = {f.default_repr}" if f.default_repr else " = ..."
|
||||
source_note = " (from parent class)" if f.source == "inherited" else ""
|
||||
lines.append(f" - {f.name}: {f.annotation}{default_note}{source_note}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_dataclass_constructor_notes(test_context: str) -> list[str]:
|
||||
"""Generate notes about dataclass constructor signatures.
|
||||
|
||||
Analyzes the test context for dataclass definitions and generates
|
||||
explicit notes about their constructor signatures, including inherited fields.
|
||||
|
||||
Args:
|
||||
test_context: The source code context to analyze for dataclass definitions.
|
||||
|
||||
"""
|
||||
# Extract actual code from markdown if present
|
||||
code_to_analyze = extract_all_code_from_markdown(test_context) if "```python" in test_context else test_context
|
||||
|
||||
# Find all dataclasses in the context
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code_to_analyze)
|
||||
|
||||
if not dataclasses or wrapper is None:
|
||||
return []
|
||||
|
||||
notes = []
|
||||
for class_name in dataclasses:
|
||||
fields = get_all_fields_with_inheritance(class_name, dataclasses, source_lines, wrapper)
|
||||
if fields:
|
||||
signature = format_constructor_signature(class_name, fields)
|
||||
notes.append(signature)
|
||||
|
||||
return notes
|
||||
|
|
@ -1,12 +1,10 @@
|
|||
from itertools import chain
|
||||
|
||||
from core.languages.python.testgen.preprocessing.class_constructor_notes import get_class_constructor_notes
|
||||
from core.languages.python.testgen.preprocessing.dataclass_constructor_notes import get_dataclass_constructor_notes
|
||||
from core.languages.python.testgen.preprocessing.torch_tensor_limit import get_tensor_size_note
|
||||
|
||||
# Preprocessing functions that analyze code context and return notes
|
||||
# Each function takes test_context (str) and returns a list of notes (list[str])
|
||||
_PREPROCESSING_FUNCTIONS = [get_tensor_size_note, get_dataclass_constructor_notes, get_class_constructor_notes]
|
||||
_PREPROCESSING_FUNCTIONS = [get_tensor_size_note]
|
||||
|
||||
|
||||
def preprocessing_testgen_pipeline(test_context: str) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import sentry_sdk
|
||||
from asgiref.sync import sync_to_async
|
||||
|
|
@ -17,6 +17,15 @@ if TYPE_CHECKING:
|
|||
features_api = NinjaAPI(urls_namespace="log_features")
|
||||
|
||||
|
||||
def _positional_list(items: list[str] | None, index: int | None) -> list[str | None] | None:
|
||||
"""Place a single item at the correct index, padding with None."""
|
||||
if not items or index is None or len(items) != 1:
|
||||
return cast("list[str | None] | None", items)
|
||||
result: list[str | None] = [None] * (index + 1)
|
||||
result[index] = items[0]
|
||||
return result
|
||||
|
||||
|
||||
@sync_to_async
|
||||
@transaction.atomic
|
||||
def log_features(
|
||||
|
|
@ -89,9 +98,9 @@ def log_features(
|
|||
"optimized_runtime": optimized_runtime,
|
||||
"optimized_line_profiler_results": optimized_line_profiler_results,
|
||||
"is_correct": is_correct,
|
||||
"generated_test": generated_tests,
|
||||
"instrumented_generated_test": instrumented_generated_tests,
|
||||
"instrumented_perf_test": instrumented_perf_tests,
|
||||
"generated_test": _positional_list(generated_tests, test_index),
|
||||
"instrumented_generated_test": _positional_list(instrumented_generated_tests, test_index),
|
||||
"instrumented_perf_test": _positional_list(instrumented_perf_tests, test_index),
|
||||
"test_framework": test_framework,
|
||||
"created_at": datetime,
|
||||
"aiservice_commit_id": aiservice_commit,
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ You are also provided with the following information.
|
|||
- Introduction of the `global` and `nonlocal` keywords in optimizations is **HIGHLY DISCOURAGED** as it reduces code clarity and maintainability, introduces hidden dependencies, can cause subtle bugs and breaks modularity. **DO NOT** prefer such optimizations.
|
||||
- Replacement of `isinstance()` checks with `type()` checks is **HIGHLY DISCOURAGED** as `isinstance()` correctly handles inheritance and subclasses, while `type()` checks are incorrect for subclass instances and represent a micro-optimization that should be avoided. Do not prefer such optimizations.
|
||||
- If the only optimizations are micro-optimizations like inlining a function call, or localizing variables or methods (attribute lookup optimizations), do not prefer the optimizations. The performance improvements are minimal and come at a substantial cost to readability.
|
||||
- Local variable caching of globals (e.g., `local_var = GLOBAL_VAR` before a loop) is a micro-optimization only relevant on Python ≤3.10. On Python 3.11+ `LOAD_GLOBAL` uses adaptive specialization and is nearly as fast as `LOAD_FAST`, so the benefit is negligible — prefer the simpler code WITHOUT the local cache.
|
||||
|
||||
## Response Format
|
||||
|
||||
|
|
@ -383,9 +384,10 @@ async def rank_optimizations( # noqa: D417
|
|||
json_response = _parse_json_response(output.content, num_candidates)
|
||||
if json_response is not None:
|
||||
logging.info("Successfully parsed JSON response")
|
||||
return RankResponseSchema(
|
||||
ranking=json_response.ranking, explanation=json_response.explanation, scores=json_response.scores
|
||||
ranking = (
|
||||
_scores_to_ranking(json_response.scores) if json_response.scores is not None else json_response.ranking
|
||||
)
|
||||
return RankResponseSchema(ranking=ranking, explanation=json_response.explanation, scores=json_response.scores)
|
||||
|
||||
# Fall back to regex parsing (legacy XML-tag format)
|
||||
logging.info("JSON parsing failed, falling back to regex")
|
||||
|
|
|
|||
|
|
@ -181,6 +181,23 @@ x ="""
|
|||
assert result == expected
|
||||
|
||||
|
||||
def test_extract_code_block_nested_code_fence_in_triple_quote() -> None:
|
||||
# LLM embeds function definition in a triple-quoted string containing ```
|
||||
text = '```python\nimport pytest\n_source = """```python:file.py\ndef foo(): pass\n```"""\ndef test_foo():\n assert True\n```'
|
||||
result = extract_code_block(text)
|
||||
assert (
|
||||
result
|
||||
== 'import pytest\n_source = """```python:file.py\ndef foo(): pass\n```"""\ndef test_foo():\n assert True'
|
||||
)
|
||||
|
||||
|
||||
def test_extract_code_block_nested_code_fence_block() -> None:
|
||||
# LLM nests a ```python:filepath block inside the main block
|
||||
text = "```python\nimport pytest\n```python:src/mod.py\ndef foo(): pass\n```\ndef test_foo():\n assert True\n```"
|
||||
result = extract_code_block(text)
|
||||
assert result == "import pytest\n```python:src/mod.py\ndef foo(): pass\n```\ndef test_foo():\n assert True"
|
||||
|
||||
|
||||
def test_extract_all_code_single_block() -> None:
|
||||
text = "```python\ncode1\n```"
|
||||
result = extract_all_code_from_markdown(text)
|
||||
|
|
|
|||
|
|
@ -1,418 +0,0 @@
|
|||
"""Tests for regular class constructor signature extraction."""
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from core.languages.python.testgen.preprocessing.class_constructor_notes import (
|
||||
ParamInfo,
|
||||
_find_all_classes_with_init,
|
||||
extract_init_params,
|
||||
format_init_signature,
|
||||
get_class_constructor_notes,
|
||||
)
|
||||
|
||||
|
||||
class TestClassInitCollector:
|
||||
def test_collects_class_with_init(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
"""
|
||||
classes, _, wrapper = _find_all_classes_with_init(code)
|
||||
assert "Foo" in classes
|
||||
assert wrapper is not None
|
||||
|
||||
def test_skips_class_without_init(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def bar(self):
|
||||
pass
|
||||
"""
|
||||
classes, _, _ = _find_all_classes_with_init(code)
|
||||
assert "Foo" not in classes
|
||||
|
||||
def test_skips_dataclass(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
"""
|
||||
classes, _, _ = _find_all_classes_with_init(code)
|
||||
assert "Foo" not in classes
|
||||
|
||||
def test_skips_dataclass_with_call(self) -> None:
|
||||
code = """
|
||||
@dataclass(frozen=True)
|
||||
class Foo:
|
||||
x: int
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
"""
|
||||
classes, _, _ = _find_all_classes_with_init(code)
|
||||
assert "Foo" not in classes
|
||||
|
||||
def test_collects_multiple_classes(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
|
||||
class Bar:
|
||||
def __init__(self, y: str):
|
||||
self.y = y
|
||||
"""
|
||||
classes, _, _ = _find_all_classes_with_init(code)
|
||||
assert "Foo" in classes
|
||||
assert "Bar" in classes
|
||||
|
||||
def test_handles_syntax_error(self) -> None:
|
||||
code = "this is not valid python code {{{{"
|
||||
classes, _, wrapper = _find_all_classes_with_init(code)
|
||||
assert classes == {}
|
||||
assert wrapper is None
|
||||
|
||||
|
||||
class TestExtractInitParams:
|
||||
def test_typed_params(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int, y: str):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 2
|
||||
assert params[0].name == "x"
|
||||
assert params[0].annotation == "int"
|
||||
assert not params[0].has_default
|
||||
assert params[1].name == "y"
|
||||
assert params[1].annotation == "str"
|
||||
assert not params[1].has_default
|
||||
|
||||
def test_params_with_defaults(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int = 0, y: str = "hello"):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 2
|
||||
assert params[0].name == "x"
|
||||
assert params[0].has_default
|
||||
assert params[0].default_repr == "0"
|
||||
assert params[1].name == "y"
|
||||
assert params[1].has_default
|
||||
assert params[1].default_repr == '"hello"'
|
||||
|
||||
def test_args_and_kwargs(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int, *args, **kwargs):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 3
|
||||
assert params[0].name == "x"
|
||||
assert params[1].name == "*args"
|
||||
assert params[2].name == "**kwargs"
|
||||
|
||||
def test_skips_self(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 0
|
||||
|
||||
def test_untyped_params(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x, y=10):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 2
|
||||
assert params[0].name == "x"
|
||||
assert params[0].annotation is None
|
||||
assert params[1].name == "y"
|
||||
assert params[1].annotation is None
|
||||
assert params[1].has_default
|
||||
assert params[1].default_repr == "10"
|
||||
|
||||
def test_kwonly_params(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: int, *, key: str = "default"):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 2
|
||||
assert params[0].name == "x"
|
||||
assert params[1].name == "key"
|
||||
assert params[1].annotation == "str"
|
||||
assert params[1].has_default
|
||||
assert params[1].default_repr == '"default"'
|
||||
|
||||
def test_long_default_truncated(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
def __init__(self, x: list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]):
|
||||
pass
|
||||
"""
|
||||
classes, source_lines, wrapper = _find_all_classes_with_init(code)
|
||||
assert wrapper is not None
|
||||
init_node = next(iter(classes["Foo"].body.body))
|
||||
assert isinstance(init_node, cst.FunctionDef)
|
||||
params = extract_init_params(init_node, source_lines, wrapper)
|
||||
|
||||
assert len(params) == 1
|
||||
assert params[0].has_default
|
||||
assert params[0].default_repr is not None
|
||||
assert len(params[0].default_repr) <= 50
|
||||
assert params[0].default_repr.endswith("...")
|
||||
|
||||
|
||||
class TestFormatInitSignature:
|
||||
def test_required_only(self) -> None:
|
||||
params = [
|
||||
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
|
||||
ParamInfo(name="y", annotation="str", has_default=False, default_repr=None),
|
||||
]
|
||||
result = format_init_signature("Foo", params)
|
||||
assert "Constructor signature for Foo:" in result
|
||||
assert "Required (positional) arguments:" in result
|
||||
assert "- x: int" in result
|
||||
assert "- y: str" in result
|
||||
assert "Optional" not in result
|
||||
|
||||
def test_optional_only(self) -> None:
|
||||
params = [
|
||||
ParamInfo(name="x", annotation="int", has_default=True, default_repr="0"),
|
||||
ParamInfo(name="y", annotation="str", has_default=True, default_repr='"hi"'),
|
||||
]
|
||||
result = format_init_signature("Foo", params)
|
||||
assert "Constructor signature for Foo:" in result
|
||||
assert "Optional (keyword) arguments:" in result
|
||||
assert "- x: int = 0" in result
|
||||
assert '- y: str = "hi"' in result
|
||||
assert "Required" not in result
|
||||
|
||||
def test_mixed_params(self) -> None:
|
||||
params = [
|
||||
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
|
||||
ParamInfo(name="y", annotation="str", has_default=True, default_repr='"default"'),
|
||||
]
|
||||
result = format_init_signature("Foo", params)
|
||||
assert "Required (positional) arguments:" in result
|
||||
assert "Optional (keyword) arguments:" in result
|
||||
|
||||
def test_no_params(self) -> None:
|
||||
result = format_init_signature("Empty", [])
|
||||
assert "Empty() - no parameters" in result
|
||||
|
||||
def test_variadic_params(self) -> None:
|
||||
params = [
|
||||
ParamInfo(name="x", annotation="int", has_default=False, default_repr=None),
|
||||
ParamInfo(name="*args", annotation=None, has_default=False, default_repr=None),
|
||||
ParamInfo(name="**kwargs", annotation=None, has_default=False, default_repr=None),
|
||||
]
|
||||
result = format_init_signature("Foo", params)
|
||||
assert "Required (positional) arguments:" in result
|
||||
assert "Variadic arguments:" in result
|
||||
assert "- *args" in result
|
||||
assert "- **kwargs" in result
|
||||
|
||||
def test_untyped_params(self) -> None:
|
||||
params = [
|
||||
ParamInfo(name="x", annotation=None, has_default=False, default_repr=None),
|
||||
ParamInfo(name="y", annotation=None, has_default=True, default_repr="10"),
|
||||
]
|
||||
result = format_init_signature("Foo", params)
|
||||
assert " - x\n" in result
|
||||
assert " - y = 10" in result
|
||||
|
||||
|
||||
class TestGetClassConstructorNotes:
|
||||
def test_class_with_typed_init(self) -> None:
|
||||
context = """
|
||||
class LayoutElements:
|
||||
def __init__(self, width: int, height: int, elements: list[str]):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.elements = elements
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "Constructor signature for LayoutElements:" in notes[0]
|
||||
assert "- width: int" in notes[0]
|
||||
assert "- height: int" in notes[0]
|
||||
assert "- elements: list[str]" in notes[0]
|
||||
|
||||
def test_class_with_defaults(self) -> None:
|
||||
context = """
|
||||
class TextRegions:
|
||||
def __init__(self, text: str = "", max_len: int = 100):
|
||||
self.text = text
|
||||
self.max_len = max_len
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "Optional (keyword) arguments:" in notes[0]
|
||||
assert '- text: str = ""' in notes[0]
|
||||
assert "- max_len: int = 100" in notes[0]
|
||||
|
||||
def test_class_with_args_kwargs(self) -> None:
|
||||
context = """
|
||||
class Flexible:
|
||||
def __init__(self, name: str, *args, **kwargs):
|
||||
self.name = name
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "- name: str" in notes[0]
|
||||
assert "- *args" in notes[0]
|
||||
assert "- **kwargs" in notes[0]
|
||||
|
||||
def test_class_without_init_skipped(self) -> None:
|
||||
context = """
|
||||
class NoInit:
|
||||
x = 10
|
||||
def method(self):
|
||||
pass
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert notes == []
|
||||
|
||||
def test_dataclass_skipped(self) -> None:
|
||||
context = """
|
||||
@dataclass
|
||||
class Config:
|
||||
name: str
|
||||
value: int
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert notes == []
|
||||
|
||||
def test_multiple_classes(self) -> None:
|
||||
context = """
|
||||
class Foo:
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
|
||||
class Bar:
|
||||
def __init__(self, y: str):
|
||||
self.y = y
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 2
|
||||
class_names = " ".join(notes)
|
||||
assert "Foo" in class_names
|
||||
assert "Bar" in class_names
|
||||
|
||||
def test_markdown_wrapped_code(self) -> None:
|
||||
context = """
|
||||
Some description text.
|
||||
```python:models.py
|
||||
class LayoutElements:
|
||||
def __init__(self, width: int, height: int):
|
||||
self.width = width
|
||||
self.height = height
|
||||
```
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "LayoutElements" in notes[0]
|
||||
assert "- width: int" in notes[0]
|
||||
|
||||
def test_syntax_error_returns_empty(self) -> None:
|
||||
context = "this is not valid python {{{{"
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert notes == []
|
||||
|
||||
def test_init_with_only_self(self) -> None:
|
||||
context = """
|
||||
class Empty:
|
||||
def __init__(self):
|
||||
pass
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "Empty() - no parameters" in notes[0]
|
||||
|
||||
def test_mixed_dataclass_and_regular(self) -> None:
|
||||
context = """
|
||||
@dataclass
|
||||
class Config:
|
||||
name: str
|
||||
|
||||
class Service:
|
||||
def __init__(self, config: Config, debug: bool = False):
|
||||
self.config = config
|
||||
self.debug = debug
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "Service" in notes[0]
|
||||
assert "Config" not in notes[0].split("\n")[0]
|
||||
|
||||
def test_plain_code_without_markdown(self) -> None:
|
||||
context = """
|
||||
class MyClass:
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "MyClass" in notes[0]
|
||||
|
||||
def test_full_output_format(self) -> None:
|
||||
context = """
|
||||
class Server:
|
||||
def __init__(self, host: str, port: int, debug: bool = False, workers: int = 4):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.debug = debug
|
||||
self.workers = workers
|
||||
"""
|
||||
notes = get_class_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
expected = """Constructor signature for Server:
|
||||
Required (positional) arguments:
|
||||
- host: str
|
||||
- port: int
|
||||
Optional (keyword) arguments:
|
||||
- debug: bool = False
|
||||
- workers: int = 4"""
|
||||
assert notes[0] == expected
|
||||
|
|
@ -1,529 +0,0 @@
|
|||
"""Tests for dataclass constructor signature extraction."""
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from core.languages.python.cst_utils import get_base_class_name, has_decorator
|
||||
from aiservice.common.markdown_utils import extract_all_code_from_markdown
|
||||
from core.languages.python.testgen.preprocessing.dataclass_constructor_notes import (
|
||||
FieldInfo,
|
||||
extract_dataclass_fields,
|
||||
find_all_dataclasses,
|
||||
format_constructor_signature,
|
||||
get_all_fields_with_inheritance,
|
||||
get_dataclass_constructor_notes,
|
||||
)
|
||||
|
||||
|
||||
class TestHasDecorator:
|
||||
def test_simple_dataclass_decorator(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert has_decorator(class_node, "dataclass")
|
||||
|
||||
def test_dataclass_with_call(self) -> None:
|
||||
code = """
|
||||
@dataclass(frozen=True)
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert has_decorator(class_node, "dataclass")
|
||||
|
||||
def test_not_a_dataclass(self) -> None:
|
||||
code = """
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert not has_decorator(class_node, "dataclass")
|
||||
|
||||
def test_other_decorator(self) -> None:
|
||||
code = """
|
||||
@other_decorator
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert not has_decorator(class_node, "dataclass")
|
||||
|
||||
|
||||
class TestGetBaseClassName:
|
||||
def test_simple_name(self) -> None:
|
||||
code = "class Foo(Bar): pass"
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert get_base_class_name(class_node.bases[0]) == "Bar"
|
||||
|
||||
def test_attribute(self) -> None:
|
||||
code = "class Foo(module.Bar): pass"
|
||||
tree = cst.parse_module(code)
|
||||
class_node = tree.body[0]
|
||||
assert isinstance(class_node, cst.ClassDef)
|
||||
assert get_base_class_name(class_node.bases[0]) == "Bar"
|
||||
|
||||
|
||||
class TestExtractDataclassFields:
|
||||
def test_simple_fields(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
y: str
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
class_node = dataclasses["Foo"]
|
||||
fields = extract_dataclass_fields(class_node, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 2
|
||||
assert fields[0].name == "x"
|
||||
assert fields[0].annotation == "int"
|
||||
assert not fields[0].has_default
|
||||
assert fields[1].name == "y"
|
||||
assert fields[1].annotation == "str"
|
||||
assert not fields[1].has_default
|
||||
|
||||
def test_fields_with_defaults(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
y: str = "default"
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
class_node = dataclasses["Foo"]
|
||||
fields = extract_dataclass_fields(class_node, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 2
|
||||
assert fields[0].name == "x"
|
||||
assert not fields[0].has_default
|
||||
assert fields[1].name == "y"
|
||||
assert fields[1].has_default
|
||||
assert fields[1].default_repr == '"default"'
|
||||
|
||||
def test_complex_annotation(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
items: list[str]
|
||||
mapping: dict[str, int] | None = None
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
class_node = dataclasses["Foo"]
|
||||
fields = extract_dataclass_fields(class_node, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 2
|
||||
assert fields[0].name == "items"
|
||||
assert fields[0].annotation == "list[str]"
|
||||
assert fields[1].name == "mapping"
|
||||
assert "dict[str, int]" in fields[1].annotation
|
||||
|
||||
|
||||
class TestFindAllDataclasses:
|
||||
def test_finds_single_dataclass(self) -> None:
|
||||
code = """
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
result, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert "Foo" in result
|
||||
assert wrapper is not None
|
||||
assert len(source_lines) > 0
|
||||
|
||||
def test_finds_multiple_dataclasses(self) -> None:
|
||||
code = """
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
|
||||
@dataclass
|
||||
class Bar:
|
||||
y: str
|
||||
"""
|
||||
result, _source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert "Foo" in result
|
||||
assert "Bar" in result
|
||||
assert wrapper is not None
|
||||
|
||||
def test_ignores_non_dataclasses(self) -> None:
|
||||
code = """
|
||||
class NotADataclass:
|
||||
x: int
|
||||
|
||||
@dataclass
|
||||
class IsADataclass:
|
||||
y: str
|
||||
"""
|
||||
result, _source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert "NotADataclass" not in result
|
||||
assert "IsADataclass" in result
|
||||
assert wrapper is not None
|
||||
|
||||
def test_handles_syntax_error(self) -> None:
|
||||
code = "this is not valid python code {{{{"
|
||||
result, _source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert result == {}
|
||||
assert wrapper is None
|
||||
|
||||
|
||||
class TestGetAllFieldsWithInheritance:
|
||||
def test_no_inheritance(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
y: str
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
fields = get_all_fields_with_inheritance("Foo", dataclasses, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 2
|
||||
assert all(f.source == "own" for f in fields)
|
||||
|
||||
def test_simple_inheritance(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Base:
|
||||
a: int
|
||||
b: str
|
||||
|
||||
@dataclass
|
||||
class Child(Base):
|
||||
c: float
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
fields = get_all_fields_with_inheritance("Child", dataclasses, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 3
|
||||
assert fields[0].name == "a"
|
||||
assert fields[0].source == "inherited"
|
||||
assert fields[1].name == "b"
|
||||
assert fields[1].source == "inherited"
|
||||
assert fields[2].name == "c"
|
||||
assert fields[2].source == "own"
|
||||
|
||||
def test_multi_level_inheritance(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class GrandParent:
|
||||
a: int
|
||||
|
||||
@dataclass
|
||||
class Parent(GrandParent):
|
||||
b: str
|
||||
|
||||
@dataclass
|
||||
class Child(Parent):
|
||||
c: float
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
fields = get_all_fields_with_inheritance("Child", dataclasses, source_lines, wrapper)
|
||||
|
||||
assert len(fields) == 3
|
||||
field_names = [f.name for f in fields]
|
||||
assert "a" in field_names
|
||||
assert "b" in field_names
|
||||
assert "c" in field_names
|
||||
|
||||
def test_unknown_class(self) -> None:
|
||||
code = """
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
"""
|
||||
dataclasses, source_lines, wrapper = find_all_dataclasses(code)
|
||||
assert wrapper is not None
|
||||
fields = get_all_fields_with_inheritance("Unknown", dataclasses, source_lines, wrapper)
|
||||
assert fields == []
|
||||
|
||||
|
||||
class TestFormatConstructorSignature:
|
||||
def test_required_only(self) -> None:
|
||||
fields = [
|
||||
FieldInfo(name="x", annotation="int", has_default=False, default_repr=None, source="own"),
|
||||
FieldInfo(name="y", annotation="str", has_default=False, default_repr=None, source="own"),
|
||||
]
|
||||
result = format_constructor_signature("Foo", fields)
|
||||
assert "Constructor signature for Foo:" in result
|
||||
assert "Required (positional) arguments:" in result
|
||||
assert "- x: int" in result
|
||||
assert "- y: str" in result
|
||||
assert "Optional" not in result
|
||||
|
||||
def test_optional_only(self) -> None:
|
||||
fields = [
|
||||
FieldInfo(name="x", annotation="int", has_default=True, default_repr="0", source="own"),
|
||||
FieldInfo(name="y", annotation="str", has_default=True, default_repr='"default"', source="own"),
|
||||
]
|
||||
result = format_constructor_signature("Foo", fields)
|
||||
assert "Constructor signature for Foo:" in result
|
||||
assert "Optional (keyword) arguments:" in result
|
||||
assert "- x: int = 0" in result
|
||||
assert '- y: str = "default"' in result
|
||||
assert "Required" not in result
|
||||
|
||||
def test_mixed_fields(self) -> None:
|
||||
fields = [
|
||||
FieldInfo(name="a", annotation="int", has_default=False, default_repr=None, source="inherited"),
|
||||
FieldInfo(name="b", annotation="str", has_default=False, default_repr=None, source="own"),
|
||||
FieldInfo(name="c", annotation="float", has_default=True, default_repr="0.0", source="own"),
|
||||
]
|
||||
result = format_constructor_signature("Foo", fields)
|
||||
assert "Required (positional) arguments:" in result
|
||||
assert "Optional (keyword) arguments:" in result
|
||||
assert "(from parent class)" in result
|
||||
|
||||
def test_no_fields(self) -> None:
|
||||
result = format_constructor_signature("Empty", [])
|
||||
assert "no fields" in result
|
||||
|
||||
|
||||
class TestExtractCodeFromMarkdown:
|
||||
def test_simple_code_block(self) -> None:
|
||||
markdown = """
|
||||
Some text
|
||||
```python
|
||||
def foo():
|
||||
pass
|
||||
```
|
||||
More text
|
||||
"""
|
||||
result = extract_all_code_from_markdown(markdown)
|
||||
assert "def foo():" in result
|
||||
assert "pass" in result
|
||||
assert "Some text" not in result
|
||||
|
||||
def test_code_block_with_filepath(self) -> None:
|
||||
markdown = """
|
||||
```python:path/to/file.py
|
||||
class Foo:
|
||||
x: int
|
||||
```
|
||||
"""
|
||||
result = extract_all_code_from_markdown(markdown)
|
||||
assert "class Foo:" in result
|
||||
assert "x: int" in result
|
||||
|
||||
def test_multiple_code_blocks(self) -> None:
|
||||
markdown = """
|
||||
```python:file1.py
|
||||
@dataclass
|
||||
class A:
|
||||
x: int
|
||||
```
|
||||
|
||||
```python:file2.py
|
||||
@dataclass
|
||||
class B:
|
||||
y: str
|
||||
```
|
||||
"""
|
||||
result = extract_all_code_from_markdown(markdown)
|
||||
assert "class A:" in result
|
||||
assert "class B:" in result
|
||||
|
||||
|
||||
class TestGetDataclassConstructorNotes:
|
||||
def test_simple_dataclass(self) -> None:
|
||||
context = """
|
||||
```python:models.py
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
name: str
|
||||
value: int
|
||||
```
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "Constructor signature for Config:" in notes[0]
|
||||
assert "- name: str" in notes[0]
|
||||
assert "- value: int" in notes[0]
|
||||
|
||||
def test_dataclass_with_inheritance(self) -> None:
|
||||
context = """
|
||||
```python:models.py
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
|
||||
@dataclass
|
||||
class ExtendedConfig(BaseConfig):
|
||||
extra_param: int = 0
|
||||
```
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
assert len(notes) == 2
|
||||
|
||||
# Find the note for ExtendedConfig
|
||||
extended_note = next(n for n in notes if "ExtendedConfig" in n)
|
||||
assert "- model_name: str (from parent class)" in extended_note
|
||||
assert "- required_env_vars: list[str] (from parent class)" in extended_note
|
||||
assert "- extra_param: int = 0" in extended_note
|
||||
|
||||
def test_no_dataclasses(self) -> None:
|
||||
context = """
|
||||
def regular_function():
|
||||
pass
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
assert notes == []
|
||||
|
||||
def test_plain_code_without_markdown(self) -> None:
|
||||
context = """
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class SimpleConfig:
|
||||
name: str
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
assert len(notes) == 1
|
||||
assert "SimpleConfig" in notes[0]
|
||||
|
||||
def test_llm_config_like_structure(self) -> None:
|
||||
"""Test with a structure similar to the skyvern LLMConfig."""
|
||||
context = """
|
||||
```python:skyvern/forge/sdk/api/llm/models.py
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfigBase:
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
supports_vision: bool
|
||||
add_assistant_prefix: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig(LLMConfigBase):
|
||||
litellm_params: Optional[dict] = field(default=None)
|
||||
max_tokens: int | None = 4096
|
||||
```
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
|
||||
# Should have notes for both classes
|
||||
assert len(notes) == 2
|
||||
|
||||
# Find the LLMConfig note
|
||||
llm_config_note = next(n for n in notes if "LLMConfig:" in n)
|
||||
|
||||
# Should show inherited fields from LLMConfigBase
|
||||
assert "- model_name: str (from parent class)" in llm_config_note
|
||||
assert "- required_env_vars: list[str] (from parent class)" in llm_config_note
|
||||
assert "- supports_vision: bool (from parent class)" in llm_config_note
|
||||
assert "- add_assistant_prefix: bool (from parent class)" in llm_config_note
|
||||
|
||||
# Should also show own fields
|
||||
assert "- litellm_params:" in llm_config_note
|
||||
assert "- max_tokens:" in llm_config_note
|
||||
|
||||
def test_full_output_for_complex_inheritance(self) -> None:
|
||||
"""Test full generated notes output for complex dataclass inheritance."""
|
||||
context = """
|
||||
```python:skyvern/forge/sdk/api/llm/models.py
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfigBase:
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
supports_vision: bool
|
||||
add_assistant_prefix: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig(LLMConfigBase):
|
||||
litellm_params: Optional[dict] = field(default=None)
|
||||
max_tokens: int | None = 4096
|
||||
max_completion_tokens: int | None = None
|
||||
temperature: float | None = 0.7
|
||||
reasoning_effort: str | None = None
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMRouterConfig(LLMConfigBase):
|
||||
model_list: list
|
||||
main_model_group: str
|
||||
redis_host: str | None = None
|
||||
redis_port: int | None = None
|
||||
fallback_model_group: str | None = None
|
||||
routing_strategy: str = "usage-based-routing"
|
||||
num_retries: int = 1
|
||||
```
|
||||
"""
|
||||
notes = get_dataclass_constructor_notes(context)
|
||||
|
||||
# Should have notes for all 3 classes
|
||||
assert len(notes) == 3
|
||||
|
||||
# Verify LLMConfigBase note (no inheritance)
|
||||
base_note = next(n for n in notes if "LLMConfigBase:" in n)
|
||||
expected_base = """Constructor signature for LLMConfigBase:
|
||||
Required (positional) arguments:
|
||||
- model_name: str
|
||||
- required_env_vars: list[str]
|
||||
- supports_vision: bool
|
||||
- add_assistant_prefix: bool"""
|
||||
assert base_note == expected_base
|
||||
|
||||
# Verify LLMConfig note (with inheritance)
|
||||
config_note = next(n for n in notes if "LLMConfig:" in n)
|
||||
expected_config = """Constructor signature for LLMConfig:
|
||||
Required (positional) arguments:
|
||||
- model_name: str (from parent class)
|
||||
- required_env_vars: list[str] (from parent class)
|
||||
- supports_vision: bool (from parent class)
|
||||
- add_assistant_prefix: bool (from parent class)
|
||||
Optional (keyword) arguments:
|
||||
- litellm_params: Optional[dict] = field(default=None)
|
||||
- max_tokens: int | None = 4096
|
||||
- max_completion_tokens: int | None = None
|
||||
- temperature: float | None = 0.7
|
||||
- reasoning_effort: str | None = None"""
|
||||
assert config_note == expected_config
|
||||
|
||||
# Verify LLMRouterConfig note (with inheritance + own required fields)
|
||||
router_note = next(n for n in notes if "LLMRouterConfig:" in n)
|
||||
expected_router = """Constructor signature for LLMRouterConfig:
|
||||
Required (positional) arguments:
|
||||
- model_name: str (from parent class)
|
||||
- required_env_vars: list[str] (from parent class)
|
||||
- supports_vision: bool (from parent class)
|
||||
- add_assistant_prefix: bool (from parent class)
|
||||
- model_list: list
|
||||
- main_model_group: str
|
||||
Optional (keyword) arguments:
|
||||
- redis_host: str | None = None
|
||||
- redis_port: int | None = None
|
||||
- fallback_model_group: str | None = None
|
||||
- routing_strategy: str = "usage-based-routing"
|
||||
- num_retries: int = 1"""
|
||||
assert router_note == expected_router
|
||||
|
|
@ -510,6 +510,40 @@ export async function triggerSuggestPrChanges(
|
|||
repo,
|
||||
pull_number: pullNumber,
|
||||
})
|
||||
|
||||
// Check if the PR is merged or closed - we can't suggest changes on merged/closed PRs
|
||||
if (originalPrData.data.merged) {
|
||||
logger.info(
|
||||
`PR #${pullNumber} is already merged, cannot suggest changes`,
|
||||
{
|
||||
endpoint: "/cfapi/suggest-pr-changes",
|
||||
operation: "pr_merged_check",
|
||||
owner,
|
||||
repo,
|
||||
userId,
|
||||
},
|
||||
)
|
||||
throw unprocessableEntity(
|
||||
`Cannot suggest changes on merged PR #${pullNumber}. The PR was already merged.`,
|
||||
)
|
||||
}
|
||||
|
||||
if (originalPrData.data.state === "closed") {
|
||||
logger.info(
|
||||
`PR #${pullNumber} is closed, cannot suggest changes`,
|
||||
{
|
||||
endpoint: "/cfapi/suggest-pr-changes",
|
||||
operation: "pr_closed_check",
|
||||
owner,
|
||||
repo,
|
||||
userId,
|
||||
},
|
||||
)
|
||||
throw unprocessableEntity(
|
||||
`Cannot suggest changes on closed PR #${pullNumber}. The PR is no longer open.`,
|
||||
)
|
||||
}
|
||||
|
||||
const baseBranch = originalPrData.data.head.ref
|
||||
logger.info(`Attempting to access ref for: ${owner}/${repo}, branch: ${baseBranch}`, {
|
||||
endpoint: "/cfapi/suggest-pr-changes",
|
||||
|
|
|
|||
|
|
@ -133,6 +133,7 @@ export async function verifyExistingOptimizations(req: Request, res: Response) {
|
|||
}
|
||||
|
||||
// Get PR with specific 404 handling
|
||||
// Note: GitHub returns 404 for both non-existent PRs and PRs the installation cannot access
|
||||
let pr
|
||||
try {
|
||||
pr = await octokit.rest.pulls.get({
|
||||
|
|
@ -142,7 +143,45 @@ export async function verifyExistingOptimizations(req: Request, res: Response) {
|
|||
})
|
||||
} catch (error: any) {
|
||||
if (error.status === 404) {
|
||||
throw githubPrNotFound(`#${pr_number} in ${repo_owner}/${repo_name}`)
|
||||
// Log additional context to help diagnose permission vs not-found issues
|
||||
logger.warn(
|
||||
`PR #${pr_number} returned 404 in ${repo_owner}/${repo_name}. This could mean the PR doesn't exist, or the GitHub App installation doesn't have access to it.`,
|
||||
{
|
||||
requestId: (req as any).requestId,
|
||||
userId,
|
||||
endpoint: "/cfapi/verify-existing-optimizations",
|
||||
operation: "pr_not_found_or_no_access",
|
||||
repo_owner,
|
||||
repo_name,
|
||||
pr_number,
|
||||
nickname,
|
||||
errorMessage: error.message,
|
||||
errorResponse: error.response?.data,
|
||||
},
|
||||
)
|
||||
throw githubPrNotFound(
|
||||
`#${pr_number} in ${repo_owner}/${repo_name}. If the PR exists, ensure the GitHub App installation has access to this repository.`,
|
||||
)
|
||||
}
|
||||
// Handle 403 (Forbidden) as a permissions issue
|
||||
if (error.status === 403) {
|
||||
logger.warn(
|
||||
`Access forbidden to PR #${pr_number} in ${repo_owner}/${repo_name}. The GitHub App installation may not have sufficient permissions.`,
|
||||
{
|
||||
requestId: (req as any).requestId,
|
||||
userId,
|
||||
endpoint: "/cfapi/verify-existing-optimizations",
|
||||
operation: "pr_access_forbidden",
|
||||
repo_owner,
|
||||
repo_name,
|
||||
pr_number,
|
||||
nickname,
|
||||
errorMessage: error.message,
|
||||
},
|
||||
)
|
||||
throw githubInstallationError(
|
||||
`Access forbidden to PR #${pr_number} in ${repo_owner}/${repo_name}. Please ensure the GitHub App has the necessary permissions.`,
|
||||
)
|
||||
}
|
||||
throw error // Re-throw to be caught by global handler
|
||||
}
|
||||
|
|
|
|||
|
|
@ -574,28 +574,52 @@ export async function processReaction(event: any): Promise<boolean> {
|
|||
console.error(
|
||||
`Error processing approved request for trace ${optimization.trace_id}: ${err}`,
|
||||
)
|
||||
await sendSlackMessage(
|
||||
|
||||
// Extract helpful error details for Slack notification
|
||||
const errorMessage = err.message || String(err)
|
||||
const errorType = err.constructor?.name || "Error"
|
||||
const isPrMergedOrClosed = errorMessage.includes("merged") || errorMessage.includes("closed")
|
||||
|
||||
const errorBlocks: any[] = [
|
||||
{
|
||||
blocks: [
|
||||
type: "section",
|
||||
text: {
|
||||
type: "mrkdwn",
|
||||
text: `:warning: Error processing approved optimization \`${optimization.trace_id}\`:`,
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "section",
|
||||
text: {
|
||||
type: "mrkdwn",
|
||||
text: `\`\`\`${errorMessage}\`\`\``,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
// Add helpful context if PR is merged/closed
|
||||
if (isPrMergedOrClosed) {
|
||||
errorBlocks.push({
|
||||
type: "context",
|
||||
elements: [
|
||||
{
|
||||
type: "section",
|
||||
text: {
|
||||
type: "mrkdwn",
|
||||
text: `:warning: Error processing approved optimization \`${optimization.trace_id}\`:`,
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "section",
|
||||
text: {
|
||||
type: "mrkdwn",
|
||||
text: `\`\`\`${err.message}\`\`\``,
|
||||
},
|
||||
type: "mrkdwn",
|
||||
text: `ℹ️ The target PR may have been merged or closed since the optimization was queued for approval.`,
|
||||
},
|
||||
],
|
||||
text: `Error processing optimization ${optimization.trace_id}: ${err.message}`,
|
||||
})
|
||||
}
|
||||
|
||||
await sendSlackMessage(
|
||||
{
|
||||
blocks: errorBlocks,
|
||||
text: `Error processing optimization ${optimization.trace_id}: ${errorMessage}`,
|
||||
},
|
||||
channel,
|
||||
)
|
||||
|
||||
// Return false to indicate the reaction processing failed
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,5 +8,5 @@ NPM_TOKEN
|
|||
SCM_DO_BUILD_DURING_DEPLOYMENT
|
||||
WEBSITE_HEALTHCHECK_MAXPINGFAILURES
|
||||
WEBSITE_HTTPLOGGING_RETENTION_DAYS
|
||||
CODEFLASH_INTERNAL_REPO_PATH=
|
||||
CODEFLASH_CLI_REPO_PATH=
|
||||
AISERVICE_DIR=
|
||||
CODEFLASH_DIR=
|
||||
|
|
|
|||
|
|
@ -168,8 +168,8 @@ export function indexTraceData(traceData: TraceData): IndexedTraceData {
|
|||
// --- Codebase browsing helpers ---
|
||||
|
||||
function getRepoRoot(repo: string): string | null {
|
||||
if (repo === "codeflash-internal") return process.env.CODEFLASH_INTERNAL_REPO_PATH || null
|
||||
if (repo === "codeflash") return process.env.CODEFLASH_CLI_REPO_PATH || null
|
||||
if (repo === "codeflash-internal") return process.env.AISERVICE_DIR || null
|
||||
if (repo === "codeflash") return process.env.CODEFLASH_DIR || null
|
||||
return null
|
||||
}
|
||||
|
||||
|
|
@ -630,7 +630,7 @@ export async function resolveToolCall(
|
|||
|
||||
case "read_file": {
|
||||
const repoRoot = getRepoRoot(args.repo as string)
|
||||
if (!repoRoot) return `Repository path not configured. Set CODEFLASH_INTERNAL_REPO_PATH or CODEFLASH_CLI_REPO_PATH env var.`
|
||||
if (!repoRoot) return `Repository path not configured. Set AISERVICE_DIR or CODEFLASH_DIR env var.`
|
||||
|
||||
const pathResult = resolveAndValidatePath(repoRoot, args.path as string)
|
||||
if ("error" in pathResult) return pathResult.error
|
||||
|
|
@ -657,7 +657,7 @@ export async function resolveToolCall(
|
|||
|
||||
case "search_code": {
|
||||
const repoRoot = getRepoRoot(args.repo as string)
|
||||
if (!repoRoot) return `Repository path not configured. Set CODEFLASH_INTERNAL_REPO_PATH or CODEFLASH_CLI_REPO_PATH env var.`
|
||||
if (!repoRoot) return `Repository path not configured. Set AISERVICE_DIR or CODEFLASH_DIR env var.`
|
||||
|
||||
const maxResults = Math.min(Math.max(1, (args.max_results as number) || 30), 100)
|
||||
const rgArgs = [
|
||||
|
|
@ -699,7 +699,7 @@ export async function resolveToolCall(
|
|||
|
||||
case "list_directory": {
|
||||
const repoRoot = getRepoRoot(args.repo as string)
|
||||
if (!repoRoot) return `Repository path not configured. Set CODEFLASH_INTERNAL_REPO_PATH or CODEFLASH_CLI_REPO_PATH env var.`
|
||||
if (!repoRoot) return `Repository path not configured. Set AISERVICE_DIR or CODEFLASH_DIR env var.`
|
||||
|
||||
const relativePath = (args.path as string) || "."
|
||||
const pathResult = resolveAndValidatePath(repoRoot, relativePath)
|
||||
|
|
|
|||
Loading…
Reference in a new issue