mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
avoid adding new imports existed in top level try/catch or if TYPE_CHECKING
This commit is contained in:
parent
e806c5dd3d
commit
c3b775fc68
2 changed files with 213 additions and 6 deletions
|
|
@ -195,6 +195,64 @@ class LastImportFinder(cst.CSTVisitor):
|
|||
self.last_import_line = self.current_line
|
||||
|
||||
|
||||
class ConditionalImportCollector(cst.CSTVisitor):
|
||||
"""Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.imports: set[str] = set()
|
||||
self.depth = 0 # top-level
|
||||
|
||||
def get_full_dotted_name(self, expr: cst.BaseExpression) -> str:
|
||||
if isinstance(expr, cst.Name):
|
||||
return expr.value
|
||||
if isinstance(expr, cst.Attribute):
|
||||
return f"{self.get_full_dotted_name(expr.value)}.{expr.attr.value}"
|
||||
return ""
|
||||
|
||||
def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
|
||||
for statement in block.body:
|
||||
if isinstance(statement, cst.SimpleStatementLine):
|
||||
for child in statement.body:
|
||||
if isinstance(child, cst.Import):
|
||||
for alias in child.names:
|
||||
module = self.get_full_dotted_name(alias.name)
|
||||
asname = alias.asname.name.value if alias.asname else alias.name.value
|
||||
self.imports.add(module if module == asname else f"{module}.{asname}")
|
||||
|
||||
elif isinstance(child, cst.ImportFrom):
|
||||
if child.module is None:
|
||||
continue
|
||||
module = self.get_full_dotted_name(child.module)
|
||||
for alias in child.names:
|
||||
if isinstance(alias, cst.ImportAlias):
|
||||
name = alias.name.value
|
||||
asname = alias.asname.name.value if alias.asname else name
|
||||
self.imports.add(f"{module}.{asname}")
|
||||
|
||||
def visit_Module(self, node: cst.Module) -> None:
|
||||
self.depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.depth += 1
|
||||
|
||||
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.depth += 1
|
||||
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.depth -= 1
|
||||
|
||||
def visit_If(self, node: cst.If) -> None:
|
||||
if self.depth == 0:
|
||||
self._collect_imports_from_block(node.body)
|
||||
|
||||
def visit_Try(self, node: cst.Try) -> None:
|
||||
if self.depth == 0:
|
||||
self._collect_imports_from_block(node.body)
|
||||
|
||||
|
||||
class ImportInserter(cst.CSTTransformer):
|
||||
"""Transformer that inserts global statements after the last import."""
|
||||
|
||||
|
|
@ -329,8 +387,19 @@ def add_needed_imports_from_module(
|
|||
except Exception as e:
|
||||
logger.error(f"Error parsing source module code: {e}")
|
||||
return dst_module_code
|
||||
|
||||
cond_import_collector = ConditionalImportCollector()
|
||||
try:
|
||||
parsed_dst_module = cst.parse_module(dst_module_code)
|
||||
parsed_dst_module.visit(cond_import_collector)
|
||||
except cst.ParserSyntaxError as e:
|
||||
logger.exception(f"Syntax error in destination module code: {e}")
|
||||
return dst_module_code # Return the original code if there's a syntax error
|
||||
|
||||
try:
|
||||
for mod in gatherer.module_imports:
|
||||
if mod in cond_import_collector.imports:
|
||||
continue
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
|
||||
for mod, obj_seq in gatherer.object_mapping.items():
|
||||
|
|
@ -339,28 +408,29 @@ def add_needed_imports_from_module(
|
|||
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
|
||||
):
|
||||
continue # Skip adding imports for helper functions already in the context
|
||||
if f"{mod}.{obj}" in cond_import_collector.imports:
|
||||
continue
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding imports to destination module code: {e}")
|
||||
return dst_module_code
|
||||
for mod, asname in gatherer.module_aliases.items():
|
||||
if f"{mod}.{asname}" in cond_import_collector.imports:
|
||||
continue
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
|
||||
for mod, alias_pairs in gatherer.alias_mapping.items():
|
||||
for alias_pair in alias_pairs:
|
||||
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
|
||||
continue
|
||||
if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports:
|
||||
continue
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
|
||||
try:
|
||||
parsed_module = cst.parse_module(dst_module_code)
|
||||
except cst.ParserSyntaxError as e:
|
||||
logger.exception(f"Syntax error in destination module code: {e}")
|
||||
return dst_module_code # Return the original code if there's a syntax error
|
||||
try:
|
||||
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
|
||||
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
|
||||
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
|
||||
return transformed_module.code.lstrip("\n")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import re
|
||||
import libcst as cst
|
||||
from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder, AddRequestArgument
|
||||
import dataclasses
|
||||
|
|
@ -3070,3 +3071,139 @@ def my_fixture(request):
|
|||
modified_module = module.visit(transformer)
|
||||
|
||||
assert modified_module.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_type_checking_imports():
|
||||
optim_code = """from dataclasses import dataclass
|
||||
from pydantic_ai.providers import Provider, infer_provider
|
||||
from pydantic_ai_slim.pydantic_ai.models import Model
|
||||
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
|
||||
from typing import Literal
|
||||
|
||||
#### problamatic imports ####
|
||||
from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool
|
||||
import requests
|
||||
import aiohttp as aiohttp_
|
||||
from math import pi as PI, sin as sine
|
||||
|
||||
@dataclass(init=False)
|
||||
class HuggingFaceModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
*,
|
||||
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
|
||||
):
|
||||
print(requests.__name__)
|
||||
print(aiohttp_.__name__)
|
||||
print(PI)
|
||||
print(sine)
|
||||
# Fast branch: avoid repeating provider assignment
|
||||
if isinstance(provider, str):
|
||||
provider_obj = infer_provider(provider)
|
||||
else:
|
||||
provider_obj = provider
|
||||
self._provider = provider
|
||||
self._model_name = model_name
|
||||
self.client = provider_obj.client
|
||||
|
||||
@staticmethod
|
||||
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
|
||||
# Inline dict creation and single pass for possible strict attribute
|
||||
tool_dict = {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': f.name,
|
||||
'description': f.description,
|
||||
'parameters': f.parameters_json_schema,
|
||||
},
|
||||
}
|
||||
if f.strict is not None:
|
||||
tool_dict['function']['strict'] = f.strict
|
||||
return ChatCompletionInputTool.parse_obj_as_instance(tool_dict) # type: ignore
|
||||
"""
|
||||
|
||||
original_code = """from dataclasses import dataclass
|
||||
from pydantic_ai.providers import Provider, infer_provider
|
||||
from pydantic_ai_slim.pydantic_ai.models import Model
|
||||
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
|
||||
from typing import Literal
|
||||
|
||||
try:
|
||||
import aiohttp as aiohttp_
|
||||
from math import pi as PI, sin as sine
|
||||
from huggingface_hub import (
|
||||
AsyncInferenceClient,
|
||||
ChatCompletionInputMessage,
|
||||
ChatCompletionInputMessageChunk,
|
||||
ChatCompletionInputTool,
|
||||
ChatCompletionInputToolCall,
|
||||
ChatCompletionInputURL,
|
||||
ChatCompletionOutput,
|
||||
ChatCompletionOutputMessage,
|
||||
ChatCompletionStreamOutput,
|
||||
)
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
except ImportError as _import_error:
|
||||
raise ImportError(
|
||||
'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
|
||||
'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
|
||||
) from _import_error
|
||||
|
||||
if True:
|
||||
import requests
|
||||
|
||||
__all__ = (
|
||||
'HuggingFaceModel',
|
||||
'HuggingFaceModelSettings',
|
||||
)
|
||||
|
||||
@dataclass(init=False)
|
||||
class HuggingFaceModel(Model):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
*,
|
||||
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
|
||||
):
|
||||
self._model_name = model_name
|
||||
self._provider = provider
|
||||
if isinstance(provider, str):
|
||||
provider = infer_provider(provider)
|
||||
self.client = provider.client
|
||||
|
||||
@staticmethod
|
||||
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
|
||||
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
|
||||
{
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': f.name,
|
||||
'description': f.description,
|
||||
'parameters': f.parameters_json_schema,
|
||||
},
|
||||
}
|
||||
)
|
||||
if f.strict is not None:
|
||||
tool_param['function']['strict'] = f.strict
|
||||
return tool_param
|
||||
"""
|
||||
|
||||
|
||||
function_name: str = "HuggingFaceModel._map_tool_definition"
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
|
||||
assert not re.search(r"^import requests\b", new_code, re.MULTILINE) # conditional simple import: import <name>
|
||||
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
|
||||
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
|
||||
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import
|
||||
Loading…
Reference in a new issue