test: sync test files from main (safe, main-only changes)

34 test files updated with main's refactored tests for new language
support protocol, JS/TS improvements, and code context extraction.
This commit is contained in:
Kevin Turcios 2026-03-02 15:25:50 -05:00
parent 2299d26ae5
commit 19bd6e4bad
34 changed files with 3105 additions and 1168 deletions

View file

@ -493,3 +493,37 @@ def my_function():
return helper
"""
assert result == expected_result
def test_module_input_preserves_comment_position_after_imports() -> None:
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.models.models import CodeContextType
src_code = """from __future__ import annotations
import re
# Comment about PATTERN.
PATTERN = re.compile(r"test")
def parse():
return PATTERN.findall("")
"""
pruned_module = parse_code_and_prune_cst(src_code, CodeContextType.READ_WRITABLE, {"parse"})
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
file_path = project_root / "mod.py"
file_path.write_text(src_code)
result = add_needed_imports_from_module(src_code, pruned_module, file_path, file_path, project_root)
expected = """from __future__ import annotations
import re
# Comment about PATTERN.
PATTERN = re.compile(r"test")
def parse():
return PATTERN.findall("")
"""
assert result == expected

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
from codeflash.code_utils.deduplicate_code import are_codes_duplicate, normalize_code
from codeflash.languages.python.normalizer import normalize_python_code as normalize_code
def test_deduplicate1():
@ -23,7 +23,7 @@ def compute_sum(numbers):
"""
assert normalize_code(code1) == normalize_code(code2)
assert are_codes_duplicate(code1, code2)
assert normalize_code(code1) == normalize_code(code2)
# Example 3: Same function and parameter names, different local variables (should match)
code3 = """
@ -43,7 +43,7 @@ def calculate_sum(numbers):
"""
assert normalize_code(code3) == normalize_code(code4)
assert are_codes_duplicate(code3, code4)
assert normalize_code(code3) == normalize_code(code4)
# Example 4: Nested functions and classes (preserving names)
code5 = """

View file

@ -11,7 +11,6 @@ from codeflash.languages.python.static_analysis.code_extractor import delete___f
from codeflash.languages.python.static_analysis.code_replacer import (
AddRequestArgument,
AutouseFixtureModifier,
OptimFunctionCollector,
PytestMarkAdder,
is_zero_diff,
replace_functions_and_add_imports,
@ -19,7 +18,7 @@ from codeflash.languages.python.static_analysis.code_replacer import (
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, FunctionSource
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
@ -55,7 +54,7 @@ def sorter(arr):
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -808,6 +807,7 @@ def test_code_replacement10() -> None:
get_code_output = """# file: test_code_replacement.py
from __future__ import annotations
class HelperClass:
def __init__(self, name):
self.name = name
@ -834,7 +834,7 @@ class MainClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
code_context = func_optimizer.get_code_optimization_context().unwrap()
assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip()
@ -1745,7 +1745,7 @@ class NewClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -1824,7 +1824,7 @@ a=2
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -1904,7 +1904,7 @@ class NewClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -1983,7 +1983,7 @@ class NewClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -2063,7 +2063,7 @@ class NewClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -2153,7 +2153,7 @@ class NewClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -3453,7 +3453,7 @@ def hydrate_input_text_actions_with_field_names(
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
@ -3476,142 +3476,6 @@ def hydrate_input_text_actions_with_field_names(
assert new_code == expected
# OptimFunctionCollector async function tests
def test_optim_function_collector_with_async_functions():
"""Test OptimFunctionCollector correctly collects async functions."""
import libcst as cst
source_code = """
def sync_function():
return "sync"
async def async_function():
return "async"
class TestClass:
def sync_method(self):
return "sync_method"
async def async_method(self):
return "async_method"
"""
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names={
(None, "sync_function"),
(None, "async_function"),
("TestClass", "sync_method"),
("TestClass", "async_method"),
},
preexisting_objects=None,
)
tree.visit(collector)
# Should collect both sync and async functions
assert len(collector.modified_functions) == 4
assert (None, "sync_function") in collector.modified_functions
assert (None, "async_function") in collector.modified_functions
assert ("TestClass", "sync_method") in collector.modified_functions
assert ("TestClass", "async_method") in collector.modified_functions
def test_optim_function_collector_new_async_functions():
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
import libcst as cst
source_code = """
def existing_function():
return "existing"
async def new_async_function():
return "new_async"
def new_sync_function():
return "new_sync"
class ExistingClass:
async def new_class_async_method(self):
return "new_class_async"
"""
# Only existing_function is in preexisting objects
preexisting_objects = {("existing_function", ())}
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names=set(), # Not looking for specific functions
preexisting_objects=preexisting_objects,
)
tree.visit(collector)
# Should identify new functions (both sync and async)
assert len(collector.new_functions) == 2
function_names = [func.name.value for func in collector.new_functions]
assert "new_async_function" in function_names
assert "new_sync_function" in function_names
# Should identify new class methods
assert "ExistingClass" in collector.new_class_functions
assert len(collector.new_class_functions["ExistingClass"]) == 1
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
def test_optim_function_collector_mixed_scenarios():
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
import libcst as cst
source_code = """
# Global functions
def global_sync():
pass
async def global_async():
pass
class ParentClass:
def __init__(self):
pass
def sync_method(self):
pass
async def async_method(self):
pass
class ChildClass:
async def child_async_method(self):
pass
def child_sync_method(self):
pass
"""
# Looking for specific functions
function_names = {
(None, "global_sync"),
(None, "global_async"),
("ParentClass", "sync_method"),
("ParentClass", "async_method"),
("ChildClass", "child_async_method"),
}
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(function_names=function_names, preexisting_objects=None)
tree.visit(collector)
# Should collect all specified functions (mix of sync and async)
assert len(collector.modified_functions) == 5
assert (None, "global_sync") in collector.modified_functions
assert (None, "global_async") in collector.modified_functions
assert ("ParentClass", "sync_method") in collector.modified_functions
assert ("ParentClass", "async_method") in collector.modified_functions
assert ("ChildClass", "child_async_method") in collector.modified_functions
# Should collect __init__ method
assert "ParentClass" in collector.modified_init_functions
def test_is_zero_diff_async_sleep():
original_code = """
import time

View file

@ -417,6 +417,312 @@ def test_standard_python_library_objects() -> None:
assert not comparator(id1, id3)
def test_itertools_count() -> None:
import itertools
# Equal: same start and step (default step=1)
assert comparator(itertools.count(0), itertools.count(0))
assert comparator(itertools.count(5), itertools.count(5))
assert comparator(itertools.count(0, 1), itertools.count(0, 1))
assert comparator(itertools.count(10, 3), itertools.count(10, 3))
# Equal: negative start and step
assert comparator(itertools.count(-5, -2), itertools.count(-5, -2))
# Equal: float start and step
assert comparator(itertools.count(0.5, 0.1), itertools.count(0.5, 0.1))
# Not equal: different start
assert not comparator(itertools.count(0), itertools.count(1))
assert not comparator(itertools.count(5), itertools.count(10))
# Not equal: different step
assert not comparator(itertools.count(0, 1), itertools.count(0, 2))
assert not comparator(itertools.count(0, 1), itertools.count(0, -1))
# Not equal: different type
assert not comparator(itertools.count(0), 0)
assert not comparator(itertools.count(0), [0, 1, 2])
# Equal after partial consumption (both advanced to the same state)
a = itertools.count(0)
b = itertools.count(0)
next(a)
next(b)
assert comparator(a, b)
# Not equal after different consumption
a = itertools.count(0)
b = itertools.count(0)
next(a)
assert not comparator(a, b)
# Works inside containers
assert comparator([itertools.count(0)], [itertools.count(0)])
assert comparator({"key": itertools.count(5, 2)}, {"key": itertools.count(5, 2)})
assert not comparator([itertools.count(0)], [itertools.count(1)])
def test_itertools_repeat() -> None:
import itertools
# Equal: infinite repeat
assert comparator(itertools.repeat(5), itertools.repeat(5))
assert comparator(itertools.repeat("hello"), itertools.repeat("hello"))
# Equal: bounded repeat
assert comparator(itertools.repeat(5, 3), itertools.repeat(5, 3))
assert comparator(itertools.repeat(None, 10), itertools.repeat(None, 10))
# Not equal: different value
assert not comparator(itertools.repeat(5), itertools.repeat(6))
assert not comparator(itertools.repeat(5, 3), itertools.repeat(6, 3))
# Not equal: different count
assert not comparator(itertools.repeat(5, 3), itertools.repeat(5, 4))
# Not equal: bounded vs infinite
assert not comparator(itertools.repeat(5), itertools.repeat(5, 3))
# Not equal: different type
assert not comparator(itertools.repeat(5), 5)
assert not comparator(itertools.repeat(5), [5])
# Equal after partial consumption
a = itertools.repeat(5, 5)
b = itertools.repeat(5, 5)
next(a)
next(b)
assert comparator(a, b)
# Not equal after different consumption
a = itertools.repeat(5, 5)
b = itertools.repeat(5, 5)
next(a)
assert not comparator(a, b)
# Works inside containers
assert comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 3)])
assert not comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 4)])
def test_itertools_cycle() -> None:
import itertools
# Equal: same sequence
assert comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 3]))
assert comparator(itertools.cycle("abc"), itertools.cycle("abc"))
# Not equal: different sequence
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 4]))
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2]))
# Not equal: different type
assert not comparator(itertools.cycle([1, 2, 3]), [1, 2, 3])
# Equal after same partial consumption
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
next(a)
next(b)
assert comparator(a, b)
# Not equal after different consumption
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
next(a)
assert not comparator(a, b)
# Equal after consuming a full cycle
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
for _ in range(3):
next(a)
next(b)
assert comparator(a, b)
# Equal at same position across different full-cycle counts
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
for _ in range(4):
next(a)
for _ in range(7):
next(b)
# Both at position 1 within the cycle (4%3 == 7%3 == 1)
assert comparator(a, b)
# Works inside containers
assert comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 2])])
assert not comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 3])])
def test_itertools_chain() -> None:
import itertools
assert comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 4]))
assert not comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 5]))
assert comparator(itertools.chain.from_iterable([[1, 2], [3]]), itertools.chain.from_iterable([[1, 2], [3]]))
assert comparator(itertools.chain(), itertools.chain())
assert not comparator(itertools.chain([1]), itertools.chain([1, 2]))
def test_itertools_islice() -> None:
import itertools
assert comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 5))
assert not comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 6))
assert comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 5))
assert not comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 6))
def test_itertools_product() -> None:
import itertools
assert comparator(itertools.product("AB", repeat=2), itertools.product("AB", repeat=2))
assert not comparator(itertools.product("AB", repeat=2), itertools.product("AC", repeat=2))
assert comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 4]))
assert not comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 5]))
def test_itertools_permutations_combinations() -> None:
import itertools
assert comparator(itertools.permutations("ABC", 2), itertools.permutations("ABC", 2))
assert not comparator(itertools.permutations("ABC", 2), itertools.permutations("ABD", 2))
assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2))
assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3))
assert comparator(
itertools.combinations_with_replacement("ABC", 2),
itertools.combinations_with_replacement("ABC", 2),
)
assert not comparator(
itertools.combinations_with_replacement("ABC", 2),
itertools.combinations_with_replacement("ABD", 2),
)
def test_itertools_accumulate() -> None:
import itertools
assert comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 4]))
assert not comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 5]))
assert comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=10))
assert not comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=0))
def test_itertools_filtering() -> None:
import itertools
# compress
assert comparator(
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
)
assert not comparator(
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1]),
)
# dropwhile
assert comparator(
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
)
assert not comparator(
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1]),
)
# takewhile
assert comparator(
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
)
assert not comparator(
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1]),
)
# filterfalse
assert comparator(
itertools.filterfalse(lambda x: x % 2, range(10)),
itertools.filterfalse(lambda x: x % 2, range(10)),
)
def test_itertools_starmap() -> None:
import itertools
assert comparator(
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
)
assert not comparator(
itertools.starmap(pow, [(2, 3), (3, 2)]),
itertools.starmap(pow, [(2, 3), (3, 3)]),
)
def test_itertools_zip_longest() -> None:
import itertools
assert comparator(
itertools.zip_longest("AB", "xyz", fillvalue="-"),
itertools.zip_longest("AB", "xyz", fillvalue="-"),
)
assert not comparator(
itertools.zip_longest("AB", "xyz", fillvalue="-"),
itertools.zip_longest("AB", "xyz", fillvalue="*"),
)
def test_itertools_groupby() -> None:
import itertools
assert comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBBCC"))
assert not comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBCC"))
assert comparator(itertools.groupby([]), itertools.groupby([]))
# With key function
assert comparator(
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
)
@pytest.mark.skipif(sys.version_info < (3, 10), reason="itertools.pairwise requires Python 3.10+")
def test_itertools_pairwise() -> None:
import itertools
assert comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 4]))
assert not comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 5]))
@pytest.mark.skipif(sys.version_info < (3, 12), reason="itertools.batched requires Python 3.12+")
def test_itertools_batched() -> None:
import itertools
assert comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 3))
assert not comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 2))
def test_itertools_in_containers() -> None:
import itertools
# Itertools objects nested in dicts/lists
assert comparator(
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
)
assert not comparator(
[itertools.product("AB", repeat=2)],
[itertools.product("AC", repeat=2)],
)
# Different itertools types should not match
assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2))
def test_numpy():
try:
import numpy as np
@ -5216,3 +5522,67 @@ class TestPythonTempfilePaths:
assert PYTHON_TEMPFILE_PATTERN.search("/tmp/tmp123456/")
assert not PYTHON_TEMPFILE_PATTERN.search("/tmp/mydir/file.txt")
assert not PYTHON_TEMPFILE_PATTERN.search("/home/tmp123/file.txt")
@pytest.mark.skipif(sys.version_info < (3, 10), reason="types.UnionType requires Python 3.10+")
class TestUnionType:
def test_union_type_equal(self):
assert comparator(int | str, int | str)
def test_union_type_not_equal(self):
assert not comparator(int | str, int | float)
def test_union_type_order_independent(self):
assert comparator(int | str, str | int)
def test_union_type_multiple_args(self):
assert comparator(int | str | float, int | str | float)
def test_union_type_in_list(self):
assert comparator([int | str, 1], [int | str, 1])
def test_union_type_in_dict(self):
assert comparator({"key": int | str}, {"key": int | str})
def test_union_type_vs_none(self):
assert not comparator(int | str, None)
class SlotsOnly:
__slots__ = ("x", "y")
def __init__(self, x, y):
self.x = x
self.y = y
class SlotsInherited(SlotsOnly):
__slots__ = ("z",)
def __init__(self, x, y, z):
super().__init__(x, y)
self.z = z
class TestSlotsObjects:
def test_slots_equal(self):
assert comparator(SlotsOnly(1, 2), SlotsOnly(1, 2))
def test_slots_not_equal(self):
assert not comparator(SlotsOnly(1, 2), SlotsOnly(1, 3))
def test_slots_inherited_equal(self):
assert comparator(SlotsInherited(1, 2, 3), SlotsInherited(1, 2, 3))
def test_slots_inherited_not_equal(self):
assert not comparator(SlotsInherited(1, 2, 3), SlotsInherited(1, 2, 4))
def test_slots_nested(self):
a = SlotsOnly(SlotsOnly(1, 2), [3, 4])
b = SlotsOnly(SlotsOnly(1, 2), [3, 4])
assert comparator(a, b)
def test_slots_nested_not_equal(self):
a = SlotsOnly(SlotsOnly(1, 2), [3, 4])
b = SlotsOnly(SlotsOnly(1, 9), [3, 4])
assert not comparator(a, b)

View file

@ -5,7 +5,7 @@ import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import is_successful
from codeflash.models.models import FunctionParent
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -132,7 +132,7 @@ def test_class_method_dependencies() -> None:
starting_line=None,
ending_line=None,
)
func_optimizer = FunctionOptimizer(
func_optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=TestConfig(
tests_root=file_path,
@ -163,6 +163,7 @@ def test_class_method_dependencies() -> None:
== """# file: test_function_dependencies.py
from collections import defaultdict
class Graph:
def __init__(self, vertices):
self.graph = defaultdict(list)
@ -201,7 +202,7 @@ def test_recursive_function_context() -> None:
starting_line=None,
ending_line=None,
)
func_optimizer = FunctionOptimizer(
func_optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=TestConfig(
tests_root=file_path,

View file

@ -680,8 +680,14 @@ def test_in_dunder_tests():
# Combine all discovered functions
all_functions = {}
for discovered in [discovered_source, discovered_test, discovered_test_underscore,
discovered_spec, discovered_tests_dir, discovered_dunder_tests]:
for discovered in [
discovered_source,
discovered_test,
discovered_test_underscore,
discovered_spec,
discovered_tests_dir,
discovered_dunder_tests,
]:
all_functions.update(discovered)
# Test Case 1: tests_root == module_root (overlapping case)
@ -781,9 +787,7 @@ def test_filter_functions_strict_string_matching():
# Strict check: exactly these 3 files should remain (those with 'test' as substring only)
expected_files = {contest_file, latest_file, attestation_file}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[contest_file]] == ["run_contest"], (
@ -871,9 +875,7 @@ def test_filter_functions_test_directory_patterns():
# Strict check: exactly these 2 files should remain (those in non-test directories)
expected_files = {contest_file, latest_file}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[contest_file]] == ["get_scores"], (
@ -936,9 +938,7 @@ def test_filter_functions_non_overlapping_tests_root():
# Strict check: exactly these 2 files should remain (both in src/, not in tests/)
expected_files = {source_file, test_in_src}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[source_file]] == ["process"], (
@ -1047,20 +1047,15 @@ def test_deep_copy():
)
root_functions = [fn.function_name for fn in filtered.get(root_source_file, [])]
assert root_functions == ["main"], (
f"Expected ['main'], got {root_functions}"
)
assert root_functions == ["main"], f"Expected ['main'], got {root_functions}"
# Strict check: exactly 3 functions (2 from utils.py + 1 from main.py)
assert count == 3, (
f"Expected exactly 3 functions, got {count}. "
f"Some source files may have been incorrectly filtered."
f"Expected exactly 3 functions, got {count}. Some source files may have been incorrectly filtered."
)
# Verify test file was properly filtered (should not be in results)
assert test_file not in filtered, (
f"Test file {test_file} should have been filtered but wasn't"
)
assert test_file not in filtered, f"Test file {test_file} should have been filtered but wasn't"
def test_filter_functions_typescript_project_in_tests_folder():
@ -1214,9 +1209,7 @@ def sample_data():
# source_file and file_in_test_dir should remain
# test_prefix_file, conftest_file, and test_in_subdir should be filtered
expected_files = {source_file, file_in_test_dir}
assert set(filtered.keys()) == expected_files, (
f"Expected {expected_files}, got {set(filtered.keys())}"
)
assert set(filtered.keys()) == expected_files, f"Expected {expected_files}, got {set(filtered.keys())}"
assert count == 2, f"Expected exactly 2 functions, got {count}"
@ -1266,7 +1259,8 @@ class TestHelpers:
""")
support = PythonSupport()
functions = support.discover_functions(fixture_file)
source = fixture_file.read_text(encoding="utf-8")
functions = support.discover_functions(source, fixture_file)
function_names = [fn.function_name for fn in functions]
assert "regular_function" in function_names

View file

@ -7,7 +7,7 @@ import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import is_successful
from codeflash.models.models import FunctionParent, get_code_block_splitter
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.verification_utils import TestConfig
@ -233,7 +233,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
with open(file_path) as f:
original_code = f.read()
ctx_result = func_optimizer.get_code_optimization_context()
@ -404,7 +404,7 @@ def test_bubble_sort_deps() -> None:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
with open(file_path) as f:
original_code = f.read()
ctx_result = func_optimizer.get_code_optimization_context()
@ -427,6 +427,7 @@ def dep2_swap(arr, j):
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
def sorter_deps(arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):

View file

@ -23,7 +23,7 @@ def test_basic_class() -> None:
class_var = "value"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -47,7 +47,7 @@ def test_dunder_methods() -> None:
return f"Value: {self.x}"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -75,7 +75,7 @@ def test_dunder_methods_remove_docstring() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
)
).code
assert dedent(expected).strip() == output.strip()
@ -102,7 +102,7 @@ def test_class_remove_docstring() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
)
).code
assert dedent(expected).strip() == output.strip()
@ -131,7 +131,7 @@ def test_mixed_remove_docstring() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set(), remove_docstrings=True
)
).code
assert dedent(expected).strip() == output.strip()
@ -171,7 +171,7 @@ def test_docstrings() -> None:
\"\"\"Class docstring.\"\"\"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -190,7 +190,7 @@ def test_method_signatures() -> None:
expected = """"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -212,7 +212,7 @@ def test_multiple_top_level_targets() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"TestClass.target1", "TestClass.target2"}, set()
)
).code
assert dedent(expected).strip() == output.strip()
@ -232,7 +232,7 @@ def test_class_annotations() -> None:
var2: str
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -254,7 +254,7 @@ def test_class_annotations_if() -> None:
var2: str
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -280,7 +280,7 @@ def test_class_annotations_try() -> None:
continue
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -316,7 +316,7 @@ def test_class_annotations_else() -> None:
var2: str
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -331,7 +331,7 @@ def test_top_level_functions() -> None:
expected = """"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -350,7 +350,7 @@ def test_module_var() -> None:
x = 5
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -377,7 +377,7 @@ def test_module_var_if() -> None:
z = 10
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"target_function"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -412,7 +412,7 @@ def test_conditional_class_definitions() -> None:
platform = "other"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"PlatformClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -471,7 +471,7 @@ def test_multiple_except_clauses() -> None:
error_type = "cleanup"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -524,7 +524,7 @@ def test_with_statement_and_loops() -> None:
context = "cleanup"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -573,7 +573,7 @@ def test_async_with_try_except() -> None:
status = "cancelled"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_ONLY, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -675,7 +675,7 @@ def test_simplified_complete_implementation() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.READ_ONLY, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
)
).code
assert dedent(expected).strip() == output.strip()
@ -768,5 +768,5 @@ def test_simplified_complete_implementation_no_docstring() -> None:
{"DataProcessor.target_method", "ResultHandler.target_method"},
set(),
remove_docstrings=True,
)
).code
assert dedent(expected).strip() == output.strip()

View file

@ -13,7 +13,7 @@ def test_simple_function() -> None:
y = 2
return x + y
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
expected = dedent("""
def target_function():
@ -32,7 +32,7 @@ def test_class_method() -> None:
y = 2
return x + y
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_function"}).code
expected = dedent("""
class MyClass:
@ -56,7 +56,7 @@ def test_class_with_attributes() -> None:
def other_method(self):
print("this should be excluded")
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
expected = dedent("""
class MyClass:
@ -80,7 +80,7 @@ def test_basic_class_structure() -> None:
def not_findable(self):
return 42
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"Outer.target_method"}).code
expected = dedent("""
class Outer:
@ -100,7 +100,7 @@ def test_top_level_targets() -> None:
def target_function():
return 42
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
expected = dedent("""
def target_function():
@ -123,7 +123,7 @@ def test_multiple_top_level_classes() -> None:
def process(self):
return "C"
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"ClassA.process", "ClassC.process"}).code
expected = dedent("""
class ClassA:
@ -148,7 +148,7 @@ def test_try_except_structure() -> None:
def handle_error(self):
print("error")
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"TargetClass.target_method"}).code
expected = dedent("""
try:
@ -175,7 +175,7 @@ def test_init_method() -> None:
def target_method(self):
return f"Value: {self.x}"
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
expected = dedent("""
class MyClass:
@ -200,7 +200,7 @@ def test_dunder_method() -> None:
def target_method(self):
return f"Value: {self.x}"
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.target_method"}).code
expected = dedent("""
class MyClass:
@ -221,7 +221,7 @@ def test_no_targets_found() -> None:
def target(self):
pass
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}).code
expected = dedent("""
class MyClass:
def method(self):
@ -266,5 +266,55 @@ def test_module_var() -> None:
var2 = "test"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"})
output = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"target_function"}).code
assert dedent(expected).strip() == output.strip()
def test_comment_between_imports_and_variable_preserves_position() -> None:
code = """
from __future__ import annotations
import re
from dataclasses import dataclass, field
# NOTE: This comment documents the constant below.
# It should stay right above SOME_RE, not jump to the top of the file.
SOME_RE = re.compile(r"^pattern", re.MULTILINE)
@dataclass(slots=True)
class Item:
name: str
value: int
children: list[Item] = field(default_factory=list)
def parse(text: str) -> list[Item]:
root = Item(name="root", value=0)
for m in SOME_RE.finditer(text):
root.children.append(Item(name=m.group(), value=1))
return root.children
"""
expected = """
# NOTE: This comment documents the constant below.
# It should stay right above SOME_RE, not jump to the top of the file.
SOME_RE = re.compile(r"^pattern", re.MULTILINE)
@dataclass(slots=True)
class Item:
name: str
value: int
children: list[Item] = field(default_factory=list)
def parse(text: str) -> list[Item]:
root = Item(name="root", value=0)
for m in SOME_RE.finditer(text):
root.children.append(Item(name=m.group(), value=1))
return root.children
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.READ_WRITABLE, {"parse"}).code
assert result.strip() == dedent(expected).strip()

View file

@ -13,7 +13,7 @@ def test_simple_function() -> None:
y = 2
return x + y
"""
result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
result = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
expected = """
def target_function():
@ -44,7 +44,7 @@ def test_basic_class() -> None:
print("This should be included")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -73,7 +73,7 @@ def test_dunder_methods() -> None:
print("include me")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -107,7 +107,7 @@ def test_dunder_methods_remove_docstring() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
)
).code
assert dedent(expected).strip() == output.strip()
@ -139,7 +139,7 @@ def test_class_remove_docstring() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set(), remove_docstrings=True
)
).code
assert dedent(expected).strip() == output.strip()
@ -181,7 +181,7 @@ def test_method_signatures() -> None:
return "value"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -215,7 +215,7 @@ def test_multiple_top_level_targets() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"TestClass.target1", "TestClass.target2"}, set()
)
).code
assert dedent(expected).strip() == output.strip()
@ -238,7 +238,7 @@ def test_class_annotations() -> None:
self.var2 = "test"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -263,7 +263,7 @@ def test_class_annotations_if() -> None:
self.var2 = "test"
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -304,7 +304,7 @@ def test_conditional_class_definitions() -> None:
print("other")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"PlatformClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -333,7 +333,7 @@ def test_try_except_structure() -> None:
print("error")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TargetClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -355,7 +355,7 @@ def test_module_var() -> None:
x = 5
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -385,7 +385,7 @@ def test_module_var_if() -> None:
z = 10
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"target_function"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -416,7 +416,7 @@ def test_multiple_classes() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"ClassA.process", "ClassC.process"}, set()
)
).code
assert dedent(expected).strip() == output.strip()
@ -477,7 +477,7 @@ def test_with_statement_and_loops() -> None:
print("cleanup")
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -532,7 +532,7 @@ def test_async_with_try_except() -> None:
await self.cleanup()
"""
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set())
output = parse_code_and_prune_cst(dedent(code), CodeContextType.TESTGEN, {"TestClass.target_method"}, set()).code
assert dedent(expected).strip() == output.strip()
@ -659,7 +659,7 @@ def test_simplified_complete_implementation() -> None:
output = parse_code_and_prune_cst(
dedent(code), CodeContextType.TESTGEN, {"DataProcessor.target_method", "ResultHandler.target_method"}, set()
)
).code
assert dedent(expected).strip() == output.strip()
@ -765,5 +765,5 @@ def test_simplified_complete_implementation_no_docstring() -> None:
{"DataProcessor.target_method", "ResultHandler.target_method"},
set(),
remove_docstrings=True,
)
).code
assert dedent(expected).strip() == output.strip()

View file

@ -1,6 +1,8 @@
"""Tests for JavaScript/TypeScript project initialization and package manager detection."""
import json
from pathlib import Path
from unittest.mock import patch
import pytest
@ -8,6 +10,7 @@ from codeflash.cli_cmds.init_javascript import (
JsPackageManager,
determine_js_package_manager,
get_package_install_command,
should_modify_package_json_config,
)
@ -281,3 +284,67 @@ class TestGetPackageInstallCommand:
result = get_package_install_command(tmp_project, "typescript", dev=True)
assert result == ["pnpm", "add", "typescript", "--save-dev"]
class TestShouldModifySkipConfirm:
"""Tests for should_modify_package_json_config with skip_confirm."""
def test_should_modify_skip_confirm_no_config(self, tmp_project: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""With skip_confirm and no codeflash config, should return (True, None)."""
monkeypatch.chdir(tmp_project)
(tmp_project / "package.json").write_text(json.dumps({"name": "test"}))
should_modify, config = should_modify_package_json_config(skip_confirm=True)
assert should_modify is True
assert config is None
def test_should_modify_skip_confirm_with_valid_config(
self, tmp_project: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""With skip_confirm and valid config, should return (False, config) — no reconfigure."""
monkeypatch.chdir(tmp_project)
codeflash_config = {"moduleRoot": "."}
(tmp_project / "package.json").write_text(
json.dumps({"name": "test", "codeflash": codeflash_config})
)
should_modify, config = should_modify_package_json_config(skip_confirm=True)
assert should_modify is False
assert config == codeflash_config
def test_should_modify_skip_confirm_with_invalid_config(
self, tmp_project: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""With skip_confirm and invalid config (bad moduleRoot), should return (True, None)."""
monkeypatch.chdir(tmp_project)
codeflash_config = {"moduleRoot": "/nonexistent/path/that/does/not/exist"}
(tmp_project / "package.json").write_text(
json.dumps({"name": "test", "codeflash": codeflash_config})
)
should_modify, config = should_modify_package_json_config(skip_confirm=True)
assert should_modify is True
assert config is None
class TestCollectJsSetupInfoSkipConfirm:
"""Tests for collect_js_setup_info with skip_confirm."""
def test_collect_js_setup_info_skip_confirm(self, tmp_project: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""skip_confirm should return defaults without any interactive prompts."""
monkeypatch.chdir(tmp_project)
(tmp_project / "package.json").write_text(json.dumps({"name": "test"}))
from codeflash.cli_cmds.init_javascript import ProjectLanguage, collect_js_setup_info
# Should not call any prompt functions
with patch("codeflash.cli_cmds.init_javascript.inquirer") as mock_inquirer:
setup_info = collect_js_setup_info(ProjectLanguage.JAVASCRIPT, skip_confirm=True)
mock_inquirer.prompt.assert_not_called()
assert setup_info.module_root_override is None
assert setup_info.formatter_override is None
assert setup_info.git_remote == "origin"

View file

@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -22,7 +22,7 @@ def test_add_decorator_imports_helper_in_class():
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
# func_optimizer = pass
try:
@ -94,7 +94,7 @@ def test_add_decorator_imports_helper_in_nested_class():
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
# func_optimizer = pass
try:
@ -143,7 +143,7 @@ def test_add_decorator_imports_nodeps():
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
# func_optimizer = pass
try:
@ -194,7 +194,7 @@ def test_add_decorator_imports_helper_outside():
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sorter_deps", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
# func_optimizer = pass
try:
@ -271,7 +271,7 @@ class helper:
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_write_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
# func_optimizer = pass
try:

View file

@ -20,12 +20,15 @@ All assertions use strict string equality to verify exact extraction output.
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context_for_language
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@pytest.fixture
@ -61,7 +64,8 @@ export function add(a, b) {
file_path = temp_project / "math.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
assert len(functions) == 1
func = functions[0]
@ -87,7 +91,8 @@ export const multiply = (a, b) => a * b;
file_path = temp_project / "math.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "multiply"
@ -121,7 +126,8 @@ export function add(a, b) {
file_path = temp_project / "math.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -173,7 +179,8 @@ export async function processItems(items, callback, options = {}) {
file_path = temp_project / "processor.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -243,7 +250,8 @@ export class CacheManager {
file_path = temp_project / "cache.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
get_or_compute = next(f for f in functions if f.function_name == "getOrCompute")
context = js_support.extract_code_context(get_or_compute, temp_project, temp_project)
@ -339,7 +347,8 @@ export function validateUserData(data, validators) {
file_path = temp_project / "validator.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = next(f for f in functions if f.function_name == "validateUserData")
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -429,7 +438,8 @@ export async function fetchWithRetry(endpoint, options = {}) {
file_path = temp_project / "api.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = next(f for f in functions if f.function_name == "fetchWithRetry")
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -515,7 +525,8 @@ export function validateField(value, fieldType) {
file_path = temp_project / "validation.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -578,7 +589,8 @@ export function processUserInput(rawInput) {
file_path = temp_project / "processor.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
process_func = next(f for f in functions if f.function_name == "processUserInput")
context = js_support.extract_code_context(process_func, temp_project, temp_project)
@ -633,7 +645,8 @@ export function generateReport(data) {
file_path = temp_project / "report.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
report_func = next(f for f in functions if f.function_name == "generateReport")
context = js_support.extract_code_context(report_func, temp_project, temp_project)
@ -731,7 +744,8 @@ export class Graph {
file_path = temp_project / "graph.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
topo_sort = next(f for f in functions if f.function_name == "topologicalSort")
context = js_support.extract_code_context(topo_sort, temp_project, temp_project)
@ -819,7 +833,8 @@ export class MainClass {
file_path = temp_project / "classes.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
main_method = next(f for f in functions if f.function_name == "mainMethod" and f.class_name == "MainClass")
context = js_support.extract_code_context(main_method, temp_project, temp_project)
@ -875,7 +890,8 @@ module.exports = { sortFromAnotherFile };
main_path = temp_project / "bubble_sort_imported.js"
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
source = main_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, main_path)
main_func = next(f for f in functions if f.function_name == "sortFromAnotherFile")
context = js_support.extract_code_context(main_func, temp_project, temp_project)
@ -926,7 +942,8 @@ export function processNumber(n) {
main_path = temp_project / "main.js"
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
source = main_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, main_path)
process_func = next(f for f in functions if f.function_name == "processNumber")
context = js_support.extract_code_context(process_func, temp_project, temp_project)
@ -992,7 +1009,8 @@ export function handleUserInput(rawInput) {
main_path = temp_project / "main.js"
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
source = main_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, main_path)
handle_func = next(f for f in functions if f.function_name == "handleUserInput")
context = js_support.extract_code_context(handle_func, temp_project, temp_project)
@ -1043,7 +1061,8 @@ export function createEntity<T extends object>(data: T): Entity<T> {
file_path = temp_project / "entity.ts"
file_path.write_text(code, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
func = functions[0]
context = ts_support.extract_code_context(func, temp_project, temp_project)
@ -1133,7 +1152,8 @@ export class TypedCache<T> {
file_path = temp_project / "cache.ts"
file_path.write_text(code, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
get_method = next(f for f in functions if f.function_name == "get")
context = ts_support.extract_code_context(get_method, temp_project, temp_project)
@ -1217,7 +1237,8 @@ export function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE
service_path = temp_project / "service.ts"
service_path.write_text(service_code, encoding="utf-8")
functions = ts_support.discover_functions(service_path)
source = service_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, service_path)
func = next(f for f in functions if f.function_name == "createUser")
context = ts_support.extract_code_context(func, temp_project, temp_project)
@ -1271,7 +1292,8 @@ export function factorial(n) {
file_path = temp_project / "math.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -1301,7 +1323,8 @@ export function isOdd(n) {
file_path = temp_project / "parity.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
is_even = next(f for f in functions if f.function_name == "isEven")
context = js_support.extract_code_context(is_even, temp_project, temp_project)
@ -1319,12 +1342,15 @@ export function isEven(n) {
assert helper_names == ["isOdd"]
# Verify helper source
assert context.helper_functions[0].source_code == """\
assert (
context.helper_functions[0].source_code
== """\
export function isOdd(n) {
if (n === 0) return false;
return isEven(n - 1);
}
"""
)
def test_complex_recursive_tree_traversal(self, js_support, temp_project):
"""Test complex recursive tree traversal with multiple recursive calls."""
@ -1363,7 +1389,8 @@ export function collectAllValues(root) {
file_path = temp_project / "tree.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
collect_func = next(f for f in functions if f.function_name == "collectAllValues")
context = js_support.extract_code_context(collect_func, temp_project, temp_project)
@ -1428,7 +1455,8 @@ export async function fetchUserProfile(userId) {
file_path = temp_project / "api.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
profile_func = next(f for f in functions if f.function_name == "fetchUserProfile")
context = js_support.extract_code_context(profile_func, temp_project, temp_project)
@ -1483,7 +1511,8 @@ module.exports = { Counter };
file_path = temp_project / "counter.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
increment_func = next(fn for fn in functions if fn.function_name == "increment")
# Step 1: Extract code context
@ -1563,7 +1592,8 @@ export function processApiResponse({
file_path = temp_project / "api.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -1605,7 +1635,8 @@ export function* fibonacci(limit) {
file_path = temp_project / "generators.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
range_func = next(f for f in functions if f.function_name == "range")
context = js_support.extract_code_context(range_func, temp_project, temp_project)
@ -1640,7 +1671,8 @@ export function createUserObject(name, email, age) {
file_path = temp_project / "user.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -1790,7 +1822,8 @@ export const sendSlackMessage = async (
file_path.write_text(code, encoding="utf-8")
target_func = "sendSlackMessage"
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
func_info = next(f for f in functions if f.function_name == target_func)
fto = FunctionToOptimize(
function_name=target_func,
@ -1804,9 +1837,11 @@ export const sendSlackMessage = async (
language="typescript",
)
ctx = get_code_optimization_context_for_language(
fto, temp_project
test_config = TestConfig(
tests_root=temp_project, tests_project_rootdir=temp_project, project_root_path=temp_project
)
func_optimizer = JavaScriptFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config, aiservice_client=MagicMock())
ctx = func_optimizer.get_code_optimization_context().unwrap()
# The read_writable_code should contain the target function AND helper functions
expected_read_writable = """```typescript:slack_util.ts
@ -1899,7 +1934,6 @@ let web: WebClient | null = null"""
assert ctx.read_only_context_code == expected_read_only
class TestContextProperties:
"""Tests for CodeContext object properties."""
@ -1913,7 +1947,8 @@ export function test() {
file_path = temp_project / "test.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
context = js_support.extract_code_context(functions[0], temp_project, temp_project)
assert context.language == Language.JAVASCRIPT
@ -1932,7 +1967,8 @@ export function test(): number {
file_path = temp_project / "test.ts"
file_path.write_text(code, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
context = ts_support.extract_code_context(functions[0], temp_project, temp_project)
# TypeScript uses JavaScript language enum
@ -1974,7 +2010,8 @@ export class Calculator {
file_path = temp_project / "calculator.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
for func in functions:
if func.function_name != "constructor":

View file

@ -107,10 +107,8 @@ class TestJavaScriptCodeContext:
"""Test extracting code context for a JavaScript function."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.JAVASCRIPT
from codeflash.languages import get_language_support
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
fib_file = js_project_dir / "fibonacci.js"
if not fib_file.exists():
@ -122,7 +120,11 @@ class TestJavaScriptCodeContext:
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
assert fib_func is not None
context = get_code_optimization_context(fib_func, js_project_dir)
js_support = get_language_support(Language.JAVASCRIPT)
code_context = js_support.extract_code_context(fib_func, js_project_dir, js_project_dir)
context = JavaScriptFunctionOptimizer._build_optimization_context(
code_context, fib_file, "javascript", js_project_dir
)
assert context.read_writable_code is not None
assert context.read_writable_code.language == "javascript"

View file

@ -71,10 +71,8 @@ module.exports = { add };
"""Verify language is preserved in code context extraction."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages import get_language_support
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
ts_file = tmp_path / "utils.ts"
ts_file.write_text("""
@ -86,7 +84,11 @@ export function add(a: number, b: number): number {
functions = find_all_functions_in_file(ts_file)
func = functions[ts_file][0]
context = get_code_optimization_context(func, tmp_path)
ts_support = get_language_support(Language.TYPESCRIPT)
code_context = ts_support.extract_code_context(func, tmp_path, tmp_path)
context = JavaScriptFunctionOptimizer._build_optimization_context(
code_context, ts_file, "typescript", tmp_path
)
assert context.read_writable_code is not None
assert context.read_writable_code.language == "typescript"
@ -373,10 +375,7 @@ describe('fibonacci', () => {
"""Test get_code_optimization_context for JavaScript."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.JAVASCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
src_file = js_project / "utils.js"
functions = find_all_functions_in_file(src_file)
@ -398,7 +397,7 @@ describe('fibonacci', () => {
pytest_cmd="jest",
)
optimizer = FunctionOptimizer(
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
@ -415,10 +414,7 @@ describe('fibonacci', () => {
"""Test get_code_optimization_context for TypeScript."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
src_file = ts_project / "utils.ts"
functions = find_all_functions_in_file(src_file)
@ -440,7 +436,7 @@ describe('fibonacci', () => {
pytest_cmd="vitest",
)
optimizer = FunctionOptimizer(
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
@ -461,10 +457,7 @@ class TestHelperFunctionLanguageAttribute:
"""Verify helper functions have language='javascript' for .js files."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.JAVASCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
# Create a file with helper functions
src_file = tmp_path / "main.js"
@ -499,7 +492,7 @@ module.exports = { main };
pytest_cmd="jest",
)
optimizer = FunctionOptimizer(
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
@ -515,10 +508,7 @@ module.exports = { main };
"""Verify helper functions have language='typescript' for .ts files."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
# Create a file with helper functions
src_file = tmp_path / "main.ts"
@ -551,7 +541,7 @@ export function main(): number {
pytest_cmd="vitest",
)
optimizer = FunctionOptimizer(
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),

View file

@ -16,8 +16,6 @@ NOTE: These tests require:
Tests will be skipped if dependencies are not available.
"""
import os
import shutil
import subprocess
from pathlib import Path
from unittest.mock import MagicMock
@ -26,7 +24,7 @@ import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestType, TestingMode
from codeflash.models.models import FunctionParent
from codeflash.verification.verification_utils import TestConfig
@ -58,13 +56,7 @@ def install_dependencies(project_dir: Path) -> bool:
if has_node_modules(project_dir):
return True
try:
result = subprocess.run(
["npm", "install"],
cwd=project_dir,
capture_output=True,
text=True,
timeout=120
)
result = subprocess.run(["npm", "install"], cwd=project_dir, capture_output=True, text=True, timeout=120)
return result.returncode == 0
except Exception:
return False
@ -82,6 +74,7 @@ def skip_if_js_not_supported():
"""Skip test if JavaScript/TypeScript languages are not supported."""
try:
from codeflash.languages import get_language_support
get_language_support(Language.JAVASCRIPT)
except Exception as e:
pytest.skip(f"JavaScript/TypeScript language support not available: {e}")
@ -157,8 +150,8 @@ module.exports = {
"""Test that JavaScript test instrumentation module can be imported."""
skip_if_js_not_supported()
from codeflash.languages import get_language_support
# Verify the instrumentation module can be imported
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
# Get JavaScript support
js_support = get_language_support(Language.JAVASCRIPT)
@ -272,8 +265,8 @@ export default defineConfig({
"""Test that TypeScript test instrumentation module can be imported."""
skip_if_js_not_supported()
from codeflash.languages import get_language_support
# Verify the instrumentation module can be imported
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
test_file = ts_project_dir / "tests" / "math.test.ts"
@ -356,10 +349,7 @@ class TestRunAndParseJavaScriptTests:
"""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
# Find the fibonacci function
fib_file = vitest_project / "fibonacci.ts"
@ -389,10 +379,8 @@ class TestRunAndParseJavaScriptTests:
)
# Create optimizer
func_optimizer = FunctionOptimizer(
function_to_optimize=func,
test_cfg=test_config,
aiservice_client=MagicMock(),
func_optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
)
# Get code context - this should work
@ -419,8 +407,8 @@ class TestTimingMarkerParsing:
# The marker format used by codeflash for JavaScript
# Start marker: !$######{tag}######$!
# End marker: !######{tag}:{duration}######!
start_pattern = r'!\$######(.+?)######\$!'
end_pattern = r'!######(.+?):(\d+)######!'
start_pattern = r"!\$######(.+?)######\$!"
end_pattern = r"!######(.+?):(\d+)######!"
start_marker = "!$######test/math.test.ts:TestMath.test_add:add:1:0_0######$!"
end_marker = "!######test/math.test.ts:TestMath.test_add:add:1:0_0:12345######!"
@ -472,6 +460,7 @@ class TestJavaScriptTestResultParsing:
# Parse the XML
import xml.etree.ElementTree as ET
tree = ET.parse(junit_xml)
root = tree.getroot()
@ -504,6 +493,7 @@ class TestJavaScriptTestResultParsing:
# Parse the XML
import xml.etree.ElementTree as ET
tree = ET.parse(junit_xml)
root = tree.getroot()

View file

@ -52,7 +52,7 @@ export function add(a, b) {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 1
assert functions[0].function_name == "add"
@ -76,7 +76,7 @@ export function multiply(a, b) {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 3
names = {func.function_name for func in functions}
@ -94,7 +94,7 @@ export const multiply = (x, y) => x * y;
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 2
names = {func.function_name for func in functions}
@ -114,7 +114,7 @@ export function withoutReturn() {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
# Only the function with return should be discovered
assert len(functions) == 1
@ -136,7 +136,7 @@ export class Calculator {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 2
for func in functions:
@ -157,7 +157,7 @@ export function syncFunction() {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 2
@ -182,7 +182,7 @@ export function syncFunc() {
f.flush()
criteria = FunctionFilterCriteria(include_async=False)
functions = js_support.discover_functions(Path(f.name), criteria)
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria)
assert len(functions) == 1
assert functions[0].function_name == "syncFunc"
@ -204,7 +204,7 @@ export class MyClass {
f.flush()
criteria = FunctionFilterCriteria(include_methods=False)
functions = js_support.discover_functions(Path(f.name), criteria)
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria)
assert len(functions) == 1
assert functions[0].function_name == "standalone"
@ -224,7 +224,7 @@ export function func2() {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
func1 = next(f for f in functions if f.function_name == "func1")
func2 = next(f for f in functions if f.function_name == "func2")
@ -246,7 +246,7 @@ export function* numberGenerator() {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 1
assert functions[0].function_name == "numberGenerator"
@ -257,14 +257,14 @@ export function* numberGenerator() {
f.write("this is not valid javascript {{{{")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
# Tree-sitter is lenient, so it may still parse partial code
# The important thing is it doesn't crash
assert isinstance(functions, list)
def test_discover_nonexistent_file_returns_empty(self, js_support):
"""Test that nonexistent file returns empty list."""
functions = js_support.discover_functions(Path("/nonexistent/file.js"))
functions = js_support.discover_functions("", Path("/nonexistent/file.js"))
assert functions == []
def test_discover_function_expression(self, js_support):
@ -277,7 +277,7 @@ export const add = function(a, b) {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 1
assert functions[0].function_name == "add"
@ -296,7 +296,7 @@ export function named() {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
# Only the named function should be discovered
assert len(functions) == 1
@ -507,7 +507,7 @@ export function main(a) {
file_path = Path(f.name)
# First discover functions to get accurate line numbers
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
main_func = next(f for f in functions if f.function_name == "main")
context = js_support.extract_code_context(main_func, file_path.parent, file_path.parent)
@ -535,7 +535,7 @@ class TestIntegration:
file_path = Path(f.name)
# Discover
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "fibonacci"
@ -584,7 +584,7 @@ export function standalone() {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should find 4 functions
assert len(functions) == 4
@ -623,7 +623,7 @@ export default Button;
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should find both components
names = {f.function_name for f in functions}
@ -653,7 +653,7 @@ describe('Math functions', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -687,7 +687,7 @@ class TestClassMethodExtraction:
file_path = Path(f.name)
# Discover the method
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_method = next(f for f in functions if f.function_name == "add")
# Extract code context
@ -725,7 +725,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_method = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
@ -763,7 +763,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
fib_method = next(f for f in functions if f.function_name == "fibonacci")
context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent)
@ -802,7 +802,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_method = next((f for f in functions if f.function_name == "add"), None)
if add_method:
@ -832,7 +832,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
fetch_method = next(f for f in functions if f.function_name == "fetchData")
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
@ -865,7 +865,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_method = next((f for f in functions if f.function_name == "add"), None)
if add_method:
@ -894,7 +894,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
method = next(f for f in functions if f.function_name == "simpleMethod")
context = js_support.extract_code_context(method, file_path.parent, file_path.parent)
@ -1079,7 +1079,7 @@ class TestClassMethodEdgeCases:
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should find constructor and increment
names = {f.function_name for f in functions}
@ -1109,7 +1109,7 @@ class TestClassMethodEdgeCases:
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should find at least greet
names = {f.function_name for f in functions}
@ -1137,7 +1137,7 @@ export class Dog extends Animal {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Find Dog's fetch method
fetch_method = next((f for f in functions if f.function_name == "fetch" and f.class_name == "Dog"), None)
@ -1172,7 +1172,7 @@ export class Dog extends Animal {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should at least find publicMethod
names = {f.function_name for f in functions}
@ -1192,7 +1192,7 @@ module.exports = { Calculator };
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_method = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
@ -1212,7 +1212,7 @@ module.exports = { Calculator };
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Find the add method
add_method = next((f for f in functions if f.function_name == "add"), None)
@ -1265,7 +1265,7 @@ module.exports = { Counter };
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
increment_func = next(fn for fn in functions if fn.function_name == "increment")
# Step 1: Extract code context (includes constructor for AI context)
@ -1362,7 +1362,7 @@ export class User {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
functions = ts_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
get_name_func = next(fn for fn in functions if fn.function_name == "getName")
# Step 1: Extract code context (includes fields and constructor)
@ -1462,7 +1462,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_func = next(fn for fn in functions if fn.function_name == "add")
# Extract context for add
@ -1546,7 +1546,7 @@ export class MathUtils {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_func = next(fn for fn in functions if fn.function_name == "add")
# Extract context

View file

@ -53,7 +53,7 @@ describe('add function', () => {
""")
# Discover functions first
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
assert len(functions) == 1
# Discover tests
@ -90,7 +90,7 @@ describe('multiply', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -124,7 +124,7 @@ test('formats date correctly', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -170,7 +170,7 @@ describe('String Utils', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -208,7 +208,7 @@ describe('sum function', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -242,7 +242,7 @@ test('subtract two numbers', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -270,7 +270,7 @@ test('greets by name', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -316,7 +316,7 @@ describe('Calculator class', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should find tests for class methods
@ -363,7 +363,7 @@ describe('clamp', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -399,7 +399,7 @@ describe('async utilities', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -436,7 +436,7 @@ describe('Button component', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# JSX tests should be discovered
@ -466,7 +466,7 @@ test('other test', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should not find tests for our function
@ -502,7 +502,7 @@ describe('validators', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should find tests for isEmail
@ -546,7 +546,7 @@ test('helper2 returns 2', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -574,7 +574,7 @@ test(`formatNumber with decimal`, () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# May or may not find depending on template literal handling
@ -605,7 +605,7 @@ describe('transform', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should still find tests since original name is imported
@ -626,7 +626,7 @@ it('third test', () => {});
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -651,7 +651,7 @@ describe('Suite B', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -675,7 +675,7 @@ describe('Outer', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -699,7 +699,7 @@ describe.skip('skipped describe', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -720,7 +720,7 @@ describe.only('only describe', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -738,7 +738,7 @@ describe('describe single', () => {});
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -757,7 +757,7 @@ describe("describe double", () => {});
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -773,7 +773,7 @@ describe("describe double", () => {});
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -806,7 +806,7 @@ test('funcA works', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# funcA should have tests
@ -833,7 +833,7 @@ test('funcX works', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# funcX should have tests
@ -859,7 +859,7 @@ test('mainFunc works', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -896,7 +896,7 @@ test('block commented', () => {
*/
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -921,7 +921,7 @@ test('broken test' { // Missing arrow function
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
# Should not crash
tests = js_support.discover_tests(tmpdir, functions)
assert isinstance(tests, dict)
@ -949,7 +949,7 @@ describe('conflict tests', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should still work despite naming conflicts
@ -966,7 +966,7 @@ export function lonelyFunc() { return 'alone'; }
module.exports = { lonelyFunc };
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should return empty dict, not crash
@ -1001,7 +1001,7 @@ test('funcA works', () => {
});
""")
functions_a = js_support.discover_functions(file_a)
functions_a = js_support.discover_functions(file_a.read_text(encoding="utf-8"), file_a)
tests = js_support.discover_tests(tmpdir, functions_a)
# Should handle circular imports gracefully
@ -1047,7 +1047,7 @@ test.each([
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1073,7 +1073,7 @@ describe.each([
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1098,7 +1098,7 @@ describe('Math operations', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1174,7 +1174,7 @@ describe('formatName', () => {
""")
# Discover functions
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
assert len(functions) == 3
# Discover tests
@ -1242,7 +1242,7 @@ describe('Database', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -1280,7 +1280,7 @@ test('funcA works', () => {
""")
# Discover functions from moduleB
functions_b = js_support.discover_functions(source_b)
functions_b = js_support.discover_functions(source_b.read_text(encoding="utf-8"), source_b)
tests = js_support.discover_tests(tmpdir, functions_b)
# funcB should not have any tests since test file doesn't import it
@ -1312,7 +1312,7 @@ test('funcOne returns 1', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Check that tests were found
@ -1340,7 +1340,7 @@ test('mentions targetFunc in string', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Current implementation may still match on string occurrence
@ -1367,7 +1367,7 @@ test('calculate doubles', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should find tests since 'calculate' appears in source
@ -1399,7 +1399,7 @@ describe('MyClass', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should find tests for class methods
@ -1432,7 +1432,7 @@ test('deepHelper works', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -1456,7 +1456,7 @@ testCases.forEach(name => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1484,7 +1484,7 @@ describe('conditional tests', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1508,7 +1508,7 @@ test('slow test', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1531,7 +1531,7 @@ test.todo('also needs implementation');
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1554,7 +1554,7 @@ test.concurrent('concurrent test 2', async () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1597,7 +1597,7 @@ describe('subtractNumbers', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# All three functions should be discovered
@ -1628,7 +1628,7 @@ describe('Unrelated name', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should still find tests
@ -1653,7 +1653,7 @@ describe('Array', function() {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1684,7 +1684,7 @@ describe('User', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1712,7 +1712,7 @@ export class Calculator {
module.exports = { Calculator };
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
# Check qualified names include class
add_func = next((f for f in functions if f.function_name == "add"), None)
@ -1737,7 +1737,7 @@ export class Outer {
module.exports = { Outer };
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
# Should find at least the Outer class method
assert any(f.class_name == "Outer" for f in functions)

View file

@ -728,3 +728,370 @@ class TestBundlerModuleResolutionFix:
# Verify codeflash configs were NOT created
assert not (tmpdir_path / "jest.codeflash.config.js").exists()
assert not (tmpdir_path / "tsconfig.codeflash.json").exists()
class TestBundledJestReporter:
"""Tests for the bundled codeflash/jest-reporter.
Verifies that:
1. The reporter JS file exists in the runtime package
2. Jest commands reference 'codeflash/jest-reporter' (not jest-junit)
3. The reporter produces valid JUnit XML
4. The CODEFLASH_JEST_REPORTER constant is correct
"""
def test_reporter_js_file_exists(self):
"""The jest-reporter.js file must exist in the runtime directory."""
reporter_path = Path(__file__).resolve().parents[2] / "packages" / "codeflash" / "runtime" / "jest-reporter.js"
assert reporter_path.exists(), f"jest-reporter.js not found at {reporter_path}"
def test_reporter_constant_value(self):
"""CODEFLASH_JEST_REPORTER should be 'codeflash/jest-reporter'."""
from codeflash.languages.javascript.test_runner import CODEFLASH_JEST_REPORTER
assert CODEFLASH_JEST_REPORTER == "codeflash/jest-reporter"
def test_behavioral_command_uses_bundled_reporter(self):
"""run_jest_behavioral_tests should use codeflash/jest-reporter in --reporters flag."""
from codeflash.languages.javascript.test_runner import run_jest_behavioral_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_dir = tmpdir_path / "test"
test_dir.mkdir()
test_file = test_dir / "test_func.test.js"
test_file.write_text("// test")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 1
mock_run.return_value = mock_result
try:
run_jest_behavioral_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass
if mock_run.called:
cmd = mock_run.call_args[0][0]
reporter_args = [a for a in cmd if "--reporters=" in a and "jest-reporter" in a]
assert len(reporter_args) == 1, f"Expected exactly one codeflash/jest-reporter flag, got: {reporter_args}"
assert reporter_args[0] == "--reporters=codeflash/jest-reporter"
# Must NOT reference jest-junit
jest_junit_args = [a for a in cmd if "jest-junit" in a]
assert len(jest_junit_args) == 0, f"Should not reference jest-junit: {jest_junit_args}"
def test_benchmarking_command_uses_bundled_reporter(self):
"""run_jest_benchmarking_tests should use codeflash/jest-reporter."""
from codeflash.languages.javascript.test_runner import run_jest_benchmarking_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_dir = tmpdir_path / "test"
test_dir.mkdir()
test_file = test_dir / "test_func__perf.test.js"
test_file.write_text("// test")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 1
mock_run.return_value = mock_result
try:
run_jest_benchmarking_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass
if mock_run.called:
cmd = mock_run.call_args[0][0]
reporter_args = [a for a in cmd if "--reporters=codeflash/jest-reporter" in a]
assert len(reporter_args) == 1
def test_line_profile_command_uses_bundled_reporter(self):
"""run_jest_line_profile_tests should use codeflash/jest-reporter."""
from codeflash.languages.javascript.test_runner import run_jest_line_profile_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_dir = tmpdir_path / "test"
test_dir.mkdir()
test_file = test_dir / "test_func__line.test.js"
test_file.write_text("// test")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 1
mock_run.return_value = mock_result
try:
run_jest_line_profile_tests(
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass
if mock_run.called:
cmd = mock_run.call_args[0][0]
reporter_args = [a for a in cmd if "--reporters=codeflash/jest-reporter" in a]
assert len(reporter_args) == 1
def test_reporter_produces_valid_junit_xml(self):
"""The reporter JS should produce JUnit XML parseable by junitparser."""
import subprocess
reporter_path = Path(__file__).resolve().parents[2] / "packages" / "codeflash" / "runtime" / "jest-reporter.js"
with tempfile.TemporaryDirectory() as tmpdir:
output_file = Path(tmpdir) / "results.xml"
# Create a Node.js script that exercises the reporter with mock data
test_script = Path(tmpdir) / "test_reporter.js"
test_script.write_text(f"""
// Set env vars BEFORE requiring reporter (matches real Jest behavior)
process.env.JEST_JUNIT_OUTPUT_FILE = '{output_file}';
process.env.JEST_JUNIT_CLASSNAME = '{{filepath}}';
process.env.JEST_JUNIT_SUITE_NAME = '{{filepath}}';
process.env.JEST_JUNIT_ADD_FILE_ATTRIBUTE = 'true';
process.env.JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT = 'true';
const Reporter = require('{reporter_path}');
// Mock Jest globalConfig
const globalConfig = {{ rootDir: '/tmp/project' }};
const reporter = new Reporter(globalConfig, {{}});
// Mock test results (matches Jest's aggregatedResults structure)
const results = {{
testResults: [
{{
testFilePath: '/tmp/project/test/math.test.js',
displayName: 'math tests',
console: [{{ type: 'log', message: 'CODEFLASH_START test1' }}],
testResults: [
{{
fullName: 'math > adds numbers',
title: 'adds numbers',
status: 'passed',
duration: 12,
}},
{{
fullName: 'math > handles failure',
title: 'handles failure',
status: 'failed',
duration: 5,
failureMessages: ['Expected 4 but got 5'],
}},
{{
fullName: 'math > skipped test',
title: 'skipped test',
status: 'pending',
duration: 0,
}},
],
}},
],
}};
// Simulate onTestFileResult for console capture
reporter.onTestFileResult(null, results.testResults[0], null);
// Simulate onRunComplete
reporter.onRunComplete([], results);
console.log('OK');
""")
result = subprocess.run(
["node", str(test_script)],
capture_output=True,
text=True,
timeout=10,
)
assert result.returncode == 0, f"Reporter script failed: {result.stderr}"
assert output_file.exists(), "Reporter did not create output file"
xml_content = output_file.read_text()
# Verify basic XML structure
assert '<?xml version="1.0"' in xml_content
assert "<testsuites" in xml_content
assert "<testsuite" in xml_content
assert "<testcase" in xml_content
# Verify classname uses filepath template
assert 'classname="/tmp/project/test/math.test.js"' in xml_content
# Verify file attribute is present
assert 'file="/tmp/project/test/math.test.js"' in xml_content
# Verify failure element
assert "<failure" in xml_content
assert "Expected 4 but got 5" in xml_content
# Verify skipped element
assert "<skipped/>" in xml_content
# Verify system-out with console output
assert "<system-out>" in xml_content
assert "CODEFLASH_START" in xml_content
# Verify it's parseable by junitparser (our actual parser)
from junitparser import JUnitXml
parsed = JUnitXml.fromfile(str(output_file))
suites = list(parsed)
assert len(suites) == 1
testcases = list(suites[0])
assert len(testcases) == 3
def test_reporter_export_in_package_json(self):
"""package.json should export codeflash/jest-reporter."""
import json
pkg_path = Path(__file__).resolve().parents[2] / "packages" / "codeflash" / "package.json"
with pkg_path.open() as f:
pkg = json.load(f)
exports = pkg.get("exports", {})
assert "./jest-reporter" in exports, "Missing ./jest-reporter export in package.json"
assert exports["./jest-reporter"]["require"] == "./runtime/jest-reporter.js"
class TestUnsupportedFrameworkError:
"""Tests for clear error on unsupported test frameworks."""
def test_unknown_framework_raises_error_behavioral(self):
"""run_behavioral_tests should raise NotImplementedError for unknown frameworks."""
from codeflash.languages.javascript.support import JavaScriptSupport
support = JavaScriptSupport()
with pytest.raises(NotImplementedError, match="not yet supported"):
support.run_behavioral_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="tap",
)
def test_unknown_framework_raises_error_benchmarking(self):
"""run_benchmarking_tests should raise NotImplementedError for unknown frameworks."""
from codeflash.languages.javascript.support import JavaScriptSupport
support = JavaScriptSupport()
with pytest.raises(NotImplementedError, match="not yet supported"):
support.run_benchmarking_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="tap",
)
def test_unknown_framework_raises_error_line_profile(self):
"""run_line_profile_tests should raise NotImplementedError for unknown frameworks."""
from codeflash.languages.javascript.support import JavaScriptSupport
support = JavaScriptSupport()
with pytest.raises(NotImplementedError, match="not yet supported"):
support.run_line_profile_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="tap",
)
def test_jest_framework_does_not_raise_not_implemented(self):
"""jest framework should NOT raise NotImplementedError."""
from codeflash.languages.javascript.support import JavaScriptSupport
support = JavaScriptSupport()
try:
support.run_behavioral_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="jest",
)
except NotImplementedError:
pytest.fail("jest framework should not raise NotImplementedError")
except Exception:
pass # Other exceptions are fine — Jest isn't installed in test env
def test_mocha_framework_does_not_raise_not_implemented(self):
"""mocha framework should NOT raise NotImplementedError."""
from codeflash.languages.javascript.support import JavaScriptSupport
support = JavaScriptSupport()
try:
support.run_behavioral_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="mocha",
)
except NotImplementedError:
pytest.fail("mocha framework should not raise NotImplementedError")
except Exception:
pass # Other exceptions are fine — Mocha isn't installed in test env

View file

@ -13,7 +13,7 @@ from codeflash.languages.base import Language
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.registry import get_language_support
from codeflash.models.models import FunctionParent
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
FIXTURES_DIR = Path(__file__).parent / "fixtures"
@ -37,7 +37,7 @@ class TestCodeExtractorCJS:
def test_discover_class_methods(self, js_support, cjs_project):
"""Test that class methods are discovered correctly."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
method_names = {f.function_name for f in functions}
@ -47,17 +47,19 @@ class TestCodeExtractorCJS:
def test_class_method_has_correct_parent(self, js_support, cjs_project):
"""Test parent class information for methods."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
for func in functions:
# All methods should belong to Calculator class
assert func.is_method is True, f"{func.function_name} should be a method"
assert func.class_name == "Calculator", f"{func.function_name} should belong to Calculator, got {func.class_name}"
assert func.class_name == "Calculator", (
f"{func.function_name} should belong to Calculator, got {func.class_name}"
)
def test_extract_permutation_code(self, js_support, cjs_project):
"""Test permutation method code extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
@ -93,7 +95,7 @@ class Calculator {
def test_extract_context_includes_direct_helpers(self, js_support, cjs_project):
"""Test that direct helper functions are included in context."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
@ -129,7 +131,7 @@ export function factorial(n) {
def test_extract_compound_interest_code(self, js_support, cjs_project):
"""Test calculateCompoundInterest code extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
@ -175,7 +177,7 @@ class Calculator {
def test_extract_compound_interest_helpers(self, js_support, cjs_project):
"""Test helper extraction for calculateCompoundInterest."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
@ -235,7 +237,7 @@ export function validateInput(value, name) {
def test_extract_context_includes_imports(self, js_support, cjs_project):
"""Test import statement extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
@ -256,7 +258,7 @@ export function validateInput(value, name) {
def test_extract_static_method(self, js_support, cjs_project):
"""Test static method extraction (quickAdd)."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
quick_add_func = next(f for f in functions if f.function_name == "quickAdd")
@ -315,7 +317,7 @@ class TestCodeExtractorESM:
def test_discover_esm_methods(self, js_support, esm_project):
"""Test method discovery in ESM project."""
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
method_names = {f.function_name for f in functions}
@ -326,7 +328,7 @@ class TestCodeExtractorESM:
def test_esm_permutation_extraction(self, js_support, esm_project):
"""Test permutation method extraction in ESM."""
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
@ -376,7 +378,7 @@ export function factorial(n) {
def test_esm_compound_interest_extraction(self, js_support, esm_project):
"""Test calculateCompoundInterest extraction in ESM with import syntax."""
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
@ -502,7 +504,7 @@ class TestCodeExtractorTypeScript:
def test_discover_ts_methods(self, ts_support, ts_project):
"""Test method discovery in TypeScript."""
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
method_names = {f.function_name for f in functions}
@ -513,7 +515,7 @@ class TestCodeExtractorTypeScript:
def test_ts_permutation_extraction(self, ts_support, ts_project):
"""Test permutation method extraction in TypeScript."""
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
@ -566,7 +568,7 @@ export function factorial(n: number): number {
def test_ts_compound_interest_extraction(self, ts_support, ts_project):
"""Test calculateCompoundInterest extraction in TypeScript."""
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
@ -676,7 +678,7 @@ module.exports = { standalone };
test_file = tmp_path / "standalone.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "standalone")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -709,7 +711,7 @@ module.exports = { processArray };
test_file = tmp_path / "processor.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "processArray")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -744,7 +746,7 @@ module.exports = { fibonacci };
test_file = tmp_path / "recursive.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "fibonacci")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -777,7 +779,7 @@ module.exports = { processValue };
test_file = tmp_path / "arrow.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "processValue")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -835,7 +837,7 @@ module.exports = { Counter };
test_file = tmp_path / "counter.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
increment_func = next(f for f in functions if f.function_name == "increment")
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
@ -874,7 +876,7 @@ module.exports = { MathUtils };
test_file = tmp_path / "math_utils.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
add_func = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
@ -910,7 +912,7 @@ export class User {
test_file = tmp_path / "user.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_name_func = next(f for f in functions if f.function_name == "getName")
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
@ -949,7 +951,7 @@ export class Config {
test_file = tmp_path / "config.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_url_func = next(f for f in functions if f.function_name == "getUrl")
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
@ -990,7 +992,7 @@ module.exports = { Logger };
test_file = tmp_path / "logger.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_prefix_func = next(f for f in functions if f.function_name == "getPrefix")
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
@ -1032,7 +1034,7 @@ module.exports = { Factory };
test_file = tmp_path / "factory.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
create_func = next(f for f in functions if f.function_name == "create")
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
@ -1074,7 +1076,7 @@ class TestCodeExtractorIntegration:
js_support = get_language_support("javascript")
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
target = next(f for f in functions if f.function_name == "permutation")
parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents]
@ -1099,7 +1101,7 @@ class TestCodeExtractorIntegration:
pytest_cmd="jest",
)
func_optimizer = FunctionOptimizer(
func_optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
)
result = func_optimizer.get_code_optimization_context()
@ -1182,7 +1184,7 @@ export function distance(p1: Point, p2: Point): number {
test_file = tmp_path / "geometry.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
distance_func = next(f for f in functions if f.function_name == "distance")
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
@ -1224,7 +1226,7 @@ export function processStatus(status: Status): string {
test_file = tmp_path / "status.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
process_func = next(f for f in functions if f.function_name == "processStatus")
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
@ -1259,7 +1261,7 @@ export function compute(x: number): Result<number> {
test_file = tmp_path / "compute.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
compute_func = next(f for f in functions if f.function_name == "compute")
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
@ -1301,7 +1303,7 @@ export class Service {
test_file = tmp_path / "service.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_timeout_func = next(f for f in functions if f.function_name == "getTimeout")
context = ts_support.extract_code_context(
@ -1332,7 +1334,7 @@ export function add(a: number, b: number): number {
test_file = tmp_path / "add.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
add_func = next(f for f in functions if f.function_name == "add")
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
@ -1363,7 +1365,7 @@ export function createRect(origin: Point, size: Size): { origin: Point; size: Si
test_file = tmp_path / "rect.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
create_rect_func = next(f for f in functions if f.function_name == "createRect")
context = ts_support.extract_code_context(
@ -1409,7 +1411,7 @@ export function calculateDistance(p1: Point, p2: Point, config: CalculationConfi
}
""")
functions = ts_support.discover_functions(geometry_file)
functions = ts_support.discover_functions(geometry_file.read_text(encoding="utf-8"), geometry_file)
calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance")
context = ts_support.extract_code_context(
@ -1460,7 +1462,7 @@ export function greetUser(user: User): string {
test_file = tmp_path / "user.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
greet_func = next(f for f in functions if f.function_name == "greetUser")
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)

View file

@ -7,6 +7,7 @@ These tests verify that code replacement correctly handles:
- ES Modules (import/export) syntax
- TypeScript import handling
"""
from __future__ import annotations
import shutil
@ -14,8 +15,8 @@ from pathlib import Path
import pytest
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_for_language
from codeflash.languages.base import Language
from codeflash.languages.code_replacer import replace_function_definitions_for_language
from codeflash.languages.current import set_current_language
from codeflash.languages.javascript.module_system import (
ModuleSystem,
@ -25,7 +26,6 @@ from codeflash.languages.javascript.module_system import (
ensure_module_system_compatibility,
get_import_statement,
)
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.models.models import CodeStringsMarkdown
@ -50,7 +50,6 @@ def temp_project(tmp_path):
return project_root
FIXTURES_DIR = Path(__file__).parent / "fixtures"
@ -308,7 +307,9 @@ class TestTsJestSkipsConversion:
When ts-jest is installed, it handles module interoperability internally,
so we skip conversion to avoid breaking valid imports.
"""
def __init__(self):
@pytest.fixture(autouse=True)
def _set_language(self):
set_current_language(Language.TYPESCRIPT)
def test_commonjs_not_converted_when_ts_jest_installed(self, tmp_path):
@ -751,6 +752,7 @@ class TestIntegrationWithFixtures:
f"import statements should be converted to require.\nFound import lines: {import_lines}"
)
class TestSimpleFunctionReplacement:
"""Tests for simple function body replacement with strict assertions."""
@ -764,7 +766,8 @@ export function add(a, b) {
file_path = temp_project / "math.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
# Optimized version with different body
@ -800,7 +803,8 @@ export function processData(data) {
file_path = temp_project / "processor.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
# Optimized version using map
@ -839,7 +843,8 @@ module.exports = { targetFunction, otherFunction };
file_path = temp_project / "module.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
target_func = next(f for f in functions if f.function_name == "targetFunction")
optimized_code = """\
@ -891,7 +896,8 @@ export class Calculator {
file_path = temp_project / "calculator.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
add_method = next(f for f in functions if f.function_name == "add")
# Optimized version provided in class context
@ -954,7 +960,8 @@ export class DataProcessor {
file_path = temp_project / "processor.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
process_method = next(f for f in functions if f.function_name == "process")
optimized_code = """\
@ -1016,7 +1023,8 @@ export function add(a, b) {
file_path = temp_project / "math.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1070,7 +1078,8 @@ export class Cache {
file_path = temp_project / "cache.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
get_method = next(f for f in functions if f.function_name == "get")
optimized_code = """\
@ -1131,7 +1140,8 @@ export async function fetchData(url) {
file_path = temp_project / "api.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1172,7 +1182,8 @@ export class ApiClient {
file_path = temp_project / "client.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
get_method = next(f for f in functions if f.function_name == "get")
optimized_code = """\
@ -1223,7 +1234,8 @@ export function* range(start, end) {
file_path = temp_project / "generators.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1262,7 +1274,8 @@ export function processArray(items: number[]): number {
file_path = temp_project / "processor.ts"
file_path.write_text(original_source, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1303,7 +1316,8 @@ export class Container<T> {
file_path = temp_project / "container.ts"
file_path.write_text(original_source, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
get_all_method = next(f for f in functions if f.function_name == "getAll")
optimized_code = """\
@ -1356,7 +1370,8 @@ export function createUser(name: string, email: string): User {
file_path = temp_project / "user.ts"
file_path.write_text(original_source, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
func = next(f for f in functions if f.function_name == "createUser")
optimized_code = """\
@ -1411,7 +1426,8 @@ export function processItems(items) {
file_path = temp_project / "processor.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
process_func = next(f for f in functions if f.function_name == "processItems")
optimized_code = """\
@ -1458,7 +1474,8 @@ export class MathUtils {
file_path.write_text(original_source, encoding="utf-8")
# First replacement: sum method
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
sum_method = next(f for f in functions if f.function_name == "sum")
optimized_sum = """\
@ -1505,7 +1522,8 @@ export function processConfig({ server: { host, port }, database: { url, poolSiz
file_path = temp_project / "config.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1544,7 +1562,8 @@ export function minimal() {
file_path = temp_project / "minimal.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1571,7 +1590,8 @@ export function identity(x) { return x; }
file_path = temp_project / "utils.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1598,7 +1618,8 @@ export function formatMessage(name) {
file_path = temp_project / "formatter.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1633,7 +1654,8 @@ export function validateEmail(email) {
file_path = temp_project / "validator.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1676,7 +1698,8 @@ module.exports = { main, helper };
file_path = temp_project / "module.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
main_func = next(f for f in functions if f.function_name == "main")
optimized_code = """\
@ -1719,7 +1742,8 @@ export function main(data) {
file_path = temp_project / "module.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
main_func = next(f for f in functions if f.function_name == "main")
optimized_code = """\
@ -1750,20 +1774,16 @@ class TestSyntaxValidation:
"""Test that various replacements all produce valid JavaScript."""
test_cases = [
# (original, optimized, description)
(
"export function f(x) { return x + 1; }",
"export function f(x) { return ++x; }",
"increment replacement"
),
("export function f(x) { return x + 1; }", "export function f(x) { return ++x; }", "increment replacement"),
(
"export function f(arr) { return arr.length > 0; }",
"export function f(arr) { return !!arr.length; }",
"boolean conversion"
"boolean conversion",
),
(
"export function f(a, b) { if (a) { return a; } return b; }",
"export function f(a, b) { return a || b; }",
"logical OR replacement"
"logical OR replacement",
),
]
@ -1771,7 +1791,8 @@ class TestSyntaxValidation:
file_path = temp_project / f"test_{i}.js"
file_path.write_text(original, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
result = js_support.replace_function(original, func, optimized)
@ -1875,7 +1896,8 @@ export class DataProcessor<T> {
target_func = "findDuplicates"
parent_class = "DataProcessor"
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
# find function
target_func_info = None
for func in functions:
@ -1920,11 +1942,15 @@ class DataProcessor<T> {
```
"""
code_markdown = CodeStringsMarkdown.parse_markdown_code(new_code)
replaced = replace_function_definitions_for_language([f"{parent_class}.{target_func}"], code_markdown, file_path, temp_project)
replaced = replace_function_definitions_for_language(
[f"{parent_class}.{target_func}"], code_markdown, file_path, temp_project, lang_support=ts_support
)
assert replaced
new_code = file_path.read_text()
assert new_code == """/**
assert (
new_code
== """/**
* DataProcessor class - demonstrates class method optimization in TypeScript.
* Contains intentionally inefficient implementations for optimization testing.
*/
@ -2015,7 +2041,7 @@ export class DataProcessor<T> {
}
}
"""
)
class TestNewVariableFromOptimizedCode:
@ -2030,9 +2056,9 @@ class TestNewVariableFromOptimizedCode:
1. Add the new variable after the constant it references
2. Replace the function with the optimized version
"""
from codeflash.models.models import CodeStringsMarkdown, CodeString
from codeflash.models.models import CodeString, CodeStringsMarkdown
original_source = '''\
original_source = """\
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
"1234",
]);
@ -2040,43 +2066,34 @@ const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
export function isCodeflashEmployee(userId: string): boolean {
return CODEFLASH_EMPLOYEE_GITHUB_IDS.has(userId);
}
'''
"""
file_path = temp_project / "auth.ts"
file_path.write_text(original_source, encoding="utf-8")
# Optimized code introduces a bound method variable for performance
optimized_code = '''const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
optimized_code = """const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
CODEFLASH_EMPLOYEE_GITHUB_IDS
);
export function isCodeflashEmployee(userId: string): boolean {
return _has(userId);
}
'''
"""
code_markdown = CodeStringsMarkdown(
code_strings=[
CodeString(
code=optimized_code,
file_path=Path("auth.ts"),
language="typescript"
)
],
language="typescript"
code_strings=[CodeString(code=optimized_code, file_path=Path("auth.ts"), language="typescript")],
language="typescript",
)
replaced = replace_function_definitions_for_language(
["isCodeflashEmployee"],
code_markdown,
file_path,
temp_project,
["isCodeflashEmployee"], code_markdown, file_path, temp_project, lang_support=ts_support
)
assert replaced
result = file_path.read_text()
# Expected result for strict equality check
expected_result = '''\
expected_result = """\
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
"1234",
]);
@ -2088,11 +2105,9 @@ const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
export function isCodeflashEmployee(userId: string): boolean {
return _has(userId);
}
'''
"""
assert result == expected_result, (
f"Result does not match expected output.\n"
f"Expected:\n{expected_result}\n\n"
f"Got:\n{result}"
f"Result does not match expected output.\nExpected:\n{expected_result}\n\nGot:\n{result}"
)
@ -2113,7 +2128,7 @@ class TestImportedTypeNotDuplicated:
contains the TreeNode interface definition (from read-only context),
the replacement should NOT add the interface to the original file.
"""
from codeflash.models.models import CodeStringsMarkdown, CodeString
from codeflash.models.models import CodeString, CodeStringsMarkdown
# Original source imports TreeNode
original_source = """\
@ -2163,20 +2178,13 @@ export function getNearestAbove(
code_markdown = CodeStringsMarkdown(
code_strings=[
CodeString(
code=optimized_code_with_interface,
file_path=Path("helpers.ts"),
language="typescript"
)
CodeString(code=optimized_code_with_interface, file_path=Path("helpers.ts"), language="typescript")
],
language="typescript"
language="typescript",
)
replace_function_definitions_for_language(
["getNearestAbove"],
code_markdown,
file_path,
temp_project,
["getNearestAbove"], code_markdown, file_path, temp_project, lang_support=ts_support
)
result = file_path.read_text()
@ -2203,7 +2211,7 @@ export function getNearestAbove(
def test_multiple_imported_types_not_duplicated(self, ts_support, temp_project):
"""Test that multiple imported types are not duplicated."""
from codeflash.models.models import CodeStringsMarkdown, CodeString
from codeflash.models.models import CodeString, CodeStringsMarkdown
original_source = """\
import type { TreeNode, NodeSpace } from "./constants";
@ -2235,21 +2243,12 @@ export function processNode(node: TreeNode, space: NodeSpace): number {
"""
code_markdown = CodeStringsMarkdown(
code_strings=[
CodeString(
code=optimized_code,
file_path=Path("processor.ts"),
language="typescript"
)
],
language="typescript"
code_strings=[CodeString(code=optimized_code, file_path=Path("processor.ts"), language="typescript")],
language="typescript",
)
replace_function_definitions_for_language(
["processNode"],
code_markdown,
file_path,
temp_project,
["processNode"], code_markdown, file_path, temp_project, lang_support=ts_support
)
result = file_path.read_text()

View file

@ -345,8 +345,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(SIMPLE_FUNCTION.python, ".py")
js_file = write_temp_file(SIMPLE_FUNCTION.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find exactly one function
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -365,8 +365,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(MULTIPLE_FUNCTIONS.python, ".py")
js_file = write_temp_file(MULTIPLE_FUNCTIONS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 3 functions
assert len(py_funcs) == 3, f"Python found {len(py_funcs)}, expected 3"
@ -384,8 +384,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(WITH_AND_WITHOUT_RETURN.python, ".py")
js_file = write_temp_file(WITH_AND_WITHOUT_RETURN.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find only 1 function (the one with return)
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -400,8 +400,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(CLASS_METHODS.python, ".py")
js_file = write_temp_file(CLASS_METHODS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 2 methods
assert len(py_funcs) == 2, f"Python found {len(py_funcs)}, expected 2"
@ -421,8 +421,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(ASYNC_FUNCTIONS.python, ".py")
js_file = write_temp_file(ASYNC_FUNCTIONS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 2 functions
assert len(py_funcs) == 2, f"Python found {len(py_funcs)}, expected 2"
@ -440,32 +440,23 @@ class TestDiscoverFunctionsParity:
assert js_sync.is_async is False, "JavaScript sync function should have is_async=False"
def test_nested_functions_discovery(self, python_support, js_support):
"""Both should discover nested functions with parent info."""
"""Python skips nested functions; JavaScript discovers them with parent info."""
py_file = write_temp_file(NESTED_FUNCTIONS.python, ".py")
js_file = write_temp_file(NESTED_FUNCTIONS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 2 functions (outer and inner)
assert len(py_funcs) == 2, f"Python found {len(py_funcs)}, expected 2"
# Python skips nested functions — only outer is discovered
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
assert py_funcs[0].function_name == "outer"
# JavaScript discovers both
assert len(js_funcs) == 2, f"JavaScript found {len(js_funcs)}, expected 2"
# Check names
py_names = {f.function_name for f in py_funcs}
js_names = {f.function_name for f in js_funcs}
assert py_names == {"outer", "inner"}, f"Python found {py_names}"
assert js_names == {"outer", "inner"}, f"JavaScript found {js_names}"
# Check parent info for inner function
py_inner = next(f for f in py_funcs if f.function_name == "inner")
js_inner = next(f for f in js_funcs if f.function_name == "inner")
assert len(py_inner.parents) >= 1, "Python inner should have parent info"
assert py_inner.parents[0].name == "outer", "Python inner's parent should be outer"
# JavaScript nested function parent check
assert len(js_inner.parents) >= 1, "JavaScript inner should have parent info"
assert js_inner.parents[0].name == "outer", "JavaScript inner's parent should be outer"
@ -474,8 +465,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(STATIC_METHODS.python, ".py")
js_file = write_temp_file(STATIC_METHODS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 1 function
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -492,8 +483,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(COMPLEX_FILE.python, ".py")
js_file = write_temp_file(COMPLEX_FILE.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 4 functions
assert len(py_funcs) == 4, f"Python found {len(py_funcs)}, expected 4"
@ -524,8 +515,8 @@ class TestDiscoverFunctionsParity:
criteria = FunctionFilterCriteria(include_async=False)
py_funcs = python_support.discover_functions(py_file, criteria)
js_funcs = js_support.discover_functions(js_file, criteria)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file, criteria)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file, criteria)
# Both should find only 1 function (the sync one)
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -542,8 +533,8 @@ class TestDiscoverFunctionsParity:
criteria = FunctionFilterCriteria(include_methods=False)
py_funcs = python_support.discover_functions(py_file, criteria)
js_funcs = js_support.discover_functions(js_file, criteria)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file, criteria)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file, criteria)
# Both should find only 1 function (standalone)
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -554,11 +545,11 @@ class TestDiscoverFunctionsParity:
assert js_funcs[0].function_name == "standalone"
def test_nonexistent_file_returns_empty(self, python_support, js_support):
"""Both should return empty list for nonexistent files."""
py_funcs = python_support.discover_functions(Path("/nonexistent/file.py"))
js_funcs = js_support.discover_functions(Path("/nonexistent/file.js"))
"""Both languages return empty list for empty source."""
py_funcs = python_support.discover_functions("", Path("/nonexistent/file.py"))
assert py_funcs == []
js_funcs = js_support.discover_functions("", Path("/nonexistent/file.js"))
assert js_funcs == []
def test_line_numbers_captured(self, python_support, js_support):
@ -566,8 +557,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(SIMPLE_FUNCTION.python, ".py")
js_file = write_temp_file(SIMPLE_FUNCTION.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should have start_line and end_line
assert py_funcs[0].starting_line is not None
@ -917,8 +908,8 @@ class TestIntegrationParity:
js_file = write_temp_file(js_original, ".js")
# Discover
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
assert len(py_funcs) == 1
assert len(js_funcs) == 1
@ -969,8 +960,8 @@ class TestFeatureGaps:
py_file = write_temp_file(CLASS_METHODS.python, ".py")
js_file = write_temp_file(CLASS_METHODS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
for py_func in py_funcs:
# Check all expected fields are populated
@ -1003,7 +994,7 @@ export const multiply = (x, y) => x * y;
export const identity = x => x;
"""
js_file = write_temp_file(js_code, ".js")
funcs = js_support.discover_functions(js_file)
funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Should find all arrow functions
names = {f.function_name for f in funcs}
@ -1030,8 +1021,8 @@ export function* numberGenerator() {
py_file = write_temp_file(py_code, ".py")
js_file = write_temp_file(js_code, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find the generator
assert len(py_funcs) == 1, f"Python found {len(py_funcs)} generators"
@ -1054,7 +1045,7 @@ def multi_decorated():
return 3
"""
py_file = write_temp_file(py_code, ".py")
funcs = python_support.discover_functions(py_file)
funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
# Should find all functions regardless of decorators
names = {f.function_name for f in funcs}
@ -1074,7 +1065,7 @@ export const namedExpr = function myFunc(x) {
};
"""
js_file = write_temp_file(js_code, ".js")
funcs = js_support.discover_functions(js_file)
funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Should find function expressions
names = {f.function_name for f in funcs}
@ -1094,8 +1085,8 @@ class TestEdgeCases:
py_file = write_temp_file("", ".py")
js_file = write_temp_file("", ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
assert py_funcs == []
assert js_funcs == []
@ -1119,8 +1110,8 @@ Multiline comment
py_file = write_temp_file(py_code, ".py")
js_file = write_temp_file(js_code, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
assert py_funcs == []
assert js_funcs == []
@ -1139,8 +1130,8 @@ export function greeting() {
py_file = write_temp_file(py_code, ".py")
js_file = write_temp_file(js_code, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
assert len(py_funcs) == 1
assert len(js_funcs) == 1

View file

@ -0,0 +1,502 @@
"""Tests for Mocha test runner functionality."""
import json
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from junitparser import JUnitXml
class TestMochaJsonToJunitXml:
"""Tests for converting Mocha JSON reporter output to JUnit XML."""
def test_passing_tests(self):
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
mocha_json = json.dumps(
{
"stats": {"tests": 2, "passes": 2, "failures": 0, "duration": 50},
"tests": [
{
"title": "should add numbers",
"fullTitle": "math should add numbers",
"duration": 20,
"err": {},
},
{
"title": "should subtract numbers",
"fullTitle": "math should subtract numbers",
"duration": 30,
"err": {},
},
],
"passes": [],
"failures": [],
"pending": [],
}
)
with tempfile.TemporaryDirectory() as tmpdir:
output_file = Path(tmpdir) / "results.xml"
mocha_json_to_junit_xml(mocha_json, output_file)
assert output_file.exists()
xml = JUnitXml.fromfile(str(output_file))
total_tests = sum(suite.tests for suite in xml)
assert total_tests == 2
def test_failing_tests(self):
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
mocha_json = json.dumps(
{
"stats": {"tests": 1, "passes": 0, "failures": 1, "duration": 10},
"tests": [
{
"title": "should fail",
"fullTitle": "errors should fail",
"duration": 10,
"err": {
"message": "expected 1 to equal 2",
"stack": "AssertionError: expected 1 to equal 2\n at Context.<anonymous>",
},
},
],
"passes": [],
"failures": [],
"pending": [],
}
)
with tempfile.TemporaryDirectory() as tmpdir:
output_file = Path(tmpdir) / "results.xml"
mocha_json_to_junit_xml(mocha_json, output_file)
assert output_file.exists()
xml = JUnitXml.fromfile(str(output_file))
total_failures = sum(suite.failures for suite in xml)
assert total_failures == 1
def test_pending_tests(self):
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
mocha_json = json.dumps(
{
"stats": {"tests": 1, "passes": 0, "failures": 0, "pending": 1, "duration": 0},
"tests": [
{
"title": "should be pending",
"fullTitle": "todo should be pending",
"duration": 0,
"pending": True,
"err": {},
},
],
"passes": [],
"failures": [],
"pending": [],
}
)
with tempfile.TemporaryDirectory() as tmpdir:
output_file = Path(tmpdir) / "results.xml"
mocha_json_to_junit_xml(mocha_json, output_file)
assert output_file.exists()
xml = JUnitXml.fromfile(str(output_file))
# Should parse without error and have the test
total_tests = sum(suite.tests for suite in xml)
assert total_tests == 1
def test_invalid_json_writes_empty_xml(self):
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
with tempfile.TemporaryDirectory() as tmpdir:
output_file = Path(tmpdir) / "results.xml"
mocha_json_to_junit_xml("not valid json {{{", output_file)
assert output_file.exists()
content = output_file.read_text()
assert "<testsuites" in content
def test_multiple_suites(self):
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
mocha_json = json.dumps(
{
"stats": {"tests": 3, "passes": 3, "failures": 0, "duration": 100},
"tests": [
{"title": "test1", "fullTitle": "suite A test1", "duration": 10, "err": {}},
{"title": "test2", "fullTitle": "suite A test2", "duration": 20, "err": {}},
{"title": "test3", "fullTitle": "suite B test3", "duration": 30, "err": {}},
],
"passes": [],
"failures": [],
"pending": [],
}
)
with tempfile.TemporaryDirectory() as tmpdir:
output_file = Path(tmpdir) / "results.xml"
mocha_json_to_junit_xml(mocha_json, output_file)
xml = JUnitXml.fromfile(str(output_file))
suite_names = [suite.name for suite in xml]
assert "suite A" in suite_names
assert "suite B" in suite_names
class TestExtractMochaJson:
"""Tests for extracting Mocha JSON from mixed stdout."""
def test_clean_json(self):
from codeflash.languages.javascript.mocha_runner import _extract_mocha_json
data = {"stats": {"tests": 1}, "tests": []}
result = _extract_mocha_json(json.dumps(data))
assert result is not None
assert json.loads(result)["stats"]["tests"] == 1
def test_json_with_leading_output(self):
from codeflash.languages.javascript.mocha_runner import _extract_mocha_json
stdout = 'Some console output\n{"stats": {"tests": 1}, "tests": []}'
result = _extract_mocha_json(stdout)
assert result is not None
assert json.loads(result)["stats"]["tests"] == 1
def test_json_with_codeflash_markers(self):
from codeflash.languages.javascript.mocha_runner import _extract_mocha_json
data = {"stats": {"tests": 1}, "tests": []}
stdout = f"!######START:test:module:0:test_name######!\n{json.dumps(data)}\n!######END######!"
result = _extract_mocha_json(stdout)
assert result is not None
def test_no_json_returns_none(self):
from codeflash.languages.javascript.mocha_runner import _extract_mocha_json
result = _extract_mocha_json("no json here at all")
assert result is None
class TestFindMochaProjectRoot:
"""Tests for finding Mocha project root."""
def test_finds_mocharc_yml(self):
from codeflash.languages.javascript.mocha_runner import _find_mocha_project_root
with tempfile.TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
(root / ".mocharc.yml").write_text("timeout: 5000\n")
sub = root / "src" / "lib"
sub.mkdir(parents=True)
test_file = sub / "test.js"
test_file.write_text("// test")
result = _find_mocha_project_root(test_file)
assert result == root
def test_finds_mocharc_json(self):
from codeflash.languages.javascript.mocha_runner import _find_mocha_project_root
with tempfile.TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
(root / ".mocharc.json").write_text("{}")
test_file = root / "test.js"
test_file.write_text("// test")
result = _find_mocha_project_root(test_file)
assert result == root
def test_falls_back_to_package_json(self):
from codeflash.languages.javascript.mocha_runner import _find_mocha_project_root
with tempfile.TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
(root / "package.json").write_text('{"name": "test"}')
sub = root / "test"
sub.mkdir()
test_file = sub / "test.js"
test_file.write_text("// test")
result = _find_mocha_project_root(test_file)
assert result == root
class TestMochaBehavioralCommand:
"""Tests for building Mocha behavioral commands."""
def test_basic_command(self):
from codeflash.languages.javascript.mocha_runner import _build_mocha_behavioral_command
with tempfile.TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "test.js"
test_file.write_text("// test")
cmd = _build_mocha_behavioral_command(test_files=[test_file])
assert "npx" in cmd
assert "mocha" in cmd
assert "--reporter" in cmd
assert "json" in cmd
assert "--jobs" in cmd
assert "1" in cmd
assert "--exit" in cmd
def test_timeout_flag(self):
from codeflash.languages.javascript.mocha_runner import _build_mocha_behavioral_command
with tempfile.TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "test.js"
test_file.write_text("// test")
cmd = _build_mocha_behavioral_command(test_files=[test_file], timeout=30)
timeout_idx = cmd.index("--timeout")
assert cmd[timeout_idx + 1] == "30000"
def test_default_timeout(self):
from codeflash.languages.javascript.mocha_runner import _build_mocha_behavioral_command
with tempfile.TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "test.js"
test_file.write_text("// test")
cmd = _build_mocha_behavioral_command(test_files=[test_file])
timeout_idx = cmd.index("--timeout")
assert cmd[timeout_idx + 1] == "60000"
class TestMochaBenchmarkingCommand:
"""Tests for building Mocha benchmarking commands."""
def test_basic_command(self):
from codeflash.languages.javascript.mocha_runner import _build_mocha_benchmarking_command
with tempfile.TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "test.js"
test_file.write_text("// test")
cmd = _build_mocha_benchmarking_command(test_files=[test_file])
assert "npx" in cmd
assert "mocha" in cmd
assert "--exit" in cmd
def test_default_timeout_is_longer(self):
from codeflash.languages.javascript.mocha_runner import _build_mocha_benchmarking_command
with tempfile.TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "test.js"
test_file.write_text("// test")
cmd = _build_mocha_benchmarking_command(test_files=[test_file])
timeout_idx = cmd.index("--timeout")
assert cmd[timeout_idx + 1] == "120000"
class TestMochaLineProfileCommand:
"""Tests for building Mocha line profile commands."""
def test_basic_command(self):
from codeflash.languages.javascript.mocha_runner import _build_mocha_line_profile_command
with tempfile.TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "test.js"
test_file.write_text("// test")
cmd = _build_mocha_line_profile_command(test_files=[test_file])
assert "npx" in cmd
assert "mocha" in cmd
assert "--exit" in cmd
def test_timeout_conversion(self):
from codeflash.languages.javascript.mocha_runner import _build_mocha_line_profile_command
with tempfile.TemporaryDirectory() as tmpdir:
test_file = Path(tmpdir) / "test.js"
test_file.write_text("// test")
cmd = _build_mocha_line_profile_command(test_files=[test_file], timeout=45)
timeout_idx = cmd.index("--timeout")
assert cmd[timeout_idx + 1] == "45000"
class TestRunMochaBehavioralTests:
"""Tests for running Mocha behavioral tests with mocked subprocess."""
@patch("codeflash.languages.javascript.mocha_runner.subprocess.run")
@patch("codeflash.languages.javascript.mocha_runner._ensure_runtime_files")
def test_sets_codeflash_env_vars(self, mock_ensure, mock_run):
from codeflash.languages.javascript.mocha_runner import run_mocha_behavioral_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
mocha_output = json.dumps(
{"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10}, "tests": [{"title": "t", "fullTitle": "s t", "duration": 10, "err": {}}], "passes": [], "failures": [], "pending": []}
)
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_file = tmpdir_path / "test.test.js"
test_file.write_text("// test")
test_paths = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
]
)
result_file, result, cov, _ = run_mocha_behavioral_tests(
test_paths=test_paths,
test_env={},
cwd=tmpdir_path,
candidate_index=3,
)
# Verify env vars were passed
call_kwargs = mock_run.call_args
env = call_kwargs.kwargs.get("env") or call_kwargs[1].get("env", {})
assert env.get("CODEFLASH_MODE") == "behavior"
assert env.get("CODEFLASH_TEST_ITERATION") == "3"
assert env.get("CODEFLASH_RANDOM_SEED") == "42"
@patch("codeflash.languages.javascript.mocha_runner.subprocess.run")
@patch("codeflash.languages.javascript.mocha_runner._ensure_runtime_files")
def test_returns_none_coverage(self, mock_ensure, mock_run):
from codeflash.languages.javascript.mocha_runner import run_mocha_behavioral_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
mocha_output = json.dumps(
{"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0}, "tests": [], "passes": [], "failures": [], "pending": []}
)
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_file = tmpdir_path / "test.test.js"
test_file.write_text("// test")
test_paths = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
]
)
_, _, coverage_path, _ = run_mocha_behavioral_tests(
test_paths=test_paths,
test_env={},
cwd=tmpdir_path,
)
assert coverage_path is None
class TestRunMochaBenchmarkingTests:
"""Tests for running Mocha benchmarking tests with mocked subprocess."""
@patch("codeflash.languages.javascript.mocha_runner.subprocess.run")
@patch("codeflash.languages.javascript.mocha_runner._ensure_runtime_files")
def test_sets_perf_env_vars(self, mock_ensure, mock_run):
from codeflash.languages.javascript.mocha_runner import run_mocha_benchmarking_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
mocha_output = json.dumps(
{"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 100}, "tests": [{"title": "perf", "fullTitle": "bench perf", "duration": 100, "err": {}}], "passes": [], "failures": [], "pending": []}
)
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_file = tmpdir_path / "perf.test.js"
test_file.write_text("// perf test")
test_paths = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
]
)
run_mocha_benchmarking_tests(
test_paths=test_paths,
test_env={},
cwd=tmpdir_path,
min_loops=3,
max_loops=50,
target_duration_ms=5000,
stability_check=False,
)
call_kwargs = mock_run.call_args
env = call_kwargs.kwargs.get("env") or call_kwargs[1].get("env", {})
assert env.get("CODEFLASH_MODE") == "performance"
assert env.get("CODEFLASH_PERF_LOOP_COUNT") == "50"
assert env.get("CODEFLASH_PERF_MIN_LOOPS") == "3"
assert env.get("CODEFLASH_PERF_TARGET_DURATION_MS") == "5000"
assert env.get("CODEFLASH_PERF_STABILITY_CHECK") == "false"
class TestRunMochaLineProfileTests:
"""Tests for running Mocha line profile tests with mocked subprocess."""
@patch("codeflash.languages.javascript.mocha_runner.subprocess.run")
@patch("codeflash.languages.javascript.mocha_runner._ensure_runtime_files")
def test_sets_line_profile_env_vars(self, mock_ensure, mock_run):
from codeflash.languages.javascript.mocha_runner import run_mocha_line_profile_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
mocha_output = json.dumps(
{"stats": {"tests": 0, "passes": 0, "failures": 0, "duration": 0}, "tests": [], "passes": [], "failures": [], "pending": []}
)
mock_run.return_value = MagicMock(returncode=0, stdout=mocha_output, stderr="", args=[])
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_file = tmpdir_path / "test.test.js"
test_file.write_text("// test")
profile_output = tmpdir_path / "profile.json"
test_paths = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
]
)
run_mocha_line_profile_tests(
test_paths=test_paths,
test_env={},
cwd=tmpdir_path,
line_profile_output_file=profile_output,
)
call_kwargs = mock_run.call_args
env = call_kwargs.kwargs.get("env") or call_kwargs[1].get("env", {})
assert env.get("CODEFLASH_MODE") == "line_profile"
assert env.get("CODEFLASH_LINE_PROFILE_OUTPUT") == str(profile_output)

View file

@ -82,9 +82,9 @@ from pathlib import Path
from unittest.mock import MagicMock
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
from codeflash.languages.registry import get_language_support
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -110,7 +110,7 @@ def test_js_replcement() -> None:
original_helper = helper_file.read_text("utf-8")
js_support = get_language_support("javascript")
functions = js_support.discover_functions(main_file)
functions = js_support.discover_functions(main_file.read_text(encoding="utf-8"), main_file)
target = None
for func in functions:
if func.function_name == "calculateStats":
@ -135,7 +135,7 @@ def test_js_replcement() -> None:
project_root_path=root_dir,
pytest_cmd="jest",
)
func_optimizer = FunctionOptimizer(
func_optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
)
result = func_optimizer.get_code_optimization_context()

View file

@ -49,7 +49,7 @@ def add(a, b):
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 1
assert functions[0].function_name == "add"
@ -70,7 +70,7 @@ def multiply(a, b):
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 3
names = {func.function_name for func in functions}
@ -88,7 +88,7 @@ def without_return():
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
# Only the function with return should be discovered
assert len(functions) == 1
@ -107,7 +107,7 @@ class Calculator:
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 2
for func in functions:
@ -126,7 +126,7 @@ def sync_function():
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 2
@ -137,7 +137,7 @@ def sync_function():
assert sync_func.is_async is False
def test_discover_nested_functions(self, python_support):
"""Test discovering nested functions."""
"""Test that nested functions are excluded — only top-level and class-level functions are discovered."""
with tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False) as f:
f.write("""
def outer():
@ -147,18 +147,11 @@ def outer():
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
# Both outer and inner should be discovered
assert len(functions) == 2
names = {func.function_name for func in functions}
assert names == {"outer", "inner"}
# Inner should have outer as parent
inner = next(f for f in functions if f.function_name == "inner")
assert len(inner.parents) == 1
assert inner.parents[0].name == "outer"
assert inner.parents[0].type == "FunctionDef"
# Only outer should be discovered; inner is nested and skipped
assert len(functions) == 1
assert functions[0].function_name == "outer"
def test_discover_static_method(self, python_support):
"""Test discovering static methods."""
@ -171,7 +164,7 @@ class Utils:
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 1
assert functions[0].function_name == "helper"
@ -190,7 +183,9 @@ def sync_func():
f.flush()
criteria = FunctionFilterCriteria(include_async=False)
functions = python_support.discover_functions(Path(f.name), criteria)
functions = python_support.discover_functions(
Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria
)
assert len(functions) == 1
assert functions[0].function_name == "sync_func"
@ -209,7 +204,9 @@ class MyClass:
f.flush()
criteria = FunctionFilterCriteria(include_methods=False)
functions = python_support.discover_functions(Path(f.name), criteria)
functions = python_support.discover_functions(
Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria
)
assert len(functions) == 1
assert functions[0].function_name == "standalone"
@ -227,7 +224,7 @@ def func2():
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
func1 = next(f for f in functions if f.function_name == "func1")
func2 = next(f for f in functions if f.function_name == "func2")
@ -237,18 +234,20 @@ def func2():
assert func2.starting_line == 4
assert func2.ending_line == 7
def test_discover_invalid_file_returns_empty(self, python_support):
"""Test that invalid Python file returns empty list."""
def test_discover_invalid_file_raises(self, python_support):
"""Test that invalid Python file raises a parse error."""
from libcst._exceptions import ParserSyntaxError
with tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False) as f:
f.write("this is not valid python {{{{")
f.flush()
functions = python_support.discover_functions(Path(f.name))
assert functions == []
with pytest.raises(ParserSyntaxError):
python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
def test_discover_nonexistent_file_returns_empty(self, python_support):
"""Test that nonexistent file returns empty list."""
functions = python_support.discover_functions(Path("/nonexistent/file.py"))
def test_discover_empty_source_returns_empty(self, python_support):
"""Test that empty source returns empty list."""
functions = python_support.discover_functions("", Path("/nonexistent/file.py"))
assert functions == []
@ -500,7 +499,7 @@ class TestIntegration:
file_path = Path(f.name)
# Discover
functions = python_support.discover_functions(file_path)
functions = python_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "fibonacci"
@ -541,7 +540,7 @@ def standalone():
f.flush()
file_path = Path(f.name)
functions = python_support.discover_functions(file_path)
functions = python_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should find 4 functions
assert len(functions) == 4
@ -584,12 +583,7 @@ def process(value):
return helper_function(value) + 1
""")
func = FunctionToOptimize(
function_name="helper_function",
file_path=source_file,
starting_line=1,
ending_line=2,
)
func = FunctionToOptimize(function_name="helper_function", file_path=source_file, starting_line=1, ending_line=2)
refs = python_support.find_references(func, project_root=tmp_path)
@ -646,12 +640,7 @@ def test_find_references_no_references(python_support, tmp_path):
return 42
""")
func = FunctionToOptimize(
function_name="isolated_function",
file_path=source_file,
starting_line=1,
ending_line=2,
)
func = FunctionToOptimize(function_name="isolated_function", file_path=source_file, starting_line=1, ending_line=2)
refs = python_support.find_references(func, project_root=tmp_path)
@ -668,10 +657,7 @@ def test_find_references_nonexistent_function(python_support, tmp_path):
""")
func = FunctionToOptimize(
function_name="nonexistent_function",
file_path=source_file,
starting_line=1,
ending_line=2,
function_name="nonexistent_function", file_path=source_file, starting_line=1, ending_line=2
)
refs = python_support.find_references(func, project_root=tmp_path)

View file

@ -821,3 +821,153 @@ export default curry(traverseEntity);"""
# createVisitorUtils is NOT wrapped, so not exported via default
is_utils_exported, _ = ts_analyzer.is_function_exported(code, "createVisitorUtils")
assert is_utils_exported is False
class TestNamedExportConstArrow:
"""Tests for const arrow functions exported via named export clause.
Pattern: const joinBy = () => {}; export { joinBy };
This is common in TypeScript codebases like Strapi.
"""
@pytest.fixture
def ts_analyzer(self):
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
@pytest.fixture
def js_analyzer(self):
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_named_export_const_arrow(self, ts_analyzer):
"""const arrow function exported via separate export { } clause."""
code = """const joinBy = (arr: string[], separator: string) => {
return arr.join(separator);
};
export { joinBy };"""
functions = ts_analyzer.find_functions(code)
joinBy = next((f for f in functions if f.name == "joinBy"), None)
assert joinBy is not None
assert joinBy.is_exported is True
def test_named_export_alias(self, ts_analyzer):
"""export { foo as bar } — foo should be marked as exported."""
code = """const foo = (x: number) => {
return x * 2;
};
export { foo as bar };"""
functions = ts_analyzer.find_functions(code)
foo = next((f for f in functions if f.name == "foo"), None)
assert foo is not None
assert foo.is_exported is True
def test_named_export_multiple(self, ts_analyzer):
"""Multiple functions in a single export clause."""
code = """const a = () => { return 1; };
const b = () => { return 2; };
const c = () => { return 3; };
export { a, b };"""
functions = ts_analyzer.find_functions(code)
a = next((f for f in functions if f.name == "a"), None)
b = next((f for f in functions if f.name == "b"), None)
c = next((f for f in functions if f.name == "c"), None)
assert a is not None and a.is_exported is True
assert b is not None and b.is_exported is True
assert c is not None and c.is_exported is False
def test_named_export_function_declaration(self, js_analyzer):
"""Regular function declarations exported via export { }."""
code = """function processData(data) {
return data;
}
export { processData };"""
functions = js_analyzer.find_functions(code)
f = next((f for f in functions if f.name == "processData"), None)
assert f is not None
assert f.is_exported is True
def test_is_function_exported_with_named_export(self, ts_analyzer):
"""is_function_exported should detect named export clause."""
code = """const joinBy = (arr: string[], separator: string) => {
return arr.join(separator);
};
export { joinBy };"""
is_exported, name = ts_analyzer.is_function_exported(code, "joinBy")
assert is_exported is True
class TestCjsReexportObjectMethods:
"""Tests for CJS re-export of object containing methods.
Pattern: const utils = { match() {} }; module.exports = utils;
This is common in Node.js libraries like Moleculer.
"""
@pytest.fixture
def js_analyzer(self):
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
def test_cjs_reexport_object_methods(self, js_analyzer):
"""module.exports = varName where varName is object with methods."""
code = """const utils = {
match(text, pattern) {
return text.match(pattern);
},
slugify(str) {
return str.toLowerCase();
}
};
module.exports = utils;"""
is_exported, name = js_analyzer.is_function_exported(code, "match")
assert is_exported is True
is_exported2, _ = js_analyzer.is_function_exported(code, "slugify")
assert is_exported2 is True
def test_cjs_reexport_shorthand_props(self, js_analyzer):
"""module.exports = varName where object has shorthand properties."""
code = """function match(text, pattern) {
return text.match(pattern);
}
const utils = { match };
module.exports = utils;"""
is_exported, _ = js_analyzer.is_function_exported(code, "match")
assert is_exported is True
def test_cjs_reexport_pair_props(self, js_analyzer):
"""module.exports = varName where object has key: value pairs."""
code = """function myMatch(text, pattern) {
return text.match(pattern);
}
const utils = { match: myMatch };
module.exports = utils;"""
is_exported, _ = js_analyzer.is_function_exported(code, "match")
assert is_exported is True
def test_cjs_reexport_nonexistent_prop(self, js_analyzer):
"""A function not in the re-exported object should not be exported."""
code = """function helper() { return 1; }
const utils = {
match(text) { return text; }
};
module.exports = utils;"""
is_exported, _ = js_analyzer.is_function_exported(code, "helper")
assert is_exported is False

View file

@ -13,7 +13,7 @@ from pathlib import Path
import pytest
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
from codeflash.languages.base import Language
from codeflash.languages.javascript.support import TypeScriptSupport
@ -126,14 +126,13 @@ export function add(a: number, b: number): number {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
assert len(functions) == 1
assert functions[0].function_name == "add"
# Extract code context
code_context = ts_support.extract_code_context(
functions[0], file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
# Verify extracted code is valid
assert ts_support.validate_syntax(code_context.target_code) is True
@ -164,14 +163,13 @@ export async function execMongoEval(queryExpression, appsmithMongoURI) {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
assert len(functions) == 1
assert functions[0].function_name == "execMongoEval"
# Extract code context
code_context = ts_support.extract_code_context(
functions[0], file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
# Verify extracted code is valid
assert ts_support.validate_syntax(code_context.target_code) is True
@ -215,14 +213,13 @@ export async function figureOutContentsPath(root: string): Promise<string> {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
assert len(functions) == 1
assert functions[0].function_name == "figureOutContentsPath"
# Extract code context
code_context = ts_support.extract_code_context(
functions[0], file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
# Verify extracted code is valid
assert ts_support.validate_syntax(code_context.target_code) is True
@ -246,12 +243,11 @@ export function readConfig(filename: string): string {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
assert len(functions) == 1
code_context = ts_support.extract_code_context(
functions[0], file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
# Check that imports are captured
assert len(code_context.imports) > 0
@ -278,12 +274,11 @@ export async function fetchWithRetry(url: string): Promise<any> {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
assert len(functions) == 1
code_context = ts_support.extract_code_context(
functions[0], file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
# Verify extracted code is valid
assert ts_support.validate_syntax(code_context.target_code) is True
@ -324,7 +319,8 @@ export class EndpointGroup {
file_path = Path(f.name)
# Discover the 'post' method
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
post_method = None
for func in functions:
if func.function_name == "post":
@ -334,9 +330,7 @@ export class EndpointGroup {
assert post_method is not None, "post method should be discovered"
# Extract code context
code_context = ts_support.extract_code_context(
post_method, file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(post_method, file_path.parent, file_path.parent)
# The extracted code should be syntactically valid
assert ts_support.validate_syntax(code_context.target_code) is True, (
@ -352,9 +346,7 @@ export class EndpointGroup {
# Check that addEndpoint appears BEFORE the closing brace of the class
class_end_index = code_context.target_code.rfind("}")
add_endpoint_index = code_context.target_code.find("addEndpoint")
assert add_endpoint_index < class_end_index, (
"addEndpoint should be inside the class wrapper"
)
assert add_endpoint_index < class_end_index, "addEndpoint should be inside the class wrapper"
def test_multiple_private_helpers_inside_class(self, ts_support):
"""Test that multiple private helpers are all included inside the class."""
@ -386,7 +378,8 @@ export class Router {
file_path = Path(f.name)
# Discover the 'addRoute' method
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
add_route_method = None
for func in functions:
if func.function_name == "addRoute":
@ -395,9 +388,7 @@ export class Router {
assert add_route_method is not None
code_context = ts_support.extract_code_context(
add_route_method, file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(add_route_method, file_path.parent, file_path.parent)
# Should be valid TypeScript
assert ts_support.validate_syntax(code_context.target_code) is True
@ -424,7 +415,8 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
add_method = None
for func in functions:
if func.function_name == "add":
@ -433,18 +425,14 @@ export class Calculator {
assert add_method is not None
code_context = ts_support.extract_code_context(
add_method, file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(add_method, file_path.parent, file_path.parent)
# 'compute' should be in target_code (inside class)
assert "compute" in code_context.target_code
# 'compute' should NOT be in helper_functions (would be duplicate)
helper_names = [h.name for h in code_context.helper_functions]
assert "compute" not in helper_names, (
"Same-class helper 'compute' should not be in helper_functions list"
)
assert "compute" not in helper_names, "Same-class helper 'compute' should not be in helper_functions list"
class TestTypeScriptLanguageProperties:

View file

@ -124,10 +124,8 @@ class TestTypeScriptCodeContext:
"""Test extracting code context for a TypeScript function."""
skip_if_ts_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages import get_language_support
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
fib_file = ts_project_dir / "fibonacci.ts"
if not fib_file.exists():
@ -139,7 +137,11 @@ class TestTypeScriptCodeContext:
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
assert fib_func is not None
context = get_code_optimization_context(fib_func, ts_project_dir)
ts_support = get_language_support(Language.TYPESCRIPT)
code_context = ts_support.extract_code_context(fib_func, ts_project_dir, ts_project_dir)
context = JavaScriptFunctionOptimizer._build_optimization_context(
code_context, fib_file, "typescript", ts_project_dir
)
assert context.read_writable_code is not None
# Critical: language should be "typescript", not "javascript"

View file

@ -118,11 +118,9 @@ class TestVitestCodeContext:
"""Test extracting code context for a TypeScript function."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
fib_file = vitest_project_dir / "fibonacci.ts"
if not fib_file.exists():
@ -134,7 +132,11 @@ class TestVitestCodeContext:
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
assert fib_func is not None
context = get_code_optimization_context(fib_func, vitest_project_dir)
ts_support = get_language_support(Language.TYPESCRIPT)
code_context = ts_support.extract_code_context(fib_func, vitest_project_dir, vitest_project_dir)
context = JavaScriptFunctionOptimizer._build_optimization_context(
code_context, fib_file, "typescript", vitest_project_dir
)
assert context.read_writable_code is not None
assert context.read_writable_code.language == "typescript"

View file

@ -1,10 +1,18 @@
import os
import sys
import types
from typing import NoReturn
from unittest.mock import patch
import pytest
from _pytest.config import Config
from codeflash.verification.pytest_plugin import PytestLoops
from codeflash.verification.pytest_plugin import (
InvalidTimeParameterError,
PytestLoops,
get_runtime_from_stdout,
should_stop,
)
@pytest.fixture
@ -15,39 +23,301 @@ def pytest_loops_instance(pytestconfig: Config) -> PytestLoops:
@pytest.fixture
def mock_item() -> type:
class MockItem:
def __init__(self, function: types.FunctionType) -> None:
def __init__(self, function: types.FunctionType, name: str = "test_func", cls: type = None, module: types.ModuleType = None) -> None:
self.function = function
self.name = name
self.cls = cls
self.module = module
return MockItem
def create_mock_module(module_name: str, source_code: str) -> types.ModuleType:
def create_mock_module(module_name: str, source_code: str, register: bool = False) -> types.ModuleType:
module = types.ModuleType(module_name)
exec(source_code, module.__dict__) # noqa: S102
if register:
sys.modules[module_name] = module
return module
def test_clear_lru_caches_function(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
source_code = """
def mock_session(**kwargs):
"""Create a mock session with config options."""
defaults = {
"codeflash_hours": 0,
"codeflash_minutes": 0,
"codeflash_seconds": 10,
"codeflash_delay": 0.0,
"codeflash_loops": 1,
"codeflash_min_loops": 1,
"codeflash_max_loops": 100_000,
}
defaults.update(kwargs)
class Option:
pass
option = Option()
for k, v in defaults.items():
setattr(option, k, v)
class MockConfig:
pass
config = MockConfig()
config.option = option
class MockSession:
pass
session = MockSession()
session.config = config
return session
# --- get_runtime_from_stdout ---
class TestGetRuntimeFromStdout:
def test_valid_payload(self) -> None:
assert get_runtime_from_stdout("!######test_func:12345######!") == 12345
def test_valid_payload_with_surrounding_text(self) -> None:
assert get_runtime_from_stdout("some output\n!######mod.func:99999######!\nmore output") == 99999
def test_empty_string(self) -> None:
assert get_runtime_from_stdout("") is None
def test_no_markers(self) -> None:
assert get_runtime_from_stdout("just some output") is None
def test_missing_end_marker(self) -> None:
assert get_runtime_from_stdout("!######test:123") is None
def test_missing_start_marker(self) -> None:
assert get_runtime_from_stdout("test:123######!") is None
def test_no_colon_in_payload(self) -> None:
assert get_runtime_from_stdout("!######nocolon######!") is None
def test_non_integer_value(self) -> None:
assert get_runtime_from_stdout("!######test:notanumber######!") is None
def test_multiple_markers_uses_last(self) -> None:
stdout = "!######first:111######! middle !######second:222######!"
assert get_runtime_from_stdout(stdout) == 222
# --- should_stop ---
class TestShouldStop:
def test_not_enough_data_for_window(self) -> None:
assert should_stop([100, 100], window=5, min_window_size=3) is False
def test_below_min_window_size(self) -> None:
assert should_stop([100, 100], window=2, min_window_size=5) is False
def test_stable_runtimes_stops(self) -> None:
runtimes = [1000000] * 10
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.01, spread_rel_tol=0.01) is True
def test_unstable_runtimes_continues(self) -> None:
runtimes = [100, 200, 100, 200, 100]
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.01, spread_rel_tol=0.01) is False
def test_zero_runtimes_raises(self) -> None:
# All-zero runtimes cause ZeroDivisionError in median check.
# In practice the caller guards with best_runtime_until_now > 0.
runtimes = [0, 0, 0, 0, 0]
with pytest.raises(ZeroDivisionError):
should_stop(runtimes, window=5, min_window_size=3)
def test_even_window_median(self) -> None:
# Even window: median is average of two middle values
runtimes = [1000, 1000, 1001, 1001]
assert should_stop(runtimes, window=4, min_window_size=2, center_rel_tol=0.01, spread_rel_tol=0.01) is True
def test_centered_but_spread_too_large(self) -> None:
# All close to median but spread exceeds tolerance
runtimes = [1000, 1050, 1000, 1050, 1000]
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.1, spread_rel_tol=0.001) is False
# --- _set_nodeid ---
class TestSetNodeid:
def test_appends_count_to_plain_nodeid(self, pytest_loops_instance: PytestLoops) -> None:
result = pytest_loops_instance._set_nodeid("test_module.py::test_func", 3) # noqa: SLF001
assert result == "test_module.py::test_func[ 3 ]"
assert os.environ["CODEFLASH_LOOP_INDEX"] == "3"
def test_replaces_existing_count(self, pytest_loops_instance: PytestLoops) -> None:
result = pytest_loops_instance._set_nodeid("test_module.py::test_func[ 1 ]", 5) # noqa: SLF001
assert result == "test_module.py::test_func[ 5 ]"
def test_replaces_only_loop_pattern(self, pytest_loops_instance: PytestLoops) -> None:
# Parametrize brackets like [param0] should not be replaced
result = pytest_loops_instance._set_nodeid("test_mod.py::test_func[param0]", 2) # noqa: SLF001
assert result == "test_mod.py::test_func[param0][ 2 ]"
# --- _get_total_time ---
class TestGetTotalTime:
def test_seconds_only(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_seconds=30)
assert pytest_loops_instance._get_total_time(session) == 30 # noqa: SLF001
def test_mixed_units(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_hours=1, codeflash_minutes=30, codeflash_seconds=45)
assert pytest_loops_instance._get_total_time(session) == 3600 + 1800 + 45 # noqa: SLF001
def test_zero_time_is_valid(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_hours=0, codeflash_minutes=0, codeflash_seconds=0)
assert pytest_loops_instance._get_total_time(session) == 0 # noqa: SLF001
def test_negative_time_raises(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_hours=0, codeflash_minutes=0, codeflash_seconds=-1)
with pytest.raises(InvalidTimeParameterError):
pytest_loops_instance._get_total_time(session) # noqa: SLF001
# --- _timed_out ---
class TestTimedOut:
def test_exceeds_max_loops(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_max_loops=10, codeflash_min_loops=1, codeflash_seconds=9999)
assert pytest_loops_instance._timed_out(session, start_time=0, count=10) is True # noqa: SLF001
def test_below_min_loops_never_times_out(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_max_loops=100_000, codeflash_min_loops=50, codeflash_seconds=0)
# Even with 0 seconds budget, count < min_loops means not timed out
assert pytest_loops_instance._timed_out(session, start_time=0, count=5) is False # noqa: SLF001
def test_above_min_loops_and_time_exceeded(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_max_loops=100_000, codeflash_min_loops=1, codeflash_seconds=1)
# start_time far in the past → time exceeded
assert pytest_loops_instance._timed_out(session, start_time=0, count=2) is True # noqa: SLF001
# --- _get_delay_time ---
class TestGetDelayTime:
def test_returns_configured_delay(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_delay=0.5)
assert pytest_loops_instance._get_delay_time(session) == 0.5 # noqa: SLF001
# --- pytest_runtest_logreport ---
class TestRunTestLogReport:
def test_skipped_when_stability_check_disabled(self, pytestconfig: Config) -> None:
instance = PytestLoops(pytestconfig)
instance.enable_stability_check = False
class MockReport:
when = "call"
passed = True
capstdout = "!######func:12345######!"
nodeid = "test::func"
instance.pytest_runtest_logreport(MockReport())
assert instance.runtime_data_by_test_case == {}
def test_records_runtime_on_passed_call(self, pytestconfig: Config) -> None:
instance = PytestLoops(pytestconfig)
instance.enable_stability_check = True
class MockReport:
when = "call"
passed = True
capstdout = "!######func:12345######!"
nodeid = "test::func [ 1 ]"
instance.pytest_runtest_logreport(MockReport())
assert "test::func" in instance.runtime_data_by_test_case
assert instance.runtime_data_by_test_case["test::func"] == [12345]
def test_ignores_non_call_phase(self, pytestconfig: Config) -> None:
instance = PytestLoops(pytestconfig)
instance.enable_stability_check = True
class MockReport:
when = "setup"
passed = True
capstdout = "!######func:12345######!"
nodeid = "test::func"
instance.pytest_runtest_logreport(MockReport())
assert instance.runtime_data_by_test_case == {}
# --- pytest_runtest_setup / teardown ---
class TestRunTestSetupTeardown:
def test_setup_sets_env_vars(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module = types.ModuleType("my_test_module")
class MyTestClass:
pass
item = mock_item(lambda: None, name="test_something[param1]", cls=MyTestClass, module=module)
pytest_loops_instance.pytest_runtest_setup(item)
assert os.environ["CODEFLASH_TEST_MODULE"] == "my_test_module"
assert os.environ["CODEFLASH_TEST_CLASS"] == "MyTestClass"
assert os.environ["CODEFLASH_TEST_FUNCTION"] == "test_something"
def test_setup_no_class(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module = types.ModuleType("my_test_module")
item = mock_item(lambda: None, name="test_plain", cls=None, module=module)
pytest_loops_instance.pytest_runtest_setup(item)
assert os.environ["CODEFLASH_TEST_CLASS"] == ""
def test_teardown_clears_env_vars(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
os.environ["CODEFLASH_TEST_MODULE"] = "leftover"
os.environ["CODEFLASH_TEST_CLASS"] = "leftover"
os.environ["CODEFLASH_TEST_FUNCTION"] = "leftover"
item = mock_item(lambda: None)
pytest_loops_instance.pytest_runtest_teardown(item)
assert "CODEFLASH_TEST_MODULE" not in os.environ
assert "CODEFLASH_TEST_CLASS" not in os.environ
assert "CODEFLASH_TEST_FUNCTION" not in os.environ
# --- _clear_lru_caches ---
class TestClearLruCaches:
def test_clears_lru_cached_function(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
source_code = """
import functools
@functools.lru_cache(maxsize=None)
def my_func(x):
return x * 2
my_func(10) # miss the cache
my_func(10) # hit the cache
my_func(10)
my_func(10)
"""
mock_module = create_mock_module("test_module_func", source_code)
item = mock_item(mock_module.my_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.my_func.cache_info().hits == 0
assert mock_module.my_func.cache_info().misses == 0
assert mock_module.my_func.cache_info().currsize == 0
mock_module = create_mock_module("test_module_func", source_code)
item = mock_item(mock_module.my_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.my_func.cache_info().hits == 0
assert mock_module.my_func.cache_info().misses == 0
assert mock_module.my_func.cache_info().currsize == 0
def test_clear_lru_caches_class_method(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
source_code = """
def test_clears_class_method_cache(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
source_code = """
import functools
class MyClass:
@ -56,32 +326,137 @@ class MyClass:
return x * 3
obj = MyClass()
obj.my_method(5) # Pre-populate the cache
obj.my_method(5) # Hit the cache
obj.my_method(5)
obj.my_method(5)
# """
mock_module = create_mock_module("test_module_class", source_code)
item = mock_item(mock_module.MyClass.my_method)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.MyClass.my_method.cache_info().hits == 0
assert mock_module.MyClass.my_method.cache_info().misses == 0
assert mock_module.MyClass.my_method.cache_info().currsize == 0
mock_module = create_mock_module("test_module_class", source_code)
item = mock_item(mock_module.MyClass.my_method)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.MyClass.my_method.cache_info().hits == 0
assert mock_module.MyClass.my_method.cache_info().misses == 0
assert mock_module.MyClass.my_method.cache_info().currsize == 0
def test_handles_exception_in_cache_clear(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
class BrokenCache:
def cache_clear(self) -> NoReturn:
msg = "Cache clearing failed!"
raise ValueError(msg)
def test_clear_lru_caches_exception_handling(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
"""Test that exceptions during clearing are handled."""
item = mock_item(BrokenCache())
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
class BrokenCache:
def cache_clear(self) -> NoReturn:
msg = "Cache clearing failed!"
raise ValueError(msg)
def test_handles_no_cache(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
def no_cache_func(x: int) -> int:
return x
item = mock_item(BrokenCache())
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
item = mock_item(no_cache_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
def test_clears_module_level_caches_via_sys_modules(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module_name = "_cf_test_module_scan"
source_code = """
import functools
def test_clear_lru_caches_no_cache(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
def no_cache_func(x: int) -> int:
return x
@functools.lru_cache(maxsize=None)
def cached_a(x):
return x + 1
item = mock_item(no_cache_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
@functools.lru_cache(maxsize=None)
def cached_b(x):
return x + 2
def plain_func(x):
return x
cached_a(1)
cached_a(1)
cached_b(2)
cached_b(2)
"""
mock_module = create_mock_module(module_name, source_code, register=True)
try:
item = mock_item(mock_module.plain_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.cached_a.cache_info().currsize == 0
assert mock_module.cached_b.cache_info().currsize == 0
finally:
sys.modules.pop(module_name, None)
def test_skips_protected_modules(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module_name = "_cf_test_protected"
source_code = """
import functools
@functools.lru_cache(maxsize=None)
def user_func(x):
return x
"""
mock_module = create_mock_module(module_name, source_code, register=True)
try:
mock_module.os_exists = os.path.exists
item = mock_item(mock_module.user_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
finally:
sys.modules.pop(module_name, None)
def test_caches_scan_result(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module_name = "_cf_test_cache_reuse"
source_code = """
import functools
@functools.lru_cache(maxsize=None)
def cached_fn(x):
return x
"""
mock_module = create_mock_module(module_name, source_code, register=True)
try:
item = mock_item(mock_module.cached_fn)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert module_name in pytest_loops_instance._module_clearables # noqa: SLF001
mock_module.cached_fn(42)
assert mock_module.cached_fn.cache_info().currsize == 1
with patch("codeflash.verification.pytest_plugin.inspect.getmembers") as mock_getmembers:
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
mock_getmembers.assert_not_called()
assert mock_module.cached_fn.cache_info().currsize == 0
finally:
sys.modules.pop(module_name, None)
def test_handles_wrapped_function(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module_name = "_cf_test_wrapped"
source_code = """
import functools
@functools.lru_cache(maxsize=None)
def inner(x):
return x
def wrapper(x):
return inner(x)
wrapper.__wrapped__ = inner
wrapper.__module__ = __name__
inner(1)
inner(1)
"""
mock_module = create_mock_module(module_name, source_code, register=True)
try:
item = mock_item(mock_module.wrapper)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.inner.cache_info().currsize == 0
finally:
sys.modules.pop(module_name, None)
def test_handles_function_without_module(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
def func() -> None:
pass
func.__module__ = None # type: ignore[assignment]
item = mock_item(func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001

View file

@ -3,9 +3,9 @@ import tempfile
from pathlib import Path
from codeflash.code_utils.code_utils import ImportErrorPattern
from codeflash.languages import current_language_support
from codeflash.models.models import TestFile, TestFiles, TestType
from codeflash.verification.parse_test_output import parse_test_xml
from codeflash.verification.test_runner import run_behavioral_tests
from codeflash.verification.verification_utils import TestConfig
@ -48,8 +48,8 @@ class TestUnittestRunnerSorter(unittest.TestCase):
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
)
test_file_path.write_text(code, encoding="utf-8")
result_file, process, _, _ = run_behavioral_tests(
test_files, test_framework=config.test_framework, cwd=Path(config.project_root_path), test_env=test_env
result_file, process, _, _ = current_language_support().run_behavioral_tests(
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path)
)
results = parse_test_xml(result_file, test_files, config, process)
assert results[0].did_pass, "Test did not pass as expected"
@ -89,13 +89,8 @@ def test_sort():
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
)
test_file_path.write_text(code, encoding="utf-8")
result_file, process, _, _ = run_behavioral_tests(
test_files,
test_framework=config.test_framework,
cwd=Path(config.project_root_path),
test_env=test_env,
pytest_timeout=1,
pytest_target_runtime_seconds=1,
result_file, process, _, _ = current_language_support().run_behavioral_tests(
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path), timeout=1
)
results = parse_test_xml(
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process
@ -136,13 +131,8 @@ def test_sort():
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
)
test_file_path.write_text(code, encoding="utf-8")
result_file, process, _, _ = run_behavioral_tests(
test_files,
test_framework=config.test_framework,
cwd=Path(config.project_root_path),
test_env=test_env,
pytest_timeout=1,
pytest_target_runtime_seconds=1,
result_file, process, _, _ = current_language_support().run_behavioral_tests(
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path), timeout=1
)
results = parse_test_xml(
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process

View file

@ -10,8 +10,8 @@ from codeflash.languages.python.context.unused_definition_remover import (
detect_unused_helper_functions,
revert_unused_helper_functions,
)
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.models.models import CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -83,7 +83,7 @@ def helper_function_2(x):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -194,7 +194,7 @@ def helper_function_2(x):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -269,7 +269,7 @@ def helper_function_2(x):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -365,7 +365,7 @@ def entrypoint_function(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -559,7 +559,7 @@ class Calculator:
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -710,7 +710,7 @@ class Processor:
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -895,7 +895,7 @@ class OuterClass:
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1051,7 +1051,7 @@ def entrypoint_function(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1215,7 +1215,7 @@ def entrypoint_function(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1442,7 +1442,7 @@ class MathUtils:
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1576,7 +1576,7 @@ async def async_entrypoint(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1664,7 +1664,7 @@ def sync_entrypoint(n):
function_to_optimize = FunctionToOptimize(file_path=main_file, function_name="sync_entrypoint", parents=[])
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1773,7 +1773,7 @@ async def mixed_entrypoint(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1874,7 +1874,7 @@ class AsyncProcessor:
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1960,7 +1960,7 @@ async def async_entrypoint(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -2039,7 +2039,7 @@ def gcd_recursive(a: int, b: int) -> int:
function_to_optimize = FunctionToOptimize(file_path=main_file, function_name="gcd_recursive", parents=[])
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -2152,7 +2152,7 @@ async def async_entrypoint_with_generators(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),

View file

@ -61,9 +61,9 @@ def test_mirror_paths_for_worktree_mode(monkeypatch: pytest.MonkeyPatch):
assert optimizer.args.test_project_root == worktree_dir
assert optimizer.args.module_root == worktree_dir / "codeflash"
# tests_root is configured as "codeflash" in pyproject.toml
assert optimizer.args.tests_root == worktree_dir / "codeflash"
assert optimizer.args.tests_root == worktree_dir / "tests"
assert optimizer.args.file == worktree_dir / "codeflash/optimization/optimizer.py"
assert optimizer.test_cfg.tests_root == worktree_dir / "codeflash"
assert optimizer.test_cfg.tests_root == worktree_dir / "tests"
assert optimizer.test_cfg.project_root_path == worktree_dir # same as project_root
assert optimizer.test_cfg.tests_project_rootdir == worktree_dir # same as test_project_root