Added a new function to generate a test file path based on function name and iteration.
This commit is contained in:
parent
21fade5a34
commit
068bf4d153
2 changed files with 18 additions and 3 deletions
|
|
@ -8,6 +8,7 @@ from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
|||
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
||||
from codeflash.optimization.function_context import get_function_variables_definition
|
||||
from codeflash.verification.verifier import generate_tests
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -16,6 +17,7 @@ class args:
|
|||
all: bool = False
|
||||
file: Optional[str] = None
|
||||
function: Optional[str] = None
|
||||
tests_root: Optional[str] = "tests"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
@ -23,6 +25,7 @@ def main() -> None:
|
|||
modified_functions = get_functions_to_optimize(
|
||||
args.root, optimize_all=args.all, file=args.file, function=args.function
|
||||
)
|
||||
test_files_created = set()
|
||||
try:
|
||||
functions_to_tests_map = discover_unit_tests(args.tests_root, args.root)
|
||||
print(functions_to_tests_map)
|
||||
|
|
@ -48,9 +51,9 @@ def main() -> None:
|
|||
print(tests_in_file)
|
||||
new_tests = generate_tests(code_to_optimize)
|
||||
print(new_tests)
|
||||
additional_tests_path = get_test_file_path(args.tests_root, function_name, 0)
|
||||
test_files_created.add(additional_tests_path)
|
||||
with open(additional_tests_path, "w") as file:
|
||||
generated_tests_path = get_test_file_path(args.tests_root, function_name, 0)
|
||||
test_files_created.add(generated_tests_path)
|
||||
with open(generated_tests_path, "w") as file:
|
||||
file.write(new_tests)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
|||
12
codeflash/verification/verification_utils.py
Normal file
12
codeflash/verification/verification_utils.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import os
|
||||
|
||||
|
||||
def get_test_file_path(test_dir, function_name, iteration=0, test_type="unit"):
|
||||
assert test_type in ["unit", "inspired", "replay"]
|
||||
function_name = function_name.replace(".", "_")
|
||||
path = os.path.join(
|
||||
test_dir, f"test_{function_name}__{test_type}_test_{iteration}.py"
|
||||
)
|
||||
if os.path.exists(path):
|
||||
return get_test_file_path(test_dir, function_name, iteration + 1)
|
||||
return path
|
||||
Loading…
Reference in a new issue