fix: exclude loop variables from test import fallback
Make get_referenced_names_from_source scope-aware by reusing UndefinedNameCollector, preventing invalid imports like `i` and `v` from loop variables in AI-generated tests.
This commit is contained in:
parent
53baee3994
commit
48b667062b
1 changed files with 9 additions and 21 deletions
|
|
@ -269,18 +269,17 @@ def get_imports_from_source_code(source_code: str) -> dict[str, str]:
|
|||
|
||||
|
||||
def get_referenced_names_from_source(source_code: str) -> set[str]:
|
||||
"""Get all names referenced (used) in source code.
|
||||
"""Get names that are used but not locally defined in source code.
|
||||
|
||||
This captures names that appear as identifiers in the source, even if they're
|
||||
not explicitly defined or imported in the visible snippet. Useful for finding
|
||||
symbols that the source module uses but whose definition may be missing from
|
||||
the snippet.
|
||||
This returns names that could potentially be module-level exports,
|
||||
excluding locally-scoped names like loop variables, function parameters,
|
||||
and comprehension variables.
|
||||
|
||||
Args:
|
||||
source_code: The source code to parse
|
||||
|
||||
Returns:
|
||||
Set of all names referenced in the source code
|
||||
Set of names that are used but not defined locally
|
||||
|
||||
"""
|
||||
try:
|
||||
|
|
@ -288,21 +287,10 @@ def get_referenced_names_from_source(source_code: str) -> set[str]:
|
|||
except SyntaxError:
|
||||
return set()
|
||||
|
||||
names: set[str] = set()
|
||||
# Use iterative stack-based traversal instead of ast.walk() generator
|
||||
stack = [tree]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if node.__class__ is ast.Name:
|
||||
names.add(node.id)
|
||||
# Directly traverse child nodes using _fields
|
||||
for field in node._fields:
|
||||
value = getattr(node, field, None)
|
||||
if isinstance(value, list):
|
||||
stack.extend(reversed([item for item in value if isinstance(item, ast.AST)]))
|
||||
elif isinstance(value, ast.AST):
|
||||
stack.append(value)
|
||||
return names
|
||||
collector = UndefinedNameCollector()
|
||||
collector.visit(tree)
|
||||
# Return names that are used but not defined locally
|
||||
return collector.used_names - collector.defined_names - collector.imported_names
|
||||
|
||||
|
||||
def get_local_definitions(test_code: str) -> set[str]:
|
||||
|
|
|
|||
Loading…
Reference in a new issue