mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
snapshot comparison in verification
This commit is contained in:
parent
4bc89f2b9d
commit
b5ba6df690
16 changed files with 505 additions and 39 deletions
|
|
@ -46,6 +46,8 @@ class AiServiceClient:
|
||||||
self.llm_call_counter = count(1)
|
self.llm_call_counter = count(1)
|
||||||
self.is_local = self.base_url == "http://localhost:8000"
|
self.is_local = self.base_url == "http://localhost:8000"
|
||||||
self.timeout: float | None = 300 if self.is_local else 90
|
self.timeout: float | None = 300 if self.is_local else 90
|
||||||
|
# React components are larger (300+ lines of JSX) and need more LLM processing time
|
||||||
|
self.react_timeout: float | None = 300 if self.is_local else 180
|
||||||
|
|
||||||
def get_next_sequence(self) -> int:
|
def get_next_sequence(self) -> int:
|
||||||
"""Get the next LLM call sequence number."""
|
"""Get the next LLM call sequence number."""
|
||||||
|
|
@ -203,7 +205,8 @@ class AiServiceClient:
|
||||||
logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}")
|
logger.debug(f"Sending optimize request: trace_id={trace_id}, n_candidates={payload['n_candidates']}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.make_ai_service_request("/optimize", payload=payload, timeout=self.timeout)
|
timeout = self.react_timeout if is_react_component else self.timeout
|
||||||
|
response = self.make_ai_service_request("/optimize", payload=payload, timeout=timeout)
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logger.exception(f"Error generating optimized candidates: {e}")
|
logger.exception(f"Error generating optimized candidates: {e}")
|
||||||
ph("cli-optimize-error-caught", {"error": str(e)})
|
ph("cli-optimize-error-caught", {"error": str(e)})
|
||||||
|
|
@ -806,7 +809,8 @@ class AiServiceClient:
|
||||||
# DEBUG: Print payload language field
|
# DEBUG: Print payload language field
|
||||||
logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'")
|
logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'")
|
||||||
try:
|
try:
|
||||||
response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout)
|
timeout = self.react_timeout if is_react_component else self.timeout
|
||||||
|
response = self.make_ai_service_request("/testgen", payload=payload, timeout=timeout)
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logger.exception(f"Error generating tests: {e}")
|
logger.exception(f"Error generating tests: {e}")
|
||||||
ph("cli-testgen-error-caught", {"error": str(e)})
|
ph("cli-testgen-error-caught", {"error": str(e)})
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,8 @@ def compare_test_results(
|
||||||
scope = TestDiffScope.STDOUT
|
scope = TestDiffScope.STDOUT
|
||||||
elif scope_str == "did_pass":
|
elif scope_str == "did_pass":
|
||||||
scope = TestDiffScope.DID_PASS
|
scope = TestDiffScope.DID_PASS
|
||||||
|
elif scope_str == "dom_snapshot":
|
||||||
|
scope = TestDiffScope.DOM_SNAPSHOT
|
||||||
|
|
||||||
test_info = diff.get("test_info", {})
|
test_info = diff.get("test_info", {})
|
||||||
# Build a test identifier string for JavaScript tests
|
# Build a test identifier string for JavaScript tests
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,10 @@ def normalize_codeflash_imports(source: str) -> str:
|
||||||
# Replace CommonJS require
|
# Replace CommonJS require
|
||||||
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
|
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
|
||||||
# Replace ES module import
|
# Replace ES module import
|
||||||
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
|
source = _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
|
||||||
|
# Fix outdated @testing-library/jest-dom import paths (v6+ removed /extend-expect subpath)
|
||||||
|
source = source.replace("@testing-library/jest-dom/extend-expect", "@testing-library/jest-dom")
|
||||||
|
return source
|
||||||
|
|
||||||
|
|
||||||
# Author: ali <mohammed18200118@gmail.com>
|
# Author: ali <mohammed18200118@gmail.com>
|
||||||
|
|
|
||||||
|
|
@ -72,42 +72,37 @@ def instrument_component_with_profiler(source: str, component_name: str, analyze
|
||||||
if computed is not None:
|
if computed is not None:
|
||||||
replacements.append(computed)
|
replacements.append(computed)
|
||||||
|
|
||||||
|
# Work entirely in bytes to avoid byte-offset vs char-index mismatches
|
||||||
|
# (tree-sitter returns byte offsets; non-ASCII chars cause divergence from Python string indices)
|
||||||
if replacements:
|
if replacements:
|
||||||
# Reconstruct result in a single pass
|
byte_parts: list[bytes] = []
|
||||||
parts: list[str] = []
|
|
||||||
prev = 0
|
prev = 0
|
||||||
for start, end, wrapped in replacements:
|
for start, end, wrapped in replacements:
|
||||||
# Use original source slices (string indices expected by original logic)
|
byte_parts.append(source_bytes[prev:start])
|
||||||
parts.append(source[prev:start])
|
byte_parts.append(wrapped.encode("utf-8"))
|
||||||
parts.append(wrapped)
|
|
||||||
prev = end
|
prev = end
|
||||||
parts.append(source[prev:])
|
byte_parts.append(source_bytes[prev:])
|
||||||
result = "".join(parts)
|
result_bytes = b"".join(byte_parts)
|
||||||
else:
|
else:
|
||||||
result = source
|
result_bytes = source_bytes
|
||||||
|
|
||||||
# Add render counter code at the top (after imports) using the already-parsed tree
|
|
||||||
|
|
||||||
# Add render counter code at the top (after imports)
|
# Add render counter code at the top (after imports)
|
||||||
counter_code = generate_render_counter_code(component_name)
|
counter_code = generate_render_counter_code(component_name)
|
||||||
|
|
||||||
# Inline logic similar to _insert_after_imports but reuse existing `tree` to avoid re-parsing
|
|
||||||
last_import_end = 0
|
last_import_end = 0
|
||||||
for child in tree.root_node.children:
|
for child in tree.root_node.children:
|
||||||
if child.type == "import_statement":
|
if child.type == "import_statement":
|
||||||
last_import_end = child.end_byte
|
last_import_end = child.end_byte
|
||||||
|
|
||||||
insert_pos = last_import_end
|
insert_pos = last_import_end
|
||||||
while insert_pos < len(result) and result[insert_pos] != "\n":
|
while insert_pos < len(result_bytes) and result_bytes[insert_pos : insert_pos + 1] != b"\n":
|
||||||
insert_pos += 1
|
insert_pos += 1
|
||||||
if insert_pos < len(result):
|
if insert_pos < len(result_bytes):
|
||||||
insert_pos += 1 # skip the newline
|
insert_pos += 1 # skip the newline
|
||||||
|
|
||||||
result = result[:insert_pos] + "\n" + counter_code + "\n\n" + result[insert_pos:]
|
counter_bytes = ("\n" + counter_code + "\n\n").encode("utf-8")
|
||||||
|
result = (result_bytes[:insert_pos] + counter_bytes + result_bytes[insert_pos:]).decode("utf-8")
|
||||||
|
|
||||||
# Ensure React is imported
|
|
||||||
|
|
||||||
# Ensure React is imported
|
|
||||||
return _ensure_react_import(result)
|
return _ensure_react_import(result)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -317,7 +312,8 @@ def _wrap_return_with_profiler(source: str, return_node: Node, profiler_id: str,
|
||||||
f"</React.Profiler>"
|
f"</React.Profiler>"
|
||||||
)
|
)
|
||||||
|
|
||||||
return source[:jsx_start] + wrapped + source[jsx_end:]
|
# Use byte-level splicing to avoid byte-offset vs char-index mismatches
|
||||||
|
return (source_bytes[:jsx_start] + wrapped.encode("utf-8") + source_bytes[jsx_end:]).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def _insert_after_imports(source: str, code: str, analyzer: TreeSitterAnalyzer) -> str:
|
def _insert_after_imports(source: str, code: str, analyzer: TreeSitterAnalyzer) -> str:
|
||||||
|
|
@ -421,7 +417,7 @@ let _codeflash_render_count_{safe_name} = 0;
|
||||||
if (typeof beforeEach !== 'undefined') {{
|
if (typeof beforeEach !== 'undefined') {{
|
||||||
beforeEach(() => {{ _codeflash_render_count_{safe_name} = 0; }});
|
beforeEach(() => {{ _codeflash_render_count_{safe_name} = 0; }});
|
||||||
}}
|
}}
|
||||||
function _codeflashOnRender_{safe_name}(id, phase, actualDuration, baseDuration) {{
|
function _codeflashOnRender_{safe_name}(id: any, phase: any, actualDuration: any, baseDuration: any) {{
|
||||||
_codeflash_render_count_{safe_name}++;
|
_codeflash_render_count_{safe_name}++;
|
||||||
console.log(`!######{marker_prefix}:${{id}}:${{phase}}:${{actualDuration}}:${{baseDuration}}:${{_codeflash_render_count_{safe_name}}}######!`);
|
console.log(`!######{marker_prefix}:${{id}}:${{phase}}:${{actualDuration}}:${{baseDuration}}:${{_codeflash_render_count_{safe_name}}}######!`);
|
||||||
}}"""
|
}}"""
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,15 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf
|
||||||
"""
|
"""
|
||||||
result = test_source
|
result = test_source
|
||||||
|
|
||||||
|
# Fix outdated @testing-library/jest-dom import paths (v6+ removed /extend-expect subpath)
|
||||||
|
result = result.replace("@testing-library/jest-dom/extend-expect", "@testing-library/jest-dom")
|
||||||
|
# Also fix require() variant
|
||||||
|
result = re.sub(
|
||||||
|
r"""require\s*\(\s*['"]@testing-library/jest-dom/extend-expect['"]\s*\)""",
|
||||||
|
"require('@testing-library/jest-dom')",
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure testing-library import
|
# Ensure testing-library import
|
||||||
if "@testing-library/react" not in result:
|
if "@testing-library/react" not in result:
|
||||||
result = "import { render, screen, act } from '@testing-library/react';\n" + result
|
result = "import { render, screen, act } from '@testing-library/react';\n" + result
|
||||||
|
|
@ -168,14 +177,14 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf
|
||||||
# This gives per-interaction A/B signal without the LLM needing to know about it.
|
# This gives per-interaction A/B signal without the LLM needing to know about it.
|
||||||
result = inject_interaction_markers(result)
|
result = inject_interaction_markers(result)
|
||||||
|
|
||||||
# Warn if no tests contain interaction calls — mount-phase only markers are
|
# If no tests contain interaction calls, auto-inject a rerender fallback so
|
||||||
# not useful for measuring optimization effectiveness.
|
# that EVERY React perf test produces at least one update-phase marker.
|
||||||
if not has_react_test_interactions(result):
|
if not has_react_test_interactions(result):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[REACT] Generated tests for %s contain no interactions (fireEvent, userEvent, rerender). "
|
"[REACT] Generated tests for %s contain no interactions — auto-injecting rerender fallback.",
|
||||||
"Tests will produce only mount-phase markers which cannot measure optimization improvements.",
|
|
||||||
component_info.function_name,
|
component_info.function_name,
|
||||||
)
|
)
|
||||||
|
result = _inject_rerender_fallback(result, component_info.function_name)
|
||||||
|
|
||||||
# Check interaction density — fewer than MIN_INTERACTION_CALLS total interactions
|
# Check interaction density — fewer than MIN_INTERACTION_CALLS total interactions
|
||||||
# means the test is unlikely to produce enough update-phase renders for reliable measurement.
|
# means the test is unlikely to produce enough update-phase renders for reliable measurement.
|
||||||
|
|
@ -221,6 +230,48 @@ def _extract_interaction_label(call_text: str) -> str:
|
||||||
return m.group(1) if m else "interaction"
|
return m.group(1) if m else "interaction"
|
||||||
|
|
||||||
|
|
||||||
|
def inject_dom_snapshot_calls(test_source: str) -> str:
|
||||||
|
"""Inject codeflash.snapshotDOM() calls after each user interaction in behavior mode.
|
||||||
|
|
||||||
|
Only active when `captureRender` (not `captureRenderPerf`) is present,
|
||||||
|
meaning the test is running in behavioral verification mode.
|
||||||
|
|
||||||
|
After each fireEvent.*, userEvent.*, or rerender() call, inserts:
|
||||||
|
codeflash.snapshotDOM('after_{label}_{n}');
|
||||||
|
on the next line with matching indentation.
|
||||||
|
"""
|
||||||
|
if "captureRender" not in test_source:
|
||||||
|
return test_source
|
||||||
|
if "captureRenderPerf" in test_source:
|
||||||
|
return test_source
|
||||||
|
|
||||||
|
interaction_counter: dict[str, int] = {}
|
||||||
|
lines = test_source.split("\n")
|
||||||
|
new_lines: list[str] = []
|
||||||
|
|
||||||
|
# Also match rerender() calls — use .* to handle nested parens like getByText('Add')
|
||||||
|
snapshot_interaction_pattern = re.compile(
|
||||||
|
r"^(\s*)((?:await\s+)?(?:fireEvent\.\w+|userEvent\.\w+|(?:\w+\.)?rerender)\s*\(.*\))\s*;?\s*$",
|
||||||
|
re.MULTILINE,
|
||||||
|
)
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
new_lines.append(line)
|
||||||
|
m = snapshot_interaction_pattern.match(line)
|
||||||
|
if m:
|
||||||
|
indent = m.group(1)
|
||||||
|
call_text = m.group(2)
|
||||||
|
label = _extract_interaction_label(call_text)
|
||||||
|
if label == "interaction":
|
||||||
|
# rerender() call
|
||||||
|
label = "rerender"
|
||||||
|
interaction_counter[label] = interaction_counter.get(label, 0) + 1
|
||||||
|
unique_label = f"{label}_{interaction_counter[label]}"
|
||||||
|
new_lines.append(f"{indent}codeflash.snapshotDOM('after_{unique_label}');")
|
||||||
|
|
||||||
|
return "\n".join(new_lines)
|
||||||
|
|
||||||
|
|
||||||
def inject_interaction_markers(test_source: str) -> str:
|
def inject_interaction_markers(test_source: str) -> str:
|
||||||
"""Inject _codeflashMarkInteraction() calls before each fireEvent/userEvent call.
|
"""Inject _codeflashMarkInteraction() calls before each fireEvent/userEvent call.
|
||||||
|
|
||||||
|
|
@ -325,3 +376,54 @@ def has_high_density_interactions(test_source: str) -> bool:
|
||||||
|
|
||||||
interaction_calls = _INTERACTION_PATTERNS.findall(test_source)
|
interaction_calls = _INTERACTION_PATTERNS.findall(test_source)
|
||||||
return len(interaction_calls) >= _MIN_SEQUENTIAL_INTERACTIONS
|
return len(interaction_calls) >= _MIN_SEQUENTIAL_INTERACTIONS
|
||||||
|
|
||||||
|
|
||||||
|
# Pattern to extract props from the first render(<Component ...props... />) call
|
||||||
|
_RENDER_CALL_PROPS_PATTERN = re.compile(
|
||||||
|
r"render\s*\(\s*<(\w+)\s+([^/]*?)\s*/?\s*>",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_render_props(test_source: str, component_name: str) -> str | None:
|
||||||
|
"""Extract the props expression from the first render(<Component ...>) call."""
|
||||||
|
for m in _RENDER_CALL_PROPS_PATTERN.finditer(test_source):
|
||||||
|
if m.group(1) == component_name:
|
||||||
|
props_text = m.group(2).strip()
|
||||||
|
if props_text:
|
||||||
|
return props_text
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _inject_rerender_fallback(test_source: str, component_name: str) -> str:
|
||||||
|
"""Inject a rerender efficiency test block when the test has no interactions.
|
||||||
|
|
||||||
|
This ensures every React perf test produces at least one update-phase marker.
|
||||||
|
"""
|
||||||
|
# Try to extract props from existing render call
|
||||||
|
props_expr = _extract_render_props(test_source, component_name)
|
||||||
|
if props_expr:
|
||||||
|
jsx_open = f"<{component_name} {props_expr} />"
|
||||||
|
else:
|
||||||
|
jsx_open = f"<{component_name} />"
|
||||||
|
|
||||||
|
rerender_block = f"""
|
||||||
|
describe('{component_name} rerender efficiency (auto-generated)', () => {{
|
||||||
|
it('should handle same-props rerenders', () => {{
|
||||||
|
const {{ rerender }} = render({jsx_open});
|
||||||
|
for (let i = 0; i < 10; i++) {{
|
||||||
|
rerender({jsx_open});
|
||||||
|
}}
|
||||||
|
}});
|
||||||
|
}});
|
||||||
|
"""
|
||||||
|
# Ensure {rerender} is in the @testing-library/react import
|
||||||
|
rtl_import_match = re.search(
|
||||||
|
r"import\s*\{([^}]+)\}\s*from\s*['\"]@testing-library/react['\"]", test_source
|
||||||
|
)
|
||||||
|
if rtl_import_match:
|
||||||
|
imports = rtl_import_match.group(1)
|
||||||
|
if "rerender" not in imports:
|
||||||
|
# Don't add 'rerender' to import — it comes from render() return value, not an import
|
||||||
|
pass
|
||||||
|
|
||||||
|
return test_source.rstrip() + "\n" + rerender_block
|
||||||
|
|
|
||||||
|
|
@ -1346,6 +1346,13 @@ def _instrument_js_test_code(
|
||||||
code=code, function_to_optimize=function_to_optimize, capture_func=capture_func, start_counter=final_counter
|
code=code, function_to_optimize=function_to_optimize, capture_func=capture_func, start_counter=final_counter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# In behavior mode, inject DOM snapshot calls after each interaction so the
|
||||||
|
# comparator can detect post-interaction DOM divergence between original and candidate.
|
||||||
|
if mode == TestingMode.BEHAVIOR:
|
||||||
|
from codeflash.languages.javascript.frameworks.react.testgen import inject_dom_snapshot_calls # noqa: PLC0415
|
||||||
|
|
||||||
|
code = inject_dom_snapshot_calls(code)
|
||||||
|
|
||||||
return code
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -174,6 +174,14 @@ class JavaScriptSupport:
|
||||||
logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004
|
logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Skip closures inside React components — they capture hook state
|
||||||
|
# and cannot be tested independently
|
||||||
|
if func.parent_function and func.parent_function in react_component_map:
|
||||||
|
logger.debug(
|
||||||
|
f"Skipping closure {func.name} inside React component {func.parent_function}" # noqa: G004
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# Build parents list
|
# Build parents list
|
||||||
parents: list[FunctionParent] = []
|
parents: list[FunctionParent] = []
|
||||||
if func.class_name:
|
if func.class_name:
|
||||||
|
|
@ -231,8 +239,23 @@ class JavaScriptSupport:
|
||||||
source, include_methods=True, include_arrow_functions=True, require_name=True
|
source, include_methods=True, include_arrow_functions=True, require_name=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build React component lookup to filter out closures
|
||||||
|
react_component_names: set[str] = set()
|
||||||
|
try:
|
||||||
|
from codeflash.languages.javascript.frameworks.react.discovery import classify_component # noqa: PLC0415
|
||||||
|
|
||||||
|
for func in tree_functions:
|
||||||
|
if classify_component(func, source, analyzer) is not None:
|
||||||
|
react_component_names.add(func.name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
functions: list[FunctionToOptimize] = []
|
functions: list[FunctionToOptimize] = []
|
||||||
for func in tree_functions:
|
for func in tree_functions:
|
||||||
|
# Skip closures inside React components
|
||||||
|
if func.parent_function and func.parent_function in react_component_names:
|
||||||
|
continue
|
||||||
|
|
||||||
# Build parents list
|
# Build parents list
|
||||||
parents: list[FunctionParent] = []
|
parents: list[FunctionParent] = []
|
||||||
if func.class_name:
|
if func.class_name:
|
||||||
|
|
@ -1212,6 +1235,47 @@ class JavaScriptSupport:
|
||||||
|
|
||||||
# === Code Transformation ===
|
# === Code Transformation ===
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_imports(source: str, analyzer: TreeSitterAnalyzer) -> str:
|
||||||
|
"""Strip import/require statements from source code.
|
||||||
|
|
||||||
|
Used to remove duplicate imports from optimizer output before inserting
|
||||||
|
into the original file (which already has its own imports).
|
||||||
|
"""
|
||||||
|
source_bytes = source.encode("utf8")
|
||||||
|
tree = analyzer.parse(source_bytes)
|
||||||
|
|
||||||
|
# Collect byte ranges of import nodes to remove
|
||||||
|
ranges_to_remove: list[tuple[int, int]] = []
|
||||||
|
for child in tree.root_node.children:
|
||||||
|
if child.type == "import_statement":
|
||||||
|
ranges_to_remove.append((child.start_byte, child.end_byte))
|
||||||
|
# CJS require: const x = require('...')
|
||||||
|
elif child.type in ("lexical_declaration", "variable_declaration"):
|
||||||
|
text = source_bytes[child.start_byte : child.end_byte].decode("utf8")
|
||||||
|
if "require(" in text:
|
||||||
|
ranges_to_remove.append((child.start_byte, child.end_byte))
|
||||||
|
|
||||||
|
if not ranges_to_remove:
|
||||||
|
return source
|
||||||
|
|
||||||
|
# Build result, skipping removed ranges
|
||||||
|
parts: list[bytes] = []
|
||||||
|
prev_end = 0
|
||||||
|
for start, end in sorted(ranges_to_remove):
|
||||||
|
parts.append(source_bytes[prev_end:start])
|
||||||
|
# Skip trailing newline after import
|
||||||
|
if end < len(source_bytes) and source_bytes[end : end + 1] == b"\n":
|
||||||
|
end += 1
|
||||||
|
prev_end = end
|
||||||
|
parts.append(source_bytes[prev_end:])
|
||||||
|
|
||||||
|
result = b"".join(parts).decode("utf8")
|
||||||
|
# Remove leading blank lines left by stripping imports
|
||||||
|
while result.startswith("\n"):
|
||||||
|
result = result[1:]
|
||||||
|
return result
|
||||||
|
|
||||||
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
|
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
|
||||||
"""Replace a function in source code with new implementation.
|
"""Replace a function in source code with new implementation.
|
||||||
|
|
||||||
|
|
@ -1247,6 +1311,15 @@ class JavaScriptSupport:
|
||||||
else:
|
else:
|
||||||
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||||
|
|
||||||
|
# Strip imports from optimizer output — the original file already has them,
|
||||||
|
# and import merging is handled by _add_global_declarations_for_language()
|
||||||
|
new_source = self._strip_imports(new_source, analyzer)
|
||||||
|
|
||||||
|
# If stripping imports left nothing, return original
|
||||||
|
if not new_source.strip():
|
||||||
|
logger.warning("new_source was only imports for %s, returning original", function.function_name)
|
||||||
|
return source
|
||||||
|
|
||||||
# Check if new_source contains a JSDoc comment - if so, use full replacement
|
# Check if new_source contains a JSDoc comment - if so, use full replacement
|
||||||
# to include the updated JSDoc along with the function body
|
# to include the updated JSDoc along with the function body
|
||||||
stripped_new_source = new_source.strip()
|
stripped_new_source = new_source.strip()
|
||||||
|
|
@ -1263,16 +1336,64 @@ class JavaScriptSupport:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Could not extract body for %s from optimized code, using full replacement", function.function_name
|
"Could not extract body for %s from optimized code, using full replacement", function.function_name
|
||||||
)
|
)
|
||||||
# Verify that new_source contains actual code before falling back to text replacement
|
if self._contains_function_declaration(new_source, function.function_name, analyzer):
|
||||||
# This prevents deletion of the original function when new_source is invalid
|
return self._replace_function_text_based(source, function, new_source, analyzer)
|
||||||
if not self._contains_function_declaration(new_source, function.function_name, analyzer):
|
# Final fallback: line-range replacement using the function's known line boundaries.
|
||||||
logger.warning("new_source does not contain function %s, returning original", function.function_name)
|
# This handles cases where tree-sitter can't parse the optimized output.
|
||||||
return source
|
logger.warning(
|
||||||
return self._replace_function_text_based(source, function, new_source, analyzer)
|
"Falling back to line-range replacement for %s (lines %d-%d)",
|
||||||
|
function.function_name,
|
||||||
|
function.starting_line,
|
||||||
|
function.ending_line,
|
||||||
|
)
|
||||||
|
return self._replace_function_by_line_range(source, function, new_source)
|
||||||
|
|
||||||
# Find the original function and replace its body
|
# Find the original function and replace its body
|
||||||
return self._replace_function_body(source, function, new_body, analyzer)
|
return self._replace_function_body(source, function, new_body, analyzer)
|
||||||
|
|
||||||
|
def _replace_function_by_line_range(
|
||||||
|
self, source: str, function: FunctionToOptimize, new_source: str
|
||||||
|
) -> str:
|
||||||
|
"""Last-resort replacement: cut the original function by line range and paste new_source."""
|
||||||
|
if function.starting_line is None or function.ending_line is None:
|
||||||
|
return source
|
||||||
|
|
||||||
|
lines = source.splitlines(keepends=True)
|
||||||
|
if lines and not lines[-1].endswith("\n"):
|
||||||
|
lines[-1] += "\n"
|
||||||
|
|
||||||
|
# Get indentation from original function's first line
|
||||||
|
if function.starting_line <= len(lines):
|
||||||
|
original_first = lines[function.starting_line - 1]
|
||||||
|
original_indent = len(original_first) - len(original_first.lstrip())
|
||||||
|
else:
|
||||||
|
original_indent = 0
|
||||||
|
|
||||||
|
new_lines = new_source.splitlines(keepends=True)
|
||||||
|
if new_lines:
|
||||||
|
first_non_blank = next((l for l in new_lines if l.strip()), new_lines[0])
|
||||||
|
new_indent = len(first_non_blank) - len(first_non_blank.lstrip())
|
||||||
|
indent_diff = original_indent - new_indent
|
||||||
|
if indent_diff != 0:
|
||||||
|
adjusted = []
|
||||||
|
for line in new_lines:
|
||||||
|
if line.strip():
|
||||||
|
if indent_diff > 0:
|
||||||
|
adjusted.append(" " * indent_diff + line)
|
||||||
|
else:
|
||||||
|
cur = len(line) - len(line.lstrip())
|
||||||
|
adjusted.append(line[min(cur, abs(indent_diff)) :])
|
||||||
|
else:
|
||||||
|
adjusted.append(line)
|
||||||
|
new_lines = adjusted
|
||||||
|
|
||||||
|
if new_lines and not new_lines[-1].endswith("\n"):
|
||||||
|
new_lines[-1] += "\n"
|
||||||
|
|
||||||
|
before = lines[: function.starting_line - 1]
|
||||||
|
after = lines[function.ending_line :]
|
||||||
|
return "".join(before + new_lines + after)
|
||||||
|
|
||||||
def _contains_function_declaration(self, source: str, function_name: str, analyzer: TreeSitterAnalyzer) -> bool:
|
def _contains_function_declaration(self, source: str, function_name: str, analyzer: TreeSitterAnalyzer) -> bool:
|
||||||
"""Check if source contains a function declaration with the given name.
|
"""Check if source contains a function declaration with the given name.
|
||||||
|
|
||||||
|
|
@ -1912,7 +2033,52 @@ class JavaScriptSupport:
|
||||||
generated_tests = inject_test_globals(generated_tests, test_framework)
|
generated_tests = inject_test_globals(generated_tests, test_framework)
|
||||||
if is_typescript():
|
if is_typescript():
|
||||||
generated_tests = disable_ts_check(generated_tests)
|
generated_tests = disable_ts_check(generated_tests)
|
||||||
return normalize_generated_tests_imports(generated_tests)
|
generated_tests = normalize_generated_tests_imports(generated_tests)
|
||||||
|
|
||||||
|
# Apply React-specific post-processing (act wrapping, interaction markers, etc.)
|
||||||
|
if self._is_react_source(source_file_path):
|
||||||
|
generated_tests = self._apply_react_postprocessing(generated_tests, source_file_path)
|
||||||
|
|
||||||
|
return generated_tests
|
||||||
|
|
||||||
|
def _is_react_source(self, source_file_path: Path) -> bool:
|
||||||
|
"""Check if the source file is a React component."""
|
||||||
|
try:
|
||||||
|
content = source_file_path.read_text("utf-8", errors="replace")
|
||||||
|
return "react" in content.lower() and ("jsx" in content.lower() or source_file_path.suffix in (".tsx", ".jsx"))
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _apply_react_postprocessing(self, generated_tests: GeneratedTestsList, source_file_path: Path) -> GeneratedTestsList:
|
||||||
|
"""Apply React-specific post-processing to generated tests."""
|
||||||
|
from codeflash.languages.javascript.frameworks.react.testgen import ( # noqa: PLC0415
|
||||||
|
post_process_react_tests,
|
||||||
|
)
|
||||||
|
from codeflash.languages.javascript.frameworks.react.discovery import ComponentType, ReactComponentInfo # noqa: PLC0415
|
||||||
|
from codeflash.models.models import GeneratedTests, GeneratedTestsList # noqa: PLC0415
|
||||||
|
|
||||||
|
component_name = source_file_path.stem
|
||||||
|
dummy_info = ReactComponentInfo(
|
||||||
|
function_name=component_name,
|
||||||
|
component_type=ComponentType.FUNCTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_tests = []
|
||||||
|
for test in generated_tests.generated_tests:
|
||||||
|
updated_tests.append(
|
||||||
|
GeneratedTests(
|
||||||
|
generated_original_test_source=test.generated_original_test_source,
|
||||||
|
instrumented_behavior_test_source=post_process_react_tests(
|
||||||
|
test.instrumented_behavior_test_source, dummy_info
|
||||||
|
),
|
||||||
|
instrumented_perf_test_source=post_process_react_tests(
|
||||||
|
test.instrumented_perf_test_source, dummy_info
|
||||||
|
),
|
||||||
|
behavior_file_path=test.behavior_file_path,
|
||||||
|
perf_file_path=test.perf_file_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return GeneratedTestsList(generated_tests=updated_tests)
|
||||||
|
|
||||||
def remove_test_functions_from_generated_tests(
|
def remove_test_functions_from_generated_tests(
|
||||||
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]
|
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]
|
||||||
|
|
|
||||||
|
|
@ -401,6 +401,29 @@ class TreeSitterAnalyzer:
|
||||||
great_grandparent = grandparent.parent
|
great_grandparent = grandparent.parent
|
||||||
if great_grandparent and great_grandparent.type == "export_statement":
|
if great_grandparent and great_grandparent.type == "export_statement":
|
||||||
return True
|
return True
|
||||||
|
# HOC wrappers: export const Comp = forwardRef/memo(function/arrow)
|
||||||
|
# Tree: function → arguments → call_expression → variable_declarator → lexical_declaration → export_statement
|
||||||
|
if parent and parent.type == "arguments":
|
||||||
|
call_expr = parent.parent
|
||||||
|
if call_expr and call_expr.type == "call_expression":
|
||||||
|
declarator = call_expr.parent
|
||||||
|
if declarator and declarator.type == "variable_declarator":
|
||||||
|
lex_decl = declarator.parent
|
||||||
|
if lex_decl and lex_decl.type in ("lexical_declaration", "variable_declaration"):
|
||||||
|
export_stmt = lex_decl.parent
|
||||||
|
if export_stmt and export_stmt.type == "export_statement":
|
||||||
|
return True
|
||||||
|
# Nested HOC: export const X = memo(forwardRef(fn))
|
||||||
|
if declarator and declarator.type == "arguments":
|
||||||
|
outer_call = declarator.parent
|
||||||
|
if outer_call and outer_call.type == "call_expression":
|
||||||
|
outer_decl = outer_call.parent
|
||||||
|
if outer_decl and outer_decl.type == "variable_declarator":
|
||||||
|
lex_decl = outer_decl.parent
|
||||||
|
if lex_decl and lex_decl.type in ("lexical_declaration", "variable_declaration"):
|
||||||
|
export_stmt = lex_decl.parent
|
||||||
|
if export_stmt and export_stmt.type == "export_statement":
|
||||||
|
return True
|
||||||
|
|
||||||
# For methods in exported classes
|
# For methods in exported classes
|
||||||
if node.type == "method_definition":
|
if node.type == "method_definition":
|
||||||
|
|
@ -609,6 +632,8 @@ class TreeSitterAnalyzer:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
_HOC_WRAPPER_NAMES = frozenset({"forwardRef", "memo", "React.forwardRef", "React.memo"})
|
||||||
|
|
||||||
def _get_name_from_assignment(self, node: Node, source_bytes: bytes) -> str:
|
def _get_name_from_assignment(self, node: Node, source_bytes: bytes) -> str:
|
||||||
"""Try to extract function name from parent variable declaration or assignment.
|
"""Try to extract function name from parent variable declaration or assignment.
|
||||||
|
|
||||||
|
|
@ -617,6 +642,8 @@ class TreeSitterAnalyzer:
|
||||||
- const foo = function() {}
|
- const foo = function() {}
|
||||||
- let bar = function() {}
|
- let bar = function() {}
|
||||||
- obj.method = () => {}
|
- obj.method = () => {}
|
||||||
|
- const Tooltip = forwardRef(function TooltipInner(...) {...})
|
||||||
|
- const MemoComp = React.memo((props) => {...})
|
||||||
"""
|
"""
|
||||||
parent = node.parent
|
parent = node.parent
|
||||||
if parent is None:
|
if parent is None:
|
||||||
|
|
@ -628,6 +655,30 @@ class TreeSitterAnalyzer:
|
||||||
if name_node:
|
if name_node:
|
||||||
return self.get_node_text(name_node, source_bytes)
|
return self.get_node_text(name_node, source_bytes)
|
||||||
|
|
||||||
|
# Check for HOC wrapper: const Name = forwardRef/memo(function/arrow)
|
||||||
|
# Tree structure: function → arguments → call_expression → variable_declarator
|
||||||
|
if parent.type == "arguments":
|
||||||
|
call_expr = parent.parent
|
||||||
|
if call_expr is not None and call_expr.type == "call_expression":
|
||||||
|
callee = call_expr.child_by_field_name("function")
|
||||||
|
if callee is not None:
|
||||||
|
callee_text = self.get_node_text(callee, source_bytes)
|
||||||
|
if callee_text in self._HOC_WRAPPER_NAMES:
|
||||||
|
declarator = call_expr.parent
|
||||||
|
if declarator is not None and declarator.type == "variable_declarator":
|
||||||
|
name_node = declarator.child_by_field_name("name")
|
||||||
|
if name_node:
|
||||||
|
return self.get_node_text(name_node, source_bytes)
|
||||||
|
# Nested HOC: memo(forwardRef(fn)) — call_expr parent is arguments of outer call
|
||||||
|
if declarator is not None and declarator.type == "arguments":
|
||||||
|
outer_call = declarator.parent
|
||||||
|
if outer_call is not None and outer_call.type == "call_expression":
|
||||||
|
outer_declarator = outer_call.parent
|
||||||
|
if outer_declarator is not None and outer_declarator.type == "variable_declarator":
|
||||||
|
name_node = outer_declarator.child_by_field_name("name")
|
||||||
|
if name_node:
|
||||||
|
return self.get_node_text(name_node, source_bytes)
|
||||||
|
|
||||||
# Check for assignment expression: foo = ...
|
# Check for assignment expression: foo = ...
|
||||||
if parent.type == "assignment_expression":
|
if parent.type == "assignment_expression":
|
||||||
left_node = parent.child_by_field_name("left")
|
left_node = parent.child_by_field_name("left")
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ class TestDiffScope(str, Enum):
|
||||||
RETURN_VALUE = "return_value"
|
RETURN_VALUE = "return_value"
|
||||||
STDOUT = "stdout"
|
STDOUT = "stdout"
|
||||||
DID_PASS = "did_pass" # noqa: S105
|
DID_PASS = "did_pass" # noqa: S105
|
||||||
|
DOM_SNAPSHOT = "dom_snapshot"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,7 @@ from codeflash.models.models import (
|
||||||
OptimizedCandidateResult,
|
OptimizedCandidateResult,
|
||||||
OptimizedCandidateSource,
|
OptimizedCandidateSource,
|
||||||
OriginalCodeBaseline,
|
OriginalCodeBaseline,
|
||||||
|
TestDiffScope,
|
||||||
TestFile,
|
TestFile,
|
||||||
TestFiles,
|
TestFiles,
|
||||||
TestingMode,
|
TestingMode,
|
||||||
|
|
@ -2955,6 +2956,13 @@ class FunctionOptimizer:
|
||||||
logger.info("h3|Test results matched ✅")
|
logger.info("h3|Test results matched ✅")
|
||||||
console.rule()
|
console.rule()
|
||||||
else:
|
else:
|
||||||
|
dom_snapshot_diffs = [d for d in diffs if d.scope == TestDiffScope.DOM_SNAPSHOT]
|
||||||
|
if dom_snapshot_diffs:
|
||||||
|
logger.warning(
|
||||||
|
"[REACT] DOM snapshot divergence detected after %d interaction(s). "
|
||||||
|
"The optimized component produces different DOM output.",
|
||||||
|
len(dom_snapshot_diffs),
|
||||||
|
)
|
||||||
self.repair_if_possible(
|
self.repair_if_possible(
|
||||||
candidate, diffs, eval_ctx, code_context, len(candidate_behavior_results), exp_type
|
candidate, diffs, eval_ctx, code_context, len(candidate_behavior_results), exp_type
|
||||||
)
|
)
|
||||||
|
|
@ -3178,8 +3186,7 @@ class FunctionOptimizer:
|
||||||
coverage_config_file=coverage_config_file,
|
coverage_config_file=coverage_config_file,
|
||||||
skip_sqlite_cleanup=skip_cleanup,
|
skip_sqlite_cleanup=skip_cleanup,
|
||||||
)
|
)
|
||||||
if testing_type == TestingMode.PERFORMANCE:
|
results.perf_stdout = run_result.stdout
|
||||||
results.perf_stdout = run_result.stdout
|
|
||||||
return results, coverage_results
|
return results, coverage_results
|
||||||
# For LINE_PROFILE mode, Python uses .lprof files while JavaScript uses JSON
|
# For LINE_PROFILE mode, Python uses .lprof files while JavaScript uses JSON
|
||||||
# Return TestResults for JavaScript so _line_profiler_step_javascript can parse the JSON
|
# Return TestResults for JavaScript so _line_profiler_step_javascript can parse the JSON
|
||||||
|
|
|
||||||
|
|
@ -46,11 +46,11 @@ function _getReact() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to load better-sqlite3, fall back to JSON if not available
|
|
||||||
let useSqlite = false;
|
|
||||||
|
|
||||||
// Configuration from environment
|
// Configuration from environment
|
||||||
const OUTPUT_FILE = process.env.CODEFLASH_OUTPUT_FILE;
|
const OUTPUT_FILE = process.env.CODEFLASH_OUTPUT_FILE;
|
||||||
|
|
||||||
|
// Enable SQLite when better-sqlite3 is available and an output file is configured
|
||||||
|
let useSqlite = !!(Database && OUTPUT_FILE);
|
||||||
const LOOP_INDEX = parseInt(process.env.CODEFLASH_LOOP_INDEX || '1', 10);
|
const LOOP_INDEX = parseInt(process.env.CODEFLASH_LOOP_INDEX || '1', 10);
|
||||||
const TEST_ITERATION = process.env.CODEFLASH_TEST_ITERATION;
|
const TEST_ITERATION = process.env.CODEFLASH_TEST_ITERATION;
|
||||||
const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE;
|
const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE;
|
||||||
|
|
@ -1091,6 +1091,35 @@ function captureRender(funcName, lineId, renderFn, Component, ...createElementAr
|
||||||
return renderResult;
|
return renderResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Capture a DOM snapshot after a user interaction for behavioral verification.
|
||||||
|
*
|
||||||
|
* Writes normalized `document.body.innerHTML` to the SQLite database as an
|
||||||
|
* additional row with function name `__dom_snapshot__`. The existing
|
||||||
|
* `compare-results.js` comparator compares all rows by invocation ID, so
|
||||||
|
* snapshot rows are compared automatically.
|
||||||
|
*
|
||||||
|
* @param {string} label - Unique label for this snapshot (e.g. 'after_click_1')
|
||||||
|
*/
|
||||||
|
function snapshotDOM(label) {
|
||||||
|
if (!useSqlite || !db) return;
|
||||||
|
if (typeof document === 'undefined' || !document.body) return;
|
||||||
|
|
||||||
|
// Normalize HTML: collapse whitespace, strip React-internal attributes
|
||||||
|
let html = document.body.innerHTML;
|
||||||
|
html = html.replace(/\s+/g, ' ').trim();
|
||||||
|
html = html.replace(/\s*data-reactroot\s*/g, '');
|
||||||
|
|
||||||
|
const { testModulePath, testClassName, testFunctionName } = _getTestContext();
|
||||||
|
const invocationId = `snapshot_${label}`;
|
||||||
|
|
||||||
|
recordResult(
|
||||||
|
testModulePath, testClassName, testFunctionName,
|
||||||
|
'__dom_snapshot__', invocationId,
|
||||||
|
[label], html, null, 0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Capture a React component render call for PERFORMANCE benchmarking only.
|
* Capture a React component render call for PERFORMANCE benchmarking only.
|
||||||
*
|
*
|
||||||
|
|
@ -1334,6 +1363,7 @@ module.exports = {
|
||||||
capturePerf, // Performance benchmarking (prints to stdout only)
|
capturePerf, // Performance benchmarking (prints to stdout only)
|
||||||
captureRender, // React render behavior verification (writes to SQLite)
|
captureRender, // React render behavior verification (writes to SQLite)
|
||||||
captureRenderPerf, // React render performance benchmarking (prints to stdout only)
|
captureRenderPerf, // React render performance benchmarking (prints to stdout only)
|
||||||
|
snapshotDOM, // DOM snapshot after interaction (behavior verification)
|
||||||
captureMultiple,
|
captureMultiple,
|
||||||
writeResults,
|
writeResults,
|
||||||
clearResults,
|
clearResults,
|
||||||
|
|
|
||||||
|
|
@ -167,9 +167,13 @@ function compareResults(originalResults, candidateResults) {
|
||||||
|
|
||||||
if (!isEqual) {
|
if (!isEqual) {
|
||||||
allEquivalent = false;
|
allEquivalent = false;
|
||||||
|
// Use dom_snapshot scope for __dom_snapshot__ rows
|
||||||
|
const scope = original.functionGettingTested === '__dom_snapshot__'
|
||||||
|
? 'dom_snapshot'
|
||||||
|
: 'return_value';
|
||||||
diffs.push({
|
diffs.push({
|
||||||
invocation_id: invocationId,
|
invocation_id: invocationId,
|
||||||
scope: 'return_value',
|
scope,
|
||||||
original: summarizeValue(originalValue),
|
original: summarizeValue(originalValue),
|
||||||
candidate: summarizeValue(candidateValue),
|
candidate: summarizeValue(candidateValue),
|
||||||
test_info: {
|
test_info: {
|
||||||
|
|
|
||||||
10
packages/codeflash/runtime/index.d.ts
vendored
10
packages/codeflash/runtime/index.d.ts
vendored
|
|
@ -36,6 +36,15 @@ export function capturePerf<T extends (...args: any[]) => any>(
|
||||||
...args: Parameters<T>
|
...args: Parameters<T>
|
||||||
): ReturnType<T>;
|
): ReturnType<T>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Capture a DOM snapshot after a user interaction for behavioral verification.
|
||||||
|
* Writes normalized document.body.innerHTML to the SQLite database with
|
||||||
|
* function name '__dom_snapshot__' so the comparator detects DOM divergence.
|
||||||
|
*
|
||||||
|
* @param label - Unique label for this snapshot (e.g. 'after_click_1')
|
||||||
|
*/
|
||||||
|
export function snapshotDOM(label: string): void;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Capture multiple invocations for benchmarking.
|
* Capture multiple invocations for benchmarking.
|
||||||
*
|
*
|
||||||
|
|
@ -126,6 +135,7 @@ export const TEST_ITERATION: string;
|
||||||
declare const codeflash: {
|
declare const codeflash: {
|
||||||
capture: typeof capture;
|
capture: typeof capture;
|
||||||
capturePerf: typeof capturePerf;
|
capturePerf: typeof capturePerf;
|
||||||
|
snapshotDOM: typeof snapshotDOM;
|
||||||
captureMultiple: typeof captureMultiple;
|
captureMultiple: typeof captureMultiple;
|
||||||
writeResults: typeof writeResults;
|
writeResults: typeof writeResults;
|
||||||
clearResults: typeof clearResults;
|
clearResults: typeof clearResults;
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ module.exports = {
|
||||||
|
|
||||||
captureRender: capture.captureRender,
|
captureRender: capture.captureRender,
|
||||||
captureRenderPerf: capture.captureRenderPerf,
|
captureRenderPerf: capture.captureRenderPerf,
|
||||||
|
snapshotDOM: capture.snapshotDOM,
|
||||||
|
|
||||||
captureMultiple: capture.captureMultiple,
|
captureMultiple: capture.captureMultiple,
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -173,6 +173,8 @@ class TestRenderEfficiencyCritic:
|
||||||
optimized_render_duration=10.0,
|
optimized_render_duration=10.0,
|
||||||
original_update_render_count=8,
|
original_update_render_count=8,
|
||||||
optimized_update_render_count=8,
|
optimized_update_render_count=8,
|
||||||
|
original_update_duration=80.0,
|
||||||
|
optimized_update_duration=5.0,
|
||||||
) is True
|
) is True
|
||||||
|
|
||||||
def test_rejects_worse_than_best(self):
|
def test_rejects_worse_than_best(self):
|
||||||
|
|
@ -448,6 +450,8 @@ class TestRenderEfficiencyCriticTrustDuration:
|
||||||
optimized_render_duration=10.0,
|
optimized_render_duration=10.0,
|
||||||
original_update_render_count=8,
|
original_update_render_count=8,
|
||||||
optimized_update_render_count=8,
|
optimized_update_render_count=8,
|
||||||
|
original_update_duration=100.0,
|
||||||
|
optimized_update_duration=10.0,
|
||||||
trust_duration=True,
|
trust_duration=True,
|
||||||
) is True
|
) is True
|
||||||
|
|
||||||
|
|
|
||||||
78
tests/react/test_dom_snapshot.py
Normal file
78
tests/react/test_dom_snapshot.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
"""Tests for inject_dom_snapshot_calls in React behavioral verification."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from codeflash.languages.javascript.frameworks.react.testgen import inject_dom_snapshot_calls
|
||||||
|
|
||||||
|
|
||||||
|
def test_injects_after_fireEvent():
|
||||||
|
source = """\
|
||||||
|
import codeflash from 'codeflash';
|
||||||
|
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
|
||||||
|
fireEvent.click(screen.getByText('Add'));
|
||||||
|
fireEvent.change(input, { target: { value: 'hi' } });
|
||||||
|
"""
|
||||||
|
result = inject_dom_snapshot_calls(source)
|
||||||
|
assert "codeflash.snapshotDOM('after_click_1');" in result
|
||||||
|
assert "codeflash.snapshotDOM('after_change_1');" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_skips_perf_mode():
|
||||||
|
source = """\
|
||||||
|
import codeflash from 'codeflash';
|
||||||
|
const result = await codeflash.captureRenderPerf('Comp', '1', render, Comp);
|
||||||
|
fireEvent.click(screen.getByText('Add'));
|
||||||
|
"""
|
||||||
|
result = inject_dom_snapshot_calls(source)
|
||||||
|
assert "snapshotDOM" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_skips_without_captureRender():
|
||||||
|
source = """\
|
||||||
|
import codeflash from 'codeflash';
|
||||||
|
fireEvent.click(screen.getByText('Add'));
|
||||||
|
"""
|
||||||
|
result = inject_dom_snapshot_calls(source)
|
||||||
|
assert "snapshotDOM" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_preserves_indentation():
|
||||||
|
source = """\
|
||||||
|
import codeflash from 'codeflash';
|
||||||
|
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
|
||||||
|
fireEvent.click(btn);
|
||||||
|
fireEvent.change(input, { target: { value: 'x' } });
|
||||||
|
"""
|
||||||
|
result = inject_dom_snapshot_calls(source)
|
||||||
|
lines = result.split("\n")
|
||||||
|
# Find the snapshot lines and check their indentation
|
||||||
|
snapshot_lines = [l for l in lines if "snapshotDOM" in l]
|
||||||
|
assert len(snapshot_lines) == 2
|
||||||
|
assert snapshot_lines[0].startswith(" codeflash.snapshotDOM")
|
||||||
|
assert snapshot_lines[1].startswith(" codeflash.snapshotDOM")
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_semicolons():
|
||||||
|
"""Real-world projects (e.g. Zustand) don't use semicolons."""
|
||||||
|
source = """\
|
||||||
|
const { container } = codeflash.captureRender('Comp', '1', render, Comp)
|
||||||
|
fireEvent.click(screen.getByText('button'))
|
||||||
|
fireEvent.click(screen.getByTestId('test-shallow'))
|
||||||
|
"""
|
||||||
|
result = inject_dom_snapshot_calls(source)
|
||||||
|
assert "after_click_1" in result
|
||||||
|
assert "after_click_2" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_counter():
|
||||||
|
source = """\
|
||||||
|
import codeflash from 'codeflash';
|
||||||
|
const { container } = codeflash.captureRender('Comp', '1', render, Comp);
|
||||||
|
fireEvent.click(btn);
|
||||||
|
fireEvent.click(btn);
|
||||||
|
fireEvent.click(btn);
|
||||||
|
"""
|
||||||
|
result = inject_dom_snapshot_calls(source)
|
||||||
|
assert "after_click_1" in result
|
||||||
|
assert "after_click_2" in result
|
||||||
|
assert "after_click_3" in result
|
||||||
Loading…
Reference in a new issue