star imports bug
This commit is contained in:
parent
2d886e87d5
commit
cc0034be5d
2 changed files with 229 additions and 3 deletions
|
|
@ -272,6 +272,8 @@ class DottedImportCollector(cst.CSTVisitor):
|
|||
if child.module is None:
|
||||
continue
|
||||
module = self.get_full_dotted_name(child.module)
|
||||
if isinstance(child.names, cst.ImportStar):
|
||||
continue
|
||||
for alias in child.names:
|
||||
if isinstance(alias, cst.ImportAlias):
|
||||
name = alias.name.value
|
||||
|
|
@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
|
|||
return transformed_module.code
|
||||
|
||||
|
||||
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
|
||||
try:
|
||||
module_path = module_name.replace(".", "/")
|
||||
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]
|
||||
|
||||
module_file = None
|
||||
for path in possible_paths:
|
||||
if path.exists():
|
||||
module_file = path
|
||||
break
|
||||
|
||||
if module_file is None:
|
||||
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
|
||||
return set()
|
||||
|
||||
with module_file.open(encoding="utf8") as f:
|
||||
module_code = f.read()
|
||||
|
||||
tree = ast.parse(module_code)
|
||||
|
||||
all_names = None
|
||||
for node in ast.walk(tree):
|
||||
if (
|
||||
isinstance(node, ast.Assign)
|
||||
and len(node.targets) == 1
|
||||
and isinstance(node.targets[0], ast.Name)
|
||||
and node.targets[0].id == "__all__"
|
||||
):
|
||||
if isinstance(node.value, (ast.List, ast.Tuple)):
|
||||
all_names = []
|
||||
for elt in node.value.elts:
|
||||
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
|
||||
all_names.append(elt.value)
|
||||
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
|
||||
all_names.append(elt.s)
|
||||
break
|
||||
|
||||
if all_names is not None:
|
||||
return set(all_names)
|
||||
|
||||
public_names = set()
|
||||
for node in tree.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
if not node.name.startswith("_"):
|
||||
public_names.add(node.name)
|
||||
elif isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and not target.id.startswith("_"):
|
||||
public_names.add(target.id)
|
||||
elif isinstance(node, ast.AnnAssign):
|
||||
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
|
||||
public_names.add(node.target.id)
|
||||
elif isinstance(node, ast.Import) or (
|
||||
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
|
||||
):
|
||||
for alias in node.names:
|
||||
name = alias.asname or alias.name
|
||||
if not name.startswith("_"):
|
||||
public_names.add(name)
|
||||
|
||||
return public_names # noqa: TRY300
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error resolving star import for {module_name}: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def add_needed_imports_from_module(
|
||||
src_module_code: str,
|
||||
dst_module_code: str,
|
||||
|
|
@ -468,9 +537,23 @@ def add_needed_imports_from_module(
|
|||
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
|
||||
):
|
||||
continue # Skip adding imports for helper functions already in the context
|
||||
if f"{mod}.{obj}" not in dotted_import_collector.imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
|
||||
# Handle star imports by resolving them to actual symbol names
|
||||
if obj == "*":
|
||||
resolved_symbols = resolve_star_import(mod, project_root)
|
||||
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")
|
||||
|
||||
for symbol in resolved_symbols:
|
||||
if (
|
||||
f"{mod}.{symbol}" not in helper_functions_fqn
|
||||
and f"{mod}.{symbol}" not in dotted_import_collector.imports
|
||||
):
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
|
||||
else:
|
||||
if f"{mod}.{obj}" not in dotted_import_collector.imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding imports to destination module code: {e}")
|
||||
return dst_module_code
|
||||
|
|
|
|||
|
|
@ -3,6 +3,10 @@ from pathlib import Path
|
|||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
|
||||
import tempfile
|
||||
from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector
|
||||
import libcst as cst
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
def test_add_needed_imports_from_module0() -> None:
|
||||
src_module = '''import ast
|
||||
|
|
@ -349,3 +353,142 @@ class DbtAdapter(BaseAdapter):
|
|||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
|
||||
|
||||
def test_resolve_star_import_with_all_defined():
|
||||
"""Test resolve_star_import when __all__ is explicitly defined."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
test_module = project_root / 'test_module.py'
|
||||
|
||||
# Create a test module with __all__ definition
|
||||
test_module.write_text('''
|
||||
__all__ = ['public_function', 'PublicClass']
|
||||
|
||||
def public_function():
|
||||
pass
|
||||
|
||||
def _private_function():
|
||||
pass
|
||||
|
||||
class PublicClass:
|
||||
pass
|
||||
|
||||
class AnotherPublicClass:
|
||||
"""Not in __all__ so should be excluded."""
|
||||
pass
|
||||
''')
|
||||
|
||||
symbols = resolve_star_import('test_module', project_root)
|
||||
expected_symbols = {'public_function', 'PublicClass'}
|
||||
assert symbols == expected_symbols
|
||||
|
||||
|
||||
def test_resolve_star_import_without_all_defined():
|
||||
"""Test resolve_star_import when __all__ is not defined - should include all public symbols."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
test_module = project_root / 'test_module.py'
|
||||
|
||||
# Create a test module without __all__ definition
|
||||
test_module.write_text('''
|
||||
def public_func():
|
||||
pass
|
||||
|
||||
def _private_func():
|
||||
pass
|
||||
|
||||
class PublicClass:
|
||||
pass
|
||||
|
||||
PUBLIC_VAR = 42
|
||||
_private_var = 'secret'
|
||||
''')
|
||||
|
||||
symbols = resolve_star_import('test_module', project_root)
|
||||
expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'}
|
||||
assert symbols == expected_symbols
|
||||
|
||||
|
||||
def test_resolve_star_import_nonexistent_module():
|
||||
"""Test resolve_star_import with non-existent module - should return empty set."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
|
||||
symbols = resolve_star_import('nonexistent_module', project_root)
|
||||
assert symbols == set()
|
||||
|
||||
|
||||
def test_dotted_import_collector_skips_star_imports():
|
||||
"""Test that DottedImportCollector correctly skips star imports."""
|
||||
code_with_star_import = '''
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import os
|
||||
'''
|
||||
|
||||
module = cst.parse_module(code_with_star_import)
|
||||
collector = DottedImportCollector()
|
||||
module.visit(collector)
|
||||
|
||||
# Should collect regular imports but skip the star import
|
||||
expected_imports = {
|
||||
'pathlib.Path',
|
||||
'collections.defaultdict',
|
||||
'os'
|
||||
}
|
||||
assert collector.imports == expected_imports
|
||||
# Ensure the star import from typing is not collected
|
||||
assert not any('typing' in imp for imp in collector.imports)
|
||||
|
||||
|
||||
def test_add_needed_imports_with_star_import_resolution():
|
||||
"""Test add_needed_imports_from_module correctly handles star imports by resolving them."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
|
||||
# Create a source module that exports symbols
|
||||
src_module = project_root / 'source_module.py'
|
||||
src_module.write_text('''
|
||||
__all__ = ['UtilFunction', 'HelperClass']
|
||||
|
||||
def UtilFunction():
|
||||
pass
|
||||
|
||||
class HelperClass:
|
||||
pass
|
||||
''')
|
||||
|
||||
# Create source code that uses star import
|
||||
src_code = '''
|
||||
from source_module import *
|
||||
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
|
||||
# Destination code that needs the imports resolved
|
||||
dst_code = '''
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
|
||||
src_path = project_root / 'src.py'
|
||||
dst_path = project_root / 'dst.py'
|
||||
src_path.write_text(src_code)
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_code, dst_code, src_path, dst_path, project_root
|
||||
)
|
||||
|
||||
# The result should have individual imports instead of star import
|
||||
assert 'from source_module import' in result
|
||||
assert 'HelperClass' in result and 'UtilFunction' in result
|
||||
assert 'from source_module import *' not in result
|
||||
|
|
|
|||
Loading…
Reference in a new issue