snapshot comparison in verification

This commit is contained in:
Sarthak Agarwal 2026-04-07 20:11:54 +05:30
parent 4bc89f2b9d
commit b5ba6df690
16 changed files with 505 additions and 39 deletions

View file

@ -46,6 +46,8 @@ class AiServiceClient:
self.llm_call_counter = count(1)
self.is_local = self.base_url == "http://localhost:8000"
self.timeout: float | None = 300 if self.is_local else 90
# React components are larger (300+ lines of JSX) and need more LLM processing time
self.react_timeout: float | None = 300 if self.is_local else 180
def get_next_sequence(self) -> int:
"""Get the next LLM call sequence number."""
@ -203,7 +205,8 @@ class AiServiceClient:
logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}")
try:
response = self.make_ai_service_request("/optimize", payload=payload, timeout=self.timeout)
timeout = self.react_timeout if is_react_component else self.timeout
response = self.make_ai_service_request("/optimize", payload=payload, timeout=timeout)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating optimized candidates: {e}")
ph("cli-optimize-error-caught", {"error": str(e)})
@ -806,7 +809,8 @@ class AiServiceClient:
# DEBUG: Print payload language field
logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'")
try:
response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout)
timeout = self.react_timeout if is_react_component else self.timeout
response = self.make_ai_service_request("/testgen", payload=payload, timeout=timeout)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating tests: {e}")
ph("cli-testgen-error-caught", {"error": str(e)})

View file

@ -143,6 +143,8 @@ def compare_test_results(
scope = TestDiffScope.STDOUT
elif scope_str == "did_pass":
scope = TestDiffScope.DID_PASS
elif scope_str == "dom_snapshot":
scope = TestDiffScope.DOM_SNAPSHOT
test_info = diff.get("test_info", {})
# Build a test identifier string for JavaScript tests

View file

@ -210,7 +210,10 @@ def normalize_codeflash_imports(source: str) -> str:
# Replace CommonJS require
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
# Replace ES module import
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
source = _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
# Fix outdated @testing-library/jest-dom import paths (v6+ removed /extend-expect subpath)
source = source.replace("@testing-library/jest-dom/extend-expect", "@testing-library/jest-dom")
return source
# Author: ali <mohammed18200118@gmail.com>

View file

@ -72,42 +72,37 @@ def instrument_component_with_profiler(source: str, component_name: str, analyze
if computed is not None:
replacements.append(computed)
# Work entirely in bytes to avoid byte-offset vs char-index mismatches
# (tree-sitter returns byte offsets; non-ASCII chars cause divergence from Python string indices)
if replacements:
# Reconstruct result in a single pass
parts: list[str] = []
byte_parts: list[bytes] = []
prev = 0
for start, end, wrapped in replacements:
# Use original source slices (string indices expected by original logic)
parts.append(source[prev:start])
parts.append(wrapped)
byte_parts.append(source_bytes[prev:start])
byte_parts.append(wrapped.encode("utf-8"))
prev = end
parts.append(source[prev:])
result = "".join(parts)
byte_parts.append(source_bytes[prev:])
result_bytes = b"".join(byte_parts)
else:
result = source
# Add render counter code at the top (after imports) using the already-parsed tree
result_bytes = source_bytes
# Add render counter code at the top (after imports)
counter_code = generate_render_counter_code(component_name)
# Inline logic similar to _insert_after_imports but reuse existing `tree` to avoid re-parsing
last_import_end = 0
for child in tree.root_node.children:
if child.type == "import_statement":
last_import_end = child.end_byte
insert_pos = last_import_end
while insert_pos < len(result) and result[insert_pos] != "\n":
while insert_pos < len(result_bytes) and result_bytes[insert_pos : insert_pos + 1] != b"\n":
insert_pos += 1
if insert_pos < len(result):
if insert_pos < len(result_bytes):
insert_pos += 1 # skip the newline
result = result[:insert_pos] + "\n" + counter_code + "\n\n" + result[insert_pos:]
counter_bytes = ("\n" + counter_code + "\n\n").encode("utf-8")
result = (result_bytes[:insert_pos] + counter_bytes + result_bytes[insert_pos:]).decode("utf-8")
# Ensure React is imported
# Ensure React is imported
return _ensure_react_import(result)
@ -317,7 +312,8 @@ def _wrap_return_with_profiler(source: str, return_node: Node, profiler_id: str,
f"</React.Profiler>"
)
return source[:jsx_start] + wrapped + source[jsx_end:]
# Use byte-level splicing to avoid byte-offset vs char-index mismatches
return (source_bytes[:jsx_start] + wrapped.encode("utf-8") + source_bytes[jsx_end:]).decode("utf-8")
def _insert_after_imports(source: str, code: str, analyzer: TreeSitterAnalyzer) -> str:
@ -421,7 +417,7 @@ let _codeflash_render_count_{safe_name} = 0;
if (typeof beforeEach !== 'undefined') {{
beforeEach(() => {{ _codeflash_render_count_{safe_name} = 0; }});
}}
function _codeflashOnRender_{safe_name}(id, phase, actualDuration, baseDuration) {{
function _codeflashOnRender_{safe_name}(id: any, phase: any, actualDuration: any, baseDuration: any) {{
_codeflash_render_count_{safe_name}++;
console.log(`!######{marker_prefix}:${{id}}:${{phase}}:${{actualDuration}}:${{baseDuration}}:${{_codeflash_render_count_{safe_name}}}######!`);
}}"""

View file

@ -98,6 +98,15 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf
"""
result = test_source
# Fix outdated @testing-library/jest-dom import paths (v6+ removed /extend-expect subpath)
result = result.replace("@testing-library/jest-dom/extend-expect", "@testing-library/jest-dom")
# Also fix require() variant
result = re.sub(
r"""require\s*\(\s*['"]@testing-library/jest-dom/extend-expect['"]\s*\)""",
"require('@testing-library/jest-dom')",
result,
)
# Ensure testing-library import
if "@testing-library/react" not in result:
result = "import { render, screen, act } from '@testing-library/react';\n" + result
@ -168,14 +177,14 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf
# This gives per-interaction A/B signal without the LLM needing to know about it.
result = inject_interaction_markers(result)
# Warn if no tests contain interaction calls — mount-phase only markers are
# not useful for measuring optimization effectiveness.
# If no tests contain interaction calls, auto-inject a rerender fallback so
# that EVERY React perf test produces at least one update-phase marker.
if not has_react_test_interactions(result):
logger.warning(
"[REACT] Generated tests for %s contain no interactions (fireEvent, userEvent, rerender). "
"Tests will produce only mount-phase markers which cannot measure optimization improvements.",
"[REACT] Generated tests for %s contain no interactions — auto-injecting rerender fallback.",
component_info.function_name,
)
result = _inject_rerender_fallback(result, component_info.function_name)
# Check interaction density — fewer than MIN_INTERACTION_CALLS total interactions
# means the test is unlikely to produce enough update-phase renders for reliable measurement.
@ -221,6 +230,48 @@ def _extract_interaction_label(call_text: str) -> str:
return m.group(1) if m else "interaction"
def inject_dom_snapshot_calls(test_source: str) -> str:
"""Inject codeflash.snapshotDOM() calls after each user interaction in behavior mode.
Only active when `captureRender` (not `captureRenderPerf`) is present,
meaning the test is running in behavioral verification mode.
After each fireEvent.*, userEvent.*, or rerender() call, inserts:
codeflash.snapshotDOM('after_{label}_{n}');
on the next line with matching indentation.
"""
if "captureRender" not in test_source:
return test_source
if "captureRenderPerf" in test_source:
return test_source
interaction_counter: dict[str, int] = {}
lines = test_source.split("\n")
new_lines: list[str] = []
# Also match rerender() calls — use .* to handle nested parens like getByText('Add')
snapshot_interaction_pattern = re.compile(
r"^(\s*)((?:await\s+)?(?:fireEvent\.\w+|userEvent\.\w+|(?:\w+\.)?rerender)\s*\(.*\))\s*;?\s*$",
re.MULTILINE,
)
for line in lines:
new_lines.append(line)
m = snapshot_interaction_pattern.match(line)
if m:
indent = m.group(1)
call_text = m.group(2)
label = _extract_interaction_label(call_text)
if label == "interaction":
# rerender() call
label = "rerender"
interaction_counter[label] = interaction_counter.get(label, 0) + 1
unique_label = f"{label}_{interaction_counter[label]}"
new_lines.append(f"{indent}codeflash.snapshotDOM('after_{unique_label}');")
return "\n".join(new_lines)
def inject_interaction_markers(test_source: str) -> str:
"""Inject _codeflashMarkInteraction() calls before each fireEvent/userEvent call.
@ -325,3 +376,54 @@ def has_high_density_interactions(test_source: str) -> bool:
interaction_calls = _INTERACTION_PATTERNS.findall(test_source)
return len(interaction_calls) >= _MIN_SEQUENTIAL_INTERACTIONS
# Pattern to extract props from the first render(<Component ...props... />) call
_RENDER_CALL_PROPS_PATTERN = re.compile(
r"render\s*\(\s*<(\w+)\s+([^/]*?)\s*/?\s*>",
)
def _extract_render_props(test_source: str, component_name: str) -> str | None:
"""Extract the props expression from the first render(<Component ...>) call."""
for m in _RENDER_CALL_PROPS_PATTERN.finditer(test_source):
if m.group(1) == component_name:
props_text = m.group(2).strip()
if props_text:
return props_text
return None
def _inject_rerender_fallback(test_source: str, component_name: str) -> str:
"""Inject a rerender efficiency test block when the test has no interactions.
This ensures every React perf test produces at least one update-phase marker.
"""
# Try to extract props from existing render call
props_expr = _extract_render_props(test_source, component_name)
if props_expr:
jsx_open = f"<{component_name} {props_expr} />"
else:
jsx_open = f"<{component_name} />"
rerender_block = f"""
describe('{component_name} rerender efficiency (auto-generated)', () => {{
it('should handle same-props rerenders', () => {{
const {{ rerender }} = render({jsx_open});
for (let i = 0; i < 10; i++) {{
rerender({jsx_open});
}}
}});
}});
"""
# Ensure {rerender} is in the @testing-library/react import
rtl_import_match = re.search(
r"import\s*\{([^}]+)\}\s*from\s*['\"]@testing-library/react['\"]", test_source
)
if rtl_import_match:
imports = rtl_import_match.group(1)
if "rerender" not in imports:
# Don't add 'rerender' to import — it comes from render() return value, not an import
pass
return test_source.rstrip() + "\n" + rerender_block

View file

@ -1346,6 +1346,13 @@ def _instrument_js_test_code(
code=code, function_to_optimize=function_to_optimize, capture_func=capture_func, start_counter=final_counter
)
# In behavior mode, inject DOM snapshot calls after each interaction so the
# comparator can detect post-interaction DOM divergence between original and candidate.
if mode == TestingMode.BEHAVIOR:
from codeflash.languages.javascript.frameworks.react.testgen import inject_dom_snapshot_calls # noqa: PLC0415
code = inject_dom_snapshot_calls(code)
return code

View file

@ -174,6 +174,14 @@ class JavaScriptSupport:
logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004
continue
# Skip closures inside React components — they capture hook state
# and cannot be tested independently
if func.parent_function and func.parent_function in react_component_map:
logger.debug(
f"Skipping closure {func.name} inside React component {func.parent_function}" # noqa: G004
)
continue
# Build parents list
parents: list[FunctionParent] = []
if func.class_name:
@ -231,8 +239,23 @@ class JavaScriptSupport:
source, include_methods=True, include_arrow_functions=True, require_name=True
)
# Build React component lookup to filter out closures
react_component_names: set[str] = set()
try:
from codeflash.languages.javascript.frameworks.react.discovery import classify_component # noqa: PLC0415
for func in tree_functions:
if classify_component(func, source, analyzer) is not None:
react_component_names.add(func.name)
except Exception:
pass
functions: list[FunctionToOptimize] = []
for func in tree_functions:
# Skip closures inside React components
if func.parent_function and func.parent_function in react_component_names:
continue
# Build parents list
parents: list[FunctionParent] = []
if func.class_name:
@ -1212,6 +1235,47 @@ class JavaScriptSupport:
# === Code Transformation ===
@staticmethod
def _strip_imports(source: str, analyzer: TreeSitterAnalyzer) -> str:
"""Strip import/require statements from source code.
Used to remove duplicate imports from optimizer output before inserting
into the original file (which already has its own imports).
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
# Collect byte ranges of import nodes to remove
ranges_to_remove: list[tuple[int, int]] = []
for child in tree.root_node.children:
if child.type == "import_statement":
ranges_to_remove.append((child.start_byte, child.end_byte))
# CJS require: const x = require('...')
elif child.type in ("lexical_declaration", "variable_declaration"):
text = source_bytes[child.start_byte : child.end_byte].decode("utf8")
if "require(" in text:
ranges_to_remove.append((child.start_byte, child.end_byte))
if not ranges_to_remove:
return source
# Build result, skipping removed ranges
parts: list[bytes] = []
prev_end = 0
for start, end in sorted(ranges_to_remove):
parts.append(source_bytes[prev_end:start])
# Skip trailing newline after import
if end < len(source_bytes) and source_bytes[end : end + 1] == b"\n":
end += 1
prev_end = end
parts.append(source_bytes[prev_end:])
result = b"".join(parts).decode("utf8")
# Remove leading blank lines left by stripping imports
while result.startswith("\n"):
result = result[1:]
return result
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
"""Replace a function in source code with new implementation.
@ -1247,6 +1311,15 @@ class JavaScriptSupport:
else:
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
# Strip imports from optimizer output — the original file already has them,
# and import merging is handled by _add_global_declarations_for_language()
new_source = self._strip_imports(new_source, analyzer)
# If stripping imports left nothing, return original
if not new_source.strip():
logger.warning("new_source was only imports for %s, returning original", function.function_name)
return source
# Check if new_source contains a JSDoc comment - if so, use full replacement
# to include the updated JSDoc along with the function body
stripped_new_source = new_source.strip()
@ -1263,16 +1336,64 @@ class JavaScriptSupport:
logger.warning(
"Could not extract body for %s from optimized code, using full replacement", function.function_name
)
# Verify that new_source contains actual code before falling back to text replacement
# This prevents deletion of the original function when new_source is invalid
if not self._contains_function_declaration(new_source, function.function_name, analyzer):
logger.warning("new_source does not contain function %s, returning original", function.function_name)
return source
if self._contains_function_declaration(new_source, function.function_name, analyzer):
return self._replace_function_text_based(source, function, new_source, analyzer)
# Final fallback: line-range replacement using the function's known line boundaries.
# This handles cases where tree-sitter can't parse the optimized output.
logger.warning(
"Falling back to line-range replacement for %s (lines %d-%d)",
function.function_name,
function.starting_line,
function.ending_line,
)
return self._replace_function_by_line_range(source, function, new_source)
# Find the original function and replace its body
return self._replace_function_body(source, function, new_body, analyzer)
def _replace_function_by_line_range(
self, source: str, function: FunctionToOptimize, new_source: str
) -> str:
"""Last-resort replacement: cut the original function by line range and paste new_source."""
if function.starting_line is None or function.ending_line is None:
return source
lines = source.splitlines(keepends=True)
if lines and not lines[-1].endswith("\n"):
lines[-1] += "\n"
# Get indentation from original function's first line
if function.starting_line <= len(lines):
original_first = lines[function.starting_line - 1]
original_indent = len(original_first) - len(original_first.lstrip())
else:
original_indent = 0
new_lines = new_source.splitlines(keepends=True)
if new_lines:
first_non_blank = next((l for l in new_lines if l.strip()), new_lines[0])
new_indent = len(first_non_blank) - len(first_non_blank.lstrip())
indent_diff = original_indent - new_indent
if indent_diff != 0:
adjusted = []
for line in new_lines:
if line.strip():
if indent_diff > 0:
adjusted.append(" " * indent_diff + line)
else:
cur = len(line) - len(line.lstrip())
adjusted.append(line[min(cur, abs(indent_diff)) :])
else:
adjusted.append(line)
new_lines = adjusted
if new_lines and not new_lines[-1].endswith("\n"):
new_lines[-1] += "\n"
before = lines[: function.starting_line - 1]
after = lines[function.ending_line :]
return "".join(before + new_lines + after)
def _contains_function_declaration(self, source: str, function_name: str, analyzer: TreeSitterAnalyzer) -> bool:
"""Check if source contains a function declaration with the given name.
@ -1912,7 +2033,52 @@ class JavaScriptSupport:
generated_tests = inject_test_globals(generated_tests, test_framework)
if is_typescript():
generated_tests = disable_ts_check(generated_tests)
return normalize_generated_tests_imports(generated_tests)
generated_tests = normalize_generated_tests_imports(generated_tests)
# Apply React-specific post-processing (act wrapping, interaction markers, etc.)
if self._is_react_source(source_file_path):
generated_tests = self._apply_react_postprocessing(generated_tests, source_file_path)
return generated_tests
def _is_react_source(self, source_file_path: Path) -> bool:
"""Check if the source file is a React component."""
try:
content = source_file_path.read_text("utf-8", errors="replace")
return "react" in content.lower() and ("jsx" in content.lower() or source_file_path.suffix in (".tsx", ".jsx"))
except OSError:
return False
def _apply_react_postprocessing(self, generated_tests: GeneratedTestsList, source_file_path: Path) -> GeneratedTestsList:
"""Apply React-specific post-processing to generated tests."""
from codeflash.languages.javascript.frameworks.react.testgen import ( # noqa: PLC0415
post_process_react_tests,
)
from codeflash.languages.javascript.frameworks.react.discovery import ComponentType, ReactComponentInfo # noqa: PLC0415
from codeflash.models.models import GeneratedTests, GeneratedTestsList # noqa: PLC0415
component_name = source_file_path.stem
dummy_info = ReactComponentInfo(
function_name=component_name,
component_type=ComponentType.FUNCTION,
)
updated_tests = []
for test in generated_tests.generated_tests:
updated_tests.append(
GeneratedTests(
generated_original_test_source=test.generated_original_test_source,
instrumented_behavior_test_source=post_process_react_tests(
test.instrumented_behavior_test_source, dummy_info
),
instrumented_perf_test_source=post_process_react_tests(
test.instrumented_perf_test_source, dummy_info
),
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
)
return GeneratedTestsList(generated_tests=updated_tests)
def remove_test_functions_from_generated_tests(
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]

View file

@ -401,6 +401,29 @@ class TreeSitterAnalyzer:
great_grandparent = grandparent.parent
if great_grandparent and great_grandparent.type == "export_statement":
return True
# HOC wrappers: export const Comp = forwardRef/memo(function/arrow)
# Tree: function → arguments → call_expression → variable_declarator → lexical_declaration → export_statement
if parent and parent.type == "arguments":
call_expr = parent.parent
if call_expr and call_expr.type == "call_expression":
declarator = call_expr.parent
if declarator and declarator.type == "variable_declarator":
lex_decl = declarator.parent
if lex_decl and lex_decl.type in ("lexical_declaration", "variable_declaration"):
export_stmt = lex_decl.parent
if export_stmt and export_stmt.type == "export_statement":
return True
# Nested HOC: export const X = memo(forwardRef(fn))
if declarator and declarator.type == "arguments":
outer_call = declarator.parent
if outer_call and outer_call.type == "call_expression":
outer_decl = outer_call.parent
if outer_decl and outer_decl.type == "variable_declarator":
lex_decl = outer_decl.parent
if lex_decl and lex_decl.type in ("lexical_declaration", "variable_declaration"):
export_stmt = lex_decl.parent
if export_stmt and export_stmt.type == "export_statement":
return True
# For methods in exported classes
if node.type == "method_definition":
@ -609,6 +632,8 @@ class TreeSitterAnalyzer:
return None
_HOC_WRAPPER_NAMES = frozenset({"forwardRef", "memo", "React.forwardRef", "React.memo"})
def _get_name_from_assignment(self, node: Node, source_bytes: bytes) -> str:
"""Try to extract function name from parent variable declaration or assignment.
@ -617,6 +642,8 @@ class TreeSitterAnalyzer:
- const foo = function() {}
- let bar = function() {}
- obj.method = () => {}
- const Tooltip = forwardRef(function TooltipInner(...) {...})
- const MemoComp = React.memo((props) => {...})
"""
parent = node.parent
if parent is None:
@ -628,6 +655,30 @@ class TreeSitterAnalyzer:
if name_node:
return self.get_node_text(name_node, source_bytes)
# Check for HOC wrapper: const Name = forwardRef/memo(function/arrow)
# Tree structure: function → arguments → call_expression → variable_declarator
if parent.type == "arguments":
call_expr = parent.parent
if call_expr is not None and call_expr.type == "call_expression":
callee = call_expr.child_by_field_name("function")
if callee is not None:
callee_text = self.get_node_text(callee, source_bytes)
if callee_text in self._HOC_WRAPPER_NAMES:
declarator = call_expr.parent
if declarator is not None and declarator.type == "variable_declarator":
name_node = declarator.child_by_field_name("name")
if name_node:
return self.get_node_text(name_node, source_bytes)
# Nested HOC: memo(forwardRef(fn)) — call_expr parent is arguments of outer call
if declarator is not None and declarator.type == "arguments":
outer_call = declarator.parent
if outer_call is not None and outer_call.type == "call_expression":
outer_declarator = outer_call.parent
if outer_declarator is not None and outer_declarator.type == "variable_declarator":
name_node = outer_declarator.child_by_field_name("name")
if name_node:
return self.get_node_text(name_node, source_bytes)
# Check for assignment expression: foo = ...
if parent.type == "assignment_expression":
left_node = parent.child_by_field_name("left")

View file

@ -83,6 +83,7 @@ class TestDiffScope(str, Enum):
RETURN_VALUE = "return_value"
STDOUT = "stdout"
DID_PASS = "did_pass" # noqa: S105
DOM_SNAPSHOT = "dom_snapshot"
@dataclass

View file

@ -96,6 +96,7 @@ from codeflash.models.models import (
OptimizedCandidateResult,
OptimizedCandidateSource,
OriginalCodeBaseline,
TestDiffScope,
TestFile,
TestFiles,
TestingMode,
@ -2955,6 +2956,13 @@ class FunctionOptimizer:
logger.info("h3|Test results matched ✅")
console.rule()
else:
dom_snapshot_diffs = [d for d in diffs if d.scope == TestDiffScope.DOM_SNAPSHOT]
if dom_snapshot_diffs:
logger.warning(
"[REACT] DOM snapshot divergence detected after %d interaction(s). "
"The optimized component produces different DOM output.",
len(dom_snapshot_diffs),
)
self.repair_if_possible(
candidate, diffs, eval_ctx, code_context, len(candidate_behavior_results), exp_type
)
@ -3178,7 +3186,6 @@ class FunctionOptimizer:
coverage_config_file=coverage_config_file,
skip_sqlite_cleanup=skip_cleanup,
)
if testing_type == TestingMode.PERFORMANCE:
results.perf_stdout = run_result.stdout
return results, coverage_results
# For LINE_PROFILE mode, Python uses .lprof files while JavaScript uses JSON

View file

@ -46,11 +46,11 @@ function _getReact() {
}
}
// Try to load better-sqlite3, fall back to JSON if not available
let useSqlite = false;
// Configuration from environment
const OUTPUT_FILE = process.env.CODEFLASH_OUTPUT_FILE;
// Enable SQLite when better-sqlite3 is available and an output file is configured
let useSqlite = !!(Database && OUTPUT_FILE);
const LOOP_INDEX = parseInt(process.env.CODEFLASH_LOOP_INDEX || '1', 10);
const TEST_ITERATION = process.env.CODEFLASH_TEST_ITERATION;
const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE;
@ -1091,6 +1091,35 @@ function captureRender(funcName, lineId, renderFn, Component, ...createElementAr
return renderResult;
}
/**
* Capture a DOM snapshot after a user interaction for behavioral verification.
*
* Writes normalized `document.body.innerHTML` to the SQLite database as an
* additional row with function name `__dom_snapshot__`. The existing
* `compare-results.js` comparator compares all rows by invocation ID, so
* snapshot rows are compared automatically.
*
* @param {string} label - Unique label for this snapshot (e.g. 'after_click_1')
*/
function snapshotDOM(label) {
if (!useSqlite || !db) return;
if (typeof document === 'undefined' || !document.body) return;
// Normalize HTML: collapse whitespace, strip React-internal attributes
let html = document.body.innerHTML;
html = html.replace(/\s+/g, ' ').trim();
html = html.replace(/\s*data-reactroot\s*/g, '');
const { testModulePath, testClassName, testFunctionName } = _getTestContext();
const invocationId = `snapshot_${label}`;
recordResult(
testModulePath, testClassName, testFunctionName,
'__dom_snapshot__', invocationId,
[label], html, null, 0
);
}
/**
* Capture a React component render call for PERFORMANCE benchmarking only.
*
@ -1334,6 +1363,7 @@ module.exports = {
capturePerf, // Performance benchmarking (prints to stdout only)
captureRender, // React render behavior verification (writes to SQLite)
captureRenderPerf, // React render performance benchmarking (prints to stdout only)
snapshotDOM, // DOM snapshot after interaction (behavior verification)
captureMultiple,
writeResults,
clearResults,

View file

@ -167,9 +167,13 @@ function compareResults(originalResults, candidateResults) {
if (!isEqual) {
allEquivalent = false;
// Use dom_snapshot scope for __dom_snapshot__ rows
const scope = original.functionGettingTested === '__dom_snapshot__'
? 'dom_snapshot'
: 'return_value';
diffs.push({
invocation_id: invocationId,
scope: 'return_value',
scope,
original: summarizeValue(originalValue),
candidate: summarizeValue(candidateValue),
test_info: {

View file

@ -36,6 +36,15 @@ export function capturePerf<T extends (...args: any[]) => any>(
...args: Parameters<T>
): ReturnType<T>;
/**
* Capture a DOM snapshot after a user interaction for behavioral verification.
* Writes normalized document.body.innerHTML to the SQLite database with
* function name '__dom_snapshot__' so the comparator detects DOM divergence.
*
* @param label - Unique label for this snapshot (e.g. 'after_click_1')
*/
export function snapshotDOM(label: string): void;
/**
* Capture multiple invocations for benchmarking.
*
@ -126,6 +135,7 @@ export const TEST_ITERATION: string;
declare const codeflash: {
capture: typeof capture;
capturePerf: typeof capturePerf;
snapshotDOM: typeof snapshotDOM;
captureMultiple: typeof captureMultiple;
writeResults: typeof writeResults;
clearResults: typeof clearResults;

View file

@ -38,6 +38,7 @@ module.exports = {
captureRender: capture.captureRender,
captureRenderPerf: capture.captureRenderPerf,
snapshotDOM: capture.snapshotDOM,
captureMultiple: capture.captureMultiple,

View file

@ -173,6 +173,8 @@ class TestRenderEfficiencyCritic:
optimized_render_duration=10.0,
original_update_render_count=8,
optimized_update_render_count=8,
original_update_duration=80.0,
optimized_update_duration=5.0,
) is True
def test_rejects_worse_than_best(self):
@ -448,6 +450,8 @@ class TestRenderEfficiencyCriticTrustDuration:
optimized_render_duration=10.0,
original_update_render_count=8,
optimized_update_render_count=8,
original_update_duration=100.0,
optimized_update_duration=10.0,
trust_duration=True,
) is True

View file

@ -0,0 +1,78 @@
"""Tests for inject_dom_snapshot_calls in React behavioral verification."""
from __future__ import annotations
from codeflash.languages.javascript.frameworks.react.testgen import inject_dom_snapshot_calls
def test_injects_after_fireEvent():
source = """\
import codeflash from 'codeflash';
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
fireEvent.click(screen.getByText('Add'));
fireEvent.change(input, { target: { value: 'hi' } });
"""
result = inject_dom_snapshot_calls(source)
assert "codeflash.snapshotDOM('after_click_1');" in result
assert "codeflash.snapshotDOM('after_change_1');" in result
def test_skips_perf_mode():
source = """\
import codeflash from 'codeflash';
const result = await codeflash.captureRenderPerf('Comp', '1', render, Comp);
fireEvent.click(screen.getByText('Add'));
"""
result = inject_dom_snapshot_calls(source)
assert "snapshotDOM" not in result
def test_skips_without_captureRender():
source = """\
import codeflash from 'codeflash';
fireEvent.click(screen.getByText('Add'));
"""
result = inject_dom_snapshot_calls(source)
assert "snapshotDOM" not in result
def test_preserves_indentation():
source = """\
import codeflash from 'codeflash';
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
fireEvent.click(btn);
fireEvent.change(input, { target: { value: 'x' } });
"""
result = inject_dom_snapshot_calls(source)
lines = result.split("\n")
# Find the snapshot lines and check their indentation
snapshot_lines = [l for l in lines if "snapshotDOM" in l]
assert len(snapshot_lines) == 2
assert snapshot_lines[0].startswith(" codeflash.snapshotDOM")
assert snapshot_lines[1].startswith(" codeflash.snapshotDOM")
def test_no_semicolons():
"""Real-world projects (e.g. Zustand) don't use semicolons."""
source = """\
const { container } = codeflash.captureRender('Comp', '1', render, Comp)
fireEvent.click(screen.getByText('button'))
fireEvent.click(screen.getByTestId('test-shallow'))
"""
result = inject_dom_snapshot_calls(source)
assert "after_click_1" in result
assert "after_click_2" in result
def test_sequential_counter():
source = """\
import codeflash from 'codeflash';
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
fireEvent.click(btn);
fireEvent.click(btn);
fireEvent.click(btn);
"""
result = inject_dom_snapshot_calls(source)
assert "after_click_1" in result
assert "after_click_2" in result
assert "after_click_3" in result