flat method rename

This commit is contained in:
mohammed 2025-07-25 17:50:13 +03:00
parent d3e5e6f49e
commit f48c77df38
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
4 changed files with 28 additions and 28 deletions

View file

@ -85,14 +85,14 @@ def get_code_optimization_context(
) )
# Handle token limits # Handle token limits
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.__str__) final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.flat)
if final_read_writable_tokens > optim_token_limit: if final_read_writable_tokens > optim_token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed") raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
# Setup preexisting objects for code replacer # Setup preexisting objects for code replacer
preexisting_objects = set( preexisting_objects = set(
chain( chain(
find_preexisting_objects(final_read_writable_code.__str__), find_preexisting_objects(final_read_writable_code.flat),
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings), *(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
) )
) )

View file

@ -139,11 +139,11 @@ class CodeString(BaseModel):
file_path: Optional[Path] = None file_path: Optional[Path] = None
SPLITTER_MARKER = "# codeflash-splitter__" LINE_SPLITTER_MARKER_PREFIX = "# codeflash-splitter__"
def get_code_block_splitter(file_path: Path) -> str: def get_code_block_splitter(file_path: Path) -> str:
return f"{SPLITTER_MARKER}{file_path}" return f"{LINE_SPLITTER_MARKER_PREFIX}{file_path}"
class CodeStringsMarkdown(BaseModel): class CodeStringsMarkdown(BaseModel):
@ -151,7 +151,7 @@ class CodeStringsMarkdown(BaseModel):
cached_code: Optional[str] = None cached_code: Optional[str] = None
@property @property
def __str__(self) -> str: def flat(self) -> str:
if self.cached_code is not None: if self.cached_code is not None:
return self.cached_code return self.cached_code
self.cached_code = "\n".join( self.cached_code = "\n".join(
@ -174,7 +174,7 @@ class CodeStringsMarkdown(BaseModel):
@staticmethod @staticmethod
def from_str_with_markers(code_with_markers: str) -> CodeStringsMarkdown: def from_str_with_markers(code_with_markers: str) -> CodeStringsMarkdown:
pattern = rf"{SPLITTER_MARKER}([^\n]+)\n" pattern = rf"{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n"
matches = list(re.finditer(pattern, code_with_markers)) matches = list(re.finditer(pattern, code_with_markers))
results = [] results = []

View file

@ -165,7 +165,7 @@ class FunctionOptimizer:
helper_code = f.read() helper_code = f.read()
original_helper_code[helper_function_path] = helper_code original_helper_code[helper_function_path] = helper_code
if has_any_async_functions(code_context.read_writable_code.__str__): if has_any_async_functions(code_context.read_writable_code.flat):
return Failure("Codeflash does not support async functions in the code to optimize.") return Failure("Codeflash does not support async functions in the code to optimize.")
# Random here means that we still attempt optimization with a fractional chance to see if # Random here means that we still attempt optimization with a fractional chance to see if
# last time we could not find an optimization, maybe this time we do. # last time we could not find an optimization, maybe this time we do.
@ -284,7 +284,7 @@ class FunctionOptimizer:
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
code_print(code_context.read_writable_code.__str__) code_print(code_context.read_writable_code.flat)
test_setup_result = self.generate_and_instrument_tests( # also generates optimizations test_setup_result = self.generate_and_instrument_tests( # also generates optimizations
code_context, should_run_experiment=should_run_experiment code_context, should_run_experiment=should_run_experiment
@ -376,7 +376,7 @@ class FunctionOptimizer:
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
future_line_profile_results = executor.submit( future_line_profile_results = executor.submit(
ai_service_client.optimize_python_code_line_profiler, ai_service_client.optimize_python_code_line_profiler,
source_code=code_context.read_writable_code.__str__, source_code=code_context.read_writable_code.flat,
dependency_code=code_context.read_only_context_code, dependency_code=code_context.read_only_context_code,
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
line_profiler_results=original_code_baseline.line_profile_results["str_out"], line_profiler_results=original_code_baseline.line_profile_results["str_out"],
@ -790,7 +790,7 @@ class FunctionOptimizer:
) )
future_optimization_candidates = executor.submit( future_optimization_candidates = executor.submit(
self.aiservice_client.optimize_python_code, self.aiservice_client.optimize_python_code,
read_writable_code.__str__, read_writable_code.flat,
read_only_context_code, read_only_context_code,
self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id, self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id,
N_CANDIDATES, N_CANDIDATES,
@ -809,7 +809,7 @@ class FunctionOptimizer:
if run_experiment: if run_experiment:
future_candidates_exp = executor.submit( future_candidates_exp = executor.submit(
self.local_aiservice_client.optimize_python_code, self.local_aiservice_client.optimize_python_code,
read_writable_code.__str__, read_writable_code.flat,
read_only_context_code, read_only_context_code,
self.function_trace_id[:-4] + "EXP1", self.function_trace_id[:-4] + "EXP1",
N_CANDIDATES, N_CANDIDATES,

View file

@ -126,7 +126,7 @@ class MainClass:
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -200,7 +200,7 @@ class Graph:
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -260,7 +260,7 @@ def sort_from_another_file(arr):
return sorted_arr return sorted_arr
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -647,7 +647,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__) return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__)
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -740,7 +740,7 @@ class HelperClass:
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -836,7 +836,7 @@ class HelperClass:
return self.x return self.x
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -923,7 +923,7 @@ class HelperClass:
return self.x return self.x
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -1118,7 +1118,7 @@ def fetch_and_process_data():
return processed return processed
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -1217,7 +1217,7 @@ def fetch_and_transform_data():
return transformed return transformed
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -1297,7 +1297,7 @@ class DataProcessor:
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -1372,7 +1372,7 @@ class DataProcessor:
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -1434,7 +1434,7 @@ def update_data(data):
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -1510,7 +1510,7 @@ class DataTransformer:
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -1576,7 +1576,7 @@ class MyClass:
return self.x + self.y return self.x + self.y
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -1658,7 +1658,7 @@ import code_to_optimize.code_directories.retriever.main
def function_to_optimize(): def function_to_optimize():
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -1902,7 +1902,7 @@ class Calculator:
``` ```
""" """
# Verify the contexts match the expected values # Verify the contexts match the expected values
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip() assert hashing_context.strip() == expected_hashing_context.strip()
@ -2113,7 +2113,7 @@ except ImportError:
CALCULATION_BACKEND = "python" CALCULATION_BACKEND = "python"
``` ```
""" """
assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip() assert read_only_context.strip() == expected_read_only_context.strip()