mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
[Fix] Normalizer and expand its scope
This commit is contained in:
parent
5d872e845d
commit
353feab063
5 changed files with 289 additions and 43 deletions
|
|
@ -1,7 +1,6 @@
|
|||
"""JavaScript/TypeScript code normalizer using tree-sitter.
|
||||
|
||||
Not currently wired into JavaScriptSupport.normalize_code — kept as a
|
||||
ready-to-use upgrade path when AST-based JS deduplication is needed.
|
||||
Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication.
|
||||
|
||||
The old CodeNormalizer ABC (deleted from base.py) is preserved below for reference.
|
||||
"""
|
||||
|
|
@ -236,8 +235,7 @@ def normalize_js_code(code: str, typescript: bool = False) -> str:
|
|||
Uses tree-sitter to parse and normalize variable names. Falls back to
|
||||
basic comment/whitespace stripping if tree-sitter is unavailable or parsing fails.
|
||||
|
||||
Not currently wired into JavaScriptSupport.normalize_code — kept as a
|
||||
ready-to-use upgrade path when AST-based JS deduplication is needed.
|
||||
Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication.
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
|
|
|||
|
|
@ -1207,20 +1207,29 @@ class JavaScriptSupport:
|
|||
return node
|
||||
|
||||
# Check function declarations
|
||||
if node.type in ("function_declaration", "function"):
|
||||
if node.type in (
|
||||
"function_declaration",
|
||||
"function",
|
||||
"generator_function_declaration",
|
||||
"generator_function",
|
||||
):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
if name == target_name:
|
||||
return node
|
||||
|
||||
# Check arrow functions assigned to variables
|
||||
if node.type == "lexical_declaration":
|
||||
# Check arrow functions and function expressions assigned to variables
|
||||
if node.type in ("lexical_declaration", "variable_declaration"):
|
||||
for child in node.children:
|
||||
if child.type == "variable_declarator":
|
||||
name_node = child.child_by_field_name("name")
|
||||
value_node = child.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type == "arrow_function":
|
||||
if (
|
||||
name_node
|
||||
and value_node
|
||||
and value_node.type in ("arrow_function", "function_expression", "generator_function")
|
||||
):
|
||||
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
if name == target_name:
|
||||
return value_node
|
||||
|
|
@ -1235,6 +1244,7 @@ class JavaScriptSupport:
|
|||
|
||||
func_node = find_function_node(tree.root_node, function_name)
|
||||
if not func_node:
|
||||
logger.debug("Could not find function '%s' in optimized code for body extraction", function_name)
|
||||
return None
|
||||
|
||||
# Find the body node
|
||||
|
|
@ -1295,14 +1305,21 @@ class JavaScriptSupport:
|
|||
if name == target_name and (node.start_point[0] + 1) == target_line:
|
||||
return node
|
||||
|
||||
if node.type == "lexical_declaration":
|
||||
if node.type in ("lexical_declaration", "variable_declaration"):
|
||||
for child in node.children:
|
||||
if child.type == "variable_declarator":
|
||||
name_node = child.child_by_field_name("name")
|
||||
value_node = child.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type == "arrow_function":
|
||||
if (
|
||||
name_node
|
||||
and value_node
|
||||
and value_node.type in ("arrow_function", "function_expression", "generator_function")
|
||||
):
|
||||
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
if name == target_name and (node.start_point[0] + 1) == target_line:
|
||||
if name == target_name and (
|
||||
(node.start_point[0] + 1) == target_line
|
||||
or (value_node.start_point[0] + 1) == target_line
|
||||
):
|
||||
return value_node
|
||||
|
||||
for child in node.children:
|
||||
|
|
@ -1686,26 +1703,14 @@ class JavaScriptSupport:
|
|||
return False
|
||||
|
||||
def normalize_code(self, source: str) -> str:
|
||||
"""Normalize JavaScript code for deduplication.
|
||||
"""Normalize JavaScript code for deduplication using tree-sitter."""
|
||||
from codeflash.languages.javascript.normalizer import normalize_js_code
|
||||
|
||||
Removes comments and normalizes whitespace.
|
||||
|
||||
Args:
|
||||
source: Source code to normalize.
|
||||
|
||||
Returns:
|
||||
Normalized source code.
|
||||
|
||||
"""
|
||||
# Simple normalization: remove extra whitespace
|
||||
# A full implementation would use tree-sitter to strip comments
|
||||
lines = source.splitlines()
|
||||
normalized_lines = []
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("//"):
|
||||
normalized_lines.append(stripped)
|
||||
return "\n".join(normalized_lines)
|
||||
try:
|
||||
is_ts = self.treesitter_language == TreeSitterLanguage.TYPESCRIPT
|
||||
return normalize_js_code(source, typescript=is_ts)
|
||||
except Exception:
|
||||
return source
|
||||
|
||||
def generate_concolic_tests(
|
||||
self, test_cfg: Any, project_root: Any, function_to_optimize: Any, function_to_optimize_ast: Any
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from codeflash.languages.javascript.normalizer import normalize_js_code
|
||||
from codeflash.languages.python.normalizer import normalize_python_code as normalize_code
|
||||
|
||||
|
||||
|
|
@ -133,3 +134,74 @@ def safe_divide(a, b):
|
|||
assert normalize_code(code9) == normalize_code(code10)
|
||||
|
||||
assert normalize_code(code9) != normalize_code(code8)
|
||||
|
||||
|
||||
# === JavaScript deduplication tests ===
|
||||
|
||||
|
||||
def test_js_deduplicate_same_logic_different_vars():
|
||||
code1 = """
|
||||
function process(items) {
|
||||
const result = [];
|
||||
for (const item of items) {
|
||||
result.push(item * 2);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
"""
|
||||
code2 = """
|
||||
function process(items) {
|
||||
const output = [];
|
||||
for (const val of items) {
|
||||
output.push(val * 2);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
"""
|
||||
assert normalize_js_code(code1) == normalize_js_code(code2)
|
||||
|
||||
|
||||
def test_js_different_logic_not_deduplicated():
|
||||
code1 = """
|
||||
function compute(x) {
|
||||
return x + 1;
|
||||
}
|
||||
"""
|
||||
code2 = """
|
||||
function compute(x) {
|
||||
return x * 2;
|
||||
}
|
||||
"""
|
||||
assert normalize_js_code(code1) != normalize_js_code(code2)
|
||||
|
||||
|
||||
def test_js_deduplicate_whitespace_and_comments():
|
||||
code1 = """
|
||||
function add(a, b) {
|
||||
// fast path
|
||||
return a + b;
|
||||
}
|
||||
"""
|
||||
code2 = """
|
||||
function add(a, b) {
|
||||
/* optimized */
|
||||
return a + b;
|
||||
}
|
||||
"""
|
||||
assert normalize_js_code(code1) == normalize_js_code(code2)
|
||||
|
||||
|
||||
def test_ts_normalize():
|
||||
code1 = """
|
||||
function greet(name: string): string {
|
||||
const msg = "hello " + name;
|
||||
return msg;
|
||||
}
|
||||
"""
|
||||
code2 = """
|
||||
function greet(name: string): string {
|
||||
const result = "hello " + name;
|
||||
return result;
|
||||
}
|
||||
"""
|
||||
assert normalize_js_code(code1, typescript=True) == normalize_js_code(code2, typescript=True)
|
||||
|
|
|
|||
|
|
@ -443,10 +443,10 @@ function add(a, b {
|
|||
|
||||
|
||||
class TestNormalizeCode:
|
||||
"""Tests for normalize_code method."""
|
||||
"""Tests for normalize_code method using tree-sitter normalizer."""
|
||||
|
||||
def test_removes_comments(self, js_support):
|
||||
"""Test that single-line comments are removed."""
|
||||
"""Test that comments are absent from normalized output."""
|
||||
code = """
|
||||
function add(a, b) {
|
||||
// Add two numbers
|
||||
|
|
@ -455,19 +455,43 @@ function add(a, b) {
|
|||
"""
|
||||
normalized = js_support.normalize_code(code)
|
||||
assert "// Add two numbers" not in normalized
|
||||
assert "return a + b" in normalized
|
||||
assert "Add two numbers" not in normalized
|
||||
|
||||
def test_preserves_functionality(self, js_support):
|
||||
"""Test that code functionality is preserved."""
|
||||
code = """
|
||||
function add(a, b) {
|
||||
// Comment
|
||||
return a + b;
|
||||
def test_same_logic_different_vars_are_equal(self, js_support):
|
||||
"""Test that two functions with same logic but different variable names normalize identically."""
|
||||
code1 = """
|
||||
function process(items) {
|
||||
const result = [];
|
||||
for (const item of items) {
|
||||
result.push(item * 2);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
"""
|
||||
normalized = js_support.normalize_code(code)
|
||||
assert "function add" in normalized
|
||||
assert "return" in normalized
|
||||
code2 = """
|
||||
function process(items) {
|
||||
const output = [];
|
||||
for (const val of items) {
|
||||
output.push(val * 2);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
"""
|
||||
assert js_support.normalize_code(code1) == js_support.normalize_code(code2)
|
||||
|
||||
def test_different_logic_not_equal(self, js_support):
|
||||
"""Test that two functions with different logic produce different normalized forms."""
|
||||
code1 = """
|
||||
function compute(x) {
|
||||
return x + 1;
|
||||
}
|
||||
"""
|
||||
code2 = """
|
||||
function compute(x) {
|
||||
return x * 2;
|
||||
}
|
||||
"""
|
||||
assert js_support.normalize_code(code1) != js_support.normalize_code(code2)
|
||||
|
||||
|
||||
class TestExtractCodeContext:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.base import FunctionFilterCriteria, Language
|
||||
from codeflash.languages.code_replacer import replace_function_definitions_for_language
|
||||
from codeflash.languages.current import set_current_language
|
||||
from codeflash.languages.javascript.module_system import (
|
||||
|
|
@ -2264,3 +2264,150 @@ export function processNode(node: TreeNode, space: NodeSpace): number {
|
|||
assert "// Optimized" in result
|
||||
|
||||
assert ts_support.validate_syntax(result) is True
|
||||
|
||||
|
||||
class TestVariableAssignedFunctionReplacement:
|
||||
"""Tests for replacing functions assigned to variables (function expressions, var declarations, etc.)."""
|
||||
|
||||
NO_EXPORT_FILTER = FunctionFilterCriteria(require_export=False, require_return=False)
|
||||
|
||||
def test_replace_function_expression_body(self, js_support, temp_project):
|
||||
"""Test replacing an exported const-assigned function expression."""
|
||||
original_source = """\
|
||||
export const foo = function(x) {
|
||||
return x + 1;
|
||||
};
|
||||
"""
|
||||
file_path = temp_project / "funcs.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "foo"
|
||||
|
||||
optimized_code = """\
|
||||
export const foo = function(x) {
|
||||
return (x + 1) | 0;
|
||||
};
|
||||
"""
|
||||
|
||||
result = js_support.replace_function(original_source, func, optimized_code)
|
||||
|
||||
assert "return (x + 1) | 0;" in result
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
||||
def test_replace_function_expression_with_var(self, js_support, temp_project):
|
||||
"""Test replacing a var-assigned function expression (non-exported, e.g. CommonJS)."""
|
||||
original_source = """\
|
||||
var foo = function(x) {
|
||||
return x * 2;
|
||||
};
|
||||
"""
|
||||
file_path = temp_project / "funcs.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path, filter_criteria=self.NO_EXPORT_FILTER)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "foo"
|
||||
|
||||
optimized_code = """\
|
||||
var foo = function(x) {
|
||||
return x << 1;
|
||||
};
|
||||
"""
|
||||
|
||||
result = js_support.replace_function(original_source, func, optimized_code)
|
||||
|
||||
assert "return x << 1;" in result
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
||||
def test_replace_generator_function_expression(self, js_support, temp_project):
|
||||
"""Test replacing an exported const-assigned generator function expression."""
|
||||
original_source = """\
|
||||
export const gen = function*(n) {
|
||||
for (let i = 0; i < n; i++) {
|
||||
yield i;
|
||||
}
|
||||
};
|
||||
"""
|
||||
file_path = temp_project / "generators.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path, filter_criteria=self.NO_EXPORT_FILTER)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "gen"
|
||||
|
||||
optimized_code = """\
|
||||
export const gen = function*(n) {
|
||||
let i = 0;
|
||||
while (i < n) yield i++;
|
||||
};
|
||||
"""
|
||||
|
||||
result = js_support.replace_function(original_source, func, optimized_code)
|
||||
|
||||
assert "while (i < n) yield i++;" in result
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
||||
def test_replace_arrow_function_multiline_declaration(self, js_support, temp_project):
|
||||
"""Test replacing an arrow function where the arrow is on a different line than const."""
|
||||
original_source = """\
|
||||
export const calculate =
|
||||
(a, b) => {
|
||||
return a + b;
|
||||
};
|
||||
"""
|
||||
file_path = temp_project / "calc.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "calculate"
|
||||
|
||||
optimized_code = """\
|
||||
export const calculate =
|
||||
(a, b) => {
|
||||
return (a + b) | 0;
|
||||
};
|
||||
"""
|
||||
|
||||
result = js_support.replace_function(original_source, func, optimized_code)
|
||||
|
||||
assert "return (a + b) | 0;" in result
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
||||
def test_replace_async_arrow_function(self, js_support, temp_project):
|
||||
"""Test replacing an exported const-assigned async arrow function."""
|
||||
original_source = """\
|
||||
export const fetchData = async (url) => {
|
||||
const response = await fetch(url);
|
||||
return response.json();
|
||||
};
|
||||
"""
|
||||
file_path = temp_project / "api.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "fetchData"
|
||||
|
||||
optimized_code = """\
|
||||
export const fetchData = async (url) => {
|
||||
return (await fetch(url)).json();
|
||||
};
|
||||
"""
|
||||
|
||||
result = js_support.replace_function(original_source, func, optimized_code)
|
||||
|
||||
assert "return (await fetch(url)).json();" in result
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
|
|
|||
Loading…
Reference in a new issue