fix tests for context extractor

This commit is contained in:
mohammed 2025-07-25 15:13:10 +03:00
parent e504c879c5
commit 99cd9dc706
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
3 changed files with 73 additions and 63 deletions

View file

@ -154,7 +154,7 @@ class CodeStringsMarkdown(BaseModel):
def __str__(self) -> str:
if self.cached_code is not None:
return self.cached_code
self.cached_code = "\n\n".join(
self.cached_code = "\n".join(
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
)
return self.cached_code

View file

@ -376,7 +376,7 @@ class FunctionOptimizer:
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
future_line_profile_results = executor.submit(
ai_service_client.optimize_python_code_line_profiler,
source_code=code_context.read_writable_code,
source_code=code_context.read_writable_code.__str__,
dependency_code=code_context.read_only_context_code,
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
line_profiler_results=original_code_baseline.line_profile_results["str_out"],

View file

@ -146,7 +146,8 @@ def test_class_method_dependencies() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(file_path.parent))}
from __future__ import annotations
from collections import defaultdict
@ -199,7 +200,7 @@ class Graph:
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -225,9 +226,9 @@ def test_bubble_sort_helper() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_with_math.py")}
import math
from bubble_sort_with_math import sorter
def sorter(arr):
arr.sort()
@ -235,7 +236,8 @@ def sorter(arr):
print(x)
return arr
{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_imported.py")}
from bubble_sort_with_math import sorter
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
@ -258,8 +260,7 @@ def sort_from_another_file(arr):
return sorted_arr
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -456,7 +457,8 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
def __init__(self) -> None: ...
@ -645,7 +647,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -697,7 +699,8 @@ class HelperClass:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
class MyClass:
def __init__(self):
self.x = 1
@ -737,7 +740,7 @@ class HelperClass:
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -794,7 +797,8 @@ class HelperClass:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
class MyClass:
def __init__(self):
self.x = 1
@ -832,7 +836,7 @@ class HelperClass:
return self.x
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -889,7 +893,8 @@ class HelperClass:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
class MyClass:
def __init__(self):
self.x = 1
@ -918,7 +923,7 @@ class HelperClass:
return self.x
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1042,11 +1047,9 @@ def test_repo_helper() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
import math
import requests
from globals import API_URL
from utils import DataProcessor
class DataProcessor:
@ -1063,7 +1066,10 @@ class DataProcessor:
\"\"\"Add a prefix to the processed data.\"\"\"
return prefix + data
{get_code_block_splitter(path_to_file.relative_to(project_root))}
import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_process_data():
# Use the global variable for the request
@ -1078,8 +1084,7 @@ def fetch_and_process_data():
processed = processor.add_prefix(processed)
return processed
"""
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
@ -1113,7 +1118,7 @@ def fetch_and_process_data():
return processed
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1134,12 +1139,10 @@ def test_repo_helper_of_helper() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
import math
from transform_utils import DataTransformer
import requests
from globals import API_URL
from utils import DataProcessor
class DataProcessor:
@ -1156,7 +1159,10 @@ class DataProcessor:
\"\"\"Transform the processed data\"\"\"
return DataTransformer().transform(data)
{get_code_block_splitter(path_to_file.relative_to(project_root))}
import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_transform_data():
# Use the global variable for the request
@ -1211,8 +1217,7 @@ def fetch_and_transform_data():
return transformed
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1232,10 +1237,8 @@ def test_repo_helper_of_helper_same_class() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from transform_utils import DataTransformer
expected_read_write_context = f"""
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
class DataTransformer:
def __init__(self):
self.data = None
@ -1243,7 +1246,9 @@ class DataTransformer:
def transform_using_own_method(self, data):
return self.transform(data)
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
import math
from transform_utils import DataTransformer
class DataProcessor:
@ -1292,7 +1297,7 @@ class DataProcessor:
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1312,10 +1317,8 @@ def test_repo_helper_of_helper_same_file() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from transform_utils import DataTransformer
expected_read_write_context = f"""
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
class DataTransformer:
def __init__(self):
self.data = None
@ -1323,7 +1326,9 @@ class DataTransformer:
def transform_using_same_file_function(self, data):
return update_data(data)
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
import math
from transform_utils import DataTransformer
class DataProcessor:
@ -1367,7 +1372,7 @@ class DataProcessor:
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1386,7 +1391,8 @@ def test_repo_helper_all_same_file() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
class DataTransformer:
def __init__(self):
self.data = None
@ -1428,7 +1434,7 @@ def update_data(data):
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1448,10 +1454,10 @@ def test_repo_helper_circular_dependency() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
import math
from transform_utils import DataTransformer
from code_to_optimize.code_directories.retriever.utils import DataProcessor
class DataProcessor:
@ -1464,7 +1470,8 @@ class DataProcessor:
\"\"\"Test circular dependency\"\"\"
return DataTransformer().circular_dependency(data)
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
from code_to_optimize.code_directories.retriever.utils import DataProcessor
class DataTransformer:
def __init__(self):
@ -1503,7 +1510,7 @@ class DataTransformer:
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1546,7 +1553,8 @@ def outside_method():
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
class MyClass:
def __init__(self):
self.x = 1
@ -1568,7 +1576,7 @@ class MyClass:
return self.x + self.y
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1625,11 +1633,11 @@ def function_to_optimize():
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
```
"""
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(path_to_main.relative_to(project_root))}
import requests
from globals import API_URL
from utils import DataProcessor
import code_to_optimize.code_directories.retriever.main
def fetch_and_transform_data():
# Use the global variable for the request
@ -1644,12 +1652,13 @@ def fetch_and_transform_data():
return transformed
{get_code_block_splitter(path_to_fto.relative_to(project_root))}
import code_to_optimize.code_directories.retriever.main
def function_to_optimize():
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1808,7 +1817,8 @@ def get_system_details():
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# The expected contexts
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))}
import utility_module
class Calculator:
@ -1892,7 +1902,7 @@ class Calculator:
```
"""
# Verify the contexts match the expected values
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -2050,11 +2060,10 @@ def get_system_details():
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
# The expected contexts
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter("utility_module.py")}
# Function that will be used in the main code
import utility_module
def select_precision(precision, fallback_precision):
if precision is None:
return fallback_precision or DEFAULT_PRECISION
@ -2077,7 +2086,8 @@ def select_precision(precision, fallback_precision):
else:
return DEFAULT_PRECISION
{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))}
import utility_module
class Calculator:
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
@ -2103,7 +2113,7 @@ except ImportError:
CALCULATION_BACKEND = "python"
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()