diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index fd489b780..1428630f6 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -1,4 +1,4 @@ -from typing import List, Union, Optional, IO, NoReturn +from typing import List, Union, Optional, IO import libcst as cst from libcst import SimpleStatementLine, FunctionDef @@ -13,9 +13,9 @@ class OptimFunctionCollector(cst.CSTVisitor): preexisting_functions = [] self.function_name = function_name self.optim_body: Union[FunctionDef, None] = None - self.optim_new_class_functions = [] - self.optim_new_functions = [] - self.optim_imports = [] + self.optim_new_class_functions: list[cst.FunctionDef] = [] + self.optim_new_functions: list[cst.FunctionDef] = [] + self.optim_imports: List[Union[cst.Import, cst.ImportFrom]] = [] self.preexisting_functions = preexisting_functions def visit_FunctionDef(self, node: cst.FunctionDef): @@ -74,13 +74,11 @@ class OptimFunctionReplacer(cst.CSTTransformer): self.in_class: bool = False def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: - self.depth += 1 return False def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef ) -> cst.FunctionDef: - self.depth -= 1 if original_node.name.value == self.function_name and ( self.depth == 0 or (self.depth == 1 and self.in_class) ): @@ -88,14 +86,17 @@ class OptimFunctionReplacer(cst.CSTTransformer): return updated_node def visit_ClassDef(self, node: cst.ClassDef) -> bool: - self.in_class = (self.depth == 0) and (node.name.value == self.class_name) self.depth += 1 + if self.in_class: + return False + self.in_class = (self.depth == 1) and (node.name.value == self.class_name) return self.in_class def leave_ClassDef( self, original_node: cst.ClassDef, updated_node: cst.ClassDef ) -> cst.ClassDef: - if self.in_class: + self.depth -= 1 + if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name): self.in_class = False return updated_node.with_changes( body=updated_node.body.with_changes( @@ -207,7 +208,7 @@ def replace_function_definitions_in_module( optimized_code: str, module_abspath: str, preexisting_functions: list[str], -) -> NoReturn: +) -> None: file: IO[str] with open(module_abspath, "r", encoding="utf8") as file: source_code: str = file.read() diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 199c94a63..4b8ad37c0 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -5,7 +5,7 @@ from codeflash.code_utils.code_replacer import replace_functions_in_file os.environ["CODEFLASH_API_KEY"] = "cf-test-key" -def test_test_libcst_code_replacement(): +def test_test_libcst_code_replacement() -> None: optim_code = """import libcst as cst from typing import Optional @@ -56,7 +56,7 @@ print("Hello world") assert new_code == expected -def test_test_libcst_code_replacement2(): +def test_test_libcst_code_replacement2() -> None: optim_code = """import libcst as cst from typing import Optional @@ -111,7 +111,7 @@ print("Hello world") assert new_code == expected -def test_test_libcst_code_replacement3(): +def test_test_libcst_code_replacement3() -> None: optim_code = """import libcst as cst from typing import Optional @@ -167,7 +167,7 @@ print("Salut monde") assert new_code == expected -def test_test_libcst_code_replacement4(): +def test_test_libcst_code_replacement4() -> None: optim_code = """import libcst as cst from typing import Optional @@ -225,7 +225,7 @@ print("Salut monde") assert new_code == expected -def test_test_libcst_code_replacement5(): +def test_test_libcst_code_replacement5() -> None: optim_code = """def sorter_deps(arr): supersort(badsort(arr)) return arr @@ -270,7 +270,7 @@ def supersort(doink): assert new_code == expected -def test_test_libcst_code_replacement6(): +def test_test_libcst_code_replacement6() -> None: optim_code = """import libcst as cst from typing import Optional @@ -352,7 +352,7 @@ print("Not cool") assert new_dependent_code == expected_dependent -def test_test_libcst_code_replacement7(): +def test_test_libcst_code_replacement7() -> None: optim_code = """@register_deserializable class CacheSimilarityEvalConfig(BaseConfig): @@ -532,7 +532,6 @@ class CacheConfig(BaseConfig): init_config=CacheInitConfig.from_config(config.get("init_config", {})), ) """ - function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"] preexisting_functions: list[str] = [ "__init__", @@ -542,3 +541,69 @@ class CacheConfig(BaseConfig): original_code, function_names, optim_code, preexisting_functions ) assert new_code == expected + + +def test_test_libcst_code_replacement8() -> None: + optim_code = '''class _EmbeddingDistanceChainMixin(Chain): + @staticmethod + def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating: + """Compute the Hamming distance between two vectors. + + Args: + a (np.ndarray): The first vector. + b (np.ndarray): The second vector. + + Returns: + np.floating: The Hamming distance. + """ + return np.sum(a != b) / a.size +''' + + original_code = '''class _EmbeddingDistanceChainMixin(Chain): + + class Config: + """Permit embeddings to go unvalidated.""" + + arbitrary_types_allowed: bool = True + + + @staticmethod + def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating: + """Compute the Hamming distance between two vectors. + + Args: + a (np.ndarray): The first vector. + b (np.ndarray): The second vector. + + Returns: + np.floating: The Hamming distance. + """ + return np.mean(a != b) +''' + expected = '''class _EmbeddingDistanceChainMixin(Chain): + + class Config: + """Permit embeddings to go unvalidated.""" + + arbitrary_types_allowed: bool = True + @staticmethod + def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating: + """Compute the Hamming distance between two vectors. + + Args: + a (np.ndarray): The first vector. + b (np.ndarray): The second vector. + + Returns: + np.floating: The Hamming distance. + """ + return np.sum(a != b) / a.size +''' + function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"] + preexisting_functions: list[str] = [ + "_hamming_distance", + ] + new_code: str = replace_functions_in_file( + original_code, function_names, optim_code, preexisting_functions + ) + assert new_code == expected