mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Fix helper class code replacement
refactor code
This commit is contained in:
parent
644ee6ba39
commit
91eef8337e
4 changed files with 574 additions and 99 deletions
|
|
@ -18,17 +18,17 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
function_name: str,
|
||||
class_name: str | None,
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] | None = None,
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if preexisting_functions is None:
|
||||
preexisting_functions = []
|
||||
if preexisting_objects is None:
|
||||
preexisting_objects = []
|
||||
self.function_name = function_name
|
||||
self.class_name = class_name
|
||||
self.optim_body: FunctionDef | None = None
|
||||
self.optim_new_class_functions: list[cst.FunctionDef] = []
|
||||
self.optim_new_functions: list[cst.FunctionDef] = []
|
||||
self.preexisting_functions = preexisting_functions
|
||||
self.preexisting_objects = preexisting_objects
|
||||
self.contextual_functions = contextual_functions.union(
|
||||
{(self.class_name, self.function_name)},
|
||||
)
|
||||
|
|
@ -44,8 +44,8 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
if node.name.value == self.function_name:
|
||||
self.optim_body = node
|
||||
elif (
|
||||
self.preexisting_functions
|
||||
and (node.name.value, []) not in self.preexisting_functions
|
||||
self.preexisting_objects
|
||||
and (node.name.value, []) not in self.preexisting_objects
|
||||
and (
|
||||
isinstance(parent, cst.Module)
|
||||
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
|
||||
|
|
@ -57,10 +57,10 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
|
||||
for child_node in node.body.body:
|
||||
if (
|
||||
self.preexisting_functions
|
||||
self.preexisting_objects
|
||||
and isinstance(child_node, cst.FunctionDef)
|
||||
and (node.name.value, child_node.name.value) not in self.contextual_functions
|
||||
and (child_node.name.value, parents) not in self.preexisting_functions
|
||||
and (child_node.name.value, parents) not in self.preexisting_objects
|
||||
):
|
||||
self.optim_new_class_functions.append(child_node)
|
||||
|
||||
|
|
@ -153,7 +153,7 @@ def replace_functions_in_file(
|
|||
source_code: str,
|
||||
original_function_names: list[str],
|
||||
optimized_code: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
) -> str:
|
||||
parsed_function_names = []
|
||||
|
|
@ -173,11 +173,11 @@ def replace_functions_in_file(
|
|||
function_name,
|
||||
class_name,
|
||||
contextual_functions,
|
||||
preexisting_functions,
|
||||
preexisting_objects,
|
||||
)
|
||||
module.visit(visitor)
|
||||
|
||||
if visitor.optim_body is None and not preexisting_functions:
|
||||
if visitor.optim_body is None and not preexisting_objects:
|
||||
continue
|
||||
if visitor.optim_body is None:
|
||||
raise ValueError(f"Did not find the function {function_name} in the optimized code")
|
||||
|
|
@ -202,7 +202,7 @@ def replace_functions_and_add_imports(
|
|||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
) -> str:
|
||||
|
|
@ -212,7 +212,7 @@ def replace_functions_and_add_imports(
|
|||
source_code,
|
||||
function_names,
|
||||
optimized_code,
|
||||
preexisting_functions,
|
||||
preexisting_objects,
|
||||
contextual_functions,
|
||||
),
|
||||
file_path_of_module_with_function_to_optimize,
|
||||
|
|
@ -226,7 +226,7 @@ def replace_function_definitions_in_module(
|
|||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]],
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
) -> bool:
|
||||
|
|
@ -235,7 +235,7 @@ def replace_function_definitions_in_module(
|
|||
:param optimized_code:
|
||||
:param file_path_of_module_with_function_to_optimize:
|
||||
:param module_abspath:
|
||||
:param preexisting_functions:
|
||||
:param preexisting_objects:
|
||||
:param contextual_functions:
|
||||
:param project_root_path:
|
||||
:return:
|
||||
|
|
@ -249,7 +249,7 @@ def replace_function_definitions_in_module(
|
|||
optimized_code,
|
||||
file_path_of_module_with_function_to_optimize,
|
||||
module_abspath,
|
||||
preexisting_functions,
|
||||
preexisting_objects,
|
||||
contextual_functions,
|
||||
project_root_path,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class CodeOptimizationContext(BaseModel):
|
|||
code_to_optimize_with_helpers: str
|
||||
contextual_dunder_methods: set[tuple[str, str]]
|
||||
helper_functions: list[FunctionSource]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]]
|
||||
|
||||
|
||||
class OptimizedCandidateResult(BaseModel):
|
||||
|
|
|
|||
|
|
@ -194,16 +194,11 @@ class Optimizer:
|
|||
if not is_successful(ctx_result):
|
||||
return Failure(ctx_result.failure())
|
||||
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
||||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
for helper_function in code_context.helper_functions:
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(
|
||||
helper_function.qualified_name,
|
||||
)
|
||||
original_helper_code = {}
|
||||
for module_abspath in helper_functions_by_module_abspath:
|
||||
with pathlib.Path(module_abspath).open(encoding="utf8") as f:
|
||||
for helper_function in code_context.helper_functions:
|
||||
with pathlib.Path(helper_function.file_path).open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[module_abspath] = helper_code
|
||||
original_helper_code[helper_function.file_path] = helper_code
|
||||
logging.info(f"Code to be optimized:\n{code_context.code_to_optimize_with_helpers}")
|
||||
module_path = module_name_from_file_path(function_to_optimize.file_path, self.args.project_root)
|
||||
|
||||
|
|
@ -272,7 +267,6 @@ class Optimizer:
|
|||
best_optimization = self.determine_best_candidate(
|
||||
candidates,
|
||||
code_context,
|
||||
helper_functions_by_module_abspath,
|
||||
function_to_optimize,
|
||||
generated_tests_path,
|
||||
instrumented_unittests_created_for_function,
|
||||
|
|
@ -306,15 +300,14 @@ class Optimizer:
|
|||
)
|
||||
|
||||
self.replace_function_and_helpers_with_optimized_code(
|
||||
code_context,
|
||||
helper_functions_by_module_abspath,
|
||||
explanation,
|
||||
best_optimization.candidate.source_code,
|
||||
function_to_optimize.qualified_name,
|
||||
code_context=code_context,
|
||||
function_to_optimize_file_path=explanation.file_path,
|
||||
optimized_code=best_optimization.candidate.source_code,
|
||||
qualified_function_name=function_to_optimize.qualified_name,
|
||||
)
|
||||
|
||||
new_code, new_helper_code = self.reformat_code_and_helpers(
|
||||
helper_functions_by_module_abspath,
|
||||
code_context.helper_functions,
|
||||
explanation.file_path,
|
||||
original_code,
|
||||
)
|
||||
|
|
@ -346,7 +339,6 @@ class Optimizer:
|
|||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
helper_functions_by_module_abspath,
|
||||
)
|
||||
# Delete all the generated tests to not cause any clutter.
|
||||
pathlib.Path(generated_tests_path).unlink(missing_ok=True)
|
||||
|
|
@ -360,7 +352,6 @@ class Optimizer:
|
|||
self,
|
||||
candidates: list[OptimizedCandidate],
|
||||
code_context: CodeOptimizationContext,
|
||||
helper_functions_by_module_abspath: dict[str, set[str]],
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
generated_tests_path: str,
|
||||
instrumented_unittests_created_for_function: set[str],
|
||||
|
|
@ -394,28 +385,12 @@ class Optimizer:
|
|||
logging.info(f"Optimized candidate {j}/{len(candidates)}:")
|
||||
logging.info(candidate.source_code)
|
||||
try:
|
||||
did_update = replace_function_definitions_in_module(
|
||||
function_names=[function_to_optimize.qualified_name],
|
||||
did_update = self.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context,
|
||||
function_to_optimize_file_path=function_to_optimize.file_path,
|
||||
optimized_code=candidate.source_code,
|
||||
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
|
||||
module_abspath=function_to_optimize.file_path,
|
||||
preexisting_functions=code_context.preexisting_functions,
|
||||
contextual_functions=code_context.contextual_dunder_methods,
|
||||
project_root_path=self.args.project_root,
|
||||
qualified_function_name=function_to_optimize.qualified_name,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in helper_functions_by_module_abspath.items():
|
||||
did_update |= replace_function_definitions_in_module(
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=candidate.source_code,
|
||||
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
|
||||
module_abspath=module_abspath,
|
||||
preexisting_functions=[],
|
||||
contextual_functions=code_context.contextual_dunder_methods,
|
||||
project_root_path=self.args.project_root,
|
||||
)
|
||||
if not did_update:
|
||||
logging.warning(
|
||||
"No functions were replaced in the optimized code. Skipping optimization candidate.",
|
||||
|
|
@ -432,7 +407,6 @@ class Optimizer:
|
|||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
helper_functions_by_module_abspath,
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
@ -490,7 +464,6 @@ class Optimizer:
|
|||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
helper_functions_by_module_abspath,
|
||||
)
|
||||
logging.info("----------------")
|
||||
except KeyboardInterrupt as e:
|
||||
|
|
@ -498,7 +471,6 @@ class Optimizer:
|
|||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
helper_functions_by_module_abspath,
|
||||
)
|
||||
logging.exception(f"Optimization interrupted: {e}")
|
||||
raise e
|
||||
|
|
@ -548,17 +520,16 @@ class Optimizer:
|
|||
original_code: str,
|
||||
original_helper_code: dict[str, str],
|
||||
path: str,
|
||||
helper_functions_by_module_abspath: dict[str, set[str]],
|
||||
) -> None:
|
||||
with pathlib.Path(path).open("w", encoding="utf8") as f:
|
||||
f.write(original_code)
|
||||
for module_abspath in helper_functions_by_module_abspath:
|
||||
for module_abspath in original_helper_code:
|
||||
with pathlib.Path(module_abspath).open("w", encoding="utf8") as f:
|
||||
f.write(original_helper_code[module_abspath])
|
||||
|
||||
def reformat_code_and_helpers(
|
||||
self,
|
||||
helper_functions_by_module_abspath: dict[str, set[str]],
|
||||
helper_functions: list[FunctionSource],
|
||||
path: str,
|
||||
original_code: str,
|
||||
) -> tuple[str, dict[str, str]]:
|
||||
|
|
@ -574,7 +545,8 @@ class Optimizer:
|
|||
new_code = sort_imports(new_code)
|
||||
|
||||
new_helper_code: dict[str, str] = {}
|
||||
for module_abspath in helper_functions_by_module_abspath:
|
||||
helper_functions_paths = {hf.file_path for hf in helper_functions}
|
||||
for module_abspath in helper_functions_paths:
|
||||
formatted_helper_code = format_code(
|
||||
self.args.formatter_cmds,
|
||||
module_abspath,
|
||||
|
|
@ -588,33 +560,40 @@ class Optimizer:
|
|||
def replace_function_and_helpers_with_optimized_code(
|
||||
self,
|
||||
code_context: CodeOptimizationContext,
|
||||
helper_functions_by_module_abspath: dict[str, set[str]],
|
||||
explanation: Explanation,
|
||||
function_to_optimize_file_path: str,
|
||||
optimized_code: str,
|
||||
qualified_function_name: str,
|
||||
) -> None:
|
||||
replace_function_definitions_in_module(
|
||||
) -> bool:
|
||||
"""Raises many exceptions if the code is not valid. Catch them where using"""
|
||||
did_update = replace_function_definitions_in_module(
|
||||
function_names=[qualified_function_name],
|
||||
optimized_code=optimized_code,
|
||||
file_path_of_module_with_function_to_optimize=explanation.file_path,
|
||||
module_abspath=explanation.file_path,
|
||||
preexisting_functions=code_context.preexisting_functions,
|
||||
file_path_of_module_with_function_to_optimize=function_to_optimize_file_path,
|
||||
module_abspath=function_to_optimize_file_path,
|
||||
preexisting_objects=code_context.preexisting_objects,
|
||||
contextual_functions=code_context.contextual_dunder_methods,
|
||||
project_root_path=self.args.project_root,
|
||||
)
|
||||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
for helper_function in code_context.helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(
|
||||
helper_function.qualified_name,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in helper_functions_by_module_abspath.items():
|
||||
replace_function_definitions_in_module(
|
||||
did_update |= replace_function_definitions_in_module(
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optimized_code,
|
||||
file_path_of_module_with_function_to_optimize=explanation.file_path,
|
||||
file_path_of_module_with_function_to_optimize=function_to_optimize_file_path,
|
||||
module_abspath=module_abspath,
|
||||
preexisting_functions=[],
|
||||
preexisting_objects=[],
|
||||
contextual_functions=code_context.contextual_dunder_methods,
|
||||
project_root_path=self.args.project_root,
|
||||
)
|
||||
return did_update
|
||||
|
||||
def get_code_optimization_context(
|
||||
self,
|
||||
|
|
@ -627,11 +606,11 @@ class Optimizer:
|
|||
)
|
||||
if code_to_optimize is None:
|
||||
return Failure("Could not find function to optimize.")
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
(name, [FunctionParent(name=class_name, type="ClassDef")])
|
||||
for class_name, name in contextual_dunder_methods
|
||||
]
|
||||
preexisting_functions.append((function_to_optimize.function_name, function_to_optimize.parents))
|
||||
preexisting_objects.append((function_to_optimize.function_name, function_to_optimize.parents))
|
||||
(
|
||||
helper_code,
|
||||
helper_functions,
|
||||
|
|
@ -674,7 +653,7 @@ class Optimizer:
|
|||
project_root,
|
||||
helper_functions,
|
||||
)
|
||||
preexisting_functions.extend(
|
||||
preexisting_objects.extend(
|
||||
[
|
||||
(qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")]))
|
||||
if len(qualified_name_list := fn.qualified_name.split(".")) > 1
|
||||
|
|
@ -688,7 +667,7 @@ class Optimizer:
|
|||
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
|
||||
contextual_dunder_methods=contextual_dunder_methods,
|
||||
helper_functions=helper_functions,
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
|
@ -11,6 +14,21 @@ from codeflash.optimization.optimizer import Optimizer
|
|||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JediDefinition:
|
||||
type: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FakeFunctionSource:
|
||||
file_path: str
|
||||
qualified_name: str
|
||||
fully_qualified_name: str
|
||||
only_function_name: str
|
||||
source_code: str
|
||||
jedi_definition: JediDefinition
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement() -> None:
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
|
@ -50,7 +68,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("new_function", [FunctionParent(name="NewClass", type="ClassDef")]),
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
|
|
@ -60,7 +78,7 @@ print("Hello world")
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -113,7 +131,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("new_function", []),
|
||||
("other_function", []),
|
||||
]
|
||||
|
|
@ -124,7 +142,7 @@ print("Hello world")
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -177,7 +195,7 @@ print("Salut monde")
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["module.other_function"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -185,7 +203,7 @@ print("Salut monde")
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -241,7 +259,7 @@ print("Salut monde")
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -249,7 +267,7 @@ print("Salut monde")
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -294,7 +312,7 @@ def supersort(doink):
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["sorter_deps"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("sorter_deps", [])]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("sorter_deps", [])]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -302,7 +320,7 @@ def supersort(doink):
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -378,7 +396,7 @@ print("Not cool")
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=[("other_function", []), ("yet_another_function", []), ("blob", [])],
|
||||
preexisting_objects=[("other_function", []), ("yet_another_function", []), ("blob", [])],
|
||||
contextual_functions=set(),
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -390,7 +408,7 @@ print("Not cool")
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=[],
|
||||
preexisting_objects=[],
|
||||
contextual_functions=set(),
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -579,7 +597,7 @@ class CacheConfig(BaseConfig):
|
|||
"""
|
||||
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
|
||||
parents = [FunctionParent(name="CacheConfig", type="ClassDef")]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("__init__", parents),
|
||||
("from_config", parents),
|
||||
]
|
||||
|
|
@ -595,7 +613,7 @@ class CacheConfig(BaseConfig):
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -659,7 +677,7 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
return np.sum(a != b) / a.size
|
||||
'''
|
||||
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")]),
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
|
|
@ -669,7 +687,7 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -718,7 +736,7 @@ print("Hello world")
|
|||
"""
|
||||
parents = [FunctionParent(name="NewClass", type="ClassDef")]
|
||||
function_name: str = "NewClass.__init__"
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("__init__", parents),
|
||||
("__call__", parents),
|
||||
]
|
||||
|
|
@ -732,7 +750,7 @@ print("Hello world")
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
|
|
@ -834,13 +852,13 @@ def test_code_replacement11() -> None:
|
|||
|
||||
function_name: str = "Fu.foo"
|
||||
parents = [FunctionParent("Fu", "ClassDef")]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
)
|
||||
assert new_code == expected_code
|
||||
|
|
@ -875,13 +893,13 @@ def test_code_replacement12() -> None:
|
|||
pass
|
||||
'''
|
||||
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=["Fu.real_bar"],
|
||||
optimized_code=optim_code,
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
)
|
||||
assert new_code == expected_code
|
||||
|
|
@ -913,7 +931,7 @@ def test_test_libcst_code_replacement13() -> None:
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
||||
preexisting_functions: list[tuple[str, list[FunctionParent]]] = []
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -921,8 +939,486 @@ def test_test_libcst_code_replacement13() -> None:
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == original_code
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_different_class_code_replacement():
|
||||
original_code = """from __future__ import annotations
|
||||
import sys
|
||||
from codeflash.verification.comparator import comparator
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Iterator
|
||||
|
||||
class TestType(Enum):
|
||||
EXISTING_UNIT_TEST = 1
|
||||
INSPIRED_REGRESSION = 2
|
||||
GENERATED_REGRESSION = 3
|
||||
REPLAY_TEST = 4
|
||||
|
||||
def to_name(self) -> str:
|
||||
names = {
|
||||
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
|
||||
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
|
||||
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
||||
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
||||
}
|
||||
return names[self]
|
||||
|
||||
class TestResults(BaseModel):
|
||||
def __iter__(self) -> Iterator[FunctionTestInvocation]:
|
||||
return iter(self.test_results)
|
||||
def __len__(self) -> int:
|
||||
return len(self.test_results)
|
||||
def __getitem__(self, index: int) -> FunctionTestInvocation:
|
||||
return self.test_results[index]
|
||||
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
|
||||
self.test_results[index] = value
|
||||
def __delitem__(self, index: int) -> None:
|
||||
del self.test_results[index]
|
||||
def __contains__(self, value: FunctionTestInvocation) -> bool:
|
||||
return value in self.test_results
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.test_results)
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# Unordered comparison
|
||||
if type(self) != type(other):
|
||||
return False
|
||||
if len(self) != len(other):
|
||||
return False
|
||||
original_recursion_limit = sys.getrecursionlimit()
|
||||
for test_result in self:
|
||||
other_test_result = other.get_by_id(test_result.id)
|
||||
if other_test_result is None:
|
||||
return False
|
||||
|
||||
if original_recursion_limit < 5000:
|
||||
sys.setrecursionlimit(5000)
|
||||
if (
|
||||
test_result.file_name != other_test_result.file_name
|
||||
or test_result.did_pass != other_test_result.did_pass
|
||||
or test_result.runtime != other_test_result.runtime
|
||||
or test_result.test_framework != other_test_result.test_framework
|
||||
or test_result.test_type != other_test_result.test_type
|
||||
or not comparator(
|
||||
test_result.return_value,
|
||||
other_test_result.return_value,
|
||||
)
|
||||
):
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
return False
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
return True
|
||||
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
|
||||
report = {}
|
||||
for test_type in TestType:
|
||||
report[test_type] = {"passed": 0, "failed": 0}
|
||||
for test_result in self.test_results:
|
||||
if test_result.test_type != TestType.EXISTING_UNIT_TEST or test_result.id.function_getting_tested:
|
||||
if test_result.did_pass:
|
||||
report[test_result.test_type]["passed"] += 1
|
||||
else:
|
||||
report[test_result.test_type]["failed"] += 1
|
||||
return report"""
|
||||
optim_code = """from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import Iterator
|
||||
|
||||
from codeflash.verification.comparator import comparator
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TestType(Enum):
|
||||
EXISTING_UNIT_TEST = 1
|
||||
INSPIRED_REGRESSION = 2
|
||||
GENERATED_REGRESSION = 3
|
||||
REPLAY_TEST = 4
|
||||
|
||||
def to_name(self) -> str:
|
||||
if self == TestType.EXISTING_UNIT_TEST:
|
||||
return "⚙️ Existing Unit Tests"
|
||||
elif self == TestType.INSPIRED_REGRESSION:
|
||||
return "🎨 Inspired Regression Tests"
|
||||
elif self == TestType.GENERATED_REGRESSION:
|
||||
return "🌀 Generated Regression Tests"
|
||||
elif self == TestType.REPLAY_TEST:
|
||||
return "⏪ Replay Tests"
|
||||
|
||||
class TestResults(BaseModel):
|
||||
def __iter__(self) -> Iterator[FunctionTestInvocation]:
|
||||
return iter(self.test_results)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.test_results)
|
||||
|
||||
def __getitem__(self, index: int) -> FunctionTestInvocation:
|
||||
return self.test_results[index]
|
||||
|
||||
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
|
||||
self.test_results[index] = value
|
||||
|
||||
def __delitem__(self, index: int) -> None:
|
||||
del self.test_results[index]
|
||||
|
||||
def __contains__(self, value: FunctionTestInvocation) -> bool:
|
||||
return value in self.test_results
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.test_results)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# Unordered comparison
|
||||
if not isinstance(other, TestResults) or len(self) != len(other):
|
||||
return False
|
||||
|
||||
# Increase recursion limit only if necessary
|
||||
original_recursion_limit = sys.getrecursionlimit()
|
||||
if original_recursion_limit < 5000:
|
||||
sys.setrecursionlimit(5000)
|
||||
|
||||
for test_result in self:
|
||||
other_test_result = other.get_by_id(test_result.id)
|
||||
if other_test_result is None or not (
|
||||
test_result.file_name == other_test_result.file_name and
|
||||
test_result.did_pass == other_test_result.did_pass and
|
||||
test_result.runtime == other_test_result.runtime and
|
||||
test_result.test_framework == other_test_result.test_framework and
|
||||
test_result.test_type == other_test_result.test_type and
|
||||
comparator(test_result.return_value, other_test_result.return_value)
|
||||
):
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
return False
|
||||
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
return True
|
||||
|
||||
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
|
||||
report = {test_type: {"passed": 0, "failed": 0} for test_type in TestType}
|
||||
for test_result in self.test_results:
|
||||
if test_result.test_type != TestType.EXISTING_UNIT_TEST or test_result.id.function_getting_tested:
|
||||
key = "passed" if test_result.did_pass else "failed"
|
||||
report[test_result.test_type][key] += 1
|
||||
return report"""
|
||||
|
||||
preexisting_objects = [
|
||||
("__contains__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__len__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__bool__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__eq__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__delitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__iter__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__setitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__getitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("get_test_pass_fail_report_by_type", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("TestType", []),
|
||||
]
|
||||
|
||||
contextual_functions = {
|
||||
("TestResults", "__bool__"),
|
||||
("TestResults", "__contains__"),
|
||||
("TestResults", "__delitem__"),
|
||||
("TestResults", "__eq__"),
|
||||
("TestResults", "__getitem__"),
|
||||
("TestResults", "__iter__"),
|
||||
("TestResults", "__len__"),
|
||||
("TestResults", "__setitem__"),
|
||||
}
|
||||
|
||||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
file_path="/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py",
|
||||
qualified_name="TestType",
|
||||
fully_qualified_name="codeflash.verification.test_results.TestType",
|
||||
only_function_name="TestType",
|
||||
source_code="",
|
||||
jedi_definition=JediDefinition(type="class"),
|
||||
),
|
||||
]
|
||||
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=["TestResults.get_test_pass_fail_report_by_type"],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).parent.resolve()),
|
||||
)
|
||||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
for helper_function in helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(
|
||||
helper_function.qualified_name,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in helper_functions_by_module_abspath.items():
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=new_code,
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=module_abspath,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).parent.resolve()),
|
||||
)
|
||||
|
||||
print("hi")
|
||||
|
||||
|
||||
def test_code_replacement_type_annotation():
|
||||
original_code = '''import numpy as np
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@dataclass(config=dict(arbitrary_types_allowed=True))
|
||||
class Matrix:
|
||||
data: Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
||||
if len(X.data) == 0 or len(Y.data) == 0:
|
||||
return np.array([])
|
||||
X = np.array(X.data)
|
||||
Y = np.array(Y.data)
|
||||
if X.shape[1] != Y.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
||||
f"and Y has shape {Y.shape}.",
|
||||
)
|
||||
X_norm = np.linalg.norm(X, axis=1)
|
||||
Y_norm = np.linalg.norm(Y, axis=1)
|
||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||
return similarity
|
||||
def cosine_similarity_top_k(
|
||||
X: Matrix,
|
||||
Y: Matrix,
|
||||
top_k: Optional[int] = 5,
|
||||
score_threshold: Optional[float] = None,
|
||||
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
||||
"""Row-wise cosine similarity with optional top-k and score threshold filtering.
|
||||
Args:
|
||||
----
|
||||
X: Matrix.
|
||||
Y: Matrix, same width as X.
|
||||
top_k: Max number of results to return.
|
||||
score_threshold: Minimum cosine similarity of results.
|
||||
Returns:
|
||||
-------
|
||||
Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx),
|
||||
second contains corresponding cosine similarities.
|
||||
"""
|
||||
if len(X.data) == 0 or len(Y.data) == 0:
|
||||
return [], []
|
||||
score_array = cosine_similarity(X, Y)
|
||||
sorted_idxs = score_array.flatten().argsort()[::-1]
|
||||
top_k = top_k or len(sorted_idxs)
|
||||
top_idxs = sorted_idxs[:top_k]
|
||||
score_threshold = score_threshold or -1.0
|
||||
top_idxs = top_idxs[score_array.flatten()[top_idxs] > score_threshold]
|
||||
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in top_idxs]
|
||||
scores = score_array.flatten()[top_idxs].tolist()
|
||||
return ret_idxs, scores
|
||||
'''
|
||||
optim_code = '''from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
from pydantic.dataclasses import dataclass
|
||||
@dataclass(config=dict(arbitrary_types_allowed=True))
|
||||
class Matrix:
|
||||
data: Union[list[list[float]], List[np.ndarray], np.ndarray]
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
||||
if len(X.data) == 0 or len(Y.data) == 0:
|
||||
return np.array([])
|
||||
|
||||
X_np, Y_np = np.asarray(X.data), np.asarray(Y.data)
|
||||
if X_np.shape[1] != Y_np.shape[1]:
|
||||
raise ValueError(f"Number of columns in X and Y must be the same. X has shape {X_np.shape} and Y has shape {Y_np.shape}.")
|
||||
X_norm = np.linalg.norm(X_np, axis=1, keepdims=True)
|
||||
Y_norm = np.linalg.norm(Y_np, axis=1, keepdims=True)
|
||||
|
||||
norm_product = X_norm * Y_norm.T
|
||||
norm_product[norm_product == 0] = np.inf # Prevent division by zero
|
||||
dot_product = np.dot(X_np, Y_np.T)
|
||||
similarity = dot_product / norm_product
|
||||
|
||||
# Any NaN or Inf values are set to 0.0
|
||||
np.nan_to_num(similarity, copy=False)
|
||||
|
||||
return similarity
|
||||
def cosine_similarity_top_k(
|
||||
X: Matrix,
|
||||
Y: Matrix,
|
||||
top_k: Optional[int] = 5,
|
||||
score_threshold: Optional[float] = None,
|
||||
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
||||
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
|
||||
if len(X.data) == 0 or len(Y.data) == 0:
|
||||
return [], []
|
||||
|
||||
score_array = cosine_similarity(X, Y)
|
||||
|
||||
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
|
||||
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
|
||||
|
||||
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
|
||||
scores = score_array.flatten()[sorted_idxs].tolist()
|
||||
|
||||
return ret_idxs, scores
|
||||
'''
|
||||
preexisting_objects = [("cosine_similarity_top_k", []), ("Matrix", []), ("cosine_similarity", [])]
|
||||
|
||||
contextual_functions = set()
|
||||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
file_path="/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/code_to_optimize/math_utils.py",
|
||||
qualified_name="Matrix",
|
||||
fully_qualified_name="code_to_optimize.math_utils.Matrix",
|
||||
only_function_name="Matrix",
|
||||
source_code="",
|
||||
jedi_definition=JediDefinition(type="class"),
|
||||
),
|
||||
FakeFunctionSource(
|
||||
file_path="/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/code_to_optimize/math_utils.py",
|
||||
qualified_name="cosine_similarity",
|
||||
fully_qualified_name="code_to_optimize.math_utils.cosine_similarity",
|
||||
only_function_name="cosine_similarity",
|
||||
source_code="",
|
||||
jedi_definition=JediDefinition(type="function"),
|
||||
),
|
||||
]
|
||||
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=["cosine_similarity_top_k"],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str((Path(__file__).parent / "code_to_optimize").resolve()),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).parent.parent.resolve()),
|
||||
)
|
||||
assert (
|
||||
new_code
|
||||
== '''import numpy as np
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@dataclass(config=dict(arbitrary_types_allowed=True))
|
||||
class Matrix:
|
||||
data: Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
||||
if len(X.data) == 0 or len(Y.data) == 0:
|
||||
return np.array([])
|
||||
X = np.array(X.data)
|
||||
Y = np.array(Y.data)
|
||||
if X.shape[1] != Y.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
||||
f"and Y has shape {Y.shape}.",
|
||||
)
|
||||
X_norm = np.linalg.norm(X, axis=1)
|
||||
Y_norm = np.linalg.norm(Y, axis=1)
|
||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||
return similarity
|
||||
def cosine_similarity_top_k(
|
||||
X: Matrix,
|
||||
Y: Matrix,
|
||||
top_k: Optional[int] = 5,
|
||||
score_threshold: Optional[float] = None,
|
||||
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
||||
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
|
||||
if len(X.data) == 0 or len(Y.data) == 0:
|
||||
return [], []
|
||||
|
||||
score_array = cosine_similarity(X, Y)
|
||||
|
||||
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
|
||||
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
|
||||
|
||||
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
|
||||
scores = score_array.flatten()[sorted_idxs].tolist()
|
||||
|
||||
return ret_idxs, scores
|
||||
'''
|
||||
)
|
||||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
for helper_function in helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(
|
||||
helper_function.qualified_name,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in helper_functions_by_module_abspath.items():
|
||||
new_helper_code: str = replace_functions_and_add_imports(
|
||||
source_code=new_code,
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=module_abspath,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).parent.parent.resolve()),
|
||||
)
|
||||
|
||||
assert (
|
||||
new_helper_code
|
||||
== '''import numpy as np
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@dataclass(config=dict(arbitrary_types_allowed=True))
|
||||
class Matrix:
|
||||
data: Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
||||
if len(X.data) == 0 or len(Y.data) == 0:
|
||||
return np.array([])
|
||||
|
||||
X_np, Y_np = np.asarray(X.data), np.asarray(Y.data)
|
||||
if X_np.shape[1] != Y_np.shape[1]:
|
||||
raise ValueError(f"Number of columns in X and Y must be the same. X has shape {X_np.shape} and Y has shape {Y_np.shape}.")
|
||||
X_norm = np.linalg.norm(X_np, axis=1, keepdims=True)
|
||||
Y_norm = np.linalg.norm(Y_np, axis=1, keepdims=True)
|
||||
|
||||
norm_product = X_norm * Y_norm.T
|
||||
norm_product[norm_product == 0] = np.inf # Prevent division by zero
|
||||
dot_product = np.dot(X_np, Y_np.T)
|
||||
similarity = dot_product / norm_product
|
||||
|
||||
# Any NaN or Inf values are set to 0.0
|
||||
np.nan_to_num(similarity, copy=False)
|
||||
|
||||
return similarity
|
||||
def cosine_similarity_top_k(
|
||||
X: Matrix,
|
||||
Y: Matrix,
|
||||
top_k: Optional[int] = 5,
|
||||
score_threshold: Optional[float] = None,
|
||||
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
||||
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
|
||||
if len(X.data) == 0 or len(Y.data) == 0:
|
||||
return [], []
|
||||
|
||||
score_array = cosine_similarity(X, Y)
|
||||
|
||||
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
|
||||
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
|
||||
|
||||
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
|
||||
scores = score_array.flatten()[sorted_idxs].tolist()
|
||||
|
||||
return ret_idxs, scores
|
||||
'''
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue