mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
test: sync test files from main (safe, main-only changes)
34 test files updated with main's refactored tests for new language support protocol, JS/TS improvements, and code context extraction.
This commit is contained in:
parent
2299d26ae5
commit
19bd6e4bad
34 changed files with 3105 additions and 1168 deletions
|
|
@ -493,3 +493,37 @@ def my_function():
|
|||
return helper
|
||||
"""
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_module_input_preserves_comment_position_after_imports() -> None:
|
||||
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
src_code = """from __future__ import annotations
|
||||
import re
|
||||
|
||||
# Comment about PATTERN.
|
||||
PATTERN = re.compile(r"test")
|
||||
|
||||
def parse():
|
||||
return PATTERN.findall("")
|
||||
"""
|
||||
pruned_module = parse_code_and_prune_cst(src_code, CodeContextType.READ_WRITABLE, {"parse"})
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
file_path = project_root / "mod.py"
|
||||
file_path.write_text(src_code)
|
||||
|
||||
result = add_needed_imports_from_module(src_code, pruned_module, file_path, file_path, project_root)
|
||||
|
||||
expected = """from __future__ import annotations
|
||||
import re
|
||||
|
||||
# Comment about PATTERN.
|
||||
PATTERN = re.compile(r"test")
|
||||
|
||||
def parse():
|
||||
return PATTERN.findall("")
|
||||
"""
|
||||
assert result == expected
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,4 +1,4 @@
|
|||
from codeflash.code_utils.deduplicate_code import are_codes_duplicate, normalize_code
|
||||
from codeflash.languages.python.normalizer import normalize_python_code as normalize_code
|
||||
|
||||
|
||||
def test_deduplicate1():
|
||||
|
|
@ -23,7 +23,7 @@ def compute_sum(numbers):
|
|||
"""
|
||||
|
||||
assert normalize_code(code1) == normalize_code(code2)
|
||||
assert are_codes_duplicate(code1, code2)
|
||||
assert normalize_code(code1) == normalize_code(code2)
|
||||
|
||||
# Example 3: Same function and parameter names, different local variables (should match)
|
||||
code3 = """
|
||||
|
|
@ -43,7 +43,7 @@ def calculate_sum(numbers):
|
|||
"""
|
||||
|
||||
assert normalize_code(code3) == normalize_code(code4)
|
||||
assert are_codes_duplicate(code3, code4)
|
||||
assert normalize_code(code3) == normalize_code(code4)
|
||||
|
||||
# Example 4: Nested functions and classes (preserving names)
|
||||
code5 = """
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from codeflash.languages.python.static_analysis.code_extractor import delete___f
|
|||
from codeflash.languages.python.static_analysis.code_replacer import (
|
||||
AddRequestArgument,
|
||||
AutouseFixtureModifier,
|
||||
OptimFunctionCollector,
|
||||
PytestMarkAdder,
|
||||
is_zero_diff,
|
||||
replace_functions_and_add_imports,
|
||||
|
|
@ -19,7 +18,7 @@ from codeflash.languages.python.static_analysis.code_replacer import (
|
|||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, FunctionSource
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||
|
|
@ -55,7 +54,7 @@ def sorter(arr):
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -808,6 +807,7 @@ def test_code_replacement10() -> None:
|
|||
get_code_output = """# file: test_code_replacement.py
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
|
@ -834,7 +834,7 @@ class MainClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
|
||||
code_context = func_optimizer.get_code_optimization_context().unwrap()
|
||||
assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip()
|
||||
|
||||
|
|
@ -1745,7 +1745,7 @@ class NewClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -1824,7 +1824,7 @@ a=2
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -1904,7 +1904,7 @@ class NewClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -1983,7 +1983,7 @@ class NewClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -2063,7 +2063,7 @@ class NewClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -2153,7 +2153,7 @@ class NewClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -3453,7 +3453,7 @@ def hydrate_input_text_actions_with_field_names(
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
|
|
@ -3476,142 +3476,6 @@ def hydrate_input_text_actions_with_field_names(
|
|||
assert new_code == expected
|
||||
|
||||
|
||||
# OptimFunctionCollector async function tests
|
||||
def test_optim_function_collector_with_async_functions():
|
||||
"""Test OptimFunctionCollector correctly collects async functions."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
def sync_function():
|
||||
return "sync"
|
||||
|
||||
async def async_function():
|
||||
return "async"
|
||||
|
||||
class TestClass:
|
||||
def sync_method(self):
|
||||
return "sync_method"
|
||||
|
||||
async def async_method(self):
|
||||
return "async_method"
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names={
|
||||
(None, "sync_function"),
|
||||
(None, "async_function"),
|
||||
("TestClass", "sync_method"),
|
||||
("TestClass", "async_method"),
|
||||
},
|
||||
preexisting_objects=None,
|
||||
)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect both sync and async functions
|
||||
assert len(collector.modified_functions) == 4
|
||||
assert (None, "sync_function") in collector.modified_functions
|
||||
assert (None, "async_function") in collector.modified_functions
|
||||
assert ("TestClass", "sync_method") in collector.modified_functions
|
||||
assert ("TestClass", "async_method") in collector.modified_functions
|
||||
|
||||
|
||||
def test_optim_function_collector_new_async_functions():
|
||||
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
def existing_function():
|
||||
return "existing"
|
||||
|
||||
async def new_async_function():
|
||||
return "new_async"
|
||||
|
||||
def new_sync_function():
|
||||
return "new_sync"
|
||||
|
||||
class ExistingClass:
|
||||
async def new_class_async_method(self):
|
||||
return "new_class_async"
|
||||
"""
|
||||
|
||||
# Only existing_function is in preexisting objects
|
||||
preexisting_objects = {("existing_function", ())}
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names=set(), # Not looking for specific functions
|
||||
preexisting_objects=preexisting_objects,
|
||||
)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should identify new functions (both sync and async)
|
||||
assert len(collector.new_functions) == 2
|
||||
function_names = [func.name.value for func in collector.new_functions]
|
||||
assert "new_async_function" in function_names
|
||||
assert "new_sync_function" in function_names
|
||||
|
||||
# Should identify new class methods
|
||||
assert "ExistingClass" in collector.new_class_functions
|
||||
assert len(collector.new_class_functions["ExistingClass"]) == 1
|
||||
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
|
||||
|
||||
|
||||
def test_optim_function_collector_mixed_scenarios():
|
||||
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global functions
|
||||
def global_sync():
|
||||
pass
|
||||
|
||||
async def global_async():
|
||||
pass
|
||||
|
||||
class ParentClass:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def sync_method(self):
|
||||
pass
|
||||
|
||||
async def async_method(self):
|
||||
pass
|
||||
|
||||
class ChildClass:
|
||||
async def child_async_method(self):
|
||||
pass
|
||||
|
||||
def child_sync_method(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
# Looking for specific functions
|
||||
function_names = {
|
||||
(None, "global_sync"),
|
||||
(None, "global_async"),
|
||||
("ParentClass", "sync_method"),
|
||||
("ParentClass", "async_method"),
|
||||
("ChildClass", "child_async_method"),
|
||||
}
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(function_names=function_names, preexisting_objects=None)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect all specified functions (mix of sync and async)
|
||||
assert len(collector.modified_functions) == 5
|
||||
assert (None, "global_sync") in collector.modified_functions
|
||||
assert (None, "global_async") in collector.modified_functions
|
||||
assert ("ParentClass", "sync_method") in collector.modified_functions
|
||||
assert ("ParentClass", "async_method") in collector.modified_functions
|
||||
assert ("ChildClass", "child_async_method") in collector.modified_functions
|
||||
|
||||
# Should collect __init__ method
|
||||
assert "ParentClass" in collector.modified_init_functions
|
||||
|
||||
|
||||
def test_is_zero_diff_async_sleep():
|
||||
original_code = """
|
||||
import time
|
||||
|
|
|
|||
|
|
@ -417,6 +417,312 @@ def test_standard_python_library_objects() -> None:
|
|||
assert not comparator(id1, id3)
|
||||
|
||||
|
||||
def test_itertools_count() -> None:
|
||||
import itertools
|
||||
|
||||
# Equal: same start and step (default step=1)
|
||||
assert comparator(itertools.count(0), itertools.count(0))
|
||||
assert comparator(itertools.count(5), itertools.count(5))
|
||||
assert comparator(itertools.count(0, 1), itertools.count(0, 1))
|
||||
assert comparator(itertools.count(10, 3), itertools.count(10, 3))
|
||||
|
||||
# Equal: negative start and step
|
||||
assert comparator(itertools.count(-5, -2), itertools.count(-5, -2))
|
||||
|
||||
# Equal: float start and step
|
||||
assert comparator(itertools.count(0.5, 0.1), itertools.count(0.5, 0.1))
|
||||
|
||||
# Not equal: different start
|
||||
assert not comparator(itertools.count(0), itertools.count(1))
|
||||
assert not comparator(itertools.count(5), itertools.count(10))
|
||||
|
||||
# Not equal: different step
|
||||
assert not comparator(itertools.count(0, 1), itertools.count(0, 2))
|
||||
assert not comparator(itertools.count(0, 1), itertools.count(0, -1))
|
||||
|
||||
# Not equal: different type
|
||||
assert not comparator(itertools.count(0), 0)
|
||||
assert not comparator(itertools.count(0), [0, 1, 2])
|
||||
|
||||
# Equal after partial consumption (both advanced to the same state)
|
||||
a = itertools.count(0)
|
||||
b = itertools.count(0)
|
||||
next(a)
|
||||
next(b)
|
||||
assert comparator(a, b)
|
||||
|
||||
# Not equal after different consumption
|
||||
a = itertools.count(0)
|
||||
b = itertools.count(0)
|
||||
next(a)
|
||||
assert not comparator(a, b)
|
||||
|
||||
# Works inside containers
|
||||
assert comparator([itertools.count(0)], [itertools.count(0)])
|
||||
assert comparator({"key": itertools.count(5, 2)}, {"key": itertools.count(5, 2)})
|
||||
assert not comparator([itertools.count(0)], [itertools.count(1)])
|
||||
|
||||
|
||||
def test_itertools_repeat() -> None:
|
||||
import itertools
|
||||
|
||||
# Equal: infinite repeat
|
||||
assert comparator(itertools.repeat(5), itertools.repeat(5))
|
||||
assert comparator(itertools.repeat("hello"), itertools.repeat("hello"))
|
||||
|
||||
# Equal: bounded repeat
|
||||
assert comparator(itertools.repeat(5, 3), itertools.repeat(5, 3))
|
||||
assert comparator(itertools.repeat(None, 10), itertools.repeat(None, 10))
|
||||
|
||||
# Not equal: different value
|
||||
assert not comparator(itertools.repeat(5), itertools.repeat(6))
|
||||
assert not comparator(itertools.repeat(5, 3), itertools.repeat(6, 3))
|
||||
|
||||
# Not equal: different count
|
||||
assert not comparator(itertools.repeat(5, 3), itertools.repeat(5, 4))
|
||||
|
||||
# Not equal: bounded vs infinite
|
||||
assert not comparator(itertools.repeat(5), itertools.repeat(5, 3))
|
||||
|
||||
# Not equal: different type
|
||||
assert not comparator(itertools.repeat(5), 5)
|
||||
assert not comparator(itertools.repeat(5), [5])
|
||||
|
||||
# Equal after partial consumption
|
||||
a = itertools.repeat(5, 5)
|
||||
b = itertools.repeat(5, 5)
|
||||
next(a)
|
||||
next(b)
|
||||
assert comparator(a, b)
|
||||
|
||||
# Not equal after different consumption
|
||||
a = itertools.repeat(5, 5)
|
||||
b = itertools.repeat(5, 5)
|
||||
next(a)
|
||||
assert not comparator(a, b)
|
||||
|
||||
# Works inside containers
|
||||
assert comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 3)])
|
||||
assert not comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 4)])
|
||||
|
||||
|
||||
def test_itertools_cycle() -> None:
|
||||
import itertools
|
||||
|
||||
# Equal: same sequence
|
||||
assert comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 3]))
|
||||
assert comparator(itertools.cycle("abc"), itertools.cycle("abc"))
|
||||
|
||||
# Not equal: different sequence
|
||||
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 4]))
|
||||
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2]))
|
||||
|
||||
# Not equal: different type
|
||||
assert not comparator(itertools.cycle([1, 2, 3]), [1, 2, 3])
|
||||
|
||||
# Equal after same partial consumption
|
||||
a = itertools.cycle([1, 2, 3])
|
||||
b = itertools.cycle([1, 2, 3])
|
||||
next(a)
|
||||
next(b)
|
||||
assert comparator(a, b)
|
||||
|
||||
# Not equal after different consumption
|
||||
a = itertools.cycle([1, 2, 3])
|
||||
b = itertools.cycle([1, 2, 3])
|
||||
next(a)
|
||||
assert not comparator(a, b)
|
||||
|
||||
# Equal after consuming a full cycle
|
||||
a = itertools.cycle([1, 2, 3])
|
||||
b = itertools.cycle([1, 2, 3])
|
||||
for _ in range(3):
|
||||
next(a)
|
||||
next(b)
|
||||
assert comparator(a, b)
|
||||
|
||||
# Equal at same position across different full-cycle counts
|
||||
a = itertools.cycle([1, 2, 3])
|
||||
b = itertools.cycle([1, 2, 3])
|
||||
for _ in range(4):
|
||||
next(a)
|
||||
for _ in range(7):
|
||||
next(b)
|
||||
# Both at position 1 within the cycle (4%3 == 7%3 == 1)
|
||||
assert comparator(a, b)
|
||||
|
||||
# Works inside containers
|
||||
assert comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 2])])
|
||||
assert not comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 3])])
|
||||
|
||||
|
||||
def test_itertools_chain() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 4]))
|
||||
assert not comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 5]))
|
||||
assert comparator(itertools.chain.from_iterable([[1, 2], [3]]), itertools.chain.from_iterable([[1, 2], [3]]))
|
||||
assert comparator(itertools.chain(), itertools.chain())
|
||||
assert not comparator(itertools.chain([1]), itertools.chain([1, 2]))
|
||||
|
||||
|
||||
def test_itertools_islice() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 5))
|
||||
assert not comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 6))
|
||||
assert comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 5))
|
||||
assert not comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 6))
|
||||
|
||||
|
||||
def test_itertools_product() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(itertools.product("AB", repeat=2), itertools.product("AB", repeat=2))
|
||||
assert not comparator(itertools.product("AB", repeat=2), itertools.product("AC", repeat=2))
|
||||
assert comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 4]))
|
||||
assert not comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 5]))
|
||||
|
||||
|
||||
def test_itertools_permutations_combinations() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(itertools.permutations("ABC", 2), itertools.permutations("ABC", 2))
|
||||
assert not comparator(itertools.permutations("ABC", 2), itertools.permutations("ABD", 2))
|
||||
assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2))
|
||||
assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3))
|
||||
assert comparator(
|
||||
itertools.combinations_with_replacement("ABC", 2),
|
||||
itertools.combinations_with_replacement("ABC", 2),
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.combinations_with_replacement("ABC", 2),
|
||||
itertools.combinations_with_replacement("ABD", 2),
|
||||
)
|
||||
|
||||
|
||||
def test_itertools_accumulate() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 4]))
|
||||
assert not comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 5]))
|
||||
assert comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=10))
|
||||
assert not comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=0))
|
||||
|
||||
|
||||
def test_itertools_filtering() -> None:
|
||||
import itertools
|
||||
|
||||
# compress
|
||||
assert comparator(
|
||||
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
|
||||
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
|
||||
itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1]),
|
||||
)
|
||||
|
||||
# dropwhile
|
||||
assert comparator(
|
||||
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1]),
|
||||
)
|
||||
|
||||
# takewhile
|
||||
assert comparator(
|
||||
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
|
||||
itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1]),
|
||||
)
|
||||
|
||||
# filterfalse
|
||||
assert comparator(
|
||||
itertools.filterfalse(lambda x: x % 2, range(10)),
|
||||
itertools.filterfalse(lambda x: x % 2, range(10)),
|
||||
)
|
||||
|
||||
|
||||
def test_itertools_starmap() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(
|
||||
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
|
||||
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.starmap(pow, [(2, 3), (3, 2)]),
|
||||
itertools.starmap(pow, [(2, 3), (3, 3)]),
|
||||
)
|
||||
|
||||
|
||||
def test_itertools_zip_longest() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(
|
||||
itertools.zip_longest("AB", "xyz", fillvalue="-"),
|
||||
itertools.zip_longest("AB", "xyz", fillvalue="-"),
|
||||
)
|
||||
assert not comparator(
|
||||
itertools.zip_longest("AB", "xyz", fillvalue="-"),
|
||||
itertools.zip_longest("AB", "xyz", fillvalue="*"),
|
||||
)
|
||||
|
||||
|
||||
def test_itertools_groupby() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBBCC"))
|
||||
assert not comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBCC"))
|
||||
assert comparator(itertools.groupby([]), itertools.groupby([]))
|
||||
|
||||
# With key function
|
||||
assert comparator(
|
||||
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
|
||||
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 10), reason="itertools.pairwise requires Python 3.10+")
|
||||
def test_itertools_pairwise() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 4]))
|
||||
assert not comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 5]))
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 12), reason="itertools.batched requires Python 3.12+")
|
||||
def test_itertools_batched() -> None:
|
||||
import itertools
|
||||
|
||||
assert comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 3))
|
||||
assert not comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 2))
|
||||
|
||||
|
||||
def test_itertools_in_containers() -> None:
|
||||
import itertools
|
||||
|
||||
# Itertools objects nested in dicts/lists
|
||||
assert comparator(
|
||||
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
|
||||
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
|
||||
)
|
||||
assert not comparator(
|
||||
[itertools.product("AB", repeat=2)],
|
||||
[itertools.product("AC", repeat=2)],
|
||||
)
|
||||
|
||||
# Different itertools types should not match
|
||||
assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2))
|
||||
|
||||
|
||||
def test_numpy():
|
||||
try:
|
||||
import numpy as np
|
||||
|
|
@ -5216,3 +5522,67 @@ class TestPythonTempfilePaths:
|
|||
assert PYTHON_TEMPFILE_PATTERN.search("/tmp/tmp123456/")
|
||||
assert not PYTHON_TEMPFILE_PATTERN.search("/tmp/mydir/file.txt")
|
||||
assert not PYTHON_TEMPFILE_PATTERN.search("/home/tmp123/file.txt")
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 10), reason="types.UnionType requires Python 3.10+")
|
||||
class TestUnionType:
|
||||
def test_union_type_equal(self):
|
||||
assert comparator(int | str, int | str)
|
||||
|
||||
def test_union_type_not_equal(self):
|
||||
assert not comparator(int | str, int | float)
|
||||
|
||||
def test_union_type_order_independent(self):
|
||||
assert comparator(int | str, str | int)
|
||||
|
||||
def test_union_type_multiple_args(self):
|
||||
assert comparator(int | str | float, int | str | float)
|
||||
|
||||
def test_union_type_in_list(self):
|
||||
assert comparator([int | str, 1], [int | str, 1])
|
||||
|
||||
def test_union_type_in_dict(self):
|
||||
assert comparator({"key": int | str}, {"key": int | str})
|
||||
|
||||
def test_union_type_vs_none(self):
|
||||
assert not comparator(int | str, None)
|
||||
|
||||
|
||||
class SlotsOnly:
|
||||
__slots__ = ("x", "y")
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
|
||||
class SlotsInherited(SlotsOnly):
|
||||
__slots__ = ("z",)
|
||||
|
||||
def __init__(self, x, y, z):
|
||||
super().__init__(x, y)
|
||||
self.z = z
|
||||
|
||||
|
||||
class TestSlotsObjects:
|
||||
def test_slots_equal(self):
|
||||
assert comparator(SlotsOnly(1, 2), SlotsOnly(1, 2))
|
||||
|
||||
def test_slots_not_equal(self):
|
||||
assert not comparator(SlotsOnly(1, 2), SlotsOnly(1, 3))
|
||||
|
||||
def test_slots_inherited_equal(self):
|
||||
assert comparator(SlotsInherited(1, 2, 3), SlotsInherited(1, 2, 3))
|
||||
|
||||
def test_slots_inherited_not_equal(self):
|
||||
assert not comparator(SlotsInherited(1, 2, 3), SlotsInherited(1, 2, 4))
|
||||
|
||||
def test_slots_nested(self):
|
||||
a = SlotsOnly(SlotsOnly(1, 2), [3, 4])
|
||||
b = SlotsOnly(SlotsOnly(1, 2), [3, 4])
|
||||
assert comparator(a, b)
|
||||
|
||||
def test_slots_nested_not_equal(self):
|
||||
a = SlotsOnly(SlotsOnly(1, 2), [3, 4])
|
||||
b = SlotsOnly(SlotsOnly(1, 9), [3, 4])
|
||||
assert not comparator(a, b)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -132,7 +132,7 @@ def test_class_method_dependencies() -> None:
|
|||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(
|
||||
func_optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=TestConfig(
|
||||
tests_root=file_path,
|
||||
|
|
@ -163,6 +163,7 @@ def test_class_method_dependencies() -> None:
|
|||
== """# file: test_function_dependencies.py
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class Graph:
|
||||
def __init__(self, vertices):
|
||||
self.graph = defaultdict(list)
|
||||
|
|
@ -201,7 +202,7 @@ def test_recursive_function_context() -> None:
|
|||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(
|
||||
func_optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=TestConfig(
|
||||
tests_root=file_path,
|
||||
|
|
|
|||
|
|
@ -680,8 +680,14 @@ def test_in_dunder_tests():
|
|||
|
||||
# Combine all discovered functions
|
||||
all_functions = {}
|
||||
for discovered in [discovered_source, discovered_test, discovered_test_underscore,
|
||||
discovered_spec, discovered_tests_dir, discovered_dunder_tests]:
|
||||
for discovered in [
|
||||
discovered_source,
|
||||
discovered_test,
|
||||
discovered_test_underscore,
|
||||
discovered_spec,
|
||||
discovered_tests_dir,
|
||||
discovered_dunder_tests,
|
||||
]:
|
||||
all_functions.update(discovered)
|
||||
|
||||
# Test Case 1: tests_root == module_root (overlapping case)
|
||||
|
|
@ -781,9 +787,7 @@ def test_filter_functions_strict_string_matching():
|
|||
|
||||
# Strict check: exactly these 3 files should remain (those with 'test' as substring only)
|
||||
expected_files = {contest_file, latest_file, attestation_file}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[contest_file]] == ["run_contest"], (
|
||||
|
|
@ -871,9 +875,7 @@ def test_filter_functions_test_directory_patterns():
|
|||
|
||||
# Strict check: exactly these 2 files should remain (those in non-test directories)
|
||||
expected_files = {contest_file, latest_file}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[contest_file]] == ["get_scores"], (
|
||||
|
|
@ -936,9 +938,7 @@ def test_filter_functions_non_overlapping_tests_root():
|
|||
|
||||
# Strict check: exactly these 2 files should remain (both in src/, not in tests/)
|
||||
expected_files = {source_file, test_in_src}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[source_file]] == ["process"], (
|
||||
|
|
@ -1047,20 +1047,15 @@ def test_deep_copy():
|
|||
)
|
||||
|
||||
root_functions = [fn.function_name for fn in filtered.get(root_source_file, [])]
|
||||
assert root_functions == ["main"], (
|
||||
f"Expected ['main'], got {root_functions}"
|
||||
)
|
||||
assert root_functions == ["main"], f"Expected ['main'], got {root_functions}"
|
||||
|
||||
# Strict check: exactly 3 functions (2 from utils.py + 1 from main.py)
|
||||
assert count == 3, (
|
||||
f"Expected exactly 3 functions, got {count}. "
|
||||
f"Some source files may have been incorrectly filtered."
|
||||
f"Expected exactly 3 functions, got {count}. Some source files may have been incorrectly filtered."
|
||||
)
|
||||
|
||||
# Verify test file was properly filtered (should not be in results)
|
||||
assert test_file not in filtered, (
|
||||
f"Test file {test_file} should have been filtered but wasn't"
|
||||
)
|
||||
assert test_file not in filtered, f"Test file {test_file} should have been filtered but wasn't"
|
||||
|
||||
|
||||
def test_filter_functions_typescript_project_in_tests_folder():
|
||||
|
|
@ -1214,9 +1209,7 @@ def sample_data():
|
|||
# source_file and file_in_test_dir should remain
|
||||
# test_prefix_file, conftest_file, and test_in_subdir should be filtered
|
||||
expected_files = {source_file, file_in_test_dir}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
assert set(filtered.keys()) == expected_files, f"Expected {expected_files}, got {set(filtered.keys())}"
|
||||
assert count == 2, f"Expected exactly 2 functions, got {count}"
|
||||
|
||||
|
||||
|
|
@ -1266,7 +1259,8 @@ class TestHelpers:
|
|||
""")
|
||||
|
||||
support = PythonSupport()
|
||||
functions = support.discover_functions(fixture_file)
|
||||
source = fixture_file.read_text(encoding="utf-8")
|
||||
functions = support.discover_functions(source, fixture_file)
|
||||
function_names = [fn.function_name for fn in functions]
|
||||
|
||||
assert "regular_function" in function_names
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.models.models import FunctionParent, get_code_block_splitter
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
|
@ -233,7 +233,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
ctx_result = func_optimizer.get_code_optimization_context()
|
||||
|
|
@ -404,7 +404,7 @@ def test_bubble_sort_deps() -> None:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
ctx_result = func_optimizer.get_code_optimization_context()
|
||||
|
|
@ -427,6 +427,7 @@ def dep2_swap(arr, j):
|
|||
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
|
||||
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
|
||||
|
||||
|
||||
def sorter_deps(arr):
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def test_basic_class() -> None:
|
|||
class_var = "value"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ def test_dunder_methods() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ def test_dunder_methods_remove_docstring() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ def test_class_remove_docstring() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -131,7 +131,7 @@ def test_mixed_remove_docstring() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -171,7 +171,7 @@ def test_docstrings() -> None:
|
|||
\"\"\"Class docstring.\"\"\"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ def test_method_signatures() -> None:
|
|||
|
||||
expected = """"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -212,7 +212,7 @@ def test_multiple_top_level_targets() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set()
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -232,7 +232,7 @@ def test_class_annotations() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -254,7 +254,7 @@ def test_class_annotations_if() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -280,7 +280,7 @@ def test_class_annotations_try() -> None:
|
|||
continue
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -316,7 +316,7 @@ def test_class_annotations_else() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -331,7 +331,7 @@ def test_top_level_functions() -> None:
|
|||
|
||||
expected = """"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -350,7 +350,7 @@ def test_module_var() -> None:
|
|||
x = 5
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -377,7 +377,7 @@ def test_module_var_if() -> None:
|
|||
z = 10
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -412,7 +412,7 @@ def test_conditional_class_definitions() -> None:
|
|||
platform = "other"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -471,7 +471,7 @@ def test_multiple_except_clauses() -> None:
|
|||
error_type = "cleanup"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -524,7 +524,7 @@ def test_with_statement_and_loops() -> None:
|
|||
context = "cleanup"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -573,7 +573,7 @@ def test_async_with_try_except() -> None:
|
|||
status = "cancelled"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -675,7 +675,7 @@ def test_simplified_complete_implementation() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -768,5 +768,5 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
{"DataProcessor.target_method", "ResultHandler.target_method"},
|
||||
set(),
|
||||
remove_docstrings=True,
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ def test_simple_function() -> None:
|
|||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
||||
|
||||
expected = dedent("""
|
||||
def target_function():
|
||||
|
|
@ -32,7 +32,7 @@ def test_class_method() -> None:
|
|||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -56,7 +56,7 @@ def test_class_with_attributes() -> None:
|
|||
def other_method(self):
|
||||
print("this should be excluded")
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -80,7 +80,7 @@ def test_basic_class_structure() -> None:
|
|||
def not_findable(self):
|
||||
return 42
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class Outer:
|
||||
|
|
@ -100,7 +100,7 @@ def test_top_level_targets() -> None:
|
|||
def target_function():
|
||||
return 42
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
||||
|
||||
expected = dedent("""
|
||||
def target_function():
|
||||
|
|
@ -123,7 +123,7 @@ def test_multiple_top_level_classes() -> None:
|
|||
def process(self):
|
||||
return "C"
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class ClassA:
|
||||
|
|
@ -148,7 +148,7 @@ def test_try_except_structure() -> None:
|
|||
def handle_error(self):
|
||||
print("error")
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"}).code
|
||||
|
||||
expected = dedent("""
|
||||
try:
|
||||
|
|
@ -175,7 +175,7 @@ def test_init_method() -> None:
|
|||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -200,7 +200,7 @@ def test_dunder_method() -> None:
|
|||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
@ -221,7 +221,7 @@ def test_no_targets_found() -> None:
|
|||
def target(self):
|
||||
pass
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}).code
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
def method(self):
|
||||
|
|
@ -266,5 +266,55 @@ def test_module_var() -> None:
|
|||
var2 = "test"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_comment_between_imports_and_variable_preserves_position() -> None:
|
||||
code = """
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# NOTE: This comment documents the constant below.
|
||||
# It should stay right above SOME_RE, not jump to the top of the file.
|
||||
SOME_RE = re.compile(r"^pattern", re.MULTILINE)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Item:
|
||||
name: str
|
||||
value: int
|
||||
children: list[Item] = field(default_factory=list)
|
||||
|
||||
|
||||
def parse(text: str) -> list[Item]:
|
||||
root = Item(name="root", value=0)
|
||||
for m in SOME_RE.finditer(text):
|
||||
root.children.append(Item(name=m.group(), value=1))
|
||||
return root.children
|
||||
"""
|
||||
|
||||
expected = """
|
||||
# NOTE: This comment documents the constant below.
|
||||
# It should stay right above SOME_RE, not jump to the top of the file.
|
||||
SOME_RE = re.compile(r"^pattern", re.MULTILINE)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Item:
|
||||
name: str
|
||||
value: int
|
||||
children: list[Item] = field(default_factory=list)
|
||||
|
||||
|
||||
def parse(text: str) -> list[Item]:
|
||||
root = Item(name="root", value=0)
|
||||
for m in SOME_RE.finditer(text):
|
||||
root.children.append(Item(name=m.group(), value=1))
|
||||
return root.children
|
||||
"""
|
||||
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"parse"}).code
|
||||
assert result.strip() == dedent(expected).strip()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ def test_simple_function() -> None:
|
|||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
|
||||
|
||||
expected = """
|
||||
def target_function():
|
||||
|
|
@ -44,7 +44,7 @@ def test_basic_class() -> None:
|
|||
print("This should be included")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ def test_dunder_methods() -> None:
|
|||
print("include me")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -107,7 +107,7 @@ def test_dunder_methods_remove_docstring() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ def test_class_remove_docstring() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -181,7 +181,7 @@ def test_method_signatures() -> None:
|
|||
return "value"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -215,7 +215,7 @@ def test_multiple_top_level_targets() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set()
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -238,7 +238,7 @@ def test_class_annotations() -> None:
|
|||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -263,7 +263,7 @@ def test_class_annotations_if() -> None:
|
|||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -304,7 +304,7 @@ def test_conditional_class_definitions() -> None:
|
|||
print("other")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -333,7 +333,7 @@ def test_try_except_structure() -> None:
|
|||
print("error")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -355,7 +355,7 @@ def test_module_var() -> None:
|
|||
x = 5
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -385,7 +385,7 @@ def test_module_var_if() -> None:
|
|||
z = 10
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -416,7 +416,7 @@ def test_multiple_classes() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set()
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -477,7 +477,7 @@ def test_with_statement_and_loops() -> None:
|
|||
print("cleanup")
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -532,7 +532,7 @@ def test_async_with_try_except() -> None:
|
|||
await self.cleanup()
|
||||
"""
|
||||
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
|
||||
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -659,7 +659,7 @@ def test_simplified_complete_implementation() -> None:
|
|||
|
||||
output = parse_code_and_prune_cst(
|
||||
dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -765,5 +765,5 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
{"DataProcessor.target_method", "ResultHandler.target_method"},
|
||||
set(),
|
||||
remove_docstrings=True,
|
||||
)
|
||||
).code
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""Tests for JavaScript/TypeScript project initialization and package manager detection."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -8,6 +10,7 @@ from codeflash.cli_cmds.init_javascript import (
|
|||
JsPackageManager,
|
||||
determine_js_package_manager,
|
||||
get_package_install_command,
|
||||
should_modify_package_json_config,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -281,3 +284,67 @@ class TestGetPackageInstallCommand:
|
|||
result = get_package_install_command(tmp_project, "typescript", dev=True)
|
||||
|
||||
assert result == ["pnpm", "add", "typescript", "--save-dev"]
|
||||
|
||||
|
||||
class TestShouldModifySkipConfirm:
|
||||
"""Tests for should_modify_package_json_config with skip_confirm."""
|
||||
|
||||
def test_should_modify_skip_confirm_no_config(self, tmp_project: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""With skip_confirm and no codeflash config, should return (True, None)."""
|
||||
monkeypatch.chdir(tmp_project)
|
||||
(tmp_project / "package.json").write_text(json.dumps({"name": "test"}))
|
||||
|
||||
should_modify, config = should_modify_package_json_config(skip_confirm=True)
|
||||
|
||||
assert should_modify is True
|
||||
assert config is None
|
||||
|
||||
def test_should_modify_skip_confirm_with_valid_config(
|
||||
self, tmp_project: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""With skip_confirm and valid config, should return (False, config) — no reconfigure."""
|
||||
monkeypatch.chdir(tmp_project)
|
||||
codeflash_config = {"moduleRoot": "."}
|
||||
(tmp_project / "package.json").write_text(
|
||||
json.dumps({"name": "test", "codeflash": codeflash_config})
|
||||
)
|
||||
|
||||
should_modify, config = should_modify_package_json_config(skip_confirm=True)
|
||||
|
||||
assert should_modify is False
|
||||
assert config == codeflash_config
|
||||
|
||||
def test_should_modify_skip_confirm_with_invalid_config(
|
||||
self, tmp_project: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""With skip_confirm and invalid config (bad moduleRoot), should return (True, None)."""
|
||||
monkeypatch.chdir(tmp_project)
|
||||
codeflash_config = {"moduleRoot": "/nonexistent/path/that/does/not/exist"}
|
||||
(tmp_project / "package.json").write_text(
|
||||
json.dumps({"name": "test", "codeflash": codeflash_config})
|
||||
)
|
||||
|
||||
should_modify, config = should_modify_package_json_config(skip_confirm=True)
|
||||
|
||||
assert should_modify is True
|
||||
assert config is None
|
||||
|
||||
|
||||
class TestCollectJsSetupInfoSkipConfirm:
|
||||
"""Tests for collect_js_setup_info with skip_confirm."""
|
||||
|
||||
def test_collect_js_setup_info_skip_confirm(self, tmp_project: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""skip_confirm should return defaults without any interactive prompts."""
|
||||
monkeypatch.chdir(tmp_project)
|
||||
(tmp_project / "package.json").write_text(json.dumps({"name": "test"}))
|
||||
|
||||
from codeflash.cli_cmds.init_javascript import ProjectLanguage, collect_js_setup_info
|
||||
|
||||
# Should not call any prompt functions
|
||||
with patch("codeflash.cli_cmds.init_javascript.inquirer") as mock_inquirer:
|
||||
setup_info = collect_js_setup_info(ProjectLanguage.JAVASCRIPT, skip_confirm=True)
|
||||
mock_inquirer.prompt.assert_not_called()
|
||||
|
||||
assert setup_info.module_root_override is None
|
||||
assert setup_info.formatter_override is None
|
||||
assert setup_info.git_remote == "origin"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory
|
|||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ def test_add_decorator_imports_helper_in_class():
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
|
|
@ -94,7 +94,7 @@ def test_add_decorator_imports_helper_in_nested_class():
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
|
|
@ -143,7 +143,7 @@ def test_add_decorator_imports_nodeps():
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
|
|
@ -194,7 +194,7 @@ def test_add_decorator_imports_helper_outside():
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter_deps", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
|
|
@ -271,7 +271,7 @@ class helper:
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_write_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -20,12 +20,15 @@ All assertions use strict string equality to verify exact extraction output.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context_for_language
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -61,7 +64,8 @@ export function add(a, b) {
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
|
||||
|
|
@ -87,7 +91,8 @@ export const multiply = (a, b) => a * b;
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "multiply"
|
||||
|
|
@ -121,7 +126,8 @@ export function add(a, b) {
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -173,7 +179,8 @@ export async function processItems(items, callback, options = {}) {
|
|||
file_path = temp_project / "processor.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -243,7 +250,8 @@ export class CacheManager {
|
|||
file_path = temp_project / "cache.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
get_or_compute = next(f for f in functions if f.function_name == "getOrCompute")
|
||||
|
||||
context = js_support.extract_code_context(get_or_compute, temp_project, temp_project)
|
||||
|
|
@ -339,7 +347,8 @@ export function validateUserData(data, validators) {
|
|||
file_path = temp_project / "validator.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = next(f for f in functions if f.function_name == "validateUserData")
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -429,7 +438,8 @@ export async function fetchWithRetry(endpoint, options = {}) {
|
|||
file_path = temp_project / "api.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = next(f for f in functions if f.function_name == "fetchWithRetry")
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -515,7 +525,8 @@ export function validateField(value, fieldType) {
|
|||
file_path = temp_project / "validation.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -578,7 +589,8 @@ export function processUserInput(rawInput) {
|
|||
file_path = temp_project / "processor.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
process_func = next(f for f in functions if f.function_name == "processUserInput")
|
||||
|
||||
context = js_support.extract_code_context(process_func, temp_project, temp_project)
|
||||
|
|
@ -633,7 +645,8 @@ export function generateReport(data) {
|
|||
file_path = temp_project / "report.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
report_func = next(f for f in functions if f.function_name == "generateReport")
|
||||
|
||||
context = js_support.extract_code_context(report_func, temp_project, temp_project)
|
||||
|
|
@ -731,7 +744,8 @@ export class Graph {
|
|||
file_path = temp_project / "graph.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
topo_sort = next(f for f in functions if f.function_name == "topologicalSort")
|
||||
|
||||
context = js_support.extract_code_context(topo_sort, temp_project, temp_project)
|
||||
|
|
@ -819,7 +833,8 @@ export class MainClass {
|
|||
file_path = temp_project / "classes.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
main_method = next(f for f in functions if f.function_name == "mainMethod" and f.class_name == "MainClass")
|
||||
|
||||
context = js_support.extract_code_context(main_method, temp_project, temp_project)
|
||||
|
|
@ -875,7 +890,8 @@ module.exports = { sortFromAnotherFile };
|
|||
main_path = temp_project / "bubble_sort_imported.js"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
source = main_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, main_path)
|
||||
main_func = next(f for f in functions if f.function_name == "sortFromAnotherFile")
|
||||
|
||||
context = js_support.extract_code_context(main_func, temp_project, temp_project)
|
||||
|
|
@ -926,7 +942,8 @@ export function processNumber(n) {
|
|||
main_path = temp_project / "main.js"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
source = main_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, main_path)
|
||||
process_func = next(f for f in functions if f.function_name == "processNumber")
|
||||
|
||||
context = js_support.extract_code_context(process_func, temp_project, temp_project)
|
||||
|
|
@ -992,7 +1009,8 @@ export function handleUserInput(rawInput) {
|
|||
main_path = temp_project / "main.js"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
source = main_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, main_path)
|
||||
handle_func = next(f for f in functions if f.function_name == "handleUserInput")
|
||||
|
||||
context = js_support.extract_code_context(handle_func, temp_project, temp_project)
|
||||
|
|
@ -1043,7 +1061,8 @@ export function createEntity<T extends object>(data: T): Entity<T> {
|
|||
file_path = temp_project / "entity.ts"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = ts_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -1133,7 +1152,8 @@ export class TypedCache<T> {
|
|||
file_path = temp_project / "cache.ts"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
context = ts_support.extract_code_context(get_method, temp_project, temp_project)
|
||||
|
|
@ -1217,7 +1237,8 @@ export function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE
|
|||
service_path = temp_project / "service.ts"
|
||||
service_path.write_text(service_code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(service_path)
|
||||
source = service_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, service_path)
|
||||
func = next(f for f in functions if f.function_name == "createUser")
|
||||
|
||||
context = ts_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -1271,7 +1292,8 @@ export function factorial(n) {
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -1301,7 +1323,8 @@ export function isOdd(n) {
|
|||
file_path = temp_project / "parity.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
is_even = next(f for f in functions if f.function_name == "isEven")
|
||||
|
||||
context = js_support.extract_code_context(is_even, temp_project, temp_project)
|
||||
|
|
@ -1319,12 +1342,15 @@ export function isEven(n) {
|
|||
assert helper_names == ["isOdd"]
|
||||
|
||||
# Verify helper source
|
||||
assert context.helper_functions[0].source_code == """\
|
||||
assert (
|
||||
context.helper_functions[0].source_code
|
||||
== """\
|
||||
export function isOdd(n) {
|
||||
if (n === 0) return false;
|
||||
return isEven(n - 1);
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_complex_recursive_tree_traversal(self, js_support, temp_project):
|
||||
"""Test complex recursive tree traversal with multiple recursive calls."""
|
||||
|
|
@ -1363,7 +1389,8 @@ export function collectAllValues(root) {
|
|||
file_path = temp_project / "tree.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
collect_func = next(f for f in functions if f.function_name == "collectAllValues")
|
||||
|
||||
context = js_support.extract_code_context(collect_func, temp_project, temp_project)
|
||||
|
|
@ -1428,7 +1455,8 @@ export async function fetchUserProfile(userId) {
|
|||
file_path = temp_project / "api.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
profile_func = next(f for f in functions if f.function_name == "fetchUserProfile")
|
||||
|
||||
context = js_support.extract_code_context(profile_func, temp_project, temp_project)
|
||||
|
|
@ -1483,7 +1511,8 @@ module.exports = { Counter };
|
|||
file_path = temp_project / "counter.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
increment_func = next(fn for fn in functions if fn.function_name == "increment")
|
||||
|
||||
# Step 1: Extract code context
|
||||
|
|
@ -1563,7 +1592,8 @@ export function processApiResponse({
|
|||
file_path = temp_project / "api.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -1605,7 +1635,8 @@ export function* fibonacci(limit) {
|
|||
file_path = temp_project / "generators.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
range_func = next(f for f in functions if f.function_name == "range")
|
||||
|
||||
context = js_support.extract_code_context(range_func, temp_project, temp_project)
|
||||
|
|
@ -1640,7 +1671,8 @@ export function createUserObject(name, email, age) {
|
|||
file_path = temp_project / "user.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -1790,7 +1822,8 @@ export const sendSlackMessage = async (
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
target_func = "sendSlackMessage"
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
func_info = next(f for f in functions if f.function_name == target_func)
|
||||
fto = FunctionToOptimize(
|
||||
function_name=target_func,
|
||||
|
|
@ -1804,9 +1837,11 @@ export const sendSlackMessage = async (
|
|||
language="typescript",
|
||||
)
|
||||
|
||||
ctx = get_code_optimization_context_for_language(
|
||||
fto, temp_project
|
||||
test_config = TestConfig(
|
||||
tests_root=temp_project, tests_project_rootdir=temp_project, project_root_path=temp_project
|
||||
)
|
||||
func_optimizer = JavaScriptFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config, aiservice_client=MagicMock())
|
||||
ctx = func_optimizer.get_code_optimization_context().unwrap()
|
||||
|
||||
# The read_writable_code should contain the target function AND helper functions
|
||||
expected_read_writable = """```typescript:slack_util.ts
|
||||
|
|
@ -1899,7 +1934,6 @@ let web: WebClient | null = null"""
|
|||
assert ctx.read_only_context_code == expected_read_only
|
||||
|
||||
|
||||
|
||||
class TestContextProperties:
|
||||
"""Tests for CodeContext object properties."""
|
||||
|
||||
|
|
@ -1913,7 +1947,8 @@ export function test() {
|
|||
file_path = temp_project / "test.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
context = js_support.extract_code_context(functions[0], temp_project, temp_project)
|
||||
|
||||
assert context.language == Language.JAVASCRIPT
|
||||
|
|
@ -1932,7 +1967,8 @@ export function test(): number {
|
|||
file_path = temp_project / "test.ts"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
context = ts_support.extract_code_context(functions[0], temp_project, temp_project)
|
||||
|
||||
# TypeScript uses JavaScript language enum
|
||||
|
|
@ -1974,7 +2010,8 @@ export class Calculator {
|
|||
file_path = temp_project / "calculator.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
|
||||
for func in functions:
|
||||
if func.function_name != "constructor":
|
||||
|
|
|
|||
|
|
@ -107,10 +107,8 @@ class TestJavaScriptCodeContext:
|
|||
"""Test extracting code context for a JavaScript function."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
fib_file = js_project_dir / "fibonacci.js"
|
||||
if not fib_file.exists():
|
||||
|
|
@ -122,7 +120,11 @@ class TestJavaScriptCodeContext:
|
|||
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
|
||||
assert fib_func is not None
|
||||
|
||||
context = get_code_optimization_context(fib_func, js_project_dir)
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
code_context = js_support.extract_code_context(fib_func, js_project_dir, js_project_dir)
|
||||
context = JavaScriptFunctionOptimizer._build_optimization_context(
|
||||
code_context, fib_file, "javascript", js_project_dir
|
||||
)
|
||||
|
||||
assert context.read_writable_code is not None
|
||||
assert context.read_writable_code.language == "javascript"
|
||||
|
|
|
|||
|
|
@ -71,10 +71,8 @@ module.exports = { add };
|
|||
"""Verify language is preserved in code context extraction."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
ts_file = tmp_path / "utils.ts"
|
||||
ts_file.write_text("""
|
||||
|
|
@ -86,7 +84,11 @@ export function add(a: number, b: number): number {
|
|||
functions = find_all_functions_in_file(ts_file)
|
||||
func = functions[ts_file][0]
|
||||
|
||||
context = get_code_optimization_context(func, tmp_path)
|
||||
ts_support = get_language_support(Language.TYPESCRIPT)
|
||||
code_context = ts_support.extract_code_context(func, tmp_path, tmp_path)
|
||||
context = JavaScriptFunctionOptimizer._build_optimization_context(
|
||||
code_context, ts_file, "typescript", tmp_path
|
||||
)
|
||||
|
||||
assert context.read_writable_code is not None
|
||||
assert context.read_writable_code.language == "typescript"
|
||||
|
|
@ -373,10 +375,7 @@ describe('fibonacci', () => {
|
|||
"""Test get_code_optimization_context for JavaScript."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
src_file = js_project / "utils.js"
|
||||
functions = find_all_functions_in_file(src_file)
|
||||
|
|
@ -398,7 +397,7 @@ describe('fibonacci', () => {
|
|||
pytest_cmd="jest",
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
|
|
@ -415,10 +414,7 @@ describe('fibonacci', () => {
|
|||
"""Test get_code_optimization_context for TypeScript."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
src_file = ts_project / "utils.ts"
|
||||
functions = find_all_functions_in_file(src_file)
|
||||
|
|
@ -440,7 +436,7 @@ describe('fibonacci', () => {
|
|||
pytest_cmd="vitest",
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
|
|
@ -461,10 +457,7 @@ class TestHelperFunctionLanguageAttribute:
|
|||
"""Verify helper functions have language='javascript' for .js files."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
# Create a file with helper functions
|
||||
src_file = tmp_path / "main.js"
|
||||
|
|
@ -499,7 +492,7 @@ module.exports = { main };
|
|||
pytest_cmd="jest",
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
|
|
@ -515,10 +508,7 @@ module.exports = { main };
|
|||
"""Verify helper functions have language='typescript' for .ts files."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
# Create a file with helper functions
|
||||
src_file = tmp_path / "main.ts"
|
||||
|
|
@ -551,7 +541,7 @@ export function main(): number {
|
|||
pytest_cmd="vitest",
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@ NOTE: These tests require:
|
|||
Tests will be skipped if dependencies are not available.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
|
@ -26,7 +24,7 @@ import pytest
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestType, TestingMode
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -58,13 +56,7 @@ def install_dependencies(project_dir: Path) -> bool:
|
|||
if has_node_modules(project_dir):
|
||||
return True
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=project_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120
|
||||
)
|
||||
result = subprocess.run(["npm", "install"], cwd=project_dir, capture_output=True, text=True, timeout=120)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -82,6 +74,7 @@ def skip_if_js_not_supported():
|
|||
"""Skip test if JavaScript/TypeScript languages are not supported."""
|
||||
try:
|
||||
from codeflash.languages import get_language_support
|
||||
|
||||
get_language_support(Language.JAVASCRIPT)
|
||||
except Exception as e:
|
||||
pytest.skip(f"JavaScript/TypeScript language support not available: {e}")
|
||||
|
|
@ -157,8 +150,8 @@ module.exports = {
|
|||
"""Test that JavaScript test instrumentation module can be imported."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.languages import get_language_support
|
||||
|
||||
# Verify the instrumentation module can be imported
|
||||
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
|
||||
|
||||
# Get JavaScript support
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
|
|
@ -272,8 +265,8 @@ export default defineConfig({
|
|||
"""Test that TypeScript test instrumentation module can be imported."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.languages import get_language_support
|
||||
|
||||
# Verify the instrumentation module can be imported
|
||||
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
|
||||
|
||||
test_file = ts_project_dir / "tests" / "math.test.ts"
|
||||
|
||||
|
|
@ -356,10 +349,7 @@ class TestRunAndParseJavaScriptTests:
|
|||
"""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
# Find the fibonacci function
|
||||
fib_file = vitest_project / "fibonacci.ts"
|
||||
|
|
@ -389,10 +379,8 @@ class TestRunAndParseJavaScriptTests:
|
|||
)
|
||||
|
||||
# Create optimizer
|
||||
func_optimizer = FunctionOptimizer(
|
||||
function_to_optimize=func,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
func_optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
|
||||
# Get code context - this should work
|
||||
|
|
@ -419,8 +407,8 @@ class TestTimingMarkerParsing:
|
|||
# The marker format used by codeflash for JavaScript
|
||||
# Start marker: !$######{tag}######$!
|
||||
# End marker: !######{tag}:{duration}######!
|
||||
start_pattern = r'!\$######(.+?)######\$!'
|
||||
end_pattern = r'!######(.+?):(\d+)######!'
|
||||
start_pattern = r"!\$######(.+?)######\$!"
|
||||
end_pattern = r"!######(.+?):(\d+)######!"
|
||||
|
||||
start_marker = "!$######test/math.test.ts:TestMath.test_add:add:1:0_0######$!"
|
||||
end_marker = "!######test/math.test.ts:TestMath.test_add:add:1:0_0:12345######!"
|
||||
|
|
@ -472,6 +460,7 @@ class TestJavaScriptTestResultParsing:
|
|||
|
||||
# Parse the XML
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
tree = ET.parse(junit_xml)
|
||||
root = tree.getroot()
|
||||
|
||||
|
|
@ -504,6 +493,7 @@ class TestJavaScriptTestResultParsing:
|
|||
|
||||
# Parse the XML
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
tree = ET.parse(junit_xml)
|
||||
root = tree.getroot()
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ export function add(a, b) {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
|
|
@ -76,7 +76,7 @@ export function multiply(a, b) {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 3
|
||||
names = {func.function_name for func in functions}
|
||||
|
|
@ -94,7 +94,7 @@ export const multiply = (x, y) => x * y;
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
names = {func.function_name for func in functions}
|
||||
|
|
@ -114,7 +114,7 @@ export function withoutReturn() {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
# Only the function with return should be discovered
|
||||
assert len(functions) == 1
|
||||
|
|
@ -136,7 +136,7 @@ export class Calculator {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
for func in functions:
|
||||
|
|
@ -157,7 +157,7 @@ export function syncFunction() {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
|
||||
|
|
@ -182,7 +182,7 @@ export function syncFunc() {
|
|||
f.flush()
|
||||
|
||||
criteria = FunctionFilterCriteria(include_async=False)
|
||||
functions = js_support.discover_functions(Path(f.name), criteria)
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "syncFunc"
|
||||
|
|
@ -204,7 +204,7 @@ export class MyClass {
|
|||
f.flush()
|
||||
|
||||
criteria = FunctionFilterCriteria(include_methods=False)
|
||||
functions = js_support.discover_functions(Path(f.name), criteria)
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "standalone"
|
||||
|
|
@ -224,7 +224,7 @@ export function func2() {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
func1 = next(f for f in functions if f.function_name == "func1")
|
||||
func2 = next(f for f in functions if f.function_name == "func2")
|
||||
|
|
@ -246,7 +246,7 @@ export function* numberGenerator() {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "numberGenerator"
|
||||
|
|
@ -257,14 +257,14 @@ export function* numberGenerator() {
|
|||
f.write("this is not valid javascript {{{{")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
# Tree-sitter is lenient, so it may still parse partial code
|
||||
# The important thing is it doesn't crash
|
||||
assert isinstance(functions, list)
|
||||
|
||||
def test_discover_nonexistent_file_returns_empty(self, js_support):
|
||||
"""Test that nonexistent file returns empty list."""
|
||||
functions = js_support.discover_functions(Path("/nonexistent/file.js"))
|
||||
functions = js_support.discover_functions("", Path("/nonexistent/file.js"))
|
||||
assert functions == []
|
||||
|
||||
def test_discover_function_expression(self, js_support):
|
||||
|
|
@ -277,7 +277,7 @@ export const add = function(a, b) {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
|
|
@ -296,7 +296,7 @@ export function named() {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
# Only the named function should be discovered
|
||||
assert len(functions) == 1
|
||||
|
|
@ -507,7 +507,7 @@ export function main(a) {
|
|||
file_path = Path(f.name)
|
||||
|
||||
# First discover functions to get accurate line numbers
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
context = js_support.extract_code_context(main_func, file_path.parent, file_path.parent)
|
||||
|
|
@ -535,7 +535,7 @@ class TestIntegration:
|
|||
file_path = Path(f.name)
|
||||
|
||||
# Discover
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "fibonacci"
|
||||
|
|
@ -584,7 +584,7 @@ export function standalone() {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should find 4 functions
|
||||
assert len(functions) == 4
|
||||
|
|
@ -623,7 +623,7 @@ export default Button;
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should find both components
|
||||
names = {f.function_name for f in functions}
|
||||
|
|
@ -653,7 +653,7 @@ describe('Math functions', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -687,7 +687,7 @@ class TestClassMethodExtraction:
|
|||
file_path = Path(f.name)
|
||||
|
||||
# Discover the method
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
# Extract code context
|
||||
|
|
@ -725,7 +725,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -763,7 +763,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
fib_method = next(f for f in functions if f.function_name == "fibonacci")
|
||||
|
||||
context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -802,7 +802,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_method = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
||||
if add_method:
|
||||
|
|
@ -832,7 +832,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
fetch_method = next(f for f in functions if f.function_name == "fetchData")
|
||||
|
||||
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -865,7 +865,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_method = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
||||
if add_method:
|
||||
|
|
@ -894,7 +894,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
method = next(f for f in functions if f.function_name == "simpleMethod")
|
||||
|
||||
context = js_support.extract_code_context(method, file_path.parent, file_path.parent)
|
||||
|
|
@ -1079,7 +1079,7 @@ class TestClassMethodEdgeCases:
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should find constructor and increment
|
||||
names = {f.function_name for f in functions}
|
||||
|
|
@ -1109,7 +1109,7 @@ class TestClassMethodEdgeCases:
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should find at least greet
|
||||
names = {f.function_name for f in functions}
|
||||
|
|
@ -1137,7 +1137,7 @@ export class Dog extends Animal {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Find Dog's fetch method
|
||||
fetch_method = next((f for f in functions if f.function_name == "fetch" and f.class_name == "Dog"), None)
|
||||
|
|
@ -1172,7 +1172,7 @@ export class Dog extends Animal {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should at least find publicMethod
|
||||
names = {f.function_name for f in functions}
|
||||
|
|
@ -1192,7 +1192,7 @@ module.exports = { Calculator };
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -1212,7 +1212,7 @@ module.exports = { Calculator };
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Find the add method
|
||||
add_method = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
|
@ -1265,7 +1265,7 @@ module.exports = { Counter };
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
increment_func = next(fn for fn in functions if fn.function_name == "increment")
|
||||
|
||||
# Step 1: Extract code context (includes constructor for AI context)
|
||||
|
|
@ -1362,7 +1362,7 @@ export class User {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
functions = ts_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
get_name_func = next(fn for fn in functions if fn.function_name == "getName")
|
||||
|
||||
# Step 1: Extract code context (includes fields and constructor)
|
||||
|
|
@ -1462,7 +1462,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_func = next(fn for fn in functions if fn.function_name == "add")
|
||||
|
||||
# Extract context for add
|
||||
|
|
@ -1546,7 +1546,7 @@ export class MathUtils {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_func = next(fn for fn in functions if fn.function_name == "add")
|
||||
|
||||
# Extract context
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ describe('add function', () => {
|
|||
""")
|
||||
|
||||
# Discover functions first
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
assert len(functions) == 1
|
||||
|
||||
# Discover tests
|
||||
|
|
@ -90,7 +90,7 @@ describe('multiply', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -124,7 +124,7 @@ test('formats date correctly', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -170,7 +170,7 @@ describe('String Utils', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -208,7 +208,7 @@ describe('sum function', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -242,7 +242,7 @@ test('subtract two numbers', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -270,7 +270,7 @@ test('greets by name', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -316,7 +316,7 @@ describe('Calculator class', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should find tests for class methods
|
||||
|
|
@ -363,7 +363,7 @@ describe('clamp', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -399,7 +399,7 @@ describe('async utilities', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -436,7 +436,7 @@ describe('Button component', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# JSX tests should be discovered
|
||||
|
|
@ -466,7 +466,7 @@ test('other test', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should not find tests for our function
|
||||
|
|
@ -502,7 +502,7 @@ describe('validators', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should find tests for isEmail
|
||||
|
|
@ -546,7 +546,7 @@ test('helper2 returns 2', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -574,7 +574,7 @@ test(`formatNumber with decimal`, () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# May or may not find depending on template literal handling
|
||||
|
|
@ -605,7 +605,7 @@ describe('transform', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should still find tests since original name is imported
|
||||
|
|
@ -626,7 +626,7 @@ it('third test', () => {});
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -651,7 +651,7 @@ describe('Suite B', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -675,7 +675,7 @@ describe('Outer', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -699,7 +699,7 @@ describe.skip('skipped describe', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -720,7 +720,7 @@ describe.only('only describe', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -738,7 +738,7 @@ describe('describe single', () => {});
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -757,7 +757,7 @@ describe("describe double", () => {});
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -773,7 +773,7 @@ describe("describe double", () => {});
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -806,7 +806,7 @@ test('funcA works', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# funcA should have tests
|
||||
|
|
@ -833,7 +833,7 @@ test('funcX works', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# funcX should have tests
|
||||
|
|
@ -859,7 +859,7 @@ test('mainFunc works', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -896,7 +896,7 @@ test('block commented', () => {
|
|||
*/
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -921,7 +921,7 @@ test('broken test' { // Missing arrow function
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
# Should not crash
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
assert isinstance(tests, dict)
|
||||
|
|
@ -949,7 +949,7 @@ describe('conflict tests', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should still work despite naming conflicts
|
||||
|
|
@ -966,7 +966,7 @@ export function lonelyFunc() { return 'alone'; }
|
|||
module.exports = { lonelyFunc };
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should return empty dict, not crash
|
||||
|
|
@ -1001,7 +1001,7 @@ test('funcA works', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions_a = js_support.discover_functions(file_a)
|
||||
functions_a = js_support.discover_functions(file_a.read_text(encoding="utf-8"), file_a)
|
||||
tests = js_support.discover_tests(tmpdir, functions_a)
|
||||
|
||||
# Should handle circular imports gracefully
|
||||
|
|
@ -1047,7 +1047,7 @@ test.each([
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1073,7 +1073,7 @@ describe.each([
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1098,7 +1098,7 @@ describe('Math operations', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1174,7 +1174,7 @@ describe('formatName', () => {
|
|||
""")
|
||||
|
||||
# Discover functions
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
assert len(functions) == 3
|
||||
|
||||
# Discover tests
|
||||
|
|
@ -1242,7 +1242,7 @@ describe('Database', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -1280,7 +1280,7 @@ test('funcA works', () => {
|
|||
""")
|
||||
|
||||
# Discover functions from moduleB
|
||||
functions_b = js_support.discover_functions(source_b)
|
||||
functions_b = js_support.discover_functions(source_b.read_text(encoding="utf-8"), source_b)
|
||||
tests = js_support.discover_tests(tmpdir, functions_b)
|
||||
|
||||
# funcB should not have any tests since test file doesn't import it
|
||||
|
|
@ -1312,7 +1312,7 @@ test('funcOne returns 1', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Check that tests were found
|
||||
|
|
@ -1340,7 +1340,7 @@ test('mentions targetFunc in string', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Current implementation may still match on string occurrence
|
||||
|
|
@ -1367,7 +1367,7 @@ test('calculate doubles', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should find tests since 'calculate' appears in source
|
||||
|
|
@ -1399,7 +1399,7 @@ describe('MyClass', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should find tests for class methods
|
||||
|
|
@ -1432,7 +1432,7 @@ test('deepHelper works', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -1456,7 +1456,7 @@ testCases.forEach(name => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1484,7 +1484,7 @@ describe('conditional tests', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1508,7 +1508,7 @@ test('slow test', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1531,7 +1531,7 @@ test.todo('also needs implementation');
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1554,7 +1554,7 @@ test.concurrent('concurrent test 2', async () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1597,7 +1597,7 @@ describe('subtractNumbers', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# All three functions should be discovered
|
||||
|
|
@ -1628,7 +1628,7 @@ describe('Unrelated name', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should still find tests
|
||||
|
|
@ -1653,7 +1653,7 @@ describe('Array', function() {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1684,7 +1684,7 @@ describe('User', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1712,7 +1712,7 @@ export class Calculator {
|
|||
module.exports = { Calculator };
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
|
||||
# Check qualified names include class
|
||||
add_func = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
|
@ -1737,7 +1737,7 @@ export class Outer {
|
|||
module.exports = { Outer };
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
|
||||
# Should find at least the Outer class method
|
||||
assert any(f.class_name == "Outer" for f in functions)
|
||||
|
|
|
|||
|
|
@ -728,3 +728,370 @@ class TestBundlerModuleResolutionFix:
|
|||
# Verify codeflash configs were NOT created
|
||||
assert not (tmpdir_path / "jest.codeflash.config.js").exists()
|
||||
assert not (tmpdir_path / "tsconfig.codeflash.json").exists()
|
||||
|
||||
|
||||
class TestBundledJestReporter:
|
||||
"""Tests for the bundled codeflash/jest-reporter.
|
||||
|
||||
Verifies that:
|
||||
1. The reporter JS file exists in the runtime package
|
||||
2. Jest commands reference 'codeflash/jest-reporter' (not jest-junit)
|
||||
3. The reporter produces valid JUnit XML
|
||||
4. The CODEFLASH_JEST_REPORTER constant is correct
|
||||
"""
|
||||
|
||||
def test_reporter_js_file_exists(self):
|
||||
"""The jest-reporter.js file must exist in the runtime directory."""
|
||||
reporter_path = Path(__file__).resolve().parents[2] / "packages" / "codeflash" / "runtime" / "jest-reporter.js"
|
||||
assert reporter_path.exists(), f"jest-reporter.js not found at {reporter_path}"
|
||||
|
||||
def test_reporter_constant_value(self):
|
||||
"""CODEFLASH_JEST_REPORTER should be 'codeflash/jest-reporter'."""
|
||||
from codeflash.languages.javascript.test_runner import CODEFLASH_JEST_REPORTER
|
||||
|
||||
assert CODEFLASH_JEST_REPORTER == "codeflash/jest-reporter"
|
||||
|
||||
def test_behavioral_command_uses_bundled_reporter(self):
|
||||
"""run_jest_behavioral_tests should use codeflash/jest-reporter in --reporters flag."""
|
||||
from codeflash.languages.javascript.test_runner import run_jest_behavioral_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
test_file = test_dir / "test_func.test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 1
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
try:
|
||||
run_jest_behavioral_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if mock_run.called:
|
||||
cmd = mock_run.call_args[0][0]
|
||||
reporter_args = [a for a in cmd if "--reporters=" in a and "jest-reporter" in a]
|
||||
assert len(reporter_args) == 1, f"Expected exactly one codeflash/jest-reporter flag, got: {reporter_args}"
|
||||
assert reporter_args[0] == "--reporters=codeflash/jest-reporter"
|
||||
# Must NOT reference jest-junit
|
||||
jest_junit_args = [a for a in cmd if "jest-junit" in a]
|
||||
assert len(jest_junit_args) == 0, f"Should not reference jest-junit: {jest_junit_args}"
|
||||
|
||||
def test_benchmarking_command_uses_bundled_reporter(self):
|
||||
"""run_jest_benchmarking_tests should use codeflash/jest-reporter."""
|
||||
from codeflash.languages.javascript.test_runner import run_jest_benchmarking_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
test_file = test_dir / "test_func__perf.test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 1
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
try:
|
||||
run_jest_benchmarking_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if mock_run.called:
|
||||
cmd = mock_run.call_args[0][0]
|
||||
reporter_args = [a for a in cmd if "--reporters=codeflash/jest-reporter" in a]
|
||||
assert len(reporter_args) == 1
|
||||
|
||||
def test_line_profile_command_uses_bundled_reporter(self):
|
||||
"""run_jest_line_profile_tests should use codeflash/jest-reporter."""
|
||||
from codeflash.languages.javascript.test_runner import run_jest_line_profile_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
test_dir = tmpdir_path / "test"
|
||||
test_dir.mkdir()
|
||||
test_file = test_dir / "test_func__line.test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
mock_test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 1
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
try:
|
||||
run_jest_line_profile_tests(
|
||||
test_paths=mock_test_files,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
project_root=tmpdir_path,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if mock_run.called:
|
||||
cmd = mock_run.call_args[0][0]
|
||||
reporter_args = [a for a in cmd if "--reporters=codeflash/jest-reporter" in a]
|
||||
assert len(reporter_args) == 1
|
||||
|
||||
def test_reporter_produces_valid_junit_xml(self):
|
||||
"""The reporter JS should produce JUnit XML parseable by junitparser."""
|
||||
import subprocess
|
||||
|
||||
reporter_path = Path(__file__).resolve().parents[2] / "packages" / "codeflash" / "runtime" / "jest-reporter.js"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_file = Path(tmpdir) / "results.xml"
|
||||
|
||||
# Create a Node.js script that exercises the reporter with mock data
|
||||
test_script = Path(tmpdir) / "test_reporter.js"
|
||||
test_script.write_text(f"""
|
||||
// Set env vars BEFORE requiring reporter (matches real Jest behavior)
|
||||
process.env.JEST_JUNIT_OUTPUT_FILE = '{output_file}';
|
||||
process.env.JEST_JUNIT_CLASSNAME = '{{filepath}}';
|
||||
process.env.JEST_JUNIT_SUITE_NAME = '{{filepath}}';
|
||||
process.env.JEST_JUNIT_ADD_FILE_ATTRIBUTE = 'true';
|
||||
process.env.JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT = 'true';
|
||||
|
||||
const Reporter = require('{reporter_path}');
|
||||
|
||||
// Mock Jest globalConfig
|
||||
const globalConfig = {{ rootDir: '/tmp/project' }};
|
||||
const reporter = new Reporter(globalConfig, {{}});
|
||||
|
||||
// Mock test results (matches Jest's aggregatedResults structure)
|
||||
const results = {{
|
||||
testResults: [
|
||||
{{
|
||||
testFilePath: '/tmp/project/test/math.test.js',
|
||||
displayName: 'math tests',
|
||||
console: [{{ type: 'log', message: 'CODEFLASH_START test1' }}],
|
||||
testResults: [
|
||||
{{
|
||||
fullName: 'math > adds numbers',
|
||||
title: 'adds numbers',
|
||||
status: 'passed',
|
||||
duration: 12,
|
||||
}},
|
||||
{{
|
||||
fullName: 'math > handles failure',
|
||||
title: 'handles failure',
|
||||
status: 'failed',
|
||||
duration: 5,
|
||||
failureMessages: ['Expected 4 but got 5'],
|
||||
}},
|
||||
{{
|
||||
fullName: 'math > skipped test',
|
||||
title: 'skipped test',
|
||||
status: 'pending',
|
||||
duration: 0,
|
||||
}},
|
||||
],
|
||||
}},
|
||||
],
|
||||
}};
|
||||
|
||||
// Simulate onTestFileResult for console capture
|
||||
reporter.onTestFileResult(null, results.testResults[0], null);
|
||||
|
||||
// Simulate onRunComplete
|
||||
reporter.onRunComplete([], results);
|
||||
|
||||
console.log('OK');
|
||||
""")
|
||||
|
||||
result = subprocess.run(
|
||||
["node", str(test_script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"Reporter script failed: {result.stderr}"
|
||||
assert output_file.exists(), "Reporter did not create output file"
|
||||
|
||||
xml_content = output_file.read_text()
|
||||
|
||||
# Verify basic XML structure
|
||||
assert '<?xml version="1.0"' in xml_content
|
||||
assert "<testsuites" in xml_content
|
||||
assert "<testsuite" in xml_content
|
||||
assert "<testcase" in xml_content
|
||||
|
||||
# Verify classname uses filepath template
|
||||
assert 'classname="/tmp/project/test/math.test.js"' in xml_content
|
||||
|
||||
# Verify file attribute is present
|
||||
assert 'file="/tmp/project/test/math.test.js"' in xml_content
|
||||
|
||||
# Verify failure element
|
||||
assert "<failure" in xml_content
|
||||
assert "Expected 4 but got 5" in xml_content
|
||||
|
||||
# Verify skipped element
|
||||
assert "<skipped/>" in xml_content
|
||||
|
||||
# Verify system-out with console output
|
||||
assert "<system-out>" in xml_content
|
||||
assert "CODEFLASH_START" in xml_content
|
||||
|
||||
# Verify it's parseable by junitparser (our actual parser)
|
||||
from junitparser import JUnitXml
|
||||
|
||||
parsed = JUnitXml.fromfile(str(output_file))
|
||||
suites = list(parsed)
|
||||
assert len(suites) == 1
|
||||
testcases = list(suites[0])
|
||||
assert len(testcases) == 3
|
||||
|
||||
def test_reporter_export_in_package_json(self):
|
||||
"""package.json should export codeflash/jest-reporter."""
|
||||
import json
|
||||
|
||||
pkg_path = Path(__file__).resolve().parents[2] / "packages" / "codeflash" / "package.json"
|
||||
with pkg_path.open() as f:
|
||||
pkg = json.load(f)
|
||||
|
||||
exports = pkg.get("exports", {})
|
||||
assert "./jest-reporter" in exports, "Missing ./jest-reporter export in package.json"
|
||||
assert exports["./jest-reporter"]["require"] == "./runtime/jest-reporter.js"
|
||||
|
||||
|
||||
|
||||
class TestUnsupportedFrameworkError:
|
||||
"""Tests for clear error on unsupported test frameworks."""
|
||||
|
||||
def test_unknown_framework_raises_error_behavioral(self):
|
||||
"""run_behavioral_tests should raise NotImplementedError for unknown frameworks."""
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
support = JavaScriptSupport()
|
||||
with pytest.raises(NotImplementedError, match="not yet supported"):
|
||||
support.run_behavioral_tests(
|
||||
test_paths=MagicMock(),
|
||||
test_env={},
|
||||
cwd=Path("."),
|
||||
test_framework="tap",
|
||||
)
|
||||
|
||||
def test_unknown_framework_raises_error_benchmarking(self):
|
||||
"""run_benchmarking_tests should raise NotImplementedError for unknown frameworks."""
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
support = JavaScriptSupport()
|
||||
with pytest.raises(NotImplementedError, match="not yet supported"):
|
||||
support.run_benchmarking_tests(
|
||||
test_paths=MagicMock(),
|
||||
test_env={},
|
||||
cwd=Path("."),
|
||||
test_framework="tap",
|
||||
)
|
||||
|
||||
def test_unknown_framework_raises_error_line_profile(self):
|
||||
"""run_line_profile_tests should raise NotImplementedError for unknown frameworks."""
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
support = JavaScriptSupport()
|
||||
with pytest.raises(NotImplementedError, match="not yet supported"):
|
||||
support.run_line_profile_tests(
|
||||
test_paths=MagicMock(),
|
||||
test_env={},
|
||||
cwd=Path("."),
|
||||
test_framework="tap",
|
||||
)
|
||||
|
||||
def test_jest_framework_does_not_raise_not_implemented(self):
|
||||
"""jest framework should NOT raise NotImplementedError."""
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
support = JavaScriptSupport()
|
||||
try:
|
||||
support.run_behavioral_tests(
|
||||
test_paths=MagicMock(),
|
||||
test_env={},
|
||||
cwd=Path("."),
|
||||
test_framework="jest",
|
||||
)
|
||||
except NotImplementedError:
|
||||
pytest.fail("jest framework should not raise NotImplementedError")
|
||||
except Exception:
|
||||
pass # Other exceptions are fine — Jest isn't installed in test env
|
||||
|
||||
def test_mocha_framework_does_not_raise_not_implemented(self):
|
||||
"""mocha framework should NOT raise NotImplementedError."""
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport
|
||||
|
||||
support = JavaScriptSupport()
|
||||
try:
|
||||
support.run_behavioral_tests(
|
||||
test_paths=MagicMock(),
|
||||
test_env={},
|
||||
cwd=Path("."),
|
||||
test_framework="mocha",
|
||||
)
|
||||
except NotImplementedError:
|
||||
pytest.fail("mocha framework should not raise NotImplementedError")
|
||||
except Exception:
|
||||
pass # Other exceptions are fine — Mocha isn't installed in test env
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from codeflash.languages.base import Language
|
|||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
|
|
@ -37,7 +37,7 @@ class TestCodeExtractorCJS:
|
|||
def test_discover_class_methods(self, js_support, cjs_project):
|
||||
"""Test that class methods are discovered correctly."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
|
|
@ -47,17 +47,19 @@ class TestCodeExtractorCJS:
|
|||
def test_class_method_has_correct_parent(self, js_support, cjs_project):
|
||||
"""Test parent class information for methods."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
for func in functions:
|
||||
# All methods should belong to Calculator class
|
||||
assert func.is_method is True, f"{func.function_name} should be a method"
|
||||
assert func.class_name == "Calculator", f"{func.function_name} should belong to Calculator, got {func.class_name}"
|
||||
assert func.class_name == "Calculator", (
|
||||
f"{func.function_name} should belong to Calculator, got {func.class_name}"
|
||||
)
|
||||
|
||||
def test_extract_permutation_code(self, js_support, cjs_project):
|
||||
"""Test permutation method code extraction."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
|
|
@ -93,7 +95,7 @@ class Calculator {
|
|||
def test_extract_context_includes_direct_helpers(self, js_support, cjs_project):
|
||||
"""Test that direct helper functions are included in context."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
|
|
@ -129,7 +131,7 @@ export function factorial(n) {
|
|||
def test_extract_compound_interest_code(self, js_support, cjs_project):
|
||||
"""Test calculateCompoundInterest code extraction."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
|
|
@ -175,7 +177,7 @@ class Calculator {
|
|||
def test_extract_compound_interest_helpers(self, js_support, cjs_project):
|
||||
"""Test helper extraction for calculateCompoundInterest."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
|
|
@ -235,7 +237,7 @@ export function validateInput(value, name) {
|
|||
def test_extract_context_includes_imports(self, js_support, cjs_project):
|
||||
"""Test import statement extraction."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
|
|
@ -256,7 +258,7 @@ export function validateInput(value, name) {
|
|||
def test_extract_static_method(self, js_support, cjs_project):
|
||||
"""Test static method extraction (quickAdd)."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
quick_add_func = next(f for f in functions if f.function_name == "quickAdd")
|
||||
|
||||
|
|
@ -315,7 +317,7 @@ class TestCodeExtractorESM:
|
|||
def test_discover_esm_methods(self, js_support, esm_project):
|
||||
"""Test method discovery in ESM project."""
|
||||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
|
|
@ -326,7 +328,7 @@ class TestCodeExtractorESM:
|
|||
def test_esm_permutation_extraction(self, js_support, esm_project):
|
||||
"""Test permutation method extraction in ESM."""
|
||||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
|
|
@ -376,7 +378,7 @@ export function factorial(n) {
|
|||
def test_esm_compound_interest_extraction(self, js_support, esm_project):
|
||||
"""Test calculateCompoundInterest extraction in ESM with import syntax."""
|
||||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
|
|
@ -502,7 +504,7 @@ class TestCodeExtractorTypeScript:
|
|||
def test_discover_ts_methods(self, ts_support, ts_project):
|
||||
"""Test method discovery in TypeScript."""
|
||||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
|
|
@ -513,7 +515,7 @@ class TestCodeExtractorTypeScript:
|
|||
def test_ts_permutation_extraction(self, ts_support, ts_project):
|
||||
"""Test permutation method extraction in TypeScript."""
|
||||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
|
|
@ -566,7 +568,7 @@ export function factorial(n: number): number {
|
|||
def test_ts_compound_interest_extraction(self, ts_support, ts_project):
|
||||
"""Test calculateCompoundInterest extraction in TypeScript."""
|
||||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
|
|
@ -676,7 +678,7 @@ module.exports = { standalone };
|
|||
test_file = tmp_path / "standalone.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
func = next(f for f in functions if f.function_name == "standalone")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -709,7 +711,7 @@ module.exports = { processArray };
|
|||
test_file = tmp_path / "processor.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
func = next(f for f in functions if f.function_name == "processArray")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -744,7 +746,7 @@ module.exports = { fibonacci };
|
|||
test_file = tmp_path / "recursive.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
func = next(f for f in functions if f.function_name == "fibonacci")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -777,7 +779,7 @@ module.exports = { processValue };
|
|||
test_file = tmp_path / "arrow.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
func = next(f for f in functions if f.function_name == "processValue")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -835,7 +837,7 @@ module.exports = { Counter };
|
|||
test_file = tmp_path / "counter.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
increment_func = next(f for f in functions if f.function_name == "increment")
|
||||
|
||||
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -874,7 +876,7 @@ module.exports = { MathUtils };
|
|||
test_file = tmp_path / "math_utils.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
add_func = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -910,7 +912,7 @@ export class User {
|
|||
test_file = tmp_path / "user.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
get_name_func = next(f for f in functions if f.function_name == "getName")
|
||||
|
||||
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -949,7 +951,7 @@ export class Config {
|
|||
test_file = tmp_path / "config.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
get_url_func = next(f for f in functions if f.function_name == "getUrl")
|
||||
|
||||
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -990,7 +992,7 @@ module.exports = { Logger };
|
|||
test_file = tmp_path / "logger.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
get_prefix_func = next(f for f in functions if f.function_name == "getPrefix")
|
||||
|
||||
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1032,7 +1034,7 @@ module.exports = { Factory };
|
|||
test_file = tmp_path / "factory.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
create_func = next(f for f in functions if f.function_name == "create")
|
||||
|
||||
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1074,7 +1076,7 @@ class TestCodeExtractorIntegration:
|
|||
js_support = get_language_support("javascript")
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
target = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents]
|
||||
|
|
@ -1099,7 +1101,7 @@ class TestCodeExtractorIntegration:
|
|||
pytest_cmd="jest",
|
||||
)
|
||||
|
||||
func_optimizer = FunctionOptimizer(
|
||||
func_optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
result = func_optimizer.get_code_optimization_context()
|
||||
|
|
@ -1182,7 +1184,7 @@ export function distance(p1: Point, p2: Point): number {
|
|||
test_file = tmp_path / "geometry.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
distance_func = next(f for f in functions if f.function_name == "distance")
|
||||
|
||||
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1224,7 +1226,7 @@ export function processStatus(status: Status): string {
|
|||
test_file = tmp_path / "status.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
process_func = next(f for f in functions if f.function_name == "processStatus")
|
||||
|
||||
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1259,7 +1261,7 @@ export function compute(x: number): Result<number> {
|
|||
test_file = tmp_path / "compute.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
compute_func = next(f for f in functions if f.function_name == "compute")
|
||||
|
||||
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1301,7 +1303,7 @@ export class Service {
|
|||
test_file = tmp_path / "service.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
get_timeout_func = next(f for f in functions if f.function_name == "getTimeout")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
|
|
@ -1332,7 +1334,7 @@ export function add(a: number, b: number): number {
|
|||
test_file = tmp_path / "add.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
add_func = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1363,7 +1365,7 @@ export function createRect(origin: Point, size: Size): { origin: Point; size: Si
|
|||
test_file = tmp_path / "rect.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
create_rect_func = next(f for f in functions if f.function_name == "createRect")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
|
|
@ -1409,7 +1411,7 @@ export function calculateDistance(p1: Point, p2: Point, config: CalculationConfi
|
|||
}
|
||||
""")
|
||||
|
||||
functions = ts_support.discover_functions(geometry_file)
|
||||
functions = ts_support.discover_functions(geometry_file.read_text(encoding="utf-8"), geometry_file)
|
||||
calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
|
|
@ -1460,7 +1462,7 @@ export function greetUser(user: User): string {
|
|||
test_file = tmp_path / "user.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
greet_func = next(f for f in functions if f.function_name == "greetUser")
|
||||
|
||||
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ These tests verify that code replacement correctly handles:
|
|||
- ES Modules (import/export) syntax
|
||||
- TypeScript import handling
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
|
|
@ -14,8 +15,8 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_for_language
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.code_replacer import replace_function_definitions_for_language
|
||||
from codeflash.languages.current import set_current_language
|
||||
from codeflash.languages.javascript.module_system import (
|
||||
ModuleSystem,
|
||||
|
|
@ -25,7 +26,6 @@ from codeflash.languages.javascript.module_system import (
|
|||
ensure_module_system_compatibility,
|
||||
get_import_statement,
|
||||
)
|
||||
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
|
||||
|
|
@ -50,7 +50,6 @@ def temp_project(tmp_path):
|
|||
return project_root
|
||||
|
||||
|
||||
|
||||
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
|
||||
|
||||
|
|
@ -308,7 +307,9 @@ class TestTsJestSkipsConversion:
|
|||
When ts-jest is installed, it handles module interoperability internally,
|
||||
so we skip conversion to avoid breaking valid imports.
|
||||
"""
|
||||
def __init__(self):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _set_language(self):
|
||||
set_current_language(Language.TYPESCRIPT)
|
||||
|
||||
def test_commonjs_not_converted_when_ts_jest_installed(self, tmp_path):
|
||||
|
|
@ -751,6 +752,7 @@ class TestIntegrationWithFixtures:
|
|||
f"import statements should be converted to require.\nFound import lines: {import_lines}"
|
||||
)
|
||||
|
||||
|
||||
class TestSimpleFunctionReplacement:
|
||||
"""Tests for simple function body replacement with strict assertions."""
|
||||
|
||||
|
|
@ -764,7 +766,8 @@ export function add(a, b) {
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
# Optimized version with different body
|
||||
|
|
@ -800,7 +803,8 @@ export function processData(data) {
|
|||
file_path = temp_project / "processor.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
# Optimized version using map
|
||||
|
|
@ -839,7 +843,8 @@ module.exports = { targetFunction, otherFunction };
|
|||
file_path = temp_project / "module.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
target_func = next(f for f in functions if f.function_name == "targetFunction")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -891,7 +896,8 @@ export class Calculator {
|
|||
file_path = temp_project / "calculator.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
# Optimized version provided in class context
|
||||
|
|
@ -954,7 +960,8 @@ export class DataProcessor {
|
|||
file_path = temp_project / "processor.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
process_method = next(f for f in functions if f.function_name == "process")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1016,7 +1023,8 @@ export function add(a, b) {
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1070,7 +1078,8 @@ export class Cache {
|
|||
file_path = temp_project / "cache.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1131,7 +1140,8 @@ export async function fetchData(url) {
|
|||
file_path = temp_project / "api.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1172,7 +1182,8 @@ export class ApiClient {
|
|||
file_path = temp_project / "client.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1223,7 +1234,8 @@ export function* range(start, end) {
|
|||
file_path = temp_project / "generators.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1262,7 +1274,8 @@ export function processArray(items: number[]): number {
|
|||
file_path = temp_project / "processor.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1303,7 +1316,8 @@ export class Container<T> {
|
|||
file_path = temp_project / "container.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
get_all_method = next(f for f in functions if f.function_name == "getAll")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1356,7 +1370,8 @@ export function createUser(name: string, email: string): User {
|
|||
file_path = temp_project / "user.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
func = next(f for f in functions if f.function_name == "createUser")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1411,7 +1426,8 @@ export function processItems(items) {
|
|||
file_path = temp_project / "processor.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
process_func = next(f for f in functions if f.function_name == "processItems")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1458,7 +1474,8 @@ export class MathUtils {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
# First replacement: sum method
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
sum_method = next(f for f in functions if f.function_name == "sum")
|
||||
|
||||
optimized_sum = """\
|
||||
|
|
@ -1505,7 +1522,8 @@ export function processConfig({ server: { host, port }, database: { url, poolSiz
|
|||
file_path = temp_project / "config.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1544,7 +1562,8 @@ export function minimal() {
|
|||
file_path = temp_project / "minimal.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1571,7 +1590,8 @@ export function identity(x) { return x; }
|
|||
file_path = temp_project / "utils.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1598,7 +1618,8 @@ export function formatMessage(name) {
|
|||
file_path = temp_project / "formatter.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1633,7 +1654,8 @@ export function validateEmail(email) {
|
|||
file_path = temp_project / "validator.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1676,7 +1698,8 @@ module.exports = { main, helper };
|
|||
file_path = temp_project / "module.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1719,7 +1742,8 @@ export function main(data) {
|
|||
file_path = temp_project / "module.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1750,20 +1774,16 @@ class TestSyntaxValidation:
|
|||
"""Test that various replacements all produce valid JavaScript."""
|
||||
test_cases = [
|
||||
# (original, optimized, description)
|
||||
(
|
||||
"export function f(x) { return x + 1; }",
|
||||
"export function f(x) { return ++x; }",
|
||||
"increment replacement"
|
||||
),
|
||||
("export function f(x) { return x + 1; }", "export function f(x) { return ++x; }", "increment replacement"),
|
||||
(
|
||||
"export function f(arr) { return arr.length > 0; }",
|
||||
"export function f(arr) { return !!arr.length; }",
|
||||
"boolean conversion"
|
||||
"boolean conversion",
|
||||
),
|
||||
(
|
||||
"export function f(a, b) { if (a) { return a; } return b; }",
|
||||
"export function f(a, b) { return a || b; }",
|
||||
"logical OR replacement"
|
||||
"logical OR replacement",
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -1771,7 +1791,8 @@ class TestSyntaxValidation:
|
|||
file_path = temp_project / f"test_{i}.js"
|
||||
file_path.write_text(original, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
result = js_support.replace_function(original, func, optimized)
|
||||
|
|
@ -1875,7 +1896,8 @@ export class DataProcessor<T> {
|
|||
target_func = "findDuplicates"
|
||||
parent_class = "DataProcessor"
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
# find function
|
||||
target_func_info = None
|
||||
for func in functions:
|
||||
|
|
@ -1920,11 +1942,15 @@ class DataProcessor<T> {
|
|||
```
|
||||
"""
|
||||
code_markdown = CodeStringsMarkdown.parse_markdown_code(new_code)
|
||||
replaced = replace_function_definitions_for_language([f"{parent_class}.{target_func}"], code_markdown, file_path, temp_project)
|
||||
replaced = replace_function_definitions_for_language(
|
||||
[f"{parent_class}.{target_func}"], code_markdown, file_path, temp_project, lang_support=ts_support
|
||||
)
|
||||
assert replaced
|
||||
|
||||
new_code = file_path.read_text()
|
||||
assert new_code == """/**
|
||||
assert (
|
||||
new_code
|
||||
== """/**
|
||||
* DataProcessor class - demonstrates class method optimization in TypeScript.
|
||||
* Contains intentionally inefficient implementations for optimization testing.
|
||||
*/
|
||||
|
|
@ -2015,7 +2041,7 @@ export class DataProcessor<T> {
|
|||
}
|
||||
}
|
||||
"""
|
||||
|
||||
)
|
||||
|
||||
|
||||
class TestNewVariableFromOptimizedCode:
|
||||
|
|
@ -2030,9 +2056,9 @@ class TestNewVariableFromOptimizedCode:
|
|||
1. Add the new variable after the constant it references
|
||||
2. Replace the function with the optimized version
|
||||
"""
|
||||
from codeflash.models.models import CodeStringsMarkdown, CodeString
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
original_source = '''\
|
||||
original_source = """\
|
||||
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
|
||||
"1234",
|
||||
]);
|
||||
|
|
@ -2040,43 +2066,34 @@ const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
|
|||
export function isCodeflashEmployee(userId: string): boolean {
|
||||
return CODEFLASH_EMPLOYEE_GITHUB_IDS.has(userId);
|
||||
}
|
||||
'''
|
||||
"""
|
||||
file_path = temp_project / "auth.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
# Optimized code introduces a bound method variable for performance
|
||||
optimized_code = '''const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
|
||||
optimized_code = """const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
|
||||
CODEFLASH_EMPLOYEE_GITHUB_IDS
|
||||
);
|
||||
|
||||
export function isCodeflashEmployee(userId: string): boolean {
|
||||
return _has(userId);
|
||||
}
|
||||
'''
|
||||
"""
|
||||
|
||||
code_markdown = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code=optimized_code,
|
||||
file_path=Path("auth.ts"),
|
||||
language="typescript"
|
||||
)
|
||||
],
|
||||
language="typescript"
|
||||
code_strings=[CodeString(code=optimized_code, file_path=Path("auth.ts"), language="typescript")],
|
||||
language="typescript",
|
||||
)
|
||||
|
||||
replaced = replace_function_definitions_for_language(
|
||||
["isCodeflashEmployee"],
|
||||
code_markdown,
|
||||
file_path,
|
||||
temp_project,
|
||||
["isCodeflashEmployee"], code_markdown, file_path, temp_project, lang_support=ts_support
|
||||
)
|
||||
|
||||
assert replaced
|
||||
result = file_path.read_text()
|
||||
|
||||
# Expected result for strict equality check
|
||||
expected_result = '''\
|
||||
expected_result = """\
|
||||
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
|
||||
"1234",
|
||||
]);
|
||||
|
|
@ -2088,11 +2105,9 @@ const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
|
|||
export function isCodeflashEmployee(userId: string): boolean {
|
||||
return _has(userId);
|
||||
}
|
||||
'''
|
||||
"""
|
||||
assert result == expected_result, (
|
||||
f"Result does not match expected output.\n"
|
||||
f"Expected:\n{expected_result}\n\n"
|
||||
f"Got:\n{result}"
|
||||
f"Result does not match expected output.\nExpected:\n{expected_result}\n\nGot:\n{result}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2113,7 +2128,7 @@ class TestImportedTypeNotDuplicated:
|
|||
contains the TreeNode interface definition (from read-only context),
|
||||
the replacement should NOT add the interface to the original file.
|
||||
"""
|
||||
from codeflash.models.models import CodeStringsMarkdown, CodeString
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
# Original source imports TreeNode
|
||||
original_source = """\
|
||||
|
|
@ -2163,20 +2178,13 @@ export function getNearestAbove(
|
|||
|
||||
code_markdown = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code=optimized_code_with_interface,
|
||||
file_path=Path("helpers.ts"),
|
||||
language="typescript"
|
||||
)
|
||||
CodeString(code=optimized_code_with_interface, file_path=Path("helpers.ts"), language="typescript")
|
||||
],
|
||||
language="typescript"
|
||||
language="typescript",
|
||||
)
|
||||
|
||||
replace_function_definitions_for_language(
|
||||
["getNearestAbove"],
|
||||
code_markdown,
|
||||
file_path,
|
||||
temp_project,
|
||||
["getNearestAbove"], code_markdown, file_path, temp_project, lang_support=ts_support
|
||||
)
|
||||
|
||||
result = file_path.read_text()
|
||||
|
|
@ -2203,7 +2211,7 @@ export function getNearestAbove(
|
|||
|
||||
def test_multiple_imported_types_not_duplicated(self, ts_support, temp_project):
|
||||
"""Test that multiple imported types are not duplicated."""
|
||||
from codeflash.models.models import CodeStringsMarkdown, CodeString
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
original_source = """\
|
||||
import type { TreeNode, NodeSpace } from "./constants";
|
||||
|
|
@ -2235,21 +2243,12 @@ export function processNode(node: TreeNode, space: NodeSpace): number {
|
|||
"""
|
||||
|
||||
code_markdown = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code=optimized_code,
|
||||
file_path=Path("processor.ts"),
|
||||
language="typescript"
|
||||
)
|
||||
],
|
||||
language="typescript"
|
||||
code_strings=[CodeString(code=optimized_code, file_path=Path("processor.ts"), language="typescript")],
|
||||
language="typescript",
|
||||
)
|
||||
|
||||
replace_function_definitions_for_language(
|
||||
["processNode"],
|
||||
code_markdown,
|
||||
file_path,
|
||||
temp_project,
|
||||
["processNode"], code_markdown, file_path, temp_project, lang_support=ts_support
|
||||
)
|
||||
|
||||
result = file_path.read_text()
|
||||
|
|
|
|||
|
|
@ -345,8 +345,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(SIMPLE_FUNCTION.python, ".py")
|
||||
js_file = write_temp_file(SIMPLE_FUNCTION.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find exactly one function
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -365,8 +365,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(MULTIPLE_FUNCTIONS.python, ".py")
|
||||
js_file = write_temp_file(MULTIPLE_FUNCTIONS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 3 functions
|
||||
assert len(py_funcs) == 3, f"Python found {len(py_funcs)}, expected 3"
|
||||
|
|
@ -384,8 +384,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(WITH_AND_WITHOUT_RETURN.python, ".py")
|
||||
js_file = write_temp_file(WITH_AND_WITHOUT_RETURN.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find only 1 function (the one with return)
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -400,8 +400,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(CLASS_METHODS.python, ".py")
|
||||
js_file = write_temp_file(CLASS_METHODS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 2 methods
|
||||
assert len(py_funcs) == 2, f"Python found {len(py_funcs)}, expected 2"
|
||||
|
|
@ -421,8 +421,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(ASYNC_FUNCTIONS.python, ".py")
|
||||
js_file = write_temp_file(ASYNC_FUNCTIONS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 2 functions
|
||||
assert len(py_funcs) == 2, f"Python found {len(py_funcs)}, expected 2"
|
||||
|
|
@ -440,32 +440,23 @@ class TestDiscoverFunctionsParity:
|
|||
assert js_sync.is_async is False, "JavaScript sync function should have is_async=False"
|
||||
|
||||
def test_nested_functions_discovery(self, python_support, js_support):
|
||||
"""Both should discover nested functions with parent info."""
|
||||
"""Python skips nested functions; JavaScript discovers them with parent info."""
|
||||
py_file = write_temp_file(NESTED_FUNCTIONS.python, ".py")
|
||||
js_file = write_temp_file(NESTED_FUNCTIONS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 2 functions (outer and inner)
|
||||
assert len(py_funcs) == 2, f"Python found {len(py_funcs)}, expected 2"
|
||||
# Python skips nested functions — only outer is discovered
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
assert py_funcs[0].function_name == "outer"
|
||||
|
||||
# JavaScript discovers both
|
||||
assert len(js_funcs) == 2, f"JavaScript found {len(js_funcs)}, expected 2"
|
||||
|
||||
# Check names
|
||||
py_names = {f.function_name for f in py_funcs}
|
||||
js_names = {f.function_name for f in js_funcs}
|
||||
|
||||
assert py_names == {"outer", "inner"}, f"Python found {py_names}"
|
||||
assert js_names == {"outer", "inner"}, f"JavaScript found {js_names}"
|
||||
|
||||
# Check parent info for inner function
|
||||
py_inner = next(f for f in py_funcs if f.function_name == "inner")
|
||||
js_inner = next(f for f in js_funcs if f.function_name == "inner")
|
||||
|
||||
assert len(py_inner.parents) >= 1, "Python inner should have parent info"
|
||||
assert py_inner.parents[0].name == "outer", "Python inner's parent should be outer"
|
||||
|
||||
# JavaScript nested function parent check
|
||||
assert len(js_inner.parents) >= 1, "JavaScript inner should have parent info"
|
||||
assert js_inner.parents[0].name == "outer", "JavaScript inner's parent should be outer"
|
||||
|
||||
|
|
@ -474,8 +465,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(STATIC_METHODS.python, ".py")
|
||||
js_file = write_temp_file(STATIC_METHODS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 1 function
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -492,8 +483,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(COMPLEX_FILE.python, ".py")
|
||||
js_file = write_temp_file(COMPLEX_FILE.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 4 functions
|
||||
assert len(py_funcs) == 4, f"Python found {len(py_funcs)}, expected 4"
|
||||
|
|
@ -524,8 +515,8 @@ class TestDiscoverFunctionsParity:
|
|||
|
||||
criteria = FunctionFilterCriteria(include_async=False)
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file, criteria)
|
||||
js_funcs = js_support.discover_functions(js_file, criteria)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file, criteria)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file, criteria)
|
||||
|
||||
# Both should find only 1 function (the sync one)
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -542,8 +533,8 @@ class TestDiscoverFunctionsParity:
|
|||
|
||||
criteria = FunctionFilterCriteria(include_methods=False)
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file, criteria)
|
||||
js_funcs = js_support.discover_functions(js_file, criteria)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file, criteria)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file, criteria)
|
||||
|
||||
# Both should find only 1 function (standalone)
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -554,11 +545,11 @@ class TestDiscoverFunctionsParity:
|
|||
assert js_funcs[0].function_name == "standalone"
|
||||
|
||||
def test_nonexistent_file_returns_empty(self, python_support, js_support):
|
||||
"""Both should return empty list for nonexistent files."""
|
||||
py_funcs = python_support.discover_functions(Path("/nonexistent/file.py"))
|
||||
js_funcs = js_support.discover_functions(Path("/nonexistent/file.js"))
|
||||
|
||||
"""Both languages return empty list for empty source."""
|
||||
py_funcs = python_support.discover_functions("", Path("/nonexistent/file.py"))
|
||||
assert py_funcs == []
|
||||
|
||||
js_funcs = js_support.discover_functions("", Path("/nonexistent/file.js"))
|
||||
assert js_funcs == []
|
||||
|
||||
def test_line_numbers_captured(self, python_support, js_support):
|
||||
|
|
@ -566,8 +557,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(SIMPLE_FUNCTION.python, ".py")
|
||||
js_file = write_temp_file(SIMPLE_FUNCTION.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should have start_line and end_line
|
||||
assert py_funcs[0].starting_line is not None
|
||||
|
|
@ -917,8 +908,8 @@ class TestIntegrationParity:
|
|||
js_file = write_temp_file(js_original, ".js")
|
||||
|
||||
# Discover
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
assert len(py_funcs) == 1
|
||||
assert len(js_funcs) == 1
|
||||
|
|
@ -969,8 +960,8 @@ class TestFeatureGaps:
|
|||
py_file = write_temp_file(CLASS_METHODS.python, ".py")
|
||||
js_file = write_temp_file(CLASS_METHODS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
for py_func in py_funcs:
|
||||
# Check all expected fields are populated
|
||||
|
|
@ -1003,7 +994,7 @@ export const multiply = (x, y) => x * y;
|
|||
export const identity = x => x;
|
||||
"""
|
||||
js_file = write_temp_file(js_code, ".js")
|
||||
funcs = js_support.discover_functions(js_file)
|
||||
funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Should find all arrow functions
|
||||
names = {f.function_name for f in funcs}
|
||||
|
|
@ -1030,8 +1021,8 @@ export function* numberGenerator() {
|
|||
py_file = write_temp_file(py_code, ".py")
|
||||
js_file = write_temp_file(js_code, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find the generator
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)} generators"
|
||||
|
|
@ -1054,7 +1045,7 @@ def multi_decorated():
|
|||
return 3
|
||||
"""
|
||||
py_file = write_temp_file(py_code, ".py")
|
||||
funcs = python_support.discover_functions(py_file)
|
||||
funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
|
||||
# Should find all functions regardless of decorators
|
||||
names = {f.function_name for f in funcs}
|
||||
|
|
@ -1074,7 +1065,7 @@ export const namedExpr = function myFunc(x) {
|
|||
};
|
||||
"""
|
||||
js_file = write_temp_file(js_code, ".js")
|
||||
funcs = js_support.discover_functions(js_file)
|
||||
funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Should find function expressions
|
||||
names = {f.function_name for f in funcs}
|
||||
|
|
@ -1094,8 +1085,8 @@ class TestEdgeCases:
|
|||
py_file = write_temp_file("", ".py")
|
||||
js_file = write_temp_file("", ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
assert py_funcs == []
|
||||
assert js_funcs == []
|
||||
|
|
@ -1119,8 +1110,8 @@ Multiline comment
|
|||
py_file = write_temp_file(py_code, ".py")
|
||||
js_file = write_temp_file(js_code, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
assert py_funcs == []
|
||||
assert js_funcs == []
|
||||
|
|
@ -1139,8 +1130,8 @@ export function greeting() {
|
|||
py_file = write_temp_file(py_code, ".py")
|
||||
js_file = write_temp_file(js_code, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
assert len(py_funcs) == 1
|
||||
assert len(js_funcs) == 1
|
||||
|
|
|
|||
502
tests/test_languages/test_mocha_runner.py
Normal file
502
tests/test_languages/test_mocha_runner.py
Normal file
|
|
@ -0,0 +1,502 @@
|
|||
"""Tests for Mocha test runner functionality."""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from junitparser import JUnitXml
|
||||
|
||||
|
||||
class TestMochaJsonToJunitXml:
|
||||
"""Tests for converting Mocha JSON reporter output to JUnit XML."""
|
||||
|
||||
def test_passing_tests(self):
|
||||
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
|
||||
|
||||
mocha_json = json.dumps(
|
||||
{
|
||||
"stats": {"tests": 2, "passes": 2, "failures": 0, "duration": 50},
|
||||
"tests": [
|
||||
{
|
||||
"title": "should add numbers",
|
||||
"fullTitle": "math should add numbers",
|
||||
"duration": 20,
|
||||
"err": {},
|
||||
},
|
||||
{
|
||||
"title": "should subtract numbers",
|
||||
"fullTitle": "math should subtract numbers",
|
||||
"duration": 30,
|
||||
"err": {},
|
||||
},
|
||||
],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
"pending": [],
|
||||
}
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_file = Path(tmpdir) / "results.xml"
|
||||
mocha_json_to_junit_xml(mocha_json, output_file)
|
||||
|
||||
assert output_file.exists()
|
||||
xml = JUnitXml.fromfile(str(output_file))
|
||||
total_tests = sum(suite.tests for suite in xml)
|
||||
assert total_tests == 2
|
||||
|
||||
def test_failing_tests(self):
|
||||
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
|
||||
|
||||
mocha_json = json.dumps(
|
||||
{
|
||||
"stats": {"tests": 1, "passes": 0, "failures": 1, "duration": 10},
|
||||
"tests": [
|
||||
{
|
||||
"title": "should fail",
|
||||
"fullTitle": "errors should fail",
|
||||
"duration": 10,
|
||||
"err": {
|
||||
"message": "expected 1 to equal 2",
|
||||
"stack": "AssertionError: expected 1 to equal 2\n at Context.<anonymous>",
|
||||
},
|
||||
},
|
||||
],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
"pending": [],
|
||||
}
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_file = Path(tmpdir) / "results.xml"
|
||||
mocha_json_to_junit_xml(mocha_json, output_file)
|
||||
|
||||
assert output_file.exists()
|
||||
xml = JUnitXml.fromfile(str(output_file))
|
||||
total_failures = sum(suite.failures for suite in xml)
|
||||
assert total_failures == 1
|
||||
|
||||
def test_pending_tests(self):
|
||||
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
|
||||
|
||||
mocha_json = json.dumps(
|
||||
{
|
||||
"stats": {"tests": 1, "passes": 0, "failures": 0, "pending": 1, "duration": 0},
|
||||
"tests": [
|
||||
{
|
||||
"title": "should be pending",
|
||||
"fullTitle": "todo should be pending",
|
||||
"duration": 0,
|
||||
"pending": True,
|
||||
"err": {},
|
||||
},
|
||||
],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
"pending": [],
|
||||
}
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_file = Path(tmpdir) / "results.xml"
|
||||
mocha_json_to_junit_xml(mocha_json, output_file)
|
||||
|
||||
assert output_file.exists()
|
||||
xml = JUnitXml.fromfile(str(output_file))
|
||||
# Should parse without error and have the test
|
||||
total_tests = sum(suite.tests for suite in xml)
|
||||
assert total_tests == 1
|
||||
|
||||
def test_invalid_json_writes_empty_xml(self):
|
||||
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_file = Path(tmpdir) / "results.xml"
|
||||
mocha_json_to_junit_xml("not valid json {{{", output_file)
|
||||
|
||||
assert output_file.exists()
|
||||
content = output_file.read_text()
|
||||
assert "<testsuites" in content
|
||||
|
||||
def test_multiple_suites(self):
|
||||
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
|
||||
|
||||
mocha_json = json.dumps(
|
||||
{
|
||||
"stats": {"tests": 3, "passes": 3, "failures": 0, "duration": 100},
|
||||
"tests": [
|
||||
{"title": "test1", "fullTitle": "suite A test1", "duration": 10, "err": {}},
|
||||
{"title": "test2", "fullTitle": "suite A test2", "duration": 20, "err": {}},
|
||||
{"title": "test3", "fullTitle": "suite B test3", "duration": 30, "err": {}},
|
||||
],
|
||||
"passes": [],
|
||||
"failures": [],
|
||||
"pending": [],
|
||||
}
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_file = Path(tmpdir) / "results.xml"
|
||||
mocha_json_to_junit_xml(mocha_json, output_file)
|
||||
|
||||
xml = JUnitXml.fromfile(str(output_file))
|
||||
suite_names = [suite.name for suite in xml]
|
||||
assert "suite A" in suite_names
|
||||
assert "suite B" in suite_names
|
||||
|
||||
|
||||
class TestExtractMochaJson:
|
||||
"""Tests for extracting Mocha JSON from mixed stdout."""
|
||||
|
||||
def test_clean_json(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _extract_mocha_json
|
||||
|
||||
data = {"stats": {"tests": 1}, "tests": []}
|
||||
result = _extract_mocha_json(json.dumps(data))
|
||||
assert result is not None
|
||||
assert json.loads(result)["stats"]["tests"] == 1
|
||||
|
||||
def test_json_with_leading_output(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _extract_mocha_json
|
||||
|
||||
stdout = 'Some console output\n{"stats": {"tests": 1}, "tests": []}'
|
||||
result = _extract_mocha_json(stdout)
|
||||
assert result is not None
|
||||
assert json.loads(result)["stats"]["tests"] == 1
|
||||
|
||||
def test_json_with_codeflash_markers(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _extract_mocha_json
|
||||
|
||||
data = {"stats": {"tests": 1}, "tests": []}
|
||||
stdout = f"!######START:test:module:0:test_name######!\n{json.dumps(data)}\n!######END######!"
|
||||
result = _extract_mocha_json(stdout)
|
||||
assert result is not None
|
||||
|
||||
def test_no_json_returns_none(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _extract_mocha_json
|
||||
|
||||
result = _extract_mocha_json("no json here at all")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFindMochaProjectRoot:
|
||||
"""Tests for finding Mocha project root."""
|
||||
|
||||
def test_finds_mocharc_yml(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _find_mocha_project_root
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
root = Path(tmpdir)
|
||||
(root / ".mocharc.yml").write_text("timeout: 5000\n")
|
||||
sub = root / "src" / "lib"
|
||||
sub.mkdir(parents=True)
|
||||
test_file = sub / "test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
result = _find_mocha_project_root(test_file)
|
||||
assert result == root
|
||||
|
||||
def test_finds_mocharc_json(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _find_mocha_project_root
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
root = Path(tmpdir)
|
||||
(root / ".mocharc.json").write_text("{}")
|
||||
test_file = root / "test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
result = _find_mocha_project_root(test_file)
|
||||
assert result == root
|
||||
|
||||
def test_falls_back_to_package_json(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _find_mocha_project_root
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
root = Path(tmpdir)
|
||||
(root / "package.json").write_text('{"name": "test"}')
|
||||
sub = root / "test"
|
||||
sub.mkdir()
|
||||
test_file = sub / "test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
result = _find_mocha_project_root(test_file)
|
||||
assert result == root
|
||||
|
||||
|
||||
class TestMochaBehavioralCommand:
|
||||
"""Tests for building Mocha behavioral commands."""
|
||||
|
||||
def test_basic_command(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _build_mocha_behavioral_command
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
cmd = _build_mocha_behavioral_command(test_files=[test_file])
|
||||
assert "npx" in cmd
|
||||
assert "mocha" in cmd
|
||||
assert "--reporter" in cmd
|
||||
assert "json" in cmd
|
||||
assert "--jobs" in cmd
|
||||
assert "1" in cmd
|
||||
assert "--exit" in cmd
|
||||
|
||||
def test_timeout_flag(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _build_mocha_behavioral_command
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
cmd = _build_mocha_behavioral_command(test_files=[test_file], timeout=30)
|
||||
timeout_idx = cmd.index("--timeout")
|
||||
assert cmd[timeout_idx + 1] == "30000"
|
||||
|
||||
def test_default_timeout(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _build_mocha_behavioral_command
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
cmd = _build_mocha_behavioral_command(test_files=[test_file])
|
||||
timeout_idx = cmd.index("--timeout")
|
||||
assert cmd[timeout_idx + 1] == "60000"
|
||||
|
||||
|
||||
class TestMochaBenchmarkingCommand:
|
||||
"""Tests for building Mocha benchmarking commands."""
|
||||
|
||||
def test_basic_command(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _build_mocha_benchmarking_command
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
cmd = _build_mocha_benchmarking_command(test_files=[test_file])
|
||||
assert "npx" in cmd
|
||||
assert "mocha" in cmd
|
||||
assert "--exit" in cmd
|
||||
|
||||
def test_default_timeout_is_longer(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _build_mocha_benchmarking_command
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
cmd = _build_mocha_benchmarking_command(test_files=[test_file])
|
||||
timeout_idx = cmd.index("--timeout")
|
||||
assert cmd[timeout_idx + 1] == "120000"
|
||||
|
||||
|
||||
class TestMochaLineProfileCommand:
|
||||
"""Tests for building Mocha line profile commands."""
|
||||
|
||||
def test_basic_command(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _build_mocha_line_profile_command
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
cmd = _build_mocha_line_profile_command(test_files=[test_file])
|
||||
assert "npx" in cmd
|
||||
assert "mocha" in cmd
|
||||
assert "--exit" in cmd
|
||||
|
||||
def test_timeout_conversion(self):
|
||||
from codeflash.languages.javascript.mocha_runner import _build_mocha_line_profile_command
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
cmd = _build_mocha_line_profile_command(test_files=[test_file], timeout=45)
|
||||
timeout_idx = cmd.index("--timeout")
|
||||
assert cmd[timeout_idx + 1] == "45000"
|
||||
|
||||
|
||||
class TestRunMochaBehavioralTests:
|
||||
"""Tests for running Mocha behavioral tests with mocked subprocess."""
|
||||
|
||||
@patch("codeflash.languages.javascript.mocha_runner.subprocess.run")
|
||||
@patch("codeflash.languages.javascript.mocha_runner._ensure_runtime_files")
|
||||
def test_sets_codeflash_env_vars(self, mock_ensure, mock_run):
|
||||
from codeflash.languages.javascript.mocha_runner import run_mocha_behavioral_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
mocha_output = json.dumps(
|
||||
{"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10}, "tests": [{"title": "t", "fullTitle": "s t", "duration": 10, "err": {}}], "passes": [], "failures": [], "pending": []}
|
||||
)
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
test_file = tmpdir_path / "test.test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
test_paths = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result_file, result, cov, _ = run_mocha_behavioral_tests(
|
||||
test_paths=test_paths,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
candidate_index=3,
|
||||
)
|
||||
|
||||
# Verify env vars were passed
|
||||
call_kwargs = mock_run.call_args
|
||||
env = call_kwargs.kwargs.get("env") or call_kwargs[1].get("env", {})
|
||||
assert env.get("CODEFLASH_MODE") == "behavior"
|
||||
assert env.get("CODEFLASH_TEST_ITERATION") == "3"
|
||||
assert env.get("CODEFLASH_RANDOM_SEED") == "42"
|
||||
|
||||
@patch("codeflash.languages.javascript.mocha_runner.subprocess.run")
|
||||
@patch("codeflash.languages.javascript.mocha_runner._ensure_runtime_files")
|
||||
def test_returns_none_coverage(self, mock_ensure, mock_run):
|
||||
from codeflash.languages.javascript.mocha_runner import run_mocha_behavioral_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
mocha_output = json.dumps(
|
||||
{"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0}, "tests": [], "passes": [], "failures": [], "pending": []}
|
||||
)
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
test_file = tmpdir_path / "test.test.js"
|
||||
test_file.write_text("// test")
|
||||
|
||||
test_paths = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
_, _, coverage_path, _ = run_mocha_behavioral_tests(
|
||||
test_paths=test_paths,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
)
|
||||
assert coverage_path is None
|
||||
|
||||
|
||||
class TestRunMochaBenchmarkingTests:
|
||||
"""Tests for running Mocha benchmarking tests with mocked subprocess."""
|
||||
|
||||
@patch("codeflash.languages.javascript.mocha_runner.subprocess.run")
|
||||
@patch("codeflash.languages.javascript.mocha_runner._ensure_runtime_files")
|
||||
def test_sets_perf_env_vars(self, mock_ensure, mock_run):
|
||||
from codeflash.languages.javascript.mocha_runner import run_mocha_benchmarking_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
mocha_output = json.dumps(
|
||||
{"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 100}, "tests": [{"title": "perf", "fullTitle": "bench perf", "duration": 100, "err": {}}], "passes": [], "failures": [], "pending": []}
|
||||
)
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
test_file = tmpdir_path / "perf.test.js"
|
||||
test_file.write_text("// perf test")
|
||||
|
||||
test_paths = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
run_mocha_benchmarking_tests(
|
||||
test_paths=test_paths,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
min_loops=3,
|
||||
max_loops=50,
|
||||
target_duration_ms=5000,
|
||||
stability_check=False,
|
||||
)
|
||||
|
||||
call_kwargs = mock_run.call_args
|
||||
env = call_kwargs.kwargs.get("env") or call_kwargs[1].get("env", {})
|
||||
assert env.get("CODEFLASH_MODE") == "performance"
|
||||
assert env.get("CODEFLASH_PERF_LOOP_COUNT") == "50"
|
||||
assert env.get("CODEFLASH_PERF_MIN_LOOPS") == "3"
|
||||
assert env.get("CODEFLASH_PERF_TARGET_DURATION_MS") == "5000"
|
||||
assert env.get("CODEFLASH_PERF_STABILITY_CHECK") == "false"
|
||||
|
||||
|
||||
class TestRunMochaLineProfileTests:
|
||||
"""Tests for running Mocha line profile tests with mocked subprocess."""
|
||||
|
||||
@patch("codeflash.languages.javascript.mocha_runner.subprocess.run")
|
||||
@patch("codeflash.languages.javascript.mocha_runner._ensure_runtime_files")
|
||||
def test_sets_line_profile_env_vars(self, mock_ensure, mock_run):
|
||||
from codeflash.languages.javascript.mocha_runner import run_mocha_line_profile_tests
|
||||
from codeflash.models.models import TestFile, TestFiles
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
mocha_output = json.dumps(
|
||||
{"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0}, "tests": [], "passes": [], "failures": [], "pending": []}
|
||||
)
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
(tmpdir_path / "package.json").write_text('{"name": "test"}')
|
||||
test_file = tmpdir_path / "test.test.js"
|
||||
test_file.write_text("// test")
|
||||
profile_output = tmpdir_path / "profile.json"
|
||||
|
||||
test_paths = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
original_file_path=test_file,
|
||||
instrumented_behavior_file_path=test_file,
|
||||
benchmarking_file_path=test_file,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
run_mocha_line_profile_tests(
|
||||
test_paths=test_paths,
|
||||
test_env={},
|
||||
cwd=tmpdir_path,
|
||||
line_profile_output_file=profile_output,
|
||||
)
|
||||
|
||||
call_kwargs = mock_run.call_args
|
||||
env = call_kwargs.kwargs.get("env") or call_kwargs[1].get("env", {})
|
||||
assert env.get("CODEFLASH_MODE") == "line_profile"
|
||||
assert env.get("CODEFLASH_LINE_PROFILE_OUTPUT") == str(profile_output)
|
||||
|
|
@ -82,9 +82,9 @@ from pathlib import Path
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ def test_js_replcement() -> None:
|
|||
original_helper = helper_file.read_text("utf-8")
|
||||
|
||||
js_support = get_language_support("javascript")
|
||||
functions = js_support.discover_functions(main_file)
|
||||
functions = js_support.discover_functions(main_file.read_text(encoding="utf-8"), main_file)
|
||||
target = None
|
||||
for func in functions:
|
||||
if func.function_name == "calculateStats":
|
||||
|
|
@ -135,7 +135,7 @@ def test_js_replcement() -> None:
|
|||
project_root_path=root_dir,
|
||||
pytest_cmd="jest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(
|
||||
func_optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
result = func_optimizer.get_code_optimization_context()
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def add(a, b):
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
|
|
@ -70,7 +70,7 @@ def multiply(a, b):
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 3
|
||||
names = {func.function_name for func in functions}
|
||||
|
|
@ -88,7 +88,7 @@ def without_return():
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
# Only the function with return should be discovered
|
||||
assert len(functions) == 1
|
||||
|
|
@ -107,7 +107,7 @@ class Calculator:
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
for func in functions:
|
||||
|
|
@ -126,7 +126,7 @@ def sync_function():
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
|
||||
|
|
@ -137,7 +137,7 @@ def sync_function():
|
|||
assert sync_func.is_async is False
|
||||
|
||||
def test_discover_nested_functions(self, python_support):
|
||||
"""Test discovering nested functions."""
|
||||
"""Test that nested functions are excluded — only top-level and class-level functions are discovered."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False) as f:
|
||||
f.write("""
|
||||
def outer():
|
||||
|
|
@ -147,18 +147,11 @@ def outer():
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
# Both outer and inner should be discovered
|
||||
assert len(functions) == 2
|
||||
names = {func.function_name for func in functions}
|
||||
assert names == {"outer", "inner"}
|
||||
|
||||
# Inner should have outer as parent
|
||||
inner = next(f for f in functions if f.function_name == "inner")
|
||||
assert len(inner.parents) == 1
|
||||
assert inner.parents[0].name == "outer"
|
||||
assert inner.parents[0].type == "FunctionDef"
|
||||
# Only outer should be discovered; inner is nested and skipped
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "outer"
|
||||
|
||||
def test_discover_static_method(self, python_support):
|
||||
"""Test discovering static methods."""
|
||||
|
|
@ -171,7 +164,7 @@ class Utils:
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "helper"
|
||||
|
|
@ -190,7 +183,9 @@ def sync_func():
|
|||
f.flush()
|
||||
|
||||
criteria = FunctionFilterCriteria(include_async=False)
|
||||
functions = python_support.discover_functions(Path(f.name), criteria)
|
||||
functions = python_support.discover_functions(
|
||||
Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria
|
||||
)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "sync_func"
|
||||
|
|
@ -209,7 +204,9 @@ class MyClass:
|
|||
f.flush()
|
||||
|
||||
criteria = FunctionFilterCriteria(include_methods=False)
|
||||
functions = python_support.discover_functions(Path(f.name), criteria)
|
||||
functions = python_support.discover_functions(
|
||||
Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria
|
||||
)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "standalone"
|
||||
|
|
@ -227,7 +224,7 @@ def func2():
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
func1 = next(f for f in functions if f.function_name == "func1")
|
||||
func2 = next(f for f in functions if f.function_name == "func2")
|
||||
|
|
@ -237,18 +234,20 @@ def func2():
|
|||
assert func2.starting_line == 4
|
||||
assert func2.ending_line == 7
|
||||
|
||||
def test_discover_invalid_file_returns_empty(self, python_support):
|
||||
"""Test that invalid Python file returns empty list."""
|
||||
def test_discover_invalid_file_raises(self, python_support):
|
||||
"""Test that invalid Python file raises a parse error."""
|
||||
from libcst._exceptions import ParserSyntaxError
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False) as f:
|
||||
f.write("this is not valid python {{{{")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
assert functions == []
|
||||
with pytest.raises(ParserSyntaxError):
|
||||
python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
def test_discover_nonexistent_file_returns_empty(self, python_support):
|
||||
"""Test that nonexistent file returns empty list."""
|
||||
functions = python_support.discover_functions(Path("/nonexistent/file.py"))
|
||||
def test_discover_empty_source_returns_empty(self, python_support):
|
||||
"""Test that empty source returns empty list."""
|
||||
functions = python_support.discover_functions("", Path("/nonexistent/file.py"))
|
||||
assert functions == []
|
||||
|
||||
|
||||
|
|
@ -500,7 +499,7 @@ class TestIntegration:
|
|||
file_path = Path(f.name)
|
||||
|
||||
# Discover
|
||||
functions = python_support.discover_functions(file_path)
|
||||
functions = python_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "fibonacci"
|
||||
|
|
@ -541,7 +540,7 @@ def standalone():
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = python_support.discover_functions(file_path)
|
||||
functions = python_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should find 4 functions
|
||||
assert len(functions) == 4
|
||||
|
|
@ -584,12 +583,7 @@ def process(value):
|
|||
return helper_function(value) + 1
|
||||
""")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="helper_function",
|
||||
file_path=source_file,
|
||||
starting_line=1,
|
||||
ending_line=2,
|
||||
)
|
||||
func = FunctionToOptimize(function_name="helper_function", file_path=source_file, starting_line=1, ending_line=2)
|
||||
|
||||
refs = python_support.find_references(func, project_root=tmp_path)
|
||||
|
||||
|
|
@ -646,12 +640,7 @@ def test_find_references_no_references(python_support, tmp_path):
|
|||
return 42
|
||||
""")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="isolated_function",
|
||||
file_path=source_file,
|
||||
starting_line=1,
|
||||
ending_line=2,
|
||||
)
|
||||
func = FunctionToOptimize(function_name="isolated_function", file_path=source_file, starting_line=1, ending_line=2)
|
||||
|
||||
refs = python_support.find_references(func, project_root=tmp_path)
|
||||
|
||||
|
|
@ -668,10 +657,7 @@ def test_find_references_nonexistent_function(python_support, tmp_path):
|
|||
""")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="nonexistent_function",
|
||||
file_path=source_file,
|
||||
starting_line=1,
|
||||
ending_line=2,
|
||||
function_name="nonexistent_function", file_path=source_file, starting_line=1, ending_line=2
|
||||
)
|
||||
|
||||
refs = python_support.find_references(func, project_root=tmp_path)
|
||||
|
|
|
|||
|
|
@ -821,3 +821,153 @@ export default curry(traverseEntity);"""
|
|||
# createVisitorUtils is NOT wrapped, so not exported via default
|
||||
is_utils_exported, _ = ts_analyzer.is_function_exported(code, "createVisitorUtils")
|
||||
assert is_utils_exported is False
|
||||
|
||||
|
||||
class TestNamedExportConstArrow:
|
||||
"""Tests for const arrow functions exported via named export clause.
|
||||
|
||||
Pattern: const joinBy = () => {}; export { joinBy };
|
||||
This is common in TypeScript codebases like Strapi.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def ts_analyzer(self):
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
||||
|
||||
@pytest.fixture
|
||||
def js_analyzer(self):
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
def test_named_export_const_arrow(self, ts_analyzer):
|
||||
"""const arrow function exported via separate export { } clause."""
|
||||
code = """const joinBy = (arr: string[], separator: string) => {
|
||||
return arr.join(separator);
|
||||
};
|
||||
|
||||
export { joinBy };"""
|
||||
|
||||
functions = ts_analyzer.find_functions(code)
|
||||
joinBy = next((f for f in functions if f.name == "joinBy"), None)
|
||||
assert joinBy is not None
|
||||
assert joinBy.is_exported is True
|
||||
|
||||
def test_named_export_alias(self, ts_analyzer):
|
||||
"""export { foo as bar } — foo should be marked as exported."""
|
||||
code = """const foo = (x: number) => {
|
||||
return x * 2;
|
||||
};
|
||||
|
||||
export { foo as bar };"""
|
||||
|
||||
functions = ts_analyzer.find_functions(code)
|
||||
foo = next((f for f in functions if f.name == "foo"), None)
|
||||
assert foo is not None
|
||||
assert foo.is_exported is True
|
||||
|
||||
def test_named_export_multiple(self, ts_analyzer):
|
||||
"""Multiple functions in a single export clause."""
|
||||
code = """const a = () => { return 1; };
|
||||
const b = () => { return 2; };
|
||||
const c = () => { return 3; };
|
||||
|
||||
export { a, b };"""
|
||||
|
||||
functions = ts_analyzer.find_functions(code)
|
||||
a = next((f for f in functions if f.name == "a"), None)
|
||||
b = next((f for f in functions if f.name == "b"), None)
|
||||
c = next((f for f in functions if f.name == "c"), None)
|
||||
assert a is not None and a.is_exported is True
|
||||
assert b is not None and b.is_exported is True
|
||||
assert c is not None and c.is_exported is False
|
||||
|
||||
def test_named_export_function_declaration(self, js_analyzer):
|
||||
"""Regular function declarations exported via export { }."""
|
||||
code = """function processData(data) {
|
||||
return data;
|
||||
}
|
||||
|
||||
export { processData };"""
|
||||
|
||||
functions = js_analyzer.find_functions(code)
|
||||
f = next((f for f in functions if f.name == "processData"), None)
|
||||
assert f is not None
|
||||
assert f.is_exported is True
|
||||
|
||||
def test_is_function_exported_with_named_export(self, ts_analyzer):
|
||||
"""is_function_exported should detect named export clause."""
|
||||
code = """const joinBy = (arr: string[], separator: string) => {
|
||||
return arr.join(separator);
|
||||
};
|
||||
|
||||
export { joinBy };"""
|
||||
|
||||
is_exported, name = ts_analyzer.is_function_exported(code, "joinBy")
|
||||
assert is_exported is True
|
||||
|
||||
|
||||
class TestCjsReexportObjectMethods:
|
||||
"""Tests for CJS re-export of object containing methods.
|
||||
|
||||
Pattern: const utils = { match() {} }; module.exports = utils;
|
||||
This is common in Node.js libraries like Moleculer.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def js_analyzer(self):
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
def test_cjs_reexport_object_methods(self, js_analyzer):
|
||||
"""module.exports = varName where varName is object with methods."""
|
||||
code = """const utils = {
|
||||
match(text, pattern) {
|
||||
return text.match(pattern);
|
||||
},
|
||||
slugify(str) {
|
||||
return str.toLowerCase();
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = utils;"""
|
||||
|
||||
is_exported, name = js_analyzer.is_function_exported(code, "match")
|
||||
assert is_exported is True
|
||||
|
||||
is_exported2, _ = js_analyzer.is_function_exported(code, "slugify")
|
||||
assert is_exported2 is True
|
||||
|
||||
def test_cjs_reexport_shorthand_props(self, js_analyzer):
|
||||
"""module.exports = varName where object has shorthand properties."""
|
||||
code = """function match(text, pattern) {
|
||||
return text.match(pattern);
|
||||
}
|
||||
|
||||
const utils = { match };
|
||||
module.exports = utils;"""
|
||||
|
||||
is_exported, _ = js_analyzer.is_function_exported(code, "match")
|
||||
assert is_exported is True
|
||||
|
||||
def test_cjs_reexport_pair_props(self, js_analyzer):
|
||||
"""module.exports = varName where object has key: value pairs."""
|
||||
code = """function myMatch(text, pattern) {
|
||||
return text.match(pattern);
|
||||
}
|
||||
|
||||
const utils = { match: myMatch };
|
||||
module.exports = utils;"""
|
||||
|
||||
is_exported, _ = js_analyzer.is_function_exported(code, "match")
|
||||
assert is_exported is True
|
||||
|
||||
def test_cjs_reexport_nonexistent_prop(self, js_analyzer):
|
||||
"""A function not in the re-exported object should not be exported."""
|
||||
code = """function helper() { return 1; }
|
||||
|
||||
const utils = {
|
||||
match(text) { return text; }
|
||||
};
|
||||
|
||||
module.exports = utils;"""
|
||||
|
||||
is_exported, _ = js_analyzer.is_function_exported(code, "helper")
|
||||
assert is_exported is False
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.support import TypeScriptSupport
|
||||
|
||||
|
||||
|
|
@ -126,14 +126,13 @@ export function add(a: number, b: number): number {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
|
||||
# Extract code context
|
||||
code_context = ts_support.extract_code_context(
|
||||
functions[0], file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
|
||||
|
||||
# Verify extracted code is valid
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True
|
||||
|
|
@ -164,14 +163,13 @@ export async function execMongoEval(queryExpression, appsmithMongoURI) {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "execMongoEval"
|
||||
|
||||
# Extract code context
|
||||
code_context = ts_support.extract_code_context(
|
||||
functions[0], file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
|
||||
|
||||
# Verify extracted code is valid
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True
|
||||
|
|
@ -215,14 +213,13 @@ export async function figureOutContentsPath(root: string): Promise<string> {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "figureOutContentsPath"
|
||||
|
||||
# Extract code context
|
||||
code_context = ts_support.extract_code_context(
|
||||
functions[0], file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
|
||||
|
||||
# Verify extracted code is valid
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True
|
||||
|
|
@ -246,12 +243,11 @@ export function readConfig(filename: string): string {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
|
||||
code_context = ts_support.extract_code_context(
|
||||
functions[0], file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
|
||||
|
||||
# Check that imports are captured
|
||||
assert len(code_context.imports) > 0
|
||||
|
|
@ -278,12 +274,11 @@ export async function fetchWithRetry(url: string): Promise<any> {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
|
||||
code_context = ts_support.extract_code_context(
|
||||
functions[0], file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
|
||||
|
||||
# Verify extracted code is valid
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True
|
||||
|
|
@ -324,7 +319,8 @@ export class EndpointGroup {
|
|||
file_path = Path(f.name)
|
||||
|
||||
# Discover the 'post' method
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
post_method = None
|
||||
for func in functions:
|
||||
if func.function_name == "post":
|
||||
|
|
@ -334,9 +330,7 @@ export class EndpointGroup {
|
|||
assert post_method is not None, "post method should be discovered"
|
||||
|
||||
# Extract code context
|
||||
code_context = ts_support.extract_code_context(
|
||||
post_method, file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(post_method, file_path.parent, file_path.parent)
|
||||
|
||||
# The extracted code should be syntactically valid
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True, (
|
||||
|
|
@ -352,9 +346,7 @@ export class EndpointGroup {
|
|||
# Check that addEndpoint appears BEFORE the closing brace of the class
|
||||
class_end_index = code_context.target_code.rfind("}")
|
||||
add_endpoint_index = code_context.target_code.find("addEndpoint")
|
||||
assert add_endpoint_index < class_end_index, (
|
||||
"addEndpoint should be inside the class wrapper"
|
||||
)
|
||||
assert add_endpoint_index < class_end_index, "addEndpoint should be inside the class wrapper"
|
||||
|
||||
def test_multiple_private_helpers_inside_class(self, ts_support):
|
||||
"""Test that multiple private helpers are all included inside the class."""
|
||||
|
|
@ -386,7 +378,8 @@ export class Router {
|
|||
file_path = Path(f.name)
|
||||
|
||||
# Discover the 'addRoute' method
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
add_route_method = None
|
||||
for func in functions:
|
||||
if func.function_name == "addRoute":
|
||||
|
|
@ -395,9 +388,7 @@ export class Router {
|
|||
|
||||
assert add_route_method is not None
|
||||
|
||||
code_context = ts_support.extract_code_context(
|
||||
add_route_method, file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(add_route_method, file_path.parent, file_path.parent)
|
||||
|
||||
# Should be valid TypeScript
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True
|
||||
|
|
@ -424,7 +415,8 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
add_method = None
|
||||
for func in functions:
|
||||
if func.function_name == "add":
|
||||
|
|
@ -433,18 +425,14 @@ export class Calculator {
|
|||
|
||||
assert add_method is not None
|
||||
|
||||
code_context = ts_support.extract_code_context(
|
||||
add_method, file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
||||
# 'compute' should be in target_code (inside class)
|
||||
assert "compute" in code_context.target_code
|
||||
|
||||
# 'compute' should NOT be in helper_functions (would be duplicate)
|
||||
helper_names = [h.name for h in code_context.helper_functions]
|
||||
assert "compute" not in helper_names, (
|
||||
"Same-class helper 'compute' should not be in helper_functions list"
|
||||
)
|
||||
assert "compute" not in helper_names, "Same-class helper 'compute' should not be in helper_functions list"
|
||||
|
||||
|
||||
class TestTypeScriptLanguageProperties:
|
||||
|
|
|
|||
|
|
@ -124,10 +124,8 @@ class TestTypeScriptCodeContext:
|
|||
"""Test extracting code context for a TypeScript function."""
|
||||
skip_if_ts_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
fib_file = ts_project_dir / "fibonacci.ts"
|
||||
if not fib_file.exists():
|
||||
|
|
@ -139,7 +137,11 @@ class TestTypeScriptCodeContext:
|
|||
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
|
||||
assert fib_func is not None
|
||||
|
||||
context = get_code_optimization_context(fib_func, ts_project_dir)
|
||||
ts_support = get_language_support(Language.TYPESCRIPT)
|
||||
code_context = ts_support.extract_code_context(fib_func, ts_project_dir, ts_project_dir)
|
||||
context = JavaScriptFunctionOptimizer._build_optimization_context(
|
||||
code_context, fib_file, "typescript", ts_project_dir
|
||||
)
|
||||
|
||||
assert context.read_writable_code is not None
|
||||
# Critical: language should be "typescript", not "javascript"
|
||||
|
|
|
|||
|
|
@ -118,11 +118,9 @@ class TestVitestCodeContext:
|
|||
"""Test extracting code context for a TypeScript function."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
fib_file = vitest_project_dir / "fibonacci.ts"
|
||||
if not fib_file.exists():
|
||||
|
|
@ -134,7 +132,11 @@ class TestVitestCodeContext:
|
|||
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
|
||||
assert fib_func is not None
|
||||
|
||||
context = get_code_optimization_context(fib_func, vitest_project_dir)
|
||||
ts_support = get_language_support(Language.TYPESCRIPT)
|
||||
code_context = ts_support.extract_code_context(fib_func, vitest_project_dir, vitest_project_dir)
|
||||
context = JavaScriptFunctionOptimizer._build_optimization_context(
|
||||
code_context, fib_file, "typescript", vitest_project_dir
|
||||
)
|
||||
|
||||
assert context.read_writable_code is not None
|
||||
assert context.read_writable_code.language == "typescript"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,18 @@
|
|||
import os
|
||||
import sys
|
||||
import types
|
||||
from typing import NoReturn
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from _pytest.config import Config
|
||||
|
||||
from codeflash.verification.pytest_plugin import PytestLoops
|
||||
from codeflash.verification.pytest_plugin import (
|
||||
InvalidTimeParameterError,
|
||||
PytestLoops,
|
||||
get_runtime_from_stdout,
|
||||
should_stop,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -15,39 +23,301 @@ def pytest_loops_instance(pytestconfig: Config) -> PytestLoops:
|
|||
@pytest.fixture
|
||||
def mock_item() -> type:
|
||||
class MockItem:
|
||||
def __init__(self, function: types.FunctionType) -> None:
|
||||
def __init__(self, function: types.FunctionType, name: str = "test_func", cls: type = None, module: types.ModuleType = None) -> None:
|
||||
self.function = function
|
||||
self.name = name
|
||||
self.cls = cls
|
||||
self.module = module
|
||||
|
||||
return MockItem
|
||||
|
||||
|
||||
def create_mock_module(module_name: str, source_code: str) -> types.ModuleType:
|
||||
def create_mock_module(module_name: str, source_code: str, register: bool = False) -> types.ModuleType:
|
||||
module = types.ModuleType(module_name)
|
||||
exec(source_code, module.__dict__) # noqa: S102
|
||||
if register:
|
||||
sys.modules[module_name] = module
|
||||
return module
|
||||
|
||||
|
||||
def test_clear_lru_caches_function(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
source_code = """
|
||||
def mock_session(**kwargs):
|
||||
"""Create a mock session with config options."""
|
||||
defaults = {
|
||||
"codeflash_hours": 0,
|
||||
"codeflash_minutes": 0,
|
||||
"codeflash_seconds": 10,
|
||||
"codeflash_delay": 0.0,
|
||||
"codeflash_loops": 1,
|
||||
"codeflash_min_loops": 1,
|
||||
"codeflash_max_loops": 100_000,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
|
||||
class Option:
|
||||
pass
|
||||
|
||||
option = Option()
|
||||
for k, v in defaults.items():
|
||||
setattr(option, k, v)
|
||||
|
||||
class MockConfig:
|
||||
pass
|
||||
|
||||
config = MockConfig()
|
||||
config.option = option
|
||||
|
||||
class MockSession:
|
||||
pass
|
||||
|
||||
session = MockSession()
|
||||
session.config = config
|
||||
return session
|
||||
|
||||
|
||||
# --- get_runtime_from_stdout ---
|
||||
|
||||
|
||||
class TestGetRuntimeFromStdout:
|
||||
def test_valid_payload(self) -> None:
|
||||
assert get_runtime_from_stdout("!######test_func:12345######!") == 12345
|
||||
|
||||
def test_valid_payload_with_surrounding_text(self) -> None:
|
||||
assert get_runtime_from_stdout("some output\n!######mod.func:99999######!\nmore output") == 99999
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
assert get_runtime_from_stdout("") is None
|
||||
|
||||
def test_no_markers(self) -> None:
|
||||
assert get_runtime_from_stdout("just some output") is None
|
||||
|
||||
def test_missing_end_marker(self) -> None:
|
||||
assert get_runtime_from_stdout("!######test:123") is None
|
||||
|
||||
def test_missing_start_marker(self) -> None:
|
||||
assert get_runtime_from_stdout("test:123######!") is None
|
||||
|
||||
def test_no_colon_in_payload(self) -> None:
|
||||
assert get_runtime_from_stdout("!######nocolon######!") is None
|
||||
|
||||
def test_non_integer_value(self) -> None:
|
||||
assert get_runtime_from_stdout("!######test:notanumber######!") is None
|
||||
|
||||
def test_multiple_markers_uses_last(self) -> None:
|
||||
stdout = "!######first:111######! middle !######second:222######!"
|
||||
assert get_runtime_from_stdout(stdout) == 222
|
||||
|
||||
|
||||
# --- should_stop ---
|
||||
|
||||
|
||||
class TestShouldStop:
|
||||
def test_not_enough_data_for_window(self) -> None:
|
||||
assert should_stop([100, 100], window=5, min_window_size=3) is False
|
||||
|
||||
def test_below_min_window_size(self) -> None:
|
||||
assert should_stop([100, 100], window=2, min_window_size=5) is False
|
||||
|
||||
def test_stable_runtimes_stops(self) -> None:
|
||||
runtimes = [1000000] * 10
|
||||
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.01, spread_rel_tol=0.01) is True
|
||||
|
||||
def test_unstable_runtimes_continues(self) -> None:
|
||||
runtimes = [100, 200, 100, 200, 100]
|
||||
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.01, spread_rel_tol=0.01) is False
|
||||
|
||||
def test_zero_runtimes_raises(self) -> None:
|
||||
# All-zero runtimes cause ZeroDivisionError in median check.
|
||||
# In practice the caller guards with best_runtime_until_now > 0.
|
||||
runtimes = [0, 0, 0, 0, 0]
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
should_stop(runtimes, window=5, min_window_size=3)
|
||||
|
||||
def test_even_window_median(self) -> None:
|
||||
# Even window: median is average of two middle values
|
||||
runtimes = [1000, 1000, 1001, 1001]
|
||||
assert should_stop(runtimes, window=4, min_window_size=2, center_rel_tol=0.01, spread_rel_tol=0.01) is True
|
||||
|
||||
def test_centered_but_spread_too_large(self) -> None:
|
||||
# All close to median but spread exceeds tolerance
|
||||
runtimes = [1000, 1050, 1000, 1050, 1000]
|
||||
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.1, spread_rel_tol=0.001) is False
|
||||
|
||||
|
||||
# --- _set_nodeid ---
|
||||
|
||||
|
||||
class TestSetNodeid:
|
||||
def test_appends_count_to_plain_nodeid(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
result = pytest_loops_instance._set_nodeid("test_module.py::test_func", 3) # noqa: SLF001
|
||||
assert result == "test_module.py::test_func[ 3 ]"
|
||||
assert os.environ["CODEFLASH_LOOP_INDEX"] == "3"
|
||||
|
||||
def test_replaces_existing_count(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
result = pytest_loops_instance._set_nodeid("test_module.py::test_func[ 1 ]", 5) # noqa: SLF001
|
||||
assert result == "test_module.py::test_func[ 5 ]"
|
||||
|
||||
def test_replaces_only_loop_pattern(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
# Parametrize brackets like [param0] should not be replaced
|
||||
result = pytest_loops_instance._set_nodeid("test_mod.py::test_func[param0]", 2) # noqa: SLF001
|
||||
assert result == "test_mod.py::test_func[param0][ 2 ]"
|
||||
|
||||
|
||||
# --- _get_total_time ---
|
||||
|
||||
|
||||
class TestGetTotalTime:
|
||||
def test_seconds_only(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_seconds=30)
|
||||
assert pytest_loops_instance._get_total_time(session) == 30 # noqa: SLF001
|
||||
|
||||
def test_mixed_units(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_hours=1, codeflash_minutes=30, codeflash_seconds=45)
|
||||
assert pytest_loops_instance._get_total_time(session) == 3600 + 1800 + 45 # noqa: SLF001
|
||||
|
||||
def test_zero_time_is_valid(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_hours=0, codeflash_minutes=0, codeflash_seconds=0)
|
||||
assert pytest_loops_instance._get_total_time(session) == 0 # noqa: SLF001
|
||||
|
||||
def test_negative_time_raises(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_hours=0, codeflash_minutes=0, codeflash_seconds=-1)
|
||||
with pytest.raises(InvalidTimeParameterError):
|
||||
pytest_loops_instance._get_total_time(session) # noqa: SLF001
|
||||
|
||||
|
||||
# --- _timed_out ---
|
||||
|
||||
|
||||
class TestTimedOut:
|
||||
def test_exceeds_max_loops(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_max_loops=10, codeflash_min_loops=1, codeflash_seconds=9999)
|
||||
assert pytest_loops_instance._timed_out(session, start_time=0, count=10) is True # noqa: SLF001
|
||||
|
||||
def test_below_min_loops_never_times_out(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_max_loops=100_000, codeflash_min_loops=50, codeflash_seconds=0)
|
||||
# Even with 0 seconds budget, count < min_loops means not timed out
|
||||
assert pytest_loops_instance._timed_out(session, start_time=0, count=5) is False # noqa: SLF001
|
||||
|
||||
def test_above_min_loops_and_time_exceeded(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_max_loops=100_000, codeflash_min_loops=1, codeflash_seconds=1)
|
||||
# start_time far in the past → time exceeded
|
||||
assert pytest_loops_instance._timed_out(session, start_time=0, count=2) is True # noqa: SLF001
|
||||
|
||||
|
||||
# --- _get_delay_time ---
|
||||
|
||||
|
||||
class TestGetDelayTime:
|
||||
def test_returns_configured_delay(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_delay=0.5)
|
||||
assert pytest_loops_instance._get_delay_time(session) == 0.5 # noqa: SLF001
|
||||
|
||||
|
||||
# --- pytest_runtest_logreport ---
|
||||
|
||||
|
||||
class TestRunTestLogReport:
|
||||
def test_skipped_when_stability_check_disabled(self, pytestconfig: Config) -> None:
|
||||
instance = PytestLoops(pytestconfig)
|
||||
instance.enable_stability_check = False
|
||||
|
||||
class MockReport:
|
||||
when = "call"
|
||||
passed = True
|
||||
capstdout = "!######func:12345######!"
|
||||
nodeid = "test::func"
|
||||
|
||||
instance.pytest_runtest_logreport(MockReport())
|
||||
assert instance.runtime_data_by_test_case == {}
|
||||
|
||||
def test_records_runtime_on_passed_call(self, pytestconfig: Config) -> None:
|
||||
instance = PytestLoops(pytestconfig)
|
||||
instance.enable_stability_check = True
|
||||
|
||||
class MockReport:
|
||||
when = "call"
|
||||
passed = True
|
||||
capstdout = "!######func:12345######!"
|
||||
nodeid = "test::func [ 1 ]"
|
||||
|
||||
instance.pytest_runtest_logreport(MockReport())
|
||||
assert "test::func" in instance.runtime_data_by_test_case
|
||||
assert instance.runtime_data_by_test_case["test::func"] == [12345]
|
||||
|
||||
def test_ignores_non_call_phase(self, pytestconfig: Config) -> None:
|
||||
instance = PytestLoops(pytestconfig)
|
||||
instance.enable_stability_check = True
|
||||
|
||||
class MockReport:
|
||||
when = "setup"
|
||||
passed = True
|
||||
capstdout = "!######func:12345######!"
|
||||
nodeid = "test::func"
|
||||
|
||||
instance.pytest_runtest_logreport(MockReport())
|
||||
assert instance.runtime_data_by_test_case == {}
|
||||
|
||||
|
||||
# --- pytest_runtest_setup / teardown ---
|
||||
|
||||
|
||||
class TestRunTestSetupTeardown:
|
||||
def test_setup_sets_env_vars(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module = types.ModuleType("my_test_module")
|
||||
|
||||
class MyTestClass:
|
||||
pass
|
||||
|
||||
item = mock_item(lambda: None, name="test_something[param1]", cls=MyTestClass, module=module)
|
||||
pytest_loops_instance.pytest_runtest_setup(item)
|
||||
|
||||
assert os.environ["CODEFLASH_TEST_MODULE"] == "my_test_module"
|
||||
assert os.environ["CODEFLASH_TEST_CLASS"] == "MyTestClass"
|
||||
assert os.environ["CODEFLASH_TEST_FUNCTION"] == "test_something"
|
||||
|
||||
def test_setup_no_class(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module = types.ModuleType("my_test_module")
|
||||
item = mock_item(lambda: None, name="test_plain", cls=None, module=module)
|
||||
pytest_loops_instance.pytest_runtest_setup(item)
|
||||
|
||||
assert os.environ["CODEFLASH_TEST_CLASS"] == ""
|
||||
|
||||
def test_teardown_clears_env_vars(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
os.environ["CODEFLASH_TEST_MODULE"] = "leftover"
|
||||
os.environ["CODEFLASH_TEST_CLASS"] = "leftover"
|
||||
os.environ["CODEFLASH_TEST_FUNCTION"] = "leftover"
|
||||
|
||||
item = mock_item(lambda: None)
|
||||
pytest_loops_instance.pytest_runtest_teardown(item)
|
||||
|
||||
assert "CODEFLASH_TEST_MODULE" not in os.environ
|
||||
assert "CODEFLASH_TEST_CLASS" not in os.environ
|
||||
assert "CODEFLASH_TEST_FUNCTION" not in os.environ
|
||||
|
||||
|
||||
# --- _clear_lru_caches ---
|
||||
|
||||
|
||||
class TestClearLruCaches:
|
||||
def test_clears_lru_cached_function(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def my_func(x):
|
||||
return x * 2
|
||||
|
||||
my_func(10) # miss the cache
|
||||
my_func(10) # hit the cache
|
||||
my_func(10)
|
||||
my_func(10)
|
||||
"""
|
||||
mock_module = create_mock_module("test_module_func", source_code)
|
||||
item = mock_item(mock_module.my_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert mock_module.my_func.cache_info().hits == 0
|
||||
assert mock_module.my_func.cache_info().misses == 0
|
||||
assert mock_module.my_func.cache_info().currsize == 0
|
||||
mock_module = create_mock_module("test_module_func", source_code)
|
||||
item = mock_item(mock_module.my_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert mock_module.my_func.cache_info().hits == 0
|
||||
assert mock_module.my_func.cache_info().misses == 0
|
||||
assert mock_module.my_func.cache_info().currsize == 0
|
||||
|
||||
|
||||
def test_clear_lru_caches_class_method(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
source_code = """
|
||||
def test_clears_class_method_cache(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
class MyClass:
|
||||
|
|
@ -56,32 +326,137 @@ class MyClass:
|
|||
return x * 3
|
||||
|
||||
obj = MyClass()
|
||||
obj.my_method(5) # Pre-populate the cache
|
||||
obj.my_method(5) # Hit the cache
|
||||
obj.my_method(5)
|
||||
obj.my_method(5)
|
||||
# """
|
||||
mock_module = create_mock_module("test_module_class", source_code)
|
||||
item = mock_item(mock_module.MyClass.my_method)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert mock_module.MyClass.my_method.cache_info().hits == 0
|
||||
assert mock_module.MyClass.my_method.cache_info().misses == 0
|
||||
assert mock_module.MyClass.my_method.cache_info().currsize == 0
|
||||
mock_module = create_mock_module("test_module_class", source_code)
|
||||
item = mock_item(mock_module.MyClass.my_method)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert mock_module.MyClass.my_method.cache_info().hits == 0
|
||||
assert mock_module.MyClass.my_method.cache_info().misses == 0
|
||||
assert mock_module.MyClass.my_method.cache_info().currsize == 0
|
||||
|
||||
def test_handles_exception_in_cache_clear(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
class BrokenCache:
|
||||
def cache_clear(self) -> NoReturn:
|
||||
msg = "Cache clearing failed!"
|
||||
raise ValueError(msg)
|
||||
|
||||
def test_clear_lru_caches_exception_handling(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
"""Test that exceptions during clearing are handled."""
|
||||
item = mock_item(BrokenCache())
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
|
||||
class BrokenCache:
|
||||
def cache_clear(self) -> NoReturn:
|
||||
msg = "Cache clearing failed!"
|
||||
raise ValueError(msg)
|
||||
def test_handles_no_cache(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
def no_cache_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
item = mock_item(BrokenCache())
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
item = mock_item(no_cache_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
|
||||
def test_clears_module_level_caches_via_sys_modules(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module_name = "_cf_test_module_scan"
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
def test_clear_lru_caches_no_cache(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
def no_cache_func(x: int) -> int:
|
||||
return x
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def cached_a(x):
|
||||
return x + 1
|
||||
|
||||
item = mock_item(no_cache_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def cached_b(x):
|
||||
return x + 2
|
||||
|
||||
def plain_func(x):
|
||||
return x
|
||||
|
||||
cached_a(1)
|
||||
cached_a(1)
|
||||
cached_b(2)
|
||||
cached_b(2)
|
||||
"""
|
||||
mock_module = create_mock_module(module_name, source_code, register=True)
|
||||
try:
|
||||
item = mock_item(mock_module.plain_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
|
||||
assert mock_module.cached_a.cache_info().currsize == 0
|
||||
assert mock_module.cached_b.cache_info().currsize == 0
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
def test_skips_protected_modules(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module_name = "_cf_test_protected"
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def user_func(x):
|
||||
return x
|
||||
"""
|
||||
mock_module = create_mock_module(module_name, source_code, register=True)
|
||||
try:
|
||||
mock_module.os_exists = os.path.exists
|
||||
item = mock_item(mock_module.user_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
def test_caches_scan_result(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module_name = "_cf_test_cache_reuse"
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def cached_fn(x):
|
||||
return x
|
||||
"""
|
||||
mock_module = create_mock_module(module_name, source_code, register=True)
|
||||
try:
|
||||
item = mock_item(mock_module.cached_fn)
|
||||
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert module_name in pytest_loops_instance._module_clearables # noqa: SLF001
|
||||
|
||||
mock_module.cached_fn(42)
|
||||
assert mock_module.cached_fn.cache_info().currsize == 1
|
||||
|
||||
with patch("codeflash.verification.pytest_plugin.inspect.getmembers") as mock_getmembers:
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
mock_getmembers.assert_not_called()
|
||||
|
||||
assert mock_module.cached_fn.cache_info().currsize == 0
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
def test_handles_wrapped_function(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module_name = "_cf_test_wrapped"
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def inner(x):
|
||||
return x
|
||||
|
||||
def wrapper(x):
|
||||
return inner(x)
|
||||
|
||||
wrapper.__wrapped__ = inner
|
||||
wrapper.__module__ = __name__
|
||||
|
||||
inner(1)
|
||||
inner(1)
|
||||
"""
|
||||
mock_module = create_mock_module(module_name, source_code, register=True)
|
||||
try:
|
||||
item = mock_item(mock_module.wrapper)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert mock_module.inner.cache_info().currsize == 0
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
def test_handles_function_without_module(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
def func() -> None:
|
||||
pass
|
||||
|
||||
func.__module__ = None # type: ignore[assignment]
|
||||
item = mock_item(func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ import tempfile
|
|||
from pathlib import Path
|
||||
|
||||
from codeflash.code_utils.code_utils import ImportErrorPattern
|
||||
from codeflash.languages import current_language_support
|
||||
from codeflash.models.models import TestFile, TestFiles, TestType
|
||||
from codeflash.verification.parse_test_output import parse_test_xml
|
||||
from codeflash.verification.test_runner import run_behavioral_tests
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -48,8 +48,8 @@ class TestUnittestRunnerSorter(unittest.TestCase):
|
|||
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
result_file, process, _, _ = run_behavioral_tests(
|
||||
test_files, test_framework=config.test_framework, cwd=Path(config.project_root_path), test_env=test_env
|
||||
result_file, process, _, _ = current_language_support().run_behavioral_tests(
|
||||
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path)
|
||||
)
|
||||
results = parse_test_xml(result_file, test_files, config, process)
|
||||
assert results[0].did_pass, "Test did not pass as expected"
|
||||
|
|
@ -89,13 +89,8 @@ def test_sort():
|
|||
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
result_file, process, _, _ = run_behavioral_tests(
|
||||
test_files,
|
||||
test_framework=config.test_framework,
|
||||
cwd=Path(config.project_root_path),
|
||||
test_env=test_env,
|
||||
pytest_timeout=1,
|
||||
pytest_target_runtime_seconds=1,
|
||||
result_file, process, _, _ = current_language_support().run_behavioral_tests(
|
||||
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path), timeout=1
|
||||
)
|
||||
results = parse_test_xml(
|
||||
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process
|
||||
|
|
@ -136,13 +131,8 @@ def test_sort():
|
|||
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
result_file, process, _, _ = run_behavioral_tests(
|
||||
test_files,
|
||||
test_framework=config.test_framework,
|
||||
cwd=Path(config.project_root_path),
|
||||
test_env=test_env,
|
||||
pytest_timeout=1,
|
||||
pytest_target_runtime_seconds=1,
|
||||
result_file, process, _, _ = current_language_support().run_behavioral_tests(
|
||||
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path), timeout=1
|
||||
)
|
||||
results = parse_test_xml(
|
||||
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from codeflash.languages.python.context.unused_definition_remover import (
|
|||
detect_unused_helper_functions,
|
||||
revert_unused_helper_functions,
|
||||
)
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -83,7 +83,7 @@ def helper_function_2(x):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -194,7 +194,7 @@ def helper_function_2(x):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -269,7 +269,7 @@ def helper_function_2(x):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -365,7 +365,7 @@ def entrypoint_function(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -559,7 +559,7 @@ class Calculator:
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -710,7 +710,7 @@ class Processor:
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -895,7 +895,7 @@ class OuterClass:
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1051,7 +1051,7 @@ def entrypoint_function(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1215,7 +1215,7 @@ def entrypoint_function(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1442,7 +1442,7 @@ class MathUtils:
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1576,7 +1576,7 @@ async def async_entrypoint(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1664,7 +1664,7 @@ def sync_entrypoint(n):
|
|||
function_to_optimize = FunctionToOptimize(file_path=main_file, function_name="sync_entrypoint", parents=[])
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1773,7 +1773,7 @@ async def mixed_entrypoint(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1874,7 +1874,7 @@ class AsyncProcessor:
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1960,7 +1960,7 @@ async def async_entrypoint(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -2039,7 +2039,7 @@ def gcd_recursive(a: int, b: int) -> int:
|
|||
function_to_optimize = FunctionToOptimize(file_path=main_file, function_name="gcd_recursive", parents=[])
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -2152,7 +2152,7 @@ async def async_entrypoint_with_generators(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
|
|||
|
|
@ -61,9 +61,9 @@ def test_mirror_paths_for_worktree_mode(monkeypatch: pytest.MonkeyPatch):
|
|||
assert optimizer.args.test_project_root == worktree_dir
|
||||
assert optimizer.args.module_root == worktree_dir / "codeflash"
|
||||
# tests_root is configured as "codeflash" in pyproject.toml
|
||||
assert optimizer.args.tests_root == worktree_dir / "codeflash"
|
||||
assert optimizer.args.tests_root == worktree_dir / "tests"
|
||||
assert optimizer.args.file == worktree_dir / "codeflash/optimization/optimizer.py"
|
||||
|
||||
assert optimizer.test_cfg.tests_root == worktree_dir / "codeflash"
|
||||
assert optimizer.test_cfg.tests_root == worktree_dir / "tests"
|
||||
assert optimizer.test_cfg.project_root_path == worktree_dir # same as project_root
|
||||
assert optimizer.test_cfg.tests_project_rootdir == worktree_dir # same as test_project_root
|
||||
|
|
|
|||
Loading…
Reference in a new issue