When and Why to Use a Custom Training Loop
Keras model.fit covers most training needs, but a lower-level training loop gives you direct control over what happens in each step. This is useful when you need custom losses that depend on intermediate tensors, advanced metric logic, multiple optimizers (for different parts of the model), or special update rules (for example, gradient clipping or conditional updates).
The core idea is to write a single training step that does four things: (1) run a forward pass, (2) compute the loss, (3) compute gradients, and (4) apply gradients with an optimizer. TensorFlow provides tf.GradientTape to record operations for automatic differentiation.
The Anatomy of a Training Step with tf.GradientTape
1) Forward pass
In a training step, you call the model with training=True so layers like Dropout and BatchNorm behave correctly.
y_pred = model(x, training=True)2) Loss computation
You compute a scalar loss. This can be a standard loss function, a custom function, or a combination (for example, data loss + regularization terms). If your model uses layer regularizers, Keras exposes them through model.losses (a list of tensors) and you typically add them to your main loss.
data_loss = loss_fn(y_true, y_pred) # should be scalar per batch (or reducible to scalar)
reg_loss = tf.add_n(model.losses) if model.losses else 0.0
total_loss = data_loss + reg_loss3) Gradient computation
Wrap the forward pass and loss computation inside a tf.GradientTape context. Then ask the tape for gradients of the loss with respect to the model’s trainable variables.
Continue in our app.
You can listen to the audiobook with the screen off, receive a free certificate for this course, and also have access to 5,000 other free online courses.
Or continue reading below...Download the app
with tf.GradientTape() as tape:
y_pred = model(x, training=True)
data_loss = loss_fn(y, y_pred)
reg_loss = tf.add_n(model.losses) if model.losses else 0.0
total_loss = data_loss + reg_loss
grads = tape.gradient(total_loss, model.trainable_variables)4) Optimizer application
Apply gradients to variables using optimizer.apply_gradients. A common enhancement is to filter out None gradients (which can happen if a variable is not connected to the loss).
grads_and_vars = [(g, v) for (g, v) in zip(grads, model.trainable_variables) if g is not None]
optimizer.apply_gradients(grads_and_vars)Wrapping the Step in tf.function for Performance
Python loops add overhead. Once your step works correctly in eager mode, wrap it with @tf.function to compile it into a TensorFlow graph for faster execution. Keep in mind that graph execution is stricter about shapes and Python side effects.
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
y_pred = model(x, training=True)
data_loss = loss_fn(y, y_pred)
reg_loss = tf.add_n(model.losses) if model.losses else 0.0
total_loss = data_loss + reg_loss
grads = tape.gradient(total_loss, model.trainable_variables)
grads_and_vars = [(g, v) for (g, v) in zip(grads, model.trainable_variables) if g is not None]
optimizer.apply_gradients(grads_and_vars)
return total_loss, y_predDuring debugging, you can temporarily remove @tf.function to get clearer Python stack traces, then add it back once stable.
Reusable Templates: train_step and test_step with Metrics
A clean pattern is to keep metrics as objects and update them each step. Use separate metrics for training and validation to avoid mixing states.
Define loss and metrics
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_loss = tf.keras.metrics.Mean(name="train_loss")
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="train_acc")
val_loss = tf.keras.metrics.Mean(name="val_loss")
val_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="val_acc")Training step template
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
y_pred = model(x, training=True)
data_loss = loss_fn(y, y_pred)
reg_loss = tf.add_n(model.losses) if model.losses else 0.0
total_loss = data_loss + reg_loss
grads = tape.gradient(total_loss, model.trainable_variables)
grads_and_vars = [(g, v) for (g, v) in zip(grads, model.trainable_variables) if g is not None]
optimizer.apply_gradients(grads_and_vars)
train_loss.update_state(total_loss)
train_acc.update_state(y, y_pred)Validation/test step template
Validation does not compute gradients and uses training=False.
@tf.function
def test_step(x, y):
y_pred = model(x, training=False)
data_loss = loss_fn(y, y_pred)
reg_loss = tf.add_n(model.losses) if model.losses else 0.0
total_loss = data_loss + reg_loss
val_loss.update_state(total_loss)
val_acc.update_state(y, y_pred)Putting It Together: A Minimal Epoch Loop with Logging
Assuming you already have train_ds and val_ds as tf.data.Dataset objects, the outer loop is plain Python. The key is to reset metric states at the start of each epoch, then print or record their results.
for epoch in range(num_epochs):
train_loss.reset_state()
train_acc.reset_state()
val_loss.reset_state()
val_acc.reset_state()
for x_batch, y_batch in train_ds:
train_step(x_batch, y_batch)
for x_batch, y_batch in val_ds:
test_step(x_batch, y_batch)
logs = {
"train_loss": float(train_loss.result()),
"train_acc": float(train_acc.result()),
"val_loss": float(val_loss.result()),
"val_acc": float(val_acc.result()),
}
print(f"Epoch {epoch+1}: {logs}")If you want TensorBoard logging, you can write scalars using tf.summary inside the epoch loop (outside @tf.function is often simpler), or inside the step with a step counter. Keep the logging minimal inside @tf.function to avoid retracing or performance issues.
Multiple Optimizers: Updating Different Parts of the Model
A common reason to use a custom loop is applying different optimizers to different variable groups (for example, a backbone and a new head). You compute one loss (or multiple losses) and then apply gradients separately.
backbone_vars = backbone.trainable_variables
head_vars = head.trainable_variables
@tf.function
def train_step_multiopt(x, y):
with tf.GradientTape() as tape:
y_pred = model(x, training=True)
total_loss = loss_fn(y, y_pred)
grads = tape.gradient(total_loss, backbone_vars + head_vars)
grads_backbone = grads[:len(backbone_vars)]
grads_head = grads[len(backbone_vars):]
opt_backbone.apply_gradients([(g, v) for g, v in zip(grads_backbone, backbone_vars) if g is not None])
opt_head.apply_gradients([(g, v) for g, v in zip(grads_head, head_vars) if g is not None])
train_loss.update_state(total_loss)
train_acc.update_state(y, y_pred)This pattern also works for GAN-like setups where you alternate updates, but be careful to compute the right loss for the variables you update.
How This Compares to model.fit
What you gain: full control over forward pass, loss composition, gradient handling (clipping, skipping, scaling), metric updates, and how/when optimizers run.
What you must manage: metric state resets, consistent loss reduction, correct
trainingflags, and correct variable selection for gradients.How to validate parity: run a few epochs with your custom loop and with
model.fitusing the same model, optimizer settings, and loss; compare training/validation metrics. Small differences can occur due to randomness, but trends should match.
Common Pitfalls and How to Fix Them
Pitfall 1: Forgetting to reset metrics
If you don’t reset metrics each epoch, they keep accumulating across epochs and your logs will look wrong (often monotonically improving or drifting strangely).
Fix: call
reset_state()on every metric at the start of each epoch.
train_loss.reset_state()
train_acc.reset_state()Pitfall 2: Gradients are None
None gradients mean TensorFlow could not connect the loss to some variables. Common causes include: using the wrong variables list, computing loss from tensors not derived from the model output, accidentally stopping gradients (for example, tf.stop_gradient), or updating a submodel that wasn’t used in the forward pass.
Fix checklist:
Ensure the loss depends on
y_predcomputed from the model call inside the tape.Use the correct variable list (for example,
model.trainable_variablesor the specific submodule variables).Filter out
Nonegradients to avoid optimizer errors, but also investigate why they areNone.
for g, v in zip(grads, model.trainable_variables):
tf.debugging.assert_equal(g is None, False, message=v.name)When running under @tf.function, prefer tf.debugging assertions over Python print.
Pitfall 3: Shape errors during loss calculation
Loss functions expect specific shapes. A frequent issue is mismatching label shape and prediction shape, or using the wrong loss for the label encoding.
Examples of common mismatches:
Using
SparseCategoricalCrossentropybut providing one-hot labels (should be integer class IDs).Using
CategoricalCrossentropybut providing integer labels (should be one-hot or probability distributions).Binary classification logits shaped
(batch, 1)but labels shaped(batch,)(or vice versa), causing broadcasting surprises.
Fix: inspect shapes and choose the correct loss. Ensure batch dimensions align and that you reduce to a scalar loss.
tf.debugging.assert_rank(y, 1)
tf.debugging.assert_rank(y_pred, 2)
# Example: make binary labels match (batch, 1) if needed
y = tf.cast(y, tf.float32)
y = tf.reshape(y, (-1, 1))Pitfall 4: Loss reduction and scaling issues
Some loss objects return per-example losses unless configured, and you may accidentally sum when you intended to average (or vice versa). This changes gradient magnitudes and can destabilize training.
Fix: use a loss object with a known reduction (default is usually
SUM_OVER_BATCH_SIZE) or explicitly reduce.
per_example = tf.keras.losses.sparse_categorical_crossentropy(y, y_pred, from_logits=True)
data_loss = tf.reduce_mean(per_example)Pitfall 5: Retracing with tf.function
If input shapes vary a lot (for example, variable-length sequences without padding), @tf.function may retrace frequently, hurting performance.
Fix: keep shapes consistent (padding/bucketing) or provide an input signature.
@tf.function(input_signature=[
tf.TensorSpec(shape=[None, 32], dtype=tf.float32),
tf.TensorSpec(shape=[None], dtype=tf.int32),
])
def train_step(x, y):
...