Merge branch 'refs/heads/main' into codeflash-trace-decorator

# Conflicts:
#	codeflash/code_utils/config_parser.py
#	codeflash/optimization/function_optimizer.py
This commit is contained in:
Alvin Ryanputra 2025-04-10 21:28:48 -04:00
commit 40e416e0d0
21 changed files with 231 additions and 120 deletions

View file

@ -29,18 +29,25 @@ jobs:
fetch-depth: 0
- name: Validate PR
run: |
# Checking for any workflow changes for security risks
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
echo "Workflow changes detected."
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Check if the PR author is allowed
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "Authorized user ($AUTHOR). Proceeding."
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: 🐍 Set up Python 3.11 for CLI

View file

@ -24,18 +24,25 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Checking for any workflow changes for security risks
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
echo "Workflow changes detected."
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Check if the PR author is allowed
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "Authorized user ($AUTHOR). Proceeding."
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Python 3.11 for CLI

View file

@ -24,18 +24,25 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Checking for any workflow changes for security risks
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
echo "Workflow changes detected."
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Check if the PR author is allowed
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "Authorized user ($AUTHOR). Proceeding."
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Python 3.11 for CLI

View file

@ -22,18 +22,25 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Checking for any workflow changes for security risks
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
echo "Workflow changes detected."
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Check if the PR author is allowed
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "Authorized user ($AUTHOR). Proceeding."
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Python 3.11 for CLI

View file

@ -24,18 +24,25 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Checking for any workflow changes for security risks
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
echo "Workflow changes detected."
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Check if the PR author is allowed
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "Authorized user ($AUTHOR). Proceeding."
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Python 3.11 for CLI

View file

@ -24,18 +24,25 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Checking for any workflow changes for security risks
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
echo "Workflow changes detected."
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Check if the PR author is allowed
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "Authorized user ($AUTHOR). Proceeding."
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Python 3.11 for CLI

View file

@ -24,18 +24,25 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Checking for any workflow changes for security risks
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
echo "Workflow changes detected."
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Check if the PR author is allowed
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "Authorized user ($AUTHOR). Proceeding."
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi

View file

@ -22,23 +22,26 @@ jobs:
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Checking for any workflow changes for security risks
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
echo "Workflow changes detected."
# Check if the PR author is allowed
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "Authorized user ($AUTHOR). Proceeding."
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
fiif git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "end-to-end-topological-sort-test.yaml"; then
echo "This workflow file has been modified. Exiting for security."
exit 1
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Python 3.11 for CLI

View file

@ -0,0 +1,55 @@
name: PR Labeler
on:
pull_request:
paths:
- ".github/workflows/**"
types: [opened, synchronize, reopened]
jobs:
label-workflow-changes:
runs-on: ubuntu-latest
permissions:
pull-requests: write
steps:
- name: Label PR with workflow changes
uses: actions/github-script@v6
with:
script: |
const labelName = 'workflow-modified';
// Check if the label exists
try {
const labels = await github.rest.issues.listLabelsForRepo({
owner: context.repo.owner,
repo: context.repo.repo
});
const labelExists = labels.data.some(label => label.name === labelName);
if (!labelExists) {
// Create the label if it doesn't exist
await github.rest.issues.createLabel({
owner: context.repo.owner,
repo: context.repo.repo,
name: labelName,
color: 'f9d0c4',
description: 'This PR modifies GitHub Actions workflows'
});
console.log(`Label "${labelName}" created`);
} else {
console.log(`Label "${labelName}" already exists`);
}
} catch (error) {
console.error(`Failed to check or create label: ${error.message}`);
throw error;
}
// Add the label to the PR
await github.rest.issues.addLabels({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
labels: [labelName]
});
console.log(`Label "${labelName}" added to the PR.`);

View file

@ -83,7 +83,7 @@ def parse_config_file(
else: # Default to empty list
config[key] = []
assert config["test-framework"] in ["pytest", "unittest"], (
assert config["test-framework"] in {"pytest", "unittest"}, (
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
)
if len(config["formatter-cmds"]) > 0:

View file

@ -70,7 +70,7 @@ def _pipe_segment_with_colons(align, colwidth):
"""Return a segment of a horizontal line with optional colons which
indicate column's alignment (as in `pipe` output format)."""
w = colwidth
if align in ["right", "decimal"]:
if align in {"right", "decimal"}:
return ("-" * (w - 1)) + ":"
elif align == "center":
return ":" + ("-" * (w - 2)) + ":"
@ -176,7 +176,7 @@ def _isconvertible(conv, string):
def _isnumber(string):
return (
# fast path
type(string) in (float, int)
type(string) in {float, int}
# covers 'NaN', +/- 'inf', and eg. '1e2', as well as any type
# convertible to int/float.
or (
@ -188,7 +188,7 @@ def _isnumber(string):
# just an over/underflow
or (
not (math.isinf(float(string)) or math.isnan(float(string)))
or string.lower() in ["inf", "-inf", "nan"]
or string.lower() in {"inf", "-inf", "nan"}
)
)
)
@ -210,7 +210,7 @@ def _isint(string, inttype=int):
def _isbool(string):
return type(string) is bool or (
isinstance(string, (bytes, str)) and string in ("True", "False")
isinstance(string, (bytes, str)) and string in {"True", "False"}
)
@ -570,7 +570,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
# values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0)
keys = list(tabular_data)
if (
showindex in ["default", "always", True]
showindex in {"default", "always", True}
and tabular_data.index.name is not None
):
if isinstance(tabular_data.index.name, list):
@ -686,7 +686,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
rows = list(map(lambda r: r if _is_separating_line(r) else list(r), rows))
# add or remove an index column
showindex_is_a_str = type(showindex) in [str, bytes]
showindex_is_a_str = type(showindex) in {str, bytes}
if showindex == "never" or (not _bool(showindex) and not showindex_is_a_str):
pass
@ -820,7 +820,7 @@ def tabulate(
if colglobalalign is not None: # if global alignment provided
aligns = [colglobalalign] * len(cols)
else: # default
aligns = [numalign if ct in [int, float] else stralign for ct in coltypes]
aligns = [numalign if ct in {int, float} else stralign for ct in coltypes]
# then specific alignments
if colalign is not None:
assert isinstance(colalign, Iterable)
@ -1044,4 +1044,4 @@ def _format_table(
output = "\n".join(lines)
return output
else: # a completely empty table
return ""
return ""

View file

@ -16,13 +16,13 @@ def humanize_runtime(time_in_ns: int) -> str:
units = re.split(r",|\s", runtime_human)[1]
if units in ("microseconds", "microsecond"):
if units in {"microseconds", "microsecond"}:
runtime_human = f"{time_micro:.3g}"
elif units in ("milliseconds", "millisecond"):
elif units in {"milliseconds", "millisecond"}:
runtime_human = "%.3g" % (time_micro / 1000)
elif units in ("seconds", "second"):
elif units in {"seconds", "second"}:
runtime_human = "%.3g" % (time_micro / (1000**2))
elif units in ("minutes", "minute"):
elif units in {"minutes", "minute"}:
runtime_human = "%.3g" % (time_micro / (60 * 1000**2))
else: # hours
runtime_human = "%.3g" % (time_micro / (3600 * 1000**2))

View file

@ -830,7 +830,7 @@ class FunctionOptimizer:
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
# For the original function - run the tests and get the runtime, plus coverage
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"}
success = True
test_env = os.environ.copy()
@ -981,7 +981,8 @@ class FunctionOptimizer:
original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]],
) -> Result[OptimizedCandidateResult, str]:
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"}
with progress_bar("Testing optimization candidate"):
test_env = os.environ.copy()
test_env["CODEFLASH_LOOP_INDEX"] = "0"
@ -1159,7 +1160,7 @@ class FunctionOptimizer:
f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n"
)
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]:
if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}:
results, coverage_results = parse_test_results(
test_xml_path=result_file_path,
test_files=test_files,

View file

@ -10,7 +10,7 @@ from codeflash.cli_cmds.console import logger
class ProfileStats(pstats.Stats):
def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None:
assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist"
assert time_unit in ["ns", "us", "ms", "s"], f"Invalid time unit {time_unit}"
assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}"
self.trace_file_path = trace_file_path
self.time_unit = time_unit
logger.debug(hasattr(self, "create_stats"))
@ -59,10 +59,10 @@ class ProfileStats(pstats.Stats):
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit]
print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream)
print(file=self.stream)
width, list = self.get_print_list(amount)
if list:
width, list_ = self.get_print_list(amount)
if list_:
self.print_title()
for func in list:
for func in list_:
self.print_line(func)
print(file=self.stream)
print(file=self.stream)

View file

@ -42,7 +42,7 @@ def get_function_alias(module: str, function_name: str) -> str:
def create_trace_replay_test(
trace_file: str, functions: list[FunctionModules], test_framework: str = "pytest", max_run_count=100
) -> str:
assert test_framework in ["pytest", "unittest"]
assert test_framework in {"pytest", "unittest"}
imports = f"""import dill as pickle
{"import unittest" if test_framework == "unittest" else ""}

View file

@ -233,7 +233,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"}
return comparator(orig_keys, new_keys, superset_obj)
if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
if type(orig) in {types.BuiltinFunctionType, types.BuiltinMethodType}:
return new == orig
if str(type(orig)) == "<class 'object'>":
return True

View file

@ -40,26 +40,29 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
superset_obj = False
if original_test_result.verification_type and (
original_test_result.verification_type
in (VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO)
in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO}
):
superset_obj = True
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
are_equal = False
logger.debug(
"File Name: %s\n"
"Test Type: %s\n"
"Verification Type: %s\n"
"Invocation ID: %s\n"
"Original return value: %s\n"
"Candidate return value: %s\n"
"-------------------",
original_test_result.file_name,
original_test_result.test_type,
original_test_result.verification_type,
original_test_result.id,
original_test_result.return_value,
cdd_test_result.return_value,
)
try:
logger.debug(
"File Name: %s\n"
"Test Type: %s\n"
"Verification Type: %s\n"
"Invocation ID: %s\n"
"Original return value: %s\n"
"Candidate return value: %s\n"
"-------------------",
original_test_result.file_name,
original_test_result.test_type,
original_test_result.verification_type,
original_test_result.id,
original_test_result.return_value,
cdd_test_result.return_value,
)
except Exception as e:
logger.error(e)
break
if (original_test_result.stdout and cdd_test_result.stdout) and not comparator(
original_test_result.stdout, cdd_test_result.stdout
@ -67,7 +70,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
are_equal = False
break
if original_test_result.test_type in [TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST] and (
if original_test_result.test_type in {TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST} and (
cdd_test_result.did_pass != original_test_result.did_pass
):
are_equal = False

View file

@ -16,12 +16,12 @@ def show_func(filename, start_lineno, func_name, timings, unit):
return ''
scalar = 1
if os.path.exists(filename):
out_table+=f'## Function: {func_name}\n'
out_table += f'## Function: {func_name}\n'
# Clear the cache to ensure that we get up-to-date results.
linecache.clearcache()
all_lines = linecache.getlines(filename)
sublines = inspect.getblock(all_lines[start_lineno - 1:])
out_table+='## Total time: %g s\n' % (total_time * unit)
out_table += '## Total time: %g s\n' % (total_time * unit)
# Define minimum column sizes so text fits and usually looks consistent
default_column_sizes = {
'hits': 9,
@ -57,7 +57,7 @@ def show_func(filename, start_lineno, func_name, timings, unit):
if 'def' in line_ or nhits!='':
table_rows.append((nhits, time, per_hit, percent, line_))
pass
out_table+= tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
out_table += tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
out_table+='\n'
return out_table
@ -65,12 +65,12 @@ def show_text(stats: dict) -> str:
""" Show text for the given timings.
"""
out_table = ""
out_table+='# Timer unit: %g s\n' % stats['unit']
out_table += '# Timer unit: %g s\n' % stats['unit']
stats_order = sorted(stats['timings'].items())
# Show detailed per-line information for each function.
for (fn, lineno, name), timings in stats_order:
table_md =show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
out_table+=table_md
table_md = show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
out_table += table_md
return out_table
def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dict:
@ -83,6 +83,6 @@ def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dic
stats = pickle.load(f)
stats_dict['timings'] = stats.timings
stats_dict['unit'] = stats.unit
str_out=show_text(stats_dict)
stats_dict['str_out']=str_out
return stats_dict, None
str_out = show_text(stats_dict)
stats_dict['str_out'] = str_out
return stats_dict, None

View file

@ -127,7 +127,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
iteration_id = val[5]
runtime = val[6]
verification_type = val[8]
if verification_type in (VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER):
if verification_type in {VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER}:
test_type = TestType.INIT_STATE_TEST
else:
# TODO : this is because sqlite writes original file module path. Should make it consistent

View file

@ -166,7 +166,7 @@ def run_line_profile_tests(
)
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file:
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
test_files.extend(
[
str(file.benchmarking_file_path)
@ -226,7 +226,7 @@ def run_benchmarking_tests(
)
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file:
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
test_files.extend(
[
str(file.benchmarking_file_path)

View file

@ -10,7 +10,7 @@ from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
assert test_type in ["unit", "inspired", "replay", "perf"]
assert test_type in {"unit", "inspired", "replay", "perf"}
function_name = function_name.replace(".", "_")
path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}.py"
if path.exists():