fix the tracer (#884)

* parse args correctly

When there were fewer than 4 test files, the pytest_split() function returned a flat list of strings instead of a list of lists

* update python path correctly

* improve messaging here

* Revert "improve messaging here"

This reverts commit b6ab255135.

* improve error slightly
This commit is contained in:
Kevin Turcios 2025-11-07 21:14:01 -08:00 committed by GitHub
parent ed065b855c
commit 6696f079e9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 29 additions and 5 deletions

View file

@ -38,7 +38,7 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
dir_path = dir_path.parent
msg = f"Could not find pyproject.toml in the current directory {Path.cwd()} or any of the parent directories. Please create it by running `codeflash init`, or pass the path to pyproject.toml with the --config-file argument."
raise ValueError(msg)
raise ValueError(msg) from None
def get_all_closest_config_files() -> list[Path]:
@ -93,7 +93,7 @@ def parse_config_file(
data = tomlkit.parse(f.read())
except tomlkit.exceptions.ParseError as e:
msg = f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}"
raise ValueError(msg) from e
raise ValueError(msg) from None
lsp_mode = is_LSP_enabled()

View file

@ -12,6 +12,7 @@
from __future__ import annotations
import json
import os
import pickle
import subprocess
import sys
@ -64,13 +65,15 @@ def main(args: Namespace | None = None) -> ArgumentParser:
parsed_args.tracer_timeout = getattr(args, "timeout", None)
parsed_args.codeflash_config = getattr(args, "config_file_path", None)
parsed_args.trace_only = getattr(args, "trace_only", False)
parsed_args.module = False
temp_parsed, unknown_args = parser.parse_known_args()
parsed_args.module = temp_parsed.module
sys.argv[:] = unknown_args
if getattr(args, "disable", False):
console.rule("Codeflash: Tracer disabled by --disable option", style="bold red")
return parser
unknown_args = []
else:
if not sys.argv[1:]:
parser.print_usage()
@ -127,6 +130,13 @@ def main(args: Namespace | None = None) -> ArgumentParser:
else:
updated_sys_argv.append(elem)
args_dict["command"] = " ".join(updated_sys_argv)
env = os.environ.copy()
pythonpath = env.get("PYTHONPATH", "")
project_root_str = str(project_root)
if pythonpath:
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
else:
env["PYTHONPATH"] = project_root_str
processes.append(
subprocess.Popen(
[
@ -136,6 +146,7 @@ def main(args: Namespace | None = None) -> ArgumentParser:
json.dumps(args_dict),
],
cwd=Path.cwd(),
env=env,
)
)
for process in processes:
@ -156,6 +167,15 @@ def main(args: Namespace | None = None) -> ArgumentParser:
args_dict["output"] = str(parsed_args.outfile)
args_dict["command"] = " ".join(sys.argv)
env = os.environ.copy()
# Add project root to PYTHONPATH so imports work correctly
pythonpath = env.get("PYTHONPATH", "")
project_root_str = str(project_root)
if pythonpath:
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
else:
env["PYTHONPATH"] = project_root_str
subprocess.run(
[
SAFE_SYS_EXECUTABLE,
@ -164,6 +184,7 @@ def main(args: Namespace | None = None) -> ArgumentParser:
json.dumps(args_dict),
],
cwd=Path.cwd(),
env=env,
check=False,
)
try:

View file

@ -62,7 +62,7 @@ def pytest_split(
# If we have fewer test files than 4 * num_splits, reduce num_splits
max_possible_splits = len(test_files) // 4
if max_possible_splits == 0:
return test_files, test_paths
return [test_files], test_paths
num_splits = min(num_splits, max_possible_splits)

View file

@ -9,6 +9,9 @@ def main() -> None:
# Use the version tuple from version.py
version = __version__
if ".dev" in version or "+" in version or "post" in version:
return
# Use the major and minor version components from the version tuple
major_minor_version = ".".join(map(str, version.split(".")[:2]))