Remove Python 3.8 hacks. Also fix all the tests.

This commit is contained in:
Saurabh Misra 2024-01-13 22:17:25 -08:00
parent 202a0f1e1a
commit f05db78477
9 changed files with 166 additions and 359 deletions

View file

@ -1,13 +0,0 @@
import sys
if sys.version_info < (3, 9, 0):
import astunparse
else:
import ast
def ast_unparse(ast_obj) -> str:
if sys.version_info < (3, 9, 0):
return astunparse.ast_unparse(ast_obj)
else:
return ast.unparse(ast_obj)

View file

@ -1,13 +1,9 @@
import ast
import sys
from _ast import ClassDef
from typing import Any, Optional
from codeflash.code_utils.code_utils import module_name_from_file_path, get_run_tmp_file
if sys.version_info < (3, 9, 0):
from astunparse import unparse as ast_unparse
class ReplaceCallNodeWithName(ast.NodeTransformer):
def __init__(self, only_function_name, new_variable_name="return_value"):
@ -365,7 +361,4 @@ def inject_profiling_into_existing_test(test_path, function_name, root_path):
]
tree.body = new_imports + tree.body
if sys.version_info < (3, 9, 0):
return ast_unparse(tree)
else:
return ast.unparse(tree)
return ast.unparse(tree)

View file

@ -3,7 +3,6 @@ import logging
from typing import Tuple, Optional
from codeflash.api.aiservice import generate_regression_tests
from codeflash.code_utils.ast_unparser import ast_unparse
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.code_utils import module_name_from_file_path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -84,4 +83,4 @@ def merge_unit_tests(unit_test_source: str, inspired_unit_tests: str, test_frame
unit_test_source_ast.body = import_list + unit_test_source_ast.body
if test_framework == "unittest":
unit_test_source_ast = delete_multiple_if_name_main(unit_test_source_ast)
return ast_unparse(unit_test_source_ast)
return ast.unparse(unit_test_source_ast)

View file

@ -0,0 +1,140 @@
import os.path
import tempfile
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
def test_perfinjector_bubble_sort():
code = """import unittest
from code_to_optimize.bubble_sort import sorter
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
output = sorter(input)
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = sorter(input)
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(5000)))
output = sorter(input)
self.assertEqual(output, list(range(5000)))
"""
expected = """import time
import gc
import os
import sqlite3
import pickle
import unittest
from code_to_optimize.bubble_sort import sorter
class TestPigLatin(unittest.TestCase):
def test_sort(self):
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, iteration_id TEXT, runtime INTEGER, return_value BLOB)')
input = [5, 4, 3, 2, 1, 0]
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter(input)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '5', codeflash_duration, pickle.dumps(return_value)))
codeflash_con.commit()
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:5#####{{codeflash_duration}}^^^^^')
output = return_value
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter(input)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '8', codeflash_duration, pickle.dumps(return_value)))
codeflash_con.commit()
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:8#####{{codeflash_duration}}^^^^^')
output = return_value
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(5000)))
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter(input)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '11', codeflash_duration, pickle.dumps(return_value)))
codeflash_con.commit()
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:11#####{{codeflash_duration}}^^^^^')
output = return_value
self.assertEqual(output, list(range(5000)))
codeflash_con.close()"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
new_test = inject_profiling_into_existing_test(f.name, "sorter", os.path.dirname(f.name))
assert new_test == expected.format(
module_path=os.path.basename(f.name),
tmp_dir_path=get_run_tmp_file("test_return_values"),
)
def test_perfinjector_only_replay_test():
code = """import pickle
import pytest
from codeflash.tracing.replay_test import get_next_arg_and_return
from codeflash.validation.equivalence import compare_results
from packagename.ml.yolo.image_reshaping_utils import prepare_image_for_yolo as packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo
def test_prepare_image_for_yolo():
for arg_val_pkl, return_val_pkl in get_next_arg_and_return('/home/saurabh/packagename/traces/first.trace', 3):
args = pickle.loads(arg_val_pkl)
return_val_1= pickle.loads(return_val_pkl)
ret = packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo(**args)
assert compare_results(return_val_1, ret)
"""
expected = """import time
import gc
import os
import sqlite3
import pickle
import pickle
import pytest
from codeflash.tracing.replay_test import get_next_arg_and_return
from codeflash.validation.equivalence import compare_results
from packagename.ml.yolo.image_reshaping_utils import prepare_image_for_yolo as packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo
def test_prepare_image_for_yolo():
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, iteration_id TEXT, runtime INTEGER, return_value BLOB)')
for arg_val_pkl, return_val_pkl in get_next_arg_and_return('/home/saurabh/packagename/traces/first.trace', 3):
args = pickle.loads(arg_val_pkl)
return_val_1 = pickle.loads(return_val_pkl)
gc.disable()
counter = time.perf_counter_ns()
return_value = packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo(**args)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', None, 'test_prepare_image_for_yolo', 'packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo', '4_2', codeflash_duration, pickle.dumps(return_value)))
codeflash_con.commit()
print(f'#####{module_path}:test_prepare_image_for_yolo:packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo:4_2#####{{codeflash_duration}}^^^^^')
ret = return_value
assert compare_results(return_val_1, ret)
codeflash_con.close()"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
new_test = inject_profiling_into_existing_test(
f.name, "prepare_image_for_yolo", os.path.dirname(f.name)
)
assert new_test == expected.format(
module_path=os.path.basename(f.name),
tmp_dir_path=get_run_tmp_file("test_return_values"),
)

View file

@ -1,5 +1,4 @@
import os
import sys
os.environ["CODEFLASH_API_KEY"] = "test-key"
from codeflash.verification.verifier import merge_unit_tests
@ -113,113 +112,7 @@ def test_tsp_negative_coordinates():
gc.enable()
_log__test__values(return_value, duration, 'tsp_test_tsp_negative_coordinates__inspired_1')
"""
if sys.version_info < (3, 9, 0):
expected = """
import pytest
import math
import sys
import itertools
import time
import gc
from code_to_optimize.tsp import tsp
import pytest
import math
import sys
import itertools
def distance_between(city1: tuple, city2: tuple) -> float:
return math.hypot((city1[0] - city2[0]), (city1[1] - city2[1]))
def test_tsp_decimal_coordinates():
gc.disable()
counter = time.perf_counter_ns()
return_value = tsp([(0.5, 0.5), (1.5, 1.5), (2.5, 2.5)])
duration = (time.perf_counter_ns() - counter)
gc.enable()
_log__test__values(return_value, duration, 'tsp_test_tsp_decimal_coordinates_0')
def test_tsp_large_coordinate_values():
cities = [(1000000, 1000000), (2000000, 2000000), (3000000, 3000000)]
gc.disable()
counter = time.perf_counter_ns()
return_value = tsp(cities)
duration = (time.perf_counter_ns() - counter)
gc.enable()
_log__test__values(return_value, duration, 'tsp_test_tsp_large_coordinate_values_1')
def distance_between(city1: tuple, city2: tuple) -> float:
return math.hypot((city1[0] - city2[0]), (city1[1] - city2[1]))
def tsp(cities: list[list[int]]):
permutations = itertools.permutations(cities)
min_distance = sys.maxsize
optimal_route = []
for permutation in permutations:
distance = 0
for i in range((len(permutation) - 1)):
distance += distance_between(permutation[i], permutation[(i + 1)])
distance += distance_between(permutation[(- 1)], permutation[0])
if (distance < min_distance):
min_distance = distance
optimal_route = permutation
return (optimal_route, min_distance)
def test_tsp_more_cities__inspired():
cities = [[1, 2], [3, 4], [5, 6], [(- 3), 4], [0, 0]]
gc.disable()
counter = time.perf_counter_ns()
return_value = tsp(cities)
duration = (time.perf_counter_ns() - counter)
gc.enable()
_log__test__values(return_value, duration, 'tsp_test_tsp_more_cities__inspired_1')
def test_tsp_three_cities__inspired():
cities = [[1, 2], [3, 4], [5, 6]]
gc.disable()
counter = time.perf_counter_ns()
return_value = tsp(cities)
duration = (time.perf_counter_ns() - counter)
gc.enable()
_log__test__values(return_value, duration, 'tsp_test_tsp_three_cities__inspired_1')
def test_tsp_single_city__inspired():
cities = [[1, 2]]
gc.disable()
counter = time.perf_counter_ns()
return_value = tsp(cities)
duration = (time.perf_counter_ns() - counter)
gc.enable()
_log__test__values(return_value, duration, 'tsp_test_tsp_single_city__inspired_1')
def test_tsp_empty_cities__inspired():
cities = []
gc.disable()
counter = time.perf_counter_ns()
return_value = tsp(cities)
duration = (time.perf_counter_ns() - counter)
gc.enable()
_log__test__values(return_value, duration, 'tsp_test_tsp_empty_cities__inspired_1')
def test_tsp_duplicate_cities__inspired():
cities = [[1, 2], [3, 4], [1, 2], [3, 4]]
gc.disable()
counter = time.perf_counter_ns()
return_value = tsp(cities)
duration = (time.perf_counter_ns() - counter)
gc.enable()
_log__test__values(return_value, duration, 'tsp_test_tsp_duplicate_cities__inspired_1')
def test_tsp_negative_coordinates__inspired():
cities = [[(- 1), (- 2)], [(- 3), (- 4)], [(- 5), (- 6)]]
gc.disable()
counter = time.perf_counter_ns()
return_value = tsp(cities)
duration = (time.perf_counter_ns() - counter)
gc.enable()
_log__test__values(return_value, duration, 'tsp_test_tsp_negative_coordinates__inspired_1')
"""
else:
expected = """import pytest
expected = """import pytest
import math
import sys
import itertools
@ -432,20 +325,12 @@ class TestGetFilteredClusters(unittest.TestCase):
counter = time.perf_counter_ns()
return_value = get_filtered_clusters(self.cluster_tree, filters)
"""
expected += (
""" duration = (time.perf_counter_ns() - counter)\n"""
if sys.version_info < (3, 9, 0)
else """ duration = time.perf_counter_ns() - counter\n"""
)
expected += """ duration = time.perf_counter_ns() - counter\n"""
expected += """ gc.enable()
_log__test__values(return_value, duration, 'get_filtered_clusters_test_get_filtered_clusters_scenario3_2')
"""
expected += (
"""class MockClusterTree():\n"""
if sys.version_info < (3, 9, 0)
else """class MockClusterTree:\n"""
)
expected += """class MockClusterTree:\n"""
expected += """
def __init__(self, clusters_dict, field_indices, stability_column, ordered_ids):
self.clusters_dict = clusters_dict
@ -477,11 +362,7 @@ class TestGetFilteredClustersInspired(unittest.TestCase):
counter = time.perf_counter_ns()
return_value = get_filtered_clusters(self.cluster_tree, filters)
"""
expected += (
""" duration = (time.perf_counter_ns() - counter)\n"""
if sys.version_info < (3, 9, 0)
else """ duration = time.perf_counter_ns() - counter\n"""
)
expected += """ duration = time.perf_counter_ns() - counter\n"""
expected += """ gc.enable()
_log__test__values(return_value, duration, 'get_filtered_clusters_test_get_filtered_clusters_1')
@ -496,24 +377,12 @@ class TestGetFilteredClustersInspired(unittest.TestCase):
counter = time.perf_counter_ns()
return_value = get_filtered_clusters(self.cluster_tree, filters)
"""
expected += (
""" duration = (time.perf_counter_ns() - counter)\n"""
if sys.version_info < (3, 9, 0)
else """ duration = time.perf_counter_ns() - counter\n"""
)
expected += """ duration = time.perf_counter_ns() - counter\n"""
expected += """ gc.enable()
_log__test__values(return_value, duration, 'get_filtered_clusters_test_get_filtered_clusters_with_clusters_5')
"""
if sys.version_info < (3, 9, 0):
expected += """if (__name__ == '__main__'):
unittest.main()"""
else:
expected += """if __name__ == '__main__':
expected += """if __name__ == '__main__':
unittest.main()"""
modified_file = merge_unit_tests(unit_tests, inspired_test, "unittest")
if sys.version_info < (3, 9, 0):
# assert modified_file.strip("\n") == modified_file
assert modified_file.strip("\n") == expected
else:
assert modified_file == expected
assert modified_file == expected

View file

@ -1,13 +0,0 @@
import sys
if sys.version_info < (3, 9, 0):
import astunparse
else:
import ast
def ast_unparse(ast_obj) -> str:
if sys.version_info < (3, 9, 0):
return astunparse.ast_unparse(ast_obj)
else:
return ast.unparse(ast_obj)

View file

@ -1,6 +1,5 @@
import ast
from injectperf.ast_unparser import ast_unparse
from injectperf.instrument_new_tests import InjectPerfAndLogging, inject_logging_code
from models.functions_to_optimize import FunctionToOptimize
@ -37,6 +36,6 @@ def instrument_test_source(
if test_framework == "unittest":
new_imports += [ast.Import(names=[ast.alias(name="timeout_decorator")])]
new_module_node.body = new_imports + new_module_node.body
new_tests = ast_unparse(new_module_node)
new_tests = ast.unparse(new_module_node)
modified_new_tests = inject_logging_code(new_tests)
return modified_new_tests

View file

@ -3,25 +3,19 @@
# TODO: This is only here as a temporary reference implementaion of how an early version of LLM inspired tests was written.
# It didn't work very well. This should be improved significantly.
import ast # used for detecting whether generated Python code is valid
import sys
from typing import List, Tuple
from codeflash.verification import EXPLAIN_MODEL, PLAN_MODEL, EXECUTE_MODEL, LLM
if sys.version_info < (3, 9, 0):
from astunparse import unparse as ast_unparse
import openai # used for calling the OpenAI API
from codeflash.code_utils.code_extractor import get_code
from codeflash.code_utils.code_utils import ellipsis_in_ast, get_imports_from_file
from codeflash.discovery.discover_unit_tests import TestsInFile
from codeflash.verification import EXPLAIN_MODEL, PLAN_MODEL, EXECUTE_MODEL, LLM
from codeflash.verification.gen_regression_tests import (
print_messages,
print_message_delta,
)
from codeflash.code_utils.code_extractor import get_code
from codeflash.code_utils.code_utils import ellipsis_in_ast, get_imports_from_file
from codeflash.discovery.discover_unit_tests import TestsInFile
def regression_tests_from_function_with_inspiration(
function_to_test: str, # Python function to test, as a string
@ -188,10 +182,7 @@ import {unit_test_package} # used for our unit tests
code = execution.split("```python")[1].split("```")[0].strip()
# TODO: This adds a bunch of redundant imports, clean them up
tests_list = [imp for sublist in inspired_test_imports for imp in sublist]
if sys.version_info >= (3, 9, 0):
code = ast.unparse(tests_list) + "\n" + code
else:
code = ast_unparse(tests_list) + "\n" + code
code = ast.unparse(tests_list) + "\n" + code
try:
module = ast.parse(code)
if ellipsis_in_ast(module):

View file

@ -1,12 +1,7 @@
import ast
import os
import sys
import tempfile
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent
from injectperf.ast_unparser import ast_unparse
from injectperf.instrument_existing_tests import inject_profiling_into_existing_test
from injectperf.instrument_new_tests import InjectPerfAndLogging
os.environ["CODEFLASH_API_KEY"] = "test-key"
@ -35,14 +30,10 @@ def test_InjectPerfAndLogging_with():
gc.disable()
counter = time.perf_counter_ns()
return_value = hdbscan.relative_validity_()\n"""
expected += (
""" duration = (time.perf_counter_ns() - counter)\n"""
if sys.version_info < (3, 9, 0)
else """ duration = time.perf_counter_ns() - counter\n"""
)
expected += """ duration = time.perf_counter_ns() - counter\n"""
expected += """ gc.enable()
_log__test__values(return_value, duration, 'code_to_optimize_path:test_relative_validity_no_tree:relative_validity_:1_0')"""
assert ast_unparse(new_module_node).strip("\n") == expected
assert ast.unparse(new_module_node).strip("\n") == expected
def test_InjectPerfAndLogging():
@ -65,69 +56,10 @@ def test_InjectPerfAndLogging():
gc.disable()
counter = time.perf_counter_ns()
return_value = hdbscan.relative_validity_()\n"""
expected += (
""" duration = (time.perf_counter_ns() - counter)\n"""
if sys.version_info < (3, 9, 0)
else """ duration = time.perf_counter_ns() - counter\n"""
)
expected += """ duration = time.perf_counter_ns() - counter\n"""
expected += """ gc.enable()
_log__test__values(return_value, duration, 'code_to_optimize_path:test_relative_validity_no_tree:relative_validity_:1')"""
assert ast_unparse(new_module_node).strip("\n") == expected
def test_perfinjector_only_replay_test():
code = """import pickle
import pytest
from codeflash.tracing.replay_test import get_next_arg_and_return
from codeflash.validation.equivalence import compare_results
from velo.ml.yolo.image_reshaping_utils import prepare_image_for_yolo as velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo
def test_prepare_image_for_yolo():
for arg_val_pkl, return_val_pkl in get_next_arg_and_return('/home/saurabh/velo/traces/first.trace', 3):
args = pickle.loads(arg_val_pkl)
return_val_1= pickle.loads(return_val_pkl)
ret = velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo(**args)
assert compare_results(return_val_1, ret)
"""
expected = """import time
import gc
import os
import sqlite3
import pickle
import pickle
import pytest
from codeflash.tracing.replay_test import get_next_arg_and_return
from codeflash.validation.equivalence import compare_results
from velo.ml.yolo.image_reshaping_utils import prepare_image_for_yolo as velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo
def test_prepare_image_for_yolo():
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, iteration_id TEXT, runtime INTEGER, return_value BLOB)')
for arg_val_pkl, return_val_pkl in get_next_arg_and_return('/home/saurabh/velo/traces/first.trace', 3):
args = pickle.loads(arg_val_pkl)
return_val_1 = pickle.loads(return_val_pkl)
gc.disable()
counter = time.perf_counter_ns()
return_value = velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo(**args)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', None, 'test_prepare_image_for_yolo', 'velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo', '4_2', codeflash_duration, pickle.dumps(return_value)))
codeflash_con.commit()
print(f'#####{module_path}:test_prepare_image_for_yolo:velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo:4_2#####{{codeflash_duration}}^^^^^')
ret = return_value
assert compare_results(return_val_1, ret)
codeflash_con.close()"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
new_test = inject_profiling_into_existing_test(
f.name, "prepare_image_for_yolo", os.path.dirname(f.name)
)
assert new_test == expected.format(
module_path=os.path.basename(f.name),
tmp_dir_path=get_run_tmp_file("test_return_values"),
)
assert ast.unparse(new_module_node).strip("\n") == expected
def test_remove_bad_assert():
@ -153,14 +85,11 @@ def test_remove_bad_assert():
counter = time.perf_counter_ns()
return_value = hdbscan.relative_validity_()
"""
if sys.version_info < (3, 9, 0):
expected += """ duration = (time.perf_counter_ns() - counter)\n"""
else:
expected += """ duration = time.perf_counter_ns() - counter\n"""
expected += """ duration = time.perf_counter_ns() - counter\n"""
expected += """ gc.enable()
_log__test__values(return_value, duration, 'code_to_optimize_path:test_relative_validity_no_tree:relative_validity_:1')
result = 5"""
assert ast_unparse(new_module_node).strip("\n") == expected
assert ast.unparse(new_module_node).strip("\n") == expected
code = """def test_translate_word_starting_with_vowel():
assert 1 == True
@ -182,11 +111,7 @@ def test_translate_word_starting_with_single_consonant():
counter = time.perf_counter_ns()
return_value = translate('apple')
"""
expected += (
""" duration = (time.perf_counter_ns() - counter)\n"""
if sys.version_info < (3, 9, 0)
else """ duration = time.perf_counter_ns() - counter\n"""
)
expected += """ duration = time.perf_counter_ns() - counter\n"""
# duration = time.perf_counter_ns() - counter
expected += """ gc.enable()
@ -196,93 +121,10 @@ def test_translate_word_starting_with_single_consonant():
gc.disable()
counter = time.perf_counter_ns()
return_value = translate('banana')\n"""
expected += (
""" duration = (time.perf_counter_ns() - counter)\n"""
if sys.version_info < (3, 9, 0)
else """ duration = time.perf_counter_ns() - counter\n"""
)
expected += """ duration = time.perf_counter_ns() - counter\n"""
expected += """ gc.enable()
_log__test__values(return_value, duration, 'code_to_optimize_path:test_translate_word_starting_with_single_consonant:translate:0')"""
assert ast_unparse(new_module_node).strip("\n") == expected
def test_perfinjector_bubble_sort():
code = """import unittest
from code_to_optimize.bubble_sort import sorter
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
output = sorter(input)
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = sorter(input)
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(5000)))
output = sorter(input)
self.assertEqual(output, list(range(5000)))
"""
expected = """import time
import gc
import os
import sqlite3
import pickle
import unittest
from code_to_optimize.bubble_sort import sorter
class TestPigLatin(unittest.TestCase):
def test_sort(self):
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, iteration_id TEXT, runtime INTEGER, return_value BLOB)')
input = [5, 4, 3, 2, 1, 0]
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter(input)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '5', codeflash_duration, pickle.dumps(return_value)))
codeflash_con.commit()
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:5#####{{codeflash_duration}}^^^^^')
output = return_value
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter(input)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '8', codeflash_duration, pickle.dumps(return_value)))
codeflash_con.commit()
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:8#####{{codeflash_duration}}^^^^^')
output = return_value
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(5000)))
gc.disable()
counter = time.perf_counter_ns()
return_value = sorter(input)
codeflash_duration = time.perf_counter_ns() - counter
gc.enable()
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '11', codeflash_duration, pickle.dumps(return_value)))
codeflash_con.commit()
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:11#####{{codeflash_duration}}^^^^^')
output = return_value
self.assertEqual(output, list(range(5000)))
codeflash_con.close()"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
new_test = inject_profiling_into_existing_test(f.name, "sorter", os.path.dirname(f.name))
assert new_test == expected.format(
module_path=os.path.basename(f.name),
tmp_dir_path=get_run_tmp_file("test_return_values"),
)
assert ast.unparse(new_module_node).strip("\n") == expected
def test_unittest_generated_tests_bubble_sort():
@ -490,4 +332,4 @@ if __name__ == '__main__':
new_module_node = InjectPerfAndLogging(
function_to_optimize, auxillary_functions, test_module_path
).visit(module_node)
assert ast_unparse(new_module_node).strip("\n") == expected
assert ast.unparse(new_module_node).strip("\n") == expected