added helpers of helpers into readonly context, and refactored code slightly
This commit is contained in:
parent
641088fad5
commit
a10b399dbe
6 changed files with 779 additions and 227 deletions
|
|
@ -0,0 +1,27 @@
|
|||
from code_to_optimize.code_directories.retriever.utils import DataProcessor
|
||||
|
||||
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def transform(self, data):
|
||||
self.data = data
|
||||
return self.data
|
||||
|
||||
def transform_using_own_method(self, data):
|
||||
return self.transform(data)
|
||||
|
||||
def transform_using_same_file_function(self, data):
|
||||
return update_data(data)
|
||||
|
||||
def transform_data_all_same_file(self, data):
|
||||
new_data = update_data(data)
|
||||
return self.transform_using_own_method(new_data)
|
||||
|
||||
def circular_dependency(self, data):
|
||||
return DataProcessor().circular_dependency(data)
|
||||
|
||||
|
||||
def update_data(data):
|
||||
return data + " updated"
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
import math
|
||||
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
"""A class for processing data."""
|
||||
|
|
@ -25,3 +29,19 @@ class DataProcessor:
|
|||
|
||||
def do_something(self):
|
||||
print("something")
|
||||
|
||||
def transform_data(self, data: str) -> str:
|
||||
"""Transform the processed data"""
|
||||
return DataTransformer().transform(data)
|
||||
|
||||
def transform_data_own_method(self, data: str) -> str:
|
||||
"""Transform the processed data using own method"""
|
||||
return DataTransformer().transform_using_own_method(data)
|
||||
|
||||
def transform_data_same_file_function(self, data: str) -> str:
|
||||
"""Transform the processed data using a function from the same file"""
|
||||
return DataTransformer().transform_using_same_file_function(data)
|
||||
|
||||
def circular_dependency(self, data: str) -> str:
|
||||
"""Test circular dependency"""
|
||||
return DataTransformer().circular_dependency(data)
|
||||
|
|
|
|||
|
|
@ -15,99 +15,36 @@ from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
|||
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
from codeflash.optimization.function_context import belongs_to_class, belongs_to_function
|
||||
from codeflash.optimization.function_context import belongs_to_function_qualified
|
||||
|
||||
|
||||
def get_code_optimization_context(
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path, token_limit: int = 8000
|
||||
) -> tuple[str, str]:
|
||||
function_name = function_to_optimize.function_name
|
||||
file_path = function_to_optimize.file_path
|
||||
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
file_path_to_qualified_function_names = defaultdict(set)
|
||||
file_path_to_qualified_function_names[file_path].add(function_to_optimize.qualified_name)
|
||||
read_only_code_markdown = CodeStringsMarkdown()
|
||||
final_read_writable_code = ""
|
||||
names = []
|
||||
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
|
||||
if ref.full_name:
|
||||
if function_to_optimize.parents:
|
||||
# Check if the reference belongs to the specified class when FunctionParent is provided
|
||||
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
|
||||
ref, function_name
|
||||
):
|
||||
names.append(ref)
|
||||
elif belongs_to_function(ref, function_name):
|
||||
names.append(ref)
|
||||
# Get qualified names and fully qualified names(fqn) of helpers
|
||||
helpers_of_fto, helpers_of_fto_fqn = get_file_path_to_helper_functions_dict(
|
||||
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
|
||||
)
|
||||
helpers_of_helpers, helpers_of_helpers_fqn = get_file_path_to_helper_functions_dict(
|
||||
helpers_of_fto, project_root_path
|
||||
)
|
||||
|
||||
for name in names:
|
||||
try:
|
||||
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
except Exception as e:
|
||||
try:
|
||||
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
|
||||
except Exception as e:
|
||||
# name.full_name can also throw exceptions sometimes
|
||||
logger.exception(f"Error while getting definition: {e}")
|
||||
definitions = []
|
||||
if definitions:
|
||||
# TODO: there can be multiple definitions, see how to handle such cases
|
||||
definition = definitions[0]
|
||||
definition_path = definition.module_path
|
||||
# Add function to optimize
|
||||
helpers_of_fto[function_to_optimize.file_path].add(function_to_optimize.qualified_name)
|
||||
helpers_of_fto_fqn[function_to_optimize.file_path].add(
|
||||
function_to_optimize.qualified_name_with_modules_from_root(project_root_path)
|
||||
)
|
||||
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and not belongs_to_function(definition, function_name)
|
||||
and definition.module_name != definition.full_name
|
||||
):
|
||||
file_path_to_qualified_function_names[definition_path].add(
|
||||
get_qualified_name(definition.module_name, definition.full_name)
|
||||
)
|
||||
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
|
||||
try:
|
||||
og_code_containing_helpers = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
read_writable_code = get_read_writable_code(og_code_containing_helpers, qualified_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-writable code: {e}")
|
||||
continue
|
||||
|
||||
if read_writable_code:
|
||||
final_read_writable_code += f"\n{read_writable_code}"
|
||||
final_read_writable_code = add_needed_imports_from_module(
|
||||
src_module_code=og_code_containing_helpers,
|
||||
dst_module_code=final_read_writable_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=qualified_function_names,
|
||||
)
|
||||
|
||||
try:
|
||||
read_only_code = get_read_only_code(og_code_containing_helpers, qualified_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
read_only_code_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=og_code_containing_helpers,
|
||||
dst_module_code=read_only_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=qualified_function_names,
|
||||
),
|
||||
file_path=Path(file_path),
|
||||
)
|
||||
if read_only_code_with_imports.code:
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
# Extract code
|
||||
final_read_writable_code = get_all_read_writable_code(helpers_of_fto, helpers_of_fto_fqn, project_root_path)
|
||||
read_only_code_markdown = get_all_read_only_code_context(
|
||||
helpers_of_fto,
|
||||
helpers_of_fto_fqn,
|
||||
helpers_of_helpers,
|
||||
helpers_of_helpers_fqn,
|
||||
project_root_path,
|
||||
remove_docstrings=False,
|
||||
)
|
||||
|
||||
# Handle token limits
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-4o")
|
||||
|
|
@ -121,12 +58,85 @@ def get_code_optimization_context(
|
|||
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
|
||||
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
|
||||
|
||||
# Get read-only code context again, this time without docstrings
|
||||
# Extract read only code without docstrings
|
||||
read_only_code_no_docstring_markdown = get_all_read_only_code_context(
|
||||
helpers_of_fto,
|
||||
helpers_of_fto_fqn,
|
||||
helpers_of_helpers,
|
||||
helpers_of_helpers_fqn,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
)
|
||||
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_code_no_docstring_markdown.markdown))
|
||||
total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens
|
||||
if total_tokens <= token_limit:
|
||||
return CodeString(code=final_read_writable_code).code, read_only_code_no_docstring_markdown.markdown
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing read-only code")
|
||||
return CodeString(code=final_read_writable_code).code, ""
|
||||
|
||||
|
||||
def get_all_read_writable_code(
|
||||
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
|
||||
) -> str:
|
||||
final_read_writable_code = ""
|
||||
# Extract code from file paths that contain fto and first degree helpers
|
||||
for file_path, qualified_function_names in helpers_of_fto.items():
|
||||
try:
|
||||
original_code = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
read_writable_code = get_read_writable_code(original_code, qualified_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-writable code: {e}")
|
||||
continue
|
||||
|
||||
if read_writable_code:
|
||||
final_read_writable_code += f"\n{read_writable_code}"
|
||||
final_read_writable_code = add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=final_read_writable_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=helpers_of_fto_fqn[file_path],
|
||||
)
|
||||
return final_read_writable_code
|
||||
|
||||
|
||||
def get_all_read_only_code_context(
|
||||
helpers_of_fto: dict[Path, set[str]],
|
||||
helpers_of_fto_fqn: dict[Path, set[str]],
|
||||
helpers_of_helpers: dict[Path, set[str]],
|
||||
helpers_of_helpers_fqn: dict[Path, set[str]],
|
||||
project_root_path: Path,
|
||||
remove_docstrings: bool = False,
|
||||
) -> CodeStringsMarkdown:
|
||||
# Rearrange to remove overlaps, so we only access each file path once
|
||||
helpers_of_helpers_no_overlap = defaultdict(set)
|
||||
helpers_of_helpers_no_overlap_fqn = defaultdict(set)
|
||||
for file_path in helpers_of_helpers:
|
||||
if file_path in helpers_of_fto:
|
||||
# Remove duplicates, in case a helper of helper is also a helper of fto
|
||||
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
|
||||
helpers_of_helpers_fqn[file_path] -= helpers_of_fto_fqn[file_path]
|
||||
else:
|
||||
helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path]
|
||||
helpers_of_helpers_no_overlap_fqn[file_path] = helpers_of_helpers_fqn[file_path]
|
||||
|
||||
read_only_code_markdown = CodeStringsMarkdown()
|
||||
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
|
||||
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
|
||||
for file_path, qualified_function_names in helpers_of_fto.items():
|
||||
try:
|
||||
original_code = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
read_only_code = get_read_only_code(
|
||||
og_code_containing_helpers, qualified_function_names, remove_docstrings=True
|
||||
original_code, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
|
|
@ -134,24 +144,93 @@ def get_code_optimization_context(
|
|||
|
||||
read_only_code_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=og_code_containing_helpers,
|
||||
src_module_code=original_code,
|
||||
dst_module_code=read_only_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=qualified_function_names,
|
||||
helper_functions_fqn=helpers_of_fto_fqn[file_path] | helpers_of_helpers_fqn[file_path],
|
||||
),
|
||||
file_path=Path(file_path),
|
||||
)
|
||||
if read_only_code_with_imports.code:
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
|
||||
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
|
||||
if total_tokens <= token_limit:
|
||||
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
|
||||
if read_only_code_with_imports.code:
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing read-only code")
|
||||
return CodeString(code=final_read_writable_code).code, ""
|
||||
# Extract code from file paths containing helpers of helpers
|
||||
for file_path, qualified_helper_function_names in helpers_of_helpers_no_overlap.items():
|
||||
try:
|
||||
original_code = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
read_only_code = get_read_only_code(
|
||||
original_code, set(), qualified_helper_function_names, remove_docstrings
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
read_only_code_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=read_only_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=helpers_of_helpers_no_overlap_fqn[file_path],
|
||||
),
|
||||
file_path=Path(file_path),
|
||||
)
|
||||
if read_only_code_with_imports.code:
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
return read_only_code_markdown
|
||||
|
||||
|
||||
def get_file_path_to_helper_functions_dict(
|
||||
file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path
|
||||
) -> tuple[dict[Path, set[str]], dict[Path, set[str]]]:
|
||||
file_path_to_helper_function_qualified_names = defaultdict(set)
|
||||
file_path_to_helper_function_fqn = defaultdict(set)
|
||||
for file_path in file_path_to_qualified_function_names:
|
||||
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
|
||||
|
||||
for qualified_function_name in file_path_to_qualified_function_names[file_path]:
|
||||
names = [
|
||||
ref
|
||||
for ref in file_refs
|
||||
if ref.full_name and belongs_to_function_qualified(ref, qualified_function_name)
|
||||
]
|
||||
for name in names:
|
||||
try:
|
||||
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
except Exception as e:
|
||||
try:
|
||||
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
|
||||
except Exception as e:
|
||||
# name.full_name can also throw exceptions sometimes
|
||||
logger.exception(f"Error while getting definition: {e}")
|
||||
definitions = []
|
||||
if definitions:
|
||||
# TODO: there can be multiple definitions, see how to handle such cases
|
||||
definition = definitions[0]
|
||||
definition_path = definition.module_path
|
||||
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and definition.type == "function"
|
||||
and not belongs_to_function_qualified(definition, qualified_function_name)
|
||||
):
|
||||
file_path_to_helper_function_qualified_names[definition_path].add(
|
||||
get_qualified_name(definition.module_name, definition.full_name)
|
||||
)
|
||||
file_path_to_helper_function_fqn[definition_path].add(definition.full_name)
|
||||
|
||||
return file_path_to_helper_function_qualified_names, file_path_to_helper_function_fqn
|
||||
|
||||
|
||||
def is_dunder_method(name: str) -> bool:
|
||||
|
|
@ -166,7 +245,6 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
|
|||
|
||||
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
|
||||
"""Removes the docstring from an indented block if it exists"""
|
||||
print(indented_block)
|
||||
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
|
||||
return indented_block
|
||||
first_stmt = indented_block.body[0].body[0]
|
||||
|
|
@ -268,7 +346,11 @@ def get_read_writable_code(code: str, target_functions: set[str]) -> str:
|
|||
|
||||
|
||||
def prune_cst_for_read_only_code(
|
||||
node: cst.CSTNode, target_functions: set[str], prefix: str = "", remove_docstrings: bool = False
|
||||
node: cst.CSTNode,
|
||||
target_functions: set[str],
|
||||
helpers_of_helper_functions: set[str],
|
||||
prefix: str = "",
|
||||
remove_docstrings: bool = False,
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node for read-only context:
|
||||
|
||||
|
|
@ -284,6 +366,8 @@ def prune_cst_for_read_only_code(
|
|||
if isinstance(node, cst.FunctionDef):
|
||||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
# If it's a target function, remove it but mark found_target = True
|
||||
if qualified_name in helpers_of_helper_functions:
|
||||
return node, True
|
||||
if qualified_name in target_functions:
|
||||
return None, True
|
||||
# Keep only dunder methods
|
||||
|
|
@ -309,15 +393,9 @@ def prune_cst_for_read_only_code(
|
|||
new_class_body: list[CSTNode] = []
|
||||
for stmt in node.body.body:
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
stmt, target_functions, class_prefix, remove_docstrings=remove_docstrings
|
||||
stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
found_in_class |= found_target
|
||||
|
||||
if isinstance(filtered, cst.FunctionDef):
|
||||
# Check if it's a target or non-dunder method
|
||||
qname = f"{class_prefix}.{filtered.name.value}"
|
||||
if qname in target_functions or not is_dunder_method(filtered.name.value):
|
||||
continue
|
||||
if filtered:
|
||||
new_class_body.append(filtered)
|
||||
|
||||
|
|
@ -345,7 +423,7 @@ def prune_cst_for_read_only_code(
|
|||
section_found_target = False
|
||||
for child in original_content:
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
child, target_functions, prefix, remove_docstrings=remove_docstrings
|
||||
child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
|
|
@ -356,25 +434,30 @@ def prune_cst_for_read_only_code(
|
|||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
original_content, target_functions, prefix, remove_docstrings=remove_docstrings
|
||||
original_content,
|
||||
target_functions,
|
||||
helpers_of_helper_functions,
|
||||
prefix,
|
||||
remove_docstrings=remove_docstrings,
|
||||
)
|
||||
found_any_target |= found_target
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
|
||||
if updates:
|
||||
return (node.with_changes(**updates), found_any_target)
|
||||
|
||||
return node, found_any_target
|
||||
return None, False
|
||||
|
||||
|
||||
def get_read_only_code(code: str, target_functions: set[str], remove_docstrings: bool = False) -> str:
|
||||
def get_read_only_code(
|
||||
code: str, target_functions: set[str], helpers_of_helper_functions: set[str], remove_docstrings: bool = False
|
||||
) -> str:
|
||||
"""Creates a read-only version of the code by parsing and filtering the code to keep only
|
||||
class contextual information, and other module scoped variables.
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
filtered_node, found_target = prune_cst_for_read_only_code(
|
||||
module, target_functions, remove_docstrings=remove_docstrings
|
||||
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
|
||||
)
|
||||
if not found_target:
|
||||
raise ValueError("No target functions found in the provided code")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,11 @@ from jedi.api.classes import Name
|
|||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path, path_belongs_to_site_packages
|
||||
from codeflash.code_utils.code_utils import (
|
||||
get_qualified_name,
|
||||
module_name_from_file_path,
|
||||
path_belongs_to_site_packages,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent, FunctionSource
|
||||
|
||||
|
|
@ -26,7 +30,7 @@ def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool:
|
|||
|
||||
|
||||
def belongs_to_function(name: Name, function_name: str) -> bool:
|
||||
"""Check if the given jedi Name is a direct child of the specified function"""
|
||||
"""Check if the given jedi Name is a direct child of the specified function."""
|
||||
if name.name == function_name: # Handles function definition and recursive function calls
|
||||
return False
|
||||
if name := name.parent():
|
||||
|
|
@ -36,13 +40,28 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
|
|||
|
||||
|
||||
def belongs_to_class(name: Name, class_name: str) -> bool:
|
||||
"""Check if given jedi Name is a direct child of the specified class"""
|
||||
"""Check if given jedi Name is a direct child of the specified class."""
|
||||
while name := name.parent():
|
||||
if name.type == "class":
|
||||
return name.name == class_name
|
||||
return False
|
||||
|
||||
|
||||
def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> bool:
|
||||
"""Check if the given jedi Name is a direct child of the specified function, matched by qualified function name."""
|
||||
try:
|
||||
if get_qualified_name(name.module_name, name.full_name) == qualified_function_name:
|
||||
# Handles function definition and recursive function calls
|
||||
return False
|
||||
if name := name.parent():
|
||||
if name.type == "function":
|
||||
return get_qualified_name(name.module_name, name.full_name) == qualified_function_name
|
||||
return False
|
||||
except ValueError as e:
|
||||
logger.exception(f"Error while checking if {name.full_name} belongs to {qualified_function_name}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_type_annotation_context(
|
||||
function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path
|
||||
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ class HelperClass:
|
|||
return self.name
|
||||
|
||||
|
||||
def main_method():
|
||||
return "hello"
|
||||
|
||||
|
||||
class MainClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
|
@ -67,7 +71,7 @@ def test_code_replacement10() -> None:
|
|||
func_top_optimize = FunctionToOptimize(
|
||||
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
|
||||
)
|
||||
original_code = file_path.read_text()
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize=func_top_optimize, project_root_path=file_path.parent
|
||||
)
|
||||
|
|
@ -90,7 +94,6 @@ def test_code_replacement10() -> None:
|
|||
```python:{file_path}
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
|
@ -151,7 +154,6 @@ class Graph:
|
|||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class Graph:
|
||||
def __init__(self, vertices):
|
||||
self.graph = defaultdict(list)
|
||||
|
|
@ -184,14 +186,8 @@ def test_bubble_sort_helper() -> None:
|
|||
)
|
||||
|
||||
expected_read_write_context = """
|
||||
from bubble_sort_with_math import sorter
|
||||
import math
|
||||
|
||||
def sort_from_another_file(arr):
|
||||
sorted_arr = sorter(arr)
|
||||
return sorted_arr
|
||||
|
||||
|
||||
from bubble_sort_with_math import sorter
|
||||
|
||||
def sorter(arr):
|
||||
arr.sort()
|
||||
|
|
@ -199,6 +195,12 @@ def sorter(arr):
|
|||
print(x)
|
||||
return arr
|
||||
|
||||
|
||||
|
||||
def sort_from_another_file(arr):
|
||||
sorted_arr = sorter(arr)
|
||||
return sorted_arr
|
||||
|
||||
"""
|
||||
expected_read_only_context = ""
|
||||
|
||||
|
|
@ -206,84 +208,6 @@ def sorter(arr):
|
|||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_repo_helper() -> None:
|
||||
path_to_file = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "main.py"
|
||||
)
|
||||
path_to_utils = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="fetch_and_process_data",
|
||||
file_path=str(path_to_file),
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, Path(__file__).resolve().parent.parent
|
||||
)
|
||||
expected_read_write_context = """
|
||||
import requests
|
||||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
def fetch_and_process_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
response.raise_for_status()
|
||||
|
||||
raw_data = response.text
|
||||
|
||||
# Use code from another file (utils.py)
|
||||
processor = DataProcessor()
|
||||
processed = processor.process_data(raw_data)
|
||||
processed = processor.add_prefix(processed)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def process_data(self, raw_data: str) -> str:
|
||||
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
||||
return raw_data.upper()
|
||||
|
||||
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
|
||||
\"\"\"Add a prefix to the processed data.\"\"\"
|
||||
return prefix + data
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_file}
|
||||
if __name__ == "__main__":
|
||||
result = fetch_and_process_data()
|
||||
print("Processed data:", result)
|
||||
```
|
||||
```python:{path_to_utils}
|
||||
import math
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_flavio_typed_code_helper() -> None:
|
||||
code = '''
|
||||
|
||||
|
|
@ -569,7 +493,27 @@ class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
|||
"""Interface for cache backends used by the persistent cache decorator."""
|
||||
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
def hash_key(
|
||||
self,
|
||||
*,
|
||||
func: Callable[_P, Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> tuple[str, _KEY_T]: ...
|
||||
|
||||
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
|
||||
...
|
||||
|
||||
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
|
||||
...
|
||||
|
||||
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
|
||||
|
||||
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
|
||||
|
||||
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
|
||||
|
|
@ -885,3 +829,462 @@ class HelperClass:
|
|||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, opt.args.project_root
|
||||
)
|
||||
|
||||
|
||||
def test_repo_helper() -> None:
|
||||
path_to_file = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "main.py"
|
||||
)
|
||||
path_to_utils = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="fetch_and_process_data",
|
||||
file_path=str(path_to_file),
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
|
||||
)
|
||||
expected_read_write_context = """
|
||||
import requests
|
||||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def process_data(self, raw_data: str) -> str:
|
||||
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
||||
return raw_data.upper()
|
||||
|
||||
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
|
||||
\"\"\"Add a prefix to the processed data.\"\"\"
|
||||
return prefix + data
|
||||
|
||||
|
||||
|
||||
def fetch_and_process_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
response.raise_for_status()
|
||||
|
||||
raw_data = response.text
|
||||
|
||||
# Use code from another file (utils.py)
|
||||
processor = DataProcessor()
|
||||
processed = processor.process_data(raw_data)
|
||||
processed = processor.add_prefix(processed)
|
||||
|
||||
return processed
|
||||
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_utils}
|
||||
import math
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
```python:{path_to_file}
|
||||
if __name__ == "__main__":
|
||||
result = fetch_and_process_data()
|
||||
print("Processed data:", result)
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_repo_helper_of_helper() -> None:
|
||||
path_to_file = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "main.py"
|
||||
)
|
||||
path_to_utils = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
|
||||
)
|
||||
path_to_transform_utils = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "code_to_optimize"
|
||||
/ "code_directories"
|
||||
/ "retriever"
|
||||
/ "transform_utils.py"
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="fetch_and_transform_data",
|
||||
file_path=str(path_to_file),
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
|
||||
)
|
||||
expected_read_write_context = """
|
||||
from transform_utils import DataTransformer
|
||||
import requests
|
||||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def process_data(self, raw_data: str) -> str:
|
||||
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
||||
return raw_data.upper()
|
||||
|
||||
def transform_data(self, data: str) -> str:
|
||||
\"\"\"Transform the processed data\"\"\"
|
||||
return DataTransformer().transform(data)
|
||||
|
||||
|
||||
|
||||
def fetch_and_transform_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
|
||||
raw_data = response.text
|
||||
|
||||
# Use code from another file (utils.py)
|
||||
processor = DataProcessor()
|
||||
processed = processor.process_data(raw_data)
|
||||
transformed = processor.transform_data(processed)
|
||||
|
||||
return transformed
|
||||
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_utils}
|
||||
import math
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
```python:{path_to_file}
|
||||
if __name__ == "__main__":
|
||||
result = fetch_and_process_data()
|
||||
print("Processed data:", result)
|
||||
```
|
||||
```python:{path_to_transform_utils}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def transform(self, data):
|
||||
self.data = data
|
||||
return self.data
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_repo_helper_of_helper_same_class() -> None:
|
||||
path_to_utils = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
|
||||
)
|
||||
path_to_transform_utils = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "code_to_optimize"
|
||||
/ "code_directories"
|
||||
/ "retriever"
|
||||
/ "transform_utils.py"
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="transform_data_own_method",
|
||||
file_path=str(path_to_utils),
|
||||
parents=[FunctionParent(name="DataProcessor", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
|
||||
)
|
||||
expected_read_write_context = """
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
class DataTransformer:
|
||||
|
||||
def transform_using_own_method(self, data):
|
||||
return self.transform(data)
|
||||
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def transform_data_own_method(self, data: str) -> str:
|
||||
\"\"\"Transform the processed data using own method\"\"\"
|
||||
return DataTransformer().transform_using_own_method(data)
|
||||
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_transform_utils}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def transform(self, data):
|
||||
self.data = data
|
||||
return self.data
|
||||
```
|
||||
```python:{path_to_utils}
|
||||
import math
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_repo_helper_of_helper_same_file() -> None:
|
||||
path_to_utils = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
|
||||
)
|
||||
path_to_transform_utils = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "code_to_optimize"
|
||||
/ "code_directories"
|
||||
/ "retriever"
|
||||
/ "transform_utils.py"
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="transform_data_same_file_function",
|
||||
file_path=str(path_to_utils),
|
||||
parents=[FunctionParent(name="DataProcessor", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
|
||||
)
|
||||
expected_read_write_context = """
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
class DataTransformer:
|
||||
|
||||
def transform_using_same_file_function(self, data):
|
||||
return update_data(data)
|
||||
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def transform_data_same_file_function(self, data: str) -> str:
|
||||
\"\"\"Transform the processed data using a function from the same file\"\"\"
|
||||
return DataTransformer().transform_using_same_file_function(data)
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_transform_utils}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
|
||||
def update_data(data):
|
||||
return data + " updated"
|
||||
```
|
||||
```python:{path_to_utils}
|
||||
import math
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_repo_helper_all_same_file() -> None:
|
||||
path_to_transform_utils = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "code_to_optimize"
|
||||
/ "code_directories"
|
||||
/ "retriever"
|
||||
/ "transform_utils.py"
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="transform_data_all_same_file",
|
||||
file_path=str(path_to_transform_utils),
|
||||
parents=[FunctionParent(name="DataTransformer", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
|
||||
)
|
||||
expected_read_write_context = """
|
||||
class DataTransformer:
|
||||
|
||||
def transform_using_own_method(self, data):
|
||||
return self.transform(data)
|
||||
|
||||
def transform_data_all_same_file(self, data):
|
||||
new_data = update_data(data)
|
||||
return self.transform_using_own_method(new_data)
|
||||
|
||||
|
||||
def update_data(data):
|
||||
return data + " updated"
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_transform_utils}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def transform(self, data):
|
||||
self.data = data
|
||||
return self.data
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_repo_helper_circular_dependency() -> None:
|
||||
path_to_transform_utils = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "code_to_optimize"
|
||||
/ "code_directories"
|
||||
/ "retriever"
|
||||
/ "transform_utils.py"
|
||||
)
|
||||
path_to_utils = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="circular_dependency",
|
||||
file_path=str(path_to_transform_utils),
|
||||
parents=[FunctionParent(name="DataTransformer", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
|
||||
)
|
||||
expected_read_write_context = """
|
||||
from transform_utils import DataTransformer
|
||||
from code_to_optimize.code_directories.retriever.utils import DataProcessor
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def circular_dependency(self, data: str) -> str:
|
||||
\"\"\"Test circular dependency\"\"\"
|
||||
return DataTransformer().circular_dependency(data)
|
||||
|
||||
|
||||
|
||||
class DataTransformer:
|
||||
|
||||
def circular_dependency(self, data):
|
||||
return DataProcessor().circular_dependency(data)
|
||||
|
||||
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_utils}
|
||||
import math
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
```python:{path_to_transform_utils}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def test_basic_class() -> None:
|
|||
class_var = "value"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ def test_dunder_methods() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ def test_dunder_methods_remove_docstring() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ def test_class_remove_docstring() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -131,7 +131,7 @@ def test_mixed_remove_docstring() -> None:
|
|||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -149,7 +149,7 @@ def test_target_in_nested_class() -> None:
|
|||
"""
|
||||
|
||||
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
||||
get_read_only_code(dedent(code), {"Outer.Inner.target_method"})
|
||||
get_read_only_code(dedent(code), {"Outer.Inner.target_method"}, set())
|
||||
|
||||
|
||||
def test_docstrings() -> None:
|
||||
|
|
@ -171,7 +171,7 @@ def test_docstrings() -> None:
|
|||
\"\"\"Class docstring.\"\"\"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ def test_method_signatures() -> None:
|
|||
|
||||
expected = """"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -214,7 +214,7 @@ def test_multiple_top_level_targets() -> None:
|
|||
self.x = 42
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target1", "TestClass.target2"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target1", "TestClass.target2"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ def test_class_annotations() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -256,7 +256,7 @@ def test_class_annotations_if() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -282,7 +282,7 @@ def test_class_annotations_try() -> None:
|
|||
continue
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -318,7 +318,7 @@ def test_class_annotations_else() -> None:
|
|||
var2: str
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -333,7 +333,7 @@ def test_top_level_functions() -> None:
|
|||
|
||||
expected = """"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"target_function"})
|
||||
output = get_read_only_code(dedent(code), {"target_function"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -352,7 +352,7 @@ def test_module_var() -> None:
|
|||
x = 5
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"target_function"})
|
||||
output = get_read_only_code(dedent(code), {"target_function"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -379,7 +379,7 @@ def test_module_var_if() -> None:
|
|||
z = 10
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"target_function"})
|
||||
output = get_read_only_code(dedent(code), {"target_function"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -414,7 +414,7 @@ def test_conditional_class_definitions() -> None:
|
|||
platform = "other"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"PlatformClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"PlatformClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -473,7 +473,7 @@ def test_multiple_except_clauses() -> None:
|
|||
error_type = "cleanup"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -526,7 +526,7 @@ def test_with_statement_and_loops() -> None:
|
|||
context = "cleanup"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -575,7 +575,7 @@ def test_async_with_try_except() -> None:
|
|||
status = "cancelled"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -685,7 +685,7 @@ def test_simplified_complete_implementation() -> None:
|
|||
self.error = str(e)
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"})
|
||||
output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
@ -795,6 +795,6 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
"""
|
||||
|
||||
output = get_read_only_code(
|
||||
dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, remove_docstrings=True
|
||||
dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
|
|||
Loading…
Reference in a new issue