Keep optimized decorators.

This commit is contained in:
RD 2024-07-30 16:35:32 -07:00
parent 818756d28b
commit ea5f5c490f
2 changed files with 5 additions and 3 deletions

View file

@ -94,7 +94,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
if original_node.name.value == self.function_name and (
self.depth == 0 or (self.depth == 1 and self.in_class)
):
return updated_node.with_changes(body=self.optim_body.body)
return updated_node.with_changes(body=self.optim_body.body, decorators=self.optim_body.decorators)
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> bool:

View file

@ -52,6 +52,7 @@ class NewClass:
original_code = """class NewClass:
def __init__(self, name):
self.name = name
@staticmethod
def new_function(self, value):
return "I am still old"
@ -279,7 +280,8 @@ print("Salut monde")
def test_test_libcst_code_replacement5() -> None:
optim_code = """def sorter_deps(arr):
optim_code = """@lru_cache(17)
def sorter_deps(arr: list[int]) -> list[int]:
supersort(badsort(arr))
return arr
@ -304,6 +306,7 @@ def sorter_deps(arr):
expected = """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
@lru_cache(17)
def sorter_deps(arr):
supersort(badsort(arr))
return arr
@ -649,7 +652,6 @@ def test_test_libcst_code_replacement8() -> None:
arbitrary_types_allowed: bool = True
@staticmethod
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Hamming distance between two vectors.