codeflash/tests/test_add_needed_imports_from_module.py
Kevin Turcios 5cd94cdf64 round 1
2024-10-12 17:29:15 -05:00

135 lines
4.8 KiB
Python

from pathlib import Path
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
def test_add_needed_imports_from_module0() -> None:
src_module = '''import ast
import logging
import os
from typing import Union
import jedi
import tiktoken
from jedi.api.classes import Name
from pydantic.dataclasses import dataclass
from codeflash.code_utils.code_extractor import get_code, get_code_no_skeleton
from codeflash.code_utils.code_utils import path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
def belongs_to_class(name: Name, class_name: str) -> bool:
"""Check if the given name belongs to the specified class."""
if name.full_name and name.full_name.startswith(f"{name.module_name}.{class_name}."):
return True
return False
def heyjude() -> None:
print("Hey Jude, don't make it bad")
def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given name belongs to the specified function"""
if name.full_name and name.full_name.startswith(name.module_name):
subname: str = name.full_name.replace(name.module_name, "", 1)
else:
return False
# The name is defined inside the function or is the function itself
return f".{function_name}." in subname or f".{function_name}" == subname
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class Source:
full_name: str
definition: Name
source_code: str
'''
dst_module = """def heyjude() -> None:
print("Hey Jude, don't make it bad")
"""
expected = """def heyjude() -> None:
print("Hey Jude, don't make it bad")
"""
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
project_root = Path("/home/roger/repos/codeflash")
new_module = add_needed_imports_from_module(
src_module,
dst_module,
src_path,
dst_path,
project_root,
)
assert new_module == expected
def test_add_needed_imports_from_module() -> None:
src_module = '''import ast
import logging
import os
from typing import Union
import jedi
import tiktoken
from jedi.api.classes import Name
from pydantic.dataclasses import dataclass
from codeflash.code_utils.code_extractor import get_code, get_code_no_skeleton
from codeflash.code_utils.code_utils import path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
def belongs_to_class(name: Name, class_name: str) -> bool:
"""Check if the given name belongs to the specified class."""
if name.full_name and name.full_name.startswith(f"{name.module_name}.{class_name}."):
return True
return False
def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given name belongs to the specified function"""
if name.full_name and name.full_name.startswith(name.module_name):
subname: str = name.full_name.replace(name.module_name, "", 1)
else:
return False
# The name is defined inside the function or is the function itself
return f".{function_name}." in subname or f".{function_name}" == subname
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class Source:
full_name: str
definition: Name
source_code: str
'''
dst_module = '''def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given name belongs to the specified function"""
if name.full_name and name.full_name.startswith(name.module_name):
subname: str = name.full_name.replace(name.module_name, "", 1)
else:
return False
# The name is defined inside the function or is the function itself
return f".{function_name}." in subname or f".{function_name}" == subname
'''
expected = '''from jedi.api.classes import Name
def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given name belongs to the specified function"""
if name.full_name and name.full_name.startswith(name.module_name):
subname: str = name.full_name.replace(name.module_name, "", 1)
else:
return False
# The name is defined inside the function or is the function itself
return f".{function_name}." in subname or f".{function_name}" == subname
'''
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
project_root = Path("/home/roger/repos/codeflash")
new_module = add_needed_imports_from_module(
src_module,
dst_module,
src_path,
dst_path,
project_root,
)
assert new_module == expected