From f65aa08878124231b782cddef6722507b51a70b4 Mon Sep 17 00:00:00 2001 From: RD <92499101+iusedmyimagination@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:34:07 -0700 Subject: [PATCH] Ruff cleanup of code replacer, new stub files added by mypy. --- .idea/ruff.xml | 5 +- cli/codeflash/code_utils/code_replacer.py | 75 +++++------------------ cli/pyproject.toml | 12 +++- 3 files changed, 28 insertions(+), 64 deletions(-) diff --git a/.idea/ruff.xml b/.idea/ruff.xml index f7fb96578..ba1418ba1 100644 --- a/.idea/ruff.xml +++ b/.idea/ruff.xml @@ -2,11 +2,10 @@ \ No newline at end of file diff --git a/cli/codeflash/code_utils/code_replacer.py b/cli/codeflash/code_utils/code_replacer.py index 893af5b15..31362fe0a 100644 --- a/cli/codeflash/code_utils/code_replacer.py +++ b/cli/codeflash/code_utils/code_replacer.py @@ -18,15 +18,10 @@ ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST) def normalize_node(node: ASTNodeT) -> ASTNodeT: - if isinstance( - node, - (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef), - ) and ast.get_docstring(node): + if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and ast.get_docstring(node): node.body = node.body[1:] if hasattr(node, "body"): - node.body = [ - normalize_node(node) for node in node.body if not isinstance(node, (ast.Import, ast.ImportFrom)) - ] + node.body = [normalize_node(node) for node in node.body if not isinstance(node, (ast.Import, ast.ImportFrom))] return node @@ -54,9 +49,7 @@ class OptimFunctionCollector(cst.CSTVisitor): self.optim_new_class_functions: list[cst.FunctionDef] = [] self.optim_new_functions: list[cst.FunctionDef] = [] self.preexisting_objects = preexisting_objects - self.contextual_functions = contextual_functions.union( - {(self.class_name, self.function_name)}, - ) + self.contextual_functions = contextual_functions.union({(self.class_name, self.function_name)}) def visit_FunctionDef(self, node: cst.FunctionDef) -> None: parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) @@ -71,10 +64,7 @@ class OptimFunctionCollector(cst.CSTVisitor): elif ( self.preexisting_objects and (node.name.value, []) not in self.preexisting_objects - and ( - isinstance(parent, cst.Module) - or (parent2 is not None and not isinstance(parent2, cst.ClassDef)) - ) + and (isinstance(parent, cst.Module) or (parent2 is not None and not isinstance(parent2, cst.ClassDef))) ): self.optim_new_functions.append(node) @@ -111,14 +101,8 @@ class OptimFunctionReplacer(cst.CSTTransformer): def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: return False - def leave_FunctionDef( - self, - original_node: cst.FunctionDef, - updated_node: cst.FunctionDef, - ) -> cst.FunctionDef: - if original_node.name.value == self.function_name and ( - self.depth == 0 or (self.depth == 1 and self.in_class) - ): + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + if original_node.name.value == self.function_name and (self.depth == 0 or (self.depth == 1 and self.in_class)): return updated_node.with_changes(body=self.optim_body.body, decorators=self.optim_body.decorators) return updated_node @@ -129,18 +113,14 @@ class OptimFunctionReplacer(cst.CSTTransformer): self.in_class = (self.depth == 1) and (node.name.value == self.class_name) return self.in_class - def leave_ClassDef( - self, - original_node: cst.ClassDef, - updated_node: cst.ClassDef, - ) -> cst.ClassDef: + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: self.depth -= 1 if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name): self.in_class = False return updated_node.with_changes( body=updated_node.body.with_changes( - body=(list(updated_node.body.body) + self.optim_new_class_functions), - ), + body=(list(updated_node.body.body) + self.optim_new_class_functions) + ) ) return updated_node @@ -159,15 +139,11 @@ class OptimFunctionReplacer(cst.CSTTransformer): *node.body[: max_function_index + 1], *self.optim_new_functions, *node.body[max_function_index + 1 :], - ), + ) ) elif class_index is not None: node = node.with_changes( - body=( - *node.body[: class_index + 1], - *self.optim_new_functions, - *node.body[class_index + 1 :], - ), + body=(*node.body[: class_index + 1], *self.optim_new_functions, *node.body[class_index + 1 :]) ) else: node = node.with_changes(body=(*self.optim_new_functions, *node.body)) @@ -188,24 +164,21 @@ def replace_functions_in_file( elif original_function_name.count(".") == 1: class_name, function_name = original_function_name.split(".") else: - raise ValueError(f"Don't know how to find {original_function_name} yet!") + msg = f"Don't know how to find {original_function_name} yet!" + raise ValueError(msg) parsed_function_names.append((function_name, class_name)) module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code)) for function_name, class_name in parsed_function_names: - visitor = OptimFunctionCollector( - function_name, - class_name, - contextual_functions, - preexisting_objects, - ) + visitor = OptimFunctionCollector(function_name, class_name, contextual_functions, preexisting_objects) module.visit(visitor) if visitor.optim_body is None and not preexisting_objects: continue if visitor.optim_body is None: - raise ValueError(f"Did not find the function {function_name} in the optimized code") + msg = f"Did not find the function {function_name} in the optimized code" + raise ValueError(msg) transformer = OptimFunctionReplacer( visitor.function_name, @@ -234,11 +207,7 @@ def replace_functions_and_add_imports( return add_needed_imports_from_module( optimized_code, replace_functions_in_file( - source_code, - function_names, - optimized_code, - preexisting_objects, - contextual_functions, + source_code, function_names, optimized_code, preexisting_objects, contextual_functions ), file_path_of_module_with_function_to_optimize, module_abspath, @@ -255,16 +224,6 @@ def replace_function_definitions_in_module( contextual_functions: set[tuple[str, str]], project_root_path: Path, ) -> bool: - """:param function_names: List of qualified (not fully qualified) function names (function_name or - class_name.method_name). - :param optimized_code: - :param file_path_of_module_with_function_to_optimize: - :param module_abspath: - :param preexisting_objects: - :param contextual_functions: - :param project_root_path: - :return: - """ source_code: str = module_abspath.read_text(encoding="utf8") new_code: str = replace_functions_and_add_imports( source_code, diff --git a/cli/pyproject.toml b/cli/pyproject.toml index 0ccf5c01e..6780a7c9b 100644 --- a/cli/pyproject.toml +++ b/cli/pyproject.toml @@ -12,9 +12,6 @@ packages = [ ] keywords = ["codeflash", "performance", "optimization", "ai", "code", "machine learning", "LLM"] -# Don't forget to install the poetry plugins we use too: -# poetry self add poetry-dynamic-versioning - [tool.poetry.dependencies] python = "^3.9" unidiff = ">=0.7.4" @@ -40,6 +37,15 @@ returns = ">=0.23" isort = ">=5.11.0" dill = "^0.3.8" rich = "^13.8.1" +pandas-stubs = "^2.2.3.241009" +types-Pygments = "^2.18.0.20240506" +types-colorama = "^0.4.15.20240311" +types-decorator = "^5.1.8.20240310" +types-jsonschema = "^4.23.0.20240813" +types-requests = "^2.32.0.20241016" +types-six = "^1.16.21.20241009" +types-cffi = "^1.16.0.20240331" +types-openpyxl = "^3.1.5.20241020" [tool.poetry.group.dev] optional = true