diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index fcc9fbb1c..825d9a46e 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -253,8 +253,18 @@ def replace_function_definitions_in_module( contextual_functions, project_root_path, ) - if ast.dump(ast.parse(new_code)) == ast.dump(ast.parse(source_code)): + if is_zero_diff(source_code, new_code): return False with open(module_abspath, "w", encoding="utf8") as file: file.write(new_code) return True + + +def is_zero_diff(original_code: str, new_code: str) -> bool: + def normalize_for_diff(tree: ast.Module) -> ast.Module: + tree.body = [node for node in tree.body if not isinstance(node, (ast.Import, ast.ImportFrom))] + return tree + + original_code_unparsed = ast.unparse(normalize_for_diff(ast.parse(original_code))) + new_code_unparsed = ast.unparse(normalize_for_diff(ast.parse(new_code))) + return original_code_unparsed == new_code_unparsed diff --git a/codeflash/version.py b/codeflash/version.py index 3d296a5e2..0275e0c0d 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,3 +1,3 @@ # These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`. -__version__ = "0.6.11" -__version_tuple__ = (0, 6, 11) +__version__ = "0.6.12" +__version_tuple__ = (0, 6, 12) diff --git a/pyproject.toml b/pyproject.toml index 665caa38f..9eea44e9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ keywords = ["codeflash", "performance", "optimization", "ai", "code", "machine l # poetry self add poetry-dynamic-versioning [tool.poetry.dependencies] -python = ">=3.9,<4.0" +python = ">=3.9,<3.13" unidiff = ">=0.7.4" pytest = ">=7.0.0" gitpython = ">=3.1.31" diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 433391784..d32cf825e 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -8,7 +8,11 @@ from pathlib import Path import libcst as cst from codeflash.code_utils.code_extractor import remove_first_imported_aliased_objects -from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file +from codeflash.code_utils.code_replacer import ( + is_zero_diff, + replace_functions_and_add_imports, + replace_functions_in_file, +) from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize from codeflash.optimization.optimizer import Optimizer @@ -1547,3 +1551,38 @@ print("Hello monde") """ assert remove_first_imported_aliased_objects(module_code5, "__future__")[0] == module_code5 + +def test_0_diff_code_replacement(): + original_code = """from __future__ import annotations + +import numpy as np +def functionA(): + return np.array([1, 2, 3]) +""" + optim_code_a = """from __future__ import annotations +import numpy as np +def functionA(): + return np.array([1, 2, 3])""" + + assert is_zero_diff(original_code, optim_code_a) + + optim_code_b = """ +import numpy as np +def functionA(): + return np.array([1, 2, 3])""" + + assert is_zero_diff(original_code, optim_code_b) + + optim_code_c = """ +def functionA(): + return np.array([1, 2, 3])""" + + assert is_zero_diff(original_code, optim_code_c) + + optim_code_d = """from __future__ import annotations + +import numpy as np +def functionA(): + return np.array([1, 2, 3, 4]) +""" + assert not is_zero_diff(original_code, optim_code_d)