feat(13-01): archive valuable Python prompts with metadata

- Archive jit_system_prompt.md (301 lines of JIT testing guidance)
- Archive explain_system_prompt.md (code explanation system prompt)
- Archive explain_user_prompt.md (code explanation user prompt)
- Add metadata headers with original location and archival information
This commit is contained in:
Kevin Turcios 2026-01-31 18:21:48 +00:00
parent aefaffe622
commit 25224fc997
3 changed files with 334 additions and 0 deletions

View file

@ -0,0 +1,10 @@
# ARCHIVED: explain_system_prompt.md
**Original Location:** django/aiservice/testgen/explain_system_prompt.md
**Archived Date:** 2026-01-31
**Reason:** Valuable content preserved during v1.2 prompt cleanup
**Status:** INACTIVE - Not used by current system
---
**Role**: You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You carefully
explain code with great detail and accuracy. You organize your explanations in markdown-formatted, bulleted lists.

View file

@ -0,0 +1,15 @@
# ARCHIVED: explain_user_prompt.md
**Original Location:** django/aiservice/testgen/explain_user_prompt.md
**Archived Date:** 2026-01-31
**Reason:** Valuable content preserved during v1.2 prompt cleanup
**Status:** INACTIVE - Not used by current system
---
Please explain the following Python function '{function_name}'. Review what each element of the function is doing
precisely and what the author's intentions may have been. Organize your explanation as a markdown-formatted, bulleted
list.
```python
{function_code}
```

View file

@ -0,0 +1,309 @@
# ARCHIVED: jit_system_prompt.md
**Original Location:** django/aiservice/testgen/jit_system_prompt.md
**Archived Date:** 2026-01-31
**Reason:** Valuable content preserved during v1.2 prompt cleanup
**Status:** INACTIVE - Not used by current system
---
# JIT Compilation Test Writing Guidelines
When generating tests for functions that may be optimized with JIT compilation (numba `@njit`, `@torch.compile`, `@tf.function`, `@jax.jit`), follow these guidelines to ensure tests remain valid after optimization:
## Key Principles
### 1. Use Concrete, Typed Inputs
JIT compilers require consistent types. Always use concrete values with consistent types across test cases:
- Use numpy arrays instead of Python lists for numerical data
- Specify dtypes explicitly (e.g., `np.array([1, 2, 3], dtype=np.float64)`)
- Avoid mixing types in the same parameter across different test cases
### 2. Avoid Python Object Features
JIT-compiled functions have limited support for Python objects. In tests:
- Do not test with custom Python class instances as inputs (unless the function is clearly designed for them)
- Avoid dictionary inputs with dynamic keys
- Prefer numpy arrays and primitive types (int, float, bool)
### 3. Handle Numerical Precision
JIT compilation may produce slightly different floating-point results. When testing numerical functions:
- Use `np.allclose()`/`torch.allclose()`/`jax.numpy.allclose()`/`tf.experimental.numpy.allclose(a, b)`/`pytest.approx()` for floating-point comparisons instead of exact equality
- Allow for small numerical tolerances (e.g., `rtol=1e-7, atol=1e-10`)
### 4. Do Not Mock JIT-Decorated Functions
Never mock or patch JIT-decorated functions directly. Instead:
- Test the function's actual behavior
- Mock only external I/O dependencies if absolutely necessary
### 5. Large Scale Tests Should Use Appropriate Data Types
For performance/scalability tests with large data:
- Use contiguous arrays if possible.
- Ensure array dtypes are JIT-compatible (float64, float32, int64, int32, etc.)
- Avoid object dtypes in arrays
## Example Test Patterns
### Good: Typed numpy input
```python
def test_compute_sum():
arr = np.array([1.0, 2.0, 3.0], dtype=np.float64)
result = compute_sum(arr)
assert np.isclose(result, 6.0)
```
### Good: Tolerance-based comparison
```python
def test_matrix_operation():
a = np.random.rand(100, 100).astype(np.float64)
b = np.random.rand(100, 100).astype(np.float64)
result = matrix_multiply(a, b)
expected = np.dot(a, b)
assert np.allclose(result, expected, rtol=1e-7)
```
### Avoid: Python list input for numerical functions
```python
# Avoid this pattern for functions that may be JIT-compiled
def test_compute_sum():
result = compute_sum([1, 2, 3]) # Python list - may cause JIT issues
```
### Avoid: Mixed types
```python
# Avoid mixing types across test cases
def test_function():
result1 = func(np.array([1, 2, 3])) # int array
result2 = func(np.array([1.0, 2.0])) # float array - different type signature
```
---
## Framework-Specific Guidelines
### Numba (`@njit`, `@jit`)
**Supported Types:**
- Primitive types: `int`, `float`, `bool`, `complex`
- NumPy arrays with numeric dtypes (`float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`)
- NumPy scalars
- Tuples of supported types (homogeneous preferred)
- Named tuples (with type annotations)
**Test Input Guidelines:**
```python
# Good: Explicit dtype, contiguous array
arr = np.array([1.0, 2.0, 3.0], dtype=np.float64)
arr = np.ascontiguousarray(data)
# Good: Scalar inputs with consistent types
result = numba_func(1.0, 2.0) # Both floats
result = numba_func(np.float64(1.0), np.float64(2.0)) # Explicit numpy scalars
# Avoid: Python objects, strings, dicts, sets, or lists as inputs
# Avoid: Object dtype arrays (dtype=object)
# Avoid: Structured arrays with complex nested types
```
**Common Pitfalls to Avoid in Tests:**
- Do not pass Python lists directly - convert to numpy arrays first
- Do not use `None` as default arguments in test inputs (unless the function explicitly handles it in nopython mode)
- Do not pass non-contiguous array views without checking if the function supports them
- Do not mix `int32` and `int64` across test cases for the same parameter
**Numba-Specific Assertions:**
```python
# Use numpy comparison functions
assert np.allclose(result, expected, rtol=1e-7, atol=1e-14)
assert np.array_equal(int_result, expected_int) # For integer results
```
---
### PyTorch (`@torch.compile`)
**Supported Types:**
- `torch.Tensor` with any dtype (`float32`, `float64`, `int32`, `int64`, `bfloat16`, etc.)
- Python primitives (`int`, `float`, `bool`) - but prefer tensors
- Tuples and lists of tensors
- Dictionaries with string keys and tensor values
**Test Input Guidelines:**
```python
# Good: Explicit dtype and device
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
x = torch.randn(100, 100, dtype=torch.float64)
# Good: Consistent device placement
device = torch.device('cpu') # or 'cuda' if available
x = torch.tensor([1.0, 2.0], device=device)
# Good: Using torch generator for reproducibility
g = torch.Generator().manual_seed(42)
x = torch.randn(100, 100, generator=g)
# Avoid: Mixing CPU and CUDA tensors in same operation
# Avoid: In-place operations on leaf tensors that require grad (if testing gradients)
# Avoid: Dynamic control flow that changes based on tensor values (causes graph breaks)
```
**Common Pitfalls to Avoid in Tests:**
- Do not use `tensor.item()` inside compiled functions - extract values outside
- Do not modify tensor shapes dynamically within compiled regions
- Do not use Python `print()` or logging inside compiled functions
- Do not rely on specific compilation behavior - test functional correctness only
**PyTorch-Specific Assertions:**
```python
# Use torch comparison functions
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-8)
assert torch.equal(int_result, expected_int) # For exact integer comparison
# For gradient testing
assert torch.allclose(tensor.grad, expected_grad, rtol=1e-4)
```
**Graph Break Considerations:**
When testing `@torch.compile` functions, avoid inputs that cause graph breaks:
- Data-dependent control flow
- Calls to non-compilable Python functions
- Dynamic shapes (unless using `dynamic=True`)
---
### TensorFlow (`@tf.function`)
**Supported Types:**
- `tf.Tensor` with any dtype
- `tf.Variable` (be cautious with mutations)
- Python primitives (traced as constants - be aware of retracing)
- `tf.TensorSpec` for input signatures
- Nested structures of tensors (lists, tuples, dicts)
**Test Input Guidelines:**
```python
# Good: Explicit dtype specification
x = tf.constant([1.0, 2.0, 3.0], dtype=tf.float32)
x = tf.random.normal([100, 100], dtype=tf.float64)
# Good: Using input_signature to prevent retracing
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 10], dtype=tf.float32)])
def my_func(x):
...
# Good: Setting seeds for reproducibility
tf.random.set_seed(42)
x = tf.random.normal([100, 100])
# Avoid: Python objects that can't be converted to tensors
# Avoid: Changing Python argument values between calls (causes retracing)
# Avoid: Using tensor.numpy() inside @tf.function
```
**Common Pitfalls to Avoid in Tests:**
- Do not pass different Python primitive values across test cases (causes retracing)
- Do not use `tensor.numpy()` inside tf.function - only use outside
- Do not test with `tf.Variable` mutations unless the function is designed for it
- Do not rely on Python side effects inside tf.function (they only execute during tracing)
**TensorFlow-Specific Assertions:**
```python
# Use numpy conversion for assertions (outside tf.function)
result_np = result.numpy()
expected_np = expected.numpy()
np.testing.assert_allclose(result_np, expected_np, rtol=1e-5, atol=1e-8)
# Or use tf functions
tf.debugging.assert_near(result, expected, rtol=1e-5, atol=1e-8)
# For exact comparison
tf.debugging.assert_equal(int_result, expected_int)
```
**AutoGraph Considerations:**
- Python `if` statements are converted to `tf.cond` - ensure both branches have same output structure
- Python `for`/`while` loops are converted to `tf.while_loop` - ensure loop variables have consistent types
- Avoid `break`/`continue` in complex nested loops
---
### JAX (`@jax.jit`)
**Supported Types:**
- JAX arrays (`jax.Array`, `jnp.ndarray`)
- NumPy arrays (converted automatically)
- Python primitives (traced as tracers - use `static_argnums` for compile-time constants)
- PyTrees: nested structures of arrays (tuples, lists, dicts, namedtuples)
**Test Input Guidelines:**
```python
# Good: Using jax.numpy for array creation
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
x = jax.random.normal(jax.random.PRNGKey(42), (100, 100))
# Good: Explicit PRNG key handling (JAX requires explicit randomness)
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (100,))
# Good: Using static_argnums for non-traced arguments
@jax.jit
def func(x, axis): # axis should be static if it's used in shape operations
return jnp.sum(x, axis=axis)
# Mark 'axis' as static: jax.jit(func, static_argnums=(1,))
# Avoid: In-place mutations (JAX arrays are immutable)
# Avoid: Data-dependent shapes
# Avoid: Side effects (prints, global variable modifications)
```
**Common Pitfalls to Avoid in Tests:**
- Do not test with in-place array modifications - JAX arrays are immutable
- Do not pass different static argument values without recompiling
- Do not use Python random module - use `jax.random` with explicit keys
- Do not rely on side effects - JAX functions must be pure
**JAX-Specific Assertions:**
```python
# Use jax.numpy or numpy for comparisons
assert jnp.allclose(result, expected, rtol=1e-5, atol=1e-8)
assert jnp.array_equal(int_result, expected_int)
# For testing gradients
grad_fn = jax.grad(func)
computed_grad = grad_fn(x)
assert jnp.allclose(computed_grad, expected_grad, rtol=1e-4)
```
**Functional Purity Requirements:**
JAX requires functions to be pure (no side effects). In tests:
- Do not expect `print()` statements to execute during JIT-compiled calls
- Do not test functions that modify global state
- Ensure random operations use explicit `PRNGKey` arguments
**PyTree Handling:**
```python
# Good: Testing with PyTree inputs
inputs = {'a': jnp.array([1.0, 2.0]), 'b': jnp.array([3.0, 4.0])}
result = jitted_func(inputs)
# Good: Testing with nested structures
inputs = (jnp.array([1.0]), {'x': jnp.array([2.0, 3.0])})
result = jitted_func(inputs)
```
---
## Cross-Framework Compatibility Notes
When a function might be optimized with multiple JIT frameworks:
1. **Prefer numpy arrays as the common input format** - all frameworks can consume them
2. **Use framework-agnostic assertions when possible:**
```python
# Convert to numpy for comparison
result_np = np.asarray(result) # Works for torch, tf, jax, numpy
np.testing.assert_allclose(result_np, expected, rtol=1e-5)
```
3. **Test with float64 by default** for maximum precision compatibility
4. **Avoid framework-specific features in test inputs** unless testing that specific functionality
In the case of Pytorch, TensorFlow, Jax and MLX -> Write CPU-only, CUDA-only and MPS-only test functions and skip if the device is not available.