mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: include dependency classes in read-writable optimization context
Classes used as dependencies (enums, dataclasses, types) were being excluded from the optimization context even when marked as used by the target function. This caused NameError when the LLM used these types in generated optimizations.
This commit is contained in:
parent
34de67681e
commit
6b3b10e7fa
2 changed files with 135 additions and 1 deletions
|
|
@ -1016,10 +1016,29 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
|
|||
# Do not recurse into nested classes
|
||||
if prefix:
|
||||
return None, False
|
||||
|
||||
class_name = node.name.value
|
||||
|
||||
# Assuming always an IndentedBlock
|
||||
if not isinstance(node.body, cst.IndentedBlock):
|
||||
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
|
||||
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
class_prefix = f"{prefix}.{class_name}" if prefix else class_name
|
||||
|
||||
# Check if this class contains any target functions
|
||||
has_target_functions = any(
|
||||
isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions
|
||||
for stmt in node.body.body
|
||||
)
|
||||
|
||||
# If the class is used as a dependency (not containing target functions), keep it entirely
|
||||
# This handles cases like enums, dataclasses, and other types used by the target function
|
||||
if (
|
||||
not has_target_functions
|
||||
and class_name in defs_with_usages
|
||||
and defs_with_usages[class_name].used_by_qualified_function
|
||||
):
|
||||
return node, True
|
||||
|
||||
new_body = []
|
||||
found_target = False
|
||||
|
||||
|
|
|
|||
|
|
@ -3820,3 +3820,118 @@ def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path)
|
|||
result = get_external_base_class_inits(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_dependency_classes_kept_in_read_writable_context(tmp_path: Path) -> None:
|
||||
"""Tests that classes used as dependencies (enums, dataclasses) are kept in read-writable context.
|
||||
|
||||
This test verifies that when a function uses classes like enums or dataclasses
|
||||
as types or in match statements, those classes are included in the optimization
|
||||
context, even though they don't contain any target functions.
|
||||
"""
|
||||
code = '''
|
||||
import dataclasses
|
||||
import enum
|
||||
import typing as t
|
||||
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response"
|
||||
BEGIN_EXFILTRATION = "begin-exfiltration"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Message:
|
||||
kind: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInAskForClipboardResponse(Message):
|
||||
kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE
|
||||
text: str = ""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInBeginExfiltration(Message):
|
||||
kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION
|
||||
|
||||
|
||||
MessageIn = (
|
||||
MessageInAskForClipboardResponse
|
||||
| MessageInBeginExfiltration
|
||||
)
|
||||
|
||||
|
||||
def reify_channel_message(data: dict) -> MessageIn:
|
||||
kind = data.get("kind", None)
|
||||
|
||||
match kind:
|
||||
case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE:
|
||||
text = data.get("text") or ""
|
||||
return MessageInAskForClipboardResponse(text=text)
|
||||
case MessageKind.BEGIN_EXFILTRATION:
|
||||
return MessageInBeginExfiltration()
|
||||
case _:
|
||||
raise ValueError(f"Unknown message kind: '{kind}'")
|
||||
'''
|
||||
code_path = tmp_path / "message.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="reify_channel_message",
|
||||
file_path=code_path,
|
||||
parents=[],
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
)
|
||||
|
||||
expected_read_writable = """
|
||||
```python:message.py
|
||||
import dataclasses
|
||||
import enum
|
||||
import typing as t
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response"
|
||||
BEGIN_EXFILTRATION = "begin-exfiltration"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Message:
|
||||
kind: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInAskForClipboardResponse(Message):
|
||||
kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE
|
||||
text: str = ""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInBeginExfiltration(Message):
|
||||
kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION
|
||||
|
||||
|
||||
MessageIn = (
|
||||
MessageInAskForClipboardResponse
|
||||
| MessageInBeginExfiltration
|
||||
)
|
||||
|
||||
|
||||
def reify_channel_message(data: dict) -> MessageIn:
|
||||
kind = data.get("kind", None)
|
||||
|
||||
match kind:
|
||||
case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE:
|
||||
text = data.get("text") or ""
|
||||
return MessageInAskForClipboardResponse(text=text)
|
||||
case MessageKind.BEGIN_EXFILTRATION:
|
||||
return MessageInBeginExfiltration()
|
||||
case _:
|
||||
raise ValueError(f"Unknown message kind: '{kind}'")
|
||||
```
|
||||
"""
|
||||
assert code_ctx.read_writable_code.markdown.strip() == expected_read_writable.strip()
|
||||
|
|
|
|||
Loading…
Reference in a new issue