mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
End-to-end done.
This commit is contained in:
parent
cafc1ecbaa
commit
5c3a2e18d0
6 changed files with 228 additions and 210 deletions
|
|
@ -11,6 +11,7 @@
|
|||
</inspection_tool>
|
||||
<inspection_tool class="Mypy" enabled="true" level="SERVER PROBLEM" enabled_by_default="true" editorAttributes="GENERIC_SERVER_ERROR_OR_WARNING" />
|
||||
<inspection_tool class="PyMissingTypeHintsInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="PyNestedDecoratorsInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
|
||||
<inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<option name="ignoredErrors">
|
||||
<list>
|
||||
|
|
|
|||
|
|
@ -1,154 +1,104 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
import site
|
||||
import sys
|
||||
import sysconfig
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class ImportedModuleAnalysis(BaseModel, frozen=True):
|
||||
name: str # TODO Crosshair: Validate that this is the basename of full_name.
|
||||
origin: str # TODO Crosshair: Make this an enum. Validate what the prefix of file_path must be depending on origin.
|
||||
full_name: str # TODO Crosshair: Validate that if file_path exists, it can transform into its suffix.
|
||||
file_path: Path | None # TODO Crosshair: Validate that it transforms into full_name, can be None only for std lib.
|
||||
|
||||
# TODO CROSSHAIR Add clone of libcst path to qualified name and package function, add package info to model.
|
||||
class ImportedInternalModuleAnalysis(BaseModel, frozen=True):
|
||||
name: str
|
||||
full_name: str
|
||||
file_path: Path
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def name_is_identifier(cls, v: str) -> str:
|
||||
if not v.isidentifier():
|
||||
raise ValueError("must be an identifier")
|
||||
msg = "must be an identifier"
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
||||
# TODO CROSSHAIR Make this into a standalone function.
|
||||
@field_validator("full_name")
|
||||
@classmethod
|
||||
def full_name_is_dotted_identifier(cls, v: str) -> str:
|
||||
if any(not s or not s.isidentifier() for s in v.split(".")):
|
||||
raise ValueError("must be a dotted identifier")
|
||||
msg = "must be a dotted identifier"
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
||||
@field_validator("file_path")
|
||||
@classmethod
|
||||
def file_path_exists(cls, v: Path | None) -> Path | None:
|
||||
if v and not v.exists():
|
||||
raise ValueError("must be an existing path")
|
||||
msg = "must be an existing path"
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
||||
|
||||
def resolve_module_name(module: str | None, level: int, name: str, current_module_name: str) -> str | None:
|
||||
def parse_imports(code: str) -> Iterator[ast.AST]:
|
||||
return (node for node in ast.walk(ast.parse(code)) if isinstance(node, (ast.Import, ast.ImportFrom)))
|
||||
|
||||
|
||||
def resolve_relative_name(module: str | None, level: int, current_module: str) -> str | None:
|
||||
if level == 0:
|
||||
if module:
|
||||
return module # Absolute import, return module name
|
||||
return name # Edge case
|
||||
current_module_parts = current_module_name.split(".")
|
||||
if level > len(current_module_parts):
|
||||
return None # Invalid relative import
|
||||
base_parts = current_module_parts[:-level]
|
||||
return module
|
||||
current_parts = current_module.split(".")
|
||||
if level > len(current_parts):
|
||||
return None
|
||||
base_parts = current_parts[:-level]
|
||||
if module:
|
||||
base_parts.extend(module.split("."))
|
||||
if name:
|
||||
base_parts.append(name)
|
||||
return ".".join(base_parts)
|
||||
|
||||
|
||||
def collect_imports(node: ast.AST, module_name: str) -> set[str]:
|
||||
imports: set[str] = set()
|
||||
|
||||
def get_module_full_name(node: ast.Import | ast.ImportFrom, current_module: str) -> Iterator[str]:
|
||||
if isinstance(node, ast.Import):
|
||||
imports.update(alias.name for alias in node.names)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
module = node.module
|
||||
level = node.level
|
||||
if module is None and level > 0:
|
||||
# Relative import with names
|
||||
for alias in node.names:
|
||||
name = alias.name
|
||||
resolved_module = resolve_module_name(module, level, name, module_name)
|
||||
if resolved_module:
|
||||
imports.add(resolved_module)
|
||||
else:
|
||||
# Absolute import, collect the module without imported names
|
||||
resolved_module = resolve_module_name(module, level, "", module_name)
|
||||
if resolved_module:
|
||||
imports.add(resolved_module)
|
||||
else:
|
||||
for child in ast.iter_child_nodes(node):
|
||||
imports.update(collect_imports(child, module_name))
|
||||
|
||||
return imports
|
||||
return (alias.name for alias in node.names)
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
base_module = resolve_relative_name(node.module, node.level, current_module)
|
||||
if base_module is None:
|
||||
return iter(())
|
||||
if node.module is None and node.level > 0:
|
||||
# Relative import with no module specified, e.g., from . import mymodule
|
||||
return (f"{base_module}.{alias.name}" for alias in node.names)
|
||||
# For absolute imports or relative imports with module specified
|
||||
return iter([base_module])
|
||||
return iter(())
|
||||
|
||||
|
||||
def categorize_module(module_name: str, project_root: Path, module_file_path: Path) -> ImportedModuleAnalysis | None:
|
||||
if not module_name:
|
||||
return None
|
||||
|
||||
# Default module search paths
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
spec = None
|
||||
|
||||
if spec is None:
|
||||
# Internal modules
|
||||
custom_path_list = [str(module_file_path.parent), str(project_root), *sys.path.copy()]
|
||||
|
||||
try:
|
||||
spec = importlib.machinery.PathFinder.find_spec(module_name, path=custom_path_list)
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
spec = None
|
||||
|
||||
if spec is None:
|
||||
# Module not found
|
||||
origin = "unknown"
|
||||
file_path = None
|
||||
else:
|
||||
spec_origin = spec.origin
|
||||
if spec_origin is None or spec_origin in ("built-in", "frozen"):
|
||||
origin = "standard library"
|
||||
file_path = None
|
||||
else:
|
||||
file_path_path: Path = Path(spec_origin).resolve()
|
||||
stdlib_paths: list[Path] = [
|
||||
Path(p).resolve() for key, p in sysconfig.get_paths().items() if "stdlib" in key or key == "purelib"
|
||||
]
|
||||
site_packages_paths: list[Path] = [Path(p).resolve() for p in site.getsitepackages()]
|
||||
|
||||
if user_site_packages := site.getusersitepackages():
|
||||
site_packages_paths.append(Path(user_site_packages).resolve())
|
||||
|
||||
# Determine the origin based on the file path
|
||||
if any(file_path_path.is_relative_to(p) for p in stdlib_paths):
|
||||
origin = "standard library"
|
||||
elif any(file_path_path.is_relative_to(p) for p in site_packages_paths):
|
||||
origin = "third party"
|
||||
elif file_path_path.is_relative_to(project_root.resolve()):
|
||||
origin = "internal"
|
||||
else:
|
||||
origin = "unknown"
|
||||
|
||||
file_path = file_path_path
|
||||
|
||||
return ImportedModuleAnalysis(
|
||||
name=module_name.split(".")[-1], origin=origin, full_name=module_name, file_path=file_path
|
||||
)
|
||||
def is_internal_module(module_name: str, project_root: Path) -> bool:
|
||||
module_path = module_name.replace(".", "/")
|
||||
possible_paths = [project_root / f"{module_path}.py", project_root / module_path / "__init__.py"]
|
||||
return any(path.exists() for path in possible_paths)
|
||||
|
||||
|
||||
def analyze_imported_modules(code_str: str, module_file_path: Path, project_root: Path) -> list[ImportedModuleAnalysis]:
|
||||
rel_parts: list[str] = list(module_file_path.resolve().relative_to(project_root.resolve()).with_suffix("").parts)
|
||||
if rel_parts and rel_parts[-1] == "__init__":
|
||||
rel_parts = rel_parts[:-1]
|
||||
def get_module_file_path(module_name: str, project_root: Path) -> Path | None:
|
||||
module_path = module_name.replace(".", "/")
|
||||
possible_paths = [project_root / f"{module_path}.py", project_root / module_path / "__init__.py"]
|
||||
for path in possible_paths:
|
||||
if path.exists():
|
||||
return path.resolve()
|
||||
return None
|
||||
|
||||
imported_modules = []
|
||||
for module in collect_imports(ast.parse(code_str), ".".join(rel_parts)):
|
||||
info = categorize_module(module, project_root, module_file_path)
|
||||
if info and info not in imported_modules:
|
||||
imported_modules.append(info)
|
||||
|
||||
return imported_modules
|
||||
def analyze_imported_internal_modules(
|
||||
code_str: str, module_file_path: Path, project_root: Path
|
||||
) -> list[ImportedInternalModuleAnalysis]:
|
||||
"""Analyzes a Python module's code to find all imported internal modules."""
|
||||
module_rel_path = module_file_path.relative_to(project_root).with_suffix("")
|
||||
current_module = ".".join(module_rel_path.parts)
|
||||
|
||||
imports = parse_imports(code_str)
|
||||
module_names = set()
|
||||
for node in imports:
|
||||
module_names.update(get_module_full_name(node, current_module))
|
||||
|
||||
internal_modules = filter(is_internal_module, (module_names, project_root))
|
||||
|
||||
return [
|
||||
ImportedInternalModuleAnalysis(name=mod_name.split(".")[-1], full_name=mod_name, file_path=file_path)
|
||||
for mod_name in internal_modules
|
||||
if (file_path := get_module_file_path(mod_name, project_root)) is not None
|
||||
]
|
||||
|
|
|
|||
|
|
@ -236,7 +236,6 @@ def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
|
|||
|
||||
def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]:
|
||||
functions: dict[str, list[FunctionToOptimize]] = {}
|
||||
module_root_path = Path(module_root_path)
|
||||
for file_path in module_root_path.rglob("*.py"):
|
||||
# Find all the functions in the file
|
||||
functions.update(find_all_functions_in_file(file_path).items())
|
||||
|
|
|
|||
|
|
@ -18,12 +18,17 @@ from rich.console import Group
|
|||
from rich.panel import Panel
|
||||
from rich.syntax import Syntax
|
||||
from rich.tree import Tree
|
||||
from sqlalchemy import false
|
||||
|
||||
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
||||
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code, find_preexisting_objects
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module, replace_functions_and_add_imports
|
||||
from codeflash.code_utils.code_replacer import (
|
||||
normalize_code,
|
||||
replace_function_definitions_in_module,
|
||||
replace_functions_and_add_imports,
|
||||
)
|
||||
from codeflash.code_utils.code_utils import (
|
||||
file_name_from_test_module_name,
|
||||
get_run_tmp_file,
|
||||
|
|
@ -155,17 +160,19 @@ class Optimizer:
|
|||
# TODO CROSSHAIR Parse/validate ast, norm. ast, libcst. New pydantic model for code.
|
||||
|
||||
for function_to_optimize in file_to_funcs_to_optimize[path]:
|
||||
# TODO CROSSHAIR Factor out in git tools.
|
||||
# TODO CROSSHAIR Factor out in git tools. Handle no git case.
|
||||
# TODO CROSSHAIR Pydantic model for worktree info: work_tree_root, work_trees. work_tree_branches
|
||||
worktree_root = Path(tempfile.mkdtemp()) # TODO CROSSHAIR Check for randomness side effect issues
|
||||
worktree_root: Path = Path(
|
||||
tempfile.mkdtemp()
|
||||
) # TODO CROSSHAIR Check for randomness side effect issues
|
||||
# TODO CROSSHAIR Make more race condition collision proof. Handle 11 magic number.
|
||||
worktrees = [Path(tempfile.mkdtemp(dir=worktree_root)) for _ in range(11)]
|
||||
worktrees: list[Path] = [Path(tempfile.mkdtemp(dir=worktree_root)) for _ in range(N_CANDIDATES + 1)]
|
||||
for worktree in worktrees:
|
||||
# TODO CROSSHAIR Check for IO errors, collisions (race conditions).
|
||||
worktree.mkdir()
|
||||
# TODO CROSSHAIR Check for IO errors, collisions, add more chaos.
|
||||
subprocess.run(
|
||||
["git", "worktree", "add", "-d", worktree], cwd=self.args.project_root, check=True
|
||||
["git", "worktree", "add", "-d", worktree], cwd=self.args.module_root, check=True
|
||||
)
|
||||
|
||||
function_iterator_count += 1
|
||||
|
|
@ -175,7 +182,12 @@ class Optimizer:
|
|||
)
|
||||
|
||||
best_optimization = self.optimize_function(
|
||||
function_to_optimize, function_to_tests, original_code, imported_internal_module_information
|
||||
function_to_optimize,
|
||||
function_to_tests,
|
||||
original_code,
|
||||
imported_internal_module_information,
|
||||
worktree_root,
|
||||
worktrees,
|
||||
)
|
||||
self.test_files = TestFiles(test_files=[])
|
||||
|
||||
|
|
@ -213,6 +225,8 @@ class Optimizer:
|
|||
function_to_tests: dict[str, list[FunctionCalledInTest]],
|
||||
original_code: str,
|
||||
imported_internal_module_information: dict[Path, dict[str, str]],
|
||||
worktree_root: Path,
|
||||
worktrees: list[Path],
|
||||
) -> Result[BestOptimization, str]:
|
||||
should_run_experiment = self.experiment_id is not None
|
||||
function_trace_id: str = str(uuid.uuid4())
|
||||
|
|
@ -306,8 +320,8 @@ class Optimizer:
|
|||
if candidates is None:
|
||||
continue
|
||||
|
||||
optimized_module_code_strings = [
|
||||
replace_functions_and_add_imports(
|
||||
optimized_module_code_strings = {
|
||||
candidate.optimization_id: replace_functions_and_add_imports(
|
||||
original_code,
|
||||
[function_to_optimize_qualified_name],
|
||||
candidate.source_code,
|
||||
|
|
@ -318,18 +332,17 @@ class Optimizer:
|
|||
self.args.project_root,
|
||||
)
|
||||
for candidate in candidates
|
||||
]
|
||||
|
||||
}
|
||||
callee_module_paths = {callee.file_path for callee in code_context.helper_functions}
|
||||
optimized_callee_modules_code_strings = [
|
||||
{
|
||||
optimized_callee_modules_code_strings = {
|
||||
candidate.optimization_id: {
|
||||
callee_module_path: replace_functions_and_add_imports(
|
||||
imported_internal_module_information[callee_module_path]["code"],
|
||||
(
|
||||
[
|
||||
callee.qualified_name
|
||||
for callee in code_context.helper_functions
|
||||
if callee.file_path == callee_module_path
|
||||
if callee.file_path == callee_module_path and callee.jedi_definition.type != "class"
|
||||
]
|
||||
),
|
||||
candidate.source_code,
|
||||
|
|
@ -342,18 +355,108 @@ class Optimizer:
|
|||
for callee_module_path in callee_module_paths
|
||||
}
|
||||
for candidate in candidates
|
||||
}
|
||||
|
||||
normalized_original_code = normalize_code(original_code)
|
||||
are_optimized_module_code_strings_zero_diff = {
|
||||
candidate.optimization_id: normalize_code(optimized_module_code_strings[candidate.optimization_id])
|
||||
!= normalized_original_code
|
||||
for candidate in candidates
|
||||
}
|
||||
normalized_callee_module_code_strings = {
|
||||
callee_module_path: normalize_code(imported_internal_module_information[callee_module_path]["code"])
|
||||
for callee_module_path in callee_module_paths
|
||||
}
|
||||
are_optimized_callee_module_code_strings_zero_diff = {
|
||||
candidate.optimization_id: {
|
||||
callee_module_path: normalize_code(
|
||||
optimized_callee_modules_code_strings[candidate.optimization_id][callee_module_path]
|
||||
)
|
||||
!= normalized_callee_module_code_strings[callee_module_path]
|
||||
for callee_module_path in callee_module_paths
|
||||
}
|
||||
for candidate in candidates
|
||||
}
|
||||
|
||||
candidates_with_diffs = [
|
||||
candidate
|
||||
for candidate in candidates
|
||||
if not (
|
||||
are_optimized_module_code_strings_zero_diff[candidate.optimization_id]
|
||||
and all(are_optimized_callee_module_code_strings_zero_diff[candidate.optimization_id].values())
|
||||
)
|
||||
]
|
||||
|
||||
# TODO Crosshair: Write optimized modules into 10 git work trees, filter out identical normalized ast.
|
||||
# TODO Crosshair: It's concolic testing time, using Python API of Crosshair.
|
||||
# TODO Crosshair: Collect passes and failures, to compare later with regression verification results.
|
||||
# TODO CROSSHAIR: refactor filtering + write as single function. Put under try/except block.
|
||||
# TODO Crosshair: Precalculate or factor out repeated function_to_optimize.file_path.relative_to(self.args.module_root)
|
||||
for candidate, worktree in zip(candidates_with_diffs, worktrees[1:]):
|
||||
if are_optimized_module_code_strings_zero_diff[candidate.optimization_id]:
|
||||
(worktree / function_to_optimize.file_path.relative_to(self.args.module_root)).write_text(
|
||||
optimized_module_code_strings[i], encoding="utf8"
|
||||
)
|
||||
for callee_module_path in optimized_callee_modules_code_strings[candidate.optimization_id]:
|
||||
if are_optimized_callee_module_code_strings_zero_diff[candidate.optimization_id][
|
||||
callee_module_path
|
||||
]:
|
||||
(worktree / callee_module_path.relative_to(self.args.module_root)).write_text(
|
||||
optimized_callee_modules_code_strings[i][callee_module_path], encoding="utf8"
|
||||
)
|
||||
|
||||
# TODO Crosshair: Factor out relative path munging code, repeated.
|
||||
function_to_optimize_original_worktree_fqn = str(
|
||||
worktrees[0].name / function_to_optimize.file_path.relative_to(self.args.module_root).with_suffix("")
|
||||
).replace("/", ".")
|
||||
|
||||
# TODO CROSSHAIR: Turn into enum.
|
||||
diffbehavior_results: dict[str, int] = {}
|
||||
for candidate_index, candidate in enumerate(candidates_with_diffs, start=1):
|
||||
logger.info(f"Optimization candidate {candidate_index}/{len(candidates_with_diffs)}:")
|
||||
code_print(candidate.source_code)
|
||||
result = subprocess.run(
|
||||
[
|
||||
"crosshair",
|
||||
"diffbehavior",
|
||||
"--max_uninteresting_iterations",
|
||||
"64",
|
||||
str(
|
||||
worktrees[candidate_index].name
|
||||
/ function_to_optimize.file_path.relative_to(self.args.module_root).with_suffix("")
|
||||
).replace("/", "."),
|
||||
function_to_optimize_original_worktree_fqn,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=worktree_root,
|
||||
check=False,
|
||||
)
|
||||
diffbehavior_results[candidate.optimization_id] = result.returncode
|
||||
|
||||
if result.returncode == 2:
|
||||
logger.info("Inconclusive results from concolic behavior correctness check.")
|
||||
logger.warning(
|
||||
f"Error running crosshair diffbehavior{': '+ result.stderr if result.stderr else '.'}"
|
||||
)
|
||||
elif result.returncode == 1:
|
||||
logger.info(f"Optimization candidate failed concolic behavior correctness check:\n{result.stdout}")
|
||||
elif result.returncode == 0:
|
||||
logger.info(
|
||||
f"Optimization candidate passed concolic behavior correctness check"
|
||||
f"{': \n' + result.stdout.split('\n', 1)[0] if '\n' in result.stdout else '.'}"
|
||||
)
|
||||
if result.stdout.endswith("All paths exhausted, functions are likely the same!\n"):
|
||||
logger.info("All paths exhausted, functions are likely the same!")
|
||||
else:
|
||||
logger.warning("Consider increasing the --max_uninteresting_iterations option.")
|
||||
else:
|
||||
logger.info("Inconclusive results from concolic behavior correctness check.")
|
||||
logger.error("Unknown return code running crosshair diffbehavior.")
|
||||
|
||||
tests_in_file: list[FunctionCalledInTest] = function_to_tests.get(
|
||||
function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root), []
|
||||
)
|
||||
|
||||
best_optimization = self.determine_best_candidate(
|
||||
candidates=candidates,
|
||||
candidates=candidates_with_diffs,
|
||||
code_context=code_context,
|
||||
function_to_optimize=function_to_optimize,
|
||||
original_code=original_code,
|
||||
|
|
@ -361,6 +464,7 @@ class Optimizer:
|
|||
original_helper_code=original_helper_code,
|
||||
function_trace_id=function_trace_id[:-4] + f"EXP{u}" if should_run_experiment else function_trace_id,
|
||||
only_run_this_test_function=tests_in_file,
|
||||
diffbehavior_results=diffbehavior_results,
|
||||
)
|
||||
ph("cli-optimize-function-finished", {"function_trace_id": function_trace_id})
|
||||
|
||||
|
|
@ -420,10 +524,6 @@ class Optimizer:
|
|||
function_trace_id=function_trace_id,
|
||||
)
|
||||
if self.args.all or env_utils.get_pr_number():
|
||||
# Reverting to original code, because optimizing functions in a sequence can lead to
|
||||
# a) Error propagation, where error in one function can cause the next optimization to fail
|
||||
# b) Performance estimates become unstable, as the runtime of an optimization might be
|
||||
# dependent on the runtime of the previous optimization
|
||||
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
|
||||
for generated_test_path in generated_tests_paths:
|
||||
generated_test_path.unlink(missing_ok=True)
|
||||
|
|
@ -445,6 +545,7 @@ class Optimizer:
|
|||
original_helper_code: dict[Path, str],
|
||||
function_trace_id: str,
|
||||
only_run_this_test_function: list[FunctionCalledInTest] | None = None,
|
||||
diffbehavior_results: dict[str, int],
|
||||
) -> BestOptimization | None:
|
||||
best_optimization: BestOptimization | None = None
|
||||
best_runtime_until_now = original_code_baseline.runtime # The fastest code runtime until now
|
||||
|
|
@ -454,14 +555,14 @@ class Optimizer:
|
|||
is_correct = {}
|
||||
|
||||
logger.info(
|
||||
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ..."
|
||||
f"Determining best optimization candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ..."
|
||||
)
|
||||
console.rule()
|
||||
try:
|
||||
for candidate_index, candidate in enumerate(candidates, start=1):
|
||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||
logger.info(f"Optimized candidate {candidate_index}/{len(candidates)}:")
|
||||
logger.info(f"Optimization candidate {candidate_index}/{len(candidates)}:")
|
||||
code_print(candidate.source_code)
|
||||
try:
|
||||
did_update = self.replace_function_and_helpers_with_optimized_code(
|
||||
|
|
@ -480,18 +581,11 @@ class Optimizer:
|
|||
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
|
||||
continue
|
||||
|
||||
# Run generated tests if at least one of them passed
|
||||
run_generated_tests = False
|
||||
if original_code_baseline.generated_test_results:
|
||||
for test_result in original_code_baseline.generated_test_results.test_results:
|
||||
if test_result.did_pass:
|
||||
run_generated_tests = True
|
||||
break
|
||||
|
||||
run_results = self.run_optimized_candidate(
|
||||
optimization_candidate_index=candidate_index,
|
||||
original_test_results=original_code_baseline.overall_test_results,
|
||||
tests_in_file=only_run_this_test_function,
|
||||
diffbehavior_result=diffbehavior_results[candidate_index],
|
||||
)
|
||||
if not is_successful(run_results):
|
||||
optimized_runtimes[candidate.optimization_id] = None
|
||||
|
|
@ -968,6 +1062,7 @@ class Optimizer:
|
|||
optimization_candidate_index: int,
|
||||
original_test_results: TestResults | None,
|
||||
tests_in_file: list[FunctionCalledInTest] | None,
|
||||
diffbehavior_result: int,
|
||||
) -> Result[OptimizedCandidateResult, str]:
|
||||
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
|
||||
|
||||
|
|
@ -1054,17 +1149,40 @@ class Optimizer:
|
|||
test_results=[result for result in original_test_results.test_results if result.loop_index == 1]
|
||||
)
|
||||
|
||||
# TODO Crosshair: Compare to Concolic Testing results (false negative, match). Collect all comparisons.
|
||||
# TODO Crosshair: Report failure if regression OR concolic failure.
|
||||
if compare_test_results(initial_loop_original_test_results, initial_loop_candidate_results):
|
||||
logger.info("Test results matched!")
|
||||
console.rule()
|
||||
equal_results = True
|
||||
else:
|
||||
logger.info("Test results did not match the test results of the original code.")
|
||||
console.rule()
|
||||
success = False
|
||||
equal_results = False
|
||||
console.rule()
|
||||
|
||||
if diffbehavior_result == 0:
|
||||
logger.info("Concolic behavior correctness check successful!")
|
||||
console.rule()
|
||||
if equal_results:
|
||||
logger.info("True negative: Concolic behavior correctness check successful and test results matched.")
|
||||
else:
|
||||
logger.warning(
|
||||
"False negative for concolic testing: Concolic behavior correctness check successful but test results did not match."
|
||||
)
|
||||
console.rule()
|
||||
elif diffbehavior_result == 1:
|
||||
logger.warning("Concolic behavior correctness check failed.")
|
||||
console.rule()
|
||||
if equal_results:
|
||||
logger.warning(
|
||||
"False negative for regression testing: Concolic behavior correctness check failed but test results matched."
|
||||
)
|
||||
success = false()
|
||||
equal_results = False
|
||||
else:
|
||||
logger.info("True positive: Concolic behavior correctness check failed and test results did not match.")
|
||||
console.rule()
|
||||
else:
|
||||
logger.warning("Concolic behavior correctness check inconclusive.")
|
||||
console.rule()
|
||||
|
||||
if (total_candidate_timing := candidate_results.total_passed_runtime()) == 0:
|
||||
logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.")
|
||||
|
|
@ -1076,7 +1194,7 @@ class Optimizer:
|
|||
success = False
|
||||
|
||||
if not success:
|
||||
return Failure("Failed to run the optimized candidate.")
|
||||
return Failure("Failed to run the optimization candidate.")
|
||||
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
|
||||
return Success(
|
||||
OptimizedCandidateResult(
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ isort = ">=5.11.0"
|
|||
dill = "^0.3.8"
|
||||
rich = "^13.8.1"
|
||||
lxml = "^5.3.0"
|
||||
crosshair-tool = ">=0.0.76"
|
||||
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from pathlib import Path
|
||||
|
||||
from codeflash.code_utils.static_analysis import ImportedModuleAnalysis, analyze_imported_modules
|
||||
from codeflash.code_utils.static_analysis import ImportedInternalModuleAnalysis, analyze_imported_internal_modules
|
||||
|
||||
|
||||
def test_analyze_imported_modules() -> None:
|
||||
|
|
@ -12,14 +12,13 @@ from . import mymodule
|
|||
from datetime import datetime
|
||||
from pandas import DataFrame
|
||||
from pathlib import *
|
||||
from codeflash.code_utils.static_analysis import analyze_imported_modules
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
from codeflash.code_utils.static_analysis import ImportedInternalModuleAnalysis
|
||||
|
||||
def afunction():
|
||||
import datetime
|
||||
def a_function():
|
||||
from codeflash.code_utils.static_analysis import analyze_imported_modules
|
||||
from returns.result import Failure, Success
|
||||
pass
|
||||
"""
|
||||
|
|
@ -28,64 +27,14 @@ def afunction():
|
|||
project_root = (Path(__file__).parent.resolve() / "../").resolve()
|
||||
|
||||
expected_imported_module_analysis = [
|
||||
ImportedModuleAnalysis(
|
||||
name="result",
|
||||
origin="standard library",
|
||||
full_name="returns.result",
|
||||
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/site-packages/returns/result.py"),
|
||||
),
|
||||
ImportedModuleAnalysis(
|
||||
name="pandas",
|
||||
origin="standard library",
|
||||
full_name="pandas",
|
||||
file_path=Path(
|
||||
"/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/site-packages/pandas/__init__.py"
|
||||
),
|
||||
),
|
||||
ImportedModuleAnalysis(name="sys", origin="standard library", full_name="sys", file_path=None),
|
||||
ImportedModuleAnalysis(
|
||||
name="typing",
|
||||
origin="standard library",
|
||||
full_name="typing",
|
||||
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/typing.py"),
|
||||
),
|
||||
ImportedModuleAnalysis(
|
||||
name="numpy",
|
||||
origin="standard library",
|
||||
full_name="numpy",
|
||||
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/site-packages/numpy/__init__.py"),
|
||||
),
|
||||
ImportedModuleAnalysis(
|
||||
name="datetime",
|
||||
origin="standard library",
|
||||
full_name="datetime",
|
||||
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/datetime.py"),
|
||||
),
|
||||
ImportedModuleAnalysis(
|
||||
name="argparse",
|
||||
origin="standard library",
|
||||
full_name="argparse",
|
||||
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/argparse.py"),
|
||||
),
|
||||
ImportedModuleAnalysis(name="os", origin="standard library", full_name="os", file_path=None),
|
||||
ImportedModuleAnalysis(
|
||||
ImportedInternalModuleAnalysis(
|
||||
name="static_analysis",
|
||||
origin="internal",
|
||||
full_name="codeflash.code_utils.static_analysis",
|
||||
file_path=Path("/Users/renaud/repos/codeflash/cli/codeflash/code_utils/static_analysis.py"),
|
||||
file_path=project_root / Path("codeflash/code_utils/static_analysis.py"),
|
||||
),
|
||||
ImportedModuleAnalysis(
|
||||
name="mymodule",
|
||||
origin="internal",
|
||||
full_name="tests.mymodule",
|
||||
file_path=Path("/Users/renaud/repos/codeflash/cli/tests/mymodule.py"),
|
||||
),
|
||||
ImportedModuleAnalysis(
|
||||
name="pathlib",
|
||||
origin="standard library",
|
||||
full_name="pathlib",
|
||||
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/pathlib.py"),
|
||||
ImportedInternalModuleAnalysis(
|
||||
name="mymodule", full_name="tests.mymodule", file_path=project_root / Path("tests/mymodule.py")
|
||||
),
|
||||
]
|
||||
actual_imported_module_analysis = analyze_imported_modules(code_str, module_file_path, project_root)
|
||||
actual_imported_module_analysis = analyze_imported_internal_modules(code_str, module_file_path, project_root)
|
||||
assert set(actual_imported_module_analysis) == set(expected_imported_module_analysis)
|
||||
|
|
|
|||
Loading…
Reference in a new issue