deduplicate optimizations better

Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com>
This commit is contained in:
Saurabh Misra 2025-09-13 16:03:52 -07:00
parent 2802ae6c25
commit 95a38d3f6e
3 changed files with 372 additions and 1 deletions

View file

@ -0,0 +1,235 @@
import ast
import hashlib
from typing import Dict, Set
class VariableNormalizer(ast.NodeTransformer):
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
Preserves function names, class names, parameters, built-ins, and imported names.
"""
def __init__(self):
self.var_counter = 0
self.var_mapping: Dict[str, str] = {}
self.scope_stack = []
self.builtins = set(dir(__builtins__))
self.imports: Set[str] = set()
self.global_vars: Set[str] = set()
self.nonlocal_vars: Set[str] = set()
self.parameters: Set[str] = set() # Track function parameters
def enter_scope(self):
"""Enter a new scope (function/class)"""
self.scope_stack.append(
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
)
def exit_scope(self):
"""Exit current scope and restore parent scope"""
if self.scope_stack:
scope = self.scope_stack.pop()
self.var_mapping = scope["var_mapping"]
self.var_counter = scope["var_counter"]
self.parameters = scope["parameters"]
def get_normalized_name(self, name: str) -> str:
"""Get or create normalized name for a variable"""
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
if (
name in self.builtins
or name in self.imports
or name in self.global_vars
or name in self.nonlocal_vars
or name in self.parameters
):
return name
# Only normalize local variables
if name not in self.var_mapping:
self.var_mapping[name] = f"var_{self.var_counter}"
self.var_counter += 1
return self.var_mapping[name]
def visit_Import(self, node):
"""Track imported names"""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name.split(".")[0])
return node
def visit_ImportFrom(self, node):
"""Track imported names from modules"""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name)
return node
def visit_Global(self, node):
"""Track global variable declarations"""
for name in node.names:
self.global_vars.add(name)
return node
def visit_Nonlocal(self, node):
"""Track nonlocal variable declarations"""
for name in node.names:
self.nonlocal_vars.add(name)
return node
def visit_FunctionDef(self, node):
"""Process function but keep function name and parameters unchanged"""
self.enter_scope()
# Track all parameters (don't modify them)
for arg in node.args.args:
self.parameters.add(arg.arg)
if node.args.vararg:
self.parameters.add(node.args.vararg.arg)
if node.args.kwarg:
self.parameters.add(node.args.kwarg.arg)
for arg in node.args.kwonlyargs:
self.parameters.add(arg.arg)
# Visit function body
node = self.generic_visit(node)
self.exit_scope()
return node
def visit_AsyncFunctionDef(self, node):
"""Handle async functions same as regular functions"""
return self.visit_FunctionDef(node)
def visit_ClassDef(self, node):
"""Process class but keep class name unchanged"""
self.enter_scope()
node = self.generic_visit(node)
self.exit_scope()
return node
def visit_Name(self, node):
"""Normalize variable names in Name nodes"""
if isinstance(node.ctx, (ast.Store, ast.Del)):
# For assignments and deletions, check if we should normalize
if (
node.id not in self.builtins
and node.id not in self.imports
and node.id not in self.parameters
and node.id not in self.global_vars
and node.id not in self.nonlocal_vars
):
node.id = self.get_normalized_name(node.id)
elif isinstance(node.ctx, ast.Load):
# For loading, use existing mapping if available
if node.id in self.var_mapping:
node.id = self.var_mapping[node.id]
return node
def visit_ExceptHandler(self, node):
"""Normalize exception variable names"""
if node.name:
node.name = self.get_normalized_name(node.name)
return self.generic_visit(node)
def visit_comprehension(self, node):
"""Normalize comprehension target variables"""
# Create new scope for comprehension
old_mapping = dict(self.var_mapping)
old_counter = self.var_counter
# Process the comprehension
node = self.generic_visit(node)
# Restore scope
self.var_mapping = old_mapping
self.var_counter = old_counter
return node
def visit_For(self, node):
"""Handle for loop target variables"""
# The target in a for loop is a local variable that should be normalized
return self.generic_visit(node)
def visit_With(self, node):
"""Handle with statement as variables"""
return self.generic_visit(node)
def normalize_code(code: str, remove_docstrings: bool = True) -> str:
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
Function names, class names, and parameters are preserved.
Args:
code: Python source code as string
remove_docstrings: Whether to remove docstrings
Returns:
Normalized code as string
"""
try:
# Parse the code
tree = ast.parse(code)
# Remove docstrings if requested
if remove_docstrings:
remove_docstrings_from_ast(tree)
# Normalize variable names
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
# Fix missing locations in the AST
ast.fix_missing_locations(normalized_tree)
# Unparse back to code
return ast.unparse(normalized_tree)
except SyntaxError as e:
msg = f"Invalid Python syntax: {e}"
raise ValueError(msg) from e
def remove_docstrings_from_ast(node):
"""Remove docstrings from AST nodes."""
# Process all nodes in the tree, but avoid recursion
for current_node in ast.walk(node):
if isinstance(current_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)):
if (
current_node.body
and isinstance(current_node.body[0], ast.Expr)
and isinstance(current_node.body[0].value, ast.Constant)
and isinstance(current_node.body[0].value.value, str)
):
current_node.body = current_node.body[1:]
def get_code_fingerprint(code: str) -> str:
"""Generate a fingerprint for normalized code.
Args:
code: Python source code
Returns:
SHA-256 hash of normalized code
"""
normalized = normalize_code(code)
return hashlib.sha256(normalized.encode()).hexdigest()
def are_codes_duplicate(code1: str, code2: str) -> bool:
"""Check if two code segments are duplicates after normalization.
Args:
code1: First code segment
code2: Second code segment
Returns:
True if codes are structurally identical (ignoring local variable names)
"""
try:
normalized1 = normalize_code(code1)
normalized2 = normalize_code(code2)
return normalized1 == normalized2
except Exception:
return False

View file

@ -48,6 +48,7 @@ from codeflash.code_utils.config_consts import (
REPEAT_OPTIMIZATION_PROBABILITY,
TOTAL_LOOPING_TIME,
)
from codeflash.code_utils.deduplicate_code import normalize_code
from codeflash.code_utils.edit_generated_tests import (
add_runtime_comments_to_generated_tests,
remove_functions_from_generated_tests,
@ -519,7 +520,7 @@ class FunctionOptimizer:
)
continue
# check if this code has been evaluated before by checking the ast normalized code string
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
normalized_code = normalize_code(candidate.source_code.flat.strip())
if normalized_code in ast_code_to_id:
logger.info(
"Current candidate has been encountered before in testing, Skipping optimization candidate."

View file

@ -0,0 +1,135 @@
from codeflash.code_utils.deduplicate_code import are_codes_duplicate, normalize_code
def test_deduplicate1():
# Example usage and tests
# Example 1: Same logic, different variable names (should NOT match due to different function/param names)
code1 = """
def compute_sum(numbers):
'''Calculate sum of numbers'''
total = 0
for num in numbers:
total += num
return total
"""
code2 = """
def compute_sum(numbers):
# This computes the sum
result = 0
for value in numbers:
result += value
return result
"""
assert normalize_code(code1) == normalize_code(code2)
assert are_codes_duplicate(code1, code2)
# Example 3: Same function and parameter names, different local variables (should match)
code3 = """
def calculate_sum(numbers):
accumulator = 0
for item in numbers:
accumulator += item
return accumulator
"""
code4 = """
def calculate_sum(numbers):
total = 0
for num in numbers:
total += num
return total
"""
assert normalize_code(code3) == normalize_code(code4)
assert are_codes_duplicate(code3, code4)
# Example 4: Nested functions and classes (preserving names)
code5 = """
class DataProcessor:
def __init__(self, data):
self.data = data
def process(self):
def helper(item):
temp = item * 2
return temp
results = []
for element in self.data:
results.append(helper(element))
return results
"""
code6 = """
class DataProcessor:
def __init__(self, data):
self.data = data
def process(self):
def helper(item):
x = item * 2
return x
output = []
for thing in self.data:
output.append(helper(thing))
return output
"""
assert normalize_code(code5) == normalize_code(code6)
# Example 5: With imports and built-ins (these should be preserved)
code7 = """
import math
def calculate_circle_area(radius):
pi_value = math.pi
area = pi_value * radius ** 2
return area
"""
code8 = """
import math
def calculate_circle_area(radius):
constant = math.pi
result = constant * radius ** 2
return result
"""
code85 = """
import math
def calculate_circle_area(radius):
constant = math.pi
result = constant *2 * radius ** 2
return result
"""
assert normalize_code(code7) == normalize_code(code8)
assert normalize_code(code8) != normalize_code(code85)
# Example 6: Exception handling
code9 = """
def safe_divide(a, b):
try:
result = a / b
return result
except ZeroDivisionError as e:
error_msg = str(e)
return None
"""
code10 = """
def safe_divide(a, b):
try:
output = a / b
return output
except ZeroDivisionError as exc:
message = str(exc)
return None
"""
assert normalize_code(code9) == normalize_code(code10)
assert normalize_code(code9) != normalize_code(code8)