snapshot comparison in verification

This commit is contained in:
Sarthak Agarwal 2026-04-07 20:11:54 +05:30
parent 4bc89f2b9d
commit b5ba6df690
16 changed files with 505 additions and 39 deletions

View file

@ -46,6 +46,8 @@ class AiServiceClient:
self.llm_call_counter = count(1) self.llm_call_counter = count(1)
self.is_local = self.base_url == "http://localhost:8000" self.is_local = self.base_url == "http://localhost:8000"
self.timeout: float | None = 300 if self.is_local else 90 self.timeout: float | None = 300 if self.is_local else 90
# React components are larger (300+ lines of JSX) and need more LLM processing time
self.react_timeout: float | None = 300 if self.is_local else 180
def get_next_sequence(self) -> int: def get_next_sequence(self) -> int:
"""Get the next LLM call sequence number.""" """Get the next LLM call sequence number."""
@ -203,7 +205,8 @@ class AiServiceClient:
logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}") logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}")
try: try:
response = self.make_ai_service_request("/optimize", payload=payload, timeout=self.timeout) timeout = self.react_timeout if is_react_component else self.timeout
response = self.make_ai_service_request("/optimize", payload=payload, timeout=timeout)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.exception(f"Error generating optimized candidates: {e}") logger.exception(f"Error generating optimized candidates: {e}")
ph("cli-optimize-error-caught", {"error": str(e)}) ph("cli-optimize-error-caught", {"error": str(e)})
@ -806,7 +809,8 @@ class AiServiceClient:
# DEBUG: Print payload language field # DEBUG: Print payload language field
logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'") logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'")
try: try:
response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout) timeout = self.react_timeout if is_react_component else self.timeout
response = self.make_ai_service_request("/testgen", payload=payload, timeout=timeout)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.exception(f"Error generating tests: {e}") logger.exception(f"Error generating tests: {e}")
ph("cli-testgen-error-caught", {"error": str(e)}) ph("cli-testgen-error-caught", {"error": str(e)})

View file

@ -143,6 +143,8 @@ def compare_test_results(
scope = TestDiffScope.STDOUT scope = TestDiffScope.STDOUT
elif scope_str == "did_pass": elif scope_str == "did_pass":
scope = TestDiffScope.DID_PASS scope = TestDiffScope.DID_PASS
elif scope_str == "dom_snapshot":
scope = TestDiffScope.DOM_SNAPSHOT
test_info = diff.get("test_info", {}) test_info = diff.get("test_info", {})
# Build a test identifier string for JavaScript tests # Build a test identifier string for JavaScript tests

View file

@ -210,7 +210,10 @@ def normalize_codeflash_imports(source: str) -> str:
# Replace CommonJS require # Replace CommonJS require
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source) source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
# Replace ES module import # Replace ES module import
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source) source = _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
# Fix outdated @testing-library/jest-dom import paths (v6+ removed /extend-expect subpath)
source = source.replace("@testing-library/jest-dom/extend-expect", "@testing-library/jest-dom")
return source
# Author: ali <mohammed18200118@gmail.com> # Author: ali <mohammed18200118@gmail.com>

View file

@ -72,42 +72,37 @@ def instrument_component_with_profiler(source: str, component_name: str, analyze
if computed is not None: if computed is not None:
replacements.append(computed) replacements.append(computed)
# Work entirely in bytes to avoid byte-offset vs char-index mismatches
# (tree-sitter returns byte offsets; non-ASCII chars cause divergence from Python string indices)
if replacements: if replacements:
# Reconstruct result in a single pass byte_parts: list[bytes] = []
parts: list[str] = []
prev = 0 prev = 0
for start, end, wrapped in replacements: for start, end, wrapped in replacements:
# Use original source slices (string indices expected by original logic) byte_parts.append(source_bytes[prev:start])
parts.append(source[prev:start]) byte_parts.append(wrapped.encode("utf-8"))
parts.append(wrapped)
prev = end prev = end
parts.append(source[prev:]) byte_parts.append(source_bytes[prev:])
result = "".join(parts) result_bytes = b"".join(byte_parts)
else: else:
result = source result_bytes = source_bytes
# Add render counter code at the top (after imports) using the already-parsed tree
# Add render counter code at the top (after imports) # Add render counter code at the top (after imports)
counter_code = generate_render_counter_code(component_name) counter_code = generate_render_counter_code(component_name)
# Inline logic similar to _insert_after_imports but reuse existing `tree` to avoid re-parsing
last_import_end = 0 last_import_end = 0
for child in tree.root_node.children: for child in tree.root_node.children:
if child.type == "import_statement": if child.type == "import_statement":
last_import_end = child.end_byte last_import_end = child.end_byte
insert_pos = last_import_end insert_pos = last_import_end
while insert_pos < len(result) and result[insert_pos] != "\n": while insert_pos < len(result_bytes) and result_bytes[insert_pos : insert_pos + 1] != b"\n":
insert_pos += 1 insert_pos += 1
if insert_pos < len(result): if insert_pos < len(result_bytes):
insert_pos += 1 # skip the newline insert_pos += 1 # skip the newline
result = result[:insert_pos] + "\n" + counter_code + "\n\n" + result[insert_pos:] counter_bytes = ("\n" + counter_code + "\n\n").encode("utf-8")
result = (result_bytes[:insert_pos] + counter_bytes + result_bytes[insert_pos:]).decode("utf-8")
# Ensure React is imported
# Ensure React is imported
return _ensure_react_import(result) return _ensure_react_import(result)
@ -317,7 +312,8 @@ def _wrap_return_with_profiler(source: str, return_node: Node, profiler_id: str,
f"</React.Profiler>" f"</React.Profiler>"
) )
return source[:jsx_start] + wrapped + source[jsx_end:] # Use byte-level splicing to avoid byte-offset vs char-index mismatches
return (source_bytes[:jsx_start] + wrapped.encode("utf-8") + source_bytes[jsx_end:]).decode("utf-8")
def _insert_after_imports(source: str, code: str, analyzer: TreeSitterAnalyzer) -> str: def _insert_after_imports(source: str, code: str, analyzer: TreeSitterAnalyzer) -> str:
@ -421,7 +417,7 @@ let _codeflash_render_count_{safe_name} = 0;
if (typeof beforeEach !== 'undefined') {{ if (typeof beforeEach !== 'undefined') {{
beforeEach(() => {{ _codeflash_render_count_{safe_name} = 0; }}); beforeEach(() => {{ _codeflash_render_count_{safe_name} = 0; }});
}} }}
function _codeflashOnRender_{safe_name}(id, phase, actualDuration, baseDuration) {{ function _codeflashOnRender_{safe_name}(id: any, phase: any, actualDuration: any, baseDuration: any) {{
_codeflash_render_count_{safe_name}++; _codeflash_render_count_{safe_name}++;
console.log(`!######{marker_prefix}:${{id}}:${{phase}}:${{actualDuration}}:${{baseDuration}}:${{_codeflash_render_count_{safe_name}}}######!`); console.log(`!######{marker_prefix}:${{id}}:${{phase}}:${{actualDuration}}:${{baseDuration}}:${{_codeflash_render_count_{safe_name}}}######!`);
}}""" }}"""

View file

@ -98,6 +98,15 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf
""" """
result = test_source result = test_source
# Fix outdated @testing-library/jest-dom import paths (v6+ removed /extend-expect subpath)
result = result.replace("@testing-library/jest-dom/extend-expect", "@testing-library/jest-dom")
# Also fix require() variant
result = re.sub(
r"""require\s*\(\s*['"]@testing-library/jest-dom/extend-expect['"]\s*\)""",
"require('@testing-library/jest-dom')",
result,
)
# Ensure testing-library import # Ensure testing-library import
if "@testing-library/react" not in result: if "@testing-library/react" not in result:
result = "import { render, screen, act } from '@testing-library/react';\n" + result result = "import { render, screen, act } from '@testing-library/react';\n" + result
@ -168,14 +177,14 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf
# This gives per-interaction A/B signal without the LLM needing to know about it. # This gives per-interaction A/B signal without the LLM needing to know about it.
result = inject_interaction_markers(result) result = inject_interaction_markers(result)
# Warn if no tests contain interaction calls — mount-phase only markers are # If no tests contain interaction calls, auto-inject a rerender fallback so
# not useful for measuring optimization effectiveness. # that EVERY React perf test produces at least one update-phase marker.
if not has_react_test_interactions(result): if not has_react_test_interactions(result):
logger.warning( logger.warning(
"[REACT] Generated tests for %s contain no interactions (fireEvent, userEvent, rerender). " "[REACT] Generated tests for %s contain no interactions — auto-injecting rerender fallback.",
"Tests will produce only mount-phase markers which cannot measure optimization improvements.",
component_info.function_name, component_info.function_name,
) )
result = _inject_rerender_fallback(result, component_info.function_name)
# Check interaction density — fewer than MIN_INTERACTION_CALLS total interactions # Check interaction density — fewer than MIN_INTERACTION_CALLS total interactions
# means the test is unlikely to produce enough update-phase renders for reliable measurement. # means the test is unlikely to produce enough update-phase renders for reliable measurement.
@ -221,6 +230,48 @@ def _extract_interaction_label(call_text: str) -> str:
return m.group(1) if m else "interaction" return m.group(1) if m else "interaction"
def inject_dom_snapshot_calls(test_source: str) -> str:
"""Inject codeflash.snapshotDOM() calls after each user interaction in behavior mode.
Only active when `captureRender` (not `captureRenderPerf`) is present,
meaning the test is running in behavioral verification mode.
After each fireEvent.*, userEvent.*, or rerender() call, inserts:
codeflash.snapshotDOM('after_{label}_{n}');
on the next line with matching indentation.
"""
if "captureRender" not in test_source:
return test_source
if "captureRenderPerf" in test_source:
return test_source
interaction_counter: dict[str, int] = {}
lines = test_source.split("\n")
new_lines: list[str] = []
# Also match rerender() calls — use .* to handle nested parens like getByText('Add')
snapshot_interaction_pattern = re.compile(
r"^(\s*)((?:await\s+)?(?:fireEvent\.\w+|userEvent\.\w+|(?:\w+\.)?rerender)\s*\(.*\))\s*;?\s*$",
re.MULTILINE,
)
for line in lines:
new_lines.append(line)
m = snapshot_interaction_pattern.match(line)
if m:
indent = m.group(1)
call_text = m.group(2)
label = _extract_interaction_label(call_text)
if label == "interaction":
# rerender() call
label = "rerender"
interaction_counter[label] = interaction_counter.get(label, 0) + 1
unique_label = f"{label}_{interaction_counter[label]}"
new_lines.append(f"{indent}codeflash.snapshotDOM('after_{unique_label}');")
return "\n".join(new_lines)
def inject_interaction_markers(test_source: str) -> str: def inject_interaction_markers(test_source: str) -> str:
"""Inject _codeflashMarkInteraction() calls before each fireEvent/userEvent call. """Inject _codeflashMarkInteraction() calls before each fireEvent/userEvent call.
@ -325,3 +376,54 @@ def has_high_density_interactions(test_source: str) -> bool:
interaction_calls = _INTERACTION_PATTERNS.findall(test_source) interaction_calls = _INTERACTION_PATTERNS.findall(test_source)
return len(interaction_calls) >= _MIN_SEQUENTIAL_INTERACTIONS return len(interaction_calls) >= _MIN_SEQUENTIAL_INTERACTIONS
# Pattern to extract props from the first render(<Component ...props... />) call
_RENDER_CALL_PROPS_PATTERN = re.compile(
r"render\s*\(\s*<(\w+)\s+([^/]*?)\s*/?\s*>",
)
def _extract_render_props(test_source: str, component_name: str) -> str | None:
"""Extract the props expression from the first render(<Component ...>) call."""
for m in _RENDER_CALL_PROPS_PATTERN.finditer(test_source):
if m.group(1) == component_name:
props_text = m.group(2).strip()
if props_text:
return props_text
return None
def _inject_rerender_fallback(test_source: str, component_name: str) -> str:
"""Inject a rerender efficiency test block when the test has no interactions.
This ensures every React perf test produces at least one update-phase marker.
"""
# Try to extract props from existing render call
props_expr = _extract_render_props(test_source, component_name)
if props_expr:
jsx_open = f"<{component_name} {props_expr} />"
else:
jsx_open = f"<{component_name} />"
rerender_block = f"""
describe('{component_name} rerender efficiency (auto-generated)', () => {{
it('should handle same-props rerenders', () => {{
const {{ rerender }} = render({jsx_open});
for (let i = 0; i < 10; i++) {{
rerender({jsx_open});
}}
}});
}});
"""
# Ensure {rerender} is in the @testing-library/react import
rtl_import_match = re.search(
r"import\s*\{([^}]+)\}\s*from\s*['\"]@testing-library/react['\"]", test_source
)
if rtl_import_match:
imports = rtl_import_match.group(1)
if "rerender" not in imports:
# Don't add 'rerender' to import — it comes from render() return value, not an import
pass
return test_source.rstrip() + "\n" + rerender_block

View file

@ -1346,6 +1346,13 @@ def _instrument_js_test_code(
code=code, function_to_optimize=function_to_optimize, capture_func=capture_func, start_counter=final_counter code=code, function_to_optimize=function_to_optimize, capture_func=capture_func, start_counter=final_counter
) )
# In behavior mode, inject DOM snapshot calls after each interaction so the
# comparator can detect post-interaction DOM divergence between original and candidate.
if mode == TestingMode.BEHAVIOR:
from codeflash.languages.javascript.frameworks.react.testgen import inject_dom_snapshot_calls # noqa: PLC0415
code = inject_dom_snapshot_calls(code)
return code return code

View file

@ -174,6 +174,14 @@ class JavaScriptSupport:
logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004 logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004
continue continue
# Skip closures inside React components — they capture hook state
# and cannot be tested independently
if func.parent_function and func.parent_function in react_component_map:
logger.debug(
f"Skipping closure {func.name} inside React component {func.parent_function}" # noqa: G004
)
continue
# Build parents list # Build parents list
parents: list[FunctionParent] = [] parents: list[FunctionParent] = []
if func.class_name: if func.class_name:
@ -231,8 +239,23 @@ class JavaScriptSupport:
source, include_methods=True, include_arrow_functions=True, require_name=True source, include_methods=True, include_arrow_functions=True, require_name=True
) )
# Build React component lookup to filter out closures
react_component_names: set[str] = set()
try:
from codeflash.languages.javascript.frameworks.react.discovery import classify_component # noqa: PLC0415
for func in tree_functions:
if classify_component(func, source, analyzer) is not None:
react_component_names.add(func.name)
except Exception:
pass
functions: list[FunctionToOptimize] = [] functions: list[FunctionToOptimize] = []
for func in tree_functions: for func in tree_functions:
# Skip closures inside React components
if func.parent_function and func.parent_function in react_component_names:
continue
# Build parents list # Build parents list
parents: list[FunctionParent] = [] parents: list[FunctionParent] = []
if func.class_name: if func.class_name:
@ -1212,6 +1235,47 @@ class JavaScriptSupport:
# === Code Transformation === # === Code Transformation ===
@staticmethod
def _strip_imports(source: str, analyzer: TreeSitterAnalyzer) -> str:
"""Strip import/require statements from source code.
Used to remove duplicate imports from optimizer output before inserting
into the original file (which already has its own imports).
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
# Collect byte ranges of import nodes to remove
ranges_to_remove: list[tuple[int, int]] = []
for child in tree.root_node.children:
if child.type == "import_statement":
ranges_to_remove.append((child.start_byte, child.end_byte))
# CJS require: const x = require('...')
elif child.type in ("lexical_declaration", "variable_declaration"):
text = source_bytes[child.start_byte : child.end_byte].decode("utf8")
if "require(" in text:
ranges_to_remove.append((child.start_byte, child.end_byte))
if not ranges_to_remove:
return source
# Build result, skipping removed ranges
parts: list[bytes] = []
prev_end = 0
for start, end in sorted(ranges_to_remove):
parts.append(source_bytes[prev_end:start])
# Skip trailing newline after import
if end < len(source_bytes) and source_bytes[end : end + 1] == b"\n":
end += 1
prev_end = end
parts.append(source_bytes[prev_end:])
result = b"".join(parts).decode("utf8")
# Remove leading blank lines left by stripping imports
while result.startswith("\n"):
result = result[1:]
return result
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str: def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
"""Replace a function in source code with new implementation. """Replace a function in source code with new implementation.
@ -1247,6 +1311,15 @@ class JavaScriptSupport:
else: else:
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT) analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
# Strip imports from optimizer output — the original file already has them,
# and import merging is handled by _add_global_declarations_for_language()
new_source = self._strip_imports(new_source, analyzer)
# If stripping imports left nothing, return original
if not new_source.strip():
logger.warning("new_source was only imports for %s, returning original", function.function_name)
return source
# Check if new_source contains a JSDoc comment - if so, use full replacement # Check if new_source contains a JSDoc comment - if so, use full replacement
# to include the updated JSDoc along with the function body # to include the updated JSDoc along with the function body
stripped_new_source = new_source.strip() stripped_new_source = new_source.strip()
@ -1263,16 +1336,64 @@ class JavaScriptSupport:
logger.warning( logger.warning(
"Could not extract body for %s from optimized code, using full replacement", function.function_name "Could not extract body for %s from optimized code, using full replacement", function.function_name
) )
# Verify that new_source contains actual code before falling back to text replacement if self._contains_function_declaration(new_source, function.function_name, analyzer):
# This prevents deletion of the original function when new_source is invalid return self._replace_function_text_based(source, function, new_source, analyzer)
if not self._contains_function_declaration(new_source, function.function_name, analyzer): # Final fallback: line-range replacement using the function's known line boundaries.
logger.warning("new_source does not contain function %s, returning original", function.function_name) # This handles cases where tree-sitter can't parse the optimized output.
return source logger.warning(
return self._replace_function_text_based(source, function, new_source, analyzer) "Falling back to line-range replacement for %s (lines %d-%d)",
function.function_name,
function.starting_line,
function.ending_line,
)
return self._replace_function_by_line_range(source, function, new_source)
# Find the original function and replace its body # Find the original function and replace its body
return self._replace_function_body(source, function, new_body, analyzer) return self._replace_function_body(source, function, new_body, analyzer)
def _replace_function_by_line_range(
self, source: str, function: FunctionToOptimize, new_source: str
) -> str:
"""Last-resort replacement: cut the original function by line range and paste new_source."""
if function.starting_line is None or function.ending_line is None:
return source
lines = source.splitlines(keepends=True)
if lines and not lines[-1].endswith("\n"):
lines[-1] += "\n"
# Get indentation from original function's first line
if function.starting_line <= len(lines):
original_first = lines[function.starting_line - 1]
original_indent = len(original_first) - len(original_first.lstrip())
else:
original_indent = 0
new_lines = new_source.splitlines(keepends=True)
if new_lines:
first_non_blank = next((l for l in new_lines if l.strip()), new_lines[0])
new_indent = len(first_non_blank) - len(first_non_blank.lstrip())
indent_diff = original_indent - new_indent
if indent_diff != 0:
adjusted = []
for line in new_lines:
if line.strip():
if indent_diff > 0:
adjusted.append(" " * indent_diff + line)
else:
cur = len(line) - len(line.lstrip())
adjusted.append(line[min(cur, abs(indent_diff)) :])
else:
adjusted.append(line)
new_lines = adjusted
if new_lines and not new_lines[-1].endswith("\n"):
new_lines[-1] += "\n"
before = lines[: function.starting_line - 1]
after = lines[function.ending_line :]
return "".join(before + new_lines + after)
def _contains_function_declaration(self, source: str, function_name: str, analyzer: TreeSitterAnalyzer) -> bool: def _contains_function_declaration(self, source: str, function_name: str, analyzer: TreeSitterAnalyzer) -> bool:
"""Check if source contains a function declaration with the given name. """Check if source contains a function declaration with the given name.
@ -1912,7 +2033,52 @@ class JavaScriptSupport:
generated_tests = inject_test_globals(generated_tests, test_framework) generated_tests = inject_test_globals(generated_tests, test_framework)
if is_typescript(): if is_typescript():
generated_tests = disable_ts_check(generated_tests) generated_tests = disable_ts_check(generated_tests)
return normalize_generated_tests_imports(generated_tests) generated_tests = normalize_generated_tests_imports(generated_tests)
# Apply React-specific post-processing (act wrapping, interaction markers, etc.)
if self._is_react_source(source_file_path):
generated_tests = self._apply_react_postprocessing(generated_tests, source_file_path)
return generated_tests
def _is_react_source(self, source_file_path: Path) -> bool:
"""Check if the source file is a React component."""
try:
content = source_file_path.read_text("utf-8", errors="replace")
return "react" in content.lower() and ("jsx" in content.lower() or source_file_path.suffix in (".tsx", ".jsx"))
except OSError:
return False
def _apply_react_postprocessing(self, generated_tests: GeneratedTestsList, source_file_path: Path) -> GeneratedTestsList:
"""Apply React-specific post-processing to generated tests."""
from codeflash.languages.javascript.frameworks.react.testgen import ( # noqa: PLC0415
post_process_react_tests,
)
from codeflash.languages.javascript.frameworks.react.discovery import ComponentType, ReactComponentInfo # noqa: PLC0415
from codeflash.models.models import GeneratedTests, GeneratedTestsList # noqa: PLC0415
component_name = source_file_path.stem
dummy_info = ReactComponentInfo(
function_name=component_name,
component_type=ComponentType.FUNCTION,
)
updated_tests = []
for test in generated_tests.generated_tests:
updated_tests.append(
GeneratedTests(
generated_original_test_source=test.generated_original_test_source,
instrumented_behavior_test_source=post_process_react_tests(
test.instrumented_behavior_test_source, dummy_info
),
instrumented_perf_test_source=post_process_react_tests(
test.instrumented_perf_test_source, dummy_info
),
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
)
return GeneratedTestsList(generated_tests=updated_tests)
def remove_test_functions_from_generated_tests( def remove_test_functions_from_generated_tests(
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str] self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]

View file

@ -401,6 +401,29 @@ class TreeSitterAnalyzer:
great_grandparent = grandparent.parent great_grandparent = grandparent.parent
if great_grandparent and great_grandparent.type == "export_statement": if great_grandparent and great_grandparent.type == "export_statement":
return True return True
# HOC wrappers: export const Comp = forwardRef/memo(function/arrow)
# Tree: function → arguments → call_expression → variable_declarator → lexical_declaration → export_statement
if parent and parent.type == "arguments":
call_expr = parent.parent
if call_expr and call_expr.type == "call_expression":
declarator = call_expr.parent
if declarator and declarator.type == "variable_declarator":
lex_decl = declarator.parent
if lex_decl and lex_decl.type in ("lexical_declaration", "variable_declaration"):
export_stmt = lex_decl.parent
if export_stmt and export_stmt.type == "export_statement":
return True
# Nested HOC: export const X = memo(forwardRef(fn))
if declarator and declarator.type == "arguments":
outer_call = declarator.parent
if outer_call and outer_call.type == "call_expression":
outer_decl = outer_call.parent
if outer_decl and outer_decl.type == "variable_declarator":
lex_decl = outer_decl.parent
if lex_decl and lex_decl.type in ("lexical_declaration", "variable_declaration"):
export_stmt = lex_decl.parent
if export_stmt and export_stmt.type == "export_statement":
return True
# For methods in exported classes # For methods in exported classes
if node.type == "method_definition": if node.type == "method_definition":
@ -609,6 +632,8 @@ class TreeSitterAnalyzer:
return None return None
_HOC_WRAPPER_NAMES = frozenset({"forwardRef", "memo", "React.forwardRef", "React.memo"})
def _get_name_from_assignment(self, node: Node, source_bytes: bytes) -> str: def _get_name_from_assignment(self, node: Node, source_bytes: bytes) -> str:
"""Try to extract function name from parent variable declaration or assignment. """Try to extract function name from parent variable declaration or assignment.
@ -617,6 +642,8 @@ class TreeSitterAnalyzer:
- const foo = function() {} - const foo = function() {}
- let bar = function() {} - let bar = function() {}
- obj.method = () => {} - obj.method = () => {}
- const Tooltip = forwardRef(function TooltipInner(...) {...})
- const MemoComp = React.memo((props) => {...})
""" """
parent = node.parent parent = node.parent
if parent is None: if parent is None:
@ -628,6 +655,30 @@ class TreeSitterAnalyzer:
if name_node: if name_node:
return self.get_node_text(name_node, source_bytes) return self.get_node_text(name_node, source_bytes)
# Check for HOC wrapper: const Name = forwardRef/memo(function/arrow)
# Tree structure: function → arguments → call_expression → variable_declarator
if parent.type == "arguments":
call_expr = parent.parent
if call_expr is not None and call_expr.type == "call_expression":
callee = call_expr.child_by_field_name("function")
if callee is not None:
callee_text = self.get_node_text(callee, source_bytes)
if callee_text in self._HOC_WRAPPER_NAMES:
declarator = call_expr.parent
if declarator is not None and declarator.type == "variable_declarator":
name_node = declarator.child_by_field_name("name")
if name_node:
return self.get_node_text(name_node, source_bytes)
# Nested HOC: memo(forwardRef(fn)) — call_expr parent is arguments of outer call
if declarator is not None and declarator.type == "arguments":
outer_call = declarator.parent
if outer_call is not None and outer_call.type == "call_expression":
outer_declarator = outer_call.parent
if outer_declarator is not None and outer_declarator.type == "variable_declarator":
name_node = outer_declarator.child_by_field_name("name")
if name_node:
return self.get_node_text(name_node, source_bytes)
# Check for assignment expression: foo = ... # Check for assignment expression: foo = ...
if parent.type == "assignment_expression": if parent.type == "assignment_expression":
left_node = parent.child_by_field_name("left") left_node = parent.child_by_field_name("left")

View file

@ -83,6 +83,7 @@ class TestDiffScope(str, Enum):
RETURN_VALUE = "return_value" RETURN_VALUE = "return_value"
STDOUT = "stdout" STDOUT = "stdout"
DID_PASS = "did_pass" # noqa: S105 DID_PASS = "did_pass" # noqa: S105
DOM_SNAPSHOT = "dom_snapshot"
@dataclass @dataclass

View file

@ -96,6 +96,7 @@ from codeflash.models.models import (
OptimizedCandidateResult, OptimizedCandidateResult,
OptimizedCandidateSource, OptimizedCandidateSource,
OriginalCodeBaseline, OriginalCodeBaseline,
TestDiffScope,
TestFile, TestFile,
TestFiles, TestFiles,
TestingMode, TestingMode,
@ -2955,6 +2956,13 @@ class FunctionOptimizer:
logger.info("h3|Test results matched ✅") logger.info("h3|Test results matched ✅")
console.rule() console.rule()
else: else:
dom_snapshot_diffs = [d for d in diffs if d.scope == TestDiffScope.DOM_SNAPSHOT]
if dom_snapshot_diffs:
logger.warning(
"[REACT] DOM snapshot divergence detected after %d interaction(s). "
"The optimized component produces different DOM output.",
len(dom_snapshot_diffs),
)
self.repair_if_possible( self.repair_if_possible(
candidate, diffs, eval_ctx, code_context, len(candidate_behavior_results), exp_type candidate, diffs, eval_ctx, code_context, len(candidate_behavior_results), exp_type
) )
@ -3178,8 +3186,7 @@ class FunctionOptimizer:
coverage_config_file=coverage_config_file, coverage_config_file=coverage_config_file,
skip_sqlite_cleanup=skip_cleanup, skip_sqlite_cleanup=skip_cleanup,
) )
if testing_type == TestingMode.PERFORMANCE: results.perf_stdout = run_result.stdout
results.perf_stdout = run_result.stdout
return results, coverage_results return results, coverage_results
# For LINE_PROFILE mode, Python uses .lprof files while JavaScript uses JSON # For LINE_PROFILE mode, Python uses .lprof files while JavaScript uses JSON
# Return TestResults for JavaScript so _line_profiler_step_javascript can parse the JSON # Return TestResults for JavaScript so _line_profiler_step_javascript can parse the JSON

View file

@ -46,11 +46,11 @@ function _getReact() {
} }
} }
// Try to load better-sqlite3, fall back to JSON if not available
let useSqlite = false;
// Configuration from environment // Configuration from environment
const OUTPUT_FILE = process.env.CODEFLASH_OUTPUT_FILE; const OUTPUT_FILE = process.env.CODEFLASH_OUTPUT_FILE;
// Enable SQLite when better-sqlite3 is available and an output file is configured
let useSqlite = !!(Database && OUTPUT_FILE);
const LOOP_INDEX = parseInt(process.env.CODEFLASH_LOOP_INDEX || '1', 10); const LOOP_INDEX = parseInt(process.env.CODEFLASH_LOOP_INDEX || '1', 10);
const TEST_ITERATION = process.env.CODEFLASH_TEST_ITERATION; const TEST_ITERATION = process.env.CODEFLASH_TEST_ITERATION;
const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE; const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE;
@ -1091,6 +1091,35 @@ function captureRender(funcName, lineId, renderFn, Component, ...createElementAr
return renderResult; return renderResult;
} }
/**
* Capture a DOM snapshot after a user interaction for behavioral verification.
*
* Writes normalized `document.body.innerHTML` to the SQLite database as an
* additional row with function name `__dom_snapshot__`. The existing
* `compare-results.js` comparator compares all rows by invocation ID, so
* snapshot rows are compared automatically.
*
* @param {string} label - Unique label for this snapshot (e.g. 'after_click_1')
*/
function snapshotDOM(label) {
if (!useSqlite || !db) return;
if (typeof document === 'undefined' || !document.body) return;
// Normalize HTML: collapse whitespace, strip React-internal attributes
let html = document.body.innerHTML;
html = html.replace(/\s+/g, ' ').trim();
html = html.replace(/\s*data-reactroot\s*/g, '');
const { testModulePath, testClassName, testFunctionName } = _getTestContext();
const invocationId = `snapshot_${label}`;
recordResult(
testModulePath, testClassName, testFunctionName,
'__dom_snapshot__', invocationId,
[label], html, null, 0
);
}
/** /**
* Capture a React component render call for PERFORMANCE benchmarking only. * Capture a React component render call for PERFORMANCE benchmarking only.
* *
@ -1334,6 +1363,7 @@ module.exports = {
capturePerf, // Performance benchmarking (prints to stdout only) capturePerf, // Performance benchmarking (prints to stdout only)
captureRender, // React render behavior verification (writes to SQLite) captureRender, // React render behavior verification (writes to SQLite)
captureRenderPerf, // React render performance benchmarking (prints to stdout only) captureRenderPerf, // React render performance benchmarking (prints to stdout only)
snapshotDOM, // DOM snapshot after interaction (behavior verification)
captureMultiple, captureMultiple,
writeResults, writeResults,
clearResults, clearResults,

View file

@ -167,9 +167,13 @@ function compareResults(originalResults, candidateResults) {
if (!isEqual) { if (!isEqual) {
allEquivalent = false; allEquivalent = false;
// Use dom_snapshot scope for __dom_snapshot__ rows
const scope = original.functionGettingTested === '__dom_snapshot__'
? 'dom_snapshot'
: 'return_value';
diffs.push({ diffs.push({
invocation_id: invocationId, invocation_id: invocationId,
scope: 'return_value', scope,
original: summarizeValue(originalValue), original: summarizeValue(originalValue),
candidate: summarizeValue(candidateValue), candidate: summarizeValue(candidateValue),
test_info: { test_info: {

View file

@ -36,6 +36,15 @@ export function capturePerf<T extends (...args: any[]) => any>(
...args: Parameters<T> ...args: Parameters<T>
): ReturnType<T>; ): ReturnType<T>;
/**
* Capture a DOM snapshot after a user interaction for behavioral verification.
* Writes normalized document.body.innerHTML to the SQLite database with
* function name '__dom_snapshot__' so the comparator detects DOM divergence.
*
* @param label - Unique label for this snapshot (e.g. 'after_click_1')
*/
export function snapshotDOM(label: string): void;
/** /**
* Capture multiple invocations for benchmarking. * Capture multiple invocations for benchmarking.
* *
@ -126,6 +135,7 @@ export const TEST_ITERATION: string;
declare const codeflash: { declare const codeflash: {
capture: typeof capture; capture: typeof capture;
capturePerf: typeof capturePerf; capturePerf: typeof capturePerf;
snapshotDOM: typeof snapshotDOM;
captureMultiple: typeof captureMultiple; captureMultiple: typeof captureMultiple;
writeResults: typeof writeResults; writeResults: typeof writeResults;
clearResults: typeof clearResults; clearResults: typeof clearResults;

View file

@ -38,6 +38,7 @@ module.exports = {
captureRender: capture.captureRender, captureRender: capture.captureRender,
captureRenderPerf: capture.captureRenderPerf, captureRenderPerf: capture.captureRenderPerf,
snapshotDOM: capture.snapshotDOM,
captureMultiple: capture.captureMultiple, captureMultiple: capture.captureMultiple,

View file

@ -173,6 +173,8 @@ class TestRenderEfficiencyCritic:
optimized_render_duration=10.0, optimized_render_duration=10.0,
original_update_render_count=8, original_update_render_count=8,
optimized_update_render_count=8, optimized_update_render_count=8,
original_update_duration=80.0,
optimized_update_duration=5.0,
) is True ) is True
def test_rejects_worse_than_best(self): def test_rejects_worse_than_best(self):
@ -448,6 +450,8 @@ class TestRenderEfficiencyCriticTrustDuration:
optimized_render_duration=10.0, optimized_render_duration=10.0,
original_update_render_count=8, original_update_render_count=8,
optimized_update_render_count=8, optimized_update_render_count=8,
original_update_duration=100.0,
optimized_update_duration=10.0,
trust_duration=True, trust_duration=True,
) is True ) is True

View file

@ -0,0 +1,78 @@
"""Tests for inject_dom_snapshot_calls in React behavioral verification."""
from __future__ import annotations
from codeflash.languages.javascript.frameworks.react.testgen import inject_dom_snapshot_calls
def test_injects_after_fireEvent():
source = """\
import codeflash from 'codeflash';
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
fireEvent.click(screen.getByText('Add'));
fireEvent.change(input, { target: { value: 'hi' } });
"""
result = inject_dom_snapshot_calls(source)
assert "codeflash.snapshotDOM('after_click_1');" in result
assert "codeflash.snapshotDOM('after_change_1');" in result
def test_skips_perf_mode():
source = """\
import codeflash from 'codeflash';
const result = await codeflash.captureRenderPerf('Comp', '1', render, Comp);
fireEvent.click(screen.getByText('Add'));
"""
result = inject_dom_snapshot_calls(source)
assert "snapshotDOM" not in result
def test_skips_without_captureRender():
source = """\
import codeflash from 'codeflash';
fireEvent.click(screen.getByText('Add'));
"""
result = inject_dom_snapshot_calls(source)
assert "snapshotDOM" not in result
def test_preserves_indentation():
source = """\
import codeflash from 'codeflash';
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
fireEvent.click(btn);
fireEvent.change(input, { target: { value: 'x' } });
"""
result = inject_dom_snapshot_calls(source)
lines = result.split("\n")
# Find the snapshot lines and check their indentation
snapshot_lines = [l for l in lines if "snapshotDOM" in l]
assert len(snapshot_lines) == 2
assert snapshot_lines[0].startswith(" codeflash.snapshotDOM")
assert snapshot_lines[1].startswith(" codeflash.snapshotDOM")
def test_no_semicolons():
"""Real-world projects (e.g. Zustand) don't use semicolons."""
source = """\
const { container } = codeflash.captureRender('Comp', '1', render, Comp)
fireEvent.click(screen.getByText('button'))
fireEvent.click(screen.getByTestId('test-shallow'))
"""
result = inject_dom_snapshot_calls(source)
assert "after_click_1" in result
assert "after_click_2" in result
def test_sequential_counter():
source = """\
import codeflash from 'codeflash';
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
fireEvent.click(btn);
fireEvent.click(btn);
fireEvent.click(btn);
"""
result = inject_dom_snapshot_calls(source)
assert "after_click_1" in result
assert "after_click_2" in result
assert "after_click_3" in result