Fing code replacer bug, some type annotations and a parsing exception.
This commit is contained in:
parent
abd9375262
commit
bd7f8bf0eb
2 changed files with 83 additions and 17 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Union, Optional, IO, NoReturn
|
||||
from typing import List, Union, Optional, IO
|
||||
|
||||
import libcst as cst
|
||||
from libcst import SimpleStatementLine, FunctionDef
|
||||
|
|
@ -13,9 +13,9 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
preexisting_functions = []
|
||||
self.function_name = function_name
|
||||
self.optim_body: Union[FunctionDef, None] = None
|
||||
self.optim_new_class_functions = []
|
||||
self.optim_new_functions = []
|
||||
self.optim_imports = []
|
||||
self.optim_new_class_functions: list[cst.FunctionDef] = []
|
||||
self.optim_new_functions: list[cst.FunctionDef] = []
|
||||
self.optim_imports: List[Union[cst.Import, cst.ImportFrom]] = []
|
||||
self.preexisting_functions = preexisting_functions
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef):
|
||||
|
|
@ -74,13 +74,11 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
self.in_class: bool = False
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
self.depth += 1
|
||||
return False
|
||||
|
||||
def leave_FunctionDef(
|
||||
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
|
||||
) -> cst.FunctionDef:
|
||||
self.depth -= 1
|
||||
if original_node.name.value == self.function_name and (
|
||||
self.depth == 0 or (self.depth == 1 and self.in_class)
|
||||
):
|
||||
|
|
@ -88,14 +86,17 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
self.in_class = (self.depth == 0) and (node.name.value == self.class_name)
|
||||
self.depth += 1
|
||||
if self.in_class:
|
||||
return False
|
||||
self.in_class = (self.depth == 1) and (node.name.value == self.class_name)
|
||||
return self.in_class
|
||||
|
||||
def leave_ClassDef(
|
||||
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
||||
) -> cst.ClassDef:
|
||||
if self.in_class:
|
||||
self.depth -= 1
|
||||
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
|
||||
self.in_class = False
|
||||
return updated_node.with_changes(
|
||||
body=updated_node.body.with_changes(
|
||||
|
|
@ -207,7 +208,7 @@ def replace_function_definitions_in_module(
|
|||
optimized_code: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[str],
|
||||
) -> NoReturn:
|
||||
) -> None:
|
||||
file: IO[str]
|
||||
with open(module_abspath, "r", encoding="utf8") as file:
|
||||
source_code: str = file.read()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from codeflash.code_utils.code_replacer import replace_functions_in_file
|
|||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement():
|
||||
def test_test_libcst_code_replacement() -> None:
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -56,7 +56,7 @@ print("Hello world")
|
|||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement2():
|
||||
def test_test_libcst_code_replacement2() -> None:
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ print("Hello world")
|
|||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement3():
|
||||
def test_test_libcst_code_replacement3() -> None:
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -167,7 +167,7 @@ print("Salut monde")
|
|||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement4():
|
||||
def test_test_libcst_code_replacement4() -> None:
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -225,7 +225,7 @@ print("Salut monde")
|
|||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement5():
|
||||
def test_test_libcst_code_replacement5() -> None:
|
||||
optim_code = """def sorter_deps(arr):
|
||||
supersort(badsort(arr))
|
||||
return arr
|
||||
|
|
@ -270,7 +270,7 @@ def supersort(doink):
|
|||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement6():
|
||||
def test_test_libcst_code_replacement6() -> None:
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -352,7 +352,7 @@ print("Not cool")
|
|||
assert new_dependent_code == expected_dependent
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement7():
|
||||
def test_test_libcst_code_replacement7() -> None:
|
||||
optim_code = """@register_deserializable
|
||||
class CacheSimilarityEvalConfig(BaseConfig):
|
||||
|
||||
|
|
@ -532,7 +532,6 @@ class CacheConfig(BaseConfig):
|
|||
init_config=CacheInitConfig.from_config(config.get("init_config", {})),
|
||||
)
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
|
||||
preexisting_functions: list[str] = [
|
||||
"__init__",
|
||||
|
|
@ -542,3 +541,69 @@ class CacheConfig(BaseConfig):
|
|||
original_code, function_names, optim_code, preexisting_functions
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement8() -> None:
|
||||
optim_code = '''class _EmbeddingDistanceChainMixin(Chain):
|
||||
@staticmethod
|
||||
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
|
||||
"""Compute the Hamming distance between two vectors.
|
||||
|
||||
Args:
|
||||
a (np.ndarray): The first vector.
|
||||
b (np.ndarray): The second vector.
|
||||
|
||||
Returns:
|
||||
np.floating: The Hamming distance.
|
||||
"""
|
||||
return np.sum(a != b) / a.size
|
||||
'''
|
||||
|
||||
original_code = '''class _EmbeddingDistanceChainMixin(Chain):
|
||||
|
||||
class Config:
|
||||
"""Permit embeddings to go unvalidated."""
|
||||
|
||||
arbitrary_types_allowed: bool = True
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
|
||||
"""Compute the Hamming distance between two vectors.
|
||||
|
||||
Args:
|
||||
a (np.ndarray): The first vector.
|
||||
b (np.ndarray): The second vector.
|
||||
|
||||
Returns:
|
||||
np.floating: The Hamming distance.
|
||||
"""
|
||||
return np.mean(a != b)
|
||||
'''
|
||||
expected = '''class _EmbeddingDistanceChainMixin(Chain):
|
||||
|
||||
class Config:
|
||||
"""Permit embeddings to go unvalidated."""
|
||||
|
||||
arbitrary_types_allowed: bool = True
|
||||
@staticmethod
|
||||
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
|
||||
"""Compute the Hamming distance between two vectors.
|
||||
|
||||
Args:
|
||||
a (np.ndarray): The first vector.
|
||||
b (np.ndarray): The second vector.
|
||||
|
||||
Returns:
|
||||
np.floating: The Hamming distance.
|
||||
"""
|
||||
return np.sum(a != b) / a.size
|
||||
'''
|
||||
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
|
||||
preexisting_functions: list[str] = [
|
||||
"_hamming_distance",
|
||||
]
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, function_names, optim_code, preexisting_functions
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue