deduplicate optimizations better
Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
This commit is contained in:
parent
2802ae6c25
commit
95a38d3f6e
3 changed files with 372 additions and 1 deletions
235
codeflash/code_utils/deduplicate_code.py
Normal file
235
codeflash/code_utils/deduplicate_code.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
import ast
|
||||
import hashlib
|
||||
from typing import Dict, Set
|
||||
|
||||
|
||||
class VariableNormalizer(ast.NodeTransformer):
|
||||
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
|
||||
Preserves function names, class names, parameters, built-ins, and imported names.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.var_counter = 0
|
||||
self.var_mapping: Dict[str, str] = {}
|
||||
self.scope_stack = []
|
||||
self.builtins = set(dir(__builtins__))
|
||||
self.imports: Set[str] = set()
|
||||
self.global_vars: Set[str] = set()
|
||||
self.nonlocal_vars: Set[str] = set()
|
||||
self.parameters: Set[str] = set() # Track function parameters
|
||||
|
||||
def enter_scope(self):
|
||||
"""Enter a new scope (function/class)"""
|
||||
self.scope_stack.append(
|
||||
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
|
||||
)
|
||||
|
||||
def exit_scope(self):
|
||||
"""Exit current scope and restore parent scope"""
|
||||
if self.scope_stack:
|
||||
scope = self.scope_stack.pop()
|
||||
self.var_mapping = scope["var_mapping"]
|
||||
self.var_counter = scope["var_counter"]
|
||||
self.parameters = scope["parameters"]
|
||||
|
||||
def get_normalized_name(self, name: str) -> str:
|
||||
"""Get or create normalized name for a variable"""
|
||||
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
|
||||
if (
|
||||
name in self.builtins
|
||||
or name in self.imports
|
||||
or name in self.global_vars
|
||||
or name in self.nonlocal_vars
|
||||
or name in self.parameters
|
||||
):
|
||||
return name
|
||||
|
||||
# Only normalize local variables
|
||||
if name not in self.var_mapping:
|
||||
self.var_mapping[name] = f"var_{self.var_counter}"
|
||||
self.var_counter += 1
|
||||
return self.var_mapping[name]
|
||||
|
||||
def visit_Import(self, node):
|
||||
"""Track imported names"""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name.split(".")[0])
|
||||
return node
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
"""Track imported names from modules"""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name)
|
||||
return node
|
||||
|
||||
def visit_Global(self, node):
|
||||
"""Track global variable declarations"""
|
||||
for name in node.names:
|
||||
self.global_vars.add(name)
|
||||
return node
|
||||
|
||||
def visit_Nonlocal(self, node):
|
||||
"""Track nonlocal variable declarations"""
|
||||
for name in node.names:
|
||||
self.nonlocal_vars.add(name)
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
"""Process function but keep function name and parameters unchanged"""
|
||||
self.enter_scope()
|
||||
|
||||
# Track all parameters (don't modify them)
|
||||
for arg in node.args.args:
|
||||
self.parameters.add(arg.arg)
|
||||
if node.args.vararg:
|
||||
self.parameters.add(node.args.vararg.arg)
|
||||
if node.args.kwarg:
|
||||
self.parameters.add(node.args.kwarg.arg)
|
||||
for arg in node.args.kwonlyargs:
|
||||
self.parameters.add(arg.arg)
|
||||
|
||||
# Visit function body
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_AsyncFunctionDef(self, node):
|
||||
"""Handle async functions same as regular functions"""
|
||||
return self.visit_FunctionDef(node)
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
"""Process class but keep class name unchanged"""
|
||||
self.enter_scope()
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_Name(self, node):
|
||||
"""Normalize variable names in Name nodes"""
|
||||
if isinstance(node.ctx, (ast.Store, ast.Del)):
|
||||
# For assignments and deletions, check if we should normalize
|
||||
if (
|
||||
node.id not in self.builtins
|
||||
and node.id not in self.imports
|
||||
and node.id not in self.parameters
|
||||
and node.id not in self.global_vars
|
||||
and node.id not in self.nonlocal_vars
|
||||
):
|
||||
node.id = self.get_normalized_name(node.id)
|
||||
elif isinstance(node.ctx, ast.Load):
|
||||
# For loading, use existing mapping if available
|
||||
if node.id in self.var_mapping:
|
||||
node.id = self.var_mapping[node.id]
|
||||
return node
|
||||
|
||||
def visit_ExceptHandler(self, node):
|
||||
"""Normalize exception variable names"""
|
||||
if node.name:
|
||||
node.name = self.get_normalized_name(node.name)
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_comprehension(self, node):
|
||||
"""Normalize comprehension target variables"""
|
||||
# Create new scope for comprehension
|
||||
old_mapping = dict(self.var_mapping)
|
||||
old_counter = self.var_counter
|
||||
|
||||
# Process the comprehension
|
||||
node = self.generic_visit(node)
|
||||
|
||||
# Restore scope
|
||||
self.var_mapping = old_mapping
|
||||
self.var_counter = old_counter
|
||||
return node
|
||||
|
||||
def visit_For(self, node):
|
||||
"""Handle for loop target variables"""
|
||||
# The target in a for loop is a local variable that should be normalized
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_With(self, node):
|
||||
"""Handle with statement as variables"""
|
||||
return self.generic_visit(node)
|
||||
|
||||
|
||||
def normalize_code(code: str, remove_docstrings: bool = True) -> str:
|
||||
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
|
||||
Function names, class names, and parameters are preserved.
|
||||
|
||||
Args:
|
||||
code: Python source code as string
|
||||
remove_docstrings: Whether to remove docstrings
|
||||
|
||||
Returns:
|
||||
Normalized code as string
|
||||
|
||||
"""
|
||||
try:
|
||||
# Parse the code
|
||||
tree = ast.parse(code)
|
||||
|
||||
# Remove docstrings if requested
|
||||
if remove_docstrings:
|
||||
remove_docstrings_from_ast(tree)
|
||||
|
||||
# Normalize variable names
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
|
||||
# Fix missing locations in the AST
|
||||
ast.fix_missing_locations(normalized_tree)
|
||||
|
||||
# Unparse back to code
|
||||
return ast.unparse(normalized_tree)
|
||||
except SyntaxError as e:
|
||||
msg = f"Invalid Python syntax: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
|
||||
def remove_docstrings_from_ast(node):
|
||||
"""Remove docstrings from AST nodes."""
|
||||
# Process all nodes in the tree, but avoid recursion
|
||||
for current_node in ast.walk(node):
|
||||
if isinstance(current_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)):
|
||||
if (
|
||||
current_node.body
|
||||
and isinstance(current_node.body[0], ast.Expr)
|
||||
and isinstance(current_node.body[0].value, ast.Constant)
|
||||
and isinstance(current_node.body[0].value.value, str)
|
||||
):
|
||||
current_node.body = current_node.body[1:]
|
||||
|
||||
|
||||
def get_code_fingerprint(code: str) -> str:
|
||||
"""Generate a fingerprint for normalized code.
|
||||
|
||||
Args:
|
||||
code: Python source code
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of normalized code
|
||||
|
||||
"""
|
||||
normalized = normalize_code(code)
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()
|
||||
|
||||
|
||||
def are_codes_duplicate(code1: str, code2: str) -> bool:
|
||||
"""Check if two code segments are duplicates after normalization.
|
||||
|
||||
Args:
|
||||
code1: First code segment
|
||||
code2: Second code segment
|
||||
|
||||
Returns:
|
||||
True if codes are structurally identical (ignoring local variable names)
|
||||
|
||||
"""
|
||||
try:
|
||||
normalized1 = normalize_code(code1)
|
||||
normalized2 = normalize_code(code2)
|
||||
return normalized1 == normalized2
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -48,6 +48,7 @@ from codeflash.code_utils.config_consts import (
|
|||
REPEAT_OPTIMIZATION_PROBABILITY,
|
||||
TOTAL_LOOPING_TIME,
|
||||
)
|
||||
from codeflash.code_utils.deduplicate_code import normalize_code
|
||||
from codeflash.code_utils.edit_generated_tests import (
|
||||
add_runtime_comments_to_generated_tests,
|
||||
remove_functions_from_generated_tests,
|
||||
|
|
@ -519,7 +520,7 @@ class FunctionOptimizer:
|
|||
)
|
||||
continue
|
||||
# check if this code has been evaluated before by checking the ast normalized code string
|
||||
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
|
||||
normalized_code = normalize_code(candidate.source_code.flat.strip())
|
||||
if normalized_code in ast_code_to_id:
|
||||
logger.info(
|
||||
"Current candidate has been encountered before in testing, Skipping optimization candidate."
|
||||
|
|
|
|||
135
tests/test_code_deduplication.py
Normal file
135
tests/test_code_deduplication.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
from codeflash.code_utils.deduplicate_code import are_codes_duplicate, normalize_code
|
||||
|
||||
|
||||
def test_deduplicate1():
|
||||
# Example usage and tests
|
||||
# Example 1: Same logic, different variable names (should NOT match due to different function/param names)
|
||||
code1 = """
|
||||
def compute_sum(numbers):
|
||||
'''Calculate sum of numbers'''
|
||||
total = 0
|
||||
for num in numbers:
|
||||
total += num
|
||||
return total
|
||||
"""
|
||||
|
||||
code2 = """
|
||||
def compute_sum(numbers):
|
||||
# This computes the sum
|
||||
result = 0
|
||||
for value in numbers:
|
||||
result += value
|
||||
return result
|
||||
"""
|
||||
|
||||
assert normalize_code(code1) == normalize_code(code2)
|
||||
assert are_codes_duplicate(code1, code2)
|
||||
|
||||
# Example 3: Same function and parameter names, different local variables (should match)
|
||||
code3 = """
|
||||
def calculate_sum(numbers):
|
||||
accumulator = 0
|
||||
for item in numbers:
|
||||
accumulator += item
|
||||
return accumulator
|
||||
"""
|
||||
|
||||
code4 = """
|
||||
def calculate_sum(numbers):
|
||||
total = 0
|
||||
for num in numbers:
|
||||
total += num
|
||||
return total
|
||||
"""
|
||||
|
||||
assert normalize_code(code3) == normalize_code(code4)
|
||||
assert are_codes_duplicate(code3, code4)
|
||||
|
||||
# Example 4: Nested functions and classes (preserving names)
|
||||
code5 = """
|
||||
class DataProcessor:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def process(self):
|
||||
def helper(item):
|
||||
temp = item * 2
|
||||
return temp
|
||||
|
||||
results = []
|
||||
for element in self.data:
|
||||
results.append(helper(element))
|
||||
return results
|
||||
"""
|
||||
|
||||
code6 = """
|
||||
class DataProcessor:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def process(self):
|
||||
def helper(item):
|
||||
x = item * 2
|
||||
return x
|
||||
|
||||
output = []
|
||||
for thing in self.data:
|
||||
output.append(helper(thing))
|
||||
return output
|
||||
"""
|
||||
|
||||
assert normalize_code(code5) == normalize_code(code6)
|
||||
|
||||
# Example 5: With imports and built-ins (these should be preserved)
|
||||
code7 = """
|
||||
import math
|
||||
|
||||
def calculate_circle_area(radius):
|
||||
pi_value = math.pi
|
||||
area = pi_value * radius ** 2
|
||||
return area
|
||||
"""
|
||||
|
||||
code8 = """
|
||||
import math
|
||||
|
||||
def calculate_circle_area(radius):
|
||||
constant = math.pi
|
||||
result = constant * radius ** 2
|
||||
return result
|
||||
"""
|
||||
code85 = """
|
||||
import math
|
||||
|
||||
def calculate_circle_area(radius):
|
||||
constant = math.pi
|
||||
result = constant *2 * radius ** 2
|
||||
return result
|
||||
"""
|
||||
|
||||
assert normalize_code(code7) == normalize_code(code8)
|
||||
assert normalize_code(code8) != normalize_code(code85)
|
||||
|
||||
# Example 6: Exception handling
|
||||
code9 = """
|
||||
def safe_divide(a, b):
|
||||
try:
|
||||
result = a / b
|
||||
return result
|
||||
except ZeroDivisionError as e:
|
||||
error_msg = str(e)
|
||||
return None
|
||||
"""
|
||||
|
||||
code10 = """
|
||||
def safe_divide(a, b):
|
||||
try:
|
||||
output = a / b
|
||||
return output
|
||||
except ZeroDivisionError as exc:
|
||||
message = str(exc)
|
||||
return None
|
||||
"""
|
||||
assert normalize_code(code9) == normalize_code(code10)
|
||||
|
||||
assert normalize_code(code9) != normalize_code(code8)
|
||||
Loading…
Reference in a new issue