some refactoring
This commit is contained in:
parent
3f228539f5
commit
693c150262
9 changed files with 152 additions and 129 deletions
|
|
@ -91,3 +91,13 @@ def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:
|
|||
source = file.read()
|
||||
tree = ast.parse(source)
|
||||
return any(isinstance(node, ast.ClassDef) and node.name == class_name for node in ast.walk(tree))
|
||||
|
||||
|
||||
def validate_python_code(code: str) -> str:
|
||||
"""Validates a string of python code by attempting to compile it"""
|
||||
try:
|
||||
compile(code, "<string>", "exec")
|
||||
except SyntaxError as e:
|
||||
msg = f"Invalid Python code: {e.msg} (line {e.lineno}, column {e.offset})"
|
||||
raise ValueError(msg) from e
|
||||
return code
|
||||
|
|
|
|||
0
codeflash/context/__init__.py
Normal file
0
codeflash/context/__init__.py
Normal file
|
|
@ -1,6 +1,125 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import jedi
|
||||
import libcst as cst
|
||||
from jedi.api.classes import Name
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
from codeflash.optimization.function_context import belongs_to_class, belongs_to_function
|
||||
|
||||
|
||||
def get_code_optimization_context(
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path, token_limit: int = 8000
|
||||
) -> tuple[str, str]:
|
||||
function_name = function_to_optimize.function_name
|
||||
file_path = function_to_optimize.file_path
|
||||
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
file_path_to_qualified_function_names = defaultdict(set)
|
||||
file_path_to_qualified_function_names[file_path].add(function_to_optimize.qualified_name)
|
||||
read_only_code_markdown = CodeStringsMarkdown()
|
||||
final_read_writable_code = ""
|
||||
names = []
|
||||
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
|
||||
if ref.full_name:
|
||||
if function_to_optimize.parents:
|
||||
# Check if the reference belongs to the specified class when FunctionParent is provided
|
||||
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
|
||||
ref, function_name
|
||||
):
|
||||
names.append(ref)
|
||||
elif belongs_to_function(ref, function_name):
|
||||
names.append(ref)
|
||||
|
||||
for name in names:
|
||||
try:
|
||||
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
except Exception as e:
|
||||
try:
|
||||
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
|
||||
except Exception as e:
|
||||
# name.full_name can also throw exceptions sometimes
|
||||
logger.exception(f"Error while getting definition: {e}")
|
||||
definitions = []
|
||||
if definitions:
|
||||
# TODO: there can be multiple definitions, see how to handle such cases
|
||||
definition = definitions[0]
|
||||
definition_path = definition.module_path
|
||||
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and not belongs_to_function(definition, function_name)
|
||||
and definition.module_name != definition.full_name
|
||||
):
|
||||
file_path_to_qualified_function_names[definition_path].add(
|
||||
get_qualified_name(definition.module_name, definition.full_name)
|
||||
)
|
||||
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
|
||||
try:
|
||||
og_code_containing_helpers = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
read_writable_code = get_read_writable_code(og_code_containing_helpers, qualified_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-writable code: {e}")
|
||||
continue
|
||||
|
||||
if read_writable_code:
|
||||
final_read_writable_code += f"\n{read_writable_code}"
|
||||
final_read_writable_code = add_needed_imports_from_module(
|
||||
src_module_code=og_code_containing_helpers,
|
||||
dst_module_code=final_read_writable_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=qualified_function_names,
|
||||
)
|
||||
|
||||
try:
|
||||
read_only_code = get_read_only_code(og_code_containing_helpers, qualified_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
read_only_code_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=og_code_containing_helpers,
|
||||
dst_module_code=read_only_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=qualified_function_names,
|
||||
),
|
||||
file_path=Path(file_path),
|
||||
)
|
||||
if read_only_code_with_imports.code:
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
# final_read_writable_codestring = CodeString(code=final_read_writable_code)
|
||||
# tokenizer = tiktoken.encoding_for_model("gpt-4o")
|
||||
# final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
|
||||
# if final_read_writable_tokens > token_limit:
|
||||
# logger.debug(
|
||||
# "Read writable code exceeded token limit, removing helper functions and only keeping function to optimize"
|
||||
# )
|
||||
# try:
|
||||
# read_writable_code = get_read_writable_code(og_code_containing_helpers, qualified_function_names)
|
||||
# except ValueError as e:
|
||||
# logger.debug(f"Error while getting read-writable code: {e}")
|
||||
# continue
|
||||
print(read_only_code_markdown.markdown)
|
||||
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
|
||||
|
||||
|
||||
def is_dunder_method(name: str) -> bool:
|
||||
|
|
@ -99,7 +218,7 @@ def get_read_writable_code(code: str, target_functions: set[str]) -> str:
|
|||
|
||||
|
||||
def prune_cst_for_read_only_code(
|
||||
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
|
||||
node: cst.CSTNode, target_functions: set[str], prefix: str = "", remove_docstrings: bool = False
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node for read-only context:
|
||||
|
||||
|
|
@ -136,7 +255,9 @@ def prune_cst_for_read_only_code(
|
|||
found_in_class = False
|
||||
new_body = []
|
||||
for stmt in node.body.body:
|
||||
filtered, found_target = prune_cst_for_read_only_code(stmt, target_functions, class_prefix)
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
stmt, target_functions, class_prefix, remove_docstrings
|
||||
)
|
||||
found_in_class |= found_target
|
||||
|
||||
if isinstance(filtered, cst.FunctionDef):
|
||||
|
|
@ -166,7 +287,9 @@ def prune_cst_for_read_only_code(
|
|||
new_children = []
|
||||
section_found_target = False
|
||||
for child in original_content:
|
||||
filtered, found_target = prune_cst_for_read_only_code(child, target_functions, prefix)
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
child, target_functions, prefix, remove_docstrings
|
||||
)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
section_found_target |= found_target
|
||||
|
|
@ -175,7 +298,9 @@ def prune_cst_for_read_only_code(
|
|||
found_any_target |= section_found_target
|
||||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, found_target = prune_cst_for_read_only_code(original_content, target_functions, prefix)
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
original_content, target_functions, prefix, remove_docstrings
|
||||
)
|
||||
found_any_target |= found_target
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
|
|
@ -186,12 +311,12 @@ def prune_cst_for_read_only_code(
|
|||
return node, found_any_target
|
||||
|
||||
|
||||
def get_read_only_code(code: str, target_functions: set[str]) -> str:
|
||||
def get_read_only_code(code: str, target_functions: set[str], remove_docstrings: bool = False) -> str:
|
||||
"""Creates a read-only version of the code by parsing and filtering the code to keep only
|
||||
class contextual information, and other module scoped variables.
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
filtered_node, found_target = prune_cst_for_read_only_code(module, target_functions)
|
||||
filtered_node, found_target = prune_cst_for_read_only_code(module, target_functions, remove_docstrings)
|
||||
if not found_target:
|
||||
raise ValueError("No target functions found in the provided code")
|
||||
if filtered_node and isinstance(filtered_node, cst.Module):
|
||||
|
|
@ -16,6 +16,7 @@ from typing_extensions import Annotated
|
|||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates
|
||||
from codeflash.code_utils.env_utils import is_end_to_end
|
||||
from codeflash.code_utils.code_utils import validate_python_code
|
||||
from codeflash.verification.test_results import TestResults, TestType
|
||||
|
||||
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
|
||||
|
|
@ -55,19 +56,9 @@ class BestOptimization(BaseModel):
|
|||
winning_test_results: TestResults
|
||||
|
||||
|
||||
def validate_python_code(code: str) -> str:
|
||||
"""Validates a string of python code by attempting to compile it"""
|
||||
try:
|
||||
compile(code, "<string>", "exec")
|
||||
except SyntaxError as e:
|
||||
msg = f"Invalid Python code: {e.msg} (line {e.lineno}, column {e.offset})"
|
||||
raise ValueError(msg) from e
|
||||
return code
|
||||
|
||||
|
||||
class CodeString(BaseModel):
|
||||
code: Annotated[str, AfterValidator(validate_python_code)]
|
||||
file_path: Path | None = None
|
||||
file_path: Optional[Path] = None
|
||||
|
||||
|
||||
class CodeStringsMarkdown(BaseModel):
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from codeflash.code_utils.instrument_existing_tests import inject_profiling_into
|
|||
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
|
||||
from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.context import code_context_extractor
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, get_functions_to_optimize
|
||||
from codeflash.either import Failure, Success, is_successful
|
||||
|
|
@ -76,8 +77,6 @@ if TYPE_CHECKING:
|
|||
from codeflash.either import Result
|
||||
from codeflash.models.models import CoverageData, FunctionCalledInTest, FunctionSource, OptimizedCandidate
|
||||
|
||||
from codeflash.optimization import retriever
|
||||
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, args: Namespace) -> None:
|
||||
|
|
@ -716,9 +715,13 @@ class Optimizer:
|
|||
contextual_dunder_methods.update(helper_dunder_methods)
|
||||
|
||||
# Will eventually refactor to use this function instead of the above
|
||||
read_writable_code, read_only_context_code = retriever.get_code_optimization_context(
|
||||
read_writable_code, read_only_context_code = code_context_extractor.get_code_optimization_context(
|
||||
function_to_optimize, project_root
|
||||
)
|
||||
logger.info("Read-writable code:")
|
||||
code_print(read_writable_code)
|
||||
logger.info("Read-only context code:")
|
||||
# code_print(read_only_context_code)
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
|
||||
|
|
|
|||
|
|
@ -1,106 +0,0 @@
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import jedi
|
||||
from jedi.api.classes import Name
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
from codeflash.optimization.cst_manipulator import get_read_only_code, get_read_writable_code
|
||||
from codeflash.optimization.function_context import belongs_to_class, belongs_to_function
|
||||
|
||||
|
||||
def get_code_optimization_context(function_to_optimize: FunctionToOptimize, project_root_path: Path) -> tuple[str, str]:
|
||||
function_name = function_to_optimize.function_name
|
||||
file_path = function_to_optimize.file_path
|
||||
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
file_path_to_qualified_function_names = defaultdict(set)
|
||||
file_path_to_qualified_function_names[file_path].add(function_to_optimize.qualified_name)
|
||||
read_only_code_markdown = CodeStringsMarkdown()
|
||||
final_read_writable_code = ""
|
||||
names = []
|
||||
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
|
||||
if ref.full_name:
|
||||
if function_to_optimize.parents:
|
||||
# Check if the reference belongs to the specified class when FunctionParent is provided
|
||||
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
|
||||
ref, function_name
|
||||
):
|
||||
names.append(ref)
|
||||
elif belongs_to_function(ref, function_name):
|
||||
names.append(ref)
|
||||
|
||||
for name in names:
|
||||
try:
|
||||
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
except Exception as e:
|
||||
try:
|
||||
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
|
||||
except Exception as e:
|
||||
# name.full_name can also throw exceptions sometimes
|
||||
logger.exception(f"Error while getting definition: {e}")
|
||||
definitions = []
|
||||
if definitions:
|
||||
# TODO: there can be multiple definitions, see how to handle such cases
|
||||
definition = definitions[0]
|
||||
definition_path = definition.module_path
|
||||
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and not belongs_to_function(definition, function_name)
|
||||
and definition.module_name != definition.full_name
|
||||
):
|
||||
file_path_to_qualified_function_names[definition_path].add(
|
||||
get_qualified_name(definition.module_name, definition.full_name)
|
||||
)
|
||||
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
|
||||
try:
|
||||
og_code_containing_helpers = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
read_writable_code = get_read_writable_code(og_code_containing_helpers, qualified_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-writable code: {e}")
|
||||
continue
|
||||
|
||||
if read_writable_code:
|
||||
final_read_writable_code += f"\n{read_writable_code}"
|
||||
final_read_writable_code = add_needed_imports_from_module(
|
||||
src_module_code=og_code_containing_helpers,
|
||||
dst_module_code=final_read_writable_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=qualified_function_names,
|
||||
)
|
||||
|
||||
try:
|
||||
read_only_code = get_read_only_code(og_code_containing_helpers, qualified_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
read_only_code_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=og_code_containing_helpers,
|
||||
dst_module_code=read_only_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=qualified_function_names,
|
||||
),
|
||||
file_path=Path(file_path),
|
||||
)
|
||||
if read_only_code_with_imports.code:
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
|
||||
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
|
||||
|
|
@ -9,7 +9,7 @@ from textwrap import dedent
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.optimization.retriever import get_code_optimization_context
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
|
||||
class HelperClass:
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
from codeflash.optimization.cst_manipulator import get_read_only_code
|
||||
from codeflash.context.code_context_extractor import get_read_only_code
|
||||
|
||||
|
||||
def test_basic_class() -> None:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
from codeflash.optimization.cst_manipulator import get_read_writable_code
|
||||
from codeflash.context.code_context_extractor import get_read_writable_code
|
||||
|
||||
|
||||
def test_simple_function() -> None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue