Merge pull request #776 from codeflash-ai/libcst-importstar-bug

don't iterate over star imports
This commit is contained in:
Kevin Turcios 2025-09-29 21:37:48 +00:00 committed by GitHub
commit 565d65be74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 228 additions and 3 deletions

View file

@ -272,6 +272,8 @@ class DottedImportCollector(cst.CSTVisitor):
if child.module is None:
continue
module = self.get_full_dotted_name(child.module)
if isinstance(child.names, cst.ImportStar):
continue
for alias in child.names:
if isinstance(alias, cst.ImportAlias):
name = alias.name.value
@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
return transformed_module.code
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
try:
module_path = module_name.replace(".", "/")
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]
module_file = None
for path in possible_paths:
if path.exists():
module_file = path
break
if module_file is None:
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
return set()
with module_file.open(encoding="utf8") as f:
module_code = f.read()
tree = ast.parse(module_code)
all_names = None
for node in ast.walk(tree):
if (
isinstance(node, ast.Assign)
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and node.targets[0].id == "__all__"
):
if isinstance(node.value, (ast.List, ast.Tuple)):
all_names = []
for elt in node.value.elts:
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
all_names.append(elt.value)
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
all_names.append(elt.s)
break
if all_names is not None:
return set(all_names)
public_names = set()
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
if not node.name.startswith("_"):
public_names.add(node.name)
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and not target.id.startswith("_"):
public_names.add(target.id)
elif isinstance(node, ast.AnnAssign):
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
public_names.add(node.target.id)
elif isinstance(node, ast.Import) or (
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
):
for alias in node.names:
name = alias.asname or alias.name
if not name.startswith("_"):
public_names.add(name)
return public_names # noqa: TRY300
except Exception as e:
logger.warning(f"Error resolving star import for {module_name}: {e}")
return set()
def add_needed_imports_from_module(
src_module_code: str,
dst_module_code: str,
@ -468,9 +537,23 @@ def add_needed_imports_from_module(
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
):
continue # Skip adding imports for helper functions already in the context
if f"{mod}.{obj}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
# Handle star imports by resolving them to actual symbol names
if obj == "*":
resolved_symbols = resolve_star_import(mod, project_root)
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")
for symbol in resolved_symbols:
if (
f"{mod}.{symbol}" not in helper_functions_fqn
and f"{mod}.{symbol}" not in dotted_import_collector.imports
):
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
else:
if f"{mod}.{obj}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code

View file

@ -3,6 +3,10 @@ from pathlib import Path
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
import tempfile
from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector
import libcst as cst
from codeflash.models.models import FunctionParent
def test_add_needed_imports_from_module0() -> None:
src_module = '''import ast
@ -349,3 +353,141 @@ class DbtAdapter(BaseAdapter):
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
def test_resolve_star_import_with_all_defined():
"""Test resolve_star_import when __all__ is explicitly defined."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
test_module = project_root / 'test_module.py'
# Create a test module with __all__ definition
test_module.write_text('''
__all__ = ['public_function', 'PublicClass']
def public_function():
pass
def _private_function():
pass
class PublicClass:
pass
class AnotherPublicClass:
"""Not in __all__ so should be excluded."""
pass
''')
symbols = resolve_star_import('test_module', project_root)
expected_symbols = {'public_function', 'PublicClass'}
assert symbols == expected_symbols
def test_resolve_star_import_without_all_defined():
"""Test resolve_star_import when __all__ is not defined - should include all public symbols."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
test_module = project_root / 'test_module.py'
# Create a test module without __all__ definition
test_module.write_text('''
def public_func():
pass
def _private_func():
pass
class PublicClass:
pass
PUBLIC_VAR = 42
_private_var = 'secret'
''')
symbols = resolve_star_import('test_module', project_root)
expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'}
assert symbols == expected_symbols
def test_resolve_star_import_nonexistent_module():
"""Test resolve_star_import with non-existent module - should return empty set."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
symbols = resolve_star_import('nonexistent_module', project_root)
assert symbols == set()
def test_dotted_import_collector_skips_star_imports():
"""Test that DottedImportCollector correctly skips star imports."""
code_with_star_import = '''
from typing import *
from pathlib import Path
from collections import defaultdict
import os
'''
module = cst.parse_module(code_with_star_import)
collector = DottedImportCollector()
module.visit(collector)
# Should collect regular imports but skip the star import
expected_imports = {'collections.defaultdict', 'os', 'pathlib.Path'}
assert collector.imports == expected_imports
def test_add_needed_imports_with_star_import_resolution():
"""Test add_needed_imports_from_module correctly handles star imports by resolving them."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
# Create a source module that exports symbols
src_module = project_root / 'source_module.py'
src_module.write_text('''
__all__ = ['UtilFunction', 'HelperClass']
def UtilFunction():
pass
class HelperClass:
pass
''')
# Create source code that uses star import
src_code = '''
from source_module import *
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
# Destination code that needs the imports resolved
dst_code = '''
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
src_path = project_root / 'src.py'
dst_path = project_root / 'dst.py'
src_path.write_text(src_code)
result = add_needed_imports_from_module(
src_code, dst_code, src_path, dst_path, project_root
)
# The result should have individual imports instead of star import
expected_result = '''from source_module import HelperClass, UtilFunction
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
assert result == expected_result