[Fix] Normalizer and expand its scope

This commit is contained in:
Sarthak Agarwal 2026-03-06 21:31:24 +05:30
parent 5d872e845d
commit 353feab063
5 changed files with 289 additions and 43 deletions

View file

@ -1,7 +1,6 @@
"""JavaScript/TypeScript code normalizer using tree-sitter.
Not currently wired into JavaScriptSupport.normalize_code kept as a
ready-to-use upgrade path when AST-based JS deduplication is needed.
Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication.
The old CodeNormalizer ABC (deleted from base.py) is preserved below for reference.
"""
@ -236,8 +235,7 @@ def normalize_js_code(code: str, typescript: bool = False) -> str:
Uses tree-sitter to parse and normalize variable names. Falls back to
basic comment/whitespace stripping if tree-sitter is unavailable or parsing fails.
Not currently wired into JavaScriptSupport.normalize_code kept as a
ready-to-use upgrade path when AST-based JS deduplication is needed.
Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication.
"""
try:
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage

View file

@ -1207,20 +1207,29 @@ class JavaScriptSupport:
return node
# Check function declarations
if node.type in ("function_declaration", "function"):
if node.type in (
"function_declaration",
"function",
"generator_function_declaration",
"generator_function",
):
name_node = node.child_by_field_name("name")
if name_node:
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name:
return node
# Check arrow functions assigned to variables
if node.type == "lexical_declaration":
# Check arrow functions and function expressions assigned to variables
if node.type in ("lexical_declaration", "variable_declaration"):
for child in node.children:
if child.type == "variable_declarator":
name_node = child.child_by_field_name("name")
value_node = child.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
if (
name_node
and value_node
and value_node.type in ("arrow_function", "function_expression", "generator_function")
):
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name:
return value_node
@ -1235,6 +1244,7 @@ class JavaScriptSupport:
func_node = find_function_node(tree.root_node, function_name)
if not func_node:
logger.debug("Could not find function '%s' in optimized code for body extraction", function_name)
return None
# Find the body node
@ -1295,14 +1305,21 @@ class JavaScriptSupport:
if name == target_name and (node.start_point[0] + 1) == target_line:
return node
if node.type == "lexical_declaration":
if node.type in ("lexical_declaration", "variable_declaration"):
for child in node.children:
if child.type == "variable_declarator":
name_node = child.child_by_field_name("name")
value_node = child.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
if (
name_node
and value_node
and value_node.type in ("arrow_function", "function_expression", "generator_function")
):
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name and (node.start_point[0] + 1) == target_line:
if name == target_name and (
(node.start_point[0] + 1) == target_line
or (value_node.start_point[0] + 1) == target_line
):
return value_node
for child in node.children:
@ -1686,26 +1703,14 @@ class JavaScriptSupport:
return False
def normalize_code(self, source: str) -> str:
"""Normalize JavaScript code for deduplication.
"""Normalize JavaScript code for deduplication using tree-sitter."""
from codeflash.languages.javascript.normalizer import normalize_js_code
Removes comments and normalizes whitespace.
Args:
source: Source code to normalize.
Returns:
Normalized source code.
"""
# Simple normalization: remove extra whitespace
# A full implementation would use tree-sitter to strip comments
lines = source.splitlines()
normalized_lines = []
for line in lines:
stripped = line.strip()
if stripped and not stripped.startswith("//"):
normalized_lines.append(stripped)
return "\n".join(normalized_lines)
try:
is_ts = self.treesitter_language == TreeSitterLanguage.TYPESCRIPT
return normalize_js_code(source, typescript=is_ts)
except Exception:
return source
def generate_concolic_tests(
self, test_cfg: Any, project_root: Any, function_to_optimize: Any, function_to_optimize_ast: Any

View file

@ -1,3 +1,4 @@
from codeflash.languages.javascript.normalizer import normalize_js_code
from codeflash.languages.python.normalizer import normalize_python_code as normalize_code
@ -133,3 +134,74 @@ def safe_divide(a, b):
assert normalize_code(code9) == normalize_code(code10)
assert normalize_code(code9) != normalize_code(code8)
# === JavaScript deduplication tests ===
def test_js_deduplicate_same_logic_different_vars():
code1 = """
function process(items) {
const result = [];
for (const item of items) {
result.push(item * 2);
}
return result;
}
"""
code2 = """
function process(items) {
const output = [];
for (const val of items) {
output.push(val * 2);
}
return output;
}
"""
assert normalize_js_code(code1) == normalize_js_code(code2)
def test_js_different_logic_not_deduplicated():
code1 = """
function compute(x) {
return x + 1;
}
"""
code2 = """
function compute(x) {
return x * 2;
}
"""
assert normalize_js_code(code1) != normalize_js_code(code2)
def test_js_deduplicate_whitespace_and_comments():
code1 = """
function add(a, b) {
// fast path
return a + b;
}
"""
code2 = """
function add(a, b) {
/* optimized */
return a + b;
}
"""
assert normalize_js_code(code1) == normalize_js_code(code2)
def test_ts_normalize():
code1 = """
function greet(name: string): string {
const msg = "hello " + name;
return msg;
}
"""
code2 = """
function greet(name: string): string {
const result = "hello " + name;
return result;
}
"""
assert normalize_js_code(code1, typescript=True) == normalize_js_code(code2, typescript=True)

View file

@ -443,10 +443,10 @@ function add(a, b {
class TestNormalizeCode:
"""Tests for normalize_code method."""
"""Tests for normalize_code method using tree-sitter normalizer."""
def test_removes_comments(self, js_support):
"""Test that single-line comments are removed."""
"""Test that comments are absent from normalized output."""
code = """
function add(a, b) {
// Add two numbers
@ -455,19 +455,43 @@ function add(a, b) {
"""
normalized = js_support.normalize_code(code)
assert "// Add two numbers" not in normalized
assert "return a + b" in normalized
assert "Add two numbers" not in normalized
def test_preserves_functionality(self, js_support):
"""Test that code functionality is preserved."""
code = """
function add(a, b) {
// Comment
return a + b;
def test_same_logic_different_vars_are_equal(self, js_support):
"""Test that two functions with same logic but different variable names normalize identically."""
code1 = """
function process(items) {
const result = [];
for (const item of items) {
result.push(item * 2);
}
return result;
}
"""
normalized = js_support.normalize_code(code)
assert "function add" in normalized
assert "return" in normalized
code2 = """
function process(items) {
const output = [];
for (const val of items) {
output.push(val * 2);
}
return output;
}
"""
assert js_support.normalize_code(code1) == js_support.normalize_code(code2)
def test_different_logic_not_equal(self, js_support):
"""Test that two functions with different logic produce different normalized forms."""
code1 = """
function compute(x) {
return x + 1;
}
"""
code2 = """
function compute(x) {
return x * 2;
}
"""
assert js_support.normalize_code(code1) != js_support.normalize_code(code2)
class TestExtractCodeContext:

View file

@ -15,7 +15,7 @@ from pathlib import Path
import pytest
from codeflash.languages.base import Language
from codeflash.languages.base import FunctionFilterCriteria, Language
from codeflash.languages.code_replacer import replace_function_definitions_for_language
from codeflash.languages.current import set_current_language
from codeflash.languages.javascript.module_system import (
@ -2264,3 +2264,150 @@ export function processNode(node: TreeNode, space: NodeSpace): number {
assert "// Optimized" in result
assert ts_support.validate_syntax(result) is True
class TestVariableAssignedFunctionReplacement:
"""Tests for replacing functions assigned to variables (function expressions, var declarations, etc.)."""
NO_EXPORT_FILTER = FunctionFilterCriteria(require_export=False, require_return=False)
def test_replace_function_expression_body(self, js_support, temp_project):
"""Test replacing an exported const-assigned function expression."""
original_source = """\
export const foo = function(x) {
return x + 1;
};
"""
file_path = temp_project / "funcs.js"
file_path.write_text(original_source, encoding="utf-8")
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "foo"
optimized_code = """\
export const foo = function(x) {
return (x + 1) | 0;
};
"""
result = js_support.replace_function(original_source, func, optimized_code)
assert "return (x + 1) | 0;" in result
assert js_support.validate_syntax(result) is True
def test_replace_function_expression_with_var(self, js_support, temp_project):
"""Test replacing a var-assigned function expression (non-exported, e.g. CommonJS)."""
original_source = """\
var foo = function(x) {
return x * 2;
};
"""
file_path = temp_project / "funcs.js"
file_path.write_text(original_source, encoding="utf-8")
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path, filter_criteria=self.NO_EXPORT_FILTER)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "foo"
optimized_code = """\
var foo = function(x) {
return x << 1;
};
"""
result = js_support.replace_function(original_source, func, optimized_code)
assert "return x << 1;" in result
assert js_support.validate_syntax(result) is True
def test_replace_generator_function_expression(self, js_support, temp_project):
"""Test replacing an exported const-assigned generator function expression."""
original_source = """\
export const gen = function*(n) {
for (let i = 0; i < n; i++) {
yield i;
}
};
"""
file_path = temp_project / "generators.js"
file_path.write_text(original_source, encoding="utf-8")
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path, filter_criteria=self.NO_EXPORT_FILTER)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "gen"
optimized_code = """\
export const gen = function*(n) {
let i = 0;
while (i < n) yield i++;
};
"""
result = js_support.replace_function(original_source, func, optimized_code)
assert "while (i < n) yield i++;" in result
assert js_support.validate_syntax(result) is True
def test_replace_arrow_function_multiline_declaration(self, js_support, temp_project):
"""Test replacing an arrow function where the arrow is on a different line than const."""
original_source = """\
export const calculate =
(a, b) => {
return a + b;
};
"""
file_path = temp_project / "calc.js"
file_path.write_text(original_source, encoding="utf-8")
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "calculate"
optimized_code = """\
export const calculate =
(a, b) => {
return (a + b) | 0;
};
"""
result = js_support.replace_function(original_source, func, optimized_code)
assert "return (a + b) | 0;" in result
assert js_support.validate_syntax(result) is True
def test_replace_async_arrow_function(self, js_support, temp_project):
"""Test replacing an exported const-assigned async arrow function."""
original_source = """\
export const fetchData = async (url) => {
const response = await fetch(url);
return response.json();
};
"""
file_path = temp_project / "api.js"
file_path.write_text(original_source, encoding="utf-8")
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "fetchData"
optimized_code = """\
export const fetchData = async (url) => {
return (await fetch(url)).json();
};
"""
result = js_support.replace_function(original_source, func, optimized_code)
assert "return (await fetch(url)).json();" in result
assert js_support.validate_syntax(result) is True