Ruff reformat and fix all the python files

Set minimum libcst version to be 1.0.1
move the stub files to dev dependencies
This commit is contained in:
Saurabh Misra 2024-10-25 15:45:44 -07:00
parent 0a06160a57
commit b42c270f9a
114 changed files with 1085 additions and 3157 deletions

View file

@ -2,15 +2,7 @@ from __future__ import annotations
from sqlalchemy import ForeignKey, Integer, String, create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
Relationship,
Session,
mapped_column,
relationship,
sessionmaker,
)
from sqlalchemy.orm import DeclarativeBase, Mapped, Relationship, Session, mapped_column, relationship, sessionmaker
# Custom base class
@ -18,24 +10,24 @@ class Base(DeclarativeBase):
pass
engine: Engine = create_engine('sqlite:///example.db')
engine: Engine = create_engine("sqlite:///example.db")
session_factory = sessionmaker(bind=engine)
session: Session = session_factory()
class User(Base):
__tablename__: str = 'users'
__tablename__: str = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String)
posts: Relationship[list[Post]] = relationship("Post", order_by="Post.id", back_populates="user")
class Post(Base):
__tablename__: str = 'posts'
__tablename__: str = "posts"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
title: Mapped[str] = mapped_column(String)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey('users.id'))
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"))
user: Relationship[User] = relationship("User", back_populates="posts")

View file

@ -6,8 +6,10 @@ from sqlalchemy.engine import Engine, create_engine
from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker
from sqlalchemy.orm.relationships import Relationship
POSTGRES_CONNECTION_STRING: str = ("postgresql://cf_developer:XJcbU37MBYeh4dDK6PTV5n@sqlalchemy-experiments.postgres"
".database.azure.com:5432/postgres")
POSTGRES_CONNECTION_STRING: str = (
"postgresql://cf_developer:XJcbU37MBYeh4dDK6PTV5n@sqlalchemy-experiments.postgres"
".database.azure.com:5432/postgres"
)
class Base(DeclarativeBase):
@ -52,10 +54,8 @@ def get_authors(books: list[Book]) -> list[Author]:
book: Book
for book in books:
_authors.append(book.author)
return sorted(
list(set(_authors)),
key=lambda x: x.id,
)
return sorted(list(set(_authors)), key=lambda x: x.id)
def get_authors2(num_authors) -> list[Author]:
engine: Engine = create_engine(POSTGRES_CONNECTION_STRING, echo=True)
@ -66,10 +66,7 @@ def get_authors2(num_authors) -> list[Author]:
book: Book
for book in books:
_authors.append(book.author)
return sorted(
list(set(_authors)),
key=lambda x: x.id,
)[:num_authors]
return sorted(list(set(_authors)), key=lambda x: x.id)[:num_authors]
def get_top_author(authors: List[Author]) -> Author:
@ -84,9 +81,7 @@ def get_top_author(authors: List[Author]) -> Author:
# Step 2: Iterate over each author to count their bestsellers
for author in authors:
bestseller_count = (
session.query(func.count(Book.id))
.filter(Book.author_id == author.id, Book.is_bestseller == True)
.scalar()
session.query(func.count(Book.id)).filter(Book.author_id == author.id, Book.is_bestseller == True).scalar()
)
# Step 3: Update the author with the maximum bestsellers

View file

@ -5,19 +5,7 @@ from typing import Any, cast
from _typeshed import SupportsDunderGT, SupportsDunderLT
from sqlalchemy.orm import Session
from code_to_optimize.book_catalog import (
POSTGRES_CONNECTION_STRING,
Author,
Base,
Book,
_session,
_t,
authors,
authors_name,
engine,
init_table,
session_factory,
)
from code_to_optimize.book_catalog import Author, Book
def get_authors(session: Session) -> list[Author]:
@ -26,7 +14,4 @@ def get_authors(session: Session) -> list[Author]:
book: Book
for book in books:
_authors.append(book.author)
return sorted(
list(set(_authors)),
key=lambda x: cast(SupportsDunderLT[Any] | SupportsDunderGT[Any], x.id),
)
return sorted(list(set(_authors)), key=lambda x: cast(SupportsDunderLT[Any] | SupportsDunderGT[Any], x.id))

View file

@ -1,18 +1,6 @@
from __future__ import annotations
from code_to_optimize.book_catalog import (
POSTGRES_CONNECTION_STRING,
Author,
Base,
Book,
_session,
_t,
authors,
authors_name,
engine,
init_table,
session_factory,
)
from code_to_optimize.book_catalog import Book
def get_authors(session):

View file

@ -9,7 +9,7 @@ class BubbleSortClass:
def sorter(self, arr):
n = len(arr)
for i in range(n):
for j in range(0, n - i - 1):
for j in range(n - i - 1):
if arr[j] > arr[j + 1]:
arr[j], arr[j + 1] = arr[j + 1], arr[j]
return arr

View file

@ -1,4 +1,4 @@
from __future__ import annotations
from __future__ import annotations
from typing import Any, Callable, Iterable, NewType, Optional, Protocol, TypeVar
@ -84,16 +84,11 @@ def is_new_type2(type_: type[Any]) -> bool:
def _to_str(
size: int,
suffixes: Iterable[str],
base: int,
*,
precision: Optional[int] = 1,
separator: Optional[str] = " ",
size: int, suffixes: Iterable[str], base: int, *, precision: Optional[int] = 1, separator: Optional[str] = " "
) -> str:
if size == 1:
return "1 byte"
elif size < base:
if size < base:
return f"{size:,} bytes"
for i, suffix in enumerate(suffixes, 2): # noqa: B007
@ -101,10 +96,7 @@ def _to_str(
if size < unit:
break
return "{:,.{precision}f}{separator}{}".format(
(base * size / unit),
suffix,
precision=precision,
separator=separator,
(base * size / unit), suffix, precision=precision, separator=separator
)
@ -114,16 +106,11 @@ def _to_str(
def _to_str2(
size: int,
suffixes: Iterable[str],
base: int,
*,
precision: Optional[int] = 1,
separator: Optional[str] = " ",
size: int, suffixes: Iterable[str], base: int, *, precision: Optional[int] = 1, separator: Optional[str] = " "
) -> str:
if size == 1:
return "1 byte"
elif size < base:
if size < base:
return f"{size:,} bytes"
unit = base
@ -376,9 +363,7 @@ def with_pattern(pattern: str, regex_group_count: int | None = None) -> Callable
def with_pattern2(pattern: str, regex_group_count: int | None = None) -> Callable:
return (
lambda func: setattr(func, "pattern", pattern)
or setattr(func, "regex_group_count", regex_group_count)
or func
lambda func: setattr(func, "pattern", pattern) or setattr(func, "regex_group_count", regex_group_count) or func
)

View file

@ -4,10 +4,7 @@ from typing import List, Set, Tuple
def compare_lists(
li1: List[int],
li2: List[int],
value_func1=None,
value_func2=None,
li1: List[int], li2: List[int], value_func1=None, value_func2=None
) -> Tuple[Set[int], Set[int], Set[int]]:
"""Compare *li1* and *li2*, return the results as a list in the following form:

View file

@ -2,14 +2,13 @@ def translate(word):
vowels = "aeiou"
if word[0] in vowels:
return word + "way"
else:
consonants = ""
for letter in word:
if letter not in vowels:
consonants += letter
else:
break
return word[len(consonants) :] + consonants + "ay"
consonants = ""
for letter in word:
if letter not in vowels:
consonants += letter
else:
break
return word[len(consonants) :] + consonants + "ay"
def pig_latin(text):

View file

@ -1,11 +1,8 @@
def single_name_to_first_last_names(
name: str,
) -> list[tuple[str, str]]:
def single_name_to_first_last_names(name: str) -> list[tuple[str, str]]:
parts = name.upper().split()
if len(parts) == 2:
return [tuple(parts)]
elif len(parts) == 3:
if len(parts) == 3:
a, b, c = parts
return [(a, c), (a, f"{b} {c}"), (f"{a} {b}", c)]
else:
return []
return []

View file

@ -1,6 +1,4 @@
from code_to_optimize.final_test_set.encode_python_string_to_c import (
_encodePythonStringToC,
)
from code_to_optimize.final_test_set.encode_python_string_to_c import _encodePythonStringToC
def test_empty_string():

View file

@ -25,9 +25,7 @@ def test_common_tags_1():
def test_empty_article_list():
articles = []
expected = set()
assert (
find_common_tags(articles) == expected
), "Test failed for empty list of articles."
assert find_common_tags(articles) == expected, "Test failed for empty list of articles."
def test_no_common_tags():
@ -37,9 +35,7 @@ def test_no_common_tags():
{"tags": ["javascript", "development", "web"]},
]
expected = set()
assert (
find_common_tags(articles) == expected
), "Test failed when no tags are common."
assert find_common_tags(articles) == expected, "Test failed when no tags are common."
def test_all_common_tags():
@ -49,14 +45,10 @@ def test_all_common_tags():
{"tags": ["tech", "startups", "innovation"]},
]
expected = {"tech", "startups", "innovation"}
assert (
find_common_tags(articles) == expected
), "Test failed when all tags are common."
assert find_common_tags(articles) == expected, "Test failed when all tags are common."
def test_single_article():
articles = [{"tags": ["single", "article", "test"]}]
expected = {"single", "article", "test"}
assert (
find_common_tags(articles) == expected
), "Test failed for a single article input."
assert find_common_tags(articles) == expected, "Test failed for a single article input."

View file

@ -2,11 +2,7 @@ from code_to_optimize.final_test_set.find_duplicates import find_duplicates
def test_basic_case():
assert find_duplicates([1, 2, 3, 2, 1, 5, 6, 5]) == [
1,
2,
5,
], "Failed on basic case"
assert find_duplicates([1, 2, 3, 2, 1, 5, 6, 5]) == [1, 2, 5], "Failed on basic case"
def test_no_duplicates():
@ -14,10 +10,7 @@ def test_no_duplicates():
def test_multiple_duplicates():
assert find_duplicates([1, 2, 2, 3, 3, 3, 4]) == [
2,
3,
], "Failed on multiple duplicates of the same item"
assert find_duplicates([1, 2, 2, 3, 3, 3, 4]) == [2, 3], "Failed on multiple duplicates of the same item"
def test_empty_list():
@ -29,7 +22,4 @@ def test_all_elements_same():
def test_mixed_data_types():
assert find_duplicates(["apple", "banana", "apple", 42, 42]) == [
"apple",
42,
], "Failed on mixed data types"
assert find_duplicates(["apple", "banana", "apple", 42, 42]) == ["apple", 42], "Failed on mixed data types"

View file

@ -17,28 +17,18 @@ def test_prime_number():
def test_perfect_square():
assert find_factors(16) == [
(1, 16),
(2, 8),
(4, 4),
(8, 2),
(16, 1),
], "Failed on perfect square number"
assert find_factors(16) == [(1, 16), (2, 8), (4, 4), (8, 2), (16, 1)], "Failed on perfect square number"
def test_large_number():
# 120 has factors: 1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20, 24, 30, 40, 60, 120
result = find_factors(120)
expected_factors = 16 # There should be 16 pairs
assert (
len(result) == expected_factors
), "Failed on large number with multiple factors"
assert len(result) == expected_factors, "Failed on large number with multiple factors"
def test_one():
assert find_factors(1) == [
(1, 1)
], "Failed on one, which should only have one factor pair"
assert find_factors(1) == [(1, 1)], "Failed on one, which should only have one factor pair"
def test_zero():

View file

@ -4,8 +4,7 @@ from code_to_optimize.final_test_set.integration import integrate_f
def isclose(a, b, rel_tol=1e-5, abs_tol=0.0):
"""
Helper function to compare two floating points for 'closeness'.
"""Helper function to compare two floating points for 'closeness'.
Uses a combination of relative and absolute tolerances.
"""
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
@ -30,6 +29,4 @@ def test_with_pytest_approx():
a, b, N = 0, 1, 1000
result = integrate_f(a, b, N)
expected = -1 / 6
assert result == pytest.approx(
expected, rel=1e-5
), "Test failed with pytest's approx."
assert result == pytest.approx(expected, rel=1e-5), "Test failed with pytest's approx."

View file

@ -4,7 +4,7 @@ from code_to_optimize.final_test_set.pig_latin import pig_latin
def log_test_values(values, test_name):
with open(f"/tmp/test_return_values.bin", "ab") as f:
with open("/tmp/test_return_values.bin", "ab") as f:
return_bytes = pickle.dumps(values)
_test_name = f"{test_name}".encode("ascii")
f.write(len(_test_name).to_bytes(4, byteorder="big"))
@ -25,12 +25,8 @@ def test_pig_latin_single_consonant():
def test_pig_latin_multiple_consonants():
log_test_values(
pig_latin("string"), "pig_latin_test_pig_latin_multiple_consonants_0"
)
log_test_values(
pig_latin("glove"), "pig_latin_test_pig_latin_multiple_consonants_1"
)
log_test_values(pig_latin("string"), "pig_latin_test_pig_latin_multiple_consonants_0")
log_test_values(pig_latin("glove"), "pig_latin_test_pig_latin_multiple_consonants_1")
def test_pig_latin_capital_letters():
@ -39,13 +35,8 @@ def test_pig_latin_capital_letters():
def test_pig_latin_multiple_words():
log_test_values(
pig_latin("The quick brown fox"), "pig_latin_test_pig_latin_multiple_words_0"
)
log_test_values(
pig_latin("Python is a fun language"),
"pig_latin_test_pig_latin_multiple_words_1",
)
log_test_values(pig_latin("The quick brown fox"), "pig_latin_test_pig_latin_multiple_words_0")
log_test_values(pig_latin("Python is a fun language"), "pig_latin_test_pig_latin_multiple_words_1")
def test_pig_latin_empty_input():
@ -58,9 +49,7 @@ def test_pig_latin_spaces_input():
def test_pig_latin_non_alphabetic():
log_test_values(pig_latin("123"), "pig_latin_test_pig_latin_non_alphabetic_0")
log_test_values(
pig_latin("Hello, world!"), "pig_latin_test_pig_latin_non_alphabetic_1"
)
log_test_values(pig_latin("Hello, world!"), "pig_latin_test_pig_latin_non_alphabetic_1")
def test_pig_latin_non_ascii():
@ -69,12 +58,8 @@ def test_pig_latin_non_ascii():
def test_pig_latin_hyphenated_words():
log_test_values(
pig_latin("sister-in-law"), "pig_latin_test_pig_latin_hyphenated_words_0"
)
log_test_values(
pig_latin("self-driving car"), "pig_latin_test_pig_latin_hyphenated_words_1"
)
log_test_values(pig_latin("sister-in-law"), "pig_latin_test_pig_latin_hyphenated_words_0")
log_test_values(pig_latin("self-driving car"), "pig_latin_test_pig_latin_hyphenated_words_1")
def test_pig_latin_contractions():
@ -84,9 +69,7 @@ def test_pig_latin_contractions():
def test_pig_latin_apostrophes():
log_test_values(pig_latin("don't"), "pig_latin_test_pig_latin_apostrophes_0")
log_test_values(
pig_latin("rock 'n' roll"), "pig_latin_test_pig_latin_apostrophes_1"
)
log_test_values(pig_latin("rock 'n' roll"), "pig_latin_test_pig_latin_apostrophes_1")
def test_pig_latin_non_letter():

View file

@ -1,6 +1,4 @@
from code_to_optimize.final_test_set.single_name_to_first_last_names import (
single_name_to_first_last_names,
)
from code_to_optimize.final_test_set.single_name_to_first_last_names import single_name_to_first_last_names
def test_two_part_name():

View file

@ -6,16 +6,10 @@ def test_concatenate_strings_zero():
def test_concatenate_strings_positive():
assert (
concatenate_strings(5) == "0, 1, 2, 3, 4, "
), "Failed: Incorrect string for input 5"
assert concatenate_strings(5) == "0, 1, 2, 3, 4, ", "Failed: Incorrect string for input 5"
def test_concatenate_strings_large_number():
result = concatenate_strings(1000)
expected_length = sum(
len(str(i)) + 2 for i in range(1000)
) # Each number i + len(", ")
assert (
len(result) == expected_length
), f"Failed: Incorrect length for large input 1000"
expected_length = sum(len(str(i)) + 2 for i in range(1000)) # Each number i + len(", ")
assert len(result) == expected_length, "Failed: Incorrect length for large input 1000"

View file

@ -13,9 +13,7 @@ def test_k_greater_than_array_length():
array = [4, 1, 5, 6, 2]
k = 10
expected = sorted(array, reverse=True)
assert (
find_top_k_elements(array, k) == expected
), "Failed when k is greater than array length"
assert find_top_k_elements(array, k) == expected, "Failed when k is greater than array length"
def test_normal_case():
@ -29,9 +27,7 @@ def test_array_with_duplicate_values():
array = [5, 5, 5, 5]
k = 2
expected = [5, 5]
assert (
find_top_k_elements(array, k) == expected
), "Failed when array contains duplicates"
assert find_top_k_elements(array, k) == expected, "Failed when array contains duplicates"
def test_empty_array():
@ -42,6 +38,4 @@ def test_single_element_array():
array = [42]
k = 1
expected = [42]
assert (
find_top_k_elements(array, k) == expected
), "Failed when array contains a single element"
assert find_top_k_elements(array, k) == expected, "Failed when array contains a single element"

View file

@ -33,10 +33,7 @@ def test_complex_dag():
g.addEdge(2, 3)
g.addEdge(3, 1)
result = g.topologicalSort()
assert all(
result.index(u) < result.index(v)
for u, v in [(5, 2), (5, 0), (4, 0), (4, 1), (2, 3), (3, 1)]
)
assert all(result.index(u) < result.index(v) for u, v in [(5, 2), (5, 0), (4, 0), (4, 1), (2, 3), (3, 1)])
def test_single_node_graph():

View file

@ -15,8 +15,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
Y = np.array(Y)
if X.shape[1] != Y.shape[1]:
raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
f"and Y has shape {Y.shape}.",
f"Number of columns in X and Y must be the same. X has shape {X.shape} " f"and Y has shape {Y.shape}."
)
X_norm = np.linalg.norm(X, axis=1)
@ -27,10 +26,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
def cosine_similarity_top_k(
X: Matrix,
Y: Matrix,
top_k: Optional[int] = 5,
score_threshold: Optional[float] = None,
X: Matrix, Y: Matrix, top_k: Optional[int] = 5, score_threshold: Optional[float] = None
) -> Tuple[List[Tuple[int, int]], List[float]]:
"""Row-wise cosine similarity with optional top-k and score threshold filtering.

View file

@ -2,14 +2,13 @@ def translate(word):
vowels = "aeiou"
if word[0] in vowels:
return word + "way"
else:
consonants = ""
for letter in word:
if letter not in vowels:
consonants += letter
else:
break
return word[len(consonants) :] + consonants + "ay"
consonants = ""
for letter in word:
if letter not in vowels:
consonants += letter
else:
break
return word[len(consonants) :] + consonants + "ay"
def pig_latin(text):

View file

@ -1,10 +1,10 @@
from typing import Generator
import pytest
from sqlalchemy import Engine, create_engine, delete, update
from sqlalchemy import Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker
from code_to_optimize.book_catalog import Author, Book, get_authors
from code_to_optimize.book_catalog import Book, get_authors
POSTGRES_CONNECTION_STRING = (
"postgresql://cf_developer:XJcbU37MBYeh4dDK6PTV5n@sqlalchemy-experiments.postgres"

View file

@ -1,10 +1,7 @@
from typing import Generator
import pytest
from sqlalchemy import Engine, create_engine, delete, update
from sqlalchemy import Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker
from code_to_optimize.book_catalog import Author, Book, get_top_author
from code_to_optimize.book_catalog import Author, get_top_author
POSTGRES_CONNECTION_STRING = (
"postgresql://cf_developer:XJcbU37MBYeh4dDK6PTV5n@sqlalchemy-experiments.postgres"

View file

@ -4,7 +4,7 @@ from code_to_optimize.pig_latin import pig_latin
def log_test_values(values, test_name):
with open(f"/tmp/test_return_values.bin", "ab") as f:
with open("/tmp/test_return_values.bin", "ab") as f:
return_bytes = pickle.dumps(values)
_test_name = f"{test_name}".encode("ascii")
f.write(len(_test_name).to_bytes(4, byteorder="big"))
@ -25,12 +25,8 @@ def test_pig_latin_single_consonant():
def test_pig_latin_multiple_consonants():
log_test_values(
pig_latin("string"), "pig_latin_test_pig_latin_multiple_consonants_0"
)
log_test_values(
pig_latin("glove"), "pig_latin_test_pig_latin_multiple_consonants_1"
)
log_test_values(pig_latin("string"), "pig_latin_test_pig_latin_multiple_consonants_0")
log_test_values(pig_latin("glove"), "pig_latin_test_pig_latin_multiple_consonants_1")
def test_pig_latin_capital_letters():
@ -39,13 +35,8 @@ def test_pig_latin_capital_letters():
def test_pig_latin_multiple_words():
log_test_values(
pig_latin("The quick brown fox"), "pig_latin_test_pig_latin_multiple_words_0"
)
log_test_values(
pig_latin("Python is a fun language"),
"pig_latin_test_pig_latin_multiple_words_1",
)
log_test_values(pig_latin("The quick brown fox"), "pig_latin_test_pig_latin_multiple_words_0")
log_test_values(pig_latin("Python is a fun language"), "pig_latin_test_pig_latin_multiple_words_1")
def test_pig_latin_empty_input():
@ -58,9 +49,7 @@ def test_pig_latin_spaces_input():
def test_pig_latin_non_alphabetic():
log_test_values(pig_latin("123"), "pig_latin_test_pig_latin_non_alphabetic_0")
log_test_values(
pig_latin("Hello, world!"), "pig_latin_test_pig_latin_non_alphabetic_1"
)
log_test_values(pig_latin("Hello, world!"), "pig_latin_test_pig_latin_non_alphabetic_1")
def test_pig_latin_non_ascii():
@ -69,12 +58,8 @@ def test_pig_latin_non_ascii():
def test_pig_latin_hyphenated_words():
log_test_values(
pig_latin("sister-in-law"), "pig_latin_test_pig_latin_hyphenated_words_0"
)
log_test_values(
pig_latin("self-driving car"), "pig_latin_test_pig_latin_hyphenated_words_1"
)
log_test_values(pig_latin("sister-in-law"), "pig_latin_test_pig_latin_hyphenated_words_0")
log_test_values(pig_latin("self-driving car"), "pig_latin_test_pig_latin_hyphenated_words_1")
def test_pig_latin_contractions():
@ -84,9 +69,7 @@ def test_pig_latin_contractions():
def test_pig_latin_apostrophes():
log_test_values(pig_latin("don't"), "pig_latin_test_pig_latin_apostrophes_0")
log_test_values(
pig_latin("rock 'n' roll"), "pig_latin_test_pig_latin_apostrophes_1"
)
log_test_values(pig_latin("rock 'n' roll"), "pig_latin_test_pig_latin_apostrophes_1")
def test_pig_latin_non_letter():

View file

@ -4,10 +4,7 @@ from code_to_optimize.math_utils import Matrix, cosine_similarity_top_k
def use_cosine_similarity(
X: Matrix,
Y: Matrix,
top_k: Optional[int] = 5,
score_threshold: Optional[float] = None,
X: Matrix, Y: Matrix, top_k: Optional[int] = 5, score_threshold: Optional[float] = None
) -> Tuple[List[Tuple[int, int]], List[float]]:
return cosine_similarity_top_k(X, Y, top_k, score_threshold)

View file

@ -33,10 +33,7 @@ class OptimizedCandidate:
class AiServiceClient:
def __init__(self) -> None:
self.base_url = self.get_aiservice_base_url()
self.headers = {
"Authorization": f"Bearer {get_codeflash_api_key()}",
"Connection": "close",
}
self.headers = {"Authorization": f"Bearer {get_codeflash_api_key()}", "Connection": "close"}
def get_aiservice_base_url(self) -> str:
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
@ -45,11 +42,7 @@ class AiServiceClient:
return "https://app.codeflash.ai"
def make_ai_service_request(
self,
endpoint: str,
method: str = "POST",
payload: dict[str, Any] | None = None,
timeout: float | None = None,
self, endpoint: str, method: str = "POST", payload: dict[str, Any] | None = None, timeout: float | None = None
) -> requests.Response:
"""Make an API request to the given endpoint on the AI service.
@ -98,11 +91,7 @@ class AiServiceClient:
}
logger.info("Generating optimized candidates ...")
try:
response = self.make_ai_service_request(
"/optimize",
payload=payload,
timeout=600,
)
response = self.make_ai_service_request("/optimize", payload=payload, timeout=600)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating optimized candidates: {e}")
ph("cli-optimize-error-caught", {"error": str(e)})
@ -124,10 +113,7 @@ class AiServiceClient:
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph(
"cli-optimize-error-response",
{"response_status_code": response.status_code, "error": error},
)
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
return []
def log_results(
@ -225,17 +211,11 @@ class AiServiceClient:
try:
error = response.json()["error"]
logger.error(f"Error generating tests: {response.status_code} - {error}")
ph(
"cli-testgen-error-response",
{"response_status_code": response.status_code, "error": error},
)
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
return None
except Exception:
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
ph(
"cli-testgen-error-response",
{"response_status_code": response.status_code, "error": response.text},
)
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
return None

View file

@ -22,11 +22,7 @@ else:
CFAPI_BASE_URL = "https://app.codeflash.ai"
def make_cfapi_request(
endpoint: str,
method: str,
payload: Optional[Dict[str, Any]] = None,
) -> requests.Response:
def make_cfapi_request(endpoint: str, method: str, payload: Optional[Dict[str, Any]] = None) -> requests.Response:
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
:param endpoint: The endpoint URL to send the request to.
:param method: The HTTP method to use ('GET', 'POST', etc.).
@ -55,11 +51,8 @@ def get_user_id() -> Optional[str]:
response = make_cfapi_request(endpoint="/cli-get-user", method="GET")
if response.status_code == 200:
return response.text
else:
logger.error(
f"Failed to look up your userid; is your CF API key valid? ({response.reason})",
)
return None
logger.error(f"Failed to look up your userid; is your CF API key valid? ({response.reason})")
return None
def suggest_changes(
@ -137,10 +130,7 @@ def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
:param repo: The name of the repository.
:return: The response object.
"""
response = make_cfapi_request(
endpoint=f"/is-github-app-installed?repo={repo}&owner={owner}",
method="GET",
)
response = make_cfapi_request(endpoint=f"/is-github-app-installed?repo={repo}&owner={owner}", method="GET")
if not response.ok or response.text != "true":
logger.error(f"Error: {response.text}")
return False
@ -153,17 +143,9 @@ def get_blocklisted_functions() -> dict[str, str]:
return {}
owner, repo = get_repo_owner_and_name()
information = {
"pr_number": pr_number,
"repo_owner": owner,
"repo_name": repo,
}
information = {"pr_number": pr_number, "repo_owner": owner, "repo_name": repo}
try:
req = make_cfapi_request(
endpoint="/verify-existing-optimizations",
method="POST",
payload=information,
)
req = make_cfapi_request(endpoint="/verify-existing-optimizations", method="POST", payload=information)
content: dict[str, list[str]] = req.json()
except Exception as e:
logger.error(f"Error getting blocklisted functions: {e}")

View file

@ -31,10 +31,7 @@ def parse_args() -> Namespace:
init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
init_actions_parser.set_defaults(func=install_github_actions)
parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument(
"--function",
help="Try to optimize only this function within the given file path",
)
parser.add_argument("--function", help="Try to optimize only this function within the given file path")
parser.add_argument(
"--all",
help="Try to optimize all functions. Can take a really long time. Can pass an optional starting directory to"
@ -50,30 +47,16 @@ def parse_args() -> Namespace:
" This is the top-level root directory where all the Python source code is located.",
)
parser.add_argument(
"--tests-root",
type=str,
help="Path to the test directory of the project, where all the tests are located.",
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
)
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
parser.add_argument(
"--config-file",
type=str,
help="Path to the pyproject.toml with codeflash configs.",
"--use-cached-tests", action="store_true", help="Use cached tests from a specified file for debugging."
)
parser.add_argument("--replay-test", type=str, help="Path to replay test to optimize functions from")
parser.add_argument(
"--use-cached-tests",
action="store_true",
help="Use cached tests from a specified file for debugging.",
)
parser.add_argument(
"--replay-test",
type=str,
help="Path to replay test to optimize functions from",
)
parser.add_argument(
"--no-pr",
action="store_true",
help="Do not create a PR for the optimization, only update the code locally.",
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
)
parser.add_argument(
"--verify-setup",
@ -188,7 +171,7 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
except git.exc.InvalidGitRepositoryError:
logger.exception(
"I couldn't find a git repository in the current directory. "
"I need a git repository to run --all and open PRs for optimizations. Exiting...",
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
)
apologize_and_exit()
if not args.no_pr and not check_and_push_branch(git_repo):

View file

@ -10,17 +10,13 @@ import inquirer
def apologize_and_exit() -> None:
click.echo(
"💡 If you're having trouble, see https://docs.codeflash.ai/getting-started/local-installation for further help getting started with Codeflash!",
"💡 If you're having trouble, see https://docs.codeflash.ai/getting-started/local-installation for further help getting started with Codeflash!"
)
click.echo("👋 Exiting...")
sys.exit(1)
def inquirer_wrapper(
func: Callable[..., str | bool],
*args: str | bool,
**kwargs: str | bool,
) -> str | bool:
def inquirer_wrapper(func: Callable[..., str | bool], *args: str | bool, **kwargs: str | bool) -> str | bool:
new_args = []
new_kwargs = {}
@ -29,10 +25,7 @@ def inquirer_wrapper(
else:
message = kwargs["message"]
new_kwargs = kwargs.copy()
split_messages = split_string_to_cli_width(
message,
is_confirm=func == inquirer.confirm,
)
split_messages = split_string_to_cli_width(message, is_confirm=func == inquirer.confirm)
for split_message in split_messages[:-1]:
click.echo(split_message)
@ -77,14 +70,7 @@ def inquirer_wrapper_path(*args: str, **kwargs: str) -> dict[str, str]:
new_kwargs["message"] = last_message
new_args.append(args[0])
return cast(
dict[str, str],
inquirer.prompt(
[
inquirer.Path(*new_args, **new_kwargs),
],
),
)
return cast(dict[str, str], inquirer.prompt([inquirer.Path(*new_args, **new_kwargs)]))
def split_string_to_fit_width(string: str, width: int) -> list[str]:

View file

@ -21,18 +21,10 @@ from codeflash.api.cfapi import is_github_app_installed_on_repo
from codeflash.cli_cmds.cli_common import apologize_and_exit, inquirer_wrapper, inquirer_wrapper_path
from codeflash.code_utils.compat import LF
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.env_utils import (
get_codeflash_api_key,
)
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.git_utils import get_repo_owner_and_name
from codeflash.code_utils.github_utils import (
get_github_secrets_page_url,
require_github_app_or_exit,
)
from codeflash.code_utils.shell_utils import (
get_shell_rc_path,
save_api_key_to_rc,
)
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc
from codeflash.telemetry.posthog_cf import ph
from codeflash.version import __version__ as version
@ -78,12 +70,10 @@ def init_codeflash() -> None:
f" codeflash --file <path-to-file> to optimize all functions in a file{LF}"
f" codeflash --all to optimize all functions in all files in the module you selected ({setup_info.module_root}){LF}"
f"-or-{LF}"
f" codeflash --help to see all options{LF}",
f" codeflash --help to see all options{LF}"
)
if did_add_new_key:
click.echo(
"🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!",
)
click.echo("🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!")
click.echo("Or run the following command to reload:")
if os.name == "nt":
click.echo(f" call {get_shell_rc_path()}")
@ -111,30 +101,16 @@ def collect_setup_info() -> SetupInfo:
curdir = Path.cwd()
# Check if the cwd is writable
if not os.access(curdir, os.W_OK):
click.echo(
f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}",
)
click.echo(f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}")
click.echo("It's likely you don't have write permissions for this folder.")
sys.exit(1)
# Check for the existence of pyproject.toml or setup.py
project_name = check_for_toml_or_setup_file()
ignore_subdirs = [
"venv",
"node_modules",
"dist",
"build",
"build_temp",
"build_scripts",
"env",
"logs",
"tmp",
]
ignore_subdirs = ["venv", "node_modules", "dist", "build", "build_temp", "build_scripts", "env", "logs", "tmp"]
valid_subdirs = [
d
for d in next(os.walk("."))[1]
if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs
d for d in next(os.walk("."))[1] if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs
]
valid_module_subdirs = [d for d in valid_subdirs if d != "tests"]
@ -165,9 +141,7 @@ def collect_setup_info() -> SetupInfo:
message="Where are your tests located? "
f"(If you don't have any tests yet, I can create an empty tests{os.pathsep} directory for you)",
choices=test_subdir_options,
default=(
default_tests_subdir if default_tests_subdir in test_subdir_options else test_subdir_options[0]
),
default=(default_tests_subdir if default_tests_subdir in test_subdir_options else test_subdir_options[0]),
)
if tests_root_answer == create_for_me_option:
@ -285,11 +259,7 @@ def check_for_toml_or_setup_file() -> str | None:
else:
if setup_py_path.exists():
setup_py_content = setup_py_path.read_text(encoding="utf8")
project_name_match = re.search(
r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]",
setup_py_content,
re.DOTALL,
)
project_name_match = re.search(r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, re.DOTALL)
if project_name_match:
project_name = project_name_match.group(1)
click.echo(f"✅ Found setup.py for your project {project_name}")
@ -300,7 +270,7 @@ def check_for_toml_or_setup_file() -> str | None:
click.echo(
f"💡 I couldn't find a pyproject.toml in the current directory ({curdir}).{LF}"
f"(make sure you're running `codeflash init` from your project's root directory!){LF}"
f"I need this file to store my configuration settings.",
f"I need this file to store my configuration settings."
)
ph("cli-no-pyproject-toml-or-setup-py")
@ -321,14 +291,12 @@ def check_for_toml_or_setup_file() -> str | None:
# Check if the pyproject.toml file was created
if pyproject_toml_path.exists():
click.echo(
f"✅ Created a pyproject.toml file at {pyproject_toml_path}",
)
click.echo(f"✅ Created a pyproject.toml file at {pyproject_toml_path}")
click.pause()
ph("cli-created-pyproject-toml")
except OSError:
click.echo(
"❌ Failed to create pyproject.toml. Please check your disk permissions and available space.",
"❌ Failed to create pyproject.toml. Please check your disk permissions and available space."
)
apologize_and_exit()
else:
@ -341,7 +309,7 @@ def check_for_toml_or_setup_file() -> str | None:
def install_github_actions() -> None:
try:
click.echo(
"⚡️ Codeflash can automatically optimize new Github PRs for you when they're opened. Let's get that set up!",
"⚡️ Codeflash can automatically optimize new Github PRs for you when they're opened. Let's get that set up!"
)
config, config_file_path = parse_config_file()
@ -360,10 +328,7 @@ def install_github_actions() -> None:
message=f"I'm going to create a new GitHub actions workflow file at {optimize_yaml_path} ... is this OK?",
default=True,
)
ph(
"cli-github-optimization-confirm-workflow-creation",
{"confirm_creation": confirm_creation_yes},
)
ph("cli-github-optimization-confirm-workflow-creation", {"confirm_creation": confirm_creation_yes})
if not confirm_creation_yes:
click.echo("⏩️ Exiting workflow creation.")
ph("cli-github-workflow-skipped")
@ -374,14 +339,9 @@ def install_github_actions() -> None:
py_version = sys.version_info
python_version_string = f"'{py_version.major}.{py_version.minor}'"
optimize_yml_content = (
files("codeflash")
.joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml")
.read_text(encoding="utf-8")
)
optimize_yml_content = optimize_yml_content.replace(
"{{ python_version }}",
python_version_string,
files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8")
)
optimize_yml_content = optimize_yml_content.replace("{{ python_version }}", python_version_string)
with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file:
optimize_yml_file.write(optimize_yml_content)
click.echo(f"✅ Created {optimize_yaml_path}{LF}")
@ -398,7 +358,7 @@ def install_github_actions() -> None:
click.echo(
"🐙 I opened your Github secrets page! Note: if you see a 404, you probably don't have access to this "
"repo's secrets; ask a repo admin to add it for you, or (not super recommended) you can temporarily "
f"hard-code your api key into the workflow file.{LF}",
f"hard-code your api key into the workflow file.{LF}"
)
click.pause()
click.echo()
@ -414,13 +374,13 @@ def install_github_actions() -> None:
click.launch(optimize_yaml_path.as_posix())
click.echo(
"📝 I opened the workflow file in your editor! You'll need to edit the steps that install the right Python "
f"version and any project dependencies. See the comments in the file for more details.{LF}",
f"version and any project dependencies. See the comments in the file for more details.{LF}"
)
click.pause()
click.echo()
click.echo(
f"Please commit and push this GitHub actions file to your repo, and you're all set!{LF}"
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}",
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}"
)
ph("cli-github-workflow-created")
except KeyboardInterrupt:
@ -436,14 +396,12 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
except FileNotFoundError:
click.echo(
f"I couldn't find a pyproject.toml in the current directory.{LF}"
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file.",
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file."
)
apologize_and_exit()
codeflash_section = tomlkit.table()
codeflash_section.add(
tomlkit.comment("All paths are relative to this pyproject.toml's directory."),
)
codeflash_section.add(tomlkit.comment("All paths are relative to this pyproject.toml's directory."))
codeflash_section["module-root"] = setup_info.module_root
codeflash_section["tests-root"] = setup_info.tests_root
codeflash_section["test-framework"] = setup_info.test_framework
@ -457,7 +415,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
elif formatter == "other":
formatter_cmds.append("your-formatter $file")
click.echo(
"🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code.",
"🔧 In pyproject.toml, please replace 'your-formatter' with the command you use to format your code."
)
elif formatter == "don't use a formatter":
formatter_cmds.append("disabled")
@ -483,9 +441,7 @@ def install_github_app() -> None:
owner, repo = get_repo_owner_and_name(git_repo)
if is_github_app_installed_on_repo(owner, repo):
click.echo(
"🐙 Looks like you've already installed the Codeflash GitHub app on this repository! Continuing…",
)
click.echo("🐙 Looks like you've already installed the Codeflash GitHub app on this repository! Continuing…")
else:
click.prompt(
@ -512,7 +468,7 @@ def install_github_app() -> None:
click.echo(
f"❌ It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo}.{LF}"
f"You won't be able to create PRs with Codeflash until you install the app.{LF}"
f"In the meantime you can make local only optimizations by using the '--no-pr' flag with codeflash.{LF}",
f"In the meantime you can make local only optimizations by using the '--no-pr' flag with codeflash.{LF}"
)
break
click.prompt(
@ -549,15 +505,10 @@ def prompt_api_key() -> bool:
existing_api_key = None
if existing_api_key:
display_key = f"{existing_api_key[:3]}****{existing_api_key[-4:]}"
click.echo(
f"🔑 I found a CODEFLASH_API_KEY in your environment [{display_key}]!",
)
click.echo(f"🔑 I found a CODEFLASH_API_KEY in your environment [{display_key}]!")
use_existing_key = inquirer_wrapper(
inquirer.confirm,
message="Do you want to use this key?",
default=True,
show_default=False,
inquirer.confirm, message="Do you want to use this key?", default=True, show_default=False
)
if use_existing_key:
ph("cli-existing-api-key-used")
@ -584,7 +535,7 @@ def enter_api_key_and_save_to_rc() -> None:
if not browser_launched:
click.echo(
f"Opening your Codeflash API key page. Grab a key from there!{LF}"
"You can also open this link manually: https://app.codeflash.ai/app/apikeys",
"You can also open this link manually: https://app.codeflash.ai/app/apikeys"
)
click.launch("https://app.codeflash.ai/app/apikeys")
browser_launched = True # This does not work on remote consoles
@ -664,22 +615,11 @@ def test_sort():
def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None:
command = [
"codeflash",
"--file",
"bubble_sort.py",
"--function",
"sorter",
]
command = ["codeflash", "--file", "bubble_sort.py", "--function", "sorter"]
sys.stdout.write("Running sample optimization... ")
sys.stdout.flush()
try:
process = subprocess.run(
command,
text=True,
cwd=args.module_root,
check=False,
)
process = subprocess.run(command, text=True, cwd=args.module_root, check=False)
finally:
# Delete the bubble_sort.py file after the test
Path(bubble_sort_path).unlink(missing_ok=True)
@ -688,10 +628,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
click.echo(f"{LF}🗑️ Deleted {bubble_sort_test_path}")
if process.returncode == 0:
click.echo(
f"{LF}✅ End-to-end test passed. Codeflash has been correctly set up!",
)
click.echo(f"{LF}✅ End-to-end test passed. Codeflash has been correctly set up!")
else:
click.echo(
f"{LF}❌ End-to-end test failed. Please check the logs above, and take a look at https://docs.codeflash.ai/getting-started/local-installation for help and troubleshooting.",
f"{LF}❌ End-to-end test failed. Please check the logs above, and take a look at https://docs.codeflash.ai/getting-started/local-installation for help and troubleshooting."
)

View file

@ -11,9 +11,7 @@ console = Console(record=True)
logging.basicConfig(
level=logging.INFO,
handlers=[
RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False),
],
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],
format=BARE_LOGGING_FORMAT,
)
@ -21,9 +19,7 @@ logger = logging.getLogger("rich")
def paneled_text(
text: str,
panel_args: dict[str, str | bool] | None = None,
text_args: dict[str, str] | None = None,
text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None
) -> None:
from rich.panel import Panel
from rich.text import Text

View file

@ -13,15 +13,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
logging.basicConfig(
level=level,
handlers=[
RichHandler(
rich_tracebacks=True,
markup=False,
console=console,
show_path=False,
show_time=False,
),
],
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],
format=BARE_LOGGING_FORMAT,
)
logging.getLogger().setLevel(level)
@ -31,13 +23,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
logging.basicConfig(
format=VERBOSE_LOGGING_FORMAT,
handlers=[
RichHandler(
rich_tracebacks=True,
markup=False,
console=console,
show_path=False,
show_time=False,
),
RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)
],
force=True,
)

View file

@ -22,9 +22,7 @@ if TYPE_CHECKING:
class FutureAliasedImportTransformer(cst.CSTTransformer):
def leave_ImportFrom(
self,
original_node: cst.ImportFrom,
updated_node: cst.ImportFrom,
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
if (
(updated_node_module := updated_node.module)
@ -67,7 +65,7 @@ def add_needed_imports_from_module(
filename=src_path.name,
full_module_name=src_module_and_package.name,
full_package_name=src_module_and_package.package,
),
)
)
cst.parse_module(src_module_code).visit(gatherer)
try:
@ -91,12 +89,7 @@ def add_needed_imports_from_module(
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
RemoveImportsVisitor.remove_unused_import(
dst_context,
mod,
alias_pair[0],
asname=alias_pair[1],
)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
try:
parsed_module = cst.parse_module(dst_module_code)
@ -112,9 +105,7 @@ def add_needed_imports_from_module(
return dst_module_code
def get_code(
functions_to_optimize: list[FunctionToOptimize],
) -> tuple[str | None, set[tuple[str, str]]]:
def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]:
"""Return the code for a function or methods in a Python module. functions_to_optimize is either a singleton
FunctionToOptimize instance, which represents either a function at the module level or a method of a class at the
module level, or it represents a list of methods of the same class.
@ -135,21 +126,15 @@ def get_code(
contextual_dunder_methods: set[tuple[str, str]] = set()
target_code: str = ""
def find_target(
node_list: list[ast.stmt],
name_parts: tuple[str, str] | tuple[str],
) -> ast.AST | None:
target: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Assign | ast.AnnAssign | None = (
None
)
def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[str]) -> ast.AST | None:
target: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Assign | ast.AnnAssign | None = None
node: ast.stmt
for node in node_list:
if (
# The many mypy issues will be fixed once this code moves to the backend,
# using Type Guards as we move to 3.10+.
# We will cover the Type Alias case on the backend since it's a 3.12 feature.
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))
and node.name == name_parts[0]
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == name_parts[0]
):
target = node
break
@ -159,11 +144,7 @@ def get_code(
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and node.targets[0].id == name_parts[0]
) or (
isinstance(node, ast.AnnAssign)
and hasattr(node.target, "id")
and node.target.id == name_parts[0]
):
) or (isinstance(node, ast.AnnAssign) and hasattr(node.target, "id") and node.target.id == name_parts[0]):
if class_skeleton:
break
target = node
@ -214,16 +195,14 @@ def get_code(
]
else:
logger.error(
f"Error: get_code does not support inner functions: {functions_to_optimize[0].parents}",
)
logger.error(f"Error: get_code does not support inner functions: {functions_to_optimize[0].parents}")
return None, set()
elif len(functions_to_optimize[0].parents) == 0:
qualified_name_parts_list = [(functions_to_optimize[0].function_name,)]
else:
logger.error(
"Error: get_code does not support more than one level of nesting for now. "
f"Parents: {functions_to_optimize[0].parents}",
f"Parents: {functions_to_optimize[0].parents}"
)
return None, set()
for qualified_name_parts in qualified_name_parts_list:
@ -235,32 +214,24 @@ def get_code(
isinstance(target_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))
and target_node.decorator_list
):
target_code += "".join(
lines[target_node.decorator_list[0].lineno - 1 : target_node.end_lineno],
)
target_code += "".join(lines[target_node.decorator_list[0].lineno - 1 : target_node.end_lineno])
else:
target_code += "".join(lines[target_node.lineno - 1 : target_node.end_lineno])
if not target_code:
return None, set()
class_list: list[tuple[int, int | None]] = sorted(class_skeleton)
class_code = "".join(
["".join(lines[s_lineno - 1 : e_lineno]) for (s_lineno, e_lineno) in class_list],
)
class_code = "".join(["".join(lines[s_lineno - 1 : e_lineno]) for (s_lineno, e_lineno) in class_list])
return class_code + target_code, contextual_dunder_methods
def extract_code(
functions_to_optimize: list[FunctionToOptimize],
) -> tuple[str | None, set[tuple[str, str]]]:
def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]:
edited_code, contextual_dunder_methods = get_code(functions_to_optimize)
if edited_code is None:
return None, set()
try:
compile(edited_code, "edited_code", "exec")
except SyntaxError as e:
logger.exception(
f"extract_code - Syntax error in extracted optimization candidate code: {e}",
)
logger.exception(f"extract_code - Syntax error in extracted optimization candidate code: {e}")
return None, set()
return edited_code, contextual_dunder_methods

View file

@ -20,9 +20,7 @@ def file_path_from_module_name(module_name: str, project_root_path: Path) -> Pat
def get_imports_from_file(
file_path: Path | None = None,
file_string: str | None = None,
file_ast: ast.AST | None = None,
file_path: Path | None = None, file_string: str | None = None, file_ast: ast.AST | None = None
) -> list[ast.Import | ast.ImportFrom]:
assert (
sum([file_path is not None, file_string is not None, file_ast is not None]) == 1

View file

@ -52,16 +52,9 @@ def parse_config_file(config_file_path: Path | None = None) -> tuple[dict[str, A
# default values:
path_keys = ["module-root", "tests-root"]
path_list_keys = ["ignore-paths"]
str_keys = {
"pytest-cmd": "pytest",
}
bool_keys = {
"disable-telemetry": False,
"disable-imports-sorting": False,
}
list_str_keys = {
"formatter-cmds": ["black $file"],
}
str_keys = {"pytest-cmd": "pytest"}
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False}
list_str_keys = {"formatter-cmds": ["black $file"]}
for key in str_keys:
if key in config:

View file

@ -12,13 +12,13 @@ def get_codeflash_api_key() -> Optional[str]:
if not api_key:
raise OSError(
"I didn't find a Codeflash API key in your environment.\nYou can generate one at "
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable.",
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable."
)
if not api_key.startswith("cf-"):
raise OSError(
f"Your Codeflash API key seems to be invalid. It should start with a 'cf-' prefix; I found '{api_key}' "
f"instead.\nYou can generate one at https://app.codeflash.ai/app/apikeys,\nthen set it as a "
f"CODEFLASH_API_KEY environment variable.",
f"CODEFLASH_API_KEY environment variable."
)
return api_key
@ -27,9 +27,9 @@ def ensure_codeflash_api_key() -> bool:
try:
get_codeflash_api_key()
except OSError:
logger.error( # noqa: TRY400
logger.error(
"Codeflash API key not found in your environment.\nYou can generate one at "
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable.",
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable."
)
return False
return True
@ -53,6 +53,6 @@ def ensure_pr_number() -> bool:
if not get_pr_number():
raise OSError(
"CODEFLASH_PR_NUMBER not found in environment variables; make sure the Github Action is setting this so "
"Codeflash can comment on the right PR",
"Codeflash can comment on the right PR"
)
return True

View file

@ -10,10 +10,7 @@ import isort
from codeflash.cli_cmds.console import logger
def format_code(
formatter_cmds: list[str],
path: Path,
) -> str | None:
def format_code(formatter_cmds: list[str], path: Path) -> str | None:
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
if not path.exists():
logger.error(f"File {path} does not exist. Cannot format the file.")
@ -29,12 +26,7 @@ def format_code(
logger.info(f"Formatting code with {' '.join(formatter_cmd_list)} ...")
try:
result = subprocess.run(
formatter_cmd_list,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False,
)
result = subprocess.run(formatter_cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
except Exception as e:
logger.exception(f"Failed to format code with {' '.join(formatter_cmd_list)}: {e}")
return None

View file

@ -17,25 +17,14 @@ if TYPE_CHECKING:
from git import Repo
def get_git_diff(
repo_directory: Path = Path.cwd(),
uncommitted_changes: bool = False,
) -> dict[str, list[int]]:
def get_git_diff(repo_directory: Path = Path.cwd(), uncommitted_changes: bool = False) -> dict[str, list[int]]:
repository = git.Repo(repo_directory, search_parent_directories=True)
commit = repository.head.commit
if uncommitted_changes:
uni_diff_text = repository.git.diff(
None,
"HEAD",
ignore_blank_lines=True,
ignore_space_at_eol=True,
)
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
else:
uni_diff_text = repository.git.diff(
commit.hexsha + "^1",
commit.hexsha,
ignore_blank_lines=True,
ignore_space_at_eol=True,
commit.hexsha + "^1", commit.hexsha, ignore_blank_lines=True, ignore_space_at_eol=True
)
patch_set = PatchSet(StringIO(uni_diff_text))
change_list: dict[str, list[int]] = {} # list of changes
@ -47,10 +36,7 @@ def get_git_diff(
logger.debug(f"file name: {file_path}")
add_line_no: list[int] = [
line.target_line_no
for hunk in patched_file
for line in hunk
if line.is_added and line.value.strip() != ""
line.target_line_no for hunk in patched_file for line in hunk if line.is_added and line.value.strip() != ""
] # the row number of deleted lines
logger.debug(f"added lines: {add_line_no}")
@ -89,9 +75,7 @@ def get_repo_owner_and_name(repo: Repo | None = None) -> tuple[str, str]:
remote_url = get_remote_url(repo).removesuffix(".git") if remote_url.endswith(".git") else remote_url
split_url = remote_url.split("/")
repo_owner_with_github, repo_name = split_url[-2], split_url[-1]
repo_owner = (
repo_owner_with_github.split(":")[1] if ":" in repo_owner_with_github else repo_owner_with_github
)
repo_owner = repo_owner_with_github.split(":")[1] if ":" in repo_owner_with_github else repo_owner_with_github
return repo_owner, repo_name

View file

@ -20,9 +20,9 @@ def require_github_app_or_exit(owner: str, repo: str) -> None:
f"It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo} or the GitHub"
f" account linked to your CODEFLASH_API_KEY does not have access to the repository {owner}/{repo}.{LF}"
"Before continuing, please install the Codeflash GitHub App on your repository by visiting "
f"https://github.com/apps/codeflash-ai{LF}",
f"https://github.com/apps/codeflash-ai{LF}"
)
logger.error(
f"Note: if you want to find optimizations without opening PRs, you can run Codeflash with the --no-pr flag.{LF}",
f"Note: if you want to find optimizations without opening PRs, you can run Codeflash with the --no-pr flag.{LF}"
)
apologize_and_exit()

View file

@ -37,11 +37,7 @@ def is_argument_name(name: str, arguments_node: ast.arguments) -> bool:
class InjectPerfOnly(ast.NodeTransformer):
def __init__(
self,
function: FunctionToOptimize,
module_path: str,
test_framework: str,
call_positions: list[CodePosition],
self, function: FunctionToOptimize, module_path: str, test_framework: str, call_positions: list[CodePosition]
) -> None:
self.function_object = function
self.class_name = None
@ -53,11 +49,7 @@ class InjectPerfOnly(ast.NodeTransformer):
self.class_name = function.top_level_parent_name
def find_and_update_line_node(
self,
test_node: ast.stmt,
node_name: str,
index: str,
test_class_name: str | None = None,
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
) -> Iterable[ast.stmt] | None:
call_node = None
for node in ast.walk(test_node):
@ -123,7 +115,7 @@ class InjectPerfOnly(ast.NodeTransformer):
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
),
)
)
i = len(node.body) - 1
while i >= 0:
@ -136,15 +128,9 @@ class InjectPerfOnly(ast.NodeTransformer):
compound_line_node: ast.stmt = line_node.body[j]
internal_node: ast.AST
for internal_node in ast.walk(compound_line_node):
if isinstance(
internal_node,
(ast.stmt, ast.Assign),
):
if isinstance(internal_node, (ast.stmt, ast.Assign)):
updated_node = self.find_and_update_line_node(
internal_node,
node.name,
str(i) + "_" + str(j),
test_class_name,
internal_node, node.name, str(i) + "_" + str(j), test_class_name
)
if updated_node is not None:
line_node.body[j : j + 1] = updated_node
@ -152,12 +138,7 @@ class InjectPerfOnly(ast.NodeTransformer):
break
j -= 1
else:
updated_node = self.find_and_update_line_node(
line_node,
node.name,
str(i),
test_class_name,
)
updated_node = self.find_and_update_line_node(line_node, node.name, str(i), test_class_name)
if updated_node is not None:
node.body[i : i + 1] = updated_node
did_update = True
@ -169,9 +150,7 @@ class InjectPerfOnly(ast.NodeTransformer):
targets=[ast.Name(id="codeflash_iteration", ctx=ast.Store())],
value=ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()),
attr="environ",
ctx=ast.Load(),
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
),
slice=ast.Constant(value="CODEFLASH_TEST_ITERATION"),
ctx=ast.Load(),
@ -186,13 +165,11 @@ class InjectPerfOnly(ast.NodeTransformer):
args=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()),
attr="environ",
ctx=ast.Load(),
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
),
slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"),
ctx=ast.Load(),
),
)
],
keywords=[],
),
@ -203,26 +180,18 @@ class InjectPerfOnly(ast.NodeTransformer):
targets=[ast.Name(id="codeflash_con", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="sqlite3", ctx=ast.Load()),
attr="connect",
ctx=ast.Load(),
value=ast.Name(id="sqlite3", ctx=ast.Load()), attr="connect", ctx=ast.Load()
),
args=[
ast.JoinedStr(
values=[
ast.Constant(
value=f"{get_run_tmp_file('test_return_values_')}",
),
ast.Constant(value=f"{get_run_tmp_file('test_return_values_')}"),
ast.FormattedValue(
value=ast.Name(
id="codeflash_iteration",
ctx=ast.Load(),
),
conversion=-1,
value=ast.Name(id="codeflash_iteration", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=".sqlite"),
],
),
]
)
],
keywords=[],
),
@ -233,9 +202,7 @@ class InjectPerfOnly(ast.NodeTransformer):
targets=[ast.Name(id="codeflash_cur", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_con", ctx=ast.Load()),
attr="cursor",
ctx=ast.Load(),
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="cursor", ctx=ast.Load()
),
args=[],
keywords=[],
@ -246,16 +213,14 @@ class InjectPerfOnly(ast.NodeTransformer):
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_cur", ctx=ast.Load()),
attr="execute",
ctx=ast.Load(),
value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()
),
args=[
ast.Constant(
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)",
),
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB)"
)
],
keywords=[],
),
@ -268,14 +233,12 @@ class InjectPerfOnly(ast.NodeTransformer):
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_con", ctx=ast.Load()),
attr="close",
ctx=ast.Load(),
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="close", ctx=ast.Load()
),
args=[],
keywords=[],
),
),
)
)
]
)
return node
@ -381,31 +344,16 @@ def create_wrapper_function() -> ast.FunctionDef:
targets=[ast.Name(id="test_id", ctx=ast.Store())],
value=ast.JoinedStr(
values=[
ast.FormattedValue(
value=ast.Name(id="test_module_name", ctx=ast.Load()),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="test_class_name", ctx=ast.Load()),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="test_name", ctx=ast.Load()),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="line_id", ctx=ast.Load()),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="loop_index", ctx=ast.Load()),
conversion=-1,
),
],
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
]
),
lineno=lineno + 1,
),
@ -414,10 +362,7 @@ def create_wrapper_function() -> ast.FunctionDef:
op=ast.Not(),
operand=ast.Call(
func=ast.Name(id="hasattr", ctx=ast.Load()),
args=[
ast.Name(id="codeflash_wrap", ctx=ast.Load()),
ast.Constant(value="index"),
],
args=[ast.Name(id="codeflash_wrap", ctx=ast.Load()), ast.Constant(value="index")],
keywords=[],
),
),
@ -425,14 +370,12 @@ def create_wrapper_function() -> ast.FunctionDef:
ast.Assign(
targets=[
ast.Attribute(
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
attr="index",
ctx=ast.Store(),
),
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Store()
)
],
value=ast.Dict(keys=[], values=[]),
lineno=lineno + 3,
),
)
],
orelse=[],
lineno=lineno + 2,
@ -442,20 +385,14 @@ def create_wrapper_function() -> ast.FunctionDef:
left=ast.Name(id="test_id", ctx=ast.Load()),
ops=[ast.In()],
comparators=[
ast.Attribute(
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
attr="index",
ctx=ast.Load(),
),
ast.Attribute(value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load())
],
),
body=[
ast.AugAssign(
target=ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
attr="index",
ctx=ast.Load(),
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load()
),
slice=ast.Name(id="test_id", ctx=ast.Load()),
ctx=ast.Store(),
@ -463,36 +400,30 @@ def create_wrapper_function() -> ast.FunctionDef:
op=ast.Add(),
value=ast.Constant(value=1),
lineno=lineno + 5,
),
)
],
orelse=[
ast.Assign(
targets=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
attr="index",
ctx=ast.Load(),
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load()
),
slice=ast.Name(id="test_id", ctx=ast.Load()),
ctx=ast.Store(),
),
)
],
value=ast.Constant(value=0),
lineno=lineno + 6,
),
)
],
lineno=lineno + 4,
),
ast.Assign(
targets=[
ast.Name(id="codeflash_test_index", ctx=ast.Store()),
],
targets=[ast.Name(id="codeflash_test_index", ctx=ast.Store())],
value=ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
attr="index",
ctx=ast.Load(),
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load()
),
slice=ast.Name(id="test_id", ctx=ast.Load()),
ctx=ast.Load(),
@ -503,16 +434,10 @@ def create_wrapper_function() -> ast.FunctionDef:
targets=[ast.Name(id="invocation_id", ctx=ast.Store())],
value=ast.JoinedStr(
values=[
ast.FormattedValue(
value=ast.Name(id="line_id", ctx=ast.Load()),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
ast.Constant(value="_"),
ast.FormattedValue(
value=ast.Name(id="codeflash_test_index", ctx=ast.Load()),
conversion=-1,
),
],
ast.FormattedValue(value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1),
]
),
lineno=lineno + 8,
),
@ -524,8 +449,7 @@ def create_wrapper_function() -> ast.FunctionDef:
values=[
ast.Constant(value="!######"),
ast.FormattedValue(
value=ast.Name(id="test_module_name", ctx=ast.Load()),
conversion=-1,
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(
@ -540,39 +464,23 @@ def create_wrapper_function() -> ast.FunctionDef:
),
conversion=-1,
),
ast.FormattedValue(
value=ast.Name(id="test_name", ctx=ast.Load()),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="function_name", ctx=ast.Load()),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="loop_index", ctx=ast.Load()),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="invocation_id", ctx=ast.Load()),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1),
ast.Constant(value="######!"),
],
),
]
)
],
keywords=[],
),
)
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="gc", ctx=ast.Load()),
attr="disable",
ctx=ast.Load(),
),
func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()),
args=[],
keywords=[],
),
@ -582,9 +490,7 @@ def create_wrapper_function() -> ast.FunctionDef:
targets=[ast.Name(id="counter", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="time", ctx=ast.Load()),
attr="perf_counter_ns",
ctx=ast.Load(),
value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load()
),
args=[],
keywords=[],
@ -605,9 +511,7 @@ def create_wrapper_function() -> ast.FunctionDef:
value=ast.BinOp(
left=ast.Call(
func=ast.Attribute(
value=ast.Name(id="time", ctx=ast.Load()),
attr="perf_counter_ns",
ctx=ast.Load(),
value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load()
),
args=[],
keywords=[],
@ -619,11 +523,7 @@ def create_wrapper_function() -> ast.FunctionDef:
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="gc", ctx=ast.Load()),
attr="enable",
ctx=ast.Load(),
),
func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="enable", ctx=ast.Load()),
args=[],
keywords=[],
),
@ -641,31 +541,27 @@ def create_wrapper_function() -> ast.FunctionDef:
targets=[ast.Name(id="pickled_return_value", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="pickle", ctx=ast.Load()),
attr="dumps",
ctx=ast.Load(),
value=ast.Name(id="pickle", ctx=ast.Load()), attr="dumps", ctx=ast.Load()
),
args=[ast.Name(id="return_value", ctx=ast.Load())],
keywords=[],
),
lineno=lineno + 15,
),
)
],
orelse=[
ast.Assign(
targets=[ast.Name(id="pickled_return_value", ctx=ast.Store())],
value=ast.Constant(value=None),
lineno=lineno + 16,
),
)
],
# lineno=lineno + 16,
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_cur", ctx=ast.Load()),
attr="execute",
ctx=ast.Load(),
value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()
),
args=[
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?)"),
@ -690,9 +586,7 @@ def create_wrapper_function() -> ast.FunctionDef:
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="codeflash_con", ctx=ast.Load()),
attr="commit",
ctx=ast.Load(),
value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="commit", ctx=ast.Load()
),
args=[],
keywords=[],

View file

@ -8,10 +8,10 @@ from returns.result import Failure, Result, Success
from codeflash.code_utils.compat import LF
if os.name == "nt": # Windows
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.M)
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
else:
SHELL_RC_EXPORT_PATTERN = re.compile(r'^export CODEFLASH_API_KEY="?(cf-[^\s"]+)"?$', re.M)
SHELL_RC_EXPORT_PATTERN = re.compile(r'^export CODEFLASH_API_KEY="?(cf-[^\s"]+)"?$', re.MULTILINE)
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
@ -30,18 +30,11 @@ def get_shell_rc_path() -> Path:
"""Get the path to the user's shell configuration file."""
if os.name == "nt": # on Windows, we use a batch file in the user's home directory
return Path.home() / "codeflash_env.bat"
else:
shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1]
shell_rc_filename = {
"zsh": ".zshrc",
"ksh": ".kshrc",
"csh": ".cshrc",
"tcsh": ".cshrc",
"dash": ".profile",
}.get(
shell, ".bashrc",
) # map each shell to its config file and default to .bashrc
return Path.home() / shell_rc_filename
shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1]
shell_rc_filename = {"zsh": ".zshrc", "ksh": ".kshrc", "csh": ".cshrc", "tcsh": ".cshrc", "dash": ".profile"}.get(
shell, ".bashrc"
) # map each shell to its config file and default to .bashrc
return Path.home() / shell_rc_filename
def get_api_key_export_line(api_key: str) -> str:
@ -61,9 +54,7 @@ def save_api_key_to_rc(api_key) -> Result[str, str]:
if existing_api_key:
# Replace the existing API key line
updated_shell_contents = re.sub(
SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents,
)
updated_shell_contents = re.sub(SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents)
action = "Updated CODEFLASH_API_KEY in"
else:
# Append the new API key line
@ -77,11 +68,11 @@ def save_api_key_to_rc(api_key) -> Result[str, str]:
except PermissionError:
return Failure(
f"💡 I tried adding your Codeflash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}"
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}",
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}"
)
except FileNotFoundError:
return Failure(
f"💡 I went to save your Codeflash API key to {shell_rc_path}, but noticed that it doesn't exist.{LF}"
f"To ensure your Codeflash API key is automatically loaded into your environment at startup, you can create {shell_rc_path} and add the following line:{LF}"
f"{LF}{api_key_line}{LF}",
f"{LF}{api_key_line}{LF}"
)

View file

@ -12,10 +12,7 @@ def humanize_runtime(time_in_ns: int) -> str:
if time_in_ns / 1000 >= 1:
time_micro = float(time_in_ns) / 1000
runtime_human = humanize.precisedelta(
dt.timedelta(microseconds=time_micro),
minimum_unit="microseconds",
)
runtime_human = humanize.precisedelta(dt.timedelta(microseconds=time_micro), minimum_unit="microseconds")
units = re.split(r",|\s", runtime_human)[1]

View file

@ -50,8 +50,7 @@ class TestFunction:
def discover_unit_tests(
cfg: TestConfig,
discover_only_these_tests: list[str] | None = None,
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
) -> dict[str, list[FunctionCalledInTest]]:
if cfg.test_framework == "pytest":
return discover_tests_pytest(cfg, discover_only_these_tests)
@ -78,8 +77,7 @@ def run_pytest_discovery_new_process(queue: Queue, cwd: str, tests_root: str) ->
try:
exitcode = pytest.main(
[tests_root, "--collect-only", "-pno:terminal", "-m", "not skip"],
plugins=[PytestCollectionPlugin()],
[tests_root, "--collect-only", "-pno:terminal", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
)
except Exception as e:
logger.exception(f"Failed to collect tests: {e!s}")
@ -89,9 +87,7 @@ def run_pytest_discovery_new_process(queue: Queue, cwd: str, tests_root: str) ->
queue.put((exitcode, tests, pytest_rootdir))
def parse_pytest_collection_results(
pytest_tests: str,
) -> list[TestsInFile]:
def parse_pytest_collection_results(pytest_tests: str) -> list[TestsInFile]:
test_results: list[TestsInFile] = []
for test in pytest_tests:
test_class = None
@ -106,14 +102,13 @@ def parse_pytest_collection_results(
test_function=test.name,
test_suite=None, # not used in pytest until now
test_type=test_type,
),
)
)
return test_results
def discover_tests_pytest(
cfg: TestConfig,
discover_only_these_tests: list[str] | None = None,
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
) -> dict[str, list[FunctionCalledInTest]]:
tests_root = cfg.tests_root
project_root = cfg.project_root_path
@ -140,8 +135,7 @@ def discover_tests_pytest(
def discover_tests_unittest(
cfg: TestConfig,
discover_only_these_tests: list[str] | None = None,
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
) -> dict[str, list[FunctionCalledInTest]]:
tests_root: Path = cfg.tests_root
loader: unittest.TestLoader = unittest.TestLoader()
@ -184,9 +178,7 @@ def discover_tests_unittest(
if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"):
for test_2 in test._tests:
if not hasattr(test_2, "_testMethodName"):
logger.warning(
f"Didn't find tests for {test_2}",
) # it goes deeper?
logger.warning(f"Didn't find tests for {test_2}") # it goes deeper?
continue
details = get_test_details(test_2)
if details is not None:
@ -207,8 +199,7 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
def process_test_files(
file_to_test_map: dict[str, list[TestsInFile]],
cfg: TestConfig,
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
) -> dict[str, list[FunctionCalledInTest]]:
project_root_path = cfg.project_root_path
test_framework = cfg.test_framework
@ -230,9 +221,7 @@ def process_test_files(
function_name = re.split(r"[\[\]]", function)[0]
parameters = re.split(r"[\[\]]", function)[1]
if name.name == function_name and name.type == "function":
test_functions.add(
TestFunction(name.name, None, parameters, functions[i].test_type),
)
test_functions.add(TestFunction(name.name, None, parameters, functions[i].test_type))
elif name.name == function and name.type == "function":
test_functions.add(TestFunction(name.name, None, None, functions[i].test_type))
break
@ -248,24 +237,17 @@ def process_test_files(
and f".{name.name}." in def_name.full_name
):
for function in functions_to_search:
(
is_parameterized,
new_function,
parameters,
) = discover_parameters_unittest(function)
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
if is_parameterized and new_function == def_name.name:
test_functions.add(
TestFunction(
def_name.name,
name.name,
parameters,
functions[0].test_type,
), # A test file must not have more than one test type
def_name.name, name.name, parameters, functions[0].test_type
) # A test file must not have more than one test type
)
elif function == def_name.name:
test_functions.add(
TestFunction(def_name.name, name.name, None, functions[0].test_type),
TestFunction(def_name.name, name.name, None, functions[0].test_type)
)
test_functions_list = list(test_functions)
@ -285,10 +267,7 @@ def process_test_files(
scope_parameters = test_functions_list[index].parameters
test_type = test_functions_list[index].test_type
try:
definition = name.goto(
follow_imports=True,
follow_builtin_imports=False,
)
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
logger.exception(str(e))
continue
@ -306,9 +285,7 @@ def process_test_files(
if test_framework == "unittest":
scope_test_function += "_" + scope_parameters
full_name_without_module_prefix = definition[0].full_name.replace(
definition[0].module_name + ".",
"",
1,
definition[0].module_name + ".", "", 1
)
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
function_to_test_map[qualified_name_with_modules_from_root].append(
@ -320,11 +297,8 @@ def process_test_files(
test_suite=scope_test_suite,
test_type=test_type,
),
position=CodePosition(
line_no=name.line,
col_no=name.column,
),
),
position=CodePosition(line_no=name.line, col_no=name.column),
)
)
deduped_function_to_test_map = {}
for function, tests in function_to_test_map.items():

View file

@ -75,7 +75,7 @@ class FunctionVisitor(cst.CSTVisitor):
parents=list(reversed(ast_parents)),
starting_line=pos.start.line,
ending_line=pos.end.line,
),
)
)
@ -89,11 +89,7 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
# Check if the function has a return statement and add it to the list
if function_has_return_statement(node):
self.functions.append(
FunctionToOptimize(
function_name=node.name,
file_path=self.file_path,
parents=self.ast_path[:],
),
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)
# Continue visiting the body of the function to find nested functions
self.generic_visit(node)
@ -170,7 +166,7 @@ def get_functions_to_optimize(
bool(optimize_all),
bool(replay_test),
bool(file),
],
]
)
<= 1
), "Only one of optimize_all, replay_test, or file should be provided"
@ -180,9 +176,7 @@ def get_functions_to_optimize(
functions = get_all_files_and_functions(Path(optimize_all))
elif replay_test is not None:
functions = get_all_replay_test_functions(
replay_test=replay_test,
test_cfg=test_cfg,
project_root_path=project_root,
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
)
elif file is not None:
@ -213,11 +207,7 @@ def get_functions_to_optimize(
ph("cli-optimizing-git-diff")
functions = get_functions_within_git_diff()
filtered_modified_functions, functions_count = filter_functions(
functions,
test_cfg.tests_root,
ignore_paths,
project_root,
module_root,
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
)
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
return filtered_modified_functions, functions_count
@ -277,9 +267,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[str, list[FunctionToOpti
def get_all_replay_test_functions(
replay_test: str,
test_cfg: TestConfig,
project_root_path: Path,
replay_test: str, test_cfg: TestConfig, project_root_path: Path
) -> dict[str, list[FunctionToOptimize]]:
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
# Get the absolute file paths for each function, excluding class name if present
@ -295,8 +283,7 @@ def get_all_replay_test_functions(
module_path_parts[-1]
if module_path_parts
and is_class_defined_in_file(
module_path_parts[-1],
Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py"),
module_path_parts[-1], Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py")
)
else None
)
@ -310,9 +297,7 @@ def get_all_replay_test_functions(
file_path = Path(project_root_path, *file_path_parts).with_suffix(".py")
file_to_functions_map[file_path].append((function, function_name, class_name))
for file_path, functions in file_to_functions_map.items():
all_valid_functions: dict[str, list[FunctionToOptimize]] = find_all_functions_in_file(
file_path=file_path,
)
all_valid_functions: dict[str, list[FunctionToOptimize]] = find_all_functions_in_file(file_path=file_path)
filtered_list = []
for function in functions:
function_name, function_name_only, class_name = function
@ -321,7 +306,7 @@ def get_all_replay_test_functions(
valid_function
for valid_function in all_valid_functions[file_path]
if valid_function.qualified_name == function_name
],
]
)
if len(filtered_list):
filtered_valid_functions[file_path] = filtered_list
@ -341,19 +326,13 @@ def is_git_repo(file_path: str) -> bool:
def ignored_submodule_paths(module_root: str) -> list[str]:
if is_git_repo(module_root):
git_repo = git.Repo(module_root, search_parent_directories=True)
return [
Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules
]
return [Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules]
return []
class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
def __init__(
self,
file_name: Path,
function_or_method_name: str,
class_name: str | None = None,
line_no: int | None = None,
self, file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None
) -> None:
self.file_name = file_name
self.class_name = class_name
@ -374,7 +353,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
bool(node.args.kwarg),
bool(node.args.posonlyargs),
bool(node.args.vararg),
),
)
)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
@ -410,10 +389,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
def inspect_top_level_functions_or_methods(
file_name: Path,
function_or_method_name: str,
class_name: str | None = None,
line_no: int | None = None,
file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None
) -> FunctionProperties:
with open(file_name, encoding="utf8") as file:
try:
@ -422,10 +398,7 @@ def inspect_top_level_functions_or_methods(
logger.exception(e)
return False
visitor = TopLevelFunctionOrMethodVisitor(
file_name=file_name,
function_or_method_name=function_or_method_name,
class_name=class_name,
line_no=line_no,
file_name=file_name, function_or_method_name=function_or_method_name, class_name=class_name, line_no=line_no
)
visitor.visit(ast_module)
staticmethod_class_name = visitor.class_name if visitor.is_staticmethod else None
@ -495,9 +468,7 @@ def filter_functions(
path = Path(function.file_path).name
if path in blocklist_funcs and function.function_name in blocklist_funcs[path]:
functions.remove(function)
logger.debug(
f"Skipping {function.function_name} in {path} as it has already been optimized",
)
logger.debug(f"Skipping {function.function_name} in {path} as it has already been optimized")
continue
filtered_modified_functions[file_path] = functions
@ -517,12 +488,7 @@ def filter_functions(
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
def filter_files_optimized(
file_path: Path,
tests_root: Path,
ignore_paths: list[Path],
module_root: Path,
) -> bool:
def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool:
"""Optimized version of the filter_functions function above.
Takes in file paths and returns the count of files that are to be optimized.
@ -530,9 +496,7 @@ def filter_files_optimized(
submodule_paths = None
if file_path.is_relative_to(tests_root):
return False
if file_path in ignore_paths or any(
file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths
):
if file_path in ignore_paths or any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths):
return False
if path_belongs_to_site_packages(file_path):
return False

View file

@ -16,9 +16,7 @@ from codeflash.telemetry.sentry import init_sentry
def main() -> None:
"""Entry point for the codeflash command-line interface."""
paneled_text(
CODEFLASH_LOGO,
panel_args={"title": "https://codeflash.ai", "expand": False},
text_args={"style": "bold gold3"},
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}
)
args = parse_args()
if args.command:

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Generator, Iterator, Optional
from typing import Iterator, Optional
from jedi.api.classes import Name
from pydantic import BaseModel
@ -66,9 +66,7 @@ class TestFiles(BaseModel):
test_files: list[TestFile]
def get_by_type(self, test_type: TestType) -> TestFiles:
return TestFiles(
test_files=[test_file for test_file in self.test_files if test_file.test_type == test_type],
)
return TestFiles(test_files=[test_file for test_file in self.test_files if test_file.test_type == test_type])
def add(self, test_file: TestFile) -> None:
if test_file not in self.test_files:
@ -77,29 +75,17 @@ class TestFiles(BaseModel):
raise ValueError("Test file already exists in the list")
def get_by_original_file_path(self, file_path: Path) -> TestFile | None:
return next(
(test_file for test_file in self.test_files if test_file.original_file_path == file_path),
None,
)
return next((test_file for test_file in self.test_files if test_file.original_file_path == file_path), None)
def get_test_type_by_instrumented_file_path(self, file_path: Path) -> TestType | None:
return next(
(
test_file.test_type
for test_file in self.test_files
if test_file.instrumented_file_path == file_path
),
(test_file.test_type for test_file in self.test_files if test_file.instrumented_file_path == file_path),
None,
)
def get_test_type_by_original_file_path(self, file_path: Path) -> TestType | None:
return next(
(
test_file.test_type
for test_file in self.test_files
if test_file.original_file_path == file_path
),
None,
(test_file.test_type for test_file in self.test_files if test_file.original_file_path == file_path), None
)
def __iter__(self) -> Iterator[TestFile]:

View file

@ -36,9 +36,7 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
def get_type_annotation_context(
function: FunctionToOptimize,
jedi_script: jedi.Script,
project_root_path: Path,
function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
function_name: str = function.function_name
file_path: Path = function.file_path
@ -53,18 +51,11 @@ def get_type_annotation_context(
contextual_dunder_methods = set()
def get_annotation_source(
j_script: jedi.Script,
name: str,
node_parents: list[FunctionParent],
line_no: int,
col_no: str,
j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str
) -> None:
try:
definition: list[Name] = j_script.goto(
line=line_no,
column=col_no,
follow_imports=True,
follow_builtin_imports=False,
line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False
)
except Exception as ex:
if hasattr(name, "full_name"):
@ -82,15 +73,7 @@ def get_type_annotation_context(
and not path_belongs_to_site_packages(definition_path)
and not belongs_to_function(definition[0], function_name)
):
source_code = get_code(
[
FunctionToOptimize(
definition[0].name,
definition_path,
node_parents[:-1],
),
],
)
source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])])
if source_code[0]:
sources.append(
FunctionSource(
@ -98,25 +81,21 @@ def get_type_annotation_context(
jedi_definition=definition[0],
source_code=source_code[0],
file_path=definition_path,
qualified_name=definition[0].full_name.removeprefix(
definition[0].module_name + ".",
),
qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."),
only_function_name=definition[0].name,
),
)
)
contextual_dunder_methods.update(source_code[1])
def visit_children(
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module,
node_parents: list[FunctionParent],
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent]
) -> None:
child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module
for child in ast.iter_child_nodes(node):
visit(child, node_parents)
def visit_all_annotation_children(
node: ast.Subscript | ast.Name | ast.BinOp,
node_parents: list[FunctionParent],
node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent]
) -> None:
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
visit_all_annotation_children(node.left, node_parents)
@ -165,8 +144,7 @@ def get_type_annotation_context(
def get_function_variables_definitions(
function_to_optimize: FunctionToOptimize,
project_root_path: Path,
function_to_optimize: FunctionToOptimize, project_root_path: Path
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
function_name = function_to_optimize.function_name
file_path = function_to_optimize.file_path
@ -181,20 +159,16 @@ def get_function_variables_definitions(
if ref.full_name:
if function_to_optimize.parents:
# Check if the reference belongs to the specified class when FunctionParent is provided
if belongs_to_class(
ref,
function_to_optimize.parents[-1].name,
) and belongs_to_function(ref, function_name):
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
ref, function_name
):
names.append(ref)
elif belongs_to_function(ref, function_name):
names.append(ref)
for name in names:
try:
definitions: list[Name] = name.goto(
follow_imports=True,
follow_builtin_imports=False,
)
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
try:
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
@ -221,13 +195,7 @@ def get_function_variables_definitions(
parents = [FunctionParent(m.group(1), "ClassDef")]
source_code = get_code(
[
FunctionToOptimize(
function_name=definitions[0].name,
file_path=definition_path,
parents=parents,
),
],
[FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)]
)
if source_code[0]:
sources.append(
@ -236,24 +204,18 @@ def get_function_variables_definitions(
jedi_definition=definition,
source_code=source_code[0],
file_path=definition_path,
qualified_name=definition.full_name.removeprefix(
definition.module_name + ".",
),
qualified_name=definition.full_name.removeprefix(definition.module_name + "."),
only_function_name=definition.name,
),
)
)
contextual_dunder_methods.update(source_code[1])
annotation_sources, annotation_dunder_methods = get_type_annotation_context(
function_to_optimize,
script,
project_root_path,
function_to_optimize, script, project_root_path
)
sources[:0] = annotation_sources # prepend the annotation sources
contextual_dunder_methods.update(annotation_dunder_methods)
existing_fully_qualified_names = set()
no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(
lambda: defaultdict(set),
)
no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set))
parent_sources = set()
for source in sources:
if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names:
@ -269,10 +231,7 @@ def get_function_variables_definitions(
or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path]
]
deduped_no_parent_sources = [
source
for k1 in no_parent_sources
for k2 in no_parent_sources[k1]
for source in no_parent_sources[k1][k2]
source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2]
]
return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods
@ -288,10 +247,7 @@ def get_constrained_function_context_and_helper_functions(
) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]:
# TODO: Not just do static analysis, but also find the datatypes of function arguments by running the existing
# unittests and inspecting the arguments to resolve the real definitions and dependencies.
helper_functions, dunder_methods = get_function_variables_definitions(
function_to_optimize,
project_root_path,
)
helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path)
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
code_to_optimize_tokens = tokenizer.encode(code_to_optimize)

View file

@ -14,22 +14,12 @@ import libcst as cst
from returns.pipeline import is_successful
from returns.result import Failure, Success
from codeflash.api.aiservice import (
AiServiceClient,
LocalAiServiceClient,
)
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.cli_cmds.console import code_print, logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import (
add_needed_imports_from_module,
extract_code,
find_preexisting_objects,
)
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code, find_preexisting_objects
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.code_utils.code_utils import (
get_run_tmp_file,
module_name_from_file_path,
)
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.code_utils.config_consts import (
INDIVIDUAL_TESTCASE_TIMEOUT,
N_CANDIDATES,
@ -37,21 +27,11 @@ from codeflash.code_utils.config_consts import (
TOTAL_LOOPING_TIME,
)
from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.instrument_existing_tests import (
inject_profiling_into_existing_test,
)
from codeflash.code_utils.remove_generated_tests import (
remove_functions_from_generated_tests,
)
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.discovery.discover_unit_tests import (
discover_unit_tests,
)
from codeflash.discovery.functions_to_optimize import (
FunctionParent,
FunctionToOptimize,
get_functions_to_optimize,
)
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize, get_functions_to_optimize
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
BestOptimization,
@ -64,9 +44,7 @@ from codeflash.models.models import (
TestFile,
TestFiles,
)
from codeflash.optimization.function_context import (
get_constrained_function_context_and_helper_functions,
)
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.critic import performance_gain, quantity_of_tests_critic, speedup_critic
from codeflash.result.explanation import Explanation
@ -83,15 +61,9 @@ if TYPE_CHECKING:
from returns.result import Result
from codeflash.api.aiservice import (
OptimizedCandidate,
)
from codeflash.discovery.discover_unit_tests import (
FunctionCalledInTest,
)
from codeflash.models.models import (
FunctionSource,
)
from codeflash.api.aiservice import OptimizedCandidate
from codeflash.discovery.discover_unit_tests import FunctionCalledInTest
from codeflash.models.models import FunctionSource
class Optimizer:
@ -121,10 +93,7 @@ class Optimizer:
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
num_optimizable_functions: int
(
file_to_funcs_to_optimize,
num_optimizable_functions,
) = get_functions_to_optimize(
(file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize(
optimize_all=self.args.all,
replay_test=self.args.replay_test,
file=self.args.file,
@ -140,24 +109,15 @@ class Optimizer:
function_iterator_count: int = 0
try:
ph(
"cli-optimize-functions-to-optimize",
{"num_functions": num_optimizable_functions},
)
ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions})
if num_optimizable_functions == 0:
logger.info("No functions found to optimize. Exiting...")
return
logger.info(f"Discovering existing unit tests in {self.test_cfg.tests_root} ...")
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(
self.test_cfg,
)
num_discovered_tests: int = sum(
[len(value) for value in function_to_tests.values()],
)
logger.info(
f"Discovered {num_discovered_tests} existing unit tests in {self.test_cfg.tests_root}",
)
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
logger.info(f"Discovered {num_discovered_tests} existing unit tests in {self.test_cfg.tests_root}")
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
for path in file_to_funcs_to_optimize:
logger.info(f"Examining file {path} ...")
@ -168,14 +128,10 @@ class Optimizer:
function_iterator_count += 1
logger.info(
f"Optimizing function {function_iterator_count} of {num_optimizable_functions} - "
f"{function_to_optimize.qualified_name}",
f"{function_to_optimize.qualified_name}"
)
best_optimization = self.optimize_function(
function_to_optimize,
function_to_tests,
original_code,
)
best_optimization = self.optimize_function(function_to_optimize, function_to_tests, original_code)
self.test_files = TestFiles(test_files=[])
if is_successful(best_optimization):
optimizations_found += 1
@ -207,11 +163,7 @@ class Optimizer:
logger.debug(f"Function Trace ID: {function_trace_id}")
ph("cli-optimize-function-start", {"function_trace_id": function_trace_id})
self.cleanup_leftover_test_return_values()
ctx_result = self.get_code_optimization_context(
function_to_optimize,
self.args.project_root,
original_code,
)
ctx_result = self.get_code_optimization_context(function_to_optimize, self.args.project_root, original_code)
if not is_successful(ctx_result):
return Failure(ctx_result.failure())
code_context: CodeOptimizationContext = ctx_result.unwrap()
@ -236,8 +188,7 @@ class Optimizer:
)
instrumented_unittests_created_for_function = self.instrument_existing_tests(
function_to_optimize=function_to_optimize,
function_to_tests=function_to_tests,
function_to_optimize=function_to_optimize, function_to_tests=function_to_tests
)
logger.info(f"Generating new tests for function {function_to_optimize.function_name} ...")
@ -257,12 +208,7 @@ class Optimizer:
count_tests = len(generated_tests.generated_tests)
generated_tests_paths = [
get_test_file_path(
self.args.tests_root,
function_to_optimize.function_name,
i,
)
for i in range(count_tests)
get_test_file_path(self.args.tests_root, function_to_optimize.function_name, i) for i in range(count_tests)
]
for i, generated_test in enumerate(generated_tests.generated_tests):
@ -275,7 +221,7 @@ class Optimizer:
original_file_path=None,
original_source=generated_test.generated_original_test_source,
test_type=TestType.GENERATED_REGRESSION,
),
)
)
logger.info(f"Generated test {i + 1}/{count_tests}:")
code_print(generated_test.generated_original_test_source)
@ -297,15 +243,12 @@ class Optimizer:
# TODO: Postprocess the optimized function to include the original docstring and such
best_optimization = None
for u, candidates in enumerate(
[optimizations_set.control, optimizations_set.experiment],
):
for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
if candidates is None:
continue
tests_in_file: list[FunctionCalledInTest] = function_to_tests.get(
function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root),
[],
function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root), []
)
best_optimization = self.determine_best_candidate(
@ -315,16 +258,13 @@ class Optimizer:
original_code=original_code,
original_code_baseline=original_code_baseline,
original_helper_code=original_helper_code,
function_trace_id=function_trace_id[:-4] + f"EXP{u}"
if should_run_experiment
else function_trace_id,
function_trace_id=function_trace_id[:-4] + f"EXP{u}" if should_run_experiment else function_trace_id,
only_run_this_test_function=tests_in_file,
)
ph("cli-optimize-function-finished", {"function_trace_id": function_trace_id})
generated_tests = remove_functions_from_generated_tests(
generated_tests=generated_tests,
test_functions_to_remove=test_functions_to_remove,
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
)
if best_optimization:
@ -340,12 +280,7 @@ class Optimizer:
file_path=function_to_optimize.file_path,
)
self.log_successful_optimization(
explanation,
function_to_optimize,
function_trace_id,
generated_tests,
)
self.log_successful_optimization(explanation, function_to_optimize, function_trace_id, generated_tests)
self.replace_function_and_helpers_with_optimized_code(
code_context=code_context,
@ -355,9 +290,7 @@ class Optimizer:
)
new_code, new_helper_code = self.reformat_code_and_helpers(
code_context.helper_functions,
explanation.file_path,
original_code,
code_context.helper_functions, explanation.file_path, original_code
)
existing_tests = existing_tests_source_for(
@ -377,7 +310,7 @@ class Optimizer:
explanation=explanation,
existing_tests_source=existing_tests,
generated_original_test_source="\n".join(
[test.generated_original_test_source for test in generated_tests.generated_tests],
[test.generated_original_test_source for test in generated_tests.generated_tests]
),
function_trace_id=function_trace_id,
)
@ -386,17 +319,14 @@ class Optimizer:
# a) Error propagation, where error in one function can cause the next optimization to fail
# b) Performance estimates become unstable, as the runtime of an optimization might be
# dependent on the runtime of the previous optimization
self.write_code_and_helpers(
original_code,
original_helper_code,
function_to_optimize.file_path,
)
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
for generated_test_path in generated_tests_paths:
generated_test_path.unlink(missing_ok=True)
for test_paths in instrumented_unittests_created_for_function:
test_paths.unlink(missing_ok=True)
if not best_optimization:
return Failure(f"No best optimizations found for function {function_to_optimize.qualified_name}")
logger.info("----------------")
return Success(best_optimization)
def determine_best_candidate(
@ -419,7 +349,7 @@ class Optimizer:
is_correct = {}
logger.info(
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ...",
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ..."
)
try:
for candidate_index, candidate in enumerate(candidates, start=1):
@ -436,21 +366,12 @@ class Optimizer:
)
if not did_update:
logger.warning(
"No functions were replaced in the optimized code. Skipping optimization candidate.",
"No functions were replaced in the optimized code. Skipping optimization candidate."
)
continue
except (
ValueError,
SyntaxError,
cst.ParserSyntaxError,
AttributeError,
) as e:
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
logger.error(e)
self.write_code_and_helpers(
original_code,
original_helper_code,
function_to_optimize.file_path,
)
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
continue
# Run generated tests if at least one of them passed
@ -477,34 +398,27 @@ class Optimizer:
optimized_runtimes[candidate.optimization_id] = best_test_runtime
is_correct[candidate.optimization_id] = True
perf_gain = performance_gain(
original_runtime_ns=original_code_baseline.runtime,
optimized_runtime_ns=best_test_runtime,
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime
)
speedup_ratios[candidate.optimization_id] = perf_gain
loop_count = (
max(all_loop_indices)
if (
all_loop_indices := {
result.loop_index for result in candidate_result.best_test_results
}
)
if (all_loop_indices := {result.loop_index for result in candidate_result.best_test_results})
else 1
)
logger.info(
f"Candidate code runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}: {humanize_runtime(best_test_runtime)} per full loop.\n"
f"Speedup ratio: {perf_gain:.3f}",
f"Speedup ratio: {perf_gain:.3f}"
)
if speedup_critic(
candidate_result,
original_code_baseline.runtime,
best_runtime_until_now,
candidate_result, original_code_baseline.runtime, best_runtime_until_now
) and quantity_of_tests_critic(candidate_result):
logger.info("This candidate is faster than the previous best candidate.")
logger.info(
f"Original runtime: {humanize_runtime(original_code_baseline.runtime)}\n"
f"Best test runtime: {humanize_runtime(candidate_result.best_test_runtime)}\n"
f"Speedup ratio: {perf_gain:.3f}",
f"Speedup ratio: {perf_gain:.3f}"
)
best_optimization = BestOptimization(
candidate=candidate,
@ -514,18 +428,9 @@ class Optimizer:
)
best_runtime_until_now = best_test_runtime
self.write_code_and_helpers(
original_code,
original_helper_code,
function_to_optimize.file_path,
)
logger.info("----------------")
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
except KeyboardInterrupt as e:
self.write_code_and_helpers(
original_code,
original_helper_code,
function_to_optimize.file_path,
)
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
logger.exception(f"Optimization interrupted: {e}")
raise e
@ -545,9 +450,7 @@ class Optimizer:
function_trace_id: str,
generated_tests: GeneratedTestsList,
) -> None:
logger.info(
f"⚡️ Optimization successful! 📄 {function_to_optimize.qualified_name} in {explanation.file_path}",
)
logger.info(f"⚡️ Optimization successful! 📄 {function_to_optimize.qualified_name} in {explanation.file_path}")
logger.info(f"📈 {explanation.perf_improvement_line}")
logger.info(f"Explanation: \n{explanation.to_console_string()}")
@ -572,11 +475,7 @@ class Optimizer:
)
@staticmethod
def write_code_and_helpers(
original_code: str,
original_helper_code: dict[Path, str],
path: Path,
) -> None:
def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, str], path: Path) -> None:
with path.open("w", encoding="utf8") as f:
f.write(original_code)
for module_abspath in original_helper_code:
@ -584,29 +483,20 @@ class Optimizer:
f.write(original_helper_code[module_abspath])
def reformat_code_and_helpers(
self,
helper_functions: list[FunctionSource],
path: Path,
original_code: str,
self, helper_functions: list[FunctionSource], path: Path, original_code: str
) -> tuple[str, dict[Path, str]]:
should_sort_imports = not self.args.disable_imports_sorting
if should_sort_imports and isort.code(original_code) != original_code:
should_sort_imports = False
new_code = format_code(
self.args.formatter_cmds,
path,
)
new_code = format_code(self.args.formatter_cmds, path)
if should_sort_imports and new_code is not None:
new_code = sort_imports(new_code)
new_helper_code: dict[Path, str] = {}
helper_functions_paths = {hf.file_path for hf in helper_functions}
for module_abspath in helper_functions_paths:
formatted_helper_code = format_code(
self.args.formatter_cmds,
module_abspath,
)
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
if should_sort_imports and formatted_helper_code is not None:
formatted_helper_code = sort_imports(formatted_helper_code)
if formatted_helper_code is not None:
@ -633,13 +523,8 @@ class Optimizer:
helper_functions_by_module_abspath = defaultdict(set)
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
helper_functions_by_module_abspath[helper_function.file_path].add(
helper_function.qualified_name,
)
for (
module_abspath,
qualified_names,
) in helper_functions_by_module_abspath.items():
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
did_update |= replace_function_definitions_in_module(
function_names=list(qualified_names),
optimized_code=optimized_code,
@ -652,24 +537,13 @@ class Optimizer:
return did_update
def get_code_optimization_context(
self,
function_to_optimize: FunctionToOptimize,
project_root: Path,
original_source_code: str,
self, function_to_optimize: FunctionToOptimize, project_root: Path, original_source_code: str
) -> Result[CodeOptimizationContext, str]:
code_to_optimize, contextual_dunder_methods = extract_code(
[function_to_optimize],
)
code_to_optimize, contextual_dunder_methods = extract_code([function_to_optimize])
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
(
helper_code,
helper_functions,
helper_dunder_methods,
) = get_constrained_function_context_and_helper_functions(
function_to_optimize,
self.args.project_root,
code_to_optimize,
(helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions(
function_to_optimize, self.args.project_root, code_to_optimize
)
if function_to_optimize.parents:
function_class = function_to_optimize.parents[0].name
@ -695,9 +569,7 @@ class Optimizer:
dedup_optimizable_methods.append(method)
added_methods.add(f"{method.file_path}.{method.qualified_name}")
if len(dedup_optimizable_methods) > 1:
code_to_optimize, contextual_dunder_methods = extract_code(
list(reversed(dedup_optimizable_methods)),
)
code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods)))
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
@ -718,7 +590,7 @@ class Optimizer:
contextual_dunder_methods=contextual_dunder_methods,
helper_functions=helper_functions,
preexisting_objects=preexisting_objects,
),
)
)
@staticmethod
@ -728,26 +600,18 @@ class Optimizer:
get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True)
def instrument_existing_tests(
self,
function_to_optimize: FunctionToOptimize,
function_to_tests: dict[str, list[FunctionCalledInTest]],
self, function_to_optimize: FunctionToOptimize, function_to_tests: dict[str, list[FunctionCalledInTest]]
) -> set[Path]:
relevant_test_files_count = 0
unique_instrumented_test_files = set()
func_qualname = function_to_optimize.qualified_name_with_modules_from_root(
self.args.project_root,
)
func_qualname = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root)
if func_qualname not in function_to_tests:
logger.info(
f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.",
)
logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.")
else:
test_file_invocation_positions = defaultdict(list)
for tests_in_file in function_to_tests.get(func_qualname):
test_file_invocation_positions[tests_in_file.tests_in_file.test_file].append(
tests_in_file.position,
)
test_file_invocation_positions[tests_in_file.tests_in_file.test_file].append(tests_in_file.position)
for test_file, positions in test_file_invocation_positions.items():
path_obj_test_file = Path(test_file)
relevant_test_files_count += 1
@ -762,7 +626,7 @@ class Optimizer:
continue
new_test_path = Path(
f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}",
f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}"
)
if injected_test is not None:
with new_test_path.open("w", encoding="utf8") as _f:
@ -778,11 +642,11 @@ class Optimizer:
original_source=None,
original_file_path=Path(test_file),
test_type=TestType.EXISTING_UNIT_TEST,
),
)
)
logger.info(
f"Discovered {relevant_test_files_count} existing unit test file"
f"{'s' if relevant_test_files_count != 1 else ''} for {func_qualname}",
f"{'s' if relevant_test_files_count != 1 else ''} for {func_qualname}"
)
return unique_instrumented_test_files
@ -849,31 +713,18 @@ class Optimizer:
GeneratedTests(
generated_original_test_source=generated_test_source,
instrumented_test_source=instrumented_test_source,
),
)
)
if not tests:
logger.warning(
f"Failed to generate and instrument tests for {function_to_optimize.function_name}",
)
logger.warning(f"Failed to generate and instrument tests for {function_to_optimize.function_name}")
return Failure(f"/!\\ NO TESTS GENERATED for {function_to_optimize.function_name}")
logger.info(f"Generated {len(tests)} tests for {function_to_optimize.function_name}")
generated_tests = GeneratedTestsList(generated_tests=tests)
return Success(
(
generated_tests,
OptimizationSet(
control=candidates,
experiment=candidates_experiment,
),
),
)
return Success((generated_tests, OptimizationSet(control=candidates, experiment=candidates_experiment)))
def establish_original_code_baseline(
self,
function_name: str,
generated_tests_paths: list[Path],
tests_in_file: list[FunctionCalledInTest],
self, function_name: str, generated_tests_paths: list[Path], tests_in_file: list[FunctionCalledInTest]
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
success = True
@ -902,12 +753,10 @@ class Optimizer:
) == TestType.REPLAY_TEST
first_test_types.append(first_test_type)
first_test_functions.append(
relevant_tests_in_file[0].tests_in_file.test_function if is_replay_test else None,
relevant_tests_in_file[0].tests_in_file.test_function if is_replay_test else None
)
if is_replay_test and len(relevant_tests_in_file) > 1:
logger.warning(
f"Multiple tests found for the replay test {test_file}. Should not happen",
)
logger.warning(f"Multiple tests found for the replay test {test_file}. Should not happen")
first_test_functions.extend([None] * len(generated_tests_paths))
if test_framework == "pytest":
@ -935,51 +784,41 @@ class Optimizer:
unittest_results.merge(unittest_loop_results)
initial_loop_unittest_results = TestResults(
test_results=[result for result in unittest_results.test_results if result.loop_index == 1],
test_results=[result for result in unittest_results.test_results if result.loop_index == 1]
)
logger.info(
f"Overall initial loop test results for original code: {TestResults.report_to_string(initial_loop_unittest_results.get_test_pass_fail_report_by_type())}",
f"Overall initial loop test results for original code: {TestResults.report_to_string(initial_loop_unittest_results.get_test_pass_fail_report_by_type())}"
)
existing_test_results = TestResults(
test_results=[
result for result in unittest_results if result.test_type == TestType.EXISTING_UNIT_TEST
],
test_results=[result for result in unittest_results if result.test_type == TestType.EXISTING_UNIT_TEST]
)
generated_test_results = TestResults(
test_results=[
result for result in unittest_results if result.test_type == TestType.GENERATED_REGRESSION
],
test_results=[result for result in unittest_results if result.test_type == TestType.GENERATED_REGRESSION]
)
total_timing = unittest_results.total_passed_runtime()
functions_to_remove = [
result.id.test_function_name
for result in generated_test_results.test_results
if not result.did_pass
result.id.test_function_name for result in generated_test_results.test_results if not result.did_pass
]
if not initial_loop_unittest_results:
logger.warning(
f"Couldn't run any tests for original function {function_name}. SKIPPING OPTIMIZING THIS FUNCTION.",
f"Couldn't run any tests for original function {function_name}. SKIPPING OPTIMIZING THIS FUNCTION."
)
success = False
if total_timing == 0:
logger.warning(
"The overall test runtime of the original function is 0, couldn't run tests.",
)
logger.warning("The overall test runtime of the original function is 0, couldn't run tests.")
success = False
if not total_timing:
logger.warning(
"Failed to run the tests for the original function, skipping optimization",
)
logger.warning("Failed to run the tests for the original function, skipping optimization")
success = False
if not success:
return Failure("Failed to establish a baseline for the original code.")
loop_count = max([int(result.loop_index) for result in unittest_results.test_results])
logger.info(
f"Original code runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}: {humanize_runtime(total_timing)} per full loop",
f"Original code runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}: {humanize_runtime(total_timing)} per full loop"
)
logger.debug(f"Total original code runtime (ns): {total_timing}")
return Success(
@ -991,7 +830,7 @@ class Optimizer:
runtime=total_timing,
),
functions_to_remove,
),
)
)
def run_optimized_candidate(
@ -1021,12 +860,8 @@ class Optimizer:
first_test_types = []
first_test_functions = []
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
missing_ok=True,
)
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
missing_ok=True,
)
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
for test_file in instrumented_unittests_created_for_function:
relevant_tests_in_file = [
@ -1039,11 +874,11 @@ class Optimizer:
) == TestType.REPLAY_TEST
first_test_types.append(first_test_type)
first_test_functions.append(
relevant_tests_in_file[0].tests_in_file.test_function if is_replay_test else None,
relevant_tests_in_file[0].tests_in_file.test_function if is_replay_test else None
)
if is_replay_test and len(relevant_tests_in_file) > 1:
logger.warning(
f"Multiple tests found for the replay test {test_file.original_file_path}. Should not happen",
f"Multiple tests found for the replay test {test_file.original_file_path}. Should not happen"
)
first_test_functions.extend([None] * len(generated_tests_paths))
if test_framework == "pytest":
@ -1071,19 +906,16 @@ class Optimizer:
candidate_results.merge(candidate_loop_results)
initial_loop_candidate_results = TestResults(
test_results=[result for result in candidate_results.test_results if result.loop_index == 1],
test_results=[result for result in candidate_results.test_results if result.loop_index == 1]
)
logger.info(
f"Overall initial loop test results for candidate code: {TestResults.report_to_string(initial_loop_candidate_results.get_test_pass_fail_report_by_type())}",
f"Overall initial loop test results for candidate code: {TestResults.report_to_string(initial_loop_candidate_results.get_test_pass_fail_report_by_type())}"
)
initial_loop_original_test_results = TestResults(
test_results=[result for result in original_test_results.test_results if result.loop_index == 1],
test_results=[result for result in original_test_results.test_results if result.loop_index == 1]
)
if compare_test_results(
initial_loop_original_test_results,
initial_loop_candidate_results,
):
if compare_test_results(initial_loop_original_test_results, initial_loop_candidate_results):
logger.info("Test results matched!")
equal_results = True
else:
@ -1092,32 +924,22 @@ class Optimizer:
equal_results = False
if (total_candidate_timing := candidate_results.total_passed_runtime()) == 0:
logger.warning(
"The overall test runtime of the optimized function is 0, couldn't run tests.",
)
logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.")
if best_runtime_until_now is None or total_candidate_timing < best_runtime_until_now:
best_test_results = candidate_results
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.bin")).unlink(
missing_ok=True,
)
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
missing_ok=True,
)
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
if not equal_results:
success = False
if not success:
return Failure("Failed to run the optimized candidate.")
logger.debug(
f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}",
)
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
return Success(
OptimizedCandidateResult(
times_run=times_run,
best_test_runtime=total_candidate_timing,
best_test_results=best_test_results,
),
times_run=times_run, best_test_runtime=total_candidate_timing, best_test_results=best_test_results
)
)
def run_and_parse_tests(
@ -1146,14 +968,14 @@ class Optimizer:
)
except subprocess.TimeoutExpired:
logger.exception(
f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error',
f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error'
)
return TestResults()
if run_result.returncode != 0:
logger.debug(
f'Nonzero return code {run_result.returncode} when running tests in {", ".join([str(f.instrumented_file_path) for f in test_files.test_files])}.\n'
f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n",
f"stderr: {run_result.stderr}\n"
)
return parse_test_results(
test_xml_path=result_file_path,

View file

@ -29,9 +29,7 @@ def existing_tests_source_for(
existing_tests_unique = set()
if test_files:
for test_file in test_files:
existing_tests_unique.add(
"- " + str(Path(test_file.tests_in_file.test_file).relative_to(tests_root)),
)
existing_tests_unique.add("- " + str(Path(test_file.tests_in_file.test_file).relative_to(tests_root)))
return "\n".join(sorted(existing_tests_unique))
@ -52,8 +50,7 @@ def check_create_pr(
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
build_file_changes = {
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
oldContent=original_code[p],
newContent=new_code[p],
oldContent=original_code[p], newContent=new_code[p]
)
for p in original_code
if not is_zero_diff(original_code[p], new_code[p])
@ -85,7 +82,7 @@ def check_create_pr(
else:
logger.error(
f"Optimization was successful, but I failed to suggest changes to PR #{pr_number}."
f" Response from server was: {response.text}",
f" Response from server was: {response.text}"
)
else:
logger.info("Creating a new PR with the optimized code...")
@ -98,8 +95,7 @@ def check_create_pr(
base_branch = get_current_branch()
build_file_changes = {
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
oldContent=original_code[p],
newContent=new_code[p],
oldContent=original_code[p], newContent=new_code[p]
)
for p in original_code
}
@ -128,5 +124,5 @@ def check_create_pr(
else:
logger.error(
f"Optimization was successful, but I failed to create a PR with the optimized code."
f" Response from server was: {response.text}",
f" Response from server was: {response.text}"
)

View file

@ -13,9 +13,7 @@ def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) ->
def speedup_critic(
candidate_result: OptimizedCandidateResult,
original_code_runtime: int,
best_runtime_until_now: int,
candidate_result: OptimizedCandidateResult, original_code_runtime: int, best_runtime_until_now: int
) -> bool:
"""Takes in a correct optimized Test Result and decides if the optimization should actually
be surfaced to the user.
@ -33,8 +31,7 @@ def speedup_critic(
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
perf_gain = performance_gain(
original_runtime_ns=original_code_runtime,
optimized_runtime_ns=candidate_result.best_test_runtime,
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
)
if (perf_gain > noise_floor) and candidate_result.best_test_runtime < best_runtime_until_now:
return True

View file

@ -18,10 +18,7 @@ def initialize_posthog(enabled: bool) -> None:
return
global _posthog
_posthog = Posthog(
project_api_key="phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol",
host="https://us.posthog.com",
)
_posthog = Posthog(project_api_key="phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol", host="https://us.posthog.com")
_posthog.log.setLevel(logging.CRITICAL) # Suppress PostHog logging
ph("cli-telemetry-enabled")
@ -40,10 +37,6 @@ def ph(event: str, properties: Optional[Dict[str, Any]] = None) -> None:
user_id = get_user_id()
if user_id:
_posthog.capture(
distinct_id=user_id,
event=event,
properties=properties,
)
_posthog.capture(distinct_id=user_id, event=event, properties=properties)
else:
logger.debug("Failed to log event to PostHog: User ID could not be retrieved.")

View file

@ -75,7 +75,7 @@ class Tracer:
if sys.getprofile() is not None or sys.gettrace() is not None:
console.print(
"WARNING - Codeflash: Another profiler, debugger or coverage tool is already running. "
"Please disable it before starting the Codeflash Tracer, both can't run. Codeflash Tracer is DISABLED.",
"Please disable it before starting the Codeflash Tracer, both can't run. Codeflash Tracer is DISABLED."
)
self.disable = True
return
@ -91,22 +91,10 @@ class Tracer:
}
self.max_function_count = max_function_count
self.config, found_config_path = parse_config_file(config_file_path)
self.project_root = project_root_from_module_root(
self.config["module_root"],
found_config_path,
)
self.ignored_functions = {
"<listcomp>",
"<genexpr>",
"<dictcomp>",
"<setcomp>",
"<lambda>",
"<module>",
}
self.project_root = project_root_from_module_root(self.config["module_root"], found_config_path)
self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"}
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(
".", "_"
)
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_")
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
self.timeout = timeout
@ -121,9 +109,7 @@ class Tracer:
self.timer = time.process_time_ns
self.total_tt = 0
self.simulate_call("profiler")
assert (
"test_framework" in self.config
), "Please specify 'test-framework' in pyproject.toml config file"
assert "test_framework" in self.config, "Please specify 'test-framework' in pyproject.toml config file"
self.t = self.timer()
def __enter__(self) -> None:
@ -132,7 +118,7 @@ class Tracer:
if getattr(Tracer, "used_once", False):
console.print(
"Codeflash: Tracer can only be used once per program run. "
"Please only enable the Tracer once. Skipping tracing this section.",
"Please only enable the Tracer once. Skipping tracing this section."
)
self.disable = True
return
@ -148,7 +134,7 @@ class Tracer:
# TODO: Check out if we need to export the function test name as well
cur.execute(
"CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, "
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)",
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)"
)
console.print("Codeflash: Tracing started!")
frame = sys._getframe(0) # Get this frame and simulate a call to it
@ -168,23 +154,13 @@ class Tracer:
cur.execute(
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
"cumulative_time_ns INTEGER, callers BLOB)",
"cumulative_time_ns INTEGER, callers BLOB)"
)
for func, (cc, nc, tt, ct, callers) in self.stats.items():
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
cur.execute(
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
Path(func[0]).resolve(),
func[1],
func[2],
func[3],
cc,
nc,
tt,
ct,
json.dumps(remapped_callers),
),
(Path(func[0]).resolve(), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)),
)
self.con.commit()
@ -217,16 +193,14 @@ class Tracer:
)
function_path = "_".join(self.functions) if self.functions else self.file_being_called_from
test_file_path = get_test_file_path(
test_dir=self.config["tests_root"],
function_name=function_path,
test_type="replay",
test_dir=self.config["tests_root"], function_name=function_path, test_type="replay"
)
replay_test = isort.code(replay_test)
with open(test_file_path, "w", encoding="utf8") as file:
file.write(replay_test)
console.print(
f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}",
f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}"
)
def tracer_logic(self, frame: FrameType, event: str):
@ -235,9 +209,7 @@ class Tracer:
if self.timeout is not None:
if (time.time() - self.start_time) > self.timeout:
sys.setprofile(None)
console.print(
f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.",
)
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
return
code = frame.f_code
file_name = code.co_filename
@ -290,13 +262,10 @@ class Tracer:
FunctionModules(
function_name=code.co_name,
file_name=file_name,
module_name=module_name_from_file_path(
file_name,
project_root_path=self.project_root,
),
module_name=module_name_from_file_path(file_name, project_root_path=self.project_root),
class_name=class_name,
line_no=code.co_firstlineno,
),
)
)
# TODO: Also check if this function arguments are unique from the values logged earlier
@ -314,18 +283,12 @@ class Tracer:
arguments = dict(arguments.items())
if class_name and code.co_name == "__init__":
del arguments["self"]
local_vars = pickle.dumps(
arguments,
protocol=pickle.HIGHEST_PROTOCOL,
)
local_vars = pickle.dumps(arguments, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
# we retry with dill if pickle fails. It's slower but more comprehensive
try:
local_vars = dill.dumps(
arguments,
protocol=dill.HIGHEST_PROTOCOL,
)
local_vars = dill.dumps(arguments, protocol=dill.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
@ -334,16 +297,7 @@ class Tracer:
return
cur.execute(
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
(
event,
code.co_name,
class_name,
file_name,
frame.f_lineno,
frame.f_back.__hash__(),
t_ns,
local_vars,
),
(event, code.co_name, class_name, file_name, frame.f_lineno, frame.f_back.__hash__(), t_ns, local_vars),
)
self.trace_count += 1
self.next_insert -= 1
@ -374,14 +328,7 @@ class Tracer:
if self.cur and frame.f_back is not self.cur[-2]:
rpt, rit, ret, rfn, rframe, rcur = self.cur
if not isinstance(rframe, Tracer.fake_frame):
assert rframe.f_back is frame.f_back, (
"Bad call",
rfn,
rframe,
rframe.f_back,
frame,
frame.f_back,
)
assert rframe.f_back is frame.f_back, ("Bad call", rfn, rframe, rframe.f_back, frame, frame.f_back)
self.trace_dispatch_return(rframe, 0)
assert self.cur is None or frame.f_back is self.cur[-2], ("Bad call", self.cur[-3])
fcode = frame.f_code
@ -528,9 +475,7 @@ class Tracer:
console.print("Failed to get total time from stats")
total_time_ms = total_time / 1e6
raw_stats = re.sub(
r"(function calls?.*)in (\d+)\.\d+ (seconds?)",
rf"\1 in {total_time_ms:.3f} milliseconds",
raw_stats,
r"(function calls?.*)in (\d+)\.\d+ (seconds?)", rf"\1 in {total_time_ms:.3f} milliseconds", raw_stats
)
match_pattern = r"^ *[\d\/]+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +"
m = re.findall(match_pattern, raw_stats, re.MULTILINE)
@ -605,7 +550,6 @@ class Tracer:
def main():
import os
from argparse import ArgumentParser
parser = ArgumentParser(allow_abbrev=False)
@ -624,13 +568,7 @@ def main():
type=float,
default=None,
)
parser.add_argument(
"-m",
action="store_true",
dest="module",
help="Trace a library module",
default=False,
)
parser.add_argument("-m", action="store_true", dest="module", help="Trace a library module", default=False)
parser.add_argument(
"--codeflash-config",
help="Optional path to the project's pyproject.toml file "
@ -655,10 +593,7 @@ def main():
import runpy
code = "run_module(modname, run_name='__main__')"
globs = {
"run_module": runpy.run_module,
"modname": unknown_args[0],
}
globs = {"run_module": runpy.run_module, "modname": unknown_args[0]}
else:
progname = unknown_args[0]
sys.path.insert(0, str(Path(progname).parent))

View file

@ -39,9 +39,7 @@ class ProfileStats(pstats.Stats):
call_count_nonrecursive,
num_callers,
total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns,
cumulative_time_ns / time_conversion_factor
if time_conversion_factor != 1
else cumulative_time_ns,
cumulative_time_ns / time_conversion_factor if time_conversion_factor != 1 else cumulative_time_ns,
unmapped_callers,
)
@ -58,9 +56,7 @@ class ProfileStats(pstats.Stats):
print(indent, self.total_calls, "function calls", end=" ", file=self.stream)
if self.total_calls != self.prim_calls:
print("(%d primitive calls)" % self.prim_calls, end=" ", file=self.stream)
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[
self.time_unit
]
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit]
print("in %.3f %s" % (self.total_tt, time_unit), file=self.stream)
print(file=self.stream)
width, list = self.get_print_list(amount)

View file

@ -4,19 +4,12 @@ import sqlite3
import textwrap
from typing import Any, Generator, List, Optional
from codeflash.discovery.functions_to_optimize import (
FunctionProperties,
inspect_top_level_functions_or_methods,
)
from codeflash.discovery.functions_to_optimize import FunctionProperties, inspect_top_level_functions_or_methods
from codeflash.tracing.tracing_utils import FunctionModules
def get_next_arg_and_return(
trace_file: str,
function_name: str,
file_name: str,
class_name: Optional[str] = None,
num_to_get: int = 25,
trace_file: str, function_name: str, file_name: str, class_name: Optional[str] = None, num_to_get: int = 25
) -> Generator[Any]:
db = sqlite3.connect(trace_file)
cur = db.cursor()
@ -45,10 +38,7 @@ def get_function_alias(module: str, function_name: str) -> str:
def create_trace_replay_test(
trace_file: str,
functions: List[FunctionModules],
test_framework: str = "pytest",
max_run_count=100,
trace_file: str, functions: List[FunctionModules], test_framework: str = "pytest", max_run_count=100
) -> str:
assert test_framework in ["pytest", "unittest"]
@ -74,21 +64,19 @@ from codeflash.tracing.replay_test import get_next_arg_and_return
continue
if function_property.is_staticmethod:
function_imports.append(
f"from {function.module_name} import {function_property.staticmethod_class_name} as {get_function_alias(function.module_name, function_property.staticmethod_class_name)}",
f"from {function.module_name} import {function_property.staticmethod_class_name} as {get_function_alias(function.module_name, function_property.staticmethod_class_name)}"
)
elif function.class_name:
function_imports.append(
f"from {function.module_name} import {function.class_name} as {get_function_alias(function.module_name, function.class_name)}",
f"from {function.module_name} import {function.class_name} as {get_function_alias(function.module_name, function.class_name)}"
)
else:
function_imports.append(
f"from {function.module_name} import {function.function_name} as {get_function_alias(function.module_name, function.function_name)}",
f"from {function.module_name} import {function.function_name} as {get_function_alias(function.module_name, function.function_name)}"
)
imports += "\n".join(function_imports)
functions_to_optimize = [
function.function_name for function in functions if function.function_name != "__init__"
]
functions_to_optimize = [function.function_name for function in functions if function.function_name != "__init__"]
metadata = f"""functions = {functions_to_optimize}
trace_file_path = r"{trace_file}"
""" # trace_file_path path is parsed with regex later, format is important
@ -97,21 +85,21 @@ trace_file_path = r"{trace_file}"
for arg_val_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", num_to_get={max_run_count}):
args = pickle.loads(arg_val_pkl)
ret = {function_name}({args})
""",
"""
)
test_class_method_body = textwrap.dedent(
"""\
for arg_val_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(arg_val_pkl){filter_variables}
ret = {class_name_alias}{method_name}(**args)
""",
"""
)
test_class_staticmethod_body = textwrap.dedent(
"""\
for arg_val_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name=r"{file_name}", num_to_get={max_run_count}):
args = pickle.loads(arg_val_pkl){filter_variables}
ret = {class_name_alias}{method_name}(**args)
""",
"""
)
if test_framework == "unittest":
self = "self"
@ -135,8 +123,7 @@ trace_file_path = r"{trace_file}"
elif func_property.is_staticmethod:
class_name_alias = get_function_alias(func.module_name, func_property.staticmethod_class_name)
alias = get_function_alias(
func.module_name,
func_property.staticmethod_class_name + "_" + func.function_name,
func.module_name, func_property.staticmethod_class_name + "_" + func.function_name
)
method_name = "." + func.function_name if func.function_name != "__init__" else ""
test_body = test_class_staticmethod_body.format(
@ -149,10 +136,7 @@ trace_file_path = r"{trace_file}"
)
else:
class_name_alias = get_function_alias(func.module_name, func.class_name)
alias = get_function_alias(
func.module_name,
func.class_name + "_" + func.function_name,
)
alias = get_function_alias(func.module_name, func.class_name + "_" + func.function_name)
if func_property.is_classmethod:
filter_variables = '\n args.pop("cls", None)'
@ -170,10 +154,7 @@ trace_file_path = r"{trace_file}"
max_run_count=max_run_count,
filter_variables=filter_variables,
)
formatted_test_body = textwrap.indent(
test_body,
" " if test_framework == "unittest" else " ",
)
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
test_template += " " if test_framework == "unittest" else ""
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"

View file

@ -129,21 +129,11 @@ def comparator(orig: Any, new: Any) -> bool:
return (orig != new).nnz == 0
if HAS_PANDAS and isinstance(
orig,
(
pandas.DataFrame,
pandas.Series,
pandas.Index,
pandas.Categorical,
pandas.arrays.SparseArray,
),
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
):
return orig.equals(new)
if HAS_PANDAS and isinstance(
orig,
(pandas.CategoricalDtype, pandas.Interval, pandas.Period),
):
if HAS_PANDAS and isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
return orig == new
# This should be at the end of all numpy checking
@ -173,10 +163,7 @@ def comparator(orig: Any, new: Any) -> bool:
):
return orig == new
if isinstance(
orig,
(datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone),
):
if isinstance(orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone)):
return orig == new
# If the object passed has a user defined __eq__ method, use that

View file

@ -11,17 +11,9 @@ import dill as pickle
from junitparser.xunit2 import JUnitXml
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import (
file_path_from_module_name,
get_run_tmp_file,
module_name_from_file_path,
)
from codeflash.code_utils.code_utils import file_path_from_module_name, get_run_tmp_file, module_name_from_file_path
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
from codeflash.verification.test_results import (
FunctionTestInvocation,
InvocationId,
TestResults,
)
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults
if TYPE_CHECKING:
import subprocess
@ -30,11 +22,7 @@ if TYPE_CHECKING:
from codeflash.verification.verification_utils import TestConfig
def parse_test_return_values_bin(
file_location: Path,
test_files: TestFiles,
test_config: TestConfig,
) -> TestResults:
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
test_results = TestResults()
if not file_location.exists():
logger.warning(f"No test results for {file_location} found.")
@ -66,8 +54,7 @@ def parse_test_return_values_bin(
invocation_id_object = InvocationId.from_str_id(encoded_test_name, invocation_id)
test_file_path = file_path_from_module_name(
invocation_id_object.test_module_path,
test_config.tests_project_rootdir,
invocation_id_object.test_module_path, test_config.tests_project_rootdir
)
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
@ -85,16 +72,12 @@ def parse_test_return_values_bin(
test_type=test_type,
return_value=test_pickle,
timed_out=False,
),
)
)
return test_results
def parse_sqlite_test_results(
sqlite_file_path: Path,
test_files: TestFiles,
test_config: TestConfig,
) -> TestResults:
def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
test_results = TestResults()
if not sqlite_file_path.exists():
logger.warning(f"No test results for {sqlite_file_path} found.")
@ -104,7 +87,7 @@ def parse_sqlite_test_results(
cur = db.cursor()
data = cur.execute(
"SELECT test_module_path, test_class_name, test_function_name, "
"function_getting_tested, loop_index, iteration_id, runtime, return_value FROM test_results",
"function_getting_tested, loop_index, iteration_id, runtime, return_value FROM test_results"
).fetchall()
finally:
db.close()
@ -132,7 +115,7 @@ def parse_sqlite_test_results(
test_type=test_type,
return_value=pickle.loads(val[7]) if loop_index == 1 else None,
timed_out=False,
),
)
)
except Exception:
logger.exception("Failed to load pickle file.")
@ -155,14 +138,10 @@ def parse_test_xml(
try:
xml = JUnitXml.fromfile(str(test_xml_file_path))
except Exception as e:
logger.warning(
f"Failed to parse {test_xml_file_path} as JUnitXml. Exception: {e}",
)
logger.warning(f"Failed to parse {test_xml_file_path} as JUnitXml. Exception: {e}")
return test_results
base_dir = (
test_config.tests_project_rootdir
if test_config.test_framework == "pytest"
else test_config.project_root_path
test_config.tests_project_rootdir if test_config.test_framework == "pytest" else test_config.project_root_path
)
for suite in xml:
for testcase in suite:
@ -178,12 +157,10 @@ def parse_test_xml(
logger.info("Test failed to load, skipping it.")
if run_result is not None:
if isinstance(run_result.stdout, str) and isinstance(run_result.stderr, str):
logger.info(
f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}",
)
logger.info(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}")
else:
logger.info(
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}",
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}"
)
return test_results
@ -192,15 +169,9 @@ def parse_test_xml(
if test_file_name is None:
if test_class_path:
# TODO : This might not be true if the test is organized under a class
test_file_path = file_path_from_module_name(
test_class_path,
base_dir,
)
test_file_path = file_path_from_module_name(test_class_path, base_dir)
else:
test_file_path = file_path_from_module_name(
test_function,
base_dir,
)
test_file_path = file_path_from_module_name(test_function, base_dir)
else:
# TODO: not sure which root path fits better here
test_file_path = base_dir / test_file_name
@ -213,9 +184,7 @@ def parse_test_xml(
result = testcase.is_passed # TODO: See for the cases of ERROR and SKIPPED
test_class = None
if class_name is not None and class_name.startswith(test_module_path):
test_class = class_name[
len(test_module_path) + 1 :
] # +1 for the dot, gets Unittest class name
test_class = class_name[len(test_module_path) + 1 :] # +1 for the dot, gets Unittest class name
loop_index = 1
if test_function is None:
@ -226,26 +195,19 @@ def parse_test_xml(
if test_config.test_framework == "pytest":
loop_index = int(testcase.name.split("[ ", 1)[1][:-2]) if "[" in testcase.name else 1
if len(testcase.result) > 1:
logger.warning(
f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!",
)
logger.warning(f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!")
if len(testcase.result) == 1:
message = testcase.result[0].message.lower()
if "failed: timeout >" in message:
timed_out = True
else:
if len(testcase.result) > 1:
logger.warning(
f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!",
)
logger.warning(f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!")
if len(testcase.result) == 1:
message = testcase.result[0].message.lower()
if "timed out" in message:
timed_out = True
matches = re.findall(
r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!",
testcase.system_out or "",
)
matches = re.findall(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!", testcase.system_out or "")
if not matches or not len(matches):
test_results.add(
FunctionTestInvocation(
@ -264,7 +226,7 @@ def parse_test_xml(
test_type=test_type,
return_value=None,
timed_out=timed_out,
),
)
)
else:
@ -286,12 +248,12 @@ def parse_test_xml(
test_type=test_type,
return_value=None,
timed_out=timed_out,
),
)
)
if not test_results:
logger.info(
f"Tests '{[test_file.original_file_path for test_file in test_files.test_files]}' failed to run, skipping",
f"Tests '{[test_file.original_file_path for test_file in test_files.test_files]}' failed to run, skipping"
)
if run_result is not None:
stdout, stderr = "", ""
@ -300,16 +262,12 @@ def parse_test_xml(
stderr = run_result.stderr.decode()
except AttributeError:
stdout = run_result.stderr
logger.debug(
f"Test log - STDOUT : {stdout} \n STDERR : {stderr}",
)
logger.debug(f"Test log - STDOUT : {stdout} \n STDERR : {stderr}")
return test_results
def merge_test_results(
xml_test_results: TestResults,
bin_test_results: TestResults,
test_framework: str,
xml_test_results: TestResults, bin_test_results: TestResults, test_framework: str
) -> TestResults:
merged_test_results = TestResults()
@ -319,18 +277,14 @@ def merge_test_results(
# This is done to match the right iteration_id which might not be available in the xml
for result in xml_test_results:
if test_framework == "pytest":
if (
result.id.test_function_name.endswith("]") and "[" in result.id.test_function_name
): # parameterized test
if result.id.test_function_name.endswith("]") and "[" in result.id.test_function_name: # parameterized test
test_function_name = result.id.test_function_name[: result.id.test_function_name.index("[")]
else:
test_function_name = result.id.test_function_name
if test_framework == "unittest":
test_function_name = result.id.test_function_name
is_parameterized, new_test_function_name, _ = discover_parameters_unittest(
test_function_name,
)
is_parameterized, new_test_function_name, _ = discover_parameters_unittest(test_function_name)
if is_parameterized: # handle parameterized test
test_function_name = new_test_function_name
@ -378,7 +332,7 @@ def merge_test_results(
test_type=xml_result.test_type,
return_value=result_bin.return_value,
timed_out=xml_result.timed_out,
),
)
)
elif xml_results.test_results[0].id.iteration_id is not None:
# This means that we have multiple iterations of the same test function
@ -404,7 +358,7 @@ def merge_test_results(
timed_out=xml_result.timed_out
if bin_result.runtime is None
else False, # If runtime was measured in the bin file, then the testcase did not time out
),
)
)
else:
# Should happen only if the xml did not have any test invocation id info
@ -427,7 +381,7 @@ def merge_test_results(
test_type=bin_result.test_type,
return_value=bin_result.return_value,
timed_out=xml_result.timed_out, # only the xml gets the timed_out flag
),
)
)
return merged_test_results
@ -441,10 +395,7 @@ def parse_test_results(
run_result: subprocess.CompletedProcess | None = None,
) -> TestResults:
test_results_xml = parse_test_xml(
test_xml_path,
test_files=test_files,
test_config=test_config,
run_result=run_result,
test_xml_path, test_files=test_files, test_config=test_config, run_result=run_result
)
try:
@ -463,9 +414,7 @@ def parse_test_results(
sql_results_file = get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.sqlite"))
if sql_results_file.exists():
test_results_sqlite_file = parse_sqlite_test_results(
sqlite_file_path=sql_results_file,
test_files=test_files,
test_config=test_config,
sqlite_file_path=sql_results_file, test_files=test_files, test_config=test_config
)
test_results_bin_file.merge(test_results_sqlite_file)
except AttributeError as e:
@ -475,8 +424,4 @@ def parse_test_results(
get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.sqlite")).unlink(missing_ok=True)
return merge_test_results(
test_results_xml,
test_results_bin_file,
test_config.test_framework,
)
return merge_test_results(test_results_xml, test_results_bin_file, test_config.test_framework)

View file

@ -44,11 +44,7 @@ def pytest_addoption(parser: Parser) -> None:
help="The amount of time to wait between each test loop.",
)
pytest_loops.addoption(
"--codeflash_hours",
action="store",
default=0,
type=float,
help="The number of hours to loop the tests for.",
"--codeflash_hours", action="store", default=0, type=float, help="The number of hours to loop the tests for."
)
pytest_loops.addoption(
"--codeflash_minutes",
@ -66,11 +62,7 @@ def pytest_addoption(parser: Parser) -> None:
)
pytest_loops.addoption(
"--codeflash_loops",
action="store",
default=1,
type=int,
help="The number of times to loop each test",
"--codeflash_loops", action="store", default=1, type=int, help="The number of times to loop each test"
)
pytest_loops.addoption(
@ -101,10 +93,7 @@ def pytest_addoption(parser: Parser) -> None:
@pytest.hookimpl(trylast=True)
def pytest_configure(config: Config) -> None:
config.addinivalue_line(
"markers",
"loops(n): run the given test function `n` times.",
)
config.addinivalue_line("markers", "loops(n): run the given test function `n` times.")
config.pluginmanager.register(PyTest_Loops(config), PyTest_Loops.name)
@ -122,8 +111,7 @@ class PyTest_Loops:
"""Reimplement the test loop but loop for the user defined amount of time."""
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else ""),
"%d error%s during collection" % (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
@ -188,9 +176,7 @@ class PyTest_Loops:
seconds = session.config.option.codeflash_seconds
total_time = hours_in_seconds + minutes_in_seconds + seconds
if total_time < SHORTEST_AMOUNT_OF_TIME:
raise InvalidTimeParameterError(
f"Total time cannot be less than: {SHORTEST_AMOUNT_OF_TIME}!",
)
raise InvalidTimeParameterError(f"Total time cannot be less than: {SHORTEST_AMOUNT_OF_TIME}!")
return total_time
def _timed_out(self, session: Session, start_time: float, count: int) -> bool:
@ -233,7 +219,7 @@ class PyTest_Loops:
else:
raise UnexpectedError(
"This call couldn't work with pytest-loops. "
"Please consider raising an issue with your usage.",
"Please consider raising an issue with your usage."
)
return count
@ -257,9 +243,5 @@ class PyTest_Loops:
scope = metafunc.config.option.codeflash_loops_scope
metafunc.parametrize(
"__pytest_loop_step_number",
range(count),
indirect=True,
ids=make_progress_id,
scope=scope,
"__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope
)

View file

@ -85,10 +85,7 @@ class TestResults(BaseModel):
def merge(self, other: TestResults) -> None:
self.test_results.extend(other.test_results)
def get_by_id(
self,
invocation_id: InvocationId,
) -> FunctionTestInvocation | None:
def get_by_id(self, invocation_id: InvocationId) -> FunctionTestInvocation | None:
return next((r for r in self.test_results if r.id == invocation_id), None)
def get_all_ids(self) -> set[InvocationId]:
@ -129,7 +126,7 @@ class TestResults(BaseModel):
[
f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})"
for test_type in TestType
],
]
)
def total_passed_runtime(self) -> int:
@ -141,14 +138,14 @@ class TestResults(BaseModel):
for result in self.test_results:
if result.did_pass and not result.runtime:
logger.debug(
f"Ignoring test case that passed but had no runtime -> {result.id}, Loop # {result.loop_index}",
f"Ignoring test case that passed but had no runtime -> {result.id}, Loop # {result.loop_index}"
)
usable_results = [result for result in self.test_results if result.did_pass and result.runtime]
return sum(
[
min([result.runtime for result in usable_results if result.id == invocation_id])
for invocation_id in {result.id for result in usable_results}
],
]
)
def __iter__(self) -> Iterator[FunctionTestInvocation]:
@ -192,10 +189,7 @@ class TestResults(BaseModel):
or test_result.runtime != other_test_result.runtime
or test_result.test_framework != other_test_result.test_framework
or test_result.test_type != other_test_result.test_type
or not comparator(
test_result.return_value,
other_test_result.return_value,
)
or not comparator(test_result.return_value, other_test_result.return_value)
):
sys.setrecursionlimit(original_recursion_limit)
return False

View file

@ -77,7 +77,5 @@ def run_tests(
check=False,
)
else:
raise ValueError(
"Invalid test framework -- I only support Pytest and Unittest currently.",
)
raise ValueError("Invalid test framework -- I only support Pytest and Unittest currently.")
return result_file_path, results

View file

@ -4,12 +4,7 @@ from pathlib import Path
from pydantic.dataclasses import dataclass
def get_test_file_path(
test_dir: Path,
function_name: str,
iteration: int = 0,
test_type: str = "unit",
) -> Path:
def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
assert test_type in ["unit", "inspired", "replay"]
function_name = function_name.replace(".", "_")
path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}.py"

View file

@ -40,10 +40,7 @@ def generate_tests(
instrumented_test_source = module.CACHED_INSTRUMENTED_TESTS
temp_run_dir = get_run_tmp_file(Path())
path = str(temp_run_dir).replace("\\", "\\\\") # Escape backslash for windows paths
instrumented_test_source = instrumented_test_source.replace(
"{codeflash_run_tmp_dir_client_side}",
path,
)
instrumented_test_source = instrumented_test_source.replace("{codeflash_run_tmp_dir_client_side}", path)
logger.info(f"Using cached tests from {module_path}.CACHED_TESTS")
else:
test_file_path = get_test_file_path(test_cfg.tests_root, function_to_optimize.function_name, 0)
@ -63,14 +60,9 @@ def generate_tests(
generated_test_source, instrumented_test_source = response
temp_run_dir = get_run_tmp_file(Path())
path = str(temp_run_dir).replace("\\", "\\\\")
instrumented_test_source = instrumented_test_source.replace(
"{codeflash_run_tmp_dir_client_side}",
path,
)
instrumented_test_source = instrumented_test_source.replace("{codeflash_run_tmp_dir_client_side}", path)
else:
logger.warning(
f"Failed to generate and instrument tests for {function_to_optimize.function_name}",
)
logger.warning(f"Failed to generate and instrument tests for {function_to_optimize.function_name}")
return None
return generated_test_source, instrumented_test_source

View file

@ -15,11 +15,7 @@ def load_data(experiment_id: str, database_uri: str = os.environ.get("DATABASE_U
WHERE (trace_id LIKE %s OR trace_id LIKE %s)
AND experiment_metadata->>'id' = %s
"""
return pd.read_sql_query(
query,
connection,
params=("%EXP0", "%EXP1", experiment_id),
)
return pd.read_sql_query(query, connection, params=("%EXP0", "%EXP1", experiment_id))
def process_column_pairs(df: DataFrame, column_name: str) -> DataFrame:
@ -44,22 +40,15 @@ def process_column_pairs(df: DataFrame, column_name: str) -> DataFrame:
def calculate_validity(df: DataFrame, perf_threshold: float = 0.05) -> Dict[str, Any]:
# Calculate the percentage of valid PRs given that the original function run succeeded
successful_runs = df[(~df["original_runtime"].isna())]
successful_runs_above_thres = successful_runs[
successful_runs["best_correct_speedup_ratio"] >= perf_threshold
]
successful_runs_above_thres = successful_runs[successful_runs["best_correct_speedup_ratio"] >= perf_threshold]
valid_prs = len(successful_runs_above_thres)
percent_valid_pr = (valid_prs / len(df)) * 100
# Calculate the percentage of valid candidates generated given that original function run succeeded
valid_candidates = successful_runs[successful_runs["is_correct"].apply(lambda x: any(x.values()))]
percent_valid_candidates = (
len(valid_candidates) / len(successful_runs) * 100 if len(successful_runs) > 0 else 0
)
percent_valid_candidates = len(valid_candidates) / len(successful_runs) * 100 if len(successful_runs) > 0 else 0
return {
"percent_valid_pr": percent_valid_pr,
"percent_valid_candidates": percent_valid_candidates,
}
return {"percent_valid_pr": percent_valid_pr, "percent_valid_candidates": percent_valid_candidates}
def calculate_performance(df: DataFrame, perf_threshold: float = 0.05) -> Dict[str, Any]:
@ -81,9 +70,7 @@ def calculate_performance(df: DataFrame, perf_threshold: float = 0.05) -> Dict[s
def calculate_time_saved_for_row(row: pd.Series):
if row["optimized_runtime"] is not None and row["is_correct"] is not None:
correct_runtimes = [
runtime
for opt_id, runtime in row["optimized_runtime"].items()
if row["is_correct"].get(opt_id)
runtime for opt_id, runtime in row["optimized_runtime"].items() if row["is_correct"].get(opt_id)
]
else:
correct_runtimes = []
@ -93,23 +80,11 @@ def calculate_performance(df: DataFrame, perf_threshold: float = 0.05) -> Dict[s
# (3) The average time saved in a PR given that a valid candidate was found above the perf threshold.
pr_time_saved = (
valid_candidates_above_thres.apply(
lambda row: calculate_time_saved_for_row(row),
axis=1,
)
.dropna()
.mean()
valid_candidates_above_thres.apply(lambda row: calculate_time_saved_for_row(row), axis=1).dropna().mean()
)
# (4) Calculate the mean average time saved for all the valid candidates
all_candidates_time_saved = (
df.apply(
lambda row: calculate_time_saved_for_row(row),
axis=1,
)
.dropna()
.mean()
)
all_candidates_time_saved = df.apply(lambda row: calculate_time_saved_for_row(row), axis=1).dropna().mean()
return {
"average_percentage_gain_pr": average_percentage_gain_pr,
@ -132,7 +107,7 @@ def calculate_coverage(df: DataFrame) -> Dict[str, Any]:
return successful_optimization_runs / total_runs * 100 if total_runs > 0 else 0.0
df["percent_successful_optimization_runs"] = df["optimized_runtime"].apply(
calculate_percent_optimization_successful_runs,
calculate_percent_optimization_successful_runs
)
total_optimizations = sum(len(runs) for runs in successful_runs["optimized_runtime"] if runs is not None)
@ -158,15 +133,9 @@ def calculate_coverage(df: DataFrame) -> Dict[str, Any]:
def paired_comparison_coverage(
df: DataFrame,
model_a_suffix: str = "EXP0",
model_b_suffix: str = "EXP1",
df: DataFrame, model_a_suffix: str = "EXP0", model_b_suffix: str = "EXP1"
) -> Dict[str, Any]:
paired_coverage_results = {
"model_a_more_successful": 0,
"equal_successful": 0,
"model_b_more_successful": 0,
}
paired_coverage_results = {"model_a_more_successful": 0, "equal_successful": 0, "model_b_more_successful": 0}
grouped = df.groupby(df["trace_id"].str[:-4])
for _, group in grouped:
if len(group) == 2:
@ -176,17 +145,13 @@ def paired_comparison_coverage(
model_a_success_count = 0
else:
model_a_success_count = sum(
1
for runtime in model_a_row["optimized_runtime"].values[0].values()
if runtime is not None
1 for runtime in model_a_row["optimized_runtime"].values[0].values() if runtime is not None
)
if model_b_row["optimized_runtime"].values[0] is None:
model_b_success_count = 0
else:
model_b_success_count = sum(
1
for runtime in model_b_row["optimized_runtime"].values[0].values()
if runtime is not None
1 for runtime in model_b_row["optimized_runtime"].values[0].values() if runtime is not None
)
if model_a_success_count > model_b_success_count:
@ -201,16 +166,10 @@ def paired_comparison_coverage(
def paired_comparison_validity(
df: DataFrame,
model_a_suffix: str = "EXP0",
model_b_suffix: str = "EXP1",
df: DataFrame, model_a_suffix: str = "EXP0", model_b_suffix: str = "EXP1"
) -> Dict[str, Any]:
# Paired - Calculate the percentage of runs where model A generated more, equal, or less valid candidates than model B
paired_validity_results = {
"model_a_more_valid": 0,
"equal_valid": 0,
"model_b_more_valid": 0,
}
paired_validity_results = {"model_a_more_valid": 0, "equal_valid": 0, "model_b_more_valid": 0}
grouped = df.groupby(df["trace_id"].str[:-4])
for _, group in grouped:
if len(group) == 2:
@ -235,15 +194,9 @@ def paired_comparison_validity(
def paired_comparison_performance(
df: DataFrame,
model_a_suffix: str = "EXP0",
model_b_suffix: str = "EXP1",
df: DataFrame, model_a_suffix: str = "EXP0", model_b_suffix: str = "EXP1"
) -> Dict[str, Any]:
paired_results = {
"model_a_better": 0,
"equal": 0,
"model_b_better": 0,
}
paired_results = {"model_a_better": 0, "equal": 0, "model_b_better": 0}
# Group by the trace_id without the suffix
grouped = df.groupby(df["trace_id"].str[:-4])
@ -274,8 +227,7 @@ def paired_comparison_performance(
def augment_with_best_correct_speedup_ratio(df: DataFrame) -> DataFrame:
# Extract the best speedup ratio from the speedup_ratio dictionary, accounting for empty dictionaries
def get_best_correct_speedup_ratio(
speedup_ratios: Dict[str, float],
is_correct: Dict[str, bool],
speedup_ratios: Dict[str, float], is_correct: Dict[str, bool]
) -> Optional[float]:
correct_speedup_ratios = (
{uuid: ratio for uuid, ratio in speedup_ratios.items() if is_correct.get(uuid)}
@ -288,10 +240,7 @@ def augment_with_best_correct_speedup_ratio(df: DataFrame) -> DataFrame:
return None
df["best_correct_speedup_ratio"] = df.apply(
lambda row: get_best_correct_speedup_ratio(
row["speedup_ratio"],
row["is_correct"],
)
lambda row: get_best_correct_speedup_ratio(row["speedup_ratio"], row["is_correct"])
if row["speedup_ratio"] is not None
else None,
axis=1,
@ -324,17 +273,9 @@ def main() -> None:
# Combine metrics into a DataFrame
metrics_df = pd.DataFrame(
{
"EXP0": {
**exp0_performance_metrics,
**exp0_validity_metrics,
**exp0_coverage_metrics,
},
"EXP1": {
**exp1_performance_metrics,
**exp1_validity_metrics,
**exp1_coverage_metrics,
},
},
"EXP0": {**exp0_performance_metrics, **exp0_validity_metrics, **exp0_coverage_metrics},
"EXP1": {**exp1_performance_metrics, **exp1_validity_metrics, **exp1_coverage_metrics},
}
) # Transpose to have experiments as rows and metrics as columns
# Output the combined metrics DataFrame

View file

@ -45,10 +45,7 @@ def main(trace_id: str) -> None:
write_to_file("original_code.py", original_code)
# Write each optimization candidate to its own file
for idx, (opt_id, optimization) in enumerate(
extract_json_values(optimizations_post).items(),
start=1,
):
for idx, (opt_id, optimization) in enumerate(extract_json_values(optimizations_post).items(), start=1):
filename = f"optimization_candidate_{idx}.py"
explanation = explanations_post.get(opt_id, "")
speedup = speedup_ratio.get(opt_id)
@ -61,21 +58,13 @@ def main(trace_id: str) -> None:
valid_speedup_values = [v for v in speedup_ratio.values() if v is not None]
best_speedup = max(valid_speedup_values, default=None) if valid_speedup_values else None
if best_speedup is not None:
best_optimization_id = next(
(id for id, speedup in speedup_ratio.items() if speedup == best_speedup),
None,
)
best_optimization_id = next((id for id, speedup in speedup_ratio.items() if speedup == best_speedup), None)
best_optimization = optimizations_post.get(best_optimization_id)
best_explanation = explanations_post.get(best_optimization_id, "")
best_speedup_comment = f"Speedup: {best_speedup}"
best_content_with_comment = (
f'"""{best_explanation}\n\n{best_speedup_comment}"""\n\n{best_optimization}'
)
best_content_with_comment = f'"""{best_explanation}\n\n{best_speedup_comment}"""\n\n{best_optimization}'
if best_optimization and best_explanation is not None:
write_to_file(
"best_optimization_candidate.py",
best_content_with_comment,
)
write_to_file("best_optimization_candidate.py", best_content_with_comment)
else:
print("No speedup ratio found")

View file

@ -1,11 +1,5 @@
mock_dataframe = {
"original_runtime": [
1.0,
2.0,
3.0,
4.0,
8333.00000,
],
"original_runtime": [1.0, 2.0, 3.0, 4.0, 8333.00000],
"optimized_runtime": [
{
"0a631960-6c33-4de4-8d38-b31e888918c7": 59209,
@ -51,7 +45,7 @@ mock_dataframe = {
"573c9f22-ae68-4e62-929d-b70d22da0ec0": -0.10714668381013608,
"6fa4faaf-5d67-4ec6-bf55-5a84dfa23a11": 0.1695438596491228,
"a06d1ad8-cc41-4b3b-a1a6-4566697d7b10": 0.16285235835891712,
},
}
],
"best_correct_speedup_ratio": [0.02112, None, 1.81043, None],
"is_correct": [

View file

@ -398,22 +398,22 @@ pie4perf_sample_dataframe_dict = {
},
"generated_test": {
0: [
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02847 import problem_p02847\n\n# unit tests\n\n# Test normal cases with valid inputs for each day of the week\n@pytest.mark.parametrize("input_data, expected", [\n ("\'MON\'", 6),\n ("\'TUE\'", 5),\n ("\'WED\'", 4),\n ("\'THU\'", 3),\n ("\'FRI\'", 2),\n ("\'SAT\'", 1),\n ("\'SUN\'", 7),\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p02847(input_data) == expected\n\n# Test edge cases with invalid day names\n@pytest.mark.parametrize("input_data", [\n "\'FUNDAY\'",\n "\'MOON\'",\n "123",\n "[\'MON\']",\n "None",\n])\ndef test_edge_cases_with_invalid_day_names(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test misuse cases with potentially malicious code\n@pytest.mark.parametrize("input_data", [\n "__import__(\'os\').system(\'echo hello\')",\n "exit()",\n])\ndef test_misuse_cases_with_malicious_code(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with lowercase and mixed case inputs\n@pytest.mark.parametrize("input_data", [\n "\'mon\'",\n "\'Sun\'",\n "\'Mon.\'",\n "\'Monday\'",\n])\ndef test_edge_cases_with_case_sensitivity(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with additional text\n@pytest.mark.parametrize("input_data", [\n "\'Today is MON\'",\n "\'SUNday\'",\n])\ndef test_edge_cases_with_additional_text(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with non-standard capitalization and characters\n@pytest.mark.parametrize("input_data", [\n "\'monDAY\'",\n "\'M@N\'",\n])\ndef test_edge_cases_with_non_standard_characters(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with whitespace and control characters\n@pytest.mark.parametrize("input_data", [\n "\'\\t\\nMON\\n\\t\'",\n "\'MON\\u200b\'",\n])\ndef test_edge_cases_with_whitespace_and_control_characters(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with similar sounding or looking names\n@pytest.mark.parametrize("input_data", [\n "\'MOAN\'",\n "\'SUNN\'",\n])\ndef test_edge_cases_with_similar_names(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with empty and null inputs\n@pytest.mark.parametrize("input_data", [\n "\'\'",\n "None",\n])\ndef test_edge_cases_with_empty_and_null_inputs(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with valid Python expressions\n@pytest.mark.parametrize("input_data", [\n "\'6-1\'",\n "\'[x for x in range(5)]\'",\n])\ndef test_edge_cases_with_valid_python_expressions(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with input as Python code\n@pytest.mark.parametrize("input_data", [\n "\'os.system(\\\'echo hello\\\')\'",\n "\'(__import__(\\\'sys\\\').exit())\'",\n])\ndef test_edge_cases_with_python_code(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)',
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02847 import problem_p02847\n\n# unit tests\n\n# Test normal cases with valid inputs for each day of the week\n@pytest.mark.parametrize("input_data, expected", [\n ("\'MON\'", 6),\n ("\'TUE\'", 5),\n ("\'WED\'", 4),\n ("\'THU\'", 3),\n ("\'FRI\'", 2),\n ("\'SAT\'", 1),\n ("\'SUN\'", 7),\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p02847(input_data) == expected\n\n# Test edge cases with invalid day names\n@pytest.mark.parametrize("input_data", [\n "\'FUNDAY\'",\n "\'MOON\'",\n "123",\n "[\'MON\']",\n "None",\n])\ndef test_edge_cases_with_invalid_day_names(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test misuse cases with potentially malicious code\n@pytest.mark.parametrize("input_data", [\n "__import__(\'os\').system(\'echo hello\')",\n "exit()",\n])\ndef test_misuse_cases_with_malicious_code(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with lowercase and mixed case inputs\n@pytest.mark.parametrize("input_data", [\n "\'mon\'",\n "\'Sun\'",\n "\'Mon.\'",\n "\'Monday\'",\n])\ndef test_edge_cases_with_case_sensitivity(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with additional text\n@pytest.mark.parametrize("input_data", [\n "\'Today is MON\'",\n "\'SUNday\'",\n])\ndef test_edge_cases_with_additional_text(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with non-standard capitalization and characters\n@pytest.mark.parametrize("input_data", [\n "\'monDAY\'",\n "\'M@N\'",\n])\ndef test_edge_cases_with_non_standard_characters(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with whitespace and control characters\n@pytest.mark.parametrize("input_data", [\n "\'\\t\\nMON\\n\\t\'",\n "\'MON\\u200b\'",\n])\ndef test_edge_cases_with_whitespace_and_control_characters(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with similar sounding or looking names\n@pytest.mark.parametrize("input_data", [\n "\'MOAN\'",\n "\'SUNN\'",\n])\ndef test_edge_cases_with_similar_names(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with empty and null inputs\n@pytest.mark.parametrize("input_data", [\n "\'\'",\n "None",\n])\ndef test_edge_cases_with_empty_and_null_inputs(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with valid Python expressions\n@pytest.mark.parametrize("input_data", [\n "\'6-1\'",\n "\'[x for x in range(5)]\'",\n])\ndef test_edge_cases_with_valid_python_expressions(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)\n\n# Test edge cases with input as Python code\n@pytest.mark.parametrize("input_data", [\n "\'os.system(\\\'echo hello\\\')\'",\n "\'(__import__(\\\'sys\\\').exit())\'",\n])\ndef test_edge_cases_with_python_code(input_data):\n with pytest.raises(Exception):\n problem_p02847(input_data)'
],
1: [
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03196 import problem_p03196\n\n# unit tests\n\n# Test normal cases\n@pytest.mark.parametrize("input_data, expected", [\n ("2 16", 4),\n ("3 27", 3),\n ("2 20", 1),\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test edge cases\n@pytest.mark.parametrize("input_data, expected", [\n ("1 100", 100),\n ("1 1", 1),\n ("41 1000000", 1),\n ("100 2", 1),\n ("2 17", 1),\n ("3 19", 1),\n])\ndef test_edge_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test large numbers\n@pytest.mark.parametrize("input_data, expected", [\n ("2 1000000000000", 1000000),\n ("3 1000000000000", 10000),\n])\ndef test_large_numbers(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test invalid input format\n@pytest.mark.parametrize("input_data", [\n "2.5 10",\n "2 -10",\n "2",\n "two 10",\n "2, 10",\n])\ndef test_invalid_input(input_data):\n with pytest.raises(ValueError):\n problem_p03196(input_data)\n\n# Test boundary cases\n@pytest.mark.parametrize("input_data, expected", [\n ("40 1024", 2),\n ("2 1024", 32),\n])\ndef test_boundary_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test rare or unexpected edge cases\n@pytest.mark.parametrize("input_data, expected", [\n ("2 " + str(2**63 - 1), 1),\n ("3 " + str(2**63 - 1), 1),\n ("2 1", 1),\n ("10 1", 1),\n ("37 137**37", 137),\n ("37 138", 1),\n ("3 8", 2),\n ("5 32", 2),\n ("2 6", 2),\n ("2 120", 4),\n ("2 49", 7),\n ("2 77", 1),\n])\ndef test_unexpected_edge_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test zero or negative n values\n@pytest.mark.parametrize("input_data", [\n "0 10",\n "-2 16",\n])\ndef test_zero_or_negative_n(input_data):\n with pytest.raises(ValueError):\n problem_p03196(input_data)',
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03196 import problem_p03196\n\n# unit tests\n\n# Test normal cases\n@pytest.mark.parametrize("input_data, expected", [\n ("2 16", 4),\n ("3 27", 3),\n ("2 20", 1),\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test edge cases\n@pytest.mark.parametrize("input_data, expected", [\n ("1 100", 100),\n ("1 1", 1),\n ("41 1000000", 1),\n ("100 2", 1),\n ("2 17", 1),\n ("3 19", 1),\n])\ndef test_edge_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test large numbers\n@pytest.mark.parametrize("input_data, expected", [\n ("2 1000000000000", 1000000),\n ("3 1000000000000", 10000),\n])\ndef test_large_numbers(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test invalid input format\n@pytest.mark.parametrize("input_data", [\n "2.5 10",\n "2 -10",\n "2",\n "two 10",\n "2, 10",\n])\ndef test_invalid_input(input_data):\n with pytest.raises(ValueError):\n problem_p03196(input_data)\n\n# Test boundary cases\n@pytest.mark.parametrize("input_data, expected", [\n ("40 1024", 2),\n ("2 1024", 32),\n])\ndef test_boundary_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test rare or unexpected edge cases\n@pytest.mark.parametrize("input_data, expected", [\n ("2 " + str(2**63 - 1), 1),\n ("3 " + str(2**63 - 1), 1),\n ("2 1", 1),\n ("10 1", 1),\n ("37 137**37", 137),\n ("37 138", 1),\n ("3 8", 2),\n ("5 32", 2),\n ("2 6", 2),\n ("2 120", 4),\n ("2 49", 7),\n ("2 77", 1),\n])\ndef test_unexpected_edge_cases(input_data, expected):\n assert problem_p03196(input_data) == expected\n\n# Test zero or negative n values\n@pytest.mark.parametrize("input_data", [\n "0 10",\n "-2 16",\n])\ndef test_zero_or_negative_n(input_data):\n with pytest.raises(ValueError):\n problem_p03196(input_data)'
],
2: [
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03861 import problem_p03861\n\n# unit tests\n\n# Test typical cases with different ranges and divisors\ndef test_typical_cases():\n assert problem_p03861("1 10 2") == 5\n assert problem_p03861("3 15 3") == 5\n assert problem_p03861("10 100 10") == 10\n\n# Test edge cases where a and b are the same\ndef test_edge_cases_same_a_b():\n assert problem_p03861("5 5 1") == 1\n assert problem_p03861("7 7 2") == 0\n assert problem_p03861("10 10 10") == 1\n\n# Test edge cases where a is a multiple of x\ndef test_edge_cases_a_multiple_of_x():\n assert problem_p03861("6 12 6") == 2\n assert problem_p03861("10 20 10") == 2\n\n# Test edge cases where b is a multiple of x\ndef test_edge_cases_b_multiple_of_x():\n assert problem_p03861("1 10 5") == 2\n assert problem_p03861("3 21 7") == 3\n\n# Test cases where a is greater than b\ndef test_cases_a_greater_than_b():\n assert problem_p03861("10 5 1") == 0\n assert problem_p03861("20 10 5") == 0\n\n# Test cases where x is larger than both a and b\ndef test_cases_x_larger_than_a_b():\n assert problem_p03861("1 5 10") == 0\n assert problem_p03861("2 8 20") == 0\n\n# Test cases with negative numbers and zero\ndef test_cases_negative_numbers_and_zero():\n assert problem_p03861("0 0 1") == 1\n assert problem_p03861("-5 5 5") == 3\n assert problem_p03861("-10 -1 2") == 5\n\n# Test cases with large numbers\ndef test_cases_large_numbers():\n assert problem_p03861("1000000 2000000 100000") == 20\n # This test may need to be adjusted for the expected number of multiples\n assert problem_p03861("123456789 987654321 12345") == 987654321 // 12345 - (123456789 - 1) // 12345\n\n# Test invalid input cases\ndef test_invalid_input_cases():\n with pytest.raises(ValueError):\n problem_p03861("")\n with pytest.raises(ValueError):\n problem_p03861("1 10")\n with pytest.raises(ValueError):\n problem_p03861("1 10 two")\n\n# Test rare or unexpected edge cases\ndef test_rare_edge_cases():\n with pytest.raises(ZeroDivisionError):\n problem_p03861("1 10 0")\n with pytest.raises(ZeroDivisionError):\n problem_p03861("5 5 0")\n assert problem_p03861("0 0 5") == 1\n assert problem_p03861("-10 -1 2") == 5\n assert problem_p03861("-15 -5 3") == 4\n assert problem_p03861("-10 -1 -2") == 5\n assert problem_p03861("-20 -10 -5") == 3\n assert problem_p03861("1 10 -1") == 10\n assert problem_p03861("5 15 -3") == 4\n assert problem_p03861("1000000000 2000000000 100000000") == 20\n assert problem_p03861("2147483647 2147483647 1") == 1\n assert problem_p03861("-2147483648 -2147483648 1") == 1\n assert problem_p03861("2147483647 2147483647 2147483647") == 1\n assert problem_p03861("7 7 7") == 1\n assert problem_p03861("123456 123456 123456") == 1\n assert problem_p03861("0 0 -5") == 1',
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03861 import problem_p03861\n\n# unit tests\n\n# Test typical cases with different ranges and divisors\ndef test_typical_cases():\n assert problem_p03861("1 10 2") == 5\n assert problem_p03861("3 15 3") == 5\n assert problem_p03861("10 100 10") == 10\n\n# Test edge cases where a and b are the same\ndef test_edge_cases_same_a_b():\n assert problem_p03861("5 5 1") == 1\n assert problem_p03861("7 7 2") == 0\n assert problem_p03861("10 10 10") == 1\n\n# Test edge cases where a is a multiple of x\ndef test_edge_cases_a_multiple_of_x():\n assert problem_p03861("6 12 6") == 2\n assert problem_p03861("10 20 10") == 2\n\n# Test edge cases where b is a multiple of x\ndef test_edge_cases_b_multiple_of_x():\n assert problem_p03861("1 10 5") == 2\n assert problem_p03861("3 21 7") == 3\n\n# Test cases where a is greater than b\ndef test_cases_a_greater_than_b():\n assert problem_p03861("10 5 1") == 0\n assert problem_p03861("20 10 5") == 0\n\n# Test cases where x is larger than both a and b\ndef test_cases_x_larger_than_a_b():\n assert problem_p03861("1 5 10") == 0\n assert problem_p03861("2 8 20") == 0\n\n# Test cases with negative numbers and zero\ndef test_cases_negative_numbers_and_zero():\n assert problem_p03861("0 0 1") == 1\n assert problem_p03861("-5 5 5") == 3\n assert problem_p03861("-10 -1 2") == 5\n\n# Test cases with large numbers\ndef test_cases_large_numbers():\n assert problem_p03861("1000000 2000000 100000") == 20\n # This test may need to be adjusted for the expected number of multiples\n assert problem_p03861("123456789 987654321 12345") == 987654321 // 12345 - (123456789 - 1) // 12345\n\n# Test invalid input cases\ndef test_invalid_input_cases():\n with pytest.raises(ValueError):\n problem_p03861("")\n with pytest.raises(ValueError):\n problem_p03861("1 10")\n with pytest.raises(ValueError):\n problem_p03861("1 10 two")\n\n# Test rare or unexpected edge cases\ndef test_rare_edge_cases():\n with pytest.raises(ZeroDivisionError):\n problem_p03861("1 10 0")\n with pytest.raises(ZeroDivisionError):\n problem_p03861("5 5 0")\n assert problem_p03861("0 0 5") == 1\n assert problem_p03861("-10 -1 2") == 5\n assert problem_p03861("-15 -5 3") == 4\n assert problem_p03861("-10 -1 -2") == 5\n assert problem_p03861("-20 -10 -5") == 3\n assert problem_p03861("1 10 -1") == 10\n assert problem_p03861("5 15 -3") == 4\n assert problem_p03861("1000000000 2000000000 100000000") == 20\n assert problem_p03861("2147483647 2147483647 1") == 1\n assert problem_p03861("-2147483648 -2147483648 1") == 1\n assert problem_p03861("2147483647 2147483647 2147483647") == 1\n assert problem_p03861("7 7 7") == 1\n assert problem_p03861("123456 123456 123456") == 1\n assert problem_p03861("0 0 -5") == 1'
],
3: [
"# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02910 import problem_p02910\n\n# unit tests\n\n# Test normal cases\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['S', 'A', 'M', 'P', 'L', 'E']\", \"Yes\"),\n (\"['T', 'E', 'S', 'T', 'I', 'N', 'G']\", \"Yes\"),\n (\"['L', 'A', 'M', 'P', 'L', 'E']\", \"No\"),\n (\"['T', 'E', 'S', 'R', 'I', 'N', 'G']\", \"No\")\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test edge cases\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"[]\", \"Yes\"),\n (\"['L']\", \"No\"),\n (\"['R']\", \"Yes\"),\n (\"['A', 'A', 'A', 'A']\", \"Yes\"),\n (\"['L', 'R', 'L', 'R']\", \"No\")\n])\ndef test_edge_cases(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test special characters and non-string elements\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['$', '2', '#', '4', '@', '6']\", \"Yes\"),\n (\"[1, 'R', None, 'L', True, 'F']\", \"Yes\")\n])\ndef test_special_characters_and_non_string_elements(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test large inputs\ndef test_large_input():\n large_input = \"['N', 'E', 'V', 'E', 'R', 'L', 'O', 'O', 'K', 'B', 'A', 'C', 'K'] * 1000\"\n assert problem_p02910(large_input) == \"Yes\"\n\n# Test invalid inputs\n@pytest.mark.parametrize(\"input_data\", [\n \"'This is not a list'\",\n \"[['A', 'B'], ['C', 'D']]\",\n \"__import__('os').system('rm -rf /')\"\n])\ndef test_invalid_inputs(input_data):\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)\n\n# Test unicode characters\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['😀', 'L', '🎉', 'R', '❤️', '🚀']\", \"No\"),\n (\"['こんにちは', '世界', '안녕하세요', '세계']\", \"Yes\")\n])\ndef test_unicode_characters(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test escaped characters\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['\\\\n', 'R', '\\\\t', 'L']\", \"Yes\"),\n (\"['A', '\\\\n', 'B', '\\\\r']\", \"Yes\")\n])\ndef test_escaped_characters(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test boolean values\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"[True, False, 'T', 'F']\", \"Yes\"),\n (\"[False, True, 'L', 'R']\", \"No\")\n])\ndef test_boolean_values(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test case sensitivity\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['l', 'R', 'm', 'r', 'n', 'L']\", \"Yes\"),\n (\"['L', 'r', 'M', 'R', 'N', 'l']\", \"No\")\n])\ndef test_case_sensitivity(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test input with only one character type\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['L', 'L', 'L', 'L']\", \"No\"),\n (\"['R', 'R', 'R', 'R']\", \"Yes\")\n])\ndef test_input_with_one_character_type(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test input with data structures other than lists\n@pytest.mark.parametrize(\"input_data\", [\n \"('A', 'B', 'C', 'D')\",\n \"{1, 2, 3, 4}\"\n])\ndef test_input_with_other_data_structures(input_data):\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)\n\n# Test input with nested quotes\ndef test_input_with_nested_quotes():\n input_data = \"'[\\\\'A\\\\', \\\\'B\\\\', \\\\'C\\\\']'\"\n assert problem_p02910(input_data) == \"Yes\"\n\n# Test input with dictionary\ndef test_input_with_dictionary():\n input_data = \"{'key1': 'value1', 'key2': 'value2'}\"\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)",
"# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02910 import problem_p02910\n\n# unit tests\n\n# Test normal cases\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['S', 'A', 'M', 'P', 'L', 'E']\", \"Yes\"),\n (\"['T', 'E', 'S', 'T', 'I', 'N', 'G']\", \"Yes\"),\n (\"['L', 'A', 'M', 'P', 'L', 'E']\", \"No\"),\n (\"['T', 'E', 'S', 'R', 'I', 'N', 'G']\", \"No\")\n])\ndef test_normal_cases(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test edge cases\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"[]\", \"Yes\"),\n (\"['L']\", \"No\"),\n (\"['R']\", \"Yes\"),\n (\"['A', 'A', 'A', 'A']\", \"Yes\"),\n (\"['L', 'R', 'L', 'R']\", \"No\")\n])\ndef test_edge_cases(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test special characters and non-string elements\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['$', '2', '#', '4', '@', '6']\", \"Yes\"),\n (\"[1, 'R', None, 'L', True, 'F']\", \"Yes\")\n])\ndef test_special_characters_and_non_string_elements(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test large inputs\ndef test_large_input():\n large_input = \"['N', 'E', 'V', 'E', 'R', 'L', 'O', 'O', 'K', 'B', 'A', 'C', 'K'] * 1000\"\n assert problem_p02910(large_input) == \"Yes\"\n\n# Test invalid inputs\n@pytest.mark.parametrize(\"input_data\", [\n \"'This is not a list'\",\n \"[['A', 'B'], ['C', 'D']]\",\n \"__import__('os').system('rm -rf /')\"\n])\ndef test_invalid_inputs(input_data):\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)\n\n# Test unicode characters\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['😀', 'L', '🎉', 'R', '❤️', '🚀']\", \"No\"),\n (\"['こんにちは', '世界', '안녕하세요', '세계']\", \"Yes\")\n])\ndef test_unicode_characters(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test escaped characters\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['\\\\n', 'R', '\\\\t', 'L']\", \"Yes\"),\n (\"['A', '\\\\n', 'B', '\\\\r']\", \"Yes\")\n])\ndef test_escaped_characters(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test boolean values\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"[True, False, 'T', 'F']\", \"Yes\"),\n (\"[False, True, 'L', 'R']\", \"No\")\n])\ndef test_boolean_values(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test case sensitivity\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['l', 'R', 'm', 'r', 'n', 'L']\", \"Yes\"),\n (\"['L', 'r', 'M', 'R', 'N', 'l']\", \"No\")\n])\ndef test_case_sensitivity(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test input with only one character type\n@pytest.mark.parametrize(\"input_data, expected\", [\n (\"['L', 'L', 'L', 'L']\", \"No\"),\n (\"['R', 'R', 'R', 'R']\", \"Yes\")\n])\ndef test_input_with_one_character_type(input_data, expected):\n assert problem_p02910(input_data) == expected\n\n# Test input with data structures other than lists\n@pytest.mark.parametrize(\"input_data\", [\n \"('A', 'B', 'C', 'D')\",\n \"{1, 2, 3, 4}\"\n])\ndef test_input_with_other_data_structures(input_data):\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)\n\n# Test input with nested quotes\ndef test_input_with_nested_quotes():\n input_data = \"'[\\\\'A\\\\', \\\\'B\\\\', \\\\'C\\\\']'\"\n assert problem_p02910(input_data) == \"Yes\"\n\n# Test input with dictionary\ndef test_input_with_dictionary():\n input_data = \"{'key1': 'value1', 'key2': 'value2'}\"\n with pytest.raises(SyntaxError):\n problem_p02910(input_data)"
],
4: [
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03965 import problem_p03965\n\n# unit tests\n\n# Test empty input\ndef test_empty_input():\n assert problem_p03965("") == 0\n\n# Test input with only "g" characters\n@pytest.mark.parametrize("input_data", ["g", "ggggg"])\ndef test_only_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test input with no "g" characters\n@pytest.mark.parametrize("input_data", ["abcdef", "pppp"])\ndef test_no_g_characters(input_data):\n assert problem_p03965(input_data) == -len(input_data)\n\n# Test input with alternating "g" and other characters\n@pytest.mark.parametrize("input_data", ["gpgpg", "gagbgc"])\ndef test_alternating_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test input with consecutive "g" characters surrounded by other characters\n@pytest.mark.parametrize("input_data", ["agggbc", "xggggy"])\ndef test_consecutive_g_characters(input_data):\n assert problem_p03965(input_data) == 1\n\n# Test input with leading and trailing whitespace\n@pytest.mark.parametrize("input_data, expected", [(" gggg ", 0), ("\\tgggg\\n", 0)])\ndef test_whitespace(input_data, expected):\n assert problem_p03965(input_data) == expected\n\n# Test input with only whitespace characters\n@pytest.mark.parametrize("input_data", [" ", "\\t\\n\\r"])\ndef test_only_whitespace(input_data):\n assert problem_p03965(input_data) == -len(input_data.rstrip())\n\n# Test input with "g" at the start, middle, and end\n@pytest.mark.parametrize("input_data", ["gabcgdefg", "gxyzg"])\ndef test_g_start_middle_end(input_data):\n assert problem_p03965(input_data) == 1\n\n# Test input with repeated non-"g" characters\n@pytest.mark.parametrize("input_data", ["ppppp", "xxxxx"])\ndef test_repeated_non_g_characters(input_data):\n assert problem_p03965(input_data) == -len(input_data)\n\n# Test input with mixed "g" and non-"g" characters\n@pytest.mark.parametrize("input_data", ["gpapbpcg", "ggppggpp"])\ndef test_mixed_g_non_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test long strings\ndef test_long_string():\n long_string = "g" * 10000\n assert problem_p03965(long_string) == 0\n\n# Test special characters and numbers\n@pytest.mark.parametrize("input_data", ["g#%&g123", "12g34g!@"])\ndef test_special_characters_and_numbers(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test Unicode characters\n@pytest.mark.parametrize("input_data", ["g日本語g", "🙂g🙃g"])\ndef test_unicode_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test case sensitivity\n@pytest.mark.parametrize("input_data", ["GgGg", "gGgG"])\ndef test_case_sensitivity(input_data):\n assert problem_p03965(input_data) == 0',
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p03965 import problem_p03965\n\n# unit tests\n\n# Test empty input\ndef test_empty_input():\n assert problem_p03965("") == 0\n\n# Test input with only "g" characters\n@pytest.mark.parametrize("input_data", ["g", "ggggg"])\ndef test_only_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test input with no "g" characters\n@pytest.mark.parametrize("input_data", ["abcdef", "pppp"])\ndef test_no_g_characters(input_data):\n assert problem_p03965(input_data) == -len(input_data)\n\n# Test input with alternating "g" and other characters\n@pytest.mark.parametrize("input_data", ["gpgpg", "gagbgc"])\ndef test_alternating_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test input with consecutive "g" characters surrounded by other characters\n@pytest.mark.parametrize("input_data", ["agggbc", "xggggy"])\ndef test_consecutive_g_characters(input_data):\n assert problem_p03965(input_data) == 1\n\n# Test input with leading and trailing whitespace\n@pytest.mark.parametrize("input_data, expected", [(" gggg ", 0), ("\\tgggg\\n", 0)])\ndef test_whitespace(input_data, expected):\n assert problem_p03965(input_data) == expected\n\n# Test input with only whitespace characters\n@pytest.mark.parametrize("input_data", [" ", "\\t\\n\\r"])\ndef test_only_whitespace(input_data):\n assert problem_p03965(input_data) == -len(input_data.rstrip())\n\n# Test input with "g" at the start, middle, and end\n@pytest.mark.parametrize("input_data", ["gabcgdefg", "gxyzg"])\ndef test_g_start_middle_end(input_data):\n assert problem_p03965(input_data) == 1\n\n# Test input with repeated non-"g" characters\n@pytest.mark.parametrize("input_data", ["ppppp", "xxxxx"])\ndef test_repeated_non_g_characters(input_data):\n assert problem_p03965(input_data) == -len(input_data)\n\n# Test input with mixed "g" and non-"g" characters\n@pytest.mark.parametrize("input_data", ["gpapbpcg", "ggppggpp"])\ndef test_mixed_g_non_g_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test long strings\ndef test_long_string():\n long_string = "g" * 10000\n assert problem_p03965(long_string) == 0\n\n# Test special characters and numbers\n@pytest.mark.parametrize("input_data", ["g#%&g123", "12g34g!@"])\ndef test_special_characters_and_numbers(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test Unicode characters\n@pytest.mark.parametrize("input_data", ["g日本語g", "🙂g🙃g"])\ndef test_unicode_characters(input_data):\n assert problem_p03965(input_data) == 0\n\n# Test case sensitivity\n@pytest.mark.parametrize("input_data", ["GgGg", "gGgG"])\ndef test_case_sensitivity(input_data):\n assert problem_p03965(input_data) == 0'
],
5: [
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02553 import problem_p02553\n\n# unit tests\n\n@pytest.mark.parametrize("input_data, expected", [\n ("1 2 3 4", 8),\n ("-1 -2 -3 -4", -3),\n ("1 -2 3 -4", 12),\n ("0 2 3 4", 8),\n ("1 0 3 4", 4),\n ("0 0 0 0", 0),\n ("-1 2 -3 4", 0),\n ("1 -2 3 -4", 0),\n ("-1 -2 3 4", -3),\n ("1 2 -3 -4", -6),\n ("1000000000 1000000000 1000000000 1000000000", 1000000000000000000),\n ("-1000000000 -1000000000 -1000000000 -1000000000", -1000000000000000000),\n ("2147483647 -2147483648 -2147483648 2147483647", 4611686014132420609),\n ("1000000000 1 1000000000 1", 1000000000),\n ("-1000000000 1 -1000000000 1", 1),\n ("-1 1 -1 1", 0),\n ("1 -1 1 -1", 0),\n ("1e3 2e3 3e3 4e3", 8000000),\n ("-1e3 -2e3 3e3 4e3", 3000000),\n])\ndef test_problem_p02553_normal_and_edge_cases(input_data, expected):\n # Test normal cases and various edge cases\n assert problem_p02553(input_data) == expected\n\n@pytest.mark.parametrize("input_data", [\n "1,2,3,4",\n "1 2 3",\n "one two three four",\n "\\t1\\t2\\t3\\t4",\n "1\\n2\\n3\\n4"\n])\ndef test_problem_p02553_invalid_input(input_data):\n # Test invalid inputs that should raise a ValueError\n with pytest.raises(ValueError):\n problem_p02553(input_data)',
'# imports\nimport pytest # used for our unit tests\nfrom pie_test_set.p02553 import problem_p02553\n\n# unit tests\n\n@pytest.mark.parametrize("input_data, expected", [\n ("1 2 3 4", 8),\n ("-1 -2 -3 -4", -3),\n ("1 -2 3 -4", 12),\n ("0 2 3 4", 8),\n ("1 0 3 4", 4),\n ("0 0 0 0", 0),\n ("-1 2 -3 4", 0),\n ("1 -2 3 -4", 0),\n ("-1 -2 3 4", -3),\n ("1 2 -3 -4", -6),\n ("1000000000 1000000000 1000000000 1000000000", 1000000000000000000),\n ("-1000000000 -1000000000 -1000000000 -1000000000", -1000000000000000000),\n ("2147483647 -2147483648 -2147483648 2147483647", 4611686014132420609),\n ("1000000000 1 1000000000 1", 1000000000),\n ("-1000000000 1 -1000000000 1", 1),\n ("-1 1 -1 1", 0),\n ("1 -1 1 -1", 0),\n ("1e3 2e3 3e3 4e3", 8000000),\n ("-1e3 -2e3 3e3 4e3", 3000000),\n])\ndef test_problem_p02553_normal_and_edge_cases(input_data, expected):\n # Test normal cases and various edge cases\n assert problem_p02553(input_data) == expected\n\n@pytest.mark.parametrize("input_data", [\n "1,2,3,4",\n "1 2 3",\n "one two three four",\n "\\t1\\t2\\t3\\t4",\n "1\\n2\\n3\\n4"\n])\ndef test_problem_p02553_invalid_input(input_data):\n # Test invalid inputs that should raise a ValueError\n with pytest.raises(ValueError):\n problem_p02553(input_data)'
],
},
"test_framework": {0: "pytest", 1: "pytest", 2: "pytest", 3: "pytest", 4: "pytest", 5: "pytest"},

View file

@ -14,12 +14,12 @@ def pickled_dataframe():
return df
@pytest.fixture()
@pytest.fixture
def pie4perf_sample_dataframe():
return DataFrame.from_dict(pie4perf_sample_dataframe_dict)
@pytest.fixture()
@pytest.fixture
def sample_dataframe():
return DataFrame(mock_dataframe)
@ -43,14 +43,7 @@ def test_calculate_validity_some_successful_runs(pie4perf_sample_dataframe):
def test_calculate_validity_all_successful_runs(pie4perf_sample_dataframe):
df = pie4perf_sample_dataframe
df["original_runtime"] = [1.0, 2.0, 2.0, 2.0, 3.0, 18.0] # All successful runs
df["best_correct_speedup_ratio"] = [
0.05,
0.06,
0.08,
0.09,
0.56,
0.17,
] # All above threshold
df["best_correct_speedup_ratio"] = [0.05, 0.06, 0.08, 0.09, 0.56, 0.17] # All above threshold
df["is_correct"] = [
{"1f39ef86-5eff-4760-a262-43011492906e": True},
{"aeeeca3b-4ccf-46eb-8dbf-c526c05fca27": True},
@ -68,14 +61,7 @@ def test_calculate_validity_with_valid_candidates(pie4perf_sample_dataframe):
df = pie4perf_sample_dataframe
# Assuming some runs are successful and some candidates are valid
df["original_runtime"] = [1.0, None, 2.0, None, 3.0, None] # Some successful runs
df["best_correct_speedup_ratio"] = [
0.04,
None,
0.06,
None,
0.07,
None,
] # One below and two above the threshold
df["best_correct_speedup_ratio"] = [0.04, None, 0.06, None, 0.07, None] # One below and two above the threshold
df["is_correct"] = [
{"1f39ef86-5eff-4760-a262-43011492906e": True},
{},

View file

@ -7,19 +7,20 @@
"metadata": {
"collapsed": true
},
"source": [
""
],
"outputs": []
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 2,
"id": "43aad670ab761737",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-13T01:20:22.380492Z",
"start_time": "2024-04-13T01:20:17.939300Z"
}
},
"cell_type": "code",
"outputs": [],
"source": [
"import json\n",
"\n",
@ -27,43 +28,38 @@
"from openai import OpenAI\n",
"\n",
"\n",
"@weave.op() # 🐝\n",
"@weave.op() # 🐝\n",
"def extract_fruit(sentence: str) -> dict:\n",
" client = OpenAI()\n",
"\n",
" response = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo-1106\",\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You will be provided with unstructured data, and your task is to parse it one JSON dictionary with fruit, color and flavor as keys.\"\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": sentence\n",
" }\n",
" model=\"gpt-3.5-turbo-1106\",\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You will be provided with unstructured data, and your task is to parse it one JSON dictionary with fruit, color and flavor as keys.\",\n",
" },\n",
" {\"role\": \"user\", \"content\": sentence},\n",
" ],\n",
" temperature=0.7,\n",
" response_format={ \"type\": \"json_object\" }\n",
" response_format={\"type\": \"json_object\"},\n",
" )\n",
" extracted = response.choices[0].message.content\n",
" return json.loads(extracted)\n",
"\n",
"weave.init('intro-example') # 🐝\n",
"\n",
"weave.init(\"intro-example\") # 🐝\n",
"sentence = \"There are many fruits that were found on the recently discovered planet Goocrux. There are neoskizzles that grow there, which are purple and taste like candy.\"\n",
"extract_fruit(sentence)"
],
"id": "43aad670ab761737",
"execution_count": 2,
"outputs": []
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"source": "",
"id": "d929002a32073a23",
"outputs": []
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View file

@ -12,12 +12,13 @@ packages = [
]
keywords = ["codeflash", "performance", "optimization", "ai", "code", "machine learning", "LLM"]
# Versions here the minimum required versions for the project. These should be as loose as possible.
[tool.poetry.dependencies]
python = "^3.9"
unidiff = ">=0.7.4"
pytest = ">=7.0.0"
gitpython = ">=3.1.31"
libcst = ">=1.5.0"
libcst = ">=1.0.1"
jedi = ">=0.19.1"
tiktoken = ">=0.3.2"
timeout-decorator = ">=0.5.0"
@ -37,6 +38,15 @@ returns = ">=0.23"
isort = ">=5.11.0"
dill = "^0.3.8"
rich = "^13.8.1"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
ipython = "^8.12.0"
mypy = ">=1.13"
ruff = ">=0.7.0"
pandas-stubs = ">=2.2.2.240807, <2.2.3.241009"
types-Pygments = "^2.18.0.20240506"
types-colorama = "^0.4.15.20240311"
@ -47,14 +57,6 @@ types-six = "^1.16.21.20241009"
types-cffi = "^1.16.0.20240331"
types-openpyxl = "^3.1.5.20241020"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
ipython = "^8.12.0"
mypy = ">=1.13"
ruff = ">=0.7.0"
[tool.poetry.build]
script = "codeflash/update_license_version.py"

View file

@ -45,16 +45,11 @@ def main():
improvement_pct = int(re.search(r"📈 ([\d,]+)% improvement", stdout).group(1).replace(",", ""))
improvement_x = float(improvement_pct) / 100
assert (
improvement_pct > 300
), f"Performance improvement percentage was {improvement_pct}, which was not above 300%"
assert improvement_pct > 300, f"Performance improvement percentage was {improvement_pct}, which was not above 300%"
assert improvement_x > 3, f"Performance improvement rate was {improvement_x}x, which was not above 3x"
# Check for the line indicating the number of discovered existing unit tests
unit_test_search = re.search(
r"Discovered (\d+) existing unit tests",
stdout,
)
unit_test_search = re.search(r"Discovered (\d+) existing unit tests", stdout)
num_unit_tests = int(unit_test_search.group(1))
assert num_unit_tests > 0, "Could not find existing unit tests"

View file

@ -45,16 +45,11 @@ def main():
improvement_pct = int(re.search(r"📈 ([\d,]+)% improvement", stdout).group(1).replace(",", ""))
improvement_x = float(improvement_pct) / 100
assert (
improvement_pct > 300
), f"Performance improvement percentage was {improvement_pct}, which was not above 300%"
assert improvement_pct > 300, f"Performance improvement percentage was {improvement_pct}, which was not above 300%"
assert improvement_x > 3, f"Performance improvement rate was {improvement_x}x, which was not above 3x"
# Check for the line indicating the number of discovered existing unit tests
unit_test_search = re.search(
r"Discovered (\d+) existing unit tests",
stdout,
)
unit_test_search = re.search(r"Discovered (\d+) existing unit tests", stdout)
num_unit_tests = int(unit_test_search.group(1))
assert num_unit_tests > 0, "Could not find existing unit tests"

View file

@ -6,26 +6,12 @@ import subprocess
def main():
cwd = (
pathlib.Path(__file__).parent.parent.parent
/ "code_to_optimize"
/ "code_directories"
/ "futurehouse_structure"
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "futurehouse_structure"
).resolve()
print("cwd", cwd)
command = [
"python",
"../../../codeflash/main.py",
"--file",
"src/aviary/common_tags.py",
"--no-pr",
]
command = ["python", "../../../codeflash/main.py", "--file", "src/aviary/common_tags.py", "--no-pr"]
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
cwd=str(cwd),
env=os.environ.copy(),
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy()
)
output = []
@ -41,16 +27,11 @@ def main():
improvement_pct = int(re.search(r"📈 ([\d,]+)% improvement", stdout).group(1).replace(",", ""))
improvement_x = float(improvement_pct) / 100
assert (
improvement_pct > 5
), f"Performance improvement percentage was {improvement_pct}, which was not above 10%"
assert improvement_pct > 5, f"Performance improvement percentage was {improvement_pct}, which was not above 10%"
assert improvement_x > 0.1, f"Performance improvement rate was {improvement_x}x, which was not above 0.1x"
# Check for the line indicating the number of discovered existing unit tests
unit_test_search = re.search(
r"Discovered (\d+) existing unit tests",
stdout,
)
unit_test_search = re.search(r"Discovered (\d+) existing unit tests", stdout)
num_unit_tests = int(unit_test_search.group(1))
assert num_unit_tests == 2, "Could not find existing unit tests"

View file

@ -45,18 +45,11 @@ def main():
improvement_pct = int(re.search(r"📈 ([\d,]+)% improvement", stdout).group(1).replace(",", ""))
improvement_x = float(improvement_pct) / 100
assert (
improvement_pct > 5
), f"Performance improvement percentage was {improvement_pct}, which was not above 5%"
assert (
improvement_x > 0.05
), f"Performance improvement rate was {improvement_x}x, which was not above 0.05x"
assert improvement_pct > 5, f"Performance improvement percentage was {improvement_pct}, which was not above 5%"
assert improvement_x > 0.05, f"Performance improvement rate was {improvement_x}x, which was not above 0.05x"
# Check for the line indicating the number of discovered existing unit tests
unit_test_search = re.search(
r"Discovered (\d+) existing unit tests",
stdout,
)
unit_test_search = re.search(r"Discovered (\d+) existing unit tests", stdout)
num_unit_tests = int(unit_test_search.group(1))
assert num_unit_tests > 0, "Could not find existing unit tests"

View file

@ -51,13 +51,7 @@ class Source:
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
project_root = Path("/home/roger/repos/codeflash")
new_module = add_needed_imports_from_module(
src_module,
dst_module,
src_path,
dst_path,
project_root,
)
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
assert new_module == expected
@ -125,11 +119,5 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
project_root = Path("/home/roger/repos/codeflash")
new_module = add_needed_imports_from_module(
src_module,
dst_module,
src_path,
dst_path,
project_root,
)
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
assert new_module == expected

View file

@ -74,7 +74,7 @@ print("Hello world")
function_name: str = "NewClass.new_function"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
("new_function", [FunctionParent(name="NewClass", type="ClassDef")]),
("new_function", [FunctionParent(name="NewClass", type="ClassDef")])
]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_and_add_imports(
@ -136,10 +136,7 @@ print("Hello world")
"""
function_name: str = "NewClass.new_function"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
("new_function", []),
("other_function", []),
]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("new_function", []), ("other_function", [])]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -605,10 +602,7 @@ class CacheConfig(BaseConfig):
"""
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
parents = [FunctionParent(name="CacheConfig", type="ClassDef")]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
("__init__", parents),
("from_config", parents),
]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("from_config", parents)]
contextual_functions: set[tuple[str, str]] = {
("CacheSimilarityEvalConfig", "__init__"),
@ -687,7 +681,7 @@ def test_test_libcst_code_replacement8() -> None:
'''
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")]),
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")])
]
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_and_add_imports(
@ -745,14 +739,8 @@ print("Hello world")
"""
parents = [FunctionParent(name="NewClass", type="ClassDef")]
function_name: str = "NewClass.__init__"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
("__init__", parents),
("__call__", parents),
]
contextual_functions: set[tuple[str, str]] = {
("NewClass", "__init__"),
("NewClass", "__call__"),
}
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("__call__", parents)]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__"), ("NewClass", "__call__")}
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
@ -814,18 +802,14 @@ class MainClass:
pytest_cmd="pytest",
experiment_id=None,
test_project_root=file_path.parent.resolve(),
),
)
)
func_top_optimize = FunctionToOptimize(
function_name="main_method",
file_path=file_path,
parents=[FunctionParent("MainClass", "ClassDef")],
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
)
original_code = file_path.read_text()
code_context = opt.get_code_optimization_context(
function_to_optimize=func_top_optimize,
project_root=file_path.parent,
original_source_code=original_code,
function_to_optimize=func_top_optimize, project_root=file_path.parent, original_source_code=original_code
).unwrap()
assert code_context.code_to_optimize_with_helpers == get_code_output
@ -1144,14 +1128,14 @@ class TestResults(BaseModel):
helper_functions = [
FakeFunctionSource(
file_path=Path(
"/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py",
"/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py"
),
qualified_name="TestType",
fully_qualified_name="codeflash.verification.test_results.TestType",
only_function_name="TestType",
source_code="",
jedi_definition=JediDefinition(type="class"),
),
)
]
new_code: str = replace_functions_and_add_imports(
@ -1168,13 +1152,8 @@ class TestResults(BaseModel):
helper_functions_by_module_abspath = defaultdict(set)
for helper_function in helper_functions:
if helper_function.jedi_definition.type != "class":
helper_functions_by_module_abspath[helper_function.file_path].add(
helper_function.qualified_name,
)
for (
module_abspath,
qualified_names,
) in helper_functions_by_module_abspath.items():
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
new_code: str = replace_functions_and_add_imports(
source_code=new_code,
function_names=list(qualified_names),
@ -1447,13 +1426,8 @@ def cosine_similarity_top_k(
helper_functions_by_module_abspath = defaultdict(set)
for helper_function in helper_functions:
if helper_function.jedi_definition.type != "class":
helper_functions_by_module_abspath[helper_function.file_path].add(
helper_function.qualified_name,
)
for (
module_abspath,
qualified_names,
) in helper_functions_by_module_abspath.items():
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
new_helper_code: str = replace_functions_and_add_imports(
source_code=new_code,
function_names=list(qualified_names),

View file

@ -96,10 +96,7 @@ def test_get_imports_from_file_with_syntax_error(caplog) -> None:
def test_get_imports_from_file_with_no_input() -> None:
with pytest.raises(
AssertionError,
match="Must provide exactly one of file_path, file_string, or file_ast",
):
with pytest.raises(AssertionError, match="Must provide exactly one of file_path, file_string, or file_ast"):
get_imports_from_file()

View file

@ -9,12 +9,7 @@ from returns.result import Failure, Success
from codeflash.verification.comparator import comparator
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.test_results import (
FunctionTestInvocation,
InvocationId,
TestResults,
TestType,
)
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
def test_basic_python_objects():
@ -377,22 +372,13 @@ def test_pandas():
assert not comparator(c, ca)
ak = pd.DataFrame(
{
"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)],
"b": [4, 5],
},
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]}
)
al = pd.DataFrame(
{
"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)],
"b": [4, 5],
},
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]}
)
am = pd.DataFrame(
{
"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 3)],
"b": [4, 5],
},
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 3)], "b": [4, 5]}
)
assert comparator(ak, al)
assert not comparator(ak, am)
@ -693,8 +679,8 @@ def test_compare_results_fn():
return_value=5,
timed_out=False,
loop_index=1,
),
],
)
]
)
new_results_1 = TestResults(
@ -715,8 +701,8 @@ def test_compare_results_fn():
return_value=5,
timed_out=False,
loop_index=1,
),
],
)
]
)
assert compare_test_results(original_results, new_results_1)
@ -739,8 +725,8 @@ def test_compare_results_fn():
return_value=[5],
timed_out=False,
loop_index=1,
),
],
)
]
)
assert not compare_test_results(original_results, new_results_2)
@ -781,7 +767,7 @@ def test_compare_results_fn():
timed_out=False,
loop_index=1,
),
],
]
)
assert compare_test_results(original_results, new_results_3)

View file

@ -3,45 +3,24 @@ import os
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.models.models import OptimizedCandidateResult
from codeflash.result.critic import quantity_of_tests_critic, speedup_critic
from codeflash.verification.test_results import (
FunctionTestInvocation,
InvocationId,
TestResults,
TestType,
)
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
def test_speedup_critic():
original_code_runtime = 1000
best_runtime_until_now = 1000
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=800,
best_test_results=TestResults(),
)
candidate_result = OptimizedCandidateResult(times_run=5, best_test_runtime=800, best_test_results=TestResults())
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now) # 20% improvement
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=940,
best_test_results=TestResults(),
)
candidate_result = OptimizedCandidateResult(times_run=5, best_test_runtime=940, best_test_results=TestResults())
assert not speedup_critic(
candidate_result,
original_code_runtime,
best_runtime_until_now,
) # 6% improvement
assert not speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now) # 6% improvement
original_code_runtime = 100000
best_runtime_until_now = 100000
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=94000,
best_test_results=TestResults(),
)
candidate_result = OptimizedCandidateResult(times_run=5, best_test_runtime=94000, best_test_results=TestResults())
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now) # 6% improvement
@ -140,9 +119,7 @@ def test_generated_test_critic():
test_results = [test_1, test_2, test_3]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert quantity_of_tests_critic(candidate_result)
@ -150,9 +127,7 @@ def test_generated_test_critic():
test_results = [test_1, test_3]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert quantity_of_tests_critic(candidate_result)
@ -160,9 +135,7 @@ def test_generated_test_critic():
test_results = [test_1, test_3, test_4]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert quantity_of_tests_critic(candidate_result)
@ -170,9 +143,7 @@ def test_generated_test_critic():
test_results = [test_1]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert not quantity_of_tests_critic(candidate_result)
@ -180,9 +151,7 @@ def test_generated_test_critic():
test_results = [test_1, test_2]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert quantity_of_tests_critic(candidate_result)
@ -190,9 +159,7 @@ def test_generated_test_critic():
test_results = [test_1, test_4]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert not quantity_of_tests_critic(candidate_result)
@ -200,9 +167,7 @@ def test_generated_test_critic():
test_results = [test_4, test_5]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert quantity_of_tests_critic(candidate_result)
@ -210,9 +175,7 @@ def test_generated_test_critic():
test_results = [test_1, test_2, test_3, test_4, test_5]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert quantity_of_tests_critic(candidate_result)
@ -222,9 +185,7 @@ def test_generated_test_critic():
test_results = [test_1, test_2, test_3]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert not quantity_of_tests_critic(candidate_result)
@ -232,9 +193,7 @@ def test_generated_test_critic():
test_results = [test_1, test_2, test_3, test_4]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert not quantity_of_tests_critic(candidate_result)
@ -242,9 +201,7 @@ def test_generated_test_critic():
test_results = [test_1, test_2, test_3, test_5]
candidate_result = OptimizedCandidateResult(
times_run=5,
best_test_runtime=100,
best_test_results=TestResults(test_results=test_results),
times_run=5, best_test_runtime=100, best_test_results=TestResults(test_results=test_results)
)
assert quantity_of_tests_critic(candidate_result)

View file

@ -38,10 +38,7 @@ def test_sort_imports_without_formatting():
tmp.flush()
tmp_path = Path(tmp.name)
new_code = format_code(
formatter_cmds=["disabled"],
path=tmp_path,
)
new_code = format_code(formatter_cmds=["disabled"], path=tmp_path)
assert new_code is not None
new_code = sort_imports(new_code)
assert new_code == "import os\nimport sys\nimport unittest\n"
@ -136,10 +133,7 @@ def foo():
tmp.flush()
tmp_path = tmp.name
actual = format_code(
formatter_cmds=["black $file"],
path=Path(tmp_path),
)
actual = format_code(formatter_cmds=["black $file"], path=Path(tmp_path))
assert actual == expected
@ -161,10 +155,7 @@ def foo():
tmp.flush()
tmp_path = tmp.name
actual = format_code(
formatter_cmds=["black $file"],
path=Path(tmp_path),
)
actual = format_code(formatter_cmds=["black $file"], path=Path(tmp_path))
assert actual == expected
@ -191,7 +182,6 @@ def foo():
tmp_path = tmp.name
actual = format_code(
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"],
path=Path(tmp_path),
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=Path(tmp_path)
)
assert actual == expected

View file

@ -3,10 +3,11 @@ from argparse import Namespace
from dataclasses import dataclass
import pytest
from returns.pipeline import is_successful
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
from codeflash.optimization.function_context import get_function_variables_definitions
from codeflash.optimization.optimizer import Optimizer
from returns.pipeline import is_successful
def calculate_something(data):
@ -20,8 +21,7 @@ def simple_function_with_one_dep(data):
def test_simple_dependencies() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("simple_function_with_one_dep", str(file_path), []),
str(file_path.parent.resolve()),
FunctionToOptimize("simple_function_with_one_dep", str(file_path), []), str(file_path.parent.resolve())
)[0]
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something"
@ -82,8 +82,7 @@ def test_multiple_classes_dependencies() -> None:
# TODO: Check if C.run only gets calculate_something_3 as dependency and likewise for other classes
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]),
str(file_path.parent.resolve()),
FunctionToOptimize("run", str(file_path), [FunctionParent("C", "ClassDef")]), str(file_path.parent.resolve())
)
# assert len(helper_functions) == 2
@ -103,8 +102,7 @@ def recursive_dependency_1(num):
def test_recursive_dependency() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("recursive_dependency_1", str(file_path), []),
str(file_path.parent.resolve()),
FunctionToOptimize("recursive_dependency_1", str(file_path), []), str(file_path.parent.resolve())
)[0]
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.calculate_something"
@ -127,14 +125,11 @@ def simple_function_with_one_dep_ann(data: MyData):
def test_simple_dependencies_ann() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("simple_function_with_one_dep_ann", str(file_path), []),
str(file_path.parent.resolve()),
FunctionToOptimize("simple_function_with_one_dep_ann", str(file_path), []), str(file_path.parent.resolve())
)[0]
assert len(helper_functions) == 2
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData"
assert (
helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something_ann"
)
assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something_ann"
from collections import defaultdict
@ -180,7 +175,7 @@ def test_class_method_dependencies() -> None:
pytest_cmd="pytest",
experiment_id=None,
test_project_root=file_path.parent.resolve(),
),
)
)
function_to_optimize = FunctionToOptimize(
function_name="topologicalSort",
@ -191,11 +186,7 @@ def test_class_method_dependencies() -> None:
)
with open(file_path) as f:
original_code = f.read()
ctx_result = opt.get_code_optimization_context(
function_to_optimize,
opt.args.project_root,
original_code,
)
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
@ -207,8 +198,7 @@ def test_class_method_dependencies() -> None:
)
assert code_context.helper_functions[0].jedi_definition.name == "topologicalSortUtil"
assert (
code_context.helper_functions[0].fully_qualified_name
== "test_function_dependencies.Graph.topologicalSortUtil"
code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.Graph.topologicalSortUtil"
)
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
assert code_context.contextual_dunder_methods == {("Graph", "__init__")}
@ -262,8 +252,7 @@ def simple_function_with_decorator_dep(data):
def test_decorator_dependencies() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("simple_function_with_decorator_dep", str(file_path), []),
str(file_path.parent.resolve()),
FunctionToOptimize("simple_function_with_decorator_dep", str(file_path), []), str(file_path.parent.resolve())
)[0]
assert len(helper_functions) == 2
assert {helper_functions[0][0].definition.full_name, helper_functions[1][0].definition.full_name} == {
@ -283,7 +272,7 @@ def test_recursive_function_context() -> None:
pytest_cmd="pytest",
experiment_id=None,
test_project_root=file_path.parent.resolve(),
),
)
)
function_to_optimize = FunctionToOptimize(
function_name="recursive",
@ -294,21 +283,14 @@ def test_recursive_function_context() -> None:
)
with open(file_path) as f:
original_code = f.read()
ctx_result = opt.get_code_optimization_context(
function_to_optimize,
opt.args.project_root,
original_code,
)
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
# The code_context above should have the topologicalSortUtil function in it
assert len(code_context.helper_functions) == 2
assert set(
[
code_context.helper_functions[1].fully_qualified_name,
code_context.helper_functions[0].fully_qualified_name,
],
[code_context.helper_functions[1].fully_qualified_name, code_context.helper_functions[0].fully_qualified_name]
) == set(["test_function_dependencies.C.calculate_something_3", "test_function_dependencies.C.recursive"])
assert (
code_context.code_to_optimize_with_helpers

View file

@ -58,41 +58,26 @@ class AirbyteEntrypoint(object):
return AirbyteEntrypoint.handle_record_counts(num)
def non_classmethod_function(cls, name):
return cls.name
""",
"""
)
f.flush()
path_obj_name = Path(f.name)
assert inspect_top_level_functions_or_methods(path_obj_name, "functionA").is_top_level
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionB").is_top_level
assert inspect_top_level_functions_or_methods(path_obj_name, "functionC", class_name="A").is_top_level
assert not inspect_top_level_functions_or_methods(
path_obj_name,
"functionD",
class_name="A",
).is_top_level
assert not inspect_top_level_functions_or_methods(
path_obj_name,
"functionF",
class_name="E",
).is_top_level
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionD", class_name="A").is_top_level
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionF", class_name="E").is_top_level
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA").has_args
staticmethod_func = inspect_top_level_functions_or_methods(
path_obj_name,
"handle_record_counts",
class_name=None,
line_no=15,
path_obj_name, "handle_record_counts", class_name=None, line_no=15
)
assert staticmethod_func.is_staticmethod
assert staticmethod_func.staticmethod_class_name == "AirbyteEntrypoint"
assert inspect_top_level_functions_or_methods(
path_obj_name,
"functionE",
class_name="AirbyteEntrypoint",
path_obj_name, "functionE", class_name="AirbyteEntrypoint"
).is_classmethod
assert not inspect_top_level_functions_or_methods(
path_obj_name,
"non_classmethod_function",
class_name="AirbyteEntrypoint",
path_obj_name, "non_classmethod_function", class_name="AirbyteEntrypoint"
).is_top_level
# needed because this will be traced with a class_name being passed
@ -111,7 +96,7 @@ class X:
def functionB():
return False
def functionA():
return True""",
return True"""
)
f.flush()
test_config = TestConfig(

View file

@ -29,9 +29,7 @@ def test_get_code_property() -> None:
f.flush()
new_code, contextual_dunder_methods = get_code(
[
FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")]),
],
[FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")])]
)
assert new_code == code
assert contextual_dunder_methods == {("TestClass", "__init__")}
@ -60,7 +58,7 @@ class TestClass:
f.flush()
new_code, contextual_dunder_methods = get_code(
[FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")])],
[FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")])]
)
assert new_code == expected
assert contextual_dunder_methods == {("TestClass", "__init__")}
@ -111,13 +109,10 @@ class BubbleSortClass:
f.flush()
new_code, contextual_dunder_methods = get_code(
[FunctionToOptimize("sorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")])],
[FunctionToOptimize("sorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")])]
)
assert new_code == expected
assert contextual_dunder_methods == {
("BubbleSortClass", "__init__"),
("BubbleSortClass", "__call__"),
}
assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")}
def test_get_code_indent() -> None:
@ -177,23 +172,12 @@ def non():
f.flush()
new_code, contextual_dunder_methods = get_code(
[
FunctionToOptimize(
"sorter",
f.name,
[FunctionParent("BubbleSortClass", "ClassDef")],
),
FunctionToOptimize(
"helper",
f.name,
[FunctionParent("BubbleSortClass", "ClassDef")],
),
],
FunctionToOptimize("sorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")]),
FunctionToOptimize("helper", f.name, [FunctionParent("BubbleSortClass", "ClassDef")]),
]
)
assert new_code == expected
assert contextual_dunder_methods == {
("BubbleSortClass", "__init__"),
("BubbleSortClass", "__call__"),
}
assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")}
expected2 = """class BubbleSortClass:
def __init__(self):
@ -218,28 +202,13 @@ def non():
f.flush()
new_code, contextual_dunder_methods = get_code(
[
FunctionToOptimize(
"sorter",
f.name,
[FunctionParent("BubbleSortClass", "ClassDef")],
),
FunctionToOptimize(
"helper",
f.name,
[FunctionParent("BubbleSortClass", "ClassDef")],
),
FunctionToOptimize(
"unsorter",
f.name,
[FunctionParent("BubbleSortClass", "ClassDef")],
),
],
FunctionToOptimize("sorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")]),
FunctionToOptimize("helper", f.name, [FunctionParent("BubbleSortClass", "ClassDef")]),
FunctionToOptimize("unsorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")]),
]
)
assert new_code == expected2
assert contextual_dunder_methods == {
("BubbleSortClass", "__init__"),
("BubbleSortClass", "__call__"),
}
assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")}
def test_get_code_multiline_class_def() -> None:
@ -275,8 +244,8 @@ def test_get_code_multiline_class_def() -> None:
"computeStatement",
f.name,
[FunctionParent("StatementAssignmentVariableConstantMutable", "ClassDef")],
),
],
)
]
)
assert new_code == expected
assert contextual_dunder_methods == set()
@ -296,13 +265,7 @@ class CustomDataClass:
# single FunctionToOptimize instance, in the case where that instance has been filtered to represent a function
# (with a definition).
new_code, contextual_dunder_methods = get_code(
[
FunctionToOptimize(
"name",
f.name,
[FunctionParent("CustomDataClass", "ClassDef")],
),
],
[FunctionToOptimize("name", f.name, [FunctionParent("CustomDataClass", "ClassDef")])]
)
assert new_code is None
assert contextual_dunder_methods == set()

View file

@ -3,9 +3,10 @@ from argparse import Namespace
from pathlib import Path
import pytest
from returns.pipeline import is_successful
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
from codeflash.optimization.optimizer import Optimizer
from returns.pipeline import is_successful
class HelperClass:
@ -28,22 +29,14 @@ def test_get_outside_method_helper() -> None:
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
),
)
)
function_to_optimize = FunctionToOptimize(
function_name="OptimizeMe",
file_path=file_path,
parents=[],
starting_line=None,
ending_line=None,
function_name="OptimizeMe", file_path=file_path, parents=[], starting_line=None, ending_line=None
)
with open(file_path) as f:
original_code = f.read()
ctx_result = opt.get_code_optimization_context(
function_to_optimize,
opt.args.project_root,
original_code,
)
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
@ -229,7 +222,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
),
)
)
function_to_optimize = FunctionToOptimize(
function_name="__call__",
@ -240,11 +233,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
)
with open(file_path) as f:
original_code = f.read()
ctx_result = opt.get_code_optimization_context(
function_to_optimize,
opt.args.project_root,
original_code,
)
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
@ -360,22 +349,14 @@ def test_bubble_sort_deps() -> None:
pytest_cmd="pytest",
experiment_id=None,
test_project_root=file_path.parent.resolve(),
),
)
)
function_to_optimize = FunctionToOptimize(
function_name="sorter_deps",
file_path=file_path,
parents=[],
starting_line=None,
ending_line=None,
function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None
)
with open(file_path) as f:
original_code = f.read()
ctx_result = opt.get_code_optimization_context(
function_to_optimize,
opt.args.project_root,
original_code,
)
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
@ -402,7 +383,4 @@ def sorter_deps(arr):
code_context.helper_functions[0].fully_qualified_name
== "code_to_optimize.bubble_sort_dep1_helper.dep1_comparer"
)
assert (
code_context.helper_functions[1].fully_qualified_name
== "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"
)
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"

View file

@ -2,11 +2,8 @@ import unittest
from unittest.mock import patch
import git
from codeflash.code_utils.git_utils import (
check_and_push_branch,
check_running_in_git_repo,
get_repo_owner_and_name,
)
from codeflash.code_utils.git_utils import check_and_push_branch, check_running_in_git_repo, get_repo_owner_and_name
class TestGitUtils(unittest.TestCase):
@ -44,12 +41,7 @@ class TestGitUtils(unittest.TestCase):
@patch("codeflash.code_utils.git_utils.git.Repo")
@patch("codeflash.code_utils.git_utils.sys.__stdin__.isatty", return_value=True)
@patch("codeflash.code_utils.git_utils.confirm_proceeding_with_no_git_repo", return_value=True)
def test_check_running_in_git_repo_not_in_git_repo_interactive(
self,
mock_confirm,
mock_isatty,
mock_repo,
):
def test_check_running_in_git_repo_not_in_git_repo_interactive(self, mock_confirm, mock_isatty, mock_repo):
mock_repo.side_effect = git.InvalidGitRepositoryError # type: ignore
assert check_running_in_git_repo("/path/to/non-repo") == False

View file

@ -102,11 +102,7 @@ class TestPigLatin(unittest.TestCase):
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
func = FunctionToOptimize(
function_name="sorter",
parents=[],
file_path=Path("module.py"),
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
os.chdir(run_cwd)
@ -120,8 +116,7 @@ class TestPigLatin(unittest.TestCase):
os.chdir(original_cwd)
assert success
assert new_test == expected.format(
module_path=Path(f.name).name,
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
)
@ -204,26 +199,17 @@ def test_prepare_image_for_yolo():
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
func = FunctionToOptimize(
function_name="prepare_image_for_yolo",
parents=[],
file_path=Path("module.py"),
)
func = FunctionToOptimize(function_name="prepare_image_for_yolo", parents=[], file_path=Path("module.py"))
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
Path(f.name),
[CodePosition(10, 14)],
func,
Path(f.name).parent,
"pytest",
Path(f.name), [CodePosition(10, 14)], func, Path(f.name).parent, "pytest"
)
os.chdir(original_cwd)
assert success
assert new_test == expected.format(
module_path=Path(f.name).name,
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
)
@ -301,18 +287,10 @@ def test_sort():
project_root_path = (Path(__file__).parent / "..").resolve()
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
func = FunctionToOptimize(
function_name="sorter",
parents=[],
file_path=Path("module.py"),
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(6, 13), CodePosition(10, 13)],
func,
project_root_path,
"pytest",
test_path, [CodePosition(6, 13), CodePosition(10, 13)], func, project_root_path, "pytest"
)
os.chdir(original_cwd)
assert success
@ -335,7 +313,7 @@ def test_sort():
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
),
)
)
test_env = os.environ.copy()
@ -343,13 +321,7 @@ def test_sort():
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.EXISTING_UNIT_TEST
test_files = TestFiles(
test_files=[
TestFile(
instrumented_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
),
],
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
)
test_results = opt.run_and_parse_tests(
test_env=test_env,
@ -461,18 +433,10 @@ def test_sort_parametrized(input, expected_output):
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
func = FunctionToOptimize(
function_name="sorter",
parents=[],
file_path=Path("module.py"),
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(14, 13)],
func,
project_root_path,
"pytest",
test_path, [CodePosition(14, 13)], func, project_root_path, "pytest"
)
os.chdir(original_cwd)
assert success
@ -490,13 +454,7 @@ def test_sort_parametrized(input, expected_output):
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_type = TestType.EXISTING_UNIT_TEST
test_files = TestFiles(
test_files=[
TestFile(
instrumented_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
),
],
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
)
opt = Optimizer(
@ -508,7 +466,7 @@ def test_sort_parametrized(input, expected_output):
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
),
)
)
test_results = opt.run_and_parse_tests(
test_env=test_env,
@ -635,18 +593,10 @@ def test_sort_parametrized_loop(input, expected_output):
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
func = FunctionToOptimize(
function_name="sorter",
parents=[],
file_path=Path("module.py"),
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(15, 17)],
func,
project_root_path,
"pytest",
test_path, [CodePosition(15, 17)], func, project_root_path, "pytest"
)
os.chdir(original_cwd)
assert success
@ -664,13 +614,7 @@ def test_sort_parametrized_loop(input, expected_output):
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_type = TestType.EXISTING_UNIT_TEST
test_files = TestFiles(
test_files=[
TestFile(
instrumented_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
),
],
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
)
opt = Optimizer(
Namespace(
@ -681,7 +625,7 @@ def test_sort_parametrized_loop(input, expected_output):
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
),
)
)
test_results = opt.run_and_parse_tests(
test_env=test_env,
@ -838,18 +782,10 @@ def test_sort():
run_cwd = Path(__file__).parent.parent.resolve()
original_cwd = Path.cwd()
func = FunctionToOptimize(
function_name="sorter",
parents=[],
file_path=Path("module.py"),
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
os.chdir(str(run_cwd))
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(11, 17)],
func,
project_root_path,
"pytest",
test_path, [CodePosition(11, 17)], func, project_root_path, "pytest"
)
os.chdir(original_cwd)
assert success
@ -867,13 +803,7 @@ def test_sort():
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_type = TestType.EXISTING_UNIT_TEST
test_files = TestFiles(
test_files=[
TestFile(
instrumented_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
),
],
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
)
opt = Optimizer(
@ -885,7 +815,7 @@ def test_sort():
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
),
)
)
test_results = opt.run_and_parse_tests(
test_env=test_env,
@ -1023,11 +953,7 @@ class TestPigLatin(unittest.TestCase):
run_cwd = Path(__file__).parent.parent.resolve()
original_cwd = Path.cwd()
func = FunctionToOptimize(
function_name="sorter",
parents=[],
file_path=Path("module.py"),
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
@ -1054,13 +980,7 @@ class TestPigLatin(unittest.TestCase):
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.EXISTING_UNIT_TEST
test_files = TestFiles(
test_files=[
TestFile(
instrumented_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
),
],
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
)
opt = Optimizer(
Namespace(
@ -1071,7 +991,7 @@ class TestPigLatin(unittest.TestCase):
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
),
)
)
test_results = opt.run_and_parse_tests(
test_env=test_env,
@ -1202,18 +1122,10 @@ class TestPigLatin(unittest.TestCase):
run_cwd = Path(__file__).parent.parent.resolve()
original_cwd = Path.cwd()
func = FunctionToOptimize(
function_name="sorter",
parents=[],
file_path=Path("module.py"),
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(16, 17)],
func,
project_root_path,
"unittest",
test_path, [CodePosition(16, 17)], func, project_root_path, "unittest"
)
os.chdir(original_cwd)
assert success
@ -1231,13 +1143,7 @@ class TestPigLatin(unittest.TestCase):
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.EXISTING_UNIT_TEST
test_files = TestFiles(
test_files=[
TestFile(
instrumented_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
),
],
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
)
opt = Optimizer(
Namespace(
@ -1248,7 +1154,7 @@ class TestPigLatin(unittest.TestCase):
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
),
)
)
test_results = opt.run_and_parse_tests(
test_env=test_env,
@ -1384,11 +1290,7 @@ class TestPigLatin(unittest.TestCase):
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path("module.py"))
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(14, 21)],
func,
project_root_path,
"unittest",
test_path, [CodePosition(14, 21)], func, project_root_path, "unittest"
)
os.chdir(original_cwd)
assert success
@ -1406,13 +1308,7 @@ class TestPigLatin(unittest.TestCase):
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.EXISTING_UNIT_TEST
test_files = TestFiles(
test_files=[
TestFile(
instrumented_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
),
],
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
)
opt = Optimizer(
@ -1424,7 +1320,7 @@ class TestPigLatin(unittest.TestCase):
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
),
)
)
test_results = opt.run_and_parse_tests(
test_env=test_env,
@ -1557,18 +1453,10 @@ class TestPigLatin(unittest.TestCase):
run_cwd = Path(__file__).parent.parent.resolve()
original_cwd = Path.cwd()
f = FunctionToOptimize(
function_name="sorter",
file_path=Path("module.py"),
parents=[],
)
f = FunctionToOptimize(function_name="sorter", file_path=Path("module.py"), parents=[])
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(17, 21)],
f,
project_root_path,
"unittest",
test_path, [CodePosition(17, 21)], f, project_root_path, "unittest"
)
os.chdir(original_cwd)
assert success
@ -1587,13 +1475,7 @@ class TestPigLatin(unittest.TestCase):
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.EXISTING_UNIT_TEST
test_files = TestFiles(
test_files=[
TestFile(
instrumented_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
),
],
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
)
opt = Optimizer(
Namespace(
@ -1604,7 +1486,7 @@ class TestPigLatin(unittest.TestCase):
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
),
)
)
test_results = opt.run_and_parse_tests(
test_env=test_env,
@ -1691,21 +1573,13 @@ from module import functionB as function_B
import class_name_B
from nuitka.nodes.ImportNodes import ExpressionBuiltinImport as nuitka_nodes_ImportNodes_ExpressionBuiltinImport
"""
f = FunctionToOptimize(
function_name="functionA",
file_path=Path("module.py"),
parents=[],
)
f = FunctionToOptimize(function_name="functionA", file_path=Path("module.py"), parents=[])
tree = ast.parse(code)
visitor = FunctionImportedAsVisitor(f)
visitor.visit(tree)
assert visitor.imported_as.function_name == "functionA"
f = FunctionToOptimize(
function_name="functionB",
file_path=Path("module.py"),
parents=[],
)
f = FunctionToOptimize(function_name="functionB", file_path=Path("module.py"), parents=[])
visitor = FunctionImportedAsVisitor(f)
visitor.visit(tree)
assert visitor.imported_as.function_name == "function_B"
@ -1717,15 +1591,9 @@ from nuitka.nodes.ImportNodes import ExpressionBuiltinImport as nuitka_nodes_Imp
)
visitor = FunctionImportedAsVisitor(f)
visitor.visit(tree)
assert (
visitor.imported_as.qualified_name == "nuitka_nodes_ImportNodes_ExpressionBuiltinImport.method_name"
)
assert visitor.imported_as.qualified_name == "nuitka_nodes_ImportNodes_ExpressionBuiltinImport.method_name"
f = FunctionToOptimize(
function_name="class_name_B",
file_path=Path("module.py"),
parents=[],
)
f = FunctionToOptimize(function_name="class_name_B", file_path=Path("module.py"), parents=[])
visitor = FunctionImportedAsVisitor(f)
visitor.visit(tree)
assert visitor.imported_as.qualified_name == "class_name_B"
@ -1782,8 +1650,7 @@ def test_class_name_A_function_name():
"""
test_path = (
Path(__file__).parent.resolve()
/ "../code_to_optimize/tests/pytest/test_class_function_instrumentation_temp.py"
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_class_function_instrumentation_temp.py"
)
try:
with open(test_path, "w") as f:
@ -1799,11 +1666,7 @@ def test_class_name_A_function_name():
)
os.chdir(str(run_cwd))
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(4, 23)],
func,
project_root_path,
"pytest",
test_path, [CodePosition(4, 23)], func, project_root_path, "pytest"
)
os.chdir(original_cwd)
finally:
@ -1878,8 +1741,7 @@ def test_common_tags_1():
"""
test_path = (
Path(__file__).parent.resolve()
/ "../code_to_optimize/tests/pytest/test_wrong_function_instrumentation_temp.py"
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_wrong_function_instrumentation_temp.py"
)
try:
with test_path.open("w") as f:
@ -1890,18 +1752,12 @@ def test_common_tags_1():
run_cwd = Path(__file__).parent.parent.resolve()
original_cwd = Path.cwd()
func = FunctionToOptimize(
function_name="find_common_tags",
file_path=project_root_path / "module.py",
parents=[],
function_name="find_common_tags", file_path=project_root_path / "module.py", parents=[]
)
os.chdir(str(run_cwd))
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(7, 11), CodePosition(11, 11)],
func,
project_root_path,
"pytest",
test_path, [CodePosition(7, 11), CodePosition(11, 11)], func, project_root_path, "pytest"
)
os.chdir(original_cwd)
assert success
@ -1969,8 +1825,7 @@ def test_sort():
codeflash_con.close()
"""
test_path = (
Path(__file__).parent.resolve()
/ "../code_to_optimize/tests/pytest/test_conditional_instrumentation_temp.py"
Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_conditional_instrumentation_temp.py"
)
try:
with open(test_path, "w") as f:
@ -1980,19 +1835,11 @@ def test_sort():
project_root_path = Path(__file__).parent.resolve() / "../code_to_optimize/"
run_cwd = Path(__file__).parent.parent.resolve()
original_cwd = Path.cwd()
func = FunctionToOptimize(
function_name="sorter",
file_path=project_root_path / "module.py",
parents=[],
)
func = FunctionToOptimize(function_name="sorter", file_path=project_root_path / "module.py", parents=[])
os.chdir(str(run_cwd))
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(7, 15)],
func,
project_root_path,
"pytest",
test_path, [CodePosition(7, 15)], func, project_root_path, "pytest"
)
os.chdir(original_cwd)
assert success
@ -2089,11 +1936,7 @@ def test_sort():
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(6, 26), CodePosition(10, 26)],
function_to_optimize,
project_root_path,
"pytest",
test_path, [CodePosition(6, 26), CodePosition(10, 26)], function_to_optimize, project_root_path, "pytest"
)
os.chdir(original_cwd)
assert success
@ -2213,17 +2056,12 @@ def test_code_replacement10() -> None:
run_cwd = Path(__file__).parent.parent.resolve()
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
Path(f.name),
[CodePosition(22, 28), CodePosition(28, 28)],
func,
Path(f.name).parent,
"pytest",
Path(f.name), [CodePosition(22, 28), CodePosition(28, 28)], func, Path(f.name).parent, "pytest"
)
os.chdir(original_cwd)
assert success
assert new_test == expected.format(
module_path=Path(f.name).name,
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
)
@ -2299,18 +2137,10 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time):
project_root_path = (Path(__file__).parent.resolve() / "../").resolve()
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
func = FunctionToOptimize(
function_name="accurate_sleepfunc",
parents=[],
file_path=Path("module.py"),
)
func = FunctionToOptimize(function_name="accurate_sleepfunc", parents=[], file_path=Path("module.py"))
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(8, 13)],
func,
project_root_path,
"pytest",
test_path, [CodePosition(8, 13)], func, project_root_path, "pytest"
)
os.chdir(original_cwd)
@ -2337,16 +2167,10 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time):
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
),
)
)
test_files = TestFiles(
test_files=[
TestFile(
instrumented_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
),
],
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
)
test_results = opt.run_and_parse_tests(
test_env=test_env,
@ -2454,18 +2278,10 @@ class TestPigLatin(unittest.TestCase):
project_root_path = (Path(__file__).parent.resolve() / "../").resolve()
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
func = FunctionToOptimize(
function_name="accurate_sleepfunc",
parents=[],
file_path=Path("module.py"),
)
func = FunctionToOptimize(function_name="accurate_sleepfunc", parents=[], file_path=Path("module.py"))
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(12, 17)],
func,
project_root_path,
"unittest",
test_path, [CodePosition(12, 17)], func, project_root_path, "unittest"
)
os.chdir(original_cwd)
@ -2492,23 +2308,13 @@ class TestPigLatin(unittest.TestCase):
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
),
)
)
test_files = TestFiles(
test_files=[
TestFile(
instrumented_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
),
],
test_files=[TestFile(instrumented_file_path=test_path, test_type=test_type, original_file_path=test_path)]
)
test_results = opt.run_and_parse_tests(
test_env=test_env,
test_files=test_files,
optimization_iteration=0,
test_functions=None,
testing_time=0.1,
test_env=test_env, test_files=test_files, optimization_iteration=0, test_functions=None, testing_time=0.1
)
assert test_results[0].id.function_getting_tested == "accurate_sleepfunc"

View file

@ -1,10 +1,5 @@
from codeflash.verification.parse_test_output import merge_test_results
from codeflash.verification.test_results import (
FunctionTestInvocation,
InvocationId,
TestResults,
TestType,
)
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
def test_merge_test_results_1():
@ -61,7 +56,7 @@ def test_merge_test_results_1():
timed_out=False,
loop_index=1,
),
],
]
)
test_results_bin = TestResults(
@ -117,7 +112,7 @@ def test_merge_test_results_1():
timed_out=False,
loop_index=1,
),
],
]
)
expected_merged_results = TestResults(
@ -173,12 +168,10 @@ def test_merge_test_results_1():
timed_out=False,
loop_index=1,
),
],
]
)
merged_results = merge_test_results(
xml_test_results=test_results_xml,
bin_test_results=test_results_bin,
test_framework="unittest",
xml_test_results=test_results_xml, bin_test_results=test_results_bin, test_framework="unittest"
)
assert merged_results == expected_merged_results
@ -200,30 +193,24 @@ def test_merge_test_results_1():
return_value=None,
timed_out=False,
loop_index=1,
),
],
)
]
)
merged_results = merge_test_results(
xml_test_results=test_results_xml_single,
bin_test_results=test_results_bin,
test_framework="unittest",
xml_test_results=test_results_xml_single, bin_test_results=test_results_bin, test_framework="unittest"
)
assert merged_results == expected_merged_results
merged_results = merge_test_results(
xml_test_results=test_results_xml_single,
bin_test_results=TestResults(),
test_framework="unittest",
xml_test_results=test_results_xml_single, bin_test_results=TestResults(), test_framework="unittest"
)
assert merged_results == test_results_xml_single
merged_results = merge_test_results(
xml_test_results=TestResults(),
bin_test_results=test_results_bin,
test_framework="unittest",
xml_test_results=TestResults(), bin_test_results=test_results_bin, test_framework="unittest"
)
assert merged_results == TestResults() # XML Results should always have better coverage than bin results
@ -246,8 +233,8 @@ def test_merge_test_results_1():
return_value=None,
timed_out=False,
loop_index=1,
),
],
)
]
)
test_results_bin_pytest = TestResults(
@ -286,13 +273,11 @@ def test_merge_test_results_1():
timed_out=False,
loop_index=1,
),
],
]
)
merged_results = merge_test_results(
xml_test_results=test_results_xml_pytest,
bin_test_results=test_results_bin_pytest,
test_framework="unittest",
xml_test_results=test_results_xml_pytest, bin_test_results=test_results_bin_pytest, test_framework="unittest"
)
assert merged_results == test_results_bin_pytest

View file

@ -1,8 +1,6 @@
import pytest
from codeflash.code_utils.remove_generated_tests import (
remove_functions_from_generated_tests,
)
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
from codeflash.models.models import GeneratedTests, GeneratedTestsList
@ -21,10 +19,7 @@ def test_sorted_list():
# Test sorting an already sorted list
codeflash_output = sorter([1, 2, 3, 4, 5])
# Outputs were verified to be equal to the original implementation"""
generated_tests = GeneratedTests(
generated_original_test_source=generated_test_source,
instrumented_test_source="",
)
generated_tests = GeneratedTests(generated_original_test_source=generated_test_source, instrumented_test_source="")
generated_tests_list = GeneratedTestsList(generated_tests=[generated_tests])
functions_to_remove = ["test_single_element"]
@ -59,10 +54,7 @@ def test_sorted_list():
# Test sorting an already sorted list
codeflash_output = sorter([1, 2, 3, 4, 5])
# Outputs were verified to be equal to the original implementation"""
generated_tests = GeneratedTests(
generated_original_test_source=generated_test_source,
instrumented_test_source="",
)
generated_tests = GeneratedTests(generated_original_test_source=generated_test_source, instrumented_test_source="")
generated_tests_list_1 = GeneratedTestsList(generated_tests=[generated_tests])
functions_to_remove = ["test_single_element", "test_sorted_list"]
@ -85,8 +77,7 @@ def test_sorted_list():
# Outputs were verified to be equal to the original implementation"""
generated_tests_2 = GeneratedTests(
generated_original_test_source=generated_test_source,
instrumented_test_source="",
generated_original_test_source=generated_test_source, instrumented_test_source=""
)
generated_tests_list_2 = GeneratedTestsList(generated_tests=[generated_tests_2])
@ -133,10 +124,7 @@ def test_list_with_mixed_orderable_and_non_orderable_types():
sorter([True, 1, "string", [1, 2]])
# Outputs were verified to be equal to the original implementation"""
generated_tests = GeneratedTests(
generated_original_test_source=generated_test_source,
instrumented_test_source="",
)
generated_tests = GeneratedTests(generated_original_test_source=generated_test_source, instrumented_test_source="")
generated_tests_list = GeneratedTestsList(generated_tests=[generated_tests])
functions_to_remove = ["test_list_with_custom_objects"]
@ -189,10 +177,7 @@ def test_sorted_list():
# Test sorting an already sorted list
codeflash_output = sorter([1, 2, 3, 4, 5])
# Outputs were verified to be equal to the original implementation"""
generated_tests = GeneratedTests(
generated_original_test_source=generated_test_source,
instrumented_test_source="",
)
generated_tests = GeneratedTests(generated_original_test_source=generated_test_source, instrumented_test_source="")
generated_tests_list = GeneratedTestsList(generated_tests=[generated_tests])
functions_to_remove = ["test_empty_list", "test_sort_parametrized"]
@ -223,9 +208,7 @@ def test_sorted_list():
assert generated_tests_list.generated_tests[0].generated_original_test_source == expected
@pytest.mark.skip(
"We don't handle the edge case where the parametrized test appears right after the test to remove",
)
@pytest.mark.skip("We don't handle the edge case where the parametrized test appears right after the test to remove")
def test_keep_parametrized_test2():
generated_test_source = """def test_empty_list():
# Test sorting an empty list
@ -253,10 +236,7 @@ def test_sorted_list():
# Test sorting an already sorted list
codeflash_output = sorter([1, 2, 3, 4, 5])
# Outputs were verified to be equal to the original implementation"""
generated_tests = GeneratedTests(
generated_original_test_source=generated_test_source,
instrumented_test_source="",
)
generated_tests = GeneratedTests(generated_original_test_source=generated_test_source, instrumented_test_source="")
generated_tests_list = GeneratedTestsList(generated_tests=[generated_tests])
functions_to_remove = ["test_empty_list", "test_sort_parametrized"]

View file

@ -9,11 +9,7 @@ from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, sav
class TestShellUtils(unittest.TestCase):
@patch(
"codeflash.code_utils.shell_utils.open",
new_callable=mock_open,
read_data="existing content",
)
@patch("codeflash.code_utils.shell_utils.open", new_callable=mock_open, read_data="existing content")
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
def test_save_api_key_to_rc_success(self, mock_get_shell_rc_path, mock_file):
mock_get_shell_rc_path.return_value = "/fake/path/.bashrc"
@ -25,11 +21,7 @@ class TestShellUtils(unittest.TestCase):
handle.write.assert_called_once()
handle.truncate.assert_called_once()
@patch(
"codeflash.code_utils.shell_utils.open",
new_callable=mock_open,
read_data="existing content",
)
@patch("codeflash.code_utils.shell_utils.open", new_callable=mock_open, read_data="existing content")
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
def test_save_api_key_to_rc_failure(self, mock_get_shell_rc_path, mock_file):
mock_get_shell_rc_path.return_value = "/fake/path/.bashrc"
@ -59,8 +51,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
with patch("codeflash.code_utils.shell_utils.get_shell_rc_path") as mock_get_shell_rc_path:
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch(
"builtins.open",
mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n'),
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n')
) as mock_file:
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
@ -83,10 +74,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
with patch("builtins.open", mock_open(read_data=f"CODEFLASH_API_KEY={self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
with patch(
"builtins.open",
mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n"),
):
with patch("builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
@ -96,9 +84,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch(
"builtins.open",
mock_open(
read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n',
),
mock_open(read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n'),
):
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
@ -108,9 +94,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch(
"builtins.open",
mock_open(
read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n',
),
mock_open(read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n'),
):
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
@ -118,10 +102,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
def test_api_key_in_comment(self, mock_get_shell_rc_path):
"""Test with API key export in a comment."""
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch(
"builtins.open",
mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n'),
):
with patch("builtins.open", mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n')):
self.assertIsNone(read_api_key_from_shell_config())
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")

View file

@ -36,16 +36,12 @@ class TestUnittestRunnerSorter(unittest.TestCase):
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
test_files = TestFiles(
test_files=[
TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST),
],
test_files=[TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
)
fp.write(code.encode("utf-8"))
fp.flush()
result_file, process = run_tests(
test_files,
test_framework=config.test_framework,
cwd=Path(config.project_root_path),
test_files, test_framework=config.test_framework, cwd=Path(config.project_root_path)
)
results = parse_test_xml(result_file, test_files, config, process)
assert results[0].did_pass, "Test did not pass as expected"
@ -72,9 +68,7 @@ def test_sort():
)
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
test_files = TestFiles(
test_files=[
TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST),
],
test_files=[TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
)
fp.write(code.encode("utf-8"))
fp.flush()
@ -90,10 +84,7 @@ def test_sort():
pytest_target_runtime_seconds=1,
)
results = parse_test_xml(
test_xml_file_path=result_file,
test_files=test_files,
test_config=config,
run_result=process,
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process
)
assert results[0].did_pass, "Test did not pass as expected"
result_file.unlink(missing_ok=True)

View file

@ -6,9 +6,7 @@ from typing import List
from codeflash.code_utils.code_extractor import get_code
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.optimization.function_context import (
get_constrained_function_context_and_helper_functions,
)
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
class CustomType:
@ -94,16 +92,11 @@ def test_function_context_works_for_composite_types() -> None:
def test_function_context_custom_datatype() -> None:
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
file_path = project_path / "math_utils.py"
code, contextual_dunder_methods = get_code(
[FunctionToOptimize("cosine_similarity", str(file_path), [])],
)
code, contextual_dunder_methods = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])])
assert code is not None
assert contextual_dunder_methods == set()
a, helper_functions, dunder_methods = get_constrained_function_context_and_helper_functions(
FunctionToOptimize("cosine_similarity", str(file_path), []),
str(project_path),
code,
1000,
FunctionToOptimize("cosine_similarity", str(file_path), []), str(project_path), code, 1000
)
assert len(helper_functions) == 1

View file

@ -75,10 +75,7 @@ def test_discover_tests_pytest_with_temp_dir_root():
assert {
discovered_tests["dummy_code.dummy_function"][0].tests_in_file.test_function,
discovered_tests["dummy_code.dummy_function"][1].tests_in_file.test_function,
} == {
"test_dummy_parametrized_function[True]",
"test_dummy_function",
}
} == {"test_dummy_parametrized_function[True]", "test_dummy_function"}
def test_discover_tests_pytest_with_multi_level_dirs():
@ -147,8 +144,7 @@ def test_discover_tests_pytest_with_multi_level_dirs():
assert len(discovered_tests) == 3
assert discovered_tests["root_code.root_function"][0].tests_in_file.test_file == root_test_file_path
assert (
discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file
== level1_test_file_path
discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file == level1_test_file_path
)
assert (
@ -238,8 +234,7 @@ def test_discover_tests_pytest_dirs():
assert len(discovered_tests) == 4
assert discovered_tests["root_code.root_function"][0].tests_in_file.test_file == root_test_file_path
assert (
discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file
== level1_test_file_path
discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file == level1_test_file_path
)
assert (
discovered_tests["level1.level2.level2_code.level2_function"][0].tests_in_file.test_file
@ -283,10 +278,7 @@ def test_discover_tests_pytest_with_class():
# Check if the test class and method are discovered
assert len(discovered_tests) == 1
assert (
discovered_tests["some_class_code.SomeClass.some_method"][0].tests_in_file.test_file
== test_file_path
)
assert discovered_tests["some_class_code.SomeClass.some_method"][0].tests_in_file.test_file == test_file_path
def test_discover_tests_pytest_with_double_nested_directories():
@ -414,9 +406,7 @@ def test_discover_tests_pytest_with_nested_class():
# Check if the test for the nested class method is discovered
assert len(discovered_tests) == 1
assert (
discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"][
0
].tests_in_file.test_file
discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"][0].tests_in_file.test_file
== test_file_path
)
@ -455,6 +445,4 @@ def test_discover_tests_pytest_separate_moduledir():
# Check if the test for the nested class method is discovered
assert len(discovered_tests) == 1
assert (
discovered_tests["mypackage.code.find_common_tags"][0].tests_in_file.test_file == test_file_path
)
assert discovered_tests["mypackage.code.find_common_tags"][0].tests_in_file.test_file == test_file_path

View file

@ -4,7 +4,7 @@ from typing import Any, Dict, Optional
from posthog import Posthog
_posthog: Posthog = Posthog(
project_api_key="phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol", host="https://us.posthog.com",
project_api_key="phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol", host="https://us.posthog.com"
)
@ -20,8 +20,4 @@ def ph(user_id: str, event: str, properties: Optional[Dict[str, Any]] = None) ->
properties["environment"] = os.environ.get("ENVIRONMENT", default="None")
properties["openai_api_type"] = os.environ.get("OPENAI_API_TYPE", default="openai")
_posthog.capture(
distinct_id=user_id,
event=event,
properties=properties,
)
_posthog.capture(distinct_id=user_id, event=event, properties=properties)

View file

@ -26,8 +26,7 @@ class FunctionToOptimize:
def top_level_parent_name(self) -> str:
if self.parents:
return self.parents[0].name
else:
return self.function_name
return self.function_name
def __str__(self) -> str:
return f"{self.file_path}:{'.'.join([p.name for p in self.parents]) + '.' if self.parents else ''}{self.function_name}"

View file

@ -76,30 +76,17 @@ WSGI_APPLICATION: str = "aiservice.wsgi.application"
# Requires DATABASE_URL environment variable to be set
assert "DATABASE_URL" in os.environ, "DATABASE_URL environment variable not set"
DATABASES = {
"default": dj_database_url.config(
conn_max_age=600,
conn_health_checks=True,
),
}
DATABASES = {"default": dj_database_url.config(conn_max_age=600, conn_health_checks=True)}
# Password validation
# https://docs.djangoproject.com/en/5.0/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS: list[dict[str, str]] = [
{
"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
},
{
"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
},
{
"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
},
{
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
},
{"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator"},
{"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"},
{"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"},
{"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"},
]

View file

@ -18,12 +18,12 @@ Including another URLconf
"""
# from django.contrib import admin
from django.urls import path
from log_features.log_features import features_api
from optimizer.optimizer import optimize_api
from testgen.testgen import testgen_api
from django.urls import path
urlpatterns = [
path("ai/optimize", optimize_api.urls),
path("ai/testgen", testgen_api.urls),

View file

@ -13,7 +13,7 @@ class AuthBearer(HttpBearer):
num_users = await CFAPIKeys.objects.filter(key=hashed_token).aupdate(last_used=Now())
if num_users == 0:
raise HttpError(403, "Invalid API key")
elif num_users == 1:
if num_users == 1:
api_key_instance = await instance_for_api_key(hashed_token)
if not api_key_instance:
print(f"Instance not found for api key {token}. Returning 403")
@ -21,10 +21,7 @@ class AuthBearer(HttpBearer):
request.user = api_key_instance.user_id
request.tier = api_key_instance.tier
return token
else:
print(
"THIS SHOULD NOT HAPPEN! More than one users found in the db with the same api key!",
)
raise HttpError(403, "Invalid API key")
print("THIS SHOULD NOT HAPPEN! More than one users found in the db with the same api key!")
raise HttpError(403, "Invalid API key")
except CFAPIKeys.DoesNotExist:
raise HttpError(403, "Invalid API key")

View file

@ -24,10 +24,7 @@ class CFAPIKeys(models.Model):
suffix = models.CharField(max_length=4)
name = models.CharField(max_length=255)
created_at = models.DateTimeField(auto_now_add=True)
last_used = models.DateTimeField(
null=True,
blank=True,
)
last_used = models.DateTimeField(null=True, blank=True)
user_id = models.TextField(null=True, blank=True)
tier = models.TextField(null=True, blank=True)

View file

@ -2,12 +2,12 @@ from __future__ import annotations
import datetime as dt
import logging
from asyncio import Lock, Semaphore
from asyncio import Semaphore
from typing import Dict, List, Optional
from authapp.auth import AuthBearer
from ninja import NinjaAPI, Schema
from authapp.auth import AuthBearer
from log_features.models import OptimizationFeatures
features_api = NinjaAPI(auth=AuthBearer(), urls_namespace="log_features")
@ -79,36 +79,24 @@ async def log_features(
if generated_tests:
f.generated_test = (f.generated_test or []) + generated_tests
if instrumented_generated_tests:
f.instrumented_generated_test = (
f.instrumented_generated_test or []
) + instrumented_generated_tests
f.instrumented_generated_test = (f.instrumented_generated_test or []) + instrumented_generated_tests
# Update fields on the existing instance
f.user_id = user_id if user_id is not None else f.user_id
f.original_code = original_code if original_code is not None else f.original_code
f.optimizations_raw = (
optimizations_raw if optimizations_raw is not None else f.optimizations_raw
)
f.optimizations_post = (
optimizations_post if optimizations_post is not None else f.optimizations_post
)
f.optimizations_raw = optimizations_raw if optimizations_raw is not None else f.optimizations_raw
f.optimizations_post = optimizations_post if optimizations_post is not None else f.optimizations_post
f.explanations_raw = explanations_raw if explanations_raw is not None else f.explanations_raw
f.explanations_post = (
explanations_post if explanations_post is not None else f.explanations_post
)
f.explanations_post = explanations_post if explanations_post is not None else f.explanations_post
f.speedup_ratio = speedup_ratio if speedup_ratio is not None else f.speedup_ratio
f.original_runtime = original_runtime if original_runtime is not None else f.original_runtime
f.optimized_runtime = (
optimized_runtime if optimized_runtime is not None else f.optimized_runtime
)
f.optimized_runtime = optimized_runtime if optimized_runtime is not None else f.optimized_runtime
f.is_correct = is_correct if is_correct is not None else f.is_correct
f.generated_test = generated_tests if generated_tests is not None else f.generated_test
f.test_framework = test_framework if test_framework is not None else f.test_framework
f.created_at = datetime if datetime is not None else f.created_at
f.aiservice_commit_id = (
aiservice_commit if aiservice_commit is not None else f.aiservice_commit_id
)
f.aiservice_commit_id = aiservice_commit if aiservice_commit is not None else f.aiservice_commit_id
f.metadata = updated_metadata
f.experiment_metadata = updated_experiment_metadata
await f.asave()
@ -154,13 +142,7 @@ class LoggingErrorResponseSchema(Schema):
error: str
@features_api.post(
"/",
response={
200: None,
500: LoggingErrorResponseSchema,
},
)
@features_api.post("/", response={200: None, 500: LoggingErrorResponseSchema})
async def log_features_cli(request, data: LoggingSchema):
try:
if request.tier is None:

View file

@ -4,11 +4,7 @@ from django.db import models
class OptimizationFeatures(models.Model):
trace_id = models.CharField(
max_length=36,
primary_key=True,
validators=[MinLengthValidator(36)],
)
trace_id = models.CharField(max_length=36, primary_key=True, validators=[MinLengthValidator(36)])
original_code = models.TextField(null=True, blank=True)
user_id = models.TextField(null=True, blank=True)
optimizations_raw = models.JSONField(null=True, blank=True)

View file

@ -20,7 +20,7 @@ def main():
raise ImportError(
"Couldn't import Django. Are you sure it's installed and "
"available on your PYTHONPATH environment variable? Did you "
"forget to activate a virtual environment?",
"forget to activate a virtual environment?"
) from exc
load_env()

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show more