745 lines
23 KiB
Python
745 lines
23 KiB
Python
"""Test replace_function_and_helpers_with_optimized_code with mock candidate from mock_candidate.txt."""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|
from codeflash.languages.python.context.unused_definition_remover import detect_unused_helper_functions
|
|
from codeflash.models.function_types import FunctionParent
|
|
from codeflash.models.models import CodeStringsMarkdown
|
|
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
|
from codeflash.verification.verification_utils import TestConfig
|
|
|
|
ORIGINAL_SOURCE = '''\
|
|
import contextlib
|
|
from typing import BinaryIO, TypeVar, Union
|
|
|
|
_SymbolT = TypeVar("_SymbolT", PSLiteral, PSKeyword)
|
|
|
|
|
|
PSLiteralTable = PSSymbolTable(PSLiteral)
|
|
PSKeywordTable = PSSymbolTable(PSKeyword)
|
|
LIT = PSLiteralTable.intern
|
|
KWD = PSKeywordTable.intern
|
|
KEYWORD_DICT_BEGIN = KWD(b"<<")
|
|
KEYWORD_DICT_END = KWD(b">>")
|
|
|
|
|
|
PSBaseParserToken = Union[float, bool, PSLiteral, PSKeyword, bytes]
|
|
|
|
|
|
class PSBaseParser:
|
|
|
|
def __init__(self, fp: BinaryIO) -> None:
|
|
self.fp = fp
|
|
self.eof = False
|
|
self.seek(0)
|
|
|
|
def _parse_main(self, s: bytes, i: int) -> int:
|
|
m = NONSPC.search(s, i)
|
|
if not m:
|
|
return len(s)
|
|
j = m.start(0)
|
|
c = s[j : j + 1]
|
|
self._curtokenpos = self.bufpos + j
|
|
if c == b"%":
|
|
self._curtoken = b"%"
|
|
self._parse1 = self._parse_comment
|
|
return j + 1
|
|
elif c == b"/":
|
|
self._curtoken = b""
|
|
self._parse1 = self._parse_literal
|
|
return j + 1
|
|
elif c in b"-+" or c.isdigit():
|
|
self._curtoken = c
|
|
self._parse1 = self._parse_number
|
|
return j + 1
|
|
elif c == b".":
|
|
self._curtoken = c
|
|
self._parse1 = self._parse_float
|
|
return j + 1
|
|
elif c.isalpha():
|
|
self._curtoken = c
|
|
self._parse1 = self._parse_keyword
|
|
return j + 1
|
|
elif c == b"(":
|
|
self._curtoken = b""
|
|
self.paren = 1
|
|
self._parse1 = self._parse_string
|
|
return j + 1
|
|
elif c == b"<":
|
|
self._curtoken = b""
|
|
self._parse1 = self._parse_wopen
|
|
return j + 1
|
|
elif c == b">":
|
|
self._curtoken = b""
|
|
self._parse1 = self._parse_wclose
|
|
return j + 1
|
|
elif c == b"\\x00":
|
|
return j + 1
|
|
else:
|
|
self._add_token(KWD(c))
|
|
return j + 1
|
|
|
|
def _add_token(self, obj: PSBaseParserToken) -> None:
|
|
self._tokens.append((self._curtokenpos, obj))
|
|
|
|
def _parse_comment(self, s: bytes, i: int) -> int:
|
|
m = EOL.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_literal(self, s: bytes, i: int) -> int:
|
|
m = END_LITERAL.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
c = s[j : j + 1]
|
|
if c == b"#":
|
|
self.hex = b""
|
|
self._parse1 = self._parse_literal_hex
|
|
return j + 1
|
|
try:
|
|
name: str | bytes = str(self._curtoken, "utf-8")
|
|
except Exception:
|
|
name = self._curtoken
|
|
self._add_token(LIT(name))
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_number(self, s: bytes, i: int) -> int:
|
|
m = END_NUMBER.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
c = s[j : j + 1]
|
|
if c == b".":
|
|
self._curtoken += b"."
|
|
self._parse1 = self._parse_float
|
|
return j + 1
|
|
with contextlib.suppress(ValueError):
|
|
self._add_token(int(self._curtoken))
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_float(self, s: bytes, i: int) -> int:
|
|
m = END_NUMBER.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
with contextlib.suppress(ValueError):
|
|
self._add_token(float(self._curtoken))
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_keyword(self, s: bytes, i: int) -> int:
|
|
m = END_KEYWORD.search(s, i)
|
|
if m:
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
else:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
if self._curtoken == b"true":
|
|
token: bool | PSKeyword = True
|
|
elif self._curtoken == b"false":
|
|
token = False
|
|
else:
|
|
token = KWD(self._curtoken)
|
|
self._add_token(token)
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_string(self, s: bytes, i: int) -> int:
|
|
m = END_STRING.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
c = s[j : j + 1]
|
|
if c == b"\\\\":
|
|
self.oct = b""
|
|
self._parse1 = self._parse_string_1
|
|
return j + 1
|
|
if c == b"(":
|
|
self.paren += 1
|
|
self._curtoken += c
|
|
return j + 1
|
|
if c == b")":
|
|
self.paren -= 1
|
|
if self.paren:
|
|
self._curtoken += c
|
|
return j + 1
|
|
self._add_token(self._curtoken)
|
|
self._parse1 = self._parse_main
|
|
return j + 1
|
|
|
|
def _parse_wopen(self, s: bytes, i: int) -> int:
|
|
c = s[i : i + 1]
|
|
if c == b"<":
|
|
self._add_token(KEYWORD_DICT_BEGIN)
|
|
self._parse1 = self._parse_main
|
|
i += 1
|
|
else:
|
|
self._parse1 = self._parse_hexstring
|
|
return i
|
|
|
|
def _parse_wclose(self, s: bytes, i: int) -> int:
|
|
c = s[i : i + 1]
|
|
if c == b">":
|
|
self._add_token(KEYWORD_DICT_END)
|
|
i += 1
|
|
self._parse1 = self._parse_main
|
|
return i
|
|
'''
|
|
|
|
MOCK_CANDIDATE_MARKDOWN = '''\
|
|
```python
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
import contextlib
|
|
from typing import BinaryIO, TypeVar, Union
|
|
|
|
_SymbolT = TypeVar("_SymbolT", PSLiteral, PSKeyword)
|
|
|
|
|
|
PSLiteralTable = PSSymbolTable(PSLiteral)
|
|
PSKeywordTable = PSSymbolTable(PSKeyword)
|
|
LIT = PSLiteralTable.intern
|
|
KWD = PSKeywordTable.intern
|
|
KEYWORD_DICT_BEGIN = KWD(b"<<")
|
|
KEYWORD_DICT_END = KWD(b">>")
|
|
|
|
|
|
PSBaseParserToken = Union[float, bool, PSLiteral, PSKeyword, bytes]
|
|
|
|
|
|
class PSBaseParser:
|
|
|
|
def __init__(self, fp: BinaryIO) -> None:
|
|
self.fp = fp
|
|
self.eof = False
|
|
self.seek(0)
|
|
|
|
def _parse_main(self, s: bytes, i: int) -> int:
|
|
m = NONSPC.search(s, i)
|
|
if not m:
|
|
return len(s)
|
|
j = m.start(0)
|
|
# Use integer byte access to avoid creating a new one-byte bytes object.
|
|
c_int = s[j]
|
|
c_byte = bytes((c_int,))
|
|
self._curtokenpos = self.bufpos + j
|
|
if c_int == 37: # b"%"
|
|
self._curtoken = b"%"
|
|
self._parse1 = self._parse_comment
|
|
return j + 1
|
|
elif c_int == 47: # b"/"
|
|
self._curtoken = b""
|
|
self._parse1 = self._parse_literal
|
|
return j + 1
|
|
# b"-" is 45, b"+" is 43
|
|
elif c_int == 45 or c_int == 43 or (48 <= c_int <= 57):
|
|
self._curtoken = c_byte
|
|
self._parse1 = self._parse_number
|
|
return j + 1
|
|
elif c_int == 46: # b"."
|
|
self._curtoken = c_byte
|
|
self._parse1 = self._parse_float
|
|
return j + 1
|
|
# ASCII alphabetic check
|
|
elif (65 <= c_int <= 90) or (97 <= c_int <= 122):
|
|
self._curtoken = c_byte
|
|
self._parse1 = self._parse_keyword
|
|
return j + 1
|
|
elif c_int == 40: # b"("
|
|
self._curtoken = b""
|
|
self.paren = 1
|
|
self._parse1 = self._parse_string
|
|
return j + 1
|
|
elif c_int == 60: # b"<"
|
|
self._curtoken = b""
|
|
self._parse1 = self._parse_wopen
|
|
return j + 1
|
|
elif c_int == 62: # b">"
|
|
self._curtoken = b""
|
|
self._parse1 = self._parse_wclose
|
|
return j + 1
|
|
elif c_int == 0: # b"\\x00"
|
|
return j + 1
|
|
else:
|
|
self._add_token(KWD(c_byte))
|
|
return j + 1
|
|
|
|
def _add_token(self, obj: PSBaseParserToken) -> None:
|
|
self._tokens.append((self._curtokenpos, obj))
|
|
|
|
def _parse_comment(self, s: bytes, i: int) -> int:
|
|
m = EOL.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
self._parse1 = self._parse_main
|
|
# We ignore comments.
|
|
# self._tokens.append(self._curtoken)
|
|
return j
|
|
|
|
def _parse_literal(self, s: bytes, i: int) -> int:
|
|
m = END_LITERAL.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
c_int = s[j]
|
|
if c_int == 35: # b"#"
|
|
self.hex = b""
|
|
self._parse1 = self._parse_literal_hex
|
|
return j + 1
|
|
try:
|
|
name: str | bytes = str(self._curtoken, "utf-8")
|
|
except Exception:
|
|
name = self._curtoken
|
|
self._add_token(LIT(name))
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_number(self, s: bytes, i: int) -> int:
|
|
m = END_NUMBER.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
c_int = s[j]
|
|
if c_int == 46: # b"."
|
|
self._curtoken += b"."
|
|
self._parse1 = self._parse_float
|
|
return j + 1
|
|
with contextlib.suppress(ValueError):
|
|
self._add_token(int(self._curtoken))
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_float(self, s: bytes, i: int) -> int:
|
|
m = END_NUMBER.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
with contextlib.suppress(ValueError):
|
|
self._add_token(float(self._curtoken))
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_keyword(self, s: bytes, i: int) -> int:
|
|
m = END_KEYWORD.search(s, i)
|
|
if m:
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
else:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
if self._curtoken == b"true":
|
|
token: bool | PSKeyword = True
|
|
elif self._curtoken == b"false":
|
|
token = False
|
|
else:
|
|
token = KWD(self._curtoken)
|
|
self._add_token(token)
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_string(self, s: bytes, i: int) -> int:
|
|
m = END_STRING.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
c_int = s[j]
|
|
if c_int == 92: # b"\\\\"
|
|
self.oct = b""
|
|
self._parse1 = self._parse_string_1
|
|
return j + 1
|
|
if c_int == 40: # b"("
|
|
self.paren += 1
|
|
# append the literal "(" byte
|
|
self._curtoken += b"("
|
|
return j + 1
|
|
if c_int == 41: # b")"
|
|
self.paren -= 1
|
|
if self.paren:
|
|
# WTF, they said balanced parens need no special treatment.
|
|
self._curtoken += b")"
|
|
return j + 1
|
|
self._add_token(self._curtoken)
|
|
self._parse1 = self._parse_main
|
|
return j + 1
|
|
|
|
def _parse_wopen(self, s: bytes, i: int) -> int:
|
|
c_int = s[i]
|
|
if c_int == 60: # b"<"
|
|
self._add_token(KEYWORD_DICT_BEGIN)
|
|
self._parse1 = self._parse_main
|
|
i += 1
|
|
else:
|
|
self._parse1 = self._parse_hexstring
|
|
return i
|
|
|
|
def _parse_wclose(self, s: bytes, i: int) -> int:
|
|
c_int = s[i]
|
|
if c_int == 62: # b">"
|
|
self._add_token(KEYWORD_DICT_END)
|
|
i += 1
|
|
self._parse1 = self._parse_main
|
|
return i
|
|
```
|
|
'''
|
|
|
|
EXPECTED_OUTPUT = '''\
|
|
import contextlib
|
|
from typing import BinaryIO, TypeVar, Union
|
|
|
|
_SymbolT = TypeVar("_SymbolT", PSLiteral, PSKeyword)
|
|
|
|
|
|
PSLiteralTable = PSSymbolTable(PSLiteral)
|
|
PSKeywordTable = PSSymbolTable(PSKeyword)
|
|
LIT = PSLiteralTable.intern
|
|
KWD = PSKeywordTable.intern
|
|
KEYWORD_DICT_BEGIN = KWD(b"<<")
|
|
KEYWORD_DICT_END = KWD(b">>")
|
|
|
|
|
|
PSBaseParserToken = Union[float, bool, PSLiteral, PSKeyword, bytes]
|
|
|
|
|
|
class PSBaseParser:
|
|
|
|
def __init__(self, fp: BinaryIO) -> None:
|
|
self.fp = fp
|
|
self.eof = False
|
|
self.seek(0)
|
|
|
|
def _parse_main(self, s: bytes, i: int) -> int:
|
|
m = NONSPC.search(s, i)
|
|
if not m:
|
|
return len(s)
|
|
j = m.start(0)
|
|
# Use integer byte access to avoid creating a new one-byte bytes object.
|
|
c_int = s[j]
|
|
c_byte = bytes((c_int,))
|
|
self._curtokenpos = self.bufpos + j
|
|
if c_int == 37: # b"%"
|
|
self._curtoken = b"%"
|
|
self._parse1 = self._parse_comment
|
|
return j + 1
|
|
elif c_int == 47: # b"/"
|
|
self._curtoken = b""
|
|
self._parse1 = self._parse_literal
|
|
return j + 1
|
|
# b"-" is 45, b"+" is 43
|
|
elif c_int == 45 or c_int == 43 or (48 <= c_int <= 57):
|
|
self._curtoken = c_byte
|
|
self._parse1 = self._parse_number
|
|
return j + 1
|
|
elif c_int == 46: # b"."
|
|
self._curtoken = c_byte
|
|
self._parse1 = self._parse_float
|
|
return j + 1
|
|
# ASCII alphabetic check
|
|
elif (65 <= c_int <= 90) or (97 <= c_int <= 122):
|
|
self._curtoken = c_byte
|
|
self._parse1 = self._parse_keyword
|
|
return j + 1
|
|
elif c_int == 40: # b"("
|
|
self._curtoken = b""
|
|
self.paren = 1
|
|
self._parse1 = self._parse_string
|
|
return j + 1
|
|
elif c_int == 60: # b"<"
|
|
self._curtoken = b""
|
|
self._parse1 = self._parse_wopen
|
|
return j + 1
|
|
elif c_int == 62: # b">"
|
|
self._curtoken = b""
|
|
self._parse1 = self._parse_wclose
|
|
return j + 1
|
|
elif c_int == 0: # b"\\x00"
|
|
return j + 1
|
|
else:
|
|
self._add_token(KWD(c_byte))
|
|
return j + 1
|
|
|
|
def _add_token(self, obj: PSBaseParserToken) -> None:
|
|
self._tokens.append((self._curtokenpos, obj))
|
|
|
|
def _parse_comment(self, s: bytes, i: int) -> int:
|
|
m = EOL.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
self._parse1 = self._parse_main
|
|
# We ignore comments.
|
|
# self._tokens.append(self._curtoken)
|
|
return j
|
|
|
|
def _parse_literal(self, s: bytes, i: int) -> int:
|
|
m = END_LITERAL.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
c_int = s[j]
|
|
if c_int == 35: # b"#"
|
|
self.hex = b""
|
|
self._parse1 = self._parse_literal_hex
|
|
return j + 1
|
|
try:
|
|
name: str | bytes = str(self._curtoken, "utf-8")
|
|
except Exception:
|
|
name = self._curtoken
|
|
self._add_token(LIT(name))
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_number(self, s: bytes, i: int) -> int:
|
|
m = END_NUMBER.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
c_int = s[j]
|
|
if c_int == 46: # b"."
|
|
self._curtoken += b"."
|
|
self._parse1 = self._parse_float
|
|
return j + 1
|
|
with contextlib.suppress(ValueError):
|
|
self._add_token(int(self._curtoken))
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_float(self, s: bytes, i: int) -> int:
|
|
m = END_NUMBER.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
with contextlib.suppress(ValueError):
|
|
self._add_token(float(self._curtoken))
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_keyword(self, s: bytes, i: int) -> int:
|
|
m = END_KEYWORD.search(s, i)
|
|
if m:
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
else:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
if self._curtoken == b"true":
|
|
token: bool | PSKeyword = True
|
|
elif self._curtoken == b"false":
|
|
token = False
|
|
else:
|
|
token = KWD(self._curtoken)
|
|
self._add_token(token)
|
|
self._parse1 = self._parse_main
|
|
return j
|
|
|
|
def _parse_string(self, s: bytes, i: int) -> int:
|
|
m = END_STRING.search(s, i)
|
|
if not m:
|
|
self._curtoken += s[i:]
|
|
return len(s)
|
|
j = m.start(0)
|
|
self._curtoken += s[i:j]
|
|
c_int = s[j]
|
|
if c_int == 92: # b"\\\\"
|
|
self.oct = b""
|
|
self._parse1 = self._parse_string_1
|
|
return j + 1
|
|
if c_int == 40: # b"("
|
|
self.paren += 1
|
|
# append the literal "(" byte
|
|
self._curtoken += b"("
|
|
return j + 1
|
|
if c_int == 41: # b")"
|
|
self.paren -= 1
|
|
if self.paren:
|
|
# WTF, they said balanced parens need no special treatment.
|
|
self._curtoken += b")"
|
|
return j + 1
|
|
self._add_token(self._curtoken)
|
|
self._parse1 = self._parse_main
|
|
return j + 1
|
|
|
|
def _parse_wopen(self, s: bytes, i: int) -> int:
|
|
c_int = s[i]
|
|
if c_int == 60: # b"<"
|
|
self._add_token(KEYWORD_DICT_BEGIN)
|
|
self._parse1 = self._parse_main
|
|
i += 1
|
|
else:
|
|
self._parse1 = self._parse_hexstring
|
|
return i
|
|
|
|
def _parse_wclose(self, s: bytes, i: int) -> int:
|
|
c_int = s[i]
|
|
if c_int == 62: # b">"
|
|
self._add_token(KEYWORD_DICT_END)
|
|
i += 1
|
|
self._parse1 = self._parse_main
|
|
return i
|
|
'''
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_project():
|
|
temp_dir = Path(tempfile.mkdtemp())
|
|
source_file = temp_dir / "psparser.py"
|
|
source_file.write_text(ORIGINAL_SOURCE, encoding="utf-8")
|
|
|
|
test_cfg = TestConfig(
|
|
tests_root=temp_dir / "tests",
|
|
tests_project_rootdir=temp_dir,
|
|
project_root_path=temp_dir,
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
)
|
|
|
|
yield temp_dir, source_file, test_cfg
|
|
|
|
import shutil
|
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
|
|
|
|
def run_replacement(temp_project):
|
|
"""Helper: run the full replacement pipeline and return (optimizer, code_context, final_content)."""
|
|
temp_dir, source_file, test_cfg = temp_project
|
|
|
|
function_to_optimize = FunctionToOptimize(
|
|
file_path=source_file,
|
|
function_name="_parse_main",
|
|
parents=[FunctionParent(name="PSBaseParser", type="ClassDef")],
|
|
)
|
|
|
|
optimizer = PythonFunctionOptimizer(
|
|
function_to_optimize=function_to_optimize,
|
|
test_cfg=test_cfg,
|
|
function_to_optimize_source_code=source_file.read_text(encoding="utf-8"),
|
|
)
|
|
|
|
ctx_result = optimizer.get_code_optimization_context()
|
|
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
|
code_context = ctx_result.unwrap()
|
|
|
|
original_content = source_file.read_text(encoding="utf-8")
|
|
original_helper_code = {source_file: original_content}
|
|
optimized_code = CodeStringsMarkdown.parse_markdown_code(MOCK_CANDIDATE_MARKDOWN)
|
|
|
|
did_update = optimizer.replace_function_and_helpers_with_optimized_code(
|
|
code_context, optimized_code, original_helper_code
|
|
)
|
|
assert did_update, "Expected the code to be updated"
|
|
|
|
final_content = source_file.read_text(encoding="utf-8")
|
|
return optimizer, code_context, final_content
|
|
|
|
|
|
def test_replace_with_mock_candidate(temp_project):
|
|
"""Verify replace_function_and_helpers_with_optimized_code produces the exact expected output.
|
|
|
|
The code context detects ALL sibling methods as helpers of _parse_main.
|
|
replace_function_definitions_in_module replaces ALL method bodies.
|
|
detect_unused_helper_functions correctly recognizes methods referenced via attribute
|
|
assignment (self._parse1 = self._parse_literal) as used, so they are NOT reverted.
|
|
"""
|
|
_, code_context, final_content = run_replacement(temp_project)
|
|
|
|
# Code context correctly detects ALL methods as helpers
|
|
helper_names = {h.qualified_name for h in code_context.helper_functions}
|
|
assert helper_names == {
|
|
"PSBaseParser._parse_comment",
|
|
"PSBaseParser._parse_literal",
|
|
"PSBaseParser._parse_number",
|
|
"PSBaseParser._parse_float",
|
|
"PSBaseParser._parse_keyword",
|
|
"PSBaseParser._parse_string",
|
|
"PSBaseParser._parse_wopen",
|
|
"PSBaseParser._parse_wclose",
|
|
"PSBaseParser._add_token",
|
|
"KWD",
|
|
}
|
|
|
|
# The final content should match the expected output exactly
|
|
assert final_content == EXPECTED_OUTPUT
|
|
|
|
|
|
def test_detect_unused_helpers_handles_attribute_refs(temp_project):
|
|
"""Verify detect_unused_helper_functions recognizes methods referenced via attribute assignment.
|
|
|
|
When _parse_main does `self._parse1 = self._parse_literal`, the method is referenced as
|
|
an ast.Attribute value (not an ast.Call). The detection should recognize these as used.
|
|
"""
|
|
temp_dir, source_file, test_cfg = temp_project
|
|
|
|
function_to_optimize = FunctionToOptimize(
|
|
file_path=source_file,
|
|
function_name="_parse_main",
|
|
parents=[FunctionParent(name="PSBaseParser", type="ClassDef")],
|
|
)
|
|
|
|
optimizer = PythonFunctionOptimizer(
|
|
function_to_optimize=function_to_optimize,
|
|
test_cfg=test_cfg,
|
|
function_to_optimize_source_code=source_file.read_text(encoding="utf-8"),
|
|
)
|
|
|
|
ctx_result = optimizer.get_code_optimization_context()
|
|
assert ctx_result.is_successful()
|
|
code_context = ctx_result.unwrap()
|
|
|
|
optimized_code = CodeStringsMarkdown.parse_markdown_code(MOCK_CANDIDATE_MARKDOWN)
|
|
|
|
unused_helpers = detect_unused_helper_functions(
|
|
optimizer.function_to_optimize, code_context, optimized_code
|
|
)
|
|
unused_names = {h.qualified_name for h in unused_helpers}
|
|
|
|
# No helpers should be detected as unused — all are either directly called or
|
|
# referenced via attribute assignment (self._parse1 = self._parse_X)
|
|
assert unused_names == set(), f"Expected no unused helpers, got: {unused_names}"
|
|
|
|
|
|
def test_replace_produces_valid_python(temp_project):
|
|
"""Verify the final output is valid, parseable Python."""
|
|
_, _, final_content = run_replacement(temp_project)
|
|
|
|
import ast
|
|
ast.parse(final_content)
|