567 lines
22 KiB
Python
567 lines
22 KiB
Python
import libcst as cst
|
|
import pytest
|
|
|
|
from optimizer.models import CodeExplanationAndID
|
|
from optimizer.postprocess import DocstringTransformer, DocstringVisitor
|
|
|
|
|
|
def test_function_docstring_preservation() -> None:
|
|
# Original code with docstring
|
|
original_code = """
|
|
def example_function():
|
|
\"\"\"This is a docstring for the example function.\"\"\"
|
|
return 42
|
|
"""
|
|
# Optimized code without docstring
|
|
optimized_code = """
|
|
def example_function():
|
|
return 43
|
|
"""
|
|
# Extract original docstrings
|
|
original_visitor = DocstringVisitor()
|
|
original_tree = cst.parse_module(original_code)
|
|
original_tree.visit(original_visitor)
|
|
|
|
# Apply transformer to optimized code
|
|
transformer = DocstringTransformer(original_visitor.original_docstrings)
|
|
optimized_tree = cst.parse_module(optimized_code)
|
|
transformed_tree = optimized_tree.visit(transformer)
|
|
|
|
# Check if docstring was preserved
|
|
assert "This is a docstring for the example function" in transformed_tree.code
|
|
|
|
|
|
def test_function_docstring_in_both_functions() -> None:
|
|
# Original code with docstring
|
|
original_code = """
|
|
def example_function():
|
|
\"\"\"This is a docstring for the example function.\"\"\"
|
|
return 42
|
|
"""
|
|
# Optimized code without docstring
|
|
optimized_code = """
|
|
def example_function():
|
|
\"\"\"We dont want this docstring, we want the original one\"\"\"
|
|
return 43
|
|
"""
|
|
# Extract original docstrings
|
|
original_visitor = DocstringVisitor()
|
|
original_tree = cst.parse_module(original_code)
|
|
original_tree.visit(original_visitor)
|
|
|
|
# Apply transformer to optimized code
|
|
transformer = DocstringTransformer(original_visitor.original_docstrings)
|
|
optimized_tree = cst.parse_module(optimized_code)
|
|
transformed_tree = optimized_tree.visit(transformer)
|
|
|
|
# Check if docstring was preserved
|
|
assert "This is a docstring for the example function" in transformed_tree.code
|
|
|
|
|
|
def test_class_docstring_preservation() -> None:
|
|
# Original code with class docstring
|
|
original_code = """
|
|
class ExampleClass:
|
|
\"\"\"This is a docstring for the example class.\"\"\"
|
|
def method(self):
|
|
return 42
|
|
"""
|
|
# Optimized code without docstring
|
|
optimized_code = """
|
|
class ExampleClass:
|
|
def method(self):
|
|
return 42
|
|
"""
|
|
# Extract original docstrings
|
|
original_visitor = DocstringVisitor()
|
|
original_tree = cst.parse_module(original_code)
|
|
original_tree.visit(original_visitor)
|
|
|
|
# Apply transformer to optimized code
|
|
transformer = DocstringTransformer(original_visitor.original_docstrings)
|
|
optimized_tree = cst.parse_module(optimized_code)
|
|
transformed_tree = optimized_tree.visit(transformer)
|
|
|
|
# Check if class docstring was preserved
|
|
assert "This is a docstring for the example class" in transformed_tree.code
|
|
|
|
|
|
def test_class_docstring_in_both_functions() -> None:
|
|
# Original code with class docstring
|
|
original_code = """
|
|
class ExampleClass:
|
|
\"\"\"This is a docstring for the example class.\"\"\"
|
|
def method(self):
|
|
return 42
|
|
"""
|
|
# Optimized code without docstring
|
|
optimized_code = """
|
|
class ExampleClass:
|
|
\"\"\"This is a docstring we dont want, we want the original one\"\"\"
|
|
def method(self):
|
|
return 42
|
|
"""
|
|
# Extract original docstrings
|
|
original_visitor = DocstringVisitor()
|
|
original_tree = cst.parse_module(original_code)
|
|
original_tree.visit(original_visitor)
|
|
|
|
# Apply transformer to optimized code
|
|
transformer = DocstringTransformer(original_visitor.original_docstrings)
|
|
optimized_tree = cst.parse_module(optimized_code)
|
|
transformed_tree = optimized_tree.visit(transformer)
|
|
|
|
# Check if class docstring was preserved
|
|
assert "This is a docstring for the example class" in transformed_tree.code
|
|
|
|
|
|
def test_method_docstring_preservation() -> None:
|
|
# Original code with method docstring
|
|
original_code = """
|
|
class ExampleClass:
|
|
def method(self):
|
|
\"\"\"This is a docstring for the method.\"\"\"
|
|
return 42
|
|
"""
|
|
# Optimized code without docstring
|
|
optimized_code = """
|
|
class ExampleClass:
|
|
def method(self):
|
|
return 42
|
|
"""
|
|
# Extract original docstrings
|
|
original_visitor = DocstringVisitor()
|
|
original_tree = cst.parse_module(original_code)
|
|
original_tree.visit(original_visitor)
|
|
|
|
# Apply transformer to optimized code
|
|
transformer = DocstringTransformer(original_visitor.original_docstrings)
|
|
optimized_tree = cst.parse_module(optimized_code)
|
|
transformed_tree = optimized_tree.visit(transformer)
|
|
|
|
# Check if method docstring was preserved
|
|
assert "This is a docstring for the method" in transformed_tree.code
|
|
|
|
|
|
def test_fix_missing_docstring_pipeline_function() -> None:
|
|
# Test the integration with the fix_missing_docstring pipeline function
|
|
from optimizer.postprocess import fix_missing_docstring
|
|
|
|
original_code = """
|
|
def example_function():
|
|
\"\"\"This is a docstring for the example function.\"\"\"
|
|
return 42
|
|
"""
|
|
optimized_code = """
|
|
def example_function():
|
|
return 42
|
|
"""
|
|
|
|
optimized_cst = cst.parse_module(optimized_code)
|
|
|
|
code_explanations = [
|
|
CodeExplanationAndID(cst_module=optimized_cst, explanation="Removed unnecessary docstring", id="test-id")
|
|
]
|
|
|
|
# Apply the pipeline function
|
|
result = fix_missing_docstring(original_code, code_explanations)
|
|
|
|
# Check if docstring was preserved
|
|
assert "This is a docstring for the example function" in result[0].cst_module.code
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="This currently results in an exception and should be fixed. This test case reproduces the error"
|
|
)
|
|
def test_docstring_exception() -> None:
|
|
# derived from optimizing https://github.com/pydantic/pydantic-ai/blob/39e28771e538a3a4af98222ca565ecfa402d9c08/pydantic_ai_slim/pydantic_ai/agent.py#L1717
|
|
original_code = """import dataclasses
|
|
import warnings
|
|
from pydantic_ai import _agent_graph, _output, _system_prompt, _utils, models, result
|
|
from collections.abc import Sequence
|
|
from contextvars import ContextVar
|
|
from pydantic_ai._agent_graph import HistoryProcessor
|
|
from pydantic_ai.mcp import MCPServer
|
|
from pydantic_ai.models.instrumented import InstrumentationSettings
|
|
from pydantic_ai.output import OutputDataT, OutputSpec
|
|
from pydantic_ai.settings import ModelSettings
|
|
from pydantic_ai.tools import AgentDepsT, Tool, ToolFuncEither, ToolsPrepareFunc
|
|
from pydantic_graph import End
|
|
from typing import Any, Callable, Generic, final, overload
|
|
from typing_extensions import TypeIs, deprecated
|
|
@final
|
|
@dataclasses.dataclass(init=False)
|
|
class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
@overload
|
|
def __init__(
|
|
self,
|
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
*,
|
|
output_type: OutputSpec[OutputDataT] = str,
|
|
instructions: str
|
|
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
| None = None,
|
|
system_prompt: str | Sequence[str] = (),
|
|
deps_type: type[AgentDepsT] = NoneType,
|
|
name: str | None = None,
|
|
model_settings: ModelSettings | None = None,
|
|
retries: int = 1,
|
|
output_retries: int | None = None,
|
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
mcp_servers: Sequence[MCPServer] = (),
|
|
defer_model_check: bool = False,
|
|
end_strategy: EndStrategy = 'early',
|
|
instrument: InstrumentationSettings | bool | None = None,
|
|
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
|
|
) -> None: ...
|
|
@overload
|
|
@deprecated(
|
|
'`result_type`, `result_tool_name`, `result_tool_description` & `result_retries` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.'
|
|
)
|
|
def __init__(
|
|
self,
|
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
*,
|
|
result_type: type[OutputDataT] = str,
|
|
instructions: str
|
|
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
| None = None,
|
|
system_prompt: str | Sequence[str] = (),
|
|
deps_type: type[AgentDepsT] = NoneType,
|
|
name: str | None = None,
|
|
model_settings: ModelSettings | None = None,
|
|
retries: int = 1,
|
|
result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME,
|
|
result_tool_description: str | None = None,
|
|
result_retries: int | None = None,
|
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
mcp_servers: Sequence[MCPServer] = (),
|
|
defer_model_check: bool = False,
|
|
end_strategy: EndStrategy = 'early',
|
|
instrument: InstrumentationSettings | bool | None = None,
|
|
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
|
|
) -> None: ...
|
|
def __init__(
|
|
self,
|
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
*,
|
|
# TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads
|
|
output_type: Any = str,
|
|
instructions: str
|
|
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
| None = None,
|
|
system_prompt: str | Sequence[str] = (),
|
|
deps_type: type[AgentDepsT] = NoneType,
|
|
name: str | None = None,
|
|
model_settings: ModelSettings | None = None,
|
|
retries: int = 1,
|
|
output_retries: int | None = None,
|
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
mcp_servers: Sequence[MCPServer] = (),
|
|
defer_model_check: bool = False,
|
|
end_strategy: EndStrategy = 'early',
|
|
instrument: InstrumentationSettings | bool | None = None,
|
|
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
|
|
**_deprecated_kwargs: Any,
|
|
):
|
|
\"\"\"Create an agent.
|
|
|
|
Args:
|
|
model: The default model to use for this agent, if not provide,
|
|
you must provide the model when calling it. We allow `str` here since the actual list of allowed models changes frequently.
|
|
output_type: The type of the output data, used to validate the data returned by the model,
|
|
defaults to `str`.
|
|
instructions: Instructions to use for this agent, you can also register instructions via a function with
|
|
[`instructions`][pydantic_ai.Agent.instructions].
|
|
system_prompt: Static system prompts to use for this agent, you can also register system
|
|
prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
|
|
deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully
|
|
parameterize the agent, and therefore get the best out of static type checking.
|
|
If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright
|
|
or add a type hint `: Agent[None, <return type>]`.
|
|
name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
|
|
when the agent is first run.
|
|
model_settings: Optional model request settings to use for this agent's runs, by default.
|
|
retries: The default number of retries to allow before raising an error.
|
|
output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
|
|
tools: Tools to register with the agent, you can also register tools via the decorators
|
|
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
|
|
prepare_tools: custom method to prepare the tool definition of all tools for each step.
|
|
This is useful if you want to customize the definition of multiple tools or you want to register
|
|
a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
|
|
mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
|
|
for each server you want the agent to connect to.
|
|
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
|
|
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
|
|
which checks for the necessary environment variables. Set this to `false`
|
|
to defer the evaluation until the first run. Useful if you want to
|
|
[override the model][pydantic_ai.Agent.override] for testing.
|
|
end_strategy: Strategy for handling tool calls that are requested alongside a final result.
|
|
See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
|
|
instrument: Set to True to automatically instrument with OpenTelemetry,
|
|
which will use Logfire if it's configured.
|
|
Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize.
|
|
If this isn't set, then the last value set by
|
|
[`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all]
|
|
will be used, which defaults to False.
|
|
See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
|
|
history_processors: Optional list of callables to process the message history before sending it to the model.
|
|
Each processor takes a list of messages and returns a modified list of messages.
|
|
Processors can be sync or async and are applied in sequence.
|
|
\"\"\"
|
|
if model is None or defer_model_check:
|
|
self.model = model
|
|
else:
|
|
self.model = models.infer_model(model)
|
|
|
|
self.end_strategy = end_strategy
|
|
self.name = name
|
|
self.model_settings = model_settings
|
|
|
|
if 'result_type' in _deprecated_kwargs:
|
|
if output_type is not str: # pragma: no cover
|
|
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning)
|
|
output_type = _deprecated_kwargs.pop('result_type')
|
|
|
|
self.output_type = output_type
|
|
|
|
self.instrument = instrument
|
|
|
|
self._deps_type = deps_type
|
|
|
|
self._deprecated_result_tool_name = _deprecated_kwargs.pop('result_tool_name', None)
|
|
if self._deprecated_result_tool_name is not None:
|
|
warnings.warn(
|
|
'`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead',
|
|
DeprecationWarning,
|
|
)
|
|
|
|
self._deprecated_result_tool_description = _deprecated_kwargs.pop('result_tool_description', None)
|
|
if self._deprecated_result_tool_description is not None:
|
|
warnings.warn(
|
|
'`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead',
|
|
DeprecationWarning,
|
|
)
|
|
result_retries = _deprecated_kwargs.pop('result_retries', None)
|
|
if result_retries is not None:
|
|
if output_retries is not None: # pragma: no cover
|
|
raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.')
|
|
warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning)
|
|
output_retries = result_retries
|
|
|
|
default_output_mode = (
|
|
self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
|
|
)
|
|
_utils.validate_empty_kwargs(_deprecated_kwargs)
|
|
|
|
self._output_schema = _output.OutputSchema[OutputDataT].build(
|
|
output_type,
|
|
default_mode=default_output_mode,
|
|
name=self._deprecated_result_tool_name,
|
|
description=self._deprecated_result_tool_description,
|
|
)
|
|
self._output_validators = []
|
|
|
|
self._instructions = ''
|
|
self._instructions_functions = []
|
|
if isinstance(instructions, (str, Callable)):
|
|
instructions = [instructions]
|
|
for instruction in instructions or []:
|
|
if isinstance(instruction, str):
|
|
self._instructions += instruction + '\\n'
|
|
else:
|
|
self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction))
|
|
self._instructions = self._instructions.strip() or None
|
|
|
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
self._system_prompt_functions = []
|
|
self._system_prompt_dynamic_functions = {}
|
|
|
|
self._function_tools = {}
|
|
|
|
self._default_retries = retries
|
|
self._max_result_retries = output_retries if output_retries is not None else retries
|
|
self._mcp_servers = mcp_servers
|
|
self._prepare_tools = prepare_tools
|
|
self.history_processors = history_processors or []
|
|
for tool in tools:
|
|
if isinstance(tool, Tool):
|
|
self._register_tool(tool)
|
|
else:
|
|
self._register_tool(Tool(tool))
|
|
|
|
self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None)
|
|
self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)
|
|
@staticmethod
|
|
def is_end_node(
|
|
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
|
|
) -> TypeIs[End[result.FinalResult[S]]]:
|
|
\"\"\"Check if the node is a `End`, narrowing the type if it is.
|
|
This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
|
|
\"\"\"
|
|
return isinstance(node, End)
|
|
"""
|
|
original_cst = cst.parse_module(original_code)
|
|
code_explanations = [
|
|
CodeExplanationAndID(cst_module=original_cst, explanation="Removed unnecessary docstring", id="test-id")
|
|
]
|
|
|
|
# Apply the pipeline function
|
|
fix_missing_docstring(original_code, code_explanations)
|
|
# TODO : This test case fails
|
|
|
|
|
|
def test_benjamin_button() -> None:
|
|
original_code = """
|
|
def test_1():
|
|
\"\"\"useful docstring
|
|
has a lot of multi-line details.\"\"\"
|
|
pass
|
|
def test_2():
|
|
\"\"\"useful docstring\"\"\"
|
|
pass
|
|
"""
|
|
optimized_code = """
|
|
def test_1():
|
|
pass
|
|
def test_2():
|
|
\"\"\"useful docstring v2\"\"\"
|
|
pass
|
|
"""
|
|
# Extract original docstrings
|
|
original_visitor = DocstringVisitor()
|
|
original_tree = cst.parse_module(original_code)
|
|
original_tree.visit(original_visitor)
|
|
|
|
# Apply transformer to optimized code
|
|
transformer = DocstringTransformer(original_visitor.original_docstrings)
|
|
optimized_tree = cst.parse_module(optimized_code)
|
|
optimized_tree.visit(transformer)
|
|
|
|
# Check if all docstrings were preserved
|
|
assert True
|
|
|
|
|
|
def test_multiple_functions_and_classes() -> None:
|
|
# Test with multiple functions and classes
|
|
original_code = """
|
|
def function1():
|
|
\"\"\"Docstring for function1.\"\"\"
|
|
return 1
|
|
|
|
class Class1:
|
|
\"\"\"Docstring for Class1.\"\"\"
|
|
def method1(self):
|
|
\"\"\"Docstring for method1.\"\"\"
|
|
return 2
|
|
|
|
def method2(self):
|
|
\"\"\"Docstring for method2.\"\"\"
|
|
return 3
|
|
|
|
def function2():
|
|
\"\"\"Docstring for function2.\"\"\"
|
|
return 4
|
|
"""
|
|
optimized_code = """
|
|
def function1():
|
|
return 1
|
|
|
|
class Class1:
|
|
def method1(self):
|
|
return 2
|
|
|
|
def method2(self):
|
|
return 3
|
|
|
|
def function2():
|
|
return 4
|
|
"""
|
|
# Extract original docstrings
|
|
original_visitor = DocstringVisitor()
|
|
original_tree = cst.parse_module(original_code)
|
|
original_tree.visit(original_visitor)
|
|
|
|
# Apply transformer to optimized code
|
|
transformer = DocstringTransformer(original_visitor.original_docstrings)
|
|
optimized_tree = cst.parse_module(optimized_code)
|
|
transformed_tree = optimized_tree.visit(transformer)
|
|
|
|
# Check if all docstrings were preserved
|
|
transformed_code = transformed_tree.code
|
|
assert "Docstring for function1" in transformed_code
|
|
assert "Docstring for Class1" in transformed_code
|
|
assert "Docstring for method1" in transformed_code
|
|
assert "Docstring for method2" in transformed_code
|
|
assert "Docstring for function2" in transformed_code
|
|
|
|
|
|
# We can add fixtures to reduce code duplication
|
|
@pytest.fixture
|
|
def code_with_docstrings():
|
|
return {
|
|
"function": """
|
|
def example_function():
|
|
\"\"\"This is a docstring for the example function.\"\"\"
|
|
return 42
|
|
""",
|
|
"class": """
|
|
class ExampleClass:
|
|
\"\"\"This is a docstring for the example class.\"\"\"
|
|
def method(self):
|
|
return 42
|
|
""",
|
|
"method": """
|
|
class ExampleClass:
|
|
def method(self):
|
|
\"\"\"This is a docstring for the method.\"\"\"
|
|
return 42
|
|
""",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def code_without_docstrings():
|
|
return {
|
|
"function": """
|
|
def example_function():
|
|
return 42
|
|
""",
|
|
"class": """
|
|
class ExampleClass:
|
|
def method(self):
|
|
return 42
|
|
""",
|
|
"method": """
|
|
class ExampleClass:
|
|
def method(self):
|
|
return 42
|
|
""",
|
|
}
|
|
|
|
|
|
def test_docstring_preservation_with_fixtures(code_with_docstrings, code_without_docstrings) -> None:
|
|
# Test all three types of docstrings using fixtures
|
|
for code_type in ["function", "class", "method"]:
|
|
original_code = code_with_docstrings[code_type]
|
|
optimized_code = code_without_docstrings[code_type]
|
|
|
|
# Extract original docstrings
|
|
original_visitor = DocstringVisitor()
|
|
original_tree = cst.parse_module(original_code)
|
|
original_tree.visit(original_visitor)
|
|
|
|
# Apply transformer to optimized code
|
|
transformer = DocstringTransformer(original_visitor.original_docstrings)
|
|
optimized_tree = cst.parse_module(optimized_code)
|
|
transformed_tree = optimized_tree.visit(transformer)
|
|
|
|
# Check if docstring was preserved
|
|
assert "This is a docstring for the" in transformed_tree.code
|