Merge commit '6020c4fa' into sync-main-batch-3

This commit is contained in:
Kevin Turcios 2026-02-19 20:33:09 -05:00
commit 85d1d4fbf6
41 changed files with 1533 additions and 1974 deletions

View file

@ -42,11 +42,17 @@ jobs:
uv venv --seed
uv sync
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }}
aws-region: ${{ secrets.AWS_REGION }}
- name: Run Claude Code
id: claude
uses: anthropics/claude-code-action@v1
with:
use_foundry: "true"
use_bedrock: "true"
use_sticky_comment: true
allowed_bots: "claude[bot],codeflash-ai[bot]"
prompt: |
@ -173,12 +179,9 @@ jobs:
2. For each optimization PR:
- Check if CI is passing: `gh pr checks <number>`
- If all checks pass, merge it: `gh pr merge <number> --squash --delete-branch`
claude_args: '--model claude-opus-4-6 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"'
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"'
additional_permissions: |
actions: read
env:
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}
# @claude mentions (can edit and push) - restricted to maintainers only
claude-mention:
@ -240,14 +243,17 @@ jobs:
uv venv --seed
uv sync
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }}
aws-region: ${{ secrets.AWS_REGION }}
- name: Run Claude Code
id: claude
uses: anthropics/claude-code-action@v1
with:
use_foundry: "true"
claude_args: '--model claude-opus-4-6 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
use_bedrock: "true"
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
additional_permissions: |
actions: read
env:
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}

View file

@ -42,10 +42,16 @@ jobs:
}
EOF
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }}
aws-region: ${{ secrets.AWS_REGION }}
- name: Run Claude Code
uses: anthropics/claude-code-action@v1
with:
use_foundry: "true"
use_bedrock: "true"
use_sticky_comment: true
allowed_bots: "claude[bot],codeflash-ai[bot]"
claude_args: '--mcp-config /tmp/mcp-config/mcp-servers.json --allowedTools "Read,Glob,Grep,Bash(git diff:*),Bash(git log:*),Bash(git show:*),Bash(wc *),Bash(find *),mcp__serena__*"'
@ -105,10 +111,6 @@ jobs:
- Concrete refactoring suggestion
If no significant duplication is found, say so briefly. Do not create issues — just comment on the PR.
env:
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}
- name: Stop Serena
if: always()
run: docker stop serena && docker rm serena || true

View file

@ -1,50 +0,0 @@
name: JavaScript/TypeScript Integration Tests
on:
push:
branches:
- main
pull_request:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
js-integration-tests:
name: JS/TS Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Install Python dependencies
run: |
uv venv --seed
uv sync
- name: Install npm dependencies for test projects
run: |
npm install --prefix code_to_optimize/js/code_to_optimize_js
npm install --prefix code_to_optimize/js/code_to_optimize_ts
npm install --prefix code_to_optimize/js/code_to_optimize_vitest
- name: Run JavaScript integration tests
run: |
uv run pytest tests/languages/javascript/ -v
uv run pytest tests/test_languages/test_vitest_e2e.py -v
uv run pytest tests/test_languages/test_javascript_e2e.py -v
uv run pytest tests/test_languages/test_javascript_support.py -v
uv run pytest tests/code_utils/test_config_js.py -v

2
.gitignore vendored
View file

@ -274,3 +274,5 @@ tessl.json
# Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status
AGENTS.md
.serena/
.codeflash/

98
LICENSE Normal file
View file

@ -0,0 +1,98 @@
Business Source License 1.1
Parameters
Licensor: CodeFlash Inc.
Licensed Work: Codeflash Client version 0.20.x
The Licensed Work is (c) 2024 CodeFlash Inc.
Additional Use Grant: None. Production use of the Licensed Work is only permitted
if you have entered into a separate written agreement
with CodeFlash Inc. for production use in connection
with a subscription to CodeFlash's Code Optimization
Platform. Please visit codeflash.ai for further
information.
Change Date: 2030-01-26
Change License: MIT
Notice
The Business Source License (this document, or the “License”) is not an Open
Source license. However, the Licensed Work will eventually be made available
under an Open Source License, as stated in this License.
License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved.
“Business Source License” is a trademark of MariaDB Corporation Ab.
-----------------------------------------------------------------------------
Business Source License 1.1
Terms
The Licensor hereby grants you the right to copy, modify, create derivative
works, redistribute, and make non-production use of the Licensed Work. The
Licensor may make an Additional Use Grant, above, permitting limited
production use.
Effective on the Change Date, or the fourth anniversary of the first publicly
available distribution of a specific version of the Licensed Work under this
License, whichever comes first, the Licensor hereby grants you rights under
the terms of the Change License, and the rights granted in the paragraph
above terminate.
If your use of the Licensed Work does not comply with the requirements
currently in effect as described in this License, you must purchase a
commercial license from the Licensor, its affiliated entities, or authorized
resellers, or you must refrain from using the Licensed Work.
All copies of the original and modified Licensed Work, and derivative works
of the Licensed Work, are subject to this License. This License applies
separately for each version of the Licensed Work and the Change Date may vary
for each version of the Licensed Work released by Licensor.
You must conspicuously display this License on each original or modified copy
of the Licensed Work. If you receive the Licensed Work in original or
modified form from a third party, the terms and conditions set forth in this
License apply to your use of that work.
Any use of the Licensed Work in violation of this License will automatically
terminate your rights under this License for the current and all other
versions of the Licensed Work.
This License does not grant you any right in any trademark or logo of
Licensor or its affiliates (provided that you may use a trademark or logo of
Licensor as expressly required by this License).
TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON
AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS,
EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND
TITLE.
MariaDB hereby grants you permission to use this Licenses text to license
your works, and to refer to it using the trademark “Business Source License”,
as long as you comply with the Covenants of Licensor below.
Covenants of Licensor
In consideration of the right to use this Licenses text and the “Business
Source License” name and trademark, Licensor covenants to MariaDB, and to all
other recipients of the licensed work to be provided by Licensor:
1. To specify as the Change License the GPL Version 2.0 or any later version,
or a license that is compatible with GPL Version 2.0 or a later version,
where “compatible” means that software provided under the Change License can
be included in a program with software provided under GPL Version 2.0 or a
later version. Licensor may specify additional Change Licenses without
limitation.
2. To either: (a) specify an additional grant of rights to use that does not
impose any additional restriction on the right granted in this License, as
the Additional Use Grant; or (b) insert the text “None”.
3. To specify a Change Date.
4. Not to modify this License in any other way.

View file

@ -0,0 +1,98 @@
Business Source License 1.1
Parameters
Licensor: CodeFlash Inc.
Licensed Work: Codeflash Client version 0.20.x
The Licensed Work is (c) 2024 CodeFlash Inc.
Additional Use Grant: None. Production use of the Licensed Work is only permitted
if you have entered into a separate written agreement
with CodeFlash Inc. for production use in connection
with a subscription to CodeFlash's Code Optimization
Platform. Please visit codeflash.ai for further
information.
Change Date: 2030-01-26
Change License: MIT
Notice
The Business Source License (this document, or the “License”) is not an Open
Source license. However, the Licensed Work will eventually be made available
under an Open Source License, as stated in this License.
License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved.
“Business Source License” is a trademark of MariaDB Corporation Ab.
-----------------------------------------------------------------------------
Business Source License 1.1
Terms
The Licensor hereby grants you the right to copy, modify, create derivative
works, redistribute, and make non-production use of the Licensed Work. The
Licensor may make an Additional Use Grant, above, permitting limited
production use.
Effective on the Change Date, or the fourth anniversary of the first publicly
available distribution of a specific version of the Licensed Work under this
License, whichever comes first, the Licensor hereby grants you rights under
the terms of the Change License, and the rights granted in the paragraph
above terminate.
If your use of the Licensed Work does not comply with the requirements
currently in effect as described in this License, you must purchase a
commercial license from the Licensor, its affiliated entities, or authorized
resellers, or you must refrain from using the Licensed Work.
All copies of the original and modified Licensed Work, and derivative works
of the Licensed Work, are subject to this License. This License applies
separately for each version of the Licensed Work and the Change Date may vary
for each version of the Licensed Work released by Licensor.
You must conspicuously display this License on each original or modified copy
of the Licensed Work. If you receive the Licensed Work in original or
modified form from a third party, the terms and conditions set forth in this
License apply to your use of that work.
Any use of the Licensed Work in violation of this License will automatically
terminate your rights under this License for the current and all other
versions of the Licensed Work.
This License does not grant you any right in any trademark or logo of
Licensor or its affiliates (provided that you may use a trademark or logo of
Licensor as expressly required by this License).
TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON
AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS,
EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND
TITLE.
MariaDB hereby grants you permission to use this Licenses text to license
your works, and to refer to it using the trademark “Business Source License”,
as long as you comply with the Covenants of Licensor below.
Covenants of Licensor
In consideration of the right to use this Licenses text and the “Business
Source License” name and trademark, Licensor covenants to MariaDB, and to all
other recipients of the licensed work to be provided by Licensor:
1. To specify as the Change License the GPL Version 2.0 or any later version,
or a license that is compatible with GPL Version 2.0 or a later version,
where “compatible” means that software provided under the Change License can
be included in a program with software provided under GPL Version 2.0 or a
later version. Licensor may specify additional Change Licenses without
limitation.
2. To either: (a) specify an additional grant of rights to use that does not
impose any additional restriction on the right granted in this License, as
the Additional Use Grant; or (b) insert the text “None”.
3. To specify a Change Date.
4. Not to modify this License in any other way.

View file

@ -0,0 +1,15 @@
# CodeFlash Benchmark
A pytest benchmarking plugin for [CodeFlash](https://codeflash.ai) - automatic code performance optimization.
## Installation
```bash
pip install codeflash-benchmark
```
## Usage
This plugin provides benchmarking capabilities for pytest tests used by CodeFlash's optimization pipeline.
For more information, visit [codeflash.ai](https://codeflash.ai).

View file

@ -1,32 +1,32 @@
[project]
name = "codeflash-benchmark"
version = "0.2.0"
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }]
requires-python = ">=3.9"
readme = "README.md"
license = {text = "BSL-1.1"}
keywords = [
"codeflash",
"benchmark",
"pytest",
"performance",
"testing",
]
dependencies = [
"pytest>=7.0.0,!=8.3.4",
]
[project.urls]
Homepage = "https://codeflash.ai"
Repository = "https://github.com/codeflash-ai/codeflash-benchmark"
[project.entry-points.pytest11]
codeflash-benchmark = "codeflash_benchmark.plugin"
[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
packages = ["codeflash_benchmark"]
[project]
name = "codeflash-benchmark"
version = "0.2.0"
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }]
requires-python = ">=3.9"
readme = "README.md"
license-files = ["LICENSE"]
keywords = [
"codeflash",
"benchmark",
"pytest",
"performance",
"testing",
]
dependencies = [
"pytest>=7.0.0,!=8.3.4",
]
[project.urls]
Homepage = "https://codeflash.ai"
Repository = "https://github.com/codeflash-ai/codeflash-benchmark"
[project.entry-points.pytest11]
codeflash-benchmark = "codeflash_benchmark.plugin"
[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
packages = ["codeflash_benchmark"]

View file

@ -4,8 +4,8 @@ from enum import Enum
from typing import Any, Union
MAX_TEST_RUN_ITERATIONS = 5
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 16000
TESTGEN_CONTEXT_TOKEN_LIMIT = 16000
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 48000
TESTGEN_CONTEXT_TOKEN_LIMIT = 48000
INDIVIDUAL_TESTCASE_TIMEOUT = 15 # For Python pytest
JAVA_TESTCASE_TIMEOUT = 120 # Java Maven tests need more time due to startup overhead
MAX_FUNCTION_TEST_SECONDS = 60

View file

@ -1518,73 +1518,207 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
return False
class AsyncDecoratorImportAdder(cst.CSTTransformer):
"""Transformer that adds the import for async decorators."""
ASYNC_HELPER_INLINE_CODE = """import asyncio
import gc
import os
import sqlite3
import time
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
self.mode = mode
self.has_import = False
import dill as pickle
def _get_decorator_name(self) -> str:
"""Get the decorator name based on the testing mode."""
if self.mode == TestingMode.BEHAVIOR:
return "codeflash_behavior_async"
if self.mode == TestingMode.CONCURRENCY:
return "codeflash_concurrency_async"
return "codeflash_performance_async"
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
# Check if the async decorator import is already present
if (
isinstance(node.module, cst.Attribute)
and isinstance(node.module.value, cst.Attribute)
and isinstance(node.module.value.value, cst.Name)
and node.module.value.value.value == "codeflash"
and node.module.value.attr.value == "code_utils"
and node.module.attr.value == "codeflash_wrap_decorator"
and not isinstance(node.names, cst.ImportStar)
):
decorator_name = self._get_decorator_name()
for import_alias in node.names:
if import_alias.name.value == decorator_name:
self.has_import = True
def get_run_tmp_file(file_path):
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return Path(get_run_tmp_file.tmpdir.name) / file_path
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# If the import is already there, don't add it again
if self.has_import:
return updated_node
# Choose import based on mode
decorator_name = self._get_decorator_name()
def extract_test_context_from_env():
test_module = os.environ["CODEFLASH_TEST_MODULE"]
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
if test_module and test_function:
return (test_module, test_class if test_class else None, test_function)
raise RuntimeError(
"Test context environment variables not set - ensure tests are run through codeflash test runner"
)
# Parse the import statement into a CST node
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
# Add the import to the module's body
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
def codeflash_behavior_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
class_prefix = (test_class_name + ".") if test_class_name else ""
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
codeflash_con = sqlite3.connect(db_path)
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
)
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}######!")
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
invocation_id,
codeflash_duration,
pickled_return_value,
"function_call",
),
)
codeflash_con.commit()
codeflash_con.close()
if exception:
raise exception
return return_value
return async_wrapper
def codeflash_performance_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
class_prefix = (test_class_name + ".") if test_class_name else ""
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
if exception:
raise exception
return return_value
return async_wrapper
def codeflash_concurrency_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
function_name = func.__name__
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
gc.disable()
try:
seq_start = time.perf_counter_ns()
for _ in range(concurrency_factor):
result = await func(*args, **kwargs)
sequential_time = time.perf_counter_ns() - seq_start
finally:
gc.enable()
gc.disable()
try:
conc_start = time.perf_counter_ns()
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
await asyncio.gather(*tasks)
concurrent_time = time.perf_counter_ns() - conc_start
finally:
gc.enable()
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
return result
return async_wrapper
"""
ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
def get_decorator_name_for_mode(mode: TestingMode) -> str:
if mode == TestingMode.BEHAVIOR:
return "codeflash_behavior_async"
if mode == TestingMode.CONCURRENCY:
return "codeflash_concurrency_async"
return "codeflash_performance_async"
def write_async_helper_file(target_dir: Path) -> Path:
"""Write the async decorator helper file to the target directory."""
helper_path = target_dir / ASYNC_HELPER_FILENAME
if not helper_path.exists():
helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8")
return helper_path
def add_async_decorator_to_function(
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
source_path: Path,
function: FunctionToOptimize,
mode: TestingMode = TestingMode.BEHAVIOR,
project_root: Path | None = None,
) -> bool:
"""Add async decorator to an async function definition and write back to file.
Args:
----
source_path: Path to the source file to modify in-place.
function: The FunctionToOptimize object representing the target async function.
mode: The testing mode to determine which decorator to apply.
Returns:
-------
Boolean indicating whether the decorator was successfully added.
Writes a helper file containing the decorator implementation to project_root (or source directory
as fallback) and adds a standard import + decorator to the source file.
"""
if not function.is_async:
return False
try:
# Read source code
with source_path.open(encoding="utf8") as f:
source_code = f.read()
@ -1594,10 +1728,14 @@ def add_async_decorator_to_function(
decorator_transformer = AsyncDecoratorAdder(function, mode)
module = module.visit(decorator_transformer)
# Add the import if decorator was added
if decorator_transformer.added_decorator:
import_transformer = AsyncDecoratorImportAdder(mode)
module = module.visit(import_transformer)
# Write the helper file to project_root (on sys.path) or source dir as fallback
helper_dir = project_root if project_root is not None else source_path.parent
write_async_helper_file(helper_dir)
# Add the import via CST so sort_imports can place it correctly
decorator_name = get_decorator_name_for_mode(mode)
import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}")
module = module.with_changes(body=[import_node, *list(module.body)])
modified_code = sort_imports(code=module.code, float_to_top=True)
except Exception as e:

View file

@ -520,15 +520,6 @@ class LanguageSupport(Protocol):
"""
...
def get_comment_prefix(self) -> str:
"""Get the comment prefix for this language.
Returns:
Comment prefix (e.g., "//" for JS, "#" for Python).
"""
...
def find_test_root(self, project_root: Path) -> Path | None:
"""Find the test root directory for a project.

View file

@ -34,7 +34,7 @@ if TYPE_CHECKING:
from codeflash.languages.base import LanguageSupport
# Module-level singleton for the current language
_current_language: Language | None = None
_current_language: Language = Language.PYTHON
def current_language() -> Language:

View file

@ -527,10 +527,5 @@ def parse_jest_test_xml(
f"[LOOP-SUMMARY] Results loop_index: min={min_idx}, max={max_idx}, "
f"unique_count={len(unique_loop_indices)}, total_results={len(loop_indices)}"
)
if max_idx == 1 and len(loop_indices) > 1:
logger.warning(
f"[LOOP-WARNING] All {len(loop_indices)} results have loop_index=1. "
"Perf test markers may not have been parsed correctly."
)
return test_results

View file

@ -1805,15 +1805,6 @@ class JavaScriptSupport:
"""
return ".test.js"
def get_comment_prefix(self) -> str:
"""Get the comment prefix for JavaScript.
Returns:
JavaScript single-line comment prefix.
"""
return "//"
def find_test_root(self, project_root: Path) -> Path | None:
"""Find the test root directory for a JavaScript project.

View file

@ -803,8 +803,6 @@ def run_jest_behavioral_tests(
wall_clock_ns = time.perf_counter_ns() - start_time_ns
logger.debug(f"Jest behavioral tests completed in {wall_clock_ns / 1e9:.2f}s")
print(result.stdout)
return result_file_path, result, coverage_json_path, None
@ -1046,6 +1044,10 @@ def run_jest_benchmarking_tests(
# Create result with combined stdout
result = subprocess.CompletedProcess(args=result.args, returncode=result.returncode, stdout=stdout, stderr="")
if result.returncode != 0:
logger.info(f"Jest benchmarking failed with return code {result.returncode}")
logger.info(f"Jest benchmarking stdout: {result.stdout}")
logger.info(f"Jest benchmarking stderr: {result.stderr}")
except subprocess.TimeoutExpired:
logger.warning(f"Jest benchmarking timed out after {total_timeout}s")

View file

@ -15,6 +15,8 @@ from codeflash.languages import is_java, is_javascript
from codeflash.models.models import CodeString, CodeStringsMarkdown
if TYPE_CHECKING:
from collections.abc import Callable
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, FunctionSource
@ -49,6 +51,69 @@ def extract_names_from_targets(target: cst.CSTNode) -> list[str]:
return names
def is_assignment_used(node: cst.CSTNode, definitions: dict[str, UsageInfo], name_prefix: str = "") -> bool:
if isinstance(node, cst.Assign):
for target in node.targets:
names = extract_names_from_targets(target.target)
for name in names:
lookup = f"{name_prefix}{name}" if name_prefix else name
if lookup in definitions and definitions[lookup].used_by_qualified_function:
return True
return False
if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
names = extract_names_from_targets(node.target)
for name in names:
lookup = f"{name_prefix}{name}" if name_prefix else name
if lookup in definitions and definitions[lookup].used_by_qualified_function:
return True
return False
return False
def recurse_sections(
node: cst.CSTNode,
section_names: list[str],
prune_fn: Callable[[cst.CSTNode], tuple[cst.CSTNode | None, bool]],
keep_non_target_children: bool = False,
) -> tuple[cst.CSTNode | None, bool]:
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
found_any_target = False
for section in section_names:
original_content = getattr(node, section, None)
if isinstance(original_content, (list, tuple)):
new_children = []
section_found_target = False
for child in original_content:
filtered, found_target = prune_fn(child)
if filtered:
new_children.append(filtered)
section_found_target |= found_target
if keep_non_target_children:
if section_found_target or new_children:
found_any_target |= section_found_target
updates[section] = new_children
elif section_found_target:
found_any_target = True
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_fn(original_content)
if keep_non_target_children:
found_any_target |= found_target
if filtered:
updates[section] = filtered
elif found_target:
found_any_target = True
if filtered:
updates[section] = filtered
if keep_non_target_children:
if updates:
return node.with_changes(**updates), found_any_target
return None, False
if not found_any_target:
return None, False
return (node.with_changes(**updates) if updates else node), True
def collect_top_level_definitions(
node: cst.CSTNode, definitions: Optional[dict[str, UsageInfo]] = None
) -> dict[str, UsageInfo]:
@ -423,27 +488,9 @@ def remove_unused_definitions_recursively(
elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
var_used = False
# Check if any variable in this assignment is used
if isinstance(statement, cst.Assign):
for target in statement.targets:
names = extract_names_from_targets(target.target)
for name in names:
class_var_name = f"{class_name}.{name}"
if (
class_var_name in definitions
and definitions[class_var_name].used_by_qualified_function
):
var_used = True
method_or_var_used = True
break
elif isinstance(statement, (cst.AnnAssign, cst.AugAssign)):
names = extract_names_from_targets(statement.target)
for name in names:
class_var_name = f"{class_name}.{name}"
if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function:
var_used = True
method_or_var_used = True
break
if is_assignment_used(statement, definitions, name_prefix=f"{class_name}."):
var_used = True
method_or_var_used = True
if var_used or class_has_dependencies:
new_statements.append(statement)
@ -459,56 +506,19 @@ def remove_unused_definitions_recursively(
return node, method_or_var_used or class_has_dependencies
# Handle assignments (Assign and AnnAssign)
if isinstance(node, cst.Assign):
for target in node.targets:
names = extract_names_from_targets(target.target)
for name in names:
if name in definitions and definitions[name].used_by_qualified_function:
return node, True
return None, False
if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
names = extract_names_from_targets(node.target)
for name in names:
if name in definitions and definitions[name].used_by_qualified_function:
return node, True
# Handle assignments (Assign, AnnAssign, AugAssign)
if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
if is_assignment_used(node, definitions):
return node, True
return None, False
# For other nodes, recursively process children
section_names = get_section_names(node)
if not section_names:
return node, False
updates = {}
found_used = False
for section in section_names:
original_content = getattr(node, section, None)
if isinstance(original_content, (list, tuple)):
new_children = []
section_found_used = False
for child in original_content:
filtered, used = remove_unused_definitions_recursively(child, definitions)
if filtered:
new_children.append(filtered)
section_found_used |= used
if new_children or section_found_used:
found_used |= section_found_used
updates[section] = new_children
elif original_content is not None:
filtered, used = remove_unused_definitions_recursively(original_content, definitions)
found_used |= used
if filtered:
updates[section] = filtered
if not found_used:
return None, False
if updates:
return node.with_changes(**updates), found_used
return node, False
return recurse_sections(
node, section_names, lambda child: remove_unused_definitions_recursively(child, definitions)
)
def collect_top_level_defs_with_usages(

View file

@ -21,9 +21,25 @@ from codeflash.languages.registry import register_language
if TYPE_CHECKING:
from collections.abc import Sequence
from codeflash.models.models import FunctionSource
logger = logging.getLogger(__name__)
def function_sources_to_helpers(sources: list[FunctionSource]) -> list[HelperFunction]:
return [
HelperFunction(
name=fs.only_function_name,
qualified_name=fs.qualified_name,
file_path=fs.file_path,
source_code=fs.source_code,
start_line=fs.jedi_definition.line if fs.jedi_definition else 1,
end_line=fs.jedi_definition.line if fs.jedi_definition else 1,
)
for fs in sources
]
@register_language
class PythonSupport:
"""Python language support implementation.
@ -171,127 +187,39 @@ class PythonSupport:
# === Code Analysis ===
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
"""Extract function code and its dependencies.
"""Extract function code and its dependencies via the canonical context pipeline."""
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
Uses jedi and libcst for Python code analysis.
Args:
function: The function to extract context for.
project_root: Root of the project.
module_root: Root of the module containing the function.
Returns:
CodeContext with target code and dependencies.
"""
try:
source = function.file_path.read_text()
result = get_code_optimization_context(function, project_root)
except Exception as e:
logger.exception("Failed to read %s: %s", function.file_path, e)
logger.warning("Failed to extract code context for %s: %s", function.function_name, e)
return CodeContext(target_code="", target_file=function.file_path, language=Language.PYTHON)
# Extract the function source
lines = source.splitlines(keepends=True)
if function.starting_line and function.ending_line:
target_lines = lines[function.starting_line - 1 : function.ending_line]
target_code = "".join(target_lines)
else:
target_code = ""
# Find helper functions
helpers = self.find_helper_functions(function, project_root)
# Extract imports
import_lines = []
for line in lines:
stripped = line.strip()
if stripped.startswith(("import ", "from ")):
import_lines.append(stripped)
elif stripped and not stripped.startswith("#"):
# Stop at first non-import, non-comment line
break
helpers = function_sources_to_helpers(result.helper_functions)
return CodeContext(
target_code=target_code,
target_code=result.read_writable_code.markdown,
target_file=function.file_path,
helper_functions=helpers,
read_only_context="",
imports=import_lines,
read_only_context=result.read_only_context_code,
imports=[],
language=Language.PYTHON,
)
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
"""Find helper functions called by the target function.
Uses jedi for Python code analysis.
Args:
function: The target function to analyze.
project_root: Root of the project.
Returns:
List of HelperFunction objects.
"""
helpers: list[HelperFunction] = []
"""Find helper functions called by the target function via the canonical jedi pipeline."""
from codeflash.languages.python.context.code_context_extractor import get_function_sources_from_jedi
try:
import jedi
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.optimization.function_context import belongs_to_function_qualified
script = jedi.Script(path=function.file_path, project=jedi.Project(path=project_root))
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
qualified_name = function.qualified_name
for ref in file_refs:
if not ref.full_name or not belongs_to_function_qualified(ref, qualified_name):
continue
try:
definitions = ref.goto(follow_imports=True, follow_builtin_imports=False)
except Exception:
continue
for definition in definitions:
definition_path = definition.module_path
if definition_path is None:
continue
# Check if it's a valid helper (in project, not in target function)
is_valid = (
str(definition_path).startswith(str(project_root))
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and not belongs_to_function_qualified(definition, qualified_name)
and definition.type == "function"
)
if is_valid:
helper_qualified_name = get_qualified_name(definition.module_name, definition.full_name)
# Get source code
try:
helper_source = definition.get_line_code()
except Exception:
helper_source = ""
helpers.append(
HelperFunction(
name=definition.name,
qualified_name=helper_qualified_name,
file_path=definition_path,
source_code=helper_source,
start_line=definition.line or 1,
end_line=definition.line or 1,
)
)
_dict, sources = get_function_sources_from_jedi(
{function.file_path: {function.qualified_name}}, project_root
)
except Exception as e:
logger.warning("Failed to find helpers for %s: %s", function.function_name, e)
return []
return helpers
return function_sources_to_helpers(sources)
def find_references(
self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500
@ -728,15 +656,6 @@ class PythonSupport:
"""
return ".py"
def get_comment_prefix(self) -> str:
"""Get the comment prefix for Python.
Returns:
Python single-line comment prefix.
"""
return "#"
def find_test_root(self, project_root: Path) -> Path | None:
"""Find the test root directory for a Python project.

View file

@ -12,6 +12,7 @@ from pathlib import Path
from typing import TYPE_CHECKING, Callable
import libcst as cst
from git import Repo as GitRepo
from rich.console import Group
from rich.panel import Panel
from rich.syntax import Syntax
@ -71,8 +72,6 @@ from codeflash.code_utils.line_profile_utils import add_decorator_imports, conta
from codeflash.code_utils.shell_utils import make_env_with_project_root
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.context import code_context_extractor
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
from codeflash.either import Failure, Success, is_successful
from codeflash.languages import is_java, is_javascript, is_python
@ -80,6 +79,11 @@ from codeflash.languages.base import Language
from codeflash.languages.current import current_language_support, is_typescript
from codeflash.languages.javascript.module_system import detect_module_system
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files
from codeflash.languages.python.context import code_context_extractor
from codeflash.languages.python.context.unused_definition_remover import (
detect_unused_helper_functions,
revert_unused_helper_functions,
)
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
from codeflash.models.ExperimentMetadata import ExperimentMetadata
@ -2231,6 +2235,7 @@ class FunctionOptimizer:
if self.args.override_fixtures:
restore_conftest(original_conftest_content)
cleanup_paths(paths_to_cleanup)
self.cleanup_async_helper_file()
return Failure(baseline_result.failure())
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
@ -2242,6 +2247,7 @@ class FunctionOptimizer:
if self.args.override_fixtures:
restore_conftest(original_conftest_content)
cleanup_paths(paths_to_cleanup)
self.cleanup_async_helper_file()
return Failure("The threshold for test confidence was not met.")
return Success(
@ -2411,7 +2417,10 @@ class FunctionOptimizer:
generated_tests_str = ""
code_lang = self.function_to_optimize.language
for test in generated_tests.generated_tests:
if map_gen_test_file_to_no_of_tests[test.behavior_file_path] > 0:
if any(
test_file.name == test.behavior_file_path.name and count > 0
for test_file, count in map_gen_test_file_to_no_of_tests.items()
):
formatted_generated_test = format_generated_code(
test.generated_original_test_source, self.args.formatter_cmds
)
@ -2551,11 +2560,11 @@ class FunctionOptimizer:
console.print(Panel(panel_content, title="Optimization Review", border_style=display_info[1]))
if raise_pr or staging_review:
data["root_dir"] = git_root_dir()
data["root_dir"] = git_root_dir(GitRepo(str(self.args.module_root), search_parent_directories=True))
if raise_pr and not staging_review and opt_review_result.review != "low":
# Ensure root_dir is set for PR creation (needed for async functions that skip opt_review)
if "root_dir" not in data:
data["root_dir"] = git_root_dir()
data["root_dir"] = git_root_dir(GitRepo(str(self.args.module_root), search_parent_directories=True))
data["git_remote"] = self.args.git_remote
# Remove language from data dict as check_create_pr doesn't accept it
pr_data = {k: v for k, v in data.items() if k != "language"}
@ -2610,6 +2619,13 @@ class FunctionOptimizer:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
self.cleanup_async_helper_file()
def cleanup_async_helper_file(self) -> None:
from codeflash.code_utils.instrument_existing_tests import ASYNC_HELPER_FILENAME
helper_path = self.project_root / ASYNC_HELPER_FILENAME
helper_path.unlink(missing_ok=True)
def establish_original_code_baseline(
self,
@ -2627,7 +2643,10 @@ class FunctionOptimizer:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
success = add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.BEHAVIOR,
project_root=self.project_root,
)
# Instrument codeflash capture
@ -2692,7 +2711,10 @@ class FunctionOptimizer:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.PERFORMANCE,
project_root=self.project_root,
)
try:
@ -2866,7 +2888,10 @@ class FunctionOptimizer:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.BEHAVIOR,
project_root=self.project_root,
)
try:
@ -2961,7 +2986,10 @@ class FunctionOptimizer:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.PERFORMANCE,
project_root=self.project_root,
)
try:
@ -3330,7 +3358,10 @@ class FunctionOptimizer:
try:
# Add concurrency decorator to the source function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.CONCURRENCY,
project_root=self.project_root,
)
# Run the concurrency benchmark tests

View file

@ -183,18 +183,54 @@ class Optimizer:
"""Discover functions to optimize."""
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
return get_functions_to_optimize(
# In worktree mode for git-diff discovery, file paths come from the original repo
# (via get_git_diff using cwd), but module_root/project_root have been mirrored to
# the worktree. Use the original roots for filtering so path comparisons match,
# then remap the discovered file paths to the worktree.
project_root = self.args.project_root
module_root = self.args.module_root
use_original_roots = (
self.current_worktree and self.original_args_and_test_cfg and not self.args.all and not self.args.file
)
if use_original_roots:
assert self.original_args_and_test_cfg is not None
original_args, _ = self.original_args_and_test_cfg
project_root = original_args.project_root
module_root = original_args.module_root
result = get_functions_to_optimize(
optimize_all=self.args.all,
replay_test=self.args.replay_test,
file=self.args.file,
only_get_this_function=self.args.function,
test_cfg=self.test_cfg,
ignore_paths=self.args.ignore_paths,
project_root=self.args.project_root,
module_root=self.args.module_root,
project_root=project_root,
module_root=module_root,
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
)
# Remap discovered file paths from the original repo to the worktree so
# downstream optimization reads/writes happen in the worktree.
if use_original_roots:
import dataclasses
assert self.current_worktree is not None
original_git_root = git_root_dir()
file_to_funcs, count, trace = result
remapped: dict[Path, list[FunctionToOptimize]] = {}
for file_path, funcs in file_to_funcs.items():
new_path = mirror_path(Path(file_path), original_git_root, self.current_worktree)
remapped[new_path] = [
dataclasses.replace(
func, file_path=mirror_path(func.file_path, original_git_root, self.current_worktree)
)
for func in funcs
]
return remapped, count, trace
return result
def create_function_optimizer(
self,
function_to_optimize: FunctionToOptimize,

View file

@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.20.0.post510.dev0+b8932209"
__version__ = "0.20.1"

View file

@ -6,8 +6,8 @@ codeflash/result/explanation.py
codeflash/result/critic.py
codeflash/version.py
codeflash/optimization/__init__.py
codeflash/context/__init__.py
codeflash/context/code_context_extractor.py
codeflash/languages/python/context/__init__.py
codeflash/languages/python/context/code_context_extractor.py
codeflash/discovery/__init__.py
codeflash/__init__.py
codeflash/models/ExperimentMetadata.py

View file

@ -113,21 +113,26 @@ function checkSharedTimeLimit() {
/**
* Get the current loop index for a specific invocation.
* The loop index represents how many times ALL test files have been run through.
* This is the batch count from the loop-runner.
* When using external loop-runner (Jest), returns the batch number directly.
* When using internal looping (Vitest), tracks and returns the invocation count.
*
* @param {string} invocationKey - Unique key for this test invocation
* @returns {number} The current batch number (loop index)
* @returns {number} The loop index for timing markers (1-based)
*/
function getInvocationLoopIndex(invocationKey) {
// Track local loop count for stopping logic (increments on each call)
// When using external loop-runner, use the batch number directly
// This is reliable because Jest resets module state between batches
const currentBatch = process.env.CODEFLASH_PERF_CURRENT_BATCH;
if (currentBatch !== undefined) {
return parseInt(currentBatch, 10);
}
// For internal looping (Vitest), track the count locally
if (!sharedPerfState.invocationLoopCounts[invocationKey]) {
sharedPerfState.invocationLoopCounts[invocationKey] = 0;
}
++sharedPerfState.invocationLoopCounts[invocationKey];
// Return the batch number as the loop index for timing markers
// This represents how many times all test files have been run through
return parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10);
return sharedPerfState.invocationLoopCounts[invocationKey];
}
/**
@ -693,11 +698,9 @@ function capturePerf(funcName, lineId, fn, ...args) {
// If not set, we're in Vitest mode and need to do all loops internally
const hasExternalLoopRunner = process.env.CODEFLASH_PERF_CURRENT_BATCH !== undefined;
// Batched looping: run BATCH_SIZE loops per capturePerf call when using loop-runner
// When using external loop-runner (Jest), execute only once per call - the loop-runner handles batching
// For Vitest (no loop-runner), do all loops internally in a single call
const batchSize = shouldLoop
? (hasExternalLoopRunner ? getPerfBatchSize() : getPerfLoopCount())
: 1;
const batchSize = hasExternalLoopRunner ? 1 : (shouldLoop ? getPerfLoopCount() : 1);
// Initialize runtime tracking for this invocation if needed
if (!sharedPerfState.invocationRuntimes[invocationKey]) {
@ -710,21 +713,21 @@ function capturePerf(funcName, lineId, fn, ...args) {
for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) {
// Check shared time limit BEFORE each iteration
if (shouldLoop && checkSharedTimeLimit()) {
if (!hasExternalLoopRunner && shouldLoop && checkSharedTimeLimit()) {
break;
}
// Check if this invocation has already reached stability
if (getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) {
if (!hasExternalLoopRunner && getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) {
break;
}
// Get the loop index (batch number) for timing markers
// Get the loop index for timing markers
const loopIndex = getInvocationLoopIndex(invocationKey);
// Check if we've exceeded max loops for this invocation
const totalIterations = getTotalIterations(invocationKey);
if (totalIterations > getPerfLoopCount()) {
if (!hasExternalLoopRunner && totalIterations > getPerfLoopCount()) {
break;
}
@ -776,7 +779,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
}
// Check stability after accumulating enough samples
if (getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) {
if (!hasExternalLoopRunner && getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) {
const window = getStabilityWindow();
if (shouldStopStability(runtimes, window, getPerfMinLoops())) {
sharedPerfState.stableInvocations[invocationKey] = true;
@ -785,7 +788,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
}
// If we had an error, stop looping
if (lastError) {
if (!hasExternalLoopRunner && lastError) {
break;
}
}

View file

@ -35,69 +35,113 @@ const path = require('path');
const fs = require('fs');
/**
* Validates that a jest-runner path is valid by checking for package.json.
* @param {string} jestRunnerPath - Path to check
* @returns {boolean} True if valid jest-runner package
* Recursively find jest-runner package in node_modules.
* Works with any package manager (npm, yarn, pnpm) by searching for
* jest-runner/package.json anywhere in the tree.
*
* @param {string} nodeModulesPath - Path to node_modules directory
* @param {number} maxDepth - Maximum recursion depth (default: 5)
* @returns {string|null} Path to jest-runner or null if not found
*/
function isValidJestRunnerPath(jestRunnerPath) {
if (!fs.existsSync(jestRunnerPath)) {
return false;
function findJestRunnerRecursive(nodeModulesPath, maxDepth = 5) {
function search(dir, depth) {
if (depth > maxDepth || !fs.existsSync(dir)) return null;
try {
let entries = fs.readdirSync(dir, { withFileTypes: true });
// Sort entries: prefer higher versions for jest-runner@X.Y.Z directories
entries = entries.slice().sort((a, b) => {
const aMatch = a.name.match(/^jest-runner@(\d+)/);
const bMatch = b.name.match(/^jest-runner@(\d+)/);
if (aMatch && bMatch) {
return parseInt(bMatch[1], 10) - parseInt(aMatch[1], 10);
}
return a.name.localeCompare(b.name);
});
for (const entry of entries) {
if (!entry.isDirectory()) continue;
const entryPath = path.join(dir, entry.name);
// Found jest-runner directory - check if it's a valid package
if (entry.name === 'jest-runner') {
const pkgJsonPath = path.join(entryPath, 'package.json');
if (fs.existsSync(pkgJsonPath)) {
try {
const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf8'));
if (pkgJson.name === 'jest-runner') {
return entryPath;
}
} catch (e) {
// Ignore JSON parse errors
}
}
}
// Recurse into:
// - node_modules subdirectories
// - scoped packages (@org/pkg)
// - hidden directories (.pnpm, .yarn, etc.)
// - pnpm versioned directories (jest-runner@30.0.5)
const shouldRecurse = entry.name === 'node_modules' ||
entry.name.startsWith('@') ||
entry.name === '.pnpm' || entry.name === '.yarn' ||
entry.name.startsWith('jest-runner@');
if (shouldRecurse) {
const result = search(entryPath, depth + 1);
if (result) return result;
}
}
} catch (e) {
// Ignore permission errors
}
return null;
}
const packageJsonPath = path.join(jestRunnerPath, 'package.json');
return fs.existsSync(packageJsonPath);
return search(nodeModulesPath, 0);
}
/**
* Resolve jest-runner with monorepo support.
* Uses CODEFLASH_MONOREPO_ROOT environment variable if available,
* otherwise walks up the directory tree looking for node_modules/jest-runner.
* Resolve jest-runner from the PROJECT's node_modules (not codeflash's).
*
* Uses recursive search to find jest-runner anywhere in node_modules,
* working with any package manager (npm, yarn, pnpm).
*
* @returns {string} Path to jest-runner package
* @throws {Error} If jest-runner cannot be found
*/
function resolveJestRunner() {
// Try standard resolution first (works in simple projects)
try {
return require.resolve('jest-runner');
} catch (e) {
// Standard resolution failed - try monorepo-aware resolution
}
// If Python detected a monorepo root, check there first
const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT;
if (monorepoRoot) {
const jestRunnerPath = path.join(monorepoRoot, 'node_modules', 'jest-runner');
if (isValidJestRunnerPath(jestRunnerPath)) {
return jestRunnerPath;
}
}
// Fallback: Walk up from cwd looking for node_modules/jest-runner
const monorepoMarkers = ['yarn.lock', 'pnpm-workspace.yaml', 'lerna.json', 'package-lock.json'];
// Walk up from cwd to find all potential node_modules locations
let currentDir = process.cwd();
const visitedDirs = new Set();
// If Python detected a monorepo root, check there first
const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT;
if (monorepoRoot && !visitedDirs.has(monorepoRoot)) {
visitedDirs.add(monorepoRoot);
const result = findJestRunnerRecursive(path.join(monorepoRoot, 'node_modules'));
if (result) return result;
}
while (currentDir !== path.dirname(currentDir)) {
// Avoid infinite loops
if (visitedDirs.has(currentDir)) break;
visitedDirs.add(currentDir);
// Try node_modules/jest-runner at this level
const jestRunnerPath = path.join(currentDir, 'node_modules', 'jest-runner');
if (isValidJestRunnerPath(jestRunnerPath)) {
return jestRunnerPath;
}
const result = findJestRunnerRecursive(path.join(currentDir, 'node_modules'));
if (result) return result;
// Check if this is a workspace root (has monorepo markers)
// Check if this is a workspace root - stop after this
const isWorkspaceRoot = monorepoMarkers.some(marker =>
fs.existsSync(path.join(currentDir, marker))
);
if (isWorkspaceRoot) {
// Found workspace root but no jest-runner - stop searching
break;
}
if (isWorkspaceRoot) break;
currentDir = path.dirname(currentDir);
}
@ -120,10 +164,15 @@ let jestVersion = 0;
try {
const jestRunnerPath = resolveJestRunner();
const internalRequire = createRequire(jestRunnerPath);
// Try to get the TestRunner class (Jest 30+)
const jestRunner = internalRequire(jestRunnerPath);
// Read the package.json to find the actual entry point and version
const pkgJsonPath = path.join(jestRunnerPath, 'package.json');
const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf8'));
// Require using the full path to the entry point
const entryPoint = path.join(jestRunnerPath, pkgJson.main || 'build/index.js');
const jestRunner = require(entryPoint);
TestRunner = jestRunner.default || jestRunner.TestRunner;
if (TestRunner && TestRunner.prototype && typeof TestRunner.prototype.runTests === 'function') {
@ -131,9 +180,11 @@ try {
jestVersion = 30;
jestRunnerAvailable = true;
} else {
// Try Jest 29 style import
// Try Jest 29 style import - runTest is in build/runTest.js
try {
runTest = internalRequire('./runTest').default;
const runTestPath = path.join(jestRunnerPath, 'build', 'runTest.js');
const runTestModule = require(runTestPath);
runTest = runTestModule.default;
if (typeof runTest === 'function') {
// Jest 29 - use direct runTest function
jestVersion = 29;
@ -141,17 +192,23 @@ try {
}
} catch (e29) {
// Neither Jest 29 nor 30 style import worked
const errorMsg = `Found jest-runner at ${jestRunnerPath} but could not load it. ` +
`This may indicate an unsupported Jest version. ` +
`Supported versions: Jest 29.x and Jest 30.x`;
console.error(errorMsg);
jestRunnerAvailable = false;
}
}
} catch (e) {
// jest-runner not installed - this is expected for Vitest projects
// The runner will throw a helpful error if someone tries to use it without jest-runner
jestRunnerAvailable = false;
// try to directly import jest-runner
try {
const jestRunner = require('jest-runner');
TestRunner = jestRunner.default || jestRunner.TestRunner;
if (TestRunner && TestRunner.prototype && typeof TestRunner.prototype.runTests === 'function') {
jestVersion = 30;
jestRunnerAvailable = true;
} else {
jestRunnerAvailable = false;
}
} catch (e2) {
jestRunnerAvailable = false;
}
}
// Configuration
@ -233,15 +290,12 @@ class CodeflashLoopRunner {
this._context = context || {};
this._eventEmitter = new SimpleEventEmitter();
// For Jest 30+, create an instance of the base TestRunner for delegation
if (jestVersion >= 30) {
if (!TestRunner) {
throw new Error(
`Jest ${jestVersion} detected but TestRunner class not available. ` +
`This indicates an internal error in loop-runner initialization.`
);
}
this._baseRunner = new TestRunner(globalConfig, context);
// For Jest 30+, verify TestRunner is available (we create fresh instances per batch)
if (jestVersion >= 30 && !TestRunner) {
throw new Error(
`Jest ${jestVersion} detected but TestRunner class not available. ` +
`This indicates an internal error in loop-runner initialization.`
);
}
}
@ -270,7 +324,7 @@ class CodeflashLoopRunner {
* @param {Object} options - Jest runner options
* @returns {Promise<void>}
*/
async runTests(tests, watcher, options) {
async runTests(tests, watcher, ...rest) {
const startTime = Date.now();
let batchCount = 0;
let hasFailure = false;
@ -289,13 +343,11 @@ class CodeflashLoopRunner {
// Check time limit BEFORE each batch
if (batchCount > MIN_BATCHES && checkTimeLimit()) {
console.log(`[codeflash] Time limit reached after ${batchCount - 1} batches (${Date.now() - startTime}ms elapsed)`);
break;
}
// Check if interrupted
if (watcher.isInterrupted()) {
console.log(`[codeflash] Watcher is interrupted`)
break;
}
@ -303,57 +355,54 @@ class CodeflashLoopRunner {
process.env.CODEFLASH_PERF_CURRENT_BATCH = String(batchCount);
// Run all test files in this batch
const batchResult = await this._runAllTestsOnce(tests, watcher, options);
const batchResult = await this._runAllTestsOnce(tests, watcher, ...rest);
allConsoleOutput += batchResult.consoleOutput;
// if (batchResult.hasFailure) {
// hasFailure = true;
// break;
// }
// Check time limit AFTER each batch
if (checkTimeLimit()) {
console.log(`[codeflash] Time limit reached after ${batchCount} batches (${Date.now() - startTime}ms elapsed)`);
break;
}
}
const totalTimeMs = Date.now() - startTime;
console.log(`[codeflash] now: ${Date.now()}`)
// Output all collected console logs - this is critical for timing marker extraction
// The console output contains the !######...######! timing markers from capturePerf
if (allConsoleOutput) {
process.stdout.write(allConsoleOutput);
}
console.log(`[codeflash] Batched runner completed: ${batchCount} batches, ${tests.length} test files, ${totalTimeMs}ms total`);
}
/**
* Run all test files once (one batch).
* Uses different approaches for Jest 29 vs Jest 30.
*/
async _runAllTestsOnce(tests, watcher, options) {
async _runAllTestsOnce(tests, watcher, ...args) {
if (jestVersion >= 30) {
return this._runAllTestsOnceJest30(tests, watcher, options);
return this._runAllTestsOnceJest30(tests, watcher, ...args);
} else {
return this._runAllTestsOnceJest29(tests, watcher);
}
}
/**
* Jest 30+ implementation - delegates to base TestRunner and collects results.
* Jest 30+ implementation - creates a fresh TestRunner for each batch to avoid
* state corruption issues that occur when reusing runners across batches.
*/
async _runAllTestsOnceJest30(tests, watcher, options) {
async _runAllTestsOnceJest30(tests, watcher, ...args) {
let hasFailure = false;
let allConsoleOutput = '';
// For Jest 30, we need to collect results through event listeners
const resultsCollector = [];
// Subscribe to events from the base runner
const unsubscribeSuccess = this._baseRunner.on('test-file-success', (testData) => {
// Create a FRESH TestRunner instance for each batch
// Jest 30's TestRunner corrupts its internal state after running tests,
// so we cannot reuse the same instance across multiple batches
const batchRunner = new TestRunner(this._globalConfig, this._context);
// Subscribe to events from the batch runner
const unsubscribeSuccess = batchRunner.on('test-file-success', (testData) => {
const [test, result] = testData;
resultsCollector.push({ test, result, success: true });
@ -369,7 +418,7 @@ class CodeflashLoopRunner {
this._eventEmitter.emit('test-file-success', testData);
});
const unsubscribeFailure = this._baseRunner.on('test-file-failure', (testData) => {
const unsubscribeFailure = batchRunner.on('test-file-failure', (testData) => {
const [test, error] = testData;
resultsCollector.push({ test, error, success: false });
hasFailure = true;
@ -378,14 +427,14 @@ class CodeflashLoopRunner {
this._eventEmitter.emit('test-file-failure', testData);
});
const unsubscribeStart = this._baseRunner.on('test-file-start', (testData) => {
const unsubscribeStart = batchRunner.on('test-file-start', (testData) => {
// Forward to our event emitter
this._eventEmitter.emit('test-file-start', testData);
});
try {
// Run tests using the base runner (always serial for benchmarking)
await this._baseRunner.runTests(tests, watcher, { ...options, serial: true });
// Run tests using the fresh batch runner (always serial for benchmarking)
await batchRunner.runTests(tests, watcher, ...args);
} finally {
// Cleanup subscriptions
if (typeof unsubscribeSuccess === 'function') unsubscribeSuccess();

View file

@ -5,7 +5,7 @@ description = "Client for codeflash.ai - automatic code performance optimization
authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }]
requires-python = ">=3.9"
readme = "README.md"
license = {text = "BSL-1.1"}
license-files = ["LICENSE"]
keywords = [
"codeflash",
"performance",
@ -356,4 +356,3 @@ markers = [
[build-system]
requires = ["hatchling", "uv-dynamic-versioning"]
build-backend = "hatchling.build"

View file

@ -1,8 +1,8 @@
from argparse import Namespace
from pathlib import Path
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer

View file

@ -13,7 +13,7 @@ def run_test(expected_improvement_pct: int) -> bool:
CoverageExpectation(
function_name="retry_with_backoff",
expected_coverage=100.0,
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
expected_lines=[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
)
],
)

View file

@ -8,7 +8,9 @@ from pathlib import Path
import pytest
from codeflash.code_utils.instrument_existing_tests import (
ASYNC_HELPER_FILENAME,
add_async_decorator_to_function,
get_decorator_name_for_mode,
inject_profiling_into_existing_test,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -55,16 +57,23 @@ async def test_async_sort():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
# For async functions, instrument the source module directly with decorators
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
# Verify the file was modified
# Verify the file was modified with exact expected output
instrumented_source = fto_path.read_text("utf-8")
assert (
'''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
in instrumented_source
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
decorated_original = original_code.replace(
"async def async_sorter", f"@{decorator_name}\nasync def async_sorter"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
# Add codeflash capture
instrument_codeflash_capture(func, {}, tests_root)
@ -142,6 +151,9 @@ async def test_async_sort():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -182,7 +194,9 @@ async def test_async_class_sort():
is_async=True,
)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
@ -264,6 +278,9 @@ async def test_async_class_sort():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -294,16 +311,23 @@ async def test_async_perf():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
# Instrument the source module with async performance decorators
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.PERFORMANCE)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.PERFORMANCE, project_root=project_root_path
)
assert source_success
# Verify the file was modified
instrumented_source = fto_path.read_text("utf-8")
assert (
instrumented_source
== '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
decorated_original = original_code.replace(
"async def async_sorter", f"@{decorator_name}\nasync def async_sorter"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
instrument_codeflash_capture(func, {}, tests_root)
@ -359,6 +383,9 @@ async def test_async_perf():
# Clean up test files
if test_path.exists():
test_path.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -404,68 +431,24 @@ async def async_error_function(lst):
function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True
)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
# Verify the file was modified
instrumented_source = fto_path.read_text("utf-8")
expected_instrumented_source = """import asyncio
from typing import List, Union
from codeflash.code_utils.formatter import sort_imports
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
\"\"\"
Async bubble sort implementation for testing.
\"\"\"
print("codeflash stdout: Async sorting list")
await asyncio.sleep(0.01)
n = len(lst)
for i in range(n):
for j in range(0, n - i - 1):
if lst[j] > lst[j + 1]:
lst[j], lst[j + 1] = lst[j + 1], lst[j]
result = lst.copy()
print(f"result: {result}")
return result
class AsyncBubbleSorter:
\"\"\"Class with async sorting method for testing.\"\"\"
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
\"\"\"
Async bubble sort implementation within a class.
\"\"\"
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
# Add some async delay
await asyncio.sleep(0.005)
n = len(lst)
for i in range(n):
for j in range(0, n - i - 1):
if lst[j] > lst[j + 1]:
lst[j], lst[j + 1] = lst[j + 1], lst[j]
result = lst.copy()
return result
@codeflash_behavior_async
async def async_error_function(lst):
\"\"\"Async function that raises an error for testing.\"\"\"
await asyncio.sleep(0.001) # Small delay
raise ValueError("Test error")
"""
assert expected_instrumented_source == instrumented_source
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
decorated_modified = modified_code.replace(
"async def async_error_function", f"@{decorator_name}\nasync def async_error_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_modified}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
instrument_codeflash_capture(func, {}, tests_root)
opt = Optimizer(
@ -526,6 +509,9 @@ async def async_error_function(lst):
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -563,7 +549,9 @@ async def test_async_multi():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
instrument_codeflash_capture(func, {}, tests_root)
@ -636,6 +624,9 @@ async def test_async_multi():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -678,7 +669,9 @@ async def test_async_edge_cases():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
instrument_codeflash_capture(func, {}, tests_root)
@ -753,6 +746,9 @@ async def test_async_edge_cases():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -988,7 +984,9 @@ async def test_mixed_sorting():
function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True
)
source_success = add_async_decorator_to_function(mixed_fto_path, async_func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
mixed_fto_path, async_func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
@ -1061,3 +1059,6 @@ async def test_mixed_sorting():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()

View file

@ -10,17 +10,15 @@ import pytest
from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.context.code_context_extractor import (
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.code_context_extractor import (
collect_names_from_annotation,
enrich_testgen_context,
extract_classes_from_type_hint,
extract_imports_for_class,
get_code_optimization_context,
get_external_base_class_inits,
get_external_class_inits,
get_imported_class_definitions,
resolve_transitive_type_deps,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
from codeflash.optimization.optimizer import Optimizer
@ -769,199 +767,6 @@ class HelperClass:
assert hashing_context.strip() == expected_hashing_context.strip()
def test_example_class_token_limit_1(tmp_path: Path) -> None:
docstring_filler = " ".join(
["This is a long docstring that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method.
{docstring_filler}\"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
# Create a temporary Python file using pytest's tmp_path fixture
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
expected_read_write_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def helper_method(self):
return self.x
```
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
pass
class HelperClass:
def __repr__(self):
return "HelperClass" + str(self.x)
```
"""
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
```
"""
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_example_class_token_limit_2(tmp_path: Path) -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
x = '{string_filler}'
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
# Create a temporary Python file using pytest's tmp_path fixture
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
expected_read_write_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def helper_method(self):
return self.x
```
"""
expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
"""A class with a helper method. """
class HelperClass:
"""A helper class for MyClass."""
def __repr__(self):
"""Return a string representation of the HelperClass."""
return "HelperClass" + str(self.x)
```
'''
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
```
"""
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_example_class_token_limit_3(tmp_path: Path) -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
@ -1009,7 +814,7 @@ class HelperClass:
)
# In this scenario, the read-writable code is too long, so we abort.
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000)
def test_example_class_token_limit_4(tmp_path: Path) -> None:
@ -1062,7 +867,7 @@ class HelperClass:
# In this scenario, the read-writable code context becomes too large because the __init__ function is referencing the global x variable instead of the class attribute self.x, so we abort.
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000)
def test_example_class_token_limit_5(tmp_path: Path) -> None:
@ -2422,7 +2227,7 @@ class OuterClass:
assert "__init__" not in hashing_context # Should not contain __init__ methods
# Verify nested classes are excluded from the hashing context
# The prune_cst_for_code_hashing function should not recurse into nested classes
# The prune_cst function in hashing mode should not recurse into nested classes
assert "class NestedClass:" not in hashing_context # Nested class definition should not be present
# The target method will reference NestedClass, but the actual nested class definition should not be included
@ -3275,8 +3080,8 @@ def dump_layout(layout_type, layout):
assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context"
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
"""Test that get_imported_class_definitions extracts class definitions from project modules."""
def test_enrich_testgen_context_extracts_project_classes(tmp_path: Path) -> None:
"""Test that enrich_testgen_context extracts class definitions from project modules."""
# Create a package structure with two modules
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
@ -3325,8 +3130,8 @@ class Accumulator:
# Create CodeStringsMarkdown from the chunking module (simulating testgen context)
context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Call enrich_testgen_context
result = enrich_testgen_context(context, tmp_path)
# Verify Element class was extracted
assert len(result.code_strings) == 1, "Should extract exactly one class (Element)"
@ -3339,8 +3144,8 @@ class Accumulator:
assert "import abc" in extracted_code, "Should include necessary imports for base class"
def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Path) -> None:
"""Test that get_imported_class_definitions skips classes already defined in context."""
def test_enrich_testgen_context_skips_existing_definitions(tmp_path: Path) -> None:
"""Test that enrich_testgen_context skips classes already defined in context."""
# Create a package structure
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
@ -3373,15 +3178,15 @@ class User:
context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Call enrich_testgen_context
result = enrich_testgen_context(context, tmp_path)
# Should NOT extract Element since it's already defined locally
assert len(result.code_strings) == 0, "Should not extract classes already defined in context"
def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> None:
"""Test that get_imported_class_definitions skips third-party/stdlib imports."""
def test_enrich_testgen_context_skips_third_party(tmp_path: Path) -> None:
"""Test that enrich_testgen_context skips third-party/stdlib imports."""
# Create a simple package
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
@ -3402,15 +3207,15 @@ class MyClass:
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Call enrich_testgen_context
result = enrich_testgen_context(context, tmp_path)
# Should not extract any classes (Path, Optional, dataclass are stdlib/third-party)
assert len(result.code_strings) == 0, "Should not extract stdlib/third-party classes"
def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) -> None:
"""Test that get_imported_class_definitions handles multiple class imports."""
def test_enrich_testgen_context_handles_multiple_imports(tmp_path: Path) -> None:
"""Test that enrich_testgen_context handles multiple class imports."""
# Create a package structure
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
@ -3446,8 +3251,8 @@ class Processor:
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Call enrich_testgen_context
result = enrich_testgen_context(context, tmp_path)
# Should extract both TypeA and TypeB (but not TypeC since it's not imported)
assert len(result.code_strings) == 2, "Should extract exactly two classes (TypeA, TypeB)"
@ -3458,8 +3263,8 @@ class Processor:
assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)"
def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None:
"""Test that get_imported_class_definitions includes decorators when extracting dataclasses."""
def test_enrich_testgen_context_includes_dataclass_decorators(tmp_path: Path) -> None:
"""Test that enrich_testgen_context includes decorators when extracting dataclasses."""
# Create a package structure
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
@ -3496,8 +3301,8 @@ class ConfigRegistry:
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Call enrich_testgen_context
result = enrich_testgen_context(context, tmp_path)
# Should extract both LLMConfigBase (base class) and LLMConfig
assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase"
@ -3521,7 +3326,7 @@ class ConfigRegistry:
assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import"
def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
def test_enrich_testgen_context_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
"""Test that extract_imports_for_class includes decorator and type annotation imports."""
# Create a package structure
package_dir = tmp_path / "mypackage"
@ -3552,7 +3357,7 @@ def create_config() -> Config:
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_imported_class_definitions(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert len(result.code_strings) == 1, "Should extract Config class"
extracted_code = result.code_strings[0].code
@ -3724,7 +3529,7 @@ class MyClass:
assert result.count("from typing import Optional") == 1
def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None:
def test_enrich_testgen_context_multiple_decorators(tmp_path: Path) -> None:
"""Test that classes with multiple decorators are extracted correctly."""
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
@ -3755,7 +3560,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_imported_class_definitions(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert len(result.code_strings) == 1
extracted_code = result.code_strings[0].code
@ -3766,7 +3571,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
assert "class OrderedConfig" in extracted_code
def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None:
def test_enrich_testgen_context_extracts_multilevel_inheritance(tmp_path: Path) -> None:
"""Test that base classes are recursively extracted for multi-level inheritance.
This is critical for understanding dataclass constructor signatures, as fields
@ -3826,8 +3631,8 @@ class ConfigRegistry:
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Call enrich_testgen_context
result = enrich_testgen_context(context, tmp_path)
# Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig
# (all classes needed to understand the full inheritance hierarchy)
@ -3862,7 +3667,7 @@ class ConfigRegistry:
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
def test_get_external_base_class_inits_extracts_userdict(tmp_path: Path) -> None:
def test_enrich_testgen_context_extracts_userdict(tmp_path: Path) -> None:
"""Extracts __init__ from collections.UserDict when a class inherits from it."""
code = """from collections import UserDict
@ -3873,7 +3678,7 @@ class MyCustomDict(UserDict):
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert len(result.code_strings) == 1
code_string = result.code_strings[0]
@ -3891,8 +3696,8 @@ class UserDict:
assert code_string.file_path.as_posix().endswith("collections/__init__.py")
def test_get_external_base_class_inits_skips_project_classes(tmp_path: Path) -> None:
"""Returns empty when base class is from the project, not external."""
def test_enrich_testgen_context_skips_unresolvable_base_classes(tmp_path: Path) -> None:
"""Returns empty when base class module cannot be resolved."""
child_code = """from base import ProjectBase
class Child(ProjectBase):
@ -3902,12 +3707,12 @@ class Child(ProjectBase):
child_path.write_text(child_code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=child_code, file_path=child_path)])
result = get_external_base_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert result.code_strings == []
def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None:
def test_enrich_testgen_context_skips_builtin_base_classes(tmp_path: Path) -> None:
"""Returns empty for builtin classes like list that have no inspectable source."""
code = """class MyList(list):
pass
@ -3916,12 +3721,12 @@ def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert result.code_strings == []
def test_get_external_base_class_inits_deduplicates(tmp_path: Path) -> None:
def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
"""Extracts the same external base class only once even when inherited multiple times."""
code = """from collections import UserDict
@ -3935,7 +3740,7 @@ class MyDict2(UserDict):
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert len(result.code_strings) == 1
expected_code = """\
@ -3950,7 +3755,7 @@ class UserDict:
assert result.code_strings[0].code == expected_code
def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) -> None:
def test_enrich_testgen_context_empty_when_no_inheritance(tmp_path: Path) -> None:
"""Returns empty when there are no external base classes."""
code = """class SimpleClass:
pass
@ -3959,7 +3764,7 @@ def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path)
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert result.code_strings == []
@ -4103,127 +3908,8 @@ class MyCustomDict(UserDict):
assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included"
def test_read_only_code_removed_when_exceeds_limit(tmp_path: Path) -> None:
"""Test read-only code is completely removed when it exceeds token limit even without docstrings.
This covers lines 152-153 in code_context_extractor.py where read_only_context_code is set
to empty string when it still exceeds the token limit after docstring removal.
"""
# Create a second-degree helper with large implementation that has no docstrings
# Second-degree helpers go into read-only context
long_lines = [" x = 0"]
for i in range(150):
long_lines.append(f" x = x + {i}")
long_lines.append(" return x")
long_body = "\n".join(long_lines)
code = f"""
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
return first_helper()
def first_helper():
# First degree helper - calls second degree
return second_helper()
def second_helper():
# Second degree helper - goes into read-only context
{long_body}
"""
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")]
)
# Use a small optim_token_limit that allows read-writable but not read-only
# Read-writable is ~48 tokens, read-only is ~600 tokens
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
optim_token_limit=100, # Small limit to trigger read-only removal
)
# The read-only context should be empty because it exceeded the limit
assert code_ctx.read_only_context_code == "", "Read-only code should be removed when exceeding token limit"
def test_testgen_removes_imported_classes_on_overflow(tmp_path: Path) -> None:
"""Test testgen context removes imported class definitions when exceeding token limit.
This covers lines 176-186 in code_context_extractor.py where:
- Testgen context exceeds limit (line 175)
- Removing docstrings still exceeds (line 175 again)
- Removing imported classes succeeds (line 177-183)
"""
# Create a package structure with a large type class used only in type annotations
# This ensures get_imported_class_definitions extracts the full class
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a large class with methods that will be extracted via get_imported_class_definitions
# Use methods WITHOUT docstrings so removing docstrings won't help much
many_methods = "\n".join([f" def method_{i}(self):\n return {i}" for i in range(100)])
type_class_code = f'''
class TypeClass:
"""A type class for annotations."""
def __init__(self, value: int):
self.value = value
{many_methods}
'''
type_class_path = package_dir / "types.py"
type_class_path.write_text(type_class_code, encoding="utf-8")
# Main module uses TypeClass only in annotation (not instantiated)
# This triggers get_imported_class_definitions to extract the full class
main_code = """
from mypackage.types import TypeClass
def target_function(obj: TypeClass) -> int:
return obj.value
"""
main_path = package_dir / "main.py"
main_path.write_text(main_code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(function_name="target_function", file_path=main_path, parents=[])
# Use a testgen_token_limit that:
# - Is exceeded by full context with imported class (~1500 tokens)
# - Is exceeded even after removing docstrings
# - But fits when imported class is removed (~40 tokens)
code_ctx = get_code_optimization_context(
function_to_optimize=func_to_optimize,
project_root_path=tmp_path,
testgen_token_limit=200, # Small limit to trigger imported class removal
)
# The testgen context should exist (didn't raise ValueError)
testgen_context = code_ctx.testgen_context.markdown
assert testgen_context, "Testgen context should not be empty"
# The target function should still be there
assert "def target_function" in testgen_context, "Target function should be in testgen context"
# The large imported class should NOT be included (removed due to token limit)
assert "class TypeClass" not in testgen_context, (
"TypeClass should be removed from testgen context when exceeding token limit"
)
def test_testgen_raises_when_all_fallbacks_fail(tmp_path: Path) -> None:
"""Test that ValueError is raised when testgen context exceeds limit even after all fallbacks.
This covers line 186 in code_context_extractor.py.
"""
def test_testgen_raises_when_exceeds_limit(tmp_path: Path) -> None:
"""Test that ValueError is raised when testgen context exceeds token limit."""
# Create a function with a very long body that exceeds limits even without imports/docstrings
long_lines = [" x = 0"]
for i in range(200):
@ -4249,7 +3935,7 @@ def target_function():
)
def test_get_external_base_class_inits_attribute_base(tmp_path: Path) -> None:
def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
"""Test handling of base class accessed as module.ClassName (ast.Attribute).
This covers line 616 in code_context_extractor.py.
@ -4265,7 +3951,7 @@ class MyDict(UserDict):
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
# Should extract UserDict __init__
assert len(result.code_strings) == 1
@ -4273,7 +3959,7 @@ class MyDict(UserDict):
assert "def __init__" in result.code_strings[0].code
def test_get_external_base_class_inits_no_init_method(tmp_path: Path) -> None:
def test_enrich_testgen_context_no_init_method(tmp_path: Path) -> None:
"""Test handling when base class has no __init__ method.
This covers line 641 in code_context_extractor.py.
@ -4288,7 +3974,7 @@ class MyProtocol(Protocol):
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
# Protocol's __init__ can't be easily inspected, should handle gracefully
# Result may be empty or contain Protocol based on implementation
@ -4377,7 +4063,7 @@ class MyClass:
def test_imported_class_definitions_module_path_none(tmp_path: Path) -> None:
"""Test handling when module_path is None in get_imported_class_definitions.
"""Test handling when module_path is None in enrich_testgen_context.
This covers line 560 in code_context_extractor.py.
"""
@ -4393,123 +4079,12 @@ class MyClass:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_imported_class_definitions(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
# Should handle gracefully and return empty or partial results
assert isinstance(result.code_strings, list)
def test_get_imported_names_import_star(tmp_path: Path) -> None:
"""Test get_imported_names handles import * correctly.
This covers lines 808-809 and 824-825 in code_context_extractor.py.
"""
import libcst as cst
# Test regular import *
# Note: "import *" is not valid Python, but "from x import *" is
from_import_star = cst.parse_statement("from os import *")
assert isinstance(from_import_star, cst.SimpleStatementLine)
import_node = from_import_star.body[0]
assert isinstance(import_node, cst.ImportFrom)
from codeflash.context.code_context_extractor import get_imported_names
result = get_imported_names(import_node)
assert result == {"*"}
def test_get_imported_names_aliased_import(tmp_path: Path) -> None:
"""Test get_imported_names handles aliased imports correctly.
This covers lines 812-813 and 828-829 in code_context_extractor.py.
"""
import libcst as cst
from codeflash.context.code_context_extractor import get_imported_names
# Test import with alias
import_stmt = cst.parse_statement("import numpy as np")
assert isinstance(import_stmt, cst.SimpleStatementLine)
import_node = import_stmt.body[0]
assert isinstance(import_node, cst.Import)
result = get_imported_names(import_node)
assert "np" in result
# Test from import with alias
from_import_stmt = cst.parse_statement("from os import path as ospath")
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
from_import_node = from_import_stmt.body[0]
assert isinstance(from_import_node, cst.ImportFrom)
result2 = get_imported_names(from_import_node)
assert "ospath" in result2
def test_get_imported_names_dotted_import(tmp_path: Path) -> None:
"""Test get_imported_names handles dotted imports correctly.
This covers lines 816-822 in code_context_extractor.py.
"""
import libcst as cst
from codeflash.context.code_context_extractor import get_imported_names
# Test dotted import like "import os.path"
import_stmt = cst.parse_statement("import os.path")
assert isinstance(import_stmt, cst.SimpleStatementLine)
import_node = import_stmt.body[0]
assert isinstance(import_node, cst.Import)
result = get_imported_names(import_node)
assert "os" in result
def test_used_name_collector_comprehensive(tmp_path: Path) -> None:
"""Test UsedNameCollector handles various node types.
This covers lines 767-801 in code_context_extractor.py.
"""
import libcst as cst
from codeflash.context.code_context_extractor import UsedNameCollector
code = """
import os
from typing import List
x: int = 1
y = os.path.join("a", "b")
class MyClass:
z = 10
def my_func():
pass
"""
module = cst.parse_module(code)
collector = UsedNameCollector()
# In libcst, the walker traverses the module
cst.MetadataWrapper(module).visit(collector)
# Check used names
assert "os" in collector.used_names
assert "int" in collector.used_names
assert "List" in collector.used_names
# Check defined names
assert "x" in collector.defined_names
assert "y" in collector.defined_names
assert "MyClass" in collector.defined_names
assert "my_func" in collector.defined_names
# Check external names (used but not defined)
external = collector.get_external_names()
assert "os" in external
assert "x" not in external # x is defined
def test_imported_class_with_base_in_same_module(tmp_path: Path) -> None:
"""Test that imported classes with bases in the same module are extracted correctly.
@ -4549,52 +4124,13 @@ def target_function(obj: DerivedClass) -> bool:
main_path.write_text(main_code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=main_code, file_path=main_path)])
result = get_imported_class_definitions(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
# Should extract the inheritance chain
all_code = "\n".join(cs.code for cs in result.code_strings)
assert "class BaseClass" in all_code or "class DerivedClass" in all_code
def test_get_imported_names_from_import_without_alias(tmp_path: Path) -> None:
"""Test get_imported_names handles from imports without aliases.
This covers lines 830-831 in code_context_extractor.py.
"""
import libcst as cst
from codeflash.context.code_context_extractor import get_imported_names
# Test from import without alias
from_import_stmt = cst.parse_statement("from os import path, getcwd")
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
from_import_node = from_import_stmt.body[0]
assert isinstance(from_import_node, cst.ImportFrom)
result = get_imported_names(from_import_node)
assert "path" in result
assert "getcwd" in result
def test_get_imported_names_regular_import(tmp_path: Path) -> None:
"""Test get_imported_names handles regular imports.
This covers lines 814-815 in code_context_extractor.py.
"""
import libcst as cst
from codeflash.context.code_context_extractor import get_imported_names
# Test regular import without alias
import_stmt = cst.parse_statement("import json")
assert isinstance(import_stmt, cst.SimpleStatementLine)
import_node = import_stmt.body[0]
assert isinstance(import_node, cst.Import)
result = get_imported_names(import_node)
assert "json" in result
def test_augmented_assignment_not_in_context(tmp_path: Path) -> None:
"""Test that augmented assignments are handled but not included unless used.
@ -4625,7 +4161,7 @@ class MyClass:
assert "counter" in read_writable
def test_get_external_class_inits_extracts_click_option(tmp_path: Path) -> None:
def test_enrich_testgen_context_extracts_click_option(tmp_path: Path) -> None:
"""Extracts __init__ from click.Option when directly imported."""
code = """from click import Option
@ -4636,7 +4172,7 @@ def my_func(opt: Option) -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert len(result.code_strings) == 1
code_string = result.code_strings[0]
@ -4645,8 +4181,8 @@ def my_func(opt: Option) -> None:
assert code_string.file_path is not None and "click" in code_string.file_path.as_posix()
def test_get_external_class_inits_skips_project_classes(tmp_path: Path) -> None:
"""Returns empty when imported class is from the project, not external."""
def test_enrich_testgen_context_extracts_project_class_defs(tmp_path: Path) -> None:
"""Extracts project class definitions via jedi resolution."""
# Create a project module with a class
(tmp_path / "mymodule.py").write_text("class ProjectClass:\n pass\n", encoding="utf-8")
@ -4659,12 +4195,13 @@ def my_func(obj: ProjectClass) -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert result.code_strings == []
assert len(result.code_strings) == 1
assert "class ProjectClass" in result.code_strings[0].code
def test_get_external_class_inits_skips_non_classes(tmp_path: Path) -> None:
def test_enrich_testgen_context_skips_non_classes(tmp_path: Path) -> None:
"""Returns empty when imported name is a function, not a class."""
code = """from collections import OrderedDict
from os.path import join
@ -4676,7 +4213,7 @@ def my_func() -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
# join is a function, not a class — should be skipped
# OrderedDict is a class and should be included
@ -4684,8 +4221,8 @@ def my_func() -> None:
assert not any("join" in name for name in class_names)
def test_get_external_class_inits_skips_already_defined_classes(tmp_path: Path) -> None:
"""Skips classes already defined in the context (e.g., added by get_imported_class_definitions)."""
def test_enrich_testgen_context_skips_already_defined_classes(tmp_path: Path) -> None:
"""Skips classes already defined in the context (e.g., added by enrich_testgen_context)."""
code = """from collections import UserDict
class UserDict:
@ -4699,14 +4236,14 @@ def my_func(d: UserDict) -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
# UserDict is already defined in the context, so it should be skipped
assert result.code_strings == []
def test_get_external_class_inits_skips_builtins(tmp_path: Path) -> None:
"""Returns empty for builtin classes like list/dict that have no inspectable source."""
def test_enrich_testgen_context_skips_builtin_annotations(tmp_path: Path) -> None:
"""Returns empty for builtin type annotations like list/dict that are not imported."""
code = """x: list = []
y: dict = {}
@ -4717,12 +4254,12 @@ def my_func() -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert result.code_strings == []
def test_get_external_class_inits_skips_object_init(tmp_path: Path) -> None:
def test_enrich_testgen_context_skips_object_init(tmp_path: Path) -> None:
"""Skips classes whose __init__ is just object.__init__ (trivial)."""
# enum.Enum has a metaclass-based __init__, but individual enum members
# effectively use object.__init__. Use a class we know has object.__init__.
@ -4735,14 +4272,14 @@ def my_func(q: QName) -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
# QName has its own __init__, so it should be included if it's in site-packages.
# But since it's stdlib (not site-packages), it should be skipped.
assert result.code_strings == []
def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None:
"""Returns empty when there are no from-imports."""
code = """def my_func() -> None:
pass
@ -4751,7 +4288,7 @@ def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
assert result.code_strings == []
@ -4840,17 +4377,17 @@ def test_resolve_transitive_type_deps_handles_failure_gracefully() -> None:
"""Returns empty list for a class where get_type_hints fails."""
class BadClass:
def __init__(self, x: "NonexistentType") -> None: # type: ignore[name-defined] # noqa: F821
def __init__(self, x: NonexistentType) -> None: # type: ignore[name-defined] # noqa: F821
pass
result = resolve_transitive_type_deps(BadClass)
assert result == []
# --- Integration tests for transitive resolution in get_external_class_inits ---
# --- Integration tests for transitive resolution in enrich_testgen_context ---
def test_get_external_class_inits_transitive_deps(tmp_path: Path) -> None:
def test_enrich_testgen_context_transitive_deps(tmp_path: Path) -> None:
"""Extracts transitive type dependencies from __init__ annotations."""
code = """from click import Context
@ -4861,7 +4398,7 @@ def my_func(ctx: Context) -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
class_names = {cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings}
assert "Context" in class_names
@ -4869,7 +4406,7 @@ def my_func(ctx: Context) -> None:
assert "Command" in class_names
def test_get_external_class_inits_no_infinite_loops(tmp_path: Path) -> None:
def test_enrich_testgen_context_no_infinite_loops(tmp_path: Path) -> None:
"""Handles classes with circular type references without infinite loops."""
# click.Context references Command, and Command references Context back
# This should terminate without issues due to the processed_classes set
@ -4882,13 +4419,13 @@ def my_func(ctx: Context) -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
# Should complete without hanging; just verify we got results
assert len(result.code_strings) >= 1
def test_get_external_class_inits_no_duplicate_stubs(tmp_path: Path) -> None:
def test_enrich_testgen_context_no_duplicate_stubs(tmp_path: Path) -> None:
"""Does not emit duplicate stubs for the same class name."""
code = """from click import Context
@ -4899,7 +4436,7 @@ def my_func(ctx: Context) -> None:
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
result = enrich_testgen_context(context, tmp_path)
class_names = [cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings]
assert len(class_names) == len(set(class_names)), f"Duplicate class stubs found: {class_names}"

View file

@ -2,7 +2,7 @@ from textwrap import dedent
import pytest
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.models.models import CodeContextType

View file

@ -2,7 +2,7 @@ from textwrap import dedent
import pytest
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.models.models import CodeContextType

View file

@ -2,7 +2,7 @@ from textwrap import dedent
import pytest
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
from codeflash.models.models import CodeContextType

View file

@ -6,7 +6,9 @@ from pathlib import Path
import pytest
from codeflash.code_utils.instrument_existing_tests import (
ASYNC_HELPER_FILENAME,
add_async_decorator_to_function,
get_decorator_name_for_mode,
inject_profiling_into_existing_test,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -57,20 +59,6 @@ def test_async_decorator_application_behavior_mode(temp_dir):
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
@ -86,7 +74,16 @@ async def async_function(x: int, y: int) -> int:
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_decorated_code.strip()
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = async_function_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert modified_code.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -94,20 +91,6 @@ def test_async_decorator_application_performance_mode(temp_dir):
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_performance_async
@codeflash_performance_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
@ -123,7 +106,16 @@ async def async_function(x: int, y: int) -> int:
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_decorated_code.strip()
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
code_with_decorator = async_function_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert modified_code.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -132,20 +124,6 @@ def test_async_decorator_application_concurrency_mode(temp_dir):
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_concurrency_async
@codeflash_concurrency_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
@ -161,7 +139,16 @@ async def async_function(x: int, y: int) -> int:
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_decorated_code.strip()
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.CONCURRENCY)
code_with_decorator = async_function_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert modified_code.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -182,27 +169,6 @@ class Calculator:
return a - b
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
class Calculator:
"""Test class with async methods."""
@codeflash_behavior_async
async def async_method(self, a: int, b: int) -> int:
"""Async method in class."""
await asyncio.sleep(0.005)
return a ** b
def sync_method(self, a: int, b: int) -> int:
"""Sync method in class."""
return a - b
'''
test_file = temp_dir / "test_async.py"
test_file.write_text(async_class_code)
@ -217,11 +183,21 @@ class Calculator:
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_decorated_code.strip()
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = async_class_code.replace(
" async def async_method", f" @{decorator_name}\n async def async_method"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert modified_code.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_decorator_no_duplicate_application(temp_dir):
# Case 1: Old-style import already present — injector should detect and skip
already_decorated_code = '''
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
import asyncio
@ -243,6 +219,30 @@ async def async_function(x: int, y: int) -> int:
# Should not add duplicate decorator
assert not decorator_added
# Case 2: Inline definition already present — injector should detect and skip
already_inline_code = '''
import asyncio
def codeflash_behavior_async(func):
return func
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Already decorated async function."""
await asyncio.sleep(0.01)
return x * y
'''
test_file2 = temp_dir / "test_async2.py"
test_file2.write_text(already_inline_code)
func2 = FunctionToOptimize(function_name="async_function", file_path=test_file2, parents=[], is_async=True)
decorator_added2 = add_async_decorator_to_function(test_file2, func2, TestingMode.BEHAVIOR)
# Should not add duplicate decorator
assert not decorator_added2
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_inject_profiling_async_function_behavior_mode(temp_dir):
@ -285,11 +285,18 @@ async def test_async_function():
assert source_success is True
# Verify the file was modified
# Verify the file was modified with exact expected output
instrumented_source = source_file.read_text()
assert "@codeflash_behavior_async" in instrumented_source
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_behavior_async" in instrumented_source
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = source_module_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
success, instrumented_test_code = inject_profiling_into_existing_test(
async_test_code, test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, mode=TestingMode.BEHAVIOR
@ -340,12 +347,18 @@ async def test_async_function():
assert source_success is True
# Verify the file was modified
# Verify the file was modified with exact expected output
instrumented_source = source_file.read_text()
assert "@codeflash_performance_async" in instrumented_source
# Check for the import with line continuation formatting
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_performance_async" in instrumented_source
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
code_with_decorator = source_module_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
# Now test the full pipeline with source module path
success, instrumented_test_code = inject_profiling_into_existing_test(
@ -406,11 +419,16 @@ async def test_mixed_functions():
# Verify the file was modified
instrumented_source = source_file.read_text()
assert "@codeflash_behavior_async" in instrumented_source
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_behavior_async" in instrumented_source
# Sync function should remain unchanged
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = source_module_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
success, instrumented_test_code = inject_profiling_into_existing_test(
mixed_test_code, test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR
@ -446,24 +464,19 @@ class OuterClass:
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
expected_output = """import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
class OuterClass:
class InnerClass:
@codeflash_behavior_async
async def nested_async_method(self, x: int) -> int:
\"\"\"Nested async method.\"\"\"
await asyncio.sleep(0.001)
return x * 2
"""
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_output.strip()
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = nested_async_code.replace(
" async def nested_async_method",
f" @{decorator_name}\n async def nested_async_method",
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert modified_code.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")

View file

@ -20,14 +20,12 @@ All assertions use strict string equality to verify exact extraction output.
from __future__ import annotations
from pathlib import Path
import pytest
from codeflash.context.code_context_extractor import get_code_optimization_context_for_language
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context_for_language
@pytest.fixture

View file

@ -106,9 +106,9 @@ class TestJavaScriptCodeContext:
def test_extract_code_context_for_javascript(self, js_project_dir):
"""Test extracting code context for a JavaScript function."""
skip_if_js_not_supported()
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.JAVASCRIPT

View file

@ -9,7 +9,6 @@ These tests verify the full optimization pipeline including:
This is the JavaScript equivalent of test_instrument_tests.py for Python.
"""
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@ -71,9 +70,9 @@ module.exports = { add };
def test_code_context_preserves_language(self, tmp_path):
"""Verify language is preserved in code context extraction."""
skip_if_js_not_supported()
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.TYPESCRIPT
@ -164,7 +163,7 @@ export function add(a: number, b: number): number {
# Mock the AI service request
ai_client = AiServiceClient()
with patch.object(ai_client, 'make_ai_service_request') as mock_request:
with patch.object(ai_client, "make_ai_service_request") as mock_request:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
@ -191,8 +190,8 @@ export function add(a: number, b: number): number {
# Verify the request was made with correct language
assert mock_request.called, "API request should have been made"
call_args = mock_request.call_args
payload = call_args[1].get('payload', call_args[0][1] if len(call_args[0]) > 1 else {})
assert payload.get('language') == 'typescript', \
payload = call_args[1].get("payload", call_args[0][1] if len(call_args[0]) > 1 else {})
assert payload.get("language") == "typescript", \
f"Expected language='typescript', got language='{payload.get('language')}'"
@ -462,7 +461,7 @@ class TestHelperFunctionLanguageAttribute:
"""Verify helper functions have language='javascript' for .js files."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current, get_language_support
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.JAVASCRIPT

View file

@ -69,7 +69,7 @@ class TestTypeScriptFunctionDiscovery:
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f:
f.write("""
f.write(r"""
export function add(a: number, b: number): number {
return a + b;
}
@ -123,9 +123,9 @@ class TestTypeScriptCodeContext:
def test_extract_code_context_for_typescript(self, ts_project_dir):
"""Test extracting code context for a TypeScript function."""
skip_if_ts_not_supported()
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.TYPESCRIPT
@ -201,7 +201,7 @@ function multiply(a: number, b: number): number {
from codeflash.languages import get_language_support
from codeflash.languages.base import FunctionInfo
original_source = """
original_source = r"""
interface Config {
timeout: number;
retries: number;
@ -212,7 +212,7 @@ function processConfig(config: Config): string {
}
"""
new_function = """function processConfig(config: Config): string {
new_function = r"""function processConfig(config: Config): string {
// Optimized with template caching
const { timeout, retries } = config;
return `timeout=\${timeout}, retries=\${retries}`;

View file

@ -117,10 +117,10 @@ class TestVitestCodeContext:
def test_extract_code_context_for_typescript(self, vitest_project_dir):
"""Test extracting code context for a TypeScript function."""
skip_if_js_not_supported()
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.base import Language
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.TYPESCRIPT

View file

@ -1,6 +1,6 @@
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
from codeflash.languages.python.context.unused_definition_remover import remove_unused_definitions_by_function_names
def test_variable_removal_only() -> None:

View file

@ -5,8 +5,11 @@ from pathlib import Path
import pytest
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.unused_definition_remover import (
detect_unused_helper_functions,
revert_unused_helper_functions,
)
from codeflash.models.models import CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig