From 325534dbc27e6b393b70997e5675e54c420d0239 Mon Sep 17 00:00:00 2001 From: misrasaurabh1 Date: Wed, 28 Jan 2026 23:28:59 -0800 Subject: [PATCH] extract class skeleton for optimization context --- .../fibonacci_class.js | 61 ++ .../code_to_optimize_js_cjs/package-lock.json | 2 +- .../tests/fibonacci_class.test.js | 105 ++++ codeflash/languages/javascript/instrument.py | 126 +++- codeflash/languages/javascript/support.py | 88 +++ .../test_javascript_instrumentation.py | 354 +++++++++++- .../test_languages/test_javascript_support.py | 543 ++++++++++++++++++ 7 files changed, 1265 insertions(+), 14 deletions(-) create mode 100644 code_to_optimize/js/code_to_optimize_js_cjs/fibonacci_class.js create mode 100644 code_to_optimize/js/code_to_optimize_js_cjs/tests/fibonacci_class.test.js diff --git a/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci_class.js b/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci_class.js new file mode 100644 index 000000000..24621ee7f --- /dev/null +++ b/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci_class.js @@ -0,0 +1,61 @@ +/** + * Fibonacci Calculator Class - CommonJS module + * Intentionally inefficient for optimization testing. + */ + +class FibonacciCalculator { + constructor() { + // No initialization needed + } + + /** + * Calculate the nth Fibonacci number using naive recursion. + * This is intentionally slow to demonstrate optimization potential. + * @param {number} n - The index of the Fibonacci number to calculate + * @returns {number} The nth Fibonacci number + */ + fibonacci(n) { + if (n <= 1) { + return n; + } + return this.fibonacci(n - 1) + this.fibonacci(n - 2); + } + + /** + * Check if a number is a Fibonacci number. + * @param {number} num - The number to check + * @returns {boolean} True if num is a Fibonacci number + */ + isFibonacci(num) { + // A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square + const check1 = 5 * num * num + 4; + const check2 = 5 * num * num - 4; + return this.isPerfectSquare(check1) || this.isPerfectSquare(check2); + } + + /** + * Check if a number is a perfect square. + * @param {number} n - The number to check + * @returns {boolean} True if n is a perfect square + */ + isPerfectSquare(n) { + const sqrt = Math.sqrt(n); + return sqrt === Math.floor(sqrt); + } + + /** + * Generate an array of Fibonacci numbers up to n. + * @param {number} n - The number of Fibonacci numbers to generate + * @returns {number[]} Array of Fibonacci numbers + */ + fibonacciSequence(n) { + const result = []; + for (let i = 0; i < n; i++) { + result.push(this.fibonacci(i)); + } + return result; + } +} + +// CommonJS exports +module.exports = { FibonacciCalculator }; diff --git a/code_to_optimize/js/code_to_optimize_js_cjs/package-lock.json b/code_to_optimize/js/code_to_optimize_js_cjs/package-lock.json index b81c1140b..71ef4f5c5 100644 --- a/code_to_optimize/js/code_to_optimize_js_cjs/package-lock.json +++ b/code_to_optimize/js/code_to_optimize_js_cjs/package-lock.json @@ -14,7 +14,7 @@ } }, "../../../packages/codeflash": { - "version": "0.1.0", + "version": "0.2.0", "dev": true, "hasInstallScript": true, "license": "MIT", diff --git a/code_to_optimize/js/code_to_optimize_js_cjs/tests/fibonacci_class.test.js b/code_to_optimize/js/code_to_optimize_js_cjs/tests/fibonacci_class.test.js new file mode 100644 index 000000000..8d1859991 --- /dev/null +++ b/code_to_optimize/js/code_to_optimize_js_cjs/tests/fibonacci_class.test.js @@ -0,0 +1,105 @@ +const { FibonacciCalculator } = require('../fibonacci_class'); + +describe('FibonacciCalculator', () => { + let calc; + + beforeEach(() => { + calc = new FibonacciCalculator(); + }); + + describe('fibonacci', () => { + test('returns 0 for n=0', () => { + expect(calc.fibonacci(0)).toBe(0); + }); + + test('returns 1 for n=1', () => { + expect(calc.fibonacci(1)).toBe(1); + }); + + test('returns 1 for n=2', () => { + expect(calc.fibonacci(2)).toBe(1); + }); + + test('returns 5 for n=5', () => { + expect(calc.fibonacci(5)).toBe(5); + }); + + test('returns 55 for n=10', () => { + expect(calc.fibonacci(10)).toBe(55); + }); + + test('returns 233 for n=13', () => { + expect(calc.fibonacci(13)).toBe(233); + }); + }); + + describe('isFibonacci', () => { + test('returns true for 0', () => { + expect(calc.isFibonacci(0)).toBe(true); + }); + + test('returns true for 1', () => { + expect(calc.isFibonacci(1)).toBe(true); + }); + + test('returns true for 8', () => { + expect(calc.isFibonacci(8)).toBe(true); + }); + + test('returns true for 13', () => { + expect(calc.isFibonacci(13)).toBe(true); + }); + + test('returns false for 4', () => { + expect(calc.isFibonacci(4)).toBe(false); + }); + + test('returns false for 6', () => { + expect(calc.isFibonacci(6)).toBe(false); + }); + }); + + describe('isPerfectSquare', () => { + test('returns true for 0', () => { + expect(calc.isPerfectSquare(0)).toBe(true); + }); + + test('returns true for 1', () => { + expect(calc.isPerfectSquare(1)).toBe(true); + }); + + test('returns true for 4', () => { + expect(calc.isPerfectSquare(4)).toBe(true); + }); + + test('returns true for 16', () => { + expect(calc.isPerfectSquare(16)).toBe(true); + }); + + test('returns false for 2', () => { + expect(calc.isPerfectSquare(2)).toBe(false); + }); + + test('returns false for 3', () => { + expect(calc.isPerfectSquare(3)).toBe(false); + }); + }); + + describe('fibonacciSequence', () => { + test('returns empty array for n=0', () => { + expect(calc.fibonacciSequence(0)).toEqual([]); + }); + + test('returns [0] for n=1', () => { + expect(calc.fibonacciSequence(1)).toEqual([0]); + }); + + test('returns first 5 Fibonacci numbers', () => { + expect(calc.fibonacciSequence(5)).toEqual([0, 1, 1, 2, 3]); + }); + + test('returns first 10 Fibonacci numbers', () => { + expect(calc.fibonacciSequence(10)).toEqual([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]); + }); + }); +}); diff --git a/codeflash/languages/javascript/instrument.py b/codeflash/languages/javascript/instrument.py index 47ed9ea51..2cf45c7a5 100644 --- a/codeflash/languages/javascript/instrument.py +++ b/codeflash/languages/javascript/instrument.py @@ -35,6 +35,7 @@ class ExpectCallMatch: func_args: str assertion_chain: str has_trailing_semicolon: bool + object_prefix: str = "" # Object prefix like "calc." or "this." or "" @dataclass @@ -45,7 +46,8 @@ class StandaloneCallMatch: end_pos: int leading_whitespace: str func_args: str - prefix: str # Everything between whitespace and func name (e.g., "await ", "") + prefix: str # "await " or "" + object_prefix: str # Object prefix like "calc." or "this." or "" has_trailing_semicolon: bool @@ -61,6 +63,7 @@ class StandaloneCallTransformer: - func(args) -> codeflash.capturePerf('name', 'id', func, args) - const result = func(args) -> const result = codeflash.capturePerf(...) - arr.map(() => func(args)) -> arr.map(() => codeflash.capturePerf(..., func, args)) + - calc.fibonacci(n) -> codeflash.capturePerf('...', 'id', calc.fibonacci.bind(calc), n) """ @@ -69,9 +72,12 @@ class StandaloneCallTransformer: self.qualified_name = qualified_name self.capture_func = capture_func self.invocation_counter = 0 - # Pattern to match func_name( with optional leading await + # Pattern to match func_name( with optional leading await and optional object prefix + # Captures: (whitespace)(await )?(object.)*func_name( # We'll filter out expect() and codeflash. cases in the transform loop - self._call_pattern = re.compile(rf"(\s*)(await\s+)?{re.escape(func_name)}\s*\(") + self._call_pattern = re.compile( + rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(func_name)}\s*\(" + ) def transform(self, code: str) -> str: """Transform all standalone calls in the code.""" @@ -121,6 +127,42 @@ class StandaloneCallTransformer: if f"codeflash.{self.capture_func}(" in lookback[-60:]: return True + # Skip if this is a function/method definition, not a call + # Patterns to skip: + # - ClassName.prototype.funcName = function( + # - funcName = function( + # - funcName: function( + # - function funcName( + # - funcName() { (method definition in class) + near_context = lookback[-80:] if len(lookback) >= 80 else lookback + + # Skip prototype assignment: ClassName.prototype.funcName = function( + if re.search(r"\.prototype\.\w+\s*=\s*function\s*$", near_context): + return True + + # Skip function assignment: funcName = function( + if re.search(rf"{re.escape(self.func_name)}\s*=\s*function\s*$", near_context): + return True + + # Skip function declaration: function funcName( + if re.search(rf"function\s+{re.escape(self.func_name)}\s*$", near_context): + return True + + # Skip method definition in class body: funcName(params) { or async funcName(params) { + # Check by looking at what comes after the closing paren + # The match ends at the opening paren, so find the closing paren and check what follows + close_paren_pos = self._find_matching_paren(code, match.end() - 1) + if close_paren_pos != -1: + # Check if followed by { (method definition) after optional whitespace + after_close = code[close_paren_pos : close_paren_pos + 20].lstrip() + if after_close.startswith("{"): + # This is a method definition like "fibonacci(n) {" + # But we still want to capture certain patterns like arrow functions + # Check if there's no => before the { + between = code[close_paren_pos : close_paren_pos + 20].strip() + if not between.startswith("=>"): + return True + # Skip if inside expect() - look for 'expect(' with unmatched parens # Find the last 'expect(' and check if it's still open expect_search_start = max(0, start - 100) @@ -138,10 +180,28 @@ class StandaloneCallTransformer: return False + def _find_matching_paren(self, code: str, open_paren_pos: int) -> int: + """Find the position of the closing paren for the given opening paren.""" + if open_paren_pos >= len(code) or code[open_paren_pos] != "(": + return -1 + + depth = 1 + pos = open_paren_pos + 1 + + while pos < len(code) and depth > 0: + if code[pos] == "(": + depth += 1 + elif code[pos] == ")": + depth -= 1 + pos += 1 + + return pos if depth == 0 else -1 + def _parse_standalone_call(self, code: str, match: re.Match) -> StandaloneCallMatch | None: """Parse a complete standalone func(...) call.""" leading_ws = match.group(1) prefix = match.group(2) or "" # "await " or "" + object_prefix = match.group(3) or "" # Object prefix like "calc." or "" # Find the opening paren position match_text = match.group(0) @@ -169,6 +229,7 @@ class StandaloneCallTransformer: leading_whitespace=leading_ws, func_args=func_args, prefix=prefix, + object_prefix=object_prefix, has_trailing_semicolon=has_trailing_semicolon, ) @@ -212,6 +273,23 @@ class StandaloneCallTransformer: args_str = match.func_args.strip() semicolon = ";" if match.has_trailing_semicolon else "" + # Handle method calls on objects (e.g., calc.fibonacci, this.method) + if match.object_prefix: + # Remove trailing dot from object prefix for the bind call + obj = match.object_prefix.rstrip(".") + full_method = f"{obj}.{self.func_name}" + + if args_str: + return ( + f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', " + f"'{line_id}', {full_method}.bind({obj}), {args_str}){semicolon}" + ) + return ( + f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', " + f"'{line_id}', {full_method}.bind({obj})){semicolon}" + ) + + # Handle standalone function calls if args_str: return ( f"{match.leading_whitespace}{match.prefix}codeflash.{self.capture_func}('{self.qualified_name}', " @@ -268,8 +346,11 @@ class ExpectCallTransformer: self.capture_func = capture_func self.remove_assertions = remove_assertions self.invocation_counter = 0 - # Pattern to match start of expect(func_name( - self._expect_pattern = re.compile(rf"(\s*)expect\s*\(\s*{re.escape(func_name)}\s*\(") + # Pattern to match start of expect((object.)*func_name( + # Captures: (whitespace), (object prefix like calc. or this.) + self._expect_pattern = re.compile( + rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(func_name)}\s*\(" + ) def transform(self, code: str) -> str: """Transform all expect calls in the code.""" @@ -307,6 +388,7 @@ class ExpectCallTransformer: Returns None if the pattern doesn't match expected structure. """ leading_ws = match.group(1) + object_prefix = match.group(2) or "" # Object prefix like "calc." or "" # Find the arguments of the function call (handling nested parens) args_start = match.end() @@ -341,6 +423,7 @@ class ExpectCallTransformer: func_args=func_args, assertion_chain=assertion_chain, has_trailing_semicolon=has_trailing_semicolon, + object_prefix=object_prefix, ) def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]: @@ -473,16 +556,24 @@ class ExpectCallTransformer: line_id = str(self.invocation_counter) args_str = match.func_args.strip() + # Determine the function reference to use + if match.object_prefix: + # Method call on object: calc.fibonacci -> calc.fibonacci.bind(calc) + obj = match.object_prefix.rstrip(".") + func_ref = f"{obj}.{self.func_name}.bind({obj})" + else: + func_ref = self.func_name + if self.remove_assertions: # For generated/regression tests: remove expect wrapper and assertion if args_str: return ( f"{match.leading_whitespace}codeflash.{self.capture_func}('{self.qualified_name}', " - f"'{line_id}', {self.func_name}, {args_str});" + f"'{line_id}', {func_ref}, {args_str});" ) return ( f"{match.leading_whitespace}codeflash.{self.capture_func}('{self.qualified_name}', " - f"'{line_id}', {self.func_name});" + f"'{line_id}', {func_ref});" ) # For existing tests: keep the expect wrapper @@ -490,11 +581,11 @@ class ExpectCallTransformer: if args_str: return ( f"{match.leading_whitespace}expect(codeflash.{self.capture_func}('{self.qualified_name}', " - f"'{line_id}', {self.func_name}, {args_str})){match.assertion_chain}{semicolon}" + f"'{line_id}', {func_ref}, {args_str})){match.assertion_chain}{semicolon}" ) return ( f"{match.leading_whitespace}expect(codeflash.{self.capture_func}('{self.qualified_name}', " - f"'{line_id}', {self.func_name})){match.assertion_chain}{semicolon}" + f"'{line_id}', {func_ref})){match.assertion_chain}{semicolon}" ) @@ -582,13 +673,18 @@ def inject_profiling_into_existing_js_test( def _is_function_used_in_test(code: str, func_name: str) -> bool: - """Check if a function is imported or used in the test code.""" - # Check for CommonJS require + """Check if a function is imported or used in the test code. + + This function handles both standalone functions and class methods. + For class methods, it checks if the method is called on any object + (e.g., calc.fibonacci, this.fibonacci). + """ + # Check for CommonJS require with named export require_pattern = rf"(?:const|let|var)\s+\{{\s*[^}}]*\b{re.escape(func_name)}\b[^}}]*\}}\s*=\s*require\s*\(" if re.search(require_pattern, code): return True - # Check for ES6 import + # Check for ES6 import with named export import_pattern = rf"import\s+\{{\s*[^}}]*\b{re.escape(func_name)}\b[^}}]*\}}\s+from" if re.search(import_pattern, code): return True @@ -602,6 +698,12 @@ def _is_function_used_in_test(code: str, func_name: str) -> bool: if re.search(default_import, code): return True + # Check for method calls: obj.funcName( or this.funcName( + # This handles class methods called on instances + method_call_pattern = rf"\w+\.{re.escape(func_name)}\s*\(" + if re.search(method_call_pattern, code): + return True + return False diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index b30df5196..d4742cf8f 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -321,6 +321,29 @@ class JavaScriptSupport: else: target_code = "" + # For class methods, wrap the method in its class definition + # This is necessary because method definition syntax is only valid inside a class body + if function.is_method and function.parents: + class_name = None + for parent in function.parents: + if parent.type == "ClassDef": + class_name = parent.name + break + + if class_name: + # Find the class definition in the source to get proper indentation and any class JSDoc + class_info = self._find_class_definition(source, class_name, analyzer) + if class_info: + class_jsdoc, class_indent = class_info + # Wrap the method in a minimal class definition + if class_jsdoc: + target_code = f"{class_jsdoc}\n{class_indent}class {class_name} {{\n{target_code}{class_indent}}}\n" + else: + target_code = f"{class_indent}class {class_name} {{\n{target_code}{class_indent}}}\n" + else: + # Fallback: wrap with no indentation + target_code = f"class {class_name} {{\n{target_code}}}\n" + imports = analyzer.find_imports(source) # Find helper functions called by target @@ -337,6 +360,16 @@ class JavaScriptSupport: target_code=target_code, helpers=helpers, source=source, analyzer=analyzer, imports=imports ) + # Validate that the extracted code is syntactically valid + # If not, raise an error to fail the optimization early + if target_code and not self.validate_syntax(target_code): + error_msg = ( + f"Extracted code for {function.name} is not syntactically valid JavaScript. " + f"Cannot proceed with optimization." + ) + logger.error(error_msg) + raise ValueError(error_msg) + return CodeContext( target_code=target_code, target_file=function.file_path, @@ -346,6 +379,61 @@ class JavaScriptSupport: language=Language.JAVASCRIPT, ) + def _find_class_definition( + self, source: str, class_name: str, analyzer: TreeSitterAnalyzer + ) -> tuple[str, str] | None: + """Find a class definition and extract its JSDoc comment and indentation. + + Args: + source: The source code to search. + class_name: The name of the class to find. + analyzer: TreeSitterAnalyzer for parsing. + + Returns: + Tuple of (jsdoc_comment, indentation) or None if not found. + + """ + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + + def find_class_node(node): + """Recursively find a class declaration with the given name.""" + if node.type in ("class_declaration", "class"): + name_node = node.child_by_field_name("name") + if name_node: + node_name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") + if node_name == class_name: + return node + for child in node.children: + result = find_class_node(child) + if result: + return result + return None + + class_node = find_class_node(tree.root_node) + if not class_node: + return None + + # Get indentation from the class line + lines = source.splitlines(keepends=True) + class_line_idx = class_node.start_point[0] + if class_line_idx < len(lines): + class_line = lines[class_line_idx] + indent = len(class_line) - len(class_line.lstrip()) + indentation = " " * indent + else: + indentation = "" + + # Look for preceding JSDoc comment + jsdoc = "" + prev_sibling = class_node.prev_named_sibling + if prev_sibling and prev_sibling.type == "comment": + comment_text = source_bytes[prev_sibling.start_byte : prev_sibling.end_byte].decode("utf8") + if comment_text.strip().startswith("/**"): + jsdoc = comment_text + + return (jsdoc, indentation) + def _find_helper_functions( self, function: FunctionInfo, source: str, analyzer: TreeSitterAnalyzer, imports: list[Any], module_root: Path ) -> list[HelperFunction]: diff --git a/tests/test_languages/test_javascript_instrumentation.py b/tests/test_languages/test_javascript_instrumentation.py index 02b77e033..cc9287fb2 100644 --- a/tests/test_languages/test_javascript_instrumentation.py +++ b/tests/test_languages/test_javascript_instrumentation.py @@ -357,4 +357,356 @@ test('decrypt works', () => {{ assert fixed_code == test_code # Clean up - source_path.unlink() \ No newline at end of file + source_path.unlink() + + +class TestClassMethodInstrumentation: + """Tests for class method instrumentation.""" + + def test_instrument_method_call_on_instance(self): + """Test that method calls on instances are correctly instrumented.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + code = """ +const calc = new Calculator(); +const result = calc.fibonacci(10); +console.log(result); +""" + transformed, counter = transform_standalone_calls( + code=code, + func_name="fibonacci", + qualified_name="Calculator.fibonacci", + capture_func="capture", + ) + + # Should transform calc.fibonacci(10) to codeflash.capture(..., calc.fibonacci.bind(calc), 10) + assert "codeflash.capture('Calculator.fibonacci'" in transformed + assert "calc.fibonacci.bind(calc)" in transformed + assert counter == 1 + + def test_instrument_expect_with_method_call(self): + """Test that expect() with method calls are correctly instrumented.""" + from codeflash.languages.javascript.instrument import transform_expect_calls + + code = """ +test('fibonacci works', () => { + const calc = new FibonacciCalculator(); + expect(calc.fibonacci(10)).toBe(55); +}); +""" + transformed, counter = transform_expect_calls( + code=code, + func_name="fibonacci", + qualified_name="FibonacciCalculator.fibonacci", + capture_func="capture", + ) + + # Should transform expect(calc.fibonacci(10)) to + # expect(codeflash.capture(..., calc.fibonacci.bind(calc), 10)) + assert "codeflash.capture('FibonacciCalculator.fibonacci'" in transformed + assert "calc.fibonacci.bind(calc)" in transformed + assert ".toBe(55)" in transformed + assert counter == 1 + + def test_instrument_expect_with_method_removes_assertion(self): + """Test that expect() with method calls are correctly instrumented with assertion removal.""" + from codeflash.languages.javascript.instrument import transform_expect_calls + + code = """ +test('fibonacci works', () => { + const calc = new FibonacciCalculator(); + expect(calc.fibonacci(10)).toBe(55); +}); +""" + transformed, counter = transform_expect_calls( + code=code, + func_name="fibonacci", + qualified_name="FibonacciCalculator.fibonacci", + capture_func="capture", + remove_assertions=True, + ) + + # Should remove expect wrapper and assertion + assert "codeflash.capture('FibonacciCalculator.fibonacci'" in transformed + assert "calc.fibonacci.bind(calc)" in transformed + assert ".toBe(55)" not in transformed # Assertion removed + assert "expect(" not in transformed # expect wrapper removed + assert counter == 1 + + def test_does_not_instrument_function_definition(self): + """Test that function definitions are NOT transformed.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + code = """ +class FibonacciCalculator { + fibonacci(n) { + if (n <= 1) return n; + return this.fibonacci(n - 1) + this.fibonacci(n - 2); + } +} +""" + transformed, counter = transform_standalone_calls( + code=code, + func_name="fibonacci", + qualified_name="FibonacciCalculator.fibonacci", + capture_func="capture", + ) + + # The method definition should NOT be transformed + # Only the recursive calls this.fibonacci(...) should potentially be transformed + assert "fibonacci(n) {" in transformed # Method definition unchanged + assert counter >= 0 # May or may not transform the recursive calls + + def test_does_not_instrument_prototype_assignment(self): + """Test that prototype assignments are NOT transformed.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + code = """ +FibonacciCalculator.prototype.fibonacci = function(n) { + if (n <= 1) return n; + return this.fibonacci(n - 1) + this.fibonacci(n - 2); +}; +""" + transformed, counter = transform_standalone_calls( + code=code, + func_name="fibonacci", + qualified_name="FibonacciCalculator.fibonacci", + capture_func="capture", + ) + + # The prototype assignment should NOT be transformed + # It should still have the original pattern + assert "FibonacciCalculator.prototype.fibonacci = function(n)" in transformed + + def test_instrument_multiple_method_calls(self): + """Test that multiple method calls are correctly instrumented.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + code = """ +const calc = new Calculator(); +const a = calc.fibonacci(5); +const b = calc.fibonacci(10); +const sum = a + b; +""" + transformed, counter = transform_standalone_calls( + code=code, + func_name="fibonacci", + qualified_name="Calculator.fibonacci", + capture_func="capture", + ) + + # Should transform both calls + assert transformed.count("codeflash.capture") == 2 + assert counter == 2 + + def test_instrument_this_method_call(self): + """Test that this.method() calls are correctly instrumented.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + code = """ +class Wrapper { + callFibonacci(n) { + return this.fibonacci(n); + } +} +""" + transformed, counter = transform_standalone_calls( + code=code, + func_name="fibonacci", + qualified_name="Wrapper.fibonacci", + capture_func="capture", + ) + + # Should transform this.fibonacci(n) + assert "codeflash.capture('Wrapper.fibonacci'" in transformed + assert "this.fibonacci.bind(this)" in transformed + assert counter == 1 + + def test_full_instrumentation_produces_valid_syntax(self): + """Test that full instrumentation produces syntactically valid JavaScript.""" + from codeflash.languages.javascript.instrument import _instrument_js_test_code + from codeflash.languages import get_language_support + from codeflash.languages.base import Language + + js_support = get_language_support(Language.JAVASCRIPT) + + test_code = """ +const { FibonacciCalculator } = require('../fibonacci_class'); + +describe('FibonacciCalculator', () => { + let calc; + + beforeEach(() => { + calc = new FibonacciCalculator(); + }); + + test('fibonacci returns correct values', () => { + expect(calc.fibonacci(0)).toBe(0); + expect(calc.fibonacci(1)).toBe(1); + expect(calc.fibonacci(10)).toBe(55); + }); + + test('standalone call', () => { + const result = calc.fibonacci(5); + expect(result).toBe(5); + }); +}); +""" + instrumented = _instrument_js_test_code( + code=test_code, + func_name="fibonacci", + test_file_path="test.js", + mode="behavior", + qualified_name="FibonacciCalculator.fibonacci", + ) + + # Check that codeflash import was added + assert "codeflash" in instrumented + + # Check that method calls were instrumented + assert "codeflash.capture" in instrumented + + # Check that the instrumented code is valid JavaScript + assert js_support.validate_syntax(instrumented) is True, f"Invalid syntax:\n{instrumented}" + + def test_instrumentation_preserves_test_structure(self): + """Test that instrumentation preserves the test structure.""" + from codeflash.languages.javascript.instrument import _instrument_js_test_code + + test_code = """ +const { Calculator } = require('../calculator'); + +describe('Calculator', () => { + test('add works', () => { + const calc = new Calculator(); + expect(calc.add(1, 2)).toBe(3); + }); +}); +""" + instrumented = _instrument_js_test_code( + code=test_code, + func_name="add", + test_file_path="test.js", + mode="behavior", + qualified_name="Calculator.add", + ) + + # describe and test structure should be preserved + assert "describe('Calculator'" in instrumented + assert "test('add works'" in instrumented + assert "beforeEach" in instrumented or "beforeEach" not in test_code # Only if it was there + + # Method call should be instrumented + assert "codeflash.capture('Calculator.add'" in instrumented + assert "calc.add.bind(calc)" in instrumented + + def test_instrumentation_with_async_methods(self): + """Test instrumentation with async method calls.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + code = """ +const api = new ApiClient(); +const data = await api.fetchData('http://example.com'); +console.log(data); +""" + transformed, counter = transform_standalone_calls( + code=code, + func_name="fetchData", + qualified_name="ApiClient.fetchData", + capture_func="capture", + ) + + # Should preserve await + assert "await codeflash.capture" in transformed + assert "api.fetchData.bind(api)" in transformed + assert counter == 1 + + +class TestInstrumentationFullStringEquality: + """Tests with full string equality for precise verification.""" + + def test_standalone_method_call_exact_output(self): + """Test exact output of standalone method call instrumentation.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + code = " calc.fibonacci(10);" + + transformed, counter = transform_standalone_calls( + code=code, + func_name="fibonacci", + qualified_name="Calculator.fibonacci", + capture_func="capture", + ) + + expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);" + assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}" + assert counter == 1 + + def test_expect_method_call_exact_output(self): + """Test exact output of expect() method call instrumentation.""" + from codeflash.languages.javascript.instrument import transform_expect_calls + + code = " expect(calc.fibonacci(10)).toBe(55);" + + transformed, counter = transform_expect_calls( + code=code, + func_name="fibonacci", + qualified_name="Calculator.fibonacci", + capture_func="capture", + ) + + expected = " expect(codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10)).toBe(55);" + assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}" + assert counter == 1 + + def test_expect_method_call_remove_assertions_exact_output(self): + """Test exact output when removing assertions.""" + from codeflash.languages.javascript.instrument import transform_expect_calls + + code = " expect(calc.fibonacci(10)).toBe(55);" + + transformed, counter = transform_expect_calls( + code=code, + func_name="fibonacci", + qualified_name="Calculator.fibonacci", + capture_func="capture", + remove_assertions=True, + ) + + expected = " codeflash.capture('Calculator.fibonacci', '1', calc.fibonacci.bind(calc), 10);" + assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}" + assert counter == 1 + + def test_standalone_function_call_no_object_prefix(self): + """Test that standalone function calls (no object) work correctly.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + code = " fibonacci(10);" + + transformed, counter = transform_standalone_calls( + code=code, + func_name="fibonacci", + qualified_name="fibonacci", + capture_func="capture", + ) + + expected = " codeflash.capture('fibonacci', '1', fibonacci, 10);" + assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}" + assert counter == 1 + + def test_this_method_call_exact_output(self): + """Test exact output for this.method() call.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + code = " return this.fibonacci(n - 1);" + + transformed, counter = transform_standalone_calls( + code=code, + func_name="fibonacci", + qualified_name="Class.fibonacci", + capture_func="capture", + ) + + expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);" + assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}" + assert counter == 1 \ No newline at end of file diff --git a/tests/test_languages/test_javascript_support.py b/tests/test_languages/test_javascript_support.py index 3014fc247..bcca59443 100644 --- a/tests/test_languages/test_javascript_support.py +++ b/tests/test_languages/test_javascript_support.py @@ -695,3 +695,546 @@ describe('Math functions', () => { assert "Math functions" in test_names assert "add returns sum" in test_names assert "handles negative numbers" in test_names + + +class TestClassMethodExtraction: + """Tests for class method extraction and code context. + + These tests use full string equality to verify exact extraction output. + """ + + def test_extract_class_method_wraps_in_class(self, js_support): + """Test that extracting a class method wraps it in a class definition.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class Calculator { + add(a, b) { + return a + b; + } + + multiply(a, b) { + return a * b; + } +} +""") + f.flush() + file_path = Path(f.name) + + # Discover the method + functions = js_support.discover_functions(file_path) + add_method = next(f for f in functions if f.name == "add") + + # Extract code context + context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) + + # Full string equality check for exact extraction output + expected_code = """class Calculator { + add(a, b) { + return a + b; + } +} +""" + assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}" + assert js_support.validate_syntax(context.target_code) is True + + def test_extract_class_method_with_jsdoc(self, js_support): + """Test extracting a class method with JSDoc comments.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""/** + * A simple calculator class. + */ +class Calculator { + /** + * Adds two numbers. + * @param {number} a - First number + * @param {number} b - Second number + * @returns {number} The sum + */ + add(a, b) { + return a + b; + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + add_method = next(f for f in functions if f.name == "add") + + context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) + + # Full string equality check - includes class JSDoc, class definition, method JSDoc, and method + expected_code = """/** + * A simple calculator class. + */ +class Calculator { + /** + * Adds two numbers. + * @param {number} a - First number + * @param {number} b - Second number + * @returns {number} The sum + */ + add(a, b) { + return a + b; + } +} +""" + assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}" + assert js_support.validate_syntax(context.target_code) is True + + def test_extract_class_method_syntax_valid(self, js_support): + """Test that extracted class method code is always syntactically valid.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class FibonacciCalculator { + fibonacci(n) { + if (n <= 1) { + return n; + } + return this.fibonacci(n - 1) + this.fibonacci(n - 2); + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + fib_method = next(f for f in functions if f.name == "fibonacci") + + context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent) + + # Full string equality check + expected_code = """class FibonacciCalculator { + fibonacci(n) { + if (n <= 1) { + return n; + } + return this.fibonacci(n - 1) + this.fibonacci(n - 2); + } +} +""" + assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}" + assert js_support.validate_syntax(context.target_code) is True + + def test_extract_nested_class_method(self, js_support): + """Test extracting a method from a nested class structure.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class Outer { + createInner() { + return class Inner { + getValue() { + return 42; + } + }; + } + + add(a, b) { + return a + b; + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + add_method = next((f for f in functions if f.name == "add"), None) + + if add_method: + context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) + + # Full string equality check + expected_code = """class Outer { + add(a, b) { + return a + b; + } +} +""" + assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}" + assert js_support.validate_syntax(context.target_code) is True + + def test_extract_async_class_method(self, js_support): + """Test extracting an async class method.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class ApiClient { + async fetchData(url) { + const response = await fetch(url); + return response.json(); + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + fetch_method = next(f for f in functions if f.name == "fetchData") + + context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent) + + # Full string equality check + expected_code = """class ApiClient { + async fetchData(url) { + const response = await fetch(url); + return response.json(); + } +} +""" + assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}" + assert js_support.validate_syntax(context.target_code) is True + + def test_extract_static_class_method(self, js_support): + """Test extracting a static class method.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class MathUtils { + static add(a, b) { + return a + b; + } + + static multiply(a, b) { + return a * b; + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + add_method = next((f for f in functions if f.name == "add"), None) + + if add_method: + context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) + + # Full string equality check + expected_code = """class MathUtils { + static add(a, b) { + return a + b; + } +} +""" + assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}" + assert js_support.validate_syntax(context.target_code) is True + + def test_extract_class_method_without_class_jsdoc(self, js_support): + """Test extracting a method from a class without JSDoc.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class SimpleClass { + simpleMethod() { + return "hello"; + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + method = next(f for f in functions if f.name == "simpleMethod") + + context = js_support.extract_code_context(method, file_path.parent, file_path.parent) + + # Full string equality check + expected_code = """class SimpleClass { + simpleMethod() { + return "hello"; + } +} +""" + assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}" + assert js_support.validate_syntax(context.target_code) is True + + +class TestClassMethodReplacement: + """Tests for replacing class methods.""" + + def test_replace_class_method_preserves_class_structure(self, js_support): + """Test that replacing a class method preserves the class structure.""" + source = """class Calculator { + add(a, b) { + return a + b; + } + + multiply(a, b) { + return a * b; + } +} +""" + func = FunctionInfo( + name="add", + file_path=Path("/test.js"), + start_line=2, + end_line=4, + parents=(ParentInfo(name="Calculator", type="ClassDef"),), + is_method=True, + ) + new_code = """ add(a, b) { + // Optimized bitwise addition + return (a + b) | 0; + } +""" + result = js_support.replace_function(source, func, new_code) + + # Check class structure is preserved + assert "class Calculator" in result + assert "multiply(a, b)" in result + assert "return a * b" in result + + # Check new code is inserted + assert "Optimized bitwise addition" in result + assert "(a + b) | 0" in result + + # Check result is valid JavaScript + assert js_support.validate_syntax(result) is True + + def test_replace_class_method_with_jsdoc(self, js_support): + """Test replacing a class method that has JSDoc.""" + source = """class Calculator { + /** + * Adds two numbers. + */ + add(a, b) { + return a + b; + } +} +""" + func = FunctionInfo( + name="add", + file_path=Path("/test.js"), + start_line=5, # Method starts here + end_line=7, + doc_start_line=2, # JSDoc starts here + parents=(ParentInfo(name="Calculator", type="ClassDef"),), + is_method=True, + ) + new_code = """ /** + * Adds two numbers (optimized). + */ + add(a, b) { + return (a + b) | 0; + } +""" + result = js_support.replace_function(source, func, new_code) + + assert "optimized" in result + assert "(a + b) | 0" in result + assert js_support.validate_syntax(result) is True + + def test_replace_multiple_class_methods_sequentially(self, js_support): + """Test replacing multiple methods in sequence.""" + source = """class Math { + add(a, b) { + return a + b; + } + + subtract(a, b) { + return a - b; + } +} +""" + # Replace add first + add_func = FunctionInfo( + name="add", + file_path=Path("/test.js"), + start_line=2, + end_line=4, + parents=(ParentInfo(name="Math", type="ClassDef"),), + is_method=True, + ) + source = js_support.replace_function(source, add_func, """ add(a, b) { + return (a + b) | 0; + } +""") + + assert js_support.validate_syntax(source) is True + + # Now need to re-discover to get updated line numbers + # In practice, codeflash handles this, but for test we just check validity + assert "return (a + b) | 0" in source + assert "return a - b" in source + + def test_replace_class_method_indentation_adjustment(self, js_support): + """Test that indentation is correctly adjusted when replacing.""" + source = """ class Indented { + innerMethod() { + return 1; + } + } +""" + func = FunctionInfo( + name="innerMethod", + file_path=Path("/test.js"), + start_line=2, + end_line=4, + parents=(ParentInfo(name="Indented", type="ClassDef"),), + is_method=True, + ) + # New code with no indentation + new_code = """innerMethod() { + return 42; +} +""" + result = js_support.replace_function(source, func, new_code) + + # Check that indentation was adjusted + lines = result.splitlines() + method_line = next(l for l in lines if "innerMethod" in l) + # Should have 8 spaces (original indentation) + assert method_line.startswith(" ") + + assert js_support.validate_syntax(result) is True + + +class TestClassMethodEdgeCases: + """Edge case tests for class method handling.""" + + def test_class_with_constructor(self, js_support): + """Test handling classes with constructors.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class Counter { + constructor(start = 0) { + this.value = start; + } + + increment() { + return ++this.value; + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + + # Should find constructor and increment + names = {f.name for f in functions} + assert "constructor" in names or "increment" in names + + def test_class_with_getters_setters(self, js_support): + """Test handling classes with getters and setters.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class Person { + constructor(name) { + this._name = name; + } + + get name() { + return this._name; + } + + set name(value) { + this._name = value; + } + + greet() { + return 'Hello, ' + this._name; + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + + # Should find at least greet + names = {f.name for f in functions} + assert "greet" in names + + def test_class_extending_another(self, js_support): + """Test handling classes that extend another class.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class Animal { + speak() { + return 'sound'; + } +} + +class Dog extends Animal { + speak() { + return 'bark'; + } + + fetch() { + return 'ball'; + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + + # Find Dog's fetch method + fetch_method = next((f for f in functions if f.name == "fetch" and f.class_name == "Dog"), None) + + if fetch_method: + context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent) + + # Full string equality check + expected_code = """class Dog { + fetch() { + return 'ball'; + } +} +""" + assert context.target_code == expected_code, f"Expected:\n{expected_code}\nGot:\n{context.target_code}" + assert js_support.validate_syntax(context.target_code) is True + + def test_class_with_private_method(self, js_support): + """Test handling classes with private methods (ES2022+).""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class SecureClass { + #privateMethod() { + return 'secret'; + } + + publicMethod() { + return this.#privateMethod(); + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + + # Should at least find publicMethod + names = {f.name for f in functions} + assert "publicMethod" in names + + def test_commonjs_class_export(self, js_support): + """Test handling CommonJS exported classes.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""class Calculator { + add(a, b) { + return a + b; + } +} + +module.exports = { Calculator }; +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + add_method = next(f for f in functions if f.name == "add") + + context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) + + assert "class Calculator" in context.target_code + assert js_support.validate_syntax(context.target_code) is True + + def test_es_module_class_export(self, js_support): + """Test handling ES module exported classes.""" + with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: + f.write("""export class Calculator { + add(a, b) { + return a + b; + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = js_support.discover_functions(file_path) + + # Find the add method + add_method = next((f for f in functions if f.name == "add"), None) + + if add_method: + context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) + assert js_support.validate_syntax(context.target_code) is True