fix tests for context extractor
This commit is contained in:
parent
e504c879c5
commit
99cd9dc706
3 changed files with 73 additions and 63 deletions
|
|
@ -154,7 +154,7 @@ class CodeStringsMarkdown(BaseModel):
|
|||
def __str__(self) -> str:
|
||||
if self.cached_code is not None:
|
||||
return self.cached_code
|
||||
self.cached_code = "\n\n".join(
|
||||
self.cached_code = "\n".join(
|
||||
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
|
||||
)
|
||||
return self.cached_code
|
||||
|
|
|
|||
|
|
@ -376,7 +376,7 @@ class FunctionOptimizer:
|
|||
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
|
||||
future_line_profile_results = executor.submit(
|
||||
ai_service_client.optimize_python_code_line_profiler,
|
||||
source_code=code_context.read_writable_code,
|
||||
source_code=code_context.read_writable_code.__str__,
|
||||
dependency_code=code_context.read_only_context_code,
|
||||
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
|
||||
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
|
||||
|
|
|
|||
|
|
@ -146,7 +146,8 @@ def test_class_method_dependencies() -> None:
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(file_path.relative_to(file_path.parent))}
|
||||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
|
||||
|
|
@ -199,7 +200,7 @@ class Graph:
|
|||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -225,9 +226,9 @@ def test_bubble_sort_helper() -> None:
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_with_math.py")}
|
||||
import math
|
||||
from bubble_sort_with_math import sorter
|
||||
|
||||
def sorter(arr):
|
||||
arr.sort()
|
||||
|
|
@ -235,7 +236,8 @@ def sorter(arr):
|
|||
print(x)
|
||||
return arr
|
||||
|
||||
|
||||
{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_imported.py")}
|
||||
from bubble_sort_with_math import sorter
|
||||
|
||||
def sort_from_another_file(arr):
|
||||
sorted_arr = sorter(arr)
|
||||
|
|
@ -258,8 +260,7 @@ def sort_from_another_file(arr):
|
|||
return sorted_arr
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -456,7 +457,8 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
|
||||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
|
||||
def __init__(self) -> None: ...
|
||||
|
|
@ -645,7 +647,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__)
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -697,7 +699,8 @@ class HelperClass:
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
|
@ -737,7 +740,7 @@ class HelperClass:
|
|||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -794,7 +797,8 @@ class HelperClass:
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
|
@ -832,7 +836,7 @@ class HelperClass:
|
|||
return self.x
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -889,7 +893,8 @@ class HelperClass:
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
|
@ -918,7 +923,7 @@ class HelperClass:
|
|||
return self.x
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -1042,11 +1047,9 @@ def test_repo_helper() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
|
||||
import math
|
||||
import requests
|
||||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
|
|
@ -1063,7 +1066,10 @@ class DataProcessor:
|
|||
\"\"\"Add a prefix to the processed data.\"\"\"
|
||||
return prefix + data
|
||||
|
||||
|
||||
{get_code_block_splitter(path_to_file.relative_to(project_root))}
|
||||
import requests
|
||||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
def fetch_and_process_data():
|
||||
# Use the global variable for the request
|
||||
|
|
@ -1078,8 +1084,7 @@ def fetch_and_process_data():
|
|||
processed = processor.add_prefix(processed)
|
||||
|
||||
return processed
|
||||
|
||||
"""
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
class DataProcessor:
|
||||
|
|
@ -1113,7 +1118,7 @@ def fetch_and_process_data():
|
|||
return processed
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -1134,12 +1139,10 @@ def test_repo_helper_of_helper() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
import requests
|
||||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
|
|
@ -1156,7 +1159,10 @@ class DataProcessor:
|
|||
\"\"\"Transform the processed data\"\"\"
|
||||
return DataTransformer().transform(data)
|
||||
|
||||
|
||||
{get_code_block_splitter(path_to_file.relative_to(project_root))}
|
||||
import requests
|
||||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
def fetch_and_transform_data():
|
||||
# Use the global variable for the request
|
||||
|
|
@ -1211,8 +1217,7 @@ def fetch_and_transform_data():
|
|||
return transformed
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -1232,10 +1237,8 @@ def test_repo_helper_of_helper_same_class() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
|
@ -1243,7 +1246,9 @@ class DataTransformer:
|
|||
def transform_using_own_method(self, data):
|
||||
return self.transform(data)
|
||||
|
||||
|
||||
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
|
|
@ -1292,7 +1297,7 @@ class DataProcessor:
|
|||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -1312,10 +1317,8 @@ def test_repo_helper_of_helper_same_file() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
|
@ -1323,7 +1326,9 @@ class DataTransformer:
|
|||
def transform_using_same_file_function(self, data):
|
||||
return update_data(data)
|
||||
|
||||
|
||||
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
|
|
@ -1367,7 +1372,7 @@ class DataProcessor:
|
|||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -1386,7 +1391,8 @@ def test_repo_helper_all_same_file() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
|
@ -1428,7 +1434,7 @@ def update_data(data):
|
|||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -1448,10 +1454,10 @@ def test_repo_helper_circular_dependency() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
from code_to_optimize.code_directories.retriever.utils import DataProcessor
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
|
|
@ -1464,7 +1470,8 @@ class DataProcessor:
|
|||
\"\"\"Test circular dependency\"\"\"
|
||||
return DataTransformer().circular_dependency(data)
|
||||
|
||||
|
||||
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
|
||||
from code_to_optimize.code_directories.retriever.utils import DataProcessor
|
||||
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
|
|
@ -1503,7 +1510,7 @@ class DataTransformer:
|
|||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -1546,7 +1553,8 @@ def outside_method():
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
|
@ -1568,7 +1576,7 @@ class MyClass:
|
|||
return self.x + self.y
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -1625,11 +1633,11 @@ def function_to_optimize():
|
|||
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
|
||||
```
|
||||
"""
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(path_to_main.relative_to(project_root))}
|
||||
import requests
|
||||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
import code_to_optimize.code_directories.retriever.main
|
||||
|
||||
def fetch_and_transform_data():
|
||||
# Use the global variable for the request
|
||||
|
|
@ -1644,12 +1652,13 @@ def fetch_and_transform_data():
|
|||
|
||||
return transformed
|
||||
|
||||
|
||||
{get_code_block_splitter(path_to_fto.relative_to(project_root))}
|
||||
import code_to_optimize.code_directories.retriever.main
|
||||
|
||||
def function_to_optimize():
|
||||
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -1808,7 +1817,8 @@ def get_system_details():
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# The expected contexts
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))}
|
||||
import utility_module
|
||||
|
||||
class Calculator:
|
||||
|
|
@ -1892,7 +1902,7 @@ class Calculator:
|
|||
```
|
||||
"""
|
||||
# Verify the contexts match the expected values
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
@ -2050,11 +2060,10 @@ def get_system_details():
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
# The expected contexts
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter("utility_module.py")}
|
||||
# Function that will be used in the main code
|
||||
|
||||
import utility_module
|
||||
|
||||
def select_precision(precision, fallback_precision):
|
||||
if precision is None:
|
||||
return fallback_precision or DEFAULT_PRECISION
|
||||
|
|
@ -2077,7 +2086,8 @@ def select_precision(precision, fallback_precision):
|
|||
else:
|
||||
return DEFAULT_PRECISION
|
||||
|
||||
|
||||
{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))}
|
||||
import utility_module
|
||||
|
||||
class Calculator:
|
||||
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
||||
|
|
@ -2103,7 +2113,7 @@ except ImportError:
|
|||
CALCULATION_BACKEND = "python"
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue