changed preexisting objects to be a set. removes duplicates naturally and makes it easier to search for matches when replacing code.

This commit is contained in:
Alvin Ryanputra 2025-03-13 18:52:11 -07:00
parent a651fbfc65
commit 274f98b209
4 changed files with 45 additions and 27 deletions

View file

@ -235,7 +235,7 @@ def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str |
return edited_code, contextual_dunder_methods
def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
def find_preexisting_object_old(source_code: str) -> list[tuple[str, list[FunctionParent]]]:
"""Find all preexisting functions, classes or class methods in the source code"""
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
try:
@ -252,3 +252,21 @@ def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionP
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")]))
return preexisting_objects
def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionParent, ...]]]:
"""Find all preexisting functions, classes or class methods in the source code"""
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = set()
try:
module_node: ast.Module = ast.parse(source_code)
except SyntaxError:
logger.exception("find_preexisting_objects - Syntax error while parsing code")
return preexisting_objects
for node in module_node.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
preexisting_objects.add((node.name, ()))
elif isinstance(node, ast.ClassDef):
preexisting_objects.add((node.name, ()))
for cnode in node.body:
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),)))
return preexisting_objects

View file

@ -38,11 +38,11 @@ class OptimFunctionCollector(cst.CSTVisitor):
def __init__(
self,
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None,
function_names: set[tuple[str | None, str]] | None = None,
) -> None:
super().__init__()
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else []
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set()
self.function_names = function_names # set of (class_name, function_name)
self.modified_functions: dict[
@ -60,7 +60,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
self.modified_init_functions[self.current_class] = node
elif (
self.preexisting_objects
and (node.name.value, []) not in self.preexisting_objects
and (node.name.value, ()) not in self.preexisting_objects
and self.current_class is None
):
self.new_functions.append(node)
@ -71,7 +71,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
return False # If already in a class, do not recurse deeper
self.current_class = node.name.value
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
for child_node in node.body.body:
if (
self.preexisting_objects
@ -159,7 +159,7 @@ def replace_functions_in_file(
source_code: str,
original_function_names: list[str],
optimized_code: str,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
) -> str:
parsed_function_names = []
for original_function_name in original_function_names:
@ -195,7 +195,7 @@ def replace_functions_and_add_imports(
function_names: list[str],
optimized_code: str,
module_abspath: Path,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
project_root_path: Path,
) -> str:
return add_needed_imports_from_module(
@ -211,7 +211,7 @@ def replace_function_definitions_in_module(
function_names: list[str],
optimized_code: str,
module_abspath: Path,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]],
project_root_path: Path,
) -> bool:
source_code: str = module_abspath.read_text(encoding="utf8")

View file

@ -65,13 +65,13 @@ def get_code_optimization_context(
if final_read_writable_tokens > optim_token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
# Setup preexisting objects for code replacer TODO: should remove duplicates
preexisting_objects = list(
# Setup preexisting objects for code replacer
preexisting_objects = list(set(
chain(
find_preexisting_objects(final_read_writable_code),
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
)
)
))
read_only_context_code = read_only_code_markdown.markdown
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code))

View file

@ -74,7 +74,7 @@ print("Hello world")
"""
function_name: str = "NewClass.new_function"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
@ -135,7 +135,7 @@ print("Hello world")
"""
function_name: str = "NewClass.new_function"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
@ -196,7 +196,7 @@ print("Salut monde")
"""
function_names: list[str] = ["other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -260,7 +260,7 @@ print("Salut monde")
"""
function_names: list[str] = ["yet_another_function", "other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -313,7 +313,7 @@ def supersort(doink):
"""
function_names: list[str] = ["sorter_deps"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -388,7 +388,7 @@ def blab(st):
print("Not cool")
"""
preexisting_objects = find_preexisting_objects(original_code_main) + find_preexisting_objects(original_code_helper)
preexisting_objects = find_preexisting_objects(original_code_main) | find_preexisting_objects(original_code_helper)
new_main_code: str = replace_functions_and_add_imports(
source_code=original_code_main,
function_names=["other_function"],
@ -591,7 +591,7 @@ class CacheConfig(BaseConfig):
)
"""
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
@ -662,7 +662,7 @@ def test_test_libcst_code_replacement8() -> None:
return np.sum(a != b) / a.size
'''
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -715,7 +715,7 @@ def totally_new_function(value: Optional[str]):
print("Hello world")
"""
function_name: str = "NewClass.__init__"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
@ -814,8 +814,8 @@ def test_code_replacement11() -> None:
'''
function_name: str = "Fu.foo"
parents = [FunctionParent("Fu", "ClassDef")]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
parents = (FunctionParent("Fu", "ClassDef"),)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = {("foo", parents), ("real_bar", parents)}
new_code: str = replace_functions_in_file(
source_code=original_code,
original_function_names=[function_name],
@ -854,7 +854,7 @@ def test_code_replacement12() -> None:
pass
'''
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
new_code: str = replace_functions_in_file(
source_code=original_code,
original_function_names=["Fu.real_bar"],
@ -891,7 +891,7 @@ def test_test_libcst_code_replacement13() -> None:
"""
function_names: list[str] = ["yet_another_function", "other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = []
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -1278,7 +1278,7 @@ def cosine_similarity_top_k(
return ret_idxs, scores
'''
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
helper_functions = [
FakeFunctionSource(
@ -1579,7 +1579,7 @@ print("Hello world")
"NewClass.new_function2",
"NestedClass.nested_function",
] # Nested classes should be ignored, even if provided as target
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -1615,7 +1615,7 @@ print("Hello world")
"""
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,