go function optimizer and handle global vars context correctly

This commit is contained in:
ali 2026-04-23 17:40:20 +02:00
parent ba560308bc
commit ac478de753
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
9 changed files with 1701 additions and 5 deletions

View file

@ -36,9 +36,15 @@ def extract_code_context(
imports = analyzer.find_imports(source)
import_lines = [_import_to_line(imp) for imp in imports]
read_only_context = ""
read_only_parts: list[str] = []
if receiver_type:
read_only_context = _extract_struct_context(source, receiver_type, analyzer)
struct_ctx = _extract_struct_context(source, receiver_type, analyzer)
if struct_ctx:
read_only_parts.append(struct_ctx)
init_ctx = _extract_init_context(source, analyzer)
if init_ctx:
read_only_parts.append(init_ctx)
helpers = find_helper_functions(source, function, analyzer)
@ -46,7 +52,7 @@ def extract_code_context(
target_code=target_code,
target_file=function.file_path,
helper_functions=helpers,
read_only_context=read_only_context,
read_only_context="\n\n".join(read_only_parts),
imports=import_lines,
language=Language.GO,
)
@ -125,3 +131,27 @@ def _extract_struct_context(source: str, struct_name: str, analyzer: GoAnalyzer)
lines = source.splitlines()
return "\n".join(lines[s.starting_line - 1 : s.ending_line])
return ""
def _extract_init_context(source: str, analyzer: GoAnalyzer) -> str:
init_source = analyzer.extract_function_source(source, "init")
if init_source is None:
return ""
init_ids = analyzer.collect_body_identifiers(source, "init")
if not init_ids:
return init_source
parts: list[str] = []
for decl in analyzer.find_global_declarations(source):
if init_ids & set(decl.names):
parts.append(decl.source_code)
for struct in analyzer.find_structs(source):
if struct.name in init_ids:
lines = source.splitlines()
parts.append("\n".join(lines[struct.starting_line - 1 : struct.ending_line]))
parts.append(init_source)
return "\n\n".join(parts)

View file

@ -0,0 +1,160 @@
from __future__ import annotations
import hashlib
from collections import defaultdict
from typing import TYPE_CHECKING
from codeflash.code_utils.code_utils import encoded_tokens_len
from codeflash.code_utils.config_consts import (
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
READ_WRITABLE_LIMIT_ERROR,
TESTGEN_CONTEXT_TOKEN_LIMIT,
TESTGEN_LIMIT_ERROR,
)
from codeflash.either import Failure, Success
from codeflash.languages.function_optimizer import FunctionOptimizer
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource
from codeflash.verification.equivalence import compare_test_results
if TYPE_CHECKING:
from pathlib import Path
from codeflash.either import Result
from codeflash.languages.base import CodeContext, HelperFunction
from codeflash.models.models import OriginalCodeBaseline, TestDiff, TestResults
class GoFunctionOptimizer(FunctionOptimizer):
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
language = Language(self.function_to_optimize.language)
lang_support = get_language_support(language)
try:
code_context = lang_support.extract_code_context(
self.function_to_optimize, self.project_root, self.project_root
)
return Success(
_build_optimization_context(
code_context,
self.function_to_optimize.file_path,
self.function_to_optimize.language,
self.project_root,
)
)
except ValueError as e:
return Failure(str(e))
def compare_candidate_results(
self,
baseline_results: OriginalCodeBaseline,
candidate_behavior_results: TestResults,
optimization_candidate_index: int,
) -> tuple[bool, list[TestDiff]]:
return compare_test_results(
baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True
)
def replace_function_and_helpers_with_optimized_code(
self,
code_context: CodeOptimizationContext,
optimized_code: CodeStringsMarkdown,
original_helper_code: dict[Path, str],
) -> bool:
from codeflash.languages.code_replacer import replace_function_definitions_for_language
did_update = False
for module_abspath, qualified_names in self.group_functions_by_file(code_context).items():
did_update |= replace_function_definitions_for_language(
function_names=list(qualified_names),
optimized_code=optimized_code,
module_abspath=module_abspath,
project_root_path=self.project_root,
lang_support=self.language_support,
function_to_optimize=self.function_to_optimize,
)
return did_update
def _build_optimization_context(
code_context: CodeContext,
file_path: Path,
language: str,
project_root: Path,
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
) -> CodeOptimizationContext:
if code_context.imports:
inner = "\n".join(f"\t{imp}" for imp in code_context.imports)
imports_code = f"import (\n{inner}\n)"
else:
imports_code = ""
try:
target_relative_path = file_path.resolve().relative_to(project_root.resolve())
except ValueError:
target_relative_path = file_path
helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list)
helper_function_sources = []
for helper in code_context.helper_functions:
helpers_by_file[helper.file_path].append(helper)
helper_function_sources.append(
FunctionSource(
file_path=helper.file_path,
qualified_name=helper.qualified_name,
fully_qualified_name=helper.qualified_name,
only_function_name=helper.name,
source_code=helper.source_code,
)
)
target_file_code = code_context.target_code
same_file_helpers = helpers_by_file.get(file_path, [])
if same_file_helpers:
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
target_file_code = target_file_code + "\n\n" + helper_code
if imports_code:
target_file_code = imports_code + "\n\n" + target_file_code
read_writable_code_strings = [CodeString(code=target_file_code, file_path=target_relative_path, language=language)]
for helper_file_path, file_helpers in helpers_by_file.items():
if helper_file_path == file_path:
continue
try:
helper_relative_path = helper_file_path.resolve().relative_to(project_root.resolve())
except ValueError:
helper_relative_path = helper_file_path
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
read_writable_code_strings.append(
CodeString(code=combined_helper_code, file_path=helper_relative_path, language=language)
)
read_writable_code = CodeStringsMarkdown(code_strings=read_writable_code_strings, language=language)
testgen_context = CodeStringsMarkdown(code_strings=read_writable_code_strings.copy(), language=language)
read_writable_tokens = encoded_tokens_len(read_writable_code.markdown)
if read_writable_tokens > optim_token_limit:
raise ValueError(READ_WRITABLE_LIMIT_ERROR)
testgen_tokens = encoded_tokens_len(testgen_context.markdown)
if testgen_tokens > testgen_token_limit:
raise ValueError(TESTGEN_LIMIT_ERROR)
code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest()
return CodeOptimizationContext(
testgen_context=testgen_context,
read_writable_code=read_writable_code,
read_only_context_code=code_context.read_only_context,
hashing_code_context=read_writable_code.flat,
hashing_code_context_hash=code_hash,
helper_functions=helper_function_sources,
testgen_helper_fqns=[fs.fully_qualified_name for fs in helper_function_sources],
preexisting_objects=set(),
)

View file

@ -81,6 +81,15 @@ class GoImportInfo:
ending_line: int
@dataclass(frozen=True)
class GoGlobalDeclaration:
names: tuple[str, ...]
kind: str
source_code: str
starting_line: int
ending_line: int
class GoAnalyzer:
def __init__(self) -> None:
self._parser = _get_go_parser()
@ -189,6 +198,45 @@ class GoAnalyzer:
)
return results
def find_global_declarations(self, source: str) -> list[GoGlobalDeclaration]:
tree = self.parse(source)
results: list[GoGlobalDeclaration] = []
for node in tree.root_node.children:
if node.type in ("var_declaration", "const_declaration"):
kind = "var" if node.type == "var_declaration" else "const"
names = _extract_declaration_names(node, self)
if names:
results.append(
GoGlobalDeclaration(
names=tuple(names),
kind=kind,
source_code=self.get_node_text(node),
starting_line=node.start_point.row + 1,
ending_line=node.end_point.row + 1,
)
)
return results
def collect_body_identifiers(self, source: str, func_name: str, receiver_type: str | None = None) -> set[str]:
tree = self.parse(source)
for node in tree.root_node.children:
if receiver_type is None and node.type == "function_declaration":
name_node = node.child_by_field_name("name")
if name_node is not None and self.get_node_text(name_node) == func_name:
body = node.child_by_field_name("body")
return _collect_identifiers(body) if body else set()
if receiver_type is not None and node.type == "method_declaration":
name_node = node.child_by_field_name("name")
if name_node is None or self.get_node_text(name_node) != func_name:
continue
recv_node = node.child_by_field_name("receiver")
if recv_node is not None:
recv_name, _ = self.parse_receiver(recv_node)
if recv_name == receiver_type:
body = node.child_by_field_name("body")
return _collect_identifiers(body) if body else set()
return set()
def find_package_name(self, source: str) -> str | None:
tree = self.parse(source)
for node in tree.root_node.children:
@ -317,6 +365,42 @@ def _iter_import_specs(import_node: Node) -> list[Node]:
return results
def _extract_declaration_names(node: Node, analyzer: GoAnalyzer) -> list[str]:
names: list[str] = []
for child in node.children:
if child.type in ("var_spec", "const_spec"):
name_node = child.child_by_field_name("name")
if name_node is not None:
names.append(analyzer.get_node_text(name_node))
elif child.type in ("var_spec_list", "const_spec_list"):
for spec in child.children:
if spec.type in ("var_spec", "const_spec"):
name_node = spec.child_by_field_name("name")
if name_node is not None:
names.append(analyzer.get_node_text(name_node))
return names
def _collect_identifiers(node: Node | None) -> set[str]:
if node is None:
return set()
ids: set[str] = set()
stack = [node]
while stack:
n = stack.pop()
if n.type in ("identifier", "type_identifier"):
text = n.parent
if text is not None and text.type not in ("parameter_declaration", "short_var_declaration"):
ids.add(n.text.decode("utf-8") if n.text else "")
elif text is not None and text.type == "short_var_declaration":
name_node = text.child_by_field_name("left")
if name_node is not n and (name_node is None or n not in (name_node, *tuple(name_node.children))):
ids.add(n.text.decode("utf-8") if n.text else "")
stack.extend(n.children)
ids.discard("")
return ids
def _find_preceding_comment_line(node: Node) -> int | None:
prev = node.prev_named_sibling
if prev is None:

View file

@ -55,6 +55,11 @@ def replace_function(
def add_global_declarations(optimized_code: str, original_source: str, analyzer: GoAnalyzer | None = None) -> str:
analyzer = analyzer or GoAnalyzer()
merged = _merge_imports(optimized_code, original_source, analyzer)
return _merge_global_var_const(optimized_code, merged, analyzer)
def _merge_imports(optimized_code: str, original_source: str, analyzer: GoAnalyzer) -> str:
opt_imports = analyzer.find_imports(optimized_code)
orig_imports = analyzer.find_imports(original_source)
orig_paths = {imp.path for imp in orig_imports}
@ -91,6 +96,74 @@ def add_global_declarations(optimized_code: str, original_source: str, analyzer:
return "".join([*lines[:insert_at], import_block, *lines[insert_at:]])
def _merge_global_var_const(optimized_code: str, original_source: str, analyzer: GoAnalyzer) -> str:
opt_decls = analyzer.find_global_declarations(optimized_code)
if not opt_decls:
return original_source
orig_decls = analyzer.find_global_declarations(original_source)
orig_names_to_decl: dict[str, object] = {}
for decl in orig_decls:
for name in decl.names:
orig_names_to_decl[name] = decl
new_decls: list[str] = []
replaced_decls: set[int] = set()
for opt_decl in opt_decls:
overlapping_orig = None
for name in opt_decl.names:
if name in orig_names_to_decl:
overlapping_orig = orig_names_to_decl[name]
break
if overlapping_orig is None:
new_decls.append(opt_decl.source_code)
elif overlapping_orig.source_code.strip() != opt_decl.source_code.strip():
orig_id = id(overlapping_orig)
if orig_id not in replaced_decls:
replaced_decls.add(orig_id)
original_source = _replace_declaration_block(original_source, overlapping_orig, opt_decl.source_code)
if new_decls:
original_source = _insert_new_declarations(original_source, new_decls, analyzer)
return original_source
def _replace_declaration_block(source: str, orig_decl: object, new_source_code: str) -> str:
lines = source.splitlines(keepends=True)
start = orig_decl.starting_line - 1
end = orig_decl.ending_line
replacement = new_source_code.rstrip("\n") + "\n"
return "".join([*lines[:start], replacement, *lines[end:]])
def _insert_new_declarations(source: str, new_decls: list[str], analyzer: GoAnalyzer) -> str:
lines = source.splitlines(keepends=True)
insert_at = _find_declarations_insert_point(source, analyzer)
block = "\n".join(new_decls) + "\n\n"
return "".join([*lines[:insert_at], block, *lines[insert_at:]])
def _find_declarations_insert_point(source: str, analyzer: GoAnalyzer) -> int:
tree = analyzer.parse(source)
last_line = 0
for node in tree.root_node.children:
if node.type in ("import_declaration", "var_declaration", "const_declaration"):
candidate = node.end_point.row + 1
last_line = max(last_line, candidate)
if last_line > 0:
return last_line
for node in tree.root_node.children:
if node.type == "package_clause":
return node.end_point.row + 1
return 0
def remove_test_functions(test_source: str, functions_to_remove: list[str], analyzer: GoAnalyzer | None = None) -> str:
analyzer = analyzer or GoAnalyzer()
tree = analyzer.parse(test_source)

View file

@ -85,7 +85,9 @@ class GoSupport:
@property
def function_optimizer_class(self) -> type:
raise NotImplementedError
from codeflash.languages.golang.function_optimizer import GoFunctionOptimizer
return GoFunctionOptimizer
def discover_functions(
self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None

View file

@ -195,6 +195,120 @@ class TestExtractCodeContextEdgeCases:
assert ctx.imports == ['"fmt"', '"os"', 'str "strings"']
GO_SOURCE_WITH_INIT = """\
package server
import "sync"
var (
\tglobalCache map[string]int
\tmu sync.Mutex
)
const MaxRetries = 5
type Config struct {
\tName string
\tMax int
}
func init() {
\tglobalCache = make(map[string]int)
\tglobalCache["default"] = 0
\tmu.Lock()
\tmu.Unlock()
}
func Process() int {
\treturn MaxRetries
}
"""
class TestExtractCodeContextWithInit:
def test_init_in_read_only_context(self, tmp_path: Path) -> None:
source_file = (tmp_path / "server.go").resolve()
source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8")
func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go")
ctx = extract_code_context(func, tmp_path.resolve())
assert "func init()" in ctx.read_only_context
def test_init_referenced_globals_in_read_only_context(self, tmp_path: Path) -> None:
source_file = (tmp_path / "server.go").resolve()
source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8")
func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go")
ctx = extract_code_context(func, tmp_path.resolve())
assert "globalCache" in ctx.read_only_context
assert "mu" in ctx.read_only_context
def test_init_not_in_helpers(self, tmp_path: Path) -> None:
source_file = (tmp_path / "server.go").resolve()
source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8")
func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go")
ctx = extract_code_context(func, tmp_path.resolve())
helper_names = [h.name for h in ctx.helper_functions]
assert "init" not in helper_names
def test_no_init_no_extra_context(self, tmp_path: Path) -> None:
source_file = (tmp_path / "calc.go").resolve()
source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8")
func = FunctionToOptimize(function_name="Add", file_path=source_file, language="go")
ctx = extract_code_context(func, tmp_path.resolve())
assert "func init()" not in ctx.read_only_context
def test_full_init_read_only_context(self, tmp_path: Path) -> None:
source_file = (tmp_path / "server.go").resolve()
source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8")
func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go")
ctx = extract_code_context(func, tmp_path.resolve())
expected = (
"var (\n"
"\tglobalCache map[string]int\n"
"\tmu sync.Mutex\n"
")\n"
"\n"
"func init() {\n"
"\tglobalCache = make(map[string]int)\n"
"\tglobalCache[\"default\"] = 0\n"
"\tmu.Lock()\n"
"\tmu.Unlock()\n"
"}"
)
assert ctx.read_only_context == expected
def test_method_with_init_combines_struct_and_init_context(self, tmp_path: Path) -> None:
source = """\
package server
var globalOffset = 10
type Calc struct {
\tVal int
}
func init() {
\tglobalOffset = 42
}
func (c *Calc) Compute() int {
\treturn c.Val + globalOffset
}
"""
source_file = (tmp_path / "server.go").resolve()
source_file.write_text(source, encoding="utf-8")
func = FunctionToOptimize(
function_name="Compute",
file_path=source_file,
parents=[FunctionParent(name="Calc", type="StructDef")],
language="go",
is_method=True,
)
ctx = extract_code_context(func, tmp_path.resolve())
assert "type Calc struct" in ctx.read_only_context
assert "func init()" in ctx.read_only_context
assert "var globalOffset = 10" in ctx.read_only_context
class TestFindHelperFunctions:
def test_skips_init_and_main(self, tmp_path: Path) -> None:
source = "package main\n\nfunc init() { println() }\n\nfunc main() { println() }\n\nfunc Target() int { return 1 }\n"

View file

@ -0,0 +1,653 @@
from __future__ import annotations
import hashlib
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING
import pytest
from codeflash.languages.golang.context import extract_code_context
from codeflash.languages.golang.function_optimizer import _build_optimization_context
from codeflash.models.function_types import FunctionParent, FunctionToOptimize
if TYPE_CHECKING:
from codeflash.models.models import CodeOptimizationContext
# ---------------------------------------------------------------------------
# Realistic Go sources used across test classes
# ---------------------------------------------------------------------------
CALCULATOR_SOURCE = dedent("""\
package calc
import (
\t"fmt"
\t"math"
\tstr "strings"
)
// Calculator holds running computation state.
type Calculator struct {
\tResult float64
\tHistory []float64
}
// Formatter controls output rendering.
type Formatter interface {
\tFormat(val float64) string
}
// Add returns the sum of two integers.
func Add(a, b int) int {
\treturn a + b
}
func subtract(a, b int) int {
\treturn a - b
}
func multiply(a, b int) int {
\treturn a * b
}
// Greet builds a greeting message.
func Greet(name string) string {
\treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name))
}
// AddFloat adds a float value and records history.
func (c *Calculator) AddFloat(val float64) float64 {
\tc.Result += val
\tc.History = append(c.History, c.Result)
\treturn c.Result
}
// Sqrt computes the square root of the current result.
func (c *Calculator) Sqrt() float64 {
\tc.Result = math.Sqrt(c.Result)
\tc.History = append(c.History, c.Result)
\treturn c.Result
}
// Reset zeroes out the calculator.
func (c Calculator) Reset() Calculator {
\tc.Result = 0
\tc.History = nil
\treturn c
}
""")
SIMPLE_SOURCE = dedent("""\
package simple
func Double(x int) int {
\treturn x * 2
}
""")
INIT_SOURCE = dedent("""\
package server
import (
\t"fmt"
\t"sync"
)
var (
\tglobalCache map[string]int
\tmu sync.Mutex
)
var singleVar = 42
const MaxRetries = 5
type Config struct {
\tName string
\tMax int
}
func init() {
\tglobalCache = make(map[string]int)
\tglobalCache["default"] = 0
\tdefaultCfg := Config{Name: "prod", Max: MaxRetries}
\t_ = defaultCfg
\tmu.Lock()
\tmu.Unlock()
}
func Process() int {
\tfmt.Println("processing")
\treturn singleVar + MaxRetries
}
""")
# ---------------------------------------------------------------------------
# Helpers to drive the full extract → build pipeline
# ---------------------------------------------------------------------------
def _build_context_for_function(
source: str,
filename: str,
function_name: str,
tmp_path: Path,
parents: list[FunctionParent] | None = None,
is_method: bool = False,
) -> CodeOptimizationContext:
root = tmp_path.resolve()
source_file = (root / filename).resolve()
source_file.write_text(source, encoding="utf-8")
func = FunctionToOptimize(
function_name=function_name, file_path=source_file, parents=parents or [], language="go", is_method=is_method
)
code_context = extract_code_context(func, root)
return _build_optimization_context(code_context, source_file, "go", root)
# ---------------------------------------------------------------------------
# Tests: targeting a plain exported function
# ---------------------------------------------------------------------------
class TestBuildContextExportedFunction:
"""Target: Add(a, b int) int — a plain exported function with a doc comment."""
def test_full_assembled_code_string(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
code = result.read_writable_code.code_strings[0].code
expected = dedent("""\
import (
\t"fmt"
\t"math"
\tstr "strings"
)
// Add returns the sum of two integers.
func Add(a, b int) int {
\treturn a + b
}
func subtract(a, b int) int {
\treturn a - b
}
func multiply(a, b int) int {
\treturn a * b
}
// Greet builds a greeting message.
func Greet(name string) string {
\treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name))
}
// AddFloat adds a float value and records history.
func (c *Calculator) AddFloat(val float64) float64 {
\tc.Result += val
\tc.History = append(c.History, c.Result)
\treturn c.Result
}
// Sqrt computes the square root of the current result.
func (c *Calculator) Sqrt() float64 {
\tc.Result = math.Sqrt(c.Result)
\tc.History = append(c.History, c.Result)
\treturn c.Result
}
// Reset zeroes out the calculator.
func (c Calculator) Reset() Calculator {
\tc.Result = 0
\tc.History = nil
\treturn c
}
""")
assert code == expected
def test_code_excludes_package_clause(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
code = result.read_writable_code.code_strings[0].code
assert "package calc" not in code
def test_code_excludes_struct_definition(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
code = result.read_writable_code.code_strings[0].code
assert "type Calculator struct" not in code
def test_code_excludes_interface_definition(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
code = result.read_writable_code.code_strings[0].code
assert "type Formatter interface" not in code
def test_helpers_include_other_functions_and_methods(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
helper_names = sorted(h.only_function_name for h in result.helper_functions)
assert "subtract" in helper_names
assert "multiply" in helper_names
assert "Greet" in helper_names
assert "AddFloat" in helper_names
assert "Sqrt" in helper_names
assert "Reset" in helper_names
assert "Add" not in helper_names
def test_helper_sources_are_full_functions(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
by_name = {h.only_function_name: h for h in result.helper_functions}
assert by_name["subtract"].source_code == dedent("""\
func subtract(a, b int) int {
\treturn a - b
}""")
assert by_name["Greet"].source_code == dedent("""\
// Greet builds a greeting message.
func Greet(name string) string {
\treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name))
}
""")
def test_method_helpers_have_qualified_names(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
by_name = {h.only_function_name: h for h in result.helper_functions}
assert by_name["AddFloat"].qualified_name == "Calculator.AddFloat"
assert by_name["AddFloat"].fully_qualified_name == "Calculator.AddFloat"
assert by_name["subtract"].qualified_name == "subtract"
def test_no_read_only_context_for_plain_function(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
assert result.read_only_context_code == ""
def test_relative_path(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
assert result.read_writable_code.code_strings[0].file_path == Path("calc.go")
def test_language_tag(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
assert result.read_writable_code.code_strings[0].language == "go"
def test_testgen_fqns_match_helpers(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
fqns = set(result.testgen_helper_fqns)
helper_fqns = {h.fully_qualified_name for h in result.helper_functions}
assert fqns == helper_fqns
# ---------------------------------------------------------------------------
# Tests: targeting a method with a pointer receiver
# ---------------------------------------------------------------------------
class TestBuildContextPointerReceiverMethod:
"""Target: (c *Calculator) AddFloat(val float64) — pointer receiver method."""
def _build(self, tmp_path: Path) -> CodeOptimizationContext:
return _build_context_for_function(
CALCULATOR_SOURCE,
"calc.go",
"AddFloat",
tmp_path,
parents=[FunctionParent(name="Calculator", type="StructDef")],
is_method=True,
)
def test_full_assembled_code_string(self, tmp_path: Path) -> None:
result = self._build(tmp_path)
code = result.read_writable_code.code_strings[0].code
expected = dedent("""\
import (
\t"fmt"
\t"math"
\tstr "strings"
)
// AddFloat adds a float value and records history.
func (c *Calculator) AddFloat(val float64) float64 {
\tc.Result += val
\tc.History = append(c.History, c.Result)
\treturn c.Result
}
// Add returns the sum of two integers.
func Add(a, b int) int {
\treturn a + b
}
func subtract(a, b int) int {
\treturn a - b
}
func multiply(a, b int) int {
\treturn a * b
}
// Greet builds a greeting message.
func Greet(name string) string {
\treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name))
}
// Sqrt computes the square root of the current result.
func (c *Calculator) Sqrt() float64 {
\tc.Result = math.Sqrt(c.Result)
\tc.History = append(c.History, c.Result)
\treturn c.Result
}
// Reset zeroes out the calculator.
func (c Calculator) Reset() Calculator {
\tc.Result = 0
\tc.History = nil
\treturn c
}
""")
assert code == expected
def test_code_excludes_package_and_type_defs(self, tmp_path: Path) -> None:
result = self._build(tmp_path)
code = result.read_writable_code.code_strings[0].code
assert "package calc" not in code
assert "type Calculator struct" not in code
assert "type Formatter interface" not in code
def test_read_only_context_is_struct_definition(self, tmp_path: Path) -> None:
result = self._build(tmp_path)
assert result.read_only_context_code == dedent("""\
type Calculator struct {
\tResult float64
\tHistory []float64
}""")
def test_helpers_exclude_self_include_others(self, tmp_path: Path) -> None:
result = self._build(tmp_path)
helper_names = sorted(h.only_function_name for h in result.helper_functions)
assert "AddFloat" not in helper_names
assert "Add" in helper_names
assert "subtract" in helper_names
assert "multiply" in helper_names
assert "Greet" in helper_names
assert "Sqrt" in helper_names
assert "Reset" in helper_names
def test_target_not_duplicated_in_code_string(self, tmp_path: Path) -> None:
result = self._build(tmp_path)
code = result.read_writable_code.code_strings[0].code
assert code.count("func (c *Calculator) AddFloat") == 1
# ---------------------------------------------------------------------------
# Tests: targeting a value receiver method
# ---------------------------------------------------------------------------
class TestBuildContextValueReceiverMethod:
"""Target: (c Calculator) Reset() — value receiver method."""
def _build(self, tmp_path: Path) -> CodeOptimizationContext:
return _build_context_for_function(
CALCULATOR_SOURCE,
"calc.go",
"Reset",
tmp_path,
parents=[FunctionParent(name="Calculator", type="StructDef")],
is_method=True,
)
def test_target_in_code_string(self, tmp_path: Path) -> None:
result = self._build(tmp_path)
code = result.read_writable_code.code_strings[0].code
expected_target = dedent("""\
// Reset zeroes out the calculator.
func (c Calculator) Reset() Calculator {
\tc.Result = 0
\tc.History = nil
\treturn c
}""")
assert code.count("func (c Calculator) Reset()") == 1
assert expected_target in code
def test_helpers_include_other_methods_on_same_struct(self, tmp_path: Path) -> None:
result = self._build(tmp_path)
helper_names = sorted(h.only_function_name for h in result.helper_functions)
assert "Reset" not in helper_names
assert "AddFloat" in helper_names
assert "Sqrt" in helper_names
assert "Add" in helper_names
def test_helper_code_in_assembled_string(self, tmp_path: Path) -> None:
result = self._build(tmp_path)
code = result.read_writable_code.code_strings[0].code
assert "func (c *Calculator) AddFloat" in code
assert "func (c *Calculator) Sqrt()" in code
assert "func Add(a, b int) int" in code
assert "func subtract(a, b int) int" in code
def test_struct_in_read_only_context(self, tmp_path: Path) -> None:
result = self._build(tmp_path)
assert result.read_only_context_code == dedent("""\
type Calculator struct {
\tResult float64
\tHistory []float64
}""")
# ---------------------------------------------------------------------------
# Tests: simple source with no imports, no methods, one function
# ---------------------------------------------------------------------------
class TestBuildContextMinimalSource:
"""Target: Double(x int) — minimal file with no imports or structs."""
def test_no_imports_no_prefix(self, tmp_path: Path) -> None:
result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path)
code = result.read_writable_code.code_strings[0].code
assert code == dedent("""\
func Double(x int) int {
\treturn x * 2
}""")
def test_no_helpers(self, tmp_path: Path) -> None:
result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path)
assert result.helper_functions == []
assert result.testgen_helper_fqns == []
def test_empty_read_only_context(self, tmp_path: Path) -> None:
result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path)
assert result.read_only_context_code == ""
def test_preexisting_objects_empty(self, tmp_path: Path) -> None:
result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path)
assert result.preexisting_objects == set()
# ---------------------------------------------------------------------------
# Tests: init function and globals in context
# ---------------------------------------------------------------------------
class TestBuildContextWithInit:
"""Target: Process() — source has init(), global vars, consts, struct."""
def test_init_in_read_only_context(self, tmp_path: Path) -> None:
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
assert "func init()" in result.read_only_context_code
def test_referenced_globals_in_read_only_context(self, tmp_path: Path) -> None:
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
assert "globalCache" in result.read_only_context_code
assert "mu" in result.read_only_context_code
def test_referenced_const_in_read_only_context(self, tmp_path: Path) -> None:
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
assert "MaxRetries" in result.read_only_context_code
def test_referenced_struct_in_read_only_context(self, tmp_path: Path) -> None:
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
assert "type Config struct" in result.read_only_context_code
def test_init_not_in_helpers(self, tmp_path: Path) -> None:
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
helper_names = [h.only_function_name for h in result.helper_functions]
assert "init" not in helper_names
def test_init_not_in_read_writable_code(self, tmp_path: Path) -> None:
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
code = result.read_writable_code.code_strings[0].code
assert "func init()" not in code
def test_full_read_only_context_string(self, tmp_path: Path) -> None:
result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path)
expected = dedent("""\
var (
\tglobalCache map[string]int
\tmu sync.Mutex
)
const MaxRetries = 5
type Config struct {
\tName string
\tMax int
}
func init() {
\tglobalCache = make(map[string]int)
\tglobalCache["default"] = 0
\tdefaultCfg := Config{Name: "prod", Max: MaxRetries}
\t_ = defaultCfg
\tmu.Lock()
\tmu.Unlock()
}""")
assert result.read_only_context_code == expected
class TestBuildContextNoInit:
"""Source without init — verify no init context is added."""
def test_no_init_no_extra_read_only(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
assert "func init()" not in result.read_only_context_code
def test_no_init_read_only_empty_for_function(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
assert result.read_only_context_code == ""
# ---------------------------------------------------------------------------
# Tests: subdirectory / relative path handling
# ---------------------------------------------------------------------------
class TestBuildContextSubdirectory:
"""Source file in a pkg/ subdirectory."""
def test_relative_path_includes_subdir(self, tmp_path: Path) -> None:
root = tmp_path.resolve()
pkg = root / "pkg"
pkg.mkdir()
source_file = (pkg / "calc.go").resolve()
source_file.write_text(SIMPLE_SOURCE, encoding="utf-8")
func = FunctionToOptimize(function_name="Double", file_path=source_file, language="go")
ctx = extract_code_context(func, root)
result = _build_optimization_context(ctx, source_file, "go", root)
assert result.read_writable_code.code_strings[0].file_path == Path("pkg/calc.go")
# ---------------------------------------------------------------------------
# Tests: hashing
# ---------------------------------------------------------------------------
class TestBuildContextHashing:
def test_hash_is_sha256_of_flat(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
expected_hash = hashlib.sha256(result.read_writable_code.flat.encode("utf-8")).hexdigest()
assert result.hashing_code_context_hash == expected_hash
def test_hashing_code_equals_flat(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
assert result.hashing_code_context == result.read_writable_code.flat
def test_different_targets_different_hashes(self, tmp_path: Path) -> None:
dir_a = tmp_path / "a"
dir_a.mkdir()
dir_b = tmp_path / "b"
dir_b.mkdir()
r1 = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", dir_a)
r2 = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Greet", dir_b)
assert r1.hashing_code_context_hash != r2.hashing_code_context_hash
# ---------------------------------------------------------------------------
# Tests: testgen context
# ---------------------------------------------------------------------------
class TestBuildContextTestgen:
def test_testgen_matches_read_writable(self, tmp_path: Path) -> None:
result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path)
assert result.testgen_context.markdown == result.read_writable_code.markdown
# ---------------------------------------------------------------------------
# Tests: token limit enforcement
# ---------------------------------------------------------------------------
class TestBuildContextTokenLimits:
def test_exceeds_optim_token_limit(self, tmp_path: Path) -> None:
root = tmp_path.resolve()
source_file = (root / "big.go").resolve()
huge_code = "package big\n\nfunc Big() string {\n\treturn " + '"x" + ' * 100000 + '"x"\n}\n'
source_file.write_text(huge_code, encoding="utf-8")
func = FunctionToOptimize(function_name="Big", file_path=source_file, language="go")
ctx = extract_code_context(func, root)
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit"):
_build_optimization_context(ctx, source_file, "go", root, optim_token_limit=10)
def test_exceeds_testgen_token_limit(self, tmp_path: Path) -> None:
root = tmp_path.resolve()
source_file = (root / "big.go").resolve()
huge_code = "package big\n\nfunc Big() string {\n\treturn " + '"x" + ' * 100000 + '"x"\n}\n'
source_file.write_text(huge_code, encoding="utf-8")
func = FunctionToOptimize(function_name="Big", file_path=source_file, language="go")
ctx = extract_code_context(func, root)
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit"):
_build_optimization_context(
ctx, source_file, "go", root, optim_token_limit=1_000_000, testgen_token_limit=10
)
# ---------------------------------------------------------------------------
# Tests: GoSupport wiring
# ---------------------------------------------------------------------------
class TestGoSupportFunctionOptimizerClass:
def test_returns_go_function_optimizer(self) -> None:
from codeflash.languages.golang.function_optimizer import GoFunctionOptimizer
from codeflash.languages.golang.support import GoSupport
support = GoSupport()
assert support.function_optimizer_class is GoFunctionOptimizer

View file

@ -216,3 +216,125 @@ class TestGoAnalyzerExtract:
def test_extract_nonexistent(self) -> None:
analyzer = GoAnalyzer()
assert analyzer.extract_function_source(GO_SOURCE, "DoesNotExist") is None
GLOBALS_SOURCE = """\
package server
import "sync"
var (
\tglobalCache map[string]int
\tmu sync.Mutex
)
var singleVar = 42
const MaxRetries = 5
const (
\tDefaultName = "prod"
\tTimeout = 30
)
type Config struct {
\tName string
\tMax int
}
func init() {
\tglobalCache = make(map[string]int)
\tglobalCache["default"] = 0
\tdefaultCfg := Config{Name: DefaultName, Max: MaxRetries}
\t_ = defaultCfg
\tmu.Lock()
\tmu.Unlock()
}
func Process() int {
\treturn singleVar + MaxRetries
}
"""
class TestGoAnalyzerGlobalDeclarations:
def test_find_var_group(self) -> None:
analyzer = GoAnalyzer()
decls = analyzer.find_global_declarations(GLOBALS_SOURCE)
var_decls = [d for d in decls if d.kind == "var"]
all_names = [name for d in var_decls for name in d.names]
assert "globalCache" in all_names
assert "mu" in all_names
assert "singleVar" in all_names
def test_find_const_group(self) -> None:
analyzer = GoAnalyzer()
decls = analyzer.find_global_declarations(GLOBALS_SOURCE)
const_decls = [d for d in decls if d.kind == "const"]
all_names = [name for d in const_decls for name in d.names]
assert "MaxRetries" in all_names
assert "DefaultName" in all_names
assert "Timeout" in all_names
def test_grouped_var_names_together(self) -> None:
analyzer = GoAnalyzer()
decls = analyzer.find_global_declarations(GLOBALS_SOURCE)
var_group = next(d for d in decls if "globalCache" in d.names)
assert var_group.names == ("globalCache", "mu")
def test_single_var(self) -> None:
analyzer = GoAnalyzer()
decls = analyzer.find_global_declarations(GLOBALS_SOURCE)
single = next(d for d in decls if "singleVar" in d.names)
assert single.kind == "var"
assert single.source_code == "var singleVar = 42"
def test_const_group_source_code(self) -> None:
analyzer = GoAnalyzer()
decls = analyzer.find_global_declarations(GLOBALS_SOURCE)
group = next(d for d in decls if "DefaultName" in d.names)
assert "DefaultName" in group.source_code
assert "Timeout" in group.source_code
def test_no_globals_in_clean_source(self) -> None:
analyzer = GoAnalyzer()
decls = analyzer.find_global_declarations("package main\n\nfunc main() {}\n")
assert decls == []
class TestGoAnalyzerCollectBodyIdentifiers:
def test_init_body_identifiers(self) -> None:
analyzer = GoAnalyzer()
ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "init")
assert "globalCache" in ids
assert "Config" in ids
assert "DefaultName" in ids
assert "MaxRetries" in ids
assert "mu" in ids
def test_process_body_identifiers(self) -> None:
analyzer = GoAnalyzer()
ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "Process")
assert "singleVar" in ids
assert "MaxRetries" in ids
def test_nonexistent_function_returns_empty(self) -> None:
analyzer = GoAnalyzer()
ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "DoesNotExist")
assert ids == set()
def test_method_body_identifiers(self) -> None:
source = """\
package calc
type Calc struct{ val int }
var offset = 10
func (c *Calc) Compute() int {
\treturn c.val + offset
}
"""
analyzer = GoAnalyzer()
ids = analyzer.collect_body_identifiers(source, "Compute", receiver_type="Calc")
assert "offset" in ids

View file

@ -74,7 +74,7 @@ class TestReplaceFunction:
assert result == expected
class TestAddGlobalDeclarations:
class TestAddGlobalDeclarationsImports:
def test_add_import_to_existing_block(self) -> None:
original = 'package calc\n\nimport (\n\t"fmt"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n'
optimized = 'package calc\n\nimport (\n\t"fmt"\n\t"math"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n'
@ -102,6 +102,464 @@ class TestAddGlobalDeclarations:
assert result == source
class TestAddGlobalDeclarationsNewVar:
def test_add_single_new_var(self) -> None:
original = (
"package calc\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
optimized = (
"package calc\n\n"
"var cache = make(map[int]int)\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package calc\n"
"var cache = make(map[int]int)\n\n"
"\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
assert result == expected
def test_add_grouped_var_block(self) -> None:
original = (
"package server\n\n"
'import "fmt"\n\n'
"func Process() {\n"
"\tfmt.Println()\n"
"}\n"
)
optimized = (
"package server\n\n"
'import "fmt"\n\n'
"var (\n"
"\tcache map[string]int\n"
"\tbuffer []byte\n"
")\n\n"
"func Process() {\n"
"\tfmt.Println()\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package server\n\n"
'import "fmt"\n'
"var (\n"
"\tcache map[string]int\n"
"\tbuffer []byte\n"
")\n\n"
"\n"
"func Process() {\n"
"\tfmt.Println()\n"
"}\n"
)
assert result == expected
def test_add_new_var_preserves_existing_var(self) -> None:
original = (
"package calc\n\n"
"var version = 1\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
optimized = (
"package calc\n\n"
"var version = 1\n\n"
"var cache = make(map[int]int)\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package calc\n\n"
"var version = 1\n"
"var cache = make(map[int]int)\n\n"
"\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
assert result == expected
class TestAddGlobalDeclarationsNewConst:
def test_add_single_new_const(self) -> None:
original = (
"package calc\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
optimized = (
"package calc\n\n"
"const maxSize = 1024\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package calc\n"
"const maxSize = 1024\n\n"
"\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
assert result == expected
def test_add_grouped_const_block(self) -> None:
original = (
"package calc\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
optimized = (
"package calc\n\n"
"const (\n"
"\tMaxRetries = 5\n"
"\tTimeout = 30\n"
")\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package calc\n"
"const (\n"
"\tMaxRetries = 5\n"
"\tTimeout = 30\n"
")\n\n"
"\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
assert result == expected
def test_add_new_const_preserves_existing_const(self) -> None:
original = (
"package calc\n\n"
"const Pi = 3.14\n\n"
"func Area(r float64) float64 {\n"
"\treturn Pi * r * r\n"
"}\n"
)
optimized = (
"package calc\n\n"
"const Pi = 3.14\n\n"
"const TwoPi = 6.28\n\n"
"func Area(r float64) float64 {\n"
"\treturn Pi * r * r\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package calc\n\n"
"const Pi = 3.14\n"
"const TwoPi = 6.28\n\n"
"\n"
"func Area(r float64) float64 {\n"
"\treturn Pi * r * r\n"
"}\n"
)
assert result == expected
class TestAddGlobalDeclarationsModifyVar:
def test_modify_single_var_value(self) -> None:
original = (
"package calc\n\n"
"var bufferSize = 256\n\n"
"func Process() int {\n"
"\treturn bufferSize\n"
"}\n"
)
optimized = (
"package calc\n\n"
"var bufferSize = 1024\n\n"
"func Process() int {\n"
"\treturn bufferSize\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package calc\n\n"
"var bufferSize = 1024\n"
"\n"
"func Process() int {\n"
"\treturn bufferSize\n"
"}\n"
)
assert result == expected
def test_modify_grouped_var_block(self) -> None:
original = (
"package server\n\n"
"var (\n"
'\thost = "localhost"\n'
"\tport = 8080\n"
")\n\n"
"func Addr() string {\n"
"\treturn host\n"
"}\n"
)
optimized = (
"package server\n\n"
"var (\n"
'\thost = "0.0.0.0"\n'
"\tport = 9090\n"
")\n\n"
"func Addr() string {\n"
"\treturn host\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package server\n\n"
"var (\n"
'\thost = "0.0.0.0"\n'
"\tport = 9090\n"
")\n"
"\n"
"func Addr() string {\n"
"\treturn host\n"
"}\n"
)
assert result == expected
def test_modify_var_type(self) -> None:
original = (
"package calc\n\n"
"var counter int\n\n"
"func Inc() {\n"
"\tcounter++\n"
"}\n"
)
optimized = (
"package calc\n\n"
"var counter int64\n\n"
"func Inc() {\n"
"\tcounter++\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package calc\n\n"
"var counter int64\n"
"\n"
"func Inc() {\n"
"\tcounter++\n"
"}\n"
)
assert result == expected
class TestAddGlobalDeclarationsModifyConst:
def test_modify_single_const_value(self) -> None:
original = (
"package calc\n\n"
"const MaxRetries = 3\n\n"
"func Retries() int {\n"
"\treturn MaxRetries\n"
"}\n"
)
optimized = (
"package calc\n\n"
"const MaxRetries = 10\n\n"
"func Retries() int {\n"
"\treturn MaxRetries\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package calc\n\n"
"const MaxRetries = 10\n"
"\n"
"func Retries() int {\n"
"\treturn MaxRetries\n"
"}\n"
)
assert result == expected
def test_modify_const_group(self) -> None:
original = (
"package server\n\n"
"const (\n"
"\tDefaultTimeout = 30\n"
"\tMaxConnections = 100\n"
")\n\n"
"func Config() int {\n"
"\treturn DefaultTimeout\n"
"}\n"
)
optimized = (
"package server\n\n"
"const (\n"
"\tDefaultTimeout = 60\n"
"\tMaxConnections = 500\n"
")\n\n"
"func Config() int {\n"
"\treturn DefaultTimeout\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package server\n\n"
"const (\n"
"\tDefaultTimeout = 60\n"
"\tMaxConnections = 500\n"
")\n"
"\n"
"func Config() int {\n"
"\treturn DefaultTimeout\n"
"}\n"
)
assert result == expected
class TestAddGlobalDeclarationsMixed:
def test_new_import_and_new_var(self) -> None:
original = (
"package calc\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
optimized = (
"package calc\n\n"
'import "sync"\n\n'
"var mu sync.Mutex\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package calc\n"
"import (\n"
'\t"sync"\n'
")\n"
"var mu sync.Mutex\n\n"
"\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
assert result == expected
def test_new_and_modified_globals_together(self) -> None:
original = (
"package server\n\n"
"var bufferSize = 256\n\n"
"func Process() int {\n"
"\treturn bufferSize\n"
"}\n"
)
optimized = (
"package server\n\n"
"var bufferSize = 1024\n\n"
"var cache = make(map[string]int)\n\n"
"func Process() int {\n"
"\treturn bufferSize\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package server\n\n"
"var bufferSize = 1024\n"
"var cache = make(map[string]int)\n\n"
"\n"
"func Process() int {\n"
"\treturn bufferSize\n"
"}\n"
)
assert result == expected
def test_no_globals_in_optimized_returns_unchanged(self) -> None:
original = (
"package calc\n\n"
"var version = 1\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
optimized = (
"package calc\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
result = add_global_declarations(optimized, original)
assert result == original
def test_identical_globals_returns_unchanged(self) -> None:
source = (
"package calc\n\n"
"var version = 1\n\n"
"const MaxSize = 100\n\n"
"func Add(a, b int) int {\n"
"\treturn a + b\n"
"}\n"
)
result = add_global_declarations(source, source)
assert result == source
def test_full_round_trip_new_import_var_const(self) -> None:
original = (
"package server\n\n"
"import (\n"
'\t"fmt"\n'
")\n\n"
"const Version = 1\n\n"
"func Handle() {\n"
"\tfmt.Println()\n"
"}\n"
)
optimized = (
"package server\n\n"
"import (\n"
'\t"fmt"\n'
'\t"sync"\n'
")\n\n"
"const Version = 1\n\n"
"var mu sync.Mutex\n\n"
"const MaxConns = 100\n\n"
"func Handle() {\n"
"\tmu.Lock()\n"
"\tdefer mu.Unlock()\n"
"\tfmt.Println()\n"
"}\n"
)
result = add_global_declarations(optimized, original)
expected = (
"package server\n\n"
"import (\n"
'\t"fmt"\n'
'\t"sync"\n'
")\n\n"
"const Version = 1\n"
"var mu sync.Mutex\n"
"const MaxConns = 100\n\n"
"\n"
"func Handle() {\n"
"\tfmt.Println()\n"
"}\n"
)
assert result == expected
class TestRemoveTestFunctions:
def test_remove_single_function(self) -> None:
test_source = (