mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge pull request #1650 from codeflash-ai/fix/unused-helper-attribute-refs
fix: detect attribute-referenced methods as used in unused helper detection
This commit is contained in:
commit
5d872e845d
2 changed files with 763 additions and 3 deletions
|
|
@ -774,7 +774,7 @@ def detect_unused_helper_functions(
|
|||
# First, analyze imports to build a mapping of imported names to their original qualified names
|
||||
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)
|
||||
|
||||
# Extract all function calls in the entrypoint function
|
||||
# Extract all function calls and attribute references in the entrypoint function
|
||||
called_function_names = {function_to_optimize.function_name}
|
||||
for node in ast.walk(entrypoint_function_ast):
|
||||
if isinstance(node, ast.Call):
|
||||
|
|
@ -795,7 +795,6 @@ def detect_unused_helper_functions(
|
|||
# self.method_name() -> add both method_name and ClassName.method_name
|
||||
called_function_names.add(attr_name)
|
||||
# For class methods, also add the qualified name
|
||||
# For class methods, also add the qualified name
|
||||
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
|
||||
class_name = function_to_optimize.parents[0].name
|
||||
called_function_names.add(f"{class_name}.{attr_name}")
|
||||
|
|
@ -808,9 +807,25 @@ def detect_unused_helper_functions(
|
|||
if mapped_names:
|
||||
called_function_names.update(mapped_names)
|
||||
# Handle nested attribute access like obj.attr.method()
|
||||
# Handle nested attribute access like obj.attr.method()
|
||||
else:
|
||||
called_function_names.add(node.func.attr)
|
||||
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
|
||||
# Attribute reference without call: e.g. self._parse1 = self._parse_literal
|
||||
# This covers methods used as callbacks, stored in variables, passed as arguments, etc.
|
||||
attr_name = node.attr
|
||||
value_id = node.value.id
|
||||
if value_id == "self":
|
||||
called_function_names.add(attr_name)
|
||||
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
|
||||
class_name = function_to_optimize.parents[0].name
|
||||
called_function_names.add(f"{class_name}.{attr_name}")
|
||||
else:
|
||||
called_function_names.add(attr_name)
|
||||
full_ref = f"{value_id}.{attr_name}"
|
||||
called_function_names.add(full_ref)
|
||||
mapped_names = imported_names_map.get(full_ref)
|
||||
if mapped_names:
|
||||
called_function_names.update(mapped_names)
|
||||
|
||||
logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
|
||||
logger.debug(f"Imported names mapping: {imported_names_map}")
|
||||
|
|
|
|||
745
tests/test_mock_candidate_replacement.py
Normal file
745
tests/test_mock_candidate_replacement.py
Normal file
|
|
@ -0,0 +1,745 @@
|
|||
"""Test replace_function_and_helpers_with_optimized_code with mock candidate from mock_candidate.txt."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.unused_definition_remover import detect_unused_helper_functions
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
ORIGINAL_SOURCE = '''\
|
||||
import contextlib
|
||||
from typing import BinaryIO, TypeVar, Union
|
||||
|
||||
_SymbolT = TypeVar("_SymbolT", PSLiteral, PSKeyword)
|
||||
|
||||
|
||||
PSLiteralTable = PSSymbolTable(PSLiteral)
|
||||
PSKeywordTable = PSSymbolTable(PSKeyword)
|
||||
LIT = PSLiteralTable.intern
|
||||
KWD = PSKeywordTable.intern
|
||||
KEYWORD_DICT_BEGIN = KWD(b"<<")
|
||||
KEYWORD_DICT_END = KWD(b">>")
|
||||
|
||||
|
||||
PSBaseParserToken = Union[float, bool, PSLiteral, PSKeyword, bytes]
|
||||
|
||||
|
||||
class PSBaseParser:
|
||||
|
||||
def __init__(self, fp: BinaryIO) -> None:
|
||||
self.fp = fp
|
||||
self.eof = False
|
||||
self.seek(0)
|
||||
|
||||
def _parse_main(self, s: bytes, i: int) -> int:
|
||||
m = NONSPC.search(s, i)
|
||||
if not m:
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
c = s[j : j + 1]
|
||||
self._curtokenpos = self.bufpos + j
|
||||
if c == b"%":
|
||||
self._curtoken = b"%"
|
||||
self._parse1 = self._parse_comment
|
||||
return j + 1
|
||||
elif c == b"/":
|
||||
self._curtoken = b""
|
||||
self._parse1 = self._parse_literal
|
||||
return j + 1
|
||||
elif c in b"-+" or c.isdigit():
|
||||
self._curtoken = c
|
||||
self._parse1 = self._parse_number
|
||||
return j + 1
|
||||
elif c == b".":
|
||||
self._curtoken = c
|
||||
self._parse1 = self._parse_float
|
||||
return j + 1
|
||||
elif c.isalpha():
|
||||
self._curtoken = c
|
||||
self._parse1 = self._parse_keyword
|
||||
return j + 1
|
||||
elif c == b"(":
|
||||
self._curtoken = b""
|
||||
self.paren = 1
|
||||
self._parse1 = self._parse_string
|
||||
return j + 1
|
||||
elif c == b"<":
|
||||
self._curtoken = b""
|
||||
self._parse1 = self._parse_wopen
|
||||
return j + 1
|
||||
elif c == b">":
|
||||
self._curtoken = b""
|
||||
self._parse1 = self._parse_wclose
|
||||
return j + 1
|
||||
elif c == b"\\x00":
|
||||
return j + 1
|
||||
else:
|
||||
self._add_token(KWD(c))
|
||||
return j + 1
|
||||
|
||||
def _add_token(self, obj: PSBaseParserToken) -> None:
|
||||
self._tokens.append((self._curtokenpos, obj))
|
||||
|
||||
def _parse_comment(self, s: bytes, i: int) -> int:
|
||||
m = EOL.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_literal(self, s: bytes, i: int) -> int:
|
||||
m = END_LITERAL.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
c = s[j : j + 1]
|
||||
if c == b"#":
|
||||
self.hex = b""
|
||||
self._parse1 = self._parse_literal_hex
|
||||
return j + 1
|
||||
try:
|
||||
name: str | bytes = str(self._curtoken, "utf-8")
|
||||
except Exception:
|
||||
name = self._curtoken
|
||||
self._add_token(LIT(name))
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_number(self, s: bytes, i: int) -> int:
|
||||
m = END_NUMBER.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
c = s[j : j + 1]
|
||||
if c == b".":
|
||||
self._curtoken += b"."
|
||||
self._parse1 = self._parse_float
|
||||
return j + 1
|
||||
with contextlib.suppress(ValueError):
|
||||
self._add_token(int(self._curtoken))
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_float(self, s: bytes, i: int) -> int:
|
||||
m = END_NUMBER.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
with contextlib.suppress(ValueError):
|
||||
self._add_token(float(self._curtoken))
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_keyword(self, s: bytes, i: int) -> int:
|
||||
m = END_KEYWORD.search(s, i)
|
||||
if m:
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
else:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
if self._curtoken == b"true":
|
||||
token: bool | PSKeyword = True
|
||||
elif self._curtoken == b"false":
|
||||
token = False
|
||||
else:
|
||||
token = KWD(self._curtoken)
|
||||
self._add_token(token)
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_string(self, s: bytes, i: int) -> int:
|
||||
m = END_STRING.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
c = s[j : j + 1]
|
||||
if c == b"\\\\":
|
||||
self.oct = b""
|
||||
self._parse1 = self._parse_string_1
|
||||
return j + 1
|
||||
if c == b"(":
|
||||
self.paren += 1
|
||||
self._curtoken += c
|
||||
return j + 1
|
||||
if c == b")":
|
||||
self.paren -= 1
|
||||
if self.paren:
|
||||
self._curtoken += c
|
||||
return j + 1
|
||||
self._add_token(self._curtoken)
|
||||
self._parse1 = self._parse_main
|
||||
return j + 1
|
||||
|
||||
def _parse_wopen(self, s: bytes, i: int) -> int:
|
||||
c = s[i : i + 1]
|
||||
if c == b"<":
|
||||
self._add_token(KEYWORD_DICT_BEGIN)
|
||||
self._parse1 = self._parse_main
|
||||
i += 1
|
||||
else:
|
||||
self._parse1 = self._parse_hexstring
|
||||
return i
|
||||
|
||||
def _parse_wclose(self, s: bytes, i: int) -> int:
|
||||
c = s[i : i + 1]
|
||||
if c == b">":
|
||||
self._add_token(KEYWORD_DICT_END)
|
||||
i += 1
|
||||
self._parse1 = self._parse_main
|
||||
return i
|
||||
'''
|
||||
|
||||
MOCK_CANDIDATE_MARKDOWN = '''\
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
import contextlib
|
||||
from typing import BinaryIO, TypeVar, Union
|
||||
|
||||
_SymbolT = TypeVar("_SymbolT", PSLiteral, PSKeyword)
|
||||
|
||||
|
||||
PSLiteralTable = PSSymbolTable(PSLiteral)
|
||||
PSKeywordTable = PSSymbolTable(PSKeyword)
|
||||
LIT = PSLiteralTable.intern
|
||||
KWD = PSKeywordTable.intern
|
||||
KEYWORD_DICT_BEGIN = KWD(b"<<")
|
||||
KEYWORD_DICT_END = KWD(b">>")
|
||||
|
||||
|
||||
PSBaseParserToken = Union[float, bool, PSLiteral, PSKeyword, bytes]
|
||||
|
||||
|
||||
class PSBaseParser:
|
||||
|
||||
def __init__(self, fp: BinaryIO) -> None:
|
||||
self.fp = fp
|
||||
self.eof = False
|
||||
self.seek(0)
|
||||
|
||||
def _parse_main(self, s: bytes, i: int) -> int:
|
||||
m = NONSPC.search(s, i)
|
||||
if not m:
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
# Use integer byte access to avoid creating a new one-byte bytes object.
|
||||
c_int = s[j]
|
||||
c_byte = bytes((c_int,))
|
||||
self._curtokenpos = self.bufpos + j
|
||||
if c_int == 37: # b"%"
|
||||
self._curtoken = b"%"
|
||||
self._parse1 = self._parse_comment
|
||||
return j + 1
|
||||
elif c_int == 47: # b"/"
|
||||
self._curtoken = b""
|
||||
self._parse1 = self._parse_literal
|
||||
return j + 1
|
||||
# b"-" is 45, b"+" is 43
|
||||
elif c_int == 45 or c_int == 43 or (48 <= c_int <= 57):
|
||||
self._curtoken = c_byte
|
||||
self._parse1 = self._parse_number
|
||||
return j + 1
|
||||
elif c_int == 46: # b"."
|
||||
self._curtoken = c_byte
|
||||
self._parse1 = self._parse_float
|
||||
return j + 1
|
||||
# ASCII alphabetic check
|
||||
elif (65 <= c_int <= 90) or (97 <= c_int <= 122):
|
||||
self._curtoken = c_byte
|
||||
self._parse1 = self._parse_keyword
|
||||
return j + 1
|
||||
elif c_int == 40: # b"("
|
||||
self._curtoken = b""
|
||||
self.paren = 1
|
||||
self._parse1 = self._parse_string
|
||||
return j + 1
|
||||
elif c_int == 60: # b"<"
|
||||
self._curtoken = b""
|
||||
self._parse1 = self._parse_wopen
|
||||
return j + 1
|
||||
elif c_int == 62: # b">"
|
||||
self._curtoken = b""
|
||||
self._parse1 = self._parse_wclose
|
||||
return j + 1
|
||||
elif c_int == 0: # b"\\x00"
|
||||
return j + 1
|
||||
else:
|
||||
self._add_token(KWD(c_byte))
|
||||
return j + 1
|
||||
|
||||
def _add_token(self, obj: PSBaseParserToken) -> None:
|
||||
self._tokens.append((self._curtokenpos, obj))
|
||||
|
||||
def _parse_comment(self, s: bytes, i: int) -> int:
|
||||
m = EOL.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
self._parse1 = self._parse_main
|
||||
# We ignore comments.
|
||||
# self._tokens.append(self._curtoken)
|
||||
return j
|
||||
|
||||
def _parse_literal(self, s: bytes, i: int) -> int:
|
||||
m = END_LITERAL.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
c_int = s[j]
|
||||
if c_int == 35: # b"#"
|
||||
self.hex = b""
|
||||
self._parse1 = self._parse_literal_hex
|
||||
return j + 1
|
||||
try:
|
||||
name: str | bytes = str(self._curtoken, "utf-8")
|
||||
except Exception:
|
||||
name = self._curtoken
|
||||
self._add_token(LIT(name))
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_number(self, s: bytes, i: int) -> int:
|
||||
m = END_NUMBER.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
c_int = s[j]
|
||||
if c_int == 46: # b"."
|
||||
self._curtoken += b"."
|
||||
self._parse1 = self._parse_float
|
||||
return j + 1
|
||||
with contextlib.suppress(ValueError):
|
||||
self._add_token(int(self._curtoken))
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_float(self, s: bytes, i: int) -> int:
|
||||
m = END_NUMBER.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
with contextlib.suppress(ValueError):
|
||||
self._add_token(float(self._curtoken))
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_keyword(self, s: bytes, i: int) -> int:
|
||||
m = END_KEYWORD.search(s, i)
|
||||
if m:
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
else:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
if self._curtoken == b"true":
|
||||
token: bool | PSKeyword = True
|
||||
elif self._curtoken == b"false":
|
||||
token = False
|
||||
else:
|
||||
token = KWD(self._curtoken)
|
||||
self._add_token(token)
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_string(self, s: bytes, i: int) -> int:
|
||||
m = END_STRING.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
c_int = s[j]
|
||||
if c_int == 92: # b"\\\\"
|
||||
self.oct = b""
|
||||
self._parse1 = self._parse_string_1
|
||||
return j + 1
|
||||
if c_int == 40: # b"("
|
||||
self.paren += 1
|
||||
# append the literal "(" byte
|
||||
self._curtoken += b"("
|
||||
return j + 1
|
||||
if c_int == 41: # b")"
|
||||
self.paren -= 1
|
||||
if self.paren:
|
||||
# WTF, they said balanced parens need no special treatment.
|
||||
self._curtoken += b")"
|
||||
return j + 1
|
||||
self._add_token(self._curtoken)
|
||||
self._parse1 = self._parse_main
|
||||
return j + 1
|
||||
|
||||
def _parse_wopen(self, s: bytes, i: int) -> int:
|
||||
c_int = s[i]
|
||||
if c_int == 60: # b"<"
|
||||
self._add_token(KEYWORD_DICT_BEGIN)
|
||||
self._parse1 = self._parse_main
|
||||
i += 1
|
||||
else:
|
||||
self._parse1 = self._parse_hexstring
|
||||
return i
|
||||
|
||||
def _parse_wclose(self, s: bytes, i: int) -> int:
|
||||
c_int = s[i]
|
||||
if c_int == 62: # b">"
|
||||
self._add_token(KEYWORD_DICT_END)
|
||||
i += 1
|
||||
self._parse1 = self._parse_main
|
||||
return i
|
||||
```
|
||||
'''
|
||||
|
||||
EXPECTED_OUTPUT = '''\
|
||||
import contextlib
|
||||
from typing import BinaryIO, TypeVar, Union
|
||||
|
||||
_SymbolT = TypeVar("_SymbolT", PSLiteral, PSKeyword)
|
||||
|
||||
|
||||
PSLiteralTable = PSSymbolTable(PSLiteral)
|
||||
PSKeywordTable = PSSymbolTable(PSKeyword)
|
||||
LIT = PSLiteralTable.intern
|
||||
KWD = PSKeywordTable.intern
|
||||
KEYWORD_DICT_BEGIN = KWD(b"<<")
|
||||
KEYWORD_DICT_END = KWD(b">>")
|
||||
|
||||
|
||||
PSBaseParserToken = Union[float, bool, PSLiteral, PSKeyword, bytes]
|
||||
|
||||
|
||||
class PSBaseParser:
|
||||
|
||||
def __init__(self, fp: BinaryIO) -> None:
|
||||
self.fp = fp
|
||||
self.eof = False
|
||||
self.seek(0)
|
||||
|
||||
def _parse_main(self, s: bytes, i: int) -> int:
|
||||
m = NONSPC.search(s, i)
|
||||
if not m:
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
# Use integer byte access to avoid creating a new one-byte bytes object.
|
||||
c_int = s[j]
|
||||
c_byte = bytes((c_int,))
|
||||
self._curtokenpos = self.bufpos + j
|
||||
if c_int == 37: # b"%"
|
||||
self._curtoken = b"%"
|
||||
self._parse1 = self._parse_comment
|
||||
return j + 1
|
||||
elif c_int == 47: # b"/"
|
||||
self._curtoken = b""
|
||||
self._parse1 = self._parse_literal
|
||||
return j + 1
|
||||
# b"-" is 45, b"+" is 43
|
||||
elif c_int == 45 or c_int == 43 or (48 <= c_int <= 57):
|
||||
self._curtoken = c_byte
|
||||
self._parse1 = self._parse_number
|
||||
return j + 1
|
||||
elif c_int == 46: # b"."
|
||||
self._curtoken = c_byte
|
||||
self._parse1 = self._parse_float
|
||||
return j + 1
|
||||
# ASCII alphabetic check
|
||||
elif (65 <= c_int <= 90) or (97 <= c_int <= 122):
|
||||
self._curtoken = c_byte
|
||||
self._parse1 = self._parse_keyword
|
||||
return j + 1
|
||||
elif c_int == 40: # b"("
|
||||
self._curtoken = b""
|
||||
self.paren = 1
|
||||
self._parse1 = self._parse_string
|
||||
return j + 1
|
||||
elif c_int == 60: # b"<"
|
||||
self._curtoken = b""
|
||||
self._parse1 = self._parse_wopen
|
||||
return j + 1
|
||||
elif c_int == 62: # b">"
|
||||
self._curtoken = b""
|
||||
self._parse1 = self._parse_wclose
|
||||
return j + 1
|
||||
elif c_int == 0: # b"\\x00"
|
||||
return j + 1
|
||||
else:
|
||||
self._add_token(KWD(c_byte))
|
||||
return j + 1
|
||||
|
||||
def _add_token(self, obj: PSBaseParserToken) -> None:
|
||||
self._tokens.append((self._curtokenpos, obj))
|
||||
|
||||
def _parse_comment(self, s: bytes, i: int) -> int:
|
||||
m = EOL.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
self._parse1 = self._parse_main
|
||||
# We ignore comments.
|
||||
# self._tokens.append(self._curtoken)
|
||||
return j
|
||||
|
||||
def _parse_literal(self, s: bytes, i: int) -> int:
|
||||
m = END_LITERAL.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
c_int = s[j]
|
||||
if c_int == 35: # b"#"
|
||||
self.hex = b""
|
||||
self._parse1 = self._parse_literal_hex
|
||||
return j + 1
|
||||
try:
|
||||
name: str | bytes = str(self._curtoken, "utf-8")
|
||||
except Exception:
|
||||
name = self._curtoken
|
||||
self._add_token(LIT(name))
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_number(self, s: bytes, i: int) -> int:
|
||||
m = END_NUMBER.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
c_int = s[j]
|
||||
if c_int == 46: # b"."
|
||||
self._curtoken += b"."
|
||||
self._parse1 = self._parse_float
|
||||
return j + 1
|
||||
with contextlib.suppress(ValueError):
|
||||
self._add_token(int(self._curtoken))
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_float(self, s: bytes, i: int) -> int:
|
||||
m = END_NUMBER.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
with contextlib.suppress(ValueError):
|
||||
self._add_token(float(self._curtoken))
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_keyword(self, s: bytes, i: int) -> int:
|
||||
m = END_KEYWORD.search(s, i)
|
||||
if m:
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
else:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
if self._curtoken == b"true":
|
||||
token: bool | PSKeyword = True
|
||||
elif self._curtoken == b"false":
|
||||
token = False
|
||||
else:
|
||||
token = KWD(self._curtoken)
|
||||
self._add_token(token)
|
||||
self._parse1 = self._parse_main
|
||||
return j
|
||||
|
||||
def _parse_string(self, s: bytes, i: int) -> int:
|
||||
m = END_STRING.search(s, i)
|
||||
if not m:
|
||||
self._curtoken += s[i:]
|
||||
return len(s)
|
||||
j = m.start(0)
|
||||
self._curtoken += s[i:j]
|
||||
c_int = s[j]
|
||||
if c_int == 92: # b"\\\\"
|
||||
self.oct = b""
|
||||
self._parse1 = self._parse_string_1
|
||||
return j + 1
|
||||
if c_int == 40: # b"("
|
||||
self.paren += 1
|
||||
# append the literal "(" byte
|
||||
self._curtoken += b"("
|
||||
return j + 1
|
||||
if c_int == 41: # b")"
|
||||
self.paren -= 1
|
||||
if self.paren:
|
||||
# WTF, they said balanced parens need no special treatment.
|
||||
self._curtoken += b")"
|
||||
return j + 1
|
||||
self._add_token(self._curtoken)
|
||||
self._parse1 = self._parse_main
|
||||
return j + 1
|
||||
|
||||
def _parse_wopen(self, s: bytes, i: int) -> int:
|
||||
c_int = s[i]
|
||||
if c_int == 60: # b"<"
|
||||
self._add_token(KEYWORD_DICT_BEGIN)
|
||||
self._parse1 = self._parse_main
|
||||
i += 1
|
||||
else:
|
||||
self._parse1 = self._parse_hexstring
|
||||
return i
|
||||
|
||||
def _parse_wclose(self, s: bytes, i: int) -> int:
|
||||
c_int = s[i]
|
||||
if c_int == 62: # b">"
|
||||
self._add_token(KEYWORD_DICT_END)
|
||||
i += 1
|
||||
self._parse1 = self._parse_main
|
||||
return i
|
||||
'''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project():
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
source_file = temp_dir / "psparser.py"
|
||||
source_file.write_text(ORIGINAL_SOURCE, encoding="utf-8")
|
||||
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
yield temp_dir, source_file, test_cfg
|
||||
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def run_replacement(temp_project):
|
||||
"""Helper: run the full replacement pipeline and return (optimizer, code_context, final_content)."""
|
||||
temp_dir, source_file, test_cfg = temp_project
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=source_file,
|
||||
function_name="_parse_main",
|
||||
parents=[FunctionParent(name="PSBaseParser", type="ClassDef")],
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=source_file.read_text(encoding="utf-8"),
|
||||
)
|
||||
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
original_content = source_file.read_text(encoding="utf-8")
|
||||
original_helper_code = {source_file: original_content}
|
||||
optimized_code = CodeStringsMarkdown.parse_markdown_code(MOCK_CANDIDATE_MARKDOWN)
|
||||
|
||||
did_update = optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context, optimized_code, original_helper_code
|
||||
)
|
||||
assert did_update, "Expected the code to be updated"
|
||||
|
||||
final_content = source_file.read_text(encoding="utf-8")
|
||||
return optimizer, code_context, final_content
|
||||
|
||||
|
||||
def test_replace_with_mock_candidate(temp_project):
|
||||
"""Verify replace_function_and_helpers_with_optimized_code produces the exact expected output.
|
||||
|
||||
The code context detects ALL sibling methods as helpers of _parse_main.
|
||||
replace_function_definitions_in_module replaces ALL method bodies.
|
||||
detect_unused_helper_functions correctly recognizes methods referenced via attribute
|
||||
assignment (self._parse1 = self._parse_literal) as used, so they are NOT reverted.
|
||||
"""
|
||||
_, code_context, final_content = run_replacement(temp_project)
|
||||
|
||||
# Code context correctly detects ALL methods as helpers
|
||||
helper_names = {h.qualified_name for h in code_context.helper_functions}
|
||||
assert helper_names == {
|
||||
"PSBaseParser._parse_comment",
|
||||
"PSBaseParser._parse_literal",
|
||||
"PSBaseParser._parse_number",
|
||||
"PSBaseParser._parse_float",
|
||||
"PSBaseParser._parse_keyword",
|
||||
"PSBaseParser._parse_string",
|
||||
"PSBaseParser._parse_wopen",
|
||||
"PSBaseParser._parse_wclose",
|
||||
"PSBaseParser._add_token",
|
||||
"KWD",
|
||||
}
|
||||
|
||||
# The final content should match the expected output exactly
|
||||
assert final_content == EXPECTED_OUTPUT
|
||||
|
||||
|
||||
def test_detect_unused_helpers_handles_attribute_refs(temp_project):
|
||||
"""Verify detect_unused_helper_functions recognizes methods referenced via attribute assignment.
|
||||
|
||||
When _parse_main does `self._parse1 = self._parse_literal`, the method is referenced as
|
||||
an ast.Attribute value (not an ast.Call). The detection should recognize these as used.
|
||||
"""
|
||||
temp_dir, source_file, test_cfg = temp_project
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=source_file,
|
||||
function_name="_parse_main",
|
||||
parents=[FunctionParent(name="PSBaseParser", type="ClassDef")],
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=source_file.read_text(encoding="utf-8"),
|
||||
)
|
||||
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful()
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
optimized_code = CodeStringsMarkdown.parse_markdown_code(MOCK_CANDIDATE_MARKDOWN)
|
||||
|
||||
unused_helpers = detect_unused_helper_functions(
|
||||
optimizer.function_to_optimize, code_context, optimized_code
|
||||
)
|
||||
unused_names = {h.qualified_name for h in unused_helpers}
|
||||
|
||||
# No helpers should be detected as unused — all are either directly called or
|
||||
# referenced via attribute assignment (self._parse1 = self._parse_X)
|
||||
assert unused_names == set(), f"Expected no unused helpers, got: {unused_names}"
|
||||
|
||||
|
||||
def test_replace_produces_valid_python(temp_project):
|
||||
"""Verify the final output is valid, parseable Python."""
|
||||
_, _, final_content = run_replacement(temp_project)
|
||||
|
||||
import ast
|
||||
ast.parse(final_content)
|
||||
Loading…
Reference in a new issue