mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
extract class skeleton for optimization context
This commit is contained in:
parent
b30a52b8eb
commit
325534dbc2
7 changed files with 1265 additions and 14 deletions
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Fibonacci Calculator Class - CommonJS module
|
||||
* Intentionally inefficient for optimization testing.
|
||||
*/
|
||||
|
||||
class FibonacciCalculator {
|
||||
constructor() {
|
||||
// No initialization needed
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the nth Fibonacci number using naive recursion.
|
||||
* This is intentionally slow to demonstrate optimization potential.
|
||||
* @param {number} n - The index of the Fibonacci number to calculate
|
||||
* @returns {number} The nth Fibonacci number
|
||||
*/
|
||||
fibonacci(n) {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a Fibonacci number.
|
||||
* @param {number} num - The number to check
|
||||
* @returns {boolean} True if num is a Fibonacci number
|
||||
*/
|
||||
isFibonacci(num) {
|
||||
// A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square
|
||||
const check1 = 5 * num * num + 4;
|
||||
const check2 = 5 * num * num - 4;
|
||||
return this.isPerfectSquare(check1) || this.isPerfectSquare(check2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is a perfect square.
|
||||
* @param {number} n - The number to check
|
||||
* @returns {boolean} True if n is a perfect square
|
||||
*/
|
||||
isPerfectSquare(n) {
|
||||
const sqrt = Math.sqrt(n);
|
||||
return sqrt === Math.floor(sqrt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an array of Fibonacci numbers up to n.
|
||||
* @param {number} n - The number of Fibonacci numbers to generate
|
||||
* @returns {number[]} Array of Fibonacci numbers
|
||||
*/
|
||||
fibonacciSequence(n) {
|
||||
const result = [];
|
||||
for (let i = 0; i < n; i++) {
|
||||
result.push(this.fibonacci(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// CommonJS exports
|
||||
module.exports = { FibonacciCalculator };
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
}
|
||||
},
|
||||
"../../../packages/codeflash": {
|
||||
"version": "0.1.0",
|
||||
"version": "0.2.0",
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,105 @@
|
|||
const { FibonacciCalculator } = require('../fibonacci_class');
|
||||
|
||||
describe('FibonacciCalculator', () => {
|
||||
let calc;
|
||||
|
||||
beforeEach(() => {
|
||||
calc = new FibonacciCalculator();
|
||||
});
|
||||
|
||||
describe('fibonacci', () => {
|
||||
test('returns 0 for n=0', () => {
|
||||
expect(calc.fibonacci(0)).toBe(0);
|
||||
});
|
||||
|
||||
test('returns 1 for n=1', () => {
|
||||
expect(calc.fibonacci(1)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 1 for n=2', () => {
|
||||
expect(calc.fibonacci(2)).toBe(1);
|
||||
});
|
||||
|
||||
test('returns 5 for n=5', () => {
|
||||
expect(calc.fibonacci(5)).toBe(5);
|
||||
});
|
||||
|
||||
test('returns 55 for n=10', () => {
|
||||
expect(calc.fibonacci(10)).toBe(55);
|
||||
});
|
||||
|
||||
test('returns 233 for n=13', () => {
|
||||
expect(calc.fibonacci(13)).toBe(233);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isFibonacci', () => {
|
||||
test('returns true for 0', () => {
|
||||
expect(calc.isFibonacci(0)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 1', () => {
|
||||
expect(calc.isFibonacci(1)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 8', () => {
|
||||
expect(calc.isFibonacci(8)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 13', () => {
|
||||
expect(calc.isFibonacci(13)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for 4', () => {
|
||||
expect(calc.isFibonacci(4)).toBe(false);
|
||||
});
|
||||
|
||||
test('returns false for 6', () => {
|
||||
expect(calc.isFibonacci(6)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPerfectSquare', () => {
|
||||
test('returns true for 0', () => {
|
||||
expect(calc.isPerfectSquare(0)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 1', () => {
|
||||
expect(calc.isPerfectSquare(1)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 4', () => {
|
||||
expect(calc.isPerfectSquare(4)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true for 16', () => {
|
||||
expect(calc.isPerfectSquare(16)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false for 2', () => {
|
||||
expect(calc.isPerfectSquare(2)).toBe(false);
|
||||
});
|
||||
|
||||
test('returns false for 3', () => {
|
||||
expect(calc.isPerfectSquare(3)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fibonacciSequence', () => {
|
||||
test('returns empty array for n=0', () => {
|
||||
expect(calc.fibonacciSequence(0)).toEqual([]);
|
||||
});
|
||||
|
||||
test('returns [0] for n=1', () => {
|
||||
expect(calc.fibonacciSequence(1)).toEqual([0]);
|
||||
});
|
||||
|
||||
test('returns first 5 Fibonacci numbers', () => {
|
||||
expect(calc.fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]);
|
||||
});
|
||||
|
||||
test('returns first 10 Fibonacci numbers', () => {
|
||||
expect(calc.fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -35,6 +35,7 @@ class ExpectCallMatch:
|
|||
func_args: str
|
||||
assertion_chain: str
|
||||
has_trailing_semicolon: bool
|
||||
object_prefix: str = "" # Object prefix like "calc." or "this." or ""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -45,7 +46,8 @@ class StandaloneCallMatch:
|
|||
end_pos: int
|
||||
leading_whitespace: str
|
||||
func_args: str
|
||||
prefix: str # Everything between whitespace and func name (e.g., "await ", "")
|
||||
prefix: str # "await " or ""
|
||||
object_prefix: str # Object prefix like "calc." or "this." or ""
|
||||
has_trailing_semicolon: bool
|
||||
|
||||
|
||||
|
|
@ -61,6 +63,7 @@ class StandaloneCallTransformer:
|
|||
- func(args) -> codeflash.capturePerf('name', 'id', func, args)
|
||||
- const result = func(args) -> const result = codeflash.capturePerf(...)
|
||||
- arr.map(() => func(args)) -> arr.map(() => codeflash.capturePerf(..., func, args))
|
||||
- calc.fibonacci(n) -> codeflash.capturePerf('...', 'id', calc.fibonacci.bind(calc), n)
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -69,9 +72,12 @@ class StandaloneCallTransformer:
|
|||
self.qualified_name = qualified_name
|
||||
self.capture_func = capture_func
|
||||
self.invocation_counter = 0
|
||||
# Pattern to match func_name( with optional leading await
|
||||
# Pattern to match func_name( with optional leading await and optional object prefix
|
||||
# Captures: (whitespace)(await )?(object.)*func_name(
|
||||
# We'll filter out expect() and codeflash. cases in the transform loop
|
||||
self._call_pattern = re.compile(rf"(\s*)(await\s+)?{re.escape(func_name)}\s*\(")
|
||||
self._call_pattern = re.compile(
|
||||
rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\("
|
||||
)
|
||||
|
||||
def transform(self, code: str) -> str:
|
||||
"""Transform all standalone calls in the code."""
|
||||
|
|
@ -121,6 +127,42 @@ class StandaloneCallTransformer:
|
|||
if f"codeflash.{self.capture_func}(" in lookback[-60:]:
|
||||
return True
|
||||
|
||||
# Skip if this is a function/method definition, not a call
|
||||
# Patterns to skip:
|
||||
# - ClassName.prototype.funcName = function(
|
||||
# - funcName = function(
|
||||
# - funcName: function(
|
||||
# - function funcName(
|
||||
# - funcName() { (method definition in class)
|
||||
near_context = lookback[-80:] if len(lookback) >= 80 else lookback
|
||||
|
||||
# Skip prototype assignment: ClassName.prototype.funcName = function(
|
||||
if re.search(r"\.prototype\.\w+\s*=\s*function\s*$", near_context):
|
||||
return True
|
||||
|
||||
# Skip function assignment: funcName = function(
|
||||
if re.search(rf"{re.escape(self.func_name)}\s*=\s*function\s*$", near_context):
|
||||
return True
|
||||
|
||||
# Skip function declaration: function funcName(
|
||||
if re.search(rf"function\s+{re.escape(self.func_name)}\s*$", near_context):
|
||||
return True
|
||||
|
||||
# Skip method definition in class body: funcName(params) { or async funcName(params) {
|
||||
# Check by looking at what comes after the closing paren
|
||||
# The match ends at the opening paren, so find the closing paren and check what follows
|
||||
close_paren_pos = self._find_matching_paren(code, match.end() - 1)
|
||||
if close_paren_pos != -1:
|
||||
# Check if followed by { (method definition) after optional whitespace
|
||||
after_close = code[close_paren_pos : close_paren_pos + 20].lstrip()
|
||||
if after_close.startswith("{"):
|
||||
# This is a method definition like "fibonacci(n) {"
|
||||
# But we still want to capture certain patterns like arrow functions
|
||||
# Check if there's no => before the {
|
||||
between = code[close_paren_pos : close_paren_pos + 20].strip()
|
||||
if not between.startswith("=>"):
|
||||
return True
|
||||
|
||||
# Skip if inside expect() - look for 'expect(' with unmatched parens
|
||||
# Find the last 'expect(' and check if it's still open
|
||||
expect_search_start = max(0, start - 100)
|
||||
|
|
@ -138,10 +180,28 @@ class StandaloneCallTransformer:
|
|||
|
||||
return False
|
||||
|
||||
def _find_matching_paren(self, code: str, open_paren_pos: int) -> int:
|
||||
"""Find the position of the closing paren for the given opening paren."""
|
||||
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
|
||||
return -1
|
||||
|
||||
depth = 1
|
||||
pos = open_paren_pos + 1
|
||||
|
||||
while pos < len(code) and depth > 0:
|
||||
if code[pos] == "(":
|
||||
depth += 1
|
||||
elif code[pos] == ")":
|
||||
depth -= 1
|
||||
pos += 1
|
||||
|
||||
return pos if depth == 0 else -1
|
||||
|
||||
def _parse_standalone_call(self, code: str, match: re.Match) -> StandaloneCallMatch | None:
|
||||
"""Parse a complete standalone func(...) call."""
|
||||
leading_ws = match.group(1)
|
||||
prefix = match.group(2) or "" # "await " or ""
|
||||
object_prefix = match.group(3) or "" # Object prefix like "calc." or ""
|
||||
|
||||
# Find the opening paren position
|
||||
match_text = match.group(0)
|
||||
|
|
@ -169,6 +229,7 @@ class StandaloneCallTransformer:
|
|||
leading_whitespace=leading_ws,
|
||||
func_args=func_args,
|
||||
prefix=prefix,
|
||||
object_prefix=object_prefix,
|
||||
has_trailing_semicolon=has_trailing_semicolon,
|
||||
)
|
||||
|
||||
|
|
@ -212,6 +273,23 @@ class StandaloneCallTransformer:
|
|||
args_str = match.func_args.strip()
|
||||
semicolon = ";" if match.has_trailing_semicolon else ""
|
||||
|
||||
# Handle method calls on objects (e.g., calc.fibonacci, this.method)
|
||||
if match.object_prefix:
|
||||
# Remove trailing dot from object prefix for the bind call
|
||||
obj = match.object_prefix.rstrip(".")
|
||||
full_method = f"{obj}.{self.func_name}"
|
||||
|
||||
if args_str:
|
||||
return (
|
||||
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {full_method}.bind({obj}), {args_str}){semicolon}"
|
||||
)
|
||||
return (
|
||||
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {full_method}.bind({obj})){semicolon}"
|
||||
)
|
||||
|
||||
# Handle standalone function calls
|
||||
if args_str:
|
||||
return (
|
||||
f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
|
|
@ -268,8 +346,11 @@ class ExpectCallTransformer:
|
|||
self.capture_func = capture_func
|
||||
self.remove_assertions = remove_assertions
|
||||
self.invocation_counter = 0
|
||||
# Pattern to match start of expect(func_name(
|
||||
self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*{re.escape(func_name)}\s*\(")
|
||||
# Pattern to match start of expect((object.)*func_name(
|
||||
# Captures: (whitespace), (object prefix like calc. or this.)
|
||||
self._expect_pattern = re.compile(
|
||||
rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\("
|
||||
)
|
||||
|
||||
def transform(self, code: str) -> str:
|
||||
"""Transform all expect calls in the code."""
|
||||
|
|
@ -307,6 +388,7 @@ class ExpectCallTransformer:
|
|||
Returns None if the pattern doesn't match expected structure.
|
||||
"""
|
||||
leading_ws = match.group(1)
|
||||
object_prefix = match.group(2) or "" # Object prefix like "calc." or ""
|
||||
|
||||
# Find the arguments of the function call (handling nested parens)
|
||||
args_start = match.end()
|
||||
|
|
@ -341,6 +423,7 @@ class ExpectCallTransformer:
|
|||
func_args=func_args,
|
||||
assertion_chain=assertion_chain,
|
||||
has_trailing_semicolon=has_trailing_semicolon,
|
||||
object_prefix=object_prefix,
|
||||
)
|
||||
|
||||
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
|
||||
|
|
@ -473,16 +556,24 @@ class ExpectCallTransformer:
|
|||
line_id = str(self.invocation_counter)
|
||||
args_str = match.func_args.strip()
|
||||
|
||||
# Determine the function reference to use
|
||||
if match.object_prefix:
|
||||
# Method call on object: calc.fibonacci -> calc.fibonacci.bind(calc)
|
||||
obj = match.object_prefix.rstrip(".")
|
||||
func_ref = f"{obj}.{self.func_name}.bind({obj})"
|
||||
else:
|
||||
func_ref = self.func_name
|
||||
|
||||
if self.remove_assertions:
|
||||
# For generated/regression tests: remove expect wrapper and assertion
|
||||
if args_str:
|
||||
return (
|
||||
f"{match.leading_whitespace}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {self.func_name}, {args_str});"
|
||||
f"'{line_id}', {func_ref}, {args_str});"
|
||||
)
|
||||
return (
|
||||
f"{match.leading_whitespace}codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {self.func_name});"
|
||||
f"'{line_id}', {func_ref});"
|
||||
)
|
||||
|
||||
# For existing tests: keep the expect wrapper
|
||||
|
|
@ -490,11 +581,11 @@ class ExpectCallTransformer:
|
|||
if args_str:
|
||||
return (
|
||||
f"{match.leading_whitespace}expect(codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {self.func_name}, {args_str})){match.assertion_chain}{semicolon}"
|
||||
f"'{line_id}', {func_ref}, {args_str})){match.assertion_chain}{semicolon}"
|
||||
)
|
||||
return (
|
||||
f"{match.leading_whitespace}expect(codeflash.{self.capture_func}('{self.qualified_name}', "
|
||||
f"'{line_id}', {self.func_name})){match.assertion_chain}{semicolon}"
|
||||
f"'{line_id}', {func_ref})){match.assertion_chain}{semicolon}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -582,13 +673,18 @@ def inject_profiling_into_existing_js_test(
|
|||
|
||||
|
||||
def _is_function_used_in_test(code: str, func_name: str) -> bool:
|
||||
"""Check if a function is imported or used in the test code."""
|
||||
# Check for CommonJS require
|
||||
"""Check if a function is imported or used in the test code.
|
||||
|
||||
This function handles both standalone functions and class methods.
|
||||
For class methods, it checks if the method is called on any object
|
||||
(e.g., calc.fibonacci, this.fibonacci).
|
||||
"""
|
||||
# Check for CommonJS require with named export
|
||||
require_pattern = rf"(?:const|let|var)\s+\{{\s*[^}}]*\b{re.escape(func_name)}\b[^}}]*\}}\s*=\s*require\s*\("
|
||||
if re.search(require_pattern, code):
|
||||
return True
|
||||
|
||||
# Check for ES6 import
|
||||
# Check for ES6 import with named export
|
||||
import_pattern = rf"import\s+\{{\s*[^}}]*\b{re.escape(func_name)}\b[^}}]*\}}\s+from"
|
||||
if re.search(import_pattern, code):
|
||||
return True
|
||||
|
|
@ -602,6 +698,12 @@ def _is_function_used_in_test(code: str, func_name: str) -> bool:
|
|||
if re.search(default_import, code):
|
||||
return True
|
||||
|
||||
# Check for method calls: obj.funcName( or this.funcName(
|
||||
# This handles class methods called on instances
|
||||
method_call_pattern = rf"\w+\.{re.escape(func_name)}\s*\("
|
||||
if re.search(method_call_pattern, code):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -321,6 +321,29 @@ class JavaScriptSupport:
|
|||
else:
|
||||
target_code = ""
|
||||
|
||||
# For class methods, wrap the method in its class definition
|
||||
# This is necessary because method definition syntax is only valid inside a class body
|
||||
if function.is_method and function.parents:
|
||||
class_name = None
|
||||
for parent in function.parents:
|
||||
if parent.type == "ClassDef":
|
||||
class_name = parent.name
|
||||
break
|
||||
|
||||
if class_name:
|
||||
# Find the class definition in the source to get proper indentation and any class JSDoc
|
||||
class_info = self._find_class_definition(source, class_name, analyzer)
|
||||
if class_info:
|
||||
class_jsdoc, class_indent = class_info
|
||||
# Wrap the method in a minimal class definition
|
||||
if class_jsdoc:
|
||||
target_code = f"{class_jsdoc}\n{class_indent}class {class_name} {{\n{target_code}{class_indent}}}\n"
|
||||
else:
|
||||
target_code = f"{class_indent}class {class_name} {{\n{target_code}{class_indent}}}\n"
|
||||
else:
|
||||
# Fallback: wrap with no indentation
|
||||
target_code = f"class {class_name} {{\n{target_code}}}\n"
|
||||
|
||||
imports = analyzer.find_imports(source)
|
||||
|
||||
# Find helper functions called by target
|
||||
|
|
@ -337,6 +360,16 @@ class JavaScriptSupport:
|
|||
target_code=target_code, helpers=helpers, source=source, analyzer=analyzer, imports=imports
|
||||
)
|
||||
|
||||
# Validate that the extracted code is syntactically valid
|
||||
# If not, raise an error to fail the optimization early
|
||||
if target_code and not self.validate_syntax(target_code):
|
||||
error_msg = (
|
||||
f"Extracted code for {function.name} is not syntactically valid JavaScript. "
|
||||
f"Cannot proceed with optimization."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return CodeContext(
|
||||
target_code=target_code,
|
||||
target_file=function.file_path,
|
||||
|
|
@ -346,6 +379,61 @@ class JavaScriptSupport:
|
|||
language=Language.JAVASCRIPT,
|
||||
)
|
||||
|
||||
def _find_class_definition(
|
||||
self, source: str, class_name: str, analyzer: TreeSitterAnalyzer
|
||||
) -> tuple[str, str] | None:
|
||||
"""Find a class definition and extract its JSDoc comment and indentation.
|
||||
|
||||
Args:
|
||||
source: The source code to search.
|
||||
class_name: The name of the class to find.
|
||||
analyzer: TreeSitterAnalyzer for parsing.
|
||||
|
||||
Returns:
|
||||
Tuple of (jsdoc_comment, indentation) or None if not found.
|
||||
|
||||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
|
||||
def find_class_node(node):
|
||||
"""Recursively find a class declaration with the given name."""
|
||||
if node.type in ("class_declaration", "class"):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
node_name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
|
||||
if node_name == class_name:
|
||||
return node
|
||||
for child in node.children:
|
||||
result = find_class_node(child)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
class_node = find_class_node(tree.root_node)
|
||||
if not class_node:
|
||||
return None
|
||||
|
||||
# Get indentation from the class line
|
||||
lines = source.splitlines(keepends=True)
|
||||
class_line_idx = class_node.start_point[0]
|
||||
if class_line_idx < len(lines):
|
||||
class_line = lines[class_line_idx]
|
||||
indent = len(class_line) - len(class_line.lstrip())
|
||||
indentation = " " * indent
|
||||
else:
|
||||
indentation = ""
|
||||
|
||||
# Look for preceding JSDoc comment
|
||||
jsdoc = ""
|
||||
prev_sibling = class_node.prev_named_sibling
|
||||
if prev_sibling and prev_sibling.type == "comment":
|
||||
comment_text = source_bytes[prev_sibling.start_byte : prev_sibling.end_byte].decode("utf8")
|
||||
if comment_text.strip().startswith("/**"):
|
||||
jsdoc = comment_text
|
||||
|
||||
return (jsdoc, indentation)
|
||||
|
||||
def _find_helper_functions(
|
||||
self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path
|
||||
) -> list[HelperFunction]:
|
||||
|
|
|
|||
|
|
@ -357,4 +357,356 @@ test('decrypt works', () => {{
|
|||
assert fixed_code == test_code
|
||||
|
||||
# Clean up
|
||||
source_path.unlink()
|
||||
source_path.unlink()
|
||||
|
||||
|
||||
class TestClassMethodInstrumentation:
|
||||
"""Tests for class method instrumentation."""
|
||||
|
||||
def test_instrument_method_call_on_instance(self):
|
||||
"""Test that method calls on instances are correctly instrumented."""
|
||||
from codeflash.languages.javascript.instrument import transform_standalone_calls
|
||||
|
||||
code = """
|
||||
const calc = new Calculator();
|
||||
const result = calc.fibonacci(10);
|
||||
console.log(result);
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Calculator.fibonacci",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
# Should transform calc.fibonacci(10) to codeflash.capture(..., calc.fibonacci.bind(calc), 10)
|
||||
assert "codeflash.capture('Calculator.fibonacci'" in transformed
|
||||
assert "calc.fibonacci.bind(calc)" in transformed
|
||||
assert counter == 1
|
||||
|
||||
def test_instrument_expect_with_method_call(self):
|
||||
"""Test that expect() with method calls are correctly instrumented."""
|
||||
from codeflash.languages.javascript.instrument import transform_expect_calls
|
||||
|
||||
code = """
|
||||
test('fibonacci works', () => {
|
||||
const calc = new FibonacciCalculator();
|
||||
expect(calc.fibonacci(10)).toBe(55);
|
||||
});
|
||||
"""
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="FibonacciCalculator.fibonacci",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
# Should transform expect(calc.fibonacci(10)) to
|
||||
# expect(codeflash.capture(..., calc.fibonacci.bind(calc), 10))
|
||||
assert "codeflash.capture('FibonacciCalculator.fibonacci'" in transformed
|
||||
assert "calc.fibonacci.bind(calc)" in transformed
|
||||
assert ".toBe(55)" in transformed
|
||||
assert counter == 1
|
||||
|
||||
def test_instrument_expect_with_method_removes_assertion(self):
|
||||
"""Test that expect() with method calls are correctly instrumented with assertion removal."""
|
||||
from codeflash.languages.javascript.instrument import transform_expect_calls
|
||||
|
||||
code = """
|
||||
test('fibonacci works', () => {
|
||||
const calc = new FibonacciCalculator();
|
||||
expect(calc.fibonacci(10)).toBe(55);
|
||||
});
|
||||
"""
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="FibonacciCalculator.fibonacci",
|
||||
capture_func="capture",
|
||||
remove_assertions=True,
|
||||
)
|
||||
|
||||
# Should remove expect wrapper and assertion
|
||||
assert "codeflash.capture('FibonacciCalculator.fibonacci'" in transformed
|
||||
assert "calc.fibonacci.bind(calc)" in transformed
|
||||
assert ".toBe(55)" not in transformed # Assertion removed
|
||||
assert "expect(" not in transformed # expect wrapper removed
|
||||
assert counter == 1
|
||||
|
||||
def test_does_not_instrument_function_definition(self):
|
||||
"""Test that function definitions are NOT transformed."""
|
||||
from codeflash.languages.javascript.instrument import transform_standalone_calls
|
||||
|
||||
code = """
|
||||
class FibonacciCalculator {
|
||||
fibonacci(n) {
|
||||
if (n <= 1) return n;
|
||||
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
|
||||
}
|
||||
}
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="FibonacciCalculator.fibonacci",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
# The method definition should NOT be transformed
|
||||
# Only the recursive calls this.fibonacci(...) should potentially be transformed
|
||||
assert "fibonacci(n) {" in transformed # Method definition unchanged
|
||||
assert counter >= 0 # May or may not transform the recursive calls
|
||||
|
||||
def test_does_not_instrument_prototype_assignment(self):
|
||||
"""Test that prototype assignments are NOT transformed."""
|
||||
from codeflash.languages.javascript.instrument import transform_standalone_calls
|
||||
|
||||
code = """
|
||||
FibonacciCalculator.prototype.fibonacci = function(n) {
|
||||
if (n <= 1) return n;
|
||||
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
|
||||
};
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="FibonacciCalculator.fibonacci",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
# The prototype assignment should NOT be transformed
|
||||
# It should still have the original pattern
|
||||
assert "FibonacciCalculator.prototype.fibonacci = function(n)" in transformed
|
||||
|
||||
def test_instrument_multiple_method_calls(self):
|
||||
"""Test that multiple method calls are correctly instrumented."""
|
||||
from codeflash.languages.javascript.instrument import transform_standalone_calls
|
||||
|
||||
code = """
|
||||
const calc = new Calculator();
|
||||
const a = calc.fibonacci(5);
|
||||
const b = calc.fibonacci(10);
|
||||
const sum = a + b;
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Calculator.fibonacci",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
# Should transform both calls
|
||||
assert transformed.count("codeflash.capture") == 2
|
||||
assert counter == 2
|
||||
|
||||
def test_instrument_this_method_call(self):
|
||||
"""Test that this.method() calls are correctly instrumented."""
|
||||
from codeflash.languages.javascript.instrument import transform_standalone_calls
|
||||
|
||||
code = """
|
||||
class Wrapper {
|
||||
callFibonacci(n) {
|
||||
return this.fibonacci(n);
|
||||
}
|
||||
}
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Wrapper.fibonacci",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
# Should transform this.fibonacci(n)
|
||||
assert "codeflash.capture('Wrapper.fibonacci'" in transformed
|
||||
assert "this.fibonacci.bind(this)" in transformed
|
||||
assert counter == 1
|
||||
|
||||
def test_full_instrumentation_produces_valid_syntax(self):
|
||||
"""Test that full instrumentation produces syntactically valid JavaScript."""
|
||||
from codeflash.languages.javascript.instrument import _instrument_js_test_code
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
|
||||
test_code = """
|
||||
const { FibonacciCalculator } = require('../fibonacci_class');
|
||||
|
||||
describe('FibonacciCalculator', () => {
|
||||
let calc;
|
||||
|
||||
beforeEach(() => {
|
||||
calc = new FibonacciCalculator();
|
||||
});
|
||||
|
||||
test('fibonacci returns correct values', () => {
|
||||
expect(calc.fibonacci(0)).toBe(0);
|
||||
expect(calc.fibonacci(1)).toBe(1);
|
||||
expect(calc.fibonacci(10)).toBe(55);
|
||||
});
|
||||
|
||||
test('standalone call', () => {
|
||||
const result = calc.fibonacci(5);
|
||||
expect(result).toBe(5);
|
||||
});
|
||||
});
|
||||
"""
|
||||
instrumented = _instrument_js_test_code(
|
||||
code=test_code,
|
||||
func_name="fibonacci",
|
||||
test_file_path="test.js",
|
||||
mode="behavior",
|
||||
qualified_name="FibonacciCalculator.fibonacci",
|
||||
)
|
||||
|
||||
# Check that codeflash import was added
|
||||
assert "codeflash" in instrumented
|
||||
|
||||
# Check that method calls were instrumented
|
||||
assert "codeflash.capture" in instrumented
|
||||
|
||||
# Check that the instrumented code is valid JavaScript
|
||||
assert js_support.validate_syntax(instrumented) is True, f"Invalid syntax:\n{instrumented}"
|
||||
|
||||
def test_instrumentation_preserves_test_structure(self):
|
||||
"""Test that instrumentation preserves the test structure."""
|
||||
from codeflash.languages.javascript.instrument import _instrument_js_test_code
|
||||
|
||||
test_code = """
|
||||
const { Calculator } = require('../calculator');
|
||||
|
||||
describe('Calculator', () => {
|
||||
test('add works', () => {
|
||||
const calc = new Calculator();
|
||||
expect(calc.add(1, 2)).toBe(3);
|
||||
});
|
||||
});
|
||||
"""
|
||||
instrumented = _instrument_js_test_code(
|
||||
code=test_code,
|
||||
func_name="add",
|
||||
test_file_path="test.js",
|
||||
mode="behavior",
|
||||
qualified_name="Calculator.add",
|
||||
)
|
||||
|
||||
# describe and test structure should be preserved
|
||||
assert "describe('Calculator'" in instrumented
|
||||
assert "test('add works'" in instrumented
|
||||
assert "beforeEach" in instrumented or "beforeEach" not in test_code # Only if it was there
|
||||
|
||||
# Method call should be instrumented
|
||||
assert "codeflash.capture('Calculator.add'" in instrumented
|
||||
assert "calc.add.bind(calc)" in instrumented
|
||||
|
||||
def test_instrumentation_with_async_methods(self):
|
||||
"""Test instrumentation with async method calls."""
|
||||
from codeflash.languages.javascript.instrument import transform_standalone_calls
|
||||
|
||||
code = """
|
||||
const api = new ApiClient();
|
||||
const data = await api.fetchData('http://example.com');
|
||||
console.log(data);
|
||||
"""
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fetchData",
|
||||
qualified_name="ApiClient.fetchData",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
# Should preserve await
|
||||
assert "await codeflash.capture" in transformed
|
||||
assert "api.fetchData.bind(api)" in transformed
|
||||
assert counter == 1
|
||||
|
||||
|
||||
class TestInstrumentationFullStringEquality:
|
||||
"""Tests with full string equality for precise verification."""
|
||||
|
||||
def test_standalone_method_call_exact_output(self):
|
||||
"""Test exact output of standalone method call instrumentation."""
|
||||
from codeflash.languages.javascript.instrument import transform_standalone_calls
|
||||
|
||||
code = " calc.fibonacci(10);"
|
||||
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Calculator.fibonacci",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);"
|
||||
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
|
||||
assert counter == 1
|
||||
|
||||
def test_expect_method_call_exact_output(self):
|
||||
"""Test exact output of expect() method call instrumentation."""
|
||||
from codeflash.languages.javascript.instrument import transform_expect_calls
|
||||
|
||||
code = " expect(calc.fibonacci(10)).toBe(55);"
|
||||
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Calculator.fibonacci",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
expected = " expect(codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10)).toBe(55);"
|
||||
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
|
||||
assert counter == 1
|
||||
|
||||
def test_expect_method_call_remove_assertions_exact_output(self):
|
||||
"""Test exact output when removing assertions."""
|
||||
from codeflash.languages.javascript.instrument import transform_expect_calls
|
||||
|
||||
code = " expect(calc.fibonacci(10)).toBe(55);"
|
||||
|
||||
transformed, counter = transform_expect_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Calculator.fibonacci",
|
||||
capture_func="capture",
|
||||
remove_assertions=True,
|
||||
)
|
||||
|
||||
expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);"
|
||||
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
|
||||
assert counter == 1
|
||||
|
||||
def test_standalone_function_call_no_object_prefix(self):
|
||||
"""Test that standalone function calls (no object) work correctly."""
|
||||
from codeflash.languages.javascript.instrument import transform_standalone_calls
|
||||
|
||||
code = " fibonacci(10);"
|
||||
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="fibonacci",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
expected = " codeflash.capture('fibonacci', '1', fibonacci, 10);"
|
||||
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
|
||||
assert counter == 1
|
||||
|
||||
def test_this_method_call_exact_output(self):
|
||||
"""Test exact output for this.method() call."""
|
||||
from codeflash.languages.javascript.instrument import transform_standalone_calls
|
||||
|
||||
code = " return this.fibonacci(n - 1);"
|
||||
|
||||
transformed, counter = transform_standalone_calls(
|
||||
code=code,
|
||||
func_name="fibonacci",
|
||||
qualified_name="Class.fibonacci",
|
||||
capture_func="capture",
|
||||
)
|
||||
|
||||
expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);"
|
||||
assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}"
|
||||
assert counter == 1
|
||||
|
|
@ -695,3 +695,546 @@ describe('Math functions', () => {
|
|||
assert "Math functions" in test_names
|
||||
assert "add returns sum" in test_names
|
||||
assert "handles negative numbers" in test_names
|
||||
|
||||
|
||||
class TestClassMethodExtraction:
|
||||
"""Tests for class method extraction and code context.
|
||||
|
||||
These tests use full string equality to verify exact extraction output.
|
||||
"""
|
||||
|
||||
def test_extract_class_method_wraps_in_class(self, js_support):
|
||||
"""Test that extracting a class method wraps it in a class definition."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class Calculator {
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
multiply(a, b) {
|
||||
return a * b;
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
# Discover the method
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next(f for f in functions if f.name == "add")
|
||||
|
||||
# Extract code context
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
||||
# Full string equality check for exact extraction output
|
||||
expected_code = """class Calculator {
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
|
||||
assert js_support.validate_syntax(context.target_code) is True
|
||||
|
||||
def test_extract_class_method_with_jsdoc(self, js_support):
|
||||
"""Test extracting a class method with JSDoc comments."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""/**
|
||||
* A simple calculator class.
|
||||
*/
|
||||
class Calculator {
|
||||
/**
|
||||
* Adds two numbers.
|
||||
* @param {number} a - First number
|
||||
* @param {number} b - Second number
|
||||
* @returns {number} The sum
|
||||
*/
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next(f for f in functions if f.name == "add")
|
||||
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
||||
# Full string equality check - includes class JSDoc, class definition, method JSDoc, and method
|
||||
expected_code = """/**
|
||||
* A simple calculator class.
|
||||
*/
|
||||
class Calculator {
|
||||
/**
|
||||
* Adds two numbers.
|
||||
* @param {number} a - First number
|
||||
* @param {number} b - Second number
|
||||
* @returns {number} The sum
|
||||
*/
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
|
||||
assert js_support.validate_syntax(context.target_code) is True
|
||||
|
||||
def test_extract_class_method_syntax_valid(self, js_support):
|
||||
"""Test that extracted class method code is always syntactically valid."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class FibonacciCalculator {
|
||||
fibonacci(n) {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
fib_method = next(f for f in functions if f.name == "fibonacci")
|
||||
|
||||
context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent)
|
||||
|
||||
# Full string equality check
|
||||
expected_code = """class FibonacciCalculator {
|
||||
fibonacci(n) {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return this.fibonacci(n - 1) + this.fibonacci(n - 2);
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
|
||||
assert js_support.validate_syntax(context.target_code) is True
|
||||
|
||||
def test_extract_nested_class_method(self, js_support):
|
||||
"""Test extracting a method from a nested class structure."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class Outer {
|
||||
createInner() {
|
||||
return class Inner {
|
||||
getValue() {
|
||||
return 42;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next((f for f in functions if f.name == "add"), None)
|
||||
|
||||
if add_method:
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
||||
# Full string equality check
|
||||
expected_code = """class Outer {
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
|
||||
assert js_support.validate_syntax(context.target_code) is True
|
||||
|
||||
def test_extract_async_class_method(self, js_support):
|
||||
"""Test extracting an async class method."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class ApiClient {
|
||||
async fetchData(url) {
|
||||
const response = await fetch(url);
|
||||
return response.json();
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
fetch_method = next(f for f in functions if f.name == "fetchData")
|
||||
|
||||
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
|
||||
|
||||
# Full string equality check
|
||||
expected_code = """class ApiClient {
|
||||
async fetchData(url) {
|
||||
const response = await fetch(url);
|
||||
return response.json();
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
|
||||
assert js_support.validate_syntax(context.target_code) is True
|
||||
|
||||
def test_extract_static_class_method(self, js_support):
|
||||
"""Test extracting a static class method."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class MathUtils {
|
||||
static add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
static multiply(a, b) {
|
||||
return a * b;
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next((f for f in functions if f.name == "add"), None)
|
||||
|
||||
if add_method:
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
||||
# Full string equality check
|
||||
expected_code = """class MathUtils {
|
||||
static add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
|
||||
assert js_support.validate_syntax(context.target_code) is True
|
||||
|
||||
def test_extract_class_method_without_class_jsdoc(self, js_support):
|
||||
"""Test extracting a method from a class without JSDoc."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class SimpleClass {
|
||||
simpleMethod() {
|
||||
return "hello";
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
method = next(f for f in functions if f.name == "simpleMethod")
|
||||
|
||||
context = js_support.extract_code_context(method, file_path.parent, file_path.parent)
|
||||
|
||||
# Full string equality check
|
||||
expected_code = """class SimpleClass {
|
||||
simpleMethod() {
|
||||
return "hello";
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
|
||||
assert js_support.validate_syntax(context.target_code) is True
|
||||
|
||||
|
||||
class TestClassMethodReplacement:
|
||||
"""Tests for replacing class methods."""
|
||||
|
||||
def test_replace_class_method_preserves_class_structure(self, js_support):
|
||||
"""Test that replacing a class method preserves the class structure."""
|
||||
source = """class Calculator {
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
multiply(a, b) {
|
||||
return a * b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
is_method=True,
|
||||
)
|
||||
new_code = """ add(a, b) {
|
||||
// Optimized bitwise addition
|
||||
return (a + b) | 0;
|
||||
}
|
||||
"""
|
||||
result = js_support.replace_function(source, func, new_code)
|
||||
|
||||
# Check class structure is preserved
|
||||
assert "class Calculator" in result
|
||||
assert "multiply(a, b)" in result
|
||||
assert "return a * b" in result
|
||||
|
||||
# Check new code is inserted
|
||||
assert "Optimized bitwise addition" in result
|
||||
assert "(a + b) | 0" in result
|
||||
|
||||
# Check result is valid JavaScript
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
||||
def test_replace_class_method_with_jsdoc(self, js_support):
|
||||
"""Test replacing a class method that has JSDoc."""
|
||||
source = """class Calculator {
|
||||
/**
|
||||
* Adds two numbers.
|
||||
*/
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=5, # Method starts here
|
||||
end_line=7,
|
||||
doc_start_line=2, # JSDoc starts here
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
is_method=True,
|
||||
)
|
||||
new_code = """ /**
|
||||
* Adds two numbers (optimized).
|
||||
*/
|
||||
add(a, b) {
|
||||
return (a + b) | 0;
|
||||
}
|
||||
"""
|
||||
result = js_support.replace_function(source, func, new_code)
|
||||
|
||||
assert "optimized" in result
|
||||
assert "(a + b) | 0" in result
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
||||
def test_replace_multiple_class_methods_sequentially(self, js_support):
|
||||
"""Test replacing multiple methods in sequence."""
|
||||
source = """class Math {
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
subtract(a, b) {
|
||||
return a - b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
# Replace add first
|
||||
add_func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
parents=(ParentInfo(name="Math", type="ClassDef"),),
|
||||
is_method=True,
|
||||
)
|
||||
source = js_support.replace_function(source, add_func, """ add(a, b) {
|
||||
return (a + b) | 0;
|
||||
}
|
||||
""")
|
||||
|
||||
assert js_support.validate_syntax(source) is True
|
||||
|
||||
# Now need to re-discover to get updated line numbers
|
||||
# In practice, codeflash handles this, but for test we just check validity
|
||||
assert "return (a + b) | 0" in source
|
||||
assert "return a - b" in source
|
||||
|
||||
def test_replace_class_method_indentation_adjustment(self, js_support):
|
||||
"""Test that indentation is correctly adjusted when replacing."""
|
||||
source = """ class Indented {
|
||||
innerMethod() {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="innerMethod",
|
||||
file_path=Path("/test.js"),
|
||||
start_line=2,
|
||||
end_line=4,
|
||||
parents=(ParentInfo(name="Indented", type="ClassDef"),),
|
||||
is_method=True,
|
||||
)
|
||||
# New code with no indentation
|
||||
new_code = """innerMethod() {
|
||||
return 42;
|
||||
}
|
||||
"""
|
||||
result = js_support.replace_function(source, func, new_code)
|
||||
|
||||
# Check that indentation was adjusted
|
||||
lines = result.splitlines()
|
||||
method_line = next(l for l in lines if "innerMethod" in l)
|
||||
# Should have 8 spaces (original indentation)
|
||||
assert method_line.startswith(" ")
|
||||
|
||||
assert js_support.validate_syntax(result) is True
|
||||
|
||||
|
||||
class TestClassMethodEdgeCases:
|
||||
"""Edge case tests for class method handling."""
|
||||
|
||||
def test_class_with_constructor(self, js_support):
|
||||
"""Test handling classes with constructors."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class Counter {
|
||||
constructor(start = 0) {
|
||||
this.value = start;
|
||||
}
|
||||
|
||||
increment() {
|
||||
return ++this.value;
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Should find constructor and increment
|
||||
names = {f.name for f in functions}
|
||||
assert "constructor" in names or "increment" in names
|
||||
|
||||
def test_class_with_getters_setters(self, js_support):
|
||||
"""Test handling classes with getters and setters."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class Person {
|
||||
constructor(name) {
|
||||
this._name = name;
|
||||
}
|
||||
|
||||
get name() {
|
||||
return this._name;
|
||||
}
|
||||
|
||||
set name(value) {
|
||||
this._name = value;
|
||||
}
|
||||
|
||||
greet() {
|
||||
return 'Hello, ' + this._name;
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Should find at least greet
|
||||
names = {f.name for f in functions}
|
||||
assert "greet" in names
|
||||
|
||||
def test_class_extending_another(self, js_support):
|
||||
"""Test handling classes that extend another class."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class Animal {
|
||||
speak() {
|
||||
return 'sound';
|
||||
}
|
||||
}
|
||||
|
||||
class Dog extends Animal {
|
||||
speak() {
|
||||
return 'bark';
|
||||
}
|
||||
|
||||
fetch() {
|
||||
return 'ball';
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Find Dog's fetch method
|
||||
fetch_method = next((f for f in functions if f.name == "fetch" and f.class_name == "Dog"), None)
|
||||
|
||||
if fetch_method:
|
||||
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
|
||||
|
||||
# Full string equality check
|
||||
expected_code = """class Dog {
|
||||
fetch() {
|
||||
return 'ball';
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}"
|
||||
assert js_support.validate_syntax(context.target_code) is True
|
||||
|
||||
def test_class_with_private_method(self, js_support):
|
||||
"""Test handling classes with private methods (ES2022+)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class SecureClass {
|
||||
#privateMethod() {
|
||||
return 'secret';
|
||||
}
|
||||
|
||||
publicMethod() {
|
||||
return this.#privateMethod();
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Should at least find publicMethod
|
||||
names = {f.name for f in functions}
|
||||
assert "publicMethod" in names
|
||||
|
||||
def test_commonjs_class_export(self, js_support):
|
||||
"""Test handling CommonJS exported classes."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""class Calculator {
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { Calculator };
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
add_method = next(f for f in functions if f.name == "add")
|
||||
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
||||
assert "class Calculator" in context.target_code
|
||||
assert js_support.validate_syntax(context.target_code) is True
|
||||
|
||||
def test_es_module_class_export(self, js_support):
|
||||
"""Test handling ES module exported classes."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f:
|
||||
f.write("""export class Calculator {
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
|
||||
# Find the add method
|
||||
add_method = next((f for f in functions if f.name == "add"), None)
|
||||
|
||||
if add_method:
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
assert js_support.validate_syntax(context.target_code) is True
|
||||
|
|
|
|||
Loading…
Reference in a new issue