extract class skeleton for optimization context

This commit is contained in:
misrasaurabh1 2026-01-28 23:28:59 -08:00
parent b30a52b8eb
commit 325534dbc2
7 changed files with 1265 additions and 14 deletions

View file

@ -0,0 +1,61 @@
/**
* Fibonacci Calculator Class - CommonJS module
* Intentionally inefficient for optimization testing.
*/
class FibonacciCalculator {
constructor() {
// No initialization needed
}
/**
* Calculate the nth Fibonacci number using naive recursion.
* This is intentionally slow to demonstrate optimization potential.
* @param {number} n - The index of the Fibonacci number to calculate
* @returns {number} The nth Fibonacci number
*/
fibonacci(n) {
if (n <= 1) {
return n;
}
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
}
/**
* Check if a number is a Fibonacci number.
* @param {number} num - The number to check
* @returns {boolean} True if num is a Fibonacci number
*/
isFibonacci(num) {
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
const check1 = 5 * num * num + 4;
const check2 = 5 * num * num - 4;
return this.isPerfectSquare(check1) || this.isPerfectSquare(check2);
}
/**
* Check if a number is a perfect square.
* @param {number} n - The number to check
* @returns {boolean} True if n is a perfect square
*/
isPerfectSquare(n) {
const sqrt = Math.sqrt(n);
return sqrt === Math.floor(sqrt);
}
/**
* Generate an array of Fibonacci numbers up to n.
* @param {number} n - The number of Fibonacci numbers to generate
* @returns {number[]} Array of Fibonacci numbers
*/
fibonacciSequence(n) {
const result = [];
for (let i = 0; i < n; i++) {
result.push(this.fibonacci(i));
}
return result;
}
}
// CommonJS exports
module.exports = { FibonacciCalculator };

View file

@ -14,7 +14,7 @@
}
},
"../../../packages/codeflash": {
"version": "0.1.0",
"version": "0.2.0",
"dev": true,
"hasInstallScript": true,
"license": "MIT",

View file

@ -0,0 +1,105 @@
const { FibonacciCalculator } = require('../fibonacci_class');
describe('FibonacciCalculator', () => {
let calc;
beforeEach(() => {
calc = new FibonacciCalculator();
});
describe('fibonacci', () => {
test('returns 0 for n=0', () => {
expect(calc.fibonacci(0)).toBe(0);
});
test('returns 1 for n=1', () => {
expect(calc.fibonacci(1)).toBe(1);
});
test('returns 1 for n=2', () => {
expect(calc.fibonacci(2)).toBe(1);
});
test('returns 5 for n=5', () => {
expect(calc.fibonacci(5)).toBe(5);
});
test('returns 55 for n=10', () => {
expect(calc.fibonacci(10)).toBe(55);
});
test('returns 233 for n=13', () => {
expect(calc.fibonacci(13)).toBe(233);
});
});
describe('isFibonacci', () => {
test('returns true for 0', () => {
expect(calc.isFibonacci(0)).toBe(true);
});
test('returns true for 1', () => {
expect(calc.isFibonacci(1)).toBe(true);
});
test('returns true for 8', () => {
expect(calc.isFibonacci(8)).toBe(true);
});
test('returns true for 13', () => {
expect(calc.isFibonacci(13)).toBe(true);
});
test('returns false for 4', () => {
expect(calc.isFibonacci(4)).toBe(false);
});
test('returns false for 6', () => {
expect(calc.isFibonacci(6)).toBe(false);
});
});
describe('isPerfectSquare', () => {
test('returns true for 0', () => {
expect(calc.isPerfectSquare(0)).toBe(true);
});
test('returns true for 1', () => {
expect(calc.isPerfectSquare(1)).toBe(true);
});
test('returns true for 4', () => {
expect(calc.isPerfectSquare(4)).toBe(true);
});
test('returns true for 16', () => {
expect(calc.isPerfectSquare(16)).toBe(true);
});
test('returns false for 2', () => {
expect(calc.isPerfectSquare(2)).toBe(false);
});
test('returns false for 3', () => {
expect(calc.isPerfectSquare(3)).toBe(false);
});
});
describe('fibonacciSequence', () => {
test('returns empty array for n=0', () => {
expect(calc.fibonacciSequence(0)).toEqual([]);
});
test('returns [0] for n=1', () => {
expect(calc.fibonacciSequence(1)).toEqual([0]);
});
test('returns first 5 Fibonacci numbers', () => {
expect(calc.fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
});
test('returns first 10 Fibonacci numbers', () => {
expect(calc.fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
});
});
});

View file

@ -35,6 +35,7 @@ class ExpectCallMatch:
func_args: str
assertion_chain: str
has_trailing_semicolon: bool
object_prefix: str = "" # Object prefix like "calc." or "this." or ""
@dataclass
@ -45,7 +46,8 @@ class StandaloneCallMatch:
end_pos: int
leading_whitespace: str
func_args: str
prefix: str # Everything between whitespace and func name (e.g., "await ", "")
prefix: str # "await " or ""
object_prefix: str # Object prefix like "calc." or "this." or ""
has_trailing_semicolon: bool
@ -61,6 +63,7 @@ class StandaloneCallTransformer:
- func(args) -> codeflash.capturePerf('name', 'id', func, args)
- const result = func(args) -> const result = codeflash.capturePerf(...)
- arr.map(() => func(args)) -> arr.map(() => codeflash.capturePerf(..., func, args))
- calc.fibonacci(n) -> codeflash.capturePerf('...', 'id', calc.fibonacci.bind(calc), n)
"""
@ -69,9 +72,12 @@ class StandaloneCallTransformer:
self.qualified_name = qualified_name
self.capture_func = capture_func
self.invocation_counter = 0
# Pattern to match func_name( with optional leading await
# Pattern to match func_name( with optional leading await and optional object prefix
# Captures: (whitespace)(await )?(object.)*func_name(
# We'll filter out expect() and codeflash. cases in the transform loop
self._call_pattern = re.compile(rf"(\s*)(await\s+)?{re.escape(func_name)}\s*\(")
self._call_pattern = re.compile(
rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\("
)
def transform(self, code: str) -> str:
"""Transform all standalone calls in the code."""
@ -121,6 +127,42 @@ class StandaloneCallTransformer:
if f"codeflash.{self.capture_func}(" in lookback[-60:]:
return True
# Skip if this is a function/method definition, not a call
# Patterns to skip:
# - ClassName.prototype.funcName = function(
# - funcName = function(
# - funcName: function(
# - function funcName(
# - funcName() { (method definition in class)
near_context = lookback[-80:] if len(lookback) >= 80 else lookback
# Skip prototype assignment: ClassName.prototype.funcName = function(
if re.search(r"\.prototype\.\w+\s*=\s*function\s*$", near_context):
return True
# Skip function assignment: funcName = function(
if re.search(rf"{re.escape(self.func_name)}\s*=\s*function\s*$", near_context):
return True
# Skip function declaration: function funcName(
if re.search(rf"function\s+{re.escape(self.func_name)}\s*$", near_context):
return True
# Skip method definition in class body: funcName(params) { or async funcName(params) {
# Check by looking at what comes after the closing paren
# The match ends at the opening paren, so find the closing paren and check what follows
close_paren_pos = self._find_matching_paren(code, match.end() - 1)
if close_paren_pos != -1:
# Check if followed by { (method definition) after optional whitespace
after_close = code[close_paren_pos : close_paren_pos + 20].lstrip()
if after_close.startswith("{"):
# This is a method definition like "fibonacci(n) {"
# But we still want to capture certain patterns like arrow functions
# Check if there's no => before the {
between = code[close_paren_pos : close_paren_pos + 20].strip()
if not between.startswith("=>"):
return True
# Skip if inside expect() - look for 'expect(' with unmatched parens
# Find the last 'expect(' and check if it's still open
expect_search_start = max(0, start - 100)
@ -138,10 +180,28 @@ class StandaloneCallTransformer:
return False
def _find_matching_paren(self, code: str, open_paren_pos: int) -> int:
"""Find the position of the closing paren for the given opening paren."""
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
return -1
depth = 1
pos = open_paren_pos + 1
while pos < len(code) and depth > 0:
if code[pos] == "(":
depth += 1
elif code[pos] == ")":
depth -= 1
pos += 1
return pos if depth == 0 else -1
def _parse_standalone_call(self, code: str, match: re.Match) -> StandaloneCallMatch | None:
"""Parse a complete standalone func(...) call."""
leading_ws = match.group(1)
prefix = match.group(2) or "" # "await " or ""
object_prefix = match.group(3) or "" # Object prefix like "calc." or ""
# Find the opening paren position
match_text = match.group(0)
@ -169,6 +229,7 @@ class StandaloneCallTransformer:
leading_whitespace=leading_ws,
func_args=func_args,
prefix=prefix,
object_prefix=object_prefix,
has_trailing_semicolon=has_trailing_semicolon,
)
@ -212,6 +273,23 @@ class StandaloneCallTransformer:
args_str = match.func_args.strip()
semicolon = ";" if match.has_trailing_semicolon else ""
# Handle method calls on objects (e.g., calc.fibonacci, this.method)
if match.object_prefix:
# Remove trailing dot from object prefix for the bind call
obj = match.object_prefix.rstrip(".")
full_method = f"{obj}.{self.func_name}"
if args_str:
return (
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {full_method}.bind({obj}), {args_str}){semicolon}"
)
return (
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {full_method}.bind({obj})){semicolon}"
)
# Handle standalone function calls
if args_str:
return (
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
@ -268,8 +346,11 @@ class ExpectCallTransformer:
self.capture_func = capture_func
self.remove_assertions = remove_assertions
self.invocation_counter = 0
# Pattern to match start of expect(func_name(
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*{re.escape(func_name)}\s*\(")
# Pattern to match start of expect((object.)*func_name(
# Captures: (whitespace), (object prefix like calc. or this.)
self._expect_pattern = re.compile(
rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\("
)
def transform(self, code: str) -> str:
"""Transform all expect calls in the code."""
@ -307,6 +388,7 @@ class ExpectCallTransformer:
Returns None if the pattern doesn't match expected structure.
"""
leading_ws = match.group(1)
object_prefix = match.group(2) or "" # Object prefix like "calc." or ""
# Find the arguments of the function call (handling nested parens)
args_start = match.end()
@ -341,6 +423,7 @@ class ExpectCallTransformer:
func_args=func_args,
assertion_chain=assertion_chain,
has_trailing_semicolon=has_trailing_semicolon,
object_prefix=object_prefix,
)
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
@ -473,16 +556,24 @@ class ExpectCallTransformer:
line_id = str(self.invocation_counter)
args_str = match.func_args.strip()
# Determine the function reference to use
if match.object_prefix:
# Method call on object: calc.fibonacci -> calc.fibonacci.bind(calc)
obj = match.object_prefix.rstrip(".")
func_ref = f"{obj}.{self.func_name}.bind({obj})"
else:
func_ref = self.func_name
if self.remove_assertions:
# For generated/regression tests: remove expect wrapper and assertion
if args_str:
return (
f"{match.leading_whitespace}codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {self.func_name}, {args_str});"
f"'{line_id}', {func_ref}, {args_str});"
)
return (
f"{match.leading_whitespace}codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {self.func_name});"
f"'{line_id}', {func_ref});"
)
# For existing tests: keep the expect wrapper
@ -490,11 +581,11 @@ class ExpectCallTransformer:
if args_str:
return (
f"{match.leading_whitespace}expect(codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {self.func_name}, {args_str})){match.assertion_chain}{semicolon}"
f"'{line_id}', {func_ref}, {args_str})){match.assertion_chain}{semicolon}"
)
return (
f"{match.leading_whitespace}expect(codeflash.{self.capture_func}('{self.qualified_name}', "
f"'{line_id}', {self.func_name})){match.assertion_chain}{semicolon}"
f"'{line_id}', {func_ref})){match.assertion_chain}{semicolon}"
)
@ -582,13 +673,18 @@ def inject_profiling_into_existing_js_test(
def _is_function_used_in_test(code: str, func_name: str) -> bool:
"""Check if a function is imported or used in the test code."""
# Check for CommonJS require
"""Check if a function is imported or used in the test code.
This function handles both standalone functions and class methods.
For class methods, it checks if the method is called on any object
(e.g., calc.fibonacci, this.fibonacci).
"""
# Check for CommonJS require with named export
require_pattern = rf"(?:const|let|var)\s+\{{\s*[^}}]*\b{re.escape(func_name)}\b[^}}]*\}}\s*=\s*require\s*\("
if re.search(require_pattern, code):
return True
# Check for ES6 import
# Check for ES6 import with named export
import_pattern = rf"import\s+\{{\s*[^}}]*\b{re.escape(func_name)}\b[^}}]*\}}\s+from"
if re.search(import_pattern, code):
return True
@ -602,6 +698,12 @@ def _is_function_used_in_test(code: str, func_name: str) -> bool:
if re.search(default_import, code):
return True
# Check for method calls: obj.funcName( or this.funcName(
# This handles class methods called on instances
method_call_pattern = rf"\w+\.{re.escape(func_name)}\s*\("
if re.search(method_call_pattern, code):
return True
return False

View file

@ -321,6 +321,29 @@ class JavaScriptSupport:
else:
target_code = ""
# For class methods, wrap the method in its class definition
# This is necessary because method definition syntax is only valid inside a class body
if function.is_method and function.parents:
class_name = None
for parent in function.parents:
if parent.type == "ClassDef":
class_name = parent.name
break
if class_name:
# Find the class definition in the source to get proper indentation and any class JSDoc
class_info = self._find_class_definition(source, class_name, analyzer)
if class_info:
class_jsdoc, class_indent = class_info
# Wrap the method in a minimal class definition
if class_jsdoc:
target_code = f"{class_jsdoc}\n{class_indent}class {class_name} {{\n{target_code}{class_indent}}}\n"
else:
target_code = f"{class_indent}class {class_name} {{\n{target_code}{class_indent}}}\n"
else:
# Fallback: wrap with no indentation
target_code = f"class {class_name} {{\n{target_code}}}\n"
imports = analyzer.find_imports(source)
# Find helper functions called by target
@ -337,6 +360,16 @@ class JavaScriptSupport:
target_code=target_code, helpers=helpers, source=source, analyzer=analyzer, imports=imports
)
# Validate that the extracted code is syntactically valid
# If not, raise an error to fail the optimization early
if target_code and not self.validate_syntax(target_code):
error_msg = (
f"Extracted code for {function.name} is not syntactically valid JavaScript. "
f"Cannot proceed with optimization."
)
logger.error(error_msg)
raise ValueError(error_msg)
return CodeContext(
target_code=target_code,
target_file=function.file_path,
@ -346,6 +379,61 @@ class JavaScriptSupport:
language=Language.JAVASCRIPT,
)
def _find_class_definition(
self, source: str, class_name: str, analyzer: TreeSitterAnalyzer
) -> tuple[str, str] | None:
"""Find a class definition and extract its JSDoc comment and indentation.
Args:
source: The source code to search.
class_name: The name of the class to find.
analyzer: TreeSitterAnalyzer for parsing.
Returns:
Tuple of (jsdoc_comment, indentation) or None if not found.
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
def find_class_node(node):
"""Recursively find a class declaration with the given name."""
if node.type in ("class_declaration", "class"):
name_node = node.child_by_field_name("name")
if name_node:
node_name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if node_name == class_name:
return node
for child in node.children:
result = find_class_node(child)
if result:
return result
return None
class_node = find_class_node(tree.root_node)
if not class_node:
return None
# Get indentation from the class line
lines = source.splitlines(keepends=True)
class_line_idx = class_node.start_point[0]
if class_line_idx < len(lines):
class_line = lines[class_line_idx]
indent = len(class_line) - len(class_line.lstrip())
indentation = " " * indent
else:
indentation = ""
# Look for preceding JSDoc comment
jsdoc = ""
prev_sibling = class_node.prev_named_sibling
if prev_sibling and prev_sibling.type == "comment":
comment_text = source_bytes[prev_sibling.start_byte : prev_sibling.end_byte].decode("utf8")
if comment_text.strip().startswith("/**"):
jsdoc = comment_text
return (jsdoc, indentation)
def _find_helper_functions(
self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path
) -> list[HelperFunction]:

View file

@ -357,4 +357,356 @@ test('decrypt works', () => {{
assert fixed_code == test_code
# Clean up
source_path.unlink()
source_path.unlink()
class TestClassMethodInstrumentation:
"""Tests for class method instrumentation."""
def test_instrument_method_call_on_instance(self):
"""Test that method calls on instances are correctly instrumented."""
from codeflash.languages.javascript.instrument import transform_standalone_calls
code = """
const calc = new Calculator();
const result = calc.fibonacci(10);
console.log(result);
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="Calculator.fibonacci",
capture_func="capture",
)
# Should transform calc.fibonacci(10) to codeflash.capture(..., calc.fibonacci.bind(calc), 10)
assert "codeflash.capture('Calculator.fibonacci'" in transformed
assert "calc.fibonacci.bind(calc)" in transformed
assert counter == 1
def test_instrument_expect_with_method_call(self):
"""Test that expect() with method calls are correctly instrumented."""
from codeflash.languages.javascript.instrument import transform_expect_calls
code = """
test('fibonacci works', () => {
const calc = new FibonacciCalculator();
expect(calc.fibonacci(10)).toBe(55);
});
"""
transformed, counter = transform_expect_calls(
code=code,
func_name="fibonacci",
qualified_name="FibonacciCalculator.fibonacci",
capture_func="capture",
)
# Should transform expect(calc.fibonacci(10)) to
# expect(codeflash.capture(..., calc.fibonacci.bind(calc), 10))
assert "codeflash.capture('FibonacciCalculator.fibonacci'" in transformed
assert "calc.fibonacci.bind(calc)" in transformed
assert ".toBe(55)" in transformed
assert counter == 1
def test_instrument_expect_with_method_removes_assertion(self):
"""Test that expect() with method calls are correctly instrumented with assertion removal."""
from codeflash.languages.javascript.instrument import transform_expect_calls
code = """
test('fibonacci works', () => {
const calc = new FibonacciCalculator();
expect(calc.fibonacci(10)).toBe(55);
});
"""
transformed, counter = transform_expect_calls(
code=code,
func_name="fibonacci",
qualified_name="FibonacciCalculator.fibonacci",
capture_func="capture",
remove_assertions=True,
)
# Should remove expect wrapper and assertion
assert "codeflash.capture('FibonacciCalculator.fibonacci'" in transformed
assert "calc.fibonacci.bind(calc)" in transformed
assert ".toBe(55)" not in transformed # Assertion removed
assert "expect(" not in transformed # expect wrapper removed
assert counter == 1
def test_does_not_instrument_function_definition(self):
"""Test that function definitions are NOT transformed."""
from codeflash.languages.javascript.instrument import transform_standalone_calls
code = """
class FibonacciCalculator {
fibonacci(n) {
if (n <= 1) return n;
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
}
}
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="FibonacciCalculator.fibonacci",
capture_func="capture",
)
# The method definition should NOT be transformed
# Only the recursive calls this.fibonacci(...) should potentially be transformed
assert "fibonacci(n) {" in transformed # Method definition unchanged
assert counter >= 0 # May or may not transform the recursive calls
def test_does_not_instrument_prototype_assignment(self):
"""Test that prototype assignments are NOT transformed."""
from codeflash.languages.javascript.instrument import transform_standalone_calls
code = """
FibonacciCalculator.prototype.fibonacci = function(n) {
if (n <= 1) return n;
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
};
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="FibonacciCalculator.fibonacci",
capture_func="capture",
)
# The prototype assignment should NOT be transformed
# It should still have the original pattern
assert "FibonacciCalculator.prototype.fibonacci = function(n)" in transformed
def test_instrument_multiple_method_calls(self):
"""Test that multiple method calls are correctly instrumented."""
from codeflash.languages.javascript.instrument import transform_standalone_calls
code = """
const calc = new Calculator();
const a = calc.fibonacci(5);
const b = calc.fibonacci(10);
const sum = a + b;
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="Calculator.fibonacci",
capture_func="capture",
)
# Should transform both calls
assert transformed.count("codeflash.capture") == 2
assert counter == 2
def test_instrument_this_method_call(self):
"""Test that this.method() calls are correctly instrumented."""
from codeflash.languages.javascript.instrument import transform_standalone_calls
code = """
class Wrapper {
callFibonacci(n) {
return this.fibonacci(n);
}
}
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="Wrapper.fibonacci",
capture_func="capture",
)
# Should transform this.fibonacci(n)
assert "codeflash.capture('Wrapper.fibonacci'" in transformed
assert "this.fibonacci.bind(this)" in transformed
assert counter == 1
def test_full_instrumentation_produces_valid_syntax(self):
"""Test that full instrumentation produces syntactically valid JavaScript."""
from codeflash.languages.javascript.instrument import _instrument_js_test_code
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
js_support = get_language_support(Language.JAVASCRIPT)
test_code = """
const { FibonacciCalculator } = require('../fibonacci_class');
describe('FibonacciCalculator', () => {
let calc;
beforeEach(() => {
calc = new FibonacciCalculator();
});
test('fibonacci returns correct values', () => {
expect(calc.fibonacci(0)).toBe(0);
expect(calc.fibonacci(1)).toBe(1);
expect(calc.fibonacci(10)).toBe(55);
});
test('standalone call', () => {
const result = calc.fibonacci(5);
expect(result).toBe(5);
});
});
"""
instrumented = _instrument_js_test_code(
code=test_code,
func_name="fibonacci",
test_file_path="test.js",
mode="behavior",
qualified_name="FibonacciCalculator.fibonacci",
)
# Check that codeflash import was added
assert "codeflash" in instrumented
# Check that method calls were instrumented
assert "codeflash.capture" in instrumented
# Check that the instrumented code is valid JavaScript
assert js_support.validate_syntax(instrumented) is True, f"Invalid syntax:\n{instrumented}"
def test_instrumentation_preserves_test_structure(self):
"""Test that instrumentation preserves the test structure."""
from codeflash.languages.javascript.instrument import _instrument_js_test_code
test_code = """
const { Calculator } = require('../calculator');
describe('Calculator', () => {
test('add works', () => {
const calc = new Calculator();
expect(calc.add(1, 2)).toBe(3);
});
});
"""
instrumented = _instrument_js_test_code(
code=test_code,
func_name="add",
test_file_path="test.js",
mode="behavior",
qualified_name="Calculator.add",
)
# describe and test structure should be preserved
assert "describe('Calculator'" in instrumented
assert "test('add works'" in instrumented
assert "beforeEach" in instrumented or "beforeEach" not in test_code # Only if it was there
# Method call should be instrumented
assert "codeflash.capture('Calculator.add'" in instrumented
assert "calc.add.bind(calc)" in instrumented
def test_instrumentation_with_async_methods(self):
"""Test instrumentation with async method calls."""
from codeflash.languages.javascript.instrument import transform_standalone_calls
code = """
const api = new ApiClient();
const data = await api.fetchData('http://example.com');
console.log(data);
"""
transformed, counter = transform_standalone_calls(
code=code,
func_name="fetchData",
qualified_name="ApiClient.fetchData",
capture_func="capture",
)
# Should preserve await
assert "await codeflash.capture" in transformed
assert "api.fetchData.bind(api)" in transformed
assert counter == 1
class TestInstrumentationFullStringEquality:
"""Tests with full string equality for precise verification."""
def test_standalone_method_call_exact_output(self):
"""Test exact output of standalone method call instrumentation."""
from codeflash.languages.javascript.instrument import transform_standalone_calls
code = " calc.fibonacci(10);"
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="Calculator.fibonacci",
capture_func="capture",
)
expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);"
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
assert counter == 1
def test_expect_method_call_exact_output(self):
"""Test exact output of expect() method call instrumentation."""
from codeflash.languages.javascript.instrument import transform_expect_calls
code = " expect(calc.fibonacci(10)).toBe(55);"
transformed, counter = transform_expect_calls(
code=code,
func_name="fibonacci",
qualified_name="Calculator.fibonacci",
capture_func="capture",
)
expected = " expect(codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10)).toBe(55);"
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
assert counter == 1
def test_expect_method_call_remove_assertions_exact_output(self):
"""Test exact output when removing assertions."""
from codeflash.languages.javascript.instrument import transform_expect_calls
code = " expect(calc.fibonacci(10)).toBe(55);"
transformed, counter = transform_expect_calls(
code=code,
func_name="fibonacci",
qualified_name="Calculator.fibonacci",
capture_func="capture",
remove_assertions=True,
)
expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);"
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
assert counter == 1
def test_standalone_function_call_no_object_prefix(self):
"""Test that standalone function calls (no object) work correctly."""
from codeflash.languages.javascript.instrument import transform_standalone_calls
code = " fibonacci(10);"
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="fibonacci",
capture_func="capture",
)
expected = " codeflash.capture('fibonacci', '1', fibonacci, 10);"
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
assert counter == 1
def test_this_method_call_exact_output(self):
"""Test exact output for this.method() call."""
from codeflash.languages.javascript.instrument import transform_standalone_calls
code = " return this.fibonacci(n - 1);"
transformed, counter = transform_standalone_calls(
code=code,
func_name="fibonacci",
qualified_name="Class.fibonacci",
capture_func="capture",
)
expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);"
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
assert counter == 1

View file

@ -695,3 +695,546 @@ describe('Math functions', () => {
assert "Math functions" in test_names
assert "add returns sum" in test_names
assert "handles negative numbers" in test_names
class TestClassMethodExtraction:
"""Tests for class method extraction and code context.
These tests use full string equality to verify exact extraction output.
"""
def test_extract_class_method_wraps_in_class(self, js_support):
"""Test that extracting a class method wraps it in a class definition."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Calculator {
add(a, b) {
return a + b;
}
multiply(a, b) {
return a * b;
}
}
""")
f.flush()
file_path = Path(f.name)
# Discover the method
functions = js_support.discover_functions(file_path)
add_method = next(f for f in functions if f.name == "add")
# Extract code context
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
# Full string equality check for exact extraction output
expected_code = """class Calculator {
add(a, b) {
return a + b;
}
}
"""
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
assert js_support.validate_syntax(context.target_code) is True
def test_extract_class_method_with_jsdoc(self, js_support):
"""Test extracting a class method with JSDoc comments."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""/**
* A simple calculator class.
*/
class Calculator {
/**
* Adds two numbers.
* @param {number} a - First number
* @param {number} b - Second number
* @returns {number} The sum
*/
add(a, b) {
return a + b;
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
add_method = next(f for f in functions if f.name == "add")
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
# Full string equality check - includes class JSDoc, class definition, method JSDoc, and method
expected_code = """/**
* A simple calculator class.
*/
class Calculator {
/**
* Adds two numbers.
* @param {number} a - First number
* @param {number} b - Second number
* @returns {number} The sum
*/
add(a, b) {
return a + b;
}
}
"""
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
assert js_support.validate_syntax(context.target_code) is True
def test_extract_class_method_syntax_valid(self, js_support):
"""Test that extracted class method code is always syntactically valid."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class FibonacciCalculator {
fibonacci(n) {
if (n <= 1) {
return n;
}
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
fib_method = next(f for f in functions if f.name == "fibonacci")
context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent)
# Full string equality check
expected_code = """class FibonacciCalculator {
fibonacci(n) {
if (n <= 1) {
return n;
}
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
}
}
"""
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
assert js_support.validate_syntax(context.target_code) is True
def test_extract_nested_class_method(self, js_support):
"""Test extracting a method from a nested class structure."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Outer {
createInner() {
return class Inner {
getValue() {
return 42;
}
};
}
add(a, b) {
return a + b;
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
add_method = next((f for f in functions if f.name == "add"), None)
if add_method:
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
# Full string equality check
expected_code = """class Outer {
add(a, b) {
return a + b;
}
}
"""
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
assert js_support.validate_syntax(context.target_code) is True
def test_extract_async_class_method(self, js_support):
"""Test extracting an async class method."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class ApiClient {
async fetchData(url) {
const response = await fetch(url);
return response.json();
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
fetch_method = next(f for f in functions if f.name == "fetchData")
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
# Full string equality check
expected_code = """class ApiClient {
async fetchData(url) {
const response = await fetch(url);
return response.json();
}
}
"""
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
assert js_support.validate_syntax(context.target_code) is True
def test_extract_static_class_method(self, js_support):
"""Test extracting a static class method."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class MathUtils {
static add(a, b) {
return a + b;
}
static multiply(a, b) {
return a * b;
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
add_method = next((f for f in functions if f.name == "add"), None)
if add_method:
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
# Full string equality check
expected_code = """class MathUtils {
static add(a, b) {
return a + b;
}
}
"""
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
assert js_support.validate_syntax(context.target_code) is True
def test_extract_class_method_without_class_jsdoc(self, js_support):
"""Test extracting a method from a class without JSDoc."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class SimpleClass {
simpleMethod() {
return "hello";
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
method = next(f for f in functions if f.name == "simpleMethod")
context = js_support.extract_code_context(method, file_path.parent, file_path.parent)
# Full string equality check
expected_code = """class SimpleClass {
simpleMethod() {
return "hello";
}
}
"""
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
assert js_support.validate_syntax(context.target_code) is True
class TestClassMethodReplacement:
"""Tests for replacing class methods."""
def test_replace_class_method_preserves_class_structure(self, js_support):
"""Test that replacing a class method preserves the class structure."""
source = """class Calculator {
add(a, b) {
return a + b;
}
multiply(a, b) {
return a * b;
}
}
"""
func = FunctionInfo(
name="add",
file_path=Path("/test.js"),
start_line=2,
end_line=4,
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
is_method=True,
)
new_code = """ add(a, b) {
// Optimized bitwise addition
return (a + b) | 0;
}
"""
result = js_support.replace_function(source, func, new_code)
# Check class structure is preserved
assert "class Calculator" in result
assert "multiply(a, b)" in result
assert "return a * b" in result
# Check new code is inserted
assert "Optimized bitwise addition" in result
assert "(a + b) | 0" in result
# Check result is valid JavaScript
assert js_support.validate_syntax(result) is True
def test_replace_class_method_with_jsdoc(self, js_support):
"""Test replacing a class method that has JSDoc."""
source = """class Calculator {
/**
* Adds two numbers.
*/
add(a, b) {
return a + b;
}
}
"""
func = FunctionInfo(
name="add",
file_path=Path("/test.js"),
start_line=5, # Method starts here
end_line=7,
doc_start_line=2, # JSDoc starts here
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
is_method=True,
)
new_code = """ /**
* Adds two numbers (optimized).
*/
add(a, b) {
return (a + b) | 0;
}
"""
result = js_support.replace_function(source, func, new_code)
assert "optimized" in result
assert "(a + b) | 0" in result
assert js_support.validate_syntax(result) is True
def test_replace_multiple_class_methods_sequentially(self, js_support):
"""Test replacing multiple methods in sequence."""
source = """class Math {
add(a, b) {
return a + b;
}
subtract(a, b) {
return a - b;
}
}
"""
# Replace add first
add_func = FunctionInfo(
name="add",
file_path=Path("/test.js"),
start_line=2,
end_line=4,
parents=(ParentInfo(name="Math", type="ClassDef"),),
is_method=True,
)
source = js_support.replace_function(source, add_func, """ add(a, b) {
return (a + b) | 0;
}
""")
assert js_support.validate_syntax(source) is True
# Now need to re-discover to get updated line numbers
# In practice, codeflash handles this, but for test we just check validity
assert "return (a + b) | 0" in source
assert "return a - b" in source
def test_replace_class_method_indentation_adjustment(self, js_support):
"""Test that indentation is correctly adjusted when replacing."""
source = """ class Indented {
innerMethod() {
return 1;
}
}
"""
func = FunctionInfo(
name="innerMethod",
file_path=Path("/test.js"),
start_line=2,
end_line=4,
parents=(ParentInfo(name="Indented", type="ClassDef"),),
is_method=True,
)
# New code with no indentation
new_code = """innerMethod() {
return 42;
}
"""
result = js_support.replace_function(source, func, new_code)
# Check that indentation was adjusted
lines = result.splitlines()
method_line = next(l for l in lines if "innerMethod" in l)
# Should have 8 spaces (original indentation)
assert method_line.startswith(" ")
assert js_support.validate_syntax(result) is True
class TestClassMethodEdgeCases:
"""Edge case tests for class method handling."""
def test_class_with_constructor(self, js_support):
"""Test handling classes with constructors."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Counter {
constructor(start = 0) {
this.value = start;
}
increment() {
return ++this.value;
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
# Should find constructor and increment
names = {f.name for f in functions}
assert "constructor" in names or "increment" in names
def test_class_with_getters_setters(self, js_support):
"""Test handling classes with getters and setters."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Person {
constructor(name) {
this._name = name;
}
get name() {
return this._name;
}
set name(value) {
this._name = value;
}
greet() {
return 'Hello, ' + this._name;
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
# Should find at least greet
names = {f.name for f in functions}
assert "greet" in names
def test_class_extending_another(self, js_support):
"""Test handling classes that extend another class."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Animal {
speak() {
return 'sound';
}
}
class Dog extends Animal {
speak() {
return 'bark';
}
fetch() {
return 'ball';
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
# Find Dog's fetch method
fetch_method = next((f for f in functions if f.name == "fetch" and f.class_name == "Dog"), None)
if fetch_method:
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
# Full string equality check
expected_code = """class Dog {
fetch() {
return 'ball';
}
}
"""
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
assert js_support.validate_syntax(context.target_code) is True
def test_class_with_private_method(self, js_support):
"""Test handling classes with private methods (ES2022+)."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class SecureClass {
#privateMethod() {
return 'secret';
}
publicMethod() {
return this.#privateMethod();
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
# Should at least find publicMethod
names = {f.name for f in functions}
assert "publicMethod" in names
def test_commonjs_class_export(self, js_support):
"""Test handling CommonJS exported classes."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""class Calculator {
add(a, b) {
return a + b;
}
}
module.exports = { Calculator };
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
add_method = next(f for f in functions if f.name == "add")
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
assert "class Calculator" in context.target_code
assert js_support.validate_syntax(context.target_code) is True
def test_es_module_class_export(self, js_support):
"""Test handling ES module exported classes."""
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
f.write("""export class Calculator {
add(a, b) {
return a + b;
}
}
""")
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
# Find the add method
add_method = next((f for f in functions if f.name == "add"), None)
if add_method:
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
assert js_support.validate_syntax(context.target_code) is True