go function optimizer and handle global vars context correctly
This commit is contained in:
parent
ba560308bc
commit
ac478de753
9 changed files with 1701 additions and 5 deletions
|
|
@ -36,9 +36,15 @@ def extract_code_context(
|
|||
imports = analyzer.find_imports(source)
|
||||
import_lines = [_import_to_line(imp) for imp in imports]
|
||||
|
||||
read_only_context = ""
|
||||
read_only_parts: list[str] = []
|
||||
if receiver_type:
|
||||
read_only_context = _extract_struct_context(source, receiver_type, analyzer)
|
||||
struct_ctx = _extract_struct_context(source, receiver_type, analyzer)
|
||||
if struct_ctx:
|
||||
read_only_parts.append(struct_ctx)
|
||||
|
||||
init_ctx = _extract_init_context(source, analyzer)
|
||||
if init_ctx:
|
||||
read_only_parts.append(init_ctx)
|
||||
|
||||
helpers = find_helper_functions(source, function, analyzer)
|
||||
|
||||
|
|
@ -46,7 +52,7 @@ def extract_code_context(
|
|||
target_code=target_code,
|
||||
target_file=function.file_path,
|
||||
helper_functions=helpers,
|
||||
read_only_context=read_only_context,
|
||||
read_only_context="\n\n".join(read_only_parts),
|
||||
imports=import_lines,
|
||||
language=Language.GO,
|
||||
)
|
||||
|
|
@ -125,3 +131,27 @@ def _extract_struct_context(source: str, struct_name: str, analyzer: GoAnalyzer)
|
|||
lines = source.splitlines()
|
||||
return "\n".join(lines[s.starting_line - 1 : s.ending_line])
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_init_context(source: str, analyzer: GoAnalyzer) -> str:
|
||||
init_source = analyzer.extract_function_source(source, "init")
|
||||
if init_source is None:
|
||||
return ""
|
||||
|
||||
init_ids = analyzer.collect_body_identifiers(source, "init")
|
||||
if not init_ids:
|
||||
return init_source
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
for decl in analyzer.find_global_declarations(source):
|
||||
if init_ids & set(decl.names):
|
||||
parts.append(decl.source_code)
|
||||
|
||||
for struct in analyzer.find_structs(source):
|
||||
if struct.name in init_ids:
|
||||
lines = source.splitlines()
|
||||
parts.append("\n".join(lines[struct.starting_line - 1 : struct.ending_line]))
|
||||
|
||||
parts.append(init_source)
|
||||
return "\n\n".join(parts)
|
||||
|
|
|
|||
160
codeflash/languages/golang/function_optimizer.py
Normal file
160
codeflash/languages/golang/function_optimizer.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.code_utils.code_utils import encoded_tokens_len
|
||||
from codeflash.code_utils.config_consts import (
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
||||
READ_WRITABLE_LIMIT_ERROR,
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT,
|
||||
TESTGEN_LIMIT_ERROR,
|
||||
)
|
||||
from codeflash.either import Failure, Success
|
||||
from codeflash.languages.function_optimizer import FunctionOptimizer
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.either import Result
|
||||
from codeflash.languages.base import CodeContext, HelperFunction
|
||||
from codeflash.models.models import OriginalCodeBaseline, TestDiff, TestResults
|
||||
|
||||
|
||||
class GoFunctionOptimizer(FunctionOptimizer):
|
||||
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
language = Language(self.function_to_optimize.language)
|
||||
lang_support = get_language_support(language)
|
||||
|
||||
try:
|
||||
code_context = lang_support.extract_code_context(
|
||||
self.function_to_optimize, self.project_root, self.project_root
|
||||
)
|
||||
return Success(
|
||||
_build_optimization_context(
|
||||
code_context,
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize.language,
|
||||
self.project_root,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
return Failure(str(e))
|
||||
|
||||
def compare_candidate_results(
|
||||
self,
|
||||
baseline_results: OriginalCodeBaseline,
|
||||
candidate_behavior_results: TestResults,
|
||||
optimization_candidate_index: int,
|
||||
) -> tuple[bool, list[TestDiff]]:
|
||||
return compare_test_results(
|
||||
baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True
|
||||
)
|
||||
|
||||
def replace_function_and_helpers_with_optimized_code(
|
||||
self,
|
||||
code_context: CodeOptimizationContext,
|
||||
optimized_code: CodeStringsMarkdown,
|
||||
original_helper_code: dict[Path, str],
|
||||
) -> bool:
|
||||
from codeflash.languages.code_replacer import replace_function_definitions_for_language
|
||||
|
||||
did_update = False
|
||||
for module_abspath, qualified_names in self.group_functions_by_file(code_context).items():
|
||||
did_update |= replace_function_definitions_for_language(
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optimized_code,
|
||||
module_abspath=module_abspath,
|
||||
project_root_path=self.project_root,
|
||||
lang_support=self.language_support,
|
||||
function_to_optimize=self.function_to_optimize,
|
||||
)
|
||||
return did_update
|
||||
|
||||
|
||||
def _build_optimization_context(
|
||||
code_context: CodeContext,
|
||||
file_path: Path,
|
||||
language: str,
|
||||
project_root: Path,
|
||||
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
||||
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
|
||||
) -> CodeOptimizationContext:
|
||||
if code_context.imports:
|
||||
inner = "\n".join(f"\t{imp}" for imp in code_context.imports)
|
||||
imports_code = f"import (\n{inner}\n)"
|
||||
else:
|
||||
imports_code = ""
|
||||
|
||||
try:
|
||||
target_relative_path = file_path.resolve().relative_to(project_root.resolve())
|
||||
except ValueError:
|
||||
target_relative_path = file_path
|
||||
|
||||
helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list)
|
||||
helper_function_sources = []
|
||||
|
||||
for helper in code_context.helper_functions:
|
||||
helpers_by_file[helper.file_path].append(helper)
|
||||
helper_function_sources.append(
|
||||
FunctionSource(
|
||||
file_path=helper.file_path,
|
||||
qualified_name=helper.qualified_name,
|
||||
fully_qualified_name=helper.qualified_name,
|
||||
only_function_name=helper.name,
|
||||
source_code=helper.source_code,
|
||||
)
|
||||
)
|
||||
|
||||
target_file_code = code_context.target_code
|
||||
same_file_helpers = helpers_by_file.get(file_path, [])
|
||||
if same_file_helpers:
|
||||
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
|
||||
target_file_code = target_file_code + "\n\n" + helper_code
|
||||
|
||||
if imports_code:
|
||||
target_file_code = imports_code + "\n\n" + target_file_code
|
||||
|
||||
read_writable_code_strings = [CodeString(code=target_file_code, file_path=target_relative_path, language=language)]
|
||||
|
||||
for helper_file_path, file_helpers in helpers_by_file.items():
|
||||
if helper_file_path == file_path:
|
||||
continue
|
||||
try:
|
||||
helper_relative_path = helper_file_path.resolve().relative_to(project_root.resolve())
|
||||
except ValueError:
|
||||
helper_relative_path = helper_file_path
|
||||
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
|
||||
read_writable_code_strings.append(
|
||||
CodeString(code=combined_helper_code, file_path=helper_relative_path, language=language)
|
||||
)
|
||||
|
||||
read_writable_code = CodeStringsMarkdown(code_strings=read_writable_code_strings, language=language)
|
||||
testgen_context = CodeStringsMarkdown(code_strings=read_writable_code_strings.copy(), language=language)
|
||||
|
||||
read_writable_tokens = encoded_tokens_len(read_writable_code.markdown)
|
||||
if read_writable_tokens > optim_token_limit:
|
||||
raise ValueError(READ_WRITABLE_LIMIT_ERROR)
|
||||
|
||||
testgen_tokens = encoded_tokens_len(testgen_context.markdown)
|
||||
if testgen_tokens > testgen_token_limit:
|
||||
raise ValueError(TESTGEN_LIMIT_ERROR)
|
||||
|
||||
code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest()
|
||||
|
||||
return CodeOptimizationContext(
|
||||
testgen_context=testgen_context,
|
||||
read_writable_code=read_writable_code,
|
||||
read_only_context_code=code_context.read_only_context,
|
||||
hashing_code_context=read_writable_code.flat,
|
||||
hashing_code_context_hash=code_hash,
|
||||
helper_functions=helper_function_sources,
|
||||
testgen_helper_fqns=[fs.fully_qualified_name for fs in helper_function_sources],
|
||||
preexisting_objects=set(),
|
||||
)
|
||||
|
|
@ -81,6 +81,15 @@ class GoImportInfo:
|
|||
ending_line: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GoGlobalDeclaration:
|
||||
names: tuple[str, ...]
|
||||
kind: str
|
||||
source_code: str
|
||||
starting_line: int
|
||||
ending_line: int
|
||||
|
||||
|
||||
class GoAnalyzer:
|
||||
def __init__(self) -> None:
|
||||
self._parser = _get_go_parser()
|
||||
|
|
@ -189,6 +198,45 @@ class GoAnalyzer:
|
|||
)
|
||||
return results
|
||||
|
||||
def find_global_declarations(self, source: str) -> list[GoGlobalDeclaration]:
|
||||
tree = self.parse(source)
|
||||
results: list[GoGlobalDeclaration] = []
|
||||
for node in tree.root_node.children:
|
||||
if node.type in ("var_declaration", "const_declaration"):
|
||||
kind = "var" if node.type == "var_declaration" else "const"
|
||||
names = _extract_declaration_names(node, self)
|
||||
if names:
|
||||
results.append(
|
||||
GoGlobalDeclaration(
|
||||
names=tuple(names),
|
||||
kind=kind,
|
||||
source_code=self.get_node_text(node),
|
||||
starting_line=node.start_point.row + 1,
|
||||
ending_line=node.end_point.row + 1,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
def collect_body_identifiers(self, source: str, func_name: str, receiver_type: str | None = None) -> set[str]:
|
||||
tree = self.parse(source)
|
||||
for node in tree.root_node.children:
|
||||
if receiver_type is None and node.type == "function_declaration":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None and self.get_node_text(name_node) == func_name:
|
||||
body = node.child_by_field_name("body")
|
||||
return _collect_identifiers(body) if body else set()
|
||||
if receiver_type is not None and node.type == "method_declaration":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None or self.get_node_text(name_node) != func_name:
|
||||
continue
|
||||
recv_node = node.child_by_field_name("receiver")
|
||||
if recv_node is not None:
|
||||
recv_name, _ = self.parse_receiver(recv_node)
|
||||
if recv_name == receiver_type:
|
||||
body = node.child_by_field_name("body")
|
||||
return _collect_identifiers(body) if body else set()
|
||||
return set()
|
||||
|
||||
def find_package_name(self, source: str) -> str | None:
|
||||
tree = self.parse(source)
|
||||
for node in tree.root_node.children:
|
||||
|
|
@ -317,6 +365,42 @@ def _iter_import_specs(import_node: Node) -> list[Node]:
|
|||
return results
|
||||
|
||||
|
||||
def _extract_declaration_names(node: Node, analyzer: GoAnalyzer) -> list[str]:
|
||||
names: list[str] = []
|
||||
for child in node.children:
|
||||
if child.type in ("var_spec", "const_spec"):
|
||||
name_node = child.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
names.append(analyzer.get_node_text(name_node))
|
||||
elif child.type in ("var_spec_list", "const_spec_list"):
|
||||
for spec in child.children:
|
||||
if spec.type in ("var_spec", "const_spec"):
|
||||
name_node = spec.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
names.append(analyzer.get_node_text(name_node))
|
||||
return names
|
||||
|
||||
|
||||
def _collect_identifiers(node: Node | None) -> set[str]:
|
||||
if node is None:
|
||||
return set()
|
||||
ids: set[str] = set()
|
||||
stack = [node]
|
||||
while stack:
|
||||
n = stack.pop()
|
||||
if n.type in ("identifier", "type_identifier"):
|
||||
text = n.parent
|
||||
if text is not None and text.type not in ("parameter_declaration", "short_var_declaration"):
|
||||
ids.add(n.text.decode("utf-8") if n.text else "")
|
||||
elif text is not None and text.type == "short_var_declaration":
|
||||
name_node = text.child_by_field_name("left")
|
||||
if name_node is not n and (name_node is None or n not in (name_node, *tuple(name_node.children))):
|
||||
ids.add(n.text.decode("utf-8") if n.text else "")
|
||||
stack.extend(n.children)
|
||||
ids.discard("")
|
||||
return ids
|
||||
|
||||
|
||||
def _find_preceding_comment_line(node: Node) -> int | None:
|
||||
prev = node.prev_named_sibling
|
||||
if prev is None:
|
||||
|
|
|
|||
|
|
@ -55,6 +55,11 @@ def replace_function(
|
|||
def add_global_declarations(optimized_code: str, original_source: str, analyzer: GoAnalyzer | None = None) -> str:
|
||||
analyzer = analyzer or GoAnalyzer()
|
||||
|
||||
merged = _merge_imports(optimized_code, original_source, analyzer)
|
||||
return _merge_global_var_const(optimized_code, merged, analyzer)
|
||||
|
||||
|
||||
def _merge_imports(optimized_code: str, original_source: str, analyzer: GoAnalyzer) -> str:
|
||||
opt_imports = analyzer.find_imports(optimized_code)
|
||||
orig_imports = analyzer.find_imports(original_source)
|
||||
orig_paths = {imp.path for imp in orig_imports}
|
||||
|
|
@ -91,6 +96,74 @@ def add_global_declarations(optimized_code: str, original_source: str, analyzer:
|
|||
return "".join([*lines[:insert_at], import_block, *lines[insert_at:]])
|
||||
|
||||
|
||||
def _merge_global_var_const(optimized_code: str, original_source: str, analyzer: GoAnalyzer) -> str:
|
||||
opt_decls = analyzer.find_global_declarations(optimized_code)
|
||||
if not opt_decls:
|
||||
return original_source
|
||||
|
||||
orig_decls = analyzer.find_global_declarations(original_source)
|
||||
orig_names_to_decl: dict[str, object] = {}
|
||||
for decl in orig_decls:
|
||||
for name in decl.names:
|
||||
orig_names_to_decl[name] = decl
|
||||
|
||||
new_decls: list[str] = []
|
||||
replaced_decls: set[int] = set()
|
||||
|
||||
for opt_decl in opt_decls:
|
||||
overlapping_orig = None
|
||||
for name in opt_decl.names:
|
||||
if name in orig_names_to_decl:
|
||||
overlapping_orig = orig_names_to_decl[name]
|
||||
break
|
||||
|
||||
if overlapping_orig is None:
|
||||
new_decls.append(opt_decl.source_code)
|
||||
elif overlapping_orig.source_code.strip() != opt_decl.source_code.strip():
|
||||
orig_id = id(overlapping_orig)
|
||||
if orig_id not in replaced_decls:
|
||||
replaced_decls.add(orig_id)
|
||||
original_source = _replace_declaration_block(original_source, overlapping_orig, opt_decl.source_code)
|
||||
|
||||
if new_decls:
|
||||
original_source = _insert_new_declarations(original_source, new_decls, analyzer)
|
||||
|
||||
return original_source
|
||||
|
||||
|
||||
def _replace_declaration_block(source: str, orig_decl: object, new_source_code: str) -> str:
|
||||
lines = source.splitlines(keepends=True)
|
||||
start = orig_decl.starting_line - 1
|
||||
end = orig_decl.ending_line
|
||||
replacement = new_source_code.rstrip("\n") + "\n"
|
||||
return "".join([*lines[:start], replacement, *lines[end:]])
|
||||
|
||||
|
||||
def _insert_new_declarations(source: str, new_decls: list[str], analyzer: GoAnalyzer) -> str:
|
||||
lines = source.splitlines(keepends=True)
|
||||
|
||||
insert_at = _find_declarations_insert_point(source, analyzer)
|
||||
|
||||
block = "\n".join(new_decls) + "\n\n"
|
||||
return "".join([*lines[:insert_at], block, *lines[insert_at:]])
|
||||
|
||||
|
||||
def _find_declarations_insert_point(source: str, analyzer: GoAnalyzer) -> int:
|
||||
tree = analyzer.parse(source)
|
||||
last_line = 0
|
||||
for node in tree.root_node.children:
|
||||
if node.type in ("import_declaration", "var_declaration", "const_declaration"):
|
||||
candidate = node.end_point.row + 1
|
||||
last_line = max(last_line, candidate)
|
||||
if last_line > 0:
|
||||
return last_line
|
||||
|
||||
for node in tree.root_node.children:
|
||||
if node.type == "package_clause":
|
||||
return node.end_point.row + 1
|
||||
return 0
|
||||
|
||||
|
||||
def remove_test_functions(test_source: str, functions_to_remove: list[str], analyzer: GoAnalyzer | None = None) -> str:
|
||||
analyzer = analyzer or GoAnalyzer()
|
||||
tree = analyzer.parse(test_source)
|
||||
|
|
|
|||
|
|
@ -85,7 +85,9 @@ class GoSupport:
|
|||
|
||||
@property
|
||||
def function_optimizer_class(self) -> type:
|
||||
raise NotImplementedError
|
||||
from codeflash.languages.golang.function_optimizer import GoFunctionOptimizer
|
||||
|
||||
return GoFunctionOptimizer
|
||||
|
||||
def discover_functions(
|
||||
self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
|
|
|
|||
|
|
@ -195,6 +195,120 @@ class TestExtractCodeContextEdgeCases:
|
|||
assert ctx.imports == ['"fmt"', '"os"', 'str "strings"']
|
||||
|
||||
|
||||
GO_SOURCE_WITH_INIT = """\
|
||||
package server
|
||||
|
||||
import "sync"
|
||||
|
||||
var (
|
||||
\tglobalCache map[string]int
|
||||
\tmu sync.Mutex
|
||||
)
|
||||
|
||||
const MaxRetries = 5
|
||||
|
||||
type Config struct {
|
||||
\tName string
|
||||
\tMax int
|
||||
}
|
||||
|
||||
func init() {
|
||||
\tglobalCache = make(map[string]int)
|
||||
\tglobalCache["default"] = 0
|
||||
\tmu.Lock()
|
||||
\tmu.Unlock()
|
||||
}
|
||||
|
||||
func Process() int {
|
||||
\treturn MaxRetries
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class TestExtractCodeContextWithInit:
|
||||
def test_init_in_read_only_context(self, tmp_path: Path) -> None:
|
||||
source_file = (tmp_path / "server.go").resolve()
|
||||
source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8")
|
||||
func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go")
|
||||
ctx = extract_code_context(func, tmp_path.resolve())
|
||||
assert "func init()" in ctx.read_only_context
|
||||
|
||||
def test_init_referenced_globals_in_read_only_context(self, tmp_path: Path) -> None:
|
||||
source_file = (tmp_path / "server.go").resolve()
|
||||
source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8")
|
||||
func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go")
|
||||
ctx = extract_code_context(func, tmp_path.resolve())
|
||||
assert "globalCache" in ctx.read_only_context
|
||||
assert "mu" in ctx.read_only_context
|
||||
|
||||
def test_init_not_in_helpers(self, tmp_path: Path) -> None:
|
||||
source_file = (tmp_path / "server.go").resolve()
|
||||
source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8")
|
||||
func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go")
|
||||
ctx = extract_code_context(func, tmp_path.resolve())
|
||||
helper_names = [h.name for h in ctx.helper_functions]
|
||||
assert "init" not in helper_names
|
||||
|
||||
def test_no_init_no_extra_context(self, tmp_path: Path) -> None:
|
||||
source_file = (tmp_path / "calc.go").resolve()
|
||||
source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8")
|
||||
func = FunctionToOptimize(function_name="Add", file_path=source_file, language="go")
|
||||
ctx = extract_code_context(func, tmp_path.resolve())
|
||||
assert "func init()" not in ctx.read_only_context
|
||||
|
||||
def test_full_init_read_only_context(self, tmp_path: Path) -> None:
|
||||
source_file = (tmp_path / "server.go").resolve()
|
||||
source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8")
|
||||
func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go")
|
||||
ctx = extract_code_context(func, tmp_path.resolve())
|
||||
expected = (
|
||||
"var (\n"
|
||||
"\tglobalCache map[string]int\n"
|
||||
"\tmu sync.Mutex\n"
|
||||
")\n"
|
||||
"\n"
|
||||
"func init() {\n"
|
||||
"\tglobalCache = make(map[string]int)\n"
|
||||
"\tglobalCache[\"default\"] = 0\n"
|
||||
"\tmu.Lock()\n"
|
||||
"\tmu.Unlock()\n"
|
||||
"}"
|
||||
)
|
||||
assert ctx.read_only_context == expected
|
||||
|
||||
def test_method_with_init_combines_struct_and_init_context(self, tmp_path: Path) -> None:
|
||||
source = """\
|
||||
package server
|
||||
|
||||
var globalOffset = 10
|
||||
|
||||
type Calc struct {
|
||||
\tVal int
|
||||
}
|
||||
|
||||
func init() {
|
||||
\tglobalOffset = 42
|
||||
}
|
||||
|
||||
func (c *Calc) Compute() int {
|
||||
\treturn c.Val + globalOffset
|
||||
}
|
||||
"""
|
||||
source_file = (tmp_path / "server.go").resolve()
|
||||
source_file.write_text(source, encoding="utf-8")
|
||||
func = FunctionToOptimize(
|
||||
function_name="Compute",
|
||||
file_path=source_file,
|
||||
parents=[FunctionParent(name="Calc", type="StructDef")],
|
||||
language="go",
|
||||
is_method=True,
|
||||
)
|
||||
ctx = extract_code_context(func, tmp_path.resolve())
|
||||
assert "type Calc struct" in ctx.read_only_context
|
||||
assert "func init()" in ctx.read_only_context
|
||||
assert "var globalOffset = 10" in ctx.read_only_context
|
||||
|
||||
|
||||
class TestFindHelperFunctions:
|
||||
def test_skips_init_and_main(self, tmp_path: Path) -> None:
|
||||
source = "package main\n\nfunc init() { println() }\n\nfunc main() { println() }\n\nfunc Target() int { return 1 }\n"
|
||||
|
|
|
|||
653
tests/test_languages/test_golang/test_function_optimizer.py
Normal file
653
tests/test_languages/test_golang/test_function_optimizer.py
Normal file
|
|
@ -0,0 +1,653 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.golang.context import extract_code_context
|
||||
from codeflash.languages.golang.function_optimizer import _build_optimization_context
|
||||
from codeflash.models.function_types import FunctionParent, FunctionToOptimize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Realistic Go sources used across test classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CALCULATOR_SOURCE = dedent("""\
|
||||
package calc
|
||||
|
||||
import (
|
||||
\t"fmt"
|
||||
\t"math"
|
||||
\tstr "strings"
|
||||
)
|
||||
|
||||
// Calculator holds running computation state.
|
||||
type Calculator struct {
|
||||
\tResult float64
|
||||
\tHistory []float64
|
||||
}
|
||||
|
||||
// Formatter controls output rendering.
|
||||
type Formatter interface {
|
||||
\tFormat(val float64) string
|
||||
}
|
||||
|
||||
// Add returns the sum of two integers.
|
||||
func Add(a, b int) int {
|
||||
\treturn a + b
|
||||
}
|
||||
|
||||
func subtract(a, b int) int {
|
||||
\treturn a - b
|
||||
}
|
||||
|
||||
func multiply(a, b int) int {
|
||||
\treturn a * b
|
||||
}
|
||||
|
||||
// Greet builds a greeting message.
|
||||
func Greet(name string) string {
|
||||
\treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name))
|
||||
}
|
||||
|
||||
// AddFloat adds a float value and records history.
|
||||
func (c *Calculator) AddFloat(val float64) float64 {
|
||||
\tc.Result += val
|
||||
\tc.History = append(c.History, c.Result)
|
||||
\treturn c.Result
|
||||
}
|
||||
|
||||
// Sqrt computes the square root of the current result.
|
||||
func (c *Calculator) Sqrt() float64 {
|
||||
\tc.Result = math.Sqrt(c.Result)
|
||||
\tc.History = append(c.History, c.Result)
|
||||
\treturn c.Result
|
||||
}
|
||||
|
||||
// Reset zeroes out the calculator.
|
||||
func (c Calculator) Reset() Calculator {
|
||||
\tc.Result = 0
|
||||
\tc.History = nil
|
||||
\treturn c
|
||||
}
|
||||
""")
|
||||
|
||||
SIMPLE_SOURCE = dedent("""\
|
||||
package simple
|
||||
|
||||
func Double(x int) int {
|
||||
\treturn x * 2
|
||||
}
|
||||
""")
|
||||
|
||||
INIT_SOURCE = dedent("""\
|
||||
package server
|
||||
|
||||
import (
|
||||
\t"fmt"
|
||||
\t"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
\tglobalCache map[string]int
|
||||
\tmu sync.Mutex
|
||||
)
|
||||
|
||||
var singleVar = 42
|
||||
|
||||
const MaxRetries = 5
|
||||
|
||||
type Config struct {
|
||||
\tName string
|
||||
\tMax int
|
||||
}
|
||||
|
||||
func init() {
|
||||
\tglobalCache = make(map[string]int)
|
||||
\tglobalCache["default"] = 0
|
||||
\tdefaultCfg := Config{Name: "prod", Max: MaxRetries}
|
||||
\t_ = defaultCfg
|
||||
\tmu.Lock()
|
||||
\tmu.Unlock()
|
||||
}
|
||||
|
||||
func Process() int {
|
||||
\tfmt.Println("processing")
|
||||
\treturn singleVar + MaxRetries
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers to drive the full extract → build pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_context_for_function(
|
||||
source: str,
|
||||
filename: str,
|
||||
function_name: str,
|
||||
tmp_path: Path,
|
||||
parents: list[FunctionParent] | None = None,
|
||||
is_method: bool = False,
|
||||
) -> CodeOptimizationContext:
|
||||
root = tmp_path.resolve()
|
||||
source_file = (root / filename).resolve()
|
||||
source_file.write_text(source, encoding="utf-8")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name=function_name, file_path=source_file, parents=parents or [], language="go", is_method=is_method
|
||||
)
|
||||
code_context = extract_code_context(func, root)
|
||||
return _build_optimization_context(code_context, source_file, "go", root)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: targeting a plain exported function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildContextExportedFunction:
|
||||
"""Target: Add(a, b int) int — a plain exported function with a doc comment."""
|
||||
|
||||
def test_full_assembled_code_string(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
|
||||
expected = dedent("""\
|
||||
import (
|
||||
\t"fmt"
|
||||
\t"math"
|
||||
\tstr "strings"
|
||||
)
|
||||
|
||||
// Add returns the sum of two integers.
|
||||
func Add(a, b int) int {
|
||||
\treturn a + b
|
||||
}
|
||||
|
||||
|
||||
func subtract(a, b int) int {
|
||||
\treturn a - b
|
||||
}
|
||||
|
||||
func multiply(a, b int) int {
|
||||
\treturn a * b
|
||||
}
|
||||
|
||||
// Greet builds a greeting message.
|
||||
func Greet(name string) string {
|
||||
\treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name))
|
||||
}
|
||||
|
||||
|
||||
// AddFloat adds a float value and records history.
|
||||
func (c *Calculator) AddFloat(val float64) float64 {
|
||||
\tc.Result += val
|
||||
\tc.History = append(c.History, c.Result)
|
||||
\treturn c.Result
|
||||
}
|
||||
|
||||
|
||||
// Sqrt computes the square root of the current result.
|
||||
func (c *Calculator) Sqrt() float64 {
|
||||
\tc.Result = math.Sqrt(c.Result)
|
||||
\tc.History = append(c.History, c.Result)
|
||||
\treturn c.Result
|
||||
}
|
||||
|
||||
|
||||
// Reset zeroes out the calculator.
|
||||
func (c Calculator) Reset() Calculator {
|
||||
\tc.Result = 0
|
||||
\tc.History = nil
|
||||
\treturn c
|
||||
}
|
||||
""")
|
||||
assert code == expected
|
||||
|
||||
def test_code_excludes_package_clause(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
assert "package calc" not in code
|
||||
|
||||
def test_code_excludes_struct_definition(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
assert "type Calculator struct" not in code
|
||||
|
||||
def test_code_excludes_interface_definition(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
assert "type Formatter interface" not in code
|
||||
|
||||
def test_helpers_include_other_functions_and_methods(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
helper_names = sorted(h.only_function_name for h in result.helper_functions)
|
||||
assert "subtract" in helper_names
|
||||
assert "multiply" in helper_names
|
||||
assert "Greet" in helper_names
|
||||
assert "AddFloat" in helper_names
|
||||
assert "Sqrt" in helper_names
|
||||
assert "Reset" in helper_names
|
||||
assert "Add" not in helper_names
|
||||
|
||||
def test_helper_sources_are_full_functions(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
by_name = {h.only_function_name: h for h in result.helper_functions}
|
||||
|
||||
assert by_name["subtract"].source_code == dedent("""\
|
||||
func subtract(a, b int) int {
|
||||
\treturn a - b
|
||||
}""")
|
||||
|
||||
assert by_name["Greet"].source_code == dedent("""\
|
||||
// Greet builds a greeting message.
|
||||
func Greet(name string) string {
|
||||
\treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name))
|
||||
}
|
||||
""")
|
||||
|
||||
def test_method_helpers_have_qualified_names(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
by_name = {h.only_function_name: h for h in result.helper_functions}
|
||||
assert by_name["AddFloat"].qualified_name == "Calculator.AddFloat"
|
||||
assert by_name["AddFloat"].fully_qualified_name == "Calculator.AddFloat"
|
||||
assert by_name["subtract"].qualified_name == "subtract"
|
||||
|
||||
def test_no_read_only_context_for_plain_function(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
assert result.read_only_context_code == ""
|
||||
|
||||
def test_relative_path(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
assert result.read_writable_code.code_strings[0].file_path == Path("calc.go")
|
||||
|
||||
def test_language_tag(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
assert result.read_writable_code.code_strings[0].language == "go"
|
||||
|
||||
def test_testgen_fqns_match_helpers(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
fqns = set(result.testgen_helper_fqns)
|
||||
helper_fqns = {h.fully_qualified_name for h in result.helper_functions}
|
||||
assert fqns == helper_fqns
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: targeting a method with a pointer receiver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildContextPointerReceiverMethod:
|
||||
"""Target: (c *Calculator) AddFloat(val float64) — pointer receiver method."""
|
||||
|
||||
def _build(self, tmp_path: Path) -> CodeOptimizationContext:
|
||||
return _build_context_for_function(
|
||||
CALCULATOR_SOURCE,
|
||||
"calc.go",
|
||||
"AddFloat",
|
||||
tmp_path,
|
||||
parents=[FunctionParent(name="Calculator", type="StructDef")],
|
||||
is_method=True,
|
||||
)
|
||||
|
||||
def test_full_assembled_code_string(self, tmp_path: Path) -> None:
|
||||
result = self._build(tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
|
||||
expected = dedent("""\
|
||||
import (
|
||||
\t"fmt"
|
||||
\t"math"
|
||||
\tstr "strings"
|
||||
)
|
||||
|
||||
// AddFloat adds a float value and records history.
|
||||
func (c *Calculator) AddFloat(val float64) float64 {
|
||||
\tc.Result += val
|
||||
\tc.History = append(c.History, c.Result)
|
||||
\treturn c.Result
|
||||
}
|
||||
|
||||
|
||||
// Add returns the sum of two integers.
|
||||
func Add(a, b int) int {
|
||||
\treturn a + b
|
||||
}
|
||||
|
||||
|
||||
func subtract(a, b int) int {
|
||||
\treturn a - b
|
||||
}
|
||||
|
||||
func multiply(a, b int) int {
|
||||
\treturn a * b
|
||||
}
|
||||
|
||||
// Greet builds a greeting message.
|
||||
func Greet(name string) string {
|
||||
\treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name))
|
||||
}
|
||||
|
||||
|
||||
// Sqrt computes the square root of the current result.
|
||||
func (c *Calculator) Sqrt() float64 {
|
||||
\tc.Result = math.Sqrt(c.Result)
|
||||
\tc.History = append(c.History, c.Result)
|
||||
\treturn c.Result
|
||||
}
|
||||
|
||||
|
||||
// Reset zeroes out the calculator.
|
||||
func (c Calculator) Reset() Calculator {
|
||||
\tc.Result = 0
|
||||
\tc.History = nil
|
||||
\treturn c
|
||||
}
|
||||
""")
|
||||
assert code == expected
|
||||
|
||||
def test_code_excludes_package_and_type_defs(self, tmp_path: Path) -> None:
|
||||
result = self._build(tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
assert "package calc" not in code
|
||||
assert "type Calculator struct" not in code
|
||||
assert "type Formatter interface" not in code
|
||||
|
||||
def test_read_only_context_is_struct_definition(self, tmp_path: Path) -> None:
|
||||
result = self._build(tmp_path)
|
||||
assert result.read_only_context_code == dedent("""\
|
||||
type Calculator struct {
|
||||
\tResult float64
|
||||
\tHistory []float64
|
||||
}""")
|
||||
|
||||
def test_helpers_exclude_self_include_others(self, tmp_path: Path) -> None:
|
||||
result = self._build(tmp_path)
|
||||
helper_names = sorted(h.only_function_name for h in result.helper_functions)
|
||||
assert "AddFloat" not in helper_names
|
||||
assert "Add" in helper_names
|
||||
assert "subtract" in helper_names
|
||||
assert "multiply" in helper_names
|
||||
assert "Greet" in helper_names
|
||||
assert "Sqrt" in helper_names
|
||||
assert "Reset" in helper_names
|
||||
|
||||
def test_target_not_duplicated_in_code_string(self, tmp_path: Path) -> None:
|
||||
result = self._build(tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
assert code.count("func (c *Calculator) AddFloat") == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: targeting a value receiver method
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildContextValueReceiverMethod:
|
||||
"""Target: (c Calculator) Reset() — value receiver method."""
|
||||
|
||||
def _build(self, tmp_path: Path) -> CodeOptimizationContext:
|
||||
return _build_context_for_function(
|
||||
CALCULATOR_SOURCE,
|
||||
"calc.go",
|
||||
"Reset",
|
||||
tmp_path,
|
||||
parents=[FunctionParent(name="Calculator", type="StructDef")],
|
||||
is_method=True,
|
||||
)
|
||||
|
||||
def test_target_in_code_string(self, tmp_path: Path) -> None:
|
||||
result = self._build(tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
|
||||
expected_target = dedent("""\
|
||||
// Reset zeroes out the calculator.
|
||||
func (c Calculator) Reset() Calculator {
|
||||
\tc.Result = 0
|
||||
\tc.History = nil
|
||||
\treturn c
|
||||
}""")
|
||||
assert code.count("func (c Calculator) Reset()") == 1
|
||||
assert expected_target in code
|
||||
|
||||
def test_helpers_include_other_methods_on_same_struct(self, tmp_path: Path) -> None:
|
||||
result = self._build(tmp_path)
|
||||
helper_names = sorted(h.only_function_name for h in result.helper_functions)
|
||||
assert "Reset" not in helper_names
|
||||
assert "AddFloat" in helper_names
|
||||
assert "Sqrt" in helper_names
|
||||
assert "Add" in helper_names
|
||||
|
||||
def test_helper_code_in_assembled_string(self, tmp_path: Path) -> None:
|
||||
result = self._build(tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
assert "func (c *Calculator) AddFloat" in code
|
||||
assert "func (c *Calculator) Sqrt()" in code
|
||||
assert "func Add(a, b int) int" in code
|
||||
assert "func subtract(a, b int) int" in code
|
||||
|
||||
def test_struct_in_read_only_context(self, tmp_path: Path) -> None:
|
||||
result = self._build(tmp_path)
|
||||
assert result.read_only_context_code == dedent("""\
|
||||
type Calculator struct {
|
||||
\tResult float64
|
||||
\tHistory []float64
|
||||
}""")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: simple source with no imports, no methods, one function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildContextMinimalSource:
|
||||
"""Target: Double(x int) — minimal file with no imports or structs."""
|
||||
|
||||
def test_no_imports_no_prefix(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
assert code == dedent("""\
|
||||
func Double(x int) int {
|
||||
\treturn x * 2
|
||||
}""")
|
||||
|
||||
def test_no_helpers(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path)
|
||||
assert result.helper_functions == []
|
||||
assert result.testgen_helper_fqns == []
|
||||
|
||||
def test_empty_read_only_context(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path)
|
||||
assert result.read_only_context_code == ""
|
||||
|
||||
def test_preexisting_objects_empty(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path)
|
||||
assert result.preexisting_objects == set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: init function and globals in context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildContextWithInit:
|
||||
"""Target: Process() — source has init(), global vars, consts, struct."""
|
||||
|
||||
def test_init_in_read_only_context(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
|
||||
assert "func init()" in result.read_only_context_code
|
||||
|
||||
def test_referenced_globals_in_read_only_context(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
|
||||
assert "globalCache" in result.read_only_context_code
|
||||
assert "mu" in result.read_only_context_code
|
||||
|
||||
def test_referenced_const_in_read_only_context(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
|
||||
assert "MaxRetries" in result.read_only_context_code
|
||||
|
||||
def test_referenced_struct_in_read_only_context(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
|
||||
assert "type Config struct" in result.read_only_context_code
|
||||
|
||||
def test_init_not_in_helpers(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
|
||||
helper_names = [h.only_function_name for h in result.helper_functions]
|
||||
assert "init" not in helper_names
|
||||
|
||||
def test_init_not_in_read_writable_code(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
|
||||
code = result.read_writable_code.code_strings[0].code
|
||||
assert "func init()" not in code
|
||||
|
||||
def test_full_read_only_context_string(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
|
||||
expected = dedent("""\
|
||||
var (
|
||||
\tglobalCache map[string]int
|
||||
\tmu sync.Mutex
|
||||
)
|
||||
|
||||
const MaxRetries = 5
|
||||
|
||||
type Config struct {
|
||||
\tName string
|
||||
\tMax int
|
||||
}
|
||||
|
||||
func init() {
|
||||
\tglobalCache = make(map[string]int)
|
||||
\tglobalCache["default"] = 0
|
||||
\tdefaultCfg := Config{Name: "prod", Max: MaxRetries}
|
||||
\t_ = defaultCfg
|
||||
\tmu.Lock()
|
||||
\tmu.Unlock()
|
||||
}""")
|
||||
assert result.read_only_context_code == expected
|
||||
|
||||
|
||||
class TestBuildContextNoInit:
|
||||
"""Source without init — verify no init context is added."""
|
||||
|
||||
def test_no_init_no_extra_read_only(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
assert "func init()" not in result.read_only_context_code
|
||||
|
||||
def test_no_init_read_only_empty_for_function(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
assert result.read_only_context_code == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: subdirectory / relative path handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildContextSubdirectory:
|
||||
"""Source file in a pkg/ subdirectory."""
|
||||
|
||||
def test_relative_path_includes_subdir(self, tmp_path: Path) -> None:
|
||||
root = tmp_path.resolve()
|
||||
pkg = root / "pkg"
|
||||
pkg.mkdir()
|
||||
source_file = (pkg / "calc.go").resolve()
|
||||
source_file.write_text(SIMPLE_SOURCE, encoding="utf-8")
|
||||
|
||||
func = FunctionToOptimize(function_name="Double", file_path=source_file, language="go")
|
||||
ctx = extract_code_context(func, root)
|
||||
result = _build_optimization_context(ctx, source_file, "go", root)
|
||||
|
||||
assert result.read_writable_code.code_strings[0].file_path == Path("pkg/calc.go")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: hashing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildContextHashing:
|
||||
def test_hash_is_sha256_of_flat(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
expected_hash = hashlib.sha256(result.read_writable_code.flat.encode("utf-8")).hexdigest()
|
||||
assert result.hashing_code_context_hash == expected_hash
|
||||
|
||||
def test_hashing_code_equals_flat(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
assert result.hashing_code_context == result.read_writable_code.flat
|
||||
|
||||
def test_different_targets_different_hashes(self, tmp_path: Path) -> None:
|
||||
dir_a = tmp_path / "a"
|
||||
dir_a.mkdir()
|
||||
dir_b = tmp_path / "b"
|
||||
dir_b.mkdir()
|
||||
|
||||
r1 = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", dir_a)
|
||||
r2 = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Greet", dir_b)
|
||||
|
||||
assert r1.hashing_code_context_hash != r2.hashing_code_context_hash
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: testgen context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildContextTestgen:
|
||||
def test_testgen_matches_read_writable(self, tmp_path: Path) -> None:
|
||||
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
|
||||
assert result.testgen_context.markdown == result.read_writable_code.markdown
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: token limit enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildContextTokenLimits:
|
||||
def test_exceeds_optim_token_limit(self, tmp_path: Path) -> None:
|
||||
root = tmp_path.resolve()
|
||||
source_file = (root / "big.go").resolve()
|
||||
huge_code = "package big\n\nfunc Big() string {\n\treturn " + '"x" + ' * 100000 + '"x"\n}\n'
|
||||
source_file.write_text(huge_code, encoding="utf-8")
|
||||
|
||||
func = FunctionToOptimize(function_name="Big", file_path=source_file, language="go")
|
||||
ctx = extract_code_context(func, root)
|
||||
|
||||
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit"):
|
||||
_build_optimization_context(ctx, source_file, "go", root, optim_token_limit=10)
|
||||
|
||||
def test_exceeds_testgen_token_limit(self, tmp_path: Path) -> None:
|
||||
root = tmp_path.resolve()
|
||||
source_file = (root / "big.go").resolve()
|
||||
huge_code = "package big\n\nfunc Big() string {\n\treturn " + '"x" + ' * 100000 + '"x"\n}\n'
|
||||
source_file.write_text(huge_code, encoding="utf-8")
|
||||
|
||||
func = FunctionToOptimize(function_name="Big", file_path=source_file, language="go")
|
||||
ctx = extract_code_context(func, root)
|
||||
|
||||
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit"):
|
||||
_build_optimization_context(
|
||||
ctx, source_file, "go", root, optim_token_limit=1_000_000, testgen_token_limit=10
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: GoSupport wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGoSupportFunctionOptimizerClass:
|
||||
def test_returns_go_function_optimizer(self) -> None:
|
||||
from codeflash.languages.golang.function_optimizer import GoFunctionOptimizer
|
||||
from codeflash.languages.golang.support import GoSupport
|
||||
|
||||
support = GoSupport()
|
||||
assert support.function_optimizer_class is GoFunctionOptimizer
|
||||
|
|
@ -216,3 +216,125 @@ class TestGoAnalyzerExtract:
|
|||
def test_extract_nonexistent(self) -> None:
|
||||
analyzer = GoAnalyzer()
|
||||
assert analyzer.extract_function_source(GO_SOURCE, "DoesNotExist") is None
|
||||
|
||||
|
||||
GLOBALS_SOURCE = """\
|
||||
package server
|
||||
|
||||
import "sync"
|
||||
|
||||
var (
|
||||
\tglobalCache map[string]int
|
||||
\tmu sync.Mutex
|
||||
)
|
||||
|
||||
var singleVar = 42
|
||||
|
||||
const MaxRetries = 5
|
||||
|
||||
const (
|
||||
\tDefaultName = "prod"
|
||||
\tTimeout = 30
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
\tName string
|
||||
\tMax int
|
||||
}
|
||||
|
||||
func init() {
|
||||
\tglobalCache = make(map[string]int)
|
||||
\tglobalCache["default"] = 0
|
||||
\tdefaultCfg := Config{Name: DefaultName, Max: MaxRetries}
|
||||
\t_ = defaultCfg
|
||||
\tmu.Lock()
|
||||
\tmu.Unlock()
|
||||
}
|
||||
|
||||
func Process() int {
|
||||
\treturn singleVar + MaxRetries
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class TestGoAnalyzerGlobalDeclarations:
|
||||
def test_find_var_group(self) -> None:
|
||||
analyzer = GoAnalyzer()
|
||||
decls = analyzer.find_global_declarations(GLOBALS_SOURCE)
|
||||
var_decls = [d for d in decls if d.kind == "var"]
|
||||
all_names = [name for d in var_decls for name in d.names]
|
||||
assert "globalCache" in all_names
|
||||
assert "mu" in all_names
|
||||
assert "singleVar" in all_names
|
||||
|
||||
def test_find_const_group(self) -> None:
|
||||
analyzer = GoAnalyzer()
|
||||
decls = analyzer.find_global_declarations(GLOBALS_SOURCE)
|
||||
const_decls = [d for d in decls if d.kind == "const"]
|
||||
all_names = [name for d in const_decls for name in d.names]
|
||||
assert "MaxRetries" in all_names
|
||||
assert "DefaultName" in all_names
|
||||
assert "Timeout" in all_names
|
||||
|
||||
def test_grouped_var_names_together(self) -> None:
|
||||
analyzer = GoAnalyzer()
|
||||
decls = analyzer.find_global_declarations(GLOBALS_SOURCE)
|
||||
var_group = next(d for d in decls if "globalCache" in d.names)
|
||||
assert var_group.names == ("globalCache", "mu")
|
||||
|
||||
def test_single_var(self) -> None:
|
||||
analyzer = GoAnalyzer()
|
||||
decls = analyzer.find_global_declarations(GLOBALS_SOURCE)
|
||||
single = next(d for d in decls if "singleVar" in d.names)
|
||||
assert single.kind == "var"
|
||||
assert single.source_code == "var singleVar = 42"
|
||||
|
||||
def test_const_group_source_code(self) -> None:
|
||||
analyzer = GoAnalyzer()
|
||||
decls = analyzer.find_global_declarations(GLOBALS_SOURCE)
|
||||
group = next(d for d in decls if "DefaultName" in d.names)
|
||||
assert "DefaultName" in group.source_code
|
||||
assert "Timeout" in group.source_code
|
||||
|
||||
def test_no_globals_in_clean_source(self) -> None:
|
||||
analyzer = GoAnalyzer()
|
||||
decls = analyzer.find_global_declarations("package main\n\nfunc main() {}\n")
|
||||
assert decls == []
|
||||
|
||||
|
||||
class TestGoAnalyzerCollectBodyIdentifiers:
|
||||
def test_init_body_identifiers(self) -> None:
|
||||
analyzer = GoAnalyzer()
|
||||
ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "init")
|
||||
assert "globalCache" in ids
|
||||
assert "Config" in ids
|
||||
assert "DefaultName" in ids
|
||||
assert "MaxRetries" in ids
|
||||
assert "mu" in ids
|
||||
|
||||
def test_process_body_identifiers(self) -> None:
|
||||
analyzer = GoAnalyzer()
|
||||
ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "Process")
|
||||
assert "singleVar" in ids
|
||||
assert "MaxRetries" in ids
|
||||
|
||||
def test_nonexistent_function_returns_empty(self) -> None:
|
||||
analyzer = GoAnalyzer()
|
||||
ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "DoesNotExist")
|
||||
assert ids == set()
|
||||
|
||||
def test_method_body_identifiers(self) -> None:
|
||||
source = """\
|
||||
package calc
|
||||
|
||||
type Calc struct{ val int }
|
||||
|
||||
var offset = 10
|
||||
|
||||
func (c *Calc) Compute() int {
|
||||
\treturn c.val + offset
|
||||
}
|
||||
"""
|
||||
analyzer = GoAnalyzer()
|
||||
ids = analyzer.collect_body_identifiers(source, "Compute", receiver_type="Calc")
|
||||
assert "offset" in ids
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class TestReplaceFunction:
|
|||
assert result == expected
|
||||
|
||||
|
||||
class TestAddGlobalDeclarations:
|
||||
class TestAddGlobalDeclarationsImports:
|
||||
def test_add_import_to_existing_block(self) -> None:
|
||||
original = 'package calc\n\nimport (\n\t"fmt"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n'
|
||||
optimized = 'package calc\n\nimport (\n\t"fmt"\n\t"math"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n'
|
||||
|
|
@ -102,6 +102,464 @@ class TestAddGlobalDeclarations:
|
|||
assert result == source
|
||||
|
||||
|
||||
class TestAddGlobalDeclarationsNewVar:
|
||||
def test_add_single_new_var(self) -> None:
|
||||
original = (
|
||||
"package calc\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package calc\n\n"
|
||||
"var cache = make(map[int]int)\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package calc\n"
|
||||
"var cache = make(map[int]int)\n\n"
|
||||
"\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_add_grouped_var_block(self) -> None:
|
||||
original = (
|
||||
"package server\n\n"
|
||||
'import "fmt"\n\n'
|
||||
"func Process() {\n"
|
||||
"\tfmt.Println()\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package server\n\n"
|
||||
'import "fmt"\n\n'
|
||||
"var (\n"
|
||||
"\tcache map[string]int\n"
|
||||
"\tbuffer []byte\n"
|
||||
")\n\n"
|
||||
"func Process() {\n"
|
||||
"\tfmt.Println()\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package server\n\n"
|
||||
'import "fmt"\n'
|
||||
"var (\n"
|
||||
"\tcache map[string]int\n"
|
||||
"\tbuffer []byte\n"
|
||||
")\n\n"
|
||||
"\n"
|
||||
"func Process() {\n"
|
||||
"\tfmt.Println()\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_add_new_var_preserves_existing_var(self) -> None:
|
||||
original = (
|
||||
"package calc\n\n"
|
||||
"var version = 1\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package calc\n\n"
|
||||
"var version = 1\n\n"
|
||||
"var cache = make(map[int]int)\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package calc\n\n"
|
||||
"var version = 1\n"
|
||||
"var cache = make(map[int]int)\n\n"
|
||||
"\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestAddGlobalDeclarationsNewConst:
|
||||
def test_add_single_new_const(self) -> None:
|
||||
original = (
|
||||
"package calc\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package calc\n\n"
|
||||
"const maxSize = 1024\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package calc\n"
|
||||
"const maxSize = 1024\n\n"
|
||||
"\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_add_grouped_const_block(self) -> None:
|
||||
original = (
|
||||
"package calc\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package calc\n\n"
|
||||
"const (\n"
|
||||
"\tMaxRetries = 5\n"
|
||||
"\tTimeout = 30\n"
|
||||
")\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package calc\n"
|
||||
"const (\n"
|
||||
"\tMaxRetries = 5\n"
|
||||
"\tTimeout = 30\n"
|
||||
")\n\n"
|
||||
"\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_add_new_const_preserves_existing_const(self) -> None:
|
||||
original = (
|
||||
"package calc\n\n"
|
||||
"const Pi = 3.14\n\n"
|
||||
"func Area(r float64) float64 {\n"
|
||||
"\treturn Pi * r * r\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package calc\n\n"
|
||||
"const Pi = 3.14\n\n"
|
||||
"const TwoPi = 6.28\n\n"
|
||||
"func Area(r float64) float64 {\n"
|
||||
"\treturn Pi * r * r\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package calc\n\n"
|
||||
"const Pi = 3.14\n"
|
||||
"const TwoPi = 6.28\n\n"
|
||||
"\n"
|
||||
"func Area(r float64) float64 {\n"
|
||||
"\treturn Pi * r * r\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestAddGlobalDeclarationsModifyVar:
|
||||
def test_modify_single_var_value(self) -> None:
|
||||
original = (
|
||||
"package calc\n\n"
|
||||
"var bufferSize = 256\n\n"
|
||||
"func Process() int {\n"
|
||||
"\treturn bufferSize\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package calc\n\n"
|
||||
"var bufferSize = 1024\n\n"
|
||||
"func Process() int {\n"
|
||||
"\treturn bufferSize\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package calc\n\n"
|
||||
"var bufferSize = 1024\n"
|
||||
"\n"
|
||||
"func Process() int {\n"
|
||||
"\treturn bufferSize\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_modify_grouped_var_block(self) -> None:
|
||||
original = (
|
||||
"package server\n\n"
|
||||
"var (\n"
|
||||
'\thost = "localhost"\n'
|
||||
"\tport = 8080\n"
|
||||
")\n\n"
|
||||
"func Addr() string {\n"
|
||||
"\treturn host\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package server\n\n"
|
||||
"var (\n"
|
||||
'\thost = "0.0.0.0"\n'
|
||||
"\tport = 9090\n"
|
||||
")\n\n"
|
||||
"func Addr() string {\n"
|
||||
"\treturn host\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package server\n\n"
|
||||
"var (\n"
|
||||
'\thost = "0.0.0.0"\n'
|
||||
"\tport = 9090\n"
|
||||
")\n"
|
||||
"\n"
|
||||
"func Addr() string {\n"
|
||||
"\treturn host\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_modify_var_type(self) -> None:
|
||||
original = (
|
||||
"package calc\n\n"
|
||||
"var counter int\n\n"
|
||||
"func Inc() {\n"
|
||||
"\tcounter++\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package calc\n\n"
|
||||
"var counter int64\n\n"
|
||||
"func Inc() {\n"
|
||||
"\tcounter++\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package calc\n\n"
|
||||
"var counter int64\n"
|
||||
"\n"
|
||||
"func Inc() {\n"
|
||||
"\tcounter++\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestAddGlobalDeclarationsModifyConst:
|
||||
def test_modify_single_const_value(self) -> None:
|
||||
original = (
|
||||
"package calc\n\n"
|
||||
"const MaxRetries = 3\n\n"
|
||||
"func Retries() int {\n"
|
||||
"\treturn MaxRetries\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package calc\n\n"
|
||||
"const MaxRetries = 10\n\n"
|
||||
"func Retries() int {\n"
|
||||
"\treturn MaxRetries\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package calc\n\n"
|
||||
"const MaxRetries = 10\n"
|
||||
"\n"
|
||||
"func Retries() int {\n"
|
||||
"\treturn MaxRetries\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_modify_const_group(self) -> None:
|
||||
original = (
|
||||
"package server\n\n"
|
||||
"const (\n"
|
||||
"\tDefaultTimeout = 30\n"
|
||||
"\tMaxConnections = 100\n"
|
||||
")\n\n"
|
||||
"func Config() int {\n"
|
||||
"\treturn DefaultTimeout\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package server\n\n"
|
||||
"const (\n"
|
||||
"\tDefaultTimeout = 60\n"
|
||||
"\tMaxConnections = 500\n"
|
||||
")\n\n"
|
||||
"func Config() int {\n"
|
||||
"\treturn DefaultTimeout\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package server\n\n"
|
||||
"const (\n"
|
||||
"\tDefaultTimeout = 60\n"
|
||||
"\tMaxConnections = 500\n"
|
||||
")\n"
|
||||
"\n"
|
||||
"func Config() int {\n"
|
||||
"\treturn DefaultTimeout\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestAddGlobalDeclarationsMixed:
|
||||
def test_new_import_and_new_var(self) -> None:
|
||||
original = (
|
||||
"package calc\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package calc\n\n"
|
||||
'import "sync"\n\n'
|
||||
"var mu sync.Mutex\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package calc\n"
|
||||
"import (\n"
|
||||
'\t"sync"\n'
|
||||
")\n"
|
||||
"var mu sync.Mutex\n\n"
|
||||
"\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_new_and_modified_globals_together(self) -> None:
|
||||
original = (
|
||||
"package server\n\n"
|
||||
"var bufferSize = 256\n\n"
|
||||
"func Process() int {\n"
|
||||
"\treturn bufferSize\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package server\n\n"
|
||||
"var bufferSize = 1024\n\n"
|
||||
"var cache = make(map[string]int)\n\n"
|
||||
"func Process() int {\n"
|
||||
"\treturn bufferSize\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package server\n\n"
|
||||
"var bufferSize = 1024\n"
|
||||
"var cache = make(map[string]int)\n\n"
|
||||
"\n"
|
||||
"func Process() int {\n"
|
||||
"\treturn bufferSize\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_no_globals_in_optimized_returns_unchanged(self) -> None:
|
||||
original = (
|
||||
"package calc\n\n"
|
||||
"var version = 1\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package calc\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
assert result == original
|
||||
|
||||
def test_identical_globals_returns_unchanged(self) -> None:
|
||||
source = (
|
||||
"package calc\n\n"
|
||||
"var version = 1\n\n"
|
||||
"const MaxSize = 100\n\n"
|
||||
"func Add(a, b int) int {\n"
|
||||
"\treturn a + b\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(source, source)
|
||||
assert result == source
|
||||
|
||||
def test_full_round_trip_new_import_var_const(self) -> None:
|
||||
original = (
|
||||
"package server\n\n"
|
||||
"import (\n"
|
||||
'\t"fmt"\n'
|
||||
")\n\n"
|
||||
"const Version = 1\n\n"
|
||||
"func Handle() {\n"
|
||||
"\tfmt.Println()\n"
|
||||
"}\n"
|
||||
)
|
||||
optimized = (
|
||||
"package server\n\n"
|
||||
"import (\n"
|
||||
'\t"fmt"\n'
|
||||
'\t"sync"\n'
|
||||
")\n\n"
|
||||
"const Version = 1\n\n"
|
||||
"var mu sync.Mutex\n\n"
|
||||
"const MaxConns = 100\n\n"
|
||||
"func Handle() {\n"
|
||||
"\tmu.Lock()\n"
|
||||
"\tdefer mu.Unlock()\n"
|
||||
"\tfmt.Println()\n"
|
||||
"}\n"
|
||||
)
|
||||
result = add_global_declarations(optimized, original)
|
||||
expected = (
|
||||
"package server\n\n"
|
||||
"import (\n"
|
||||
'\t"fmt"\n'
|
||||
'\t"sync"\n'
|
||||
")\n\n"
|
||||
"const Version = 1\n"
|
||||
"var mu sync.Mutex\n"
|
||||
"const MaxConns = 100\n\n"
|
||||
"\n"
|
||||
"func Handle() {\n"
|
||||
"\tfmt.Println()\n"
|
||||
"}\n"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestRemoveTestFunctions:
|
||||
def test_remove_single_function(self) -> None:
|
||||
test_source = (
|
||||
|
|
|
|||
Loading…
Reference in a new issue