codeflash/tests/test_code_extractor_none_aliases_exact.py
2026-02-21 01:49:31 +02:00

331 lines
8.3 KiB
Python

import tempfile
from pathlib import Path
from codeflash.languages.python.static_analysis.code_extractor import add_needed_imports_from_module
def test_add_needed_imports_with_none_aliases():
source_code = """
import json
from typing import Dict as MyDict, Optional
from collections import defaultdict
"""
target_code = """
def target_function():
pass
"""
expected_output = """
def target_function():
pass
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_add_needed_imports_complex_aliases():
source_code = """
import os
import sys as system
from typing import Dict, List as MyList, Optional as Opt
from collections import defaultdict as dd, Counter
from pathlib import Path
"""
target_code = """
def my_function():
return "test"
"""
expected_output = """
def my_function():
return "test"
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_add_needed_imports_with_usage():
source_code = """
import json
from typing import Dict as MyDict, Optional
from collections import defaultdict
"""
target_code = """
def target_function():
data = json.loads('{"key": "value"}')
my_dict: MyDict[str, str] = {}
opt_value: Optional[str] = None
dd = defaultdict(list)
return data, my_dict, opt_value, dd
"""
expected_output = """import json
from typing import Dict as MyDict, Optional
from collections import defaultdict
def target_function():
data = json.loads('{"key": "value"}')
my_dict: MyDict[str, str] = {}
opt_value: Optional[str] = None
dd = defaultdict(list)
return data, my_dict, opt_value, dd
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path,
)
# Assert exact expected output
assert result.strip() == expected_output.strip()
def test_litellm_router_style_imports():
source_code = """
import asyncio
import copy
import json
from collections import defaultdict
from typing import Dict, List, Optional, Union
from litellm.types.utils import ModelInfo
from litellm.types.utils import ModelInfo as ModelMapInfo
"""
target_code = '''
def target_function():
"""Target function for testing."""
pass
'''
expected_output = '''
def target_function():
"""Target function for testing."""
pass
'''
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "complex_source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_edge_case_none_values_in_alias_pairs():
source_code = """
from typing import Dict as MyDict, List, Optional as Opt
from collections import defaultdict, Counter as cnt
from pathlib import Path
"""
target_code = """
def my_test_function():
return "test"
"""
expected_output = """
def my_test_function():
return "test"
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "edge_case_source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_partial_import_usage():
source_code = """
import os
import sys
from typing import Dict, List, Optional
from collections import defaultdict, Counter
"""
target_code = """
def use_some_imports():
path = os.path.join("a", "b")
my_dict: Dict[str, int] = {}
counter = Counter([1, 2, 3])
return path, my_dict, counter
"""
expected_output = """import os
from collections import Counter
from typing import Dict
def use_some_imports():
path = os.path.join("a", "b")
my_dict: Dict[str, int] = {}
counter = Counter([1, 2, 3])
return path, my_dict, counter
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_alias_handling():
source_code = """
from typing import Dict as MyDict, List as MyList, Optional
from collections import defaultdict as dd, Counter
"""
target_code = """
def test_aliases():
d: MyDict[str, int] = {}
lst: MyList[str] = []
dd_instance = dd(list)
return d, lst, dd_instance
"""
expected_output = """from collections import defaultdict as dd
from typing import Dict as MyDict, List as MyList
def test_aliases():
d: MyDict[str, int] = {}
lst: MyList[str] = []
dd_instance = dd(list)
return d, lst, dd_instance
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path,
)
assert result.strip() == expected_output.strip()
def test_add_needed_imports_with_nonealiases():
source_code = """
import json
from typing import Dict as MyDict, Optional
from collections import defaultdict
"""
target_code = """
def target_function():
pass
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
src_path = temp_path / "source.py"
dst_path = temp_path / "target.py"
src_path.write_text(source_code)
dst_path.write_text(target_code)
# This should not raise a TypeError
result = add_needed_imports_from_module(
src_module_code=source_code,
dst_module_code=target_code,
src_path=src_path,
dst_path=dst_path,
project_root=temp_path,
)
expected_output = """
def target_function():
pass
"""
assert result.strip() == expected_output.strip()