codeflash/tests/test_get_read_writable_code.py
Kevin Turcios eceac13fc3 Merge remote-tracking branch 'origin/main' into omni-java
# Conflicts:
#	.claude/rules/architecture.md
#	.claude/rules/code-style.md
#	.github/workflows/claude.yml
#	.github/workflows/duplicate-code-detector.yml
#	codeflash/api/aiservice.py
#	codeflash/cli_cmds/console.py
#	codeflash/cli_cmds/logging_config.py
#	codeflash/code_utils/deduplicate_code.py
#	codeflash/discovery/discover_unit_tests.py
#	codeflash/languages/base.py
#	codeflash/languages/code_replacer.py
#	codeflash/languages/javascript/mocha_runner.py
#	codeflash/languages/javascript/support.py
#	codeflash/languages/python/support.py
#	codeflash/optimization/function_optimizer.py
#	codeflash/verification/parse_test_output.py
#	codeflash/verification/verification_utils.py
#	codeflash/verification/verifier.py
#	packages/codeflash/package-lock.json
#	packages/codeflash/package.json
#	tests/languages/javascript/test_support_dispatch.py
#	tests/test_codeflash_capture.py
#	tests/test_languages/test_javascript_test_runner.py
#	tests/test_multi_file_code_replacement.py
2026-03-04 01:52:32 -05:00

322 lines
7.9 KiB
Python

from textwrap import dedent
import pytest
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.models.models import CodeContextType
def test_simple_function() -> None:
code = """
def target_function():
x = 1
y = 2
return x + y
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
expected = dedent("""
def target_function():
x = 1
y = 2
return x + y
""")
assert result.strip() == expected.strip()
def test_class_method() -> None:
code = """
class MyClass:
def target_function(self):
x = 1
y = 2
return x + y
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"}).code
expected = dedent("""
class MyClass:
def target_function(self):
x = 1
y = 2
return x + y
""")
assert result.strip() == expected.strip()
def test_class_with_attributes() -> None:
code = """
class MyClass:
x: int = 1
y: str = "hello"
def target_method(self):
return self.x + 42
def other_method(self):
print("this should be excluded")
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
expected = dedent("""
class MyClass:
def target_method(self):
return self.x + 42
""")
assert result.strip() == expected.strip()
def test_basic_class_structure() -> None:
"""Test that nested classes are ignored for target function search."""
code = """
class Outer:
x = 1
def target_method(self):
return 42
class Inner:
y = 2
def not_findable(self):
return 42
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"}).code
expected = dedent("""
class Outer:
def target_method(self):
return 42
""")
assert result.strip() == expected.strip()
def test_top_level_targets() -> None:
code = """
class OuterClass:
x = 1
def method1(self):
return self.x
def target_function():
return 42
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
expected = dedent("""
def target_function():
return 42
""")
assert result.strip() == expected.strip()
def test_multiple_top_level_classes() -> None:
code = """
class ClassA:
def process(self):
return "A"
class ClassB:
def process(self):
return "B"
class ClassC:
def process(self):
return "C"
"""
result = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}
).code
expected = dedent("""
class ClassA:
def process(self):
return "A"
class ClassC:
def process(self):
return "C"
""")
assert result.strip() == expected.strip()
def test_try_except_structure() -> None:
code = """
try:
class TargetClass:
def target_method(self):
return 42
except ValueError:
class ErrorClass:
def handle_error(self):
print("error")
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"}).code
expected = dedent("""
try:
class TargetClass:
def target_method(self):
return 42
except ValueError:
class ErrorClass:
def handle_error(self):
print("error")
""")
assert result.strip() == expected.strip()
def test_init_method() -> None:
code = """
class MyClass:
def __init__(self):
self.x = 1
def other_method(self):
return "other"
def target_method(self):
return f"Value: {self.x}"
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
expected = dedent("""
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
return f"Value: {self.x}"
""")
assert result.strip() == expected.strip()
def test_dunder_method() -> None:
code = """
class MyClass:
def __repr__(self):
return "MyClass"
def other_method(self):
return "other"
def target_method(self):
return f"Value: {self.x}"
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
expected = dedent("""
class MyClass:
def target_method(self):
return f"Value: {self.x}"
""")
assert result.strip() == expected.strip()
def test_no_targets_found() -> None:
code = """
class MyClass:
def method(self):
pass
class Inner:
def target(self):
pass
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}).code
expected = dedent("""
class MyClass:
def method(self):
pass
class Inner:
def target(self):
pass
""")
assert result.strip() == expected.strip()
def test_no_targets_found_raises_for_nonexistent() -> None:
"""Test that ValueError is raised when the target function doesn't exist at all."""
code = """
class MyClass:
def method(self):
pass
"""
with pytest.raises(ValueError, match="No target functions found in the provided code"):
parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"NonExistent.target"})
def test_module_var() -> None:
code = """
def target_function(self) -> None:
var2 = "test"
if y:
x = 5
else:
z = 10
def some_function():
print("wow")
def some_function():
print("wow")
"""
expected = """
def target_function(self) -> None:
var2 = "test"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
assert dedent(expected).strip() == output.strip()
def test_comment_between_imports_and_variable_preserves_position() -> None:
code = """
from __future__ import annotations
import re
from dataclasses import dataclass, field
# NOTE: This comment documents the constant below.
# It should stay right above SOME_RE, not jump to the top of the file.
SOME_RE = re.compile(r"^pattern", re.MULTILINE)
@dataclass(slots=True)
class Item:
name: str
value: int
children: list[Item] = field(default_factory=list)
def parse(text: str) -> list[Item]:
root = Item(name="root", value=0)
for m in SOME_RE.finditer(text):
root.children.append(Item(name=m.group(), value=1))
return root.children
"""
expected = """
# NOTE: This comment documents the constant below.
# It should stay right above SOME_RE, not jump to the top of the file.
SOME_RE = re.compile(r"^pattern", re.MULTILINE)
@dataclass(slots=True)
class Item:
name: str
value: int
children: list[Item] = field(default_factory=list)
def parse(text: str) -> list[Item]:
root = Item(name="root", value=0)
for m in SOME_RE.finditer(text):
root.children.append(Item(name=m.group(), value=1))
return root.children
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"parse"}).code
assert result.strip() == dedent(expected).strip()