Fing code replacer bug, some type annotations and a parsing exception.

This commit is contained in:
renaud 2024-03-13 02:43:25 -07:00
parent abd9375262
commit bd7f8bf0eb
2 changed files with 83 additions and 17 deletions

View file

@ -1,4 +1,4 @@
from typing import List, Union, Optional, IO, NoReturn
from typing import List, Union, Optional, IO
import libcst as cst
from libcst import SimpleStatementLine, FunctionDef
@ -13,9 +13,9 @@ class OptimFunctionCollector(cst.CSTVisitor):
preexisting_functions = []
self.function_name = function_name
self.optim_body: Union[FunctionDef, None] = None
self.optim_new_class_functions = []
self.optim_new_functions = []
self.optim_imports = []
self.optim_new_class_functions: list[cst.FunctionDef] = []
self.optim_new_functions: list[cst.FunctionDef] = []
self.optim_imports: List[Union[cst.Import, cst.ImportFrom]] = []
self.preexisting_functions = preexisting_functions
def visit_FunctionDef(self, node: cst.FunctionDef):
@ -74,13 +74,11 @@ class OptimFunctionReplacer(cst.CSTTransformer):
self.in_class: bool = False
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
self.depth += 1
return False
def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
self.depth -= 1
if original_node.name.value == self.function_name and (
self.depth == 0 or (self.depth == 1 and self.in_class)
):
@ -88,14 +86,17 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.in_class = (self.depth == 0) and (node.name.value == self.class_name)
self.depth += 1
if self.in_class:
return False
self.in_class = (self.depth == 1) and (node.name.value == self.class_name)
return self.in_class
def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef:
if self.in_class:
self.depth -= 1
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
self.in_class = False
return updated_node.with_changes(
body=updated_node.body.with_changes(
@ -207,7 +208,7 @@ def replace_function_definitions_in_module(
optimized_code: str,
module_abspath: str,
preexisting_functions: list[str],
) -> NoReturn:
) -> None:
file: IO[str]
with open(module_abspath, "r", encoding="utf8") as file:
source_code: str = file.read()

View file

@ -5,7 +5,7 @@ from codeflash.code_utils.code_replacer import replace_functions_in_file
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
def test_test_libcst_code_replacement():
def test_test_libcst_code_replacement() -> None:
optim_code = """import libcst as cst
from typing import Optional
@ -56,7 +56,7 @@ print("Hello world")
assert new_code == expected
def test_test_libcst_code_replacement2():
def test_test_libcst_code_replacement2() -> None:
optim_code = """import libcst as cst
from typing import Optional
@ -111,7 +111,7 @@ print("Hello world")
assert new_code == expected
def test_test_libcst_code_replacement3():
def test_test_libcst_code_replacement3() -> None:
optim_code = """import libcst as cst
from typing import Optional
@ -167,7 +167,7 @@ print("Salut monde")
assert new_code == expected
def test_test_libcst_code_replacement4():
def test_test_libcst_code_replacement4() -> None:
optim_code = """import libcst as cst
from typing import Optional
@ -225,7 +225,7 @@ print("Salut monde")
assert new_code == expected
def test_test_libcst_code_replacement5():
def test_test_libcst_code_replacement5() -> None:
optim_code = """def sorter_deps(arr):
supersort(badsort(arr))
return arr
@ -270,7 +270,7 @@ def supersort(doink):
assert new_code == expected
def test_test_libcst_code_replacement6():
def test_test_libcst_code_replacement6() -> None:
optim_code = """import libcst as cst
from typing import Optional
@ -352,7 +352,7 @@ print("Not cool")
assert new_dependent_code == expected_dependent
def test_test_libcst_code_replacement7():
def test_test_libcst_code_replacement7() -> None:
optim_code = """@register_deserializable
class CacheSimilarityEvalConfig(BaseConfig):
@ -532,7 +532,6 @@ class CacheConfig(BaseConfig):
init_config=CacheInitConfig.from_config(config.get("init_config", {})),
)
"""
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
preexisting_functions: list[str] = [
"__init__",
@ -542,3 +541,69 @@ class CacheConfig(BaseConfig):
original_code, function_names, optim_code, preexisting_functions
)
assert new_code == expected
def test_test_libcst_code_replacement8() -> None:
optim_code = '''class _EmbeddingDistanceChainMixin(Chain):
@staticmethod
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Hamming distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Hamming distance.
"""
return np.sum(a != b) / a.size
'''
original_code = '''class _EmbeddingDistanceChainMixin(Chain):
class Config:
"""Permit embeddings to go unvalidated."""
arbitrary_types_allowed: bool = True
@staticmethod
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Hamming distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Hamming distance.
"""
return np.mean(a != b)
'''
expected = '''class _EmbeddingDistanceChainMixin(Chain):
class Config:
"""Permit embeddings to go unvalidated."""
arbitrary_types_allowed: bool = True
@staticmethod
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Hamming distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Hamming distance.
"""
return np.sum(a != b) / a.size
'''
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
preexisting_functions: list[str] = [
"_hamming_distance",
]
new_code: str = replace_functions_in_file(
original_code, function_names, optim_code, preexisting_functions
)
assert new_code == expected