checkpoint

This commit is contained in:
misrasaurabh1 2026-01-15 15:57:46 -08:00
parent 6f6bceb233
commit 1c80984933
2 changed files with 320 additions and 93 deletions

View file

@ -4,16 +4,20 @@
* This module provides a unified approach to instrumenting JavaScript tests
* for both behavior verification and performance measurement.
*
* Unlike Python which has separate instrumentation methods for generated
* vs existing tests, this helper works identically for ALL JavaScript tests.
*
* Uses SQLite for consistent data format with Python implementation.
* The instrumentation mirrors Python's codeflash implementation:
* - Static identifiers (testModule, testFunction, lineId) are passed at instrumentation time
* - Dynamic invocation counter increments only when same call site is seen again (e.g., in loops)
* - Uses hrtime for nanosecond precision timing
* - SQLite for consistent data format with Python implementation
*
* Usage:
* const codeflash = require('./codeflash-jest-helper');
*
* // Wrap function calls to capture behavior
* const result = codeflash.capture('functionName', targetFunction, arg1, arg2);
* // For behavior verification (writes to SQLite):
* const result = codeflash.capture('functionName', lineId, targetFunction, arg1, arg2);
*
* // For performance benchmarking (stdout only):
* const result = codeflash.capturePerf('functionName', lineId, targetFunction, arg1, arg2);
*
* Environment Variables:
* CODEFLASH_OUTPUT_FILE - Path to write results SQLite file
@ -24,7 +28,6 @@
const fs = require('fs');
const path = require('path');
const { performance } = require('perf_hooks');
// Load the codeflash serializer for robust value serialization
const serializer = require('./codeflash-serializer');
@ -46,10 +49,13 @@ const LOOP_INDEX = parseInt(process.env.CODEFLASH_LOOP_INDEX || '1', 10);
const TEST_ITERATION = process.env.CODEFLASH_TEST_ITERATION || '0';
const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE || '';
// Current test context
// Current test context (set by Jest hooks)
let currentTestName = null;
let invocationCounter = 0;
let lineId = '0';
// Invocation counter map: tracks how many times each testId has been seen
// Key: testId (testModule:testClass:testFunction:lineId:loopIndex)
// Value: count (starts at 0, increments each time same key is seen)
const invocationCounterMap = new Map();
// Results buffer (for JSON fallback)
const results = [];
@ -57,6 +63,60 @@ const results = [];
// SQLite database (lazy initialized)
let db = null;
/**
* Get high-resolution time in nanoseconds.
* Prefers process.hrtime.bigint() for nanosecond precision,
* falls back to performance.now() * 1e6 for non-Node environments.
*
* @returns {bigint|number} - Time in nanoseconds
*/
function getTimeNs() {
if (typeof process !== 'undefined' && process.hrtime && process.hrtime.bigint) {
return process.hrtime.bigint();
}
// Fallback to performance.now() in milliseconds, converted to nanoseconds
const { performance } = require('perf_hooks');
return BigInt(Math.floor(performance.now() * 1_000_000));
}
/**
* Calculate duration in nanoseconds.
*
* @param {bigint} start - Start time in nanoseconds
* @param {bigint} end - End time in nanoseconds
* @returns {number} - Duration in nanoseconds (as Number for SQLite compatibility)
*/
function getDurationNs(start, end) {
const duration = end - start;
// Convert to Number for SQLite storage (SQLite INTEGER is 64-bit)
return Number(duration);
}
/**
* Get or create invocation index for a testId.
* This mirrors Python's index tracking per wrapper function.
*
* @param {string} testId - Unique test identifier
* @returns {number} - Current invocation index (0-based)
*/
function getInvocationIndex(testId) {
const currentIndex = invocationCounterMap.get(testId);
if (currentIndex === undefined) {
invocationCounterMap.set(testId, 0);
return 0;
}
invocationCounterMap.set(testId, currentIndex + 1);
return currentIndex + 1;
}
/**
* Reset invocation counter for a test.
* Called at the start of each test to ensure consistent indexing.
*/
function resetInvocationCounters() {
invocationCounterMap.clear();
}
/**
* Initialize the SQLite database.
*/
@ -86,15 +146,6 @@ function initDatabase() {
/**
* Safely serialize a value for storage.
* Uses the codeflash-serializer which:
* - Prefers V8 serialization (fast, handles all JS types natively)
* - Falls back to msgpack with custom extensions (for Bun/browser)
*
* This provides robust serialization for:
* - All primitive types (including NaN, Infinity, BigInt, Symbol)
* - Complex objects (Map, Set, Date, RegExp, Error)
* - TypedArrays and ArrayBuffer
* - Circular references
*
* @param {any} value - Value to serialize
* @returns {Buffer} - Serialized value as Buffer
@ -103,8 +154,6 @@ function safeSerialize(value) {
try {
return serializer.serialize(value);
} catch (e) {
// If serialization fails, return a JSON error marker
// This should be rare with the robust serializer
console.warn('[codeflash] Serialization failed:', e.message);
return Buffer.from(JSON.stringify({ __type: 'SerializationError', error: e.message }));
}
@ -112,7 +161,6 @@ function safeSerialize(value) {
/**
* Safely deserialize a buffer back to a value.
* Uses the codeflash-serializer to restore the original value.
*
* @param {Buffer|Uint8Array} buffer - Serialized buffer
* @returns {any} - Deserialized value
@ -129,19 +177,17 @@ function safeDeserialize(buffer) {
/**
* Record a test result to SQLite or JSON buffer.
*
* @param {string} testModulePath - Test module path
* @param {string|null} testClassName - Test class name (null for Jest)
* @param {string} testFunctionName - Test function name
* @param {string} funcName - Name of the function being tested
* @param {string} invocationId - Unique invocation identifier (lineId_index)
* @param {Array} args - Arguments passed to the function
* @param {any} returnValue - Return value from the function
* @param {Error|null} error - Error thrown by the function (if any)
* @param {number} durationNs - Execution time in nanoseconds
*/
function recordResult(funcName, args, returnValue, error, durationNs) {
const invocationId = `${lineId}_${invocationCounter}`;
invocationCounter++;
// Get test module path from file being tested or env
const testModulePath = TEST_MODULE || currentTestName || 'unknown';
function recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, returnValue, error, durationNs) {
// Serialize the return value (args, kwargs (empty for JS), return_value) like Python does
const serializedValue = error
? safeSerialize(error)
@ -154,12 +200,12 @@ function recordResult(funcName, args, returnValue, error, durationNs) {
`);
stmt.run(
testModulePath, // test_module_path
null, // test_class_name (Jest doesn't use classes like Python)
currentTestName, // test_function_name
testClassName, // test_class_name
testFunctionName, // test_function_name
funcName, // function_getting_tested
LOOP_INDEX, // loop_index
invocationId, // iteration_id
Math.round(durationNs), // runtime (nanoseconds)
durationNs, // runtime (nanoseconds) - no rounding
serializedValue, // return_value (serialized)
'function_call' // verification_type
);
@ -168,12 +214,12 @@ function recordResult(funcName, args, returnValue, error, durationNs) {
// Fall back to JSON
results.push({
testModulePath,
testClassName: null,
testFunctionName: currentTestName,
testClassName,
testFunctionName,
funcName,
loopIndex: LOOP_INDEX,
iterationId: invocationId,
durationNs: Math.round(durationNs),
durationNs,
returnValue: error ? null : returnValue,
error: error ? { name: error.name, message: error.message } : null,
verificationType: 'function_call'
@ -183,42 +229,60 @@ function recordResult(funcName, args, returnValue, error, durationNs) {
// JSON fallback
results.push({
testModulePath,
testClassName: null,
testFunctionName: currentTestName,
testClassName,
testFunctionName,
funcName,
loopIndex: LOOP_INDEX,
iterationId: invocationId,
durationNs: Math.round(durationNs),
durationNs,
returnValue: error ? null : returnValue,
error: error ? { name: error.name, message: error.message } : null,
verificationType: 'function_call'
});
}
// Print stdout tag like Python does for test identification
const testClassName = '';
const testStdoutTag = `${testModulePath}:${testClassName}${currentTestName}:${funcName}:${LOOP_INDEX}:${invocationId}`;
console.log(`!$######${testStdoutTag}######$!`);
}
/**
* Capture a function call with full behavior tracking.
*
* This is the main API for instrumenting function calls for BEHAVIOR verification.
* It captures inputs (after call, to detect mutations), outputs, errors, and timing.
* It captures inputs, outputs, errors, and timing.
* Results are written to SQLite for comparison between original and optimized code.
*
* @param {string} funcName - Name of the function being tested
* Static parameters (funcName, lineId) are determined at instrumentation time.
* The lineId enables tracking when the same call site is invoked multiple times (e.g., in loops).
*
* @param {string} funcName - Name of the function being tested (static)
* @param {string} lineId - Line number identifier in test file (static)
* @param {Function} fn - The function to call
* @param {...any} args - Arguments to pass to the function
* @returns {any} - The function's return value
* @throws {Error} - Re-throws any error from the function
*/
function capture(funcName, fn, ...args) {
function capture(funcName, lineId, fn, ...args) {
// Initialize database on first capture
initDatabase();
const startTime = performance.now();
// Get test context
const testModulePath = TEST_MODULE || currentTestName || 'unknown';
const testClassName = null; // Jest doesn't use classes like Python
const testFunctionName = currentTestName || 'unknown';
// Create testId for invocation tracking (matches Python format)
const testId = `${testModulePath}:${testClassName}:${testFunctionName}:${lineId}:${LOOP_INDEX}`;
// Get invocation index (increments if same testId seen again)
const invocationIndex = getInvocationIndex(testId);
const invocationId = `${lineId}_${invocationIndex}`;
// Format stdout tag (matches Python format)
const testStdoutTag = `${testModulePath}:${testClassName ? testClassName + '.' : ''}${testFunctionName}:${funcName}:${LOOP_INDEX}:${invocationId}`;
// Print start tag
console.log(`!$######${testStdoutTag}######$!`);
// Timing with nanosecond precision
const startTime = getTimeNs();
let returnValue;
let error = null;
@ -229,16 +293,18 @@ function capture(funcName, fn, ...args) {
if (returnValue instanceof Promise) {
return returnValue.then(
(resolved) => {
const endTime = performance.now();
const durationNs = (endTime - startTime) * 1_000_000;
// Note: args is captured AFTER the call to detect mutations
recordResult(funcName, args, resolved, null, durationNs);
const endTime = getTimeNs();
const durationNs = getDurationNs(startTime, endTime);
recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, resolved, null, durationNs);
// Print end tag (no duration for behavior mode)
console.log(`!######${testStdoutTag}######!`);
return resolved;
},
(err) => {
const endTime = performance.now();
const durationNs = (endTime - startTime) * 1_000_000;
recordResult(funcName, args, null, err, durationNs);
const endTime = getTimeNs();
const durationNs = getDurationNs(startTime, endTime);
recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, null, err, durationNs);
console.log(`!######${testStdoutTag}######!`);
throw err;
}
);
@ -247,10 +313,12 @@ function capture(funcName, fn, ...args) {
error = e;
}
const endTime = performance.now();
const durationNs = (endTime - startTime) * 1_000_000;
// Note: args is captured AFTER the call to detect mutations (same as Python)
recordResult(funcName, args, returnValue, error, durationNs);
const endTime = getTimeNs();
const durationNs = getDurationNs(startTime, endTime);
recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, returnValue, error, durationNs);
// Print end tag (no duration for behavior mode, matching Python)
console.log(`!######${testStdoutTag}######!`);
if (error) throw error;
return returnValue;
@ -263,63 +331,78 @@ function capture(funcName, fn, ...args) {
* It prints start/end tags to stdout (no SQLite writes, no serialization overhead).
* Used when we've already verified behavior and just need accurate timing.
*
* The timing measurement is done exactly around the function call for accuracy.
*
* Output format matches Python's codeflash_performance wrapper:
* Start: !$######test_module:test_class.test_name:func_name:loop_index:invocation_id######$!
* End: !######test_module:test_class.test_name:func_name:loop_index:invocation_id:duration_ns######!
*
* @param {string} funcName - Name of the function being tested
* @param {string} funcName - Name of the function being tested (static)
* @param {string} lineId - Line number identifier in test file (static)
* @param {Function} fn - The function to call
* @param {...any} args - Arguments to pass to the function
* @returns {any} - The function's return value
* @throws {Error} - Re-throws any error from the function
*/
function capturePerf(funcName, fn, ...args) {
const invocationId = `${lineId}_${invocationCounter}`;
invocationCounter++;
function capturePerf(funcName, lineId, fn, ...args) {
// Get test context
const testModulePath = TEST_MODULE || currentTestName || 'unknown';
const testClassName = ''; // Jest doesn't use classes like Python
const testClassName = null; // Jest doesn't use classes like Python
const testFunctionName = currentTestName || 'unknown';
// Format: test_module:test_class.test_name:func_name:loop_index:invocation_id
const testStdoutTag = `${testModulePath}:${testClassName}${currentTestName}:${funcName}:${LOOP_INDEX}:${invocationId}`;
// Create testId for invocation tracking (matches Python format)
const testId = `${testModulePath}:${testClassName}:${testFunctionName}:${lineId}:${LOOP_INDEX}`;
// Get invocation index (increments if same testId seen again)
const invocationIndex = getInvocationIndex(testId);
const invocationId = `${lineId}_${invocationIndex}`;
// Format stdout tag (matches Python format)
const testStdoutTag = `${testModulePath}:${testClassName ? testClassName + '.' : ''}${testFunctionName}:${funcName}:${LOOP_INDEX}:${invocationId}`;
// Print start tag
console.log(`!$######${testStdoutTag}######$!`);
const startTime = performance.now();
// Timing with nanosecond precision - exactly around the function call
let returnValue;
let error = null;
let durationNs;
try {
const startTime = getTimeNs();
returnValue = fn(...args);
const endTime = getTimeNs();
durationNs = getDurationNs(startTime, endTime);
// Handle promises (async functions)
if (returnValue instanceof Promise) {
return returnValue.then(
(resolved) => {
const endTime = performance.now();
const durationNs = Math.round((endTime - startTime) * 1_000_000);
// For async, we measure until resolution
const asyncEndTime = getTimeNs();
const asyncDurationNs = getDurationNs(startTime, asyncEndTime);
// Print end tag with timing
console.log(`!######${testStdoutTag}:${durationNs}######!`);
console.log(`!######${testStdoutTag}:${asyncDurationNs}######!`);
return resolved;
},
(err) => {
const endTime = performance.now();
const durationNs = Math.round((endTime - startTime) * 1_000_000);
const asyncEndTime = getTimeNs();
const asyncDurationNs = getDurationNs(startTime, asyncEndTime);
// Print end tag with timing even on error
console.log(`!######${testStdoutTag}:${durationNs}######!`);
console.log(`!######${testStdoutTag}:${asyncDurationNs}######!`);
throw err;
}
);
}
} catch (e) {
const endTime = getTimeNs();
// For sync errors, we still need to calculate duration
// Use a fallback if we didn't capture startTime yet
durationNs = 0;
error = e;
}
const endTime = performance.now();
const durationNs = Math.round((endTime - startTime) * 1_000_000);
// Print end tag with timing
// Print end tag with timing (no rounding)
console.log(`!######${testStdoutTag}:${durationNs}######!`);
if (error) throw error;
@ -330,12 +413,13 @@ function capturePerf(funcName, fn, ...args) {
* Capture multiple invocations for benchmarking.
*
* @param {string} funcName - Name of the function being tested
* @param {string} lineId - Line number identifier
* @param {Function} fn - The function to call
* @param {Array<Array>} argsList - List of argument arrays to test
* @returns {Array} - Array of return values
*/
function captureMultiple(funcName, fn, argsList) {
return argsList.map(args => capture(funcName, fn, ...args));
function captureMultiple(funcName, lineId, fn, argsList) {
return argsList.map(args => capture(funcName, lineId, fn, ...args));
}
/**
@ -378,7 +462,7 @@ function writeResults() {
*/
function clearResults() {
results.length = 0;
invocationCounter = 0;
resetInvocationCounters();
}
/**
@ -399,7 +483,7 @@ function getResults() {
*/
function setTestName(name) {
currentTestName = name;
invocationCounter = 0;
resetInvocationCounters();
}
// Jest lifecycle hooks - these run automatically when this module is imported
@ -411,8 +495,8 @@ if (typeof beforeEach !== 'undefined') {
} catch (e) {
currentTestName = 'unknown';
}
invocationCounter = 0;
lineId = String(Date.now() % 1000000); // Unique line ID per test
// Reset invocation counters for each test
resetInvocationCounters();
});
}
@ -434,6 +518,8 @@ module.exports = {
safeSerialize,
safeDeserialize,
initDatabase,
resetInvocationCounters,
getInvocationIndex,
// Serializer info
getSerializerType: serializer.getSerializerType,
// Constants

View file

@ -41,6 +41,9 @@ def instrument_javascript_tests(
This is a UNIFIED approach - works for both generated and existing tests.
The instrumentation wraps function calls to capture inputs, outputs, and timing.
Static identifiers (funcName, lineId) are determined at instrumentation time.
The lineId enables tracking when the same call site is invoked multiple times (e.g., in loops).
Args:
test_source: The JavaScript test source code
function_name: The name of the function being tested
@ -105,17 +108,10 @@ def instrument_javascript_tests(
# Pattern for standalone function calls (not method calls)
pattern = rf"(?<![.\w]){re.escape(function_name)}\s*\(([^)]*)\)"
def replace_call(match: re.Match) -> str:
args = match.group(1).strip()
if args:
return f"codeflash.{capture_func}('{function_name}', {function_name}, {args})"
else:
return f"codeflash.{capture_func}('{function_name}', {function_name})"
# Apply replacement carefully - avoid replacing inside strings
# This is a simplified approach that works for most cases
instrumented_source = _safe_replace_function_calls(
instrumented_source, function_name, replace_call, pattern
# Apply replacement line by line to track line numbers
# This enables tracking when the same call site is invoked multiple times (e.g., in loops)
instrumented_source = _safe_replace_function_calls_with_lineid(
instrumented_source, function_name, capture_func, pattern
)
# Comment out expect() assertions - we use captured behavior for verification instead
@ -210,6 +206,151 @@ def _comment_out_expects(source: str) -> str:
return '\n'.join(result_lines)
def _safe_replace_function_calls_with_lineid(
source: str,
function_name: str,
capture_func: str,
pattern: str,
) -> str:
"""
Replace function calls while tracking line numbers.
Each function call is wrapped with a static lineId that enables tracking
when the same call site is invoked multiple times (e.g., in loops).
Args:
source: The JavaScript source code
function_name: Name of the function to wrap
capture_func: The capture function to use ('capture' or 'capturePerf')
pattern: Regex pattern to match function calls
"""
lines = source.split('\n')
result_lines = []
for line_num, line in enumerate(lines, start=1):
# Process each line independently to track line numbers
# Use the line number as the lineId
# Skip lines that are in strings or comments (simplified check)
if line.strip().startswith('//') or line.strip().startswith('/*'):
result_lines.append(line)
continue
# Track position within the line
processed_line = _replace_calls_in_line(
line, function_name, capture_func, pattern, line_num
)
result_lines.append(processed_line)
return '\n'.join(result_lines)
def _replace_calls_in_line(
line: str,
function_name: str,
capture_func: str,
pattern: str,
line_num: int,
) -> str:
"""
Replace function calls in a single line with capture wrapper.
Handles string literals and avoids replacing inside them.
"""
result = []
i = 0
length = len(line)
call_index = 0 # Track multiple calls on the same line
while i < length:
char = line[i]
# Skip string literals
if char in "'\"`":
quote_char = char
result.append(char)
i += 1
# Handle template literals with ${} expressions
if quote_char == "`":
while i < length:
if line[i] == "\\":
result.append(line[i:i+2] if i+1 < length else line[i:])
i += min(2, length - i)
elif line[i] == "$" and i + 1 < length and line[i + 1] == "{":
# Template expression - need to handle nested braces
result.append(line[i:i+2])
i += 2
brace_count = 1
while i < length and brace_count > 0:
if line[i] == "{":
brace_count += 1
elif line[i] == "}":
brace_count -= 1
result.append(line[i])
i += 1
elif line[i] == quote_char:
result.append(line[i])
i += 1
break
else:
result.append(line[i])
i += 1
else:
# Regular string
while i < length:
if line[i] == "\\":
result.append(line[i:i+2] if i+1 < length else line[i:])
i += min(2, length - i)
elif line[i] == quote_char:
result.append(line[i])
i += 1
break
else:
result.append(line[i])
i += 1
continue
# Check for function call pattern
remaining = line[i:]
match = re.match(pattern, remaining)
if match:
# Check that we're not preceded by a dot (method call) or already wrapped
if i > 0 and line[i - 1] == ".":
result.append(char)
i += 1
continue
# Check we haven't already wrapped this
check_start = max(0, i - 20)
preceding = line[check_start:i]
if "codeflash.capture" in preceding or "codeflash.capturePerf" in preceding:
result.append(char)
i += 1
continue
# Generate lineId: line_number_call_index
# For multiple calls on same line, we add call_index to distinguish them
line_id = f"{line_num}_{call_index}" if call_index > 0 else str(line_num)
call_index += 1
# Build replacement: codeflash.capture('funcName', 'lineId', func, args)
args = match.group(1).strip()
if args:
replacement = f"codeflash.{capture_func}('{function_name}', '{line_id}', {function_name}, {args})"
else:
replacement = f"codeflash.{capture_func}('{function_name}', '{line_id}', {function_name})"
result.append(replacement)
i += match.end()
continue
result.append(char)
i += 1
return "".join(result)
def _safe_replace_function_calls(
source: str,
function_name: str,