# Conflicts: # .claude/rules/architecture.md # .claude/rules/code-style.md # .github/workflows/claude.yml # .github/workflows/duplicate-code-detector.yml # codeflash/api/aiservice.py # codeflash/cli_cmds/console.py # codeflash/cli_cmds/logging_config.py # codeflash/code_utils/deduplicate_code.py # codeflash/discovery/discover_unit_tests.py # codeflash/languages/base.py # codeflash/languages/code_replacer.py # codeflash/languages/javascript/mocha_runner.py # codeflash/languages/javascript/support.py # codeflash/languages/python/support.py # codeflash/optimization/function_optimizer.py # codeflash/verification/parse_test_output.py # codeflash/verification/verification_utils.py # codeflash/verification/verifier.py # packages/codeflash/package-lock.json # packages/codeflash/package.json # tests/languages/javascript/test_support_dispatch.py # tests/test_codeflash_capture.py # tests/test_languages/test_javascript_test_runner.py # tests/test_multi_file_code_replacement.py
322 lines
7.9 KiB
Python
322 lines
7.9 KiB
Python
from textwrap import dedent
|
|
|
|
import pytest
|
|
|
|
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
|
|
from codeflash.models.models import CodeContextType
|
|
|
|
|
|
def test_simple_function() -> None:
|
|
code = """
|
|
def target_function():
|
|
x = 1
|
|
y = 2
|
|
return x + y
|
|
"""
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
|
|
|
expected = dedent("""
|
|
def target_function():
|
|
x = 1
|
|
y = 2
|
|
return x + y
|
|
""")
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
def test_class_method() -> None:
|
|
code = """
|
|
class MyClass:
|
|
def target_function(self):
|
|
x = 1
|
|
y = 2
|
|
return x + y
|
|
"""
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"}).code
|
|
|
|
expected = dedent("""
|
|
class MyClass:
|
|
def target_function(self):
|
|
x = 1
|
|
y = 2
|
|
return x + y
|
|
""")
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
def test_class_with_attributes() -> None:
|
|
code = """
|
|
class MyClass:
|
|
x: int = 1
|
|
y: str = "hello"
|
|
|
|
def target_method(self):
|
|
return self.x + 42
|
|
|
|
def other_method(self):
|
|
print("this should be excluded")
|
|
"""
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
|
|
|
expected = dedent("""
|
|
class MyClass:
|
|
|
|
def target_method(self):
|
|
return self.x + 42
|
|
""")
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
def test_basic_class_structure() -> None:
|
|
"""Test that nested classes are ignored for target function search."""
|
|
code = """
|
|
class Outer:
|
|
x = 1
|
|
def target_method(self):
|
|
return 42
|
|
|
|
class Inner:
|
|
y = 2
|
|
def not_findable(self):
|
|
return 42
|
|
"""
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"}).code
|
|
|
|
expected = dedent("""
|
|
class Outer:
|
|
def target_method(self):
|
|
return 42
|
|
""")
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
def test_top_level_targets() -> None:
|
|
code = """
|
|
class OuterClass:
|
|
x = 1
|
|
def method1(self):
|
|
return self.x
|
|
|
|
def target_function():
|
|
return 42
|
|
"""
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
|
|
|
expected = dedent("""
|
|
def target_function():
|
|
return 42
|
|
""")
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
def test_multiple_top_level_classes() -> None:
|
|
code = """
|
|
class ClassA:
|
|
def process(self):
|
|
return "A"
|
|
|
|
class ClassB:
|
|
def process(self):
|
|
return "B"
|
|
|
|
class ClassC:
|
|
def process(self):
|
|
return "C"
|
|
"""
|
|
result = parse_code_and_prune_cst(
|
|
dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}
|
|
).code
|
|
|
|
expected = dedent("""
|
|
class ClassA:
|
|
def process(self):
|
|
return "A"
|
|
|
|
class ClassC:
|
|
def process(self):
|
|
return "C"
|
|
""")
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
def test_try_except_structure() -> None:
|
|
code = """
|
|
try:
|
|
class TargetClass:
|
|
def target_method(self):
|
|
return 42
|
|
except ValueError:
|
|
class ErrorClass:
|
|
def handle_error(self):
|
|
print("error")
|
|
"""
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"}).code
|
|
|
|
expected = dedent("""
|
|
try:
|
|
class TargetClass:
|
|
def target_method(self):
|
|
return 42
|
|
except ValueError:
|
|
class ErrorClass:
|
|
def handle_error(self):
|
|
print("error")
|
|
""")
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
def test_init_method() -> None:
|
|
code = """
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
|
|
def other_method(self):
|
|
return "other"
|
|
|
|
def target_method(self):
|
|
return f"Value: {self.x}"
|
|
"""
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
|
|
|
expected = dedent("""
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
|
|
def target_method(self):
|
|
return f"Value: {self.x}"
|
|
""")
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
def test_dunder_method() -> None:
|
|
code = """
|
|
class MyClass:
|
|
def __repr__(self):
|
|
return "MyClass"
|
|
|
|
def other_method(self):
|
|
return "other"
|
|
|
|
def target_method(self):
|
|
return f"Value: {self.x}"
|
|
"""
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
|
|
|
expected = dedent("""
|
|
class MyClass:
|
|
|
|
def target_method(self):
|
|
return f"Value: {self.x}"
|
|
""")
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
def test_no_targets_found() -> None:
|
|
code = """
|
|
class MyClass:
|
|
def method(self):
|
|
pass
|
|
|
|
class Inner:
|
|
def target(self):
|
|
pass
|
|
"""
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}).code
|
|
expected = dedent("""
|
|
class MyClass:
|
|
def method(self):
|
|
pass
|
|
|
|
class Inner:
|
|
def target(self):
|
|
pass
|
|
""")
|
|
assert result.strip() == expected.strip()
|
|
|
|
|
|
def test_no_targets_found_raises_for_nonexistent() -> None:
|
|
"""Test that ValueError is raised when the target function doesn't exist at all."""
|
|
code = """
|
|
class MyClass:
|
|
def method(self):
|
|
pass
|
|
"""
|
|
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
|
parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"NonExistent.target"})
|
|
|
|
|
|
def test_module_var() -> None:
|
|
code = """
|
|
def target_function(self) -> None:
|
|
var2 = "test"
|
|
|
|
if y:
|
|
x = 5
|
|
else:
|
|
z = 10
|
|
def some_function():
|
|
print("wow")
|
|
|
|
def some_function():
|
|
print("wow")
|
|
"""
|
|
|
|
expected = """
|
|
def target_function(self) -> None:
|
|
var2 = "test"
|
|
"""
|
|
|
|
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
|
assert dedent(expected).strip() == output.strip()
|
|
|
|
|
|
def test_comment_between_imports_and_variable_preserves_position() -> None:
|
|
code = """
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
|
|
# NOTE: This comment documents the constant below.
|
|
# It should stay right above SOME_RE, not jump to the top of the file.
|
|
SOME_RE = re.compile(r"^pattern", re.MULTILINE)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class Item:
|
|
name: str
|
|
value: int
|
|
children: list[Item] = field(default_factory=list)
|
|
|
|
|
|
def parse(text: str) -> list[Item]:
|
|
root = Item(name="root", value=0)
|
|
for m in SOME_RE.finditer(text):
|
|
root.children.append(Item(name=m.group(), value=1))
|
|
return root.children
|
|
"""
|
|
|
|
expected = """
|
|
# NOTE: This comment documents the constant below.
|
|
# It should stay right above SOME_RE, not jump to the top of the file.
|
|
SOME_RE = re.compile(r"^pattern", re.MULTILINE)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class Item:
|
|
name: str
|
|
value: int
|
|
children: list[Item] = field(default_factory=list)
|
|
|
|
|
|
def parse(text: str) -> list[Item]:
|
|
root = Item(name="root", value=0)
|
|
for m in SOME_RE.finditer(text):
|
|
root.children.append(Item(name=m.group(), value=1))
|
|
return root.children
|
|
"""
|
|
|
|
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"parse"}).code
|
|
assert result.strip() == dedent(expected).strip()
|