codeflash-internal/cli/codeflash/code_utils/code_utils.py
2023-12-29 18:37:49 -08:00

74 lines
2.3 KiB
Python

import ast
import os
import site
from tempfile import TemporaryDirectory
from typing import Optional, List
def module_name_from_file_path(file_path: str, project_root: str) -> str:
relative_path = os.path.relpath(file_path, project_root)
module_path = relative_path.replace("/", ".")
if module_path.endswith(".py"):
module_path = module_path[:-3]
return module_path
def file_path_from_module_name(module_name: str, project_root: str) -> str:
"Get file path from module path"
file_path = module_name.replace(".", "/")
if not file_path.endswith(".py"):
file_path += ".py"
return os.path.join(project_root, file_path)
def ellipsis_in_ast(module: ast.AST) -> bool:
for node in ast.walk(module):
if isinstance(node, ast.Constant):
if node.value == ...:
return True
return False
def get_imports_from_file(
file_path: Optional[str] = None,
file_string: Optional[str] = None,
file_ast: Optional[ast.AST] = None,
) -> List[ast.Import | ast.ImportFrom]:
assert (
sum([file_path is not None, file_string is not None, file_ast is not None]) == 1
), "Must provide exactly one of file_path, file_string, or file_ast"
if file_path:
with open(file_path, "r") as file:
file_string = file.read()
if file_ast is None:
file_ast = ast.parse(file_string)
imports = []
for node in ast.walk(file_ast):
if isinstance(node, (ast.Import, ast.ImportFrom)):
imports.append(node)
return imports
def get_all_function_names(code: str) -> List[str]:
module = ast.parse(code)
function_names = []
for node in ast.walk(module):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
function_names.append(node.name)
return function_names
def get_run_tmp_file(file_path: str) -> str:
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return os.path.join(get_run_tmp_file.tmpdir.name, file_path)
def path_belongs_to_site_packages(file_path: str) -> bool:
return any( # The definition is not part of a site-package Python library
[
file_path.startswith(site_package_path + os.sep)
for site_package_path in site.getsitepackages()
]
)