Ensure loop_count is measured correctly and test_results is not an empty object. clarify the scope of the function run_optimized_candidate

This commit is contained in:
Saurabh Misra 2024-10-28 23:03:02 -07:00
parent e25155f17b
commit 099f555f95
5 changed files with 64 additions and 68 deletions

View file

@ -11,6 +11,7 @@ from codeflash.api.aiservice import OptimizedCandidate
from codeflash.discovery.functions_to_optimize import FunctionParent
from codeflash.verification.test_results import TestResults, TestType
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name
# of the module is foo.eggs.
@ -41,9 +42,9 @@ class CodeOptimizationContext(BaseModel):
class OptimizedCandidateResult(BaseModel):
times_run: int
max_loop_count: int
best_test_runtime: int
best_test_results: TestResults
test_results: TestResults
optimization_candidate_index: int
total_candidate_timing: int

View file

@ -405,7 +405,6 @@ class Optimizer:
run_results = self.run_optimized_candidate(
optimization_candidate_index=candidate_index,
original_test_results=original_code_baseline.overall_test_results,
best_runtime_until_now=best_runtime_until_now,
tests_in_file=only_run_this_test_function,
)
if not is_successful(run_results):
@ -421,32 +420,26 @@ class Optimizer:
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime
)
speedup_ratios[candidate.optimization_id] = perf_gain
loop_count = (
max(all_loop_indices)
if (all_loop_indices := {result.loop_index for result in candidate_result.best_test_results})
else 1
)
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
if speedup_critic(
candidate_result, original_code_baseline.runtime, best_runtime_until_now
) and quantity_of_tests_critic(candidate_result):
tree.add("This candidate is faster than the previous best candidate. 🚀")
tree.add("Original runtime:").add(f"{humanize_runtime(original_code_baseline.runtime)}")
tree.add("Best test runtime:").add(f"{humanize_runtime(candidate_result.best_test_runtime)}")
tree.add("Speedup ratio:").add(f"{perf_gain:.3f}")
tree.add(f"Original runtime: {humanize_runtime(original_code_baseline.runtime)}")
tree.add(f"Best test runtime: {humanize_runtime(candidate_result.best_test_runtime)} (measured over {candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})")
tree.add(f"Speedup ratio: {perf_gain:.3f}")
best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
runtime=best_test_runtime,
winning_test_results=candidate_result.best_test_results,
winning_test_results=candidate_result.test_results,
)
best_runtime_until_now = best_test_runtime
tree.add("runtime").add(f"{candidate_result.total_candidate_timing} (ns)")
tree.add(f"Runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}").add(
f"Total runtime: {humanize_runtime(best_test_runtime)}"
)
tree.add("Speedup ratio:").add(f"{perf_gain:.3f}")
else:
tree.add(f"Runtime: {humanize_runtime(best_test_runtime)} (measured over {candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})")
tree.add(f"Speedup ratio: {perf_gain:.3f}")
console.print(tree)
console.rule()
@ -883,7 +876,6 @@ class Optimizer:
*,
optimization_candidate_index: int,
original_test_results: TestResults | None,
best_runtime_until_now: int,
tests_in_file: list[FunctionCalledInTest] | None,
) -> Result[OptimizedCandidateResult, str]:
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
@ -892,9 +884,7 @@ class Optimizer:
generated_tests_paths = self.test_files.get_by_type(TestType.GENERATED_REGRESSION)
success = True
best_test_results = TestResults()
times_run = 0
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = str(optimization_candidate_index)
test_env["CODEFLASH_TRACER_DISABLE"] = "1"
@ -934,9 +924,15 @@ class Optimizer:
test_functions=first_test_functions,
testing_time=TOTAL_LOOPING_TIME,
)
loop_count = (
max(all_loop_indices)
if (all_loop_indices := {result.loop_index for result in candidate_results.test_results})
else 1
)
else:
candidate_results = TestResults()
start_time: float = time.time()
loop_count = 0
for i in range(100):
if i >= 5 and time.time() - start_time >= TOTAL_LOOPING_TIME:
break
@ -948,6 +944,7 @@ class Optimizer:
test_functions=first_test_functions,
testing_time=TOTAL_LOOPING_TIME,
)
loop_count = i + 1
candidate_results.merge(candidate_loop_results)
initial_loop_candidate_results = TestResults(
@ -979,8 +976,6 @@ class Optimizer:
if (total_candidate_timing := candidate_results.total_passed_runtime()) == 0:
logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.")
console.rule()
if best_runtime_until_now is None or total_candidate_timing < best_runtime_until_now:
best_test_results = candidate_results
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
@ -992,9 +987,9 @@ class Optimizer:
return Success(
OptimizedCandidateResult(
times_run=times_run,
max_loop_count=loop_count,
best_test_runtime=total_candidate_timing,
best_test_results=best_test_results,
test_results=candidate_results,
optimization_candidate_index=optimization_candidate_index,
total_candidate_timing=total_candidate_timing,
)

View file

@ -38,7 +38,7 @@ def speedup_critic(
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult) -> bool:
test_results = candidate_result.best_test_results
test_results = candidate_result.test_results
in_github_actions_mode = bool(env_utils.get_pr_number())
report = test_results.get_test_pass_fail_report_by_type()

View file

@ -45,21 +45,6 @@ except ImportError:
def comparator(orig: Any, new: Any) -> bool:
try:
if HAS_SQLALCHEMY:
try:
insp = sqlalchemy.inspection.inspect(orig)
insp = sqlalchemy.inspection.inspect(new)
orig_keys = orig.__dict__
new_keys = new.__dict__
for key in list(orig_keys.keys()):
if key.startswith("_"):
continue
if key not in new_keys or not comparator(orig_keys[key], new_keys[key]):
return False
return True
except sqlalchemy.exc.NoInspectionAvailable:
pass
if type(orig) != type(new):
return False
if isinstance(orig, (list, tuple)):
@ -93,6 +78,21 @@ def comparator(orig: Any, new: Any) -> bool:
if math.isnan(orig) and math.isnan(new):
return True
return math.isclose(orig, new)
if HAS_SQLALCHEMY:
try:
insp = sqlalchemy.inspection.inspect(orig)
insp = sqlalchemy.inspection.inspect(new)
orig_keys = orig.__dict__
new_keys = new.__dict__
for key in list(orig_keys.keys()):
if key.startswith("_"):
continue
if key not in new_keys or not comparator(orig_keys[key], new_keys[key]):
return False
return True
except sqlalchemy.exc.NoInspectionAvailable:
pass
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
if len(orig) != len(new):

View file

@ -22,9 +22,9 @@ def test_speedup_critic():
original_code_runtime = 1000
best_runtime_until_now = 1000
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=800,
best_test_results=TestResults(),
test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=12,
)
@ -32,9 +32,9 @@ def test_speedup_critic():
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now) # 20% improvement
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=940,
best_test_results=TestResults(),
test_results=TestResults(),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -45,9 +45,9 @@ def test_speedup_critic():
best_runtime_until_now = 100000
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=94000,
best_test_results=TestResults(),
test_results=TestResults(),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -167,9 +167,9 @@ def test_generated_test_critic():
test_results = [test_1, test_2, test_3]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -179,9 +179,9 @@ def test_generated_test_critic():
test_results = [test_1, test_3, test_6]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -191,9 +191,9 @@ def test_generated_test_critic():
test_results = [test_1, test_3, test_4]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -203,9 +203,9 @@ def test_generated_test_critic():
test_results = [test_1]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -215,9 +215,9 @@ def test_generated_test_critic():
test_results = [test_1, test_2]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -227,9 +227,9 @@ def test_generated_test_critic():
test_results = [test_1, test_4, test_6]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -239,9 +239,9 @@ def test_generated_test_critic():
test_results = [test_4, test_5]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -251,9 +251,9 @@ def test_generated_test_critic():
test_results = [test_1, test_2, test_3, test_4, test_5]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -265,9 +265,9 @@ def test_generated_test_critic():
test_results = [test_1, test_2, test_3, test_6]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -277,9 +277,9 @@ def test_generated_test_critic():
test_results = [test_1, test_2, test_3, test_4]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)
@ -289,9 +289,9 @@ def test_generated_test_critic():
test_results = [test_1, test_2, test_3, test_5]
candidate_result = OptimizedCandidateResult(
times_run=5,
max_loop_count=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
test_results=TestResults(test_results=test_results),
total_candidate_timing=12,
optimization_candidate_index=0,
)