mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge commit '6020c4fa' into sync-main-batch-3
This commit is contained in:
commit
85d1d4fbf6
41 changed files with 1533 additions and 1974 deletions
26
.github/workflows/claude.yml
vendored
26
.github/workflows/claude.yml
vendored
|
|
@ -42,11 +42,17 @@ jobs:
|
|||
uv venv --seed
|
||||
uv sync
|
||||
|
||||
- name: Configure AWS Credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }}
|
||||
aws-region: ${{ secrets.AWS_REGION }}
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
use_foundry: "true"
|
||||
use_bedrock: "true"
|
||||
use_sticky_comment: true
|
||||
allowed_bots: "claude[bot],codeflash-ai[bot]"
|
||||
prompt: |
|
||||
|
|
@ -173,12 +179,9 @@ jobs:
|
|||
2. For each optimization PR:
|
||||
- Check if CI is passing: `gh pr checks <number>`
|
||||
- If all checks pass, merge it: `gh pr merge <number> --squash --delete-branch`
|
||||
claude_args: '--model claude-opus-4-6 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"'
|
||||
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"'
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
env:
|
||||
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
|
||||
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}
|
||||
|
||||
# @claude mentions (can edit and push) - restricted to maintainers only
|
||||
claude-mention:
|
||||
|
|
@ -240,14 +243,17 @@ jobs:
|
|||
uv venv --seed
|
||||
uv sync
|
||||
|
||||
- name: Configure AWS Credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }}
|
||||
aws-region: ${{ secrets.AWS_REGION }}
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
use_foundry: "true"
|
||||
claude_args: '--model claude-opus-4-6 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
|
||||
use_bedrock: "true"
|
||||
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
env:
|
||||
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
|
||||
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}
|
||||
|
|
|
|||
12
.github/workflows/duplicate-code-detector.yml
vendored
12
.github/workflows/duplicate-code-detector.yml
vendored
|
|
@ -42,10 +42,16 @@ jobs:
|
|||
}
|
||||
EOF
|
||||
|
||||
- name: Configure AWS Credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }}
|
||||
aws-region: ${{ secrets.AWS_REGION }}
|
||||
|
||||
- name: Run Claude Code
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
use_foundry: "true"
|
||||
use_bedrock: "true"
|
||||
use_sticky_comment: true
|
||||
allowed_bots: "claude[bot],codeflash-ai[bot]"
|
||||
claude_args: '--mcp-config /tmp/mcp-config/mcp-servers.json --allowedTools "Read,Glob,Grep,Bash(git diff:*),Bash(git log:*),Bash(git show:*),Bash(wc *),Bash(find *),mcp__serena__*"'
|
||||
|
|
@ -105,10 +111,6 @@ jobs:
|
|||
- Concrete refactoring suggestion
|
||||
|
||||
If no significant duplication is found, say so briefly. Do not create issues — just comment on the PR.
|
||||
env:
|
||||
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
|
||||
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}
|
||||
|
||||
- name: Stop Serena
|
||||
if: always()
|
||||
run: docker stop serena && docker rm serena || true
|
||||
|
|
|
|||
50
.github/workflows/js-tests.yml
vendored
50
.github/workflows/js-tests.yml
vendored
|
|
@ -1,50 +0,0 @@
|
|||
name: JavaScript/TypeScript Integration Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
js-integration-tests:
|
||||
name: JS/TS Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
uv venv --seed
|
||||
uv sync
|
||||
|
||||
- name: Install npm dependencies for test projects
|
||||
run: |
|
||||
npm install --prefix code_to_optimize/js/code_to_optimize_js
|
||||
npm install --prefix code_to_optimize/js/code_to_optimize_ts
|
||||
npm install --prefix code_to_optimize/js/code_to_optimize_vitest
|
||||
|
||||
- name: Run JavaScript integration tests
|
||||
run: |
|
||||
uv run pytest tests/languages/javascript/ -v
|
||||
uv run pytest tests/test_languages/test_vitest_e2e.py -v
|
||||
uv run pytest tests/test_languages/test_javascript_e2e.py -v
|
||||
uv run pytest tests/test_languages/test_javascript_support.py -v
|
||||
uv run pytest tests/code_utils/test_config_js.py -v
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -274,3 +274,5 @@ tessl.json
|
|||
|
||||
# Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status
|
||||
AGENTS.md
|
||||
.serena/
|
||||
.codeflash/
|
||||
|
|
|
|||
98
LICENSE
Normal file
98
LICENSE
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
Business Source License 1.1
|
||||
|
||||
Parameters
|
||||
|
||||
Licensor: CodeFlash Inc.
|
||||
Licensed Work: Codeflash Client version 0.20.x
|
||||
The Licensed Work is (c) 2024 CodeFlash Inc.
|
||||
|
||||
Additional Use Grant: None. Production use of the Licensed Work is only permitted
|
||||
if you have entered into a separate written agreement
|
||||
with CodeFlash Inc. for production use in connection
|
||||
with a subscription to CodeFlash's Code Optimization
|
||||
Platform. Please visit codeflash.ai for further
|
||||
information.
|
||||
|
||||
Change Date: 2030-01-26
|
||||
|
||||
Change License: MIT
|
||||
|
||||
Notice
|
||||
|
||||
The Business Source License (this document, or the “License”) is not an Open
|
||||
Source license. However, the Licensed Work will eventually be made available
|
||||
under an Open Source License, as stated in this License.
|
||||
|
||||
License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved.
|
||||
“Business Source License” is a trademark of MariaDB Corporation Ab.
|
||||
|
||||
-----------------------------------------------------------------------------
|
||||
|
||||
Business Source License 1.1
|
||||
|
||||
Terms
|
||||
|
||||
The Licensor hereby grants you the right to copy, modify, create derivative
|
||||
works, redistribute, and make non-production use of the Licensed Work. The
|
||||
Licensor may make an Additional Use Grant, above, permitting limited
|
||||
production use.
|
||||
|
||||
Effective on the Change Date, or the fourth anniversary of the first publicly
|
||||
available distribution of a specific version of the Licensed Work under this
|
||||
License, whichever comes first, the Licensor hereby grants you rights under
|
||||
the terms of the Change License, and the rights granted in the paragraph
|
||||
above terminate.
|
||||
|
||||
If your use of the Licensed Work does not comply with the requirements
|
||||
currently in effect as described in this License, you must purchase a
|
||||
commercial license from the Licensor, its affiliated entities, or authorized
|
||||
resellers, or you must refrain from using the Licensed Work.
|
||||
|
||||
All copies of the original and modified Licensed Work, and derivative works
|
||||
of the Licensed Work, are subject to this License. This License applies
|
||||
separately for each version of the Licensed Work and the Change Date may vary
|
||||
for each version of the Licensed Work released by Licensor.
|
||||
|
||||
You must conspicuously display this License on each original or modified copy
|
||||
of the Licensed Work. If you receive the Licensed Work in original or
|
||||
modified form from a third party, the terms and conditions set forth in this
|
||||
License apply to your use of that work.
|
||||
|
||||
Any use of the Licensed Work in violation of this License will automatically
|
||||
terminate your rights under this License for the current and all other
|
||||
versions of the Licensed Work.
|
||||
|
||||
This License does not grant you any right in any trademark or logo of
|
||||
Licensor or its affiliates (provided that you may use a trademark or logo of
|
||||
Licensor as expressly required by this License).
|
||||
|
||||
TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON
|
||||
AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS,
|
||||
EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND
|
||||
TITLE.
|
||||
|
||||
MariaDB hereby grants you permission to use this License’s text to license
|
||||
your works, and to refer to it using the trademark “Business Source License”,
|
||||
as long as you comply with the Covenants of Licensor below.
|
||||
|
||||
Covenants of Licensor
|
||||
|
||||
In consideration of the right to use this License’s text and the “Business
|
||||
Source License” name and trademark, Licensor covenants to MariaDB, and to all
|
||||
other recipients of the licensed work to be provided by Licensor:
|
||||
|
||||
1. To specify as the Change License the GPL Version 2.0 or any later version,
|
||||
or a license that is compatible with GPL Version 2.0 or a later version,
|
||||
where “compatible” means that software provided under the Change License can
|
||||
be included in a program with software provided under GPL Version 2.0 or a
|
||||
later version. Licensor may specify additional Change Licenses without
|
||||
limitation.
|
||||
|
||||
2. To either: (a) specify an additional grant of rights to use that does not
|
||||
impose any additional restriction on the right granted in this License, as
|
||||
the Additional Use Grant; or (b) insert the text “None”.
|
||||
|
||||
3. To specify a Change Date.
|
||||
|
||||
4. Not to modify this License in any other way.
|
||||
98
codeflash-benchmark/LICENSE
Normal file
98
codeflash-benchmark/LICENSE
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
Business Source License 1.1
|
||||
|
||||
Parameters
|
||||
|
||||
Licensor: CodeFlash Inc.
|
||||
Licensed Work: Codeflash Client version 0.20.x
|
||||
The Licensed Work is (c) 2024 CodeFlash Inc.
|
||||
|
||||
Additional Use Grant: None. Production use of the Licensed Work is only permitted
|
||||
if you have entered into a separate written agreement
|
||||
with CodeFlash Inc. for production use in connection
|
||||
with a subscription to CodeFlash's Code Optimization
|
||||
Platform. Please visit codeflash.ai for further
|
||||
information.
|
||||
|
||||
Change Date: 2030-01-26
|
||||
|
||||
Change License: MIT
|
||||
|
||||
Notice
|
||||
|
||||
The Business Source License (this document, or the “License”) is not an Open
|
||||
Source license. However, the Licensed Work will eventually be made available
|
||||
under an Open Source License, as stated in this License.
|
||||
|
||||
License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved.
|
||||
“Business Source License” is a trademark of MariaDB Corporation Ab.
|
||||
|
||||
-----------------------------------------------------------------------------
|
||||
|
||||
Business Source License 1.1
|
||||
|
||||
Terms
|
||||
|
||||
The Licensor hereby grants you the right to copy, modify, create derivative
|
||||
works, redistribute, and make non-production use of the Licensed Work. The
|
||||
Licensor may make an Additional Use Grant, above, permitting limited
|
||||
production use.
|
||||
|
||||
Effective on the Change Date, or the fourth anniversary of the first publicly
|
||||
available distribution of a specific version of the Licensed Work under this
|
||||
License, whichever comes first, the Licensor hereby grants you rights under
|
||||
the terms of the Change License, and the rights granted in the paragraph
|
||||
above terminate.
|
||||
|
||||
If your use of the Licensed Work does not comply with the requirements
|
||||
currently in effect as described in this License, you must purchase a
|
||||
commercial license from the Licensor, its affiliated entities, or authorized
|
||||
resellers, or you must refrain from using the Licensed Work.
|
||||
|
||||
All copies of the original and modified Licensed Work, and derivative works
|
||||
of the Licensed Work, are subject to this License. This License applies
|
||||
separately for each version of the Licensed Work and the Change Date may vary
|
||||
for each version of the Licensed Work released by Licensor.
|
||||
|
||||
You must conspicuously display this License on each original or modified copy
|
||||
of the Licensed Work. If you receive the Licensed Work in original or
|
||||
modified form from a third party, the terms and conditions set forth in this
|
||||
License apply to your use of that work.
|
||||
|
||||
Any use of the Licensed Work in violation of this License will automatically
|
||||
terminate your rights under this License for the current and all other
|
||||
versions of the Licensed Work.
|
||||
|
||||
This License does not grant you any right in any trademark or logo of
|
||||
Licensor or its affiliates (provided that you may use a trademark or logo of
|
||||
Licensor as expressly required by this License).
|
||||
|
||||
TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON
|
||||
AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS,
|
||||
EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND
|
||||
TITLE.
|
||||
|
||||
MariaDB hereby grants you permission to use this License’s text to license
|
||||
your works, and to refer to it using the trademark “Business Source License”,
|
||||
as long as you comply with the Covenants of Licensor below.
|
||||
|
||||
Covenants of Licensor
|
||||
|
||||
In consideration of the right to use this License’s text and the “Business
|
||||
Source License” name and trademark, Licensor covenants to MariaDB, and to all
|
||||
other recipients of the licensed work to be provided by Licensor:
|
||||
|
||||
1. To specify as the Change License the GPL Version 2.0 or any later version,
|
||||
or a license that is compatible with GPL Version 2.0 or a later version,
|
||||
where “compatible” means that software provided under the Change License can
|
||||
be included in a program with software provided under GPL Version 2.0 or a
|
||||
later version. Licensor may specify additional Change Licenses without
|
||||
limitation.
|
||||
|
||||
2. To either: (a) specify an additional grant of rights to use that does not
|
||||
impose any additional restriction on the right granted in this License, as
|
||||
the Additional Use Grant; or (b) insert the text “None”.
|
||||
|
||||
3. To specify a Change Date.
|
||||
|
||||
4. Not to modify this License in any other way.
|
||||
15
codeflash-benchmark/README.md
Normal file
15
codeflash-benchmark/README.md
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# CodeFlash Benchmark
|
||||
|
||||
A pytest benchmarking plugin for [CodeFlash](https://codeflash.ai) - automatic code performance optimization.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install codeflash-benchmark
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
This plugin provides benchmarking capabilities for pytest tests used by CodeFlash's optimization pipeline.
|
||||
|
||||
For more information, visit [codeflash.ai](https://codeflash.ai).
|
||||
|
|
@ -1,32 +1,32 @@
|
|||
[project]
|
||||
name = "codeflash-benchmark"
|
||||
version = "0.2.0"
|
||||
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
|
||||
authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }]
|
||||
requires-python = ">=3.9"
|
||||
readme = "README.md"
|
||||
license = {text = "BSL-1.1"}
|
||||
keywords = [
|
||||
"codeflash",
|
||||
"benchmark",
|
||||
"pytest",
|
||||
"performance",
|
||||
"testing",
|
||||
]
|
||||
dependencies = [
|
||||
"pytest>=7.0.0,!=8.3.4",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://codeflash.ai"
|
||||
Repository = "https://github.com/codeflash-ai/codeflash-benchmark"
|
||||
|
||||
[project.entry-points.pytest11]
|
||||
codeflash-benchmark = "codeflash_benchmark.plugin"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["codeflash_benchmark"]
|
||||
[project]
|
||||
name = "codeflash-benchmark"
|
||||
version = "0.2.0"
|
||||
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
|
||||
authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }]
|
||||
requires-python = ">=3.9"
|
||||
readme = "README.md"
|
||||
license-files = ["LICENSE"]
|
||||
keywords = [
|
||||
"codeflash",
|
||||
"benchmark",
|
||||
"pytest",
|
||||
"performance",
|
||||
"testing",
|
||||
]
|
||||
dependencies = [
|
||||
"pytest>=7.0.0,!=8.3.4",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://codeflash.ai"
|
||||
Repository = "https://github.com/codeflash-ai/codeflash-benchmark"
|
||||
|
||||
[project.entry-points.pytest11]
|
||||
codeflash-benchmark = "codeflash_benchmark.plugin"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["codeflash_benchmark"]
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from enum import Enum
|
|||
from typing import Any, Union
|
||||
|
||||
MAX_TEST_RUN_ITERATIONS = 5
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 16000
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT = 16000
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 48000
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT = 48000
|
||||
INDIVIDUAL_TESTCASE_TIMEOUT = 15 # For Python pytest
|
||||
JAVA_TESTCASE_TIMEOUT = 120 # Java Maven tests need more time due to startup overhead
|
||||
MAX_FUNCTION_TEST_SECONDS = 60
|
||||
|
|
|
|||
|
|
@ -1518,73 +1518,207 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
return False
|
||||
|
||||
|
||||
class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
||||
"""Transformer that adds the import for async decorators."""
|
||||
ASYNC_HELPER_INLINE_CODE = """import asyncio
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
|
||||
self.mode = mode
|
||||
self.has_import = False
|
||||
import dill as pickle
|
||||
|
||||
def _get_decorator_name(self) -> str:
|
||||
"""Get the decorator name based on the testing mode."""
|
||||
if self.mode == TestingMode.BEHAVIOR:
|
||||
return "codeflash_behavior_async"
|
||||
if self.mode == TestingMode.CONCURRENCY:
|
||||
return "codeflash_concurrency_async"
|
||||
return "codeflash_performance_async"
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
# Check if the async decorator import is already present
|
||||
if (
|
||||
isinstance(node.module, cst.Attribute)
|
||||
and isinstance(node.module.value, cst.Attribute)
|
||||
and isinstance(node.module.value.value, cst.Name)
|
||||
and node.module.value.value.value == "codeflash"
|
||||
and node.module.value.attr.value == "code_utils"
|
||||
and node.module.attr.value == "codeflash_wrap_decorator"
|
||||
and not isinstance(node.names, cst.ImportStar)
|
||||
):
|
||||
decorator_name = self._get_decorator_name()
|
||||
for import_alias in node.names:
|
||||
if import_alias.name.value == decorator_name:
|
||||
self.has_import = True
|
||||
def get_run_tmp_file(file_path):
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# If the import is already there, don't add it again
|
||||
if self.has_import:
|
||||
return updated_node
|
||||
|
||||
# Choose import based on mode
|
||||
decorator_name = self._get_decorator_name()
|
||||
def extract_test_context_from_env():
|
||||
test_module = os.environ["CODEFLASH_TEST_MODULE"]
|
||||
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
|
||||
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
|
||||
if test_module and test_function:
|
||||
return (test_module, test_class if test_class else None, test_function)
|
||||
raise RuntimeError(
|
||||
"Test context environment variables not set - ensure tests are run through codeflash test runner"
|
||||
)
|
||||
|
||||
# Parse the import statement into a CST node
|
||||
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
|
||||
|
||||
# Add the import to the module's body
|
||||
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
|
||||
def codeflash_behavior_async(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
test_module_name, test_class_name, test_name = extract_test_context_from_env()
|
||||
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {}
|
||||
if test_id in async_wrapper.index:
|
||||
async_wrapper.index[test_id] += 1
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0
|
||||
codeflash_test_index = async_wrapper.index[test_id]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
class_prefix = (test_class_name + ".") if test_class_name else ""
|
||||
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
|
||||
codeflash_con = sqlite3.connect(db_path)
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
codeflash_cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
|
||||
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
|
||||
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
|
||||
)
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
print(f"!######{test_stdout_tag}######!")
|
||||
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
|
||||
codeflash_cur.execute(
|
||||
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
function_name,
|
||||
loop_index,
|
||||
invocation_id,
|
||||
codeflash_duration,
|
||||
pickled_return_value,
|
||||
"function_call",
|
||||
),
|
||||
)
|
||||
codeflash_con.commit()
|
||||
codeflash_con.close()
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def codeflash_performance_async(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
test_module_name, test_class_name, test_name = extract_test_context_from_env()
|
||||
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {}
|
||||
if test_id in async_wrapper.index:
|
||||
async_wrapper.index[test_id] += 1
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0
|
||||
codeflash_test_index = async_wrapper.index[test_id]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
class_prefix = (test_class_name + ".") if test_class_name else ""
|
||||
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def codeflash_concurrency_async(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
function_name = func.__name__
|
||||
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
|
||||
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
|
||||
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
|
||||
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
|
||||
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
|
||||
gc.disable()
|
||||
try:
|
||||
seq_start = time.perf_counter_ns()
|
||||
for _ in range(concurrency_factor):
|
||||
result = await func(*args, **kwargs)
|
||||
sequential_time = time.perf_counter_ns() - seq_start
|
||||
finally:
|
||||
gc.enable()
|
||||
gc.disable()
|
||||
try:
|
||||
conc_start = time.perf_counter_ns()
|
||||
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
|
||||
await asyncio.gather(*tasks)
|
||||
concurrent_time = time.perf_counter_ns() - conc_start
|
||||
finally:
|
||||
gc.enable()
|
||||
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
|
||||
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
|
||||
return result
|
||||
return async_wrapper
|
||||
"""
|
||||
|
||||
ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
|
||||
|
||||
|
||||
def get_decorator_name_for_mode(mode: TestingMode) -> str:
|
||||
if mode == TestingMode.BEHAVIOR:
|
||||
return "codeflash_behavior_async"
|
||||
if mode == TestingMode.CONCURRENCY:
|
||||
return "codeflash_concurrency_async"
|
||||
return "codeflash_performance_async"
|
||||
|
||||
|
||||
def write_async_helper_file(target_dir: Path) -> Path:
|
||||
"""Write the async decorator helper file to the target directory."""
|
||||
helper_path = target_dir / ASYNC_HELPER_FILENAME
|
||||
if not helper_path.exists():
|
||||
helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8")
|
||||
return helper_path
|
||||
|
||||
|
||||
def add_async_decorator_to_function(
|
||||
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
|
||||
source_path: Path,
|
||||
function: FunctionToOptimize,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
project_root: Path | None = None,
|
||||
) -> bool:
|
||||
"""Add async decorator to an async function definition and write back to file.
|
||||
|
||||
Args:
|
||||
----
|
||||
source_path: Path to the source file to modify in-place.
|
||||
function: The FunctionToOptimize object representing the target async function.
|
||||
mode: The testing mode to determine which decorator to apply.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
Boolean indicating whether the decorator was successfully added.
|
||||
Writes a helper file containing the decorator implementation to project_root (or source directory
|
||||
as fallback) and adds a standard import + decorator to the source file.
|
||||
|
||||
"""
|
||||
if not function.is_async:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Read source code
|
||||
with source_path.open(encoding="utf8") as f:
|
||||
source_code = f.read()
|
||||
|
||||
|
|
@ -1594,10 +1728,14 @@ def add_async_decorator_to_function(
|
|||
decorator_transformer = AsyncDecoratorAdder(function, mode)
|
||||
module = module.visit(decorator_transformer)
|
||||
|
||||
# Add the import if decorator was added
|
||||
if decorator_transformer.added_decorator:
|
||||
import_transformer = AsyncDecoratorImportAdder(mode)
|
||||
module = module.visit(import_transformer)
|
||||
# Write the helper file to project_root (on sys.path) or source dir as fallback
|
||||
helper_dir = project_root if project_root is not None else source_path.parent
|
||||
write_async_helper_file(helper_dir)
|
||||
# Add the import via CST so sort_imports can place it correctly
|
||||
decorator_name = get_decorator_name_for_mode(mode)
|
||||
import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}")
|
||||
module = module.with_changes(body=[import_node, *list(module.body)])
|
||||
|
||||
modified_code = sort_imports(code=module.code, float_to_top=True)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -520,15 +520,6 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_comment_prefix(self) -> str:
|
||||
"""Get the comment prefix for this language.
|
||||
|
||||
Returns:
|
||||
Comment prefix (e.g., "//" for JS, "#" for Python).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def find_test_root(self, project_root: Path) -> Path | None:
|
||||
"""Find the test root directory for a project.
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
|||
from codeflash.languages.base import LanguageSupport
|
||||
|
||||
# Module-level singleton for the current language
|
||||
_current_language: Language | None = None
|
||||
_current_language: Language = Language.PYTHON
|
||||
|
||||
|
||||
def current_language() -> Language:
|
||||
|
|
|
|||
|
|
@ -527,10 +527,5 @@ def parse_jest_test_xml(
|
|||
f"[LOOP-SUMMARY] Results loop_index: min={min_idx}, max={max_idx}, "
|
||||
f"unique_count={len(unique_loop_indices)}, total_results={len(loop_indices)}"
|
||||
)
|
||||
if max_idx == 1 and len(loop_indices) > 1:
|
||||
logger.warning(
|
||||
f"[LOOP-WARNING] All {len(loop_indices)} results have loop_index=1. "
|
||||
"Perf test markers may not have been parsed correctly."
|
||||
)
|
||||
|
||||
return test_results
|
||||
|
|
|
|||
|
|
@ -1805,15 +1805,6 @@ class JavaScriptSupport:
|
|||
"""
|
||||
return ".test.js"
|
||||
|
||||
def get_comment_prefix(self) -> str:
|
||||
"""Get the comment prefix for JavaScript.
|
||||
|
||||
Returns:
|
||||
JavaScript single-line comment prefix.
|
||||
|
||||
"""
|
||||
return "//"
|
||||
|
||||
def find_test_root(self, project_root: Path) -> Path | None:
|
||||
"""Find the test root directory for a JavaScript project.
|
||||
|
||||
|
|
|
|||
|
|
@ -803,8 +803,6 @@ def run_jest_behavioral_tests(
|
|||
wall_clock_ns = time.perf_counter_ns() - start_time_ns
|
||||
logger.debug(f"Jest behavioral tests completed in {wall_clock_ns / 1e9:.2f}s")
|
||||
|
||||
print(result.stdout)
|
||||
|
||||
return result_file_path, result, coverage_json_path, None
|
||||
|
||||
|
||||
|
|
@ -1046,6 +1044,10 @@ def run_jest_benchmarking_tests(
|
|||
|
||||
# Create result with combined stdout
|
||||
result = subprocess.CompletedProcess(args=result.args, returncode=result.returncode, stdout=stdout, stderr="")
|
||||
if result.returncode != 0:
|
||||
logger.info(f"Jest benchmarking failed with return code {result.returncode}")
|
||||
logger.info(f"Jest benchmarking stdout: {result.stdout}")
|
||||
logger.info(f"Jest benchmarking stderr: {result.stderr}")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"Jest benchmarking timed out after {total_timeout}s")
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -15,6 +15,8 @@ from codeflash.languages import is_java, is_javascript
|
|||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, FunctionSource
|
||||
|
||||
|
|
@ -49,6 +51,69 @@ def extract_names_from_targets(target: cst.CSTNode) -> list[str]:
|
|||
return names
|
||||
|
||||
|
||||
def is_assignment_used(node: cst.CSTNode, definitions: dict[str, UsageInfo], name_prefix: str = "") -> bool:
|
||||
if isinstance(node, cst.Assign):
|
||||
for target in node.targets:
|
||||
names = extract_names_from_targets(target.target)
|
||||
for name in names:
|
||||
lookup = f"{name_prefix}{name}" if name_prefix else name
|
||||
if lookup in definitions and definitions[lookup].used_by_qualified_function:
|
||||
return True
|
||||
return False
|
||||
if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
|
||||
names = extract_names_from_targets(node.target)
|
||||
for name in names:
|
||||
lookup = f"{name_prefix}{name}" if name_prefix else name
|
||||
if lookup in definitions and definitions[lookup].used_by_qualified_function:
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def recurse_sections(
|
||||
node: cst.CSTNode,
|
||||
section_names: list[str],
|
||||
prune_fn: Callable[[cst.CSTNode], tuple[cst.CSTNode | None, bool]],
|
||||
keep_non_target_children: bool = False,
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
|
||||
found_any_target = False
|
||||
for section in section_names:
|
||||
original_content = getattr(node, section, None)
|
||||
if isinstance(original_content, (list, tuple)):
|
||||
new_children = []
|
||||
section_found_target = False
|
||||
for child in original_content:
|
||||
filtered, found_target = prune_fn(child)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
section_found_target |= found_target
|
||||
if keep_non_target_children:
|
||||
if section_found_target or new_children:
|
||||
found_any_target |= section_found_target
|
||||
updates[section] = new_children
|
||||
elif section_found_target:
|
||||
found_any_target = True
|
||||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, found_target = prune_fn(original_content)
|
||||
if keep_non_target_children:
|
||||
found_any_target |= found_target
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
elif found_target:
|
||||
found_any_target = True
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
if keep_non_target_children:
|
||||
if updates:
|
||||
return node.with_changes(**updates), found_any_target
|
||||
return None, False
|
||||
if not found_any_target:
|
||||
return None, False
|
||||
return (node.with_changes(**updates) if updates else node), True
|
||||
|
||||
|
||||
def collect_top_level_definitions(
|
||||
node: cst.CSTNode, definitions: Optional[dict[str, UsageInfo]] = None
|
||||
) -> dict[str, UsageInfo]:
|
||||
|
|
@ -423,27 +488,9 @@ def remove_unused_definitions_recursively(
|
|||
elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
|
||||
var_used = False
|
||||
|
||||
# Check if any variable in this assignment is used
|
||||
if isinstance(statement, cst.Assign):
|
||||
for target in statement.targets:
|
||||
names = extract_names_from_targets(target.target)
|
||||
for name in names:
|
||||
class_var_name = f"{class_name}.{name}"
|
||||
if (
|
||||
class_var_name in definitions
|
||||
and definitions[class_var_name].used_by_qualified_function
|
||||
):
|
||||
var_used = True
|
||||
method_or_var_used = True
|
||||
break
|
||||
elif isinstance(statement, (cst.AnnAssign, cst.AugAssign)):
|
||||
names = extract_names_from_targets(statement.target)
|
||||
for name in names:
|
||||
class_var_name = f"{class_name}.{name}"
|
||||
if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function:
|
||||
var_used = True
|
||||
method_or_var_used = True
|
||||
break
|
||||
if is_assignment_used(statement, definitions, name_prefix=f"{class_name}."):
|
||||
var_used = True
|
||||
method_or_var_used = True
|
||||
|
||||
if var_used or class_has_dependencies:
|
||||
new_statements.append(statement)
|
||||
|
|
@ -459,56 +506,19 @@ def remove_unused_definitions_recursively(
|
|||
|
||||
return node, method_or_var_used or class_has_dependencies
|
||||
|
||||
# Handle assignments (Assign and AnnAssign)
|
||||
if isinstance(node, cst.Assign):
|
||||
for target in node.targets:
|
||||
names = extract_names_from_targets(target.target)
|
||||
for name in names:
|
||||
if name in definitions and definitions[name].used_by_qualified_function:
|
||||
return node, True
|
||||
return None, False
|
||||
|
||||
if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
|
||||
names = extract_names_from_targets(node.target)
|
||||
for name in names:
|
||||
if name in definitions and definitions[name].used_by_qualified_function:
|
||||
return node, True
|
||||
# Handle assignments (Assign, AnnAssign, AugAssign)
|
||||
if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
|
||||
if is_assignment_used(node, definitions):
|
||||
return node, True
|
||||
return None, False
|
||||
|
||||
# For other nodes, recursively process children
|
||||
section_names = get_section_names(node)
|
||||
if not section_names:
|
||||
return node, False
|
||||
|
||||
updates = {}
|
||||
found_used = False
|
||||
|
||||
for section in section_names:
|
||||
original_content = getattr(node, section, None)
|
||||
if isinstance(original_content, (list, tuple)):
|
||||
new_children = []
|
||||
section_found_used = False
|
||||
|
||||
for child in original_content:
|
||||
filtered, used = remove_unused_definitions_recursively(child, definitions)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
section_found_used |= used
|
||||
|
||||
if new_children or section_found_used:
|
||||
found_used |= section_found_used
|
||||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, used = remove_unused_definitions_recursively(original_content, definitions)
|
||||
found_used |= used
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
if not found_used:
|
||||
return None, False
|
||||
if updates:
|
||||
return node.with_changes(**updates), found_used
|
||||
|
||||
return node, False
|
||||
return recurse_sections(
|
||||
node, section_names, lambda child: remove_unused_definitions_recursively(child, definitions)
|
||||
)
|
||||
|
||||
|
||||
def collect_top_level_defs_with_usages(
|
||||
|
|
@ -21,9 +21,25 @@ from codeflash.languages.registry import register_language
|
|||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from codeflash.models.models import FunctionSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def function_sources_to_helpers(sources: list[FunctionSource]) -> list[HelperFunction]:
|
||||
return [
|
||||
HelperFunction(
|
||||
name=fs.only_function_name,
|
||||
qualified_name=fs.qualified_name,
|
||||
file_path=fs.file_path,
|
||||
source_code=fs.source_code,
|
||||
start_line=fs.jedi_definition.line if fs.jedi_definition else 1,
|
||||
end_line=fs.jedi_definition.line if fs.jedi_definition else 1,
|
||||
)
|
||||
for fs in sources
|
||||
]
|
||||
|
||||
|
||||
@register_language
|
||||
class PythonSupport:
|
||||
"""Python language support implementation.
|
||||
|
|
@ -171,127 +187,39 @@ class PythonSupport:
|
|||
# === Code Analysis ===
|
||||
|
||||
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
|
||||
"""Extract function code and its dependencies.
|
||||
"""Extract function code and its dependencies via the canonical context pipeline."""
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
Uses jedi and libcst for Python code analysis.
|
||||
|
||||
Args:
|
||||
function: The function to extract context for.
|
||||
project_root: Root of the project.
|
||||
module_root: Root of the module containing the function.
|
||||
|
||||
Returns:
|
||||
CodeContext with target code and dependencies.
|
||||
|
||||
"""
|
||||
try:
|
||||
source = function.file_path.read_text()
|
||||
result = get_code_optimization_context(function, project_root)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to read %s: %s", function.file_path, e)
|
||||
logger.warning("Failed to extract code context for %s: %s", function.function_name, e)
|
||||
return CodeContext(target_code="", target_file=function.file_path, language=Language.PYTHON)
|
||||
|
||||
# Extract the function source
|
||||
lines = source.splitlines(keepends=True)
|
||||
if function.starting_line and function.ending_line:
|
||||
target_lines = lines[function.starting_line - 1 : function.ending_line]
|
||||
target_code = "".join(target_lines)
|
||||
else:
|
||||
target_code = ""
|
||||
|
||||
# Find helper functions
|
||||
helpers = self.find_helper_functions(function, project_root)
|
||||
|
||||
# Extract imports
|
||||
import_lines = []
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if stripped.startswith(("import ", "from ")):
|
||||
import_lines.append(stripped)
|
||||
elif stripped and not stripped.startswith("#"):
|
||||
# Stop at first non-import, non-comment line
|
||||
break
|
||||
helpers = function_sources_to_helpers(result.helper_functions)
|
||||
|
||||
return CodeContext(
|
||||
target_code=target_code,
|
||||
target_code=result.read_writable_code.markdown,
|
||||
target_file=function.file_path,
|
||||
helper_functions=helpers,
|
||||
read_only_context="",
|
||||
imports=import_lines,
|
||||
read_only_context=result.read_only_context_code,
|
||||
imports=[],
|
||||
language=Language.PYTHON,
|
||||
)
|
||||
|
||||
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
|
||||
"""Find helper functions called by the target function.
|
||||
|
||||
Uses jedi for Python code analysis.
|
||||
|
||||
Args:
|
||||
function: The target function to analyze.
|
||||
project_root: Root of the project.
|
||||
|
||||
Returns:
|
||||
List of HelperFunction objects.
|
||||
|
||||
"""
|
||||
helpers: list[HelperFunction] = []
|
||||
"""Find helper functions called by the target function via the canonical jedi pipeline."""
|
||||
from codeflash.languages.python.context.code_context_extractor import get_function_sources_from_jedi
|
||||
|
||||
try:
|
||||
import jedi
|
||||
|
||||
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.optimization.function_context import belongs_to_function_qualified
|
||||
|
||||
script = jedi.Script(path=function.file_path, project=jedi.Project(path=project_root))
|
||||
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
|
||||
|
||||
qualified_name = function.qualified_name
|
||||
|
||||
for ref in file_refs:
|
||||
if not ref.full_name or not belongs_to_function_qualified(ref, qualified_name):
|
||||
continue
|
||||
|
||||
try:
|
||||
definitions = ref.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for definition in definitions:
|
||||
definition_path = definition.module_path
|
||||
if definition_path is None:
|
||||
continue
|
||||
|
||||
# Check if it's a valid helper (in project, not in target function)
|
||||
is_valid = (
|
||||
str(definition_path).startswith(str(project_root))
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and not belongs_to_function_qualified(definition, qualified_name)
|
||||
and definition.type == "function"
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
helper_qualified_name = get_qualified_name(definition.module_name, definition.full_name)
|
||||
# Get source code
|
||||
try:
|
||||
helper_source = definition.get_line_code()
|
||||
except Exception:
|
||||
helper_source = ""
|
||||
|
||||
helpers.append(
|
||||
HelperFunction(
|
||||
name=definition.name,
|
||||
qualified_name=helper_qualified_name,
|
||||
file_path=definition_path,
|
||||
source_code=helper_source,
|
||||
start_line=definition.line or 1,
|
||||
end_line=definition.line or 1,
|
||||
)
|
||||
)
|
||||
|
||||
_dict, sources = get_function_sources_from_jedi(
|
||||
{function.file_path: {function.qualified_name}}, project_root
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find helpers for %s: %s", function.function_name, e)
|
||||
return []
|
||||
|
||||
return helpers
|
||||
return function_sources_to_helpers(sources)
|
||||
|
||||
def find_references(
|
||||
self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500
|
||||
|
|
@ -728,15 +656,6 @@ class PythonSupport:
|
|||
"""
|
||||
return ".py"
|
||||
|
||||
def get_comment_prefix(self) -> str:
|
||||
"""Get the comment prefix for Python.
|
||||
|
||||
Returns:
|
||||
Python single-line comment prefix.
|
||||
|
||||
"""
|
||||
return "#"
|
||||
|
||||
def find_test_root(self, project_root: Path) -> Path | None:
|
||||
"""Find the test root directory for a Python project.
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import libcst as cst
|
||||
from git import Repo as GitRepo
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.syntax import Syntax
|
||||
|
|
@ -71,8 +72,6 @@ from codeflash.code_utils.line_profile_utils import add_decorator_imports, conta
|
|||
from codeflash.code_utils.shell_utils import make_env_with_project_root
|
||||
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.context import code_context_extractor
|
||||
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
|
||||
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
|
||||
from codeflash.either import Failure, Success, is_successful
|
||||
from codeflash.languages import is_java, is_javascript, is_python
|
||||
|
|
@ -80,6 +79,11 @@ from codeflash.languages.base import Language
|
|||
from codeflash.languages.current import current_language_support, is_typescript
|
||||
from codeflash.languages.javascript.module_system import detect_module_system
|
||||
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files
|
||||
from codeflash.languages.python.context import code_context_extractor
|
||||
from codeflash.languages.python.context.unused_definition_remover import (
|
||||
detect_unused_helper_functions,
|
||||
revert_unused_helper_functions,
|
||||
)
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
|
||||
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
|
|
@ -2231,6 +2235,7 @@ class FunctionOptimizer:
|
|||
if self.args.override_fixtures:
|
||||
restore_conftest(original_conftest_content)
|
||||
cleanup_paths(paths_to_cleanup)
|
||||
self.cleanup_async_helper_file()
|
||||
return Failure(baseline_result.failure())
|
||||
|
||||
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
|
||||
|
|
@ -2242,6 +2247,7 @@ class FunctionOptimizer:
|
|||
if self.args.override_fixtures:
|
||||
restore_conftest(original_conftest_content)
|
||||
cleanup_paths(paths_to_cleanup)
|
||||
self.cleanup_async_helper_file()
|
||||
return Failure("The threshold for test confidence was not met.")
|
||||
|
||||
return Success(
|
||||
|
|
@ -2411,7 +2417,10 @@ class FunctionOptimizer:
|
|||
generated_tests_str = ""
|
||||
code_lang = self.function_to_optimize.language
|
||||
for test in generated_tests.generated_tests:
|
||||
if map_gen_test_file_to_no_of_tests[test.behavior_file_path] > 0:
|
||||
if any(
|
||||
test_file.name == test.behavior_file_path.name and count > 0
|
||||
for test_file, count in map_gen_test_file_to_no_of_tests.items()
|
||||
):
|
||||
formatted_generated_test = format_generated_code(
|
||||
test.generated_original_test_source, self.args.formatter_cmds
|
||||
)
|
||||
|
|
@ -2551,11 +2560,11 @@ class FunctionOptimizer:
|
|||
console.print(Panel(panel_content, title="Optimization Review", border_style=display_info[1]))
|
||||
|
||||
if raise_pr or staging_review:
|
||||
data["root_dir"] = git_root_dir()
|
||||
data["root_dir"] = git_root_dir(GitRepo(str(self.args.module_root), search_parent_directories=True))
|
||||
if raise_pr and not staging_review and opt_review_result.review != "low":
|
||||
# Ensure root_dir is set for PR creation (needed for async functions that skip opt_review)
|
||||
if "root_dir" not in data:
|
||||
data["root_dir"] = git_root_dir()
|
||||
data["root_dir"] = git_root_dir(GitRepo(str(self.args.module_root), search_parent_directories=True))
|
||||
data["git_remote"] = self.args.git_remote
|
||||
# Remove language from data dict as check_create_pr doesn't accept it
|
||||
pr_data = {k: v for k, v in data.items() if k != "language"}
|
||||
|
|
@ -2610,6 +2619,13 @@ class FunctionOptimizer:
|
|||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
self.cleanup_async_helper_file()
|
||||
|
||||
def cleanup_async_helper_file(self) -> None:
|
||||
from codeflash.code_utils.instrument_existing_tests import ASYNC_HELPER_FILENAME
|
||||
|
||||
helper_path = self.project_root / ASYNC_HELPER_FILENAME
|
||||
helper_path.unlink(missing_ok=True)
|
||||
|
||||
def establish_original_code_baseline(
|
||||
self,
|
||||
|
|
@ -2627,7 +2643,10 @@ class FunctionOptimizer:
|
|||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
success = add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.BEHAVIOR,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
|
||||
# Instrument codeflash capture
|
||||
|
|
@ -2692,7 +2711,10 @@ class FunctionOptimizer:
|
|||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.PERFORMANCE,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -2866,7 +2888,10 @@ class FunctionOptimizer:
|
|||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.BEHAVIOR,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -2961,7 +2986,10 @@ class FunctionOptimizer:
|
|||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.PERFORMANCE,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -3330,7 +3358,10 @@ class FunctionOptimizer:
|
|||
try:
|
||||
# Add concurrency decorator to the source function
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.CONCURRENCY,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
|
||||
# Run the concurrency benchmark tests
|
||||
|
|
|
|||
|
|
@ -183,18 +183,54 @@ class Optimizer:
|
|||
"""Discover functions to optimize."""
|
||||
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
||||
|
||||
return get_functions_to_optimize(
|
||||
# In worktree mode for git-diff discovery, file paths come from the original repo
|
||||
# (via get_git_diff using cwd), but module_root/project_root have been mirrored to
|
||||
# the worktree. Use the original roots for filtering so path comparisons match,
|
||||
# then remap the discovered file paths to the worktree.
|
||||
project_root = self.args.project_root
|
||||
module_root = self.args.module_root
|
||||
use_original_roots = (
|
||||
self.current_worktree and self.original_args_and_test_cfg and not self.args.all and not self.args.file
|
||||
)
|
||||
if use_original_roots:
|
||||
assert self.original_args_and_test_cfg is not None
|
||||
original_args, _ = self.original_args_and_test_cfg
|
||||
project_root = original_args.project_root
|
||||
module_root = original_args.module_root
|
||||
|
||||
result = get_functions_to_optimize(
|
||||
optimize_all=self.args.all,
|
||||
replay_test=self.args.replay_test,
|
||||
file=self.args.file,
|
||||
only_get_this_function=self.args.function,
|
||||
test_cfg=self.test_cfg,
|
||||
ignore_paths=self.args.ignore_paths,
|
||||
project_root=self.args.project_root,
|
||||
module_root=self.args.module_root,
|
||||
project_root=project_root,
|
||||
module_root=module_root,
|
||||
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
|
||||
)
|
||||
|
||||
# Remap discovered file paths from the original repo to the worktree so
|
||||
# downstream optimization reads/writes happen in the worktree.
|
||||
if use_original_roots:
|
||||
import dataclasses
|
||||
|
||||
assert self.current_worktree is not None
|
||||
original_git_root = git_root_dir()
|
||||
file_to_funcs, count, trace = result
|
||||
remapped: dict[Path, list[FunctionToOptimize]] = {}
|
||||
for file_path, funcs in file_to_funcs.items():
|
||||
new_path = mirror_path(Path(file_path), original_git_root, self.current_worktree)
|
||||
remapped[new_path] = [
|
||||
dataclasses.replace(
|
||||
func, file_path=mirror_path(func.file_path, original_git_root, self.current_worktree)
|
||||
)
|
||||
for func in funcs
|
||||
]
|
||||
return remapped, count, trace
|
||||
|
||||
return result
|
||||
|
||||
def create_function_optimizer(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
# These version placeholders will be replaced by uv-dynamic-versioning during build.
|
||||
__version__ = "0.20.0.post510.dev0+b8932209"
|
||||
__version__ = "0.20.1"
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ codeflash/result/explanation.py
|
|||
codeflash/result/critic.py
|
||||
codeflash/version.py
|
||||
codeflash/optimization/__init__.py
|
||||
codeflash/context/__init__.py
|
||||
codeflash/context/code_context_extractor.py
|
||||
codeflash/languages/python/context/__init__.py
|
||||
codeflash/languages/python/context/code_context_extractor.py
|
||||
codeflash/discovery/__init__.py
|
||||
codeflash/__init__.py
|
||||
codeflash/models/ExperimentMetadata.py
|
||||
|
|
|
|||
|
|
@ -113,21 +113,26 @@ function checkSharedTimeLimit() {
|
|||
|
||||
/**
|
||||
* Get the current loop index for a specific invocation.
|
||||
* The loop index represents how many times ALL test files have been run through.
|
||||
* This is the batch count from the loop-runner.
|
||||
* When using external loop-runner (Jest), returns the batch number directly.
|
||||
* When using internal looping (Vitest), tracks and returns the invocation count.
|
||||
*
|
||||
* @param {string} invocationKey - Unique key for this test invocation
|
||||
* @returns {number} The current batch number (loop index)
|
||||
* @returns {number} The loop index for timing markers (1-based)
|
||||
*/
|
||||
function getInvocationLoopIndex(invocationKey) {
|
||||
// Track local loop count for stopping logic (increments on each call)
|
||||
// When using external loop-runner, use the batch number directly
|
||||
// This is reliable because Jest resets module state between batches
|
||||
const currentBatch = process.env.CODEFLASH_PERF_CURRENT_BATCH;
|
||||
if (currentBatch !== undefined) {
|
||||
return parseInt(currentBatch, 10);
|
||||
}
|
||||
|
||||
// For internal looping (Vitest), track the count locally
|
||||
if (!sharedPerfState.invocationLoopCounts[invocationKey]) {
|
||||
sharedPerfState.invocationLoopCounts[invocationKey] = 0;
|
||||
}
|
||||
++sharedPerfState.invocationLoopCounts[invocationKey];
|
||||
|
||||
// Return the batch number as the loop index for timing markers
|
||||
// This represents how many times all test files have been run through
|
||||
return parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10);
|
||||
return sharedPerfState.invocationLoopCounts[invocationKey];
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -693,11 +698,9 @@ function capturePerf(funcName, lineId, fn, ...args) {
|
|||
// If not set, we're in Vitest mode and need to do all loops internally
|
||||
const hasExternalLoopRunner = process.env.CODEFLASH_PERF_CURRENT_BATCH !== undefined;
|
||||
|
||||
// Batched looping: run BATCH_SIZE loops per capturePerf call when using loop-runner
|
||||
// When using external loop-runner (Jest), execute only once per call - the loop-runner handles batching
|
||||
// For Vitest (no loop-runner), do all loops internally in a single call
|
||||
const batchSize = shouldLoop
|
||||
? (hasExternalLoopRunner ? getPerfBatchSize() : getPerfLoopCount())
|
||||
: 1;
|
||||
const batchSize = hasExternalLoopRunner ? 1 : (shouldLoop ? getPerfLoopCount() : 1);
|
||||
|
||||
// Initialize runtime tracking for this invocation if needed
|
||||
if (!sharedPerfState.invocationRuntimes[invocationKey]) {
|
||||
|
|
@ -710,21 +713,21 @@ function capturePerf(funcName, lineId, fn, ...args) {
|
|||
|
||||
for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) {
|
||||
// Check shared time limit BEFORE each iteration
|
||||
if (shouldLoop && checkSharedTimeLimit()) {
|
||||
if (!hasExternalLoopRunner && shouldLoop && checkSharedTimeLimit()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Check if this invocation has already reached stability
|
||||
if (getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) {
|
||||
if (!hasExternalLoopRunner && getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Get the loop index (batch number) for timing markers
|
||||
// Get the loop index for timing markers
|
||||
const loopIndex = getInvocationLoopIndex(invocationKey);
|
||||
|
||||
// Check if we've exceeded max loops for this invocation
|
||||
const totalIterations = getTotalIterations(invocationKey);
|
||||
if (totalIterations > getPerfLoopCount()) {
|
||||
if (!hasExternalLoopRunner && totalIterations > getPerfLoopCount()) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -776,7 +779,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
|
|||
}
|
||||
|
||||
// Check stability after accumulating enough samples
|
||||
if (getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) {
|
||||
if (!hasExternalLoopRunner && getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) {
|
||||
const window = getStabilityWindow();
|
||||
if (shouldStopStability(runtimes, window, getPerfMinLoops())) {
|
||||
sharedPerfState.stableInvocations[invocationKey] = true;
|
||||
|
|
@ -785,7 +788,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
|
|||
}
|
||||
|
||||
// If we had an error, stop looping
|
||||
if (lastError) {
|
||||
if (!hasExternalLoopRunner && lastError) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,69 +35,113 @@ const path = require('path');
|
|||
const fs = require('fs');
|
||||
|
||||
/**
|
||||
* Validates that a jest-runner path is valid by checking for package.json.
|
||||
* @param {string} jestRunnerPath - Path to check
|
||||
* @returns {boolean} True if valid jest-runner package
|
||||
* Recursively find jest-runner package in node_modules.
|
||||
* Works with any package manager (npm, yarn, pnpm) by searching for
|
||||
* jest-runner/package.json anywhere in the tree.
|
||||
*
|
||||
* @param {string} nodeModulesPath - Path to node_modules directory
|
||||
* @param {number} maxDepth - Maximum recursion depth (default: 5)
|
||||
* @returns {string|null} Path to jest-runner or null if not found
|
||||
*/
|
||||
function isValidJestRunnerPath(jestRunnerPath) {
|
||||
if (!fs.existsSync(jestRunnerPath)) {
|
||||
return false;
|
||||
function findJestRunnerRecursive(nodeModulesPath, maxDepth = 5) {
|
||||
function search(dir, depth) {
|
||||
if (depth > maxDepth || !fs.existsSync(dir)) return null;
|
||||
|
||||
try {
|
||||
let entries = fs.readdirSync(dir, { withFileTypes: true });
|
||||
|
||||
// Sort entries: prefer higher versions for jest-runner@X.Y.Z directories
|
||||
entries = entries.slice().sort((a, b) => {
|
||||
const aMatch = a.name.match(/^jest-runner@(\d+)/);
|
||||
const bMatch = b.name.match(/^jest-runner@(\d+)/);
|
||||
if (aMatch && bMatch) {
|
||||
return parseInt(bMatch[1], 10) - parseInt(aMatch[1], 10);
|
||||
}
|
||||
return a.name.localeCompare(b.name);
|
||||
});
|
||||
|
||||
for (const entry of entries) {
|
||||
if (!entry.isDirectory()) continue;
|
||||
|
||||
const entryPath = path.join(dir, entry.name);
|
||||
|
||||
// Found jest-runner directory - check if it's a valid package
|
||||
if (entry.name === 'jest-runner') {
|
||||
const pkgJsonPath = path.join(entryPath, 'package.json');
|
||||
if (fs.existsSync(pkgJsonPath)) {
|
||||
try {
|
||||
const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf8'));
|
||||
if (pkgJson.name === 'jest-runner') {
|
||||
return entryPath;
|
||||
}
|
||||
} catch (e) {
|
||||
// Ignore JSON parse errors
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into:
|
||||
// - node_modules subdirectories
|
||||
// - scoped packages (@org/pkg)
|
||||
// - hidden directories (.pnpm, .yarn, etc.)
|
||||
// - pnpm versioned directories (jest-runner@30.0.5)
|
||||
const shouldRecurse = entry.name === 'node_modules' ||
|
||||
entry.name.startsWith('@') ||
|
||||
entry.name === '.pnpm' || entry.name === '.yarn' ||
|
||||
entry.name.startsWith('jest-runner@');
|
||||
|
||||
if (shouldRecurse) {
|
||||
const result = search(entryPath, depth + 1);
|
||||
if (result) return result;
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
// Ignore permission errors
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
const packageJsonPath = path.join(jestRunnerPath, 'package.json');
|
||||
return fs.existsSync(packageJsonPath);
|
||||
|
||||
return search(nodeModulesPath, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve jest-runner with monorepo support.
|
||||
* Uses CODEFLASH_MONOREPO_ROOT environment variable if available,
|
||||
* otherwise walks up the directory tree looking for node_modules/jest-runner.
|
||||
* Resolve jest-runner from the PROJECT's node_modules (not codeflash's).
|
||||
*
|
||||
* Uses recursive search to find jest-runner anywhere in node_modules,
|
||||
* working with any package manager (npm, yarn, pnpm).
|
||||
*
|
||||
* @returns {string} Path to jest-runner package
|
||||
* @throws {Error} If jest-runner cannot be found
|
||||
*/
|
||||
function resolveJestRunner() {
|
||||
// Try standard resolution first (works in simple projects)
|
||||
try {
|
||||
return require.resolve('jest-runner');
|
||||
} catch (e) {
|
||||
// Standard resolution failed - try monorepo-aware resolution
|
||||
}
|
||||
|
||||
// If Python detected a monorepo root, check there first
|
||||
const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT;
|
||||
if (monorepoRoot) {
|
||||
const jestRunnerPath = path.join(monorepoRoot, 'node_modules', 'jest-runner');
|
||||
if (isValidJestRunnerPath(jestRunnerPath)) {
|
||||
return jestRunnerPath;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: Walk up from cwd looking for node_modules/jest-runner
|
||||
const monorepoMarkers = ['yarn.lock', 'pnpm-workspace.yaml', 'lerna.json', 'package-lock.json'];
|
||||
|
||||
// Walk up from cwd to find all potential node_modules locations
|
||||
let currentDir = process.cwd();
|
||||
const visitedDirs = new Set();
|
||||
|
||||
// If Python detected a monorepo root, check there first
|
||||
const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT;
|
||||
if (monorepoRoot && !visitedDirs.has(monorepoRoot)) {
|
||||
visitedDirs.add(monorepoRoot);
|
||||
const result = findJestRunnerRecursive(path.join(monorepoRoot, 'node_modules'));
|
||||
if (result) return result;
|
||||
}
|
||||
|
||||
while (currentDir !== path.dirname(currentDir)) {
|
||||
// Avoid infinite loops
|
||||
if (visitedDirs.has(currentDir)) break;
|
||||
visitedDirs.add(currentDir);
|
||||
|
||||
// Try node_modules/jest-runner at this level
|
||||
const jestRunnerPath = path.join(currentDir, 'node_modules', 'jest-runner');
|
||||
if (isValidJestRunnerPath(jestRunnerPath)) {
|
||||
return jestRunnerPath;
|
||||
}
|
||||
const result = findJestRunnerRecursive(path.join(currentDir, 'node_modules'));
|
||||
if (result) return result;
|
||||
|
||||
// Check if this is a workspace root (has monorepo markers)
|
||||
// Check if this is a workspace root - stop after this
|
||||
const isWorkspaceRoot = monorepoMarkers.some(marker =>
|
||||
fs.existsSync(path.join(currentDir, marker))
|
||||
);
|
||||
|
||||
if (isWorkspaceRoot) {
|
||||
// Found workspace root but no jest-runner - stop searching
|
||||
break;
|
||||
}
|
||||
|
||||
if (isWorkspaceRoot) break;
|
||||
currentDir = path.dirname(currentDir);
|
||||
}
|
||||
|
||||
|
|
@ -120,10 +164,15 @@ let jestVersion = 0;
|
|||
|
||||
try {
|
||||
const jestRunnerPath = resolveJestRunner();
|
||||
const internalRequire = createRequire(jestRunnerPath);
|
||||
|
||||
// Try to get the TestRunner class (Jest 30+)
|
||||
const jestRunner = internalRequire(jestRunnerPath);
|
||||
// Read the package.json to find the actual entry point and version
|
||||
const pkgJsonPath = path.join(jestRunnerPath, 'package.json');
|
||||
const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf8'));
|
||||
|
||||
// Require using the full path to the entry point
|
||||
const entryPoint = path.join(jestRunnerPath, pkgJson.main || 'build/index.js');
|
||||
const jestRunner = require(entryPoint);
|
||||
|
||||
TestRunner = jestRunner.default || jestRunner.TestRunner;
|
||||
|
||||
if (TestRunner && TestRunner.prototype && typeof TestRunner.prototype.runTests === 'function') {
|
||||
|
|
@ -131,9 +180,11 @@ try {
|
|||
jestVersion = 30;
|
||||
jestRunnerAvailable = true;
|
||||
} else {
|
||||
// Try Jest 29 style import
|
||||
// Try Jest 29 style import - runTest is in build/runTest.js
|
||||
try {
|
||||
runTest = internalRequire('./runTest').default;
|
||||
const runTestPath = path.join(jestRunnerPath, 'build', 'runTest.js');
|
||||
const runTestModule = require(runTestPath);
|
||||
runTest = runTestModule.default;
|
||||
if (typeof runTest === 'function') {
|
||||
// Jest 29 - use direct runTest function
|
||||
jestVersion = 29;
|
||||
|
|
@ -141,17 +192,23 @@ try {
|
|||
}
|
||||
} catch (e29) {
|
||||
// Neither Jest 29 nor 30 style import worked
|
||||
const errorMsg = `Found jest-runner at ${jestRunnerPath} but could not load it. ` +
|
||||
`This may indicate an unsupported Jest version. ` +
|
||||
`Supported versions: Jest 29.x and Jest 30.x`;
|
||||
console.error(errorMsg);
|
||||
jestRunnerAvailable = false;
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
// jest-runner not installed - this is expected for Vitest projects
|
||||
// The runner will throw a helpful error if someone tries to use it without jest-runner
|
||||
jestRunnerAvailable = false;
|
||||
// try to directly import jest-runner
|
||||
try {
|
||||
const jestRunner = require('jest-runner');
|
||||
TestRunner = jestRunner.default || jestRunner.TestRunner;
|
||||
if (TestRunner && TestRunner.prototype && typeof TestRunner.prototype.runTests === 'function') {
|
||||
jestVersion = 30;
|
||||
jestRunnerAvailable = true;
|
||||
} else {
|
||||
jestRunnerAvailable = false;
|
||||
}
|
||||
} catch (e2) {
|
||||
jestRunnerAvailable = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Configuration
|
||||
|
|
@ -233,15 +290,12 @@ class CodeflashLoopRunner {
|
|||
this._context = context || {};
|
||||
this._eventEmitter = new SimpleEventEmitter();
|
||||
|
||||
// For Jest 30+, create an instance of the base TestRunner for delegation
|
||||
if (jestVersion >= 30) {
|
||||
if (!TestRunner) {
|
||||
throw new Error(
|
||||
`Jest ${jestVersion} detected but TestRunner class not available. ` +
|
||||
`This indicates an internal error in loop-runner initialization.`
|
||||
);
|
||||
}
|
||||
this._baseRunner = new TestRunner(globalConfig, context);
|
||||
// For Jest 30+, verify TestRunner is available (we create fresh instances per batch)
|
||||
if (jestVersion >= 30 && !TestRunner) {
|
||||
throw new Error(
|
||||
`Jest ${jestVersion} detected but TestRunner class not available. ` +
|
||||
`This indicates an internal error in loop-runner initialization.`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -270,7 +324,7 @@ class CodeflashLoopRunner {
|
|||
* @param {Object} options - Jest runner options
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async runTests(tests, watcher, options) {
|
||||
async runTests(tests, watcher, ...rest) {
|
||||
const startTime = Date.now();
|
||||
let batchCount = 0;
|
||||
let hasFailure = false;
|
||||
|
|
@ -289,13 +343,11 @@ class CodeflashLoopRunner {
|
|||
|
||||
// Check time limit BEFORE each batch
|
||||
if (batchCount > MIN_BATCHES && checkTimeLimit()) {
|
||||
console.log(`[codeflash] Time limit reached after ${batchCount - 1} batches (${Date.now() - startTime}ms elapsed)`);
|
||||
break;
|
||||
}
|
||||
|
||||
// Check if interrupted
|
||||
if (watcher.isInterrupted()) {
|
||||
console.log(`[codeflash] Watcher is interrupted`)
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -303,57 +355,54 @@ class CodeflashLoopRunner {
|
|||
process.env.CODEFLASH_PERF_CURRENT_BATCH = String(batchCount);
|
||||
|
||||
// Run all test files in this batch
|
||||
const batchResult = await this._runAllTestsOnce(tests, watcher, options);
|
||||
const batchResult = await this._runAllTestsOnce(tests, watcher, ...rest);
|
||||
allConsoleOutput += batchResult.consoleOutput;
|
||||
|
||||
// if (batchResult.hasFailure) {
|
||||
// hasFailure = true;
|
||||
// break;
|
||||
// }
|
||||
|
||||
// Check time limit AFTER each batch
|
||||
if (checkTimeLimit()) {
|
||||
console.log(`[codeflash] Time limit reached after ${batchCount} batches (${Date.now() - startTime}ms elapsed)`);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const totalTimeMs = Date.now() - startTime;
|
||||
|
||||
console.log(`[codeflash] now: ${Date.now()}`)
|
||||
// Output all collected console logs - this is critical for timing marker extraction
|
||||
// The console output contains the !######...######! timing markers from capturePerf
|
||||
if (allConsoleOutput) {
|
||||
process.stdout.write(allConsoleOutput);
|
||||
}
|
||||
|
||||
console.log(`[codeflash] Batched runner completed: ${batchCount} batches, ${tests.length} test files, ${totalTimeMs}ms total`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Run all test files once (one batch).
|
||||
* Uses different approaches for Jest 29 vs Jest 30.
|
||||
*/
|
||||
async _runAllTestsOnce(tests, watcher, options) {
|
||||
async _runAllTestsOnce(tests, watcher, ...args) {
|
||||
if (jestVersion >= 30) {
|
||||
return this._runAllTestsOnceJest30(tests, watcher, options);
|
||||
return this._runAllTestsOnceJest30(tests, watcher, ...args);
|
||||
} else {
|
||||
return this._runAllTestsOnceJest29(tests, watcher);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Jest 30+ implementation - delegates to base TestRunner and collects results.
|
||||
* Jest 30+ implementation - creates a fresh TestRunner for each batch to avoid
|
||||
* state corruption issues that occur when reusing runners across batches.
|
||||
*/
|
||||
async _runAllTestsOnceJest30(tests, watcher, options) {
|
||||
async _runAllTestsOnceJest30(tests, watcher, ...args) {
|
||||
let hasFailure = false;
|
||||
let allConsoleOutput = '';
|
||||
|
||||
// For Jest 30, we need to collect results through event listeners
|
||||
const resultsCollector = [];
|
||||
|
||||
// Subscribe to events from the base runner
|
||||
const unsubscribeSuccess = this._baseRunner.on('test-file-success', (testData) => {
|
||||
// Create a FRESH TestRunner instance for each batch
|
||||
// Jest 30's TestRunner corrupts its internal state after running tests,
|
||||
// so we cannot reuse the same instance across multiple batches
|
||||
const batchRunner = new TestRunner(this._globalConfig, this._context);
|
||||
|
||||
// Subscribe to events from the batch runner
|
||||
const unsubscribeSuccess = batchRunner.on('test-file-success', (testData) => {
|
||||
const [test, result] = testData;
|
||||
resultsCollector.push({ test, result, success: true });
|
||||
|
||||
|
|
@ -369,7 +418,7 @@ class CodeflashLoopRunner {
|
|||
this._eventEmitter.emit('test-file-success', testData);
|
||||
});
|
||||
|
||||
const unsubscribeFailure = this._baseRunner.on('test-file-failure', (testData) => {
|
||||
const unsubscribeFailure = batchRunner.on('test-file-failure', (testData) => {
|
||||
const [test, error] = testData;
|
||||
resultsCollector.push({ test, error, success: false });
|
||||
hasFailure = true;
|
||||
|
|
@ -378,14 +427,14 @@ class CodeflashLoopRunner {
|
|||
this._eventEmitter.emit('test-file-failure', testData);
|
||||
});
|
||||
|
||||
const unsubscribeStart = this._baseRunner.on('test-file-start', (testData) => {
|
||||
const unsubscribeStart = batchRunner.on('test-file-start', (testData) => {
|
||||
// Forward to our event emitter
|
||||
this._eventEmitter.emit('test-file-start', testData);
|
||||
});
|
||||
|
||||
try {
|
||||
// Run tests using the base runner (always serial for benchmarking)
|
||||
await this._baseRunner.runTests(tests, watcher, { ...options, serial: true });
|
||||
// Run tests using the fresh batch runner (always serial for benchmarking)
|
||||
await batchRunner.runTests(tests, watcher, ...args);
|
||||
} finally {
|
||||
// Cleanup subscriptions
|
||||
if (typeof unsubscribeSuccess === 'function') unsubscribeSuccess();
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ description = "Client for codeflash.ai - automatic code performance optimization
|
|||
authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }]
|
||||
requires-python = ">=3.9"
|
||||
readme = "README.md"
|
||||
license = {text = "BSL-1.1"}
|
||||
license-files = ["LICENSE"]
|
||||
keywords = [
|
||||
"codeflash",
|
||||
"performance",
|
||||
|
|
@ -356,4 +356,3 @@ markers = [
|
|||
[build-system]
|
||||
requires = ["hatchling", "uv-dynamic-versioning"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
CoverageExpectation(
|
||||
function_name="retry_with_backoff",
|
||||
expected_coverage=100.0,
|
||||
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
|
||||
expected_lines=[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
ASYNC_HELPER_FILENAME,
|
||||
add_async_decorator_to_function,
|
||||
get_decorator_name_for_mode,
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -55,16 +57,23 @@ async def test_async_sort():
|
|||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
# For async functions, instrument the source module directly with decorators
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
|
||||
# Verify the file was modified
|
||||
# Verify the file was modified with exact expected output
|
||||
instrumented_source = fto_path.read_text("utf-8")
|
||||
assert (
|
||||
'''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
|
||||
in instrumented_source
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
decorated_original = original_code.replace(
|
||||
"async def async_sorter", f"@{decorator_name}\nasync def async_sorter"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
|
||||
# Add codeflash capture
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
|
@ -142,6 +151,9 @@ async def test_async_sort():
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -182,7 +194,9 @@ async def test_async_class_sort():
|
|||
is_async=True,
|
||||
)
|
||||
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
|
@ -264,6 +278,9 @@ async def test_async_class_sort():
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -294,16 +311,23 @@ async def test_async_perf():
|
|||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
# Instrument the source module with async performance decorators
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.PERFORMANCE)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.PERFORMANCE, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = fto_path.read_text("utf-8")
|
||||
assert (
|
||||
instrumented_source
|
||||
== '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
|
||||
decorated_original = original_code.replace(
|
||||
"async def async_sorter", f"@{decorator_name}\nasync def async_sorter"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
||||
|
|
@ -359,6 +383,9 @@ async def test_async_perf():
|
|||
# Clean up test files
|
||||
if test_path.exists():
|
||||
test_path.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -404,68 +431,24 @@ async def async_error_function(lst):
|
|||
function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True
|
||||
)
|
||||
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = fto_path.read_text("utf-8")
|
||||
|
||||
expected_instrumented_source = """import asyncio
|
||||
from typing import List, Union
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
|
||||
\"\"\"
|
||||
Async bubble sort implementation for testing.
|
||||
\"\"\"
|
||||
print("codeflash stdout: Async sorting list")
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
n = len(lst)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
if lst[j] > lst[j + 1]:
|
||||
lst[j], lst[j + 1] = lst[j + 1], lst[j]
|
||||
|
||||
result = lst.copy()
|
||||
print(f"result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
class AsyncBubbleSorter:
|
||||
\"\"\"Class with async sorting method for testing.\"\"\"
|
||||
|
||||
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
|
||||
\"\"\"
|
||||
Async bubble sort implementation within a class.
|
||||
\"\"\"
|
||||
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
|
||||
|
||||
# Add some async delay
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
n = len(lst)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
if lst[j] > lst[j + 1]:
|
||||
lst[j], lst[j + 1] = lst[j + 1], lst[j]
|
||||
|
||||
result = lst.copy()
|
||||
return result
|
||||
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_error_function(lst):
|
||||
\"\"\"Async function that raises an error for testing.\"\"\"
|
||||
await asyncio.sleep(0.001) # Small delay
|
||||
raise ValueError("Test error")
|
||||
"""
|
||||
assert expected_instrumented_source == instrumented_source
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
decorated_modified = modified_code.replace(
|
||||
"async def async_error_function", f"@{decorator_name}\nasync def async_error_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_modified}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
||||
opt = Optimizer(
|
||||
|
|
@ -526,6 +509,9 @@ async def async_error_function(lst):
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -563,7 +549,9 @@ async def test_async_multi():
|
|||
|
||||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
|
@ -636,6 +624,9 @@ async def test_async_multi():
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -678,7 +669,9 @@ async def test_async_edge_cases():
|
|||
|
||||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
|
@ -753,6 +746,9 @@ async def test_async_edge_cases():
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -988,7 +984,9 @@ async def test_mixed_sorting():
|
|||
function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True
|
||||
)
|
||||
|
||||
source_success = add_async_decorator_to_function(mixed_fto_path, async_func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
mixed_fto_path, async_func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
|
@ -1061,3 +1059,6 @@ async def test_mixed_sorting():
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
|
|
|||
|
|
@ -10,17 +10,15 @@ import pytest
|
|||
|
||||
from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.context.code_context_extractor import (
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import (
|
||||
collect_names_from_annotation,
|
||||
enrich_testgen_context,
|
||||
extract_classes_from_type_hint,
|
||||
extract_imports_for_class,
|
||||
get_code_optimization_context,
|
||||
get_external_base_class_inits,
|
||||
get_external_class_inits,
|
||||
get_imported_class_definitions,
|
||||
resolve_transitive_type_deps,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
|
@ -769,199 +767,6 @@ class HelperClass:
|
|||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_1(tmp_path: Path) -> None:
|
||||
docstring_filler = " ".join(
|
||||
["This is a long docstring that will be used to fill up the token limit." for _ in range(1000)]
|
||||
)
|
||||
code = f"""
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method.
|
||||
{docstring_filler}\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
# Create a temporary Python file using pytest's tmp_path fixture
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
|
||||
expected_read_write_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
pass
|
||||
|
||||
class HelperClass:
|
||||
def __repr__(self):
|
||||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_2(tmp_path: Path) -> None:
|
||||
string_filler = " ".join(
|
||||
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
||||
)
|
||||
code = f"""
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method. \"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
x = '{string_filler}'
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
# Create a temporary Python file using pytest's tmp_path fixture
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
|
||||
expected_read_write_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
"""A class with a helper method. """
|
||||
|
||||
class HelperClass:
|
||||
"""A helper class for MyClass."""
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the HelperClass."""
|
||||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
'''
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_3(tmp_path: Path) -> None:
|
||||
string_filler = " ".join(
|
||||
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
||||
|
|
@ -1009,7 +814,7 @@ class HelperClass:
|
|||
)
|
||||
# In this scenario, the read-writable code is too long, so we abort.
|
||||
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000)
|
||||
|
||||
|
||||
def test_example_class_token_limit_4(tmp_path: Path) -> None:
|
||||
|
|
@ -1062,7 +867,7 @@ class HelperClass:
|
|||
|
||||
# In this scenario, the read-writable code context becomes too large because the __init__ function is referencing the global x variable instead of the class attribute self.x, so we abort.
|
||||
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000)
|
||||
|
||||
|
||||
def test_example_class_token_limit_5(tmp_path: Path) -> None:
|
||||
|
|
@ -2422,7 +2227,7 @@ class OuterClass:
|
|||
assert "__init__" not in hashing_context # Should not contain __init__ methods
|
||||
|
||||
# Verify nested classes are excluded from the hashing context
|
||||
# The prune_cst_for_code_hashing function should not recurse into nested classes
|
||||
# The prune_cst function in hashing mode should not recurse into nested classes
|
||||
assert "class NestedClass:" not in hashing_context # Nested class definition should not be present
|
||||
|
||||
# The target method will reference NestedClass, but the actual nested class definition should not be included
|
||||
|
|
@ -3275,8 +3080,8 @@ def dump_layout(layout_type, layout):
|
|||
assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions extracts class definitions from project modules."""
|
||||
def test_enrich_testgen_context_extracts_project_classes(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context extracts class definitions from project modules."""
|
||||
# Create a package structure with two modules
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3325,8 +3130,8 @@ class Accumulator:
|
|||
# Create CodeStringsMarkdown from the chunking module (simulating testgen context)
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Verify Element class was extracted
|
||||
assert len(result.code_strings) == 1, "Should extract exactly one class (Element)"
|
||||
|
|
@ -3339,8 +3144,8 @@ class Accumulator:
|
|||
assert "import abc" in extracted_code, "Should include necessary imports for base class"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions skips classes already defined in context."""
|
||||
def test_enrich_testgen_context_skips_existing_definitions(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context skips classes already defined in context."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3373,15 +3178,15 @@ class User:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should NOT extract Element since it's already defined locally
|
||||
assert len(result.code_strings) == 0, "Should not extract classes already defined in context"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions skips third-party/stdlib imports."""
|
||||
def test_enrich_testgen_context_skips_third_party(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context skips third-party/stdlib imports."""
|
||||
# Create a simple package
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3402,15 +3207,15 @@ class MyClass:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should not extract any classes (Path, Optional, dataclass are stdlib/third-party)
|
||||
assert len(result.code_strings) == 0, "Should not extract stdlib/third-party classes"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions handles multiple class imports."""
|
||||
def test_enrich_testgen_context_handles_multiple_imports(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context handles multiple class imports."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3446,8 +3251,8 @@ class Processor:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract both TypeA and TypeB (but not TypeC since it's not imported)
|
||||
assert len(result.code_strings) == 2, "Should extract exactly two classes (TypeA, TypeB)"
|
||||
|
|
@ -3458,8 +3263,8 @@ class Processor:
|
|||
assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions includes decorators when extracting dataclasses."""
|
||||
def test_enrich_testgen_context_includes_dataclass_decorators(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context includes decorators when extracting dataclasses."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3496,8 +3301,8 @@ class ConfigRegistry:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract both LLMConfigBase (base class) and LLMConfig
|
||||
assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase"
|
||||
|
|
@ -3521,7 +3326,7 @@ class ConfigRegistry:
|
|||
assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
|
||||
"""Test that extract_imports_for_class includes decorator and type annotation imports."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
|
|
@ -3552,7 +3357,7 @@ def create_config() -> Config:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1, "Should extract Config class"
|
||||
extracted_code = result.code_strings[0].code
|
||||
|
|
@ -3724,7 +3529,7 @@ class MyClass:
|
|||
assert result.count("from typing import Optional") == 1
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_multiple_decorators(tmp_path: Path) -> None:
|
||||
"""Test that classes with multiple decorators are extracted correctly."""
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3755,7 +3560,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
extracted_code = result.code_strings[0].code
|
||||
|
|
@ -3766,7 +3571,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
|||
assert "class OrderedConfig" in extracted_code
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_multilevel_inheritance(tmp_path: Path) -> None:
|
||||
"""Test that base classes are recursively extracted for multi-level inheritance.
|
||||
|
||||
This is critical for understanding dataclass constructor signatures, as fields
|
||||
|
|
@ -3826,8 +3631,8 @@ class ConfigRegistry:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig
|
||||
# (all classes needed to understand the full inheritance hierarchy)
|
||||
|
|
@ -3862,7 +3667,7 @@ class ConfigRegistry:
|
|||
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_extracts_userdict(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_userdict(tmp_path: Path) -> None:
|
||||
"""Extracts __init__ from collections.UserDict when a class inherits from it."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
|
|
@ -3873,7 +3678,7 @@ class MyCustomDict(UserDict):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
code_string = result.code_strings[0]
|
||||
|
|
@ -3891,8 +3696,8 @@ class UserDict:
|
|||
assert code_string.file_path.as_posix().endswith("collections/__init__.py")
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_skips_project_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when base class is from the project, not external."""
|
||||
def test_enrich_testgen_context_skips_unresolvable_base_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when base class module cannot be resolved."""
|
||||
child_code = """from base import ProjectBase
|
||||
|
||||
class Child(ProjectBase):
|
||||
|
|
@ -3902,12 +3707,12 @@ class Child(ProjectBase):
|
|||
child_path.write_text(child_code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=child_code, file_path=child_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_skips_builtin_base_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin classes like list that have no inspectable source."""
|
||||
code = """class MyList(list):
|
||||
pass
|
||||
|
|
@ -3916,12 +3721,12 @@ def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_deduplicates(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
|
||||
"""Extracts the same external base class only once even when inherited multiple times."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
|
|
@ -3935,7 +3740,7 @@ class MyDict2(UserDict):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
expected_code = """\
|
||||
|
|
@ -3950,7 +3755,7 @@ class UserDict:
|
|||
assert result.code_strings[0].code == expected_code
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_empty_when_no_inheritance(tmp_path: Path) -> None:
|
||||
"""Returns empty when there are no external base classes."""
|
||||
code = """class SimpleClass:
|
||||
pass
|
||||
|
|
@ -3959,7 +3764,7 @@ def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path)
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
|
@ -4103,127 +3908,8 @@ class MyCustomDict(UserDict):
|
|||
assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included"
|
||||
|
||||
|
||||
def test_read_only_code_removed_when_exceeds_limit(tmp_path: Path) -> None:
|
||||
"""Test read-only code is completely removed when it exceeds token limit even without docstrings.
|
||||
|
||||
This covers lines 152-153 in code_context_extractor.py where read_only_context_code is set
|
||||
to empty string when it still exceeds the token limit after docstring removal.
|
||||
"""
|
||||
# Create a second-degree helper with large implementation that has no docstrings
|
||||
# Second-degree helpers go into read-only context
|
||||
long_lines = [" x = 0"]
|
||||
for i in range(150):
|
||||
long_lines.append(f" x = x + {i}")
|
||||
long_lines.append(" return x")
|
||||
long_body = "\n".join(long_lines)
|
||||
|
||||
code = f"""
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_method(self):
|
||||
return first_helper()
|
||||
|
||||
|
||||
def first_helper():
|
||||
# First degree helper - calls second degree
|
||||
return second_helper()
|
||||
|
||||
|
||||
def second_helper():
|
||||
# Second degree helper - goes into read-only context
|
||||
{long_body}
|
||||
"""
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
# Use a small optim_token_limit that allows read-writable but not read-only
|
||||
# Read-writable is ~48 tokens, read-only is ~600 tokens
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
optim_token_limit=100, # Small limit to trigger read-only removal
|
||||
)
|
||||
|
||||
# The read-only context should be empty because it exceeded the limit
|
||||
assert code_ctx.read_only_context_code == "", "Read-only code should be removed when exceeding token limit"
|
||||
|
||||
|
||||
def test_testgen_removes_imported_classes_on_overflow(tmp_path: Path) -> None:
|
||||
"""Test testgen context removes imported class definitions when exceeding token limit.
|
||||
|
||||
This covers lines 176-186 in code_context_extractor.py where:
|
||||
- Testgen context exceeds limit (line 175)
|
||||
- Removing docstrings still exceeds (line 175 again)
|
||||
- Removing imported classes succeeds (line 177-183)
|
||||
"""
|
||||
# Create a package structure with a large type class used only in type annotations
|
||||
# This ensures get_imported_class_definitions extracts the full class
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a large class with methods that will be extracted via get_imported_class_definitions
|
||||
# Use methods WITHOUT docstrings so removing docstrings won't help much
|
||||
many_methods = "\n".join([f" def method_{i}(self):\n return {i}" for i in range(100)])
|
||||
type_class_code = f'''
|
||||
class TypeClass:
|
||||
"""A type class for annotations."""
|
||||
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
{many_methods}
|
||||
'''
|
||||
type_class_path = package_dir / "types.py"
|
||||
type_class_path.write_text(type_class_code, encoding="utf-8")
|
||||
|
||||
# Main module uses TypeClass only in annotation (not instantiated)
|
||||
# This triggers get_imported_class_definitions to extract the full class
|
||||
main_code = """
|
||||
from mypackage.types import TypeClass
|
||||
|
||||
def target_function(obj: TypeClass) -> int:
|
||||
return obj.value
|
||||
"""
|
||||
main_path = package_dir / "main.py"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(function_name="target_function", file_path=main_path, parents=[])
|
||||
|
||||
# Use a testgen_token_limit that:
|
||||
# - Is exceeded by full context with imported class (~1500 tokens)
|
||||
# - Is exceeded even after removing docstrings
|
||||
# - But fits when imported class is removed (~40 tokens)
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
testgen_token_limit=200, # Small limit to trigger imported class removal
|
||||
)
|
||||
|
||||
# The testgen context should exist (didn't raise ValueError)
|
||||
testgen_context = code_ctx.testgen_context.markdown
|
||||
assert testgen_context, "Testgen context should not be empty"
|
||||
|
||||
# The target function should still be there
|
||||
assert "def target_function" in testgen_context, "Target function should be in testgen context"
|
||||
|
||||
# The large imported class should NOT be included (removed due to token limit)
|
||||
assert "class TypeClass" not in testgen_context, (
|
||||
"TypeClass should be removed from testgen context when exceeding token limit"
|
||||
)
|
||||
|
||||
|
||||
def test_testgen_raises_when_all_fallbacks_fail(tmp_path: Path) -> None:
|
||||
"""Test that ValueError is raised when testgen context exceeds limit even after all fallbacks.
|
||||
|
||||
This covers line 186 in code_context_extractor.py.
|
||||
"""
|
||||
def test_testgen_raises_when_exceeds_limit(tmp_path: Path) -> None:
|
||||
"""Test that ValueError is raised when testgen context exceeds token limit."""
|
||||
# Create a function with a very long body that exceeds limits even without imports/docstrings
|
||||
long_lines = [" x = 0"]
|
||||
for i in range(200):
|
||||
|
|
@ -4249,7 +3935,7 @@ def target_function():
|
|||
)
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_attribute_base(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
|
||||
"""Test handling of base class accessed as module.ClassName (ast.Attribute).
|
||||
|
||||
This covers line 616 in code_context_extractor.py.
|
||||
|
|
@ -4265,7 +3951,7 @@ class MyDict(UserDict):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract UserDict __init__
|
||||
assert len(result.code_strings) == 1
|
||||
|
|
@ -4273,7 +3959,7 @@ class MyDict(UserDict):
|
|||
assert "def __init__" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_no_init_method(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_no_init_method(tmp_path: Path) -> None:
|
||||
"""Test handling when base class has no __init__ method.
|
||||
|
||||
This covers line 641 in code_context_extractor.py.
|
||||
|
|
@ -4288,7 +3974,7 @@ class MyProtocol(Protocol):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Protocol's __init__ can't be easily inspected, should handle gracefully
|
||||
# Result may be empty or contain Protocol based on implementation
|
||||
|
|
@ -4377,7 +4063,7 @@ class MyClass:
|
|||
|
||||
|
||||
def test_imported_class_definitions_module_path_none(tmp_path: Path) -> None:
|
||||
"""Test handling when module_path is None in get_imported_class_definitions.
|
||||
"""Test handling when module_path is None in enrich_testgen_context.
|
||||
|
||||
This covers line 560 in code_context_extractor.py.
|
||||
"""
|
||||
|
|
@ -4393,123 +4079,12 @@ class MyClass:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should handle gracefully and return empty or partial results
|
||||
assert isinstance(result.code_strings, list)
|
||||
|
||||
|
||||
def test_get_imported_names_import_star(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles import * correctly.
|
||||
|
||||
This covers lines 808-809 and 824-825 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
# Test regular import *
|
||||
# Note: "import *" is not valid Python, but "from x import *" is
|
||||
from_import_star = cst.parse_statement("from os import *")
|
||||
assert isinstance(from_import_star, cst.SimpleStatementLine)
|
||||
import_node = from_import_star.body[0]
|
||||
assert isinstance(import_node, cst.ImportFrom)
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert result == {"*"}
|
||||
|
||||
|
||||
def test_get_imported_names_aliased_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles aliased imports correctly.
|
||||
|
||||
This covers lines 812-813 and 828-829 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test import with alias
|
||||
import_stmt = cst.parse_statement("import numpy as np")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "np" in result
|
||||
|
||||
# Test from import with alias
|
||||
from_import_stmt = cst.parse_statement("from os import path as ospath")
|
||||
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
|
||||
from_import_node = from_import_stmt.body[0]
|
||||
assert isinstance(from_import_node, cst.ImportFrom)
|
||||
|
||||
result2 = get_imported_names(from_import_node)
|
||||
assert "ospath" in result2
|
||||
|
||||
|
||||
def test_get_imported_names_dotted_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles dotted imports correctly.
|
||||
|
||||
This covers lines 816-822 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test dotted import like "import os.path"
|
||||
import_stmt = cst.parse_statement("import os.path")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "os" in result
|
||||
|
||||
|
||||
def test_used_name_collector_comprehensive(tmp_path: Path) -> None:
|
||||
"""Test UsedNameCollector handles various node types.
|
||||
|
||||
This covers lines 767-801 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import UsedNameCollector
|
||||
|
||||
code = """
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
x: int = 1
|
||||
y = os.path.join("a", "b")
|
||||
|
||||
class MyClass:
|
||||
z = 10
|
||||
|
||||
def my_func():
|
||||
pass
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
collector = UsedNameCollector()
|
||||
# In libcst, the walker traverses the module
|
||||
cst.MetadataWrapper(module).visit(collector)
|
||||
|
||||
# Check used names
|
||||
assert "os" in collector.used_names
|
||||
assert "int" in collector.used_names
|
||||
assert "List" in collector.used_names
|
||||
|
||||
# Check defined names
|
||||
assert "x" in collector.defined_names
|
||||
assert "y" in collector.defined_names
|
||||
assert "MyClass" in collector.defined_names
|
||||
assert "my_func" in collector.defined_names
|
||||
|
||||
# Check external names (used but not defined)
|
||||
external = collector.get_external_names()
|
||||
assert "os" in external
|
||||
assert "x" not in external # x is defined
|
||||
|
||||
|
||||
def test_imported_class_with_base_in_same_module(tmp_path: Path) -> None:
|
||||
"""Test that imported classes with bases in the same module are extracted correctly.
|
||||
|
||||
|
|
@ -4549,52 +4124,13 @@ def target_function(obj: DerivedClass) -> bool:
|
|||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=main_code, file_path=main_path)])
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract the inheritance chain
|
||||
all_code = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class BaseClass" in all_code or "class DerivedClass" in all_code
|
||||
|
||||
|
||||
def test_get_imported_names_from_import_without_alias(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles from imports without aliases.
|
||||
|
||||
This covers lines 830-831 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test from import without alias
|
||||
from_import_stmt = cst.parse_statement("from os import path, getcwd")
|
||||
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
|
||||
from_import_node = from_import_stmt.body[0]
|
||||
assert isinstance(from_import_node, cst.ImportFrom)
|
||||
|
||||
result = get_imported_names(from_import_node)
|
||||
assert "path" in result
|
||||
assert "getcwd" in result
|
||||
|
||||
|
||||
def test_get_imported_names_regular_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles regular imports.
|
||||
|
||||
This covers lines 814-815 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test regular import without alias
|
||||
import_stmt = cst.parse_statement("import json")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "json" in result
|
||||
|
||||
|
||||
def test_augmented_assignment_not_in_context(tmp_path: Path) -> None:
|
||||
"""Test that augmented assignments are handled but not included unless used.
|
||||
|
||||
|
|
@ -4625,7 +4161,7 @@ class MyClass:
|
|||
assert "counter" in read_writable
|
||||
|
||||
|
||||
def test_get_external_class_inits_extracts_click_option(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_click_option(tmp_path: Path) -> None:
|
||||
"""Extracts __init__ from click.Option when directly imported."""
|
||||
code = """from click import Option
|
||||
|
||||
|
|
@ -4636,7 +4172,7 @@ def my_func(opt: Option) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
code_string = result.code_strings[0]
|
||||
|
|
@ -4645,8 +4181,8 @@ def my_func(opt: Option) -> None:
|
|||
assert code_string.file_path is not None and "click" in code_string.file_path.as_posix()
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_project_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when imported class is from the project, not external."""
|
||||
def test_enrich_testgen_context_extracts_project_class_defs(tmp_path: Path) -> None:
|
||||
"""Extracts project class definitions via jedi resolution."""
|
||||
# Create a project module with a class
|
||||
(tmp_path / "mymodule.py").write_text("class ProjectClass:\n pass\n", encoding="utf-8")
|
||||
|
||||
|
|
@ -4659,12 +4195,13 @@ def my_func(obj: ProjectClass) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class ProjectClass" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_non_classes(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_skips_non_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when imported name is a function, not a class."""
|
||||
code = """from collections import OrderedDict
|
||||
from os.path import join
|
||||
|
|
@ -4676,7 +4213,7 @@ def my_func() -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# join is a function, not a class — should be skipped
|
||||
# OrderedDict is a class and should be included
|
||||
|
|
@ -4684,8 +4221,8 @@ def my_func() -> None:
|
|||
assert not any("join" in name for name in class_names)
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_already_defined_classes(tmp_path: Path) -> None:
|
||||
"""Skips classes already defined in the context (e.g., added by get_imported_class_definitions)."""
|
||||
def test_enrich_testgen_context_skips_already_defined_classes(tmp_path: Path) -> None:
|
||||
"""Skips classes already defined in the context (e.g., added by enrich_testgen_context)."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
class UserDict:
|
||||
|
|
@ -4699,14 +4236,14 @@ def my_func(d: UserDict) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# UserDict is already defined in the context, so it should be skipped
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_builtins(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin classes like list/dict that have no inspectable source."""
|
||||
def test_enrich_testgen_context_skips_builtin_annotations(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin type annotations like list/dict that are not imported."""
|
||||
code = """x: list = []
|
||||
y: dict = {}
|
||||
|
||||
|
|
@ -4717,12 +4254,12 @@ def my_func() -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_object_init(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_skips_object_init(tmp_path: Path) -> None:
|
||||
"""Skips classes whose __init__ is just object.__init__ (trivial)."""
|
||||
# enum.Enum has a metaclass-based __init__, but individual enum members
|
||||
# effectively use object.__init__. Use a class we know has object.__init__.
|
||||
|
|
@ -4735,14 +4272,14 @@ def my_func(q: QName) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# QName has its own __init__, so it should be included if it's in site-packages.
|
||||
# But since it's stdlib (not site-packages), it should be skipped.
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None:
|
||||
"""Returns empty when there are no from-imports."""
|
||||
code = """def my_func() -> None:
|
||||
pass
|
||||
|
|
@ -4751,7 +4288,7 @@ def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
|
@ -4840,17 +4377,17 @@ def test_resolve_transitive_type_deps_handles_failure_gracefully() -> None:
|
|||
"""Returns empty list for a class where get_type_hints fails."""
|
||||
|
||||
class BadClass:
|
||||
def __init__(self, x: "NonexistentType") -> None: # type: ignore[name-defined] # noqa: F821
|
||||
def __init__(self, x: NonexistentType) -> None: # type: ignore[name-defined] # noqa: F821
|
||||
pass
|
||||
|
||||
result = resolve_transitive_type_deps(BadClass)
|
||||
assert result == []
|
||||
|
||||
|
||||
# --- Integration tests for transitive resolution in get_external_class_inits ---
|
||||
# --- Integration tests for transitive resolution in enrich_testgen_context ---
|
||||
|
||||
|
||||
def test_get_external_class_inits_transitive_deps(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_transitive_deps(tmp_path: Path) -> None:
|
||||
"""Extracts transitive type dependencies from __init__ annotations."""
|
||||
code = """from click import Context
|
||||
|
||||
|
|
@ -4861,7 +4398,7 @@ def my_func(ctx: Context) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
class_names = {cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings}
|
||||
assert "Context" in class_names
|
||||
|
|
@ -4869,7 +4406,7 @@ def my_func(ctx: Context) -> None:
|
|||
assert "Command" in class_names
|
||||
|
||||
|
||||
def test_get_external_class_inits_no_infinite_loops(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_no_infinite_loops(tmp_path: Path) -> None:
|
||||
"""Handles classes with circular type references without infinite loops."""
|
||||
# click.Context references Command, and Command references Context back
|
||||
# This should terminate without issues due to the processed_classes set
|
||||
|
|
@ -4882,13 +4419,13 @@ def my_func(ctx: Context) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should complete without hanging; just verify we got results
|
||||
assert len(result.code_strings) >= 1
|
||||
|
||||
|
||||
def test_get_external_class_inits_no_duplicate_stubs(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_no_duplicate_stubs(tmp_path: Path) -> None:
|
||||
"""Does not emit duplicate stubs for the same class name."""
|
||||
code = """from click import Context
|
||||
|
||||
|
|
@ -4899,7 +4436,7 @@ def my_func(ctx: Context) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
class_names = [cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings]
|
||||
assert len(class_names) == len(set(class_names)), f"Duplicate class stubs found: {class_names}"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
ASYNC_HELPER_FILENAME,
|
||||
add_async_decorator_to_function,
|
||||
get_decorator_name_for_mode,
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -57,20 +59,6 @@ def test_async_decorator_application_behavior_mode(temp_dir):
|
|||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -86,7 +74,16 @@ async def async_function(x: int, y: int) -> int:
|
|||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = async_function_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -94,20 +91,6 @@ def test_async_decorator_application_performance_mode(temp_dir):
|
|||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_performance_async
|
||||
|
||||
|
||||
@codeflash_performance_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -123,7 +106,16 @@ async def async_function(x: int, y: int) -> int:
|
|||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
|
||||
code_with_decorator = async_function_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -132,20 +124,6 @@ def test_async_decorator_application_concurrency_mode(temp_dir):
|
|||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_concurrency_async
|
||||
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -161,7 +139,16 @@ async def async_function(x: int, y: int) -> int:
|
|||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.CONCURRENCY)
|
||||
code_with_decorator = async_function_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -182,27 +169,6 @@ class Calculator:
|
|||
return a - b
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
class Calculator:
|
||||
"""Test class with async methods."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_method(self, a: int, b: int) -> int:
|
||||
"""Async method in class."""
|
||||
await asyncio.sleep(0.005)
|
||||
return a ** b
|
||||
|
||||
def sync_method(self, a: int, b: int) -> int:
|
||||
"""Sync method in class."""
|
||||
return a - b
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_class_code)
|
||||
|
||||
|
|
@ -217,11 +183,21 @@ class Calculator:
|
|||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = async_class_code.replace(
|
||||
" async def async_method", f" @{decorator_name}\n async def async_method"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_decorator_no_duplicate_application(temp_dir):
|
||||
# Case 1: Old-style import already present — injector should detect and skip
|
||||
already_decorated_code = '''
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
|
||||
import asyncio
|
||||
|
|
@ -243,6 +219,30 @@ async def async_function(x: int, y: int) -> int:
|
|||
# Should not add duplicate decorator
|
||||
assert not decorator_added
|
||||
|
||||
# Case 2: Inline definition already present — injector should detect and skip
|
||||
already_inline_code = '''
|
||||
import asyncio
|
||||
|
||||
def codeflash_behavior_async(func):
|
||||
return func
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Already decorated async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
test_file2 = temp_dir / "test_async2.py"
|
||||
test_file2.write_text(already_inline_code)
|
||||
|
||||
func2 = FunctionToOptimize(function_name="async_function", file_path=test_file2, parents=[], is_async=True)
|
||||
|
||||
decorator_added2 = add_async_decorator_to_function(test_file2, func2, TestingMode.BEHAVIOR)
|
||||
|
||||
# Should not add duplicate decorator
|
||||
assert not decorator_added2
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_inject_profiling_async_function_behavior_mode(temp_dir):
|
||||
|
|
@ -285,11 +285,18 @@ async def test_async_function():
|
|||
|
||||
assert source_success is True
|
||||
|
||||
# Verify the file was modified
|
||||
# Verify the file was modified with exact expected output
|
||||
instrumented_source = source_file.read_text()
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_behavior_async" in instrumented_source
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = source_module_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
async_test_code, test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, mode=TestingMode.BEHAVIOR
|
||||
|
|
@ -340,12 +347,18 @@ async def test_async_function():
|
|||
|
||||
assert source_success is True
|
||||
|
||||
# Verify the file was modified
|
||||
# Verify the file was modified with exact expected output
|
||||
instrumented_source = source_file.read_text()
|
||||
assert "@codeflash_performance_async" in instrumented_source
|
||||
# Check for the import with line continuation formatting
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_performance_async" in instrumented_source
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
|
||||
code_with_decorator = source_module_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
# Now test the full pipeline with source module path
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
|
|
@ -406,11 +419,16 @@ async def test_mixed_functions():
|
|||
|
||||
# Verify the file was modified
|
||||
instrumented_source = source_file.read_text()
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_behavior_async" in instrumented_source
|
||||
# Sync function should remain unchanged
|
||||
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = source_module_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
mixed_test_code, test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR
|
||||
|
|
@ -446,24 +464,19 @@ class OuterClass:
|
|||
|
||||
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
|
||||
|
||||
expected_output = """import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
class OuterClass:
|
||||
class InnerClass:
|
||||
@codeflash_behavior_async
|
||||
async def nested_async_method(self, x: int) -> int:
|
||||
\"\"\"Nested async method.\"\"\"
|
||||
await asyncio.sleep(0.001)
|
||||
return x * 2
|
||||
"""
|
||||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_output.strip()
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = nested_async_code.replace(
|
||||
" async def nested_async_method",
|
||||
f" @{decorator_name}\n async def nested_async_method",
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
|
|||
|
|
@ -20,14 +20,12 @@ All assertions use strict string equality to verify exact extraction output.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context_for_language
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context_for_language
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -106,9 +106,9 @@ class TestJavaScriptCodeContext:
|
|||
def test_extract_code_context_for_javascript(self, js_project_dir):
|
||||
"""Test extracting code context for a JavaScript function."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ These tests verify the full optimization pipeline including:
|
|||
This is the JavaScript equivalent of test_instrument_tests.py for Python.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -71,9 +70,9 @@ module.exports = { add };
|
|||
def test_code_context_preserves_language(self, tmp_path):
|
||||
"""Verify language is preserved in code context extraction."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
|
||||
|
|
@ -164,7 +163,7 @@ export function add(a: number, b: number): number {
|
|||
|
||||
# Mock the AI service request
|
||||
ai_client = AiServiceClient()
|
||||
with patch.object(ai_client, 'make_ai_service_request') as mock_request:
|
||||
with patch.object(ai_client, "make_ai_service_request") as mock_request:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
|
|
@ -191,8 +190,8 @@ export function add(a: number, b: number): number {
|
|||
# Verify the request was made with correct language
|
||||
assert mock_request.called, "API request should have been made"
|
||||
call_args = mock_request.call_args
|
||||
payload = call_args[1].get('payload', call_args[0][1] if len(call_args[0]) > 1 else {})
|
||||
assert payload.get('language') == 'typescript', \
|
||||
payload = call_args[1].get("payload", call_args[0][1] if len(call_args[0]) > 1 else {})
|
||||
assert payload.get("language") == "typescript", \
|
||||
f"Expected language='typescript', got language='{payload.get('language')}'"
|
||||
|
||||
|
||||
|
|
@ -462,7 +461,7 @@ class TestHelperFunctionLanguageAttribute:
|
|||
"""Verify helper functions have language='javascript' for .js files."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current, get_language_support
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class TestTypeScriptFunctionDiscovery:
|
|||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f:
|
||||
f.write("""
|
||||
f.write(r"""
|
||||
export function add(a: number, b: number): number {
|
||||
return a + b;
|
||||
}
|
||||
|
|
@ -123,9 +123,9 @@ class TestTypeScriptCodeContext:
|
|||
def test_extract_code_context_for_typescript(self, ts_project_dir):
|
||||
"""Test extracting code context for a TypeScript function."""
|
||||
skip_if_ts_not_supported()
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
|
||||
|
|
@ -201,7 +201,7 @@ function multiply(a: number, b: number): number {
|
|||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
|
||||
original_source = """
|
||||
original_source = r"""
|
||||
interface Config {
|
||||
timeout: number;
|
||||
retries: number;
|
||||
|
|
@ -212,7 +212,7 @@ function processConfig(config: Config): string {
|
|||
}
|
||||
"""
|
||||
|
||||
new_function = """function processConfig(config: Config): string {
|
||||
new_function = r"""function processConfig(config: Config): string {
|
||||
// Optimized with template caching
|
||||
const { timeout, retries } = config;
|
||||
return `timeout=\${timeout}, retries=\${retries}`;
|
||||
|
|
|
|||
|
|
@ -117,10 +117,10 @@ class TestVitestCodeContext:
|
|||
def test_extract_code_context_for_typescript(self, vitest_project_dir):
|
||||
"""Test extracting code context for a TypeScript function."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
|
||||
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
|
||||
from codeflash.languages.python.context.unused_definition_remover import remove_unused_definitions_by_function_names
|
||||
|
||||
|
||||
def test_variable_removal_only() -> None:
|
||||
|
|
|
|||
|
|
@ -5,8 +5,11 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.unused_definition_remover import (
|
||||
detect_unused_helper_functions,
|
||||
revert_unused_helper_functions,
|
||||
)
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
|
|
|||
Loading…
Reference in a new issue