Ruff reformat and fix all the python files
Set minimum libcst version to be 1.0.1 move the stub files to dev dependencies
This commit is contained in:
parent
0a06160a57
commit
b42c270f9a
114 changed files with 1085 additions and 3157 deletions
|
|
@ -2,15 +2,7 @@ from __future__ import annotations
|
|||
|
||||
from sqlalchemy import ForeignKey, Integer, String, create_engine
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
Relationship,
|
||||
Session,
|
||||
mapped_column,
|
||||
relationship,
|
||||
sessionmaker,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, Relationship, Session, mapped_column, relationship, sessionmaker
|
||||
|
||||
|
||||
# Custom base class
|
||||
|
|
@ -18,24 +10,24 @@ class Base(DeclarativeBase):
|
|||
pass
|
||||
|
||||
|
||||
engine: Engine = create_engine('sqlite:///example.db')
|
||||
engine: Engine = create_engine("sqlite:///example.db")
|
||||
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
session: Session = session_factory()
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__: str = 'users'
|
||||
__tablename__: str = "users"
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
posts: Relationship[list[Post]] = relationship("Post", order_by="Post.id", back_populates="user")
|
||||
|
||||
|
||||
class Post(Base):
|
||||
__tablename__: str = 'posts'
|
||||
__tablename__: str = "posts"
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
title: Mapped[str] = mapped_column(String)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey('users.id'))
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"))
|
||||
user: Relationship[User] = relationship("User", back_populates="posts")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@ from sqlalchemy.engine import Engine, create_engine
|
|||
from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker
|
||||
from sqlalchemy.orm.relationships import Relationship
|
||||
|
||||
POSTGRES_CONNECTION_STRING: str = ("postgresql://cf_developer:XJcbU37MBYeh4dDK6PTV5n@sqlalchemy-experiments.postgres"
|
||||
".database.azure.com:5432/postgres")
|
||||
POSTGRES_CONNECTION_STRING: str = (
|
||||
"postgresql://cf_developer:XJcbU37MBYeh4dDK6PTV5n@sqlalchemy-experiments.postgres"
|
||||
".database.azure.com:5432/postgres"
|
||||
)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
|
|
@ -52,10 +54,8 @@ def get_authors(books: list[Book]) -> list[Author]:
|
|||
book: Book
|
||||
for book in books:
|
||||
_authors.append(book.author)
|
||||
return sorted(
|
||||
list(set(_authors)),
|
||||
key=lambda x: x.id,
|
||||
)
|
||||
return sorted(list(set(_authors)), key=lambda x: x.id)
|
||||
|
||||
|
||||
def get_authors2(num_authors) -> list[Author]:
|
||||
engine: Engine = create_engine(POSTGRES_CONNECTION_STRING, echo=True)
|
||||
|
|
@ -66,10 +66,7 @@ def get_authors2(num_authors) -> list[Author]:
|
|||
book: Book
|
||||
for book in books:
|
||||
_authors.append(book.author)
|
||||
return sorted(
|
||||
list(set(_authors)),
|
||||
key=lambda x: x.id,
|
||||
)[:num_authors]
|
||||
return sorted(list(set(_authors)), key=lambda x: x.id)[:num_authors]
|
||||
|
||||
|
||||
def get_top_author(authors: List[Author]) -> Author:
|
||||
|
|
@ -84,9 +81,7 @@ def get_top_author(authors: List[Author]) -> Author:
|
|||
# Step 2: Iterate over each author to count their bestsellers
|
||||
for author in authors:
|
||||
bestseller_count = (
|
||||
session.query(func.count(Book.id))
|
||||
.filter(Book.author_id == author.id, Book.is_bestseller == True)
|
||||
.scalar()
|
||||
session.query(func.count(Book.id)).filter(Book.author_id == author.id, Book.is_bestseller == True).scalar()
|
||||
)
|
||||
|
||||
# Step 3: Update the author with the maximum bestsellers
|
||||
|
|
|
|||
|
|
@ -5,19 +5,7 @@ from typing import Any, cast
|
|||
from _typeshed import SupportsDunderGT, SupportsDunderLT
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from code_to_optimize.book_catalog import (
|
||||
POSTGRES_CONNECTION_STRING,
|
||||
Author,
|
||||
Base,
|
||||
Book,
|
||||
_session,
|
||||
_t,
|
||||
authors,
|
||||
authors_name,
|
||||
engine,
|
||||
init_table,
|
||||
session_factory,
|
||||
)
|
||||
from code_to_optimize.book_catalog import Author, Book
|
||||
|
||||
|
||||
def get_authors(session: Session) -> list[Author]:
|
||||
|
|
@ -26,7 +14,4 @@ def get_authors(session: Session) -> list[Author]:
|
|||
book: Book
|
||||
for book in books:
|
||||
_authors.append(book.author)
|
||||
return sorted(
|
||||
list(set(_authors)),
|
||||
key=lambda x: cast(SupportsDunderLT[Any] | SupportsDunderGT[Any], x.id),
|
||||
)
|
||||
return sorted(list(set(_authors)), key=lambda x: cast(SupportsDunderLT[Any] | SupportsDunderGT[Any], x.id))
|
||||
|
|
|
|||
|
|
@ -1,18 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from code_to_optimize.book_catalog import (
|
||||
POSTGRES_CONNECTION_STRING,
|
||||
Author,
|
||||
Base,
|
||||
Book,
|
||||
_session,
|
||||
_t,
|
||||
authors,
|
||||
authors_name,
|
||||
engine,
|
||||
init_table,
|
||||
session_factory,
|
||||
)
|
||||
from code_to_optimize.book_catalog import Book
|
||||
|
||||
|
||||
def get_authors(session):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class BubbleSortClass:
|
|||
def sorter(self, arr):
|
||||
n = len(arr)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
for j in range(n - i - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
arr[j], arr[j + 1] = arr[j + 1], arr[j]
|
||||
return arr
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from __future__ import annotations
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Iterable, NewType, Optional, Protocol, TypeVar
|
||||
|
||||
|
|
@ -84,16 +84,11 @@ def is_new_type2(type_: type[Any]) -> bool:
|
|||
|
||||
|
||||
def _to_str(
|
||||
size: int,
|
||||
suffixes: Iterable[str],
|
||||
base: int,
|
||||
*,
|
||||
precision: Optional[int] = 1,
|
||||
separator: Optional[str] = " ",
|
||||
size: int, suffixes: Iterable[str], base: int, *, precision: Optional[int] = 1, separator: Optional[str] = " "
|
||||
) -> str:
|
||||
if size == 1:
|
||||
return "1 byte"
|
||||
elif size < base:
|
||||
if size < base:
|
||||
return f"{size:,} bytes"
|
||||
|
||||
for i, suffix in enumerate(suffixes, 2): # noqa: B007
|
||||
|
|
@ -101,10 +96,7 @@ def _to_str(
|
|||
if size < unit:
|
||||
break
|
||||
return "{:,.{precision}f}{separator}{}".format(
|
||||
(base * size / unit),
|
||||
suffix,
|
||||
precision=precision,
|
||||
separator=separator,
|
||||
(base * size / unit), suffix, precision=precision, separator=separator
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -114,16 +106,11 @@ def _to_str(
|
|||
|
||||
|
||||
def _to_str2(
|
||||
size: int,
|
||||
suffixes: Iterable[str],
|
||||
base: int,
|
||||
*,
|
||||
precision: Optional[int] = 1,
|
||||
separator: Optional[str] = " ",
|
||||
size: int, suffixes: Iterable[str], base: int, *, precision: Optional[int] = 1, separator: Optional[str] = " "
|
||||
) -> str:
|
||||
if size == 1:
|
||||
return "1 byte"
|
||||
elif size < base:
|
||||
if size < base:
|
||||
return f"{size:,} bytes"
|
||||
|
||||
unit = base
|
||||
|
|
@ -376,9 +363,7 @@ def with_pattern(pattern: str, regex_group_count: int | None = None) -> Callable
|
|||
|
||||
def with_pattern2(pattern: str, regex_group_count: int | None = None) -> Callable:
|
||||
return (
|
||||
lambda func: setattr(func, "pattern", pattern)
|
||||
or setattr(func, "regex_group_count", regex_group_count)
|
||||
or func
|
||||
lambda func: setattr(func, "pattern", pattern) or setattr(func, "regex_group_count", regex_group_count) or func
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,7 @@ from typing import List, Set, Tuple
|
|||
|
||||
|
||||
def compare_lists(
|
||||
li1: List[int],
|
||||
li2: List[int],
|
||||
value_func1=None,
|
||||
value_func2=None,
|
||||
li1: List[int], li2: List[int], value_func1=None, value_func2=None
|
||||
) -> Tuple[Set[int], Set[int], Set[int]]:
|
||||
"""Compare *li1* and *li2*, return the results as a list in the following form:
|
||||
|
||||
|
|
|
|||
|
|
@ -2,14 +2,13 @@ def translate(word):
|
|||
vowels = "aeiou"
|
||||
if word[0] in vowels:
|
||||
return word + "way"
|
||||
else:
|
||||
consonants = ""
|
||||
for letter in word:
|
||||
if letter not in vowels:
|
||||
consonants += letter
|
||||
else:
|
||||
break
|
||||
return word[len(consonants) :] + consonants + "ay"
|
||||
consonants = ""
|
||||
for letter in word:
|
||||
if letter not in vowels:
|
||||
consonants += letter
|
||||
else:
|
||||
break
|
||||
return word[len(consonants) :] + consonants + "ay"
|
||||
|
||||
|
||||
def pig_latin(text):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,8 @@
|
|||
def single_name_to_first_last_names(
|
||||
name: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
def single_name_to_first_last_names(name: str) -> list[tuple[str, str]]:
|
||||
parts = name.upper().split()
|
||||
if len(parts) == 2:
|
||||
return [tuple(parts)]
|
||||
elif len(parts) == 3:
|
||||
if len(parts) == 3:
|
||||
a, b, c = parts
|
||||
return [(a, c), (a, f"{b} {c}"), (f"{a} {b}", c)]
|
||||
else:
|
||||
return []
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
from code_to_optimize.final_test_set.encode_python_string_to_c import (
|
||||
_encodePythonStringToC,
|
||||
)
|
||||
from code_to_optimize.final_test_set.encode_python_string_to_c import _encodePythonStringToC
|
||||
|
||||
|
||||
def test_empty_string():
|
||||
|
|
|
|||
|
|
@ -25,9 +25,7 @@ def test_common_tags_1():
|
|||
def test_empty_article_list():
|
||||
articles = []
|
||||
expected = set()
|
||||
assert (
|
||||
find_common_tags(articles) == expected
|
||||
), "Test failed for empty list of articles."
|
||||
assert find_common_tags(articles) == expected, "Test failed for empty list of articles."
|
||||
|
||||
|
||||
def test_no_common_tags():
|
||||
|
|
@ -37,9 +35,7 @@ def test_no_common_tags():
|
|||
{"tags": ["javascript", "development", "web"]},
|
||||
]
|
||||
expected = set()
|
||||
assert (
|
||||
find_common_tags(articles) == expected
|
||||
), "Test failed when no tags are common."
|
||||
assert find_common_tags(articles) == expected, "Test failed when no tags are common."
|
||||
|
||||
|
||||
def test_all_common_tags():
|
||||
|
|
@ -49,14 +45,10 @@ def test_all_common_tags():
|
|||
{"tags": ["tech", "startups", "innovation"]},
|
||||
]
|
||||
expected = {"tech", "startups", "innovation"}
|
||||
assert (
|
||||
find_common_tags(articles) == expected
|
||||
), "Test failed when all tags are common."
|
||||
assert find_common_tags(articles) == expected, "Test failed when all tags are common."
|
||||
|
||||
|
||||
def test_single_article():
|
||||
articles = [{"tags": ["single", "article", "test"]}]
|
||||
expected = {"single", "article", "test"}
|
||||
assert (
|
||||
find_common_tags(articles) == expected
|
||||
), "Test failed for a single article input."
|
||||
assert find_common_tags(articles) == expected, "Test failed for a single article input."
|
||||
|
|
|
|||
|
|
@ -2,11 +2,7 @@ from code_to_optimize.final_test_set.find_duplicates import find_duplicates
|
|||
|
||||
|
||||
def test_basic_case():
|
||||
assert find_duplicates([1, 2, 3, 2, 1, 5, 6, 5]) == [
|
||||
1,
|
||||
2,
|
||||
5,
|
||||
], "Failed on basic case"
|
||||
assert find_duplicates([1, 2, 3, 2, 1, 5, 6, 5]) == [1, 2, 5], "Failed on basic case"
|
||||
|
||||
|
||||
def test_no_duplicates():
|
||||
|
|
@ -14,10 +10,7 @@ def test_no_duplicates():
|
|||
|
||||
|
||||
def test_multiple_duplicates():
|
||||
assert find_duplicates([1, 2, 2, 3, 3, 3, 4]) == [
|
||||
2,
|
||||
3,
|
||||
], "Failed on multiple duplicates of the same item"
|
||||
assert find_duplicates([1, 2, 2, 3, 3, 3, 4]) == [2, 3], "Failed on multiple duplicates of the same item"
|
||||
|
||||
|
||||
def test_empty_list():
|
||||
|
|
@ -29,7 +22,4 @@ def test_all_elements_same():
|
|||
|
||||
|
||||
def test_mixed_data_types():
|
||||
assert find_duplicates(["apple", "banana", "apple", 42, 42]) == [
|
||||
"apple",
|
||||
42,
|
||||
], "Failed on mixed data types"
|
||||
assert find_duplicates(["apple", "banana", "apple", 42, 42]) == ["apple", 42], "Failed on mixed data types"
|
||||
|
|
|
|||
|
|
@ -17,28 +17,18 @@ def test_prime_number():
|
|||
|
||||
|
||||
def test_perfect_square():
|
||||
assert find_factors(16) == [
|
||||
(1, 16),
|
||||
(2, 8),
|
||||
(4, 4),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
], "Failed on perfect square number"
|
||||
assert find_factors(16) == [(1, 16), (2, 8), (4, 4), (8, 2), (16, 1)], "Failed on perfect square number"
|
||||
|
||||
|
||||
def test_large_number():
|
||||
# 120 has factors: 1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20, 24, 30, 40, 60, 120
|
||||
result = find_factors(120)
|
||||
expected_factors = 16 # There should be 16 pairs
|
||||
assert (
|
||||
len(result) == expected_factors
|
||||
), "Failed on large number with multiple factors"
|
||||
assert len(result) == expected_factors, "Failed on large number with multiple factors"
|
||||
|
||||
|
||||
def test_one():
|
||||
assert find_factors(1) == [
|
||||
(1, 1)
|
||||
], "Failed on one, which should only have one factor pair"
|
||||
assert find_factors(1) == [(1, 1)], "Failed on one, which should only have one factor pair"
|
||||
|
||||
|
||||
def test_zero():
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ from code_to_optimize.final_test_set.integration import integrate_f
|
|||
|
||||
|
||||
def isclose(a, b, rel_tol=1e-5, abs_tol=0.0):
|
||||
"""
|
||||
Helper function to compare two floating points for 'closeness'.
|
||||
"""Helper function to compare two floating points for 'closeness'.
|
||||
Uses a combination of relative and absolute tolerances.
|
||||
"""
|
||||
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
|
||||
|
|
@ -30,6 +29,4 @@ def test_with_pytest_approx():
|
|||
a, b, N = 0, 1, 1000
|
||||
result = integrate_f(a, b, N)
|
||||
expected = -1 / 6
|
||||
assert result == pytest.approx(
|
||||
expected, rel=1e-5
|
||||
), "Test failed with pytest's approx."
|
||||
assert result == pytest.approx(expected, rel=1e-5), "Test failed with pytest's approx."
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from code_to_optimize.final_test_set.pig_latin import pig_latin
|
|||
|
||||
|
||||
def log_test_values(values, test_name):
|
||||
with open(f"/tmp/test_return_values.bin", "ab") as f:
|
||||
with open("/tmp/test_return_values.bin", "ab") as f:
|
||||
return_bytes = pickle.dumps(values)
|
||||
_test_name = f"{test_name}".encode("ascii")
|
||||
f.write(len(_test_name).to_bytes(4, byteorder="big"))
|
||||
|
|
@ -25,12 +25,8 @@ def test_pig_latin_single_consonant():
|
|||
|
||||
|
||||
def test_pig_latin_multiple_consonants():
|
||||
log_test_values(
|
||||
pig_latin("string"), "pig_latin_test_pig_latin_multiple_consonants_0"
|
||||
)
|
||||
log_test_values(
|
||||
pig_latin("glove"), "pig_latin_test_pig_latin_multiple_consonants_1"
|
||||
)
|
||||
log_test_values(pig_latin("string"), "pig_latin_test_pig_latin_multiple_consonants_0")
|
||||
log_test_values(pig_latin("glove"), "pig_latin_test_pig_latin_multiple_consonants_1")
|
||||
|
||||
|
||||
def test_pig_latin_capital_letters():
|
||||
|
|
@ -39,13 +35,8 @@ def test_pig_latin_capital_letters():
|
|||
|
||||
|
||||
def test_pig_latin_multiple_words():
|
||||
log_test_values(
|
||||
pig_latin("The quick brown fox"), "pig_latin_test_pig_latin_multiple_words_0"
|
||||
)
|
||||
log_test_values(
|
||||
pig_latin("Python is a fun language"),
|
||||
"pig_latin_test_pig_latin_multiple_words_1",
|
||||
)
|
||||
log_test_values(pig_latin("The quick brown fox"), "pig_latin_test_pig_latin_multiple_words_0")
|
||||
log_test_values(pig_latin("Python is a fun language"), "pig_latin_test_pig_latin_multiple_words_1")
|
||||
|
||||
|
||||
def test_pig_latin_empty_input():
|
||||
|
|
@ -58,9 +49,7 @@ def test_pig_latin_spaces_input():
|
|||
|
||||
def test_pig_latin_non_alphabetic():
|
||||
log_test_values(pig_latin("123"), "pig_latin_test_pig_latin_non_alphabetic_0")
|
||||
log_test_values(
|
||||
pig_latin("Hello, world!"), "pig_latin_test_pig_latin_non_alphabetic_1"
|
||||
)
|
||||
log_test_values(pig_latin("Hello, world!"), "pig_latin_test_pig_latin_non_alphabetic_1")
|
||||
|
||||
|
||||
def test_pig_latin_non_ascii():
|
||||
|
|
@ -69,12 +58,8 @@ def test_pig_latin_non_ascii():
|
|||
|
||||
|
||||
def test_pig_latin_hyphenated_words():
|
||||
log_test_values(
|
||||
pig_latin("sister-in-law"), "pig_latin_test_pig_latin_hyphenated_words_0"
|
||||
)
|
||||
log_test_values(
|
||||
pig_latin("self-driving car"), "pig_latin_test_pig_latin_hyphenated_words_1"
|
||||
)
|
||||
log_test_values(pig_latin("sister-in-law"), "pig_latin_test_pig_latin_hyphenated_words_0")
|
||||
log_test_values(pig_latin("self-driving car"), "pig_latin_test_pig_latin_hyphenated_words_1")
|
||||
|
||||
|
||||
def test_pig_latin_contractions():
|
||||
|
|
@ -84,9 +69,7 @@ def test_pig_latin_contractions():
|
|||
|
||||
def test_pig_latin_apostrophes():
|
||||
log_test_values(pig_latin("don't"), "pig_latin_test_pig_latin_apostrophes_0")
|
||||
log_test_values(
|
||||
pig_latin("rock 'n' roll"), "pig_latin_test_pig_latin_apostrophes_1"
|
||||
)
|
||||
log_test_values(pig_latin("rock 'n' roll"), "pig_latin_test_pig_latin_apostrophes_1")
|
||||
|
||||
|
||||
def test_pig_latin_non_letter():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
from code_to_optimize.final_test_set.single_name_to_first_last_names import (
|
||||
single_name_to_first_last_names,
|
||||
)
|
||||
from code_to_optimize.final_test_set.single_name_to_first_last_names import single_name_to_first_last_names
|
||||
|
||||
|
||||
def test_two_part_name():
|
||||
|
|
|
|||
|
|
@ -6,16 +6,10 @@ def test_concatenate_strings_zero():
|
|||
|
||||
|
||||
def test_concatenate_strings_positive():
|
||||
assert (
|
||||
concatenate_strings(5) == "0, 1, 2, 3, 4, "
|
||||
), "Failed: Incorrect string for input 5"
|
||||
assert concatenate_strings(5) == "0, 1, 2, 3, 4, ", "Failed: Incorrect string for input 5"
|
||||
|
||||
|
||||
def test_concatenate_strings_large_number():
|
||||
result = concatenate_strings(1000)
|
||||
expected_length = sum(
|
||||
len(str(i)) + 2 for i in range(1000)
|
||||
) # Each number i + len(", ")
|
||||
assert (
|
||||
len(result) == expected_length
|
||||
), f"Failed: Incorrect length for large input 1000"
|
||||
expected_length = sum(len(str(i)) + 2 for i in range(1000)) # Each number i + len(", ")
|
||||
assert len(result) == expected_length, "Failed: Incorrect length for large input 1000"
|
||||
|
|
|
|||
|
|
@ -13,9 +13,7 @@ def test_k_greater_than_array_length():
|
|||
array = [4, 1, 5, 6, 2]
|
||||
k = 10
|
||||
expected = sorted(array, reverse=True)
|
||||
assert (
|
||||
find_top_k_elements(array, k) == expected
|
||||
), "Failed when k is greater than array length"
|
||||
assert find_top_k_elements(array, k) == expected, "Failed when k is greater than array length"
|
||||
|
||||
|
||||
def test_normal_case():
|
||||
|
|
@ -29,9 +27,7 @@ def test_array_with_duplicate_values():
|
|||
array = [5, 5, 5, 5]
|
||||
k = 2
|
||||
expected = [5, 5]
|
||||
assert (
|
||||
find_top_k_elements(array, k) == expected
|
||||
), "Failed when array contains duplicates"
|
||||
assert find_top_k_elements(array, k) == expected, "Failed when array contains duplicates"
|
||||
|
||||
|
||||
def test_empty_array():
|
||||
|
|
@ -42,6 +38,4 @@ def test_single_element_array():
|
|||
array = [42]
|
||||
k = 1
|
||||
expected = [42]
|
||||
assert (
|
||||
find_top_k_elements(array, k) == expected
|
||||
), "Failed when array contains a single element"
|
||||
assert find_top_k_elements(array, k) == expected, "Failed when array contains a single element"
|
||||
|
|
|
|||
|
|
@ -33,10 +33,7 @@ def test_complex_dag():
|
|||
g.addEdge(2, 3)
|
||||
g.addEdge(3, 1)
|
||||
result = g.topologicalSort()
|
||||
assert all(
|
||||
result.index(u) < result.index(v)
|
||||
for u, v in [(5, 2), (5, 0), (4, 0), (4, 1), (2, 3), (3, 1)]
|
||||
)
|
||||
assert all(result.index(u) < result.index(v) for u, v in [(5, 2), (5, 0), (4, 0), (4, 1), (2, 3), (3, 1)])
|
||||
|
||||
|
||||
def test_single_node_graph():
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|||
Y = np.array(Y)
|
||||
if X.shape[1] != Y.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
||||
f"and Y has shape {Y.shape}.",
|
||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} " f"and Y has shape {Y.shape}."
|
||||
)
|
||||
|
||||
X_norm = np.linalg.norm(X, axis=1)
|
||||
|
|
@ -27,10 +26,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|||
|
||||
|
||||
def cosine_similarity_top_k(
|
||||
X: Matrix,
|
||||
Y: Matrix,
|
||||
top_k: Optional[int] = 5,
|
||||
score_threshold: Optional[float] = None,
|
||||
X: Matrix, Y: Matrix, top_k: Optional[int] = 5, score_threshold: Optional[float] = None
|
||||
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
||||
"""Row-wise cosine similarity with optional top-k and score threshold filtering.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,14 +2,13 @@ def translate(word):
|
|||
vowels = "aeiou"
|
||||
if word[0] in vowels:
|
||||
return word + "way"
|
||||
else:
|
||||
consonants = ""
|
||||
for letter in word:
|
||||
if letter not in vowels:
|
||||
consonants += letter
|
||||
else:
|
||||
break
|
||||
return word[len(consonants) :] + consonants + "ay"
|
||||
consonants = ""
|
||||
for letter in word:
|
||||
if letter not in vowels:
|
||||
consonants += letter
|
||||
else:
|
||||
break
|
||||
return word[len(consonants) :] + consonants + "ay"
|
||||
|
||||
|
||||
def pig_latin(text):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, create_engine, delete, update
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from code_to_optimize.book_catalog import Author, Book, get_authors
|
||||
from code_to_optimize.book_catalog import Book, get_authors
|
||||
|
||||
POSTGRES_CONNECTION_STRING = (
|
||||
"postgresql://cf_developer:XJcbU37MBYeh4dDK6PTV5n@sqlalchemy-experiments.postgres"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, create_engine, delete, update
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from code_to_optimize.book_catalog import Author, Book, get_top_author
|
||||
from code_to_optimize.book_catalog import Author, get_top_author
|
||||
|
||||
POSTGRES_CONNECTION_STRING = (
|
||||
"postgresql://cf_developer:XJcbU37MBYeh4dDK6PTV5n@sqlalchemy-experiments.postgres"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from code_to_optimize.pig_latin import pig_latin
|
|||
|
||||
|
||||
def log_test_values(values, test_name):
|
||||
with open(f"/tmp/test_return_values.bin", "ab") as f:
|
||||
with open("/tmp/test_return_values.bin", "ab") as f:
|
||||
return_bytes = pickle.dumps(values)
|
||||
_test_name = f"{test_name}".encode("ascii")
|
||||
f.write(len(_test_name).to_bytes(4, byteorder="big"))
|
||||
|
|
@ -25,12 +25,8 @@ def test_pig_latin_single_consonant():
|
|||
|
||||
|
||||
def test_pig_latin_multiple_consonants():
|
||||
log_test_values(
|
||||
pig_latin("string"), "pig_latin_test_pig_latin_multiple_consonants_0"
|
||||
)
|
||||
log_test_values(
|
||||
pig_latin("glove"), "pig_latin_test_pig_latin_multiple_consonants_1"
|
||||
)
|
||||
log_test_values(pig_latin("string"), "pig_latin_test_pig_latin_multiple_consonants_0")
|
||||
log_test_values(pig_latin("glove"), "pig_latin_test_pig_latin_multiple_consonants_1")
|
||||
|
||||
|
||||
def test_pig_latin_capital_letters():
|
||||
|
|
@ -39,13 +35,8 @@ def test_pig_latin_capital_letters():
|
|||
|
||||
|
||||
def test_pig_latin_multiple_words():
|
||||
log_test_values(
|
||||
pig_latin("The quick brown fox"), "pig_latin_test_pig_latin_multiple_words_0"
|
||||
)
|
||||
log_test_values(
|
||||
pig_latin("Python is a fun language"),
|
||||
"pig_latin_test_pig_latin_multiple_words_1",
|
||||
)
|
||||
log_test_values(pig_latin("The quick brown fox"), "pig_latin_test_pig_latin_multiple_words_0")
|
||||
log_test_values(pig_latin("Python is a fun language"), "pig_latin_test_pig_latin_multiple_words_1")
|
||||
|
||||
|
||||
def test_pig_latin_empty_input():
|
||||
|
|
@ -58,9 +49,7 @@ def test_pig_latin_spaces_input():
|
|||
|
||||
def test_pig_latin_non_alphabetic():
|
||||
log_test_values(pig_latin("123"), "pig_latin_test_pig_latin_non_alphabetic_0")
|
||||
log_test_values(
|
||||
pig_latin("Hello, world!"), "pig_latin_test_pig_latin_non_alphabetic_1"
|
||||
)
|
||||
log_test_values(pig_latin("Hello, world!"), "pig_latin_test_pig_latin_non_alphabetic_1")
|
||||
|
||||
|
||||
def test_pig_latin_non_ascii():
|
||||
|
|
@ -69,12 +58,8 @@ def test_pig_latin_non_ascii():
|
|||
|
||||
|
||||
def test_pig_latin_hyphenated_words():
|
||||
log_test_values(
|
||||
pig_latin("sister-in-law"), "pig_latin_test_pig_latin_hyphenated_words_0"
|
||||
)
|
||||
log_test_values(
|
||||
pig_latin("self-driving car"), "pig_latin_test_pig_latin_hyphenated_words_1"
|
||||
)
|
||||
log_test_values(pig_latin("sister-in-law"), "pig_latin_test_pig_latin_hyphenated_words_0")
|
||||
log_test_values(pig_latin("self-driving car"), "pig_latin_test_pig_latin_hyphenated_words_1")
|
||||
|
||||
|
||||
def test_pig_latin_contractions():
|
||||
|
|
@ -84,9 +69,7 @@ def test_pig_latin_contractions():
|
|||
|
||||
def test_pig_latin_apostrophes():
|
||||
log_test_values(pig_latin("don't"), "pig_latin_test_pig_latin_apostrophes_0")
|
||||
log_test_values(
|
||||
pig_latin("rock 'n' roll"), "pig_latin_test_pig_latin_apostrophes_1"
|
||||
)
|
||||
log_test_values(pig_latin("rock 'n' roll"), "pig_latin_test_pig_latin_apostrophes_1")
|
||||
|
||||
|
||||
def test_pig_latin_non_letter():
|
||||
|
|
|
|||
|
|
@ -4,10 +4,7 @@ from code_to_optimize.math_utils import Matrix, cosine_similarity_top_k
|
|||
|
||||
|
||||
def use_cosine_similarity(
|
||||
X: Matrix,
|
||||
Y: Matrix,
|
||||
top_k: Optional[int] = 5,
|
||||
score_threshold: Optional[float] = None,
|
||||
X: Matrix, Y: Matrix, top_k: Optional[int] = 5, score_threshold: Optional[float] = None
|
||||
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
||||
return cosine_similarity_top_k(X, Y, top_k, score_threshold)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,10 +33,7 @@ class OptimizedCandidate:
|
|||
class AiServiceClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = self.get_aiservice_base_url()
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {get_codeflash_api_key()}",
|
||||
"Connection": "close",
|
||||
}
|
||||
self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"}
|
||||
|
||||
def get_aiservice_base_url(self) -> str:
|
||||
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
|
||||
|
|
@ -45,11 +42,7 @@ class AiServiceClient:
|
|||
return "https://app.codeflash.ai"
|
||||
|
||||
def make_ai_service_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str = "POST",
|
||||
payload: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
self, endpoint: str, method: str = "POST", payload: dict[str, Any] | None = None, timeout: float | None = None
|
||||
) -> requests.Response:
|
||||
"""Make an API request to the given endpoint on the AI service.
|
||||
|
||||
|
|
@ -98,11 +91,7 @@ class AiServiceClient:
|
|||
}
|
||||
logger.info("Generating optimized candidates ...")
|
||||
try:
|
||||
response = self.make_ai_service_request(
|
||||
"/optimize",
|
||||
payload=payload,
|
||||
timeout=600,
|
||||
)
|
||||
response = self.make_ai_service_request("/optimize", payload=payload, timeout=600)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.exception(f"Error generating optimized candidates: {e}")
|
||||
ph("cli-optimize-error-caught", {"error": str(e)})
|
||||
|
|
@ -124,10 +113,7 @@ class AiServiceClient:
|
|||
except Exception:
|
||||
error = response.text
|
||||
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
|
||||
ph(
|
||||
"cli-optimize-error-response",
|
||||
{"response_status_code": response.status_code, "error": error},
|
||||
)
|
||||
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
|
||||
return []
|
||||
|
||||
def log_results(
|
||||
|
|
@ -225,17 +211,11 @@ class AiServiceClient:
|
|||
try:
|
||||
error = response.json()["error"]
|
||||
logger.error(f"Error generating tests: {response.status_code} - {error}")
|
||||
ph(
|
||||
"cli-testgen-error-response",
|
||||
{"response_status_code": response.status_code, "error": error},
|
||||
)
|
||||
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
|
||||
return None
|
||||
except Exception:
|
||||
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
|
||||
ph(
|
||||
"cli-testgen-error-response",
|
||||
{"response_status_code": response.status_code, "error": response.text},
|
||||
)
|
||||
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,11 +22,7 @@ else:
|
|||
CFAPI_BASE_URL = "https://app.codeflash.ai"
|
||||
|
||||
|
||||
def make_cfapi_request(
|
||||
endpoint: str,
|
||||
method: str,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
) -> requests.Response:
|
||||
def make_cfapi_request(endpoint: str, method: str, payload: Optional[Dict[str, Any]] = None) -> requests.Response:
|
||||
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
|
||||
:param endpoint: The endpoint URL to send the request to.
|
||||
:param method: The HTTP method to use ('GET', 'POST', etc.).
|
||||
|
|
@ -55,11 +51,8 @@ def get_user_id() -> Optional[str]:
|
|||
response = make_cfapi_request(endpoint="/cli-get-user", method="GET")
|
||||
if response.status_code == 200:
|
||||
return response.text
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to look up your userid; is your CF API key valid? ({response.reason})",
|
||||
)
|
||||
return None
|
||||
logger.error(f"Failed to look up your userid; is your CF API key valid? ({response.reason})")
|
||||
return None
|
||||
|
||||
|
||||
def suggest_changes(
|
||||
|
|
@ -137,10 +130,7 @@ def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
|
|||
:param repo: The name of the repository.
|
||||
:return: The response object.
|
||||
"""
|
||||
response = make_cfapi_request(
|
||||
endpoint=f"/is-github-app-installed?repo={repo}&owner={owner}",
|
||||
method="GET",
|
||||
)
|
||||
response = make_cfapi_request(endpoint=f"/is-github-app-installed?repo={repo}&owner={owner}", method="GET")
|
||||
if not response.ok or response.text != "true":
|
||||
logger.error(f"Error: {response.text}")
|
||||
return False
|
||||
|
|
@ -153,17 +143,9 @@ def get_blocklisted_functions() -> dict[str, str]:
|
|||
return {}
|
||||
|
||||
owner, repo = get_repo_owner_and_name()
|
||||
information = {
|
||||
"pr_number": pr_number,
|
||||
"repo_owner": owner,
|
||||
"repo_name": repo,
|
||||
}
|
||||
information = {"pr_number": pr_number, "repo_owner": owner, "repo_name": repo}
|
||||
try:
|
||||
req = make_cfapi_request(
|
||||
endpoint="/verify-existing-optimizations",
|
||||
method="POST",
|
||||
payload=information,
|
||||
)
|
||||
req = make_cfapi_request(endpoint="/verify-existing-optimizations", method="POST", payload=information)
|
||||
content: dict[str, list[str]] = req.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting blocklisted functions: {e}")
|
||||
|
|
|
|||
|
|
@ -31,10 +31,7 @@ def parse_args() -> Namespace:
|
|||
init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
|
||||
init_actions_parser.set_defaults(func=install_github_actions)
|
||||
parser.add_argument("--file", help="Try to optimize only this file")
|
||||
parser.add_argument(
|
||||
"--function",
|
||||
help="Try to optimize only this function within the given file path",
|
||||
)
|
||||
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
|
||||
parser.add_argument(
|
||||
"--all",
|
||||
help="Try to optimize all functions. Can take a really long time. Can pass an optional starting directory to"
|
||||
|
|
@ -50,30 +47,16 @@ def parse_args() -> Namespace:
|
|||
" This is the top-level root directory where all the Python source code is located.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-root",
|
||||
type=str,
|
||||
help="Path to the test directory of the project, where all the tests are located.",
|
||||
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
|
||||
)
|
||||
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
|
||||
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
|
||||
parser.add_argument(
|
||||
"--config-file",
|
||||
type=str,
|
||||
help="Path to the pyproject.toml with codeflash configs.",
|
||||
"--use-cached-tests", action="store_true", help="Use cached tests from a specified file for debugging."
|
||||
)
|
||||
parser.add_argument("--replay-test", type=str, help="Path to replay test to optimize functions from")
|
||||
parser.add_argument(
|
||||
"--use-cached-tests",
|
||||
action="store_true",
|
||||
help="Use cached tests from a specified file for debugging.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--replay-test",
|
||||
type=str,
|
||||
help="Path to replay test to optimize functions from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-pr",
|
||||
action="store_true",
|
||||
help="Do not create a PR for the optimization, only update the code locally.",
|
||||
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify-setup",
|
||||
|
|
@ -188,7 +171,7 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
|
|||
except git.exc.InvalidGitRepositoryError:
|
||||
logger.exception(
|
||||
"I couldn't find a git repository in the current directory. "
|
||||
"I need a git repository to run --all and open PRs for optimizations. Exiting...",
|
||||
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
|
||||
)
|
||||
apologize_and_exit()
|
||||
if not args.no_pr and not check_and_push_branch(git_repo):
|
||||
|
|
|
|||
|
|
@ -10,17 +10,13 @@ import inquirer
|
|||
|
||||
def apologize_and_exit() -> None:
|
||||
click.echo(
|
||||
"💡 If you're having trouble, see https://docs.codeflash.ai/getting-started/local-installation for further help getting started with Codeflash!",
|
||||
"💡 If you're having trouble, see https://docs.codeflash.ai/getting-started/local-installation for further help getting started with Codeflash!"
|
||||
)
|
||||
click.echo("👋 Exiting...")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def inquirer_wrapper(
|
||||
func: Callable[..., str | bool],
|
||||
*args: str | bool,
|
||||
**kwargs: str | bool,
|
||||
) -> str | bool:
|
||||
def inquirer_wrapper(func: Callable[..., str | bool], *args: str | bool, **kwargs: str | bool) -> str | bool:
|
||||
new_args = []
|
||||
new_kwargs = {}
|
||||
|
||||
|
|
@ -29,10 +25,7 @@ def inquirer_wrapper(
|
|||
else:
|
||||
message = kwargs["message"]
|
||||
new_kwargs = kwargs.copy()
|
||||
split_messages = split_string_to_cli_width(
|
||||
message,
|
||||
is_confirm=func == inquirer.confirm,
|
||||
)
|
||||
split_messages = split_string_to_cli_width(message, is_confirm=func == inquirer.confirm)
|
||||
for split_message in split_messages[:-1]:
|
||||
click.echo(split_message)
|
||||
|
||||
|
|
@ -77,14 +70,7 @@ def inquirer_wrapper_path(*args: str, **kwargs: str) -> dict[str, str]:
|
|||
new_kwargs["message"] = last_message
|
||||
new_args.append(args[0])
|
||||
|
||||
return cast(
|
||||
dict[str, str],
|
||||
inquirer.prompt(
|
||||
[
|
||||
inquirer.Path(*new_args, **new_kwargs),
|
||||
],
|
||||
),
|
||||
)
|
||||
return cast(dict[str, str], inquirer.prompt([inquirer.Path(*new_args, **new_kwargs)]))
|
||||
|
||||
|
||||
def split_string_to_fit_width(string: str, width: int) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -21,18 +21,10 @@ from codeflash.api.cfapi import is_github_app_installed_on_repo
|
|||
from codeflash.cli_cmds.cli_common import apologize_and_exit, inquirer_wrapper, inquirer_wrapper_path
|
||||
from codeflash.code_utils.compat import LF
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.code_utils.env_utils import (
|
||||
get_codeflash_api_key,
|
||||
)
|
||||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
from codeflash.code_utils.git_utils import get_repo_owner_and_name
|
||||
from codeflash.code_utils.github_utils import (
|
||||
get_github_secrets_page_url,
|
||||
require_github_app_or_exit,
|
||||
)
|
||||
from codeflash.code_utils.shell_utils import (
|
||||
get_shell_rc_path,
|
||||
save_api_key_to_rc,
|
||||
)
|
||||
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
|
||||
from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.version import __version__ as version
|
||||
|
||||
|
|
@ -78,12 +70,10 @@ def init_codeflash() -> None:
|
|||
f" codeflash --file <path-to-file> to optimize all functions in a file{LF}"
|
||||
f" codeflash --all to optimize all functions in all files in the module you selected ({setup_info.module_root}){LF}"
|
||||
f"-or-{LF}"
|
||||
f" codeflash --help to see all options{LF}",
|
||||
f" codeflash --help to see all options{LF}"
|
||||
)
|
||||
if did_add_new_key:
|
||||
click.echo(
|
||||
"🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!",
|
||||
)
|
||||
click.echo("🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!")
|
||||
click.echo("Or run the following command to reload:")
|
||||
if os.name == "nt":
|
||||
click.echo(f" call {get_shell_rc_path()}")
|
||||
|
|
@ -111,30 +101,16 @@ def collect_setup_info() -> SetupInfo:
|
|||
curdir = Path.cwd()
|
||||
# Check if the cwd is writable
|
||||
if not os.access(curdir, os.W_OK):
|
||||
click.echo(
|
||||
f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}",
|
||||
)
|
||||
click.echo(f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}")
|
||||
click.echo("It's likely you don't have write permissions for this folder.")
|
||||
sys.exit(1)
|
||||
|
||||
# Check for the existence of pyproject.toml or setup.py
|
||||
project_name = check_for_toml_or_setup_file()
|
||||
|
||||
ignore_subdirs = [
|
||||
"venv",
|
||||
"node_modules",
|
||||
"dist",
|
||||
"build",
|
||||
"build_temp",
|
||||
"build_scripts",
|
||||
"env",
|
||||
"logs",
|
||||
"tmp",
|
||||
]
|
||||
ignore_subdirs = ["venv", "node_modules", "dist", "build", "build_temp", "build_scripts", "env", "logs", "tmp"]
|
||||
valid_subdirs = [
|
||||
d
|
||||
for d in next(os.walk("."))[1]
|
||||
if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs
|
||||
d for d in next(os.walk("."))[1] if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs
|
||||
]
|
||||
|
||||
valid_module_subdirs = [d for d in valid_subdirs if d != "tests"]
|
||||
|
|
@ -165,9 +141,7 @@ def collect_setup_info() -> SetupInfo:
|
|||
message="Where are your tests located? "
|
||||
f"(If you don't have any tests yet, I can create an empty tests{os.pathsep} directory for you)",
|
||||
choices=test_subdir_options,
|
||||
default=(
|
||||
default_tests_subdir if default_tests_subdir in test_subdir_options else test_subdir_options[0]
|
||||
),
|
||||
default=(default_tests_subdir if default_tests_subdir in test_subdir_options else test_subdir_options[0]),
|
||||
)
|
||||
|
||||
if tests_root_answer == create_for_me_option:
|
||||
|
|
@ -285,11 +259,7 @@ def check_for_toml_or_setup_file() -> str | None:
|
|||
else:
|
||||
if setup_py_path.exists():
|
||||
setup_py_content = setup_py_path.read_text(encoding="utf8")
|
||||
project_name_match = re.search(
|
||||
r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]",
|
||||
setup_py_content,
|
||||
re.DOTALL,
|
||||
)
|
||||
project_name_match = re.search(r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, re.DOTALL)
|
||||
if project_name_match:
|
||||
project_name = project_name_match.group(1)
|
||||
click.echo(f"✅ Found setup.py for your project {project_name}")
|
||||
|
|
@ -300,7 +270,7 @@ def check_for_toml_or_setup_file() -> str | None:
|
|||
click.echo(
|
||||
f"💡 I couldn't find a pyproject.toml in the current directory ({curdir}).{LF}"
|
||||
f"(make sure you're running `codeflash init` from your project's root directory!){LF}"
|
||||
f"I need this file to store my configuration settings.",
|
||||
f"I need this file to store my configuration settings."
|
||||
)
|
||||
ph("cli-no-pyproject-toml-or-setup-py")
|
||||
|
||||
|
|
@ -321,14 +291,12 @@ def check_for_toml_or_setup_file() -> str | None:
|
|||
|
||||
# Check if the pyproject.toml file was created
|
||||
if pyproject_toml_path.exists():
|
||||
click.echo(
|
||||
f"✅ Created a pyproject.toml file at {pyproject_toml_path}",
|
||||
)
|
||||
click.echo(f"✅ Created a pyproject.toml file at {pyproject_toml_path}")
|
||||
click.pause()
|
||||
ph("cli-created-pyproject-toml")
|
||||
except OSError:
|
||||
click.echo(
|
||||
"❌ Failed to create pyproject.toml. Please check your disk permissions and available space.",
|
||||
"❌ Failed to create pyproject.toml. Please check your disk permissions and available space."
|
||||
)
|
||||
apologize_and_exit()
|
||||
else:
|
||||
|
|
@ -341,7 +309,7 @@ def check_for_toml_or_setup_file() -> str | None:
|
|||
def install_github_actions() -> None:
|
||||
try:
|
||||
click.echo(
|
||||
"⚡️ Codeflash can automatically optimize new Github PRs for you when they're opened. Let's get that set up!",
|
||||
"⚡️ Codeflash can automatically optimize new Github PRs for you when they're opened. Let's get that set up!"
|
||||
)
|
||||
config, config_file_path = parse_config_file()
|
||||
|
||||
|
|
@ -360,10 +328,7 @@ def install_github_actions() -> None:
|
|||
message=f"I'm going to create a new GitHub actions workflow file at {optimize_yaml_path} ... is this OK?",
|
||||
default=True,
|
||||
)
|
||||
ph(
|
||||
"cli-github-optimization-confirm-workflow-creation",
|
||||
{"confirm_creation": confirm_creation_yes},
|
||||
)
|
||||
ph("cli-github-optimization-confirm-workflow-creation", {"confirm_creation": confirm_creation_yes})
|
||||
if not confirm_creation_yes:
|
||||
click.echo("⏩️ Exiting workflow creation.")
|
||||
ph("cli-github-workflow-skipped")
|
||||
|
|
@ -374,14 +339,9 @@ def install_github_actions() -> None:
|
|||
py_version = sys.version_info
|
||||
python_version_string = f"'{py_version.major}.{py_version.minor}'"
|
||||
optimize_yml_content = (
|
||||
files("codeflash")
|
||||
.joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml")
|
||||
.read_text(encoding="utf-8")
|
||||
)
|
||||
optimize_yml_content = optimize_yml_content.replace(
|
||||
"{{ python_version }}",
|
||||
python_version_string,
|
||||
files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8")
|
||||
)
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ python_version }}", python_version_string)
|
||||
with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file:
|
||||
optimize_yml_file.write(optimize_yml_content)
|
||||
click.echo(f"✅ Created {optimize_yaml_path}{LF}")
|
||||
|
|
@ -398,7 +358,7 @@ def install_github_actions() -> None:
|
|||
click.echo(
|
||||
"🐙 I opened your Github secrets page! Note: if you see a 404, you probably don't have access to this "
|
||||
"repo's secrets; ask a repo admin to add it for you, or (not super recommended) you can temporarily "
|
||||
f"hard-code your api key into the workflow file.{LF}",
|
||||
f"hard-code your api key into the workflow file.{LF}"
|
||||
)
|
||||
click.pause()
|
||||
click.echo()
|
||||
|
|
@ -414,13 +374,13 @@ def install_github_actions() -> None:
|
|||
click.launch(optimize_yaml_path.as_posix())
|
||||
click.echo(
|
||||
"📝 I opened the workflow file in your editor! You'll need to edit the steps that install the right Python "
|
||||
f"version and any project dependencies. See the comments in the file for more details.{LF}",
|
||||
f"version and any project dependencies. See the comments in the file for more details.{LF}"
|
||||
)
|
||||
click.pause()
|
||||
click.echo()
|
||||
click.echo(
|
||||
f"Please commit and push this GitHub actions file to your repo, and you're all set!{LF}"
|
||||
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}",
|
||||
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}"
|
||||
)
|
||||
ph("cli-github-workflow-created")
|
||||
except KeyboardInterrupt:
|
||||
|
|
@ -436,14 +396,12 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
|
|||
except FileNotFoundError:
|
||||
click.echo(
|
||||
f"I couldn't find a pyproject.toml in the current directory.{LF}"
|
||||
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file.",
|
||||
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file."
|
||||
)
|
||||
apologize_and_exit()
|
||||
|
||||
codeflash_section = tomlkit.table()
|
||||
codeflash_section.add(
|
||||
tomlkit.comment("All paths are relative to this pyproject.toml's directory."),
|
||||
)
|
||||
codeflash_section.add(tomlkit.comment("All paths are relative to this pyproject.toml's directory."))
|
||||
codeflash_section["module-root"] = setup_info.module_root
|
||||
codeflash_section["tests-root"] = setup_info.tests_root
|
||||
codeflash_section["test-framework"] = setup_info.test_framework
|
||||
|
|
@ -457,7 +415,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
|
|||
elif formatter == "other":
|
||||
formatter_cmds.append("your-formatter $file")
|
||||
click.echo(
|
||||
"🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code.",
|
||||
"🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code."
|
||||
)
|
||||
elif formatter == "don't use a formatter":
|
||||
formatter_cmds.append("disabled")
|
||||
|
|
@ -483,9 +441,7 @@ def install_github_app() -> None:
|
|||
owner, repo = get_repo_owner_and_name(git_repo)
|
||||
|
||||
if is_github_app_installed_on_repo(owner, repo):
|
||||
click.echo(
|
||||
"🐙 Looks like you've already installed the Codeflash GitHub app on this repository! Continuing…",
|
||||
)
|
||||
click.echo("🐙 Looks like you've already installed the Codeflash GitHub app on this repository! Continuing…")
|
||||
|
||||
else:
|
||||
click.prompt(
|
||||
|
|
@ -512,7 +468,7 @@ def install_github_app() -> None:
|
|||
click.echo(
|
||||
f"❌ It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo}.{LF}"
|
||||
f"You won't be able to create PRs with Codeflash until you install the app.{LF}"
|
||||
f"In the meantime you can make local only optimizations by using the '--no-pr' flag with codeflash.{LF}",
|
||||
f"In the meantime you can make local only optimizations by using the '--no-pr' flag with codeflash.{LF}"
|
||||
)
|
||||
break
|
||||
click.prompt(
|
||||
|
|
@ -549,15 +505,10 @@ def prompt_api_key() -> bool:
|
|||
existing_api_key = None
|
||||
if existing_api_key:
|
||||
display_key = f"{existing_api_key[:3]}****{existing_api_key[-4:]}"
|
||||
click.echo(
|
||||
f"🔑 I found a CODEFLASH_API_KEY in your environment [{display_key}]!",
|
||||
)
|
||||
click.echo(f"🔑 I found a CODEFLASH_API_KEY in your environment [{display_key}]!")
|
||||
|
||||
use_existing_key = inquirer_wrapper(
|
||||
inquirer.confirm,
|
||||
message="Do you want to use this key?",
|
||||
default=True,
|
||||
show_default=False,
|
||||
inquirer.confirm, message="Do you want to use this key?", default=True, show_default=False
|
||||
)
|
||||
if use_existing_key:
|
||||
ph("cli-existing-api-key-used")
|
||||
|
|
@ -584,7 +535,7 @@ def enter_api_key_and_save_to_rc() -> None:
|
|||
if not browser_launched:
|
||||
click.echo(
|
||||
f"Opening your Codeflash API key page. Grab a key from there!{LF}"
|
||||
"You can also open this link manually: https://app.codeflash.ai/app/apikeys",
|
||||
"You can also open this link manually: https://app.codeflash.ai/app/apikeys"
|
||||
)
|
||||
click.launch("https://app.codeflash.ai/app/apikeys")
|
||||
browser_launched = True # This does not work on remote consoles
|
||||
|
|
@ -664,22 +615,11 @@ def test_sort():
|
|||
|
||||
|
||||
def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None:
|
||||
command = [
|
||||
"codeflash",
|
||||
"--file",
|
||||
"bubble_sort.py",
|
||||
"--function",
|
||||
"sorter",
|
||||
]
|
||||
command = ["codeflash", "--file", "bubble_sort.py", "--function", "sorter"]
|
||||
sys.stdout.write("Running sample optimization... ")
|
||||
sys.stdout.flush()
|
||||
try:
|
||||
process = subprocess.run(
|
||||
command,
|
||||
text=True,
|
||||
cwd=args.module_root,
|
||||
check=False,
|
||||
)
|
||||
process = subprocess.run(command, text=True, cwd=args.module_root, check=False)
|
||||
finally:
|
||||
# Delete the bubble_sort.py file after the test
|
||||
Path(bubble_sort_path).unlink(missing_ok=True)
|
||||
|
|
@ -688,10 +628,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
|
|||
click.echo(f"{LF}🗑️ Deleted {bubble_sort_test_path}")
|
||||
|
||||
if process.returncode == 0:
|
||||
click.echo(
|
||||
f"{LF}✅ End-to-end test passed. Codeflash has been correctly set up!",
|
||||
)
|
||||
click.echo(f"{LF}✅ End-to-end test passed. Codeflash has been correctly set up!")
|
||||
else:
|
||||
click.echo(
|
||||
f"{LF}❌ End-to-end test failed. Please check the logs above, and take a look at https://docs.codeflash.ai/getting-started/local-installation for help and troubleshooting.",
|
||||
f"{LF}❌ End-to-end test failed. Please check the logs above, and take a look at https://docs.codeflash.ai/getting-started/local-installation for help and troubleshooting."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ console = Console(record=True)
|
|||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[
|
||||
RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False),
|
||||
],
|
||||
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],
|
||||
format=BARE_LOGGING_FORMAT,
|
||||
)
|
||||
|
||||
|
|
@ -21,9 +19,7 @@ logger = logging.getLogger("rich")
|
|||
|
||||
|
||||
def paneled_text(
|
||||
text: str,
|
||||
panel_args: dict[str, str | bool] | None = None,
|
||||
text_args: dict[str, str] | None = None,
|
||||
text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None
|
||||
) -> None:
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
|
|
|||
|
|
@ -13,15 +13,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
|
|||
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
handlers=[
|
||||
RichHandler(
|
||||
rich_tracebacks=True,
|
||||
markup=False,
|
||||
console=console,
|
||||
show_path=False,
|
||||
show_time=False,
|
||||
),
|
||||
],
|
||||
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],
|
||||
format=BARE_LOGGING_FORMAT,
|
||||
)
|
||||
logging.getLogger().setLevel(level)
|
||||
|
|
@ -31,13 +23,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
|
|||
logging.basicConfig(
|
||||
format=VERBOSE_LOGGING_FORMAT,
|
||||
handlers=[
|
||||
RichHandler(
|
||||
rich_tracebacks=True,
|
||||
markup=False,
|
||||
console=console,
|
||||
show_path=False,
|
||||
show_time=False,
|
||||
),
|
||||
RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)
|
||||
],
|
||||
force=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,9 +22,7 @@ if TYPE_CHECKING:
|
|||
|
||||
class FutureAliasedImportTransformer(cst.CSTTransformer):
|
||||
def leave_ImportFrom(
|
||||
self,
|
||||
original_node: cst.ImportFrom,
|
||||
updated_node: cst.ImportFrom,
|
||||
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
||||
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
|
||||
if (
|
||||
(updated_node_module := updated_node.module)
|
||||
|
|
@ -67,7 +65,7 @@ def add_needed_imports_from_module(
|
|||
filename=src_path.name,
|
||||
full_module_name=src_module_and_package.name,
|
||||
full_package_name=src_module_and_package.package,
|
||||
),
|
||||
)
|
||||
)
|
||||
cst.parse_module(src_module_code).visit(gatherer)
|
||||
try:
|
||||
|
|
@ -91,12 +89,7 @@ def add_needed_imports_from_module(
|
|||
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
|
||||
continue
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
RemoveImportsVisitor.remove_unused_import(
|
||||
dst_context,
|
||||
mod,
|
||||
alias_pair[0],
|
||||
asname=alias_pair[1],
|
||||
)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
|
||||
try:
|
||||
parsed_module = cst.parse_module(dst_module_code)
|
||||
|
|
@ -112,9 +105,7 @@ def add_needed_imports_from_module(
|
|||
return dst_module_code
|
||||
|
||||
|
||||
def get_code(
|
||||
functions_to_optimize: list[FunctionToOptimize],
|
||||
) -> tuple[str | None, set[tuple[str, str]]]:
|
||||
def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]:
|
||||
"""Return the code for a function or methods in a Python module. functions_to_optimize is either a singleton
|
||||
FunctionToOptimize instance, which represents either a function at the module level or a method of a class at the
|
||||
module level, or it represents a list of methods of the same class.
|
||||
|
|
@ -135,21 +126,15 @@ def get_code(
|
|||
contextual_dunder_methods: set[tuple[str, str]] = set()
|
||||
target_code: str = ""
|
||||
|
||||
def find_target(
|
||||
node_list: list[ast.stmt],
|
||||
name_parts: tuple[str, str] | tuple[str],
|
||||
) -> ast.AST | None:
|
||||
target: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Assign | ast.AnnAssign | None = (
|
||||
None
|
||||
)
|
||||
def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[str]) -> ast.AST | None:
|
||||
target: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Assign | ast.AnnAssign | None = None
|
||||
node: ast.stmt
|
||||
for node in node_list:
|
||||
if (
|
||||
# The many mypy issues will be fixed once this code moves to the backend,
|
||||
# using Type Guards as we move to 3.10+.
|
||||
# We will cover the Type Alias case on the backend since it's a 3.12 feature.
|
||||
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))
|
||||
and node.name == name_parts[0]
|
||||
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == name_parts[0]
|
||||
):
|
||||
target = node
|
||||
break
|
||||
|
|
@ -159,11 +144,7 @@ def get_code(
|
|||
and len(node.targets) == 1
|
||||
and isinstance(node.targets[0], ast.Name)
|
||||
and node.targets[0].id == name_parts[0]
|
||||
) or (
|
||||
isinstance(node, ast.AnnAssign)
|
||||
and hasattr(node.target, "id")
|
||||
and node.target.id == name_parts[0]
|
||||
):
|
||||
) or (isinstance(node, ast.AnnAssign) and hasattr(node.target, "id") and node.target.id == name_parts[0]):
|
||||
if class_skeleton:
|
||||
break
|
||||
target = node
|
||||
|
|
@ -214,16 +195,14 @@ def get_code(
|
|||
]
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
f"Error: get_code does not support inner functions: {functions_to_optimize[0].parents}",
|
||||
)
|
||||
logger.error(f"Error: get_code does not support inner functions: {functions_to_optimize[0].parents}")
|
||||
return None, set()
|
||||
elif len(functions_to_optimize[0].parents) == 0:
|
||||
qualified_name_parts_list = [(functions_to_optimize[0].function_name,)]
|
||||
else:
|
||||
logger.error(
|
||||
"Error: get_code does not support more than one level of nesting for now. "
|
||||
f"Parents: {functions_to_optimize[0].parents}",
|
||||
f"Parents: {functions_to_optimize[0].parents}"
|
||||
)
|
||||
return None, set()
|
||||
for qualified_name_parts in qualified_name_parts_list:
|
||||
|
|
@ -235,32 +214,24 @@ def get_code(
|
|||
isinstance(target_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))
|
||||
and target_node.decorator_list
|
||||
):
|
||||
target_code += "".join(
|
||||
lines[target_node.decorator_list[0].lineno - 1 : target_node.end_lineno],
|
||||
)
|
||||
target_code += "".join(lines[target_node.decorator_list[0].lineno - 1 : target_node.end_lineno])
|
||||
else:
|
||||
target_code += "".join(lines[target_node.lineno - 1 : target_node.end_lineno])
|
||||
if not target_code:
|
||||
return None, set()
|
||||
class_list: list[tuple[int, int | None]] = sorted(class_skeleton)
|
||||
class_code = "".join(
|
||||
["".join(lines[s_lineno - 1 : e_lineno]) for (s_lineno, e_lineno) in class_list],
|
||||
)
|
||||
class_code = "".join(["".join(lines[s_lineno - 1 : e_lineno]) for (s_lineno, e_lineno) in class_list])
|
||||
return class_code + target_code, contextual_dunder_methods
|
||||
|
||||
|
||||
def extract_code(
|
||||
functions_to_optimize: list[FunctionToOptimize],
|
||||
) -> tuple[str | None, set[tuple[str, str]]]:
|
||||
def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]:
|
||||
edited_code, contextual_dunder_methods = get_code(functions_to_optimize)
|
||||
if edited_code is None:
|
||||
return None, set()
|
||||
try:
|
||||
compile(edited_code, "edited_code", "exec")
|
||||
except SyntaxError as e:
|
||||
logger.exception(
|
||||
f"extract_code - Syntax error in extracted optimization candidate code: {e}",
|
||||
)
|
||||
logger.exception(f"extract_code - Syntax error in extracted optimization candidate code: {e}")
|
||||
return None, set()
|
||||
return edited_code, contextual_dunder_methods
|
||||
|
||||
|
|
|
|||
|
|
@ -20,9 +20,7 @@ def file_path_from_module_name(module_name: str, project_root_path: Path) -> Pat
|
|||
|
||||
|
||||
def get_imports_from_file(
|
||||
file_path: Path | None = None,
|
||||
file_string: str | None = None,
|
||||
file_ast: ast.AST | None = None,
|
||||
file_path: Path | None = None, file_string: str | None = None, file_ast: ast.AST | None = None
|
||||
) -> list[ast.Import | ast.ImportFrom]:
|
||||
assert (
|
||||
sum([file_path is not None, file_string is not None, file_ast is not None]) == 1
|
||||
|
|
|
|||
|
|
@ -52,16 +52,9 @@ def parse_config_file(config_file_path: Path | None = None) -> tuple[dict[str, A
|
|||
# default values:
|
||||
path_keys = ["module-root", "tests-root"]
|
||||
path_list_keys = ["ignore-paths"]
|
||||
str_keys = {
|
||||
"pytest-cmd": "pytest",
|
||||
}
|
||||
bool_keys = {
|
||||
"disable-telemetry": False,
|
||||
"disable-imports-sorting": False,
|
||||
}
|
||||
list_str_keys = {
|
||||
"formatter-cmds": ["black $file"],
|
||||
}
|
||||
str_keys = {"pytest-cmd": "pytest"}
|
||||
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False}
|
||||
list_str_keys = {"formatter-cmds": ["black $file"]}
|
||||
|
||||
for key in str_keys:
|
||||
if key in config:
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ def get_codeflash_api_key() -> Optional[str]:
|
|||
if not api_key:
|
||||
raise OSError(
|
||||
"I didn't find a Codeflash API key in your environment.\nYou can generate one at "
|
||||
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable.",
|
||||
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable."
|
||||
)
|
||||
if not api_key.startswith("cf-"):
|
||||
raise OSError(
|
||||
f"Your Codeflash API key seems to be invalid. It should start with a 'cf-' prefix; I found '{api_key}' "
|
||||
f"instead.\nYou can generate one at https://app.codeflash.ai/app/apikeys,\nthen set it as a "
|
||||
f"CODEFLASH_API_KEY environment variable.",
|
||||
f"CODEFLASH_API_KEY environment variable."
|
||||
)
|
||||
return api_key
|
||||
|
||||
|
|
@ -27,9 +27,9 @@ def ensure_codeflash_api_key() -> bool:
|
|||
try:
|
||||
get_codeflash_api_key()
|
||||
except OSError:
|
||||
logger.error( # noqa: TRY400
|
||||
logger.error(
|
||||
"Codeflash API key not found in your environment.\nYou can generate one at "
|
||||
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable.",
|
||||
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
|
@ -53,6 +53,6 @@ def ensure_pr_number() -> bool:
|
|||
if not get_pr_number():
|
||||
raise OSError(
|
||||
"CODEFLASH_PR_NUMBER not found in environment variables; make sure the Github Action is setting this so "
|
||||
"Codeflash can comment on the right PR",
|
||||
"Codeflash can comment on the right PR"
|
||||
)
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -10,10 +10,7 @@ import isort
|
|||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
|
||||
def format_code(
|
||||
formatter_cmds: list[str],
|
||||
path: Path,
|
||||
) -> str | None:
|
||||
def format_code(formatter_cmds: list[str], path: Path) -> str | None:
|
||||
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
|
||||
if not path.exists():
|
||||
logger.error(f"File {path} does not exist. Cannot format the file.")
|
||||
|
|
@ -29,12 +26,7 @@ def format_code(
|
|||
logger.info(f"Formatting code with {' '.join(formatter_cmd_list)} ...")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
formatter_cmd_list,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
check=False,
|
||||
)
|
||||
result = subprocess.run(formatter_cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to format code with {' '.join(formatter_cmd_list)}: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -17,25 +17,14 @@ if TYPE_CHECKING:
|
|||
from git import Repo
|
||||
|
||||
|
||||
def get_git_diff(
|
||||
repo_directory: Path = Path.cwd(),
|
||||
uncommitted_changes: bool = False,
|
||||
) -> dict[str, list[int]]:
|
||||
def get_git_diff(repo_directory: Path = Path.cwd(), uncommitted_changes: bool = False) -> dict[str, list[int]]:
|
||||
repository = git.Repo(repo_directory, search_parent_directories=True)
|
||||
commit = repository.head.commit
|
||||
if uncommitted_changes:
|
||||
uni_diff_text = repository.git.diff(
|
||||
None,
|
||||
"HEAD",
|
||||
ignore_blank_lines=True,
|
||||
ignore_space_at_eol=True,
|
||||
)
|
||||
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
|
||||
else:
|
||||
uni_diff_text = repository.git.diff(
|
||||
commit.hexsha + "^1",
|
||||
commit.hexsha,
|
||||
ignore_blank_lines=True,
|
||||
ignore_space_at_eol=True,
|
||||
commit.hexsha + "^1", commit.hexsha, ignore_blank_lines=True, ignore_space_at_eol=True
|
||||
)
|
||||
patch_set = PatchSet(StringIO(uni_diff_text))
|
||||
change_list: dict[str, list[int]] = {} # list of changes
|
||||
|
|
@ -47,10 +36,7 @@ def get_git_diff(
|
|||
logger.debug(f"file name: {file_path}")
|
||||
|
||||
add_line_no: list[int] = [
|
||||
line.target_line_no
|
||||
for hunk in patched_file
|
||||
for line in hunk
|
||||
if line.is_added and line.value.strip() != ""
|
||||
line.target_line_no for hunk in patched_file for line in hunk if line.is_added and line.value.strip() != ""
|
||||
] # the row number of deleted lines
|
||||
|
||||
logger.debug(f"added lines: {add_line_no}")
|
||||
|
|
@ -89,9 +75,7 @@ def get_repo_owner_and_name(repo: Repo | None = None) -> tuple[str, str]:
|
|||
remote_url = get_remote_url(repo).removesuffix(".git") if remote_url.endswith(".git") else remote_url
|
||||
split_url = remote_url.split("/")
|
||||
repo_owner_with_github, repo_name = split_url[-2], split_url[-1]
|
||||
repo_owner = (
|
||||
repo_owner_with_github.split(":")[1] if ":" in repo_owner_with_github else repo_owner_with_github
|
||||
)
|
||||
repo_owner = repo_owner_with_github.split(":")[1] if ":" in repo_owner_with_github else repo_owner_with_github
|
||||
return repo_owner, repo_name
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ def require_github_app_or_exit(owner: str, repo: str) -> None:
|
|||
f"It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo} or the GitHub"
|
||||
f" account linked to your CODEFLASH_API_KEY does not have access to the repository {owner}/{repo}.{LF}"
|
||||
"Before continuing, please install the Codeflash GitHub App on your repository by visiting "
|
||||
f"https://github.com/apps/codeflash-ai{LF}",
|
||||
f"https://github.com/apps/codeflash-ai{LF}"
|
||||
)
|
||||
logger.error(
|
||||
f"Note: if you want to find optimizations without opening PRs, you can run Codeflash with the --no-pr flag.{LF}",
|
||||
f"Note: if you want to find optimizations without opening PRs, you can run Codeflash with the --no-pr flag.{LF}"
|
||||
)
|
||||
apologize_and_exit()
|
||||
|
|
|
|||
|
|
@ -37,11 +37,7 @@ def is_argument_name(name: str, arguments_node: ast.arguments) -> bool:
|
|||
|
||||
class InjectPerfOnly(ast.NodeTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
module_path: str,
|
||||
test_framework: str,
|
||||
call_positions: list[CodePosition],
|
||||
self, function: FunctionToOptimize, module_path: str, test_framework: str, call_positions: list[CodePosition]
|
||||
) -> None:
|
||||
self.function_object = function
|
||||
self.class_name = None
|
||||
|
|
@ -53,11 +49,7 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
self.class_name = function.top_level_parent_name
|
||||
|
||||
def find_and_update_line_node(
|
||||
self,
|
||||
test_node: ast.stmt,
|
||||
node_name: str,
|
||||
index: str,
|
||||
test_class_name: str | None = None,
|
||||
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
|
||||
) -> Iterable[ast.stmt] | None:
|
||||
call_node = None
|
||||
for node in ast.walk(test_node):
|
||||
|
|
@ -123,7 +115,7 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
|
||||
args=[ast.Constant(value=15)],
|
||||
keywords=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
i = len(node.body) - 1
|
||||
while i >= 0:
|
||||
|
|
@ -136,15 +128,9 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
compound_line_node: ast.stmt = line_node.body[j]
|
||||
internal_node: ast.AST
|
||||
for internal_node in ast.walk(compound_line_node):
|
||||
if isinstance(
|
||||
internal_node,
|
||||
(ast.stmt, ast.Assign),
|
||||
):
|
||||
if isinstance(internal_node, (ast.stmt, ast.Assign)):
|
||||
updated_node = self.find_and_update_line_node(
|
||||
internal_node,
|
||||
node.name,
|
||||
str(i) + "_" + str(j),
|
||||
test_class_name,
|
||||
internal_node, node.name, str(i) + "_" + str(j), test_class_name
|
||||
)
|
||||
if updated_node is not None:
|
||||
line_node.body[j : j + 1] = updated_node
|
||||
|
|
@ -152,12 +138,7 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
break
|
||||
j -= 1
|
||||
else:
|
||||
updated_node = self.find_and_update_line_node(
|
||||
line_node,
|
||||
node.name,
|
||||
str(i),
|
||||
test_class_name,
|
||||
)
|
||||
updated_node = self.find_and_update_line_node(line_node, node.name, str(i), test_class_name)
|
||||
if updated_node is not None:
|
||||
node.body[i : i + 1] = updated_node
|
||||
did_update = True
|
||||
|
|
@ -169,9 +150,7 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
targets=[ast.Name(id="codeflash_iteration", ctx=ast.Store())],
|
||||
value=ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()),
|
||||
attr="environ",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
|
||||
),
|
||||
slice=ast.Constant(value="CODEFLASH_TEST_ITERATION"),
|
||||
ctx=ast.Load(),
|
||||
|
|
@ -186,13 +165,11 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
args=[
|
||||
ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()),
|
||||
attr="environ",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
|
||||
),
|
||||
slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"),
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
|
|
@ -203,26 +180,18 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
targets=[ast.Name(id="codeflash_con", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="sqlite3", ctx=ast.Load()),
|
||||
attr="connect",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="sqlite3", ctx=ast.Load()), attr="connect", ctx=ast.Load()
|
||||
),
|
||||
args=[
|
||||
ast.JoinedStr(
|
||||
values=[
|
||||
ast.Constant(
|
||||
value=f"{get_run_tmp_file('test_return_values_')}",
|
||||
),
|
||||
ast.Constant(value=f"{get_run_tmp_file('test_return_values_')}"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(
|
||||
id="codeflash_iteration",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
conversion=-1,
|
||||
value=ast.Name(id="codeflash_iteration", ctx=ast.Load()), conversion=-1
|
||||
),
|
||||
ast.Constant(value=".sqlite"),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
|
|
@ -233,9 +202,7 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
targets=[ast.Name(id="codeflash_cur", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_con", ctx=ast.Load()),
|
||||
attr="cursor",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="cursor", ctx=ast.Load()
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
|
|
@ -246,16 +213,14 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_cur", ctx=ast.Load()),
|
||||
attr="execute",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()
|
||||
),
|
||||
args=[
|
||||
ast.Constant(
|
||||
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
|
||||
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
|
||||
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)",
|
||||
),
|
||||
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)"
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
|
|
@ -268,14 +233,12 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_con", ctx=ast.Load()),
|
||||
attr="close",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="close", ctx=ast.Load()
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
return node
|
||||
|
|
@ -381,31 +344,16 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
targets=[ast.Name(id="test_id", ctx=ast.Store())],
|
||||
value=ast.JoinedStr(
|
||||
values=[
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="test_module_name", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="test_class_name", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="test_name", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="line_id", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="loop_index", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
],
|
||||
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
|
||||
]
|
||||
),
|
||||
lineno=lineno + 1,
|
||||
),
|
||||
|
|
@ -414,10 +362,7 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
op=ast.Not(),
|
||||
operand=ast.Call(
|
||||
func=ast.Name(id="hasattr", ctx=ast.Load()),
|
||||
args=[
|
||||
ast.Name(id="codeflash_wrap", ctx=ast.Load()),
|
||||
ast.Constant(value="index"),
|
||||
],
|
||||
args=[ast.Name(id="codeflash_wrap", ctx=ast.Load()), ast.Constant(value="index")],
|
||||
keywords=[],
|
||||
),
|
||||
),
|
||||
|
|
@ -425,14 +370,12 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
ast.Assign(
|
||||
targets=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
|
||||
attr="index",
|
||||
ctx=ast.Store(),
|
||||
),
|
||||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Store()
|
||||
)
|
||||
],
|
||||
value=ast.Dict(keys=[], values=[]),
|
||||
lineno=lineno + 3,
|
||||
),
|
||||
)
|
||||
],
|
||||
orelse=[],
|
||||
lineno=lineno + 2,
|
||||
|
|
@ -442,20 +385,14 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
left=ast.Name(id="test_id", ctx=ast.Load()),
|
||||
ops=[ast.In()],
|
||||
comparators=[
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
|
||||
attr="index",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
ast.Attribute(value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load())
|
||||
],
|
||||
),
|
||||
body=[
|
||||
ast.AugAssign(
|
||||
target=ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
|
||||
attr="index",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load()
|
||||
),
|
||||
slice=ast.Name(id="test_id", ctx=ast.Load()),
|
||||
ctx=ast.Store(),
|
||||
|
|
@ -463,36 +400,30 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
op=ast.Add(),
|
||||
value=ast.Constant(value=1),
|
||||
lineno=lineno + 5,
|
||||
),
|
||||
)
|
||||
],
|
||||
orelse=[
|
||||
ast.Assign(
|
||||
targets=[
|
||||
ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
|
||||
attr="index",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load()
|
||||
),
|
||||
slice=ast.Name(id="test_id", ctx=ast.Load()),
|
||||
ctx=ast.Store(),
|
||||
),
|
||||
)
|
||||
],
|
||||
value=ast.Constant(value=0),
|
||||
lineno=lineno + 6,
|
||||
),
|
||||
)
|
||||
],
|
||||
lineno=lineno + 4,
|
||||
),
|
||||
ast.Assign(
|
||||
targets=[
|
||||
ast.Name(id="codeflash_test_index", ctx=ast.Store()),
|
||||
],
|
||||
targets=[ast.Name(id="codeflash_test_index", ctx=ast.Store())],
|
||||
value=ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
|
||||
attr="index",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load()
|
||||
),
|
||||
slice=ast.Name(id="test_id", ctx=ast.Load()),
|
||||
ctx=ast.Load(),
|
||||
|
|
@ -503,16 +434,10 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
targets=[ast.Name(id="invocation_id", ctx=ast.Store())],
|
||||
value=ast.JoinedStr(
|
||||
values=[
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="line_id", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value="_"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="codeflash_test_index", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
],
|
||||
ast.FormattedValue(value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1),
|
||||
]
|
||||
),
|
||||
lineno=lineno + 8,
|
||||
),
|
||||
|
|
@ -524,8 +449,7 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
values=[
|
||||
ast.Constant(value="!######"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="test_module_name", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1
|
||||
),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
|
|
@ -540,39 +464,23 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="test_name", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="function_name", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="loop_index", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="invocation_id", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value="######!"),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
)
|
||||
),
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="gc", ctx=ast.Load()),
|
||||
attr="disable",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
|
|
@ -582,9 +490,7 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
targets=[ast.Name(id="counter", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="time", ctx=ast.Load()),
|
||||
attr="perf_counter_ns",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load()
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
|
|
@ -605,9 +511,7 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
value=ast.BinOp(
|
||||
left=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="time", ctx=ast.Load()),
|
||||
attr="perf_counter_ns",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load()
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
|
|
@ -619,11 +523,7 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
),
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="gc", ctx=ast.Load()),
|
||||
attr="enable",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="enable", ctx=ast.Load()),
|
||||
args=[],
|
||||
keywords=[],
|
||||
),
|
||||
|
|
@ -641,31 +541,27 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
targets=[ast.Name(id="pickled_return_value", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="pickle", ctx=ast.Load()),
|
||||
attr="dumps",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="pickle", ctx=ast.Load()), attr="dumps", ctx=ast.Load()
|
||||
),
|
||||
args=[ast.Name(id="return_value", ctx=ast.Load())],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=lineno + 15,
|
||||
),
|
||||
)
|
||||
],
|
||||
orelse=[
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="pickled_return_value", ctx=ast.Store())],
|
||||
value=ast.Constant(value=None),
|
||||
lineno=lineno + 16,
|
||||
),
|
||||
)
|
||||
],
|
||||
# lineno=lineno + 16,
|
||||
),
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_cur", ctx=ast.Load()),
|
||||
attr="execute",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()
|
||||
),
|
||||
args=[
|
||||
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)"),
|
||||
|
|
@ -690,9 +586,7 @@ def create_wrapper_function() -> ast.FunctionDef:
|
|||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="codeflash_con", ctx=ast.Load()),
|
||||
attr="commit",
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="commit", ctx=ast.Load()
|
||||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from returns.result import Failure, Result, Success
|
|||
from codeflash.code_utils.compat import LF
|
||||
|
||||
if os.name == "nt": # Windows
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.M)
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
|
||||
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
|
||||
else:
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(r'^export CODEFLASH_API_KEY="?(cf-[^\s"]+)"?$', re.M)
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(r'^export CODEFLASH_API_KEY="?(cf-[^\s"]+)"?$', re.MULTILINE)
|
||||
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
|
||||
|
||||
|
||||
|
|
@ -30,18 +30,11 @@ def get_shell_rc_path() -> Path:
|
|||
"""Get the path to the user's shell configuration file."""
|
||||
if os.name == "nt": # on Windows, we use a batch file in the user's home directory
|
||||
return Path.home() / "codeflash_env.bat"
|
||||
else:
|
||||
shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1]
|
||||
shell_rc_filename = {
|
||||
"zsh": ".zshrc",
|
||||
"ksh": ".kshrc",
|
||||
"csh": ".cshrc",
|
||||
"tcsh": ".cshrc",
|
||||
"dash": ".profile",
|
||||
}.get(
|
||||
shell, ".bashrc",
|
||||
) # map each shell to its config file and default to .bashrc
|
||||
return Path.home() / shell_rc_filename
|
||||
shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1]
|
||||
shell_rc_filename = {"zsh": ".zshrc", "ksh": ".kshrc", "csh": ".cshrc", "tcsh": ".cshrc", "dash": ".profile"}.get(
|
||||
shell, ".bashrc"
|
||||
) # map each shell to its config file and default to .bashrc
|
||||
return Path.home() / shell_rc_filename
|
||||
|
||||
|
||||
def get_api_key_export_line(api_key: str) -> str:
|
||||
|
|
@ -61,9 +54,7 @@ def save_api_key_to_rc(api_key) -> Result[str, str]:
|
|||
|
||||
if existing_api_key:
|
||||
# Replace the existing API key line
|
||||
updated_shell_contents = re.sub(
|
||||
SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents,
|
||||
)
|
||||
updated_shell_contents = re.sub(SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents)
|
||||
action = "Updated CODEFLASH_API_KEY in"
|
||||
else:
|
||||
# Append the new API key line
|
||||
|
|
@ -77,11 +68,11 @@ def save_api_key_to_rc(api_key) -> Result[str, str]:
|
|||
except PermissionError:
|
||||
return Failure(
|
||||
f"💡 I tried adding your Codeflash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}"
|
||||
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}",
|
||||
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return Failure(
|
||||
f"💡 I went to save your Codeflash API key to {shell_rc_path}, but noticed that it doesn't exist.{LF}"
|
||||
f"To ensure your Codeflash API key is automatically loaded into your environment at startup, you can create {shell_rc_path} and add the following line:{LF}"
|
||||
f"{LF}{api_key_line}{LF}",
|
||||
f"{LF}{api_key_line}{LF}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,10 +12,7 @@ def humanize_runtime(time_in_ns: int) -> str:
|
|||
|
||||
if time_in_ns / 1000 >= 1:
|
||||
time_micro = float(time_in_ns) / 1000
|
||||
runtime_human = humanize.precisedelta(
|
||||
dt.timedelta(microseconds=time_micro),
|
||||
minimum_unit="microseconds",
|
||||
)
|
||||
runtime_human = humanize.precisedelta(dt.timedelta(microseconds=time_micro), minimum_unit="microseconds")
|
||||
|
||||
units = re.split(r",|\s", runtime_human)[1]
|
||||
|
||||
|
|
|
|||
|
|
@ -50,8 +50,7 @@ class TestFunction:
|
|||
|
||||
|
||||
def discover_unit_tests(
|
||||
cfg: TestConfig,
|
||||
discover_only_these_tests: list[str] | None = None,
|
||||
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
|
||||
) -> dict[str, list[FunctionCalledInTest]]:
|
||||
if cfg.test_framework == "pytest":
|
||||
return discover_tests_pytest(cfg, discover_only_these_tests)
|
||||
|
|
@ -78,8 +77,7 @@ def run_pytest_discovery_new_process(queue: Queue, cwd: str, tests_root: str) ->
|
|||
|
||||
try:
|
||||
exitcode = pytest.main(
|
||||
[tests_root, "--collect-only", "-pno:terminal", "-m", "not skip"],
|
||||
plugins=[PytestCollectionPlugin()],
|
||||
[tests_root, "--collect-only", "-pno:terminal", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to collect tests: {e!s}")
|
||||
|
|
@ -89,9 +87,7 @@ def run_pytest_discovery_new_process(queue: Queue, cwd: str, tests_root: str) ->
|
|||
queue.put((exitcode, tests, pytest_rootdir))
|
||||
|
||||
|
||||
def parse_pytest_collection_results(
|
||||
pytest_tests: str,
|
||||
) -> list[TestsInFile]:
|
||||
def parse_pytest_collection_results(pytest_tests: str) -> list[TestsInFile]:
|
||||
test_results: list[TestsInFile] = []
|
||||
for test in pytest_tests:
|
||||
test_class = None
|
||||
|
|
@ -106,14 +102,13 @@ def parse_pytest_collection_results(
|
|||
test_function=test.name,
|
||||
test_suite=None, # not used in pytest until now
|
||||
test_type=test_type,
|
||||
),
|
||||
)
|
||||
)
|
||||
return test_results
|
||||
|
||||
|
||||
def discover_tests_pytest(
|
||||
cfg: TestConfig,
|
||||
discover_only_these_tests: list[str] | None = None,
|
||||
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
|
||||
) -> dict[str, list[FunctionCalledInTest]]:
|
||||
tests_root = cfg.tests_root
|
||||
project_root = cfg.project_root_path
|
||||
|
|
@ -140,8 +135,7 @@ def discover_tests_pytest(
|
|||
|
||||
|
||||
def discover_tests_unittest(
|
||||
cfg: TestConfig,
|
||||
discover_only_these_tests: list[str] | None = None,
|
||||
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
|
||||
) -> dict[str, list[FunctionCalledInTest]]:
|
||||
tests_root: Path = cfg.tests_root
|
||||
loader: unittest.TestLoader = unittest.TestLoader()
|
||||
|
|
@ -184,9 +178,7 @@ def discover_tests_unittest(
|
|||
if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"):
|
||||
for test_2 in test._tests:
|
||||
if not hasattr(test_2, "_testMethodName"):
|
||||
logger.warning(
|
||||
f"Didn't find tests for {test_2}",
|
||||
) # it goes deeper?
|
||||
logger.warning(f"Didn't find tests for {test_2}") # it goes deeper?
|
||||
continue
|
||||
details = get_test_details(test_2)
|
||||
if details is not None:
|
||||
|
|
@ -207,8 +199,7 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
|
|||
|
||||
|
||||
def process_test_files(
|
||||
file_to_test_map: dict[str, list[TestsInFile]],
|
||||
cfg: TestConfig,
|
||||
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
|
||||
) -> dict[str, list[FunctionCalledInTest]]:
|
||||
project_root_path = cfg.project_root_path
|
||||
test_framework = cfg.test_framework
|
||||
|
|
@ -230,9 +221,7 @@ def process_test_files(
|
|||
function_name = re.split(r"[\[\]]", function)[0]
|
||||
parameters = re.split(r"[\[\]]", function)[1]
|
||||
if name.name == function_name and name.type == "function":
|
||||
test_functions.add(
|
||||
TestFunction(name.name, None, parameters, functions[i].test_type),
|
||||
)
|
||||
test_functions.add(TestFunction(name.name, None, parameters, functions[i].test_type))
|
||||
elif name.name == function and name.type == "function":
|
||||
test_functions.add(TestFunction(name.name, None, None, functions[i].test_type))
|
||||
break
|
||||
|
|
@ -248,24 +237,17 @@ def process_test_files(
|
|||
and f".{name.name}." in def_name.full_name
|
||||
):
|
||||
for function in functions_to_search:
|
||||
(
|
||||
is_parameterized,
|
||||
new_function,
|
||||
parameters,
|
||||
) = discover_parameters_unittest(function)
|
||||
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
|
||||
|
||||
if is_parameterized and new_function == def_name.name:
|
||||
test_functions.add(
|
||||
TestFunction(
|
||||
def_name.name,
|
||||
name.name,
|
||||
parameters,
|
||||
functions[0].test_type,
|
||||
), # A test file must not have more than one test type
|
||||
def_name.name, name.name, parameters, functions[0].test_type
|
||||
) # A test file must not have more than one test type
|
||||
)
|
||||
elif function == def_name.name:
|
||||
test_functions.add(
|
||||
TestFunction(def_name.name, name.name, None, functions[0].test_type),
|
||||
TestFunction(def_name.name, name.name, None, functions[0].test_type)
|
||||
)
|
||||
|
||||
test_functions_list = list(test_functions)
|
||||
|
|
@ -285,10 +267,7 @@ def process_test_files(
|
|||
scope_parameters = test_functions_list[index].parameters
|
||||
test_type = test_functions_list[index].test_type
|
||||
try:
|
||||
definition = name.goto(
|
||||
follow_imports=True,
|
||||
follow_builtin_imports=False,
|
||||
)
|
||||
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
except Exception as e:
|
||||
logger.exception(str(e))
|
||||
continue
|
||||
|
|
@ -306,9 +285,7 @@ def process_test_files(
|
|||
if test_framework == "unittest":
|
||||
scope_test_function += "_" + scope_parameters
|
||||
full_name_without_module_prefix = definition[0].full_name.replace(
|
||||
definition[0].module_name + ".",
|
||||
"",
|
||||
1,
|
||||
definition[0].module_name + ".", "", 1
|
||||
)
|
||||
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
|
||||
function_to_test_map[qualified_name_with_modules_from_root].append(
|
||||
|
|
@ -320,11 +297,8 @@ def process_test_files(
|
|||
test_suite=scope_test_suite,
|
||||
test_type=test_type,
|
||||
),
|
||||
position=CodePosition(
|
||||
line_no=name.line,
|
||||
col_no=name.column,
|
||||
),
|
||||
),
|
||||
position=CodePosition(line_no=name.line, col_no=name.column),
|
||||
)
|
||||
)
|
||||
deduped_function_to_test_map = {}
|
||||
for function, tests in function_to_test_map.items():
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class FunctionVisitor(cst.CSTVisitor):
|
|||
parents=list(reversed(ast_parents)),
|
||||
starting_line=pos.start.line,
|
||||
ending_line=pos.end.line,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -89,11 +89,7 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
|
|||
# Check if the function has a return statement and add it to the list
|
||||
if function_has_return_statement(node):
|
||||
self.functions.append(
|
||||
FunctionToOptimize(
|
||||
function_name=node.name,
|
||||
file_path=self.file_path,
|
||||
parents=self.ast_path[:],
|
||||
),
|
||||
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
|
||||
)
|
||||
# Continue visiting the body of the function to find nested functions
|
||||
self.generic_visit(node)
|
||||
|
|
@ -170,7 +166,7 @@ def get_functions_to_optimize(
|
|||
bool(optimize_all),
|
||||
bool(replay_test),
|
||||
bool(file),
|
||||
],
|
||||
]
|
||||
)
|
||||
<= 1
|
||||
), "Only one of optimize_all, replay_test, or file should be provided"
|
||||
|
|
@ -180,9 +176,7 @@ def get_functions_to_optimize(
|
|||
functions = get_all_files_and_functions(Path(optimize_all))
|
||||
elif replay_test is not None:
|
||||
functions = get_all_replay_test_functions(
|
||||
replay_test=replay_test,
|
||||
test_cfg=test_cfg,
|
||||
project_root_path=project_root,
|
||||
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
|
||||
)
|
||||
|
||||
elif file is not None:
|
||||
|
|
@ -213,11 +207,7 @@ def get_functions_to_optimize(
|
|||
ph("cli-optimizing-git-diff")
|
||||
functions = get_functions_within_git_diff()
|
||||
filtered_modified_functions, functions_count = filter_functions(
|
||||
functions,
|
||||
test_cfg.tests_root,
|
||||
ignore_paths,
|
||||
project_root,
|
||||
module_root,
|
||||
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
|
||||
)
|
||||
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
|
||||
return filtered_modified_functions, functions_count
|
||||
|
|
@ -277,9 +267,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[str, list[FunctionToOpti
|
|||
|
||||
|
||||
def get_all_replay_test_functions(
|
||||
replay_test: str,
|
||||
test_cfg: TestConfig,
|
||||
project_root_path: Path,
|
||||
replay_test: str, test_cfg: TestConfig, project_root_path: Path
|
||||
) -> dict[str, list[FunctionToOptimize]]:
|
||||
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
|
||||
# Get the absolute file paths for each function, excluding class name if present
|
||||
|
|
@ -295,8 +283,7 @@ def get_all_replay_test_functions(
|
|||
module_path_parts[-1]
|
||||
if module_path_parts
|
||||
and is_class_defined_in_file(
|
||||
module_path_parts[-1],
|
||||
Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py"),
|
||||
module_path_parts[-1], Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py")
|
||||
)
|
||||
else None
|
||||
)
|
||||
|
|
@ -310,9 +297,7 @@ def get_all_replay_test_functions(
|
|||
file_path = Path(project_root_path, *file_path_parts).with_suffix(".py")
|
||||
file_to_functions_map[file_path].append((function, function_name, class_name))
|
||||
for file_path, functions in file_to_functions_map.items():
|
||||
all_valid_functions: dict[str, list[FunctionToOptimize]] = find_all_functions_in_file(
|
||||
file_path=file_path,
|
||||
)
|
||||
all_valid_functions: dict[str, list[FunctionToOptimize]] = find_all_functions_in_file(file_path=file_path)
|
||||
filtered_list = []
|
||||
for function in functions:
|
||||
function_name, function_name_only, class_name = function
|
||||
|
|
@ -321,7 +306,7 @@ def get_all_replay_test_functions(
|
|||
valid_function
|
||||
for valid_function in all_valid_functions[file_path]
|
||||
if valid_function.qualified_name == function_name
|
||||
],
|
||||
]
|
||||
)
|
||||
if len(filtered_list):
|
||||
filtered_valid_functions[file_path] = filtered_list
|
||||
|
|
@ -341,19 +326,13 @@ def is_git_repo(file_path: str) -> bool:
|
|||
def ignored_submodule_paths(module_root: str) -> list[str]:
|
||||
if is_git_repo(module_root):
|
||||
git_repo = git.Repo(module_root, search_parent_directories=True)
|
||||
return [
|
||||
Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules
|
||||
]
|
||||
return [Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules]
|
||||
return []
|
||||
|
||||
|
||||
class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
||||
def __init__(
|
||||
self,
|
||||
file_name: Path,
|
||||
function_or_method_name: str,
|
||||
class_name: str | None = None,
|
||||
line_no: int | None = None,
|
||||
self, file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None
|
||||
) -> None:
|
||||
self.file_name = file_name
|
||||
self.class_name = class_name
|
||||
|
|
@ -374,7 +353,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
bool(node.args.kwarg),
|
||||
bool(node.args.posonlyargs),
|
||||
bool(node.args.vararg),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
|
|
@ -410,10 +389,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
|
||||
|
||||
def inspect_top_level_functions_or_methods(
|
||||
file_name: Path,
|
||||
function_or_method_name: str,
|
||||
class_name: str | None = None,
|
||||
line_no: int | None = None,
|
||||
file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None
|
||||
) -> FunctionProperties:
|
||||
with open(file_name, encoding="utf8") as file:
|
||||
try:
|
||||
|
|
@ -422,10 +398,7 @@ def inspect_top_level_functions_or_methods(
|
|||
logger.exception(e)
|
||||
return False
|
||||
visitor = TopLevelFunctionOrMethodVisitor(
|
||||
file_name=file_name,
|
||||
function_or_method_name=function_or_method_name,
|
||||
class_name=class_name,
|
||||
line_no=line_no,
|
||||
file_name=file_name, function_or_method_name=function_or_method_name, class_name=class_name, line_no=line_no
|
||||
)
|
||||
visitor.visit(ast_module)
|
||||
staticmethod_class_name = visitor.class_name if visitor.is_staticmethod else None
|
||||
|
|
@ -495,9 +468,7 @@ def filter_functions(
|
|||
path = Path(function.file_path).name
|
||||
if path in blocklist_funcs and function.function_name in blocklist_funcs[path]:
|
||||
functions.remove(function)
|
||||
logger.debug(
|
||||
f"Skipping {function.function_name} in {path} as it has already been optimized",
|
||||
)
|
||||
logger.debug(f"Skipping {function.function_name} in {path} as it has already been optimized")
|
||||
continue
|
||||
|
||||
filtered_modified_functions[file_path] = functions
|
||||
|
|
@ -517,12 +488,7 @@ def filter_functions(
|
|||
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
|
||||
|
||||
|
||||
def filter_files_optimized(
|
||||
file_path: Path,
|
||||
tests_root: Path,
|
||||
ignore_paths: list[Path],
|
||||
module_root: Path,
|
||||
) -> bool:
|
||||
def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool:
|
||||
"""Optimized version of the filter_functions function above.
|
||||
|
||||
Takes in file paths and returns the count of files that are to be optimized.
|
||||
|
|
@ -530,9 +496,7 @@ def filter_files_optimized(
|
|||
submodule_paths = None
|
||||
if file_path.is_relative_to(tests_root):
|
||||
return False
|
||||
if file_path in ignore_paths or any(
|
||||
file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths
|
||||
):
|
||||
if file_path in ignore_paths or any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths):
|
||||
return False
|
||||
if path_belongs_to_site_packages(file_path):
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -16,9 +16,7 @@ from codeflash.telemetry.sentry import init_sentry
|
|||
def main() -> None:
|
||||
"""Entry point for the codeflash command-line interface."""
|
||||
paneled_text(
|
||||
CODEFLASH_LOGO,
|
||||
panel_args={"title": "https://codeflash.ai", "expand": False},
|
||||
text_args={"style": "bold gold3"},
|
||||
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}
|
||||
)
|
||||
args = parse_args()
|
||||
if args.command:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Iterator, Optional
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from jedi.api.classes import Name
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -66,9 +66,7 @@ class TestFiles(BaseModel):
|
|||
test_files: list[TestFile]
|
||||
|
||||
def get_by_type(self, test_type: TestType) -> TestFiles:
|
||||
return TestFiles(
|
||||
test_files=[test_file for test_file in self.test_files if test_file.test_type == test_type],
|
||||
)
|
||||
return TestFiles(test_files=[test_file for test_file in self.test_files if test_file.test_type == test_type])
|
||||
|
||||
def add(self, test_file: TestFile) -> None:
|
||||
if test_file not in self.test_files:
|
||||
|
|
@ -77,29 +75,17 @@ class TestFiles(BaseModel):
|
|||
raise ValueError("Test file already exists in the list")
|
||||
|
||||
def get_by_original_file_path(self, file_path: Path) -> TestFile | None:
|
||||
return next(
|
||||
(test_file for test_file in self.test_files if test_file.original_file_path == file_path),
|
||||
None,
|
||||
)
|
||||
return next((test_file for test_file in self.test_files if test_file.original_file_path == file_path), None)
|
||||
|
||||
def get_test_type_by_instrumented_file_path(self, file_path: Path) -> TestType | None:
|
||||
return next(
|
||||
(
|
||||
test_file.test_type
|
||||
for test_file in self.test_files
|
||||
if test_file.instrumented_file_path == file_path
|
||||
),
|
||||
(test_file.test_type for test_file in self.test_files if test_file.instrumented_file_path == file_path),
|
||||
None,
|
||||
)
|
||||
|
||||
def get_test_type_by_original_file_path(self, file_path: Path) -> TestType | None:
|
||||
return next(
|
||||
(
|
||||
test_file.test_type
|
||||
for test_file in self.test_files
|
||||
if test_file.original_file_path == file_path
|
||||
),
|
||||
None,
|
||||
(test_file.test_type for test_file in self.test_files if test_file.original_file_path == file_path), None
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[TestFile]:
|
||||
|
|
|
|||
|
|
@ -36,9 +36,7 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
|
|||
|
||||
|
||||
def get_type_annotation_context(
|
||||
function: FunctionToOptimize,
|
||||
jedi_script: jedi.Script,
|
||||
project_root_path: Path,
|
||||
function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path
|
||||
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
|
||||
function_name: str = function.function_name
|
||||
file_path: Path = function.file_path
|
||||
|
|
@ -53,18 +51,11 @@ def get_type_annotation_context(
|
|||
contextual_dunder_methods = set()
|
||||
|
||||
def get_annotation_source(
|
||||
j_script: jedi.Script,
|
||||
name: str,
|
||||
node_parents: list[FunctionParent],
|
||||
line_no: int,
|
||||
col_no: str,
|
||||
j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str
|
||||
) -> None:
|
||||
try:
|
||||
definition: list[Name] = j_script.goto(
|
||||
line=line_no,
|
||||
column=col_no,
|
||||
follow_imports=True,
|
||||
follow_builtin_imports=False,
|
||||
line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False
|
||||
)
|
||||
except Exception as ex:
|
||||
if hasattr(name, "full_name"):
|
||||
|
|
@ -82,15 +73,7 @@ def get_type_annotation_context(
|
|||
and not path_belongs_to_site_packages(definition_path)
|
||||
and not belongs_to_function(definition[0], function_name)
|
||||
):
|
||||
source_code = get_code(
|
||||
[
|
||||
FunctionToOptimize(
|
||||
definition[0].name,
|
||||
definition_path,
|
||||
node_parents[:-1],
|
||||
),
|
||||
],
|
||||
)
|
||||
source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])])
|
||||
if source_code[0]:
|
||||
sources.append(
|
||||
FunctionSource(
|
||||
|
|
@ -98,25 +81,21 @@ def get_type_annotation_context(
|
|||
jedi_definition=definition[0],
|
||||
source_code=source_code[0],
|
||||
file_path=definition_path,
|
||||
qualified_name=definition[0].full_name.removeprefix(
|
||||
definition[0].module_name + ".",
|
||||
),
|
||||
qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."),
|
||||
only_function_name=definition[0].name,
|
||||
),
|
||||
)
|
||||
)
|
||||
contextual_dunder_methods.update(source_code[1])
|
||||
|
||||
def visit_children(
|
||||
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module,
|
||||
node_parents: list[FunctionParent],
|
||||
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent]
|
||||
) -> None:
|
||||
child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module
|
||||
for child in ast.iter_child_nodes(node):
|
||||
visit(child, node_parents)
|
||||
|
||||
def visit_all_annotation_children(
|
||||
node: ast.Subscript | ast.Name | ast.BinOp,
|
||||
node_parents: list[FunctionParent],
|
||||
node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent]
|
||||
) -> None:
|
||||
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
visit_all_annotation_children(node.left, node_parents)
|
||||
|
|
@ -165,8 +144,7 @@ def get_type_annotation_context(
|
|||
|
||||
|
||||
def get_function_variables_definitions(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: Path,
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path
|
||||
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
|
||||
function_name = function_to_optimize.function_name
|
||||
file_path = function_to_optimize.file_path
|
||||
|
|
@ -181,20 +159,16 @@ def get_function_variables_definitions(
|
|||
if ref.full_name:
|
||||
if function_to_optimize.parents:
|
||||
# Check if the reference belongs to the specified class when FunctionParent is provided
|
||||
if belongs_to_class(
|
||||
ref,
|
||||
function_to_optimize.parents[-1].name,
|
||||
) and belongs_to_function(ref, function_name):
|
||||
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
|
||||
ref, function_name
|
||||
):
|
||||
names.append(ref)
|
||||
elif belongs_to_function(ref, function_name):
|
||||
names.append(ref)
|
||||
|
||||
for name in names:
|
||||
try:
|
||||
definitions: list[Name] = name.goto(
|
||||
follow_imports=True,
|
||||
follow_builtin_imports=False,
|
||||
)
|
||||
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
except Exception as e:
|
||||
try:
|
||||
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
|
||||
|
|
@ -221,13 +195,7 @@ def get_function_variables_definitions(
|
|||
parents = [FunctionParent(m.group(1), "ClassDef")]
|
||||
|
||||
source_code = get_code(
|
||||
[
|
||||
FunctionToOptimize(
|
||||
function_name=definitions[0].name,
|
||||
file_path=definition_path,
|
||||
parents=parents,
|
||||
),
|
||||
],
|
||||
[FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)]
|
||||
)
|
||||
if source_code[0]:
|
||||
sources.append(
|
||||
|
|
@ -236,24 +204,18 @@ def get_function_variables_definitions(
|
|||
jedi_definition=definition,
|
||||
source_code=source_code[0],
|
||||
file_path=definition_path,
|
||||
qualified_name=definition.full_name.removeprefix(
|
||||
definition.module_name + ".",
|
||||
),
|
||||
qualified_name=definition.full_name.removeprefix(definition.module_name + "."),
|
||||
only_function_name=definition.name,
|
||||
),
|
||||
)
|
||||
)
|
||||
contextual_dunder_methods.update(source_code[1])
|
||||
annotation_sources, annotation_dunder_methods = get_type_annotation_context(
|
||||
function_to_optimize,
|
||||
script,
|
||||
project_root_path,
|
||||
function_to_optimize, script, project_root_path
|
||||
)
|
||||
sources[:0] = annotation_sources # prepend the annotation sources
|
||||
contextual_dunder_methods.update(annotation_dunder_methods)
|
||||
existing_fully_qualified_names = set()
|
||||
no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(
|
||||
lambda: defaultdict(set),
|
||||
)
|
||||
no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set))
|
||||
parent_sources = set()
|
||||
for source in sources:
|
||||
if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names:
|
||||
|
|
@ -269,10 +231,7 @@ def get_function_variables_definitions(
|
|||
or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path]
|
||||
]
|
||||
deduped_no_parent_sources = [
|
||||
source
|
||||
for k1 in no_parent_sources
|
||||
for k2 in no_parent_sources[k1]
|
||||
for source in no_parent_sources[k1][k2]
|
||||
source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2]
|
||||
]
|
||||
return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods
|
||||
|
||||
|
|
@ -288,10 +247,7 @@ def get_constrained_function_context_and_helper_functions(
|
|||
) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]:
|
||||
# TODO: Not just do static analysis, but also find the datatypes of function arguments by running the existing
|
||||
# unittests and inspecting the arguments to resolve the real definitions and dependencies.
|
||||
helper_functions, dunder_methods = get_function_variables_definitions(
|
||||
function_to_optimize,
|
||||
project_root_path,
|
||||
)
|
||||
helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path)
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
code_to_optimize_tokens = tokenizer.encode(code_to_optimize)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,22 +14,12 @@ import libcst as cst
|
|||
from returns.pipeline import is_successful
|
||||
from returns.result import Failure, Success
|
||||
|
||||
from codeflash.api.aiservice import (
|
||||
AiServiceClient,
|
||||
LocalAiServiceClient,
|
||||
)
|
||||
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
||||
from codeflash.cli_cmds.console import code_print, logger
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_extractor import (
|
||||
add_needed_imports_from_module,
|
||||
extract_code,
|
||||
find_preexisting_objects,
|
||||
)
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code, find_preexisting_objects
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.code_utils.code_utils import (
|
||||
get_run_tmp_file,
|
||||
module_name_from_file_path,
|
||||
)
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
from codeflash.code_utils.config_consts import (
|
||||
INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
N_CANDIDATES,
|
||||
|
|
@ -37,21 +27,11 @@ from codeflash.code_utils.config_consts import (
|
|||
TOTAL_LOOPING_TIME,
|
||||
)
|
||||
from codeflash.code_utils.formatter import format_code, sort_imports
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.code_utils.remove_generated_tests import (
|
||||
remove_functions_from_generated_tests,
|
||||
)
|
||||
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.discovery.discover_unit_tests import (
|
||||
discover_unit_tests,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
FunctionParent,
|
||||
FunctionToOptimize,
|
||||
get_functions_to_optimize,
|
||||
)
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize, get_functions_to_optimize
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
from codeflash.models.models import (
|
||||
BestOptimization,
|
||||
|
|
@ -64,9 +44,7 @@ from codeflash.models.models import (
|
|||
TestFile,
|
||||
TestFiles,
|
||||
)
|
||||
from codeflash.optimization.function_context import (
|
||||
get_constrained_function_context_and_helper_functions,
|
||||
)
|
||||
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
|
||||
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
||||
from codeflash.result.critic import performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.result.explanation import Explanation
|
||||
|
|
@ -83,15 +61,9 @@ if TYPE_CHECKING:
|
|||
|
||||
from returns.result import Result
|
||||
|
||||
from codeflash.api.aiservice import (
|
||||
OptimizedCandidate,
|
||||
)
|
||||
from codeflash.discovery.discover_unit_tests import (
|
||||
FunctionCalledInTest,
|
||||
)
|
||||
from codeflash.models.models import (
|
||||
FunctionSource,
|
||||
)
|
||||
from codeflash.api.aiservice import OptimizedCandidate
|
||||
from codeflash.discovery.discover_unit_tests import FunctionCalledInTest
|
||||
from codeflash.models.models import FunctionSource
|
||||
|
||||
|
||||
class Optimizer:
|
||||
|
|
@ -121,10 +93,7 @@ class Optimizer:
|
|||
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
|
||||
num_optimizable_functions: int
|
||||
|
||||
(
|
||||
file_to_funcs_to_optimize,
|
||||
num_optimizable_functions,
|
||||
) = get_functions_to_optimize(
|
||||
(file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize(
|
||||
optimize_all=self.args.all,
|
||||
replay_test=self.args.replay_test,
|
||||
file=self.args.file,
|
||||
|
|
@ -140,24 +109,15 @@ class Optimizer:
|
|||
function_iterator_count: int = 0
|
||||
|
||||
try:
|
||||
ph(
|
||||
"cli-optimize-functions-to-optimize",
|
||||
{"num_functions": num_optimizable_functions},
|
||||
)
|
||||
ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions})
|
||||
if num_optimizable_functions == 0:
|
||||
logger.info("No functions found to optimize. Exiting...")
|
||||
return
|
||||
|
||||
logger.info(f"Discovering existing unit tests in {self.test_cfg.tests_root} ...")
|
||||
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(
|
||||
self.test_cfg,
|
||||
)
|
||||
num_discovered_tests: int = sum(
|
||||
[len(value) for value in function_to_tests.values()],
|
||||
)
|
||||
logger.info(
|
||||
f"Discovered {num_discovered_tests} existing unit tests in {self.test_cfg.tests_root}",
|
||||
)
|
||||
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
|
||||
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
|
||||
logger.info(f"Discovered {num_discovered_tests} existing unit tests in {self.test_cfg.tests_root}")
|
||||
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
|
||||
for path in file_to_funcs_to_optimize:
|
||||
logger.info(f"Examining file {path} ...")
|
||||
|
|
@ -168,14 +128,10 @@ class Optimizer:
|
|||
function_iterator_count += 1
|
||||
logger.info(
|
||||
f"Optimizing function {function_iterator_count} of {num_optimizable_functions} - "
|
||||
f"{function_to_optimize.qualified_name}",
|
||||
f"{function_to_optimize.qualified_name}"
|
||||
)
|
||||
|
||||
best_optimization = self.optimize_function(
|
||||
function_to_optimize,
|
||||
function_to_tests,
|
||||
original_code,
|
||||
)
|
||||
best_optimization = self.optimize_function(function_to_optimize, function_to_tests, original_code)
|
||||
self.test_files = TestFiles(test_files=[])
|
||||
if is_successful(best_optimization):
|
||||
optimizations_found += 1
|
||||
|
|
@ -207,11 +163,7 @@ class Optimizer:
|
|||
logger.debug(f"Function Trace ID: {function_trace_id}")
|
||||
ph("cli-optimize-function-start", {"function_trace_id": function_trace_id})
|
||||
self.cleanup_leftover_test_return_values()
|
||||
ctx_result = self.get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
self.args.project_root,
|
||||
original_code,
|
||||
)
|
||||
ctx_result = self.get_code_optimization_context(function_to_optimize, self.args.project_root, original_code)
|
||||
if not is_successful(ctx_result):
|
||||
return Failure(ctx_result.failure())
|
||||
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
||||
|
|
@ -236,8 +188,7 @@ class Optimizer:
|
|||
)
|
||||
|
||||
instrumented_unittests_created_for_function = self.instrument_existing_tests(
|
||||
function_to_optimize=function_to_optimize,
|
||||
function_to_tests=function_to_tests,
|
||||
function_to_optimize=function_to_optimize, function_to_tests=function_to_tests
|
||||
)
|
||||
|
||||
logger.info(f"Generating new tests for function {function_to_optimize.function_name} ...")
|
||||
|
|
@ -257,12 +208,7 @@ class Optimizer:
|
|||
|
||||
count_tests = len(generated_tests.generated_tests)
|
||||
generated_tests_paths = [
|
||||
get_test_file_path(
|
||||
self.args.tests_root,
|
||||
function_to_optimize.function_name,
|
||||
i,
|
||||
)
|
||||
for i in range(count_tests)
|
||||
get_test_file_path(self.args.tests_root, function_to_optimize.function_name, i) for i in range(count_tests)
|
||||
]
|
||||
|
||||
for i, generated_test in enumerate(generated_tests.generated_tests):
|
||||
|
|
@ -275,7 +221,7 @@ class Optimizer:
|
|||
original_file_path=None,
|
||||
original_source=generated_test.generated_original_test_source,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info(f"Generated test {i + 1}/{count_tests}:")
|
||||
code_print(generated_test.generated_original_test_source)
|
||||
|
|
@ -297,15 +243,12 @@ class Optimizer:
|
|||
# TODO: Postprocess the optimized function to include the original docstring and such
|
||||
|
||||
best_optimization = None
|
||||
for u, candidates in enumerate(
|
||||
[optimizations_set.control, optimizations_set.experiment],
|
||||
):
|
||||
for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
|
||||
if candidates is None:
|
||||
continue
|
||||
|
||||
tests_in_file: list[FunctionCalledInTest] = function_to_tests.get(
|
||||
function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root),
|
||||
[],
|
||||
function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root), []
|
||||
)
|
||||
|
||||
best_optimization = self.determine_best_candidate(
|
||||
|
|
@ -315,16 +258,13 @@ class Optimizer:
|
|||
original_code=original_code,
|
||||
original_code_baseline=original_code_baseline,
|
||||
original_helper_code=original_helper_code,
|
||||
function_trace_id=function_trace_id[:-4] + f"EXP{u}"
|
||||
if should_run_experiment
|
||||
else function_trace_id,
|
||||
function_trace_id=function_trace_id[:-4] + f"EXP{u}" if should_run_experiment else function_trace_id,
|
||||
only_run_this_test_function=tests_in_file,
|
||||
)
|
||||
ph("cli-optimize-function-finished", {"function_trace_id": function_trace_id})
|
||||
|
||||
generated_tests = remove_functions_from_generated_tests(
|
||||
generated_tests=generated_tests,
|
||||
test_functions_to_remove=test_functions_to_remove,
|
||||
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
|
||||
)
|
||||
|
||||
if best_optimization:
|
||||
|
|
@ -340,12 +280,7 @@ class Optimizer:
|
|||
file_path=function_to_optimize.file_path,
|
||||
)
|
||||
|
||||
self.log_successful_optimization(
|
||||
explanation,
|
||||
function_to_optimize,
|
||||
function_trace_id,
|
||||
generated_tests,
|
||||
)
|
||||
self.log_successful_optimization(explanation, function_to_optimize, function_trace_id, generated_tests)
|
||||
|
||||
self.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context,
|
||||
|
|
@ -355,9 +290,7 @@ class Optimizer:
|
|||
)
|
||||
|
||||
new_code, new_helper_code = self.reformat_code_and_helpers(
|
||||
code_context.helper_functions,
|
||||
explanation.file_path,
|
||||
original_code,
|
||||
code_context.helper_functions, explanation.file_path, original_code
|
||||
)
|
||||
|
||||
existing_tests = existing_tests_source_for(
|
||||
|
|
@ -377,7 +310,7 @@ class Optimizer:
|
|||
explanation=explanation,
|
||||
existing_tests_source=existing_tests,
|
||||
generated_original_test_source="\n".join(
|
||||
[test.generated_original_test_source for test in generated_tests.generated_tests],
|
||||
[test.generated_original_test_source for test in generated_tests.generated_tests]
|
||||
),
|
||||
function_trace_id=function_trace_id,
|
||||
)
|
||||
|
|
@ -386,17 +319,14 @@ class Optimizer:
|
|||
# a) Error propagation, where error in one function can cause the next optimization to fail
|
||||
# b) Performance estimates become unstable, as the runtime of an optimization might be
|
||||
# dependent on the runtime of the previous optimization
|
||||
self.write_code_and_helpers(
|
||||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
)
|
||||
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
|
||||
for generated_test_path in generated_tests_paths:
|
||||
generated_test_path.unlink(missing_ok=True)
|
||||
for test_paths in instrumented_unittests_created_for_function:
|
||||
test_paths.unlink(missing_ok=True)
|
||||
if not best_optimization:
|
||||
return Failure(f"No best optimizations found for function {function_to_optimize.qualified_name}")
|
||||
logger.info("----------------")
|
||||
return Success(best_optimization)
|
||||
|
||||
def determine_best_candidate(
|
||||
|
|
@ -419,7 +349,7 @@ class Optimizer:
|
|||
is_correct = {}
|
||||
|
||||
logger.info(
|
||||
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ...",
|
||||
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ..."
|
||||
)
|
||||
try:
|
||||
for candidate_index, candidate in enumerate(candidates, start=1):
|
||||
|
|
@ -436,21 +366,12 @@ class Optimizer:
|
|||
)
|
||||
if not did_update:
|
||||
logger.warning(
|
||||
"No functions were replaced in the optimized code. Skipping optimization candidate.",
|
||||
"No functions were replaced in the optimized code. Skipping optimization candidate."
|
||||
)
|
||||
continue
|
||||
except (
|
||||
ValueError,
|
||||
SyntaxError,
|
||||
cst.ParserSyntaxError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
|
||||
logger.error(e)
|
||||
self.write_code_and_helpers(
|
||||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
)
|
||||
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
|
||||
continue
|
||||
|
||||
# Run generated tests if at least one of them passed
|
||||
|
|
@ -477,34 +398,27 @@ class Optimizer:
|
|||
optimized_runtimes[candidate.optimization_id] = best_test_runtime
|
||||
is_correct[candidate.optimization_id] = True
|
||||
perf_gain = performance_gain(
|
||||
original_runtime_ns=original_code_baseline.runtime,
|
||||
optimized_runtime_ns=best_test_runtime,
|
||||
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime
|
||||
)
|
||||
speedup_ratios[candidate.optimization_id] = perf_gain
|
||||
loop_count = (
|
||||
max(all_loop_indices)
|
||||
if (
|
||||
all_loop_indices := {
|
||||
result.loop_index for result in candidate_result.best_test_results
|
||||
}
|
||||
)
|
||||
if (all_loop_indices := {result.loop_index for result in candidate_result.best_test_results})
|
||||
else 1
|
||||
)
|
||||
logger.info(
|
||||
f"Candidate code runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}: {humanize_runtime(best_test_runtime)} per full loop.\n"
|
||||
f"Speedup ratio: {perf_gain:.3f}",
|
||||
f"Speedup ratio: {perf_gain:.3f}"
|
||||
)
|
||||
|
||||
if speedup_critic(
|
||||
candidate_result,
|
||||
original_code_baseline.runtime,
|
||||
best_runtime_until_now,
|
||||
candidate_result, original_code_baseline.runtime, best_runtime_until_now
|
||||
) and quantity_of_tests_critic(candidate_result):
|
||||
logger.info("This candidate is faster than the previous best candidate.")
|
||||
logger.info(
|
||||
f"Original runtime: {humanize_runtime(original_code_baseline.runtime)}\n"
|
||||
f"Best test runtime: {humanize_runtime(candidate_result.best_test_runtime)}\n"
|
||||
f"Speedup ratio: {perf_gain:.3f}",
|
||||
f"Speedup ratio: {perf_gain:.3f}"
|
||||
)
|
||||
best_optimization = BestOptimization(
|
||||
candidate=candidate,
|
||||
|
|
@ -514,18 +428,9 @@ class Optimizer:
|
|||
)
|
||||
best_runtime_until_now = best_test_runtime
|
||||
|
||||
self.write_code_and_helpers(
|
||||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
)
|
||||
logger.info("----------------")
|
||||
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
|
||||
except KeyboardInterrupt as e:
|
||||
self.write_code_and_helpers(
|
||||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
)
|
||||
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
|
||||
logger.exception(f"Optimization interrupted: {e}")
|
||||
raise e
|
||||
|
||||
|
|
@ -545,9 +450,7 @@ class Optimizer:
|
|||
function_trace_id: str,
|
||||
generated_tests: GeneratedTestsList,
|
||||
) -> None:
|
||||
logger.info(
|
||||
f"⚡️ Optimization successful! 📄 {function_to_optimize.qualified_name} in {explanation.file_path}",
|
||||
)
|
||||
logger.info(f"⚡️ Optimization successful! 📄 {function_to_optimize.qualified_name} in {explanation.file_path}")
|
||||
logger.info(f"📈 {explanation.perf_improvement_line}")
|
||||
logger.info(f"Explanation: \n{explanation.to_console_string()}")
|
||||
|
||||
|
|
@ -572,11 +475,7 @@ class Optimizer:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def write_code_and_helpers(
|
||||
original_code: str,
|
||||
original_helper_code: dict[Path, str],
|
||||
path: Path,
|
||||
) -> None:
|
||||
def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, str], path: Path) -> None:
|
||||
with path.open("w", encoding="utf8") as f:
|
||||
f.write(original_code)
|
||||
for module_abspath in original_helper_code:
|
||||
|
|
@ -584,29 +483,20 @@ class Optimizer:
|
|||
f.write(original_helper_code[module_abspath])
|
||||
|
||||
def reformat_code_and_helpers(
|
||||
self,
|
||||
helper_functions: list[FunctionSource],
|
||||
path: Path,
|
||||
original_code: str,
|
||||
self, helper_functions: list[FunctionSource], path: Path, original_code: str
|
||||
) -> tuple[str, dict[Path, str]]:
|
||||
should_sort_imports = not self.args.disable_imports_sorting
|
||||
if should_sort_imports and isort.code(original_code) != original_code:
|
||||
should_sort_imports = False
|
||||
|
||||
new_code = format_code(
|
||||
self.args.formatter_cmds,
|
||||
path,
|
||||
)
|
||||
new_code = format_code(self.args.formatter_cmds, path)
|
||||
if should_sort_imports and new_code is not None:
|
||||
new_code = sort_imports(new_code)
|
||||
|
||||
new_helper_code: dict[Path, str] = {}
|
||||
helper_functions_paths = {hf.file_path for hf in helper_functions}
|
||||
for module_abspath in helper_functions_paths:
|
||||
formatted_helper_code = format_code(
|
||||
self.args.formatter_cmds,
|
||||
module_abspath,
|
||||
)
|
||||
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
|
||||
if should_sort_imports and formatted_helper_code is not None:
|
||||
formatted_helper_code = sort_imports(formatted_helper_code)
|
||||
if formatted_helper_code is not None:
|
||||
|
|
@ -633,13 +523,8 @@ class Optimizer:
|
|||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
for helper_function in code_context.helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(
|
||||
helper_function.qualified_name,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in helper_functions_by_module_abspath.items():
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
|
||||
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
|
||||
did_update |= replace_function_definitions_in_module(
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optimized_code,
|
||||
|
|
@ -652,24 +537,13 @@ class Optimizer:
|
|||
return did_update
|
||||
|
||||
def get_code_optimization_context(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root: Path,
|
||||
original_source_code: str,
|
||||
self, function_to_optimize: FunctionToOptimize, project_root: Path, original_source_code: str
|
||||
) -> Result[CodeOptimizationContext, str]:
|
||||
code_to_optimize, contextual_dunder_methods = extract_code(
|
||||
[function_to_optimize],
|
||||
)
|
||||
code_to_optimize, contextual_dunder_methods = extract_code([function_to_optimize])
|
||||
if code_to_optimize is None:
|
||||
return Failure("Could not find function to optimize.")
|
||||
(
|
||||
helper_code,
|
||||
helper_functions,
|
||||
helper_dunder_methods,
|
||||
) = get_constrained_function_context_and_helper_functions(
|
||||
function_to_optimize,
|
||||
self.args.project_root,
|
||||
code_to_optimize,
|
||||
(helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions(
|
||||
function_to_optimize, self.args.project_root, code_to_optimize
|
||||
)
|
||||
if function_to_optimize.parents:
|
||||
function_class = function_to_optimize.parents[0].name
|
||||
|
|
@ -695,9 +569,7 @@ class Optimizer:
|
|||
dedup_optimizable_methods.append(method)
|
||||
added_methods.add(f"{method.file_path}.{method.qualified_name}")
|
||||
if len(dedup_optimizable_methods) > 1:
|
||||
code_to_optimize, contextual_dunder_methods = extract_code(
|
||||
list(reversed(dedup_optimizable_methods)),
|
||||
)
|
||||
code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods)))
|
||||
if code_to_optimize is None:
|
||||
return Failure("Could not find function to optimize.")
|
||||
code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
|
||||
|
|
@ -718,7 +590,7 @@ class Optimizer:
|
|||
contextual_dunder_methods=contextual_dunder_methods,
|
||||
helper_functions=helper_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -728,26 +600,18 @@ class Optimizer:
|
|||
get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True)
|
||||
|
||||
def instrument_existing_tests(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
function_to_tests: dict[str, list[FunctionCalledInTest]],
|
||||
self, function_to_optimize: FunctionToOptimize, function_to_tests: dict[str, list[FunctionCalledInTest]]
|
||||
) -> set[Path]:
|
||||
relevant_test_files_count = 0
|
||||
unique_instrumented_test_files = set()
|
||||
|
||||
func_qualname = function_to_optimize.qualified_name_with_modules_from_root(
|
||||
self.args.project_root,
|
||||
)
|
||||
func_qualname = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root)
|
||||
if func_qualname not in function_to_tests:
|
||||
logger.info(
|
||||
f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.",
|
||||
)
|
||||
logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.")
|
||||
else:
|
||||
test_file_invocation_positions = defaultdict(list)
|
||||
for tests_in_file in function_to_tests.get(func_qualname):
|
||||
test_file_invocation_positions[tests_in_file.tests_in_file.test_file].append(
|
||||
tests_in_file.position,
|
||||
)
|
||||
test_file_invocation_positions[tests_in_file.tests_in_file.test_file].append(tests_in_file.position)
|
||||
for test_file, positions in test_file_invocation_positions.items():
|
||||
path_obj_test_file = Path(test_file)
|
||||
relevant_test_files_count += 1
|
||||
|
|
@ -762,7 +626,7 @@ class Optimizer:
|
|||
continue
|
||||
|
||||
new_test_path = Path(
|
||||
f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}",
|
||||
f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}"
|
||||
)
|
||||
if injected_test is not None:
|
||||
with new_test_path.open("w", encoding="utf8") as _f:
|
||||
|
|
@ -778,11 +642,11 @@ class Optimizer:
|
|||
original_source=None,
|
||||
original_file_path=Path(test_file),
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Discovered {relevant_test_files_count} existing unit test file"
|
||||
f"{'s' if relevant_test_files_count != 1 else ''} for {func_qualname}",
|
||||
f"{'s' if relevant_test_files_count != 1 else ''} for {func_qualname}"
|
||||
)
|
||||
return unique_instrumented_test_files
|
||||
|
||||
|
|
@ -849,31 +713,18 @@ class Optimizer:
|
|||
GeneratedTests(
|
||||
generated_original_test_source=generated_test_source,
|
||||
instrumented_test_source=instrumented_test_source,
|
||||
),
|
||||
)
|
||||
)
|
||||
if not tests:
|
||||
logger.warning(
|
||||
f"Failed to generate and instrument tests for {function_to_optimize.function_name}",
|
||||
)
|
||||
logger.warning(f"Failed to generate and instrument tests for {function_to_optimize.function_name}")
|
||||
return Failure(f"/!\\ NO TESTS GENERATED for {function_to_optimize.function_name}")
|
||||
logger.info(f"Generated {len(tests)} tests for {function_to_optimize.function_name}")
|
||||
generated_tests = GeneratedTestsList(generated_tests=tests)
|
||||
|
||||
return Success(
|
||||
(
|
||||
generated_tests,
|
||||
OptimizationSet(
|
||||
control=candidates,
|
||||
experiment=candidates_experiment,
|
||||
),
|
||||
),
|
||||
)
|
||||
return Success((generated_tests, OptimizationSet(control=candidates, experiment=candidates_experiment)))
|
||||
|
||||
def establish_original_code_baseline(
|
||||
self,
|
||||
function_name: str,
|
||||
generated_tests_paths: list[Path],
|
||||
tests_in_file: list[FunctionCalledInTest],
|
||||
self, function_name: str, generated_tests_paths: list[Path], tests_in_file: list[FunctionCalledInTest]
|
||||
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
|
||||
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
|
||||
success = True
|
||||
|
|
@ -902,12 +753,10 @@ class Optimizer:
|
|||
) == TestType.REPLAY_TEST
|
||||
first_test_types.append(first_test_type)
|
||||
first_test_functions.append(
|
||||
relevant_tests_in_file[0].tests_in_file.test_function if is_replay_test else None,
|
||||
relevant_tests_in_file[0].tests_in_file.test_function if is_replay_test else None
|
||||
)
|
||||
if is_replay_test and len(relevant_tests_in_file) > 1:
|
||||
logger.warning(
|
||||
f"Multiple tests found for the replay test {test_file}. Should not happen",
|
||||
)
|
||||
logger.warning(f"Multiple tests found for the replay test {test_file}. Should not happen")
|
||||
first_test_functions.extend([None] * len(generated_tests_paths))
|
||||
|
||||
if test_framework == "pytest":
|
||||
|
|
@ -935,51 +784,41 @@ class Optimizer:
|
|||
unittest_results.merge(unittest_loop_results)
|
||||
|
||||
initial_loop_unittest_results = TestResults(
|
||||
test_results=[result for result in unittest_results.test_results if result.loop_index == 1],
|
||||
test_results=[result for result in unittest_results.test_results if result.loop_index == 1]
|
||||
)
|
||||
logger.info(
|
||||
f"Overall initial loop test results for original code: {TestResults.report_to_string(initial_loop_unittest_results.get_test_pass_fail_report_by_type())}",
|
||||
f"Overall initial loop test results for original code: {TestResults.report_to_string(initial_loop_unittest_results.get_test_pass_fail_report_by_type())}"
|
||||
)
|
||||
existing_test_results = TestResults(
|
||||
test_results=[
|
||||
result for result in unittest_results if result.test_type == TestType.EXISTING_UNIT_TEST
|
||||
],
|
||||
test_results=[result for result in unittest_results if result.test_type == TestType.EXISTING_UNIT_TEST]
|
||||
)
|
||||
generated_test_results = TestResults(
|
||||
test_results=[
|
||||
result for result in unittest_results if result.test_type == TestType.GENERATED_REGRESSION
|
||||
],
|
||||
test_results=[result for result in unittest_results if result.test_type == TestType.GENERATED_REGRESSION]
|
||||
)
|
||||
|
||||
total_timing = unittest_results.total_passed_runtime()
|
||||
|
||||
functions_to_remove = [
|
||||
result.id.test_function_name
|
||||
for result in generated_test_results.test_results
|
||||
if not result.did_pass
|
||||
result.id.test_function_name for result in generated_test_results.test_results if not result.did_pass
|
||||
]
|
||||
|
||||
if not initial_loop_unittest_results:
|
||||
logger.warning(
|
||||
f"Couldn't run any tests for original function {function_name}. SKIPPING OPTIMIZING THIS FUNCTION.",
|
||||
f"Couldn't run any tests for original function {function_name}. SKIPPING OPTIMIZING THIS FUNCTION."
|
||||
)
|
||||
success = False
|
||||
if total_timing == 0:
|
||||
logger.warning(
|
||||
"The overall test runtime of the original function is 0, couldn't run tests.",
|
||||
)
|
||||
logger.warning("The overall test runtime of the original function is 0, couldn't run tests.")
|
||||
success = False
|
||||
if not total_timing:
|
||||
logger.warning(
|
||||
"Failed to run the tests for the original function, skipping optimization",
|
||||
)
|
||||
logger.warning("Failed to run the tests for the original function, skipping optimization")
|
||||
success = False
|
||||
if not success:
|
||||
return Failure("Failed to establish a baseline for the original code.")
|
||||
|
||||
loop_count = max([int(result.loop_index) for result in unittest_results.test_results])
|
||||
logger.info(
|
||||
f"Original code runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}: {humanize_runtime(total_timing)} per full loop",
|
||||
f"Original code runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}: {humanize_runtime(total_timing)} per full loop"
|
||||
)
|
||||
logger.debug(f"Total original code runtime (ns): {total_timing}")
|
||||
return Success(
|
||||
|
|
@ -991,7 +830,7 @@ class Optimizer:
|
|||
runtime=total_timing,
|
||||
),
|
||||
functions_to_remove,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def run_optimized_candidate(
|
||||
|
|
@ -1021,12 +860,8 @@ class Optimizer:
|
|||
|
||||
first_test_types = []
|
||||
first_test_functions = []
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||
|
||||
for test_file in instrumented_unittests_created_for_function:
|
||||
relevant_tests_in_file = [
|
||||
|
|
@ -1039,11 +874,11 @@ class Optimizer:
|
|||
) == TestType.REPLAY_TEST
|
||||
first_test_types.append(first_test_type)
|
||||
first_test_functions.append(
|
||||
relevant_tests_in_file[0].tests_in_file.test_function if is_replay_test else None,
|
||||
relevant_tests_in_file[0].tests_in_file.test_function if is_replay_test else None
|
||||
)
|
||||
if is_replay_test and len(relevant_tests_in_file) > 1:
|
||||
logger.warning(
|
||||
f"Multiple tests found for the replay test {test_file.original_file_path}. Should not happen",
|
||||
f"Multiple tests found for the replay test {test_file.original_file_path}. Should not happen"
|
||||
)
|
||||
first_test_functions.extend([None] * len(generated_tests_paths))
|
||||
if test_framework == "pytest":
|
||||
|
|
@ -1071,19 +906,16 @@ class Optimizer:
|
|||
candidate_results.merge(candidate_loop_results)
|
||||
|
||||
initial_loop_candidate_results = TestResults(
|
||||
test_results=[result for result in candidate_results.test_results if result.loop_index == 1],
|
||||
test_results=[result for result in candidate_results.test_results if result.loop_index == 1]
|
||||
)
|
||||
logger.info(
|
||||
f"Overall initial loop test results for candidate code: {TestResults.report_to_string(initial_loop_candidate_results.get_test_pass_fail_report_by_type())}",
|
||||
f"Overall initial loop test results for candidate code: {TestResults.report_to_string(initial_loop_candidate_results.get_test_pass_fail_report_by_type())}"
|
||||
)
|
||||
initial_loop_original_test_results = TestResults(
|
||||
test_results=[result for result in original_test_results.test_results if result.loop_index == 1],
|
||||
test_results=[result for result in original_test_results.test_results if result.loop_index == 1]
|
||||
)
|
||||
|
||||
if compare_test_results(
|
||||
initial_loop_original_test_results,
|
||||
initial_loop_candidate_results,
|
||||
):
|
||||
if compare_test_results(initial_loop_original_test_results, initial_loop_candidate_results):
|
||||
logger.info("Test results matched!")
|
||||
equal_results = True
|
||||
else:
|
||||
|
|
@ -1092,32 +924,22 @@ class Optimizer:
|
|||
equal_results = False
|
||||
|
||||
if (total_candidate_timing := candidate_results.total_passed_runtime()) == 0:
|
||||
logger.warning(
|
||||
"The overall test runtime of the optimized function is 0, couldn't run tests.",
|
||||
)
|
||||
logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.")
|
||||
if best_runtime_until_now is None or total_candidate_timing < best_runtime_until_now:
|
||||
best_test_results = candidate_results
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.bin")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.bin")).unlink(missing_ok=True)
|
||||
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||
if not equal_results:
|
||||
success = False
|
||||
|
||||
if not success:
|
||||
return Failure("Failed to run the optimized candidate.")
|
||||
logger.debug(
|
||||
f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}",
|
||||
)
|
||||
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
|
||||
return Success(
|
||||
OptimizedCandidateResult(
|
||||
times_run=times_run,
|
||||
best_test_runtime=total_candidate_timing,
|
||||
best_test_results=best_test_results,
|
||||
),
|
||||
times_run=times_run, best_test_runtime=total_candidate_timing, best_test_results=best_test_results
|
||||
)
|
||||
)
|
||||
|
||||
def run_and_parse_tests(
|
||||
|
|
@ -1146,14 +968,14 @@ class Optimizer:
|
|||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.exception(
|
||||
f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error',
|
||||
f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error'
|
||||
)
|
||||
return TestResults()
|
||||
if run_result.returncode != 0:
|
||||
logger.debug(
|
||||
f'Nonzero return code {run_result.returncode} when running tests in {", ".join([str(f.instrumented_file_path) for f in test_files.test_files])}.\n'
|
||||
f"stdout: {run_result.stdout}\n"
|
||||
f"stderr: {run_result.stderr}\n",
|
||||
f"stderr: {run_result.stderr}\n"
|
||||
)
|
||||
return parse_test_results(
|
||||
test_xml_path=result_file_path,
|
||||
|
|
|
|||
|
|
@ -29,9 +29,7 @@ def existing_tests_source_for(
|
|||
existing_tests_unique = set()
|
||||
if test_files:
|
||||
for test_file in test_files:
|
||||
existing_tests_unique.add(
|
||||
"- " + str(Path(test_file.tests_in_file.test_file).relative_to(tests_root)),
|
||||
)
|
||||
existing_tests_unique.add("- " + str(Path(test_file.tests_in_file.test_file).relative_to(tests_root)))
|
||||
return "\n".join(sorted(existing_tests_unique))
|
||||
|
||||
|
||||
|
|
@ -52,8 +50,7 @@ def check_create_pr(
|
|||
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
|
||||
build_file_changes = {
|
||||
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
|
||||
oldContent=original_code[p],
|
||||
newContent=new_code[p],
|
||||
oldContent=original_code[p], newContent=new_code[p]
|
||||
)
|
||||
for p in original_code
|
||||
if not is_zero_diff(original_code[p], new_code[p])
|
||||
|
|
@ -85,7 +82,7 @@ def check_create_pr(
|
|||
else:
|
||||
logger.error(
|
||||
f"Optimization was successful, but I failed to suggest changes to PR #{pr_number}."
|
||||
f" Response from server was: {response.text}",
|
||||
f" Response from server was: {response.text}"
|
||||
)
|
||||
else:
|
||||
logger.info("Creating a new PR with the optimized code...")
|
||||
|
|
@ -98,8 +95,7 @@ def check_create_pr(
|
|||
base_branch = get_current_branch()
|
||||
build_file_changes = {
|
||||
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
|
||||
oldContent=original_code[p],
|
||||
newContent=new_code[p],
|
||||
oldContent=original_code[p], newContent=new_code[p]
|
||||
)
|
||||
for p in original_code
|
||||
}
|
||||
|
|
@ -128,5 +124,5 @@ def check_create_pr(
|
|||
else:
|
||||
logger.error(
|
||||
f"Optimization was successful, but I failed to create a PR with the optimized code."
|
||||
f" Response from server was: {response.text}",
|
||||
f" Response from server was: {response.text}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,9 +13,7 @@ def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) ->
|
|||
|
||||
|
||||
def speedup_critic(
|
||||
candidate_result: OptimizedCandidateResult,
|
||||
original_code_runtime: int,
|
||||
best_runtime_until_now: int,
|
||||
candidate_result: OptimizedCandidateResult, original_code_runtime: int, best_runtime_until_now: int
|
||||
) -> bool:
|
||||
"""Takes in a correct optimized Test Result and decides if the optimization should actually
|
||||
be surfaced to the user.
|
||||
|
|
@ -33,8 +31,7 @@ def speedup_critic(
|
|||
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
|
||||
|
||||
perf_gain = performance_gain(
|
||||
original_runtime_ns=original_code_runtime,
|
||||
optimized_runtime_ns=candidate_result.best_test_runtime,
|
||||
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
|
||||
)
|
||||
if (perf_gain > noise_floor) and candidate_result.best_test_runtime < best_runtime_until_now:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -18,10 +18,7 @@ def initialize_posthog(enabled: bool) -> None:
|
|||
return
|
||||
|
||||
global _posthog
|
||||
_posthog = Posthog(
|
||||
project_api_key="phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol",
|
||||
host="https://us.posthog.com",
|
||||
)
|
||||
_posthog = Posthog(project_api_key="phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol", host="https://us.posthog.com")
|
||||
_posthog.log.setLevel(logging.CRITICAL) # Suppress PostHog logging
|
||||
ph("cli-telemetry-enabled")
|
||||
|
||||
|
|
@ -40,10 +37,6 @@ def ph(event: str, properties: Optional[Dict[str, Any]] = None) -> None:
|
|||
user_id = get_user_id()
|
||||
|
||||
if user_id:
|
||||
_posthog.capture(
|
||||
distinct_id=user_id,
|
||||
event=event,
|
||||
properties=properties,
|
||||
)
|
||||
_posthog.capture(distinct_id=user_id, event=event, properties=properties)
|
||||
else:
|
||||
logger.debug("Failed to log event to PostHog: User ID could not be retrieved.")
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class Tracer:
|
|||
if sys.getprofile() is not None or sys.gettrace() is not None:
|
||||
console.print(
|
||||
"WARNING - Codeflash: Another profiler, debugger or coverage tool is already running. "
|
||||
"Please disable it before starting the Codeflash Tracer, both can't run. Codeflash Tracer is DISABLED.",
|
||||
"Please disable it before starting the Codeflash Tracer, both can't run. Codeflash Tracer is DISABLED."
|
||||
)
|
||||
self.disable = True
|
||||
return
|
||||
|
|
@ -91,22 +91,10 @@ class Tracer:
|
|||
}
|
||||
self.max_function_count = max_function_count
|
||||
self.config, found_config_path = parse_config_file(config_file_path)
|
||||
self.project_root = project_root_from_module_root(
|
||||
self.config["module_root"],
|
||||
found_config_path,
|
||||
)
|
||||
self.ignored_functions = {
|
||||
"<listcomp>",
|
||||
"<genexpr>",
|
||||
"<dictcomp>",
|
||||
"<setcomp>",
|
||||
"<lambda>",
|
||||
"<module>",
|
||||
}
|
||||
self.project_root = project_root_from_module_root(self.config["module_root"], found_config_path)
|
||||
self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"}
|
||||
|
||||
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(
|
||||
".", "_"
|
||||
)
|
||||
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_")
|
||||
|
||||
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
|
||||
self.timeout = timeout
|
||||
|
|
@ -121,9 +109,7 @@ class Tracer:
|
|||
self.timer = time.process_time_ns
|
||||
self.total_tt = 0
|
||||
self.simulate_call("profiler")
|
||||
assert (
|
||||
"test_framework" in self.config
|
||||
), "Please specify 'test-framework' in pyproject.toml config file"
|
||||
assert "test_framework" in self.config, "Please specify 'test-framework' in pyproject.toml config file"
|
||||
self.t = self.timer()
|
||||
|
||||
def __enter__(self) -> None:
|
||||
|
|
@ -132,7 +118,7 @@ class Tracer:
|
|||
if getattr(Tracer, "used_once", False):
|
||||
console.print(
|
||||
"Codeflash: Tracer can only be used once per program run. "
|
||||
"Please only enable the Tracer once. Skipping tracing this section.",
|
||||
"Please only enable the Tracer once. Skipping tracing this section."
|
||||
)
|
||||
self.disable = True
|
||||
return
|
||||
|
|
@ -148,7 +134,7 @@ class Tracer:
|
|||
# TODO: Check out if we need to export the function test name as well
|
||||
cur.execute(
|
||||
"CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, "
|
||||
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)",
|
||||
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)"
|
||||
)
|
||||
console.print("Codeflash: Tracing started!")
|
||||
frame = sys._getframe(0) # Get this frame and simulate a call to it
|
||||
|
|
@ -168,23 +154,13 @@ class Tracer:
|
|||
cur.execute(
|
||||
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
|
||||
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
|
||||
"cumulative_time_ns INTEGER, callers BLOB)",
|
||||
"cumulative_time_ns INTEGER, callers BLOB)"
|
||||
)
|
||||
for func, (cc, nc, tt, ct, callers) in self.stats.items():
|
||||
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
|
||||
cur.execute(
|
||||
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
Path(func[0]).resolve(),
|
||||
func[1],
|
||||
func[2],
|
||||
func[3],
|
||||
cc,
|
||||
nc,
|
||||
tt,
|
||||
ct,
|
||||
json.dumps(remapped_callers),
|
||||
),
|
||||
(Path(func[0]).resolve(), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)),
|
||||
)
|
||||
self.con.commit()
|
||||
|
||||
|
|
@ -217,16 +193,14 @@ class Tracer:
|
|||
)
|
||||
function_path = "_".join(self.functions) if self.functions else self.file_being_called_from
|
||||
test_file_path = get_test_file_path(
|
||||
test_dir=self.config["tests_root"],
|
||||
function_name=function_path,
|
||||
test_type="replay",
|
||||
test_dir=self.config["tests_root"], function_name=function_path, test_type="replay"
|
||||
)
|
||||
replay_test = isort.code(replay_test)
|
||||
with open(test_file_path, "w", encoding="utf8") as file:
|
||||
file.write(replay_test)
|
||||
|
||||
console.print(
|
||||
f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}",
|
||||
f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}"
|
||||
)
|
||||
|
||||
def tracer_logic(self, frame: FrameType, event: str):
|
||||
|
|
@ -235,9 +209,7 @@ class Tracer:
|
|||
if self.timeout is not None:
|
||||
if (time.time() - self.start_time) > self.timeout:
|
||||
sys.setprofile(None)
|
||||
console.print(
|
||||
f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.",
|
||||
)
|
||||
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
|
||||
return
|
||||
code = frame.f_code
|
||||
file_name = code.co_filename
|
||||
|
|
@ -290,13 +262,10 @@ class Tracer:
|
|||
FunctionModules(
|
||||
function_name=code.co_name,
|
||||
file_name=file_name,
|
||||
module_name=module_name_from_file_path(
|
||||
file_name,
|
||||
project_root_path=self.project_root,
|
||||
),
|
||||
module_name=module_name_from_file_path(file_name, project_root_path=self.project_root),
|
||||
class_name=class_name,
|
||||
line_no=code.co_firstlineno,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Also check if this function arguments are unique from the values logged earlier
|
||||
|
|
@ -314,18 +283,12 @@ class Tracer:
|
|||
arguments = dict(arguments.items())
|
||||
if class_name and code.co_name == "__init__":
|
||||
del arguments["self"]
|
||||
local_vars = pickle.dumps(
|
||||
arguments,
|
||||
protocol=pickle.HIGHEST_PROTOCOL,
|
||||
)
|
||||
local_vars = pickle.dumps(arguments, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
|
||||
# we retry with dill if pickle fails. It's slower but more comprehensive
|
||||
try:
|
||||
local_vars = dill.dumps(
|
||||
arguments,
|
||||
protocol=dill.HIGHEST_PROTOCOL,
|
||||
)
|
||||
local_vars = dill.dumps(arguments, protocol=dill.HIGHEST_PROTOCOL)
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
|
||||
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
|
||||
|
|
@ -334,16 +297,7 @@ class Tracer:
|
|||
return
|
||||
cur.execute(
|
||||
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
event,
|
||||
code.co_name,
|
||||
class_name,
|
||||
file_name,
|
||||
frame.f_lineno,
|
||||
frame.f_back.__hash__(),
|
||||
t_ns,
|
||||
local_vars,
|
||||
),
|
||||
(event, code.co_name, class_name, file_name, frame.f_lineno, frame.f_back.__hash__(), t_ns, local_vars),
|
||||
)
|
||||
self.trace_count += 1
|
||||
self.next_insert -= 1
|
||||
|
|
@ -374,14 +328,7 @@ class Tracer:
|
|||
if self.cur and frame.f_back is not self.cur[-2]:
|
||||
rpt, rit, ret, rfn, rframe, rcur = self.cur
|
||||
if not isinstance(rframe, Tracer.fake_frame):
|
||||
assert rframe.f_back is frame.f_back, (
|
||||
"Bad call",
|
||||
rfn,
|
||||
rframe,
|
||||
rframe.f_back,
|
||||
frame,
|
||||
frame.f_back,
|
||||
)
|
||||
assert rframe.f_back is frame.f_back, ("Bad call", rfn, rframe, rframe.f_back, frame, frame.f_back)
|
||||
self.trace_dispatch_return(rframe, 0)
|
||||
assert self.cur is None or frame.f_back is self.cur[-2], ("Bad call", self.cur[-3])
|
||||
fcode = frame.f_code
|
||||
|
|
@ -528,9 +475,7 @@ class Tracer:
|
|||
console.print("Failed to get total time from stats")
|
||||
total_time_ms = total_time / 1e6
|
||||
raw_stats = re.sub(
|
||||
r"(function calls?.*)in (\d+)\.\d+ (seconds?)",
|
||||
rf"\1 in {total_time_ms:.3f} milliseconds",
|
||||
raw_stats,
|
||||
r"(function calls?.*)in (\d+)\.\d+ (seconds?)", rf"\1 in {total_time_ms:.3f} milliseconds", raw_stats
|
||||
)
|
||||
match_pattern = r"^ *[\d\/]+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +"
|
||||
m = re.findall(match_pattern, raw_stats, re.MULTILINE)
|
||||
|
|
@ -605,7 +550,6 @@ class Tracer:
|
|||
|
||||
|
||||
def main():
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser(allow_abbrev=False)
|
||||
|
|
@ -624,13 +568,7 @@ def main():
|
|||
type=float,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
action="store_true",
|
||||
dest="module",
|
||||
help="Trace a library module",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument("-m", action="store_true", dest="module", help="Trace a library module", default=False)
|
||||
parser.add_argument(
|
||||
"--codeflash-config",
|
||||
help="Optional path to the project's pyproject.toml file "
|
||||
|
|
@ -655,10 +593,7 @@ def main():
|
|||
import runpy
|
||||
|
||||
code = "run_module(modname, run_name='__main__')"
|
||||
globs = {
|
||||
"run_module": runpy.run_module,
|
||||
"modname": unknown_args[0],
|
||||
}
|
||||
globs = {"run_module": runpy.run_module, "modname": unknown_args[0]}
|
||||
else:
|
||||
progname = unknown_args[0]
|
||||
sys.path.insert(0, str(Path(progname).parent))
|
||||
|
|
|
|||
|
|
@ -39,9 +39,7 @@ class ProfileStats(pstats.Stats):
|
|||
call_count_nonrecursive,
|
||||
num_callers,
|
||||
total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns,
|
||||
cumulative_time_ns / time_conversion_factor
|
||||
if time_conversion_factor != 1
|
||||
else cumulative_time_ns,
|
||||
cumulative_time_ns / time_conversion_factor if time_conversion_factor != 1 else cumulative_time_ns,
|
||||
unmapped_callers,
|
||||
)
|
||||
|
||||
|
|
@ -58,9 +56,7 @@ class ProfileStats(pstats.Stats):
|
|||
print(indent, self.total_calls, "function calls", end=" ", file=self.stream)
|
||||
if self.total_calls != self.prim_calls:
|
||||
print("(%d primitive calls)" % self.prim_calls, end=" ", file=self.stream)
|
||||
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[
|
||||
self.time_unit
|
||||
]
|
||||
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit]
|
||||
print("in %.3f %s" % (self.total_tt, time_unit), file=self.stream)
|
||||
print(file=self.stream)
|
||||
width, list = self.get_print_list(amount)
|
||||
|
|
|
|||
|
|
@ -4,19 +4,12 @@ import sqlite3
|
|||
import textwrap
|
||||
from typing import Any, Generator, List, Optional
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
FunctionProperties,
|
||||
inspect_top_level_functions_or_methods,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionProperties, inspect_top_level_functions_or_methods
|
||||
from codeflash.tracing.tracing_utils import FunctionModules
|
||||
|
||||
|
||||
def get_next_arg_and_return(
|
||||
trace_file: str,
|
||||
function_name: str,
|
||||
file_name: str,
|
||||
class_name: Optional[str] = None,
|
||||
num_to_get: int = 25,
|
||||
trace_file: str, function_name: str, file_name: str, class_name: Optional[str] = None, num_to_get: int = 25
|
||||
) -> Generator[Any]:
|
||||
db = sqlite3.connect(trace_file)
|
||||
cur = db.cursor()
|
||||
|
|
@ -45,10 +38,7 @@ def get_function_alias(module: str, function_name: str) -> str:
|
|||
|
||||
|
||||
def create_trace_replay_test(
|
||||
trace_file: str,
|
||||
functions: List[FunctionModules],
|
||||
test_framework: str = "pytest",
|
||||
max_run_count=100,
|
||||
trace_file: str, functions: List[FunctionModules], test_framework: str = "pytest", max_run_count=100
|
||||
) -> str:
|
||||
assert test_framework in ["pytest", "unittest"]
|
||||
|
||||
|
|
@ -74,21 +64,19 @@ from codeflash.tracing.replay_test import get_next_arg_and_return
|
|||
continue
|
||||
if function_property.is_staticmethod:
|
||||
function_imports.append(
|
||||
f"from {function.module_name} import {function_property.staticmethod_class_name} as {get_function_alias(function.module_name, function_property.staticmethod_class_name)}",
|
||||
f"from {function.module_name} import {function_property.staticmethod_class_name} as {get_function_alias(function.module_name, function_property.staticmethod_class_name)}"
|
||||
)
|
||||
elif function.class_name:
|
||||
function_imports.append(
|
||||
f"from {function.module_name} import {function.class_name} as {get_function_alias(function.module_name, function.class_name)}",
|
||||
f"from {function.module_name} import {function.class_name} as {get_function_alias(function.module_name, function.class_name)}"
|
||||
)
|
||||
else:
|
||||
function_imports.append(
|
||||
f"from {function.module_name} import {function.function_name} as {get_function_alias(function.module_name, function.function_name)}",
|
||||
f"from {function.module_name} import {function.function_name} as {get_function_alias(function.module_name, function.function_name)}"
|
||||
)
|
||||
|
||||
imports += "\n".join(function_imports)
|
||||
functions_to_optimize = [
|
||||
function.function_name for function in functions if function.function_name != "__init__"
|
||||
]
|
||||
functions_to_optimize = [function.function_name for function in functions if function.function_name != "__init__"]
|
||||
metadata = f"""functions = {functions_to_optimize}
|
||||
trace_file_path = r"{trace_file}"
|
||||
""" # trace_file_path path is parsed with regex later, format is important
|
||||
|
|
@ -97,21 +85,21 @@ trace_file_path = r"{trace_file}"
|
|||
for arg_val_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", num_to_get={max_run_count}):
|
||||
args = pickle.loads(arg_val_pkl)
|
||||
ret = {function_name}({args})
|
||||
""",
|
||||
"""
|
||||
)
|
||||
test_class_method_body = textwrap.dedent(
|
||||
"""\
|
||||
for arg_val_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", class_name="{class_name}", num_to_get={max_run_count}):
|
||||
args = pickle.loads(arg_val_pkl){filter_variables}
|
||||
ret = {class_name_alias}{method_name}(**args)
|
||||
""",
|
||||
"""
|
||||
)
|
||||
test_class_staticmethod_body = textwrap.dedent(
|
||||
"""\
|
||||
for arg_val_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", num_to_get={max_run_count}):
|
||||
args = pickle.loads(arg_val_pkl){filter_variables}
|
||||
ret = {class_name_alias}{method_name}(**args)
|
||||
""",
|
||||
"""
|
||||
)
|
||||
if test_framework == "unittest":
|
||||
self = "self"
|
||||
|
|
@ -135,8 +123,7 @@ trace_file_path = r"{trace_file}"
|
|||
elif func_property.is_staticmethod:
|
||||
class_name_alias = get_function_alias(func.module_name, func_property.staticmethod_class_name)
|
||||
alias = get_function_alias(
|
||||
func.module_name,
|
||||
func_property.staticmethod_class_name + "_" + func.function_name,
|
||||
func.module_name, func_property.staticmethod_class_name + "_" + func.function_name
|
||||
)
|
||||
method_name = "." + func.function_name if func.function_name != "__init__" else ""
|
||||
test_body = test_class_staticmethod_body.format(
|
||||
|
|
@ -149,10 +136,7 @@ trace_file_path = r"{trace_file}"
|
|||
)
|
||||
else:
|
||||
class_name_alias = get_function_alias(func.module_name, func.class_name)
|
||||
alias = get_function_alias(
|
||||
func.module_name,
|
||||
func.class_name + "_" + func.function_name,
|
||||
)
|
||||
alias = get_function_alias(func.module_name, func.class_name + "_" + func.function_name)
|
||||
|
||||
if func_property.is_classmethod:
|
||||
filter_variables = '\n args.pop("cls", None)'
|
||||
|
|
@ -170,10 +154,7 @@ trace_file_path = r"{trace_file}"
|
|||
max_run_count=max_run_count,
|
||||
filter_variables=filter_variables,
|
||||
)
|
||||
formatted_test_body = textwrap.indent(
|
||||
test_body,
|
||||
" " if test_framework == "unittest" else " ",
|
||||
)
|
||||
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
|
||||
|
||||
test_template += " " if test_framework == "unittest" else ""
|
||||
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
|
||||
|
|
|
|||
|
|
@ -129,21 +129,11 @@ def comparator(orig: Any, new: Any) -> bool:
|
|||
return (orig != new).nnz == 0
|
||||
|
||||
if HAS_PANDAS and isinstance(
|
||||
orig,
|
||||
(
|
||||
pandas.DataFrame,
|
||||
pandas.Series,
|
||||
pandas.Index,
|
||||
pandas.Categorical,
|
||||
pandas.arrays.SparseArray,
|
||||
),
|
||||
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
|
||||
):
|
||||
return orig.equals(new)
|
||||
|
||||
if HAS_PANDAS and isinstance(
|
||||
orig,
|
||||
(pandas.CategoricalDtype, pandas.Interval, pandas.Period),
|
||||
):
|
||||
if HAS_PANDAS and isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
|
||||
return orig == new
|
||||
|
||||
# This should be at the end of all numpy checking
|
||||
|
|
@ -173,10 +163,7 @@ def comparator(orig: Any, new: Any) -> bool:
|
|||
):
|
||||
return orig == new
|
||||
|
||||
if isinstance(
|
||||
orig,
|
||||
(datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone),
|
||||
):
|
||||
if isinstance(orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone)):
|
||||
return orig == new
|
||||
|
||||
# If the object passed has a user defined __eq__ method, use that
|
||||
|
|
|
|||
|
|
@ -11,17 +11,9 @@ import dill as pickle
|
|||
from junitparser.xunit2 import JUnitXml
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import (
|
||||
file_path_from_module_name,
|
||||
get_run_tmp_file,
|
||||
module_name_from_file_path,
|
||||
)
|
||||
from codeflash.code_utils.code_utils import file_path_from_module_name, get_run_tmp_file, module_name_from_file_path
|
||||
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
||||
from codeflash.verification.test_results import (
|
||||
FunctionTestInvocation,
|
||||
InvocationId,
|
||||
TestResults,
|
||||
)
|
||||
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import subprocess
|
||||
|
|
@ -30,11 +22,7 @@ if TYPE_CHECKING:
|
|||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
def parse_test_return_values_bin(
|
||||
file_location: Path,
|
||||
test_files: TestFiles,
|
||||
test_config: TestConfig,
|
||||
) -> TestResults:
|
||||
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
|
||||
test_results = TestResults()
|
||||
if not file_location.exists():
|
||||
logger.warning(f"No test results for {file_location} found.")
|
||||
|
|
@ -66,8 +54,7 @@ def parse_test_return_values_bin(
|
|||
|
||||
invocation_id_object = InvocationId.from_str_id(encoded_test_name, invocation_id)
|
||||
test_file_path = file_path_from_module_name(
|
||||
invocation_id_object.test_module_path,
|
||||
test_config.tests_project_rootdir,
|
||||
invocation_id_object.test_module_path, test_config.tests_project_rootdir
|
||||
)
|
||||
|
||||
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
|
||||
|
|
@ -85,16 +72,12 @@ def parse_test_return_values_bin(
|
|||
test_type=test_type,
|
||||
return_value=test_pickle,
|
||||
timed_out=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
return test_results
|
||||
|
||||
|
||||
def parse_sqlite_test_results(
|
||||
sqlite_file_path: Path,
|
||||
test_files: TestFiles,
|
||||
test_config: TestConfig,
|
||||
) -> TestResults:
|
||||
def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
|
||||
test_results = TestResults()
|
||||
if not sqlite_file_path.exists():
|
||||
logger.warning(f"No test results for {sqlite_file_path} found.")
|
||||
|
|
@ -104,7 +87,7 @@ def parse_sqlite_test_results(
|
|||
cur = db.cursor()
|
||||
data = cur.execute(
|
||||
"SELECT test_module_path, test_class_name, test_function_name, "
|
||||
"function_getting_tested, loop_index, iteration_id, runtime, return_value FROM test_results",
|
||||
"function_getting_tested, loop_index, iteration_id, runtime, return_value FROM test_results"
|
||||
).fetchall()
|
||||
finally:
|
||||
db.close()
|
||||
|
|
@ -132,7 +115,7 @@ def parse_sqlite_test_results(
|
|||
test_type=test_type,
|
||||
return_value=pickle.loads(val[7]) if loop_index == 1 else None,
|
||||
timed_out=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to load pickle file.")
|
||||
|
|
@ -155,14 +138,10 @@ def parse_test_xml(
|
|||
try:
|
||||
xml = JUnitXml.fromfile(str(test_xml_file_path))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse {test_xml_file_path} as JUnitXml. Exception: {e}",
|
||||
)
|
||||
logger.warning(f"Failed to parse {test_xml_file_path} as JUnitXml. Exception: {e}")
|
||||
return test_results
|
||||
base_dir = (
|
||||
test_config.tests_project_rootdir
|
||||
if test_config.test_framework == "pytest"
|
||||
else test_config.project_root_path
|
||||
test_config.tests_project_rootdir if test_config.test_framework == "pytest" else test_config.project_root_path
|
||||
)
|
||||
for suite in xml:
|
||||
for testcase in suite:
|
||||
|
|
@ -178,12 +157,10 @@ def parse_test_xml(
|
|||
logger.info("Test failed to load, skipping it.")
|
||||
if run_result is not None:
|
||||
if isinstance(run_result.stdout, str) and isinstance(run_result.stderr, str):
|
||||
logger.info(
|
||||
f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}",
|
||||
)
|
||||
logger.info(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}")
|
||||
else:
|
||||
logger.info(
|
||||
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}",
|
||||
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}"
|
||||
)
|
||||
return test_results
|
||||
|
||||
|
|
@ -192,15 +169,9 @@ def parse_test_xml(
|
|||
if test_file_name is None:
|
||||
if test_class_path:
|
||||
# TODO : This might not be true if the test is organized under a class
|
||||
test_file_path = file_path_from_module_name(
|
||||
test_class_path,
|
||||
base_dir,
|
||||
)
|
||||
test_file_path = file_path_from_module_name(test_class_path, base_dir)
|
||||
else:
|
||||
test_file_path = file_path_from_module_name(
|
||||
test_function,
|
||||
base_dir,
|
||||
)
|
||||
test_file_path = file_path_from_module_name(test_function, base_dir)
|
||||
else:
|
||||
# TODO: not sure which root path fits better here
|
||||
test_file_path = base_dir / test_file_name
|
||||
|
|
@ -213,9 +184,7 @@ def parse_test_xml(
|
|||
result = testcase.is_passed # TODO: See for the cases of ERROR and SKIPPED
|
||||
test_class = None
|
||||
if class_name is not None and class_name.startswith(test_module_path):
|
||||
test_class = class_name[
|
||||
len(test_module_path) + 1 :
|
||||
] # +1 for the dot, gets Unittest class name
|
||||
test_class = class_name[len(test_module_path) + 1 :] # +1 for the dot, gets Unittest class name
|
||||
|
||||
loop_index = 1
|
||||
if test_function is None:
|
||||
|
|
@ -226,26 +195,19 @@ def parse_test_xml(
|
|||
if test_config.test_framework == "pytest":
|
||||
loop_index = int(testcase.name.split("[ ", 1)[1][:-2]) if "[" in testcase.name else 1
|
||||
if len(testcase.result) > 1:
|
||||
logger.warning(
|
||||
f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!",
|
||||
)
|
||||
logger.warning(f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!")
|
||||
if len(testcase.result) == 1:
|
||||
message = testcase.result[0].message.lower()
|
||||
if "failed: timeout >" in message:
|
||||
timed_out = True
|
||||
else:
|
||||
if len(testcase.result) > 1:
|
||||
logger.warning(
|
||||
f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!",
|
||||
)
|
||||
logger.warning(f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!")
|
||||
if len(testcase.result) == 1:
|
||||
message = testcase.result[0].message.lower()
|
||||
if "timed out" in message:
|
||||
timed_out = True
|
||||
matches = re.findall(
|
||||
r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!",
|
||||
testcase.system_out or "",
|
||||
)
|
||||
matches = re.findall(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!", testcase.system_out or "")
|
||||
if not matches or not len(matches):
|
||||
test_results.add(
|
||||
FunctionTestInvocation(
|
||||
|
|
@ -264,7 +226,7 @@ def parse_test_xml(
|
|||
test_type=test_type,
|
||||
return_value=None,
|
||||
timed_out=timed_out,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
@ -286,12 +248,12 @@ def parse_test_xml(
|
|||
test_type=test_type,
|
||||
return_value=None,
|
||||
timed_out=timed_out,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not test_results:
|
||||
logger.info(
|
||||
f"Tests '{[test_file.original_file_path for test_file in test_files.test_files]}' failed to run, skipping",
|
||||
f"Tests '{[test_file.original_file_path for test_file in test_files.test_files]}' failed to run, skipping"
|
||||
)
|
||||
if run_result is not None:
|
||||
stdout, stderr = "", ""
|
||||
|
|
@ -300,16 +262,12 @@ def parse_test_xml(
|
|||
stderr = run_result.stderr.decode()
|
||||
except AttributeError:
|
||||
stdout = run_result.stderr
|
||||
logger.debug(
|
||||
f"Test log - STDOUT : {stdout} \n STDERR : {stderr}",
|
||||
)
|
||||
logger.debug(f"Test log - STDOUT : {stdout} \n STDERR : {stderr}")
|
||||
return test_results
|
||||
|
||||
|
||||
def merge_test_results(
|
||||
xml_test_results: TestResults,
|
||||
bin_test_results: TestResults,
|
||||
test_framework: str,
|
||||
xml_test_results: TestResults, bin_test_results: TestResults, test_framework: str
|
||||
) -> TestResults:
|
||||
merged_test_results = TestResults()
|
||||
|
||||
|
|
@ -319,18 +277,14 @@ def merge_test_results(
|
|||
# This is done to match the right iteration_id which might not be available in the xml
|
||||
for result in xml_test_results:
|
||||
if test_framework == "pytest":
|
||||
if (
|
||||
result.id.test_function_name.endswith("]") and "[" in result.id.test_function_name
|
||||
): # parameterized test
|
||||
if result.id.test_function_name.endswith("]") and "[" in result.id.test_function_name: # parameterized test
|
||||
test_function_name = result.id.test_function_name[: result.id.test_function_name.index("[")]
|
||||
else:
|
||||
test_function_name = result.id.test_function_name
|
||||
|
||||
if test_framework == "unittest":
|
||||
test_function_name = result.id.test_function_name
|
||||
is_parameterized, new_test_function_name, _ = discover_parameters_unittest(
|
||||
test_function_name,
|
||||
)
|
||||
is_parameterized, new_test_function_name, _ = discover_parameters_unittest(test_function_name)
|
||||
if is_parameterized: # handle parameterized test
|
||||
test_function_name = new_test_function_name
|
||||
|
||||
|
|
@ -378,7 +332,7 @@ def merge_test_results(
|
|||
test_type=xml_result.test_type,
|
||||
return_value=result_bin.return_value,
|
||||
timed_out=xml_result.timed_out,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif xml_results.test_results[0].id.iteration_id is not None:
|
||||
# This means that we have multiple iterations of the same test function
|
||||
|
|
@ -404,7 +358,7 @@ def merge_test_results(
|
|||
timed_out=xml_result.timed_out
|
||||
if bin_result.runtime is None
|
||||
else False, # If runtime was measured in the bin file, then the testcase did not time out
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Should happen only if the xml did not have any test invocation id info
|
||||
|
|
@ -427,7 +381,7 @@ def merge_test_results(
|
|||
test_type=bin_result.test_type,
|
||||
return_value=bin_result.return_value,
|
||||
timed_out=xml_result.timed_out, # only the xml gets the timed_out flag
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return merged_test_results
|
||||
|
|
@ -441,10 +395,7 @@ def parse_test_results(
|
|||
run_result: subprocess.CompletedProcess | None = None,
|
||||
) -> TestResults:
|
||||
test_results_xml = parse_test_xml(
|
||||
test_xml_path,
|
||||
test_files=test_files,
|
||||
test_config=test_config,
|
||||
run_result=run_result,
|
||||
test_xml_path, test_files=test_files, test_config=test_config, run_result=run_result
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -463,9 +414,7 @@ def parse_test_results(
|
|||
sql_results_file = get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.sqlite"))
|
||||
if sql_results_file.exists():
|
||||
test_results_sqlite_file = parse_sqlite_test_results(
|
||||
sqlite_file_path=sql_results_file,
|
||||
test_files=test_files,
|
||||
test_config=test_config,
|
||||
sqlite_file_path=sql_results_file, test_files=test_files, test_config=test_config
|
||||
)
|
||||
test_results_bin_file.merge(test_results_sqlite_file)
|
||||
except AttributeError as e:
|
||||
|
|
@ -475,8 +424,4 @@ def parse_test_results(
|
|||
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.sqlite")).unlink(missing_ok=True)
|
||||
|
||||
return merge_test_results(
|
||||
test_results_xml,
|
||||
test_results_bin_file,
|
||||
test_config.test_framework,
|
||||
)
|
||||
return merge_test_results(test_results_xml, test_results_bin_file, test_config.test_framework)
|
||||
|
|
|
|||
|
|
@ -44,11 +44,7 @@ def pytest_addoption(parser: Parser) -> None:
|
|||
help="The amount of time to wait between each test loop.",
|
||||
)
|
||||
pytest_loops.addoption(
|
||||
"--codeflash_hours",
|
||||
action="store",
|
||||
default=0,
|
||||
type=float,
|
||||
help="The number of hours to loop the tests for.",
|
||||
"--codeflash_hours", action="store", default=0, type=float, help="The number of hours to loop the tests for."
|
||||
)
|
||||
pytest_loops.addoption(
|
||||
"--codeflash_minutes",
|
||||
|
|
@ -66,11 +62,7 @@ def pytest_addoption(parser: Parser) -> None:
|
|||
)
|
||||
|
||||
pytest_loops.addoption(
|
||||
"--codeflash_loops",
|
||||
action="store",
|
||||
default=1,
|
||||
type=int,
|
||||
help="The number of times to loop each test",
|
||||
"--codeflash_loops", action="store", default=1, type=int, help="The number of times to loop each test"
|
||||
)
|
||||
|
||||
pytest_loops.addoption(
|
||||
|
|
@ -101,10 +93,7 @@ def pytest_addoption(parser: Parser) -> None:
|
|||
|
||||
@pytest.hookimpl(trylast=True)
|
||||
def pytest_configure(config: Config) -> None:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"loops(n): run the given test function `n` times.",
|
||||
)
|
||||
config.addinivalue_line("markers", "loops(n): run the given test function `n` times.")
|
||||
config.pluginmanager.register(PyTest_Loops(config), PyTest_Loops.name)
|
||||
|
||||
|
||||
|
|
@ -122,8 +111,7 @@ class PyTest_Loops:
|
|||
"""Reimplement the test loop but loop for the user defined amount of time."""
|
||||
if session.testsfailed and not session.config.option.continue_on_collection_errors:
|
||||
raise session.Interrupted(
|
||||
"%d error%s during collection"
|
||||
% (session.testsfailed, "s" if session.testsfailed != 1 else ""),
|
||||
"%d error%s during collection" % (session.testsfailed, "s" if session.testsfailed != 1 else "")
|
||||
)
|
||||
|
||||
if session.config.option.collectonly:
|
||||
|
|
@ -188,9 +176,7 @@ class PyTest_Loops:
|
|||
seconds = session.config.option.codeflash_seconds
|
||||
total_time = hours_in_seconds + minutes_in_seconds + seconds
|
||||
if total_time < SHORTEST_AMOUNT_OF_TIME:
|
||||
raise InvalidTimeParameterError(
|
||||
f"Total time cannot be less than: {SHORTEST_AMOUNT_OF_TIME}!",
|
||||
)
|
||||
raise InvalidTimeParameterError(f"Total time cannot be less than: {SHORTEST_AMOUNT_OF_TIME}!")
|
||||
return total_time
|
||||
|
||||
def _timed_out(self, session: Session, start_time: float, count: int) -> bool:
|
||||
|
|
@ -233,7 +219,7 @@ class PyTest_Loops:
|
|||
else:
|
||||
raise UnexpectedError(
|
||||
"This call couldn't work with pytest-loops. "
|
||||
"Please consider raising an issue with your usage.",
|
||||
"Please consider raising an issue with your usage."
|
||||
)
|
||||
return count
|
||||
|
||||
|
|
@ -257,9 +243,5 @@ class PyTest_Loops:
|
|||
|
||||
scope = metafunc.config.option.codeflash_loops_scope
|
||||
metafunc.parametrize(
|
||||
"__pytest_loop_step_number",
|
||||
range(count),
|
||||
indirect=True,
|
||||
ids=make_progress_id,
|
||||
scope=scope,
|
||||
"__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope
|
||||
)
|
||||
|
|
|
|||
|
|
@ -85,10 +85,7 @@ class TestResults(BaseModel):
|
|||
def merge(self, other: TestResults) -> None:
|
||||
self.test_results.extend(other.test_results)
|
||||
|
||||
def get_by_id(
|
||||
self,
|
||||
invocation_id: InvocationId,
|
||||
) -> FunctionTestInvocation | None:
|
||||
def get_by_id(self, invocation_id: InvocationId) -> FunctionTestInvocation | None:
|
||||
return next((r for r in self.test_results if r.id == invocation_id), None)
|
||||
|
||||
def get_all_ids(self) -> set[InvocationId]:
|
||||
|
|
@ -129,7 +126,7 @@ class TestResults(BaseModel):
|
|||
[
|
||||
f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})"
|
||||
for test_type in TestType
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
def total_passed_runtime(self) -> int:
|
||||
|
|
@ -141,14 +138,14 @@ class TestResults(BaseModel):
|
|||
for result in self.test_results:
|
||||
if result.did_pass and not result.runtime:
|
||||
logger.debug(
|
||||
f"Ignoring test case that passed but had no runtime -> {result.id}, Loop # {result.loop_index}",
|
||||
f"Ignoring test case that passed but had no runtime -> {result.id}, Loop # {result.loop_index}"
|
||||
)
|
||||
usable_results = [result for result in self.test_results if result.did_pass and result.runtime]
|
||||
return sum(
|
||||
[
|
||||
min([result.runtime for result in usable_results if result.id == invocation_id])
|
||||
for invocation_id in {result.id for result in usable_results}
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[FunctionTestInvocation]:
|
||||
|
|
@ -192,10 +189,7 @@ class TestResults(BaseModel):
|
|||
or test_result.runtime != other_test_result.runtime
|
||||
or test_result.test_framework != other_test_result.test_framework
|
||||
or test_result.test_type != other_test_result.test_type
|
||||
or not comparator(
|
||||
test_result.return_value,
|
||||
other_test_result.return_value,
|
||||
)
|
||||
or not comparator(test_result.return_value, other_test_result.return_value)
|
||||
):
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -77,7 +77,5 @@ def run_tests(
|
|||
check=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid test framework -- I only support Pytest and Unittest currently.",
|
||||
)
|
||||
raise ValueError("Invalid test framework -- I only support Pytest and Unittest currently.")
|
||||
return result_file_path, results
|
||||
|
|
|
|||
|
|
@ -4,12 +4,7 @@ from pathlib import Path
|
|||
from pydantic.dataclasses import dataclass
|
||||
|
||||
|
||||
def get_test_file_path(
|
||||
test_dir: Path,
|
||||
function_name: str,
|
||||
iteration: int = 0,
|
||||
test_type: str = "unit",
|
||||
) -> Path:
|
||||
def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
|
||||
assert test_type in ["unit", "inspired", "replay"]
|
||||
function_name = function_name.replace(".", "_")
|
||||
path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}.py"
|
||||
|
|
|
|||
|
|
@ -40,10 +40,7 @@ def generate_tests(
|
|||
instrumented_test_source = module.CACHED_INSTRUMENTED_TESTS
|
||||
temp_run_dir = get_run_tmp_file(Path())
|
||||
path = str(temp_run_dir).replace("\\", "\\\\") # Escape backslash for windows paths
|
||||
instrumented_test_source = instrumented_test_source.replace(
|
||||
"{codeflash_run_tmp_dir_client_side}",
|
||||
path,
|
||||
)
|
||||
instrumented_test_source = instrumented_test_source.replace("{codeflash_run_tmp_dir_client_side}", path)
|
||||
logger.info(f"Using cached tests from {module_path}.CACHED_TESTS")
|
||||
else:
|
||||
test_file_path = get_test_file_path(test_cfg.tests_root, function_to_optimize.function_name, 0)
|
||||
|
|
@ -63,14 +60,9 @@ def generate_tests(
|
|||
generated_test_source, instrumented_test_source = response
|
||||
temp_run_dir = get_run_tmp_file(Path())
|
||||
path = str(temp_run_dir).replace("\\", "\\\\")
|
||||
instrumented_test_source = instrumented_test_source.replace(
|
||||
"{codeflash_run_tmp_dir_client_side}",
|
||||
path,
|
||||
)
|
||||
instrumented_test_source = instrumented_test_source.replace("{codeflash_run_tmp_dir_client_side}", path)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to generate and instrument tests for {function_to_optimize.function_name}",
|
||||
)
|
||||
logger.warning(f"Failed to generate and instrument tests for {function_to_optimize.function_name}")
|
||||
return None
|
||||
|
||||
return generated_test_source, instrumented_test_source
|
||||
|
|
|
|||
|
|
@ -15,11 +15,7 @@ def load_data(experiment_id: str, database_uri: str = os.environ.get("DATABASE_U
|
|||
WHERE (trace_id LIKE %s OR trace_id LIKE %s)
|
||||
AND experiment_metadata->>'id' = %s
|
||||
"""
|
||||
return pd.read_sql_query(
|
||||
query,
|
||||
connection,
|
||||
params=("%EXP0", "%EXP1", experiment_id),
|
||||
)
|
||||
return pd.read_sql_query(query, connection, params=("%EXP0", "%EXP1", experiment_id))
|
||||
|
||||
|
||||
def process_column_pairs(df: DataFrame, column_name: str) -> DataFrame:
|
||||
|
|
@ -44,22 +40,15 @@ def process_column_pairs(df: DataFrame, column_name: str) -> DataFrame:
|
|||
def calculate_validity(df: DataFrame, perf_threshold: float = 0.05) -> Dict[str, Any]:
|
||||
# Calculate the percentage of valid PRs given that the original function run succeeded
|
||||
successful_runs = df[(~df["original_runtime"].isna())]
|
||||
successful_runs_above_thres = successful_runs[
|
||||
successful_runs["best_correct_speedup_ratio"] >= perf_threshold
|
||||
]
|
||||
successful_runs_above_thres = successful_runs[successful_runs["best_correct_speedup_ratio"] >= perf_threshold]
|
||||
valid_prs = len(successful_runs_above_thres)
|
||||
percent_valid_pr = (valid_prs / len(df)) * 100
|
||||
|
||||
# Calculate the percentage of valid candidates generated given that original function run succeeded
|
||||
valid_candidates = successful_runs[successful_runs["is_correct"].apply(lambda x: any(x.values()))]
|
||||
percent_valid_candidates = (
|
||||
len(valid_candidates) / len(successful_runs) * 100 if len(successful_runs) > 0 else 0
|
||||
)
|
||||
percent_valid_candidates = len(valid_candidates) / len(successful_runs) * 100 if len(successful_runs) > 0 else 0
|
||||
|
||||
return {
|
||||
"percent_valid_pr": percent_valid_pr,
|
||||
"percent_valid_candidates": percent_valid_candidates,
|
||||
}
|
||||
return {"percent_valid_pr": percent_valid_pr, "percent_valid_candidates": percent_valid_candidates}
|
||||
|
||||
|
||||
def calculate_performance(df: DataFrame, perf_threshold: float = 0.05) -> Dict[str, Any]:
|
||||
|
|
@ -81,9 +70,7 @@ def calculate_performance(df: DataFrame, perf_threshold: float = 0.05) -> Dict[s
|
|||
def calculate_time_saved_for_row(row: pd.Series):
|
||||
if row["optimized_runtime"] is not None and row["is_correct"] is not None:
|
||||
correct_runtimes = [
|
||||
runtime
|
||||
for opt_id, runtime in row["optimized_runtime"].items()
|
||||
if row["is_correct"].get(opt_id)
|
||||
runtime for opt_id, runtime in row["optimized_runtime"].items() if row["is_correct"].get(opt_id)
|
||||
]
|
||||
else:
|
||||
correct_runtimes = []
|
||||
|
|
@ -93,23 +80,11 @@ def calculate_performance(df: DataFrame, perf_threshold: float = 0.05) -> Dict[s
|
|||
|
||||
# (3) The average time saved in a PR given that a valid candidate was found above the perf threshold.
|
||||
pr_time_saved = (
|
||||
valid_candidates_above_thres.apply(
|
||||
lambda row: calculate_time_saved_for_row(row),
|
||||
axis=1,
|
||||
)
|
||||
.dropna()
|
||||
.mean()
|
||||
valid_candidates_above_thres.apply(lambda row: calculate_time_saved_for_row(row), axis=1).dropna().mean()
|
||||
)
|
||||
|
||||
# (4) Calculate the mean average time saved for all the valid candidates
|
||||
all_candidates_time_saved = (
|
||||
df.apply(
|
||||
lambda row: calculate_time_saved_for_row(row),
|
||||
axis=1,
|
||||
)
|
||||
.dropna()
|
||||
.mean()
|
||||
)
|
||||
all_candidates_time_saved = df.apply(lambda row: calculate_time_saved_for_row(row), axis=1).dropna().mean()
|
||||
|
||||
return {
|
||||
"average_percentage_gain_pr": average_percentage_gain_pr,
|
||||
|
|
@ -132,7 +107,7 @@ def calculate_coverage(df: DataFrame) -> Dict[str, Any]:
|
|||
return successful_optimization_runs / total_runs * 100 if total_runs > 0 else 0.0
|
||||
|
||||
df["percent_successful_optimization_runs"] = df["optimized_runtime"].apply(
|
||||
calculate_percent_optimization_successful_runs,
|
||||
calculate_percent_optimization_successful_runs
|
||||
)
|
||||
|
||||
total_optimizations = sum(len(runs) for runs in successful_runs["optimized_runtime"] if runs is not None)
|
||||
|
|
@ -158,15 +133,9 @@ def calculate_coverage(df: DataFrame) -> Dict[str, Any]:
|
|||
|
||||
|
||||
def paired_comparison_coverage(
|
||||
df: DataFrame,
|
||||
model_a_suffix: str = "EXP0",
|
||||
model_b_suffix: str = "EXP1",
|
||||
df: DataFrame, model_a_suffix: str = "EXP0", model_b_suffix: str = "EXP1"
|
||||
) -> Dict[str, Any]:
|
||||
paired_coverage_results = {
|
||||
"model_a_more_successful": 0,
|
||||
"equal_successful": 0,
|
||||
"model_b_more_successful": 0,
|
||||
}
|
||||
paired_coverage_results = {"model_a_more_successful": 0, "equal_successful": 0, "model_b_more_successful": 0}
|
||||
grouped = df.groupby(df["trace_id"].str[:-4])
|
||||
for _, group in grouped:
|
||||
if len(group) == 2:
|
||||
|
|
@ -176,17 +145,13 @@ def paired_comparison_coverage(
|
|||
model_a_success_count = 0
|
||||
else:
|
||||
model_a_success_count = sum(
|
||||
1
|
||||
for runtime in model_a_row["optimized_runtime"].values[0].values()
|
||||
if runtime is not None
|
||||
1 for runtime in model_a_row["optimized_runtime"].values[0].values() if runtime is not None
|
||||
)
|
||||
if model_b_row["optimized_runtime"].values[0] is None:
|
||||
model_b_success_count = 0
|
||||
else:
|
||||
model_b_success_count = sum(
|
||||
1
|
||||
for runtime in model_b_row["optimized_runtime"].values[0].values()
|
||||
if runtime is not None
|
||||
1 for runtime in model_b_row["optimized_runtime"].values[0].values() if runtime is not None
|
||||
)
|
||||
|
||||
if model_a_success_count > model_b_success_count:
|
||||
|
|
@ -201,16 +166,10 @@ def paired_comparison_coverage(
|
|||
|
||||
|
||||
def paired_comparison_validity(
|
||||
df: DataFrame,
|
||||
model_a_suffix: str = "EXP0",
|
||||
model_b_suffix: str = "EXP1",
|
||||
df: DataFrame, model_a_suffix: str = "EXP0", model_b_suffix: str = "EXP1"
|
||||
) -> Dict[str, Any]:
|
||||
# Paired - Calculate the percentage of runs where model A generated more, equal, or less valid candidates than model B
|
||||
paired_validity_results = {
|
||||
"model_a_more_valid": 0,
|
||||
"equal_valid": 0,
|
||||
"model_b_more_valid": 0,
|
||||
}
|
||||
paired_validity_results = {"model_a_more_valid": 0, "equal_valid": 0, "model_b_more_valid": 0}
|
||||
grouped = df.groupby(df["trace_id"].str[:-4])
|
||||
for _, group in grouped:
|
||||
if len(group) == 2:
|
||||
|
|
@ -235,15 +194,9 @@ def paired_comparison_validity(
|
|||
|
||||
|
||||
def paired_comparison_performance(
|
||||
df: DataFrame,
|
||||
model_a_suffix: str = "EXP0",
|
||||
model_b_suffix: str = "EXP1",
|
||||
df: DataFrame, model_a_suffix: str = "EXP0", model_b_suffix: str = "EXP1"
|
||||
) -> Dict[str, Any]:
|
||||
paired_results = {
|
||||
"model_a_better": 0,
|
||||
"equal": 0,
|
||||
"model_b_better": 0,
|
||||
}
|
||||
paired_results = {"model_a_better": 0, "equal": 0, "model_b_better": 0}
|
||||
|
||||
# Group by the trace_id without the suffix
|
||||
grouped = df.groupby(df["trace_id"].str[:-4])
|
||||
|
|
@ -274,8 +227,7 @@ def paired_comparison_performance(
|
|||
def augment_with_best_correct_speedup_ratio(df: DataFrame) -> DataFrame:
|
||||
# Extract the best speedup ratio from the speedup_ratio dictionary, accounting for empty dictionaries
|
||||
def get_best_correct_speedup_ratio(
|
||||
speedup_ratios: Dict[str, float],
|
||||
is_correct: Dict[str, bool],
|
||||
speedup_ratios: Dict[str, float], is_correct: Dict[str, bool]
|
||||
) -> Optional[float]:
|
||||
correct_speedup_ratios = (
|
||||
{uuid: ratio for uuid, ratio in speedup_ratios.items() if is_correct.get(uuid)}
|
||||
|
|
@ -288,10 +240,7 @@ def augment_with_best_correct_speedup_ratio(df: DataFrame) -> DataFrame:
|
|||
return None
|
||||
|
||||
df["best_correct_speedup_ratio"] = df.apply(
|
||||
lambda row: get_best_correct_speedup_ratio(
|
||||
row["speedup_ratio"],
|
||||
row["is_correct"],
|
||||
)
|
||||
lambda row: get_best_correct_speedup_ratio(row["speedup_ratio"], row["is_correct"])
|
||||
if row["speedup_ratio"] is not None
|
||||
else None,
|
||||
axis=1,
|
||||
|
|
@ -324,17 +273,9 @@ def main() -> None:
|
|||
# Combine metrics into a DataFrame
|
||||
metrics_df = pd.DataFrame(
|
||||
{
|
||||
"EXP0": {
|
||||
**exp0_performance_metrics,
|
||||
**exp0_validity_metrics,
|
||||
**exp0_coverage_metrics,
|
||||
},
|
||||
"EXP1": {
|
||||
**exp1_performance_metrics,
|
||||
**exp1_validity_metrics,
|
||||
**exp1_coverage_metrics,
|
||||
},
|
||||
},
|
||||
"EXP0": {**exp0_performance_metrics, **exp0_validity_metrics, **exp0_coverage_metrics},
|
||||
"EXP1": {**exp1_performance_metrics, **exp1_validity_metrics, **exp1_coverage_metrics},
|
||||
}
|
||||
) # Transpose to have experiments as rows and metrics as columns
|
||||
|
||||
# Output the combined metrics DataFrame
|
||||
|
|
|
|||
|
|
@ -45,10 +45,7 @@ def main(trace_id: str) -> None:
|
|||
write_to_file("original_code.py", original_code)
|
||||
|
||||
# Write each optimization candidate to its own file
|
||||
for idx, (opt_id, optimization) in enumerate(
|
||||
extract_json_values(optimizations_post).items(),
|
||||
start=1,
|
||||
):
|
||||
for idx, (opt_id, optimization) in enumerate(extract_json_values(optimizations_post).items(), start=1):
|
||||
filename = f"optimization_candidate_{idx}.py"
|
||||
explanation = explanations_post.get(opt_id, "")
|
||||
speedup = speedup_ratio.get(opt_id)
|
||||
|
|
@ -61,21 +58,13 @@ def main(trace_id: str) -> None:
|
|||
valid_speedup_values = [v for v in speedup_ratio.values() if v is not None]
|
||||
best_speedup = max(valid_speedup_values, default=None) if valid_speedup_values else None
|
||||
if best_speedup is not None:
|
||||
best_optimization_id = next(
|
||||
(id for id, speedup in speedup_ratio.items() if speedup == best_speedup),
|
||||
None,
|
||||
)
|
||||
best_optimization_id = next((id for id, speedup in speedup_ratio.items() if speedup == best_speedup), None)
|
||||
best_optimization = optimizations_post.get(best_optimization_id)
|
||||
best_explanation = explanations_post.get(best_optimization_id, "")
|
||||
best_speedup_comment = f"Speedup: {best_speedup}"
|
||||
best_content_with_comment = (
|
||||
f'"""{best_explanation}\n\n{best_speedup_comment}"""\n\n{best_optimization}'
|
||||
)
|
||||
best_content_with_comment = f'"""{best_explanation}\n\n{best_speedup_comment}"""\n\n{best_optimization}'
|
||||
if best_optimization and best_explanation is not None:
|
||||
write_to_file(
|
||||
"best_optimization_candidate.py",
|
||||
best_content_with_comment,
|
||||
)
|
||||
write_to_file("best_optimization_candidate.py", best_content_with_comment)
|
||||
else:
|
||||
print("No speedup ratio found")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,5 @@
|
|||
mock_dataframe = {
|
||||
"original_runtime": [
|
||||
1.0,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
8333.00000,
|
||||
],
|
||||
"original_runtime": [1.0, 2.0, 3.0, 4.0, 8333.00000],
|
||||
"optimized_runtime": [
|
||||
{
|
||||
"0a631960-6c33-4de4-8d38-b31e888918c7": 59209,
|
||||
|
|
@ -51,7 +45,7 @@ mock_dataframe = {
|
|||
"573c9f22-ae68-4e62-929d-b70d22da0ec0": -0.10714668381013608,
|
||||
"6fa4faaf-5d67-4ec6-bf55-5a84dfa23a11": 0.1695438596491228,
|
||||
"a06d1ad8-cc41-4b3b-a1a6-4566697d7b10": 0.16285235835891712,
|
||||
},
|
||||
}
|
||||
],
|
||||
"best_correct_speedup_ratio": [0.02112, None, 1.81043, None],
|
||||
"is_correct": [
|
||||
|
|
|
|||
|
|
@ -398,22 +398,22 @@ pie4perf_sample_dataframe_dict = {
|
|||
},
|
||||
"generated_test": {
|
||||
0: [
|
||||
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02847 import problem_p02847\n\n# unit tests\n\n# Test normal cases with valid inputs for each day of the week\n@pytest.mark.parametrize("input_data, expected", [\n ("\'MON\'", 6),\n ("\'TUE\'", 5),\n ("\'WED\'", 4),\n ("\'THU\'", 3),\n ("\'FRI\'", 2),\n ("\'SAT\'", 1),\n ("\'SUN\'", 7),\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p02847(input_data) == expected\n\n# Test edge cases with invalid day names\n@pytest.mark.parametrize("input_data", [\n "\'FUNDAY\'",\n "\'MOON\'",\n "123",\n "[\'MON\']",\n "None",\n])\ndef test_edge_cases_with_invalid_day_names(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test misuse cases with potentially malicious code\n@pytest.mark.parametrize("input_data", [\n "__import__(\'os\').system(\'echo hello\')",\n "exit()",\n])\ndef test_misuse_cases_with_malicious_code(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with lowercase and mixed case inputs\n@pytest.mark.parametrize("input_data", [\n "\'mon\'",\n "\'Sun\'",\n "\'Mon.\'",\n "\'Monday\'",\n])\ndef test_edge_cases_with_case_sensitivity(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with additional text\n@pytest.mark.parametrize("input_data", [\n "\'Today is MON\'",\n "\'SUNday\'",\n])\ndef test_edge_cases_with_additional_text(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with non-standard capitalization and characters\n@pytest.mark.parametrize("input_data", [\n "\'monDAY\'",\n "\'M@N\'",\n])\ndef test_edge_cases_with_non_standard_characters(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with whitespace and control characters\n@pytest.mark.parametrize("input_data", [\n "\'\\t\\nMON\\n\\t\'",\n "\'MON\\u200b\'",\n])\ndef test_edge_cases_with_whitespace_and_control_characters(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with similar sounding or looking names\n@pytest.mark.parametrize("input_data", [\n "\'MOAN\'",\n "\'SUNN\'",\n])\ndef test_edge_cases_with_similar_names(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with empty and null inputs\n@pytest.mark.parametrize("input_data", [\n "\'\'",\n "None",\n])\ndef test_edge_cases_with_empty_and_null_inputs(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with valid Python expressions\n@pytest.mark.parametrize("input_data", [\n "\'6-1\'",\n "\'[x for x in range(5)]\'",\n])\ndef test_edge_cases_with_valid_python_expressions(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with input as Python code\n@pytest.mark.parametrize("input_data", [\n "\'os.system(\\\'echo hello\\\')\'",\n "\'(__import__(\\\'sys\\\').exit())\'",\n])\ndef test_edge_cases_with_python_code(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)',
|
||||
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02847 import problem_p02847\n\n# unit tests\n\n# Test normal cases with valid inputs for each day of the week\n@pytest.mark.parametrize("input_data, expected", [\n ("\'MON\'", 6),\n ("\'TUE\'", 5),\n ("\'WED\'", 4),\n ("\'THU\'", 3),\n ("\'FRI\'", 2),\n ("\'SAT\'", 1),\n ("\'SUN\'", 7),\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p02847(input_data) == expected\n\n# Test edge cases with invalid day names\n@pytest.mark.parametrize("input_data", [\n "\'FUNDAY\'",\n "\'MOON\'",\n "123",\n "[\'MON\']",\n "None",\n])\ndef test_edge_cases_with_invalid_day_names(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test misuse cases with potentially malicious code\n@pytest.mark.parametrize("input_data", [\n "__import__(\'os\').system(\'echo hello\')",\n "exit()",\n])\ndef test_misuse_cases_with_malicious_code(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with lowercase and mixed case inputs\n@pytest.mark.parametrize("input_data", [\n "\'mon\'",\n "\'Sun\'",\n "\'Mon.\'",\n "\'Monday\'",\n])\ndef test_edge_cases_with_case_sensitivity(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with additional text\n@pytest.mark.parametrize("input_data", [\n "\'Today is MON\'",\n "\'SUNday\'",\n])\ndef test_edge_cases_with_additional_text(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with non-standard capitalization and characters\n@pytest.mark.parametrize("input_data", [\n "\'monDAY\'",\n "\'M@N\'",\n])\ndef test_edge_cases_with_non_standard_characters(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with whitespace and control characters\n@pytest.mark.parametrize("input_data", [\n "\'\\t\\nMON\\n\\t\'",\n "\'MON\\u200b\'",\n])\ndef test_edge_cases_with_whitespace_and_control_characters(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with similar sounding or looking names\n@pytest.mark.parametrize("input_data", [\n "\'MOAN\'",\n "\'SUNN\'",\n])\ndef test_edge_cases_with_similar_names(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with empty and null inputs\n@pytest.mark.parametrize("input_data", [\n "\'\'",\n "None",\n])\ndef test_edge_cases_with_empty_and_null_inputs(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with valid Python expressions\n@pytest.mark.parametrize("input_data", [\n "\'6-1\'",\n "\'[x for x in range(5)]\'",\n])\ndef test_edge_cases_with_valid_python_expressions(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with input as Python code\n@pytest.mark.parametrize("input_data", [\n "\'os.system(\\\'echo hello\\\')\'",\n "\'(__import__(\\\'sys\\\').exit())\'",\n])\ndef test_edge_cases_with_python_code(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)'
|
||||
],
|
||||
1: [
|
||||
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03196 import problem_p03196\n\n# unit tests\n\n# Test normal cases\n@pytest.mark.parametrize("input_data, expected", [\n ("2 16", 4),\n ("3 27", 3),\n ("2 20", 1),\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test edge cases\n@pytest.mark.parametrize("input_data, expected", [\n ("1 100", 100),\n ("1 1", 1),\n ("41 1000000", 1),\n ("100 2", 1),\n ("2 17", 1),\n ("3 19", 1),\n])\ndef test_edge_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test large numbers\n@pytest.mark.parametrize("input_data, expected", [\n ("2 1000000000000", 1000000),\n ("3 1000000000000", 10000),\n])\ndef test_large_numbers(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test invalid input format\n@pytest.mark.parametrize("input_data", [\n "2.5 10",\n "2 -10",\n "2",\n "two 10",\n "2, 10",\n])\ndef test_invalid_input(input_data):\n with pytest.raises(ValueError):\n problem_p03196(input_data)\n\n# Test boundary cases\n@pytest.mark.parametrize("input_data, expected", [\n ("40 1024", 2),\n ("2 1024", 32),\n])\ndef test_boundary_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test rare or unexpected edge cases\n@pytest.mark.parametrize("input_data, expected", [\n ("2 " + str(2**63 - 1), 1),\n ("3 " + str(2**63 - 1), 1),\n ("2 1", 1),\n ("10 1", 1),\n ("37 137**37", 137),\n ("37 138", 1),\n ("3 8", 2),\n ("5 32", 2),\n ("2 6", 2),\n ("2 120", 4),\n ("2 49", 7),\n ("2 77", 1),\n])\ndef test_unexpected_edge_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test zero or negative n values\n@pytest.mark.parametrize("input_data", [\n "0 10",\n "-2 16",\n])\ndef test_zero_or_negative_n(input_data):\n with pytest.raises(ValueError):\n problem_p03196(input_data)',
|
||||
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03196 import problem_p03196\n\n# unit tests\n\n# Test normal cases\n@pytest.mark.parametrize("input_data, expected", [\n ("2 16", 4),\n ("3 27", 3),\n ("2 20", 1),\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test edge cases\n@pytest.mark.parametrize("input_data, expected", [\n ("1 100", 100),\n ("1 1", 1),\n ("41 1000000", 1),\n ("100 2", 1),\n ("2 17", 1),\n ("3 19", 1),\n])\ndef test_edge_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test large numbers\n@pytest.mark.parametrize("input_data, expected", [\n ("2 1000000000000", 1000000),\n ("3 1000000000000", 10000),\n])\ndef test_large_numbers(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test invalid input format\n@pytest.mark.parametrize("input_data", [\n "2.5 10",\n "2 -10",\n "2",\n "two 10",\n "2, 10",\n])\ndef test_invalid_input(input_data):\n with pytest.raises(ValueError):\n problem_p03196(input_data)\n\n# Test boundary cases\n@pytest.mark.parametrize("input_data, expected", [\n ("40 1024", 2),\n ("2 1024", 32),\n])\ndef test_boundary_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test rare or unexpected edge cases\n@pytest.mark.parametrize("input_data, expected", [\n ("2 " + str(2**63 - 1), 1),\n ("3 " + str(2**63 - 1), 1),\n ("2 1", 1),\n ("10 1", 1),\n ("37 137**37", 137),\n ("37 138", 1),\n ("3 8", 2),\n ("5 32", 2),\n ("2 6", 2),\n ("2 120", 4),\n ("2 49", 7),\n ("2 77", 1),\n])\ndef test_unexpected_edge_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test zero or negative n values\n@pytest.mark.parametrize("input_data", [\n "0 10",\n "-2 16",\n])\ndef test_zero_or_negative_n(input_data):\n with pytest.raises(ValueError):\n problem_p03196(input_data)'
|
||||
],
|
||||
2: [
|
||||
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03861 import problem_p03861\n\n# unit tests\n\n# Test typical cases with different ranges and divisors\ndef test_typical_cases():\n assert problem_p03861("1 10 2") == 5\n assert problem_p03861("3 15 3") == 5\n assert problem_p03861("10 100 10") == 10\n\n# Test edge cases where a and b are the same\ndef test_edge_cases_same_a_b():\n assert problem_p03861("5 5 1") == 1\n assert problem_p03861("7 7 2") == 0\n assert problem_p03861("10 10 10") == 1\n\n# Test edge cases where a is a multiple of x\ndef test_edge_cases_a_multiple_of_x():\n assert problem_p03861("6 12 6") == 2\n assert problem_p03861("10 20 10") == 2\n\n# Test edge cases where b is a multiple of x\ndef test_edge_cases_b_multiple_of_x():\n assert problem_p03861("1 10 5") == 2\n assert problem_p03861("3 21 7") == 3\n\n# Test cases where a is greater than b\ndef test_cases_a_greater_than_b():\n assert problem_p03861("10 5 1") == 0\n assert problem_p03861("20 10 5") == 0\n\n# Test cases where x is larger than both a and b\ndef test_cases_x_larger_than_a_b():\n assert problem_p03861("1 5 10") == 0\n assert problem_p03861("2 8 20") == 0\n\n# Test cases with negative numbers and zero\ndef test_cases_negative_numbers_and_zero():\n assert problem_p03861("0 0 1") == 1\n assert problem_p03861("-5 5 5") == 3\n assert problem_p03861("-10 -1 2") == 5\n\n# Test cases with large numbers\ndef test_cases_large_numbers():\n assert problem_p03861("1000000 2000000 100000") == 20\n # This test may need to be adjusted for the expected number of multiples\n assert problem_p03861("123456789 987654321 12345") == 987654321 // 12345 - (123456789 - 1) // 12345\n\n# Test invalid input cases\ndef test_invalid_input_cases():\n with pytest.raises(ValueError):\n problem_p03861("")\n with pytest.raises(ValueError):\n problem_p03861("1 10")\n with pytest.raises(ValueError):\n problem_p03861("1 10 two")\n\n# Test rare or unexpected edge cases\ndef test_rare_edge_cases():\n with pytest.raises(ZeroDivisionError):\n problem_p03861("1 10 0")\n with pytest.raises(ZeroDivisionError):\n problem_p03861("5 5 0")\n assert problem_p03861("0 0 5") == 1\n assert problem_p03861("-10 -1 2") == 5\n assert problem_p03861("-15 -5 3") == 4\n assert problem_p03861("-10 -1 -2") == 5\n assert problem_p03861("-20 -10 -5") == 3\n assert problem_p03861("1 10 -1") == 10\n assert problem_p03861("5 15 -3") == 4\n assert problem_p03861("1000000000 2000000000 100000000") == 20\n assert problem_p03861("2147483647 2147483647 1") == 1\n assert problem_p03861("-2147483648 -2147483648 1") == 1\n assert problem_p03861("2147483647 2147483647 2147483647") == 1\n assert problem_p03861("7 7 7") == 1\n assert problem_p03861("123456 123456 123456") == 1\n assert problem_p03861("0 0 -5") == 1',
|
||||
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03861 import problem_p03861\n\n# unit tests\n\n# Test typical cases with different ranges and divisors\ndef test_typical_cases():\n assert problem_p03861("1 10 2") == 5\n assert problem_p03861("3 15 3") == 5\n assert problem_p03861("10 100 10") == 10\n\n# Test edge cases where a and b are the same\ndef test_edge_cases_same_a_b():\n assert problem_p03861("5 5 1") == 1\n assert problem_p03861("7 7 2") == 0\n assert problem_p03861("10 10 10") == 1\n\n# Test edge cases where a is a multiple of x\ndef test_edge_cases_a_multiple_of_x():\n assert problem_p03861("6 12 6") == 2\n assert problem_p03861("10 20 10") == 2\n\n# Test edge cases where b is a multiple of x\ndef test_edge_cases_b_multiple_of_x():\n assert problem_p03861("1 10 5") == 2\n assert problem_p03861("3 21 7") == 3\n\n# Test cases where a is greater than b\ndef test_cases_a_greater_than_b():\n assert problem_p03861("10 5 1") == 0\n assert problem_p03861("20 10 5") == 0\n\n# Test cases where x is larger than both a and b\ndef test_cases_x_larger_than_a_b():\n assert problem_p03861("1 5 10") == 0\n assert problem_p03861("2 8 20") == 0\n\n# Test cases with negative numbers and zero\ndef test_cases_negative_numbers_and_zero():\n assert problem_p03861("0 0 1") == 1\n assert problem_p03861("-5 5 5") == 3\n assert problem_p03861("-10 -1 2") == 5\n\n# Test cases with large numbers\ndef test_cases_large_numbers():\n assert problem_p03861("1000000 2000000 100000") == 20\n # This test may need to be adjusted for the expected number of multiples\n assert problem_p03861("123456789 987654321 12345") == 987654321 // 12345 - (123456789 - 1) // 12345\n\n# Test invalid input cases\ndef test_invalid_input_cases():\n with pytest.raises(ValueError):\n problem_p03861("")\n with pytest.raises(ValueError):\n problem_p03861("1 10")\n with pytest.raises(ValueError):\n problem_p03861("1 10 two")\n\n# Test rare or unexpected edge cases\ndef test_rare_edge_cases():\n with pytest.raises(ZeroDivisionError):\n problem_p03861("1 10 0")\n with pytest.raises(ZeroDivisionError):\n problem_p03861("5 5 0")\n assert problem_p03861("0 0 5") == 1\n assert problem_p03861("-10 -1 2") == 5\n assert problem_p03861("-15 -5 3") == 4\n assert problem_p03861("-10 -1 -2") == 5\n assert problem_p03861("-20 -10 -5") == 3\n assert problem_p03861("1 10 -1") == 10\n assert problem_p03861("5 15 -3") == 4\n assert problem_p03861("1000000000 2000000000 100000000") == 20\n assert problem_p03861("2147483647 2147483647 1") == 1\n assert problem_p03861("-2147483648 -2147483648 1") == 1\n assert problem_p03861("2147483647 2147483647 2147483647") == 1\n assert problem_p03861("7 7 7") == 1\n assert problem_p03861("123456 123456 123456") == 1\n assert problem_p03861("0 0 -5") == 1'
|
||||
],
|
||||
3: [
|
||||
"# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02910 import problem_p02910\n\n# unit tests\n\n# Test normal cases\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['S', 'A', 'M', 'P', 'L', 'E']\", \"Yes\"),\n (\"['T', 'E', 'S', 'T', 'I', 'N', 'G']\", \"Yes\"),\n (\"['L', 'A', 'M', 'P', 'L', 'E']\", \"No\"),\n (\"['T', 'E', 'S', 'R', 'I', 'N', 'G']\", \"No\")\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test edge cases\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"[]\", \"Yes\"),\n (\"['L']\", \"No\"),\n (\"['R']\", \"Yes\"),\n (\"['A', 'A', 'A', 'A']\", \"Yes\"),\n (\"['L', 'R', 'L', 'R']\", \"No\")\n])\ndef test_edge_cases(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test special characters and non-string elements\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['$', '2', '#', '4', '@', '6']\", \"Yes\"),\n (\"[1, 'R', None, 'L', True, 'F']\", \"Yes\")\n])\ndef test_special_characters_and_non_string_elements(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test large inputs\ndef test_large_input():\n large_input = \"['N', 'E', 'V', 'E', 'R', 'L', 'O', 'O', 'K', 'B', 'A', 'C', 'K'] * 1000\"\n assert problem_p02910(large_input) == \"Yes\"\n\n# Test invalid inputs\n@pytest.mark.parametrize(\"input_data\", [\n \"'This is not a list'\",\n \"[['A', 'B'], ['C', 'D']]\",\n \"__import__('os').system('rm -rf /')\"\n])\ndef test_invalid_inputs(input_data):\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)\n\n# Test unicode characters\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['😀', 'L', '🎉', 'R', '❤️', '🚀']\", \"No\"),\n (\"['こんにちは', '世界', '안녕하세요', '세계']\", \"Yes\")\n])\ndef test_unicode_characters(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test escaped characters\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['\\\\n', 'R', '\\\\t', 'L']\", \"Yes\"),\n (\"['A', '\\\\n', 'B', '\\\\r']\", \"Yes\")\n])\ndef test_escaped_characters(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test boolean values\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"[True, False, 'T', 'F']\", \"Yes\"),\n (\"[False, True, 'L', 'R']\", \"No\")\n])\ndef test_boolean_values(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test case sensitivity\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['l', 'R', 'm', 'r', 'n', 'L']\", \"Yes\"),\n (\"['L', 'r', 'M', 'R', 'N', 'l']\", \"No\")\n])\ndef test_case_sensitivity(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test input with only one character type\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['L', 'L', 'L', 'L']\", \"No\"),\n (\"['R', 'R', 'R', 'R']\", \"Yes\")\n])\ndef test_input_with_one_character_type(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test input with data structures other than lists\n@pytest.mark.parametrize(\"input_data\", [\n \"('A', 'B', 'C', 'D')\",\n \"{1, 2, 3, 4}\"\n])\ndef test_input_with_other_data_structures(input_data):\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)\n\n# Test input with nested quotes\ndef test_input_with_nested_quotes():\n input_data = \"'[\\\\'A\\\\', \\\\'B\\\\', \\\\'C\\\\']'\"\n assert problem_p02910(input_data) == \"Yes\"\n\n# Test input with dictionary\ndef test_input_with_dictionary():\n input_data = \"{'key1': 'value1', 'key2': 'value2'}\"\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)",
|
||||
"# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02910 import problem_p02910\n\n# unit tests\n\n# Test normal cases\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['S', 'A', 'M', 'P', 'L', 'E']\", \"Yes\"),\n (\"['T', 'E', 'S', 'T', 'I', 'N', 'G']\", \"Yes\"),\n (\"['L', 'A', 'M', 'P', 'L', 'E']\", \"No\"),\n (\"['T', 'E', 'S', 'R', 'I', 'N', 'G']\", \"No\")\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test edge cases\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"[]\", \"Yes\"),\n (\"['L']\", \"No\"),\n (\"['R']\", \"Yes\"),\n (\"['A', 'A', 'A', 'A']\", \"Yes\"),\n (\"['L', 'R', 'L', 'R']\", \"No\")\n])\ndef test_edge_cases(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test special characters and non-string elements\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['$', '2', '#', '4', '@', '6']\", \"Yes\"),\n (\"[1, 'R', None, 'L', True, 'F']\", \"Yes\")\n])\ndef test_special_characters_and_non_string_elements(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test large inputs\ndef test_large_input():\n large_input = \"['N', 'E', 'V', 'E', 'R', 'L', 'O', 'O', 'K', 'B', 'A', 'C', 'K'] * 1000\"\n assert problem_p02910(large_input) == \"Yes\"\n\n# Test invalid inputs\n@pytest.mark.parametrize(\"input_data\", [\n \"'This is not a list'\",\n \"[['A', 'B'], ['C', 'D']]\",\n \"__import__('os').system('rm -rf /')\"\n])\ndef test_invalid_inputs(input_data):\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)\n\n# Test unicode characters\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['😀', 'L', '🎉', 'R', '❤️', '🚀']\", \"No\"),\n (\"['こんにちは', '世界', '안녕하세요', '세계']\", \"Yes\")\n])\ndef test_unicode_characters(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test escaped characters\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['\\\\n', 'R', '\\\\t', 'L']\", \"Yes\"),\n (\"['A', '\\\\n', 'B', '\\\\r']\", \"Yes\")\n])\ndef test_escaped_characters(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test boolean values\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"[True, False, 'T', 'F']\", \"Yes\"),\n (\"[False, True, 'L', 'R']\", \"No\")\n])\ndef test_boolean_values(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test case sensitivity\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['l', 'R', 'm', 'r', 'n', 'L']\", \"Yes\"),\n (\"['L', 'r', 'M', 'R', 'N', 'l']\", \"No\")\n])\ndef test_case_sensitivity(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test input with only one character type\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['L', 'L', 'L', 'L']\", \"No\"),\n (\"['R', 'R', 'R', 'R']\", \"Yes\")\n])\ndef test_input_with_one_character_type(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test input with data structures other than lists\n@pytest.mark.parametrize(\"input_data\", [\n \"('A', 'B', 'C', 'D')\",\n \"{1, 2, 3, 4}\"\n])\ndef test_input_with_other_data_structures(input_data):\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)\n\n# Test input with nested quotes\ndef test_input_with_nested_quotes():\n input_data = \"'[\\\\'A\\\\', \\\\'B\\\\', \\\\'C\\\\']'\"\n assert problem_p02910(input_data) == \"Yes\"\n\n# Test input with dictionary\ndef test_input_with_dictionary():\n input_data = \"{'key1': 'value1', 'key2': 'value2'}\"\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)"
|
||||
],
|
||||
4: [
|
||||
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03965 import problem_p03965\n\n# unit tests\n\n# Test empty input\ndef test_empty_input():\n assert problem_p03965("") == 0\n\n# Test input with only "g" characters\n@pytest.mark.parametrize("input_data", ["g", "ggggg"])\ndef test_only_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test input with no "g" characters\n@pytest.mark.parametrize("input_data", ["abcdef", "pppp"])\ndef test_no_g_characters(input_data):\n assert problem_p03965(input_data) == -len(input_data)\n\n# Test input with alternating "g" and other characters\n@pytest.mark.parametrize("input_data", ["gpgpg", "gagbgc"])\ndef test_alternating_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test input with consecutive "g" characters surrounded by other characters\n@pytest.mark.parametrize("input_data", ["agggbc", "xggggy"])\ndef test_consecutive_g_characters(input_data):\n assert problem_p03965(input_data) == 1\n\n# Test input with leading and trailing whitespace\n@pytest.mark.parametrize("input_data, expected", [(" gggg ", 0), ("\\tgggg\\n", 0)])\ndef test_whitespace(input_data, expected):\n assert problem_p03965(input_data) == expected\n\n# Test input with only whitespace characters\n@pytest.mark.parametrize("input_data", [" ", "\\t\\n\\r"])\ndef test_only_whitespace(input_data):\n assert problem_p03965(input_data) == -len(input_data.rstrip())\n\n# Test input with "g" at the start, middle, and end\n@pytest.mark.parametrize("input_data", ["gabcgdefg", "gxyzg"])\ndef test_g_start_middle_end(input_data):\n assert problem_p03965(input_data) == 1\n\n# Test input with repeated non-"g" characters\n@pytest.mark.parametrize("input_data", ["ppppp", "xxxxx"])\ndef test_repeated_non_g_characters(input_data):\n assert problem_p03965(input_data) == -len(input_data)\n\n# Test input with mixed "g" and non-"g" characters\n@pytest.mark.parametrize("input_data", ["gpapbpcg", "ggppggpp"])\ndef test_mixed_g_non_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test long strings\ndef test_long_string():\n long_string = "g" * 10000\n assert problem_p03965(long_string) == 0\n\n# Test special characters and numbers\n@pytest.mark.parametrize("input_data", ["g#%&g123", "12g34g!@"])\ndef test_special_characters_and_numbers(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test Unicode characters\n@pytest.mark.parametrize("input_data", ["g日本語g", "🙂g🙃g"])\ndef test_unicode_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test case sensitivity\n@pytest.mark.parametrize("input_data", ["GgGg", "gGgG"])\ndef test_case_sensitivity(input_data):\n assert problem_p03965(input_data) == 0',
|
||||
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03965 import problem_p03965\n\n# unit tests\n\n# Test empty input\ndef test_empty_input():\n assert problem_p03965("") == 0\n\n# Test input with only "g" characters\n@pytest.mark.parametrize("input_data", ["g", "ggggg"])\ndef test_only_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test input with no "g" characters\n@pytest.mark.parametrize("input_data", ["abcdef", "pppp"])\ndef test_no_g_characters(input_data):\n assert problem_p03965(input_data) == -len(input_data)\n\n# Test input with alternating "g" and other characters\n@pytest.mark.parametrize("input_data", ["gpgpg", "gagbgc"])\ndef test_alternating_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test input with consecutive "g" characters surrounded by other characters\n@pytest.mark.parametrize("input_data", ["agggbc", "xggggy"])\ndef test_consecutive_g_characters(input_data):\n assert problem_p03965(input_data) == 1\n\n# Test input with leading and trailing whitespace\n@pytest.mark.parametrize("input_data, expected", [(" gggg ", 0), ("\\tgggg\\n", 0)])\ndef test_whitespace(input_data, expected):\n assert problem_p03965(input_data) == expected\n\n# Test input with only whitespace characters\n@pytest.mark.parametrize("input_data", [" ", "\\t\\n\\r"])\ndef test_only_whitespace(input_data):\n assert problem_p03965(input_data) == -len(input_data.rstrip())\n\n# Test input with "g" at the start, middle, and end\n@pytest.mark.parametrize("input_data", ["gabcgdefg", "gxyzg"])\ndef test_g_start_middle_end(input_data):\n assert problem_p03965(input_data) == 1\n\n# Test input with repeated non-"g" characters\n@pytest.mark.parametrize("input_data", ["ppppp", "xxxxx"])\ndef test_repeated_non_g_characters(input_data):\n assert problem_p03965(input_data) == -len(input_data)\n\n# Test input with mixed "g" and non-"g" characters\n@pytest.mark.parametrize("input_data", ["gpapbpcg", "ggppggpp"])\ndef test_mixed_g_non_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test long strings\ndef test_long_string():\n long_string = "g" * 10000\n assert problem_p03965(long_string) == 0\n\n# Test special characters and numbers\n@pytest.mark.parametrize("input_data", ["g#%&g123", "12g34g!@"])\ndef test_special_characters_and_numbers(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test Unicode characters\n@pytest.mark.parametrize("input_data", ["g日本語g", "🙂g🙃g"])\ndef test_unicode_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test case sensitivity\n@pytest.mark.parametrize("input_data", ["GgGg", "gGgG"])\ndef test_case_sensitivity(input_data):\n assert problem_p03965(input_data) == 0'
|
||||
],
|
||||
5: [
|
||||
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02553 import problem_p02553\n\n# unit tests\n\n@pytest.mark.parametrize("input_data, expected", [\n ("1 2 3 4", 8),\n ("-1 -2 -3 -4", -3),\n ("1 -2 3 -4", 12),\n ("0 2 3 4", 8),\n ("1 0 3 4", 4),\n ("0 0 0 0", 0),\n ("-1 2 -3 4", 0),\n ("1 -2 3 -4", 0),\n ("-1 -2 3 4", -3),\n ("1 2 -3 -4", -6),\n ("1000000000 1000000000 1000000000 1000000000", 1000000000000000000),\n ("-1000000000 -1000000000 -1000000000 -1000000000", -1000000000000000000),\n ("2147483647 -2147483648 -2147483648 2147483647", 4611686014132420609),\n ("1000000000 1 1000000000 1", 1000000000),\n ("-1000000000 1 -1000000000 1", 1),\n ("-1 1 -1 1", 0),\n ("1 -1 1 -1", 0),\n ("1e3 2e3 3e3 4e3", 8000000),\n ("-1e3 -2e3 3e3 4e3", 3000000),\n])\ndef test_problem_p02553_normal_and_edge_cases(input_data, expected):\n # Test normal cases and various edge cases\n assert problem_p02553(input_data) == expected\n\n@pytest.mark.parametrize("input_data", [\n "1,2,3,4",\n "1 2 3",\n "one two three four",\n "\\t1\\t2\\t3\\t4",\n "1\\n2\\n3\\n4"\n])\ndef test_problem_p02553_invalid_input(input_data):\n # Test invalid inputs that should raise a ValueError\n with pytest.raises(ValueError):\n problem_p02553(input_data)',
|
||||
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02553 import problem_p02553\n\n# unit tests\n\n@pytest.mark.parametrize("input_data, expected", [\n ("1 2 3 4", 8),\n ("-1 -2 -3 -4", -3),\n ("1 -2 3 -4", 12),\n ("0 2 3 4", 8),\n ("1 0 3 4", 4),\n ("0 0 0 0", 0),\n ("-1 2 -3 4", 0),\n ("1 -2 3 -4", 0),\n ("-1 -2 3 4", -3),\n ("1 2 -3 -4", -6),\n ("1000000000 1000000000 1000000000 1000000000", 1000000000000000000),\n ("-1000000000 -1000000000 -1000000000 -1000000000", -1000000000000000000),\n ("2147483647 -2147483648 -2147483648 2147483647", 4611686014132420609),\n ("1000000000 1 1000000000 1", 1000000000),\n ("-1000000000 1 -1000000000 1", 1),\n ("-1 1 -1 1", 0),\n ("1 -1 1 -1", 0),\n ("1e3 2e3 3e3 4e3", 8000000),\n ("-1e3 -2e3 3e3 4e3", 3000000),\n])\ndef test_problem_p02553_normal_and_edge_cases(input_data, expected):\n # Test normal cases and various edge cases\n assert problem_p02553(input_data) == expected\n\n@pytest.mark.parametrize("input_data", [\n "1,2,3,4",\n "1 2 3",\n "one two three four",\n "\\t1\\t2\\t3\\t4",\n "1\\n2\\n3\\n4"\n])\ndef test_problem_p02553_invalid_input(input_data):\n # Test invalid inputs that should raise a ValueError\n with pytest.raises(ValueError):\n problem_p02553(input_data)'
|
||||
],
|
||||
},
|
||||
"test_framework": {0: "pytest", 1: "pytest", 2: "pytest", 3: "pytest", 4: "pytest", 5: "pytest"},
|
||||
|
|
|
|||
|
|
@ -14,12 +14,12 @@ def pickled_dataframe():
|
|||
return df
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def pie4perf_sample_dataframe():
|
||||
return DataFrame.from_dict(pie4perf_sample_dataframe_dict)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def sample_dataframe():
|
||||
return DataFrame(mock_dataframe)
|
||||
|
||||
|
|
@ -43,14 +43,7 @@ def test_calculate_validity_some_successful_runs(pie4perf_sample_dataframe):
|
|||
def test_calculate_validity_all_successful_runs(pie4perf_sample_dataframe):
|
||||
df = pie4perf_sample_dataframe
|
||||
df["original_runtime"] = [1.0, 2.0, 2.0, 2.0, 3.0, 18.0] # All successful runs
|
||||
df["best_correct_speedup_ratio"] = [
|
||||
0.05,
|
||||
0.06,
|
||||
0.08,
|
||||
0.09,
|
||||
0.56,
|
||||
0.17,
|
||||
] # All above threshold
|
||||
df["best_correct_speedup_ratio"] = [0.05, 0.06, 0.08, 0.09, 0.56, 0.17] # All above threshold
|
||||
df["is_correct"] = [
|
||||
{"1f39ef86-5eff-4760-a262-43011492906e": True},
|
||||
{"aeeeca3b-4ccf-46eb-8dbf-c526c05fca27": True},
|
||||
|
|
@ -68,14 +61,7 @@ def test_calculate_validity_with_valid_candidates(pie4perf_sample_dataframe):
|
|||
df = pie4perf_sample_dataframe
|
||||
# Assuming some runs are successful and some candidates are valid
|
||||
df["original_runtime"] = [1.0, None, 2.0, None, 3.0, None] # Some successful runs
|
||||
df["best_correct_speedup_ratio"] = [
|
||||
0.04,
|
||||
None,
|
||||
0.06,
|
||||
None,
|
||||
0.07,
|
||||
None,
|
||||
] # One below and two above the threshold
|
||||
df["best_correct_speedup_ratio"] = [0.04, None, 0.06, None, 0.07, None] # One below and two above the threshold
|
||||
df["is_correct"] = [
|
||||
{"1f39ef86-5eff-4760-a262-43011492906e": True},
|
||||
{},
|
||||
|
|
|
|||
|
|
@ -7,19 +7,20 @@
|
|||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"source": [
|
||||
""
|
||||
],
|
||||
"outputs": []
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "43aad670ab761737",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-13T01:20:22.380492Z",
|
||||
"start_time": "2024-04-13T01:20:17.939300Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
|
|
@ -27,43 +28,38 @@
|
|||
"from openai import OpenAI\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@weave.op() # 🐝\n",
|
||||
"@weave.op() # 🐝\n",
|
||||
"def extract_fruit(sentence: str) -> dict:\n",
|
||||
" client = OpenAI()\n",
|
||||
"\n",
|
||||
" response = client.chat.completions.create(\n",
|
||||
" model=\"gpt-3.5-turbo-1106\",\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": \"You will be provided with unstructured data, and your task is to parse it one JSON dictionary with fruit, color and flavor as keys.\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": sentence\n",
|
||||
" }\n",
|
||||
" model=\"gpt-3.5-turbo-1106\",\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": \"You will be provided with unstructured data, and your task is to parse it one JSON dictionary with fruit, color and flavor as keys.\",\n",
|
||||
" },\n",
|
||||
" {\"role\": \"user\", \"content\": sentence},\n",
|
||||
" ],\n",
|
||||
" temperature=0.7,\n",
|
||||
" response_format={ \"type\": \"json_object\" }\n",
|
||||
" response_format={\"type\": \"json_object\"},\n",
|
||||
" )\n",
|
||||
" extracted = response.choices[0].message.content\n",
|
||||
" return json.loads(extracted)\n",
|
||||
"\n",
|
||||
"weave.init('intro-example') # 🐝\n",
|
||||
"\n",
|
||||
"weave.init(\"intro-example\") # 🐝\n",
|
||||
"sentence = \"There are many fruits that were found on the recently discovered planet Goocrux. There are neoskizzles that grow there, which are purple and taste like candy.\"\n",
|
||||
"extract_fruit(sentence)"
|
||||
],
|
||||
"id": "43aad670ab761737",
|
||||
"execution_count": 2,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": "",
|
||||
"id": "d929002a32073a23",
|
||||
"outputs": []
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
|||
|
|
@ -12,12 +12,13 @@ packages = [
|
|||
]
|
||||
keywords = ["codeflash", "performance", "optimization", "ai", "code", "machine learning", "LLM"]
|
||||
|
||||
# Versions here the minimum required versions for the project. These should be as loose as possible.
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
unidiff = ">=0.7.4"
|
||||
pytest = ">=7.0.0"
|
||||
gitpython = ">=3.1.31"
|
||||
libcst = ">=1.5.0"
|
||||
libcst = ">=1.0.1"
|
||||
jedi = ">=0.19.1"
|
||||
tiktoken = ">=0.3.2"
|
||||
timeout-decorator = ">=0.5.0"
|
||||
|
|
@ -37,6 +38,15 @@ returns = ">=0.23"
|
|||
isort = ">=5.11.0"
|
||||
dill = "^0.3.8"
|
||||
rich = "^13.8.1"
|
||||
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ipython = "^8.12.0"
|
||||
mypy = ">=1.13"
|
||||
ruff = ">=0.7.0"
|
||||
pandas-stubs = ">=2.2.2.240807, <2.2.3.241009"
|
||||
types-Pygments = "^2.18.0.20240506"
|
||||
types-colorama = "^0.4.15.20240311"
|
||||
|
|
@ -47,14 +57,6 @@ types-six = "^1.16.21.20241009"
|
|||
types-cffi = "^1.16.0.20240331"
|
||||
types-openpyxl = "^3.1.5.20241020"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ipython = "^8.12.0"
|
||||
mypy = ">=1.13"
|
||||
ruff = ">=0.7.0"
|
||||
|
||||
[tool.poetry.build]
|
||||
script = "codeflash/update_license_version.py"
|
||||
|
||||
|
|
|
|||
|
|
@ -45,16 +45,11 @@ def main():
|
|||
improvement_pct = int(re.search(r"📈 ([\d,]+)% improvement", stdout).group(1).replace(",", ""))
|
||||
improvement_x = float(improvement_pct) / 100
|
||||
|
||||
assert (
|
||||
improvement_pct > 300
|
||||
), f"Performance improvement percentage was {improvement_pct}, which was not above 300%"
|
||||
assert improvement_pct > 300, f"Performance improvement percentage was {improvement_pct}, which was not above 300%"
|
||||
assert improvement_x > 3, f"Performance improvement rate was {improvement_x}x, which was not above 3x"
|
||||
|
||||
# Check for the line indicating the number of discovered existing unit tests
|
||||
unit_test_search = re.search(
|
||||
r"Discovered (\d+) existing unit tests",
|
||||
stdout,
|
||||
)
|
||||
unit_test_search = re.search(r"Discovered (\d+) existing unit tests", stdout)
|
||||
num_unit_tests = int(unit_test_search.group(1))
|
||||
assert num_unit_tests > 0, "Could not find existing unit tests"
|
||||
|
||||
|
|
|
|||
|
|
@ -45,16 +45,11 @@ def main():
|
|||
improvement_pct = int(re.search(r"📈 ([\d,]+)% improvement", stdout).group(1).replace(",", ""))
|
||||
improvement_x = float(improvement_pct) / 100
|
||||
|
||||
assert (
|
||||
improvement_pct > 300
|
||||
), f"Performance improvement percentage was {improvement_pct}, which was not above 300%"
|
||||
assert improvement_pct > 300, f"Performance improvement percentage was {improvement_pct}, which was not above 300%"
|
||||
assert improvement_x > 3, f"Performance improvement rate was {improvement_x}x, which was not above 3x"
|
||||
|
||||
# Check for the line indicating the number of discovered existing unit tests
|
||||
unit_test_search = re.search(
|
||||
r"Discovered (\d+) existing unit tests",
|
||||
stdout,
|
||||
)
|
||||
unit_test_search = re.search(r"Discovered (\d+) existing unit tests", stdout)
|
||||
num_unit_tests = int(unit_test_search.group(1))
|
||||
assert num_unit_tests > 0, "Could not find existing unit tests"
|
||||
|
||||
|
|
|
|||
|
|
@ -6,26 +6,12 @@ import subprocess
|
|||
|
||||
def main():
|
||||
cwd = (
|
||||
pathlib.Path(__file__).parent.parent.parent
|
||||
/ "code_to_optimize"
|
||||
/ "code_directories"
|
||||
/ "futurehouse_structure"
|
||||
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "futurehouse_structure"
|
||||
).resolve()
|
||||
print("cwd", cwd)
|
||||
command = [
|
||||
"python",
|
||||
"../../../codeflash/main.py",
|
||||
"--file",
|
||||
"src/aviary/common_tags.py",
|
||||
"--no-pr",
|
||||
]
|
||||
command = ["python", "../../../codeflash/main.py", "--file", "src/aviary/common_tags.py", "--no-pr"]
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
cwd=str(cwd),
|
||||
env=os.environ.copy(),
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy()
|
||||
)
|
||||
output = []
|
||||
|
||||
|
|
@ -41,16 +27,11 @@ def main():
|
|||
improvement_pct = int(re.search(r"📈 ([\d,]+)% improvement", stdout).group(1).replace(",", ""))
|
||||
improvement_x = float(improvement_pct) / 100
|
||||
|
||||
assert (
|
||||
improvement_pct > 5
|
||||
), f"Performance improvement percentage was {improvement_pct}, which was not above 10%"
|
||||
assert improvement_pct > 5, f"Performance improvement percentage was {improvement_pct}, which was not above 10%"
|
||||
assert improvement_x > 0.1, f"Performance improvement rate was {improvement_x}x, which was not above 0.1x"
|
||||
|
||||
# Check for the line indicating the number of discovered existing unit tests
|
||||
unit_test_search = re.search(
|
||||
r"Discovered (\d+) existing unit tests",
|
||||
stdout,
|
||||
)
|
||||
unit_test_search = re.search(r"Discovered (\d+) existing unit tests", stdout)
|
||||
num_unit_tests = int(unit_test_search.group(1))
|
||||
assert num_unit_tests == 2, "Could not find existing unit tests"
|
||||
|
||||
|
|
|
|||
|
|
@ -45,18 +45,11 @@ def main():
|
|||
improvement_pct = int(re.search(r"📈 ([\d,]+)% improvement", stdout).group(1).replace(",", ""))
|
||||
improvement_x = float(improvement_pct) / 100
|
||||
|
||||
assert (
|
||||
improvement_pct > 5
|
||||
), f"Performance improvement percentage was {improvement_pct}, which was not above 5%"
|
||||
assert (
|
||||
improvement_x > 0.05
|
||||
), f"Performance improvement rate was {improvement_x}x, which was not above 0.05x"
|
||||
assert improvement_pct > 5, f"Performance improvement percentage was {improvement_pct}, which was not above 5%"
|
||||
assert improvement_x > 0.05, f"Performance improvement rate was {improvement_x}x, which was not above 0.05x"
|
||||
|
||||
# Check for the line indicating the number of discovered existing unit tests
|
||||
unit_test_search = re.search(
|
||||
r"Discovered (\d+) existing unit tests",
|
||||
stdout,
|
||||
)
|
||||
unit_test_search = re.search(r"Discovered (\d+) existing unit tests", stdout)
|
||||
num_unit_tests = int(unit_test_search.group(1))
|
||||
assert num_unit_tests > 0, "Could not find existing unit tests"
|
||||
|
||||
|
|
|
|||
|
|
@ -51,13 +51,7 @@ class Source:
|
|||
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
|
||||
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
|
||||
project_root = Path("/home/roger/repos/codeflash")
|
||||
new_module = add_needed_imports_from_module(
|
||||
src_module,
|
||||
dst_module,
|
||||
src_path,
|
||||
dst_path,
|
||||
project_root,
|
||||
)
|
||||
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
|
||||
assert new_module == expected
|
||||
|
||||
|
||||
|
|
@ -125,11 +119,5 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
|
|||
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
|
||||
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
|
||||
project_root = Path("/home/roger/repos/codeflash")
|
||||
new_module = add_needed_imports_from_module(
|
||||
src_module,
|
||||
dst_module,
|
||||
src_path,
|
||||
dst_path,
|
||||
project_root,
|
||||
)
|
||||
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
|
||||
assert new_module == expected
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ print("Hello world")
|
|||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("new_function", [FunctionParent(name="NewClass", type="ClassDef")]),
|
||||
("new_function", [FunctionParent(name="NewClass", type="ClassDef")])
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
|
|
@ -136,10 +136,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("new_function", []),
|
||||
("other_function", []),
|
||||
]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("new_function", []), ("other_function", [])]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -605,10 +602,7 @@ class CacheConfig(BaseConfig):
|
|||
"""
|
||||
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
|
||||
parents = [FunctionParent(name="CacheConfig", type="ClassDef")]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("__init__", parents),
|
||||
("from_config", parents),
|
||||
]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("from_config", parents)]
|
||||
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("CacheSimilarityEvalConfig", "__init__"),
|
||||
|
|
@ -687,7 +681,7 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
'''
|
||||
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")]),
|
||||
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")])
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
|
|
@ -745,14 +739,8 @@ print("Hello world")
|
|||
"""
|
||||
parents = [FunctionParent(name="NewClass", type="ClassDef")]
|
||||
function_name: str = "NewClass.__init__"
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("__init__", parents),
|
||||
("__call__", parents),
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("NewClass", "__init__"),
|
||||
("NewClass", "__call__"),
|
||||
}
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("__call__", parents)]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__"), ("NewClass", "__call__")}
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
|
|
@ -814,18 +802,14 @@ class MainClass:
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=file_path.parent.resolve(),
|
||||
),
|
||||
)
|
||||
)
|
||||
func_top_optimize = FunctionToOptimize(
|
||||
function_name="main_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent("MainClass", "ClassDef")],
|
||||
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
|
||||
)
|
||||
original_code = file_path.read_text()
|
||||
code_context = opt.get_code_optimization_context(
|
||||
function_to_optimize=func_top_optimize,
|
||||
project_root=file_path.parent,
|
||||
original_source_code=original_code,
|
||||
function_to_optimize=func_top_optimize, project_root=file_path.parent, original_source_code=original_code
|
||||
).unwrap()
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
|
||||
|
|
@ -1144,14 +1128,14 @@ class TestResults(BaseModel):
|
|||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
file_path=Path(
|
||||
"/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py",
|
||||
"/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py"
|
||||
),
|
||||
qualified_name="TestType",
|
||||
fully_qualified_name="codeflash.verification.test_results.TestType",
|
||||
only_function_name="TestType",
|
||||
source_code="",
|
||||
jedi_definition=JediDefinition(type="class"),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
|
|
@ -1168,13 +1152,8 @@ class TestResults(BaseModel):
|
|||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
for helper_function in helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(
|
||||
helper_function.qualified_name,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in helper_functions_by_module_abspath.items():
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
|
||||
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=new_code,
|
||||
function_names=list(qualified_names),
|
||||
|
|
@ -1447,13 +1426,8 @@ def cosine_similarity_top_k(
|
|||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
for helper_function in helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(
|
||||
helper_function.qualified_name,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in helper_functions_by_module_abspath.items():
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
|
||||
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
|
||||
new_helper_code: str = replace_functions_and_add_imports(
|
||||
source_code=new_code,
|
||||
function_names=list(qualified_names),
|
||||
|
|
|
|||
|
|
@ -96,10 +96,7 @@ def test_get_imports_from_file_with_syntax_error(caplog) -> None:
|
|||
|
||||
|
||||
def test_get_imports_from_file_with_no_input() -> None:
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Must provide exactly one of file_path, file_string, or file_ast",
|
||||
):
|
||||
with pytest.raises(AssertionError, match="Must provide exactly one of file_path, file_string, or file_ast"):
|
||||
get_imports_from_file()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,12 +9,7 @@ from returns.result import Failure, Success
|
|||
|
||||
from codeflash.verification.comparator import comparator
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.test_results import (
|
||||
FunctionTestInvocation,
|
||||
InvocationId,
|
||||
TestResults,
|
||||
TestType,
|
||||
)
|
||||
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
|
||||
|
||||
|
||||
def test_basic_python_objects():
|
||||
|
|
@ -377,22 +372,13 @@ def test_pandas():
|
|||
assert not comparator(c, ca)
|
||||
|
||||
ak = pd.DataFrame(
|
||||
{
|
||||
"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)],
|
||||
"b": [4, 5],
|
||||
},
|
||||
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]}
|
||||
)
|
||||
al = pd.DataFrame(
|
||||
{
|
||||
"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)],
|
||||
"b": [4, 5],
|
||||
},
|
||||
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]}
|
||||
)
|
||||
am = pd.DataFrame(
|
||||
{
|
||||
"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 3)],
|
||||
"b": [4, 5],
|
||||
},
|
||||
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 3)], "b": [4, 5]}
|
||||
)
|
||||
assert comparator(ak, al)
|
||||
assert not comparator(ak, am)
|
||||
|
|
@ -693,8 +679,8 @@ def test_compare_results_fn():
|
|||
return_value=5,
|
||||
timed_out=False,
|
||||
loop_index=1,
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
new_results_1 = TestResults(
|
||||
|
|
@ -715,8 +701,8 @@ def test_compare_results_fn():
|
|||
return_value=5,
|
||||
timed_out=False,
|
||||
loop_index=1,
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert compare_test_results(original_results, new_results_1)
|
||||
|
|
@ -739,8 +725,8 @@ def test_compare_results_fn():
|
|||
return_value=[5],
|
||||
timed_out=False,
|
||||
loop_index=1,
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert not compare_test_results(original_results, new_results_2)
|
||||
|
|
@ -781,7 +767,7 @@ def test_compare_results_fn():
|
|||
timed_out=False,
|
||||
loop_index=1,
|
||||
),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
assert compare_test_results(original_results, new_results_3)
|
||||
|
|
|
|||
|
|
@ -3,45 +3,24 @@ import os
|
|||
from codeflash.code_utils.env_utils import get_pr_number
|
||||
from codeflash.models.models import OptimizedCandidateResult
|
||||
from codeflash.result.critic import quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.verification.test_results import (
|
||||
FunctionTestInvocation,
|
||||
InvocationId,
|
||||
TestResults,
|
||||
TestType,
|
||||
)
|
||||
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
|
||||
|
||||
|
||||
def test_speedup_critic():
|
||||
original_code_runtime = 1000
|
||||
best_runtime_until_now = 1000
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=800,
|
||||
best_test_results=TestResults(),
|
||||
)
|
||||
candidate_result = OptimizedCandidateResult(times_run=5, best_test_runtime=800, best_test_results=TestResults())
|
||||
|
||||
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now) # 20% improvement
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=940,
|
||||
best_test_results=TestResults(),
|
||||
)
|
||||
candidate_result = OptimizedCandidateResult(times_run=5, best_test_runtime=940, best_test_results=TestResults())
|
||||
|
||||
assert not speedup_critic(
|
||||
candidate_result,
|
||||
original_code_runtime,
|
||||
best_runtime_until_now,
|
||||
) # 6% improvement
|
||||
assert not speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now) # 6% improvement
|
||||
|
||||
original_code_runtime = 100000
|
||||
best_runtime_until_now = 100000
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=94000,
|
||||
best_test_results=TestResults(),
|
||||
)
|
||||
candidate_result = OptimizedCandidateResult(times_run=5, best_test_runtime=94000, best_test_results=TestResults())
|
||||
|
||||
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now) # 6% improvement
|
||||
|
||||
|
|
@ -140,9 +119,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_1, test_2, test_3]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert quantity_of_tests_critic(candidate_result)
|
||||
|
|
@ -150,9 +127,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_1, test_3]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert quantity_of_tests_critic(candidate_result)
|
||||
|
|
@ -160,9 +135,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_1, test_3, test_4]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert quantity_of_tests_critic(candidate_result)
|
||||
|
|
@ -170,9 +143,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_1]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert not quantity_of_tests_critic(candidate_result)
|
||||
|
|
@ -180,9 +151,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_1, test_2]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert quantity_of_tests_critic(candidate_result)
|
||||
|
|
@ -190,9 +159,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_1, test_4]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert not quantity_of_tests_critic(candidate_result)
|
||||
|
|
@ -200,9 +167,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_4, test_5]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert quantity_of_tests_critic(candidate_result)
|
||||
|
|
@ -210,9 +175,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_1, test_2, test_3, test_4, test_5]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert quantity_of_tests_critic(candidate_result)
|
||||
|
|
@ -222,9 +185,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_1, test_2, test_3]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert not quantity_of_tests_critic(candidate_result)
|
||||
|
|
@ -232,9 +193,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_1, test_2, test_3, test_4]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert not quantity_of_tests_critic(candidate_result)
|
||||
|
|
@ -242,9 +201,7 @@ def test_generated_test_critic():
|
|||
test_results = [test_1, test_2, test_3, test_5]
|
||||
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
times_run=5,
|
||||
best_test_runtime=100,
|
||||
best_test_results=TestResults(test_results=test_results),
|
||||
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
|
||||
)
|
||||
|
||||
assert quantity_of_tests_critic(candidate_result)
|
||||
|
|
|
|||
|
|
@ -38,10 +38,7 @@ def test_sort_imports_without_formatting():
|
|||
tmp.flush()
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
new_code = format_code(
|
||||
formatter_cmds=["disabled"],
|
||||
path=tmp_path,
|
||||
)
|
||||
new_code = format_code(formatter_cmds=["disabled"], path=tmp_path)
|
||||
assert new_code is not None
|
||||
new_code = sort_imports(new_code)
|
||||
assert new_code == "import os\nimport sys\nimport unittest\n"
|
||||
|
|
@ -136,10 +133,7 @@ def foo():
|
|||
tmp.flush()
|
||||
tmp_path = tmp.name
|
||||
|
||||
actual = format_code(
|
||||
formatter_cmds=["black $file"],
|
||||
path=Path(tmp_path),
|
||||
)
|
||||
actual = format_code(formatter_cmds=["black $file"], path=Path(tmp_path))
|
||||
assert actual == expected
|
||||
|
||||
|
||||
|
|
@ -161,10 +155,7 @@ def foo():
|
|||
tmp.flush()
|
||||
tmp_path = tmp.name
|
||||
|
||||
actual = format_code(
|
||||
formatter_cmds=["black $file"],
|
||||
path=Path(tmp_path),
|
||||
)
|
||||
actual = format_code(formatter_cmds=["black $file"], path=Path(tmp_path))
|
||||
assert actual == expected
|
||||
|
||||
|
||||
|
|
@ -191,7 +182,6 @@ def foo():
|
|||
tmp_path = tmp.name
|
||||
|
||||
actual = format_code(
|
||||
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"],
|
||||
path=Path(tmp_path),
|
||||
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=Path(tmp_path)
|
||||
)
|
||||
assert actual == expected
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ from argparse import Namespace
|
|||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from returns.pipeline import is_successful
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
from codeflash.optimization.function_context import get_function_variables_definitions
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from returns.pipeline import is_successful
|
||||
|
||||
|
||||
def calculate_something(data):
|
||||
|
|
@ -20,8 +21,7 @@ def simple_function_with_one_dep(data):
|
|||
def test_simple_dependencies() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("simple_function_with_one_dep", str(file_path), []),
|
||||
str(file_path.parent.resolve()),
|
||||
FunctionToOptimize("simple_function_with_one_dep", str(file_path), []), str(file_path.parent.resolve())
|
||||
)[0]
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something"
|
||||
|
|
@ -82,8 +82,7 @@ def test_multiple_classes_dependencies() -> None:
|
|||
# TODO: Check if C.run only gets calculate_something_3 as dependency and likewise for other classes
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]),
|
||||
str(file_path.parent.resolve()),
|
||||
FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]), str(file_path.parent.resolve())
|
||||
)
|
||||
|
||||
# assert len(helper_functions) == 2
|
||||
|
|
@ -103,8 +102,7 @@ def recursive_dependency_1(num):
|
|||
def test_recursive_dependency() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("recursive_dependency_1", str(file_path), []),
|
||||
str(file_path.parent.resolve()),
|
||||
FunctionToOptimize("recursive_dependency_1", str(file_path), []), str(file_path.parent.resolve())
|
||||
)[0]
|
||||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something"
|
||||
|
|
@ -127,14 +125,11 @@ def simple_function_with_one_dep_ann(data: MyData):
|
|||
def test_simple_dependencies_ann() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("simple_function_with_one_dep_ann", str(file_path), []),
|
||||
str(file_path.parent.resolve()),
|
||||
FunctionToOptimize("simple_function_with_one_dep_ann", str(file_path), []), str(file_path.parent.resolve())
|
||||
)[0]
|
||||
assert len(helper_functions) == 2
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData"
|
||||
assert (
|
||||
helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something_ann"
|
||||
)
|
||||
assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something_ann"
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
|
|
@ -180,7 +175,7 @@ def test_class_method_dependencies() -> None:
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=file_path.parent.resolve(),
|
||||
),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="topologicalSort",
|
||||
|
|
@ -191,11 +186,7 @@ def test_class_method_dependencies() -> None:
|
|||
)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
ctx_result = opt.get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
opt.args.project_root,
|
||||
original_code,
|
||||
)
|
||||
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
|
||||
if not is_successful(ctx_result):
|
||||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
|
|
@ -207,8 +198,7 @@ def test_class_method_dependencies() -> None:
|
|||
)
|
||||
assert code_context.helper_functions[0].jedi_definition.name == "topologicalSortUtil"
|
||||
assert (
|
||||
code_context.helper_functions[0].fully_qualified_name
|
||||
== "test_function_dependencies.Graph.topologicalSortUtil"
|
||||
code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.Graph.topologicalSortUtil"
|
||||
)
|
||||
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
|
||||
assert code_context.contextual_dunder_methods == {("Graph", "__init__")}
|
||||
|
|
@ -262,8 +252,7 @@ def simple_function_with_decorator_dep(data):
|
|||
def test_decorator_dependencies() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("simple_function_with_decorator_dep", str(file_path), []),
|
||||
str(file_path.parent.resolve()),
|
||||
FunctionToOptimize("simple_function_with_decorator_dep", str(file_path), []), str(file_path.parent.resolve())
|
||||
)[0]
|
||||
assert len(helper_functions) == 2
|
||||
assert {helper_functions[0][0].definition.full_name, helper_functions[1][0].definition.full_name} == {
|
||||
|
|
@ -283,7 +272,7 @@ def test_recursive_function_context() -> None:
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=file_path.parent.resolve(),
|
||||
),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="recursive",
|
||||
|
|
@ -294,21 +283,14 @@ def test_recursive_function_context() -> None:
|
|||
)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
ctx_result = opt.get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
opt.args.project_root,
|
||||
original_code,
|
||||
)
|
||||
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
|
||||
if not is_successful(ctx_result):
|
||||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
# The code_context above should have the topologicalSortUtil function in it
|
||||
assert len(code_context.helper_functions) == 2
|
||||
assert set(
|
||||
[
|
||||
code_context.helper_functions[1].fully_qualified_name,
|
||||
code_context.helper_functions[0].fully_qualified_name,
|
||||
],
|
||||
[code_context.helper_functions[1].fully_qualified_name, code_context.helper_functions[0].fully_qualified_name]
|
||||
) == set(["test_function_dependencies.C.calculate_something_3", "test_function_dependencies.C.recursive"])
|
||||
assert (
|
||||
code_context.code_to_optimize_with_helpers
|
||||
|
|
|
|||
|
|
@ -58,41 +58,26 @@ class AirbyteEntrypoint(object):
|
|||
return AirbyteEntrypoint.handle_record_counts(num)
|
||||
def non_classmethod_function(cls, name):
|
||||
return cls.name
|
||||
""",
|
||||
"""
|
||||
)
|
||||
f.flush()
|
||||
path_obj_name = Path(f.name)
|
||||
assert inspect_top_level_functions_or_methods(path_obj_name, "functionA").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionB").is_top_level
|
||||
assert inspect_top_level_functions_or_methods(path_obj_name, "functionC", class_name="A").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(
|
||||
path_obj_name,
|
||||
"functionD",
|
||||
class_name="A",
|
||||
).is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(
|
||||
path_obj_name,
|
||||
"functionF",
|
||||
class_name="E",
|
||||
).is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionD", class_name="A").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionF", class_name="E").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA").has_args
|
||||
staticmethod_func = inspect_top_level_functions_or_methods(
|
||||
path_obj_name,
|
||||
"handle_record_counts",
|
||||
class_name=None,
|
||||
line_no=15,
|
||||
path_obj_name, "handle_record_counts", class_name=None, line_no=15
|
||||
)
|
||||
assert staticmethod_func.is_staticmethod
|
||||
assert staticmethod_func.staticmethod_class_name == "AirbyteEntrypoint"
|
||||
assert inspect_top_level_functions_or_methods(
|
||||
path_obj_name,
|
||||
"functionE",
|
||||
class_name="AirbyteEntrypoint",
|
||||
path_obj_name, "functionE", class_name="AirbyteEntrypoint"
|
||||
).is_classmethod
|
||||
assert not inspect_top_level_functions_or_methods(
|
||||
path_obj_name,
|
||||
"non_classmethod_function",
|
||||
class_name="AirbyteEntrypoint",
|
||||
path_obj_name, "non_classmethod_function", class_name="AirbyteEntrypoint"
|
||||
).is_top_level
|
||||
# needed because this will be traced with a class_name being passed
|
||||
|
||||
|
|
@ -111,7 +96,7 @@ class X:
|
|||
def functionB():
|
||||
return False
|
||||
def functionA():
|
||||
return True""",
|
||||
return True"""
|
||||
)
|
||||
f.flush()
|
||||
test_config = TestConfig(
|
||||
|
|
|
|||
|
|
@ -29,9 +29,7 @@ def test_get_code_property() -> None:
|
|||
f.flush()
|
||||
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[
|
||||
FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")]),
|
||||
],
|
||||
[FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")])]
|
||||
)
|
||||
assert new_code == code
|
||||
assert contextual_dunder_methods == {("TestClass", "__init__")}
|
||||
|
|
@ -60,7 +58,7 @@ class TestClass:
|
|||
f.flush()
|
||||
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")])],
|
||||
[FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")])]
|
||||
)
|
||||
assert new_code == expected
|
||||
assert contextual_dunder_methods == {("TestClass", "__init__")}
|
||||
|
|
@ -111,13 +109,10 @@ class BubbleSortClass:
|
|||
f.flush()
|
||||
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[FunctionToOptimize("sorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")])],
|
||||
[FunctionToOptimize("sorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")])]
|
||||
)
|
||||
assert new_code == expected
|
||||
assert contextual_dunder_methods == {
|
||||
("BubbleSortClass", "__init__"),
|
||||
("BubbleSortClass", "__call__"),
|
||||
}
|
||||
assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")}
|
||||
|
||||
|
||||
def test_get_code_indent() -> None:
|
||||
|
|
@ -177,23 +172,12 @@ def non():
|
|||
f.flush()
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[
|
||||
FunctionToOptimize(
|
||||
"sorter",
|
||||
f.name,
|
||||
[FunctionParent("BubbleSortClass", "ClassDef")],
|
||||
),
|
||||
FunctionToOptimize(
|
||||
"helper",
|
||||
f.name,
|
||||
[FunctionParent("BubbleSortClass", "ClassDef")],
|
||||
),
|
||||
],
|
||||
FunctionToOptimize("sorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")]),
|
||||
FunctionToOptimize("helper", f.name, [FunctionParent("BubbleSortClass", "ClassDef")]),
|
||||
]
|
||||
)
|
||||
assert new_code == expected
|
||||
assert contextual_dunder_methods == {
|
||||
("BubbleSortClass", "__init__"),
|
||||
("BubbleSortClass", "__call__"),
|
||||
}
|
||||
assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")}
|
||||
|
||||
expected2 = """class BubbleSortClass:
|
||||
def __init__(self):
|
||||
|
|
@ -218,28 +202,13 @@ def non():
|
|||
f.flush()
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[
|
||||
FunctionToOptimize(
|
||||
"sorter",
|
||||
f.name,
|
||||
[FunctionParent("BubbleSortClass", "ClassDef")],
|
||||
),
|
||||
FunctionToOptimize(
|
||||
"helper",
|
||||
f.name,
|
||||
[FunctionParent("BubbleSortClass", "ClassDef")],
|
||||
),
|
||||
FunctionToOptimize(
|
||||
"unsorter",
|
||||
f.name,
|
||||
[FunctionParent("BubbleSortClass", "ClassDef")],
|
||||
),
|
||||
],
|
||||
FunctionToOptimize("sorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")]),
|
||||
FunctionToOptimize("helper", f.name, [FunctionParent("BubbleSortClass", "ClassDef")]),
|
||||
FunctionToOptimize("unsorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")]),
|
||||
]
|
||||
)
|
||||
assert new_code == expected2
|
||||
assert contextual_dunder_methods == {
|
||||
("BubbleSortClass", "__init__"),
|
||||
("BubbleSortClass", "__call__"),
|
||||
}
|
||||
assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")}
|
||||
|
||||
|
||||
def test_get_code_multiline_class_def() -> None:
|
||||
|
|
@ -275,8 +244,8 @@ def test_get_code_multiline_class_def() -> None:
|
|||
"computeStatement",
|
||||
f.name,
|
||||
[FunctionParent("StatementAssignmentVariableConstantMutable", "ClassDef")],
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
assert new_code == expected
|
||||
assert contextual_dunder_methods == set()
|
||||
|
|
@ -296,13 +265,7 @@ class CustomDataClass:
|
|||
# single FunctionToOptimize instance, in the case where that instance has been filtered to represent a function
|
||||
# (with a definition).
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[
|
||||
FunctionToOptimize(
|
||||
"name",
|
||||
f.name,
|
||||
[FunctionParent("CustomDataClass", "ClassDef")],
|
||||
),
|
||||
],
|
||||
[FunctionToOptimize("name", f.name, [FunctionParent("CustomDataClass", "ClassDef")])]
|
||||
)
|
||||
assert new_code is None
|
||||
assert contextual_dunder_methods == set()
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@ from argparse import Namespace
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from returns.pipeline import is_successful
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from returns.pipeline import is_successful
|
||||
|
||||
|
||||
class HelperClass:
|
||||
|
|
@ -28,22 +29,14 @@ def test_get_outside_method_helper() -> None:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="OptimizeMe",
|
||||
file_path=file_path,
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
function_name="OptimizeMe", file_path=file_path, parents=[], starting_line=None, ending_line=None
|
||||
)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
ctx_result = opt.get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
opt.args.project_root,
|
||||
original_code,
|
||||
)
|
||||
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
|
||||
if not is_successful(ctx_result):
|
||||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
|
|
@ -229,7 +222,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="__call__",
|
||||
|
|
@ -240,11 +233,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
ctx_result = opt.get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
opt.args.project_root,
|
||||
original_code,
|
||||
)
|
||||
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
|
||||
if not is_successful(ctx_result):
|
||||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
|
|
@ -360,22 +349,14 @@ def test_bubble_sort_deps() -> None:
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=file_path.parent.resolve(),
|
||||
),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="sorter_deps",
|
||||
file_path=file_path,
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None
|
||||
)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
ctx_result = opt.get_code_optimization_context(
|
||||
function_to_optimize,
|
||||
opt.args.project_root,
|
||||
original_code,
|
||||
)
|
||||
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
|
||||
if not is_successful(ctx_result):
|
||||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
|
|
@ -402,7 +383,4 @@ def sorter_deps(arr):
|
|||
code_context.helper_functions[0].fully_qualified_name
|
||||
== "code_to_optimize.bubble_sort_dep1_helper.dep1_comparer"
|
||||
)
|
||||
assert (
|
||||
code_context.helper_functions[1].fully_qualified_name
|
||||
== "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"
|
||||
)
|
||||
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"
|
||||
|
|
|
|||
|
|
@ -2,11 +2,8 @@ import unittest
|
|||
from unittest.mock import patch
|
||||
|
||||
import git
|
||||
from codeflash.code_utils.git_utils import (
|
||||
check_and_push_branch,
|
||||
check_running_in_git_repo,
|
||||
get_repo_owner_and_name,
|
||||
)
|
||||
|
||||
from codeflash.code_utils.git_utils import check_and_push_branch, check_running_in_git_repo, get_repo_owner_and_name
|
||||
|
||||
|
||||
class TestGitUtils(unittest.TestCase):
|
||||
|
|
@ -44,12 +41,7 @@ class TestGitUtils(unittest.TestCase):
|
|||
@patch("codeflash.code_utils.git_utils.git.Repo")
|
||||
@patch("codeflash.code_utils.git_utils.sys.__stdin__.isatty", return_value=True)
|
||||
@patch("codeflash.code_utils.git_utils.confirm_proceeding_with_no_git_repo", return_value=True)
|
||||
def test_check_running_in_git_repo_not_in_git_repo_interactive(
|
||||
self,
|
||||
mock_confirm,
|
||||
mock_isatty,
|
||||
mock_repo,
|
||||
):
|
||||
def test_check_running_in_git_repo_not_in_git_repo_interactive(self, mock_confirm, mock_isatty, mock_repo):
|
||||
mock_repo.side_effect = git.InvalidGitRepositoryError # type: ignore
|
||||
assert check_running_in_git_repo("/path/to/non-repo") == False
|
||||
|
||||
|
|
|
|||
|
|
@ -102,11 +102,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
func = FunctionToOptimize(
|
||||
function_name="sorter",
|
||||
parents=[],
|
||||
file_path=Path("module.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
|
||||
original_cwd = Path.cwd()
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
os.chdir(run_cwd)
|
||||
|
|
@ -120,8 +116,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test == expected.format(
|
||||
module_path=Path(f.name).name,
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -204,26 +199,17 @@ def test_prepare_image_for_yolo():
|
|||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
func = FunctionToOptimize(
|
||||
function_name="prepare_image_for_yolo",
|
||||
parents=[],
|
||||
file_path=Path("module.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="prepare_image_for_yolo", parents=[], file_path=Path("module.py"))
|
||||
original_cwd = Path.cwd()
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
Path(f.name),
|
||||
[CodePosition(10, 14)],
|
||||
func,
|
||||
Path(f.name).parent,
|
||||
"pytest",
|
||||
Path(f.name), [CodePosition(10, 14)], func, Path(f.name).parent, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test == expected.format(
|
||||
module_path=Path(f.name).name,
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -301,18 +287,10 @@ def test_sort():
|
|||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
original_cwd = Path.cwd()
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
func = FunctionToOptimize(
|
||||
function_name="sorter",
|
||||
parents=[],
|
||||
file_path=Path("module.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(6, 13), CodePosition(10, 13)],
|
||||
func,
|
||||
project_root_path,
|
||||
"pytest",
|
||||
test_path, [CodePosition(6, 13), CodePosition(10, 13)], func, project_root_path, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -335,7 +313,7 @@ def test_sort():
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
test_env = os.environ.copy()
|
||||
|
|
@ -343,13 +321,7 @@ def test_sort():
|
|||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
|
||||
)
|
||||
test_results = opt.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
|
|
@ -461,18 +433,10 @@ def test_sort_parametrized(input, expected_output):
|
|||
original_cwd = Path.cwd()
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="sorter",
|
||||
parents=[],
|
||||
file_path=Path("module.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(14, 13)],
|
||||
func,
|
||||
project_root_path,
|
||||
"pytest",
|
||||
test_path, [CodePosition(14, 13)], func, project_root_path, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -490,13 +454,7 @@ def test_sort_parametrized(input, expected_output):
|
|||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
|
||||
)
|
||||
|
||||
opt = Optimizer(
|
||||
|
|
@ -508,7 +466,7 @@ def test_sort_parametrized(input, expected_output):
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
test_results = opt.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
|
|
@ -635,18 +593,10 @@ def test_sort_parametrized_loop(input, expected_output):
|
|||
original_cwd = Path.cwd()
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="sorter",
|
||||
parents=[],
|
||||
file_path=Path("module.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(15, 17)],
|
||||
func,
|
||||
project_root_path,
|
||||
"pytest",
|
||||
test_path, [CodePosition(15, 17)], func, project_root_path, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -664,13 +614,7 @@ def test_sort_parametrized_loop(input, expected_output):
|
|||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
|
||||
)
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
|
|
@ -681,7 +625,7 @@ def test_sort_parametrized_loop(input, expected_output):
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
test_results = opt.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
|
|
@ -838,18 +782,10 @@ def test_sort():
|
|||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
original_cwd = Path.cwd()
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="sorter",
|
||||
parents=[],
|
||||
file_path=Path("module.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
|
||||
os.chdir(str(run_cwd))
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(11, 17)],
|
||||
func,
|
||||
project_root_path,
|
||||
"pytest",
|
||||
test_path, [CodePosition(11, 17)], func, project_root_path, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -867,13 +803,7 @@ def test_sort():
|
|||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
|
||||
)
|
||||
|
||||
opt = Optimizer(
|
||||
|
|
@ -885,7 +815,7 @@ def test_sort():
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
test_results = opt.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
|
|
@ -1023,11 +953,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
original_cwd = Path.cwd()
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="sorter",
|
||||
parents=[],
|
||||
file_path=Path("module.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
|
|
@ -1054,13 +980,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
|
||||
)
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
|
|
@ -1071,7 +991,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
test_results = opt.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
|
|
@ -1202,18 +1122,10 @@ class TestPigLatin(unittest.TestCase):
|
|||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
original_cwd = Path.cwd()
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="sorter",
|
||||
parents=[],
|
||||
file_path=Path("module.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(16, 17)],
|
||||
func,
|
||||
project_root_path,
|
||||
"unittest",
|
||||
test_path, [CodePosition(16, 17)], func, project_root_path, "unittest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -1231,13 +1143,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
|
||||
)
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
|
|
@ -1248,7 +1154,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
test_results = opt.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
|
|
@ -1384,11 +1290,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(14, 21)],
|
||||
func,
|
||||
project_root_path,
|
||||
"unittest",
|
||||
test_path, [CodePosition(14, 21)], func, project_root_path, "unittest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -1406,13 +1308,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
|
||||
)
|
||||
|
||||
opt = Optimizer(
|
||||
|
|
@ -1424,7 +1320,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
test_results = opt.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
|
|
@ -1557,18 +1453,10 @@ class TestPigLatin(unittest.TestCase):
|
|||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
original_cwd = Path.cwd()
|
||||
|
||||
f = FunctionToOptimize(
|
||||
function_name="sorter",
|
||||
file_path=Path("module.py"),
|
||||
parents=[],
|
||||
)
|
||||
f = FunctionToOptimize(function_name="sorter", file_path=Path("module.py"), parents=[])
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(17, 21)],
|
||||
f,
|
||||
project_root_path,
|
||||
"unittest",
|
||||
test_path, [CodePosition(17, 21)], f, project_root_path, "unittest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -1587,13 +1475,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
|
||||
)
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
|
|
@ -1604,7 +1486,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
test_results = opt.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
|
|
@ -1691,21 +1573,13 @@ from module import functionB as function_B
|
|||
import class_name_B
|
||||
from nuitka.nodes.ImportNodes import ExpressionBuiltinImport as nuitka_nodes_ImportNodes_ExpressionBuiltinImport
|
||||
"""
|
||||
f = FunctionToOptimize(
|
||||
function_name="functionA",
|
||||
file_path=Path("module.py"),
|
||||
parents=[],
|
||||
)
|
||||
f = FunctionToOptimize(function_name="functionA", file_path=Path("module.py"), parents=[])
|
||||
tree = ast.parse(code)
|
||||
visitor = FunctionImportedAsVisitor(f)
|
||||
visitor.visit(tree)
|
||||
assert visitor.imported_as.function_name == "functionA"
|
||||
|
||||
f = FunctionToOptimize(
|
||||
function_name="functionB",
|
||||
file_path=Path("module.py"),
|
||||
parents=[],
|
||||
)
|
||||
f = FunctionToOptimize(function_name="functionB", file_path=Path("module.py"), parents=[])
|
||||
visitor = FunctionImportedAsVisitor(f)
|
||||
visitor.visit(tree)
|
||||
assert visitor.imported_as.function_name == "function_B"
|
||||
|
|
@ -1717,15 +1591,9 @@ from nuitka.nodes.ImportNodes import ExpressionBuiltinImport as nuitka_nodes_Imp
|
|||
)
|
||||
visitor = FunctionImportedAsVisitor(f)
|
||||
visitor.visit(tree)
|
||||
assert (
|
||||
visitor.imported_as.qualified_name == "nuitka_nodes_ImportNodes_ExpressionBuiltinImport.method_name"
|
||||
)
|
||||
assert visitor.imported_as.qualified_name == "nuitka_nodes_ImportNodes_ExpressionBuiltinImport.method_name"
|
||||
|
||||
f = FunctionToOptimize(
|
||||
function_name="class_name_B",
|
||||
file_path=Path("module.py"),
|
||||
parents=[],
|
||||
)
|
||||
f = FunctionToOptimize(function_name="class_name_B", file_path=Path("module.py"), parents=[])
|
||||
visitor = FunctionImportedAsVisitor(f)
|
||||
visitor.visit(tree)
|
||||
assert visitor.imported_as.qualified_name == "class_name_B"
|
||||
|
|
@ -1782,8 +1650,7 @@ def test_class_name_A_function_name():
|
|||
"""
|
||||
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve()
|
||||
/ "../code_to_optimize/tests/pytest/test_class_function_instrumentation_temp.py"
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_class_function_instrumentation_temp.py"
|
||||
)
|
||||
try:
|
||||
with open(test_path, "w") as f:
|
||||
|
|
@ -1799,11 +1666,7 @@ def test_class_name_A_function_name():
|
|||
)
|
||||
os.chdir(str(run_cwd))
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(4, 23)],
|
||||
func,
|
||||
project_root_path,
|
||||
"pytest",
|
||||
test_path, [CodePosition(4, 23)], func, project_root_path, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
finally:
|
||||
|
|
@ -1878,8 +1741,7 @@ def test_common_tags_1():
|
|||
"""
|
||||
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve()
|
||||
/ "../code_to_optimize/tests/pytest/test_wrong_function_instrumentation_temp.py"
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_wrong_function_instrumentation_temp.py"
|
||||
)
|
||||
try:
|
||||
with test_path.open("w") as f:
|
||||
|
|
@ -1890,18 +1752,12 @@ def test_common_tags_1():
|
|||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
original_cwd = Path.cwd()
|
||||
func = FunctionToOptimize(
|
||||
function_name="find_common_tags",
|
||||
file_path=project_root_path / "module.py",
|
||||
parents=[],
|
||||
function_name="find_common_tags", file_path=project_root_path / "module.py", parents=[]
|
||||
)
|
||||
|
||||
os.chdir(str(run_cwd))
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(7, 11), CodePosition(11, 11)],
|
||||
func,
|
||||
project_root_path,
|
||||
"pytest",
|
||||
test_path, [CodePosition(7, 11), CodePosition(11, 11)], func, project_root_path, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -1969,8 +1825,7 @@ def test_sort():
|
|||
codeflash_con.close()
|
||||
"""
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve()
|
||||
/ "../code_to_optimize/tests/pytest/test_conditional_instrumentation_temp.py"
|
||||
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_conditional_instrumentation_temp.py"
|
||||
)
|
||||
try:
|
||||
with open(test_path, "w") as f:
|
||||
|
|
@ -1980,19 +1835,11 @@ def test_sort():
|
|||
project_root_path = Path(__file__).parent.resolve() / "../code_to_optimize/"
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
original_cwd = Path.cwd()
|
||||
func = FunctionToOptimize(
|
||||
function_name="sorter",
|
||||
file_path=project_root_path / "module.py",
|
||||
parents=[],
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", file_path=project_root_path / "module.py", parents=[])
|
||||
|
||||
os.chdir(str(run_cwd))
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(7, 15)],
|
||||
func,
|
||||
project_root_path,
|
||||
"pytest",
|
||||
test_path, [CodePosition(7, 15)], func, project_root_path, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -2089,11 +1936,7 @@ def test_sort():
|
|||
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(6, 26), CodePosition(10, 26)],
|
||||
function_to_optimize,
|
||||
project_root_path,
|
||||
"pytest",
|
||||
test_path, [CodePosition(6, 26), CodePosition(10, 26)], function_to_optimize, project_root_path, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
|
|
@ -2213,17 +2056,12 @@ def test_code_replacement10() -> None:
|
|||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
Path(f.name),
|
||||
[CodePosition(22, 28), CodePosition(28, 28)],
|
||||
func,
|
||||
Path(f.name).parent,
|
||||
"pytest",
|
||||
Path(f.name), [CodePosition(22, 28), CodePosition(28, 28)], func, Path(f.name).parent, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test == expected.format(
|
||||
module_path=Path(f.name).name,
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2299,18 +2137,10 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time):
|
|||
project_root_path = (Path(__file__).parent.resolve() / "../").resolve()
|
||||
original_cwd = Path.cwd()
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
func = FunctionToOptimize(
|
||||
function_name="accurate_sleepfunc",
|
||||
parents=[],
|
||||
file_path=Path("module.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="accurate_sleepfunc", parents=[], file_path=Path("module.py"))
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(8, 13)],
|
||||
func,
|
||||
project_root_path,
|
||||
"pytest",
|
||||
test_path, [CodePosition(8, 13)], func, project_root_path, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
|
@ -2337,16 +2167,10 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time):
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
|
||||
)
|
||||
test_results = opt.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
|
|
@ -2454,18 +2278,10 @@ class TestPigLatin(unittest.TestCase):
|
|||
project_root_path = (Path(__file__).parent.resolve() / "../").resolve()
|
||||
original_cwd = Path.cwd()
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
func = FunctionToOptimize(
|
||||
function_name="accurate_sleepfunc",
|
||||
parents=[],
|
||||
file_path=Path("module.py"),
|
||||
)
|
||||
func = FunctionToOptimize(function_name="accurate_sleepfunc", parents=[], file_path=Path("module.py"))
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(12, 17)],
|
||||
func,
|
||||
project_root_path,
|
||||
"unittest",
|
||||
test_path, [CodePosition(12, 17)], func, project_root_path, "unittest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
|
@ -2492,23 +2308,13 @@ class TestPigLatin(unittest.TestCase):
|
|||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
|
||||
)
|
||||
test_results = opt.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
test_files=test_files,
|
||||
optimization_iteration=0,
|
||||
test_functions=None,
|
||||
testing_time=0.1,
|
||||
test_env=test_env, test_files=test_files, optimization_iteration=0, test_functions=None, testing_time=0.1
|
||||
)
|
||||
|
||||
assert test_results[0].id.function_getting_tested == "accurate_sleepfunc"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,5 @@
|
|||
from codeflash.verification.parse_test_output import merge_test_results
|
||||
from codeflash.verification.test_results import (
|
||||
FunctionTestInvocation,
|
||||
InvocationId,
|
||||
TestResults,
|
||||
TestType,
|
||||
)
|
||||
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
|
||||
|
||||
|
||||
def test_merge_test_results_1():
|
||||
|
|
@ -61,7 +56,7 @@ def test_merge_test_results_1():
|
|||
timed_out=False,
|
||||
loop_index=1,
|
||||
),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
test_results_bin = TestResults(
|
||||
|
|
@ -117,7 +112,7 @@ def test_merge_test_results_1():
|
|||
timed_out=False,
|
||||
loop_index=1,
|
||||
),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
expected_merged_results = TestResults(
|
||||
|
|
@ -173,12 +168,10 @@ def test_merge_test_results_1():
|
|||
timed_out=False,
|
||||
loop_index=1,
|
||||
),
|
||||
],
|
||||
]
|
||||
)
|
||||
merged_results = merge_test_results(
|
||||
xml_test_results=test_results_xml,
|
||||
bin_test_results=test_results_bin,
|
||||
test_framework="unittest",
|
||||
xml_test_results=test_results_xml, bin_test_results=test_results_bin, test_framework="unittest"
|
||||
)
|
||||
assert merged_results == expected_merged_results
|
||||
|
||||
|
|
@ -200,30 +193,24 @@ def test_merge_test_results_1():
|
|||
return_value=None,
|
||||
timed_out=False,
|
||||
loop_index=1,
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
merged_results = merge_test_results(
|
||||
xml_test_results=test_results_xml_single,
|
||||
bin_test_results=test_results_bin,
|
||||
test_framework="unittest",
|
||||
xml_test_results=test_results_xml_single, bin_test_results=test_results_bin, test_framework="unittest"
|
||||
)
|
||||
|
||||
assert merged_results == expected_merged_results
|
||||
|
||||
merged_results = merge_test_results(
|
||||
xml_test_results=test_results_xml_single,
|
||||
bin_test_results=TestResults(),
|
||||
test_framework="unittest",
|
||||
xml_test_results=test_results_xml_single, bin_test_results=TestResults(), test_framework="unittest"
|
||||
)
|
||||
|
||||
assert merged_results == test_results_xml_single
|
||||
|
||||
merged_results = merge_test_results(
|
||||
xml_test_results=TestResults(),
|
||||
bin_test_results=test_results_bin,
|
||||
test_framework="unittest",
|
||||
xml_test_results=TestResults(), bin_test_results=test_results_bin, test_framework="unittest"
|
||||
)
|
||||
|
||||
assert merged_results == TestResults() # XML Results should always have better coverage than bin results
|
||||
|
|
@ -246,8 +233,8 @@ def test_merge_test_results_1():
|
|||
return_value=None,
|
||||
timed_out=False,
|
||||
loop_index=1,
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
test_results_bin_pytest = TestResults(
|
||||
|
|
@ -286,13 +273,11 @@ def test_merge_test_results_1():
|
|||
timed_out=False,
|
||||
loop_index=1,
|
||||
),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
merged_results = merge_test_results(
|
||||
xml_test_results=test_results_xml_pytest,
|
||||
bin_test_results=test_results_bin_pytest,
|
||||
test_framework="unittest",
|
||||
xml_test_results=test_results_xml_pytest, bin_test_results=test_results_bin_pytest, test_framework="unittest"
|
||||
)
|
||||
|
||||
assert merged_results == test_results_bin_pytest
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from codeflash.code_utils.remove_generated_tests import (
|
||||
remove_functions_from_generated_tests,
|
||||
)
|
||||
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
|
||||
from codeflash.models.models import GeneratedTests, GeneratedTestsList
|
||||
|
||||
|
||||
|
|
@ -21,10 +19,7 @@ def test_sorted_list():
|
|||
# Test sorting an already sorted list
|
||||
codeflash_output = sorter([1, 2, 3, 4, 5])
|
||||
# Outputs were verified to be equal to the original implementation"""
|
||||
generated_tests = GeneratedTests(
|
||||
generated_original_test_source=generated_test_source,
|
||||
instrumented_test_source="",
|
||||
)
|
||||
generated_tests = GeneratedTests(generated_original_test_source=generated_test_source, instrumented_test_source="")
|
||||
generated_tests_list = GeneratedTestsList(generated_tests=[generated_tests])
|
||||
functions_to_remove = ["test_single_element"]
|
||||
|
||||
|
|
@ -59,10 +54,7 @@ def test_sorted_list():
|
|||
# Test sorting an already sorted list
|
||||
codeflash_output = sorter([1, 2, 3, 4, 5])
|
||||
# Outputs were verified to be equal to the original implementation"""
|
||||
generated_tests = GeneratedTests(
|
||||
generated_original_test_source=generated_test_source,
|
||||
instrumented_test_source="",
|
||||
)
|
||||
generated_tests = GeneratedTests(generated_original_test_source=generated_test_source, instrumented_test_source="")
|
||||
generated_tests_list_1 = GeneratedTestsList(generated_tests=[generated_tests])
|
||||
functions_to_remove = ["test_single_element", "test_sorted_list"]
|
||||
|
||||
|
|
@ -85,8 +77,7 @@ def test_sorted_list():
|
|||
# Outputs were verified to be equal to the original implementation"""
|
||||
|
||||
generated_tests_2 = GeneratedTests(
|
||||
generated_original_test_source=generated_test_source,
|
||||
instrumented_test_source="",
|
||||
generated_original_test_source=generated_test_source, instrumented_test_source=""
|
||||
)
|
||||
|
||||
generated_tests_list_2 = GeneratedTestsList(generated_tests=[generated_tests_2])
|
||||
|
|
@ -133,10 +124,7 @@ def test_list_with_mixed_orderable_and_non_orderable_types():
|
|||
sorter([True, 1, "string", [1, 2]])
|
||||
# Outputs were verified to be equal to the original implementation"""
|
||||
|
||||
generated_tests = GeneratedTests(
|
||||
generated_original_test_source=generated_test_source,
|
||||
instrumented_test_source="",
|
||||
)
|
||||
generated_tests = GeneratedTests(generated_original_test_source=generated_test_source, instrumented_test_source="")
|
||||
generated_tests_list = GeneratedTestsList(generated_tests=[generated_tests])
|
||||
functions_to_remove = ["test_list_with_custom_objects"]
|
||||
|
||||
|
|
@ -189,10 +177,7 @@ def test_sorted_list():
|
|||
# Test sorting an already sorted list
|
||||
codeflash_output = sorter([1, 2, 3, 4, 5])
|
||||
# Outputs were verified to be equal to the original implementation"""
|
||||
generated_tests = GeneratedTests(
|
||||
generated_original_test_source=generated_test_source,
|
||||
instrumented_test_source="",
|
||||
)
|
||||
generated_tests = GeneratedTests(generated_original_test_source=generated_test_source, instrumented_test_source="")
|
||||
generated_tests_list = GeneratedTestsList(generated_tests=[generated_tests])
|
||||
functions_to_remove = ["test_empty_list", "test_sort_parametrized"]
|
||||
|
||||
|
|
@ -223,9 +208,7 @@ def test_sorted_list():
|
|||
assert generated_tests_list.generated_tests[0].generated_original_test_source == expected
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
"We don't handle the edge case where the parametrized test appears right after the test to remove",
|
||||
)
|
||||
@pytest.mark.skip("We don't handle the edge case where the parametrized test appears right after the test to remove")
|
||||
def test_keep_parametrized_test2():
|
||||
generated_test_source = """def test_empty_list():
|
||||
# Test sorting an empty list
|
||||
|
|
@ -253,10 +236,7 @@ def test_sorted_list():
|
|||
# Test sorting an already sorted list
|
||||
codeflash_output = sorter([1, 2, 3, 4, 5])
|
||||
# Outputs were verified to be equal to the original implementation"""
|
||||
generated_tests = GeneratedTests(
|
||||
generated_original_test_source=generated_test_source,
|
||||
instrumented_test_source="",
|
||||
)
|
||||
generated_tests = GeneratedTests(generated_original_test_source=generated_test_source, instrumented_test_source="")
|
||||
generated_tests_list = GeneratedTestsList(generated_tests=[generated_tests])
|
||||
functions_to_remove = ["test_empty_list", "test_sort_parametrized"]
|
||||
|
||||
|
|
|
|||
|
|
@ -9,11 +9,7 @@ from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, sav
|
|||
|
||||
|
||||
class TestShellUtils(unittest.TestCase):
|
||||
@patch(
|
||||
"codeflash.code_utils.shell_utils.open",
|
||||
new_callable=mock_open,
|
||||
read_data="existing content",
|
||||
)
|
||||
@patch("codeflash.code_utils.shell_utils.open", new_callable=mock_open, read_data="existing content")
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_save_api_key_to_rc_success(self, mock_get_shell_rc_path, mock_file):
|
||||
mock_get_shell_rc_path.return_value = "/fake/path/.bashrc"
|
||||
|
|
@ -25,11 +21,7 @@ class TestShellUtils(unittest.TestCase):
|
|||
handle.write.assert_called_once()
|
||||
handle.truncate.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"codeflash.code_utils.shell_utils.open",
|
||||
new_callable=mock_open,
|
||||
read_data="existing content",
|
||||
)
|
||||
@patch("codeflash.code_utils.shell_utils.open", new_callable=mock_open, read_data="existing content")
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_save_api_key_to_rc_failure(self, mock_get_shell_rc_path, mock_file):
|
||||
mock_get_shell_rc_path.return_value = "/fake/path/.bashrc"
|
||||
|
|
@ -59,8 +51,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
with patch("codeflash.code_utils.shell_utils.get_shell_rc_path") as mock_get_shell_rc_path:
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n'),
|
||||
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n')
|
||||
) as mock_file:
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
|
|
@ -83,10 +74,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
with patch("builtins.open", mock_open(read_data=f"CODEFLASH_API_KEY={self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n"),
|
||||
):
|
||||
with patch("builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
|
@ -96,9 +84,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(
|
||||
read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n',
|
||||
),
|
||||
mock_open(read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n'),
|
||||
):
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
|
||||
|
|
@ -108,9 +94,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(
|
||||
read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n',
|
||||
),
|
||||
mock_open(read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n'),
|
||||
):
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
|
||||
|
|
@ -118,10 +102,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
def test_api_key_in_comment(self, mock_get_shell_rc_path):
|
||||
"""Test with API key export in a comment."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n'),
|
||||
):
|
||||
with patch("builtins.open", mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n')):
|
||||
self.assertIsNone(read_api_key_from_shell_config())
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
|
|
|
|||
|
|
@ -36,16 +36,12 @@ class TestUnittestRunnerSorter(unittest.TestCase):
|
|||
|
||||
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
fp.write(code.encode("utf-8"))
|
||||
fp.flush()
|
||||
result_file, process = run_tests(
|
||||
test_files,
|
||||
test_framework=config.test_framework,
|
||||
cwd=Path(config.project_root_path),
|
||||
test_files, test_framework=config.test_framework, cwd=Path(config.project_root_path)
|
||||
)
|
||||
results = parse_test_xml(result_file, test_files, config, process)
|
||||
assert results[0].did_pass, "Test did not pass as expected"
|
||||
|
|
@ -72,9 +68,7 @@ def test_sort():
|
|||
)
|
||||
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST),
|
||||
],
|
||||
test_files=[TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
fp.write(code.encode("utf-8"))
|
||||
fp.flush()
|
||||
|
|
@ -90,10 +84,7 @@ def test_sort():
|
|||
pytest_target_runtime_seconds=1,
|
||||
)
|
||||
results = parse_test_xml(
|
||||
test_xml_file_path=result_file,
|
||||
test_files=test_files,
|
||||
test_config=config,
|
||||
run_result=process,
|
||||
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process
|
||||
)
|
||||
assert results[0].did_pass, "Test did not pass as expected"
|
||||
result_file.unlink(missing_ok=True)
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@ from typing import List
|
|||
|
||||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.optimization.function_context import (
|
||||
get_constrained_function_context_and_helper_functions,
|
||||
)
|
||||
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
|
||||
|
||||
|
||||
class CustomType:
|
||||
|
|
@ -94,16 +92,11 @@ def test_function_context_works_for_composite_types() -> None:
|
|||
def test_function_context_custom_datatype() -> None:
|
||||
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
||||
file_path = project_path / "math_utils.py"
|
||||
code, contextual_dunder_methods = get_code(
|
||||
[FunctionToOptimize("cosine_similarity", str(file_path), [])],
|
||||
)
|
||||
code, contextual_dunder_methods = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])])
|
||||
assert code is not None
|
||||
assert contextual_dunder_methods == set()
|
||||
a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
|
||||
FunctionToOptimize("cosine_similarity", str(file_path), []),
|
||||
str(project_path),
|
||||
code,
|
||||
1000,
|
||||
FunctionToOptimize("cosine_similarity", str(file_path), []), str(project_path), code, 1000
|
||||
)
|
||||
|
||||
assert len(helper_functions) == 1
|
||||
|
|
|
|||
|
|
@ -75,10 +75,7 @@ def test_discover_tests_pytest_with_temp_dir_root():
|
|||
assert {
|
||||
discovered_tests["dummy_code.dummy_function"][0].tests_in_file.test_function,
|
||||
discovered_tests["dummy_code.dummy_function"][1].tests_in_file.test_function,
|
||||
} == {
|
||||
"test_dummy_parametrized_function[True]",
|
||||
"test_dummy_function",
|
||||
}
|
||||
} == {"test_dummy_parametrized_function[True]", "test_dummy_function"}
|
||||
|
||||
|
||||
def test_discover_tests_pytest_with_multi_level_dirs():
|
||||
|
|
@ -147,8 +144,7 @@ def test_discover_tests_pytest_with_multi_level_dirs():
|
|||
assert len(discovered_tests) == 3
|
||||
assert discovered_tests["root_code.root_function"][0].tests_in_file.test_file == root_test_file_path
|
||||
assert (
|
||||
discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file
|
||||
== level1_test_file_path
|
||||
discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file == level1_test_file_path
|
||||
)
|
||||
|
||||
assert (
|
||||
|
|
@ -238,8 +234,7 @@ def test_discover_tests_pytest_dirs():
|
|||
assert len(discovered_tests) == 4
|
||||
assert discovered_tests["root_code.root_function"][0].tests_in_file.test_file == root_test_file_path
|
||||
assert (
|
||||
discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file
|
||||
== level1_test_file_path
|
||||
discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file == level1_test_file_path
|
||||
)
|
||||
assert (
|
||||
discovered_tests["level1.level2.level2_code.level2_function"][0].tests_in_file.test_file
|
||||
|
|
@ -283,10 +278,7 @@ def test_discover_tests_pytest_with_class():
|
|||
|
||||
# Check if the test class and method are discovered
|
||||
assert len(discovered_tests) == 1
|
||||
assert (
|
||||
discovered_tests["some_class_code.SomeClass.some_method"][0].tests_in_file.test_file
|
||||
== test_file_path
|
||||
)
|
||||
assert discovered_tests["some_class_code.SomeClass.some_method"][0].tests_in_file.test_file == test_file_path
|
||||
|
||||
|
||||
def test_discover_tests_pytest_with_double_nested_directories():
|
||||
|
|
@ -414,9 +406,7 @@ def test_discover_tests_pytest_with_nested_class():
|
|||
# Check if the test for the nested class method is discovered
|
||||
assert len(discovered_tests) == 1
|
||||
assert (
|
||||
discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"][
|
||||
0
|
||||
].tests_in_file.test_file
|
||||
discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"][0].tests_in_file.test_file
|
||||
== test_file_path
|
||||
)
|
||||
|
||||
|
|
@ -455,6 +445,4 @@ def test_discover_tests_pytest_separate_moduledir():
|
|||
|
||||
# Check if the test for the nested class method is discovered
|
||||
assert len(discovered_tests) == 1
|
||||
assert (
|
||||
discovered_tests["mypackage.code.find_common_tags"][0].tests_in_file.test_file == test_file_path
|
||||
)
|
||||
assert discovered_tests["mypackage.code.find_common_tags"][0].tests_in_file.test_file == test_file_path
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, Optional
|
|||
from posthog import Posthog
|
||||
|
||||
_posthog: Posthog = Posthog(
|
||||
project_api_key="phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol", host="https://us.posthog.com",
|
||||
project_api_key="phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol", host="https://us.posthog.com"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -20,8 +20,4 @@ def ph(user_id: str, event: str, properties: Optional[Dict[str, Any]] = None) ->
|
|||
properties["environment"] = os.environ.get("ENVIRONMENT", default="None")
|
||||
properties["openai_api_type"] = os.environ.get("OPENAI_API_TYPE", default="openai")
|
||||
|
||||
_posthog.capture(
|
||||
distinct_id=user_id,
|
||||
event=event,
|
||||
properties=properties,
|
||||
)
|
||||
_posthog.capture(distinct_id=user_id, event=event, properties=properties)
|
||||
|
|
|
|||
|
|
@ -26,8 +26,7 @@ class FunctionToOptimize:
|
|||
def top_level_parent_name(self) -> str:
|
||||
if self.parents:
|
||||
return self.parents[0].name
|
||||
else:
|
||||
return self.function_name
|
||||
return self.function_name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.file_path}:{'.'.join([p.name for p in self.parents]) + '.' if self.parents else ''}{self.function_name}"
|
||||
|
|
|
|||
|
|
@ -76,30 +76,17 @@ WSGI_APPLICATION: str = "aiservice.wsgi.application"
|
|||
# Requires DATABASE_URL environment variable to be set
|
||||
|
||||
assert "DATABASE_URL" in os.environ, "DATABASE_URL environment variable not set"
|
||||
DATABASES = {
|
||||
"default": dj_database_url.config(
|
||||
conn_max_age=600,
|
||||
conn_health_checks=True,
|
||||
),
|
||||
}
|
||||
DATABASES = {"default": dj_database_url.config(conn_max_age=600, conn_health_checks=True)}
|
||||
|
||||
|
||||
# Password validation
|
||||
# https://docs.djangoproject.com/en/5.0/ref/settings/#auth-password-validators
|
||||
|
||||
AUTH_PASSWORD_VALIDATORS: list[dict[str, str]] = [
|
||||
{
|
||||
"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
|
||||
},
|
||||
{
|
||||
"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
|
||||
},
|
||||
{
|
||||
"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
|
||||
},
|
||||
{
|
||||
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
|
||||
},
|
||||
{"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator"},
|
||||
{"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"},
|
||||
{"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"},
|
||||
{"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,12 +18,12 @@ Including another URLconf
|
|||
"""
|
||||
|
||||
# from django.contrib import admin
|
||||
from django.urls import path
|
||||
|
||||
from log_features.log_features import features_api
|
||||
from optimizer.optimizer import optimize_api
|
||||
from testgen.testgen import testgen_api
|
||||
|
||||
from django.urls import path
|
||||
|
||||
urlpatterns = [
|
||||
path("ai/optimize", optimize_api.urls),
|
||||
path("ai/testgen", testgen_api.urls),
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class AuthBearer(HttpBearer):
|
|||
num_users = await CFAPIKeys.objects.filter(key=hashed_token).aupdate(last_used=Now())
|
||||
if num_users == 0:
|
||||
raise HttpError(403, "Invalid API key")
|
||||
elif num_users == 1:
|
||||
if num_users == 1:
|
||||
api_key_instance = await instance_for_api_key(hashed_token)
|
||||
if not api_key_instance:
|
||||
print(f"Instance not found for api key {token}. Returning 403")
|
||||
|
|
@ -21,10 +21,7 @@ class AuthBearer(HttpBearer):
|
|||
request.user = api_key_instance.user_id
|
||||
request.tier = api_key_instance.tier
|
||||
return token
|
||||
else:
|
||||
print(
|
||||
"THIS SHOULD NOT HAPPEN! More than one users found in the db with the same api key!",
|
||||
)
|
||||
raise HttpError(403, "Invalid API key")
|
||||
print("THIS SHOULD NOT HAPPEN! More than one users found in the db with the same api key!")
|
||||
raise HttpError(403, "Invalid API key")
|
||||
except CFAPIKeys.DoesNotExist:
|
||||
raise HttpError(403, "Invalid API key")
|
||||
|
|
|
|||
|
|
@ -24,10 +24,7 @@ class CFAPIKeys(models.Model):
|
|||
suffix = models.CharField(max_length=4)
|
||||
name = models.CharField(max_length=255)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
last_used = models.DateTimeField(
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
last_used = models.DateTimeField(null=True, blank=True)
|
||||
user_id = models.TextField(null=True, blank=True)
|
||||
tier = models.TextField(null=True, blank=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@ from __future__ import annotations
|
|||
|
||||
import datetime as dt
|
||||
import logging
|
||||
from asyncio import Lock, Semaphore
|
||||
from asyncio import Semaphore
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from authapp.auth import AuthBearer
|
||||
from ninja import NinjaAPI, Schema
|
||||
|
||||
from authapp.auth import AuthBearer
|
||||
from log_features.models import OptimizationFeatures
|
||||
|
||||
features_api = NinjaAPI(auth=AuthBearer(), urls_namespace="log_features")
|
||||
|
|
@ -79,36 +79,24 @@ async def log_features(
|
|||
if generated_tests:
|
||||
f.generated_test = (f.generated_test or []) + generated_tests
|
||||
if instrumented_generated_tests:
|
||||
f.instrumented_generated_test = (
|
||||
f.instrumented_generated_test or []
|
||||
) + instrumented_generated_tests
|
||||
f.instrumented_generated_test = (f.instrumented_generated_test or []) + instrumented_generated_tests
|
||||
|
||||
# Update fields on the existing instance
|
||||
f.user_id = user_id if user_id is not None else f.user_id
|
||||
f.original_code = original_code if original_code is not None else f.original_code
|
||||
f.optimizations_raw = (
|
||||
optimizations_raw if optimizations_raw is not None else f.optimizations_raw
|
||||
)
|
||||
f.optimizations_post = (
|
||||
optimizations_post if optimizations_post is not None else f.optimizations_post
|
||||
)
|
||||
f.optimizations_raw = optimizations_raw if optimizations_raw is not None else f.optimizations_raw
|
||||
f.optimizations_post = optimizations_post if optimizations_post is not None else f.optimizations_post
|
||||
f.explanations_raw = explanations_raw if explanations_raw is not None else f.explanations_raw
|
||||
f.explanations_post = (
|
||||
explanations_post if explanations_post is not None else f.explanations_post
|
||||
)
|
||||
f.explanations_post = explanations_post if explanations_post is not None else f.explanations_post
|
||||
f.speedup_ratio = speedup_ratio if speedup_ratio is not None else f.speedup_ratio
|
||||
f.original_runtime = original_runtime if original_runtime is not None else f.original_runtime
|
||||
f.optimized_runtime = (
|
||||
optimized_runtime if optimized_runtime is not None else f.optimized_runtime
|
||||
)
|
||||
f.optimized_runtime = optimized_runtime if optimized_runtime is not None else f.optimized_runtime
|
||||
f.is_correct = is_correct if is_correct is not None else f.is_correct
|
||||
f.generated_test = generated_tests if generated_tests is not None else f.generated_test
|
||||
|
||||
f.test_framework = test_framework if test_framework is not None else f.test_framework
|
||||
f.created_at = datetime if datetime is not None else f.created_at
|
||||
f.aiservice_commit_id = (
|
||||
aiservice_commit if aiservice_commit is not None else f.aiservice_commit_id
|
||||
)
|
||||
f.aiservice_commit_id = aiservice_commit if aiservice_commit is not None else f.aiservice_commit_id
|
||||
f.metadata = updated_metadata
|
||||
f.experiment_metadata = updated_experiment_metadata
|
||||
await f.asave()
|
||||
|
|
@ -154,13 +142,7 @@ class LoggingErrorResponseSchema(Schema):
|
|||
error: str
|
||||
|
||||
|
||||
@features_api.post(
|
||||
"/",
|
||||
response={
|
||||
200: None,
|
||||
500: LoggingErrorResponseSchema,
|
||||
},
|
||||
)
|
||||
@features_api.post("/", response={200: None, 500: LoggingErrorResponseSchema})
|
||||
async def log_features_cli(request, data: LoggingSchema):
|
||||
try:
|
||||
if request.tier is None:
|
||||
|
|
|
|||
|
|
@ -4,11 +4,7 @@ from django.db import models
|
|||
|
||||
|
||||
class OptimizationFeatures(models.Model):
|
||||
trace_id = models.CharField(
|
||||
max_length=36,
|
||||
primary_key=True,
|
||||
validators=[MinLengthValidator(36)],
|
||||
)
|
||||
trace_id = models.CharField(max_length=36, primary_key=True, validators=[MinLengthValidator(36)])
|
||||
original_code = models.TextField(null=True, blank=True)
|
||||
user_id = models.TextField(null=True, blank=True)
|
||||
optimizations_raw = models.JSONField(null=True, blank=True)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def main():
|
|||
raise ImportError(
|
||||
"Couldn't import Django. Are you sure it's installed and "
|
||||
"available on your PYTHONPATH environment variable? Did you "
|
||||
"forget to activate a virtual environment?",
|
||||
"forget to activate a virtual environment?"
|
||||
) from exc
|
||||
|
||||
load_env()
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue