mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
changed preexisting objects to be a set. removes duplicates naturally and makes it easier to search for matches when replacing code.
This commit is contained in:
parent
a651fbfc65
commit
274f98b209
4 changed files with 45 additions and 27 deletions
|
|
@ -235,7 +235,7 @@ def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str |
|
|||
return edited_code, contextual_dunder_methods
|
||||
|
||||
|
||||
def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
|
||||
def find_preexisting_object_old(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
|
||||
"""Find all preexisting functions, classes or class methods in the source code"""
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
try:
|
||||
|
|
@ -252,3 +252,21 @@ def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionP
|
|||
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")]))
|
||||
return preexisting_objects
|
||||
|
||||
def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionParent, ...]]]:
|
||||
"""Find all preexisting functions, classes or class methods in the source code"""
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = set()
|
||||
try:
|
||||
module_node: ast.Module = ast.parse(source_code)
|
||||
except SyntaxError:
|
||||
logger.exception("find_preexisting_objects - Syntax error while parsing code")
|
||||
return preexisting_objects
|
||||
for node in module_node.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
preexisting_objects.add((node.name, ()))
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
preexisting_objects.add((node.name, ()))
|
||||
for cnode in node.body:
|
||||
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),)))
|
||||
return preexisting_objects
|
||||
|
|
|
|||
|
|
@ -38,11 +38,11 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None,
|
||||
function_names: set[tuple[str | None, str]] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else []
|
||||
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set()
|
||||
|
||||
self.function_names = function_names # set of (class_name, function_name)
|
||||
self.modified_functions: dict[
|
||||
|
|
@ -60,7 +60,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
self.modified_init_functions[self.current_class] = node
|
||||
elif (
|
||||
self.preexisting_objects
|
||||
and (node.name.value, []) not in self.preexisting_objects
|
||||
and (node.name.value, ()) not in self.preexisting_objects
|
||||
and self.current_class is None
|
||||
):
|
||||
self.new_functions.append(node)
|
||||
|
|
@ -71,7 +71,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
return False # If already in a class, do not recurse deeper
|
||||
self.current_class = node.name.value
|
||||
|
||||
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
|
||||
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
|
||||
for child_node in node.body.body:
|
||||
if (
|
||||
self.preexisting_objects
|
||||
|
|
@ -159,7 +159,7 @@ def replace_functions_in_file(
|
|||
source_code: str,
|
||||
original_function_names: list[str],
|
||||
optimized_code: str,
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
|
||||
) -> str:
|
||||
parsed_function_names = []
|
||||
for original_function_name in original_function_names:
|
||||
|
|
@ -195,7 +195,7 @@ def replace_functions_and_add_imports(
|
|||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
module_abspath: Path,
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
|
||||
project_root_path: Path,
|
||||
) -> str:
|
||||
return add_needed_imports_from_module(
|
||||
|
|
@ -211,7 +211,7 @@ def replace_function_definitions_in_module(
|
|||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
module_abspath: Path,
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
|
||||
project_root_path: Path,
|
||||
) -> bool:
|
||||
source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
|
|
|
|||
|
|
@ -65,13 +65,13 @@ def get_code_optimization_context(
|
|||
if final_read_writable_tokens > optim_token_limit:
|
||||
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
|
||||
|
||||
# Setup preexisting objects for code replacer TODO: should remove duplicates
|
||||
preexisting_objects = list(
|
||||
# Setup preexisting objects for code replacer
|
||||
preexisting_objects = list(set(
|
||||
chain(
|
||||
find_preexisting_objects(final_read_writable_code),
|
||||
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
|
||||
)
|
||||
)
|
||||
))
|
||||
read_only_context_code = read_only_code_markdown.markdown
|
||||
|
||||
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code))
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
|
|
@ -135,7 +135,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
|
|
@ -196,7 +196,7 @@ print("Salut monde")
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["other_function"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -260,7 +260,7 @@ print("Salut monde")
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["yet_another_function", "other_function"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -313,7 +313,7 @@ def supersort(doink):
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["sorter_deps"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -388,7 +388,7 @@ def blab(st):
|
|||
|
||||
print("Not cool")
|
||||
"""
|
||||
preexisting_objects = find_preexisting_objects(original_code_main) + find_preexisting_objects(original_code_helper)
|
||||
preexisting_objects = find_preexisting_objects(original_code_main) | find_preexisting_objects(original_code_helper)
|
||||
new_main_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code_main,
|
||||
function_names=["other_function"],
|
||||
|
|
@ -591,7 +591,7 @@ class CacheConfig(BaseConfig):
|
|||
)
|
||||
"""
|
||||
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
|
|
@ -662,7 +662,7 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
return np.sum(a != b) / a.size
|
||||
'''
|
||||
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -715,7 +715,7 @@ def totally_new_function(value: Optional[str]):
|
|||
print("Hello world")
|
||||
"""
|
||||
function_name: str = "NewClass.__init__"
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
|
|
@ -814,8 +814,8 @@ def test_code_replacement11() -> None:
|
|||
'''
|
||||
|
||||
function_name: str = "Fu.foo"
|
||||
parents = [FunctionParent("Fu", "ClassDef")]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
|
||||
parents = (FunctionParent("Fu", "ClassDef"),)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = {("foo", parents), ("real_bar", parents)}
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=[function_name],
|
||||
|
|
@ -854,7 +854,7 @@ def test_code_replacement12() -> None:
|
|||
pass
|
||||
'''
|
||||
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=["Fu.real_bar"],
|
||||
|
|
@ -891,7 +891,7 @@ def test_test_libcst_code_replacement13() -> None:
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["yet_another_function", "other_function"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -1278,7 +1278,7 @@ def cosine_similarity_top_k(
|
|||
|
||||
return ret_idxs, scores
|
||||
'''
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
|
||||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
|
|
@ -1579,7 +1579,7 @@ print("Hello world")
|
|||
"NewClass.new_function2",
|
||||
"NestedClass.nested_function",
|
||||
] # Nested classes should be ignored, even if provided as target
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -1615,7 +1615,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
|
|||
Loading…
Reference in a new issue