mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
340 lines
No EOL
11 KiB
Markdown
340 lines
No EOL
11 KiB
Markdown
# Model Management
|
|
|
|
Complete model serialization, checkpointing, and deployment utilities for production and inference. These operations provide comprehensive model lifecycle management capabilities.
|
|
|
|
## Capabilities
|
|
|
|
### Model Saving and Loading
|
|
|
|
Save and load complete models with all weights, architecture, and training configuration.
|
|
|
|
```python { .api }
|
|
def save(obj, export_dir, signatures=None, options=None):
|
|
"""
|
|
Exports a tf.Module (and subclasses) obj to SavedModel format.
|
|
|
|
Parameters:
|
|
- obj: A trackable object (e.g. tf.Module or tf.keras.Model) to export
|
|
- export_dir: A directory in which to write the SavedModel
|
|
- signatures: Optional, either a tf.function with an input signature specified or a dictionary
|
|
- options: Optional, tf.saved_model.SaveOptions object that specifies options for saving
|
|
"""
|
|
|
|
def load(export_dir, tags=None, options=None):
|
|
"""
|
|
Load a SavedModel from export_dir.
|
|
|
|
Parameters:
|
|
- export_dir: The SavedModel directory to load from
|
|
- tags: A tag or sequence of tags identifying the MetaGraph to load
|
|
- options: Optional, tf.saved_model.LoadOptions object that specifies options for loading
|
|
|
|
Returns:
|
|
A trackable object with a save method
|
|
"""
|
|
|
|
def contains_saved_model(export_dir):
|
|
"""
|
|
Checks whether the provided export directory could contain a SavedModel.
|
|
|
|
Parameters:
|
|
- export_dir: Absolute or relative path to a directory containing the SavedModel
|
|
|
|
Returns:
|
|
True if the export directory contains SavedModel files, False otherwise
|
|
"""
|
|
```
|
|
|
|
### Checkpointing
|
|
|
|
Save and restore model weights and training state for resuming training.
|
|
|
|
```python { .api }
|
|
class Checkpoint:
|
|
"""
|
|
Groups trackable objects, saving and restoring them.
|
|
|
|
Methods:
|
|
- save(file_prefix): Saves a training checkpoint and provides a context manager
|
|
- restore(save_path): Restore a training checkpoint
|
|
- read(save_path): Returns CheckpointReader for checkpoint inspection
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
"""
|
|
Groups trackable objects, saving and restoring them.
|
|
|
|
Parameters:
|
|
- **kwargs: Keyword arguments are set as attributes of this object, and are saved with the checkpoint
|
|
"""
|
|
|
|
def save(self, file_prefix, session=None):
|
|
"""
|
|
Saves a training checkpoint and provides a context manager.
|
|
|
|
Parameters:
|
|
- file_prefix: A prefix to use for the checkpoint filenames
|
|
- session: The session to evaluate variables in. Ignored when executing eagerly
|
|
|
|
Returns:
|
|
The full path to the checkpoint
|
|
"""
|
|
|
|
def restore(self, save_path):
|
|
"""
|
|
Restore a training checkpoint.
|
|
|
|
Parameters:
|
|
- save_path: The path to the checkpoint, as returned by save or tf.train.latest_checkpoint
|
|
|
|
Returns:
|
|
A load status object, which can be used to make assertions about the status of a checkpoint restoration
|
|
"""
|
|
|
|
def read(self, save_path):
|
|
"""
|
|
Returns a CheckpointReader for the checkpoint.
|
|
|
|
Parameters:
|
|
- save_path: The path to the checkpoint, as returned by save or tf.train.latest_checkpoint
|
|
|
|
Returns:
|
|
A CheckpointReader object
|
|
"""
|
|
|
|
class CheckpointManager:
|
|
"""
|
|
Deletes old checkpoints.
|
|
|
|
Methods:
|
|
- save(checkpoint_number): Creates a new checkpoint
|
|
"""
|
|
|
|
def __init__(self, checkpoint, directory, max_to_keep=5, keep_checkpoint_every_n_hours=None,
|
|
checkpoint_name="ckpt", step_counter=None, checkpoint_interval=None,
|
|
init_fn=None):
|
|
"""
|
|
Deletes old checkpoints.
|
|
|
|
Parameters:
|
|
- checkpoint: The tf.train.Checkpoint instance to save and manage checkpoints for
|
|
- directory: The path to a directory in which to write checkpoints
|
|
- max_to_keep: An integer, the number of checkpoints to keep
|
|
- keep_checkpoint_every_n_hours: Upon removal, keep checkpoints every N hours
|
|
- checkpoint_name: Custom name for the checkpoint file
|
|
- step_counter: A tf.Variable instance for checking the current step counter value
|
|
- checkpoint_interval: An integer, indicates that keep_checkpoint_every_n_hours should be based on checkpoints saved every checkpoint_interval steps
|
|
- init_fn: Callable. Function executed the first time a checkpoint is saved
|
|
"""
|
|
|
|
def save(self, checkpoint_number=None, check_interval=True):
|
|
"""
|
|
Creates a new checkpoint and manages deletion of old checkpoints.
|
|
|
|
Parameters:
|
|
- checkpoint_number: An optional integer, or an integer-dtype Variable or Tensor, used to number the checkpoint
|
|
- check_interval: An optional boolean. The default behaviour is that checkpoint_interval is ignored when checkpoint_number is provided
|
|
|
|
Returns:
|
|
The path to the new checkpoint. It is also recorded in the checkpoints and latest_checkpoint properties
|
|
"""
|
|
```
|
|
|
|
### Checkpoint Utilities
|
|
|
|
Utility functions for working with checkpoints.
|
|
|
|
```python { .api }
|
|
def list_variables(checkpoint_dir):
|
|
"""
|
|
Returns list of all variables in the checkpoint.
|
|
|
|
Parameters:
|
|
- checkpoint_dir: Directory with checkpoint file or path to checkpoint
|
|
|
|
Returns:
|
|
List of tuples (name, shape) for all variables in the checkpoint
|
|
"""
|
|
|
|
def load_checkpoint(checkpoint_dir):
|
|
"""
|
|
Returns CheckpointReader for checkpoint found in checkpoint_dir.
|
|
|
|
Parameters:
|
|
- checkpoint_dir: Directory with checkpoint file or path to checkpoint
|
|
|
|
Returns:
|
|
CheckpointReader instance
|
|
"""
|
|
|
|
def load_variable(checkpoint_dir, name):
|
|
"""
|
|
Returns the tensor value of the given variable in the checkpoint.
|
|
|
|
Parameters:
|
|
- checkpoint_dir: Directory with checkpoint file or path to checkpoint
|
|
- name: Name of the variable to return
|
|
|
|
Returns:
|
|
A numpy ndarray with a copy of the value of this variable
|
|
"""
|
|
|
|
def latest_checkpoint(checkpoint_dir, latest_filename=None):
|
|
"""
|
|
Finds the filename of latest saved checkpoint file.
|
|
|
|
Parameters:
|
|
- checkpoint_dir: Directory where the variables were saved
|
|
- latest_filename: Optional name for the protocol buffer file that contains the list of most recent checkpoint filenames
|
|
|
|
Returns:
|
|
The full path to the latest checkpoint or None if no checkpoint was found
|
|
"""
|
|
```
|
|
|
|
### SavedModel Utilities
|
|
|
|
Additional utilities for working with SavedModel format.
|
|
|
|
```python { .api }
|
|
class SaveOptions:
|
|
"""
|
|
Options for saving to SavedModel.
|
|
|
|
Parameters:
|
|
- namespace_whitelist: List of strings containing op namespaces to whitelist when saving a model
|
|
- save_debug_info: Boolean indicating whether debug information is saved
|
|
- function_aliases: Optional dictionary of string -> string of function aliases
|
|
- experimental_io_device: string. Applies in a distributed setting
|
|
- experimental_variable_policy: The policy to apply to variables when saving
|
|
"""
|
|
|
|
class LoadOptions:
|
|
"""
|
|
Options for loading a SavedModel.
|
|
|
|
Parameters:
|
|
- allow_partial_checkpoint: Boolean. Defaults to False. When enabled, allows the SavedModel checkpoint to be missing variables
|
|
- experimental_io_device: string. Loads SavedModel and variables on the specified device
|
|
- experimental_skip_checkpoint: boolean. If True, the checkpoint will not be loaded, and the SavedModel will be loaded with randomly initialized variable values
|
|
"""
|
|
|
|
class Asset:
|
|
"""
|
|
Represents a file asset to copy into the SavedModel.
|
|
|
|
Parameters:
|
|
- path: A path, or a 0-D tf.string Tensor with path to the asset
|
|
"""
|
|
```
|
|
|
|
## Usage Examples
|
|
|
|
```python
|
|
import tensorflow as tf
|
|
import os
|
|
|
|
# Create a simple model
|
|
model = tf.keras.Sequential([
|
|
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
|
|
tf.keras.layers.Dense(32, activation='relu'),
|
|
tf.keras.layers.Dense(1, activation='sigmoid')
|
|
])
|
|
|
|
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
|
|
|
# Save entire model to SavedModel format
|
|
tf.saved_model.save(model, 'my_saved_model')
|
|
|
|
# Load the saved model
|
|
loaded_model = tf.saved_model.load('my_saved_model')
|
|
|
|
# For Keras models, use keras save/load for full functionality
|
|
model.save('my_keras_model.h5')
|
|
loaded_keras_model = tf.keras.models.load_model('my_keras_model.h5')
|
|
|
|
# Checkpoint example
|
|
checkpoint_dir = './training_checkpoints'
|
|
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
|
|
|
|
# Create checkpoint object
|
|
checkpoint = tf.train.Checkpoint(optimizer=tf.keras.optimizers.Adam(),
|
|
model=model)
|
|
|
|
# Save checkpoint
|
|
checkpoint.save(file_prefix=checkpoint_prefix)
|
|
|
|
# Restore from checkpoint
|
|
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
|
|
|
|
# Using CheckpointManager for automatic cleanup
|
|
manager = tf.train.CheckpointManager(
|
|
checkpoint, directory=checkpoint_dir, max_to_keep=3
|
|
)
|
|
|
|
# Save with automatic cleanup
|
|
save_path = manager.save()
|
|
print(f"Saved checkpoint for step {step}: {save_path}")
|
|
|
|
# Training loop with checkpointing
|
|
optimizer = tf.keras.optimizers.Adam()
|
|
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
|
|
manager = tf.train.CheckpointManager(checkpoint, './checkpoints', max_to_keep=3)
|
|
|
|
# Restore if checkpoint exists
|
|
checkpoint.restore(manager.latest_checkpoint)
|
|
if manager.latest_checkpoint:
|
|
print(f"Restored from {manager.latest_checkpoint}")
|
|
else:
|
|
print("Initializing from scratch.")
|
|
|
|
# Training step function
|
|
@tf.function
|
|
def train_step(x, y):
|
|
with tf.GradientTape() as tape:
|
|
predictions = model(x, training=True)
|
|
loss = tf.keras.losses.binary_crossentropy(y, predictions)
|
|
|
|
gradients = tape.gradient(loss, model.trainable_variables)
|
|
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
|
|
|
return loss
|
|
|
|
# Training loop
|
|
for epoch in range(10):
|
|
# Training code here...
|
|
# x_batch, y_batch = get_batch()
|
|
# loss = train_step(x_batch, y_batch)
|
|
|
|
# Save checkpoint every few epochs
|
|
if epoch % 2 == 0:
|
|
save_path = manager.save()
|
|
print(f"Saved checkpoint for epoch {epoch}: {save_path}")
|
|
|
|
# Inspect checkpoint contents
|
|
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
|
|
if checkpoint_path:
|
|
variables = tf.train.list_variables(checkpoint_path)
|
|
for name, shape in variables:
|
|
print(f"Variable: {name}, Shape: {shape}")
|
|
|
|
# Load specific variable
|
|
specific_var = tf.train.load_variable(checkpoint_path, 'model/dense/kernel/.ATTRIBUTES/VARIABLE_VALUE')
|
|
print(f"Loaded variable shape: {specific_var.shape}")
|
|
|
|
# Check if directory contains SavedModel
|
|
if tf.saved_model.contains_saved_model('my_saved_model'):
|
|
print("Directory contains a valid SavedModel")
|
|
|
|
# Advanced SavedModel with custom signatures
|
|
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 10], dtype=tf.float32)])
|
|
def inference_func(x):
|
|
return model(x)
|
|
|
|
# Save with custom signature
|
|
tf.saved_model.save(
|
|
model,
|
|
'model_with_signature',
|
|
signatures={'serving_default': inference_func}
|
|
)
|
|
``` |