mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'refs/heads/main' into codeflash-trace-decorator
# Conflicts: # codeflash/code_utils/config_parser.py # codeflash/optimization/function_optimizer.py
This commit is contained in:
commit
40e416e0d0
21 changed files with 231 additions and 120 deletions
23
.github/workflows/codeflash-optimize.yaml
vendored
23
.github/workflows/codeflash-optimize.yaml
vendored
|
|
@ -29,18 +29,25 @@ jobs:
|
|||
fetch-depth: 0
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Checking for any workflow changes for security risks
|
||||
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
|
||||
echo "Workflow changes detected."
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Check if the PR author is allowed
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
|
||||
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "Authorized user ($AUTHOR). Proceeding."
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: 🐍 Set up Python 3.11 for CLI
|
||||
|
|
|
|||
|
|
@ -24,18 +24,25 @@ jobs:
|
|||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Checking for any workflow changes for security risks
|
||||
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
|
||||
echo "Workflow changes detected."
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Check if the PR author is allowed
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
|
||||
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "Authorized user ($AUTHOR). Proceeding."
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
|
|
|
|||
|
|
@ -24,18 +24,25 @@ jobs:
|
|||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Checking for any workflow changes for security risks
|
||||
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
|
||||
echo "Workflow changes detected."
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Check if the PR author is allowed
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
|
||||
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "Authorized user ($AUTHOR). Proceeding."
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
|
|
|
|||
23
.github/workflows/end-to-end-test-coverage.yaml
vendored
23
.github/workflows/end-to-end-test-coverage.yaml
vendored
|
|
@ -22,18 +22,25 @@ jobs:
|
|||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Checking for any workflow changes for security risks
|
||||
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
|
||||
echo "Workflow changes detected."
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Check if the PR author is allowed
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
|
||||
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "Authorized user ($AUTHOR). Proceeding."
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
|
|
|
|||
|
|
@ -24,18 +24,25 @@ jobs:
|
|||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Checking for any workflow changes for security risks
|
||||
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
|
||||
echo "Workflow changes detected."
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Check if the PR author is allowed
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
|
||||
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "Authorized user ($AUTHOR). Proceeding."
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
|
|
|
|||
|
|
@ -24,18 +24,25 @@ jobs:
|
|||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Checking for any workflow changes for security risks
|
||||
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
|
||||
echo "Workflow changes detected."
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Check if the PR author is allowed
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
|
||||
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "Authorized user ($AUTHOR). Proceeding."
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
|
|
|
|||
|
|
@ -24,18 +24,25 @@ jobs:
|
|||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Checking for any workflow changes for security risks
|
||||
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
|
||||
echo "Workflow changes detected."
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Check if the PR author is allowed
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
|
||||
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "Authorized user ($AUTHOR). Proceeding."
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,23 +22,26 @@ jobs:
|
|||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Checking for any workflow changes for security risks
|
||||
if git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "^.github/workflows/"; then
|
||||
echo "Workflow changes detected."
|
||||
|
||||
# Check if the PR author is allowed
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
if [[ "$AUTHOR" != "misrasaurabh1" && "$AUTHOR" != "KRRT7" ]]; then
|
||||
echo "Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
echo "PR Author: $AUTHOR"
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event_name }}" == "pull_request_target" && "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "Authorized user ($AUTHOR). Proceeding."
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
fiif git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "end-to-end-topological-sort-test.yaml"; then
|
||||
echo "This workflow file has been modified. Exiting for security."
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
|
|
|
|||
55
.github/workflows/label-workflow-changes.yml
vendored
Normal file
55
.github/workflows/label-workflow-changes.yml
vendored
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
name: PR Labeler
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/**"
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
jobs:
|
||||
label-workflow-changes:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Label PR with workflow changes
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const labelName = 'workflow-modified';
|
||||
|
||||
// Check if the label exists
|
||||
try {
|
||||
const labels = await github.rest.issues.listLabelsForRepo({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo
|
||||
});
|
||||
|
||||
const labelExists = labels.data.some(label => label.name === labelName);
|
||||
|
||||
if (!labelExists) {
|
||||
// Create the label if it doesn't exist
|
||||
await github.rest.issues.createLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
name: labelName,
|
||||
color: 'f9d0c4',
|
||||
description: 'This PR modifies GitHub Actions workflows'
|
||||
});
|
||||
console.log(`Label "${labelName}" created`);
|
||||
} else {
|
||||
console.log(`Label "${labelName}" already exists`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Failed to check or create label: ${error.message}`);
|
||||
throw error;
|
||||
}
|
||||
|
||||
// Add the label to the PR
|
||||
await github.rest.issues.addLabels({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
labels: [labelName]
|
||||
});
|
||||
console.log(`Label "${labelName}" added to the PR.`);
|
||||
|
|
@ -83,7 +83,7 @@ def parse_config_file(
|
|||
else: # Default to empty list
|
||||
config[key] = []
|
||||
|
||||
assert config["test-framework"] in ["pytest", "unittest"], (
|
||||
assert config["test-framework"] in {"pytest", "unittest"}, (
|
||||
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
|
||||
)
|
||||
if len(config["formatter-cmds"]) > 0:
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ def _pipe_segment_with_colons(align, colwidth):
|
|||
"""Return a segment of a horizontal line with optional colons which
|
||||
indicate column's alignment (as in `pipe` output format)."""
|
||||
w = colwidth
|
||||
if align in ["right", "decimal"]:
|
||||
if align in {"right", "decimal"}:
|
||||
return ("-" * (w - 1)) + ":"
|
||||
elif align == "center":
|
||||
return ":" + ("-" * (w - 2)) + ":"
|
||||
|
|
@ -176,7 +176,7 @@ def _isconvertible(conv, string):
|
|||
def _isnumber(string):
|
||||
return (
|
||||
# fast path
|
||||
type(string) in (float, int)
|
||||
type(string) in {float, int}
|
||||
# covers 'NaN', +/- 'inf', and eg. '1e2', as well as any type
|
||||
# convertible to int/float.
|
||||
or (
|
||||
|
|
@ -188,7 +188,7 @@ def _isnumber(string):
|
|||
# just an over/underflow
|
||||
or (
|
||||
not (math.isinf(float(string)) or math.isnan(float(string)))
|
||||
or string.lower() in ["inf", "-inf", "nan"]
|
||||
or string.lower() in {"inf", "-inf", "nan"}
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -210,7 +210,7 @@ def _isint(string, inttype=int):
|
|||
|
||||
def _isbool(string):
|
||||
return type(string) is bool or (
|
||||
isinstance(string, (bytes, str)) and string in ("True", "False")
|
||||
isinstance(string, (bytes, str)) and string in {"True", "False"}
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -570,7 +570,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
|
|||
# values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0)
|
||||
keys = list(tabular_data)
|
||||
if (
|
||||
showindex in ["default", "always", True]
|
||||
showindex in {"default", "always", True}
|
||||
and tabular_data.index.name is not None
|
||||
):
|
||||
if isinstance(tabular_data.index.name, list):
|
||||
|
|
@ -686,7 +686,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
|
|||
rows = list(map(lambda r: r if _is_separating_line(r) else list(r), rows))
|
||||
|
||||
# add or remove an index column
|
||||
showindex_is_a_str = type(showindex) in [str, bytes]
|
||||
showindex_is_a_str = type(showindex) in {str, bytes}
|
||||
if showindex == "never" or (not _bool(showindex) and not showindex_is_a_str):
|
||||
pass
|
||||
|
||||
|
|
@ -820,7 +820,7 @@ def tabulate(
|
|||
if colglobalalign is not None: # if global alignment provided
|
||||
aligns = [colglobalalign] * len(cols)
|
||||
else: # default
|
||||
aligns = [numalign if ct in [int, float] else stralign for ct in coltypes]
|
||||
aligns = [numalign if ct in {int, float} else stralign for ct in coltypes]
|
||||
# then specific alignments
|
||||
if colalign is not None:
|
||||
assert isinstance(colalign, Iterable)
|
||||
|
|
@ -1044,4 +1044,4 @@ def _format_table(
|
|||
output = "\n".join(lines)
|
||||
return output
|
||||
else: # a completely empty table
|
||||
return ""
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -16,13 +16,13 @@ def humanize_runtime(time_in_ns: int) -> str:
|
|||
|
||||
units = re.split(r",|\s", runtime_human)[1]
|
||||
|
||||
if units in ("microseconds", "microsecond"):
|
||||
if units in {"microseconds", "microsecond"}:
|
||||
runtime_human = f"{time_micro:.3g}"
|
||||
elif units in ("milliseconds", "millisecond"):
|
||||
elif units in {"milliseconds", "millisecond"}:
|
||||
runtime_human = "%.3g" % (time_micro / 1000)
|
||||
elif units in ("seconds", "second"):
|
||||
elif units in {"seconds", "second"}:
|
||||
runtime_human = "%.3g" % (time_micro / (1000**2))
|
||||
elif units in ("minutes", "minute"):
|
||||
elif units in {"minutes", "minute"}:
|
||||
runtime_human = "%.3g" % (time_micro / (60 * 1000**2))
|
||||
else: # hours
|
||||
runtime_human = "%.3g" % (time_micro / (3600 * 1000**2))
|
||||
|
|
|
|||
|
|
@ -830,7 +830,7 @@ class FunctionOptimizer:
|
|||
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
|
||||
# For the original function - run the tests and get the runtime, plus coverage
|
||||
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
|
||||
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
|
||||
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"}
|
||||
success = True
|
||||
|
||||
test_env = os.environ.copy()
|
||||
|
|
@ -981,7 +981,8 @@ class FunctionOptimizer:
|
|||
original_helper_code: dict[Path, str],
|
||||
file_path_to_helper_classes: dict[Path, set[str]],
|
||||
) -> Result[OptimizedCandidateResult, str]:
|
||||
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
|
||||
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"}
|
||||
|
||||
with progress_bar("Testing optimization candidate"):
|
||||
test_env = os.environ.copy()
|
||||
test_env["CODEFLASH_LOOP_INDEX"] = "0"
|
||||
|
|
@ -1159,7 +1160,7 @@ class FunctionOptimizer:
|
|||
f"stdout: {run_result.stdout}\n"
|
||||
f"stderr: {run_result.stderr}\n"
|
||||
)
|
||||
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]:
|
||||
if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}:
|
||||
results, coverage_results = parse_test_results(
|
||||
test_xml_path=result_file_path,
|
||||
test_files=test_files,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from codeflash.cli_cmds.console import logger
|
|||
class ProfileStats(pstats.Stats):
|
||||
def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None:
|
||||
assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist"
|
||||
assert time_unit in ["ns", "us", "ms", "s"], f"Invalid time unit {time_unit}"
|
||||
assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}"
|
||||
self.trace_file_path = trace_file_path
|
||||
self.time_unit = time_unit
|
||||
logger.debug(hasattr(self, "create_stats"))
|
||||
|
|
@ -59,10 +59,10 @@ class ProfileStats(pstats.Stats):
|
|||
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit]
|
||||
print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream)
|
||||
print(file=self.stream)
|
||||
width, list = self.get_print_list(amount)
|
||||
if list:
|
||||
width, list_ = self.get_print_list(amount)
|
||||
if list_:
|
||||
self.print_title()
|
||||
for func in list:
|
||||
for func in list_:
|
||||
self.print_line(func)
|
||||
print(file=self.stream)
|
||||
print(file=self.stream)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ def get_function_alias(module: str, function_name: str) -> str:
|
|||
def create_trace_replay_test(
|
||||
trace_file: str, functions: list[FunctionModules], test_framework: str = "pytest", max_run_count=100
|
||||
) -> str:
|
||||
assert test_framework in ["pytest", "unittest"]
|
||||
assert test_framework in {"pytest", "unittest"}
|
||||
|
||||
imports = f"""import dill as pickle
|
||||
{"import unittest" if test_framework == "unittest" else ""}
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
|
|||
new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"}
|
||||
return comparator(orig_keys, new_keys, superset_obj)
|
||||
|
||||
if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
|
||||
if type(orig) in {types.BuiltinFunctionType, types.BuiltinMethodType}:
|
||||
return new == orig
|
||||
if str(type(orig)) == "<class 'object'>":
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -40,26 +40,29 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
|
|||
superset_obj = False
|
||||
if original_test_result.verification_type and (
|
||||
original_test_result.verification_type
|
||||
in (VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO)
|
||||
in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO}
|
||||
):
|
||||
superset_obj = True
|
||||
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
|
||||
are_equal = False
|
||||
logger.debug(
|
||||
"File Name: %s\n"
|
||||
"Test Type: %s\n"
|
||||
"Verification Type: %s\n"
|
||||
"Invocation ID: %s\n"
|
||||
"Original return value: %s\n"
|
||||
"Candidate return value: %s\n"
|
||||
"-------------------",
|
||||
original_test_result.file_name,
|
||||
original_test_result.test_type,
|
||||
original_test_result.verification_type,
|
||||
original_test_result.id,
|
||||
original_test_result.return_value,
|
||||
cdd_test_result.return_value,
|
||||
)
|
||||
try:
|
||||
logger.debug(
|
||||
"File Name: %s\n"
|
||||
"Test Type: %s\n"
|
||||
"Verification Type: %s\n"
|
||||
"Invocation ID: %s\n"
|
||||
"Original return value: %s\n"
|
||||
"Candidate return value: %s\n"
|
||||
"-------------------",
|
||||
original_test_result.file_name,
|
||||
original_test_result.test_type,
|
||||
original_test_result.verification_type,
|
||||
original_test_result.id,
|
||||
original_test_result.return_value,
|
||||
cdd_test_result.return_value,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
if (original_test_result.stdout and cdd_test_result.stdout) and not comparator(
|
||||
original_test_result.stdout, cdd_test_result.stdout
|
||||
|
|
@ -67,7 +70,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
|
|||
are_equal = False
|
||||
break
|
||||
|
||||
if original_test_result.test_type in [TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST] and (
|
||||
if original_test_result.test_type in {TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST} and (
|
||||
cdd_test_result.did_pass != original_test_result.did_pass
|
||||
):
|
||||
are_equal = False
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ def show_func(filename, start_lineno, func_name, timings, unit):
|
|||
return ''
|
||||
scalar = 1
|
||||
if os.path.exists(filename):
|
||||
out_table+=f'## Function: {func_name}\n'
|
||||
out_table += f'## Function: {func_name}\n'
|
||||
# Clear the cache to ensure that we get up-to-date results.
|
||||
linecache.clearcache()
|
||||
all_lines = linecache.getlines(filename)
|
||||
sublines = inspect.getblock(all_lines[start_lineno - 1:])
|
||||
out_table+='## Total time: %g s\n' % (total_time * unit)
|
||||
out_table += '## Total time: %g s\n' % (total_time * unit)
|
||||
# Define minimum column sizes so text fits and usually looks consistent
|
||||
default_column_sizes = {
|
||||
'hits': 9,
|
||||
|
|
@ -57,7 +57,7 @@ def show_func(filename, start_lineno, func_name, timings, unit):
|
|||
if 'def' in line_ or nhits!='':
|
||||
table_rows.append((nhits, time, per_hit, percent, line_))
|
||||
pass
|
||||
out_table+= tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
|
||||
out_table += tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
|
||||
out_table+='\n'
|
||||
return out_table
|
||||
|
||||
|
|
@ -65,12 +65,12 @@ def show_text(stats: dict) -> str:
|
|||
""" Show text for the given timings.
|
||||
"""
|
||||
out_table = ""
|
||||
out_table+='# Timer unit: %g s\n' % stats['unit']
|
||||
out_table += '# Timer unit: %g s\n' % stats['unit']
|
||||
stats_order = sorted(stats['timings'].items())
|
||||
# Show detailed per-line information for each function.
|
||||
for (fn, lineno, name), timings in stats_order:
|
||||
table_md =show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
|
||||
out_table+=table_md
|
||||
table_md = show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
|
||||
out_table += table_md
|
||||
return out_table
|
||||
|
||||
def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dict:
|
||||
|
|
@ -83,6 +83,6 @@ def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dic
|
|||
stats = pickle.load(f)
|
||||
stats_dict['timings'] = stats.timings
|
||||
stats_dict['unit'] = stats.unit
|
||||
str_out=show_text(stats_dict)
|
||||
stats_dict['str_out']=str_out
|
||||
return stats_dict, None
|
||||
str_out = show_text(stats_dict)
|
||||
stats_dict['str_out'] = str_out
|
||||
return stats_dict, None
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
iteration_id = val[5]
|
||||
runtime = val[6]
|
||||
verification_type = val[8]
|
||||
if verification_type in (VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER):
|
||||
if verification_type in {VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER}:
|
||||
test_type = TestType.INIT_STATE_TEST
|
||||
else:
|
||||
# TODO : this is because sqlite writes original file module path. Should make it consistent
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ def run_line_profile_tests(
|
|||
)
|
||||
test_files: list[str] = []
|
||||
for file in test_paths.test_files:
|
||||
if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file:
|
||||
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
|
||||
test_files.extend(
|
||||
[
|
||||
str(file.benchmarking_file_path)
|
||||
|
|
@ -226,7 +226,7 @@ def run_benchmarking_tests(
|
|||
)
|
||||
test_files: list[str] = []
|
||||
for file in test_paths.test_files:
|
||||
if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file:
|
||||
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
|
||||
test_files.extend(
|
||||
[
|
||||
str(file.benchmarking_file_path)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
|||
|
||||
|
||||
def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
|
||||
assert test_type in ["unit", "inspired", "replay", "perf"]
|
||||
assert test_type in {"unit", "inspired", "replay", "perf"}
|
||||
function_name = function_name.replace(".", "_")
|
||||
path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}.py"
|
||||
if path.exists():
|
||||
|
|
|
|||
Loading…
Reference in a new issue