Merge branch 'main' into omni-java

Resolve 7 merge conflicts from main's modular refactoring + JS improvements:

- aiservice.py: combine multi-language metadata (omni-java) with main's structure
- cmd_init.py: adopt main's modular split (init_config, init_auth, github_workflow) + add Java import
- code_replacer.py: main's clean early-return style + omni-java's non-Python single-block fallback
- version.py, test_support_dispatch.py, test_javascript_test_runner.py: take main's versions
- uv.lock: regenerated

Port Java into main's modular structure:
- Fix init_java.py lazy imports to point to new modules (init_config, init_auth, github_workflow)
- Add Java workflow support to github_workflow.py (detection, template, customization)
- Fix broken Java imports (function_optimizer, line_profiler) after main's module moves

Add safety tests for merge-critical functions:
- test_add_language_metadata.py: 10 tests covering per-language payload correctness
- test_code_replacer_matching.py: 8 tests covering fallback chain

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Mohamed Ashraf 2026-03-13 00:15:19 +00:00
commit fa9d32f1c4
89 changed files with 5887 additions and 5942 deletions

View file

@ -6,10 +6,15 @@ When adding, moving, or deleting source files, update this doc to match.
codeflash/
├── main.py # CLI entry point
├── cli_cmds/ # Command handling, console output (Rich)
│ ├── cmd_init.py # Init orchestrator + Python-specific setup
│ ├── init_config.py # Config types, validation, writing, shared UI helpers
│ ├── init_auth.py # API key management + GitHub app installation
│ ├── github_workflow.py # GitHub Actions workflow generation
│ ├── init_javascript.py # JavaScript/TypeScript project initialization
│ └── oauth_handler.py # OAuth PKCE flow for CodeFlash authentication
├── discovery/ # Find optimizable functions
├── optimization/ # Generate optimized code via AI
│ ├── optimizer.py # Main optimization orchestration
│ └── function_optimizer.py # Per-function optimization logic
│ └── optimizer.py # Main optimization orchestration
├── verification/ # Run deterministic tests (pytest plugin)
├── benchmarking/ # Performance measurement
├── github/ # PR creation
@ -20,12 +25,16 @@ codeflash/
│ ├── base.py # LanguageSupport protocol and shared data types
│ ├── registry.py # Language registration and lookup by extension/enum
│ ├── current.py # Current language singleton (set_current_language / current_language_support)
│ ├── function_optimizer.py # FunctionOptimizer base class for per-function optimization
│ ├── code_replacer.py # Language-agnostic code replacement
│ ├── python/
│ │ ├── support.py # PythonSupport (LanguageSupport implementation)
│ │ ├── function_optimizer.py # PythonFunctionOptimizer subclass
│ │ ├── optimizer.py # Python module preparation & AST resolution
│ │ └── normalizer.py # Python code normalization for deduplication
│ │ ├── normalizer.py # Python code normalization for deduplication
│ │ ├── test_runner.py # Test subprocess execution for Python
│ │ ├── instrument_codeflash_capture.py # Instrument __init__ with capture decorators
│ │ └── parse_line_profile_test_output.py # Parse line profiler output
│ └── javascript/
│ ├── support.py # JavaScriptSupport (LanguageSupport implementation)
│ ├── function_optimizer.py # JavaScriptFunctionOptimizer subclass
@ -46,9 +55,9 @@ codeflash/
| Task | Start here |
|------|------------|
| CLI arguments & commands | `cli_cmds/cli.py` |
| CLI arguments & commands | `cli_cmds/cli.py` (parsing), `main.py` (subcommand dispatch) |
| Optimization orchestration | `optimization/optimizer.py``run()` |
| Per-function optimization | `optimization/function_optimizer.py` (base), `languages/python/function_optimizer.py`, `languages/javascript/function_optimizer.py` |
| Per-function optimization | `languages/function_optimizer.py` (base), `languages/python/function_optimizer.py`, `languages/javascript/function_optimizer.py` |
| Function discovery | `discovery/functions_to_optimize.py` |
| Context extraction | `languages/<lang>/context/code_context_extractor.py` |
| Test execution | `languages/<lang>/support.py` (`run_behavioral_tests`, etc.), `verification/pytest_plugin.py` |

View file

@ -198,12 +198,20 @@ jobs:
For each PR:
- If CI passes and the PR is mergeable → merge with `--squash --delete-branch`
- Close the PR as stale if ANY of these apply:
- If CI is failing:
1. Check out the PR branch and inspect the failing tests
2. Attempt to fix the failures (the optimization may have broken tests or introduced issues)
3. If fixed: commit, push, and leave a comment explaining what was fixed
4. If unfixable: close with `gh pr close <number> --comment "Closing: CI checks are failing — <describe the specific failures and why they can't be auto-fixed>." --delete-branch`
- Close the PR (without attempting fixes) if ANY of these apply:
- Older than 7 days
- Has merge conflicts (mergeable state is "CONFLICTING")
- CI is failing
- The optimized function no longer exists in the target file (check the diff)
Close with: `gh pr close <number> --comment "Closing stale optimization PR." --delete-branch`
Close with: `gh pr close <number> --comment "<reason>" --delete-branch`
where <reason> explains WHY the PR is being closed. Examples:
- "Closing: PR is older than 7 days without being merged."
- "Closing: merge conflicts with the target branch."
- "Closing: the optimized function no longer exists in the target file."
</step>
<verification>

View file

@ -39,4 +39,4 @@ jobs:
- name: Codeflash Optimization
id: optimize_code
run: |
uv run codeflash --benchmark
uv run codeflash --benchmark --testgen-review

View file

@ -14,7 +14,7 @@
}
},
"../../../packages/codeflash": {
"version": "0.3.1",
"version": "0.10.1",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
@ -2892,9 +2892,9 @@
}
},
"node_modules/minimatch": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
"version": "3.1.5",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
"dev": true,
"license": "ISC",
"dependencies": {

View file

@ -14,7 +14,7 @@
}
},
"../../../packages/codeflash": {
"version": "0.3.1",
"version": "0.10.1",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
@ -2892,9 +2892,9 @@
}
},
"node_modules/minimatch": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
"version": "3.1.5",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
"dev": true,
"license": "ISC",
"dependencies": {

View file

@ -17,7 +17,7 @@
}
},
"../../../packages/codeflash": {
"version": "0.3.1",
"version": "0.10.1",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
@ -3474,12 +3474,6 @@
"@types/yargs-parser": "*"
}
},
"node_modules/@types/babel__generator": {
"dev": true
},
"node_modules/@types/babel__template": {
"dev": true
},
"node_modules/@types/estree": {
"version": "1.0.8",
"resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz",
@ -3487,19 +3481,10 @@
"dev": true,
"license": "MIT"
},
"node_modules/@types/istanbul-lib-report": {
"dev": true
},
"node_modules/@types/node": {
"dev": true
},
"node_modules/@types/yargs-parser": {
"dev": true
},
"node_modules/ajv": {
"version": "6.12.6",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
"integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==",
"version": "6.14.0",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz",
"integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==",
"dev": true,
"license": "MIT",
"dependencies": {
@ -7229,9 +7214,9 @@
"license": "MIT"
},
"node_modules/minimatch": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
"version": "3.1.5",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
"dev": true,
"license": "ISC",
"dependencies": {

View file

@ -9,7 +9,7 @@
"version": "1.0.0",
"devDependencies": {
"codeflash": "file:../../../packages/codeflash",
"mocha": "^10.7.0"
"mocha": "^10.8.2"
}
},
"../../../packages/codeflash": {

View file

@ -8,6 +8,6 @@
},
"devDependencies": {
"codeflash": "file:../../../packages/codeflash",
"mocha": "^10.7.0"
"mocha": "^10.8.2"
}
}

View file

@ -20,7 +20,7 @@
}
},
"../../../packages/codeflash": {
"version": "0.4.0",
"version": "0.10.1",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
@ -3060,9 +3060,9 @@
}
},
"node_modules/minimatch": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
"version": "3.1.5",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
"dev": true,
"license": "ISC",
"dependencies": {

View file

@ -15,7 +15,7 @@
}
},
"../../../packages/codeflash": {
"version": "0.8.0",
"version": "0.10.1",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
@ -497,9 +497,9 @@
"license": "MIT"
},
"node_modules/@rollup/rollup-android-arm-eabi": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.57.1.tgz",
"integrity": "sha512-A6ehUVSiSaaliTxai040ZpZ2zTevHYbvu/lDoeAteHI8QnaosIzm4qwtezfRg1jOYaUmnzLX1AOD6Z+UJjtifg==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.59.0.tgz",
"integrity": "sha512-upnNBkA6ZH2VKGcBj9Fyl9IGNPULcjXRlg0LLeaioQWueH30p6IXtJEbKAgvyv+mJaMxSm1l6xwDXYjpEMiLMg==",
"cpu": [
"arm"
],
@ -511,9 +511,9 @@
]
},
"node_modules/@rollup/rollup-android-arm64": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.57.1.tgz",
"integrity": "sha512-dQaAddCY9YgkFHZcFNS/606Exo8vcLHwArFZ7vxXq4rigo2bb494/xKMMwRRQW6ug7Js6yXmBZhSBRuBvCCQ3w==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.59.0.tgz",
"integrity": "sha512-hZ+Zxj3SySm4A/DylsDKZAeVg0mvi++0PYVceVyX7hemkw7OreKdCvW2oQ3T1FMZvCaQXqOTHb8qmBShoqk69Q==",
"cpu": [
"arm64"
],
@ -525,9 +525,9 @@
]
},
"node_modules/@rollup/rollup-darwin-arm64": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.57.1.tgz",
"integrity": "sha512-crNPrwJOrRxagUYeMn/DZwqN88SDmwaJ8Cvi/TN1HnWBU7GwknckyosC2gd0IqYRsHDEnXf328o9/HC6OkPgOg==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.59.0.tgz",
"integrity": "sha512-W2Psnbh1J8ZJw0xKAd8zdNgF9HRLkdWwwdWqubSVk0pUuQkoHnv7rx4GiF9rT4t5DIZGAsConRE3AxCdJ4m8rg==",
"cpu": [
"arm64"
],
@ -539,9 +539,9 @@
]
},
"node_modules/@rollup/rollup-darwin-x64": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.57.1.tgz",
"integrity": "sha512-Ji8g8ChVbKrhFtig5QBV7iMaJrGtpHelkB3lsaKzadFBe58gmjfGXAOfI5FV0lYMH8wiqsxKQ1C9B0YTRXVy4w==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.59.0.tgz",
"integrity": "sha512-ZW2KkwlS4lwTv7ZVsYDiARfFCnSGhzYPdiOU4IM2fDbL+QGlyAbjgSFuqNRbSthybLbIJ915UtZBtmuLrQAT/w==",
"cpu": [
"x64"
],
@ -553,9 +553,9 @@
]
},
"node_modules/@rollup/rollup-freebsd-arm64": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.57.1.tgz",
"integrity": "sha512-R+/WwhsjmwodAcz65guCGFRkMb4gKWTcIeLy60JJQbXrJ97BOXHxnkPFrP+YwFlaS0m+uWJTstrUA9o+UchFug==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.59.0.tgz",
"integrity": "sha512-EsKaJ5ytAu9jI3lonzn3BgG8iRBjV4LxZexygcQbpiU0wU0ATxhNVEpXKfUa0pS05gTcSDMKpn3Sx+QB9RlTTA==",
"cpu": [
"arm64"
],
@ -567,9 +567,9 @@
]
},
"node_modules/@rollup/rollup-freebsd-x64": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.57.1.tgz",
"integrity": "sha512-IEQTCHeiTOnAUC3IDQdzRAGj3jOAYNr9kBguI7MQAAZK3caezRrg0GxAb6Hchg4lxdZEI5Oq3iov/w/hnFWY9Q==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.59.0.tgz",
"integrity": "sha512-d3DuZi2KzTMjImrxoHIAODUZYoUUMsuUiY4SRRcJy6NJoZ6iIqWnJu9IScV9jXysyGMVuW+KNzZvBLOcpdl3Vg==",
"cpu": [
"x64"
],
@ -581,13 +581,16 @@
]
},
"node_modules/@rollup/rollup-linux-arm-gnueabihf": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.57.1.tgz",
"integrity": "sha512-F8sWbhZ7tyuEfsmOxwc2giKDQzN3+kuBLPwwZGyVkLlKGdV1nvnNwYD0fKQ8+XS6hp9nY7B+ZeK01EBUE7aHaw==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.59.0.tgz",
"integrity": "sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==",
"cpu": [
"arm"
],
"dev": true,
"libc": [
"glibc"
],
"license": "MIT",
"optional": true,
"os": [
@ -595,13 +598,16 @@
]
},
"node_modules/@rollup/rollup-linux-arm-musleabihf": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.57.1.tgz",
"integrity": "sha512-rGfNUfn0GIeXtBP1wL5MnzSj98+PZe/AXaGBCRmT0ts80lU5CATYGxXukeTX39XBKsxzFpEeK+Mrp9faXOlmrw==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.59.0.tgz",
"integrity": "sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==",
"cpu": [
"arm"
],
"dev": true,
"libc": [
"musl"
],
"license": "MIT",
"optional": true,
"os": [
@ -609,13 +615,16 @@
]
},
"node_modules/@rollup/rollup-linux-arm64-gnu": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.57.1.tgz",
"integrity": "sha512-MMtej3YHWeg/0klK2Qodf3yrNzz6CGjo2UntLvk2RSPlhzgLvYEB3frRvbEF2wRKh1Z2fDIg9KRPe1fawv7C+g==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.59.0.tgz",
"integrity": "sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==",
"cpu": [
"arm64"
],
"dev": true,
"libc": [
"glibc"
],
"license": "MIT",
"optional": true,
"os": [
@ -623,13 +632,16 @@
]
},
"node_modules/@rollup/rollup-linux-arm64-musl": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.57.1.tgz",
"integrity": "sha512-1a/qhaaOXhqXGpMFMET9VqwZakkljWHLmZOX48R0I/YLbhdxr1m4gtG1Hq7++VhVUmf+L3sTAf9op4JlhQ5u1Q==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.59.0.tgz",
"integrity": "sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==",
"cpu": [
"arm64"
],
"dev": true,
"libc": [
"musl"
],
"license": "MIT",
"optional": true,
"os": [
@ -637,13 +649,16 @@
]
},
"node_modules/@rollup/rollup-linux-loong64-gnu": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.57.1.tgz",
"integrity": "sha512-QWO6RQTZ/cqYtJMtxhkRkidoNGXc7ERPbZN7dVW5SdURuLeVU7lwKMpo18XdcmpWYd0qsP1bwKPf7DNSUinhvA==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.59.0.tgz",
"integrity": "sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==",
"cpu": [
"loong64"
],
"dev": true,
"libc": [
"glibc"
],
"license": "MIT",
"optional": true,
"os": [
@ -651,13 +666,16 @@
]
},
"node_modules/@rollup/rollup-linux-loong64-musl": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.57.1.tgz",
"integrity": "sha512-xpObYIf+8gprgWaPP32xiN5RVTi/s5FCR+XMXSKmhfoJjrpRAjCuuqQXyxUa/eJTdAE6eJ+KDKaoEqjZQxh3Gw==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.59.0.tgz",
"integrity": "sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==",
"cpu": [
"loong64"
],
"dev": true,
"libc": [
"musl"
],
"license": "MIT",
"optional": true,
"os": [
@ -665,13 +683,16 @@
]
},
"node_modules/@rollup/rollup-linux-ppc64-gnu": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.57.1.tgz",
"integrity": "sha512-4BrCgrpZo4hvzMDKRqEaW1zeecScDCR+2nZ86ATLhAoJ5FQ+lbHVD3ttKe74/c7tNT9c6F2viwB3ufwp01Oh2w==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.59.0.tgz",
"integrity": "sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==",
"cpu": [
"ppc64"
],
"dev": true,
"libc": [
"glibc"
],
"license": "MIT",
"optional": true,
"os": [
@ -679,13 +700,16 @@
]
},
"node_modules/@rollup/rollup-linux-ppc64-musl": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.57.1.tgz",
"integrity": "sha512-NOlUuzesGauESAyEYFSe3QTUguL+lvrN1HtwEEsU2rOwdUDeTMJdO5dUYl/2hKf9jWydJrO9OL/XSSf65R5+Xw==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.59.0.tgz",
"integrity": "sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==",
"cpu": [
"ppc64"
],
"dev": true,
"libc": [
"musl"
],
"license": "MIT",
"optional": true,
"os": [
@ -693,13 +717,16 @@
]
},
"node_modules/@rollup/rollup-linux-riscv64-gnu": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.57.1.tgz",
"integrity": "sha512-ptA88htVp0AwUUqhVghwDIKlvJMD/fmL/wrQj99PRHFRAG6Z5nbWoWG4o81Nt9FT+IuqUQi+L31ZKAFeJ5Is+A==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.59.0.tgz",
"integrity": "sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==",
"cpu": [
"riscv64"
],
"dev": true,
"libc": [
"glibc"
],
"license": "MIT",
"optional": true,
"os": [
@ -707,13 +734,16 @@
]
},
"node_modules/@rollup/rollup-linux-riscv64-musl": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.57.1.tgz",
"integrity": "sha512-S51t7aMMTNdmAMPpBg7OOsTdn4tySRQvklmL3RpDRyknk87+Sp3xaumlatU+ppQ+5raY7sSTcC2beGgvhENfuw==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.59.0.tgz",
"integrity": "sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==",
"cpu": [
"riscv64"
],
"dev": true,
"libc": [
"musl"
],
"license": "MIT",
"optional": true,
"os": [
@ -721,13 +751,16 @@
]
},
"node_modules/@rollup/rollup-linux-s390x-gnu": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.57.1.tgz",
"integrity": "sha512-Bl00OFnVFkL82FHbEqy3k5CUCKH6OEJL54KCyx2oqsmZnFTR8IoNqBF+mjQVcRCT5sB6yOvK8A37LNm/kPJiZg==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.59.0.tgz",
"integrity": "sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==",
"cpu": [
"s390x"
],
"dev": true,
"libc": [
"glibc"
],
"license": "MIT",
"optional": true,
"os": [
@ -735,13 +768,16 @@
]
},
"node_modules/@rollup/rollup-linux-x64-gnu": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.57.1.tgz",
"integrity": "sha512-ABca4ceT4N+Tv/GtotnWAeXZUZuM/9AQyCyKYyKnpk4yoA7QIAuBt6Hkgpw8kActYlew2mvckXkvx0FfoInnLg==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.59.0.tgz",
"integrity": "sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==",
"cpu": [
"x64"
],
"dev": true,
"libc": [
"glibc"
],
"license": "MIT",
"optional": true,
"os": [
@ -749,13 +785,16 @@
]
},
"node_modules/@rollup/rollup-linux-x64-musl": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.57.1.tgz",
"integrity": "sha512-HFps0JeGtuOR2convgRRkHCekD7j+gdAuXM+/i6kGzQtFhlCtQkpwtNzkNj6QhCDp7DRJ7+qC/1Vg2jt5iSOFw==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.59.0.tgz",
"integrity": "sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==",
"cpu": [
"x64"
],
"dev": true,
"libc": [
"musl"
],
"license": "MIT",
"optional": true,
"os": [
@ -763,9 +802,9 @@
]
},
"node_modules/@rollup/rollup-openbsd-x64": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.57.1.tgz",
"integrity": "sha512-H+hXEv9gdVQuDTgnqD+SQffoWoc0Of59AStSzTEj/feWTBAnSfSD3+Dql1ZruJQxmykT/JVY0dE8Ka7z0DH1hw==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.59.0.tgz",
"integrity": "sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==",
"cpu": [
"x64"
],
@ -777,9 +816,9 @@
]
},
"node_modules/@rollup/rollup-openharmony-arm64": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.57.1.tgz",
"integrity": "sha512-4wYoDpNg6o/oPximyc/NG+mYUejZrCU2q+2w6YZqrAs2UcNUChIZXjtafAiiZSUc7On8v5NyNj34Kzj/Ltk6dQ==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.59.0.tgz",
"integrity": "sha512-tt9KBJqaqp5i5HUZzoafHZX8b5Q2Fe7UjYERADll83O4fGqJ49O1FsL6LpdzVFQcpwvnyd0i+K/VSwu/o/nWlA==",
"cpu": [
"arm64"
],
@ -791,9 +830,9 @@
]
},
"node_modules/@rollup/rollup-win32-arm64-msvc": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.57.1.tgz",
"integrity": "sha512-O54mtsV/6LW3P8qdTcamQmuC990HDfR71lo44oZMZlXU4tzLrbvTii87Ni9opq60ds0YzuAlEr/GNwuNluZyMQ==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.59.0.tgz",
"integrity": "sha512-V5B6mG7OrGTwnxaNUzZTDTjDS7F75PO1ae6MJYdiMu60sq0CqN5CVeVsbhPxalupvTX8gXVSU9gq+Rx1/hvu6A==",
"cpu": [
"arm64"
],
@ -805,9 +844,9 @@
]
},
"node_modules/@rollup/rollup-win32-ia32-msvc": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.57.1.tgz",
"integrity": "sha512-P3dLS+IerxCT/7D2q2FYcRdWRl22dNbrbBEtxdWhXrfIMPP9lQhb5h4Du04mdl5Woq05jVCDPCMF7Ub0NAjIew==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.59.0.tgz",
"integrity": "sha512-UKFMHPuM9R0iBegwzKF4y0C4J9u8C6MEJgFuXTBerMk7EJ92GFVFYBfOZaSGLu6COf7FxpQNqhNS4c4icUPqxA==",
"cpu": [
"ia32"
],
@ -819,9 +858,9 @@
]
},
"node_modules/@rollup/rollup-win32-x64-gnu": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.57.1.tgz",
"integrity": "sha512-VMBH2eOOaKGtIJYleXsi2B8CPVADrh+TyNxJ4mWPnKfLB/DBUmzW+5m1xUrcwWoMfSLagIRpjUFeW5CO5hyciQ==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.59.0.tgz",
"integrity": "sha512-laBkYlSS1n2L8fSo1thDNGrCTQMmxjYY5G0WFWjFFYZkKPjsMBsgJfGf4TLxXrF6RyhI60L8TMOjBMvXiTcxeA==",
"cpu": [
"x64"
],
@ -833,9 +872,9 @@
]
},
"node_modules/@rollup/rollup-win32-x64-msvc": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.57.1.tgz",
"integrity": "sha512-mxRFDdHIWRxg3UfIIAwCm6NzvxG0jDX/wBN6KsQFTvKFqqg9vTrWUE68qEjHt19A5wwx5X5aUi2zuZT7YR0jrA==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.59.0.tgz",
"integrity": "sha512-2HRCml6OztYXyJXAvdDXPKcawukWY2GpR5/nxKp4iBgiO3wcoEGkAaqctIbZcNB6KlUQBIqt8VYkNSj2397EfA==",
"cpu": [
"x64"
],
@ -1222,9 +1261,9 @@
}
},
"node_modules/rollup": {
"version": "4.57.1",
"resolved": "https://registry.npmjs.org/rollup/-/rollup-4.57.1.tgz",
"integrity": "sha512-oQL6lgK3e2QZeQ7gcgIkS2YZPg5slw37hYufJ3edKlfQSGGm8ICoxswK15ntSzF/a8+h7ekRy7k7oWc3BQ7y8A==",
"version": "4.59.0",
"resolved": "https://registry.npmjs.org/rollup/-/rollup-4.59.0.tgz",
"integrity": "sha512-2oMpl67a3zCH9H79LeMcbDhXW/UmWG/y2zuqnF2jQq5uq9TbM9TVyXvA4+t+ne2IIkBdrLpAaRQAvo7YI/Yyeg==",
"dev": true,
"license": "MIT",
"dependencies": {
@ -1238,31 +1277,31 @@
"npm": ">=8.0.0"
},
"optionalDependencies": {
"@rollup/rollup-android-arm-eabi": "4.57.1",
"@rollup/rollup-android-arm64": "4.57.1",
"@rollup/rollup-darwin-arm64": "4.57.1",
"@rollup/rollup-darwin-x64": "4.57.1",
"@rollup/rollup-freebsd-arm64": "4.57.1",
"@rollup/rollup-freebsd-x64": "4.57.1",
"@rollup/rollup-linux-arm-gnueabihf": "4.57.1",
"@rollup/rollup-linux-arm-musleabihf": "4.57.1",
"@rollup/rollup-linux-arm64-gnu": "4.57.1",
"@rollup/rollup-linux-arm64-musl": "4.57.1",
"@rollup/rollup-linux-loong64-gnu": "4.57.1",
"@rollup/rollup-linux-loong64-musl": "4.57.1",
"@rollup/rollup-linux-ppc64-gnu": "4.57.1",
"@rollup/rollup-linux-ppc64-musl": "4.57.1",
"@rollup/rollup-linux-riscv64-gnu": "4.57.1",
"@rollup/rollup-linux-riscv64-musl": "4.57.1",
"@rollup/rollup-linux-s390x-gnu": "4.57.1",
"@rollup/rollup-linux-x64-gnu": "4.57.1",
"@rollup/rollup-linux-x64-musl": "4.57.1",
"@rollup/rollup-openbsd-x64": "4.57.1",
"@rollup/rollup-openharmony-arm64": "4.57.1",
"@rollup/rollup-win32-arm64-msvc": "4.57.1",
"@rollup/rollup-win32-ia32-msvc": "4.57.1",
"@rollup/rollup-win32-x64-gnu": "4.57.1",
"@rollup/rollup-win32-x64-msvc": "4.57.1",
"@rollup/rollup-android-arm-eabi": "4.59.0",
"@rollup/rollup-android-arm64": "4.59.0",
"@rollup/rollup-darwin-arm64": "4.59.0",
"@rollup/rollup-darwin-x64": "4.59.0",
"@rollup/rollup-freebsd-arm64": "4.59.0",
"@rollup/rollup-freebsd-x64": "4.59.0",
"@rollup/rollup-linux-arm-gnueabihf": "4.59.0",
"@rollup/rollup-linux-arm-musleabihf": "4.59.0",
"@rollup/rollup-linux-arm64-gnu": "4.59.0",
"@rollup/rollup-linux-arm64-musl": "4.59.0",
"@rollup/rollup-linux-loong64-gnu": "4.59.0",
"@rollup/rollup-linux-loong64-musl": "4.59.0",
"@rollup/rollup-linux-ppc64-gnu": "4.59.0",
"@rollup/rollup-linux-ppc64-musl": "4.59.0",
"@rollup/rollup-linux-riscv64-gnu": "4.59.0",
"@rollup/rollup-linux-riscv64-musl": "4.59.0",
"@rollup/rollup-linux-s390x-gnu": "4.59.0",
"@rollup/rollup-linux-x64-gnu": "4.59.0",
"@rollup/rollup-linux-x64-musl": "4.59.0",
"@rollup/rollup-openbsd-x64": "4.59.0",
"@rollup/rollup-openharmony-arm64": "4.59.0",
"@rollup/rollup-win32-arm64-msvc": "4.59.0",
"@rollup/rollup-win32-ia32-msvc": "4.59.0",
"@rollup/rollup-win32-x64-gnu": "4.59.0",
"@rollup/rollup-win32-x64-msvc": "4.59.0",
"fsevents": "~2.3.2"
}
},

View file

@ -15,14 +15,15 @@ from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.languages import Language, current_language
from codeflash.languages.current import current_language_support
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
AIServiceRefinerRequest,
CodeStringsMarkdown,
FunctionRepairInfo,
OptimizationReviewResult,
OptimizedCandidate,
OptimizedCandidateSource,
TestFileReview,
)
from codeflash.telemetry.posthog_cf import ph
from codeflash.version import __version__ as codeflash_version
@ -57,15 +58,23 @@ class AiServiceClient:
payload: dict[str, Any], language_version: str | None = None, module_system: str | None = None
) -> None:
"""Add language version and module system metadata to an API payload."""
# Canonical for all languages
payload["language_version"] = language_version
# Backward compat: Python backend still expects python_version
payload["python_version"] = language_version if current_language() == Language.PYTHON else None
if current_language() != Language.PYTHON:
if module_system:
payload["module_system"] = module_system
@staticmethod
def log_error_response(response: requests.Response, action: str, ph_event: str) -> None:
"""Log and report an API error response."""
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error {action}: {response.status_code} - {error}")
ph(ph_event, {"response_status_code": response.status_code, "error": error})
def get_aiservice_base_url(self) -> str:
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
logger.info("Using local AI Service at http://localhost:8000")
@ -97,14 +106,6 @@ class AiServiceClient:
------
requests.exceptions.RequestException: If the request fails
"""
"""Make an API request to the given endpoint on the AI service.
:param endpoint: The endpoint to call, e.g., "/optimize".
:param method: The HTTP method to use ('GET' or 'POST').
:param payload: Optional JSON payload to include in the POST request body.
:param timeout: The timeout for the request.
:return: The response object from the API.
"""
url = f"{self.base_url}/ai{endpoint}"
if method.upper() == "POST":
@ -214,12 +215,7 @@ class AiServiceClient:
logger.info(f"!lsp|Received {len(optimizations_json)} optimization candidates.")
console.rule()
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE, language)
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
console.rule()
return []
@ -286,12 +282,7 @@ class AiServiceClient:
end_time = time.perf_counter()
logger.debug(f"!lsp|Generating jit rewritten code took {end_time - start_time:.2f} seconds.")
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.JIT_REWRITE)
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating jit rewritten candidate: {response.status_code} - {error}")
ph("cli-jit-rewrite-error-response", {"response_status_code": response.status_code, "error": error})
self.log_error_response(response, "generating jit rewritten candidate", "cli-jit-rewrite-error-response")
console.rule()
return []
@ -360,12 +351,7 @@ class AiServiceClient:
logger.info(f"!lsp|Received {len(optimizations_json)} line profiler optimization candidates.")
console.rule()
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE_LP)
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
console.rule()
return []
@ -393,12 +379,7 @@ class AiServiceClient:
return valid_candidates[0]
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
return None
def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
@ -454,12 +435,7 @@ class AiServiceClient:
return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE)
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
console.rule()
return []
@ -506,12 +482,7 @@ class AiServiceClient:
return valid_candidates[0]
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
console.rule()
return None
@ -606,12 +577,7 @@ class AiServiceClient:
explanation: str = response.json()["explanation"]
console.rule()
return explanation
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
console.rule()
return ""
@ -658,12 +624,7 @@ class AiServiceClient:
ranking: list[int] = response.json()["ranking"]
console.rule()
return ranking
try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating ranking: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
self.log_error_response(response, "generating ranking", "cli-optimize-error-response")
console.rule()
return None
@ -724,7 +685,7 @@ class AiServiceClient:
language_version: str | None = None,
module_system: str | None = None,
is_numerical_code: bool | None = None,
) -> tuple[str, str, str] | None:
) -> tuple[str, str, str, str | None] | None:
"""Generate regression tests for the given function by making a request to the Django endpoint.
Parameters
@ -747,6 +708,8 @@ class AiServiceClient:
"""
# Validate test framework based on language
from codeflash.languages.current import current_language_support
lang_support = current_language_support()
valid_frameworks = lang_support.valid_test_frameworks
assert test_framework in valid_frameworks, (
@ -779,6 +742,8 @@ class AiServiceClient:
try:
response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout)
except requests.exceptions.RequestException as e:
from codeflash.telemetry.posthog_cf import ph
logger.exception(f"Error generating tests: {e}")
ph("cli-testgen-error-caught", {"error": str(e)})
return None
@ -792,17 +757,111 @@ class AiServiceClient:
response_json["generated_tests"],
response_json["instrumented_behavior_tests"],
response_json["instrumented_perf_tests"],
response_json.get("raw_generated_tests"),
)
self.log_error_response(response, "generating tests", "cli-testgen-error-response")
return None
def review_generated_tests(
self,
tests: list[dict[str, Any]],
function_source_code: str,
function_name: str,
trace_id: str,
coverage_summary: str = "",
coverage_details: dict[str, Any] | None = None,
language: str = "python",
) -> list[TestFileReview]:
payload: dict[str, Any] = {
"tests": tests,
"function_source_code": function_source_code,
"function_name": function_name,
"trace_id": trace_id,
"language": language,
"codeflash_version": codeflash_version,
"call_sequence": self.get_next_sequence(),
}
if coverage_summary:
payload["coverage_summary"] = coverage_summary
if coverage_details:
payload["coverage_details"] = coverage_details
self.add_language_metadata(payload)
try:
error = response.json()["error"]
logger.error(f"Error generating tests: {response.status_code} - {error}")
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
return None
except Exception:
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
response = self.make_ai_service_request("/testgen_review", payload=payload, timeout=self.timeout)
except requests.exceptions.RequestException as e:
logger.exception(f"Error reviewing generated tests: {e}")
ph("cli-testgen-review-error-caught", {"error": str(e)})
return []
if response.status_code == 200:
data = response.json()
return [
TestFileReview(
test_index=r["test_index"],
functions_to_repair=[
FunctionRepairInfo(function_name=f["function_name"], reason=f.get("reason", ""))
for f in r.get("functions", [])
],
)
for r in data.get("reviews", [])
]
self.log_error_response(response, "reviewing generated tests", "cli-testgen-review-error-response")
return []
def repair_generated_tests(
self,
test_source: str,
functions_to_repair: list[FunctionRepairInfo],
function_source_code: str,
function_to_optimize: FunctionToOptimize,
helper_function_names: list[str],
module_path: Path,
test_module_path: Path,
test_framework: str,
test_timeout: int,
trace_id: str,
language: str = "python",
coverage_details: dict[str, Any] | None = None,
previous_repair_errors: dict[str, str] | None = None,
module_source_code: str = "",
) -> tuple[str, str, str] | None:
payload: dict[str, Any] = {
"test_source": test_source,
"functions_to_repair": [
{"function_name": f.function_name, "reason": f.reason} for f in functions_to_repair
],
"function_source_code": function_source_code,
"function_to_optimize": function_to_optimize,
"helper_function_names": helper_function_names,
"module_path": module_path,
"test_module_path": test_module_path,
"test_framework": test_framework,
"test_timeout": test_timeout,
"trace_id": trace_id,
"language": language,
"codeflash_version": codeflash_version,
"call_sequence": self.get_next_sequence(),
}
if module_source_code:
payload["module_source_code"] = module_source_code
if coverage_details:
payload["coverage_details"] = coverage_details
if previous_repair_errors:
payload["previous_repair_errors"] = previous_repair_errors
self.add_language_metadata(payload)
try:
response = self.make_ai_service_request("/testgen_repair", payload=payload, timeout=self.timeout)
except requests.exceptions.RequestException as e:
logger.exception(f"Error repairing generated tests: {e}")
ph("cli-testgen-repair-error-caught", {"error": str(e)})
return None
if response.status_code == 200:
data = response.json()
return (data["generated_tests"], data["instrumented_behavior_tests"], data["instrumented_perf_tests"])
self.log_error_response(response, "repairing generated tests", "cli-testgen-repair-error-response")
return None
def get_optimization_review(
self,
original_code: dict[Path, str],
@ -875,12 +934,7 @@ class AiServiceClient:
return OptimizationReviewResult(
review=cast("str", data["review"]), explanation=cast("str", data.get("review_explanation", ""))
)
try:
error = cast("str", response.json()["error"])
except Exception:
error = response.text
logger.error(f"Error generating optimization review: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
self.log_error_response(response, "generating optimization review", "cli-optimize-error-response")
console.rule()
return OptimizationReviewResult(review="", explanation="")

View file

@ -2,13 +2,11 @@ import logging
import os
import sys
from argparse import SUPPRESS, ArgumentParser, Namespace
from functools import lru_cache
from pathlib import Path
from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.extension import install_vscode_extension
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message, normalize_ignore_paths
from codeflash.code_utils.config_parser import parse_config_file
@ -18,130 +16,14 @@ from codeflash.version import __version__ as version
def parse_args() -> Namespace:
parser = ArgumentParser()
subparsers = parser.add_subparsers(dest="command", help="Sub-commands")
init_parser = subparsers.add_parser("init", help="Initialize Codeflash for your project.")
init_parser.set_defaults(func=init_codeflash)
subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension")
init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
init_actions_parser.set_defaults(func=install_github_actions)
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.")
from codeflash.tracer import main as tracer_main
trace_optimize.set_defaults(func=tracer_main)
trace_optimize.add_argument(
"--max-function-count",
type=int,
default=100,
help="The maximum number of times to trace a single function. More calls to a function will not be traced. Default is 100.",
)
trace_optimize.add_argument(
"--timeout",
type=int,
help="The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows, to not wait indefinitely.",
)
trace_optimize.add_argument(
"--output",
type=str,
default="codeflash.trace",
help="The file to save the trace to. Default is codeflash.trace.",
)
trace_optimize.add_argument(
"--config-file-path",
type=str,
help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.",
)
parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
parser.add_argument(
"--all",
help="Try to optimize all functions. Can take a really long time. Can pass an optional starting directory to"
" optimize code from. If no args specified (just --all), will optimize all code in the project.",
nargs="?",
const="",
default=SUPPRESS,
)
parser.add_argument(
"--module-root",
type=str,
help="Path to the project's module that you want to optimize."
" This is the top-level root directory where all the source code is located.",
)
parser.add_argument(
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
)
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
parser.add_argument(
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
)
parser.add_argument(
"--no-gen-tests", action="store_true", help="Do not generate tests, use only existing tests for optimization."
)
parser.add_argument(
"--no-jit-opts", action="store_true", help="Do not generate JIT-compiled optimizations for numerical code."
)
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
parser.add_argument(
"--verify-setup",
action="store_true",
help="Verify that codeflash is set up correctly by optimizing bubble sort as a test.",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
parser.add_argument("--version", action="store_true", help="Print the version of codeflash")
parser.add_argument(
"--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks"
)
parser.add_argument(
"--benchmarks-root",
type=str,
help="Path to the directory of the project, where all the pytest-benchmark tests are located.",
)
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
parser.add_argument(
"--async",
default=False,
action="store_true",
help="(Deprecated) Async function optimization is now enabled by default. This flag is ignored.",
)
parser.add_argument(
"--server",
type=str,
choices=["local", "prod"],
help="AI service server to use: 'local' for localhost:8000, 'prod' for app.codeflash.ai",
)
parser.add_argument(
"--effort", type=str, help="Effort level for optimization", choices=["low", "medium", "high"], default="medium"
)
# Config management flags
parser.add_argument(
"--show-config", action="store_true", help="Show current or auto-detected configuration and exit."
)
parser.add_argument(
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
)
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
parser.add_argument(
"--subagent",
action="store_true",
help="Subagent mode: skip all interactive prompts with sensible defaults. Designed for AI agent integrations.",
)
parser = _build_parser()
args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args]
if args.subagent:
args.yes = True
args.no_pr = True
args.worktree = True
args.effort = "low"
return process_and_validate_cmd_args(args)
@ -176,10 +58,6 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
_handle_reset_config(confirm=not getattr(args, "yes", False))
sys.exit()
if args.command == "vscode-install":
install_vscode_extension()
sys.exit()
if not check_running_in_git_repo(module_root=args.module_root):
if not confirm_proceeding_with_no_git_repo():
exit_with_message("No git repository detected and user aborted run. Exiting...", error_on_exit=True)
@ -201,12 +79,6 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
if env_utils.is_ci():
args.no_pr = True
if getattr(args, "async", False):
logger.warning(
"The --async flag is deprecated and will be removed in a future version. "
"Async function optimization is now enabled by default."
)
return args
@ -362,6 +234,8 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
f"I couldn't find a git repository in the current directory. "
f"I need a git repository to run {mode} and open PRs for optimizations. Exiting..."
)
from codeflash.cli_cmds.cli_common import apologize_and_exit
apologize_and_exit()
git_remote = getattr(args, "git_remote", None)
if not check_and_push_branch(git_repo, git_remote=git_remote):
@ -487,3 +361,124 @@ def _handle_reset_config(confirm: bool = True) -> None:
console.print(f"[green]✓[/green] {escaped_message}")
else:
console.print(f"[red]✗[/red] {escaped_message}")
@lru_cache(maxsize=1)
def _build_parser() -> ArgumentParser:
parser = ArgumentParser()
subparsers = parser.add_subparsers(dest="command", help="Sub-commands")
subparsers.add_parser("init", help="Initialize Codeflash for your project.")
subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension")
subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.")
trace_optimize.add_argument(
"--max-function-count",
type=int,
default=100,
help="The maximum number of times to trace a single function. More calls to a function will not be traced. Default is 100.",
)
trace_optimize.add_argument(
"--timeout",
type=int,
help="The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows, to not wait indefinitely.",
)
trace_optimize.add_argument(
"--output",
type=str,
default="codeflash.trace",
help="The file to save the trace to. Default is codeflash.trace.",
)
trace_optimize.add_argument(
"--config-file-path",
type=str,
help="The path to the pyproject.toml file which stores the Codeflash config. This is auto-discovered by default.",
)
parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
parser.add_argument(
"--all",
help="Try to optimize all functions. Can take a really long time. Can pass an optional starting directory to"
" optimize code from. If no args specified (just --all), will optimize all code in the project.",
nargs="?",
const="",
default=SUPPRESS,
)
parser.add_argument(
"--module-root",
type=str,
help="Path to the project's module that you want to optimize."
" This is the top-level root directory where all the source code is located.",
)
parser.add_argument(
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
)
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
parser.add_argument(
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
)
parser.add_argument(
"--no-gen-tests", action="store_true", help="Do not generate tests, use only existing tests for optimization."
)
parser.add_argument(
"--no-jit-opts", action="store_true", help="Do not generate JIT-compiled optimizations for numerical code."
)
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
parser.add_argument(
"--verify-setup",
action="store_true",
help="Verify that codeflash is set up correctly by optimizing bubble sort as a test.",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
parser.add_argument("--version", action="store_true", help="Print the version of codeflash")
parser.add_argument(
"--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks"
)
parser.add_argument(
"--benchmarks-root",
type=str,
help="Path to the directory of the project, where all the pytest-benchmark tests are located.",
)
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
parser.add_argument(
"--testgen-review", default=False, action="store_true", help="Enable AI review and repair of generated tests"
)
parser.add_argument(
"--testgen-review-turns", type=int, default=None, help="Number of review/repair cycles (default: 2)"
)
parser.add_argument(
"--async",
default=False,
action="store_true",
help="(Deprecated) Async function optimization is now enabled by default. This flag is ignored.",
)
parser.add_argument(
"--server",
type=str,
choices=["local", "prod"],
help="AI service server to use: 'local' for localhost:8000, 'prod' for app.codeflash.ai",
)
parser.add_argument(
"--effort", type=str, help="Effort level for optimization", choices=["low", "medium", "high"], default="medium"
)
# Config management flags
parser.add_argument(
"--show-config", action="store_true", help="Show current or auto-detected configuration and exit."
)
parser.add_argument(
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
)
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
parser.add_argument(
"--subagent",
action="store_true",
help="Subagent mode: skip all interactive prompts with sensible defaults. Designed for AI agent integrations.",
)
return parser

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import shutil
import sys
from typing import Callable, cast
from typing import Callable, NoReturn, cast
import click
import inquirer
@ -10,7 +10,7 @@ import inquirer
from codeflash.cli_cmds.console import console, logger
def apologize_and_exit() -> None:
def apologize_and_exit() -> NoReturn:
console.rule()
logger.info(
"💡 If you're having trouble, see https://docs.codeflash.ai/getting-started/local-installation for further help getting started with Codeflash!"

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import contextlib
import logging
from collections import deque
from contextlib import contextmanager
@ -418,6 +419,7 @@ def subagent_log_optimization_result(
new_code: dict[Path, str],
review: str,
test_results: TestResults,
project_root: Path | None = None,
) -> None:
import sys
from xml.sax.saxutils import escape
@ -431,7 +433,11 @@ def subagent_log_optimization_result(
old = original_code.get(path, "")
new = new_code.get(path, "")
if old != new:
diff = unified_diff_strings(old, new, fromfile=str(path), tofile=str(path))
display_path = str(path)
if project_root is not None:
with contextlib.suppress(ValueError):
display_path = str(path.relative_to(project_root))
diff = unified_diff_strings(old, new, fromfile=display_path, tofile=display_path)
if diff:
diff_parts.append(diff)
@ -478,18 +484,17 @@ def subagent_log_optimization_result(
if new_code[path] != original_code.get(path, ""):
xml.append(f' <optimized-code file="{escape(str(path))}">{escape(new_code[path])}</optimized-code>')
xml.append(" <action>")
xml.append(" First, present a short summary: function name, performance improvement, and explanation.")
xml.append(" If you are running inside an IDE (VS Code, Cursor, Windsurf, etc.):")
xml.append(" For each optimized-code element, use the Write tool to write the optimized code to the file.")
xml.append(" This will surface the diff in the IDE's native diff view for the user to accept or reject.")
xml.append(" If you are running in a terminal without an IDE connection:")
xml.append(" Present this optimization to the user using AskUserQuestion.")
xml.append(" Show the function name, performance improvement, and explanation as the question text.")
xml.append(" Provide two options: 'Apply' (write the optimized code to the file) and 'Reject' (do nothing).")
xml.append(" 1. Review the diff and optimized code yourself. Write a brief assessment (2-3 sentences) covering:")
xml.append(" - Whether the optimization is correct and preserves behavior")
xml.append(" - What technique was used (e.g. algorithmic improvement, caching, vectorization)")
xml.append(" - Any risks or trade-offs (e.g. increased memory usage, reduced readability)")
xml.append(" 2. Present this optimization to the user using AskUserQuestion.")
xml.append(" Show the function name, performance improvement, and your review as the question text.")
xml.append(" Provide two options: 'Apply' (write the optimized code to the file) and 'Reject' (do nothing).")
xml.append(
" Put the full diff in the 'Apply' option's markdown preview so the user can review the exact changes."
" Put the full diff in the 'Apply' option's markdown preview so the user can review the exact changes."
)
xml.append(" If the user chooses 'Apply', write the content from optimized-code to the corresponding file.")
xml.append(" 3. If the user chooses 'Apply', write the content from optimized-code to the corresponding file.")
xml.append(" </action>")
xml.append("</codeflash-optimization>")

View file

@ -1,3 +1,18 @@
from codeflash.code_utils.compat import LF
from codeflash.version import __version__ as version
CODEFLASH_LOGO: str = (
f"{LF}"
r" _ ___ _ _ " + f"{LF}"
r" | | / __)| | | | " + f"{LF}"
r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}"
r" / ___) / _ \ / || | / _ )| __)| | / _ | /___)| || \ " + f"{LF}"
r"( (___ | |_| |( (_| |( (/ / | | | |( ( | ||___ || | | |" + f"{LF}"
r" \____) \___/ \____| \____)|_| |_| \_||_|(___/ |_| |_|" + f"{LF}"
f"{('v' + version).rjust(66)}{LF}"
f"{LF}"
)
SPINNER_TYPES = {
"point",
"simpleDots",

View file

@ -0,0 +1,951 @@
from __future__ import annotations
import sys
from enum import Enum, auto
from pathlib import Path
from typing import Any
import click
import git
import inquirer
import tomlkit
from git import Repo
from rich.panel import Panel
from rich.text import Text
from codeflash.api.aiservice import AiServiceClient
from codeflash.api.cfapi import setup_github_actions
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.init_config import CodeflashTheme
from codeflash.code_utils.compat import LF
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name
from codeflash.code_utils.github_utils import get_github_secrets_page_url
from codeflash.telemetry.posthog_cf import ph
class DependencyManager(Enum):
"""Python dependency managers."""
PIP = auto()
POETRY = auto()
UV = auto()
UNKNOWN = auto()
def install_github_actions(override_formatter_check: bool = False) -> None:
try:
config, _config_file_path = parse_config_file(override_formatter_check=override_formatter_check)
ph("cli-github-actions-install-started")
try:
repo = Repo(config["module_root"], search_parent_directories=True)
except git.InvalidGitRepositoryError:
click.echo(
"Skipping GitHub action installation for continuous optimization because you're not in a git repository."
)
return
git_root = Path(repo.git.rev_parse("--show-toplevel"))
workflows_path = git_root / ".github" / "workflows"
optimize_yaml_path = workflows_path / "codeflash.yaml"
# Check if workflow file already exists locally BEFORE showing prompt
if optimize_yaml_path.exists():
# Workflow file already exists locally - skip prompt and setup
already_exists_message = "✅ GitHub Actions workflow file already exists.\n\n"
already_exists_message += "No changes needed - your repository is already configured!"
already_exists_panel = Panel(
Text(already_exists_message, style="green", justify="center"),
title="✅ Already Configured",
border_style="bright_green",
)
console.print(already_exists_panel)
console.print()
logger.info(
"[github_workflow.py:install_github_actions] Workflow file already exists locally, skipping setup"
)
return
# Get repository information for API call
git_remote = config.get("git_remote", "origin")
# get_current_branch handles detached HEAD and other edge cases internally
try:
base_branch = get_current_branch(repo)
except Exception as e:
logger.warning(
f"[github_workflow.py:install_github_actions] Could not determine current branch: {e}. Falling back to 'main'."
)
base_branch = "main"
# Generate workflow content
from importlib.resources import files
benchmark_mode = False
benchmarks_root = config.get("benchmarks_root", "").strip()
if benchmarks_root and benchmarks_root != "":
benchmark_panel = Panel(
Text(
"📊 Benchmark Mode Available\n\n"
"I noticed you've configured a benchmarks_root in your config. "
"Benchmark mode will show the performance impact of Codeflash's optimizations on your benchmarks.",
style="cyan",
),
title="📊 Benchmark Mode",
border_style="bright_cyan",
)
console.print(benchmark_panel)
console.print()
benchmark_questions = [
inquirer.Confirm("benchmark_mode", message="Run GitHub Actions in benchmark mode?", default=True)
]
benchmark_answers = inquirer.prompt(benchmark_questions, theme=CodeflashTheme())
benchmark_mode = benchmark_answers["benchmark_mode"] if benchmark_answers else False
# Show prompt only if workflow doesn't exist locally
actions_panel = Panel(
Text(
"🤖 GitHub Actions Setup\n\n"
"GitHub Actions will automatically optimize your code in every pull request. "
"This is the recommended way to use Codeflash for continuous optimization.",
style="blue",
),
title="🤖 Continuous Optimization",
border_style="bright_blue",
)
console.print(actions_panel)
console.print()
creation_questions = [
inquirer.Confirm(
"confirm_creation",
message="Set up GitHub Actions for continuous optimization? We'll open a pull request with the workflow file.",
default=True,
)
]
creation_answers = inquirer.prompt(creation_questions, theme=CodeflashTheme())
if not creation_answers or not creation_answers["confirm_creation"]:
skip_panel = Panel(
Text("⏩️ Skipping GitHub Actions setup.", style="yellow"), title="⏩️ Skipped", border_style="yellow"
)
console.print(skip_panel)
ph("cli-github-workflow-skipped")
return
ph(
"cli-github-optimization-confirm-workflow-creation",
{"confirm_creation": creation_answers["confirm_creation"]},
)
# Generate workflow content AFTER user confirmation
logger.info("[github_workflow.py:install_github_actions] User confirmed, generating workflow content...")
# Select the appropriate workflow template based on project language
project_language = detect_project_language_for_workflow(Path.cwd())
if project_language == "java":
workflow_template = "codeflash-optimize-java.yaml"
elif project_language in ("javascript", "typescript"):
workflow_template = "codeflash-optimize-js.yaml"
else:
workflow_template = "codeflash-optimize.yaml"
optimize_yml_content = (
files("codeflash").joinpath("cli_cmds", "workflows", workflow_template).read_text(encoding="utf-8")
)
materialized_optimize_yml_content = generate_dynamic_workflow_content(
optimize_yml_content, config, git_root, benchmark_mode
)
workflows_path.mkdir(parents=True, exist_ok=True)
pr_created_via_api = False
pr_url = None
try:
owner, repo_name = get_repo_owner_and_name(repo, git_remote)
except Exception as e:
logger.error(f"[github_workflow.py:install_github_actions] Failed to get repository owner and name: {e}")
# Fall back to local file creation
workflows_path.mkdir(parents=True, exist_ok=True)
with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file:
optimize_yml_file.write(materialized_optimize_yml_content)
workflow_success_panel = Panel(
Text(
f"✅ Created GitHub action workflow at {optimize_yaml_path}\n\n"
"Your repository is now configured for continuous optimization!",
style="green",
justify="center",
),
title="🎉 Workflow Created!",
border_style="bright_green",
)
console.print(workflow_success_panel)
console.print()
else:
# Try to create PR via API
try:
# Workflow file doesn't exist on remote or content differs - proceed with PR creation
console.print("Creating PR with GitHub Actions workflow...")
logger.info(
f"[github_workflow.py:install_github_actions] Calling setup_github_actions API for {owner}/{repo_name} on branch {base_branch}"
)
response = setup_github_actions(
owner=owner,
repo=repo_name,
base_branch=base_branch,
workflow_content=materialized_optimize_yml_content,
)
if response.status_code == 200:
response_data = response.json()
if response_data.get("success"):
pr_url = response_data.get("pr_url")
if pr_url:
pr_created_via_api = True
success_message = f"✅ PR created: {pr_url}\n\n"
success_message += "Your repository is now configured for continuous optimization!"
workflow_success_panel = Panel(
Text(success_message, style="green", justify="center"),
title="🎉 Workflow PR Created!",
border_style="bright_green",
)
console.print(workflow_success_panel)
console.print()
logger.info(
f"[github_workflow.py:install_github_actions] Successfully created PR #{response_data.get('pr_number')} for {owner}/{repo_name}"
)
else:
# File already exists with same content
pr_created_via_api = True # Mark as handled (no PR needed)
already_exists_message = "✅ Workflow file already exists with the same content.\n\n"
already_exists_message += "No changes needed - your repository is already configured!"
already_exists_panel = Panel(
Text(already_exists_message, style="green", justify="center"),
title="✅ Already Configured",
border_style="bright_green",
)
console.print(already_exists_panel)
console.print()
else:
# API returned success=false, extract error details
error_data = response_data
error_msg = error_data.get("error", "Unknown error")
error_message = error_data.get("message", error_msg)
error_help = error_data.get("help", "")
installation_url = error_data.get("installation_url")
# For permission errors, don't fall back - show a focused message and abort early
if response.status_code == 403:
logger.error(
f"[github_workflow.py:install_github_actions] Permission denied for {owner}/{repo_name}"
)
# Extract installation_url if available, otherwise use default
installation_url_403 = error_data.get(
"installation_url", "https://github.com/apps/codeflash-ai/installations/select_target"
)
permission_error_panel = Panel(
Text(
"❌ Access Denied\n\n"
f"The GitHub App may not be installed on {owner}/{repo_name}, or it doesn't have the required permissions.\n\n"
"💡 To fix this:\n"
"1. Install the CodeFlash GitHub App on your repository\n"
"2. Ensure the app has 'Contents: write', 'Workflows: write', and 'Pull requests: write' permissions\n"
"3. Make sure you have write access to the repository\n\n"
f"🔗 Install GitHub App: {installation_url_403}",
style="red",
),
title="❌ Setup Failed",
border_style="red",
)
console.print(permission_error_panel)
console.print()
click.echo(
f"Please install the CodeFlash GitHub App and ensure it has the required permissions.{LF}"
f"Visit: {installation_url_403}{LF}"
)
apologize_and_exit()
# Show detailed error panel for all other errors
error_panel_text = f"{error_msg}\n\n{error_message}\n"
if error_help:
error_panel_text += f"\n💡 {error_help}\n"
if installation_url:
error_panel_text += f"\n🔗 Install GitHub App: {installation_url}"
error_panel = Panel(
Text(error_panel_text, style="red"), title="❌ Setup Failed", border_style="red"
)
console.print(error_panel)
console.print()
# For GitHub App not installed, don't fall back - show clear instructions
if response.status_code == 404 and installation_url:
logger.error(
f"[github_workflow.py:install_github_actions] GitHub App not installed on {owner}/{repo_name}"
)
click.echo(
f"Please install the CodeFlash GitHub App on your repository to continue.{LF}"
f"Visit: {installation_url}{LF}"
)
return
# For other errors, fall back to local file creation
raise Exception(error_message) # noqa: TRY002, TRY301
else:
# API call returned non-200 status, try to parse error response
try:
error_data = response.json()
error_msg = error_data.get("error", "API request failed")
error_message = error_data.get("message", f"API returned status {response.status_code}")
error_help = error_data.get("help", "")
installation_url = error_data.get("installation_url")
# For permission errors, don't fall back - show a focused message and abort early
if response.status_code == 403:
logger.error(
f"[github_workflow.py:install_github_actions] Permission denied for {owner}/{repo_name}"
)
# Extract installation_url if available, otherwise use default
installation_url_403 = error_data.get(
"installation_url", "https://github.com/apps/codeflash-ai/installations/select_target"
)
permission_error_panel = Panel(
Text(
"❌ Access Denied\n\n"
f"The GitHub App may not be installed on {owner}/{repo_name}, or it doesn't have the required permissions.\n\n"
"💡 To fix this:\n"
"1. Install the CodeFlash GitHub App on your repository\n"
"2. Ensure the app has 'Contents: write', 'Workflows: write', and 'Pull requests: write' permissions\n"
"3. Make sure you have write access to the repository\n\n"
f"🔗 Install GitHub App: {installation_url_403}",
style="red",
),
title="❌ Setup Failed",
border_style="red",
)
console.print(permission_error_panel)
console.print()
click.echo(
f"Please install the CodeFlash GitHub App and ensure it has the required permissions.{LF}"
f"Visit: {installation_url_403}{LF}"
)
apologize_and_exit()
# Show detailed error panel for all other errors
error_panel_text = f"{error_msg}\n\n{error_message}\n"
if error_help:
error_panel_text += f"\n💡 {error_help}\n"
if installation_url:
error_panel_text += f"\n🔗 Install GitHub App: {installation_url}"
error_panel = Panel(
Text(error_panel_text, style="red"), title="❌ Setup Failed", border_style="red"
)
console.print(error_panel)
console.print()
# For GitHub App not installed, don't fall back - show clear instructions
if response.status_code == 404 and installation_url:
logger.error(
f"[github_workflow.py:install_github_actions] GitHub App not installed on {owner}/{repo_name}"
)
click.echo(
f"Please install the CodeFlash GitHub App on your repository to continue.{LF}"
f"Visit: {installation_url}{LF}"
)
return
# For authentication errors, don't fall back
if response.status_code == 401:
logger.error(
f"[github_workflow.py:install_github_actions] Authentication failed for {owner}/{repo_name}"
)
click.echo(f"Authentication failed. Please check your API key and try again.{LF}")
return
# For other errors, fall back to local file creation
raise Exception(error_message) # noqa: TRY002
except (ValueError, KeyError) as parse_error:
# Couldn't parse error response, use generic message
status_msg = f"API returned status {response.status_code}"
raise Exception(status_msg) from parse_error # noqa: TRY002
except Exception as api_error:
# Fall back to local file creation if API call fails (for non-critical errors)
logger.warning(
f"[github_workflow.py:install_github_actions] API call failed, falling back to local file creation: {api_error}"
)
workflows_path.mkdir(parents=True, exist_ok=True)
with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file:
optimize_yml_file.write(materialized_optimize_yml_content)
workflow_success_panel = Panel(
Text(
f"✅ Created GitHub action workflow at {optimize_yaml_path}\n\n"
"Your repository is now configured for continuous optimization!",
style="green",
justify="center",
),
title="🎉 Workflow Created!",
border_style="bright_green",
)
console.print(workflow_success_panel)
console.print()
# Show appropriate message based on whether PR was created via API
if pr_created_via_api:
if pr_url:
click.echo(
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}"
f"Once you merge the PR, the workflow will be active.{LF}"
)
else:
# File already exists
click.echo(
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}"
f"The workflow is ready to use.{LF}"
)
else:
# Fell back to local file creation
click.echo(
f"Please edit, commit and push this GitHub actions file to your repo, and you're all set!{LF}"
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}"
)
# Show GitHub secrets setup panel (needed in both cases - PR created via API or local file)
try:
existing_api_key = get_codeflash_api_key()
except OSError:
existing_api_key = None
# GitHub secrets setup panel - always shown since secrets are required for the workflow to work
secrets_message = (
"🔐 Next Step: Add API Key as GitHub Secret\n\n"
"You'll need to add your CODEFLASH_API_KEY as a secret to your GitHub repository.\n\n"
"📋 Steps:\n"
"1. Press Enter to open your repo's secrets page\n"
"2. Click 'New repository secret'\n"
"3. Add your API key with the variable name CODEFLASH_API_KEY"
)
if existing_api_key:
secrets_message += f"\n\n🔑 Your API Key: {existing_api_key}"
secrets_panel = Panel(
Text(secrets_message, style="blue"), title="🔐 GitHub Secrets Setup", border_style="bright_blue"
)
console.print(secrets_panel)
console.print(f"\n📍 Press Enter to open: {get_github_secrets_page_url(repo)}")
console.input()
click.launch(get_github_secrets_page_url(repo))
# Post-launch message panel
launch_panel = Panel(
Text(
"🐙 I opened your GitHub secrets page!\n\n"
"Note: If you see a 404, you probably don't have access to this repo's secrets. "
"Ask a repo admin to add it for you, or (not recommended) you can temporarily "
"hard-code your API key into the workflow file.",
style="cyan",
),
title="🌐 Browser Opened",
border_style="bright_cyan",
)
console.print(launch_panel)
click.pause()
console.print()
ph("cli-github-workflow-created")
except KeyboardInterrupt:
apologize_and_exit()
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager:
"""Determine which dependency manager is being used based on pyproject.toml contents."""
cwd = Path.cwd()
if (cwd / "poetry.lock").exists():
return DependencyManager.POETRY
if (cwd / "uv.lock").exists():
return DependencyManager.UV
if "tool" not in pyproject_data:
return DependencyManager.PIP
tool_section = pyproject_data["tool"]
# Check for poetry
if "poetry" in tool_section:
return DependencyManager.POETRY
# Check for uv
if any(key.startswith("uv") for key in tool_section):
return DependencyManager.UV
# Look for pip-specific markers
if "pip" in tool_section or "setuptools" in tool_section:
return DependencyManager.PIP
return DependencyManager.UNKNOWN
def get_codeflash_github_action_command(dep_manager: DependencyManager) -> str:
"""Generate the appropriate codeflash command based on the dependency manager."""
if dep_manager == DependencyManager.POETRY:
return """|
poetry env use python
poetry run codeflash"""
if dep_manager == DependencyManager.UV:
return "uv run codeflash"
# PIP or UNKNOWN
return "codeflash"
def get_dependency_installation_commands(dep_manager: DependencyManager) -> str:
"""Generate commands to install the dependency manager and project dependencies."""
if dep_manager == DependencyManager.POETRY:
return """|
python -m pip install --upgrade pip
pip install poetry
poetry install --all-extras"""
if dep_manager == DependencyManager.UV:
return """|
uv sync --all-extras
uv pip install --upgrade codeflash"""
# PIP or UNKNOWN
return """|
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install codeflash"""
def get_dependency_manager_installation_string(dep_manager: DependencyManager) -> str:
py_version = sys.version_info
python_version_string = f"'{py_version.major}.{py_version.minor}'"
if dep_manager == DependencyManager.UV:
return """name: 🐍 Setup UV
uses: astral-sh/setup-uv@v6
with:
enable-cache: true"""
return f"""name: 🐍 Set up Python
uses: actions/setup-python@v5
with:
python-version: {python_version_string}"""
def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
if toml_path.parent == git_root:
return ""
working_dir = str(toml_path.parent.relative_to(git_root))
return f"""defaults:
run:
working-directory: ./{working_dir}"""
def detect_project_language_for_workflow(project_root: Path) -> str:
"""Detect the primary language of the project for workflow generation.
Returns: 'python', 'javascript', 'typescript', or 'java'
"""
# Check for Java build tools first (pom.xml or build.gradle)
if (
(project_root / "pom.xml").exists()
or (project_root / "build.gradle").exists()
or (project_root / "build.gradle.kts").exists()
):
return "java"
# Check for TypeScript config
if (project_root / "tsconfig.json").exists():
return "typescript"
# Check for JavaScript/TypeScript indicators
has_package_json = (project_root / "package.json").exists()
has_pyproject = (project_root / "pyproject.toml").exists()
if has_package_json and not has_pyproject:
# Pure JS/TS project
return "javascript"
if has_pyproject and not has_package_json:
# Pure Python project
return "python"
# Both exist - count files to determine primary language
js_count = 0
py_count = 0
for file in project_root.rglob("*"):
if file.is_file():
suffix = file.suffix.lower()
if suffix in {".js", ".jsx", ".ts", ".tsx", ".mjs", ".cjs"}:
js_count += 1
elif suffix == ".py":
py_count += 1
if js_count > py_count:
return "javascript"
return "python"
def collect_repo_files_for_workflow(git_root: Path) -> dict[str, Any]:
"""Collect important repository files and directory structure for workflow generation.
:param git_root: Root directory of the git repository
:return: Dictionary with 'files' (path -> content) and 'directory_structure' (nested dict)
"""
# Important files to collect with contents
important_files = [
"pyproject.toml",
"requirements.txt",
"requirements-dev.txt",
"requirements/requirements.txt",
"requirements/dev.txt",
"Pipfile",
"Pipfile.lock",
"poetry.lock",
"uv.lock",
"setup.py",
"setup.cfg",
"Dockerfile",
"docker-compose.yml",
"docker-compose.yaml",
"Makefile",
"README.md",
"README.rst",
]
# Also collect GitHub workflows
workflows_path = git_root / ".github" / "workflows"
if workflows_path.exists():
important_files.extend(
str(workflow_file.relative_to(git_root)) for workflow_file in workflows_path.glob("*.yml")
)
important_files.extend(
str(workflow_file.relative_to(git_root)) for workflow_file in workflows_path.glob("*.yaml")
)
files_dict: dict[str, str] = {}
max_file_size = 8 * 1024 # 8KB limit per file
for file_path_str in important_files:
file_path = git_root / file_path_str
if file_path.exists() and file_path.is_file():
try:
content = file_path.read_text(encoding="utf-8", errors="ignore")
# Limit file size
if len(content) > max_file_size:
content = content[:max_file_size] + "\n... (truncated)"
files_dict[file_path_str] = content
except Exception as e:
logger.warning(
f"[github_workflow.py:collect_repo_files_for_workflow] Failed to read {file_path_str}: {e}"
)
# Collect 2-level directory structure
directory_structure: dict[str, Any] = {}
try:
for item in sorted(git_root.iterdir()):
if item.name.startswith(".") and item.name not in [".github", ".git"]:
continue # Skip hidden files/folders except .github
if item.is_dir():
# Level 1: directory
dir_dict: dict[str, Any] = {"type": "directory", "contents": {}}
try:
# Level 2: contents of directory
for subitem in sorted(item.iterdir()):
if subitem.name.startswith("."):
continue
if subitem.is_dir():
dir_dict["contents"][subitem.name] = {"type": "directory"}
else:
dir_dict["contents"][subitem.name] = {"type": "file"}
except PermissionError:
pass # Skip directories we can't read
directory_structure[item.name] = dir_dict
elif item.is_file():
directory_structure[item.name] = {"type": "file"}
except Exception as e:
logger.warning(
f"[github_workflow.py:collect_repo_files_for_workflow] Error collecting directory structure: {e}"
)
return {"files": files_dict, "directory_structure": directory_structure}
def generate_dynamic_workflow_content(
optimize_yml_content: str, config: dict[str, Any], git_root: Path, benchmark_mode: bool = False
) -> str:
"""Generate workflow content with dynamic steps from AI service, falling back to static template."""
# First, do the basic replacements that are always needed
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
# Detect project language
project_language = detect_project_language_for_workflow(Path.cwd())
# For JavaScript/TypeScript and Java projects, use static template customization
# (AI-generated steps are currently Python-only)
if project_language in ("javascript", "typescript", "java"):
return customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode)
# Python project - try AI-generated steps
toml_path = Path.cwd() / "pyproject.toml"
try:
with toml_path.open(encoding="utf8") as pyproject_file:
pyproject_data = tomlkit.parse(pyproject_file.read())
except FileNotFoundError:
click.echo(
f"I couldn't find a pyproject.toml in the current directory.{LF}"
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file."
)
apologize_and_exit()
working_dir = get_github_action_working_directory(toml_path, git_root)
optimize_yml_content = optimize_yml_content.replace("{{ working_directory }}", working_dir)
# Try to generate dynamic steps using AI service
try:
repo_data = collect_repo_files_for_workflow(git_root)
# Prepare codeflash config for AI
codeflash_config = {
"module_root": config["module_root"],
"tests_root": config.get("tests_root", ""),
"benchmark_mode": benchmark_mode,
}
aiservice_client = AiServiceClient()
dynamic_steps = aiservice_client.generate_workflow_steps(
repo_files=repo_data["files"],
directory_structure=repo_data["directory_structure"],
codeflash_config=codeflash_config,
)
if dynamic_steps:
# Replace the entire steps section with AI-generated steps
# Find the steps section in the template
steps_start = optimize_yml_content.find(" steps:")
if steps_start != -1:
# Find the end of the steps section (next line at same or less indentation)
lines = optimize_yml_content.split("\n")
steps_start_line = optimize_yml_content[:steps_start].count("\n")
steps_end_line = len(lines)
# Find where steps section ends (next job or end of file)
for i in range(steps_start_line + 1, len(lines)):
line = lines[i]
# Stop if we hit a line that's not indented (new job or end of jobs)
if line and not line.startswith(" ") and not line.startswith("\t"):
steps_end_line = i
break
# Extract steps content from AI response (remove "steps:" prefix if present)
steps_content = dynamic_steps
if steps_content.startswith("steps:"):
# Remove "steps:" and leading newline
steps_content = steps_content[6:].lstrip("\n")
# Ensure proper indentation (8 spaces for steps section in YAML)
indented_steps = []
for line in steps_content.split("\n"):
if line.strip():
# If line doesn't start with enough spaces, add them
if not line.startswith(" "):
indented_steps.append(" " + line)
else:
# Preserve existing indentation but ensure minimum 8 spaces
current_indent = len(line) - len(line.lstrip())
if current_indent < 8:
indented_steps.append(" " * 8 + line.lstrip())
else:
indented_steps.append(line)
else:
indented_steps.append("")
# Add codeflash command step at the end
dep_manager = determine_dependency_manager(pyproject_data)
codeflash_cmd = get_codeflash_github_action_command(dep_manager)
if benchmark_mode:
codeflash_cmd += " --benchmark"
# Format codeflash command properly
if "|" in codeflash_cmd:
# Multi-line command
cmd_lines = codeflash_cmd.split("\n")
codeflash_step = f" - name: ⚡Codeflash Optimization\n run: {cmd_lines[0].strip()}"
for cmd_line in cmd_lines[1:]:
codeflash_step += f"\n {cmd_line.strip()}"
else:
codeflash_step = f" - name: ⚡Codeflash Optimization\n run: {codeflash_cmd}"
indented_steps.append(codeflash_step)
# Reconstruct the workflow
return "\n".join([*lines[:steps_start_line], " steps:", *indented_steps, *lines[steps_end_line:]])
logger.warning(
"[github_workflow.py:generate_dynamic_workflow_content] Could not find steps section in template"
)
else:
logger.debug(
"[github_workflow.py:generate_dynamic_workflow_content] AI service returned no steps, falling back to static"
)
except Exception as e:
logger.warning(
f"[github_workflow.py:generate_dynamic_workflow_content] Error generating dynamic workflow, falling back to static: {e}"
)
# Fallback to static template
return customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode)
def customize_codeflash_yaml_content(
optimize_yml_content: str, config: dict[str, Any], git_root: Path, benchmark_mode: bool = False
) -> str:
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
# Detect project language
project_language = detect_project_language_for_workflow(Path.cwd())
if project_language == "java":
return _customize_java_workflow_content(optimize_yml_content, git_root)
if project_language in ("javascript", "typescript"):
return _customize_js_workflow_content(optimize_yml_content, git_root, benchmark_mode)
# Python project (default)
return _customize_python_workflow_content(optimize_yml_content, git_root, benchmark_mode)
def _customize_python_workflow_content(optimize_yml_content: str, git_root: Path, benchmark_mode: bool = False) -> str:
"""Customize workflow content for Python projects."""
# Get dependency installation commands
toml_path = Path.cwd() / "pyproject.toml"
try:
with toml_path.open(encoding="utf8") as pyproject_file:
pyproject_data = tomlkit.parse(pyproject_file.read())
except FileNotFoundError:
click.echo(
f"I couldn't find a pyproject.toml in the current directory.{LF}"
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file."
)
apologize_and_exit()
working_dir = get_github_action_working_directory(toml_path, git_root)
optimize_yml_content = optimize_yml_content.replace("{{ working_directory }}", working_dir)
dep_manager = determine_dependency_manager(pyproject_data)
python_depmanager_installation = get_dependency_manager_installation_string(dep_manager)
optimize_yml_content = optimize_yml_content.replace(
"{{ setup_runtime_environment }}", python_depmanager_installation
)
install_deps_cmd = get_dependency_installation_commands(dep_manager)
optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
# Add codeflash command
codeflash_cmd = get_codeflash_github_action_command(dep_manager)
if benchmark_mode:
codeflash_cmd += " --benchmark"
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
def _customize_js_workflow_content(optimize_yml_content: str, git_root: Path, benchmark_mode: bool = False) -> str:
"""Customize workflow content for JavaScript/TypeScript projects."""
from codeflash.cli_cmds.init_javascript import (
determine_js_package_manager,
get_js_codeflash_install_step,
get_js_codeflash_run_command,
get_js_dependency_installation_commands,
get_js_runtime_setup_steps,
is_codeflash_dependency,
)
project_root = Path.cwd()
package_json_path = project_root / "package.json"
if not package_json_path.exists():
click.echo(
f"I couldn't find a package.json in the current directory.{LF}"
f"Please run `npm init` or create a package.json file first."
)
apologize_and_exit()
# Determine working directory relative to git root
if project_root == git_root:
working_dir = ""
else:
rel_path = str(project_root.relative_to(git_root))
working_dir = f"""defaults:
run:
working-directory: ./{rel_path}"""
optimize_yml_content = optimize_yml_content.replace("{{ working_directory }}", working_dir)
# Determine package manager and codeflash dependency status
pkg_manager = determine_js_package_manager(project_root)
codeflash_is_dep = is_codeflash_dependency(project_root)
# Setup runtime environment (Node.js/Bun)
runtime_setup = get_js_runtime_setup_steps(pkg_manager)
optimize_yml_content = optimize_yml_content.replace("{{ setup_runtime_steps }}", runtime_setup)
# Install dependencies
install_deps_cmd = get_js_dependency_installation_commands(pkg_manager)
optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
# Install codeflash step (only if not a dependency)
install_codeflash = get_js_codeflash_install_step(pkg_manager, is_dependency=codeflash_is_dep)
optimize_yml_content = optimize_yml_content.replace("{{ install_codeflash_step }}", install_codeflash)
# Codeflash run command
codeflash_cmd = get_js_codeflash_run_command(pkg_manager, is_dependency=codeflash_is_dep)
if benchmark_mode:
codeflash_cmd += " --benchmark"
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
def _customize_java_workflow_content(optimize_yml_content: str, git_root: Path) -> str:
"""Customize workflow content for Java projects."""
from codeflash.cli_cmds.init_java import (
JavaBuildTool,
detect_java_build_tool,
get_java_dependency_installation_commands,
)
project_root = Path.cwd()
build_tool = detect_java_build_tool(project_root)
# Working directory
if project_root == git_root:
working_dir = ""
else:
rel_path = str(project_root.relative_to(git_root))
working_dir = f"""defaults:
run:
working-directory: ./{rel_path}"""
optimize_yml_content = optimize_yml_content.replace("{{ working_directory }}", working_dir)
# Build tool cache
if build_tool == JavaBuildTool.GRADLE:
optimize_yml_content = optimize_yml_content.replace("{{ java_build_tool }}", "gradle")
else:
optimize_yml_content = optimize_yml_content.replace("{{ java_build_tool }}", "maven")
# Install dependencies command
install_deps = get_java_dependency_installation_commands(build_tool)
return optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps)

View file

@ -0,0 +1,210 @@
from __future__ import annotations
import os
import click
import git
import inquirer
from rich.panel import Panel
from rich.text import Text
from codeflash.api.cfapi import get_user_id, is_github_app_installed_on_repo
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.console import console
from codeflash.cli_cmds.init_config import CodeflashTheme
from codeflash.cli_cmds.oauth_handler import perform_oauth_signin
from codeflash.code_utils.compat import LF
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name
from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc
from codeflash.either import is_successful
from codeflash.telemetry.posthog_cf import ph
CF_THEME = CodeflashTheme()
class CFAPIKeyType(click.ParamType):
name = "cfapi-key"
def convert(self, value: str, param: click.Parameter | None, ctx: click.Context | None) -> str | None:
value = value.strip()
if not value.startswith("cf-") and value != "":
self.fail(
f"That key [{value}] seems to be invalid. It should start with a 'cf-' prefix. Please try again.",
param,
ctx,
)
return value
# Returns True if the user entered a new API key, False if they used an existing one
def prompt_api_key() -> bool:
"""Prompt user for API key via OAuth or manual entry."""
# Check for existing API key
try:
existing_api_key = get_codeflash_api_key()
except OSError:
existing_api_key = None
if existing_api_key:
display_key = f"{existing_api_key[:3]}****{existing_api_key[-4:]}"
api_key_panel = Panel(
Text(
f"🔑 I found a CODEFLASH_API_KEY in your environment [{display_key}]!\n\n"
"✅ You're all set with API authentication!",
style="green",
justify="center",
),
title="🔑 API Key Found",
border_style="bright_green",
)
console.print(api_key_panel)
console.print()
return False
# Prompt for authentication method
auth_choices = ["🔐 Login in with Codeflash", "🔑 Use Codeflash API key"]
questions = [
inquirer.List(
"auth_method",
message="How would you like to authenticate?",
choices=auth_choices,
default=auth_choices[0],
carousel=True,
)
]
answers = inquirer.prompt(questions, theme=CF_THEME)
if not answers:
apologize_and_exit()
method = answers["auth_method"]
if method == auth_choices[1]:
enter_api_key_and_save_to_rc()
ph("cli-new-api-key-entered")
return True
# Perform OAuth sign-in
api_key = perform_oauth_signin()
if not api_key:
apologize_and_exit()
# Save API key
shell_rc_path = get_shell_rc_path()
if not shell_rc_path.exists() and os.name == "nt":
shell_rc_path.touch()
click.echo(f"✅ Created {shell_rc_path}")
result = save_api_key_to_rc(api_key)
if is_successful(result):
click.echo(result.unwrap())
click.echo("✅ Signed in successfully and API key saved!")
else:
click.echo(result.failure())
click.pause()
os.environ["CODEFLASH_API_KEY"] = api_key
ph("cli-oauth-signin-completed")
return True
def enter_api_key_and_save_to_rc() -> None:
browser_launched = False
api_key = ""
while api_key == "":
api_key = click.prompt(
f"Enter your Codeflash API key{' [or press Enter to open your API key page]' if not browser_launched else ''}",
hide_input=False,
default="",
type=CFAPIKeyType(),
show_default=False,
).strip()
if api_key:
break
if not browser_launched:
click.echo(
f"Opening your Codeflash API key page. Grab a key from there!{LF}"
"You can also open this link manually: https://app.codeflash.ai/app/apikeys"
)
click.launch("https://app.codeflash.ai/app/apikeys")
browser_launched = True # This does not work on remote consoles
shell_rc_path = get_shell_rc_path()
if not shell_rc_path.exists() and os.name == "nt":
# On Windows, create the appropriate file (PowerShell .ps1 or CMD .bat) in the user's home directory
shell_rc_path.parent.mkdir(parents=True, exist_ok=True)
shell_rc_path.touch()
click.echo(f"✅ Created {shell_rc_path}")
get_user_id(api_key=api_key) # Used to verify whether the API key is valid.
result = save_api_key_to_rc(api_key)
if is_successful(result):
click.echo(result.unwrap())
else:
click.echo(result.failure())
click.pause()
os.environ["CODEFLASH_API_KEY"] = api_key
def install_github_app(git_remote: str) -> None:
try:
git_repo = git.Repo(search_parent_directories=True)
except git.InvalidGitRepositoryError:
click.echo("Skipping GitHub app installation because you're not in a git repository.")
return
if git_remote not in get_git_remotes(git_repo):
click.echo(f"Skipping GitHub app installation, remote ({git_remote}) does not exist in this repository.")
return
owner, repo = get_repo_owner_and_name(git_repo, git_remote)
if is_github_app_installed_on_repo(owner, repo, suppress_errors=True):
click.echo(
f"🐙 Looks like you've already installed the Codeflash GitHub app on this repository ({owner}/{repo})! Continuing…"
)
else:
try:
click.prompt(
f"Finally, you'll need to install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}"
f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}"
f"Please, press ENTER to open the app installation page{LF}",
default="",
type=click.STRING,
prompt_suffix=">>> ",
show_default=False,
)
click.launch("https://github.com/apps/codeflash-ai/installations/select_target")
click.prompt(
f"Please, press ENTER once you've finished installing the github app from https://github.com/apps/codeflash-ai/installations/select_target{LF}",
default="",
type=click.STRING,
prompt_suffix=">>> ",
show_default=False,
)
count = 2
while not is_github_app_installed_on_repo(owner, repo, suppress_errors=True):
if count == 0:
click.echo(
f"❌ It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo}.{LF}"
f"You won't be able to create PRs with Codeflash until you install the app.{LF}"
f"In the meantime you can make local only optimizations by using the '--no-pr' flag with codeflash.{LF}"
)
break
click.prompt(
f"❌ It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo}.{LF}"
f"Please install it from https://github.com/apps/codeflash-ai/installations/select_target {LF}"
f"Please, press ENTER to continue once you've finished installing the github app…{LF}",
default="",
type=click.STRING,
prompt_suffix=">>> ",
show_default=False,
)
count -= 1
except (KeyboardInterrupt, EOFError, click.exceptions.Abort):
# leave empty line for the next prompt to be properly rendered
click.echo()

View file

@ -0,0 +1,289 @@
from __future__ import annotations
import os
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Any, Optional, Union
import click
import inquirer
import inquirer.themes
import tomlkit
from pydantic.dataclasses import dataclass
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.console import console
from codeflash.code_utils.compat import LF
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.env_utils import check_formatter_installed
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.telemetry.posthog_cf import ph
@dataclass(frozen=True)
class CLISetupInfo:
"""Setup info for Python projects."""
module_root: str
tests_root: str
benchmarks_root: Union[str, None]
ignore_paths: list[str]
formatter: Union[str, list[str]]
git_remote: str
enable_telemetry: bool
@dataclass(frozen=True)
class VsCodeSetupInfo:
"""Setup info for VSCode extension initialization."""
module_root: str
tests_root: str
formatter: Union[str, list[str]]
# Custom theme for better UX
class CodeflashTheme(inquirer.themes.Default): # type: ignore[misc]
def __init__(self) -> None:
super().__init__()
self.Question.mark_color = inquirer.themes.term.yellow
self.Question.brackets_color = inquirer.themes.term.bright_blue
self.Question.default_color = inquirer.themes.term.bright_cyan
self.List.selection_color = inquirer.themes.term.bright_blue
self.Checkbox.selection_color = inquirer.themes.term.bright_blue
self.Checkbox.selected_icon = ""
self.Checkbox.unselected_icon = ""
# common sections between normal mode and lsp mode
class CommonSections(Enum):
module_root = "module_root"
tests_root = "tests_root"
formatter_cmds = "formatter_cmds"
def get_toml_key(self) -> str:
return self.value.replace("_", "-")
ignore_subdirs = {
"venv",
"node_modules",
"dist",
"build",
"build_temp",
"build_scripts",
"env",
"logs",
"tmp",
"__pycache__",
}
@lru_cache(maxsize=1)
def get_valid_subdirs(current_dir: Optional[Path] = None) -> list[str]:
path_str = str(current_dir) if current_dir else "."
return [
entry.name
for entry in os.scandir(path_str)
if entry.is_dir() and not entry.name.startswith((".", "__")) and entry.name not in ignore_subdirs
]
def get_suggestions(section: CommonSections) -> tuple[list[str], Optional[str]]:
valid_subdirs = get_valid_subdirs()
if section == CommonSections.module_root:
return [d for d in valid_subdirs if d != "tests"], None
if section == CommonSections.tests_root:
default = "tests" if "tests" in valid_subdirs else None
return valid_subdirs, default
if section == CommonSections.formatter_cmds:
return ["disabled", "ruff", "black"], "disabled"
msg = f"Unknown section: {section}"
raise ValueError(msg)
def config_found(pyproject_toml_path: Union[str, Path]) -> tuple[bool, str]:
pyproject_toml_path = Path(pyproject_toml_path)
if not pyproject_toml_path.exists():
return False, f"Configuration file not found: {pyproject_toml_path}"
if not pyproject_toml_path.is_file():
return False, f"Configuration file is not a file: {pyproject_toml_path}"
if pyproject_toml_path.suffix != ".toml":
return False, f"Configuration file is not a .toml file: {pyproject_toml_path}"
return True, ""
def is_valid_pyproject_toml(pyproject_toml_path: Union[str, Path]) -> tuple[bool, dict[str, Any] | None, str]:
pyproject_toml_path = Path(pyproject_toml_path)
try:
config, _ = parse_config_file(pyproject_toml_path)
except Exception as e:
return False, None, f"Failed to parse configuration: {e}"
module_root = config.get("module_root")
if not module_root:
return False, config, "Missing required field: 'module_root'"
if not Path(module_root).is_dir():
return False, config, f"Invalid 'module_root': directory does not exist at {module_root}"
tests_root = config.get("tests_root")
if not tests_root:
return False, config, "Missing required field: 'tests_root'"
if not Path(tests_root).is_dir():
return False, config, f"Invalid 'tests_root': directory does not exist at {tests_root}"
return True, config, ""
def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
"""Check if the current directory contains a valid pyproject.toml file with codeflash config.
If it does, ask the user if they want to re-configure it.
"""
from rich.prompt import Confirm
pyproject_toml_path = Path.cwd() / "pyproject.toml"
found, _ = config_found(pyproject_toml_path)
if not found:
return True, None
valid, config, _message = is_valid_pyproject_toml(pyproject_toml_path)
if not valid:
# needs to be re-configured
return True, None
return Confirm.ask(
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?",
default=False,
show_default=True,
), config
def get_formatter_cmds(formatter: str) -> list[str]:
if formatter == "black":
return ["black $file"]
if formatter == "ruff":
return ["ruff check --exit-zero --fix $file", "ruff format $file"]
if formatter == "other":
click.echo(
"🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code."
)
return ["your-formatter $file"]
if formatter in {"don't use a formatter", "disabled"}:
return ["disabled"]
if " && " in formatter:
return formatter.split(" && ")
return [formatter]
# Create or update the pyproject.toml file with the Codeflash dependency & configuration
def configure_pyproject_toml(
setup_info: Union[VsCodeSetupInfo, CLISetupInfo], config_file: Optional[Path] = None
) -> bool:
for_vscode = isinstance(setup_info, VsCodeSetupInfo)
toml_path = config_file or Path.cwd() / "pyproject.toml"
try:
with toml_path.open(encoding="utf8") as pyproject_file:
pyproject_data = tomlkit.parse(pyproject_file.read())
except FileNotFoundError:
click.echo(
f"I couldn't find a pyproject.toml in the current directory.{LF}"
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file."
)
return False
codeflash_section = tomlkit.table()
codeflash_section.add(tomlkit.comment("All paths are relative to this pyproject.toml's directory."))
if for_vscode:
for section in CommonSections:
if hasattr(setup_info, section.value):
codeflash_section[section.get_toml_key()] = getattr(setup_info, section.value)
elif isinstance(setup_info, CLISetupInfo):
codeflash_section["module-root"] = setup_info.module_root
codeflash_section["tests-root"] = setup_info.tests_root
codeflash_section["ignore-paths"] = setup_info.ignore_paths
if not setup_info.enable_telemetry:
codeflash_section["disable-telemetry"] = not setup_info.enable_telemetry
if setup_info.git_remote not in ["", "origin"]:
codeflash_section["git-remote"] = setup_info.git_remote
formatter = setup_info.formatter
formatter_cmds = formatter if isinstance(formatter, list) else get_formatter_cmds(formatter)
check_formatter_installed(formatter_cmds, exit_on_failure=False)
codeflash_section["formatter-cmds"] = formatter_cmds
# Add the 'codeflash' section, ensuring 'tool' section exists
tool_section = pyproject_data.get("tool", tomlkit.table())
if for_vscode:
# merge the existing codeflash section, instead of overwriting it
existing_codeflash = tool_section.get("codeflash", tomlkit.table())
for key, value in codeflash_section.items():
existing_codeflash[key] = value
tool_section["codeflash"] = existing_codeflash
else:
tool_section["codeflash"] = codeflash_section
pyproject_data["tool"] = tool_section
with toml_path.open("w", encoding="utf8") as pyproject_file:
pyproject_file.write(tomlkit.dumps(pyproject_data))
click.echo(f"Added Codeflash configuration to {toml_path}")
click.echo()
return True
def create_empty_pyproject_toml(pyproject_toml_path: Path) -> None:
ph("cli-create-pyproject-toml")
lsp_mode = is_LSP_enabled()
# Define a minimal pyproject.toml content
new_pyproject_toml = tomlkit.document()
new_pyproject_toml["tool"] = {"codeflash": {}}
try:
pyproject_toml_path.write_text(tomlkit.dumps(new_pyproject_toml), encoding="utf8")
# Check if the pyproject.toml file was created
if pyproject_toml_path.exists() and not lsp_mode:
from rich.panel import Panel
from rich.text import Text
success_panel = Panel(
Text(
f"✅ Created a pyproject.toml file at {pyproject_toml_path}\n\n"
"Your project is now ready for Codeflash configuration!",
style="green",
justify="center",
),
title="🎉 Success!",
border_style="bright_green",
)
console.print(success_panel)
console.print("\n📍 Press any key to continue...")
console.input()
ph("cli-created-pyproject-toml")
except OSError:
click.echo("❌ Failed to create pyproject.toml. Please check your disk permissions and available space.")
apologize_and_exit()
def ask_for_telemetry() -> bool:
"""Prompt the user to enable or disable telemetry."""
from rich.prompt import Confirm
return Confirm.ask(
"⚡️ Help us improve Codeflash by sharing anonymous usage data (e.g. errors encountered)?",
default=True,
show_default=True,
)

View file

@ -57,7 +57,7 @@ class JavaSetupInfo:
def _get_theme():
"""Get the CodeflashTheme - imported lazily to avoid circular imports."""
from codeflash.cli_cmds.cmd_init import CodeflashTheme
from codeflash.cli_cmds.init_config import CodeflashTheme
return CodeflashTheme()
@ -161,7 +161,8 @@ def detect_java_test_framework(project_root: Path) -> str:
def init_java_project() -> None:
"""Initialize Codeflash for a Java project."""
from codeflash.cli_cmds.cmd_init import install_github_actions, install_github_app, prompt_api_key
from codeflash.cli_cmds.github_workflow import install_github_actions
from codeflash.cli_cmds.init_auth import install_github_app, prompt_api_key
lang_panel = Panel(
Text(
@ -244,7 +245,7 @@ def collect_java_setup_info() -> JavaSetupInfo:
"""Collect setup information for Java projects."""
from rich.prompt import Confirm
from codeflash.cli_cmds.cmd_init import ask_for_telemetry
from codeflash.cli_cmds.init_config import ask_for_telemetry
curdir = Path.cwd()

View file

@ -9,7 +9,10 @@ import sys
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Union
if TYPE_CHECKING:
from codeflash.cli_cmds.init_config import CodeflashTheme
import click
import inquirer
@ -68,9 +71,9 @@ class JSSetupInfo:
# Import theme from cmd_init to avoid duplication
def _get_theme():
def _get_theme() -> CodeflashTheme:
"""Get the CodeflashTheme - imported lazily to avoid circular imports."""
from codeflash.cli_cmds.cmd_init import CodeflashTheme
from codeflash.cli_cmds.init_config import CodeflashTheme
return CodeflashTheme()
@ -217,7 +220,8 @@ def get_package_install_command(project_root: Path, package: str, dev: bool = Tr
def init_js_project(language: ProjectLanguage, *, skip_confirm: bool = False, skip_api_key: bool = False) -> None:
"""Initialize Codeflash for a JavaScript/TypeScript project."""
from codeflash.cli_cmds.cmd_init import install_github_actions, install_github_app, prompt_api_key
from codeflash.cli_cmds.github_workflow import install_github_actions
from codeflash.cli_cmds.init_auth import install_github_app, prompt_api_key
lang_name = "TypeScript" if language == ProjectLanguage.TYPESCRIPT else "JavaScript"
@ -332,7 +336,7 @@ def collect_js_setup_info(language: ProjectLanguage, *, skip_confirm: bool = Fal
Uses auto-detection for most settings and only asks for overrides if needed.
When skip_confirm is True, uses all auto-detected defaults without prompting.
"""
from codeflash.cli_cmds.cmd_init import ask_for_telemetry, get_valid_subdirs
from codeflash.cli_cmds.init_config import ask_for_telemetry, get_valid_subdirs
from codeflash.code_utils.config_js import (
detect_formatter,
detect_module_root,
@ -697,22 +701,9 @@ def get_js_codeflash_install_step(pkg_manager: JsPackageManager, *, is_dependenc
# Codeflash will be installed with other dependencies
return ""
# Need to install codeflash separately
if pkg_manager == JsPackageManager.BUN:
return """- name: 📥 Install Codeflash
run: bun add -g codeflash"""
if pkg_manager == JsPackageManager.PNPM:
return """- name: 📥 Install Codeflash
run: pnpm add -g codeflash"""
if pkg_manager == JsPackageManager.YARN:
return """- name: 📥 Install Codeflash
run: yarn global add codeflash"""
# NPM or UNKNOWN
# Install codeflash via uv (Python + uv are set up in the workflow)
return """- name: 📥 Install Codeflash
run: npm install -g codeflash"""
run: uv tool install codeflash"""
def get_js_codeflash_run_command(pkg_manager: JsPackageManager, *, is_dependency: bool) -> str:

View file

@ -27,6 +27,12 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: 🐍 Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: 📦 Setup uv
uses: astral-sh/setup-uv@v4
{{ setup_runtime_steps }}
- name: 📦 Install Dependencies
run: {{ install_dependencies_command }}

View file

@ -27,7 +27,7 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
- {{ setup_python_dependency_manager }}
- {{ setup_runtime_environment }}
- name: 📦 Install Dependencies
run: {{ install_dependencies_command }}
- name: Codeflash Optimization

View file

@ -21,6 +21,7 @@ TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
COVERAGE_THRESHOLD = 60.0
MIN_TESTCASE_PASSED_THRESHOLD = 6
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
MAX_TEST_REPAIR_CYCLES = 2
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
# pytest loop stability

View file

@ -59,6 +59,10 @@ def get_git_diff(
logger.debug(f"deleted lines: {del_line_no}")
if not add_line_no and del_line_no:
# Deletion-only changes: use hunk target start lines so we can still
# match the surrounding function in the current (target) file.
add_line_no = [hunk.target_start for hunk in patched_file]
change_list[file_path] = add_line_no
return change_list

View file

@ -682,7 +682,7 @@ class LanguageSupport(Protocol):
def compare_test_results(
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
) -> tuple[bool, list]:
) -> tuple[bool, list[Any]]:
"""Compare test results between original and candidate code.
Args:
@ -699,7 +699,7 @@ class LanguageSupport(Protocol):
@property
def function_optimizer_class(self) -> type:
"""Return the FunctionOptimizer subclass for this language."""
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.function_optimizer import FunctionOptimizer
return FunctionOptimizer
@ -876,7 +876,7 @@ class LanguageSupport(Protocol):
"""Instrument source code before line profiling."""
...
def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:
def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict[str, Any]:
"""Parse line profiler output."""
...
@ -888,7 +888,7 @@ class LanguageSupport(Protocol):
project_root: Path,
function_to_optimize: FunctionToOptimize,
function_to_optimize_ast: Any,
) -> tuple[dict, str]:
) -> tuple[dict[str, Any], str]:
"""Generate concolic tests for a function.
Default implementation returns empty results. Override for languages
@ -980,7 +980,7 @@ class LanguageSupport(Protocol):
...
def convert_parents_to_tuple(parents: list | tuple) -> tuple[FunctionParent, ...]:
def convert_parents_to_tuple(parents: list[Any] | tuple[Any, ...]) -> tuple[FunctionParent, ...]:
"""Convert a list of parent objects to a tuple of FunctionParent.
Args:

View file

@ -6,14 +6,13 @@ via the LanguageSupport protocol.
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.languages.base import FunctionFilterCriteria, Language
if TYPE_CHECKING:
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import LanguageSupport
from codeflash.models.models import CodeStringsMarkdown
@ -26,45 +25,35 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin
from codeflash.languages.current import is_python
file_to_code_context = optimized_code.file_to_path()
relative_path_str = str(relative_path)
module_optimized_code = file_to_code_context.get(relative_path_str)
if module_optimized_code is None:
# Fallback: if there's only one code block with None file path,
# use it regardless of the expected path (the AI server doesn't always include file paths)
if "None" in file_to_code_context and len(file_to_code_context) == 1:
module_optimized_code = file_to_code_context["None"]
logger.debug(f"Using code block with None file_path for {relative_path}")
else:
# Fallback: try to match by just the filename (for Java/JS where the AI
# might return just the class name like "Algorithms.java" instead of
# the full path like "src/main/java/com/example/Algorithms.java")
target_filename = relative_path.name
for file_path_str, code in file_to_code_context.items():
if file_path_str:
if file_path_str.endswith(target_filename) and (
len(file_path_str) == len(target_filename)
or file_path_str[-len(target_filename) - 1] in ("/", "\\")
):
module_optimized_code = code
logger.debug(f"Matched {file_path_str} to {relative_path} by filename")
break
module_optimized_code = file_to_code_context.get(str(relative_path))
if module_optimized_code is not None:
return module_optimized_code
if module_optimized_code is None:
# Also try matching if there's only one code file, but ONLY for non-Python
# languages where path matching is less strict.
if len(file_to_code_context) == 1 and not is_python():
only_key = next(iter(file_to_code_context.keys()))
module_optimized_code = file_to_code_context[only_key]
logger.debug(f"Using only code block {only_key} for {relative_path}")
else:
if logger.isEnabledFor(logger.level):
logger.warning(
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
"re-check your 'markdown code structure'"
f"existing files are {file_to_code_context.keys()}"
)
module_optimized_code = ""
return module_optimized_code
# Fallback 1: single code block with no file path
if "None" in file_to_code_context and len(file_to_code_context) == 1:
logger.debug(f"Using code block with None file_path for {relative_path}")
return file_to_code_context["None"]
# Fallback 2: match by filename (basename) — the LLM sometimes returns a different
# directory prefix but the correct filename
target_name = relative_path.name
basename_matches = [
code for path, code in file_to_code_context.items() if path != "None" and Path(path).name == target_name
]
if len(basename_matches) == 1:
logger.debug(f"Using basename-matched code block for {relative_path}")
return basename_matches[0]
# Fallback 3: single code block for non-Python (AI often returns one block with wrong path)
if len(file_to_code_context) == 1 and not is_python():
only_key = next(iter(file_to_code_context.keys()))
logger.debug(f"Using only code block {only_key} for {relative_path}")
return file_to_code_context[only_key]
logger.warning(
f"Optimized code not found for {relative_path}, existing files are {list(file_to_code_context.keys())}"
)
return ""
def replace_function_definitions_for_language(

View file

@ -42,6 +42,7 @@ from codeflash.code_utils.code_utils import (
extract_unique_errors,
file_name_from_test_module_name,
get_run_tmp_file,
module_name_from_file_path,
normalize_by_max,
restore_conftest,
unified_diff_strings,
@ -49,6 +50,7 @@ from codeflash.code_utils.code_utils import (
from codeflash.code_utils.config_consts import (
COVERAGE_THRESHOLD,
INDIVIDUAL_TESTCASE_TIMEOUT,
MAX_TEST_REPAIR_CYCLES,
MIN_CORRECT_CANDIDATES,
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
REFINED_CANDIDATE_RANKING_WEIGHTS,
@ -126,6 +128,7 @@ if TYPE_CHECKING:
FunctionCalledInTest,
FunctionSource,
TestDiff,
TestFileReview,
)
from codeflash.verification.verification_utils import TestConfig
@ -682,11 +685,9 @@ class FunctionOptimizer:
logger.info(f"Generated test {i + 1}/{count_tests}:")
# Use correct extension based on language
test_ext = self.language_support.get_test_file_suffix()
code_print(
generated_test.generated_original_test_source,
file_name=f"test_{i + 1}{test_ext}",
language=self.function_to_optimize.language,
)
# Show the raw LLM output when available, otherwise the post-processed source
display_source = generated_test.raw_generated_test_source or generated_test.generated_original_test_source
code_print(display_source, file_name=f"test_{i + 1}{test_ext}", language=self.function_to_optimize.language)
if concolic_test_str:
logger.info(f"Generated test {count_tests}/{count_tests}:")
code_print(concolic_test_str, language=self.function_to_optimize.language)
@ -768,6 +769,18 @@ class FunctionOptimizer:
optimizations_set, function_references = optimization_result.unwrap()
precomputed_behavioral: tuple[TestResults, CoverageData | None] | None = None
if generated_tests.generated_tests and self.args.testgen_review:
review_result = self.review_and_repair_tests(
generated_tests=generated_tests, code_context=code_context, original_helper_code=original_helper_code
)
if not is_successful(review_result):
return Failure(review_result.failure())
generated_tests, review_behavioral, review_coverage = review_result.unwrap()
if review_behavioral is not None:
precomputed_behavioral = (review_behavioral, review_coverage)
# Full baseline (behavioral + benchmarking) runs once on the final approved tests
baseline_setup_result = self.setup_and_establish_baseline(
code_context=code_context,
original_helper_code=original_helper_code,
@ -776,6 +789,7 @@ class FunctionOptimizer:
generated_perf_test_paths=generated_perf_test_paths,
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
original_conftest_content=original_conftest_content,
precomputed_behavioral=precomputed_behavioral,
)
if not is_successful(baseline_setup_result):
@ -1009,16 +1023,18 @@ class FunctionOptimizer:
runtimes_list.append(new_best_opt.runtime)
if len(optimization_ids) > 1:
future_ranking = self.executor.submit(
ai_service_client.generate_ranking,
diffs=diff_strs,
optimization_ids=optimization_ids,
speedups=speedups_list,
trace_id=self.get_trace_id(exp_type),
function_references=function_references,
)
concurrent.futures.wait([future_ranking])
ranking = future_ranking.result()
ranking = None
if not is_subagent_mode():
future_ranking = self.executor.submit(
ai_service_client.generate_ranking,
diffs=diff_strs,
optimization_ids=optimization_ids,
speedups=speedups_list,
trace_id=self.get_trace_id(exp_type),
function_references=function_references,
)
concurrent.futures.wait([future_ranking])
ranking = future_ranking.result()
if ranking:
min_key = ranking[0]
else:
@ -1541,7 +1557,7 @@ class FunctionOptimizer:
functions_by_file: dict[Path, set[str]] = defaultdict(set)
functions_by_file[self.function_to_optimize.file_path].add(self.function_to_optimize.qualified_name)
for helper in code_context.helper_functions:
if helper.definition_type != "class":
if helper.definition_type in ("function", None):
functions_by_file[helper.file_path].add(helper.qualified_name)
return functions_by_file
@ -1730,6 +1746,7 @@ class FunctionOptimizer:
generated_test_source,
instrumented_behavior_test_source,
instrumented_perf_test_source,
raw_generated_test_source,
test_behavior_path,
test_perf_path,
) = res
@ -1738,6 +1755,7 @@ class FunctionOptimizer:
generated_original_test_source=generated_test_source,
instrumented_behavior_test_source=instrumented_behavior_test_source,
instrumented_perf_test_source=instrumented_perf_test_source,
raw_generated_test_source=raw_generated_test_source,
behavior_file_path=test_behavior_path,
perf_file_path=test_perf_path,
)
@ -1836,6 +1854,7 @@ class FunctionOptimizer:
generated_perf_test_paths: list[Path],
instrumented_unittests_created_for_function: set[Path],
original_conftest_content: str | None,
precomputed_behavioral: tuple[TestResults, CoverageData | None] | None = None,
) -> Result[
tuple[str, dict[str, set[FunctionCalledInTest]], OriginalCodeBaseline, list[str], dict[Path, set[str]]], str
]:
@ -1846,19 +1865,13 @@ class FunctionOptimizer:
for key in set(self.function_to_tests) | set(function_to_concolic_tests)
}
# Get a dict of file_path_to_classes of fto and helpers_of_fto
file_path_to_helper_classes = defaultdict(set)
for function_source in code_context.helper_functions:
if (
function_source.qualified_name != self.function_to_optimize.qualified_name
and "." in function_source.qualified_name
):
file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0])
file_path_to_helper_classes = self.build_helper_classes_map(code_context)
baseline_result = self.establish_original_code_baseline(
code_context=code_context,
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
precomputed_behavioral=precomputed_behavioral,
)
console.rule()
@ -1894,6 +1907,368 @@ class FunctionOptimizer:
)
)
def display_repaired_functions(
self, generated_tests: GeneratedTestsList, reviews: list[TestFileReview], original_sources: dict[int, str]
) -> None:
"""Display repaired functions. Override in language subclasses for richer diff output."""
for review in reviews:
for f in review.functions_to_repair:
logger.info(f"Repaired: {f.function_name}")
def build_helper_classes_map(self, code_context: CodeOptimizationContext) -> dict[Path, set[str]]:
"""Build a mapping of file paths to helper class names from code context."""
file_path_to_helper_classes: dict[Path, set[str]] = defaultdict(set)
for function_source in code_context.helper_functions:
if (
function_source.qualified_name != self.function_to_optimize.qualified_name
and "." in function_source.qualified_name
):
file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0])
return file_path_to_helper_classes
def run_behavioral_validation(
self,
code_context: CodeOptimizationContext,
original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]],
) -> tuple[TestResults, CoverageData | None] | None:
"""Run behavioral tests only. Returns (results, coverage) or None if no tests ran."""
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
if self.function_to_optimize.is_async:
self.instrument_async_for_mode(TestingMode.BEHAVIOR)
try:
self.instrument_capture(file_path_to_helper_classes)
behavioral_results, coverage_results = self.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
enable_coverage=True,
code_context=code_context,
)
finally:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
if isinstance(behavioral_results, TestResults) and behavioral_results:
return behavioral_results, coverage_results
return None
def review_and_repair_tests(
self,
generated_tests: GeneratedTestsList,
code_context: CodeOptimizationContext,
original_helper_code: dict[Path, str],
) -> Result[tuple[GeneratedTestsList, TestResults | None, CoverageData | None], str]:
"""Run behavioral tests, review quality per-function, repair flagged functions.
Flow (up to MAX_TEST_REPAIR_CYCLES):
behavioral collect failures AI review passing functions repair flagged loop
No benchmarking runs here only behavioral validation.
Returns (generated_tests, behavioral_results, coverage) where behavioral/coverage are
non-None when the last cycle passed with no repairs (results can be reused by baseline).
"""
file_path_to_helper_classes = self.build_helper_classes_map(code_context)
behavioral_results: TestResults | None = None
coverage_results: CoverageData | None = None
previous_repair_errors: dict[int, dict[str, str]] = {}
# Apply token limit to function source (same progressive fallback as optimization/testgen context)
function_source_for_prompt = self.function_to_optimize_source_code
if encoded_tokens_len(function_source_for_prompt) > OPTIMIZATION_CONTEXT_TOKEN_LIMIT:
logger.debug("Function source exceeds token limit for review, extracting function only")
func = self.function_to_optimize
source_lines = self.function_to_optimize_source_code.splitlines(keepends=True)
func_start = (func.doc_start_line or func.starting_line or 1) - 1
func_end = func.ending_line or len(source_lines)
function_source_for_prompt = "".join(source_lines[func_start:func_end])
max_cycles = getattr(self.args, "testgen_review_turns", None) or MAX_TEST_REPAIR_CYCLES
for cycle in range(max_cycles):
with progress_bar("Running generated tests to validate quality..."):
validation = self.run_behavioral_validation(
code_context, original_helper_code, file_path_to_helper_classes
)
if validation is None:
return Failure("Generated tests failed behavioral validation.")
behavioral_results, coverage_results = validation
failed_by_file: dict[Path, list[str]] = defaultdict(list)
for result in behavioral_results.test_results:
if (
result.test_type == TestType.GENERATED_REGRESSION
and not result.did_pass
and result.id.test_function_name
):
failed_by_file[result.file_name].append(result.id.test_fn_qualified_name())
test_failure_messages = behavioral_results.test_failures or {}
tests_for_review = []
for i, gt in enumerate(generated_tests.generated_tests):
failed_fns = failed_by_file.get(gt.behavior_file_path, [])
failure_details = {fn: test_failure_messages[fn] for fn in failed_fns if fn in test_failure_messages}
tests_for_review.append(
{
"test_source": gt.raw_generated_test_source or gt.generated_original_test_source,
"test_index": i,
"failed_test_functions": failed_fns,
"failure_messages": failure_details,
}
)
coverage_summary = ""
coverage_details: dict[str, Any] | None = None
if coverage_results and coverage_results.coverage is not None:
coverage_summary = f"{coverage_results.coverage:.1f}%"
mc = coverage_results.main_func_coverage
coverage_details = {
"coverage_percentage": coverage_results.coverage,
"threshold_percentage": COVERAGE_THRESHOLD,
"function_start_line": self.function_to_optimize.starting_line,
"main_function": {
"name": mc.name,
"coverage": mc.coverage,
"executed_lines": sorted(mc.executed_lines),
"unexecuted_lines": sorted(mc.unexecuted_lines),
"executed_branches": mc.executed_branches,
"unexecuted_branches": mc.unexecuted_branches,
},
}
dc = coverage_results.dependent_func_coverage
if dc:
coverage_details["dependent_function"] = {
"name": dc.name,
"coverage": dc.coverage,
"executed_lines": sorted(dc.executed_lines),
"unexecuted_lines": sorted(dc.unexecuted_lines),
"executed_branches": dc.executed_branches,
"unexecuted_branches": dc.unexecuted_branches,
}
console.rule()
with progress_bar("Reviewing generated tests for quality issues..."):
review_results = self.aiservice_client.review_generated_tests(
tests=tests_for_review,
function_source_code=function_source_for_prompt,
function_name=self.function_to_optimize.function_name,
trace_id=self.function_trace_id,
coverage_summary=coverage_summary,
coverage_details=coverage_details,
language=self.function_to_optimize.language,
)
all_to_repair = [r for r in review_results if r.functions_to_repair]
if not all_to_repair:
console.print(Panel("[green]All generated tests passed quality review[/green]", border_style="green"))
console.rule()
return Success((generated_tests, behavioral_results, coverage_results))
issues_tree = Tree("[bold]Quality issues found[/bold]")
total_issues = 0
for review in all_to_repair:
for f in review.functions_to_repair:
reason_str = f" — [dim]{f.reason}[/dim]" if f.reason else ""
issues_tree.add(f"[yellow]{f.function_name}[/yellow]{reason_str}")
total_issues += 1
console.print(Panel(issues_tree, title=f"Test Review (cycle {cycle + 1})", border_style="yellow"))
any_repaired = False
repaired_files = 0
# Snapshot all sources before repair so we can show diffs and revert on failure
original_sources: dict[int, str] = {
r.test_index: generated_tests.generated_tests[r.test_index].generated_original_test_source
for r in all_to_repair
}
pre_repair_snapshots: dict[int, tuple[str, str, str, str | None]] = {
r.test_index: (
generated_tests.generated_tests[r.test_index].generated_original_test_source,
generated_tests.generated_tests[r.test_index].instrumented_behavior_test_source,
generated_tests.generated_tests[r.test_index].instrumented_perf_test_source,
generated_tests.generated_tests[r.test_index].raw_generated_test_source,
)
for r in all_to_repair
}
repaired_indices: set[int] = set()
with progress_bar(f"Repairing {total_issues} flagged test function(s)..."):
for review in all_to_repair:
gt = generated_tests.generated_tests[review.test_index]
ph(
"cli-testgen-repair",
{
"test_index": review.test_index,
"cycle": cycle + 1,
"functions": [f.function_name for f in review.functions_to_repair],
},
)
test_module_path = Path(
module_name_from_file_path(gt.behavior_file_path, self.test_cfg.tests_project_rootdir)
)
repair_result = self.aiservice_client.repair_generated_tests(
test_source=gt.generated_original_test_source,
functions_to_repair=review.functions_to_repair,
function_source_code=function_source_for_prompt,
module_source_code=self.function_to_optimize_source_code,
function_to_optimize=self.function_to_optimize,
helper_function_names=[f.fully_qualified_name for f in code_context.helper_functions],
module_path=Path(self.original_module_path),
test_module_path=test_module_path,
test_framework=self.test_cfg.test_framework,
test_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
trace_id=self.function_trace_id,
language=self.function_to_optimize.language,
coverage_details=coverage_details,
previous_repair_errors=previous_repair_errors.get(review.test_index),
)
if repair_result is None:
logger.debug(f"Repair failed for test {review.test_index}, keeping original")
continue
repaired_source, behavior_source, perf_source = repair_result
raw_repaired_source = repaired_source
repaired_source, behavior_source, perf_source = (
self.language_support.process_generated_test_strings(
generated_test_source=repaired_source,
instrumented_behavior_test_source=behavior_source,
instrumented_perf_test_source=perf_source,
function_to_optimize=self.function_to_optimize,
test_path=gt.behavior_file_path,
test_cfg=self.test_cfg,
project_module_system=None,
)
)
gt.generated_original_test_source = repaired_source
gt.instrumented_behavior_test_source = behavior_source
gt.instrumented_perf_test_source = perf_source
gt.raw_generated_test_source = raw_repaired_source
gt.behavior_file_path.write_text(behavior_source, encoding="utf8")
gt.perf_file_path.write_text(perf_source, encoding="utf8")
any_repaired = True
repaired_files += 1
repaired_indices.add(review.test_index)
if not any_repaired:
logger.warning("All repair API calls failed; proceeding with unrepaired tests")
break
generated_tests = self.language_support.postprocess_generated_tests(
generated_tests,
test_framework=self.test_cfg.test_framework,
project_root=self.project_root,
source_file_path=self.function_to_optimize.file_path,
)
console.print(f" [green]Repaired {repaired_files} test file(s)[/green]")
with progress_bar("Re-validating repaired tests..."):
validation = self.run_behavioral_validation(
code_context, original_helper_code, file_path_to_helper_classes
)
if validation is None:
for idx in repaired_indices:
gt = generated_tests.generated_tests[idx]
orig_source, orig_behavior, orig_perf, orig_raw = pre_repair_snapshots[idx]
gt.generated_original_test_source = orig_source
gt.instrumented_behavior_test_source = orig_behavior
gt.instrumented_perf_test_source = orig_perf
gt.raw_generated_test_source = orig_raw
gt.behavior_file_path.write_text(orig_behavior, encoding="utf8")
gt.perf_file_path.write_text(orig_perf, encoding="utf8")
return Failure("Repaired tests failed behavioral validation.")
behavioral_results, coverage_results = validation
# Collect failing and all test function names per file
still_failing_by_file: dict[Path, set[str]] = defaultdict(set)
all_fns_by_file: dict[Path, set[str]] = defaultdict(set)
for result in behavioral_results.test_results:
if result.test_type == TestType.GENERATED_REGRESSION and result.id.test_function_name:
fn_name = result.id.test_fn_qualified_name()
all_fns_by_file[result.file_name].add(fn_name)
if not result.did_pass:
still_failing_by_file[result.file_name].add(fn_name)
reverted_indices = set()
partially_fixed_indices = set()
removed_fns_by_index: dict[int, set[str]] = {}
for idx in repaired_indices:
gt = generated_tests.generated_tests[idx]
failing_fns = still_failing_by_file.get(gt.behavior_file_path)
if not failing_fns:
continue
all_fns_in_file = all_fns_by_file.get(gt.behavior_file_path, set())
if failing_fns >= all_fns_in_file and all_fns_in_file:
# ALL functions fail → full revert to pre-repair state
orig_source, orig_behavior, orig_perf, orig_raw = pre_repair_snapshots[idx]
gt.generated_original_test_source = orig_source
gt.instrumented_behavior_test_source = orig_behavior
gt.instrumented_perf_test_source = orig_perf
gt.raw_generated_test_source = orig_raw
gt.behavior_file_path.write_text(orig_behavior, encoding="utf8")
gt.perf_file_path.write_text(orig_perf, encoding="utf8")
reverted_indices.add(idx)
else:
# Partial failure → remove only failing functions, keep passing ones
fns_to_remove = list(failing_fns)
removed_fns_by_index[idx] = set(fns_to_remove)
gt.generated_original_test_source = self.language_support.remove_test_functions(
gt.generated_original_test_source, fns_to_remove
)
gt.instrumented_behavior_test_source = self.language_support.remove_test_functions(
gt.instrumented_behavior_test_source, fns_to_remove
)
gt.instrumented_perf_test_source = self.language_support.remove_test_functions(
gt.instrumented_perf_test_source, fns_to_remove
)
if gt.raw_generated_test_source is not None:
gt.raw_generated_test_source = self.language_support.remove_test_functions(
gt.raw_generated_test_source, fns_to_remove
)
gt.behavior_file_path.write_text(gt.instrumented_behavior_test_source, encoding="utf8")
gt.perf_file_path.write_text(gt.instrumented_perf_test_source, encoding="utf8")
partially_fixed_indices.add(idx)
# Show diffs only for repairs that survived re-validation
successful_repairs = [r for r in all_to_repair if r.test_index not in reverted_indices]
if successful_repairs:
self.display_repaired_functions(generated_tests, successful_repairs, original_sources)
modified_indices = reverted_indices | partially_fixed_indices
if modified_indices:
messages = []
if reverted_indices:
messages.append(f"reverted {len(reverted_indices)} test file(s)")
if partially_fixed_indices:
messages.append(f"removed failing functions from {len(partially_fixed_indices)} test file(s)")
console.print(f" [yellow]{', '.join(messages).capitalize()} after repair[/yellow]")
# Collect error messages from failed functions so the next cycle can learn
revalidation_failures = behavioral_results.test_failures or {}
for idx in modified_indices:
gt = generated_tests.generated_tests[idx]
removed_fns = removed_fns_by_index.get(idx, set())
errors_for_file: dict[str, str] = {}
for result in behavioral_results.test_results:
if (
result.file_name == gt.behavior_file_path
and result.test_type == TestType.GENERATED_REGRESSION
and not result.did_pass
and result.id.test_function_name
):
fn_name = result.id.test_fn_qualified_name()
if fn_name not in removed_fns:
errors_for_file[fn_name] = revalidation_failures.get(fn_name, "Test failed")
if errors_for_file:
previous_repair_errors[idx] = errors_for_file
# Invalidate behavioral results since files were modified
behavioral_results = None
coverage_results = None
console.rule()
return Success((generated_tests, behavioral_results, coverage_results))
def find_and_process_best_optimization(
self,
optimizations_set: OptimizationSet,
@ -2026,6 +2401,25 @@ class FunctionOptimizer:
code_context: CodeOptimizationContext,
function_references: str,
) -> None:
if is_subagent_mode():
subagent_log_optimization_result(
function_name=explanation.function_name,
file_path=explanation.file_path,
perf_improvement_line=explanation.perf_improvement_line,
original_runtime_ns=explanation.original_runtime_ns,
best_runtime_ns=explanation.best_runtime_ns,
raw_explanation=explanation.raw_explanation_message,
original_code=original_code_combined,
new_code=new_code_combined,
review="",
test_results=explanation.winning_behavior_test_results,
project_root=self.project_root,
)
mark_optimization_success(
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
)
return
coverage_message = (
original_code_baseline.coverage_results.build_message()
if original_code_baseline.coverage_results
@ -2173,20 +2567,7 @@ class FunctionOptimizer:
self.optimization_review = opt_review_result.review
# Display the reviewer result to the user
if is_subagent_mode():
subagent_log_optimization_result(
function_name=new_explanation.function_name,
file_path=new_explanation.file_path,
perf_improvement_line=new_explanation.perf_improvement_line,
original_runtime_ns=new_explanation.original_runtime_ns,
best_runtime_ns=new_explanation.best_runtime_ns,
raw_explanation=new_explanation.raw_explanation_message,
original_code=original_code_combined,
new_code=new_code_combined,
review=opt_review_result.review,
test_results=new_explanation.winning_behavior_test_results,
)
elif opt_review_result.review:
if opt_review_result.review:
review_display = {
"high": ("[bold green]High[/bold green]", "green", "Recommended to merge"),
"medium": ("[bold yellow]Medium[/bold yellow]", "yellow", "Review recommended before merging"),
@ -2278,6 +2659,7 @@ class FunctionOptimizer:
code_context: CodeOptimizationContext,
original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]],
precomputed_behavioral: tuple[TestResults, CoverageData | None] | None = None,
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
# For the original function - run the tests and get the runtime, plus coverage
@ -2285,34 +2667,40 @@ class FunctionOptimizer:
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
if self.function_to_optimize.is_async:
self.instrument_async_for_mode(TestingMode.BEHAVIOR)
if precomputed_behavioral is not None:
# Reuse behavioral results from the review cycle (no repairs were needed)
behavioral_results, coverage_results = precomputed_behavioral
logger.debug("[PIPELINE] Reusing behavioral results from test review cycle (no repairs were made)")
else:
if self.function_to_optimize.is_async:
self.instrument_async_for_mode(TestingMode.BEHAVIOR)
# Instrument codeflash capture
with progress_bar("Running tests to establish original code behavior..."):
try:
self.instrument_capture(file_path_to_helper_classes)
total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE
logger.debug(f"[PIPELINE] Establishing baseline with {len(self.test_files)} test files")
for idx, tf in enumerate(self.test_files):
logger.debug(
f"[PIPELINE] Test file {idx}: behavior={tf.instrumented_behavior_file_path}, perf={tf.benchmarking_file_path}"
# Instrument codeflash capture
with progress_bar("Running tests to establish original code behavior..."):
try:
self.instrument_capture(file_path_to_helper_classes)
logger.debug(f"[PIPELINE] Establishing baseline with {len(self.test_files)} test files")
for idx, tf in enumerate(self.test_files):
logger.debug(
f"[PIPELINE] Test file {idx}: behavior={tf.instrumented_behavior_file_path}, perf={tf.benchmarking_file_path}"
)
total_looping_time = (
TOTAL_LOOPING_TIME_EFFECTIVE / 2 if is_subagent_mode() else TOTAL_LOOPING_TIME_EFFECTIVE
)
behavioral_results, coverage_results = self.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=True,
code_context=code_context,
)
finally:
# Remove codeflash capture
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
behavioral_results, coverage_results = self.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=True,
code_context=code_context,
)
finally:
# Remove codeflash capture
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
if not behavioral_results:
logger.warning(
f"force_lsp|Couldn't run any tests for original function {self.function_to_optimize.function_name}. Skipping optimization."
@ -2345,14 +2733,16 @@ class FunctionOptimizer:
self.instrument_async_for_mode(TestingMode.PERFORMANCE)
try:
subagent = is_subagent_mode()
benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
enable_coverage=False,
code_context=code_context,
**({"pytest_min_loops": 3, "pytest_max_loops": 100} if subagent else {}),
)
logger.debug(f"[BENCHMARK-DONE] Got {len(benchmarking_results.test_results)} benchmark results")
finally:
@ -2504,13 +2894,15 @@ class FunctionOptimizer:
try:
self.instrument_capture(file_path_to_helper_classes)
total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE
total_looping_time = (
TOTAL_LOOPING_TIME_EFFECTIVE / 2 if is_subagent_mode() else TOTAL_LOOPING_TIME_EFFECTIVE
)
candidate_behavior_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=total_looping_time,
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
enable_coverage=False,
)
finally:
@ -2545,13 +2937,15 @@ class FunctionOptimizer:
self.instrument_async_for_mode(TestingMode.PERFORMANCE)
try:
subagent = is_subagent_mode()
candidate_benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=total_looping_time,
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
enable_coverage=False,
**({"pytest_min_loops": 3, "pytest_max_loops": 100} if subagent else {}),
)
finally:
if self.function_to_optimize.is_async:

View file

@ -16,6 +16,7 @@ from codeflash.code_utils.config_consts import (
TOTAL_LOOPING_TIME_EFFECTIVE,
)
from codeflash.either import Failure, Success
from codeflash.languages.function_optimizer import FunctionOptimizer
from codeflash.models.models import (
CodeOptimizationContext,
CodeString,
@ -24,7 +25,6 @@ from codeflash.models.models import (
TestingMode,
TestResults,
)
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.equivalence import compare_test_results
if TYPE_CHECKING:

View file

@ -632,6 +632,6 @@ def format_line_profile_results(
if line_contents is None:
line_contents = results.get("line_contents", {})
from codeflash.verification.parse_line_profile_test_output import show_text_non_python
from codeflash.languages.python.parse_line_profile_test_output import show_text_non_python
return show_text_non_python(results, line_contents)

View file

@ -15,6 +15,7 @@ from codeflash.code_utils.config_consts import (
TOTAL_LOOPING_TIME_EFFECTIVE,
)
from codeflash.either import Failure, Success
from codeflash.languages.function_optimizer import FunctionOptimizer
from codeflash.models.models import (
CodeOptimizationContext,
CodeString,
@ -23,7 +24,6 @@ from codeflash.models.models import (
TestingMode,
TestResults,
)
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.equivalence import compare_test_results
if TYPE_CHECKING:

View file

@ -1,7 +1,6 @@
"""JavaScript/TypeScript code normalizer using tree-sitter.
Not currently wired into JavaScriptSupport.normalize_code kept as a
ready-to-use upgrade path when AST-based JS deduplication is needed.
Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication.
The old CodeNormalizer ABC (deleted from base.py) is preserved below for reference.
"""
@ -236,8 +235,7 @@ def normalize_js_code(code: str, typescript: bool = False) -> str:
Uses tree-sitter to parse and normalize variable names. Falls back to
basic comment/whitespace stripping if tree-sitter is unavailable or parsing fails.
Not currently wired into JavaScriptSupport.normalize_code kept as a
ready-to-use upgrade path when AST-based JS deduplication is needed.
Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication.
"""
try:
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage

View file

@ -1225,20 +1225,29 @@ class JavaScriptSupport:
return node
# Check function declarations
if node.type in ("function_declaration", "function"):
if node.type in (
"function_declaration",
"function",
"generator_function_declaration",
"generator_function",
):
name_node = node.child_by_field_name("name")
if name_node:
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name:
return node
# Check arrow functions assigned to variables
if node.type == "lexical_declaration":
# Check arrow functions and function expressions assigned to variables
if node.type in ("lexical_declaration", "variable_declaration"):
for child in node.children:
if child.type == "variable_declarator":
name_node = child.child_by_field_name("name")
value_node = child.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
if (
name_node
and value_node
and value_node.type in ("arrow_function", "function_expression", "generator_function")
):
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name:
return value_node
@ -1253,6 +1262,7 @@ class JavaScriptSupport:
func_node = find_function_node(tree.root_node, function_name)
if not func_node:
logger.debug("Could not find function '%s' in optimized code for body extraction", function_name)
return None
# Find the body node
@ -1313,14 +1323,21 @@ class JavaScriptSupport:
if name == target_name and (node.start_point[0] + 1) == target_line:
return node
if node.type == "lexical_declaration":
if node.type in ("lexical_declaration", "variable_declaration"):
for child in node.children:
if child.type == "variable_declarator":
name_node = child.child_by_field_name("name")
value_node = child.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
if (
name_node
and value_node
and value_node.type in ("arrow_function", "function_expression", "generator_function")
):
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name and (node.start_point[0] + 1) == target_line:
if name == target_name and (
(node.start_point[0] + 1) == target_line
or (value_node.start_point[0] + 1) == target_line
):
return value_node
for child in node.children:
@ -1704,26 +1721,14 @@ class JavaScriptSupport:
return False
def normalize_code(self, source: str) -> str:
"""Normalize JavaScript code for deduplication.
"""Normalize JavaScript code for deduplication using tree-sitter."""
from codeflash.languages.javascript.normalizer import normalize_js_code
Removes comments and normalizes whitespace.
Args:
source: Source code to normalize.
Returns:
Normalized source code.
"""
# Simple normalization: remove extra whitespace
# A full implementation would use tree-sitter to strip comments
lines = source.splitlines()
normalized_lines = []
for line in lines:
stripped = line.strip()
if stripped and not stripped.startswith("//"):
normalized_lines.append(stripped)
return "\n".join(normalized_lines)
try:
is_ts = self.treesitter_language == TreeSitterLanguage.TYPESCRIPT
return normalize_js_code(source, typescript=is_ts)
except Exception:
return source
def generate_concolic_tests(
self, test_cfg: Any, project_root: Any, function_to_optimize: Any, function_to_optimize_ast: Any

View file

@ -339,6 +339,73 @@ module.exports = {{
return None
def _create_runtime_jest_config(base_config_path: Path | None, project_root: Path, test_dirs: set[str]) -> Path | None:
"""Create a runtime Jest config that includes test directories in roots and testMatch.
This is needed because test files generated by codeflash may be placed
outside the project root (e.g., in a monorepo where the source file lives
in a subpackage but tests are generated at the repo root). Jest requires
test files to be within configured ``roots`` and to match ``testMatch``
patterns (which typically use ``<rootDir>``). Since ``roots`` set via CLI
can be overridden by config, and ``testMatch`` patterns using ``<rootDir>``
won't match files outside the project root, we must create a wrapper config.
Args:
base_config_path: Path to the base Jest config to extend, or None.
project_root: The project root directory (where package.json lives).
test_dirs: Set of absolute directory paths containing test files.
Returns:
Path to the created runtime config file.
"""
is_esm = _is_esm_project(project_root)
config_ext = ".cjs" if is_esm else ".js"
if base_config_path:
config_dir = base_config_path.parent
else:
config_dir = project_root
runtime_config_path = config_dir / f"jest.codeflash.runtime.config{config_ext}"
test_dirs_js = ", ".join(f"'{d}'" for d in sorted(test_dirs))
if base_config_path:
require_path = f"./{base_config_path.name}"
config_content = f"""// Auto-generated by codeflash - runtime config with test roots
const baseConfig = require('{require_path}');
module.exports = {{
...baseConfig,
roots: [
...(baseConfig.roots || [__dirname]),
{test_dirs_js},
],
testMatch: ['**/*.test.ts', '**/*.test.js', '**/*.test.tsx', '**/*.test.jsx'],
}};
"""
else:
config_content = f"""// Auto-generated by codeflash - runtime config with test roots
module.exports = {{
roots: ['{project_root}', {test_dirs_js}],
testMatch: ['**/*.test.ts', '**/*.test.js', '**/*.test.tsx', '**/*.test.jsx'],
}};
"""
try:
runtime_config_path.write_text(config_content, encoding="utf-8")
_created_config_files.add(runtime_config_path)
logger.debug(f"Created runtime Jest config with test roots: {runtime_config_path}")
except Exception as e:
logger.warning(f"Failed to create runtime Jest config: {e}")
# Fall back to base config
if base_config_path:
return base_config_path
return None
return runtime_config_path
def _get_jest_config_for_project(project_root: Path) -> Path | None:
"""Get the appropriate Jest config for the project.
@ -712,6 +779,17 @@ def run_jest_behavioral_tests(
# Add Jest config if found - needed for TypeScript transformation
# Uses codeflash-compatible config if project has bundler moduleResolution
jest_config = _get_jest_config_for_project(effective_cwd)
# If test files are outside the project root, create a runtime wrapper config
# that adds their directories to Jest's `roots` and overrides `testMatch`.
# This is necessary because Jest's testMatch patterns use <rootDir> which
# resolves to the config file's directory, excluding external test files.
if test_files:
resolved_root = effective_cwd.resolve()
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
if any(not Path(d).is_relative_to(resolved_root) for d in test_dirs):
jest_config = _create_runtime_jest_config(jest_config, effective_cwd, test_dirs)
if jest_config:
jest_cmd.append(f"--config={jest_config}")
@ -723,12 +801,6 @@ def run_jest_behavioral_tests(
jest_cmd.append("--runTestsByPath")
resolved_test_files = [str(Path(f).resolve()) for f in test_files]
jest_cmd.extend(resolved_test_files)
# Add --roots to include directories containing test files
# This is needed because some projects configure Jest with restricted roots
# (e.g., roots: ["<rootDir>/src"]) which excludes the test directory
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
for test_dir in sorted(test_dirs):
jest_cmd.extend(["--roots", test_dir])
if timeout:
jest_cmd.append(f"--testTimeout={timeout * 1000}") # Jest uses milliseconds
@ -962,6 +1034,14 @@ def run_jest_benchmarking_tests(
# Add Jest config if found - needed for TypeScript transformation
# Uses codeflash-compatible config if project has bundler moduleResolution
jest_config = _get_jest_config_for_project(effective_cwd)
# If test files are outside the project root, create a runtime wrapper config
if test_files:
resolved_root = effective_cwd.resolve()
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
if any(not Path(d).is_relative_to(resolved_root) for d in test_dirs):
jest_config = _create_runtime_jest_config(jest_config, effective_cwd, test_dirs)
if jest_config:
jest_cmd.append(f"--config={jest_config}")
@ -969,10 +1049,6 @@ def run_jest_benchmarking_tests(
jest_cmd.append("--runTestsByPath")
resolved_test_files = [str(Path(f).resolve()) for f in test_files]
jest_cmd.extend(resolved_test_files)
# Add --roots to include directories containing test files
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
for test_dir in sorted(test_dirs):
jest_cmd.extend(["--roots", test_dir])
if timeout:
jest_cmd.append(f"--testTimeout={timeout * 1000}")
@ -1127,6 +1203,14 @@ def run_jest_line_profile_tests(
# Add Jest config if found - needed for TypeScript transformation
# Uses codeflash-compatible config if project has bundler moduleResolution
jest_config = _get_jest_config_for_project(effective_cwd)
# If test files are outside the project root, create a runtime wrapper config
if test_files:
resolved_root = effective_cwd.resolve()
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
if any(not Path(d).is_relative_to(resolved_root) for d in test_dirs):
jest_config = _create_runtime_jest_config(jest_config, effective_cwd, test_dirs)
if jest_config:
jest_cmd.append(f"--config={jest_config}")
@ -1134,10 +1218,6 @@ def run_jest_line_profile_tests(
jest_cmd.append("--runTestsByPath")
resolved_test_files = [str(Path(f).resolve()) for f in test_files]
jest_cmd.extend(resolved_test_files)
# Add --roots to include directories containing test files
test_dirs = {str(Path(f).resolve().parent) for f in test_files}
for test_dir in sorted(test_dirs):
jest_cmd.extend(["--roots", test_dir])
if timeout:
jest_cmd.append(f"--testTimeout={timeout * 1000}")

View file

@ -37,7 +37,6 @@ from codeflash.models.models import (
CodeStringsMarkdown,
FunctionSource,
)
from codeflash.optimization.function_context import belongs_to_function_qualified
if TYPE_CHECKING:
from jedi.api.classes import Name
@ -123,6 +122,13 @@ def get_code_optimization_context(
code_context_type=CodeContextType.READ_WRITABLE,
)
# Ensure the target file is first in the code blocks so the LLM knows which file to optimize
target_relative = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
target_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path == target_relative]
other_blocks = [cs for cs in final_read_writable_code.code_strings if cs.file_path != target_relative]
if target_blocks:
final_read_writable_code.code_strings = target_blocks + other_blocks
read_only_code_markdown = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
@ -432,6 +438,7 @@ def get_function_sources_from_jedi(
fully_qualified_name=fqn,
only_function_name=func_name,
source_code=definition.get_line_code(),
definition_type=definition.type,
)
file_path_to_function_source[definition_path].add(function_source)
function_source_list.append(function_source)
@ -1462,3 +1469,41 @@ def prune_cst(
include_init_dunder=include_init_dunder,
),
)
def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool:
"""Check if the given name belongs to the specified method."""
return belongs_to_function(name, method_name) and belongs_to_class(name, class_name)
def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given jedi Name is a direct child of the specified function."""
if name.name == function_name: # Handles function definition and recursive function calls
return False
if (name := name.parent()) and name.type == "function":
return bool(name.name == function_name)
return False
def belongs_to_class(name: Name, class_name: str) -> bool:
"""Check if given jedi Name is a direct child of the specified class."""
while name := name.parent():
if name.type == "class":
return bool(name.name == class_name)
return False
def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> bool:
"""Check if the given jedi Name is a direct child of the specified function, matched by qualified function name."""
try:
if (
name.full_name.startswith(name.module_name)
and get_qualified_name(name.module_name, name.full_name) == qualified_function_name
):
# Handles function definition and recursive function calls
return False
if (name := name.parent()) and name.type == "function":
return get_qualified_name(name.module_name, name.full_name) == qualified_function_name
return False
except ValueError:
return False

View file

@ -774,7 +774,7 @@ def detect_unused_helper_functions(
# First, analyze imports to build a mapping of imported names to their original qualified names
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)
# Extract all function calls in the entrypoint function
# Extract all function calls and attribute references in the entrypoint function
called_function_names = {function_to_optimize.function_name}
for node in ast.walk(entrypoint_function_ast):
if isinstance(node, ast.Call):
@ -795,7 +795,6 @@ def detect_unused_helper_functions(
# self.method_name() -> add both method_name and ClassName.method_name
called_function_names.add(attr_name)
# For class methods, also add the qualified name
# For class methods, also add the qualified name
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
class_name = function_to_optimize.parents[0].name
called_function_names.add(f"{class_name}.{attr_name}")
@ -808,9 +807,25 @@ def detect_unused_helper_functions(
if mapped_names:
called_function_names.update(mapped_names)
# Handle nested attribute access like obj.attr.method()
# Handle nested attribute access like obj.attr.method()
else:
called_function_names.add(node.func.attr)
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
# Attribute reference without call: e.g. self._parse1 = self._parse_literal
# This covers methods used as callbacks, stored in variables, passed as arguments, etc.
attr_name = node.attr
value_id = node.value.id
if value_id == "self":
called_function_names.add(attr_name)
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
class_name = function_to_optimize.parents[0].name
called_function_names.add(f"{class_name}.{attr_name}")
else:
called_function_names.add(attr_name)
full_ref = f"{value_id}.{attr_name}"
called_function_names.add(full_ref)
mapped_names = imported_names_map.get(full_ref)
if mapped_names:
called_function_names.update(mapped_names)
logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
logger.debug(f"Imported names mapping: {imported_names_map}")

View file

@ -4,9 +4,13 @@ import ast
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import console, logger
from rich.syntax import Syntax
from codeflash.cli_cmds.console import code_print, console, logger
from codeflash.code_utils.code_utils import unified_diff_strings
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
from codeflash.either import Failure, Success
from codeflash.languages.function_optimizer import FunctionOptimizer
from codeflash.languages.python.context.unused_definition_remover import (
detect_unused_helper_functions,
revert_unused_helper_functions,
@ -19,7 +23,6 @@ from codeflash.languages.python.static_analysis.code_replacer import (
)
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.models.models import TestingMode, TestResults
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results
if TYPE_CHECKING:
@ -33,8 +36,10 @@ if TYPE_CHECKING:
CodeStringsMarkdown,
ConcurrencyMetrics,
CoverageData,
GeneratedTestsList,
OriginalCodeBaseline,
TestDiff,
TestFileReview,
)
@ -82,10 +87,59 @@ class PythonFunctionOptimizer(FunctionOptimizer):
return original_conftest_content
def instrument_capture(self, file_path_to_helper_classes: dict[Path, set[str]]) -> None:
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
instrument_codeflash_capture(self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root)
def display_repaired_functions(
self, generated_tests: GeneratedTestsList, reviews: list[TestFileReview], original_sources: dict[int, str]
) -> None:
"""Display per-function diffs of repaired tests using libcst."""
import libcst as cst
def extract_functions(source: str, names: set[str]) -> dict[str, str]:
"""Extract functions by name from top-level and class bodies."""
try:
tree = cst.parse_module(source)
except cst.ParserSyntaxError:
logger.debug("Failed to parse source for diff display", exc_info=True)
return {}
result: dict[str, str] = {}
for node in tree.body:
if isinstance(node, cst.FunctionDef) and node.name.value in names:
result[node.name.value] = tree.code_for_node(node)
elif isinstance(node, cst.ClassDef):
for child in node.body.body:
if isinstance(child, cst.FunctionDef) and child.name.value in names:
result[child.name.value] = tree.code_for_node(child)
return result
for review in reviews:
gt = generated_tests.generated_tests[review.test_index]
repaired_names = {f.function_name for f in review.functions_to_repair}
new_source = gt.generated_original_test_source
old_source = original_sources.get(review.test_index, "")
old_funcs = extract_functions(old_source, repaired_names)
new_funcs = extract_functions(new_source, repaired_names)
for name in repaired_names:
old_func = old_funcs.get(name, "")
new_func = new_funcs.get(name, "")
if not new_func:
continue
console.rule()
if old_func and old_func != new_func:
diff = unified_diff_strings(
old_func, new_func, fromfile=f"{name} (before)", tofile=f"{name} (after)"
)
if diff:
logger.info(f"Repaired: {name}")
console.print(Syntax(diff, "diff", theme="monokai"))
continue
logger.info(f"Repaired: {name}")
code_print(new_func, language=self.function_to_optimize.language)
def should_check_coverage(self) -> bool:
return True
@ -127,7 +181,7 @@ class PythonFunctionOptimizer(FunctionOptimizer):
def parse_line_profile_test_results(
self, line_profiler_output_file: Path | None
) -> tuple[TestResults | dict, CoverageData | None]:
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
from codeflash.languages.python.parse_line_profile_test_output import parse_line_profile_results
return parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)

View file

@ -102,6 +102,22 @@ class InitDecorator(ast.NodeTransformer):
self._init_kwarg = ast.arg(arg="kwargs")
self._init_self_arg = ast.arg(arg="self", annotation=None)
# Precreate commonly reused AST fragments for classes that lack __init__
# Create the super().__init__(*args, **kwargs) Expr (reuse prebuilt pieces)
self._super_call_expr = ast.Expr(
value=ast.Call(func=self._super_func, args=[self._super_starred], keywords=[self._super_kwarg])
)
# Create function arguments: self, *args, **kwargs (reuse arg nodes)
self._init_arguments = ast.arguments(
posonlyargs=[],
args=[self._init_self_arg],
vararg=self._init_vararg,
kwonlyargs=[],
kw_defaults=[],
kwarg=self._init_kwarg,
defaults=[],
)
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
# Check if our import already exists
if node.module == "codeflash.verification.codeflash_capture" and any(
@ -156,27 +172,42 @@ class InitDecorator(ast.NodeTransformer):
self.inserted_decorator = True
if not has_init:
# Skip dataclasses — their __init__ is auto-generated at class creation time and isn't in the AST.
# The synthetic __init__ with super().__init__(*args, **kwargs) overrides it and fails because
# object.__init__() doesn't accept the dataclass field kwargs.
# TODO: support by saving a reference to the generated __init__ before overriding, e.g.
# _orig_init = ClassName.__init__; then calling _orig_init(self, *args, **kwargs) in the wrapper
for dec in node.decorator_list:
dec_name = self._expr_name(dec)
if dec_name is not None and dec_name.endswith("dataclass"):
return node
# Skip NamedTuples — their __init__ is synthesized and cannot be overwritten.
for base in node.bases:
base_name = self._expr_name(base)
if base_name is not None and base_name.endswith("NamedTuple"):
return node
# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
super_call = ast.Expr(
value=ast.Call(func=self._super_func, args=[self._super_starred], keywords=[self._super_kwarg])
)
# Create function arguments: self, *args, **kwargs (reuse arg nodes)
arguments = ast.arguments(
posonlyargs=[],
args=[self._init_self_arg],
vararg=self._init_vararg,
kwonlyargs=[],
kw_defaults=[],
kwarg=self._init_kwarg,
defaults=[],
)
super_call = self._super_call_expr
# Create the complete function using prebuilt arguments/body but attach the class-specific decorator
# Create the complete function
init_func = ast.FunctionDef(
name="__init__", args=arguments, body=[super_call], decorator_list=[decorator], returns=None
name="__init__", args=self._init_arguments, body=[super_call], decorator_list=[decorator], returns=None
)
node.body.insert(0, init_func)
self.inserted_decorator = True
return node
def _expr_name(self, node: ast.AST) -> str | None:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Call):
return self._expr_name(node.func)
if isinstance(node, ast.Attribute):
parent = self._expr_name(node.value)
return f"{parent}.{node.attr}" if parent else node.attr
return None

View file

@ -79,7 +79,7 @@ def _is_valid_definition(definition: Name, caller_qualified_name: str, project_r
return False
try:
from codeflash.optimization.function_context import belongs_to_function_qualified
from codeflash.languages.python.context.code_context_extractor import belongs_to_function_qualified
if belongs_to_function_qualified(definition, caller_qualified_name):
return False

View file

@ -714,7 +714,14 @@ def add_needed_imports_from_module(
)
)
try:
cst.parse_module(src_module_code).visit(gatherer)
src_module = cst.parse_module(src_module_code)
# Exclude function/class bodies so GatherImportsVisitor only sees module-level imports.
# Nested imports (inside functions) are part of function logic and must not be
# scheduled for add/remove — RemoveImportsVisitor would strip them as "unused".
module_level_only = src_module.with_changes(
body=[stmt for stmt in src_module.body if not isinstance(stmt, (cst.FunctionDef, cst.ClassDef))]
)
module_level_only.visit(gatherer)
except Exception as e:
logger.error(f"Error parsing source module code: {e}")
return dst_code_fallback

View file

@ -32,7 +32,7 @@ def is_valid_concolic_test(test_code: str, project_root: Optional[str] = None) -
try:
result = subprocess.run(
[SAFE_SYS_EXECUTABLE, "-m", "pytest", "--collect-only", "-q", temp_path.as_posix()],
[SAFE_SYS_EXECUTABLE, "-m", "pytest", "-x", "-q", temp_path.as_posix()],
check=False,
capture_output=True,
text=True,

View file

@ -731,20 +731,63 @@ class PythonSupport:
"""
import libcst as cst
bare_names: set[str] = set()
qualified_names: set[str] = set()
for name in functions_to_remove:
if "." in name:
qualified_names.add(name)
else:
bare_names.add(name)
class TestFunctionRemover(cst.CSTTransformer):
def __init__(self, names_to_remove: list[str]) -> None:
self.names_to_remove = set(names_to_remove)
def __init__(self) -> None:
self.class_stack: list[str] = []
self.emptied_classes: set[str] = set()
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.class_stack.append(node.name.value)
return True
def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef | cst.RemovalSentinel:
class_name = self.class_stack.pop()
if class_name in self.emptied_classes:
self.emptied_classes.discard(class_name)
body = updated_node.body
if isinstance(body, cst.IndentedBlock):
has_meaningful_body = any(
not (
isinstance(s, cst.SimpleStatementLine)
and len(s.body) == 1
and isinstance(s.body[0], (cst.Pass, cst.Expr))
and (
isinstance(s.body[0], cst.Pass)
or (isinstance(s.body[0].value, (cst.SimpleString, cst.ConcatenatedString)))
)
)
for s in body.body
)
if not has_meaningful_body:
return cst.RemovalSentinel.REMOVE
return updated_node
def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef | cst.RemovalSentinel:
if original_node.name.value in self.names_to_remove:
fn_name = original_node.name.value
if fn_name in bare_names and not self.class_stack:
return cst.RemovalSentinel.REMOVE
if self.class_stack:
qualified = f"{self.class_stack[-1]}.{fn_name}"
if qualified in qualified_names:
self.emptied_classes.add(self.class_stack[-1])
return cst.RemovalSentinel.REMOVE
return updated_node
try:
tree = cst.parse_module(test_source)
modified = tree.visit(TestFunctionRemover(functions_to_remove))
modified = tree.visit(TestFunctionRemover())
return modified.code
except Exception:
return test_source
@ -1027,8 +1070,8 @@ class PythonSupport:
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files
from codeflash.languages.python.test_runner import execute_test_subprocess
from codeflash.models.models import TestType
from codeflash.verification.test_runner import execute_test_subprocess
blocklisted_plugins = ["benchmark", "codspeed", "xdist", "sugar"]
@ -1132,7 +1175,7 @@ class PythonSupport:
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
from codeflash.verification.test_runner import execute_test_subprocess
from codeflash.languages.python.test_runner import execute_test_subprocess
blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]
@ -1176,7 +1219,7 @@ class PythonSupport:
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
from codeflash.verification.test_runner import execute_test_subprocess
from codeflash.languages.python.test_runner import execute_test_subprocess
blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]

View file

@ -11,13 +11,13 @@ from typing import TYPE_CHECKING, Optional, Union
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
from codeflash.cli_cmds.cli import process_pyproject_config
from codeflash.cli_cmds.cmd_init import (
from codeflash.cli_cmds.cmd_init import create_find_common_tags_file
from codeflash.cli_cmds.init_config import (
CommonSections,
VsCodeSetupInfo,
config_found,
configure_pyproject_toml,
create_empty_pyproject_toml,
create_find_common_tags_file,
get_formatter_cmds,
get_suggestions,
get_valid_subdirs,

View file

@ -18,14 +18,11 @@ if "--subagent" in sys.argv:
warnings.filterwarnings("ignore")
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test
from codeflash.cli_cmds.console import paneled_text
from codeflash.code_utils import env_utils
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.version_check import check_for_newer_minor_version
from codeflash.telemetry import posthog_cf
from codeflash.telemetry.sentry import init_sentry
if TYPE_CHECKING:
from argparse import Namespace
@ -33,6 +30,9 @@ if TYPE_CHECKING:
def main() -> None:
"""Entry point for the codeflash command-line interface."""
from codeflash.telemetry import posthog_cf
from codeflash.telemetry.sentry import init_sentry
args = parse_args()
print_codeflash_banner()
@ -46,11 +46,30 @@ def main() -> None:
disable_telemetry = pyproject_config.get("disable_telemetry", False)
init_sentry(enabled=not disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(enabled=not disable_telemetry)
args.func()
if args.command == "init":
from codeflash.cli_cmds.cmd_init import init_codeflash
init_codeflash()
elif args.command == "init-actions":
from codeflash.cli_cmds.github_workflow import install_github_actions
install_github_actions()
elif args.command == "vscode-install":
from codeflash.cli_cmds.extension import install_vscode_extension
install_vscode_extension()
elif args.command == "optimize":
from codeflash.tracer import main as tracer_main
tracer_main(args)
elif args.verify_setup:
args = process_pyproject_config(args)
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
from codeflash.cli_cmds.cmd_init import ask_run_end_to_end_test
ask_run_end_to_end_test(args)
else:
# Check for first-run experience (no config exists)
@ -117,6 +136,13 @@ def _handle_config_loading(args: Namespace) -> Namespace | None:
def print_codeflash_banner() -> None:
"""Print the Codeflash banner with the branded styling.
Renders the Codeflash ASCII logo inside a non-expanding panel titled with
https://codeflash.ai, using bold gold text for visual emphasis.
"""
from codeflash.cli_cmds.console_constants import CODEFLASH_LOGO
paneled_text(
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}
)

View file

@ -115,6 +115,16 @@ class OptimizationReviewResult(NamedTuple):
explanation: str
class FunctionRepairInfo(NamedTuple):
function_name: str
reason: str
class TestFileReview(NamedTuple):
test_index: int
functions_to_repair: list[FunctionRepairInfo]
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name
# of the module is foo.eggs.
@ -405,6 +415,7 @@ class GeneratedTests(BaseModel):
generated_original_test_source: str
instrumented_behavior_test_source: str
instrumented_perf_test_source: str
raw_generated_test_source: str | None = None
behavior_file_path: Path
perf_file_path: Path

View file

@ -1,46 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from codeflash.code_utils.code_utils import get_qualified_name
if TYPE_CHECKING:
from jedi.api.classes import Name
def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool:
"""Check if the given name belongs to the specified method."""
return belongs_to_function(name, method_name) and belongs_to_class(name, class_name)
def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given jedi Name is a direct child of the specified function."""
if name.name == function_name: # Handles function definition and recursive function calls
return False
if (name := name.parent()) and name.type == "function":
return name.name == function_name
return False
def belongs_to_class(name: Name, class_name: str) -> bool:
"""Check if given jedi Name is a direct child of the specified class."""
while name := name.parent():
if name.type == "class":
return name.name == class_name
return False
def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> bool:
"""Check if the given jedi Name is a direct child of the specified function, matched by qualified function name."""
try:
if (
name.full_name.startswith(name.module_name)
and get_qualified_name(name.module_name, name.full_name) == qualified_function_name
):
# Handles function definition and recursive function calls
return False
if (name := name.parent()) and name.type == "function":
return get_qualified_name(name.module_name, name.full_name) == qualified_function_name
return False
except ValueError:
return False

View file

@ -42,8 +42,8 @@ if TYPE_CHECKING:
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import DependencyResolver
from codeflash.languages.function_optimizer import FunctionOptimizer
from codeflash.models.models import BenchmarkKey, FunctionCalledInTest, ValidCode
from codeflash.optimization.function_optimizer import FunctionOptimizer
class Optimizer:
@ -519,9 +519,10 @@ class Optimizer:
validated_original_code, _original_module_ast = prepared_modules[original_module_path]
function_iterator_count = i + 1
line_suffix = f":{function_to_optimize.starting_line}" if function_to_optimize.starting_line else ""
logger.info(
f"Optimizing function {function_iterator_count} of {len(globally_ranked_functions)}: "
f"{function_to_optimize.qualified_name} (in {original_module_path.name})"
f"{function_to_optimize.qualified_name} (in {original_module_path}{line_suffix})"
)
console.rule()
function_optimizer = None

View file

@ -235,8 +235,8 @@ def main(args: Namespace | None = None) -> ArgumentParser:
result_pickle_file_path.unlink(missing_ok=True)
if not parsed_args.trace_only and replay_test_paths:
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
from codeflash.cli_cmds.console import paneled_text
from codeflash.cli_cmds.console_constants import CODEFLASH_LOGO
from codeflash.languages import set_current_language
from codeflash.languages.base import Language
from codeflash.telemetry import posthog_cf

View file

@ -1,14 +1,19 @@
import _thread
import array
import ast
import datetime
import decimal
import enum
import io
import itertools
import math
import re
import sqlite3
import threading
import types
import warnings
import weakref
import xml.etree.ElementTree as ET
from collections import ChainMap, OrderedDict, deque
from importlib.util import find_spec
from typing import Any, Optional
@ -629,6 +634,21 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
if type(orig) in {types.BuiltinFunctionType, types.BuiltinMethodType}:
return new == orig
if isinstance(orig, ET.Element):
return isinstance(new, ET.Element) and ET.tostring(orig) == ET.tostring(new)
if isinstance(
orig,
(
_thread.LockType,
_thread.RLock,
threading.Event,
threading.Condition,
sqlite3.Connection,
sqlite3.Cursor,
io.IOBase,
),
):
return type(orig) is type(new)
if str(type(orig)) == "<class 'object'>":
return True
# TODO : Add other types here

View file

@ -29,7 +29,7 @@ def generate_tests(
test_path: Path,
test_perf_path: Path,
is_numerical_code: bool | None = None,
) -> tuple[str, str, str, Path, Path] | None:
) -> tuple[str, str, str, str | None, Path, Path] | None:
# TODO: Sometimes this recreates the original Class definition. This overrides and messes up the original
# class import. Remove the recreation of the class definition
start_time = time.perf_counter()
@ -74,8 +74,10 @@ def generate_tests(
module_system=project_module_system,
is_numerical_code=is_numerical_code,
)
if response and isinstance(response, tuple) and len(response) == 3:
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = response
if response and isinstance(response, tuple) and len(response) == 4:
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source, raw_generated_tests = (
response
)
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = (
lang_support.process_generated_test_strings(
@ -97,6 +99,7 @@ def generate_tests(
generated_test_source,
instrumented_behavior_test_source,
instrumented_perf_test_source,
raw_generated_tests,
test_path,
test_perf_path,
)

View file

@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.20.1.post570.dev0+0fb5e36b"
__version__ = "0.20.1.post872.dev0+d7ab5a98"

View file

@ -2,7 +2,7 @@
title: "How Codeflash Measures Code Runtime on GPUs"
description: "Learn how Codeflash accurately measures code performance on GPUs"
icon: "microchip"
sidebarTitle: "GPU Benchmarking"
sidebarTitle: "GPU Benchmarking (Python)"
keywords: ["benchmarking", "performance", "timing", "measurement", "runtime", "noise reduction", "GPU", "MPS"]
---

View file

@ -2,7 +2,7 @@
title: "JavaScript / TypeScript Configuration"
description: "Configure Codeflash for JavaScript and TypeScript projects using package.json"
icon: "js"
sidebarTitle: "JavaScript / TypeScript"
sidebarTitle: "JS / TS (package.json)"
keywords:
[
"configuration",
@ -66,7 +66,7 @@ You can always override any auto-detected value in the `"codeflash"` section.
## Optional Options
- `testRunner`: Test framework to use. Auto-detected from your dependencies. Supported values: `"jest"`, `"vitest"`.
- `testRunner`: Test framework to use. Auto-detected from your dependencies. Supported values: `"jest"`, `"vitest"`, `"mocha"`.
- `formatterCmds`: Formatter commands. `$file` refers to the file being optimized. Disable with `["disabled"]`.
- **Prettier**: `["prettier --write $file"]`
- **ESLint + Prettier**: `["eslint --fix $file", "prettier --write $file"]`
@ -102,6 +102,7 @@ No separate configuration is needed for TypeScript vs JavaScript.
|-----------|-------------------|-------|
| **Jest** | `jest` in dependencies | Default for most projects |
| **Vitest** | `vitest` in dependencies | ESM-native support |
| **Mocha** | `mocha` in dependencies | Uses `node:assert/strict`, zero extra deps |
<Info>
**Functions must be exported** to be optimizable. Codeflash uses tree-sitter AST analysis to discover functions and check export status. Supported export patterns:
@ -198,6 +199,44 @@ my-app/
}
```
### Project with scattered test folders
If your tests are spread across multiple directories (e.g., `test/` at root and `__tests__/` inside `src/`), set `testsRoot` to the common ancestor:
```text
my-app/
|- src/
| |- utils/
| | |- __tests__/
| | | |- utils.test.js
| | |- helpers.js
| |- components/
| | |- __tests__/
| | | |- Button.test.jsx
| | |- Button.jsx
|- test/
| |- integration.test.js
|- package.json
```
```json
{
"name": "my-app",
"codeflash": {
"moduleRoot": "src",
"testsRoot": "."
}
}
```
<Info>
**`testsRoot` is a single path.** Codeflash recursively searches for `*.test.js`, `*.spec.js`, and `__tests__/**/*.js` files under this directory. Setting it to `"."` (project root) discovers tests everywhere, including co-located `__tests__/` folders inside `src/`.
</Info>
<Warning>
**Monorepos with per-package tests:** Don't set `testsRoot` to the monorepo root. Instead, run codeflash from each package directory with its own config. Each package's `testsRoot` is relative to that package.
</Warning>
### CommonJS library with no separate test directory
```text
@ -218,3 +257,115 @@ my-lib/
}
}
```
## Manual Configuration (without `codeflash init`)
If you prefer to configure manually or `codeflash init` doesn't detect your project correctly, add the `"codeflash"` key directly to your `package.json`:
```json
{
"name": "my-project",
"codeflash": {
"moduleRoot": "src",
"testsRoot": "tests",
"testRunner": "vitest",
"formatterCmds": ["prettier --write $file"],
"ignorePaths": ["src/generated/", "src/vendor/"]
}
}
```
### Step-by-step
1. **Set `moduleRoot`** — the directory containing your source code. Only files under this path are discovered for optimization.
2. **Set `testsRoot`** — the directory containing your tests. Codeflash searches recursively for `*.test.js`, `*.spec.js`, and `__tests__/**/*.js` files.
3. **Set `testRunner`** (optional) — `"jest"`, `"vitest"`, or `"mocha"`. Auto-detected from `devDependencies` if omitted.
4. **Set `formatterCmds`** (optional) — commands to format optimized code. `$file` is replaced with the file path. Use `["disabled"]` to skip formatting.
5. **Set `ignorePaths`** (optional) — directories to exclude from optimization (relative to `moduleRoot`).
### CLI flag overrides
All config values can be overridden via CLI flags:
```bash
# Override moduleRoot and testsRoot for a single run
codeflash --file src/utils.ts --function myFunc \
--module-root src \
--tests-root test \
--no-pr
# Specify a directory to optimize (within moduleRoot)
codeflash --all src/utils/
```
## FAQ
<AccordionGroup>
<Accordion title="Will codeflash handle test files scattered throughout a package?">
Yes, as long as `testsRoot` is set to a common ancestor directory. Codeflash recursively searches for `*.test.js`, `*.spec.js`, and `__tests__/**/*.js` under `testsRoot`.
For co-located tests (test files next to source files), set `testsRoot` to the same value as `moduleRoot`:
```json
{
"codeflash": {
"moduleRoot": "src",
"testsRoot": "src"
}
}
```
For tests in multiple directories, set `testsRoot` to `"."` (project root) to discover them all.
**Note:** `testsRoot` accepts a single path. For monorepos, run codeflash from each package directory with its own config rather than trying to cover all packages from the root.
</Accordion>
<Accordion title="Do I need to add one test folder at a time?">
No. Codeflash recursively searches `testsRoot`, so a single path covers all nested test directories. For example, `"testsRoot": "."` discovers tests anywhere in the project:
- `test/unit/*.test.js`
- `src/components/__tests__/Button.test.tsx`
- `lib/utils.spec.js`
All of these are found with a single `testsRoot` setting.
</Accordion>
<Accordion title="How does codeflash find existing tests for a function?">
Codeflash matches tests to functions by:
1. Scanning all test files under `testsRoot` (matching `*.test.*`, `*.spec.*`, `__tests__/**/*`)
2. Parsing imports in each test file using tree-sitter
3. Matching imported function names to the target function
If a test file imports `myFunction` from `./utils`, it's considered a test for `myFunction`.
</Accordion>
<Accordion title="What if codeflash init detects the wrong values?">
Override any value in `package.json` under the `"codeflash"` key. The most common overrides:
```json
{
"codeflash": {
"moduleRoot": "src/lib",
"testsRoot": ".",
"testRunner": "mocha"
}
}
```
Or use CLI flags for one-off overrides: `--module-root`, `--tests-root`.
</Accordion>
<Accordion title="Can I use codeflash with a standalone codeflash.yaml?">
For internal testing and development, codeflash also reads `codeflash.yaml` files. This is useful when you don't want to modify `package.json`:
```yaml
module_root: "src"
tests_root: "test"
test_framework: "vitest"
formatter_cmds: []
```
Place this in the project root. The `package.json` config takes precedence if both exist.
</Accordion>
</AccordionGroup>

View file

@ -2,7 +2,7 @@
title: "Python Configuration"
description: "Configure Codeflash for Python projects using pyproject.toml"
icon: "python"
sidebarTitle: "Python"
sidebarTitle: "Python (pyproject.toml)"
keywords:
[
"configuration",

View file

@ -19,43 +19,37 @@
"tab": "Documentation",
"groups": [
{
"group": "🏠 Overview",
"group": "Overview",
"pages": ["index"]
},
{
"group": "🚀 Quickstart",
"group": "Getting Started",
"pages": [
"getting-started/local-installation",
"getting-started/javascript-installation"
]
},
{
"group": "⚡ Optimizing with Codeflash",
"group": "Using Codeflash",
"pages": [
"optimizing-with-codeflash/one-function",
"optimizing-with-codeflash/codeflash-all",
"optimizing-with-codeflash/trace-and-optimize",
"optimizing-with-codeflash/codeflash-all"
]
},
{
"group": "✨ Continuous Optimization",
"pages": [
"optimizing-with-codeflash/codeflash-github-actions",
"optimizing-with-codeflash/benchmarking",
"optimizing-with-codeflash/review-optimizations"
]
},
{
"group": "🛠 IDE Extension",
"group": "Configuration",
"pages": [
"editor-plugins/vscode/index",
"editor-plugins/vscode/features",
"editor-plugins/vscode/configuration",
"editor-plugins/vscode/troubleshooting"
"configuration/python",
"configuration/javascript",
"getting-the-best-out-of-codeflash"
]
},
{
"group": "🧠 Core Concepts",
"group": "Core Concepts",
"pages": [
"codeflash-concepts/how-codeflash-works",
"codeflash-concepts/benchmarking",
@ -64,8 +58,13 @@
]
},
{
"group": "⚙️ Configuration & Best Practices",
"pages": ["configuration", "getting-the-best-out-of-codeflash"]
"group": "IDE Extension",
"pages": [
"editor-plugins/vscode/index",
"editor-plugins/vscode/features",
"editor-plugins/vscode/configuration",
"editor-plugins/vscode/troubleshooting"
]
}
]
}

View file

@ -1,7 +1,8 @@
---
title: "JavaScript / TypeScript Installation"
description: "Install and configure Codeflash for your JavaScript/TypeScript project"
icon: "node-js"
icon: "js"
sidebarTitle: "JS / TS Setup"
keywords:
[
"installation",
@ -214,7 +215,11 @@ my-monorepo/
|-----------|--------|-------------------|
| **Jest** | Supported | `jest` in dependencies |
| **Vitest** | Supported | `vitest` in dependencies |
| **Mocha** | Coming soon | — |
| **Mocha** | Supported | `mocha` in dependencies |
<Info>
**Mocha projects** use `node:assert/strict` for generated tests (zero extra dependencies). Mocha's `describe`/`it` globals are used automatically — no imports needed.
</Info>
<Info>
**Functions must be exported** to be optimizable. Codeflash can only discover and optimize functions that are exported from their module (via `export`, `export default`, or `module.exports`).
@ -237,8 +242,8 @@ codeflash --file src/utils.ts --function processData --no-pr
codeflash --all
```
```bash Trace and optimize
codeflash optimize --jest
```bash Dry run (see what would be optimized)
codeflash --all --dry-run
```
</CodeGroup>

View file

@ -1,7 +1,8 @@
---
title: "Local Installation"
title: "Python Installation"
description: "Install and configure Codeflash for your Python project in minutes"
icon: "download"
icon: "python"
sidebarTitle: "Python Setup"
---
Codeflash is installed and configured on a per-project basis.

View file

@ -13,46 +13,56 @@ does not modify the system architecture of your code, but it tries to find the m
### Get Started
Pick your language to install and configure Codeflash:
<CardGroup cols={2}>
<Card title="Python Setup" icon="python" href="/getting-started/local-installation">
Install via pip, uv, or poetry
<Card title="Python" icon="python" href="/getting-started/local-installation">
Install via pip, uv, or poetry. Configure in `pyproject.toml`.
</Card>
<Card title="JavaScript / TypeScript Setup" icon="js" href="/getting-started/javascript-installation">
Install via npm, yarn, pnpm, or bun
<Card title="JavaScript / TypeScript" icon="js" href="/getting-started/javascript-installation">
Install via npm, yarn, pnpm, or bun. Configure in `package.json`. Supports Jest, Vitest, and Mocha.
</Card>
</CardGroup>
### How to use Codeflash
<CardGroup cols={1}>
<Card title="Optimize a Single Function" icon="bullseye" href="/optimizing-with-codeflash/one-function">
Target and optimize individual functions for maximum performance gains.
These commands work for both Python and JS/TS projects:
<CardGroup cols={2}>
<Card title="Optimize a Function" icon="bullseye" href="/optimizing-with-codeflash/one-function">
```bash
codeflash --file path/to/file --function my_function
```
</Card>
<Card title="Optimize Pull Requests" icon="code-pull-request" href="/optimizing-with-codeflash/codeflash-github-actions">
Automatically find optimizations for Pull Requests with GitHub Actions integration.
```bash
codeflash init-actions
```
</Card>
<Card title="Optimize Workflows with Tracing" icon="route" href="/optimizing-with-codeflash/trace-and-optimize">
End-to-end optimization of entire workflows with execution tracing.
```bash
codeflash optimize myscript.py
```
</Card>
<Card title="Optimize Your Entire Codebase" icon="globe" href="/optimizing-with-codeflash/codeflash-all">
Automatically optimize all functions in your project with comprehensive analysis.
<Card title="Optimize Entire Codebase" icon="globe" href="/optimizing-with-codeflash/codeflash-all">
```bash
codeflash --all
```
</Card>
<Card title="Trace & Optimize Workflows" icon="route" href="/optimizing-with-codeflash/trace-and-optimize">
```bash
codeflash optimize myscript.py
```
</Card>
<Card title="Auto-Optimize Pull Requests" icon="code-pull-request" href="/optimizing-with-codeflash/codeflash-github-actions">
```bash
codeflash init-actions
```
</Card>
</CardGroup>
### Configuration Reference
<CardGroup cols={2}>
<Card title="Python Config" icon="python" href="/configuration/python">
`pyproject.toml` reference
</Card>
<Card title="JS / TS Config" icon="js" href="/configuration/javascript">
`package.json` reference — includes monorepo, scattered tests, manual setup
</Card>
</CardGroup>
### How does Codeflash verify correctness?

View file

@ -1,8 +1,8 @@
---
title: "Optimize Performance Benchmarks with every Pull Request"
description: "Configure and use benchmark integration for performance-critical code optimization"
description: "Configure and use benchmark integration for performance-critical code optimization (Python only)"
icon: "chart-line"
sidebarTitle: Setup Benchmarks to Optimize
sidebarTitle: "Benchmarks (Python)"
keywords:
[
"benchmarks",

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
{
"name": "codeflash",
"version": "0.10.1",
"version": "0.10.2",
"description": "Codeflash - AI-powered code optimization for JavaScript and TypeScript",
"main": "runtime/index.js",
"types": "runtime/index.d.ts",

View file

@ -82,11 +82,13 @@ function findJestRunnerRecursive(nodeModulesPath, maxDepth = 5) {
// Recurse into:
// - node_modules subdirectories
// - scoped packages (@org/pkg)
// - scoped packages (@org/pkg) and their children (e.g. @jest/core)
// - hidden directories (.pnpm, .yarn, etc.)
// - pnpm versioned directories (jest-runner@30.0.5)
const isInsideScopedDir = path.basename(dir).startsWith('@');
const shouldRecurse = entry.name === 'node_modules' ||
entry.name.startsWith('@') ||
isInsideScopedDir ||
entry.name === '.pnpm' || entry.name === '.yarn' ||
entry.name.startsWith('jest-runner@');

View file

@ -22,7 +22,7 @@ def test_benchmark_extract(benchmark) -> None:
)
function_to_optimize = FunctionToOptimize(
function_name="replace_function_and_helpers_with_optimized_code",
file_path=file_path / "optimization" / "function_optimizer.py",
file_path=file_path / "languages" / "function_optimizer.py",
parents=[FunctionParent(name="FunctionOptimizer", type="ClassDef")],
starting_line=None,
ending_line=None,

View file

@ -131,30 +131,13 @@ class TestGetJsCodeflashInstallStep:
assert result == ""
def test_npm_global_install(self) -> None:
"""Should generate npm global install when not a dependency."""
result = get_js_codeflash_install_step(JsPackageManager.NPM, is_dependency=False)
def test_uv_tool_install_when_not_dependency(self) -> None:
"""Should generate uv tool install when not a dependency, regardless of package manager."""
for pkg_manager in (JsPackageManager.NPM, JsPackageManager.YARN, JsPackageManager.PNPM, JsPackageManager.BUN):
result = get_js_codeflash_install_step(pkg_manager, is_dependency=False)
assert "Install Codeflash" in result
assert "npm install -g codeflash" in result
def test_yarn_global_install(self) -> None:
"""Should generate yarn global install when not a dependency."""
result = get_js_codeflash_install_step(JsPackageManager.YARN, is_dependency=False)
assert "yarn global add codeflash" in result
def test_pnpm_global_install(self) -> None:
"""Should generate pnpm global install when not a dependency."""
result = get_js_codeflash_install_step(JsPackageManager.PNPM, is_dependency=False)
assert "pnpm add -g codeflash" in result
def test_bun_global_install(self) -> None:
"""Should generate bun global install when not a dependency."""
result = get_js_codeflash_install_step(JsPackageManager.BUN, is_dependency=False)
assert "bun add -g codeflash" in result
assert "Install Codeflash" in result
assert "uv tool install codeflash" in result
class TestGetJsCodeflashRunCommand:

View file

@ -182,8 +182,8 @@ class TestBenchmarkingTestsDispatch:
call_kwargs = mock_vitest_runner.call_args.kwargs
assert call_kwargs["min_loops"] == 10
# JS/TS caps max_loops at JS_BENCHMARKING_MAX_LOOPS (1_000) regardless of passed value
# Actual loop count is limited by target_duration, not max_loops
# JS/TS uses JS_BENCHMARKING_MAX_LOOPS (1_000) regardless of passed value
# Actual loop count is limited by target_duration (10s), not max_loops
assert call_kwargs["max_loops"] == 1_000
assert call_kwargs["target_duration_ms"] == 5000

View file

@ -0,0 +1,96 @@
"""Safety tests for AiServiceClient.add_language_metadata().
These tests verify the correct payload structure for each language,
ensuring that merge resolution doesn't silently break the multi-language metadata logic.
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
from codeflash.api.aiservice import AiServiceClient
from codeflash.languages import Language
class TestAddLanguageMetadata:
"""Test add_language_metadata sets correct payload fields per language."""
@patch("codeflash.api.aiservice.current_language", return_value=Language.PYTHON)
def test_python_sets_language_version_and_python_version(self, _mock_lang: object) -> None:
"""For Python, both language_version and python_version should be set to the same value."""
payload: dict = {}
AiServiceClient.add_language_metadata(payload, language_version="3.11.5")
assert payload["language_version"] == "3.11.5"
assert payload["python_version"] == "3.11.5"
assert "module_system" not in payload
@patch("codeflash.api.aiservice.current_language", return_value=Language.PYTHON)
def test_python_no_module_system(self, _mock_lang: object) -> None:
"""For Python, module_system should never be set even if provided."""
payload: dict = {}
AiServiceClient.add_language_metadata(payload, language_version="3.11.5", module_system="commonjs")
assert "module_system" not in payload
@patch("codeflash.api.aiservice.current_language", return_value=Language.JAVA)
def test_java_sets_language_version_not_python_version(self, _mock_lang: object) -> None:
"""For Java, language_version should be set, python_version should be None."""
payload: dict = {}
AiServiceClient.add_language_metadata(payload, language_version="17")
assert payload["language_version"] == "17"
assert payload["python_version"] is None
@patch("codeflash.api.aiservice.current_language", return_value=Language.JAVA)
def test_java_includes_module_system(self, _mock_lang: object) -> None:
"""For Java, module_system should be set when provided."""
payload: dict = {}
AiServiceClient.add_language_metadata(payload, language_version="17", module_system="maven")
assert payload["module_system"] == "maven"
@patch("codeflash.api.aiservice.current_language", return_value=Language.JAVA)
def test_java_no_module_system_when_none(self, _mock_lang: object) -> None:
"""For Java, module_system should not be set when None."""
payload: dict = {}
AiServiceClient.add_language_metadata(payload, language_version="17", module_system=None)
assert "module_system" not in payload
@patch("codeflash.api.aiservice.current_language", return_value=Language.JAVASCRIPT)
def test_javascript_sets_language_version_not_python_version(self, _mock_lang: object) -> None:
"""For JavaScript, language_version should be set, python_version should be None."""
payload: dict = {}
AiServiceClient.add_language_metadata(payload, language_version="20.11.0")
assert payload["language_version"] == "20.11.0"
assert payload["python_version"] is None
@patch("codeflash.api.aiservice.current_language", return_value=Language.JAVASCRIPT)
def test_javascript_includes_module_system(self, _mock_lang: object) -> None:
"""For JavaScript, module_system should be set when provided."""
payload: dict = {}
AiServiceClient.add_language_metadata(payload, language_version="20.11.0", module_system="esm")
assert payload["module_system"] == "esm"
@patch("codeflash.api.aiservice.current_language", return_value=Language.TYPESCRIPT)
def test_typescript_same_as_javascript(self, _mock_lang: object) -> None:
"""TypeScript should behave the same as JavaScript."""
payload: dict = {}
AiServiceClient.add_language_metadata(payload, language_version="20.11.0", module_system="commonjs")
assert payload["language_version"] == "20.11.0"
assert payload["python_version"] is None
assert payload["module_system"] == "commonjs"
@patch("codeflash.api.aiservice.current_language", return_value=Language.PYTHON)
def test_none_language_version_python(self, _mock_lang: object) -> None:
"""When language_version is None for Python, payload should still have the keys."""
payload: dict = {}
AiServiceClient.add_language_metadata(payload, language_version=None)
assert payload["language_version"] is None
assert payload["python_version"] is None
@patch("codeflash.api.aiservice.current_language", return_value=Language.JAVA)
def test_none_language_version_java(self, _mock_lang: object) -> None:
"""When language_version is None for Java, payload should still have the keys."""
payload: dict = {}
AiServiceClient.add_language_metadata(payload, language_version=None)
assert payload["language_version"] is None
assert payload["python_version"] is None

View file

@ -16,7 +16,7 @@ from codeflash.code_utils.instrument_existing_tests import (
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")

View file

@ -4,7 +4,7 @@ from pathlib import Path
import pytest
from codeflash.cli_cmds.cmd_init import (
from codeflash.cli_cmds.init_config import (
CLISetupInfo,
VsCodeSetupInfo,
configure_pyproject_toml,

View file

@ -243,6 +243,14 @@ def test_bubble_sort_helper() -> None:
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
from bubble_sort_with_math import sorter
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr
```
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
import math
@ -253,14 +261,6 @@ def sorter(arr):
print(x)
return arr
```
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
from bubble_sort_with_math import sorter
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr
```
"""
expected_read_only_context = ""
@ -1178,6 +1178,26 @@ def test_repo_helper() -> None:
hashing_context = code_ctx.hashing_code_context
path_to_globals = project_root / "globals.py"
expected_read_write_context = f"""
```python:{path_to_file.relative_to(project_root)}
import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
response.raise_for_status()
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
processed = processor.add_prefix(processed)
return processed
```
```python:{path_to_globals.relative_to(project_root)}
# Define a global variable
API_URL = "https://api.example.com/data"
@ -1201,26 +1221,6 @@ class DataProcessor:
\"\"\"Add a prefix to the processed data.\"\"\"
return prefix + data
```
```python:{path_to_file.relative_to(project_root)}
import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
response.raise_for_status()
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
processed = processor.add_prefix(processed)
return processed
```
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
@ -1278,6 +1278,25 @@ def test_repo_helper_of_helper() -> None:
hashing_context = code_ctx.hashing_code_context
path_to_globals = project_root / "globals.py"
expected_read_write_context = f"""
```python:{path_to_file.relative_to(project_root)}
import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)
return transformed
```
```python:{path_to_globals.relative_to(project_root)}
# Define a global variable
API_URL = "https://api.example.com/data"
@ -1302,25 +1321,6 @@ class DataProcessor:
\"\"\"Transform the processed data\"\"\"
return DataTransformer().transform(data)
```
```python:{path_to_file.relative_to(project_root)}
import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)
return transformed
```
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
@ -1384,14 +1384,6 @@ def test_repo_helper_of_helper_same_class() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_own_method(self, data):
return self.transform(data)
```
```python:{path_to_utils.relative_to(project_root)}
import math
from transform_utils import DataTransformer
@ -1408,6 +1400,14 @@ class DataProcessor:
\"\"\"Transform the processed data using own method\"\"\"
return DataTransformer().transform_using_own_method(data)
```
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_own_method(self, data):
return self.transform(data)
```
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
@ -1465,14 +1465,6 @@ def test_repo_helper_of_helper_same_file() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_same_file_function(self, data):
return update_data(data)
```
```python:{path_to_utils.relative_to(project_root)}
import math
from transform_utils import DataTransformer
@ -1489,6 +1481,14 @@ class DataProcessor:
\"\"\"Transform the processed data using a function from the same file\"\"\"
return DataTransformer().transform_using_same_file_function(data)
```
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_same_file_function(self, data):
return update_data(data)
```
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
@ -1605,6 +1605,17 @@ def test_repo_helper_circular_dependency() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
from code_to_optimize.code_directories.retriever.utils import DataProcessor
class DataTransformer:
def __init__(self):
self.data = None
def circular_dependency(self, data):
return DataProcessor().circular_dependency(data)
```
```python:{path_to_utils.relative_to(project_root)}
import math
from transform_utils import DataTransformer
@ -1621,17 +1632,6 @@ class DataProcessor:
\"\"\"Test circular dependency\"\"\"
return DataTransformer().circular_dependency(data)
```
```python:{path_to_transform_utils.relative_to(project_root)}
from code_to_optimize.code_directories.retriever.utils import DataProcessor
class DataTransformer:
def __init__(self):
self.data = None
def circular_dependency(self, data):
return DataProcessor().circular_dependency(data)
```
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
@ -1796,6 +1796,12 @@ def function_to_optimize():
```
"""
expected_read_write_context = f"""
```python:{path_to_fto.relative_to(project_root)}
import code_to_optimize.code_directories.retriever.main
def function_to_optimize():
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
```
```python:{path_to_main.relative_to(project_root)}
import requests
from globals import API_URL
@ -1815,12 +1821,6 @@ def fetch_and_transform_data():
return transformed
```
```python:{path_to_fto.relative_to(project_root)}
import code_to_optimize.code_directories.retriever.main
def function_to_optimize():
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
```
"""
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
@ -2244,6 +2244,20 @@ def get_system_details():
# The expected contexts
relative_path = file_path.relative_to(project_root)
expected_read_write_context = f"""
```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())}
import utility_module
class Calculator:
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
# This is where we use the imported module
self.precision = utility_module.select_precision(precision, fallback_precision)
self.mode = mode
# Using variables from the utility module
self.backend = utility_module.CALCULATION_BACKEND
self.system = utility_module.SYSTEM_TYPE
self.default_precision = utility_module.DEFAULT_PRECISION
```
```python:utility_module.py
import sys
@ -2291,20 +2305,6 @@ def select_precision(precision, fallback_precision):
else:
return DEFAULT_PRECISION
```
```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())}
import utility_module
class Calculator:
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
# This is where we use the imported module
self.precision = utility_module.select_precision(precision, fallback_precision)
self.mode = mode
# Using variables from the utility module
self.backend = utility_module.CALCULATION_BACKEND
self.system = utility_module.SYSTEM_TYPE
self.default_precision = utility_module.DEFAULT_PRECISION
```
"""
expected_read_only_context = """
```python:utility_module.py

View file

@ -1,3 +1,4 @@
from codeflash.languages.javascript.normalizer import normalize_js_code
from codeflash.languages.python.normalizer import normalize_python_code as normalize_code
@ -133,3 +134,74 @@ def safe_divide(a, b):
assert normalize_code(code9) == normalize_code(code10)
assert normalize_code(code9) != normalize_code(code8)
# === JavaScript deduplication tests ===
def test_js_deduplicate_same_logic_different_vars():
code1 = """
function process(items) {
const result = [];
for (const item of items) {
result.push(item * 2);
}
return result;
}
"""
code2 = """
function process(items) {
const output = [];
for (const val of items) {
output.push(val * 2);
}
return output;
}
"""
assert normalize_js_code(code1) == normalize_js_code(code2)
def test_js_different_logic_not_deduplicated():
code1 = """
function compute(x) {
return x + 1;
}
"""
code2 = """
function compute(x) {
return x * 2;
}
"""
assert normalize_js_code(code1) != normalize_js_code(code2)
def test_js_deduplicate_whitespace_and_comments():
code1 = """
function add(a, b) {
// fast path
return a + b;
}
"""
code2 = """
function add(a, b) {
/* optimized */
return a + b;
}
"""
assert normalize_js_code(code1) == normalize_js_code(code2)
def test_ts_normalize():
code1 = """
function greet(name: string): string {
const msg = "hello " + name;
return msg;
}
"""
code2 = """
function greet(name: string): string {
const result = "hello " + name;
return result;
}
"""
assert normalize_js_code(code1, typescript=True) == normalize_js_code(code2, typescript=True)

View file

@ -0,0 +1,84 @@
"""Safety tests for get_optimized_code_for_module() fallback chain.
These tests verify the matching logic that maps AI-generated code blocks
to the correct source file, including all fallback strategies.
"""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from codeflash.languages.code_replacer import get_optimized_code_for_module
def _make_optimized_code(file_to_code: dict[str, str]) -> MagicMock:
"""Create a mock CodeStringsMarkdown with a given file_to_path mapping."""
mock = MagicMock()
mock.file_to_path.return_value = file_to_code
return mock
class TestGetOptimizedCodeForModule:
"""Test the fallback chain in get_optimized_code_for_module."""
def test_exact_path_match(self) -> None:
"""When the relative path matches exactly, return that code."""
code = _make_optimized_code({"src/main/java/com/example/Foo.java": "class Foo {}"})
result = get_optimized_code_for_module(Path("src/main/java/com/example/Foo.java"), code)
assert result == "class Foo {}"
def test_none_key_fallback(self) -> None:
"""When there's a single code block with 'None' key, use it."""
code = _make_optimized_code({"None": "class Foo { optimized }"})
result = get_optimized_code_for_module(Path("src/main/java/com/example/Foo.java"), code)
assert result == "class Foo { optimized }"
def test_basename_match(self) -> None:
"""When the AI returns just 'Algorithms.java', match by basename."""
code = _make_optimized_code({"Algorithms.java": "class Algorithms { fast }"})
result = get_optimized_code_for_module(
Path("src/main/java/com/example/Algorithms.java"), code
)
assert result == "class Algorithms { fast }"
def test_basename_match_with_different_prefix(self) -> None:
"""Basename match should work even with a different directory prefix."""
code = _make_optimized_code({"com/other/Foo.java": "class Foo { v2 }"})
result = get_optimized_code_for_module(Path("src/main/java/com/example/Foo.java"), code)
assert result == "class Foo { v2 }"
@patch("codeflash.languages.current.is_python", return_value=False)
def test_single_block_fallback_non_python(self, _mock: object) -> None:
"""For non-Python, a single code block with wrong path should still match."""
code = _make_optimized_code({"wrong/path/Bar.java": "class Bar { fast }"})
result = get_optimized_code_for_module(Path("src/main/java/com/example/Foo.java"), code)
assert result == "class Bar { fast }"
@patch("codeflash.languages.current.is_python", return_value=True)
def test_single_block_fallback_python_does_not_match(self, _mock: object) -> None:
"""For Python, a single code block with wrong path should NOT match."""
code = _make_optimized_code({"wrong/path/bar.py": "def bar(): pass"})
result = get_optimized_code_for_module(Path("src/foo.py"), code)
assert result == ""
def test_no_match_returns_empty(self) -> None:
"""When multiple blocks exist and none match, return empty string."""
code = _make_optimized_code({
"other/File1.java": "class File1 {}",
"other/File2.java": "class File2 {}",
})
result = get_optimized_code_for_module(Path("src/main/java/com/example/Foo.java"), code)
assert result == ""
def test_none_key_with_multiple_blocks_no_match(self) -> None:
"""When there are multiple blocks including 'None', don't use None fallback."""
code = _make_optimized_code({
"None": "class Default {}",
"other/File.java": "class File {}",
})
result = get_optimized_code_for_module(Path("src/main/java/com/example/Foo.java"), code)
# With multiple blocks, the None-key fallback should NOT trigger
assert result == ""

View file

@ -10,8 +10,8 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode, TestType, VerificationType
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.verification.test_runner import execute_test_subprocess
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.languages.python.test_runner import execute_test_subprocess
from codeflash.verification.verification_utils import TestConfig

View file

@ -9,7 +9,7 @@ from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig

View file

@ -3,7 +3,12 @@ from unittest.mock import patch
import git
from codeflash.code_utils.git_utils import check_and_push_branch, check_running_in_git_repo, get_repo_owner_and_name
from codeflash.code_utils.git_utils import (
check_and_push_branch,
check_running_in_git_repo,
get_git_diff,
get_repo_owner_and_name,
)
class TestGitUtils(unittest.TestCase):
@ -115,5 +120,136 @@ class TestGitUtils(unittest.TestCase):
mock_origin.push.assert_not_called()
DELETION_ONLY_DIFF = """\
--- a/example.py
+++ b/example.py
@@ -5,7 +5,5 @@ def foo():
a = 1
b = 2
- c = 3
- d = 4
e = 5
return a + b + e
"""
ADDITION_ONLY_DIFF = """\
--- a/example.py
+++ b/example.py
@@ -5,5 +5,7 @@ def foo():
a = 1
b = 2
+ c = 3
+ d = 4
e = 5
return a + b + e
"""
MIXED_DIFF = """\
--- a/example.py
+++ b/example.py
@@ -5,6 +5,6 @@ def foo():
a = 1
b = 2
- c = 3
+ c = 30
e = 5
return a + b + e
"""
MULTI_HUNK_DELETION_ONLY_DIFF = """\
--- a/example.py
+++ b/example.py
@@ -5,7 +5,5 @@ def foo():
a = 1
b = 2
- c = 3
- d = 4
e = 5
return a + b + e
@@ -20,6 +18,4 @@ def bar():
x = 1
y = 2
- z = 3
- w = 4
return x + y
"""
class TestGetGitDiffDeletionOnly(unittest.TestCase):
@patch("codeflash.code_utils.git_utils.git.Repo")
def test_deletion_only_diff_returns_hunk_target_starts(self, mock_repo_cls):
repo = mock_repo_cls.return_value
repo.head.commit.hexsha = "abc123"
repo.working_dir = "/repo"
repo.git.diff.return_value = DELETION_ONLY_DIFF
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
assert len(result) == 1
key = list(result.keys())[0]
assert str(key).endswith("example.py")
# The hunk target_start is 5 — this is the fix: deletion-only diffs
# should still report line numbers so the surrounding function is found.
assert result[key] == [5]
@patch("codeflash.code_utils.git_utils.git.Repo")
def test_addition_only_diff_returns_added_lines(self, mock_repo_cls):
repo = mock_repo_cls.return_value
repo.head.commit.hexsha = "abc123"
repo.working_dir = "/repo"
repo.git.diff.return_value = ADDITION_ONLY_DIFF
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
key = list(result.keys())[0]
# Added lines are at target line numbers 7 and 8
assert result[key] == [7, 8]
@patch("codeflash.code_utils.git_utils.git.Repo")
def test_mixed_diff_returns_only_added_lines(self, mock_repo_cls):
repo = mock_repo_cls.return_value
repo.head.commit.hexsha = "abc123"
repo.working_dir = "/repo"
repo.git.diff.return_value = MIXED_DIFF
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
key = list(result.keys())[0]
# Only the added line (c = 30) at target line 7
assert result[key] == [7]
@patch("codeflash.code_utils.git_utils.git.Repo")
def test_multi_hunk_deletion_only_returns_all_hunk_starts(self, mock_repo_cls):
repo = mock_repo_cls.return_value
repo.head.commit.hexsha = "abc123"
repo.working_dir = "/repo"
repo.git.diff.return_value = MULTI_HUNK_DELETION_ONLY_DIFF
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
key = list(result.keys())[0]
# Two hunks with target_start 5 and 18
assert result[key] == [5, 18]
@patch("codeflash.code_utils.git_utils.git.Repo")
def test_deletion_only_diff_does_not_return_empty_list(self, mock_repo_cls):
repo = mock_repo_cls.return_value
repo.head.commit.hexsha = "abc123"
repo.working_dir = "/repo"
repo.git.diff.return_value = DELETION_ONLY_DIFF
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
key = list(result.keys())[0]
# Without the fix, this would be an empty list, causing the function
# to be missed during discovery.
assert len(result[key]) > 0
if __name__ == "__main__":
unittest.main()

View file

@ -12,7 +12,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
# Used by cli instrumentation
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):

View file

@ -3,7 +3,7 @@ from pathlib import Path
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
def test_add_codeflash_capture():
@ -354,3 +354,179 @@ class AnotherHelperClass:
test_path.unlink(missing_ok=True)
helper1_path.unlink(missing_ok=True)
helper2_path.unlink(missing_ok=True)
def test_dataclass_no_init_skipped():
"""Dataclasses have auto-generated __init__ not visible in AST. Instrumentation should skip them."""
original_code = """
from dataclasses import dataclass
@dataclass
class MyDataClass:
x: int
y: str
def target_function(self):
return self.x + len(self.y)
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="target_function", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyDataClass")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
# Dataclass should NOT get a synthetic __init__ injected
assert "super().__init__" not in modified_code
assert "codeflash_capture" not in modified_code
finally:
test_path.unlink(missing_ok=True)
def test_dataclass_with_call_syntax_skipped():
"""@dataclass(frozen=True) should also be skipped."""
original_code = """
from dataclasses import dataclass
@dataclass(frozen=True)
class FrozenData:
value: int
def compute(self):
return self.value * 2
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="FrozenData")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert "super().__init__" not in modified_code
assert "codeflash_capture" not in modified_code
finally:
test_path.unlink(missing_ok=True)
def test_namedtuple_no_init_skipped():
"""NamedTuples have synthesized __init__ that cannot be overwritten. Instrumentation should skip them."""
original_code = """
from typing import NamedTuple
class MyTuple(NamedTuple):
x: int
y: str
def display(self):
return f"{self.x}: {self.y}"
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="display", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyTuple")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert "super().__init__" not in modified_code
assert "codeflash_capture" not in modified_code
finally:
test_path.unlink(missing_ok=True)
def test_module_qualified_dataclass_with_call_syntax_skipped():
"""@dataclasses.dataclass(frozen=True) — module-qualified call-style decorator — should be skipped."""
original_code = """
import dataclasses
@dataclasses.dataclass(frozen=True)
class FrozenPoint:
x: int
y: int
def magnitude(self):
return (self.x ** 2 + self.y ** 2) ** 0.5
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="magnitude", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="FrozenPoint")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert "super().__init__" not in modified_code
assert "codeflash_capture" not in modified_code
finally:
test_path.unlink(missing_ok=True)
def test_module_qualified_namedtuple_skipped():
"""typing.NamedTuple — module-qualified base class — should be skipped."""
original_code = """
import typing
class MyTuple(typing.NamedTuple):
x: int
y: str
def display(self):
return f"{self.x}: {self.y}"
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="display", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyTuple")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert "super().__init__" not in modified_code
assert "codeflash_capture" not in modified_code
finally:
test_path.unlink(missing_ok=True)
def test_dataclass_with_explicit_init_still_instrumented():
"""A dataclass that defines its own __init__ should still be instrumented normally."""
original_code = """
from dataclasses import dataclass
@dataclass
class CustomInit:
x: int
def __init__(self, x: int):
self.x = x * 2
def target(self):
return self.x
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="target", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="CustomInit")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
# Should be instrumented because it has an explicit __init__
assert "codeflash_capture" in modified_code
# Should NOT have super().__init__ injected (it has its own __init__)
assert "super().__init__" not in modified_code
finally:
test_path.unlink(missing_ok=True)

View file

@ -14,7 +14,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode, TestType, VerificationType
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
# Used by aiservice instrumentation
behavior_logging_code = """

View file

@ -0,0 +1,96 @@
from __future__ import annotations
from pathlib import Path
from codeflash.languages.code_replacer import get_optimized_code_for_module
from codeflash.models.models import CodeString, CodeStringsMarkdown
def _make_markdown(*code_strings: tuple[str, str | None]) -> CodeStringsMarkdown:
return CodeStringsMarkdown(
code_strings=[
CodeString(code=code, file_path=Path(path) if path else None) for code, path in code_strings
]
)
# --- Exact path match ---
def test_exact_path_match_single_file() -> None:
md = _make_markdown(("def foo(): pass", "src/utils.py"))
assert get_optimized_code_for_module(Path("src/utils.py"), md) == "def foo(): pass"
def test_exact_path_match_picks_correct_file() -> None:
md = _make_markdown(
("def foo(): pass", "src/utils.py"),
("def bar(): pass", "src/helpers.py"),
)
assert get_optimized_code_for_module(Path("src/helpers.py"), md) == "def bar(): pass"
def test_exact_match_preferred_over_basename() -> None:
md = _make_markdown(
("def wrong(): pass", "other/utils.py"),
("def correct(): pass", "src/utils.py"),
)
assert get_optimized_code_for_module(Path("src/utils.py"), md) == "def correct(): pass"
# --- Fallback 1: single None-path block ---
def test_none_path_fallback_single_block() -> None:
md = _make_markdown(("def foo(): pass", None))
assert get_optimized_code_for_module(Path("src/utils.py"), md) == "def foo(): pass"
def test_none_path_fallback_ignored_when_named_blocks_exist() -> None:
md = _make_markdown(("def foo(): pass", None), ("def bar(): pass", "src/other.py"))
# None fallback requires exactly one entry in the dict keyed "None" and no other keys
assert get_optimized_code_for_module(Path("src/utils.py"), md) == ""
# --- Fallback 2: basename match ---
def test_basename_fallback_different_directory() -> None:
md = _make_markdown(("def optimized(): pass", "wrong/dir/utils.py"))
assert get_optimized_code_for_module(Path("src/utils.py"), md) == "def optimized(): pass"
def test_basename_fallback_skips_non_matching_context_files() -> None:
"""Target file returned alongside unrelated context files — basename picks the right one."""
md = _make_markdown(
("import logging", "codeflash/cli_cmds/console.py"),
("def optimized(): pass", "other/version.py"),
)
assert get_optimized_code_for_module(Path("codeflash/version.py"), md) == "def optimized(): pass"
def test_basename_fallback_ambiguous_returns_empty() -> None:
md = _make_markdown(
("def foo(): pass", "a/utils.py"),
("def bar(): pass", "b/utils.py"),
)
assert get_optimized_code_for_module(Path("src/utils.py"), md) == ""
def test_no_match_returns_empty() -> None:
md = _make_markdown(("def foo(): pass", "src/helpers.py"))
assert get_optimized_code_for_module(Path("src/utils.py"), md) == ""
def test_empty_markdown_returns_empty() -> None:
md = CodeStringsMarkdown(code_strings=[])
assert get_optimized_code_for_module(Path("src/utils.py"), md) == ""
def test_context_files_only_returns_empty() -> None:
"""Reproduces the CI issue: LLM returns only context files, not the target."""
md = _make_markdown(
("import logging\nlogger = logging.getLogger()", "codeflash/cli_cmds/console.py"),
("class AiServiceClient: ...", "codeflash/api/aiservice.py"),
)
assert get_optimized_code_for_module(Path("codeflash/version.py"), md) == ""

View file

@ -304,7 +304,7 @@ describe('fibonacci', () => {
"""Test FunctionOptimizer can be instantiated for JavaScript."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.function_optimizer import FunctionOptimizer
src_file = js_project / "utils.js"
functions = find_all_functions_in_file(src_file)
@ -337,7 +337,7 @@ describe('fibonacci', () => {
"""Test FunctionOptimizer can be instantiated for TypeScript."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.function_optimizer import FunctionOptimizer
src_file = ts_project / "utils.ts"
functions = find_all_functions_in_file(src_file)

View file

@ -443,10 +443,10 @@ function add(a, b {
class TestNormalizeCode:
"""Tests for normalize_code method."""
"""Tests for normalize_code method using tree-sitter normalizer."""
def test_removes_comments(self, js_support):
"""Test that single-line comments are removed."""
"""Test that comments are absent from normalized output."""
code = """
function add(a, b) {
// Add two numbers
@ -455,19 +455,43 @@ function add(a, b) {
"""
normalized = js_support.normalize_code(code)
assert "// Add two numbers" not in normalized
assert "return a + b" in normalized
assert "Add two numbers" not in normalized
def test_preserves_functionality(self, js_support):
"""Test that code functionality is preserved."""
code = """
function add(a, b) {
// Comment
return a + b;
def test_same_logic_different_vars_are_equal(self, js_support):
"""Test that two functions with same logic but different variable names normalize identically."""
code1 = """
function process(items) {
const result = [];
for (const item of items) {
result.push(item * 2);
}
return result;
}
"""
normalized = js_support.normalize_code(code)
assert "function add" in normalized
assert "return" in normalized
code2 = """
function process(items) {
const output = [];
for (const val of items) {
output.push(val * 2);
}
return output;
}
"""
assert js_support.normalize_code(code1) == js_support.normalize_code(code2)
def test_different_logic_not_equal(self, js_support):
"""Test that two functions with different logic produce different normalized forms."""
code1 = """
function compute(x) {
return x + 1;
}
"""
code2 = """
function compute(x) {
return x * 2;
}
"""
assert js_support.normalize_code(code1) != js_support.normalize_code(code2)
class TestExtractCodeContext:

View file

@ -1,35 +1,31 @@
"""Tests for JavaScript/Jest test runner functionality."""
import sys
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import patch, MagicMock
import pytest
class TestJestRootsConfiguration:
"""Tests for Jest --roots flag handling."""
"""Tests for Jest runtime config creation when test files are outside the project root."""
def test_behavioral_tests_adds_roots_for_test_directories(self):
"""Test that run_jest_behavioral_tests adds --roots for test directories."""
from codeflash.languages.javascript.test_runner import run_jest_behavioral_tests
def test_no_runtime_config_when_tests_inside_project_root(self):
"""Test that no runtime config is created when test files are inside the project root."""
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files, run_jest_behavioral_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
# Create mock test files in a test directory
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir).resolve()
test_dir = tmpdir_path / "test"
test_dir.mkdir()
# Create package.json to simulate a Node project
(tmpdir_path / "package.json").write_text('{"name": "test"}')
# Create mock test files
test_file1 = test_dir / "test_func__unit_test_0.test.ts"
test_file2 = test_dir / "test_func__unit_test_1.test.ts"
test_file1.write_text("// test 1")
test_file2.write_text("// test 2")
mock_test_files = TestFiles(
test_files=[
@ -39,16 +35,11 @@ class TestJestRootsConfiguration:
benchmarking_file_path=test_file1,
test_type=TestType.GENERATED_REGRESSION,
),
TestFile(
original_file_path=test_file2,
instrumented_behavior_file_path=test_file2,
benchmarking_file_path=test_file2,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
# Mock subprocess.run to capture the command
clear_created_config_files()
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
@ -58,42 +49,96 @@ class TestJestRootsConfiguration:
try:
run_jest_behavioral_tests(
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass # Expected to fail since no real Jest
pass
# Verify the command included --roots
if mock_run.called:
call_args = mock_run.call_args
cmd = call_args[0][0]
cmd = mock_run.call_args[0][0]
# No --roots flags should be present
assert "--roots" not in cmd, "Should not have --roots flags when tests are inside project root"
# No runtime config should have been created
runtime_configs = [f for f in get_created_config_files() if "codeflash.runtime" in f.name]
assert len(runtime_configs) == 0, "Should not create runtime config when tests are inside project root"
# Find --roots flags in the command
roots_flags = []
for i, arg in enumerate(cmd):
if arg == "--roots" and i + 1 < len(cmd):
roots_flags.append(cmd[i + 1])
clear_created_config_files()
# Should have added the test directory as a root
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
assert str(test_dir) in roots_flags or any(str(test_dir) in root for root in roots_flags), (
f"Expected test directory {test_dir} in --roots flags: {roots_flags}"
)
def test_benchmarking_tests_adds_roots_for_test_directories(self):
"""Test that run_jest_benchmarking_tests adds --roots for test directories."""
from codeflash.languages.javascript.test_runner import run_jest_benchmarking_tests
def test_behavioral_tests_creates_runtime_config_for_external_tests(self):
"""Test that run_jest_behavioral_tests creates a runtime config when tests are outside the project root."""
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files, run_jest_behavioral_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir).resolve()
test_dir = tmpdir_path / "test"
test_dir.mkdir()
with tempfile.TemporaryDirectory() as project_dir, tempfile.TemporaryDirectory() as external_dir:
project_path = Path(project_dir).resolve()
external_path = Path(external_dir).resolve()
(tmpdir_path / "package.json").write_text('{"name": "test"}')
(project_path / "package.json").write_text('{"name": "test"}')
test_file = test_dir / "test_func__perf_test_0.test.ts"
test_file = external_path / "test_func__unit_test_0.test.ts"
test_file.write_text("// test 1")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file,
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
clear_created_config_files()
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 1
mock_run.return_value = mock_result
try:
run_jest_behavioral_tests(
test_paths=mock_test_files,
test_env={},
cwd=project_path,
project_root=project_path,
)
except Exception:
pass
if mock_run.called:
cmd = mock_run.call_args[0][0]
config_args = [arg for arg in cmd if arg.startswith("--config=")]
assert any("codeflash.runtime" in arg for arg in config_args), (
f"Expected runtime config in --config flag, got: {config_args}"
)
runtime_configs = [f for f in get_created_config_files() if "codeflash.runtime" in f.name]
assert len(runtime_configs) == 1, f"Expected 1 runtime config, got {len(runtime_configs)}"
config_content = runtime_configs[0].read_text(encoding="utf-8")
assert str(external_path) in config_content, "Runtime config should contain external test directory"
clear_created_config_files()
def test_benchmarking_tests_creates_runtime_config_for_external_tests(self):
"""Test that run_jest_benchmarking_tests creates a runtime config when tests are outside the project root."""
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files, run_jest_benchmarking_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as project_dir, tempfile.TemporaryDirectory() as external_dir:
project_path = Path(project_dir).resolve()
external_path = Path(external_dir).resolve()
(project_path / "package.json").write_text('{"name": "test"}')
test_file = external_path / "test_func__perf_test_0.test.ts"
test_file.write_text("// perf test")
mock_test_files = TestFiles(
@ -103,10 +148,12 @@ class TestJestRootsConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
),
]
)
clear_created_config_files()
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
@ -116,36 +163,32 @@ class TestJestRootsConfiguration:
try:
run_jest_benchmarking_tests(
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
test_paths=mock_test_files,
test_env={},
cwd=project_path,
project_root=project_path,
)
except Exception:
pass
if mock_run.called:
call_args = mock_run.call_args
cmd = call_args[0][0]
runtime_configs = [f for f in get_created_config_files() if "codeflash.runtime" in f.name]
assert len(runtime_configs) == 1, "Expected runtime config for external test files"
roots_flags = []
for i, arg in enumerate(cmd):
if arg == "--roots" and i + 1 < len(cmd):
roots_flags.append(cmd[i + 1])
clear_created_config_files()
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
def test_line_profile_tests_adds_roots_for_test_directories(self):
"""Test that run_jest_line_profile_tests adds --roots for test directories."""
from codeflash.languages.javascript.test_runner import run_jest_line_profile_tests
def test_line_profile_tests_creates_runtime_config_for_external_tests(self):
"""Test that run_jest_line_profile_tests creates a runtime config when tests are outside the project root."""
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files, run_jest_line_profile_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
test_dir = tmpdir_path / "test"
test_dir.mkdir()
with tempfile.TemporaryDirectory() as project_dir, tempfile.TemporaryDirectory() as external_dir:
project_path = Path(project_dir).resolve()
external_path = Path(external_dir).resolve()
(tmpdir_path / "package.json").write_text('{"name": "test"}')
(project_path / "package.json").write_text('{"name": "test"}')
test_file = test_dir / "test_func__line_profile.test.ts"
test_file = external_path / "test_func__line_profile.test.ts"
test_file.write_text("// line profile test")
mock_test_files = TestFiles(
@ -155,10 +198,12 @@ class TestJestRootsConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
),
]
)
clear_created_config_files()
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
@ -168,84 +213,18 @@ class TestJestRootsConfiguration:
try:
run_jest_line_profile_tests(
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
test_paths=mock_test_files,
test_env={},
cwd=project_path,
project_root=project_path,
)
except Exception:
pass
if mock_run.called:
call_args = mock_run.call_args
cmd = call_args[0][0]
runtime_configs = [f for f in get_created_config_files() if "codeflash.runtime" in f.name]
assert len(runtime_configs) == 1, "Expected runtime config for external test files"
roots_flags = []
for i, arg in enumerate(cmd):
if arg == "--roots" and i + 1 < len(cmd):
roots_flags.append(cmd[i + 1])
assert len(roots_flags) > 0, "Expected --roots flag in Jest command"
def test_multiple_test_directories_all_added_to_roots(self):
"""Test that multiple test directories are all added as --roots."""
from codeflash.languages.javascript.test_runner import run_jest_behavioral_tests
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
test_dir1 = tmpdir_path / "test"
test_dir2 = tmpdir_path / "spec"
test_dir1.mkdir()
test_dir2.mkdir()
(tmpdir_path / "package.json").write_text('{"name": "test"}')
test_file1 = test_dir1 / "test_func__unit_test_0.test.ts"
test_file2 = test_dir2 / "test_func__unit_test_1.test.ts"
test_file1.write_text("// test 1")
test_file2.write_text("// test 2")
mock_test_files = TestFiles(
test_files=[
TestFile(
original_file_path=test_file1,
instrumented_behavior_file_path=test_file1,
benchmarking_file_path=test_file1,
test_type=TestType.GENERATED_REGRESSION,
),
TestFile(
original_file_path=test_file2,
instrumented_behavior_file_path=test_file2,
benchmarking_file_path=test_file2,
test_type=TestType.GENERATED_REGRESSION,
),
]
)
with patch("subprocess.run") as mock_run:
mock_result = MagicMock()
mock_result.stdout = ""
mock_result.stderr = ""
mock_result.returncode = 1
mock_run.return_value = mock_result
try:
run_jest_behavioral_tests(
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
)
except Exception:
pass
if mock_run.called:
call_args = mock_run.call_args
cmd = call_args[0][0]
roots_flags = []
for i, arg in enumerate(cmd):
if arg == "--roots" and i + 1 < len(cmd):
roots_flags.append(cmd[i + 1])
# Should have two --roots flags (one for each directory)
assert len(roots_flags) == 2, f"Expected 2 --roots flags, got {len(roots_flags)}"
clear_created_config_files()
class TestVitestTimeoutConfiguration:
@ -274,7 +253,7 @@ class TestVitestTimeoutConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
),
]
)
@ -302,9 +281,7 @@ class TestVitestTimeoutConfiguration:
# Subprocess timeout should be at least 120 seconds (minimum)
# or 10x the per-test timeout (150 seconds)
assert subprocess_timeout >= 120, f"Expected subprocess timeout >= 120s, got {subprocess_timeout}s"
assert subprocess_timeout >= 15 * 10, (
f"Expected subprocess timeout >= 150s (10x per-test), got {subprocess_timeout}s"
)
assert subprocess_timeout >= 15 * 10, f"Expected subprocess timeout >= 150s (10x per-test), got {subprocess_timeout}s"
def test_vitest_line_profile_subprocess_timeout_larger_than_test_timeout(self):
"""Test that subprocess timeout is larger than per-test timeout for Vitest line profile tests."""
@ -329,7 +306,7 @@ class TestVitestTimeoutConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
),
]
)
@ -341,7 +318,11 @@ class TestVitestTimeoutConfiguration:
mock_run.return_value = mock_result
run_vitest_line_profile_tests(
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, timeout=15, project_root=tmpdir_path
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
timeout=15,
project_root=tmpdir_path,
)
assert mock_run.called
@ -373,7 +354,7 @@ class TestVitestTimeoutConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
),
]
)
@ -386,7 +367,10 @@ class TestVitestTimeoutConfiguration:
# Run without specifying a timeout
run_vitest_behavioral_tests(
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
assert mock_run.called
@ -428,7 +412,7 @@ class TestVitestInternalLoopingConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
),
]
)
@ -486,7 +470,7 @@ class TestVitestInternalLoopingConfiguration:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
),
]
)
@ -533,7 +517,13 @@ class TestBundlerModuleResolutionFix:
tmpdir_path = Path(tmpdir)
# Create tsconfig with bundler moduleResolution
tsconfig = {"compilerOptions": {"moduleResolution": "bundler", "module": "preserve", "target": "ES2022"}}
tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
"target": "ES2022",
}
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
assert _detect_bundler_module_resolution(tmpdir_path) is True
@ -548,7 +538,12 @@ class TestBundlerModuleResolutionFix:
tmpdir_path = Path(tmpdir)
# Create tsconfig with Node moduleResolution
tsconfig = {"compilerOptions": {"moduleResolution": "Node", "module": "ESNext"}}
tsconfig = {
"compilerOptions": {
"moduleResolution": "Node",
"module": "ESNext",
}
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
assert _detect_bundler_module_resolution(tmpdir_path) is False
@ -573,11 +568,21 @@ class TestBundlerModuleResolutionFix:
# Create a base config with bundler in a subdirectory (simulating node_modules)
node_modules = tmpdir_path / "node_modules" / "@myorg" / "tsconfig"
node_modules.mkdir(parents=True)
base_tsconfig = {"compilerOptions": {"moduleResolution": "bundler", "module": "preserve"}}
base_tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
}
}
(node_modules / "tsconfig.json").write_text(json.dumps(base_tsconfig))
# Create a project tsconfig that extends the base
project_tsconfig = {"extends": "@myorg/tsconfig/tsconfig.json", "compilerOptions": {"target": "ES2022"}}
project_tsconfig = {
"extends": "@myorg/tsconfig/tsconfig.json",
"compilerOptions": {
"target": "ES2022",
}
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(project_tsconfig))
# Should detect bundler from extended config
@ -594,7 +599,11 @@ class TestBundlerModuleResolutionFix:
# Create original tsconfig
original_tsconfig = {
"compilerOptions": {"moduleResolution": "bundler", "module": "preserve", "target": "ES2022"},
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
"target": "ES2022",
},
"include": ["src/**/*.ts"],
"exclude": ["node_modules"],
}
@ -641,7 +650,12 @@ class TestBundlerModuleResolutionFix:
tmpdir_path = Path(tmpdir)
# Create tsconfig with bundler
tsconfig = {"compilerOptions": {"moduleResolution": "bundler", "module": "preserve"}}
tsconfig = {
"compilerOptions": {
"moduleResolution": "bundler",
"module": "preserve",
}
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
(tmpdir_path / "package.json").write_text('{"name": "test"}')
@ -662,7 +676,12 @@ class TestBundlerModuleResolutionFix:
tmpdir_path = Path(tmpdir)
# Create tsconfig with Node moduleResolution
tsconfig = {"compilerOptions": {"moduleResolution": "Node", "module": "ESNext"}}
tsconfig = {
"compilerOptions": {
"moduleResolution": "Node",
"module": "ESNext",
}
}
(tmpdir_path / "tsconfig.json").write_text(json.dumps(tsconfig))
(tmpdir_path / "package.json").write_text('{"name": "test"}')
@ -720,7 +739,7 @@ class TestBundledJestReporter:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
),
]
)
@ -733,7 +752,10 @@ class TestBundledJestReporter:
try:
run_jest_behavioral_tests(
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass
@ -741,9 +763,7 @@ class TestBundledJestReporter:
if mock_run.called:
cmd = mock_run.call_args[0][0]
reporter_args = [a for a in cmd if "--reporters=" in a and "jest-reporter" in a]
assert len(reporter_args) == 1, (
f"Expected exactly one codeflash/jest-reporter flag, got: {reporter_args}"
)
assert len(reporter_args) == 1, f"Expected exactly one codeflash/jest-reporter flag, got: {reporter_args}"
assert reporter_args[0] == "--reporters=codeflash/jest-reporter"
# Must NOT reference jest-junit
jest_junit_args = [a for a in cmd if "jest-junit" in a]
@ -770,7 +790,7 @@ class TestBundledJestReporter:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
),
]
)
@ -783,7 +803,10 @@ class TestBundledJestReporter:
try:
run_jest_benchmarking_tests(
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass
@ -814,7 +837,7 @@ class TestBundledJestReporter:
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
),
]
)
@ -827,7 +850,10 @@ class TestBundledJestReporter:
try:
run_jest_line_profile_tests(
test_paths=mock_test_files, test_env={}, cwd=tmpdir_path, project_root=tmpdir_path
test_paths=mock_test_files,
test_env={},
cwd=tmpdir_path,
project_root=tmpdir_path,
)
except Exception:
pass
@ -837,6 +863,7 @@ class TestBundledJestReporter:
reporter_args = [a for a in cmd if "--reporters=codeflash/jest-reporter" in a]
assert len(reporter_args) == 1
@pytest.mark.skipif(sys.platform == "win32", reason="Node.js subprocess pipe behavior unreliable on Windows CI")
def test_reporter_produces_valid_junit_xml(self):
"""The reporter JS should produce JUnit XML parseable by junitparser."""
import subprocess
@ -848,16 +875,17 @@ class TestBundledJestReporter:
# Create a Node.js script that exercises the reporter with mock data
test_script = Path(tmpdir) / "test_reporter.js"
# Use forward slashes to avoid Windows backslash escape issues in JS strings
reporter_path_js = reporter_path.as_posix()
output_file_js = output_file.as_posix()
test_script.write_text(f"""
// Set env vars BEFORE requiring reporter (matches real Jest behavior)
process.env.JEST_JUNIT_OUTPUT_FILE = '{output_file.as_posix()}';
process.env.JEST_JUNIT_OUTPUT_FILE = '{output_file_js}';
process.env.JEST_JUNIT_CLASSNAME = '{{filepath}}';
process.env.JEST_JUNIT_SUITE_NAME = '{{filepath}}';
process.env.JEST_JUNIT_ADD_FILE_ATTRIBUTE = 'true';
process.env.JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT = 'true';
const Reporter = require('{reporter_path.as_posix()}');
const Reporter = require('{reporter_path_js}');
// Mock Jest globalConfig
const globalConfig = {{ rootDir: '/tmp/project' }};
@ -902,9 +930,14 @@ reporter.onTestFileResult(null, results.testResults[0], null);
reporter.onRunComplete([], results);
console.log('OK');
""")
""", encoding="utf-8")
result = subprocess.run(["node", str(test_script)], capture_output=True, text=True, timeout=10)
result = subprocess.run(
["node", str(test_script)],
capture_output=True,
text=True,
timeout=10,
)
assert result.returncode == 0, f"Reporter script failed: {result.stderr}"
assert output_file.exists(), "Reporter did not create output file"
@ -956,6 +989,7 @@ console.log('OK');
assert exports["./jest-reporter"]["require"] == "./runtime/jest-reporter.js"
class TestUnsupportedFrameworkError:
"""Tests for clear error on unsupported test frameworks."""
@ -965,7 +999,12 @@ class TestUnsupportedFrameworkError:
support = JavaScriptSupport()
with pytest.raises(NotImplementedError, match="not yet supported"):
support.run_behavioral_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="tap")
support.run_behavioral_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="tap",
)
def test_unknown_framework_raises_error_benchmarking(self):
"""run_benchmarking_tests should raise NotImplementedError for unknown frameworks."""
@ -973,7 +1012,12 @@ class TestUnsupportedFrameworkError:
support = JavaScriptSupport()
with pytest.raises(NotImplementedError, match="not yet supported"):
support.run_benchmarking_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="tap")
support.run_benchmarking_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="tap",
)
def test_unknown_framework_raises_error_line_profile(self):
"""run_line_profile_tests should raise NotImplementedError for unknown frameworks."""
@ -981,27 +1025,42 @@ class TestUnsupportedFrameworkError:
support = JavaScriptSupport()
with pytest.raises(NotImplementedError, match="not yet supported"):
support.run_line_profile_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="tap")
support.run_line_profile_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="tap",
)
def test_jest_framework_does_not_raise_not_implemented(self):
"""Jest framework should NOT raise NotImplementedError."""
"""jest framework should NOT raise NotImplementedError."""
from codeflash.languages.javascript.support import JavaScriptSupport
support = JavaScriptSupport()
try:
support.run_behavioral_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="jest")
support.run_behavioral_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="jest",
)
except NotImplementedError:
pytest.fail("jest framework should not raise NotImplementedError")
except Exception:
pass # Other exceptions are fine — Jest isn't installed in test env
def test_mocha_framework_does_not_raise_not_implemented(self):
"""Mocha framework should NOT raise NotImplementedError."""
"""mocha framework should NOT raise NotImplementedError."""
from codeflash.languages.javascript.support import JavaScriptSupport
support = JavaScriptSupport()
try:
support.run_behavioral_tests(test_paths=MagicMock(), test_env={}, cwd=Path(), test_framework="mocha")
support.run_behavioral_tests(
test_paths=MagicMock(),
test_env={},
cwd=Path("."),
test_framework="mocha",
)
except NotImplementedError:
pytest.fail("mocha framework should not raise NotImplementedError")
except Exception:

View file

@ -15,7 +15,7 @@ from pathlib import Path
import pytest
from codeflash.languages.base import Language
from codeflash.languages.base import FunctionFilterCriteria, Language
from codeflash.languages.code_replacer import replace_function_definitions_for_language
from codeflash.languages.current import set_current_language
from codeflash.languages.javascript.module_system import (
@ -2264,3 +2264,150 @@ export function processNode(node: TreeNode, space: NodeSpace): number {
assert "// Optimized" in result
assert ts_support.validate_syntax(result) is True
class TestVariableAssignedFunctionReplacement:
"""Tests for replacing functions assigned to variables (function expressions, var declarations, etc.)."""
NO_EXPORT_FILTER = FunctionFilterCriteria(require_export=False, require_return=False)
def test_replace_function_expression_body(self, js_support, temp_project):
"""Test replacing an exported const-assigned function expression."""
original_source = """\
export const foo = function(x) {
return x + 1;
};
"""
file_path = temp_project / "funcs.js"
file_path.write_text(original_source, encoding="utf-8")
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "foo"
optimized_code = """\
export const foo = function(x) {
return (x + 1) | 0;
};
"""
result = js_support.replace_function(original_source, func, optimized_code)
assert "return (x + 1) | 0;" in result
assert js_support.validate_syntax(result) is True
def test_replace_function_expression_with_var(self, js_support, temp_project):
"""Test replacing a var-assigned function expression (non-exported, e.g. CommonJS)."""
original_source = """\
var foo = function(x) {
return x * 2;
};
"""
file_path = temp_project / "funcs.js"
file_path.write_text(original_source, encoding="utf-8")
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path, filter_criteria=self.NO_EXPORT_FILTER)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "foo"
optimized_code = """\
var foo = function(x) {
return x << 1;
};
"""
result = js_support.replace_function(original_source, func, optimized_code)
assert "return x << 1;" in result
assert js_support.validate_syntax(result) is True
def test_replace_generator_function_expression(self, js_support, temp_project):
"""Test replacing an exported const-assigned generator function expression."""
original_source = """\
export const gen = function*(n) {
for (let i = 0; i < n; i++) {
yield i;
}
};
"""
file_path = temp_project / "generators.js"
file_path.write_text(original_source, encoding="utf-8")
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path, filter_criteria=self.NO_EXPORT_FILTER)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "gen"
optimized_code = """\
export const gen = function*(n) {
let i = 0;
while (i < n) yield i++;
};
"""
result = js_support.replace_function(original_source, func, optimized_code)
assert "while (i < n) yield i++;" in result
assert js_support.validate_syntax(result) is True
def test_replace_arrow_function_multiline_declaration(self, js_support, temp_project):
"""Test replacing an arrow function where the arrow is on a different line than const."""
original_source = """\
export const calculate =
(a, b) => {
return a + b;
};
"""
file_path = temp_project / "calc.js"
file_path.write_text(original_source, encoding="utf-8")
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "calculate"
optimized_code = """\
export const calculate =
(a, b) => {
return (a + b) | 0;
};
"""
result = js_support.replace_function(original_source, func, optimized_code)
assert "return (a + b) | 0;" in result
assert js_support.validate_syntax(result) is True
def test_replace_async_arrow_function(self, js_support, temp_project):
"""Test replacing an exported const-assigned async arrow function."""
original_source = """\
export const fetchData = async (url) => {
const response = await fetch(url);
return response.json();
};
"""
file_path = temp_project / "api.js"
file_path.write_text(original_source, encoding="utf-8")
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "fetchData"
optimized_code = """\
export const fetchData = async (url) => {
return (await fetch(url)).json();
};
"""
result = js_support.replace_function(original_source, func, optimized_code)
assert "return (await fetch(url)).json();" in result
assert js_support.validate_syntax(result) is True

View file

@ -0,0 +1,745 @@
"""Test replace_function_and_helpers_with_optimized_code with mock candidate from mock_candidate.txt."""
import tempfile
from pathlib import Path
import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.unused_definition_remover import detect_unused_helper_functions
from codeflash.models.function_types import FunctionParent
from codeflash.models.models import CodeStringsMarkdown
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
ORIGINAL_SOURCE = '''\
import contextlib
from typing import BinaryIO, TypeVar, Union
_SymbolT = TypeVar("_SymbolT", PSLiteral, PSKeyword)
PSLiteralTable = PSSymbolTable(PSLiteral)
PSKeywordTable = PSSymbolTable(PSKeyword)
LIT = PSLiteralTable.intern
KWD = PSKeywordTable.intern
KEYWORD_DICT_BEGIN = KWD(b"<<")
KEYWORD_DICT_END = KWD(b">>")
PSBaseParserToken = Union[float, bool, PSLiteral, PSKeyword, bytes]
class PSBaseParser:
def __init__(self, fp: BinaryIO) -> None:
self.fp = fp
self.eof = False
self.seek(0)
def _parse_main(self, s: bytes, i: int) -> int:
m = NONSPC.search(s, i)
if not m:
return len(s)
j = m.start(0)
c = s[j : j + 1]
self._curtokenpos = self.bufpos + j
if c == b"%":
self._curtoken = b"%"
self._parse1 = self._parse_comment
return j + 1
elif c == b"/":
self._curtoken = b""
self._parse1 = self._parse_literal
return j + 1
elif c in b"-+" or c.isdigit():
self._curtoken = c
self._parse1 = self._parse_number
return j + 1
elif c == b".":
self._curtoken = c
self._parse1 = self._parse_float
return j + 1
elif c.isalpha():
self._curtoken = c
self._parse1 = self._parse_keyword
return j + 1
elif c == b"(":
self._curtoken = b""
self.paren = 1
self._parse1 = self._parse_string
return j + 1
elif c == b"<":
self._curtoken = b""
self._parse1 = self._parse_wopen
return j + 1
elif c == b">":
self._curtoken = b""
self._parse1 = self._parse_wclose
return j + 1
elif c == b"\\x00":
return j + 1
else:
self._add_token(KWD(c))
return j + 1
def _add_token(self, obj: PSBaseParserToken) -> None:
self._tokens.append((self._curtokenpos, obj))
def _parse_comment(self, s: bytes, i: int) -> int:
m = EOL.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
self._parse1 = self._parse_main
return j
def _parse_literal(self, s: bytes, i: int) -> int:
m = END_LITERAL.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
c = s[j : j + 1]
if c == b"#":
self.hex = b""
self._parse1 = self._parse_literal_hex
return j + 1
try:
name: str | bytes = str(self._curtoken, "utf-8")
except Exception:
name = self._curtoken
self._add_token(LIT(name))
self._parse1 = self._parse_main
return j
def _parse_number(self, s: bytes, i: int) -> int:
m = END_NUMBER.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
c = s[j : j + 1]
if c == b".":
self._curtoken += b"."
self._parse1 = self._parse_float
return j + 1
with contextlib.suppress(ValueError):
self._add_token(int(self._curtoken))
self._parse1 = self._parse_main
return j
def _parse_float(self, s: bytes, i: int) -> int:
m = END_NUMBER.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
with contextlib.suppress(ValueError):
self._add_token(float(self._curtoken))
self._parse1 = self._parse_main
return j
def _parse_keyword(self, s: bytes, i: int) -> int:
m = END_KEYWORD.search(s, i)
if m:
j = m.start(0)
self._curtoken += s[i:j]
else:
self._curtoken += s[i:]
return len(s)
if self._curtoken == b"true":
token: bool | PSKeyword = True
elif self._curtoken == b"false":
token = False
else:
token = KWD(self._curtoken)
self._add_token(token)
self._parse1 = self._parse_main
return j
def _parse_string(self, s: bytes, i: int) -> int:
m = END_STRING.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
c = s[j : j + 1]
if c == b"\\\\":
self.oct = b""
self._parse1 = self._parse_string_1
return j + 1
if c == b"(":
self.paren += 1
self._curtoken += c
return j + 1
if c == b")":
self.paren -= 1
if self.paren:
self._curtoken += c
return j + 1
self._add_token(self._curtoken)
self._parse1 = self._parse_main
return j + 1
def _parse_wopen(self, s: bytes, i: int) -> int:
c = s[i : i + 1]
if c == b"<":
self._add_token(KEYWORD_DICT_BEGIN)
self._parse1 = self._parse_main
i += 1
else:
self._parse1 = self._parse_hexstring
return i
def _parse_wclose(self, s: bytes, i: int) -> int:
c = s[i : i + 1]
if c == b">":
self._add_token(KEYWORD_DICT_END)
i += 1
self._parse1 = self._parse_main
return i
'''
MOCK_CANDIDATE_MARKDOWN = '''\
```python
#!/usr/bin/env python3
import contextlib
from typing import BinaryIO, TypeVar, Union
_SymbolT = TypeVar("_SymbolT", PSLiteral, PSKeyword)
PSLiteralTable = PSSymbolTable(PSLiteral)
PSKeywordTable = PSSymbolTable(PSKeyword)
LIT = PSLiteralTable.intern
KWD = PSKeywordTable.intern
KEYWORD_DICT_BEGIN = KWD(b"<<")
KEYWORD_DICT_END = KWD(b">>")
PSBaseParserToken = Union[float, bool, PSLiteral, PSKeyword, bytes]
class PSBaseParser:
def __init__(self, fp: BinaryIO) -> None:
self.fp = fp
self.eof = False
self.seek(0)
def _parse_main(self, s: bytes, i: int) -> int:
m = NONSPC.search(s, i)
if not m:
return len(s)
j = m.start(0)
# Use integer byte access to avoid creating a new one-byte bytes object.
c_int = s[j]
c_byte = bytes((c_int,))
self._curtokenpos = self.bufpos + j
if c_int == 37: # b"%"
self._curtoken = b"%"
self._parse1 = self._parse_comment
return j + 1
elif c_int == 47: # b"/"
self._curtoken = b""
self._parse1 = self._parse_literal
return j + 1
# b"-" is 45, b"+" is 43
elif c_int == 45 or c_int == 43 or (48 <= c_int <= 57):
self._curtoken = c_byte
self._parse1 = self._parse_number
return j + 1
elif c_int == 46: # b"."
self._curtoken = c_byte
self._parse1 = self._parse_float
return j + 1
# ASCII alphabetic check
elif (65 <= c_int <= 90) or (97 <= c_int <= 122):
self._curtoken = c_byte
self._parse1 = self._parse_keyword
return j + 1
elif c_int == 40: # b"("
self._curtoken = b""
self.paren = 1
self._parse1 = self._parse_string
return j + 1
elif c_int == 60: # b"<"
self._curtoken = b""
self._parse1 = self._parse_wopen
return j + 1
elif c_int == 62: # b">"
self._curtoken = b""
self._parse1 = self._parse_wclose
return j + 1
elif c_int == 0: # b"\\x00"
return j + 1
else:
self._add_token(KWD(c_byte))
return j + 1
def _add_token(self, obj: PSBaseParserToken) -> None:
self._tokens.append((self._curtokenpos, obj))
def _parse_comment(self, s: bytes, i: int) -> int:
m = EOL.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
self._parse1 = self._parse_main
# We ignore comments.
# self._tokens.append(self._curtoken)
return j
def _parse_literal(self, s: bytes, i: int) -> int:
m = END_LITERAL.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
c_int = s[j]
if c_int == 35: # b"#"
self.hex = b""
self._parse1 = self._parse_literal_hex
return j + 1
try:
name: str | bytes = str(self._curtoken, "utf-8")
except Exception:
name = self._curtoken
self._add_token(LIT(name))
self._parse1 = self._parse_main
return j
def _parse_number(self, s: bytes, i: int) -> int:
m = END_NUMBER.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
c_int = s[j]
if c_int == 46: # b"."
self._curtoken += b"."
self._parse1 = self._parse_float
return j + 1
with contextlib.suppress(ValueError):
self._add_token(int(self._curtoken))
self._parse1 = self._parse_main
return j
def _parse_float(self, s: bytes, i: int) -> int:
m = END_NUMBER.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
with contextlib.suppress(ValueError):
self._add_token(float(self._curtoken))
self._parse1 = self._parse_main
return j
def _parse_keyword(self, s: bytes, i: int) -> int:
m = END_KEYWORD.search(s, i)
if m:
j = m.start(0)
self._curtoken += s[i:j]
else:
self._curtoken += s[i:]
return len(s)
if self._curtoken == b"true":
token: bool | PSKeyword = True
elif self._curtoken == b"false":
token = False
else:
token = KWD(self._curtoken)
self._add_token(token)
self._parse1 = self._parse_main
return j
def _parse_string(self, s: bytes, i: int) -> int:
m = END_STRING.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
c_int = s[j]
if c_int == 92: # b"\\\\"
self.oct = b""
self._parse1 = self._parse_string_1
return j + 1
if c_int == 40: # b"("
self.paren += 1
# append the literal "(" byte
self._curtoken += b"("
return j + 1
if c_int == 41: # b")"
self.paren -= 1
if self.paren:
# WTF, they said balanced parens need no special treatment.
self._curtoken += b")"
return j + 1
self._add_token(self._curtoken)
self._parse1 = self._parse_main
return j + 1
def _parse_wopen(self, s: bytes, i: int) -> int:
c_int = s[i]
if c_int == 60: # b"<"
self._add_token(KEYWORD_DICT_BEGIN)
self._parse1 = self._parse_main
i += 1
else:
self._parse1 = self._parse_hexstring
return i
def _parse_wclose(self, s: bytes, i: int) -> int:
c_int = s[i]
if c_int == 62: # b">"
self._add_token(KEYWORD_DICT_END)
i += 1
self._parse1 = self._parse_main
return i
```
'''
EXPECTED_OUTPUT = '''\
import contextlib
from typing import BinaryIO, TypeVar, Union
_SymbolT = TypeVar("_SymbolT", PSLiteral, PSKeyword)
PSLiteralTable = PSSymbolTable(PSLiteral)
PSKeywordTable = PSSymbolTable(PSKeyword)
LIT = PSLiteralTable.intern
KWD = PSKeywordTable.intern
KEYWORD_DICT_BEGIN = KWD(b"<<")
KEYWORD_DICT_END = KWD(b">>")
PSBaseParserToken = Union[float, bool, PSLiteral, PSKeyword, bytes]
class PSBaseParser:
def __init__(self, fp: BinaryIO) -> None:
self.fp = fp
self.eof = False
self.seek(0)
def _parse_main(self, s: bytes, i: int) -> int:
m = NONSPC.search(s, i)
if not m:
return len(s)
j = m.start(0)
# Use integer byte access to avoid creating a new one-byte bytes object.
c_int = s[j]
c_byte = bytes((c_int,))
self._curtokenpos = self.bufpos + j
if c_int == 37: # b"%"
self._curtoken = b"%"
self._parse1 = self._parse_comment
return j + 1
elif c_int == 47: # b"/"
self._curtoken = b""
self._parse1 = self._parse_literal
return j + 1
# b"-" is 45, b"+" is 43
elif c_int == 45 or c_int == 43 or (48 <= c_int <= 57):
self._curtoken = c_byte
self._parse1 = self._parse_number
return j + 1
elif c_int == 46: # b"."
self._curtoken = c_byte
self._parse1 = self._parse_float
return j + 1
# ASCII alphabetic check
elif (65 <= c_int <= 90) or (97 <= c_int <= 122):
self._curtoken = c_byte
self._parse1 = self._parse_keyword
return j + 1
elif c_int == 40: # b"("
self._curtoken = b""
self.paren = 1
self._parse1 = self._parse_string
return j + 1
elif c_int == 60: # b"<"
self._curtoken = b""
self._parse1 = self._parse_wopen
return j + 1
elif c_int == 62: # b">"
self._curtoken = b""
self._parse1 = self._parse_wclose
return j + 1
elif c_int == 0: # b"\\x00"
return j + 1
else:
self._add_token(KWD(c_byte))
return j + 1
def _add_token(self, obj: PSBaseParserToken) -> None:
self._tokens.append((self._curtokenpos, obj))
def _parse_comment(self, s: bytes, i: int) -> int:
m = EOL.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
self._parse1 = self._parse_main
# We ignore comments.
# self._tokens.append(self._curtoken)
return j
def _parse_literal(self, s: bytes, i: int) -> int:
m = END_LITERAL.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
c_int = s[j]
if c_int == 35: # b"#"
self.hex = b""
self._parse1 = self._parse_literal_hex
return j + 1
try:
name: str | bytes = str(self._curtoken, "utf-8")
except Exception:
name = self._curtoken
self._add_token(LIT(name))
self._parse1 = self._parse_main
return j
def _parse_number(self, s: bytes, i: int) -> int:
m = END_NUMBER.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
c_int = s[j]
if c_int == 46: # b"."
self._curtoken += b"."
self._parse1 = self._parse_float
return j + 1
with contextlib.suppress(ValueError):
self._add_token(int(self._curtoken))
self._parse1 = self._parse_main
return j
def _parse_float(self, s: bytes, i: int) -> int:
m = END_NUMBER.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
with contextlib.suppress(ValueError):
self._add_token(float(self._curtoken))
self._parse1 = self._parse_main
return j
def _parse_keyword(self, s: bytes, i: int) -> int:
m = END_KEYWORD.search(s, i)
if m:
j = m.start(0)
self._curtoken += s[i:j]
else:
self._curtoken += s[i:]
return len(s)
if self._curtoken == b"true":
token: bool | PSKeyword = True
elif self._curtoken == b"false":
token = False
else:
token = KWD(self._curtoken)
self._add_token(token)
self._parse1 = self._parse_main
return j
def _parse_string(self, s: bytes, i: int) -> int:
m = END_STRING.search(s, i)
if not m:
self._curtoken += s[i:]
return len(s)
j = m.start(0)
self._curtoken += s[i:j]
c_int = s[j]
if c_int == 92: # b"\\\\"
self.oct = b""
self._parse1 = self._parse_string_1
return j + 1
if c_int == 40: # b"("
self.paren += 1
# append the literal "(" byte
self._curtoken += b"("
return j + 1
if c_int == 41: # b")"
self.paren -= 1
if self.paren:
# WTF, they said balanced parens need no special treatment.
self._curtoken += b")"
return j + 1
self._add_token(self._curtoken)
self._parse1 = self._parse_main
return j + 1
def _parse_wopen(self, s: bytes, i: int) -> int:
c_int = s[i]
if c_int == 60: # b"<"
self._add_token(KEYWORD_DICT_BEGIN)
self._parse1 = self._parse_main
i += 1
else:
self._parse1 = self._parse_hexstring
return i
def _parse_wclose(self, s: bytes, i: int) -> int:
c_int = s[i]
if c_int == 62: # b">"
self._add_token(KEYWORD_DICT_END)
i += 1
self._parse1 = self._parse_main
return i
'''
@pytest.fixture
def temp_project():
temp_dir = Path(tempfile.mkdtemp())
source_file = temp_dir / "psparser.py"
source_file.write_text(ORIGINAL_SOURCE, encoding="utf-8")
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
yield temp_dir, source_file, test_cfg
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def run_replacement(temp_project):
"""Helper: run the full replacement pipeline and return (optimizer, code_context, final_content)."""
temp_dir, source_file, test_cfg = temp_project
function_to_optimize = FunctionToOptimize(
file_path=source_file,
function_name="_parse_main",
parents=[FunctionParent(name="PSBaseParser", type="ClassDef")],
)
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=source_file.read_text(encoding="utf-8"),
)
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
original_content = source_file.read_text(encoding="utf-8")
original_helper_code = {source_file: original_content}
optimized_code = CodeStringsMarkdown.parse_markdown_code(MOCK_CANDIDATE_MARKDOWN)
did_update = optimizer.replace_function_and_helpers_with_optimized_code(
code_context, optimized_code, original_helper_code
)
assert did_update, "Expected the code to be updated"
final_content = source_file.read_text(encoding="utf-8")
return optimizer, code_context, final_content
def test_replace_with_mock_candidate(temp_project):
"""Verify replace_function_and_helpers_with_optimized_code produces the exact expected output.
The code context detects ALL sibling methods as helpers of _parse_main.
replace_function_definitions_in_module replaces ALL method bodies.
detect_unused_helper_functions correctly recognizes methods referenced via attribute
assignment (self._parse1 = self._parse_literal) as used, so they are NOT reverted.
"""
_, code_context, final_content = run_replacement(temp_project)
# Code context correctly detects ALL methods as helpers
helper_names = {h.qualified_name for h in code_context.helper_functions}
assert helper_names == {
"PSBaseParser._parse_comment",
"PSBaseParser._parse_literal",
"PSBaseParser._parse_number",
"PSBaseParser._parse_float",
"PSBaseParser._parse_keyword",
"PSBaseParser._parse_string",
"PSBaseParser._parse_wopen",
"PSBaseParser._parse_wclose",
"PSBaseParser._add_token",
"KWD",
}
# The final content should match the expected output exactly
assert final_content == EXPECTED_OUTPUT
def test_detect_unused_helpers_handles_attribute_refs(temp_project):
"""Verify detect_unused_helper_functions recognizes methods referenced via attribute assignment.
When _parse_main does `self._parse1 = self._parse_literal`, the method is referenced as
an ast.Attribute value (not an ast.Call). The detection should recognize these as used.
"""
temp_dir, source_file, test_cfg = temp_project
function_to_optimize = FunctionToOptimize(
file_path=source_file,
function_name="_parse_main",
parents=[FunctionParent(name="PSBaseParser", type="ClassDef")],
)
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=source_file.read_text(encoding="utf-8"),
)
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful()
code_context = ctx_result.unwrap()
optimized_code = CodeStringsMarkdown.parse_markdown_code(MOCK_CANDIDATE_MARKDOWN)
unused_helpers = detect_unused_helper_functions(
optimizer.function_to_optimize, code_context, optimized_code
)
unused_names = {h.qualified_name for h in unused_helpers}
# No helpers should be detected as unused — all are either directly called or
# referenced via attribute assignment (self._parse1 = self._parse_X)
assert unused_names == set(), f"Expected no unused helpers, got: {unused_names}"
def test_replace_produces_valid_python(temp_project):
"""Verify the final output is valid, parseable Python."""
_, _, final_content = run_replacement(temp_project)
import ast
ast.parse(final_content)

View file

@ -0,0 +1,162 @@
from codeflash.languages.python.support import PythonSupport
def test_remove_bare_function():
src = """
def test_foo():
pass
def test_bar():
pass
def test_baz():
pass
"""
result = PythonSupport().remove_test_functions(src, ["test_bar"])
assert result == """
def test_foo():
pass
def test_baz():
pass
"""
def test_remove_qualified_method():
src = """
class TestSuite:
def test_alpha(self):
pass
def test_beta(self):
pass
def test_gamma(self):
pass
"""
result = PythonSupport().remove_test_functions(src, ["TestSuite.test_beta"])
assert result == """
class TestSuite:
def test_alpha(self):
pass
def test_gamma(self):
pass
"""
def test_remove_all_methods_removes_class():
src = """
class TestSuite:
def test_alpha(self):
pass
def test_beta(self):
pass
"""
result = PythonSupport().remove_test_functions(
src, ["TestSuite.test_alpha", "TestSuite.test_beta"]
)
assert result == "\n"
def test_remove_all_methods_from_class_with_docstring():
src = """
class TestSuite:
\"\"\"Suite docstring.\"\"\"
def test_only(self):
pass
"""
result = PythonSupport().remove_test_functions(src, ["TestSuite.test_only"])
assert result == "\n"
def test_mixed_bare_and_qualified():
src = """
def test_standalone():
pass
class TestSuite:
def test_method(self):
pass
"""
result = PythonSupport().remove_test_functions(
src, ["test_standalone", "TestSuite.test_method"]
)
assert result == "\n"
def test_bare_name_does_not_match_class_method():
src = """
class TestSuite:
def test_method(self):
pass
def test_method():
pass
"""
result = PythonSupport().remove_test_functions(src, ["test_method"])
assert result == """
class TestSuite:
def test_method(self):
pass
"""
def test_class_kept_when_non_test_methods_remain():
src = """
class TestSuite:
def setUp(self):
self.x = 1
def test_alpha(self):
pass
def test_beta(self):
pass
"""
result = PythonSupport().remove_test_functions(
src, ["TestSuite.test_alpha", "TestSuite.test_beta"]
)
assert result == """
class TestSuite:
def setUp(self):
self.x = 1
"""
def test_qualified_name_wrong_class_no_removal():
src = """
class TestA:
def test_method(self):
pass
class TestB:
def test_method(self):
pass
"""
result = PythonSupport().remove_test_functions(src, ["TestA.test_method"])
assert result == """
class TestB:
def test_method(self):
pass
"""
def test_no_functions_to_remove_returns_unchanged():
src = """
def test_foo():
pass
"""
result = PythonSupport().remove_test_functions(src, [])
assert result == """
def test_foo():
pass
"""
def test_invalid_syntax_returns_original():
src = "def test_foo(:\n pass"
result = PythonSupport().remove_test_functions(src, ["test_foo"])
assert result == src

View file

@ -1,6 +1,5 @@
import shutil
import sqlite3
import time
from pathlib import Path
import pytest
@ -18,6 +17,7 @@ def test_trace_benchmarks() -> None:
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
tests_root = project_root / "tests"
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
conn: sqlite3.Connection | None = None
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
try:
@ -121,8 +121,8 @@ def test_trace_benchmarks() -> None:
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
# Close connection
conn.close()
conn = None
generate_replay_test(output_file, replay_tests_dir)
test_class_sort_path = replay_tests_dir / Path(
"test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py"
@ -217,7 +217,8 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
"""
assert test_sort_path.read_text("utf-8").strip() == test_sort_code.strip()
finally:
# cleanup
if conn is not None:
conn.close()
output_file.unlink(missing_ok=True)
shutil.rmtree(replay_tests_dir)
@ -231,6 +232,7 @@ def test_trace_multithreaded_benchmark() -> None:
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
conn: sqlite3.Connection | None = None
try:
# check contents of trace file
# connect to database
@ -244,8 +246,6 @@ def test_trace_multithreaded_benchmark() -> None:
)
function_calls = cursor.fetchall()
conn.close()
# Assert the length of function calls
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
@ -281,11 +281,9 @@ def test_trace_multithreaded_benchmark() -> None:
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
# Close connection
conn.close()
finally:
# cleanup
if conn is not None:
conn.close()
output_file.unlink(missing_ok=True)
@ -296,6 +294,7 @@ def test_trace_benchmark_decorator() -> None:
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
conn: sqlite3.Connection | None = None
try:
# check contents of trace file
# connect to database
@ -352,11 +351,7 @@ def test_trace_benchmark_decorator() -> None:
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
# Close connection
cursor.close()
conn.close()
time.sleep(2)
finally:
# cleanup
if conn is not None:
conn.close()
output_file.unlink(missing_ok=True)
time.sleep(1)

1357
uv.lock

File diff suppressed because it is too large Load diff