avoid adding new imports existed in top level try/catch or if TYPE_CHECKING

This commit is contained in:
mohammed 2025-07-31 16:52:11 +03:00
parent e806c5dd3d
commit c3b775fc68
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
2 changed files with 213 additions and 6 deletions

View file

@ -195,6 +195,64 @@ class LastImportFinder(cst.CSTVisitor):
self.last_import_line = self.current_line
class ConditionalImportCollector(cst.CSTVisitor):
"""Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""
def __init__(self) -> None:
self.imports: set[str] = set()
self.depth = 0 # top-level
def get_full_dotted_name(self, expr: cst.BaseExpression) -> str:
if isinstance(expr, cst.Name):
return expr.value
if isinstance(expr, cst.Attribute):
return f"{self.get_full_dotted_name(expr.value)}.{expr.attr.value}"
return ""
def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
for statement in block.body:
if isinstance(statement, cst.SimpleStatementLine):
for child in statement.body:
if isinstance(child, cst.Import):
for alias in child.names:
module = self.get_full_dotted_name(alias.name)
asname = alias.asname.name.value if alias.asname else alias.name.value
self.imports.add(module if module == asname else f"{module}.{asname}")
elif isinstance(child, cst.ImportFrom):
if child.module is None:
continue
module = self.get_full_dotted_name(child.module)
for alias in child.names:
if isinstance(alias, cst.ImportAlias):
name = alias.name.value
asname = alias.asname.name.value if alias.asname else name
self.imports.add(f"{module}.{asname}")
def visit_Module(self, node: cst.Module) -> None:
self.depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth += 1
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.depth += 1
def leave_ClassDef(self, node: cst.ClassDef) -> None:
self.depth -= 1
def visit_If(self, node: cst.If) -> None:
if self.depth == 0:
self._collect_imports_from_block(node.body)
def visit_Try(self, node: cst.Try) -> None:
if self.depth == 0:
self._collect_imports_from_block(node.body)
class ImportInserter(cst.CSTTransformer):
"""Transformer that inserts global statements after the last import."""
@ -329,8 +387,19 @@ def add_needed_imports_from_module(
except Exception as e:
logger.error(f"Error parsing source module code: {e}")
return dst_module_code
cond_import_collector = ConditionalImportCollector()
try:
parsed_dst_module = cst.parse_module(dst_module_code)
parsed_dst_module.visit(cond_import_collector)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_module_code # Return the original code if there's a syntax error
try:
for mod in gatherer.module_imports:
if mod in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod)
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
for mod, obj_seq in gatherer.object_mapping.items():
@ -339,28 +408,29 @@ def add_needed_imports_from_module(
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
):
continue # Skip adding imports for helper functions already in the context
if f"{mod}.{obj}" in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code
for mod, asname in gatherer.module_aliases.items():
if f"{mod}.{asname}" in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
for mod, alias_pairs in gatherer.alias_mapping.items():
for alias_pair in alias_pairs:
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
continue
if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
try:
parsed_module = cst.parse_module(dst_module_code)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_module_code # Return the original code if there's a syntax error
try:
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
return transformed_module.code.lstrip("\n")
except Exception as e:

View file

@ -1,4 +1,5 @@
from __future__ import annotations
import re
import libcst as cst
from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder, AddRequestArgument
import dataclasses
@ -3070,3 +3071,139 @@ def my_fixture(request):
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_type_checking_imports():
optim_code = """from dataclasses import dataclass
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai_slim.pydantic_ai.models import Model
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
from typing import Literal
#### problamatic imports ####
from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool
import requests
import aiohttp as aiohttp_
from math import pi as PI, sin as sine
@dataclass(init=False)
class HuggingFaceModel(Model):
def __init__(
self,
model_name: str,
*,
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
):
print(requests.__name__)
print(aiohttp_.__name__)
print(PI)
print(sine)
# Fast branch: avoid repeating provider assignment
if isinstance(provider, str):
provider_obj = infer_provider(provider)
else:
provider_obj = provider
self._provider = provider
self._model_name = model_name
self.client = provider_obj.client
@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
# Inline dict creation and single pass for possible strict attribute
tool_dict = {
'type': 'function',
'function': {
'name': f.name,
'description': f.description,
'parameters': f.parameters_json_schema,
},
}
if f.strict is not None:
tool_dict['function']['strict'] = f.strict
return ChatCompletionInputTool.parse_obj_as_instance(tool_dict) # type: ignore
"""
original_code = """from dataclasses import dataclass
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai_slim.pydantic_ai.models import Model
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
from typing import Literal
try:
import aiohttp as aiohttp_
from math import pi as PI, sin as sine
from huggingface_hub import (
AsyncInferenceClient,
ChatCompletionInputMessage,
ChatCompletionInputMessageChunk,
ChatCompletionInputTool,
ChatCompletionInputToolCall,
ChatCompletionInputURL,
ChatCompletionOutput,
ChatCompletionOutputMessage,
ChatCompletionStreamOutput,
)
from huggingface_hub.errors import HfHubHTTPError
except ImportError as _import_error:
raise ImportError(
'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
) from _import_error
if True:
import requests
__all__ = (
'HuggingFaceModel',
'HuggingFaceModelSettings',
)
@dataclass(init=False)
class HuggingFaceModel(Model):
def __init__(
self,
model_name: str,
*,
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
):
self._model_name = model_name
self._provider = provider
if isinstance(provider, str):
provider = infer_provider(provider)
self.client = provider.client
@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
{
'type': 'function',
'function': {
'name': f.name,
'description': f.description,
'parameters': f.parameters_json_schema,
},
}
)
if f.strict is not None:
tool_param['function']['strict'] = f.strict
return tool_param
"""
function_name: str = "HuggingFaceModel._map_tool_definition"
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert not re.search(r"^import requests\b", new_code, re.MULTILINE) # conditional simple import: import <name>
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import