mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
snapshot comparison in verification
This commit is contained in:
parent
4bc89f2b9d
commit
b5ba6df690
16 changed files with 505 additions and 39 deletions
|
|
@ -46,6 +46,8 @@ class AiServiceClient:
|
|||
self.llm_call_counter = count(1)
|
||||
self.is_local = self.base_url == "http://localhost:8000"
|
||||
self.timeout: float | None = 300 if self.is_local else 90
|
||||
# React components are larger (300+ lines of JSX) and need more LLM processing time
|
||||
self.react_timeout: float | None = 300 if self.is_local else 180
|
||||
|
||||
def get_next_sequence(self) -> int:
|
||||
"""Get the next LLM call sequence number."""
|
||||
|
|
@ -203,7 +205,8 @@ class AiServiceClient:
|
|||
logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}")
|
||||
|
||||
try:
|
||||
response = self.make_ai_service_request("/optimize", payload=payload, timeout=self.timeout)
|
||||
timeout = self.react_timeout if is_react_component else self.timeout
|
||||
response = self.make_ai_service_request("/optimize", payload=payload, timeout=timeout)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.exception(f"Error generating optimized candidates: {e}")
|
||||
ph("cli-optimize-error-caught", {"error": str(e)})
|
||||
|
|
@ -806,7 +809,8 @@ class AiServiceClient:
|
|||
# DEBUG: Print payload language field
|
||||
logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'")
|
||||
try:
|
||||
response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout)
|
||||
timeout = self.react_timeout if is_react_component else self.timeout
|
||||
response = self.make_ai_service_request("/testgen", payload=payload, timeout=timeout)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.exception(f"Error generating tests: {e}")
|
||||
ph("cli-testgen-error-caught", {"error": str(e)})
|
||||
|
|
|
|||
|
|
@ -143,6 +143,8 @@ def compare_test_results(
|
|||
scope = TestDiffScope.STDOUT
|
||||
elif scope_str == "did_pass":
|
||||
scope = TestDiffScope.DID_PASS
|
||||
elif scope_str == "dom_snapshot":
|
||||
scope = TestDiffScope.DOM_SNAPSHOT
|
||||
|
||||
test_info = diff.get("test_info", {})
|
||||
# Build a test identifier string for JavaScript tests
|
||||
|
|
|
|||
|
|
@ -210,7 +210,10 @@ def normalize_codeflash_imports(source: str) -> str:
|
|||
# Replace CommonJS require
|
||||
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
|
||||
# Replace ES module import
|
||||
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
|
||||
source = _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
|
||||
# Fix outdated @testing-library/jest-dom import paths (v6+ removed /extend-expect subpath)
|
||||
source = source.replace("@testing-library/jest-dom/extend-expect", "@testing-library/jest-dom")
|
||||
return source
|
||||
|
||||
|
||||
# Author: ali <mohammed18200118@gmail.com>
|
||||
|
|
|
|||
|
|
@ -72,42 +72,37 @@ def instrument_component_with_profiler(source: str, component_name: str, analyze
|
|||
if computed is not None:
|
||||
replacements.append(computed)
|
||||
|
||||
# Work entirely in bytes to avoid byte-offset vs char-index mismatches
|
||||
# (tree-sitter returns byte offsets; non-ASCII chars cause divergence from Python string indices)
|
||||
if replacements:
|
||||
# Reconstruct result in a single pass
|
||||
parts: list[str] = []
|
||||
byte_parts: list[bytes] = []
|
||||
prev = 0
|
||||
for start, end, wrapped in replacements:
|
||||
# Use original source slices (string indices expected by original logic)
|
||||
parts.append(source[prev:start])
|
||||
parts.append(wrapped)
|
||||
byte_parts.append(source_bytes[prev:start])
|
||||
byte_parts.append(wrapped.encode("utf-8"))
|
||||
prev = end
|
||||
parts.append(source[prev:])
|
||||
result = "".join(parts)
|
||||
byte_parts.append(source_bytes[prev:])
|
||||
result_bytes = b"".join(byte_parts)
|
||||
else:
|
||||
result = source
|
||||
|
||||
# Add render counter code at the top (after imports) using the already-parsed tree
|
||||
result_bytes = source_bytes
|
||||
|
||||
# Add render counter code at the top (after imports)
|
||||
counter_code = generate_render_counter_code(component_name)
|
||||
|
||||
# Inline logic similar to _insert_after_imports but reuse existing `tree` to avoid re-parsing
|
||||
last_import_end = 0
|
||||
for child in tree.root_node.children:
|
||||
if child.type == "import_statement":
|
||||
last_import_end = child.end_byte
|
||||
|
||||
insert_pos = last_import_end
|
||||
while insert_pos < len(result) and result[insert_pos] != "\n":
|
||||
while insert_pos < len(result_bytes) and result_bytes[insert_pos : insert_pos + 1] != b"\n":
|
||||
insert_pos += 1
|
||||
if insert_pos < len(result):
|
||||
if insert_pos < len(result_bytes):
|
||||
insert_pos += 1 # skip the newline
|
||||
|
||||
result = result[:insert_pos] + "\n" + counter_code + "\n\n" + result[insert_pos:]
|
||||
counter_bytes = ("\n" + counter_code + "\n\n").encode("utf-8")
|
||||
result = (result_bytes[:insert_pos] + counter_bytes + result_bytes[insert_pos:]).decode("utf-8")
|
||||
|
||||
# Ensure React is imported
|
||||
|
||||
# Ensure React is imported
|
||||
return _ensure_react_import(result)
|
||||
|
||||
|
||||
|
|
@ -317,7 +312,8 @@ def _wrap_return_with_profiler(source: str, return_node: Node, profiler_id: str,
|
|||
f"</React.Profiler>"
|
||||
)
|
||||
|
||||
return source[:jsx_start] + wrapped + source[jsx_end:]
|
||||
# Use byte-level splicing to avoid byte-offset vs char-index mismatches
|
||||
return (source_bytes[:jsx_start] + wrapped.encode("utf-8") + source_bytes[jsx_end:]).decode("utf-8")
|
||||
|
||||
|
||||
def _insert_after_imports(source: str, code: str, analyzer: TreeSitterAnalyzer) -> str:
|
||||
|
|
@ -421,7 +417,7 @@ let _codeflash_render_count_{safe_name} = 0;
|
|||
if (typeof beforeEach !== 'undefined') {{
|
||||
beforeEach(() => {{ _codeflash_render_count_{safe_name} = 0; }});
|
||||
}}
|
||||
function _codeflashOnRender_{safe_name}(id, phase, actualDuration, baseDuration) {{
|
||||
function _codeflashOnRender_{safe_name}(id: any, phase: any, actualDuration: any, baseDuration: any) {{
|
||||
_codeflash_render_count_{safe_name}++;
|
||||
console.log(`!######{marker_prefix}:${{id}}:${{phase}}:${{actualDuration}}:${{baseDuration}}:${{_codeflash_render_count_{safe_name}}}######!`);
|
||||
}}"""
|
||||
|
|
|
|||
|
|
@ -98,6 +98,15 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf
|
|||
"""
|
||||
result = test_source
|
||||
|
||||
# Fix outdated @testing-library/jest-dom import paths (v6+ removed /extend-expect subpath)
|
||||
result = result.replace("@testing-library/jest-dom/extend-expect", "@testing-library/jest-dom")
|
||||
# Also fix require() variant
|
||||
result = re.sub(
|
||||
r"""require\s*\(\s*['"]@testing-library/jest-dom/extend-expect['"]\s*\)""",
|
||||
"require('@testing-library/jest-dom')",
|
||||
result,
|
||||
)
|
||||
|
||||
# Ensure testing-library import
|
||||
if "@testing-library/react" not in result:
|
||||
result = "import { render, screen, act } from '@testing-library/react';\n" + result
|
||||
|
|
@ -168,14 +177,14 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf
|
|||
# This gives per-interaction A/B signal without the LLM needing to know about it.
|
||||
result = inject_interaction_markers(result)
|
||||
|
||||
# Warn if no tests contain interaction calls — mount-phase only markers are
|
||||
# not useful for measuring optimization effectiveness.
|
||||
# If no tests contain interaction calls, auto-inject a rerender fallback so
|
||||
# that EVERY React perf test produces at least one update-phase marker.
|
||||
if not has_react_test_interactions(result):
|
||||
logger.warning(
|
||||
"[REACT] Generated tests for %s contain no interactions (fireEvent, userEvent, rerender). "
|
||||
"Tests will produce only mount-phase markers which cannot measure optimization improvements.",
|
||||
"[REACT] Generated tests for %s contain no interactions — auto-injecting rerender fallback.",
|
||||
component_info.function_name,
|
||||
)
|
||||
result = _inject_rerender_fallback(result, component_info.function_name)
|
||||
|
||||
# Check interaction density — fewer than MIN_INTERACTION_CALLS total interactions
|
||||
# means the test is unlikely to produce enough update-phase renders for reliable measurement.
|
||||
|
|
@ -221,6 +230,48 @@ def _extract_interaction_label(call_text: str) -> str:
|
|||
return m.group(1) if m else "interaction"
|
||||
|
||||
|
||||
def inject_dom_snapshot_calls(test_source: str) -> str:
|
||||
"""Inject codeflash.snapshotDOM() calls after each user interaction in behavior mode.
|
||||
|
||||
Only active when `captureRender` (not `captureRenderPerf`) is present,
|
||||
meaning the test is running in behavioral verification mode.
|
||||
|
||||
After each fireEvent.*, userEvent.*, or rerender() call, inserts:
|
||||
codeflash.snapshotDOM('after_{label}_{n}');
|
||||
on the next line with matching indentation.
|
||||
"""
|
||||
if "captureRender" not in test_source:
|
||||
return test_source
|
||||
if "captureRenderPerf" in test_source:
|
||||
return test_source
|
||||
|
||||
interaction_counter: dict[str, int] = {}
|
||||
lines = test_source.split("\n")
|
||||
new_lines: list[str] = []
|
||||
|
||||
# Also match rerender() calls — use .* to handle nested parens like getByText('Add')
|
||||
snapshot_interaction_pattern = re.compile(
|
||||
r"^(\s*)((?:await\s+)?(?:fireEvent\.\w+|userEvent\.\w+|(?:\w+\.)?rerender)\s*\(.*\))\s*;?\s*$",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
for line in lines:
|
||||
new_lines.append(line)
|
||||
m = snapshot_interaction_pattern.match(line)
|
||||
if m:
|
||||
indent = m.group(1)
|
||||
call_text = m.group(2)
|
||||
label = _extract_interaction_label(call_text)
|
||||
if label == "interaction":
|
||||
# rerender() call
|
||||
label = "rerender"
|
||||
interaction_counter[label] = interaction_counter.get(label, 0) + 1
|
||||
unique_label = f"{label}_{interaction_counter[label]}"
|
||||
new_lines.append(f"{indent}codeflash.snapshotDOM('after_{unique_label}');")
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def inject_interaction_markers(test_source: str) -> str:
|
||||
"""Inject _codeflashMarkInteraction() calls before each fireEvent/userEvent call.
|
||||
|
||||
|
|
@ -325,3 +376,54 @@ def has_high_density_interactions(test_source: str) -> bool:
|
|||
|
||||
interaction_calls = _INTERACTION_PATTERNS.findall(test_source)
|
||||
return len(interaction_calls) >= _MIN_SEQUENTIAL_INTERACTIONS
|
||||
|
||||
|
||||
# Pattern to extract props from the first render(<Component ...props... />) call
|
||||
_RENDER_CALL_PROPS_PATTERN = re.compile(
|
||||
r"render\s*\(\s*<(\w+)\s+([^/]*?)\s*/?\s*>",
|
||||
)
|
||||
|
||||
|
||||
def _extract_render_props(test_source: str, component_name: str) -> str | None:
|
||||
"""Extract the props expression from the first render(<Component ...>) call."""
|
||||
for m in _RENDER_CALL_PROPS_PATTERN.finditer(test_source):
|
||||
if m.group(1) == component_name:
|
||||
props_text = m.group(2).strip()
|
||||
if props_text:
|
||||
return props_text
|
||||
return None
|
||||
|
||||
|
||||
def _inject_rerender_fallback(test_source: str, component_name: str) -> str:
|
||||
"""Inject a rerender efficiency test block when the test has no interactions.
|
||||
|
||||
This ensures every React perf test produces at least one update-phase marker.
|
||||
"""
|
||||
# Try to extract props from existing render call
|
||||
props_expr = _extract_render_props(test_source, component_name)
|
||||
if props_expr:
|
||||
jsx_open = f"<{component_name} {props_expr} />"
|
||||
else:
|
||||
jsx_open = f"<{component_name} />"
|
||||
|
||||
rerender_block = f"""
|
||||
describe('{component_name} rerender efficiency (auto-generated)', () => {{
|
||||
it('should handle same-props rerenders', () => {{
|
||||
const {{ rerender }} = render({jsx_open});
|
||||
for (let i = 0; i < 10; i++) {{
|
||||
rerender({jsx_open});
|
||||
}}
|
||||
}});
|
||||
}});
|
||||
"""
|
||||
# Ensure {rerender} is in the @testing-library/react import
|
||||
rtl_import_match = re.search(
|
||||
r"import\s*\{([^}]+)\}\s*from\s*['\"]@testing-library/react['\"]", test_source
|
||||
)
|
||||
if rtl_import_match:
|
||||
imports = rtl_import_match.group(1)
|
||||
if "rerender" not in imports:
|
||||
# Don't add 'rerender' to import — it comes from render() return value, not an import
|
||||
pass
|
||||
|
||||
return test_source.rstrip() + "\n" + rerender_block
|
||||
|
|
|
|||
|
|
@ -1346,6 +1346,13 @@ def _instrument_js_test_code(
|
|||
code=code, function_to_optimize=function_to_optimize, capture_func=capture_func, start_counter=final_counter
|
||||
)
|
||||
|
||||
# In behavior mode, inject DOM snapshot calls after each interaction so the
|
||||
# comparator can detect post-interaction DOM divergence between original and candidate.
|
||||
if mode == TestingMode.BEHAVIOR:
|
||||
from codeflash.languages.javascript.frameworks.react.testgen import inject_dom_snapshot_calls # noqa: PLC0415
|
||||
|
||||
code = inject_dom_snapshot_calls(code)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -174,6 +174,14 @@ class JavaScriptSupport:
|
|||
logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004
|
||||
continue
|
||||
|
||||
# Skip closures inside React components — they capture hook state
|
||||
# and cannot be tested independently
|
||||
if func.parent_function and func.parent_function in react_component_map:
|
||||
logger.debug(
|
||||
f"Skipping closure {func.name} inside React component {func.parent_function}" # noqa: G004
|
||||
)
|
||||
continue
|
||||
|
||||
# Build parents list
|
||||
parents: list[FunctionParent] = []
|
||||
if func.class_name:
|
||||
|
|
@ -231,8 +239,23 @@ class JavaScriptSupport:
|
|||
source, include_methods=True, include_arrow_functions=True, require_name=True
|
||||
)
|
||||
|
||||
# Build React component lookup to filter out closures
|
||||
react_component_names: set[str] = set()
|
||||
try:
|
||||
from codeflash.languages.javascript.frameworks.react.discovery import classify_component # noqa: PLC0415
|
||||
|
||||
for func in tree_functions:
|
||||
if classify_component(func, source, analyzer) is not None:
|
||||
react_component_names.add(func.name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
functions: list[FunctionToOptimize] = []
|
||||
for func in tree_functions:
|
||||
# Skip closures inside React components
|
||||
if func.parent_function and func.parent_function in react_component_names:
|
||||
continue
|
||||
|
||||
# Build parents list
|
||||
parents: list[FunctionParent] = []
|
||||
if func.class_name:
|
||||
|
|
@ -1212,6 +1235,47 @@ class JavaScriptSupport:
|
|||
|
||||
# === Code Transformation ===
|
||||
|
||||
@staticmethod
|
||||
def _strip_imports(source: str, analyzer: TreeSitterAnalyzer) -> str:
|
||||
"""Strip import/require statements from source code.
|
||||
|
||||
Used to remove duplicate imports from optimizer output before inserting
|
||||
into the original file (which already has its own imports).
|
||||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
|
||||
# Collect byte ranges of import nodes to remove
|
||||
ranges_to_remove: list[tuple[int, int]] = []
|
||||
for child in tree.root_node.children:
|
||||
if child.type == "import_statement":
|
||||
ranges_to_remove.append((child.start_byte, child.end_byte))
|
||||
# CJS require: const x = require('...')
|
||||
elif child.type in ("lexical_declaration", "variable_declaration"):
|
||||
text = source_bytes[child.start_byte : child.end_byte].decode("utf8")
|
||||
if "require(" in text:
|
||||
ranges_to_remove.append((child.start_byte, child.end_byte))
|
||||
|
||||
if not ranges_to_remove:
|
||||
return source
|
||||
|
||||
# Build result, skipping removed ranges
|
||||
parts: list[bytes] = []
|
||||
prev_end = 0
|
||||
for start, end in sorted(ranges_to_remove):
|
||||
parts.append(source_bytes[prev_end:start])
|
||||
# Skip trailing newline after import
|
||||
if end < len(source_bytes) and source_bytes[end : end + 1] == b"\n":
|
||||
end += 1
|
||||
prev_end = end
|
||||
parts.append(source_bytes[prev_end:])
|
||||
|
||||
result = b"".join(parts).decode("utf8")
|
||||
# Remove leading blank lines left by stripping imports
|
||||
while result.startswith("\n"):
|
||||
result = result[1:]
|
||||
return result
|
||||
|
||||
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
|
||||
"""Replace a function in source code with new implementation.
|
||||
|
||||
|
|
@ -1247,6 +1311,15 @@ class JavaScriptSupport:
|
|||
else:
|
||||
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
# Strip imports from optimizer output — the original file already has them,
|
||||
# and import merging is handled by _add_global_declarations_for_language()
|
||||
new_source = self._strip_imports(new_source, analyzer)
|
||||
|
||||
# If stripping imports left nothing, return original
|
||||
if not new_source.strip():
|
||||
logger.warning("new_source was only imports for %s, returning original", function.function_name)
|
||||
return source
|
||||
|
||||
# Check if new_source contains a JSDoc comment - if so, use full replacement
|
||||
# to include the updated JSDoc along with the function body
|
||||
stripped_new_source = new_source.strip()
|
||||
|
|
@ -1263,16 +1336,64 @@ class JavaScriptSupport:
|
|||
logger.warning(
|
||||
"Could not extract body for %s from optimized code, using full replacement", function.function_name
|
||||
)
|
||||
# Verify that new_source contains actual code before falling back to text replacement
|
||||
# This prevents deletion of the original function when new_source is invalid
|
||||
if not self._contains_function_declaration(new_source, function.function_name, analyzer):
|
||||
logger.warning("new_source does not contain function %s, returning original", function.function_name)
|
||||
return source
|
||||
return self._replace_function_text_based(source, function, new_source, analyzer)
|
||||
if self._contains_function_declaration(new_source, function.function_name, analyzer):
|
||||
return self._replace_function_text_based(source, function, new_source, analyzer)
|
||||
# Final fallback: line-range replacement using the function's known line boundaries.
|
||||
# This handles cases where tree-sitter can't parse the optimized output.
|
||||
logger.warning(
|
||||
"Falling back to line-range replacement for %s (lines %d-%d)",
|
||||
function.function_name,
|
||||
function.starting_line,
|
||||
function.ending_line,
|
||||
)
|
||||
return self._replace_function_by_line_range(source, function, new_source)
|
||||
|
||||
# Find the original function and replace its body
|
||||
return self._replace_function_body(source, function, new_body, analyzer)
|
||||
|
||||
def _replace_function_by_line_range(
|
||||
self, source: str, function: FunctionToOptimize, new_source: str
|
||||
) -> str:
|
||||
"""Last-resort replacement: cut the original function by line range and paste new_source."""
|
||||
if function.starting_line is None or function.ending_line is None:
|
||||
return source
|
||||
|
||||
lines = source.splitlines(keepends=True)
|
||||
if lines and not lines[-1].endswith("\n"):
|
||||
lines[-1] += "\n"
|
||||
|
||||
# Get indentation from original function's first line
|
||||
if function.starting_line <= len(lines):
|
||||
original_first = lines[function.starting_line - 1]
|
||||
original_indent = len(original_first) - len(original_first.lstrip())
|
||||
else:
|
||||
original_indent = 0
|
||||
|
||||
new_lines = new_source.splitlines(keepends=True)
|
||||
if new_lines:
|
||||
first_non_blank = next((l for l in new_lines if l.strip()), new_lines[0])
|
||||
new_indent = len(first_non_blank) - len(first_non_blank.lstrip())
|
||||
indent_diff = original_indent - new_indent
|
||||
if indent_diff != 0:
|
||||
adjusted = []
|
||||
for line in new_lines:
|
||||
if line.strip():
|
||||
if indent_diff > 0:
|
||||
adjusted.append(" " * indent_diff + line)
|
||||
else:
|
||||
cur = len(line) - len(line.lstrip())
|
||||
adjusted.append(line[min(cur, abs(indent_diff)) :])
|
||||
else:
|
||||
adjusted.append(line)
|
||||
new_lines = adjusted
|
||||
|
||||
if new_lines and not new_lines[-1].endswith("\n"):
|
||||
new_lines[-1] += "\n"
|
||||
|
||||
before = lines[: function.starting_line - 1]
|
||||
after = lines[function.ending_line :]
|
||||
return "".join(before + new_lines + after)
|
||||
|
||||
def _contains_function_declaration(self, source: str, function_name: str, analyzer: TreeSitterAnalyzer) -> bool:
|
||||
"""Check if source contains a function declaration with the given name.
|
||||
|
||||
|
|
@ -1912,7 +2033,52 @@ class JavaScriptSupport:
|
|||
generated_tests = inject_test_globals(generated_tests, test_framework)
|
||||
if is_typescript():
|
||||
generated_tests = disable_ts_check(generated_tests)
|
||||
return normalize_generated_tests_imports(generated_tests)
|
||||
generated_tests = normalize_generated_tests_imports(generated_tests)
|
||||
|
||||
# Apply React-specific post-processing (act wrapping, interaction markers, etc.)
|
||||
if self._is_react_source(source_file_path):
|
||||
generated_tests = self._apply_react_postprocessing(generated_tests, source_file_path)
|
||||
|
||||
return generated_tests
|
||||
|
||||
def _is_react_source(self, source_file_path: Path) -> bool:
|
||||
"""Check if the source file is a React component."""
|
||||
try:
|
||||
content = source_file_path.read_text("utf-8", errors="replace")
|
||||
return "react" in content.lower() and ("jsx" in content.lower() or source_file_path.suffix in (".tsx", ".jsx"))
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
def _apply_react_postprocessing(self, generated_tests: GeneratedTestsList, source_file_path: Path) -> GeneratedTestsList:
|
||||
"""Apply React-specific post-processing to generated tests."""
|
||||
from codeflash.languages.javascript.frameworks.react.testgen import ( # noqa: PLC0415
|
||||
post_process_react_tests,
|
||||
)
|
||||
from codeflash.languages.javascript.frameworks.react.discovery import ComponentType, ReactComponentInfo # noqa: PLC0415
|
||||
from codeflash.models.models import GeneratedTests, GeneratedTestsList # noqa: PLC0415
|
||||
|
||||
component_name = source_file_path.stem
|
||||
dummy_info = ReactComponentInfo(
|
||||
function_name=component_name,
|
||||
component_type=ComponentType.FUNCTION,
|
||||
)
|
||||
|
||||
updated_tests = []
|
||||
for test in generated_tests.generated_tests:
|
||||
updated_tests.append(
|
||||
GeneratedTests(
|
||||
generated_original_test_source=test.generated_original_test_source,
|
||||
instrumented_behavior_test_source=post_process_react_tests(
|
||||
test.instrumented_behavior_test_source, dummy_info
|
||||
),
|
||||
instrumented_perf_test_source=post_process_react_tests(
|
||||
test.instrumented_perf_test_source, dummy_info
|
||||
),
|
||||
behavior_file_path=test.behavior_file_path,
|
||||
perf_file_path=test.perf_file_path,
|
||||
)
|
||||
)
|
||||
return GeneratedTestsList(generated_tests=updated_tests)
|
||||
|
||||
def remove_test_functions_from_generated_tests(
|
||||
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]
|
||||
|
|
|
|||
|
|
@ -401,6 +401,29 @@ class TreeSitterAnalyzer:
|
|||
great_grandparent = grandparent.parent
|
||||
if great_grandparent and great_grandparent.type == "export_statement":
|
||||
return True
|
||||
# HOC wrappers: export const Comp = forwardRef/memo(function/arrow)
|
||||
# Tree: function → arguments → call_expression → variable_declarator → lexical_declaration → export_statement
|
||||
if parent and parent.type == "arguments":
|
||||
call_expr = parent.parent
|
||||
if call_expr and call_expr.type == "call_expression":
|
||||
declarator = call_expr.parent
|
||||
if declarator and declarator.type == "variable_declarator":
|
||||
lex_decl = declarator.parent
|
||||
if lex_decl and lex_decl.type in ("lexical_declaration", "variable_declaration"):
|
||||
export_stmt = lex_decl.parent
|
||||
if export_stmt and export_stmt.type == "export_statement":
|
||||
return True
|
||||
# Nested HOC: export const X = memo(forwardRef(fn))
|
||||
if declarator and declarator.type == "arguments":
|
||||
outer_call = declarator.parent
|
||||
if outer_call and outer_call.type == "call_expression":
|
||||
outer_decl = outer_call.parent
|
||||
if outer_decl and outer_decl.type == "variable_declarator":
|
||||
lex_decl = outer_decl.parent
|
||||
if lex_decl and lex_decl.type in ("lexical_declaration", "variable_declaration"):
|
||||
export_stmt = lex_decl.parent
|
||||
if export_stmt and export_stmt.type == "export_statement":
|
||||
return True
|
||||
|
||||
# For methods in exported classes
|
||||
if node.type == "method_definition":
|
||||
|
|
@ -609,6 +632,8 @@ class TreeSitterAnalyzer:
|
|||
|
||||
return None
|
||||
|
||||
_HOC_WRAPPER_NAMES = frozenset({"forwardRef", "memo", "React.forwardRef", "React.memo"})
|
||||
|
||||
def _get_name_from_assignment(self, node: Node, source_bytes: bytes) -> str:
|
||||
"""Try to extract function name from parent variable declaration or assignment.
|
||||
|
||||
|
|
@ -617,6 +642,8 @@ class TreeSitterAnalyzer:
|
|||
- const foo = function() {}
|
||||
- let bar = function() {}
|
||||
- obj.method = () => {}
|
||||
- const Tooltip = forwardRef(function TooltipInner(...) {...})
|
||||
- const MemoComp = React.memo((props) => {...})
|
||||
"""
|
||||
parent = node.parent
|
||||
if parent is None:
|
||||
|
|
@ -628,6 +655,30 @@ class TreeSitterAnalyzer:
|
|||
if name_node:
|
||||
return self.get_node_text(name_node, source_bytes)
|
||||
|
||||
# Check for HOC wrapper: const Name = forwardRef/memo(function/arrow)
|
||||
# Tree structure: function → arguments → call_expression → variable_declarator
|
||||
if parent.type == "arguments":
|
||||
call_expr = parent.parent
|
||||
if call_expr is not None and call_expr.type == "call_expression":
|
||||
callee = call_expr.child_by_field_name("function")
|
||||
if callee is not None:
|
||||
callee_text = self.get_node_text(callee, source_bytes)
|
||||
if callee_text in self._HOC_WRAPPER_NAMES:
|
||||
declarator = call_expr.parent
|
||||
if declarator is not None and declarator.type == "variable_declarator":
|
||||
name_node = declarator.child_by_field_name("name")
|
||||
if name_node:
|
||||
return self.get_node_text(name_node, source_bytes)
|
||||
# Nested HOC: memo(forwardRef(fn)) — call_expr parent is arguments of outer call
|
||||
if declarator is not None and declarator.type == "arguments":
|
||||
outer_call = declarator.parent
|
||||
if outer_call is not None and outer_call.type == "call_expression":
|
||||
outer_declarator = outer_call.parent
|
||||
if outer_declarator is not None and outer_declarator.type == "variable_declarator":
|
||||
name_node = outer_declarator.child_by_field_name("name")
|
||||
if name_node:
|
||||
return self.get_node_text(name_node, source_bytes)
|
||||
|
||||
# Check for assignment expression: foo = ...
|
||||
if parent.type == "assignment_expression":
|
||||
left_node = parent.child_by_field_name("left")
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class TestDiffScope(str, Enum):
|
|||
RETURN_VALUE = "return_value"
|
||||
STDOUT = "stdout"
|
||||
DID_PASS = "did_pass" # noqa: S105
|
||||
DOM_SNAPSHOT = "dom_snapshot"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -96,6 +96,7 @@ from codeflash.models.models import (
|
|||
OptimizedCandidateResult,
|
||||
OptimizedCandidateSource,
|
||||
OriginalCodeBaseline,
|
||||
TestDiffScope,
|
||||
TestFile,
|
||||
TestFiles,
|
||||
TestingMode,
|
||||
|
|
@ -2955,6 +2956,13 @@ class FunctionOptimizer:
|
|||
logger.info("h3|Test results matched ✅")
|
||||
console.rule()
|
||||
else:
|
||||
dom_snapshot_diffs = [d for d in diffs if d.scope == TestDiffScope.DOM_SNAPSHOT]
|
||||
if dom_snapshot_diffs:
|
||||
logger.warning(
|
||||
"[REACT] DOM snapshot divergence detected after %d interaction(s). "
|
||||
"The optimized component produces different DOM output.",
|
||||
len(dom_snapshot_diffs),
|
||||
)
|
||||
self.repair_if_possible(
|
||||
candidate, diffs, eval_ctx, code_context, len(candidate_behavior_results), exp_type
|
||||
)
|
||||
|
|
@ -3178,8 +3186,7 @@ class FunctionOptimizer:
|
|||
coverage_config_file=coverage_config_file,
|
||||
skip_sqlite_cleanup=skip_cleanup,
|
||||
)
|
||||
if testing_type == TestingMode.PERFORMANCE:
|
||||
results.perf_stdout = run_result.stdout
|
||||
results.perf_stdout = run_result.stdout
|
||||
return results, coverage_results
|
||||
# For LINE_PROFILE mode, Python uses .lprof files while JavaScript uses JSON
|
||||
# Return TestResults for JavaScript so _line_profiler_step_javascript can parse the JSON
|
||||
|
|
|
|||
|
|
@ -46,11 +46,11 @@ function _getReact() {
|
|||
}
|
||||
}
|
||||
|
||||
// Try to load better-sqlite3, fall back to JSON if not available
|
||||
let useSqlite = false;
|
||||
|
||||
// Configuration from environment
|
||||
const OUTPUT_FILE = process.env.CODEFLASH_OUTPUT_FILE;
|
||||
|
||||
// Enable SQLite when better-sqlite3 is available and an output file is configured
|
||||
let useSqlite = !!(Database && OUTPUT_FILE);
|
||||
const LOOP_INDEX = parseInt(process.env.CODEFLASH_LOOP_INDEX || '1', 10);
|
||||
const TEST_ITERATION = process.env.CODEFLASH_TEST_ITERATION;
|
||||
const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE;
|
||||
|
|
@ -1091,6 +1091,35 @@ function captureRender(funcName, lineId, renderFn, Component, ...createElementAr
|
|||
return renderResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture a DOM snapshot after a user interaction for behavioral verification.
|
||||
*
|
||||
* Writes normalized `document.body.innerHTML` to the SQLite database as an
|
||||
* additional row with function name `__dom_snapshot__`. The existing
|
||||
* `compare-results.js` comparator compares all rows by invocation ID, so
|
||||
* snapshot rows are compared automatically.
|
||||
*
|
||||
* @param {string} label - Unique label for this snapshot (e.g. 'after_click_1')
|
||||
*/
|
||||
function snapshotDOM(label) {
|
||||
if (!useSqlite || !db) return;
|
||||
if (typeof document === 'undefined' || !document.body) return;
|
||||
|
||||
// Normalize HTML: collapse whitespace, strip React-internal attributes
|
||||
let html = document.body.innerHTML;
|
||||
html = html.replace(/\s+/g, ' ').trim();
|
||||
html = html.replace(/\s*data-reactroot\s*/g, '');
|
||||
|
||||
const { testModulePath, testClassName, testFunctionName } = _getTestContext();
|
||||
const invocationId = `snapshot_${label}`;
|
||||
|
||||
recordResult(
|
||||
testModulePath, testClassName, testFunctionName,
|
||||
'__dom_snapshot__', invocationId,
|
||||
[label], html, null, 0
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture a React component render call for PERFORMANCE benchmarking only.
|
||||
*
|
||||
|
|
@ -1334,6 +1363,7 @@ module.exports = {
|
|||
capturePerf, // Performance benchmarking (prints to stdout only)
|
||||
captureRender, // React render behavior verification (writes to SQLite)
|
||||
captureRenderPerf, // React render performance benchmarking (prints to stdout only)
|
||||
snapshotDOM, // DOM snapshot after interaction (behavior verification)
|
||||
captureMultiple,
|
||||
writeResults,
|
||||
clearResults,
|
||||
|
|
|
|||
|
|
@ -167,9 +167,13 @@ function compareResults(originalResults, candidateResults) {
|
|||
|
||||
if (!isEqual) {
|
||||
allEquivalent = false;
|
||||
// Use dom_snapshot scope for __dom_snapshot__ rows
|
||||
const scope = original.functionGettingTested === '__dom_snapshot__'
|
||||
? 'dom_snapshot'
|
||||
: 'return_value';
|
||||
diffs.push({
|
||||
invocation_id: invocationId,
|
||||
scope: 'return_value',
|
||||
scope,
|
||||
original: summarizeValue(originalValue),
|
||||
candidate: summarizeValue(candidateValue),
|
||||
test_info: {
|
||||
|
|
|
|||
10
packages/codeflash/runtime/index.d.ts
vendored
10
packages/codeflash/runtime/index.d.ts
vendored
|
|
@ -36,6 +36,15 @@ export function capturePerf<T extends (...args: any[]) => any>(
|
|||
...args: Parameters<T>
|
||||
): ReturnType<T>;
|
||||
|
||||
/**
|
||||
* Capture a DOM snapshot after a user interaction for behavioral verification.
|
||||
* Writes normalized document.body.innerHTML to the SQLite database with
|
||||
* function name '__dom_snapshot__' so the comparator detects DOM divergence.
|
||||
*
|
||||
* @param label - Unique label for this snapshot (e.g. 'after_click_1')
|
||||
*/
|
||||
export function snapshotDOM(label: string): void;
|
||||
|
||||
/**
|
||||
* Capture multiple invocations for benchmarking.
|
||||
*
|
||||
|
|
@ -126,6 +135,7 @@ export const TEST_ITERATION: string;
|
|||
declare const codeflash: {
|
||||
capture: typeof capture;
|
||||
capturePerf: typeof capturePerf;
|
||||
snapshotDOM: typeof snapshotDOM;
|
||||
captureMultiple: typeof captureMultiple;
|
||||
writeResults: typeof writeResults;
|
||||
clearResults: typeof clearResults;
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ module.exports = {
|
|||
|
||||
captureRender: capture.captureRender,
|
||||
captureRenderPerf: capture.captureRenderPerf,
|
||||
snapshotDOM: capture.snapshotDOM,
|
||||
|
||||
captureMultiple: capture.captureMultiple,
|
||||
|
||||
|
|
|
|||
|
|
@ -173,6 +173,8 @@ class TestRenderEfficiencyCritic:
|
|||
optimized_render_duration=10.0,
|
||||
original_update_render_count=8,
|
||||
optimized_update_render_count=8,
|
||||
original_update_duration=80.0,
|
||||
optimized_update_duration=5.0,
|
||||
) is True
|
||||
|
||||
def test_rejects_worse_than_best(self):
|
||||
|
|
@ -448,6 +450,8 @@ class TestRenderEfficiencyCriticTrustDuration:
|
|||
optimized_render_duration=10.0,
|
||||
original_update_render_count=8,
|
||||
optimized_update_render_count=8,
|
||||
original_update_duration=100.0,
|
||||
optimized_update_duration=10.0,
|
||||
trust_duration=True,
|
||||
) is True
|
||||
|
||||
|
|
|
|||
78
tests/react/test_dom_snapshot.py
Normal file
78
tests/react/test_dom_snapshot.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""Tests for inject_dom_snapshot_calls in React behavioral verification."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from codeflash.languages.javascript.frameworks.react.testgen import inject_dom_snapshot_calls
|
||||
|
||||
|
||||
def test_injects_after_fireEvent():
|
||||
source = """\
|
||||
import codeflash from 'codeflash';
|
||||
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
|
||||
fireEvent.click(screen.getByText('Add'));
|
||||
fireEvent.change(input, { target: { value: 'hi' } });
|
||||
"""
|
||||
result = inject_dom_snapshot_calls(source)
|
||||
assert "codeflash.snapshotDOM('after_click_1');" in result
|
||||
assert "codeflash.snapshotDOM('after_change_1');" in result
|
||||
|
||||
|
||||
def test_skips_perf_mode():
|
||||
source = """\
|
||||
import codeflash from 'codeflash';
|
||||
const result = await codeflash.captureRenderPerf('Comp', '1', render, Comp);
|
||||
fireEvent.click(screen.getByText('Add'));
|
||||
"""
|
||||
result = inject_dom_snapshot_calls(source)
|
||||
assert "snapshotDOM" not in result
|
||||
|
||||
|
||||
def test_skips_without_captureRender():
|
||||
source = """\
|
||||
import codeflash from 'codeflash';
|
||||
fireEvent.click(screen.getByText('Add'));
|
||||
"""
|
||||
result = inject_dom_snapshot_calls(source)
|
||||
assert "snapshotDOM" not in result
|
||||
|
||||
|
||||
def test_preserves_indentation():
|
||||
source = """\
|
||||
import codeflash from 'codeflash';
|
||||
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
|
||||
fireEvent.click(btn);
|
||||
fireEvent.change(input, { target: { value: 'x' } });
|
||||
"""
|
||||
result = inject_dom_snapshot_calls(source)
|
||||
lines = result.split("\n")
|
||||
# Find the snapshot lines and check their indentation
|
||||
snapshot_lines = [l for l in lines if "snapshotDOM" in l]
|
||||
assert len(snapshot_lines) == 2
|
||||
assert snapshot_lines[0].startswith(" codeflash.snapshotDOM")
|
||||
assert snapshot_lines[1].startswith(" codeflash.snapshotDOM")
|
||||
|
||||
|
||||
def test_no_semicolons():
|
||||
"""Real-world projects (e.g. Zustand) don't use semicolons."""
|
||||
source = """\
|
||||
const { container } = codeflash.captureRender('Comp', '1', render, Comp)
|
||||
fireEvent.click(screen.getByText('button'))
|
||||
fireEvent.click(screen.getByTestId('test-shallow'))
|
||||
"""
|
||||
result = inject_dom_snapshot_calls(source)
|
||||
assert "after_click_1" in result
|
||||
assert "after_click_2" in result
|
||||
|
||||
|
||||
def test_sequential_counter():
|
||||
source = """\
|
||||
import codeflash from 'codeflash';
|
||||
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
|
||||
fireEvent.click(btn);
|
||||
fireEvent.click(btn);
|
||||
fireEvent.click(btn);
|
||||
"""
|
||||
result = inject_dom_snapshot_calls(source)
|
||||
assert "after_click_1" in result
|
||||
assert "after_click_2" in result
|
||||
assert "after_click_3" in result
|
||||
Loading…
Reference in a new issue