fix: include dependency classes in read-writable optimization context

Classes used as dependencies (enums, dataclasses, types) were being
excluded from the optimization context even when marked as used by
the target function. This caused NameError when the LLM used these
types in generated optimizations.
This commit is contained in:
Kevin Turcios 2026-01-24 01:19:22 -05:00
parent 34de67681e
commit 6b3b10e7fa
2 changed files with 135 additions and 1 deletions

View file

@ -1016,10 +1016,29 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
# Do not recurse into nested classes
if prefix:
return None, False
class_name = node.name.value
# Assuming always an IndentedBlock
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
class_prefix = f"{prefix}.{class_name}" if prefix else class_name
# Check if this class contains any target functions
has_target_functions = any(
isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions
for stmt in node.body.body
)
# If the class is used as a dependency (not containing target functions), keep it entirely
# This handles cases like enums, dataclasses, and other types used by the target function
if (
not has_target_functions
and class_name in defs_with_usages
and defs_with_usages[class_name].used_by_qualified_function
):
return node, True
new_body = []
found_target = False

View file

@ -3820,3 +3820,118 @@ def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path)
result = get_external_base_class_inits(context, tmp_path)
assert result.code_strings == []
def test_dependency_classes_kept_in_read_writable_context(tmp_path: Path) -> None:
"""Tests that classes used as dependencies (enums, dataclasses) are kept in read-writable context.
This test verifies that when a function uses classes like enums or dataclasses
as types or in match statements, those classes are included in the optimization
context, even though they don't contain any target functions.
"""
code = '''
import dataclasses
import enum
import typing as t
class MessageKind(enum.StrEnum):
ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response"
BEGIN_EXFILTRATION = "begin-exfiltration"
@dataclasses.dataclass
class Message:
kind: str
@dataclasses.dataclass
class MessageInAskForClipboardResponse(Message):
kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE
text: str = ""
@dataclasses.dataclass
class MessageInBeginExfiltration(Message):
kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION
MessageIn = (
MessageInAskForClipboardResponse
| MessageInBeginExfiltration
)
def reify_channel_message(data: dict) -> MessageIn:
kind = data.get("kind", None)
match kind:
case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE:
text = data.get("text") or ""
return MessageInAskForClipboardResponse(text=text)
case MessageKind.BEGIN_EXFILTRATION:
return MessageInBeginExfiltration()
case _:
raise ValueError(f"Unknown message kind: '{kind}'")
'''
code_path = tmp_path / "message.py"
code_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="reify_channel_message",
file_path=code_path,
parents=[],
)
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
)
expected_read_writable = """
```python:message.py
import dataclasses
import enum
import typing as t
class MessageKind(enum.StrEnum):
ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response"
BEGIN_EXFILTRATION = "begin-exfiltration"
@dataclasses.dataclass
class Message:
kind: str
@dataclasses.dataclass
class MessageInAskForClipboardResponse(Message):
kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE
text: str = ""
@dataclasses.dataclass
class MessageInBeginExfiltration(Message):
kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION
MessageIn = (
MessageInAskForClipboardResponse
| MessageInBeginExfiltration
)
def reify_channel_message(data: dict) -> MessageIn:
kind = data.get("kind", None)
match kind:
case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE:
text = data.get("text") or ""
return MessageInAskForClipboardResponse(text=text)
case MessageKind.BEGIN_EXFILTRATION:
return MessageInBeginExfiltration()
case _:
raise ValueError(f"Unknown message kind: '{kind}'")
```
"""
assert code_ctx.read_writable_code.markdown.strip() == expected_read_writable.strip()