Remove Python 3.8 hacks. Also fix all the tests.
This commit is contained in:
parent
202a0f1e1a
commit
f05db78477
9 changed files with 166 additions and 359 deletions
|
|
@ -1,13 +0,0 @@
|
|||
import sys
|
||||
|
||||
if sys.version_info < (3, 9, 0):
|
||||
import astunparse
|
||||
else:
|
||||
import ast
|
||||
|
||||
|
||||
def ast_unparse(ast_obj) -> str:
|
||||
if sys.version_info < (3, 9, 0):
|
||||
return astunparse.ast_unparse(ast_obj)
|
||||
else:
|
||||
return ast.unparse(ast_obj)
|
||||
|
|
@ -1,13 +1,9 @@
|
|||
import ast
|
||||
import sys
|
||||
from _ast import ClassDef
|
||||
from typing import Any, Optional
|
||||
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path, get_run_tmp_file
|
||||
|
||||
if sys.version_info < (3, 9, 0):
|
||||
from astunparse import unparse as ast_unparse
|
||||
|
||||
|
||||
class ReplaceCallNodeWithName(ast.NodeTransformer):
|
||||
def __init__(self, only_function_name, new_variable_name="return_value"):
|
||||
|
|
@ -365,7 +361,4 @@ def inject_profiling_into_existing_test(test_path, function_name, root_path):
|
|||
]
|
||||
tree.body = new_imports + tree.body
|
||||
|
||||
if sys.version_info < (3, 9, 0):
|
||||
return ast_unparse(tree)
|
||||
else:
|
||||
return ast.unparse(tree)
|
||||
return ast.unparse(tree)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import logging
|
|||
from typing import Tuple, Optional
|
||||
|
||||
from codeflash.api.aiservice import generate_regression_tests
|
||||
from codeflash.code_utils.ast_unparser import ast_unparse
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -84,4 +83,4 @@ def merge_unit_tests(unit_test_source: str, inspired_unit_tests: str, test_frame
|
|||
unit_test_source_ast.body = import_list + unit_test_source_ast.body
|
||||
if test_framework == "unittest":
|
||||
unit_test_source_ast = delete_multiple_if_name_main(unit_test_source_ast)
|
||||
return ast_unparse(unit_test_source_ast)
|
||||
return ast.unparse(unit_test_source_ast)
|
||||
|
|
|
|||
140
cli/tests/test_instrumentation.py
Normal file
140
cli/tests/test_instrumentation.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
import os.path
|
||||
import tempfile
|
||||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
|
||||
|
||||
def test_perfinjector_bubble_sort():
|
||||
code = """import unittest
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
def test_sort(self):
|
||||
input = [5, 4, 3, 2, 1, 0]
|
||||
output = sorter(input)
|
||||
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
|
||||
|
||||
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
|
||||
output = sorter(input)
|
||||
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
|
||||
input = list(reversed(range(5000)))
|
||||
output = sorter(input)
|
||||
self.assertEqual(output, list(range(5000)))
|
||||
"""
|
||||
expected = """import time
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import pickle
|
||||
import unittest
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
|
||||
def test_sort(self):
|
||||
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
|
||||
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, iteration_id TEXT, runtime INTEGER, return_value BLOB)')
|
||||
input = [5, 4, 3, 2, 1, 0]
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter(input)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '5', codeflash_duration, pickle.dumps(return_value)))
|
||||
codeflash_con.commit()
|
||||
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:5#####{{codeflash_duration}}^^^^^')
|
||||
output = return_value
|
||||
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
|
||||
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter(input)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '8', codeflash_duration, pickle.dumps(return_value)))
|
||||
codeflash_con.commit()
|
||||
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:8#####{{codeflash_duration}}^^^^^')
|
||||
output = return_value
|
||||
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
input = list(reversed(range(5000)))
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter(input)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '11', codeflash_duration, pickle.dumps(return_value)))
|
||||
codeflash_con.commit()
|
||||
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:11#####{{codeflash_duration}}^^^^^')
|
||||
output = return_value
|
||||
self.assertEqual(output, list(range(5000)))
|
||||
codeflash_con.close()"""
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
new_test = inject_profiling_into_existing_test(f.name, "sorter", os.path.dirname(f.name))
|
||||
assert new_test == expected.format(
|
||||
module_path=os.path.basename(f.name),
|
||||
tmp_dir_path=get_run_tmp_file("test_return_values"),
|
||||
)
|
||||
|
||||
|
||||
def test_perfinjector_only_replay_test():
|
||||
code = """import pickle
|
||||
import pytest
|
||||
from codeflash.tracing.replay_test import get_next_arg_and_return
|
||||
from codeflash.validation.equivalence import compare_results
|
||||
from packagename.ml.yolo.image_reshaping_utils import prepare_image_for_yolo as packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo
|
||||
def test_prepare_image_for_yolo():
|
||||
for arg_val_pkl, return_val_pkl in get_next_arg_and_return('/home/saurabh/packagename/traces/first.trace', 3):
|
||||
args = pickle.loads(arg_val_pkl)
|
||||
return_val_1= pickle.loads(return_val_pkl)
|
||||
ret = packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo(**args)
|
||||
assert compare_results(return_val_1, ret)
|
||||
"""
|
||||
expected = """import time
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import pickle
|
||||
import pickle
|
||||
import pytest
|
||||
from codeflash.tracing.replay_test import get_next_arg_and_return
|
||||
from codeflash.validation.equivalence import compare_results
|
||||
from packagename.ml.yolo.image_reshaping_utils import prepare_image_for_yolo as packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo
|
||||
|
||||
def test_prepare_image_for_yolo():
|
||||
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
|
||||
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, iteration_id TEXT, runtime INTEGER, return_value BLOB)')
|
||||
for arg_val_pkl, return_val_pkl in get_next_arg_and_return('/home/saurabh/packagename/traces/first.trace', 3):
|
||||
args = pickle.loads(arg_val_pkl)
|
||||
return_val_1 = pickle.loads(return_val_pkl)
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo(**args)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', None, 'test_prepare_image_for_yolo', 'packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo', '4_2', codeflash_duration, pickle.dumps(return_value)))
|
||||
codeflash_con.commit()
|
||||
print(f'#####{module_path}:test_prepare_image_for_yolo:packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo:4_2#####{{codeflash_duration}}^^^^^')
|
||||
ret = return_value
|
||||
assert compare_results(return_val_1, ret)
|
||||
codeflash_con.close()"""
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
new_test = inject_profiling_into_existing_test(
|
||||
f.name, "prepare_image_for_yolo", os.path.dirname(f.name)
|
||||
)
|
||||
assert new_test == expected.format(
|
||||
module_path=os.path.basename(f.name),
|
||||
tmp_dir_path=get_run_tmp_file("test_return_values"),
|
||||
)
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
os.environ["CODEFLASH_API_KEY"] = "test-key"
|
||||
from codeflash.verification.verifier import merge_unit_tests
|
||||
|
|
@ -113,113 +112,7 @@ def test_tsp_negative_coordinates():
|
|||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'tsp_test_tsp_negative_coordinates__inspired_1')
|
||||
"""
|
||||
if sys.version_info < (3, 9, 0):
|
||||
expected = """
|
||||
import pytest
|
||||
import math
|
||||
import sys
|
||||
import itertools
|
||||
import time
|
||||
import gc
|
||||
from code_to_optimize.tsp import tsp
|
||||
import pytest
|
||||
import math
|
||||
import sys
|
||||
import itertools
|
||||
|
||||
def distance_between(city1: tuple, city2: tuple) -> float:
|
||||
return math.hypot((city1[0] - city2[0]), (city1[1] - city2[1]))
|
||||
|
||||
def test_tsp_decimal_coordinates():
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = tsp([(0.5, 0.5), (1.5, 1.5), (2.5, 2.5)])
|
||||
duration = (time.perf_counter_ns() - counter)
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'tsp_test_tsp_decimal_coordinates_0')
|
||||
|
||||
def test_tsp_large_coordinate_values():
|
||||
cities = [(1000000, 1000000), (2000000, 2000000), (3000000, 3000000)]
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = tsp(cities)
|
||||
duration = (time.perf_counter_ns() - counter)
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'tsp_test_tsp_large_coordinate_values_1')
|
||||
|
||||
def distance_between(city1: tuple, city2: tuple) -> float:
|
||||
return math.hypot((city1[0] - city2[0]), (city1[1] - city2[1]))
|
||||
|
||||
def tsp(cities: list[list[int]]):
|
||||
permutations = itertools.permutations(cities)
|
||||
min_distance = sys.maxsize
|
||||
optimal_route = []
|
||||
for permutation in permutations:
|
||||
distance = 0
|
||||
for i in range((len(permutation) - 1)):
|
||||
distance += distance_between(permutation[i], permutation[(i + 1)])
|
||||
distance += distance_between(permutation[(- 1)], permutation[0])
|
||||
if (distance < min_distance):
|
||||
min_distance = distance
|
||||
optimal_route = permutation
|
||||
return (optimal_route, min_distance)
|
||||
|
||||
def test_tsp_more_cities__inspired():
|
||||
cities = [[1, 2], [3, 4], [5, 6], [(- 3), 4], [0, 0]]
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = tsp(cities)
|
||||
duration = (time.perf_counter_ns() - counter)
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'tsp_test_tsp_more_cities__inspired_1')
|
||||
|
||||
def test_tsp_three_cities__inspired():
|
||||
cities = [[1, 2], [3, 4], [5, 6]]
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = tsp(cities)
|
||||
duration = (time.perf_counter_ns() - counter)
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'tsp_test_tsp_three_cities__inspired_1')
|
||||
|
||||
def test_tsp_single_city__inspired():
|
||||
cities = [[1, 2]]
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = tsp(cities)
|
||||
duration = (time.perf_counter_ns() - counter)
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'tsp_test_tsp_single_city__inspired_1')
|
||||
|
||||
def test_tsp_empty_cities__inspired():
|
||||
cities = []
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = tsp(cities)
|
||||
duration = (time.perf_counter_ns() - counter)
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'tsp_test_tsp_empty_cities__inspired_1')
|
||||
|
||||
def test_tsp_duplicate_cities__inspired():
|
||||
cities = [[1, 2], [3, 4], [1, 2], [3, 4]]
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = tsp(cities)
|
||||
duration = (time.perf_counter_ns() - counter)
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'tsp_test_tsp_duplicate_cities__inspired_1')
|
||||
|
||||
def test_tsp_negative_coordinates__inspired():
|
||||
cities = [[(- 1), (- 2)], [(- 3), (- 4)], [(- 5), (- 6)]]
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = tsp(cities)
|
||||
duration = (time.perf_counter_ns() - counter)
|
||||
gc.enable()
|
||||
_log__test__values(return_value, duration, 'tsp_test_tsp_negative_coordinates__inspired_1')
|
||||
"""
|
||||
else:
|
||||
expected = """import pytest
|
||||
expected = """import pytest
|
||||
import math
|
||||
import sys
|
||||
import itertools
|
||||
|
|
@ -432,20 +325,12 @@ class TestGetFilteredClusters(unittest.TestCase):
|
|||
counter = time.perf_counter_ns()
|
||||
return_value = get_filtered_clusters(self.cluster_tree, filters)
|
||||
"""
|
||||
expected += (
|
||||
""" duration = (time.perf_counter_ns() - counter)\n"""
|
||||
if sys.version_info < (3, 9, 0)
|
||||
else """ duration = time.perf_counter_ns() - counter\n"""
|
||||
)
|
||||
expected += """ duration = time.perf_counter_ns() - counter\n"""
|
||||
expected += """ gc.enable()
|
||||
_log__test__values(return_value, duration, 'get_filtered_clusters_test_get_filtered_clusters_scenario3_2')
|
||||
|
||||
"""
|
||||
expected += (
|
||||
"""class MockClusterTree():\n"""
|
||||
if sys.version_info < (3, 9, 0)
|
||||
else """class MockClusterTree:\n"""
|
||||
)
|
||||
expected += """class MockClusterTree:\n"""
|
||||
expected += """
|
||||
def __init__(self, clusters_dict, field_indices, stability_column, ordered_ids):
|
||||
self.clusters_dict = clusters_dict
|
||||
|
|
@ -477,11 +362,7 @@ class TestGetFilteredClustersInspired(unittest.TestCase):
|
|||
counter = time.perf_counter_ns()
|
||||
return_value = get_filtered_clusters(self.cluster_tree, filters)
|
||||
"""
|
||||
expected += (
|
||||
""" duration = (time.perf_counter_ns() - counter)\n"""
|
||||
if sys.version_info < (3, 9, 0)
|
||||
else """ duration = time.perf_counter_ns() - counter\n"""
|
||||
)
|
||||
expected += """ duration = time.perf_counter_ns() - counter\n"""
|
||||
expected += """ gc.enable()
|
||||
_log__test__values(return_value, duration, 'get_filtered_clusters_test_get_filtered_clusters_1')
|
||||
|
||||
|
|
@ -496,24 +377,12 @@ class TestGetFilteredClustersInspired(unittest.TestCase):
|
|||
counter = time.perf_counter_ns()
|
||||
return_value = get_filtered_clusters(self.cluster_tree, filters)
|
||||
"""
|
||||
expected += (
|
||||
""" duration = (time.perf_counter_ns() - counter)\n"""
|
||||
if sys.version_info < (3, 9, 0)
|
||||
else """ duration = time.perf_counter_ns() - counter\n"""
|
||||
)
|
||||
expected += """ duration = time.perf_counter_ns() - counter\n"""
|
||||
expected += """ gc.enable()
|
||||
_log__test__values(return_value, duration, 'get_filtered_clusters_test_get_filtered_clusters_with_clusters_5')
|
||||
"""
|
||||
if sys.version_info < (3, 9, 0):
|
||||
expected += """if (__name__ == '__main__'):
|
||||
unittest.main()"""
|
||||
else:
|
||||
expected += """if __name__ == '__main__':
|
||||
expected += """if __name__ == '__main__':
|
||||
unittest.main()"""
|
||||
|
||||
modified_file = merge_unit_tests(unit_tests, inspired_test, "unittest")
|
||||
if sys.version_info < (3, 9, 0):
|
||||
# assert modified_file.strip("\n") == modified_file
|
||||
assert modified_file.strip("\n") == expected
|
||||
else:
|
||||
assert modified_file == expected
|
||||
assert modified_file == expected
|
||||
|
|
|
|||
|
|
@ -1,13 +0,0 @@
|
|||
import sys
|
||||
|
||||
if sys.version_info < (3, 9, 0):
|
||||
import astunparse
|
||||
else:
|
||||
import ast
|
||||
|
||||
|
||||
def ast_unparse(ast_obj) -> str:
|
||||
if sys.version_info < (3, 9, 0):
|
||||
return astunparse.ast_unparse(ast_obj)
|
||||
else:
|
||||
return ast.unparse(ast_obj)
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
import ast
|
||||
|
||||
from injectperf.ast_unparser import ast_unparse
|
||||
from injectperf.instrument_new_tests import InjectPerfAndLogging, inject_logging_code
|
||||
from models.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
|
@ -37,6 +36,6 @@ def instrument_test_source(
|
|||
if test_framework == "unittest":
|
||||
new_imports += [ast.Import(names=[ast.alias(name="timeout_decorator")])]
|
||||
new_module_node.body = new_imports + new_module_node.body
|
||||
new_tests = ast_unparse(new_module_node)
|
||||
new_tests = ast.unparse(new_module_node)
|
||||
modified_new_tests = inject_logging_code(new_tests)
|
||||
return modified_new_tests
|
||||
|
|
|
|||
|
|
@ -3,25 +3,19 @@
|
|||
# TODO: This is only here as a temporary reference implementaion of how an early version of LLM inspired tests was written.
|
||||
# It didn't work very well. This should be improved significantly.
|
||||
import ast # used for detecting whether generated Python code is valid
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
from codeflash.verification import EXPLAIN_MODEL, PLAN_MODEL, EXECUTE_MODEL, LLM
|
||||
|
||||
if sys.version_info < (3, 9, 0):
|
||||
from astunparse import unparse as ast_unparse
|
||||
|
||||
import openai # used for calling the OpenAI API
|
||||
|
||||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.code_utils.code_utils import ellipsis_in_ast, get_imports_from_file
|
||||
from codeflash.discovery.discover_unit_tests import TestsInFile
|
||||
|
||||
from codeflash.verification import EXPLAIN_MODEL, PLAN_MODEL, EXECUTE_MODEL, LLM
|
||||
from codeflash.verification.gen_regression_tests import (
|
||||
print_messages,
|
||||
print_message_delta,
|
||||
)
|
||||
|
||||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.code_utils.code_utils import ellipsis_in_ast, get_imports_from_file
|
||||
from codeflash.discovery.discover_unit_tests import TestsInFile
|
||||
|
||||
|
||||
def regression_tests_from_function_with_inspiration(
|
||||
function_to_test: str, # Python function to test, as a string
|
||||
|
|
@ -188,10 +182,7 @@ import {unit_test_package} # used for our unit tests
|
|||
code = execution.split("```python")[1].split("```")[0].strip()
|
||||
# TODO: This adds a bunch of redundant imports, clean them up
|
||||
tests_list = [imp for sublist in inspired_test_imports for imp in sublist]
|
||||
if sys.version_info >= (3, 9, 0):
|
||||
code = ast.unparse(tests_list) + "\n" + code
|
||||
else:
|
||||
code = ast_unparse(tests_list) + "\n" + code
|
||||
code = ast.unparse(tests_list) + "\n" + code
|
||||
try:
|
||||
module = ast.parse(code)
|
||||
if ellipsis_in_ast(module):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,7 @@
|
|||
import ast
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent
|
||||
from injectperf.ast_unparser import ast_unparse
|
||||
from injectperf.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
from injectperf.instrument_new_tests import InjectPerfAndLogging
|
||||
|
||||
os.environ["CODEFLASH_API_KEY"] = "test-key"
|
||||
|
|
@ -35,14 +30,10 @@ def test_InjectPerfAndLogging_with():
|
|||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = hdbscan.relative_validity_()\n"""
|
||||
expected += (
|
||||
""" duration = (time.perf_counter_ns() - counter)\n"""
|
||||
if sys.version_info < (3, 9, 0)
|
||||
else """ duration = time.perf_counter_ns() - counter\n"""
|
||||
)
|
||||
expected += """ duration = time.perf_counter_ns() - counter\n"""
|
||||
expected += """ gc.enable()
|
||||
_log__test__values(return_value, duration, 'code_to_optimize_path:test_relative_validity_no_tree:relative_validity_:1_0')"""
|
||||
assert ast_unparse(new_module_node).strip("\n") == expected
|
||||
assert ast.unparse(new_module_node).strip("\n") == expected
|
||||
|
||||
|
||||
def test_InjectPerfAndLogging():
|
||||
|
|
@ -65,69 +56,10 @@ def test_InjectPerfAndLogging():
|
|||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = hdbscan.relative_validity_()\n"""
|
||||
expected += (
|
||||
""" duration = (time.perf_counter_ns() - counter)\n"""
|
||||
if sys.version_info < (3, 9, 0)
|
||||
else """ duration = time.perf_counter_ns() - counter\n"""
|
||||
)
|
||||
expected += """ duration = time.perf_counter_ns() - counter\n"""
|
||||
expected += """ gc.enable()
|
||||
_log__test__values(return_value, duration, 'code_to_optimize_path:test_relative_validity_no_tree:relative_validity_:1')"""
|
||||
assert ast_unparse(new_module_node).strip("\n") == expected
|
||||
|
||||
|
||||
def test_perfinjector_only_replay_test():
|
||||
code = """import pickle
|
||||
import pytest
|
||||
from codeflash.tracing.replay_test import get_next_arg_and_return
|
||||
from codeflash.validation.equivalence import compare_results
|
||||
from velo.ml.yolo.image_reshaping_utils import prepare_image_for_yolo as velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo
|
||||
def test_prepare_image_for_yolo():
|
||||
for arg_val_pkl, return_val_pkl in get_next_arg_and_return('/home/saurabh/velo/traces/first.trace', 3):
|
||||
args = pickle.loads(arg_val_pkl)
|
||||
return_val_1= pickle.loads(return_val_pkl)
|
||||
ret = velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo(**args)
|
||||
assert compare_results(return_val_1, ret)
|
||||
"""
|
||||
expected = """import time
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import pickle
|
||||
import pickle
|
||||
import pytest
|
||||
from codeflash.tracing.replay_test import get_next_arg_and_return
|
||||
from codeflash.validation.equivalence import compare_results
|
||||
from velo.ml.yolo.image_reshaping_utils import prepare_image_for_yolo as velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo
|
||||
|
||||
def test_prepare_image_for_yolo():
|
||||
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
|
||||
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, iteration_id TEXT, runtime INTEGER, return_value BLOB)')
|
||||
for arg_val_pkl, return_val_pkl in get_next_arg_and_return('/home/saurabh/velo/traces/first.trace', 3):
|
||||
args = pickle.loads(arg_val_pkl)
|
||||
return_val_1 = pickle.loads(return_val_pkl)
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo(**args)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', None, 'test_prepare_image_for_yolo', 'velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo', '4_2', codeflash_duration, pickle.dumps(return_value)))
|
||||
codeflash_con.commit()
|
||||
print(f'#####{module_path}:test_prepare_image_for_yolo:velo_ml_yolo_image_reshaping_utils_prepare_image_for_yolo:4_2#####{{codeflash_duration}}^^^^^')
|
||||
ret = return_value
|
||||
assert compare_results(return_val_1, ret)
|
||||
codeflash_con.close()"""
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
new_test = inject_profiling_into_existing_test(
|
||||
f.name, "prepare_image_for_yolo", os.path.dirname(f.name)
|
||||
)
|
||||
assert new_test == expected.format(
|
||||
module_path=os.path.basename(f.name),
|
||||
tmp_dir_path=get_run_tmp_file("test_return_values"),
|
||||
)
|
||||
assert ast.unparse(new_module_node).strip("\n") == expected
|
||||
|
||||
|
||||
def test_remove_bad_assert():
|
||||
|
|
@ -153,14 +85,11 @@ def test_remove_bad_assert():
|
|||
counter = time.perf_counter_ns()
|
||||
return_value = hdbscan.relative_validity_()
|
||||
"""
|
||||
if sys.version_info < (3, 9, 0):
|
||||
expected += """ duration = (time.perf_counter_ns() - counter)\n"""
|
||||
else:
|
||||
expected += """ duration = time.perf_counter_ns() - counter\n"""
|
||||
expected += """ duration = time.perf_counter_ns() - counter\n"""
|
||||
expected += """ gc.enable()
|
||||
_log__test__values(return_value, duration, 'code_to_optimize_path:test_relative_validity_no_tree:relative_validity_:1')
|
||||
result = 5"""
|
||||
assert ast_unparse(new_module_node).strip("\n") == expected
|
||||
assert ast.unparse(new_module_node).strip("\n") == expected
|
||||
|
||||
code = """def test_translate_word_starting_with_vowel():
|
||||
assert 1 == True
|
||||
|
|
@ -182,11 +111,7 @@ def test_translate_word_starting_with_single_consonant():
|
|||
counter = time.perf_counter_ns()
|
||||
return_value = translate('apple')
|
||||
"""
|
||||
expected += (
|
||||
""" duration = (time.perf_counter_ns() - counter)\n"""
|
||||
if sys.version_info < (3, 9, 0)
|
||||
else """ duration = time.perf_counter_ns() - counter\n"""
|
||||
)
|
||||
expected += """ duration = time.perf_counter_ns() - counter\n"""
|
||||
# duration = time.perf_counter_ns() - counter
|
||||
|
||||
expected += """ gc.enable()
|
||||
|
|
@ -196,93 +121,10 @@ def test_translate_word_starting_with_single_consonant():
|
|||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = translate('banana')\n"""
|
||||
expected += (
|
||||
""" duration = (time.perf_counter_ns() - counter)\n"""
|
||||
if sys.version_info < (3, 9, 0)
|
||||
else """ duration = time.perf_counter_ns() - counter\n"""
|
||||
)
|
||||
expected += """ duration = time.perf_counter_ns() - counter\n"""
|
||||
expected += """ gc.enable()
|
||||
_log__test__values(return_value, duration, 'code_to_optimize_path:test_translate_word_starting_with_single_consonant:translate:0')"""
|
||||
assert ast_unparse(new_module_node).strip("\n") == expected
|
||||
|
||||
|
||||
def test_perfinjector_bubble_sort():
|
||||
code = """import unittest
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
def test_sort(self):
|
||||
input = [5, 4, 3, 2, 1, 0]
|
||||
output = sorter(input)
|
||||
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
|
||||
|
||||
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
|
||||
output = sorter(input)
|
||||
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
|
||||
input = list(reversed(range(5000)))
|
||||
output = sorter(input)
|
||||
self.assertEqual(output, list(range(5000)))
|
||||
"""
|
||||
expected = """import time
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import pickle
|
||||
import unittest
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
|
||||
def test_sort(self):
|
||||
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
|
||||
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, iteration_id TEXT, runtime INTEGER, return_value BLOB)')
|
||||
input = [5, 4, 3, 2, 1, 0]
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter(input)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '5', codeflash_duration, pickle.dumps(return_value)))
|
||||
codeflash_con.commit()
|
||||
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:5#####{{codeflash_duration}}^^^^^')
|
||||
output = return_value
|
||||
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
|
||||
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter(input)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '8', codeflash_duration, pickle.dumps(return_value)))
|
||||
codeflash_con.commit()
|
||||
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:8#####{{codeflash_duration}}^^^^^')
|
||||
output = return_value
|
||||
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
input = list(reversed(range(5000)))
|
||||
gc.disable()
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = sorter(input)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
gc.enable()
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?)', ('{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '11', codeflash_duration, pickle.dumps(return_value)))
|
||||
codeflash_con.commit()
|
||||
print(f'#####{module_path}:TestPigLatin.test_sort:sorter:11#####{{codeflash_duration}}^^^^^')
|
||||
output = return_value
|
||||
self.assertEqual(output, list(range(5000)))
|
||||
codeflash_con.close()"""
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
new_test = inject_profiling_into_existing_test(f.name, "sorter", os.path.dirname(f.name))
|
||||
assert new_test == expected.format(
|
||||
module_path=os.path.basename(f.name),
|
||||
tmp_dir_path=get_run_tmp_file("test_return_values"),
|
||||
)
|
||||
assert ast.unparse(new_module_node).strip("\n") == expected
|
||||
|
||||
|
||||
def test_unittest_generated_tests_bubble_sort():
|
||||
|
|
@ -490,4 +332,4 @@ if __name__ == '__main__':
|
|||
new_module_node = InjectPerfAndLogging(
|
||||
function_to_optimize, auxillary_functions, test_module_path
|
||||
).visit(module_node)
|
||||
assert ast_unparse(new_module_node).strip("\n") == expected
|
||||
assert ast.unparse(new_module_node).strip("\n") == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue