some refactoring

This commit is contained in:
Alvin Ryanputra 2024-12-23 12:27:31 -08:00
parent 3f228539f5
commit 693c150262
9 changed files with 152 additions and 129 deletions

View file

@ -91,3 +91,13 @@ def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:
source = file.read()
tree = ast.parse(source)
return any(isinstance(node, ast.ClassDef) and node.name == class_name for node in ast.walk(tree))
def validate_python_code(code: str) -> str:
"""Validates a string of python code by attempting to compile it"""
try:
compile(code, "<string>", "exec")
except SyntaxError as e:
msg = f"Invalid Python code: {e.msg} (line {e.lineno}, column {e.offset})"
raise ValueError(msg) from e
return code

View file

View file

@ -1,6 +1,125 @@
from __future__ import annotations
import os
from collections import defaultdict
from pathlib import Path
import jedi
import libcst as cst
from jedi.api.classes import Name
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.optimization.function_context import belongs_to_class, belongs_to_function
def get_code_optimization_context(
function_to_optimize: FunctionToOptimize, project_root_path: Path, token_limit: int = 8000
) -> tuple[str, str]:
function_name = function_to_optimize.function_name
file_path = function_to_optimize.file_path
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
file_path_to_qualified_function_names = defaultdict(set)
file_path_to_qualified_function_names[file_path].add(function_to_optimize.qualified_name)
read_only_code_markdown = CodeStringsMarkdown()
final_read_writable_code = ""
names = []
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
if ref.full_name:
if function_to_optimize.parents:
# Check if the reference belongs to the specified class when FunctionParent is provided
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
ref, function_name
):
names.append(ref)
elif belongs_to_function(ref, function_name):
names.append(ref)
for name in names:
try:
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
try:
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
except Exception as e:
# name.full_name can also throw exceptions sometimes
logger.exception(f"Error while getting definition: {e}")
definitions = []
if definitions:
# TODO: there can be multiple definitions, see how to handle such cases
definition = definitions[0]
definition_path = definition.module_path
# The definition is part of this project and not defined within the original function
if (
str(definition_path).startswith(str(project_root_path) + os.sep)
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and not belongs_to_function(definition, function_name)
and definition.module_name != definition.full_name
):
file_path_to_qualified_function_names[definition_path].add(
get_qualified_name(definition.module_name, definition.full_name)
)
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
try:
og_code_containing_helpers = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
read_writable_code = get_read_writable_code(og_code_containing_helpers, qualified_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-writable code: {e}")
continue
if read_writable_code:
final_read_writable_code += f"\n{read_writable_code}"
final_read_writable_code = add_needed_imports_from_module(
src_module_code=og_code_containing_helpers,
dst_module_code=final_read_writable_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=qualified_function_names,
)
try:
read_only_code = get_read_only_code(og_code_containing_helpers, qualified_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=og_code_containing_helpers,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=qualified_function_names,
),
file_path=Path(file_path),
)
if read_only_code_with_imports.code:
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
# final_read_writable_codestring = CodeString(code=final_read_writable_code)
# tokenizer = tiktoken.encoding_for_model("gpt-4o")
# final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
# if final_read_writable_tokens > token_limit:
# logger.debug(
# "Read writable code exceeded token limit, removing helper functions and only keeping function to optimize"
# )
# try:
# read_writable_code = get_read_writable_code(og_code_containing_helpers, qualified_function_names)
# except ValueError as e:
# logger.debug(f"Error while getting read-writable code: {e}")
# continue
print(read_only_code_markdown.markdown)
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
def is_dunder_method(name: str) -> bool:
@ -99,7 +218,7 @@ def get_read_writable_code(code: str, target_functions: set[str]) -> str:
def prune_cst_for_read_only_code(
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
node: cst.CSTNode, target_functions: set[str], prefix: str = "", remove_docstrings: bool = False
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node for read-only context:
@ -136,7 +255,9 @@ def prune_cst_for_read_only_code(
found_in_class = False
new_body = []
for stmt in node.body.body:
filtered, found_target = prune_cst_for_read_only_code(stmt, target_functions, class_prefix)
filtered, found_target = prune_cst_for_read_only_code(
stmt, target_functions, class_prefix, remove_docstrings
)
found_in_class |= found_target
if isinstance(filtered, cst.FunctionDef):
@ -166,7 +287,9 @@ def prune_cst_for_read_only_code(
new_children = []
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_read_only_code(child, target_functions, prefix)
filtered, found_target = prune_cst_for_read_only_code(
child, target_functions, prefix, remove_docstrings
)
if filtered:
new_children.append(filtered)
section_found_target |= found_target
@ -175,7 +298,9 @@ def prune_cst_for_read_only_code(
found_any_target |= section_found_target
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_read_only_code(original_content, target_functions, prefix)
filtered, found_target = prune_cst_for_read_only_code(
original_content, target_functions, prefix, remove_docstrings
)
found_any_target |= found_target
if filtered:
updates[section] = filtered
@ -186,12 +311,12 @@ def prune_cst_for_read_only_code(
return node, found_any_target
def get_read_only_code(code: str, target_functions: set[str]) -> str:
def get_read_only_code(code: str, target_functions: set[str], remove_docstrings: bool = False) -> str:
"""Creates a read-only version of the code by parsing and filtering the code to keep only
class contextual information, and other module scoped variables.
"""
module = cst.parse_module(code)
filtered_node, found_target = prune_cst_for_read_only_code(module, target_functions)
filtered_node, found_target = prune_cst_for_read_only_code(module, target_functions, remove_docstrings)
if not found_target:
raise ValueError("No target functions found in the provided code")
if filtered_node and isinstance(filtered_node, cst.Module):

View file

@ -16,6 +16,7 @@ from typing_extensions import Annotated
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates
from codeflash.code_utils.env_utils import is_end_to_end
from codeflash.code_utils.code_utils import validate_python_code
from codeflash.verification.test_results import TestResults, TestType
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
@ -55,19 +56,9 @@ class BestOptimization(BaseModel):
winning_test_results: TestResults
def validate_python_code(code: str) -> str:
"""Validates a string of python code by attempting to compile it"""
try:
compile(code, "<string>", "exec")
except SyntaxError as e:
msg = f"Invalid Python code: {e.msg} (line {e.lineno}, column {e.offset})"
raise ValueError(msg) from e
return code
class CodeString(BaseModel):
code: Annotated[str, AfterValidator(validate_python_code)]
file_path: Path | None = None
file_path: Optional[Path] = None
class CodeStringsMarkdown(BaseModel):

View file

@ -40,6 +40,7 @@ from codeflash.code_utils.instrument_existing_tests import inject_profiling_into
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.context import code_context_extractor
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, get_functions_to_optimize
from codeflash.either import Failure, Success, is_successful
@ -76,8 +77,6 @@ if TYPE_CHECKING:
from codeflash.either import Result
from codeflash.models.models import CoverageData, FunctionCalledInTest, FunctionSource, OptimizedCandidate
from codeflash.optimization import retriever
class Optimizer:
def __init__(self, args: Namespace) -> None:
@ -716,9 +715,13 @@ class Optimizer:
contextual_dunder_methods.update(helper_dunder_methods)
# Will eventually refactor to use this function instead of the above
read_writable_code, read_only_context_code = retriever.get_code_optimization_context(
read_writable_code, read_only_context_code = code_context_extractor.get_code_optimization_context(
function_to_optimize, project_root
)
logger.info("Read-writable code:")
code_print(read_writable_code)
logger.info("Read-only context code:")
# code_print(read_only_context_code)
return Success(
CodeOptimizationContext(
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,

View file

@ -1,106 +0,0 @@
import os
from collections import defaultdict
from pathlib import Path
import jedi
from jedi.api.classes import Name
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.optimization.cst_manipulator import get_read_only_code, get_read_writable_code
from codeflash.optimization.function_context import belongs_to_class, belongs_to_function
def get_code_optimization_context(function_to_optimize: FunctionToOptimize, project_root_path: Path) -> tuple[str, str]:
function_name = function_to_optimize.function_name
file_path = function_to_optimize.file_path
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
file_path_to_qualified_function_names = defaultdict(set)
file_path_to_qualified_function_names[file_path].add(function_to_optimize.qualified_name)
read_only_code_markdown = CodeStringsMarkdown()
final_read_writable_code = ""
names = []
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
if ref.full_name:
if function_to_optimize.parents:
# Check if the reference belongs to the specified class when FunctionParent is provided
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
ref, function_name
):
names.append(ref)
elif belongs_to_function(ref, function_name):
names.append(ref)
for name in names:
try:
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
try:
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
except Exception as e:
# name.full_name can also throw exceptions sometimes
logger.exception(f"Error while getting definition: {e}")
definitions = []
if definitions:
# TODO: there can be multiple definitions, see how to handle such cases
definition = definitions[0]
definition_path = definition.module_path
# The definition is part of this project and not defined within the original function
if (
str(definition_path).startswith(str(project_root_path) + os.sep)
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and not belongs_to_function(definition, function_name)
and definition.module_name != definition.full_name
):
file_path_to_qualified_function_names[definition_path].add(
get_qualified_name(definition.module_name, definition.full_name)
)
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
try:
og_code_containing_helpers = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
read_writable_code = get_read_writable_code(og_code_containing_helpers, qualified_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-writable code: {e}")
continue
if read_writable_code:
final_read_writable_code += f"\n{read_writable_code}"
final_read_writable_code = add_needed_imports_from_module(
src_module_code=og_code_containing_helpers,
dst_module_code=final_read_writable_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=qualified_function_names,
)
try:
read_only_code = get_read_only_code(og_code_containing_helpers, qualified_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=og_code_containing_helpers,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=qualified_function_names,
),
file_path=Path(file_path),
)
if read_only_code_with_imports.code:
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown

View file

@ -9,7 +9,7 @@ from textwrap import dedent
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
from codeflash.optimization.retriever import get_code_optimization_context
from codeflash.context.code_context_extractor import get_code_optimization_context
class HelperClass:

View file

@ -1,7 +1,7 @@
from textwrap import dedent
import pytest
from codeflash.optimization.cst_manipulator import get_read_only_code
from codeflash.context.code_context_extractor import get_read_only_code
def test_basic_class() -> None:

View file

@ -1,7 +1,7 @@
from textwrap import dedent
import pytest
from codeflash.optimization.cst_manipulator import get_read_writable_code
from codeflash.context.code_context_extractor import get_read_writable_code
def test_simple_function() -> None: