2024-12-17 19:30:07 +00:00
|
|
|
from textwrap import dedent
|
|
|
|
|
|
|
|
|
|
import pytest
|
2026-01-29 09:39:48 +00:00
|
|
|
|
2026-02-16 19:49:04 +00:00
|
|
|
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
|
2025-03-06 00:40:23 +00:00
|
|
|
from codeflash.models.models import CodeContextType
|
2024-12-17 19:30:07 +00:00
|
|
|
|
|
|
|
|
|
2024-12-18 18:31:47 +00:00
|
|
|
def test_simple_function() -> None:
|
2024-12-17 19:30:07 +00:00
|
|
|
code = """
|
|
|
|
|
def target_function():
|
|
|
|
|
x = 1
|
|
|
|
|
y = 2
|
|
|
|
|
return x + y
|
|
|
|
|
"""
|
2026-02-23 06:08:39 +00:00
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
2024-12-17 19:30:07 +00:00
|
|
|
|
|
|
|
|
expected = dedent("""
|
|
|
|
|
def target_function():
|
|
|
|
|
x = 1
|
|
|
|
|
y = 2
|
|
|
|
|
return x + y
|
|
|
|
|
""")
|
|
|
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
|
2024-12-18 18:31:47 +00:00
|
|
|
def test_class_method() -> None:
|
2024-12-17 19:30:07 +00:00
|
|
|
code = """
|
|
|
|
|
class MyClass:
|
|
|
|
|
def target_function(self):
|
|
|
|
|
x = 1
|
|
|
|
|
y = 2
|
|
|
|
|
return x + y
|
|
|
|
|
"""
|
2026-02-23 06:08:39 +00:00
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"}).code
|
2024-12-17 19:30:07 +00:00
|
|
|
|
|
|
|
|
expected = dedent("""
|
|
|
|
|
class MyClass:
|
|
|
|
|
def target_function(self):
|
|
|
|
|
x = 1
|
|
|
|
|
y = 2
|
|
|
|
|
return x + y
|
|
|
|
|
""")
|
|
|
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
|
2024-12-18 18:31:47 +00:00
|
|
|
def test_class_with_attributes() -> None:
|
2024-12-17 19:30:07 +00:00
|
|
|
code = """
|
|
|
|
|
class MyClass:
|
|
|
|
|
x: int = 1
|
|
|
|
|
y: str = "hello"
|
|
|
|
|
|
|
|
|
|
def target_method(self):
|
|
|
|
|
return self.x + 42
|
|
|
|
|
|
|
|
|
|
def other_method(self):
|
|
|
|
|
print("this should be excluded")
|
|
|
|
|
"""
|
2026-02-23 06:08:39 +00:00
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
2024-12-17 19:30:07 +00:00
|
|
|
|
|
|
|
|
expected = dedent("""
|
|
|
|
|
class MyClass:
|
2024-12-18 02:27:59 +00:00
|
|
|
|
2024-12-17 19:30:07 +00:00
|
|
|
def target_method(self):
|
|
|
|
|
return self.x + 42
|
|
|
|
|
""")
|
|
|
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
|
2024-12-18 18:31:47 +00:00
|
|
|
def test_basic_class_structure() -> None:
|
2024-12-17 19:30:07 +00:00
|
|
|
"""Test that nested classes are ignored for target function search."""
|
|
|
|
|
code = """
|
|
|
|
|
class Outer:
|
|
|
|
|
x = 1
|
|
|
|
|
def target_method(self):
|
|
|
|
|
return 42
|
|
|
|
|
|
|
|
|
|
class Inner:
|
|
|
|
|
y = 2
|
|
|
|
|
def not_findable(self):
|
|
|
|
|
return 42
|
|
|
|
|
"""
|
2026-02-23 06:08:39 +00:00
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"}).code
|
2024-12-17 19:30:07 +00:00
|
|
|
|
|
|
|
|
expected = dedent("""
|
|
|
|
|
class Outer:
|
|
|
|
|
def target_method(self):
|
|
|
|
|
return 42
|
|
|
|
|
""")
|
|
|
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
|
2024-12-18 18:31:47 +00:00
|
|
|
def test_top_level_targets() -> None:
|
2024-12-17 19:30:07 +00:00
|
|
|
code = """
|
|
|
|
|
class OuterClass:
|
|
|
|
|
x = 1
|
|
|
|
|
def method1(self):
|
|
|
|
|
return self.x
|
|
|
|
|
|
|
|
|
|
def target_function():
|
|
|
|
|
return 42
|
|
|
|
|
"""
|
2026-02-23 06:08:39 +00:00
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
2024-12-17 19:30:07 +00:00
|
|
|
|
|
|
|
|
expected = dedent("""
|
|
|
|
|
def target_function():
|
|
|
|
|
return 42
|
|
|
|
|
""")
|
|
|
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
|
2024-12-18 18:31:47 +00:00
|
|
|
def test_multiple_top_level_classes() -> None:
|
2024-12-17 19:30:07 +00:00
|
|
|
code = """
|
|
|
|
|
class ClassA:
|
|
|
|
|
def process(self):
|
|
|
|
|
return "A"
|
|
|
|
|
|
|
|
|
|
class ClassB:
|
|
|
|
|
def process(self):
|
|
|
|
|
return "B"
|
|
|
|
|
|
|
|
|
|
class ClassC:
|
|
|
|
|
def process(self):
|
|
|
|
|
return "C"
|
|
|
|
|
"""
|
2026-03-04 06:52:32 +00:00
|
|
|
result = parse_code_and_prune_cst(
|
|
|
|
|
dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}
|
|
|
|
|
).code
|
2024-12-17 19:30:07 +00:00
|
|
|
|
|
|
|
|
expected = dedent("""
|
|
|
|
|
class ClassA:
|
|
|
|
|
def process(self):
|
|
|
|
|
return "A"
|
|
|
|
|
|
|
|
|
|
class ClassC:
|
|
|
|
|
def process(self):
|
|
|
|
|
return "C"
|
|
|
|
|
""")
|
|
|
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
|
2024-12-18 18:31:47 +00:00
|
|
|
def test_try_except_structure() -> None:
|
2024-12-17 19:30:07 +00:00
|
|
|
code = """
|
|
|
|
|
try:
|
|
|
|
|
class TargetClass:
|
|
|
|
|
def target_method(self):
|
|
|
|
|
return 42
|
|
|
|
|
except ValueError:
|
|
|
|
|
class ErrorClass:
|
|
|
|
|
def handle_error(self):
|
|
|
|
|
print("error")
|
|
|
|
|
"""
|
2026-02-23 06:08:39 +00:00
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"}).code
|
2024-12-17 19:30:07 +00:00
|
|
|
|
|
|
|
|
expected = dedent("""
|
|
|
|
|
try:
|
|
|
|
|
class TargetClass:
|
|
|
|
|
def target_method(self):
|
|
|
|
|
return 42
|
|
|
|
|
except ValueError:
|
|
|
|
|
class ErrorClass:
|
|
|
|
|
def handle_error(self):
|
|
|
|
|
print("error")
|
|
|
|
|
""")
|
|
|
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
|
2025-01-14 01:01:52 +00:00
|
|
|
def test_init_method() -> None:
|
2024-12-17 19:30:07 +00:00
|
|
|
code = """
|
|
|
|
|
class MyClass:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.x = 1
|
|
|
|
|
|
|
|
|
|
def other_method(self):
|
|
|
|
|
return "other"
|
|
|
|
|
|
|
|
|
|
def target_method(self):
|
|
|
|
|
return f"Value: {self.x}"
|
|
|
|
|
"""
|
2026-02-23 06:08:39 +00:00
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
2024-12-17 19:30:07 +00:00
|
|
|
|
|
|
|
|
expected = dedent("""
|
|
|
|
|
class MyClass:
|
2025-01-14 01:01:52 +00:00
|
|
|
def __init__(self):
|
|
|
|
|
self.x = 1
|
2024-12-18 02:27:59 +00:00
|
|
|
|
2024-12-17 19:30:07 +00:00
|
|
|
def target_method(self):
|
|
|
|
|
return f"Value: {self.x}"
|
|
|
|
|
""")
|
|
|
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
2026-01-29 09:39:48 +00:00
|
|
|
|
2025-01-14 01:01:52 +00:00
|
|
|
def test_dunder_method() -> None:
|
|
|
|
|
code = """
|
|
|
|
|
class MyClass:
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return "MyClass"
|
|
|
|
|
|
|
|
|
|
def other_method(self):
|
|
|
|
|
return "other"
|
|
|
|
|
|
|
|
|
|
def target_method(self):
|
|
|
|
|
return f"Value: {self.x}"
|
|
|
|
|
"""
|
2026-02-23 06:08:39 +00:00
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
2025-01-14 01:01:52 +00:00
|
|
|
|
|
|
|
|
expected = dedent("""
|
|
|
|
|
class MyClass:
|
|
|
|
|
|
|
|
|
|
def target_method(self):
|
|
|
|
|
return f"Value: {self.x}"
|
|
|
|
|
""")
|
|
|
|
|
assert result.strip() == expected.strip()
|
2024-12-17 19:30:07 +00:00
|
|
|
|
2026-01-29 09:39:48 +00:00
|
|
|
|
2024-12-18 18:31:47 +00:00
|
|
|
def test_no_targets_found() -> None:
|
2024-12-17 19:30:07 +00:00
|
|
|
code = """
|
|
|
|
|
class MyClass:
|
|
|
|
|
def method(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class Inner:
|
|
|
|
|
def target(self):
|
|
|
|
|
pass
|
|
|
|
|
"""
|
2026-02-23 06:08:39 +00:00
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}).code
|
2026-01-24 11:19:48 +00:00
|
|
|
expected = dedent("""
|
|
|
|
|
class MyClass:
|
|
|
|
|
def method(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class Inner:
|
|
|
|
|
def target(self):
|
|
|
|
|
pass
|
|
|
|
|
""")
|
|
|
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_no_targets_found_raises_for_nonexistent() -> None:
|
|
|
|
|
"""Test that ValueError is raised when the target function doesn't exist at all."""
|
|
|
|
|
code = """
|
|
|
|
|
class MyClass:
|
|
|
|
|
def method(self):
|
|
|
|
|
pass
|
|
|
|
|
"""
|
2024-12-17 19:30:07 +00:00
|
|
|
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
2026-01-29 09:39:48 +00:00
|
|
|
parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"NonExistent.target"})
|
2024-12-18 02:27:59 +00:00
|
|
|
|
|
|
|
|
|
2024-12-18 18:31:47 +00:00
|
|
|
def test_module_var() -> None:
|
2024-12-18 02:27:59 +00:00
|
|
|
code = """
|
|
|
|
|
def target_function(self) -> None:
|
|
|
|
|
var2 = "test"
|
|
|
|
|
|
|
|
|
|
if y:
|
|
|
|
|
x = 5
|
|
|
|
|
else:
|
|
|
|
|
z = 10
|
|
|
|
|
def some_function():
|
|
|
|
|
print("wow")
|
|
|
|
|
|
|
|
|
|
def some_function():
|
|
|
|
|
print("wow")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
expected = """
|
|
|
|
|
def target_function(self) -> None:
|
|
|
|
|
var2 = "test"
|
|
|
|
|
"""
|
|
|
|
|
|
2026-02-23 06:08:39 +00:00
|
|
|
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
2024-12-18 02:27:59 +00:00
|
|
|
assert dedent(expected).strip() == output.strip()
|
2026-02-23 06:12:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_comment_between_imports_and_variable_preserves_position() -> None:
|
|
|
|
|
code = """
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import re
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
|
|
|
|
# NOTE: This comment documents the constant below.
|
|
|
|
|
# It should stay right above SOME_RE, not jump to the top of the file.
|
|
|
|
|
SOME_RE = re.compile(r"^pattern", re.MULTILINE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(slots=True)
|
|
|
|
|
class Item:
|
|
|
|
|
name: str
|
|
|
|
|
value: int
|
|
|
|
|
children: list[Item] = field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse(text: str) -> list[Item]:
|
|
|
|
|
root = Item(name="root", value=0)
|
|
|
|
|
for m in SOME_RE.finditer(text):
|
|
|
|
|
root.children.append(Item(name=m.group(), value=1))
|
|
|
|
|
return root.children
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
expected = """
|
|
|
|
|
# NOTE: This comment documents the constant below.
|
|
|
|
|
# It should stay right above SOME_RE, not jump to the top of the file.
|
|
|
|
|
SOME_RE = re.compile(r"^pattern", re.MULTILINE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(slots=True)
|
|
|
|
|
class Item:
|
|
|
|
|
name: str
|
|
|
|
|
value: int
|
|
|
|
|
children: list[Item] = field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse(text: str) -> list[Item]:
|
|
|
|
|
root = Item(name="root", value=0)
|
|
|
|
|
for m in SOME_RE.finditer(text):
|
|
|
|
|
root.children.append(Item(name=m.group(), value=1))
|
|
|
|
|
return root.children
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"parse"}).code
|
|
|
|
|
assert result.strip() == dedent(expected).strip()
|