diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index bed07a7f4..abbcc5119 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -46,6 +46,8 @@ class AiServiceClient: self.llm_call_counter = count(1) self.is_local = self.base_url == "http://localhost:8000" self.timeout: float | None = 300 if self.is_local else 90 + # React components are larger (300+ lines of JSX) and need more LLM processing time + self.react_timeout: float | None = 300 if self.is_local else 180 def get_next_sequence(self) -> int: """Get the next LLM call sequence number.""" @@ -203,7 +205,8 @@ class AiServiceClient: logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}") try: - response = self.make_ai_service_request("/optimize", payload=payload, timeout=self.timeout) + timeout = self.react_timeout if is_react_component else self.timeout + response = self.make_ai_service_request("/optimize", payload=payload, timeout=timeout) except requests.exceptions.RequestException as e: logger.exception(f"Error generating optimized candidates: {e}") ph("cli-optimize-error-caught", {"error": str(e)}) @@ -806,7 +809,8 @@ class AiServiceClient: # DEBUG: Print payload language field logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'") try: - response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout) + timeout = self.react_timeout if is_react_component else self.timeout + response = self.make_ai_service_request("/testgen", payload=payload, timeout=timeout) except requests.exceptions.RequestException as e: logger.exception(f"Error generating tests: {e}") ph("cli-testgen-error-caught", {"error": str(e)}) diff --git a/codeflash/languages/javascript/comparator.py b/codeflash/languages/javascript/comparator.py index 05f34f839..3a1afa611 100644 --- a/codeflash/languages/javascript/comparator.py +++ b/codeflash/languages/javascript/comparator.py @@ -143,6 +143,8 @@ def compare_test_results( scope = TestDiffScope.STDOUT elif scope_str == "did_pass": scope = TestDiffScope.DID_PASS + elif scope_str == "dom_snapshot": + scope = TestDiffScope.DOM_SNAPSHOT test_info = diff.get("test_info", {}) # Build a test identifier string for JavaScript tests diff --git a/codeflash/languages/javascript/edit_tests.py b/codeflash/languages/javascript/edit_tests.py index 00ba04f9c..2be03f76e 100644 --- a/codeflash/languages/javascript/edit_tests.py +++ b/codeflash/languages/javascript/edit_tests.py @@ -210,7 +210,10 @@ def normalize_codeflash_imports(source: str) -> str: # Replace CommonJS require source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source) # Replace ES module import - return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source) + source = _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source) + # Fix outdated @testing-library/jest-dom import paths (v6+ removed /extend-expect subpath) + source = source.replace("@testing-library/jest-dom/extend-expect", "@testing-library/jest-dom") + return source # Author: ali diff --git a/codeflash/languages/javascript/frameworks/react/profiler.py b/codeflash/languages/javascript/frameworks/react/profiler.py index 89bed13ec..d734ec55c 100644 --- a/codeflash/languages/javascript/frameworks/react/profiler.py +++ b/codeflash/languages/javascript/frameworks/react/profiler.py @@ -72,42 +72,37 @@ def instrument_component_with_profiler(source: str, component_name: str, analyze if computed is not None: replacements.append(computed) + # Work entirely in bytes to avoid byte-offset vs char-index mismatches + # (tree-sitter returns byte offsets; non-ASCII chars cause divergence from Python string indices) if replacements: - # Reconstruct result in a single pass - parts: list[str] = [] + byte_parts: list[bytes] = [] prev = 0 for start, end, wrapped in replacements: - # Use original source slices (string indices expected by original logic) - parts.append(source[prev:start]) - parts.append(wrapped) + byte_parts.append(source_bytes[prev:start]) + byte_parts.append(wrapped.encode("utf-8")) prev = end - parts.append(source[prev:]) - result = "".join(parts) + byte_parts.append(source_bytes[prev:]) + result_bytes = b"".join(byte_parts) else: - result = source - - # Add render counter code at the top (after imports) using the already-parsed tree + result_bytes = source_bytes # Add render counter code at the top (after imports) counter_code = generate_render_counter_code(component_name) - # Inline logic similar to _insert_after_imports but reuse existing `tree` to avoid re-parsing last_import_end = 0 for child in tree.root_node.children: if child.type == "import_statement": last_import_end = child.end_byte insert_pos = last_import_end - while insert_pos < len(result) and result[insert_pos] != "\n": + while insert_pos < len(result_bytes) and result_bytes[insert_pos : insert_pos + 1] != b"\n": insert_pos += 1 - if insert_pos < len(result): + if insert_pos < len(result_bytes): insert_pos += 1 # skip the newline - result = result[:insert_pos] + "\n" + counter_code + "\n\n" + result[insert_pos:] + counter_bytes = ("\n" + counter_code + "\n\n").encode("utf-8") + result = (result_bytes[:insert_pos] + counter_bytes + result_bytes[insert_pos:]).decode("utf-8") - # Ensure React is imported - - # Ensure React is imported return _ensure_react_import(result) @@ -317,7 +312,8 @@ def _wrap_return_with_profiler(source: str, return_node: Node, profiler_id: str, f"" ) - return source[:jsx_start] + wrapped + source[jsx_end:] + # Use byte-level splicing to avoid byte-offset vs char-index mismatches + return (source_bytes[:jsx_start] + wrapped.encode("utf-8") + source_bytes[jsx_end:]).decode("utf-8") def _insert_after_imports(source: str, code: str, analyzer: TreeSitterAnalyzer) -> str: @@ -421,7 +417,7 @@ let _codeflash_render_count_{safe_name} = 0; if (typeof beforeEach !== 'undefined') {{ beforeEach(() => {{ _codeflash_render_count_{safe_name} = 0; }}); }} -function _codeflashOnRender_{safe_name}(id, phase, actualDuration, baseDuration) {{ +function _codeflashOnRender_{safe_name}(id: any, phase: any, actualDuration: any, baseDuration: any) {{ _codeflash_render_count_{safe_name}++; console.log(`!######{marker_prefix}:${{id}}:${{phase}}:${{actualDuration}}:${{baseDuration}}:${{_codeflash_render_count_{safe_name}}}######!`); }}""" diff --git a/codeflash/languages/javascript/frameworks/react/testgen.py b/codeflash/languages/javascript/frameworks/react/testgen.py index 9359068d9..66b7e19a1 100644 --- a/codeflash/languages/javascript/frameworks/react/testgen.py +++ b/codeflash/languages/javascript/frameworks/react/testgen.py @@ -98,6 +98,15 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf """ result = test_source + # Fix outdated @testing-library/jest-dom import paths (v6+ removed /extend-expect subpath) + result = result.replace("@testing-library/jest-dom/extend-expect", "@testing-library/jest-dom") + # Also fix require() variant + result = re.sub( + r"""require\s*\(\s*['"]@testing-library/jest-dom/extend-expect['"]\s*\)""", + "require('@testing-library/jest-dom')", + result, + ) + # Ensure testing-library import if "@testing-library/react" not in result: result = "import { render, screen, act } from '@testing-library/react';\n" + result @@ -168,14 +177,14 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf # This gives per-interaction A/B signal without the LLM needing to know about it. result = inject_interaction_markers(result) - # Warn if no tests contain interaction calls — mount-phase only markers are - # not useful for measuring optimization effectiveness. + # If no tests contain interaction calls, auto-inject a rerender fallback so + # that EVERY React perf test produces at least one update-phase marker. if not has_react_test_interactions(result): logger.warning( - "[REACT] Generated tests for %s contain no interactions (fireEvent, userEvent, rerender). " - "Tests will produce only mount-phase markers which cannot measure optimization improvements.", + "[REACT] Generated tests for %s contain no interactions — auto-injecting rerender fallback.", component_info.function_name, ) + result = _inject_rerender_fallback(result, component_info.function_name) # Check interaction density — fewer than MIN_INTERACTION_CALLS total interactions # means the test is unlikely to produce enough update-phase renders for reliable measurement. @@ -221,6 +230,48 @@ def _extract_interaction_label(call_text: str) -> str: return m.group(1) if m else "interaction" +def inject_dom_snapshot_calls(test_source: str) -> str: + """Inject codeflash.snapshotDOM() calls after each user interaction in behavior mode. + + Only active when `captureRender` (not `captureRenderPerf`) is present, + meaning the test is running in behavioral verification mode. + + After each fireEvent.*, userEvent.*, or rerender() call, inserts: + codeflash.snapshotDOM('after_{label}_{n}'); + on the next line with matching indentation. + """ + if "captureRender" not in test_source: + return test_source + if "captureRenderPerf" in test_source: + return test_source + + interaction_counter: dict[str, int] = {} + lines = test_source.split("\n") + new_lines: list[str] = [] + + # Also match rerender() calls — use .* to handle nested parens like getByText('Add') + snapshot_interaction_pattern = re.compile( + r"^(\s*)((?:await\s+)?(?:fireEvent\.\w+|userEvent\.\w+|(?:\w+\.)?rerender)\s*\(.*\))\s*;?\s*$", + re.MULTILINE, + ) + + for line in lines: + new_lines.append(line) + m = snapshot_interaction_pattern.match(line) + if m: + indent = m.group(1) + call_text = m.group(2) + label = _extract_interaction_label(call_text) + if label == "interaction": + # rerender() call + label = "rerender" + interaction_counter[label] = interaction_counter.get(label, 0) + 1 + unique_label = f"{label}_{interaction_counter[label]}" + new_lines.append(f"{indent}codeflash.snapshotDOM('after_{unique_label}');") + + return "\n".join(new_lines) + + def inject_interaction_markers(test_source: str) -> str: """Inject _codeflashMarkInteraction() calls before each fireEvent/userEvent call. @@ -325,3 +376,54 @@ def has_high_density_interactions(test_source: str) -> bool: interaction_calls = _INTERACTION_PATTERNS.findall(test_source) return len(interaction_calls) >= _MIN_SEQUENTIAL_INTERACTIONS + + +# Pattern to extract props from the first render() call +_RENDER_CALL_PROPS_PATTERN = re.compile( + r"render\s*\(\s*<(\w+)\s+([^/]*?)\s*/?\s*>", +) + + +def _extract_render_props(test_source: str, component_name: str) -> str | None: + """Extract the props expression from the first render() call.""" + for m in _RENDER_CALL_PROPS_PATTERN.finditer(test_source): + if m.group(1) == component_name: + props_text = m.group(2).strip() + if props_text: + return props_text + return None + + +def _inject_rerender_fallback(test_source: str, component_name: str) -> str: + """Inject a rerender efficiency test block when the test has no interactions. + + This ensures every React perf test produces at least one update-phase marker. + """ + # Try to extract props from existing render call + props_expr = _extract_render_props(test_source, component_name) + if props_expr: + jsx_open = f"<{component_name} {props_expr} />" + else: + jsx_open = f"<{component_name} />" + + rerender_block = f""" +describe('{component_name} rerender efficiency (auto-generated)', () => {{ + it('should handle same-props rerenders', () => {{ + const {{ rerender }} = render({jsx_open}); + for (let i = 0; i < 10; i++) {{ + rerender({jsx_open}); + }} + }}); +}}); +""" + # Ensure {rerender} is in the @testing-library/react import + rtl_import_match = re.search( + r"import\s*\{([^}]+)\}\s*from\s*['\"]@testing-library/react['\"]", test_source + ) + if rtl_import_match: + imports = rtl_import_match.group(1) + if "rerender" not in imports: + # Don't add 'rerender' to import — it comes from render() return value, not an import + pass + + return test_source.rstrip() + "\n" + rerender_block diff --git a/codeflash/languages/javascript/instrument.py b/codeflash/languages/javascript/instrument.py index d5437f5d8..815211152 100644 --- a/codeflash/languages/javascript/instrument.py +++ b/codeflash/languages/javascript/instrument.py @@ -1346,6 +1346,13 @@ def _instrument_js_test_code( code=code, function_to_optimize=function_to_optimize, capture_func=capture_func, start_counter=final_counter ) + # In behavior mode, inject DOM snapshot calls after each interaction so the + # comparator can detect post-interaction DOM divergence between original and candidate. + if mode == TestingMode.BEHAVIOR: + from codeflash.languages.javascript.frameworks.react.testgen import inject_dom_snapshot_calls # noqa: PLC0415 + + code = inject_dom_snapshot_calls(code) + return code diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index e57bdf7b5..93a2f08aa 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -174,6 +174,14 @@ class JavaScriptSupport: logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004 continue + # Skip closures inside React components — they capture hook state + # and cannot be tested independently + if func.parent_function and func.parent_function in react_component_map: + logger.debug( + f"Skipping closure {func.name} inside React component {func.parent_function}" # noqa: G004 + ) + continue + # Build parents list parents: list[FunctionParent] = [] if func.class_name: @@ -231,8 +239,23 @@ class JavaScriptSupport: source, include_methods=True, include_arrow_functions=True, require_name=True ) + # Build React component lookup to filter out closures + react_component_names: set[str] = set() + try: + from codeflash.languages.javascript.frameworks.react.discovery import classify_component # noqa: PLC0415 + + for func in tree_functions: + if classify_component(func, source, analyzer) is not None: + react_component_names.add(func.name) + except Exception: + pass + functions: list[FunctionToOptimize] = [] for func in tree_functions: + # Skip closures inside React components + if func.parent_function and func.parent_function in react_component_names: + continue + # Build parents list parents: list[FunctionParent] = [] if func.class_name: @@ -1212,6 +1235,47 @@ class JavaScriptSupport: # === Code Transformation === + @staticmethod + def _strip_imports(source: str, analyzer: TreeSitterAnalyzer) -> str: + """Strip import/require statements from source code. + + Used to remove duplicate imports from optimizer output before inserting + into the original file (which already has its own imports). + """ + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + + # Collect byte ranges of import nodes to remove + ranges_to_remove: list[tuple[int, int]] = [] + for child in tree.root_node.children: + if child.type == "import_statement": + ranges_to_remove.append((child.start_byte, child.end_byte)) + # CJS require: const x = require('...') + elif child.type in ("lexical_declaration", "variable_declaration"): + text = source_bytes[child.start_byte : child.end_byte].decode("utf8") + if "require(" in text: + ranges_to_remove.append((child.start_byte, child.end_byte)) + + if not ranges_to_remove: + return source + + # Build result, skipping removed ranges + parts: list[bytes] = [] + prev_end = 0 + for start, end in sorted(ranges_to_remove): + parts.append(source_bytes[prev_end:start]) + # Skip trailing newline after import + if end < len(source_bytes) and source_bytes[end : end + 1] == b"\n": + end += 1 + prev_end = end + parts.append(source_bytes[prev_end:]) + + result = b"".join(parts).decode("utf8") + # Remove leading blank lines left by stripping imports + while result.startswith("\n"): + result = result[1:] + return result + def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str: """Replace a function in source code with new implementation. @@ -1247,6 +1311,15 @@ class JavaScriptSupport: else: analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT) + # Strip imports from optimizer output — the original file already has them, + # and import merging is handled by _add_global_declarations_for_language() + new_source = self._strip_imports(new_source, analyzer) + + # If stripping imports left nothing, return original + if not new_source.strip(): + logger.warning("new_source was only imports for %s, returning original", function.function_name) + return source + # Check if new_source contains a JSDoc comment - if so, use full replacement # to include the updated JSDoc along with the function body stripped_new_source = new_source.strip() @@ -1263,16 +1336,64 @@ class JavaScriptSupport: logger.warning( "Could not extract body for %s from optimized code, using full replacement", function.function_name ) - # Verify that new_source contains actual code before falling back to text replacement - # This prevents deletion of the original function when new_source is invalid - if not self._contains_function_declaration(new_source, function.function_name, analyzer): - logger.warning("new_source does not contain function %s, returning original", function.function_name) - return source - return self._replace_function_text_based(source, function, new_source, analyzer) + if self._contains_function_declaration(new_source, function.function_name, analyzer): + return self._replace_function_text_based(source, function, new_source, analyzer) + # Final fallback: line-range replacement using the function's known line boundaries. + # This handles cases where tree-sitter can't parse the optimized output. + logger.warning( + "Falling back to line-range replacement for %s (lines %d-%d)", + function.function_name, + function.starting_line, + function.ending_line, + ) + return self._replace_function_by_line_range(source, function, new_source) # Find the original function and replace its body return self._replace_function_body(source, function, new_body, analyzer) + def _replace_function_by_line_range( + self, source: str, function: FunctionToOptimize, new_source: str + ) -> str: + """Last-resort replacement: cut the original function by line range and paste new_source.""" + if function.starting_line is None or function.ending_line is None: + return source + + lines = source.splitlines(keepends=True) + if lines and not lines[-1].endswith("\n"): + lines[-1] += "\n" + + # Get indentation from original function's first line + if function.starting_line <= len(lines): + original_first = lines[function.starting_line - 1] + original_indent = len(original_first) - len(original_first.lstrip()) + else: + original_indent = 0 + + new_lines = new_source.splitlines(keepends=True) + if new_lines: + first_non_blank = next((l for l in new_lines if l.strip()), new_lines[0]) + new_indent = len(first_non_blank) - len(first_non_blank.lstrip()) + indent_diff = original_indent - new_indent + if indent_diff != 0: + adjusted = [] + for line in new_lines: + if line.strip(): + if indent_diff > 0: + adjusted.append(" " * indent_diff + line) + else: + cur = len(line) - len(line.lstrip()) + adjusted.append(line[min(cur, abs(indent_diff)) :]) + else: + adjusted.append(line) + new_lines = adjusted + + if new_lines and not new_lines[-1].endswith("\n"): + new_lines[-1] += "\n" + + before = lines[: function.starting_line - 1] + after = lines[function.ending_line :] + return "".join(before + new_lines + after) + def _contains_function_declaration(self, source: str, function_name: str, analyzer: TreeSitterAnalyzer) -> bool: """Check if source contains a function declaration with the given name. @@ -1912,7 +2033,52 @@ class JavaScriptSupport: generated_tests = inject_test_globals(generated_tests, test_framework) if is_typescript(): generated_tests = disable_ts_check(generated_tests) - return normalize_generated_tests_imports(generated_tests) + generated_tests = normalize_generated_tests_imports(generated_tests) + + # Apply React-specific post-processing (act wrapping, interaction markers, etc.) + if self._is_react_source(source_file_path): + generated_tests = self._apply_react_postprocessing(generated_tests, source_file_path) + + return generated_tests + + def _is_react_source(self, source_file_path: Path) -> bool: + """Check if the source file is a React component.""" + try: + content = source_file_path.read_text("utf-8", errors="replace") + return "react" in content.lower() and ("jsx" in content.lower() or source_file_path.suffix in (".tsx", ".jsx")) + except OSError: + return False + + def _apply_react_postprocessing(self, generated_tests: GeneratedTestsList, source_file_path: Path) -> GeneratedTestsList: + """Apply React-specific post-processing to generated tests.""" + from codeflash.languages.javascript.frameworks.react.testgen import ( # noqa: PLC0415 + post_process_react_tests, + ) + from codeflash.languages.javascript.frameworks.react.discovery import ComponentType, ReactComponentInfo # noqa: PLC0415 + from codeflash.models.models import GeneratedTests, GeneratedTestsList # noqa: PLC0415 + + component_name = source_file_path.stem + dummy_info = ReactComponentInfo( + function_name=component_name, + component_type=ComponentType.FUNCTION, + ) + + updated_tests = [] + for test in generated_tests.generated_tests: + updated_tests.append( + GeneratedTests( + generated_original_test_source=test.generated_original_test_source, + instrumented_behavior_test_source=post_process_react_tests( + test.instrumented_behavior_test_source, dummy_info + ), + instrumented_perf_test_source=post_process_react_tests( + test.instrumented_perf_test_source, dummy_info + ), + behavior_file_path=test.behavior_file_path, + perf_file_path=test.perf_file_path, + ) + ) + return GeneratedTestsList(generated_tests=updated_tests) def remove_test_functions_from_generated_tests( self, generated_tests: GeneratedTestsList, functions_to_remove: list[str] diff --git a/codeflash/languages/javascript/treesitter.py b/codeflash/languages/javascript/treesitter.py index f23e209fe..cc86fbca9 100644 --- a/codeflash/languages/javascript/treesitter.py +++ b/codeflash/languages/javascript/treesitter.py @@ -401,6 +401,29 @@ class TreeSitterAnalyzer: great_grandparent = grandparent.parent if great_grandparent and great_grandparent.type == "export_statement": return True + # HOC wrappers: export const Comp = forwardRef/memo(function/arrow) + # Tree: function → arguments → call_expression → variable_declarator → lexical_declaration → export_statement + if parent and parent.type == "arguments": + call_expr = parent.parent + if call_expr and call_expr.type == "call_expression": + declarator = call_expr.parent + if declarator and declarator.type == "variable_declarator": + lex_decl = declarator.parent + if lex_decl and lex_decl.type in ("lexical_declaration", "variable_declaration"): + export_stmt = lex_decl.parent + if export_stmt and export_stmt.type == "export_statement": + return True + # Nested HOC: export const X = memo(forwardRef(fn)) + if declarator and declarator.type == "arguments": + outer_call = declarator.parent + if outer_call and outer_call.type == "call_expression": + outer_decl = outer_call.parent + if outer_decl and outer_decl.type == "variable_declarator": + lex_decl = outer_decl.parent + if lex_decl and lex_decl.type in ("lexical_declaration", "variable_declaration"): + export_stmt = lex_decl.parent + if export_stmt and export_stmt.type == "export_statement": + return True # For methods in exported classes if node.type == "method_definition": @@ -609,6 +632,8 @@ class TreeSitterAnalyzer: return None + _HOC_WRAPPER_NAMES = frozenset({"forwardRef", "memo", "React.forwardRef", "React.memo"}) + def _get_name_from_assignment(self, node: Node, source_bytes: bytes) -> str: """Try to extract function name from parent variable declaration or assignment. @@ -617,6 +642,8 @@ class TreeSitterAnalyzer: - const foo = function() {} - let bar = function() {} - obj.method = () => {} + - const Tooltip = forwardRef(function TooltipInner(...) {...}) + - const MemoComp = React.memo((props) => {...}) """ parent = node.parent if parent is None: @@ -628,6 +655,30 @@ class TreeSitterAnalyzer: if name_node: return self.get_node_text(name_node, source_bytes) + # Check for HOC wrapper: const Name = forwardRef/memo(function/arrow) + # Tree structure: function → arguments → call_expression → variable_declarator + if parent.type == "arguments": + call_expr = parent.parent + if call_expr is not None and call_expr.type == "call_expression": + callee = call_expr.child_by_field_name("function") + if callee is not None: + callee_text = self.get_node_text(callee, source_bytes) + if callee_text in self._HOC_WRAPPER_NAMES: + declarator = call_expr.parent + if declarator is not None and declarator.type == "variable_declarator": + name_node = declarator.child_by_field_name("name") + if name_node: + return self.get_node_text(name_node, source_bytes) + # Nested HOC: memo(forwardRef(fn)) — call_expr parent is arguments of outer call + if declarator is not None and declarator.type == "arguments": + outer_call = declarator.parent + if outer_call is not None and outer_call.type == "call_expression": + outer_declarator = outer_call.parent + if outer_declarator is not None and outer_declarator.type == "variable_declarator": + name_node = outer_declarator.child_by_field_name("name") + if name_node: + return self.get_node_text(name_node, source_bytes) + # Check for assignment expression: foo = ... if parent.type == "assignment_expression": left_node = parent.child_by_field_name("left") diff --git a/codeflash/models/models.py b/codeflash/models/models.py index dac39246d..413444a31 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -83,6 +83,7 @@ class TestDiffScope(str, Enum): RETURN_VALUE = "return_value" STDOUT = "stdout" DID_PASS = "did_pass" # noqa: S105 + DOM_SNAPSHOT = "dom_snapshot" @dataclass diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 6946acc6d..37116ed0b 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -96,6 +96,7 @@ from codeflash.models.models import ( OptimizedCandidateResult, OptimizedCandidateSource, OriginalCodeBaseline, + TestDiffScope, TestFile, TestFiles, TestingMode, @@ -2955,6 +2956,13 @@ class FunctionOptimizer: logger.info("h3|Test results matched ✅") console.rule() else: + dom_snapshot_diffs = [d for d in diffs if d.scope == TestDiffScope.DOM_SNAPSHOT] + if dom_snapshot_diffs: + logger.warning( + "[REACT] DOM snapshot divergence detected after %d interaction(s). " + "The optimized component produces different DOM output.", + len(dom_snapshot_diffs), + ) self.repair_if_possible( candidate, diffs, eval_ctx, code_context, len(candidate_behavior_results), exp_type ) @@ -3178,8 +3186,7 @@ class FunctionOptimizer: coverage_config_file=coverage_config_file, skip_sqlite_cleanup=skip_cleanup, ) - if testing_type == TestingMode.PERFORMANCE: - results.perf_stdout = run_result.stdout + results.perf_stdout = run_result.stdout return results, coverage_results # For LINE_PROFILE mode, Python uses .lprof files while JavaScript uses JSON # Return TestResults for JavaScript so _line_profiler_step_javascript can parse the JSON diff --git a/packages/codeflash/runtime/capture.js b/packages/codeflash/runtime/capture.js index daffbb33e..e57728dcc 100644 --- a/packages/codeflash/runtime/capture.js +++ b/packages/codeflash/runtime/capture.js @@ -46,11 +46,11 @@ function _getReact() { } } -// Try to load better-sqlite3, fall back to JSON if not available -let useSqlite = false; - // Configuration from environment const OUTPUT_FILE = process.env.CODEFLASH_OUTPUT_FILE; + +// Enable SQLite when better-sqlite3 is available and an output file is configured +let useSqlite = !!(Database && OUTPUT_FILE); const LOOP_INDEX = parseInt(process.env.CODEFLASH_LOOP_INDEX || '1', 10); const TEST_ITERATION = process.env.CODEFLASH_TEST_ITERATION; const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE; @@ -1091,6 +1091,35 @@ function captureRender(funcName, lineId, renderFn, Component, ...createElementAr return renderResult; } +/** + * Capture a DOM snapshot after a user interaction for behavioral verification. + * + * Writes normalized `document.body.innerHTML` to the SQLite database as an + * additional row with function name `__dom_snapshot__`. The existing + * `compare-results.js` comparator compares all rows by invocation ID, so + * snapshot rows are compared automatically. + * + * @param {string} label - Unique label for this snapshot (e.g. 'after_click_1') + */ +function snapshotDOM(label) { + if (!useSqlite || !db) return; + if (typeof document === 'undefined' || !document.body) return; + + // Normalize HTML: collapse whitespace, strip React-internal attributes + let html = document.body.innerHTML; + html = html.replace(/\s+/g, ' ').trim(); + html = html.replace(/\s*data-reactroot\s*/g, ''); + + const { testModulePath, testClassName, testFunctionName } = _getTestContext(); + const invocationId = `snapshot_${label}`; + + recordResult( + testModulePath, testClassName, testFunctionName, + '__dom_snapshot__', invocationId, + [label], html, null, 0 + ); +} + /** * Capture a React component render call for PERFORMANCE benchmarking only. * @@ -1334,6 +1363,7 @@ module.exports = { capturePerf, // Performance benchmarking (prints to stdout only) captureRender, // React render behavior verification (writes to SQLite) captureRenderPerf, // React render performance benchmarking (prints to stdout only) + snapshotDOM, // DOM snapshot after interaction (behavior verification) captureMultiple, writeResults, clearResults, diff --git a/packages/codeflash/runtime/compare-results.js b/packages/codeflash/runtime/compare-results.js index 478332ee7..f957ad9c9 100644 --- a/packages/codeflash/runtime/compare-results.js +++ b/packages/codeflash/runtime/compare-results.js @@ -167,9 +167,13 @@ function compareResults(originalResults, candidateResults) { if (!isEqual) { allEquivalent = false; + // Use dom_snapshot scope for __dom_snapshot__ rows + const scope = original.functionGettingTested === '__dom_snapshot__' + ? 'dom_snapshot' + : 'return_value'; diffs.push({ invocation_id: invocationId, - scope: 'return_value', + scope, original: summarizeValue(originalValue), candidate: summarizeValue(candidateValue), test_info: { diff --git a/packages/codeflash/runtime/index.d.ts b/packages/codeflash/runtime/index.d.ts index 2e7b904eb..94cb359f4 100644 --- a/packages/codeflash/runtime/index.d.ts +++ b/packages/codeflash/runtime/index.d.ts @@ -36,6 +36,15 @@ export function capturePerf any>( ...args: Parameters ): ReturnType; +/** + * Capture a DOM snapshot after a user interaction for behavioral verification. + * Writes normalized document.body.innerHTML to the SQLite database with + * function name '__dom_snapshot__' so the comparator detects DOM divergence. + * + * @param label - Unique label for this snapshot (e.g. 'after_click_1') + */ +export function snapshotDOM(label: string): void; + /** * Capture multiple invocations for benchmarking. * @@ -126,6 +135,7 @@ export const TEST_ITERATION: string; declare const codeflash: { capture: typeof capture; capturePerf: typeof capturePerf; + snapshotDOM: typeof snapshotDOM; captureMultiple: typeof captureMultiple; writeResults: typeof writeResults; clearResults: typeof clearResults; diff --git a/packages/codeflash/runtime/index.js b/packages/codeflash/runtime/index.js index bf2fdbbf5..eded9b933 100644 --- a/packages/codeflash/runtime/index.js +++ b/packages/codeflash/runtime/index.js @@ -38,6 +38,7 @@ module.exports = { captureRender: capture.captureRender, captureRenderPerf: capture.captureRenderPerf, + snapshotDOM: capture.snapshotDOM, captureMultiple: capture.captureMultiple, diff --git a/tests/react/test_benchmarking.py b/tests/react/test_benchmarking.py index 778ea83da..2fef4f621 100644 --- a/tests/react/test_benchmarking.py +++ b/tests/react/test_benchmarking.py @@ -173,6 +173,8 @@ class TestRenderEfficiencyCritic: optimized_render_duration=10.0, original_update_render_count=8, optimized_update_render_count=8, + original_update_duration=80.0, + optimized_update_duration=5.0, ) is True def test_rejects_worse_than_best(self): @@ -448,6 +450,8 @@ class TestRenderEfficiencyCriticTrustDuration: optimized_render_duration=10.0, original_update_render_count=8, optimized_update_render_count=8, + original_update_duration=100.0, + optimized_update_duration=10.0, trust_duration=True, ) is True diff --git a/tests/react/test_dom_snapshot.py b/tests/react/test_dom_snapshot.py new file mode 100644 index 000000000..d5ebd608b --- /dev/null +++ b/tests/react/test_dom_snapshot.py @@ -0,0 +1,78 @@ +"""Tests for inject_dom_snapshot_calls in React behavioral verification.""" + +from __future__ import annotations + +from codeflash.languages.javascript.frameworks.react.testgen import inject_dom_snapshot_calls + + +def test_injects_after_fireEvent(): + source = """\ +import codeflash from 'codeflash'; +const { container } = codeflash.captureRender('Comp', '1', render, Comp); +fireEvent.click(screen.getByText('Add')); +fireEvent.change(input, { target: { value: 'hi' } }); +""" + result = inject_dom_snapshot_calls(source) + assert "codeflash.snapshotDOM('after_click_1');" in result + assert "codeflash.snapshotDOM('after_change_1');" in result + + +def test_skips_perf_mode(): + source = """\ +import codeflash from 'codeflash'; +const result = await codeflash.captureRenderPerf('Comp', '1', render, Comp); +fireEvent.click(screen.getByText('Add')); +""" + result = inject_dom_snapshot_calls(source) + assert "snapshotDOM" not in result + + +def test_skips_without_captureRender(): + source = """\ +import codeflash from 'codeflash'; +fireEvent.click(screen.getByText('Add')); +""" + result = inject_dom_snapshot_calls(source) + assert "snapshotDOM" not in result + + +def test_preserves_indentation(): + source = """\ +import codeflash from 'codeflash'; +const { container } = codeflash.captureRender('Comp', '1', render, Comp); + fireEvent.click(btn); + fireEvent.change(input, { target: { value: 'x' } }); +""" + result = inject_dom_snapshot_calls(source) + lines = result.split("\n") + # Find the snapshot lines and check their indentation + snapshot_lines = [l for l in lines if "snapshotDOM" in l] + assert len(snapshot_lines) == 2 + assert snapshot_lines[0].startswith(" codeflash.snapshotDOM") + assert snapshot_lines[1].startswith(" codeflash.snapshotDOM") + + +def test_no_semicolons(): + """Real-world projects (e.g. Zustand) don't use semicolons.""" + source = """\ +const { container } = codeflash.captureRender('Comp', '1', render, Comp) + fireEvent.click(screen.getByText('button')) + fireEvent.click(screen.getByTestId('test-shallow')) +""" + result = inject_dom_snapshot_calls(source) + assert "after_click_1" in result + assert "after_click_2" in result + + +def test_sequential_counter(): + source = """\ +import codeflash from 'codeflash'; +const { container } = codeflash.captureRender('Comp', '1', render, Comp); +fireEvent.click(btn); +fireEvent.click(btn); +fireEvent.click(btn); +""" + result = inject_dom_snapshot_calls(source) + assert "after_click_1" in result + assert "after_click_2" in result + assert "after_click_3" in result