diff --git a/codeflash/languages/javascript/normalizer.py b/codeflash/languages/javascript/normalizer.py index 39ae952cb..b42ee8057 100644 --- a/codeflash/languages/javascript/normalizer.py +++ b/codeflash/languages/javascript/normalizer.py @@ -1,7 +1,6 @@ """JavaScript/TypeScript code normalizer using tree-sitter. -Not currently wired into JavaScriptSupport.normalize_code — kept as a -ready-to-use upgrade path when AST-based JS deduplication is needed. +Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication. The old CodeNormalizer ABC (deleted from base.py) is preserved below for reference. """ @@ -236,8 +235,7 @@ def normalize_js_code(code: str, typescript: bool = False) -> str: Uses tree-sitter to parse and normalize variable names. Falls back to basic comment/whitespace stripping if tree-sitter is unavailable or parsing fails. - Not currently wired into JavaScriptSupport.normalize_code — kept as a - ready-to-use upgrade path when AST-based JS deduplication is needed. + Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication. """ try: from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 69fd5ac96..f1a570740 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -1207,20 +1207,29 @@ class JavaScriptSupport: return node # Check function declarations - if node.type in ("function_declaration", "function"): + if node.type in ( + "function_declaration", + "function", + "generator_function_declaration", + "generator_function", + ): name_node = node.child_by_field_name("name") if name_node: name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") if name == target_name: return node - # Check arrow functions assigned to variables - if node.type == "lexical_declaration": + # Check arrow functions and function expressions assigned to variables + if node.type in ("lexical_declaration", "variable_declaration"): for child in node.children: if child.type == "variable_declarator": name_node = child.child_by_field_name("name") value_node = child.child_by_field_name("value") - if name_node and value_node and value_node.type == "arrow_function": + if ( + name_node + and value_node + and value_node.type in ("arrow_function", "function_expression", "generator_function") + ): name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") if name == target_name: return value_node @@ -1235,6 +1244,7 @@ class JavaScriptSupport: func_node = find_function_node(tree.root_node, function_name) if not func_node: + logger.debug("Could not find function '%s' in optimized code for body extraction", function_name) return None # Find the body node @@ -1295,14 +1305,21 @@ class JavaScriptSupport: if name == target_name and (node.start_point[0] + 1) == target_line: return node - if node.type == "lexical_declaration": + if node.type in ("lexical_declaration", "variable_declaration"): for child in node.children: if child.type == "variable_declarator": name_node = child.child_by_field_name("name") value_node = child.child_by_field_name("value") - if name_node and value_node and value_node.type == "arrow_function": + if ( + name_node + and value_node + and value_node.type in ("arrow_function", "function_expression", "generator_function") + ): name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") - if name == target_name and (node.start_point[0] + 1) == target_line: + if name == target_name and ( + (node.start_point[0] + 1) == target_line + or (value_node.start_point[0] + 1) == target_line + ): return value_node for child in node.children: @@ -1686,26 +1703,14 @@ class JavaScriptSupport: return False def normalize_code(self, source: str) -> str: - """Normalize JavaScript code for deduplication. + """Normalize JavaScript code for deduplication using tree-sitter.""" + from codeflash.languages.javascript.normalizer import normalize_js_code - Removes comments and normalizes whitespace. - - Args: - source: Source code to normalize. - - Returns: - Normalized source code. - - """ - # Simple normalization: remove extra whitespace - # A full implementation would use tree-sitter to strip comments - lines = source.splitlines() - normalized_lines = [] - for line in lines: - stripped = line.strip() - if stripped and not stripped.startswith("//"): - normalized_lines.append(stripped) - return "\n".join(normalized_lines) + try: + is_ts = self.treesitter_language == TreeSitterLanguage.TYPESCRIPT + return normalize_js_code(source, typescript=is_ts) + except Exception: + return source def generate_concolic_tests( self, test_cfg: Any, project_root: Any, function_to_optimize: Any, function_to_optimize_ast: Any diff --git a/tests/test_code_deduplication.py b/tests/test_code_deduplication.py index 3cb266785..c4a826330 100644 --- a/tests/test_code_deduplication.py +++ b/tests/test_code_deduplication.py @@ -1,3 +1,4 @@ +from codeflash.languages.javascript.normalizer import normalize_js_code from codeflash.languages.python.normalizer import normalize_python_code as normalize_code @@ -133,3 +134,74 @@ def safe_divide(a, b): assert normalize_code(code9) == normalize_code(code10) assert normalize_code(code9) != normalize_code(code8) + + +# === JavaScript deduplication tests === + + +def test_js_deduplicate_same_logic_different_vars(): + code1 = """ +function process(items) { + const result = []; + for (const item of items) { + result.push(item * 2); + } + return result; +} +""" + code2 = """ +function process(items) { + const output = []; + for (const val of items) { + output.push(val * 2); + } + return output; +} +""" + assert normalize_js_code(code1) == normalize_js_code(code2) + + +def test_js_different_logic_not_deduplicated(): + code1 = """ +function compute(x) { + return x + 1; +} +""" + code2 = """ +function compute(x) { + return x * 2; +} +""" + assert normalize_js_code(code1) != normalize_js_code(code2) + + +def test_js_deduplicate_whitespace_and_comments(): + code1 = """ +function add(a, b) { + // fast path + return a + b; +} +""" + code2 = """ +function add(a, b) { + /* optimized */ + return a + b; +} +""" + assert normalize_js_code(code1) == normalize_js_code(code2) + + +def test_ts_normalize(): + code1 = """ +function greet(name: string): string { + const msg = "hello " + name; + return msg; +} +""" + code2 = """ +function greet(name: string): string { + const result = "hello " + name; + return result; +} +""" + assert normalize_js_code(code1, typescript=True) == normalize_js_code(code2, typescript=True) diff --git a/tests/test_languages/test_javascript_support.py b/tests/test_languages/test_javascript_support.py index 800e01a29..5d5943151 100644 --- a/tests/test_languages/test_javascript_support.py +++ b/tests/test_languages/test_javascript_support.py @@ -443,10 +443,10 @@ function add(a, b { class TestNormalizeCode: - """Tests for normalize_code method.""" + """Tests for normalize_code method using tree-sitter normalizer.""" def test_removes_comments(self, js_support): - """Test that single-line comments are removed.""" + """Test that comments are absent from normalized output.""" code = """ function add(a, b) { // Add two numbers @@ -455,19 +455,43 @@ function add(a, b) { """ normalized = js_support.normalize_code(code) assert "// Add two numbers" not in normalized - assert "return a + b" in normalized + assert "Add two numbers" not in normalized - def test_preserves_functionality(self, js_support): - """Test that code functionality is preserved.""" - code = """ -function add(a, b) { - // Comment - return a + b; + def test_same_logic_different_vars_are_equal(self, js_support): + """Test that two functions with same logic but different variable names normalize identically.""" + code1 = """ +function process(items) { + const result = []; + for (const item of items) { + result.push(item * 2); + } + return result; } """ - normalized = js_support.normalize_code(code) - assert "function add" in normalized - assert "return" in normalized + code2 = """ +function process(items) { + const output = []; + for (const val of items) { + output.push(val * 2); + } + return output; +} +""" + assert js_support.normalize_code(code1) == js_support.normalize_code(code2) + + def test_different_logic_not_equal(self, js_support): + """Test that two functions with different logic produce different normalized forms.""" + code1 = """ +function compute(x) { + return x + 1; +} +""" + code2 = """ +function compute(x) { + return x * 2; +} +""" + assert js_support.normalize_code(code1) != js_support.normalize_code(code2) class TestExtractCodeContext: diff --git a/tests/test_languages/test_js_code_replacer.py b/tests/test_languages/test_js_code_replacer.py index 5ed2a903f..8ff333ad4 100644 --- a/tests/test_languages/test_js_code_replacer.py +++ b/tests/test_languages/test_js_code_replacer.py @@ -15,7 +15,7 @@ from pathlib import Path import pytest -from codeflash.languages.base import Language +from codeflash.languages.base import FunctionFilterCriteria, Language from codeflash.languages.code_replacer import replace_function_definitions_for_language from codeflash.languages.current import set_current_language from codeflash.languages.javascript.module_system import ( @@ -2264,3 +2264,150 @@ export function processNode(node: TreeNode, space: NodeSpace): number { assert "// Optimized" in result assert ts_support.validate_syntax(result) is True + + +class TestVariableAssignedFunctionReplacement: + """Tests for replacing functions assigned to variables (function expressions, var declarations, etc.).""" + + NO_EXPORT_FILTER = FunctionFilterCriteria(require_export=False, require_return=False) + + def test_replace_function_expression_body(self, js_support, temp_project): + """Test replacing an exported const-assigned function expression.""" + original_source = """\ +export const foo = function(x) { + return x + 1; +}; +""" + file_path = temp_project / "funcs.js" + file_path.write_text(original_source, encoding="utf-8") + + source = file_path.read_text(encoding="utf-8") + functions = js_support.discover_functions(source, file_path) + assert len(functions) == 1 + func = functions[0] + assert func.function_name == "foo" + + optimized_code = """\ +export const foo = function(x) { + return (x + 1) | 0; +}; +""" + + result = js_support.replace_function(original_source, func, optimized_code) + + assert "return (x + 1) | 0;" in result + assert js_support.validate_syntax(result) is True + + def test_replace_function_expression_with_var(self, js_support, temp_project): + """Test replacing a var-assigned function expression (non-exported, e.g. CommonJS).""" + original_source = """\ +var foo = function(x) { + return x * 2; +}; +""" + file_path = temp_project / "funcs.js" + file_path.write_text(original_source, encoding="utf-8") + + source = file_path.read_text(encoding="utf-8") + functions = js_support.discover_functions(source, file_path, filter_criteria=self.NO_EXPORT_FILTER) + assert len(functions) == 1 + func = functions[0] + assert func.function_name == "foo" + + optimized_code = """\ +var foo = function(x) { + return x << 1; +}; +""" + + result = js_support.replace_function(original_source, func, optimized_code) + + assert "return x << 1;" in result + assert js_support.validate_syntax(result) is True + + def test_replace_generator_function_expression(self, js_support, temp_project): + """Test replacing an exported const-assigned generator function expression.""" + original_source = """\ +export const gen = function*(n) { + for (let i = 0; i < n; i++) { + yield i; + } +}; +""" + file_path = temp_project / "generators.js" + file_path.write_text(original_source, encoding="utf-8") + + source = file_path.read_text(encoding="utf-8") + functions = js_support.discover_functions(source, file_path, filter_criteria=self.NO_EXPORT_FILTER) + assert len(functions) == 1 + func = functions[0] + assert func.function_name == "gen" + + optimized_code = """\ +export const gen = function*(n) { + let i = 0; + while (i < n) yield i++; +}; +""" + + result = js_support.replace_function(original_source, func, optimized_code) + + assert "while (i < n) yield i++;" in result + assert js_support.validate_syntax(result) is True + + def test_replace_arrow_function_multiline_declaration(self, js_support, temp_project): + """Test replacing an arrow function where the arrow is on a different line than const.""" + original_source = """\ +export const calculate = + (a, b) => { + return a + b; + }; +""" + file_path = temp_project / "calc.js" + file_path.write_text(original_source, encoding="utf-8") + + source = file_path.read_text(encoding="utf-8") + functions = js_support.discover_functions(source, file_path) + assert len(functions) == 1 + func = functions[0] + assert func.function_name == "calculate" + + optimized_code = """\ +export const calculate = + (a, b) => { + return (a + b) | 0; + }; +""" + + result = js_support.replace_function(original_source, func, optimized_code) + + assert "return (a + b) | 0;" in result + assert js_support.validate_syntax(result) is True + + def test_replace_async_arrow_function(self, js_support, temp_project): + """Test replacing an exported const-assigned async arrow function.""" + original_source = """\ +export const fetchData = async (url) => { + const response = await fetch(url); + return response.json(); +}; +""" + file_path = temp_project / "api.js" + file_path.write_text(original_source, encoding="utf-8") + + source = file_path.read_text(encoding="utf-8") + functions = js_support.discover_functions(source, file_path) + assert len(functions) == 1 + func = functions[0] + assert func.function_name == "fetchData" + + optimized_code = """\ +export const fetchData = async (url) => { + return (await fetch(url)).json(); +}; +""" + + result = js_support.replace_function(original_source, func, optimized_code) + + assert "return (await fetch(url)).json();" in result + assert js_support.validate_syntax(result) is True