Ruff cleanup of code replacer, new stub files added by mypy.

This commit is contained in:
RD 2024-10-24 14:34:07 -07:00
parent 136ce5ee18
commit f65aa08878
3 changed files with 28 additions and 64 deletions

View file

@ -2,11 +2,10 @@
<project version="4"> <project version="4">
<component name="RuffConfigService"> <component name="RuffConfigService">
<option name="alwaysUseGlobalRuff" value="true" /> <option name="alwaysUseGlobalRuff" value="true" />
<option name="projectRuffExecutablePath" value="$USER_HOME$/miniforge3/envs/codeflash312/bin/ruff" /> <option name="globalRuffExecutablePath" value="$USER_HOME$/.local/bin/ruff" />
<option name="projectRuffLspExecutablePath" value="$USER_HOME$/miniforge3/envs/codeflash312/bin/ruff-lsp" />
<option name="ruffConfigPath" value="$PROJECT_DIR$/django/aiservice/pyproject.toml" /> <option name="ruffConfigPath" value="$PROJECT_DIR$/django/aiservice/pyproject.toml" />
<option name="runRuffOnSave" value="true" /> <option name="runRuffOnSave" value="true" />
<option name="useRuffFormat" value="true" /> <option name="useRuffFormat" value="true" />
<option name="useRuffLsp" value="true" /> <option name="useRuffServer" value="true" />
</component> </component>
</project> </project>

View file

@ -18,15 +18,10 @@ ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
def normalize_node(node: ASTNodeT) -> ASTNodeT: def normalize_node(node: ASTNodeT) -> ASTNodeT:
if isinstance( if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and ast.get_docstring(node):
node,
(ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef),
) and ast.get_docstring(node):
node.body = node.body[1:] node.body = node.body[1:]
if hasattr(node, "body"): if hasattr(node, "body"):
node.body = [ node.body = [normalize_node(node) for node in node.body if not isinstance(node, (ast.Import, ast.ImportFrom))]
normalize_node(node) for node in node.body if not isinstance(node, (ast.Import, ast.ImportFrom))
]
return node return node
@ -54,9 +49,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
self.optim_new_class_functions: list[cst.FunctionDef] = [] self.optim_new_class_functions: list[cst.FunctionDef] = []
self.optim_new_functions: list[cst.FunctionDef] = [] self.optim_new_functions: list[cst.FunctionDef] = []
self.preexisting_objects = preexisting_objects self.preexisting_objects = preexisting_objects
self.contextual_functions = contextual_functions.union( self.contextual_functions = contextual_functions.union({(self.class_name, self.function_name)})
{(self.class_name, self.function_name)},
)
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
@ -71,10 +64,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
elif ( elif (
self.preexisting_objects self.preexisting_objects
and (node.name.value, []) not in self.preexisting_objects and (node.name.value, []) not in self.preexisting_objects
and ( and (isinstance(parent, cst.Module) or (parent2 is not None and not isinstance(parent2, cst.ClassDef)))
isinstance(parent, cst.Module)
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
)
): ):
self.optim_new_functions.append(node) self.optim_new_functions.append(node)
@ -111,14 +101,8 @@ class OptimFunctionReplacer(cst.CSTTransformer):
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
return False return False
def leave_FunctionDef( def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
self, if original_node.name.value == self.function_name and (self.depth == 0 or (self.depth == 1 and self.in_class)):
original_node: cst.FunctionDef,
updated_node: cst.FunctionDef,
) -> cst.FunctionDef:
if original_node.name.value == self.function_name and (
self.depth == 0 or (self.depth == 1 and self.in_class)
):
return updated_node.with_changes(body=self.optim_body.body, decorators=self.optim_body.decorators) return updated_node.with_changes(body=self.optim_body.body, decorators=self.optim_body.decorators)
return updated_node return updated_node
@ -129,18 +113,14 @@ class OptimFunctionReplacer(cst.CSTTransformer):
self.in_class = (self.depth == 1) and (node.name.value == self.class_name) self.in_class = (self.depth == 1) and (node.name.value == self.class_name)
return self.in_class return self.in_class
def leave_ClassDef( def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> cst.ClassDef:
self.depth -= 1 self.depth -= 1
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name): if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
self.in_class = False self.in_class = False
return updated_node.with_changes( return updated_node.with_changes(
body=updated_node.body.with_changes( body=updated_node.body.with_changes(
body=(list(updated_node.body.body) + self.optim_new_class_functions), body=(list(updated_node.body.body) + self.optim_new_class_functions)
), )
) )
return updated_node return updated_node
@ -159,15 +139,11 @@ class OptimFunctionReplacer(cst.CSTTransformer):
*node.body[: max_function_index + 1], *node.body[: max_function_index + 1],
*self.optim_new_functions, *self.optim_new_functions,
*node.body[max_function_index + 1 :], *node.body[max_function_index + 1 :],
), )
) )
elif class_index is not None: elif class_index is not None:
node = node.with_changes( node = node.with_changes(
body=( body=(*node.body[: class_index + 1], *self.optim_new_functions, *node.body[class_index + 1 :])
*node.body[: class_index + 1],
*self.optim_new_functions,
*node.body[class_index + 1 :],
),
) )
else: else:
node = node.with_changes(body=(*self.optim_new_functions, *node.body)) node = node.with_changes(body=(*self.optim_new_functions, *node.body))
@ -188,24 +164,21 @@ def replace_functions_in_file(
elif original_function_name.count(".") == 1: elif original_function_name.count(".") == 1:
class_name, function_name = original_function_name.split(".") class_name, function_name = original_function_name.split(".")
else: else:
raise ValueError(f"Don't know how to find {original_function_name} yet!") msg = f"Don't know how to find {original_function_name} yet!"
raise ValueError(msg)
parsed_function_names.append((function_name, class_name)) parsed_function_names.append((function_name, class_name))
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code)) module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
for function_name, class_name in parsed_function_names: for function_name, class_name in parsed_function_names:
visitor = OptimFunctionCollector( visitor = OptimFunctionCollector(function_name, class_name, contextual_functions, preexisting_objects)
function_name,
class_name,
contextual_functions,
preexisting_objects,
)
module.visit(visitor) module.visit(visitor)
if visitor.optim_body is None and not preexisting_objects: if visitor.optim_body is None and not preexisting_objects:
continue continue
if visitor.optim_body is None: if visitor.optim_body is None:
raise ValueError(f"Did not find the function {function_name} in the optimized code") msg = f"Did not find the function {function_name} in the optimized code"
raise ValueError(msg)
transformer = OptimFunctionReplacer( transformer = OptimFunctionReplacer(
visitor.function_name, visitor.function_name,
@ -234,11 +207,7 @@ def replace_functions_and_add_imports(
return add_needed_imports_from_module( return add_needed_imports_from_module(
optimized_code, optimized_code,
replace_functions_in_file( replace_functions_in_file(
source_code, source_code, function_names, optimized_code, preexisting_objects, contextual_functions
function_names,
optimized_code,
preexisting_objects,
contextual_functions,
), ),
file_path_of_module_with_function_to_optimize, file_path_of_module_with_function_to_optimize,
module_abspath, module_abspath,
@ -255,16 +224,6 @@ def replace_function_definitions_in_module(
contextual_functions: set[tuple[str, str]], contextual_functions: set[tuple[str, str]],
project_root_path: Path, project_root_path: Path,
) -> bool: ) -> bool:
""":param function_names: List of qualified (not fully qualified) function names (function_name or
class_name.method_name).
:param optimized_code:
:param file_path_of_module_with_function_to_optimize:
:param module_abspath:
:param preexisting_objects:
:param contextual_functions:
:param project_root_path:
:return:
"""
source_code: str = module_abspath.read_text(encoding="utf8") source_code: str = module_abspath.read_text(encoding="utf8")
new_code: str = replace_functions_and_add_imports( new_code: str = replace_functions_and_add_imports(
source_code, source_code,

View file

@ -12,9 +12,6 @@ packages = [
] ]
keywords = ["codeflash", "performance", "optimization", "ai", "code", "machine learning", "LLM"] keywords = ["codeflash", "performance", "optimization", "ai", "code", "machine learning", "LLM"]
# Don't forget to install the poetry plugins we use too:
# poetry self add poetry-dynamic-versioning
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"
unidiff = ">=0.7.4" unidiff = ">=0.7.4"
@ -40,6 +37,15 @@ returns = ">=0.23"
isort = ">=5.11.0" isort = ">=5.11.0"
dill = "^0.3.8" dill = "^0.3.8"
rich = "^13.8.1" rich = "^13.8.1"
pandas-stubs = "^2.2.3.241009"
types-Pygments = "^2.18.0.20240506"
types-colorama = "^0.4.15.20240311"
types-decorator = "^5.1.8.20240310"
types-jsonschema = "^4.23.0.20240813"
types-requests = "^2.32.0.20241016"
types-six = "^1.16.21.20241009"
types-cffi = "^1.16.0.20240331"
types-openpyxl = "^3.1.5.20241020"
[tool.poetry.group.dev] [tool.poetry.group.dev]
optional = true optional = true