checkpoint
This commit is contained in:
parent
6f6bceb233
commit
1c80984933
2 changed files with 320 additions and 93 deletions
|
|
@ -4,16 +4,20 @@
|
|||
* This module provides a unified approach to instrumenting JavaScript tests
|
||||
* for both behavior verification and performance measurement.
|
||||
*
|
||||
* Unlike Python which has separate instrumentation methods for generated
|
||||
* vs existing tests, this helper works identically for ALL JavaScript tests.
|
||||
*
|
||||
* Uses SQLite for consistent data format with Python implementation.
|
||||
* The instrumentation mirrors Python's codeflash implementation:
|
||||
* - Static identifiers (testModule, testFunction, lineId) are passed at instrumentation time
|
||||
* - Dynamic invocation counter increments only when same call site is seen again (e.g., in loops)
|
||||
* - Uses hrtime for nanosecond precision timing
|
||||
* - SQLite for consistent data format with Python implementation
|
||||
*
|
||||
* Usage:
|
||||
* const codeflash = require('./codeflash-jest-helper');
|
||||
*
|
||||
* // Wrap function calls to capture behavior
|
||||
* const result = codeflash.capture('functionName', targetFunction, arg1, arg2);
|
||||
* // For behavior verification (writes to SQLite):
|
||||
* const result = codeflash.capture('functionName', lineId, targetFunction, arg1, arg2);
|
||||
*
|
||||
* // For performance benchmarking (stdout only):
|
||||
* const result = codeflash.capturePerf('functionName', lineId, targetFunction, arg1, arg2);
|
||||
*
|
||||
* Environment Variables:
|
||||
* CODEFLASH_OUTPUT_FILE - Path to write results SQLite file
|
||||
|
|
@ -24,7 +28,6 @@
|
|||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { performance } = require('perf_hooks');
|
||||
|
||||
// Load the codeflash serializer for robust value serialization
|
||||
const serializer = require('./codeflash-serializer');
|
||||
|
|
@ -46,10 +49,13 @@ const LOOP_INDEX = parseInt(process.env.CODEFLASH_LOOP_INDEX || '1', 10);
|
|||
const TEST_ITERATION = process.env.CODEFLASH_TEST_ITERATION || '0';
|
||||
const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE || '';
|
||||
|
||||
// Current test context
|
||||
// Current test context (set by Jest hooks)
|
||||
let currentTestName = null;
|
||||
let invocationCounter = 0;
|
||||
let lineId = '0';
|
||||
|
||||
// Invocation counter map: tracks how many times each testId has been seen
|
||||
// Key: testId (testModule:testClass:testFunction:lineId:loopIndex)
|
||||
// Value: count (starts at 0, increments each time same key is seen)
|
||||
const invocationCounterMap = new Map();
|
||||
|
||||
// Results buffer (for JSON fallback)
|
||||
const results = [];
|
||||
|
|
@ -57,6 +63,60 @@ const results = [];
|
|||
// SQLite database (lazy initialized)
|
||||
let db = null;
|
||||
|
||||
/**
|
||||
* Get high-resolution time in nanoseconds.
|
||||
* Prefers process.hrtime.bigint() for nanosecond precision,
|
||||
* falls back to performance.now() * 1e6 for non-Node environments.
|
||||
*
|
||||
* @returns {bigint|number} - Time in nanoseconds
|
||||
*/
|
||||
function getTimeNs() {
|
||||
if (typeof process !== 'undefined' && process.hrtime && process.hrtime.bigint) {
|
||||
return process.hrtime.bigint();
|
||||
}
|
||||
// Fallback to performance.now() in milliseconds, converted to nanoseconds
|
||||
const { performance } = require('perf_hooks');
|
||||
return BigInt(Math.floor(performance.now() * 1_000_000));
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate duration in nanoseconds.
|
||||
*
|
||||
* @param {bigint} start - Start time in nanoseconds
|
||||
* @param {bigint} end - End time in nanoseconds
|
||||
* @returns {number} - Duration in nanoseconds (as Number for SQLite compatibility)
|
||||
*/
|
||||
function getDurationNs(start, end) {
|
||||
const duration = end - start;
|
||||
// Convert to Number for SQLite storage (SQLite INTEGER is 64-bit)
|
||||
return Number(duration);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create invocation index for a testId.
|
||||
* This mirrors Python's index tracking per wrapper function.
|
||||
*
|
||||
* @param {string} testId - Unique test identifier
|
||||
* @returns {number} - Current invocation index (0-based)
|
||||
*/
|
||||
function getInvocationIndex(testId) {
|
||||
const currentIndex = invocationCounterMap.get(testId);
|
||||
if (currentIndex === undefined) {
|
||||
invocationCounterMap.set(testId, 0);
|
||||
return 0;
|
||||
}
|
||||
invocationCounterMap.set(testId, currentIndex + 1);
|
||||
return currentIndex + 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset invocation counter for a test.
|
||||
* Called at the start of each test to ensure consistent indexing.
|
||||
*/
|
||||
function resetInvocationCounters() {
|
||||
invocationCounterMap.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the SQLite database.
|
||||
*/
|
||||
|
|
@ -86,15 +146,6 @@ function initDatabase() {
|
|||
|
||||
/**
|
||||
* Safely serialize a value for storage.
|
||||
* Uses the codeflash-serializer which:
|
||||
* - Prefers V8 serialization (fast, handles all JS types natively)
|
||||
* - Falls back to msgpack with custom extensions (for Bun/browser)
|
||||
*
|
||||
* This provides robust serialization for:
|
||||
* - All primitive types (including NaN, Infinity, BigInt, Symbol)
|
||||
* - Complex objects (Map, Set, Date, RegExp, Error)
|
||||
* - TypedArrays and ArrayBuffer
|
||||
* - Circular references
|
||||
*
|
||||
* @param {any} value - Value to serialize
|
||||
* @returns {Buffer} - Serialized value as Buffer
|
||||
|
|
@ -103,8 +154,6 @@ function safeSerialize(value) {
|
|||
try {
|
||||
return serializer.serialize(value);
|
||||
} catch (e) {
|
||||
// If serialization fails, return a JSON error marker
|
||||
// This should be rare with the robust serializer
|
||||
console.warn('[codeflash] Serialization failed:', e.message);
|
||||
return Buffer.from(JSON.stringify({ __type: 'SerializationError', error: e.message }));
|
||||
}
|
||||
|
|
@ -112,7 +161,6 @@ function safeSerialize(value) {
|
|||
|
||||
/**
|
||||
* Safely deserialize a buffer back to a value.
|
||||
* Uses the codeflash-serializer to restore the original value.
|
||||
*
|
||||
* @param {Buffer|Uint8Array} buffer - Serialized buffer
|
||||
* @returns {any} - Deserialized value
|
||||
|
|
@ -129,19 +177,17 @@ function safeDeserialize(buffer) {
|
|||
/**
|
||||
* Record a test result to SQLite or JSON buffer.
|
||||
*
|
||||
* @param {string} testModulePath - Test module path
|
||||
* @param {string|null} testClassName - Test class name (null for Jest)
|
||||
* @param {string} testFunctionName - Test function name
|
||||
* @param {string} funcName - Name of the function being tested
|
||||
* @param {string} invocationId - Unique invocation identifier (lineId_index)
|
||||
* @param {Array} args - Arguments passed to the function
|
||||
* @param {any} returnValue - Return value from the function
|
||||
* @param {Error|null} error - Error thrown by the function (if any)
|
||||
* @param {number} durationNs - Execution time in nanoseconds
|
||||
*/
|
||||
function recordResult(funcName, args, returnValue, error, durationNs) {
|
||||
const invocationId = `${lineId}_${invocationCounter}`;
|
||||
invocationCounter++;
|
||||
|
||||
// Get test module path from file being tested or env
|
||||
const testModulePath = TEST_MODULE || currentTestName || 'unknown';
|
||||
|
||||
function recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, returnValue, error, durationNs) {
|
||||
// Serialize the return value (args, kwargs (empty for JS), return_value) like Python does
|
||||
const serializedValue = error
|
||||
? safeSerialize(error)
|
||||
|
|
@ -154,12 +200,12 @@ function recordResult(funcName, args, returnValue, error, durationNs) {
|
|||
`);
|
||||
stmt.run(
|
||||
testModulePath, // test_module_path
|
||||
null, // test_class_name (Jest doesn't use classes like Python)
|
||||
currentTestName, // test_function_name
|
||||
testClassName, // test_class_name
|
||||
testFunctionName, // test_function_name
|
||||
funcName, // function_getting_tested
|
||||
LOOP_INDEX, // loop_index
|
||||
invocationId, // iteration_id
|
||||
Math.round(durationNs), // runtime (nanoseconds)
|
||||
durationNs, // runtime (nanoseconds) - no rounding
|
||||
serializedValue, // return_value (serialized)
|
||||
'function_call' // verification_type
|
||||
);
|
||||
|
|
@ -168,12 +214,12 @@ function recordResult(funcName, args, returnValue, error, durationNs) {
|
|||
// Fall back to JSON
|
||||
results.push({
|
||||
testModulePath,
|
||||
testClassName: null,
|
||||
testFunctionName: currentTestName,
|
||||
testClassName,
|
||||
testFunctionName,
|
||||
funcName,
|
||||
loopIndex: LOOP_INDEX,
|
||||
iterationId: invocationId,
|
||||
durationNs: Math.round(durationNs),
|
||||
durationNs,
|
||||
returnValue: error ? null : returnValue,
|
||||
error: error ? { name: error.name, message: error.message } : null,
|
||||
verificationType: 'function_call'
|
||||
|
|
@ -183,42 +229,60 @@ function recordResult(funcName, args, returnValue, error, durationNs) {
|
|||
// JSON fallback
|
||||
results.push({
|
||||
testModulePath,
|
||||
testClassName: null,
|
||||
testFunctionName: currentTestName,
|
||||
testClassName,
|
||||
testFunctionName,
|
||||
funcName,
|
||||
loopIndex: LOOP_INDEX,
|
||||
iterationId: invocationId,
|
||||
durationNs: Math.round(durationNs),
|
||||
durationNs,
|
||||
returnValue: error ? null : returnValue,
|
||||
error: error ? { name: error.name, message: error.message } : null,
|
||||
verificationType: 'function_call'
|
||||
});
|
||||
}
|
||||
|
||||
// Print stdout tag like Python does for test identification
|
||||
const testClassName = '';
|
||||
const testStdoutTag = `${testModulePath}:${testClassName}${currentTestName}:${funcName}:${LOOP_INDEX}:${invocationId}`;
|
||||
console.log(`!$######${testStdoutTag}######$!`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture a function call with full behavior tracking.
|
||||
*
|
||||
* This is the main API for instrumenting function calls for BEHAVIOR verification.
|
||||
* It captures inputs (after call, to detect mutations), outputs, errors, and timing.
|
||||
* It captures inputs, outputs, errors, and timing.
|
||||
* Results are written to SQLite for comparison between original and optimized code.
|
||||
*
|
||||
* @param {string} funcName - Name of the function being tested
|
||||
* Static parameters (funcName, lineId) are determined at instrumentation time.
|
||||
* The lineId enables tracking when the same call site is invoked multiple times (e.g., in loops).
|
||||
*
|
||||
* @param {string} funcName - Name of the function being tested (static)
|
||||
* @param {string} lineId - Line number identifier in test file (static)
|
||||
* @param {Function} fn - The function to call
|
||||
* @param {...any} args - Arguments to pass to the function
|
||||
* @returns {any} - The function's return value
|
||||
* @throws {Error} - Re-throws any error from the function
|
||||
*/
|
||||
function capture(funcName, fn, ...args) {
|
||||
function capture(funcName, lineId, fn, ...args) {
|
||||
// Initialize database on first capture
|
||||
initDatabase();
|
||||
|
||||
const startTime = performance.now();
|
||||
// Get test context
|
||||
const testModulePath = TEST_MODULE || currentTestName || 'unknown';
|
||||
const testClassName = null; // Jest doesn't use classes like Python
|
||||
const testFunctionName = currentTestName || 'unknown';
|
||||
|
||||
// Create testId for invocation tracking (matches Python format)
|
||||
const testId = `${testModulePath}:${testClassName}:${testFunctionName}:${lineId}:${LOOP_INDEX}`;
|
||||
|
||||
// Get invocation index (increments if same testId seen again)
|
||||
const invocationIndex = getInvocationIndex(testId);
|
||||
const invocationId = `${lineId}_${invocationIndex}`;
|
||||
|
||||
// Format stdout tag (matches Python format)
|
||||
const testStdoutTag = `${testModulePath}:${testClassName ? testClassName + '.' : ''}${testFunctionName}:${funcName}:${LOOP_INDEX}:${invocationId}`;
|
||||
|
||||
// Print start tag
|
||||
console.log(`!$######${testStdoutTag}######$!`);
|
||||
|
||||
// Timing with nanosecond precision
|
||||
const startTime = getTimeNs();
|
||||
let returnValue;
|
||||
let error = null;
|
||||
|
||||
|
|
@ -229,16 +293,18 @@ function capture(funcName, fn, ...args) {
|
|||
if (returnValue instanceof Promise) {
|
||||
return returnValue.then(
|
||||
(resolved) => {
|
||||
const endTime = performance.now();
|
||||
const durationNs = (endTime - startTime) * 1_000_000;
|
||||
// Note: args is captured AFTER the call to detect mutations
|
||||
recordResult(funcName, args, resolved, null, durationNs);
|
||||
const endTime = getTimeNs();
|
||||
const durationNs = getDurationNs(startTime, endTime);
|
||||
recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, resolved, null, durationNs);
|
||||
// Print end tag (no duration for behavior mode)
|
||||
console.log(`!######${testStdoutTag}######!`);
|
||||
return resolved;
|
||||
},
|
||||
(err) => {
|
||||
const endTime = performance.now();
|
||||
const durationNs = (endTime - startTime) * 1_000_000;
|
||||
recordResult(funcName, args, null, err, durationNs);
|
||||
const endTime = getTimeNs();
|
||||
const durationNs = getDurationNs(startTime, endTime);
|
||||
recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, null, err, durationNs);
|
||||
console.log(`!######${testStdoutTag}######!`);
|
||||
throw err;
|
||||
}
|
||||
);
|
||||
|
|
@ -247,10 +313,12 @@ function capture(funcName, fn, ...args) {
|
|||
error = e;
|
||||
}
|
||||
|
||||
const endTime = performance.now();
|
||||
const durationNs = (endTime - startTime) * 1_000_000;
|
||||
// Note: args is captured AFTER the call to detect mutations (same as Python)
|
||||
recordResult(funcName, args, returnValue, error, durationNs);
|
||||
const endTime = getTimeNs();
|
||||
const durationNs = getDurationNs(startTime, endTime);
|
||||
recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, returnValue, error, durationNs);
|
||||
|
||||
// Print end tag (no duration for behavior mode, matching Python)
|
||||
console.log(`!######${testStdoutTag}######!`);
|
||||
|
||||
if (error) throw error;
|
||||
return returnValue;
|
||||
|
|
@ -263,63 +331,78 @@ function capture(funcName, fn, ...args) {
|
|||
* It prints start/end tags to stdout (no SQLite writes, no serialization overhead).
|
||||
* Used when we've already verified behavior and just need accurate timing.
|
||||
*
|
||||
* The timing measurement is done exactly around the function call for accuracy.
|
||||
*
|
||||
* Output format matches Python's codeflash_performance wrapper:
|
||||
* Start: !$######test_module:test_class.test_name:func_name:loop_index:invocation_id######$!
|
||||
* End: !######test_module:test_class.test_name:func_name:loop_index:invocation_id:duration_ns######!
|
||||
*
|
||||
* @param {string} funcName - Name of the function being tested
|
||||
* @param {string} funcName - Name of the function being tested (static)
|
||||
* @param {string} lineId - Line number identifier in test file (static)
|
||||
* @param {Function} fn - The function to call
|
||||
* @param {...any} args - Arguments to pass to the function
|
||||
* @returns {any} - The function's return value
|
||||
* @throws {Error} - Re-throws any error from the function
|
||||
*/
|
||||
function capturePerf(funcName, fn, ...args) {
|
||||
const invocationId = `${lineId}_${invocationCounter}`;
|
||||
invocationCounter++;
|
||||
|
||||
function capturePerf(funcName, lineId, fn, ...args) {
|
||||
// Get test context
|
||||
const testModulePath = TEST_MODULE || currentTestName || 'unknown';
|
||||
const testClassName = ''; // Jest doesn't use classes like Python
|
||||
const testClassName = null; // Jest doesn't use classes like Python
|
||||
const testFunctionName = currentTestName || 'unknown';
|
||||
|
||||
// Format: test_module:test_class.test_name:func_name:loop_index:invocation_id
|
||||
const testStdoutTag = `${testModulePath}:${testClassName}${currentTestName}:${funcName}:${LOOP_INDEX}:${invocationId}`;
|
||||
// Create testId for invocation tracking (matches Python format)
|
||||
const testId = `${testModulePath}:${testClassName}:${testFunctionName}:${lineId}:${LOOP_INDEX}`;
|
||||
|
||||
// Get invocation index (increments if same testId seen again)
|
||||
const invocationIndex = getInvocationIndex(testId);
|
||||
const invocationId = `${lineId}_${invocationIndex}`;
|
||||
|
||||
// Format stdout tag (matches Python format)
|
||||
const testStdoutTag = `${testModulePath}:${testClassName ? testClassName + '.' : ''}${testFunctionName}:${funcName}:${LOOP_INDEX}:${invocationId}`;
|
||||
|
||||
// Print start tag
|
||||
console.log(`!$######${testStdoutTag}######$!`);
|
||||
|
||||
const startTime = performance.now();
|
||||
// Timing with nanosecond precision - exactly around the function call
|
||||
let returnValue;
|
||||
let error = null;
|
||||
let durationNs;
|
||||
|
||||
try {
|
||||
const startTime = getTimeNs();
|
||||
returnValue = fn(...args);
|
||||
const endTime = getTimeNs();
|
||||
durationNs = getDurationNs(startTime, endTime);
|
||||
|
||||
// Handle promises (async functions)
|
||||
if (returnValue instanceof Promise) {
|
||||
return returnValue.then(
|
||||
(resolved) => {
|
||||
const endTime = performance.now();
|
||||
const durationNs = Math.round((endTime - startTime) * 1_000_000);
|
||||
// For async, we measure until resolution
|
||||
const asyncEndTime = getTimeNs();
|
||||
const asyncDurationNs = getDurationNs(startTime, asyncEndTime);
|
||||
// Print end tag with timing
|
||||
console.log(`!######${testStdoutTag}:${durationNs}######!`);
|
||||
console.log(`!######${testStdoutTag}:${asyncDurationNs}######!`);
|
||||
return resolved;
|
||||
},
|
||||
(err) => {
|
||||
const endTime = performance.now();
|
||||
const durationNs = Math.round((endTime - startTime) * 1_000_000);
|
||||
const asyncEndTime = getTimeNs();
|
||||
const asyncDurationNs = getDurationNs(startTime, asyncEndTime);
|
||||
// Print end tag with timing even on error
|
||||
console.log(`!######${testStdoutTag}:${durationNs}######!`);
|
||||
console.log(`!######${testStdoutTag}:${asyncDurationNs}######!`);
|
||||
throw err;
|
||||
}
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
const endTime = getTimeNs();
|
||||
// For sync errors, we still need to calculate duration
|
||||
// Use a fallback if we didn't capture startTime yet
|
||||
durationNs = 0;
|
||||
error = e;
|
||||
}
|
||||
|
||||
const endTime = performance.now();
|
||||
const durationNs = Math.round((endTime - startTime) * 1_000_000);
|
||||
// Print end tag with timing
|
||||
// Print end tag with timing (no rounding)
|
||||
console.log(`!######${testStdoutTag}:${durationNs}######!`);
|
||||
|
||||
if (error) throw error;
|
||||
|
|
@ -330,12 +413,13 @@ function capturePerf(funcName, fn, ...args) {
|
|||
* Capture multiple invocations for benchmarking.
|
||||
*
|
||||
* @param {string} funcName - Name of the function being tested
|
||||
* @param {string} lineId - Line number identifier
|
||||
* @param {Function} fn - The function to call
|
||||
* @param {Array<Array>} argsList - List of argument arrays to test
|
||||
* @returns {Array} - Array of return values
|
||||
*/
|
||||
function captureMultiple(funcName, fn, argsList) {
|
||||
return argsList.map(args => capture(funcName, fn, ...args));
|
||||
function captureMultiple(funcName, lineId, fn, argsList) {
|
||||
return argsList.map(args => capture(funcName, lineId, fn, ...args));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -378,7 +462,7 @@ function writeResults() {
|
|||
*/
|
||||
function clearResults() {
|
||||
results.length = 0;
|
||||
invocationCounter = 0;
|
||||
resetInvocationCounters();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -399,7 +483,7 @@ function getResults() {
|
|||
*/
|
||||
function setTestName(name) {
|
||||
currentTestName = name;
|
||||
invocationCounter = 0;
|
||||
resetInvocationCounters();
|
||||
}
|
||||
|
||||
// Jest lifecycle hooks - these run automatically when this module is imported
|
||||
|
|
@ -411,8 +495,8 @@ if (typeof beforeEach !== 'undefined') {
|
|||
} catch (e) {
|
||||
currentTestName = 'unknown';
|
||||
}
|
||||
invocationCounter = 0;
|
||||
lineId = String(Date.now() % 1000000); // Unique line ID per test
|
||||
// Reset invocation counters for each test
|
||||
resetInvocationCounters();
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -434,6 +518,8 @@ module.exports = {
|
|||
safeSerialize,
|
||||
safeDeserialize,
|
||||
initDatabase,
|
||||
resetInvocationCounters,
|
||||
getInvocationIndex,
|
||||
// Serializer info
|
||||
getSerializerType: serializer.getSerializerType,
|
||||
// Constants
|
||||
|
|
|
|||
|
|
@ -41,6 +41,9 @@ def instrument_javascript_tests(
|
|||
This is a UNIFIED approach - works for both generated and existing tests.
|
||||
The instrumentation wraps function calls to capture inputs, outputs, and timing.
|
||||
|
||||
Static identifiers (funcName, lineId) are determined at instrumentation time.
|
||||
The lineId enables tracking when the same call site is invoked multiple times (e.g., in loops).
|
||||
|
||||
Args:
|
||||
test_source: The JavaScript test source code
|
||||
function_name: The name of the function being tested
|
||||
|
|
@ -105,17 +108,10 @@ def instrument_javascript_tests(
|
|||
# Pattern for standalone function calls (not method calls)
|
||||
pattern = rf"(?<![.\w]){re.escape(function_name)}\s*\(([^)]*)\)"
|
||||
|
||||
def replace_call(match: re.Match) -> str:
|
||||
args = match.group(1).strip()
|
||||
if args:
|
||||
return f"codeflash.{capture_func}('{function_name}', {function_name}, {args})"
|
||||
else:
|
||||
return f"codeflash.{capture_func}('{function_name}', {function_name})"
|
||||
|
||||
# Apply replacement carefully - avoid replacing inside strings
|
||||
# This is a simplified approach that works for most cases
|
||||
instrumented_source = _safe_replace_function_calls(
|
||||
instrumented_source, function_name, replace_call, pattern
|
||||
# Apply replacement line by line to track line numbers
|
||||
# This enables tracking when the same call site is invoked multiple times (e.g., in loops)
|
||||
instrumented_source = _safe_replace_function_calls_with_lineid(
|
||||
instrumented_source, function_name, capture_func, pattern
|
||||
)
|
||||
|
||||
# Comment out expect() assertions - we use captured behavior for verification instead
|
||||
|
|
@ -210,6 +206,151 @@ def _comment_out_expects(source: str) -> str:
|
|||
return '\n'.join(result_lines)
|
||||
|
||||
|
||||
def _safe_replace_function_calls_with_lineid(
|
||||
source: str,
|
||||
function_name: str,
|
||||
capture_func: str,
|
||||
pattern: str,
|
||||
) -> str:
|
||||
"""
|
||||
Replace function calls while tracking line numbers.
|
||||
|
||||
Each function call is wrapped with a static lineId that enables tracking
|
||||
when the same call site is invoked multiple times (e.g., in loops).
|
||||
|
||||
Args:
|
||||
source: The JavaScript source code
|
||||
function_name: Name of the function to wrap
|
||||
capture_func: The capture function to use ('capture' or 'capturePerf')
|
||||
pattern: Regex pattern to match function calls
|
||||
"""
|
||||
lines = source.split('\n')
|
||||
result_lines = []
|
||||
|
||||
for line_num, line in enumerate(lines, start=1):
|
||||
# Process each line independently to track line numbers
|
||||
# Use the line number as the lineId
|
||||
|
||||
# Skip lines that are in strings or comments (simplified check)
|
||||
if line.strip().startswith('//') or line.strip().startswith('/*'):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
|
||||
# Track position within the line
|
||||
processed_line = _replace_calls_in_line(
|
||||
line, function_name, capture_func, pattern, line_num
|
||||
)
|
||||
result_lines.append(processed_line)
|
||||
|
||||
return '\n'.join(result_lines)
|
||||
|
||||
|
||||
def _replace_calls_in_line(
|
||||
line: str,
|
||||
function_name: str,
|
||||
capture_func: str,
|
||||
pattern: str,
|
||||
line_num: int,
|
||||
) -> str:
|
||||
"""
|
||||
Replace function calls in a single line with capture wrapper.
|
||||
|
||||
Handles string literals and avoids replacing inside them.
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
length = len(line)
|
||||
call_index = 0 # Track multiple calls on the same line
|
||||
|
||||
while i < length:
|
||||
char = line[i]
|
||||
|
||||
# Skip string literals
|
||||
if char in "'\"`":
|
||||
quote_char = char
|
||||
result.append(char)
|
||||
i += 1
|
||||
|
||||
# Handle template literals with ${} expressions
|
||||
if quote_char == "`":
|
||||
while i < length:
|
||||
if line[i] == "\\":
|
||||
result.append(line[i:i+2] if i+1 < length else line[i:])
|
||||
i += min(2, length - i)
|
||||
elif line[i] == "$" and i + 1 < length and line[i + 1] == "{":
|
||||
# Template expression - need to handle nested braces
|
||||
result.append(line[i:i+2])
|
||||
i += 2
|
||||
brace_count = 1
|
||||
while i < length and brace_count > 0:
|
||||
if line[i] == "{":
|
||||
brace_count += 1
|
||||
elif line[i] == "}":
|
||||
brace_count -= 1
|
||||
result.append(line[i])
|
||||
i += 1
|
||||
elif line[i] == quote_char:
|
||||
result.append(line[i])
|
||||
i += 1
|
||||
break
|
||||
else:
|
||||
result.append(line[i])
|
||||
i += 1
|
||||
else:
|
||||
# Regular string
|
||||
while i < length:
|
||||
if line[i] == "\\":
|
||||
result.append(line[i:i+2] if i+1 < length else line[i:])
|
||||
i += min(2, length - i)
|
||||
elif line[i] == quote_char:
|
||||
result.append(line[i])
|
||||
i += 1
|
||||
break
|
||||
else:
|
||||
result.append(line[i])
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check for function call pattern
|
||||
remaining = line[i:]
|
||||
match = re.match(pattern, remaining)
|
||||
if match:
|
||||
# Check that we're not preceded by a dot (method call) or already wrapped
|
||||
if i > 0 and line[i - 1] == ".":
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check we haven't already wrapped this
|
||||
check_start = max(0, i - 20)
|
||||
preceding = line[check_start:i]
|
||||
if "codeflash.capture" in preceding or "codeflash.capturePerf" in preceding:
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Generate lineId: line_number_call_index
|
||||
# For multiple calls on same line, we add call_index to distinguish them
|
||||
line_id = f"{line_num}_{call_index}" if call_index > 0 else str(line_num)
|
||||
call_index += 1
|
||||
|
||||
# Build replacement: codeflash.capture('funcName', 'lineId', func, args)
|
||||
args = match.group(1).strip()
|
||||
if args:
|
||||
replacement = f"codeflash.{capture_func}('{function_name}', '{line_id}', {function_name}, {args})"
|
||||
else:
|
||||
replacement = f"codeflash.{capture_func}('{function_name}', '{line_id}', {function_name})"
|
||||
|
||||
result.append(replacement)
|
||||
i += match.end()
|
||||
continue
|
||||
|
||||
result.append(char)
|
||||
i += 1
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _safe_replace_function_calls(
|
||||
source: str,
|
||||
function_name: str,
|
||||
|
|
|
|||
Loading…
Reference in a new issue