Free Ebook cover TensorFlow for Beginners: Building and Serving Your First Models

TensorFlow for Beginners: Building and Serving Your First Models

New course

10 pages

Custom Training with GradientTape for More Control

Capítulo 6

Estimated reading time: 9 minutes

+ Exercise

When and Why to Use a Custom Training Loop

Keras model.fit covers most training needs, but a lower-level training loop gives you direct control over what happens in each step. This is useful when you need custom losses that depend on intermediate tensors, advanced metric logic, multiple optimizers (for different parts of the model), or special update rules (for example, gradient clipping or conditional updates).

The core idea is to write a single training step that does four things: (1) run a forward pass, (2) compute the loss, (3) compute gradients, and (4) apply gradients with an optimizer. TensorFlow provides tf.GradientTape to record operations for automatic differentiation.

The Anatomy of a Training Step with tf.GradientTape

1) Forward pass

In a training step, you call the model with training=True so layers like Dropout and BatchNorm behave correctly.

y_pred = model(x, training=True)

2) Loss computation

You compute a scalar loss. This can be a standard loss function, a custom function, or a combination (for example, data loss + regularization terms). If your model uses layer regularizers, Keras exposes them through model.losses (a list of tensors) and you typically add them to your main loss.

data_loss = loss_fn(y_true, y_pred)  # should be scalar per batch (or reducible to scalar)  
reg_loss = tf.add_n(model.losses) if model.losses else 0.0  
total_loss = data_loss + reg_loss

3) Gradient computation

Wrap the forward pass and loss computation inside a tf.GradientTape context. Then ask the tape for gradients of the loss with respect to the model’s trainable variables.

Continue in our app.

You can listen to the audiobook with the screen off, receive a free certificate for this course, and also have access to 5,000 other free online courses.

Or continue reading below...
Download App

Download the app

with tf.GradientTape() as tape:  
    y_pred = model(x, training=True)  
    data_loss = loss_fn(y, y_pred)  
    reg_loss = tf.add_n(model.losses) if model.losses else 0.0  
    total_loss = data_loss + reg_loss  

grads = tape.gradient(total_loss, model.trainable_variables)

4) Optimizer application

Apply gradients to variables using optimizer.apply_gradients. A common enhancement is to filter out None gradients (which can happen if a variable is not connected to the loss).

grads_and_vars = [(g, v) for (g, v) in zip(grads, model.trainable_variables) if g is not None]  
optimizer.apply_gradients(grads_and_vars)

Wrapping the Step in tf.function for Performance

Python loops add overhead. Once your step works correctly in eager mode, wrap it with @tf.function to compile it into a TensorFlow graph for faster execution. Keep in mind that graph execution is stricter about shapes and Python side effects.

@tf.function  
def train_step(x, y):  
    with tf.GradientTape() as tape:  
        y_pred = model(x, training=True)  
        data_loss = loss_fn(y, y_pred)  
        reg_loss = tf.add_n(model.losses) if model.losses else 0.0  
        total_loss = data_loss + reg_loss  
    grads = tape.gradient(total_loss, model.trainable_variables)  
    grads_and_vars = [(g, v) for (g, v) in zip(grads, model.trainable_variables) if g is not None]  
    optimizer.apply_gradients(grads_and_vars)  
    return total_loss, y_pred

During debugging, you can temporarily remove @tf.function to get clearer Python stack traces, then add it back once stable.

Reusable Templates: train_step and test_step with Metrics

A clean pattern is to keep metrics as objects and update them each step. Use separate metrics for training and validation to avoid mixing states.

Define loss and metrics

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)  

train_loss = tf.keras.metrics.Mean(name="train_loss")  
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="train_acc")  

val_loss = tf.keras.metrics.Mean(name="val_loss")  
val_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="val_acc")

Training step template

@tf.function  
def train_step(x, y):  
    with tf.GradientTape() as tape:  
        y_pred = model(x, training=True)  
        data_loss = loss_fn(y, y_pred)  
        reg_loss = tf.add_n(model.losses) if model.losses else 0.0  
        total_loss = data_loss + reg_loss  

    grads = tape.gradient(total_loss, model.trainable_variables)  
    grads_and_vars = [(g, v) for (g, v) in zip(grads, model.trainable_variables) if g is not None]  
    optimizer.apply_gradients(grads_and_vars)  

    train_loss.update_state(total_loss)  
    train_acc.update_state(y, y_pred)

Validation/test step template

Validation does not compute gradients and uses training=False.

@tf.function  
def test_step(x, y):  
    y_pred = model(x, training=False)  
    data_loss = loss_fn(y, y_pred)  
    reg_loss = tf.add_n(model.losses) if model.losses else 0.0  
    total_loss = data_loss + reg_loss  

    val_loss.update_state(total_loss)  
    val_acc.update_state(y, y_pred)

Putting It Together: A Minimal Epoch Loop with Logging

Assuming you already have train_ds and val_ds as tf.data.Dataset objects, the outer loop is plain Python. The key is to reset metric states at the start of each epoch, then print or record their results.

for epoch in range(num_epochs):  
    train_loss.reset_state()  
    train_acc.reset_state()  
    val_loss.reset_state()  
    val_acc.reset_state()  

    for x_batch, y_batch in train_ds:  
        train_step(x_batch, y_batch)  

    for x_batch, y_batch in val_ds:  
        test_step(x_batch, y_batch)  

    logs = {  
        "train_loss": float(train_loss.result()),  
        "train_acc": float(train_acc.result()),  
        "val_loss": float(val_loss.result()),  
        "val_acc": float(val_acc.result()),  
    }  
    print(f"Epoch {epoch+1}: {logs}")

If you want TensorBoard logging, you can write scalars using tf.summary inside the epoch loop (outside @tf.function is often simpler), or inside the step with a step counter. Keep the logging minimal inside @tf.function to avoid retracing or performance issues.

Multiple Optimizers: Updating Different Parts of the Model

A common reason to use a custom loop is applying different optimizers to different variable groups (for example, a backbone and a new head). You compute one loss (or multiple losses) and then apply gradients separately.

backbone_vars = backbone.trainable_variables  
head_vars = head.trainable_variables  

@tf.function  
def train_step_multiopt(x, y):  
    with tf.GradientTape() as tape:  
        y_pred = model(x, training=True)  
        total_loss = loss_fn(y, y_pred)  

    grads = tape.gradient(total_loss, backbone_vars + head_vars)  
    grads_backbone = grads[:len(backbone_vars)]  
    grads_head = grads[len(backbone_vars):]  

    opt_backbone.apply_gradients([(g, v) for g, v in zip(grads_backbone, backbone_vars) if g is not None])  
    opt_head.apply_gradients([(g, v) for g, v in zip(grads_head, head_vars) if g is not None])  

    train_loss.update_state(total_loss)  
    train_acc.update_state(y, y_pred)

This pattern also works for GAN-like setups where you alternate updates, but be careful to compute the right loss for the variables you update.

How This Compares to model.fit

  • What you gain: full control over forward pass, loss composition, gradient handling (clipping, skipping, scaling), metric updates, and how/when optimizers run.

  • What you must manage: metric state resets, consistent loss reduction, correct training flags, and correct variable selection for gradients.

  • How to validate parity: run a few epochs with your custom loop and with model.fit using the same model, optimizer settings, and loss; compare training/validation metrics. Small differences can occur due to randomness, but trends should match.

Common Pitfalls and How to Fix Them

Pitfall 1: Forgetting to reset metrics

If you don’t reset metrics each epoch, they keep accumulating across epochs and your logs will look wrong (often monotonically improving or drifting strangely).

  • Fix: call reset_state() on every metric at the start of each epoch.

train_loss.reset_state()  
train_acc.reset_state()

Pitfall 2: Gradients are None

None gradients mean TensorFlow could not connect the loss to some variables. Common causes include: using the wrong variables list, computing loss from tensors not derived from the model output, accidentally stopping gradients (for example, tf.stop_gradient), or updating a submodel that wasn’t used in the forward pass.

  • Fix checklist:

    • Ensure the loss depends on y_pred computed from the model call inside the tape.

    • Use the correct variable list (for example, model.trainable_variables or the specific submodule variables).

    • Filter out None gradients to avoid optimizer errors, but also investigate why they are None.

for g, v in zip(grads, model.trainable_variables):  
    tf.debugging.assert_equal(g is None, False, message=v.name)

When running under @tf.function, prefer tf.debugging assertions over Python print.

Pitfall 3: Shape errors during loss calculation

Loss functions expect specific shapes. A frequent issue is mismatching label shape and prediction shape, or using the wrong loss for the label encoding.

  • Examples of common mismatches:

    • Using SparseCategoricalCrossentropy but providing one-hot labels (should be integer class IDs).

    • Using CategoricalCrossentropy but providing integer labels (should be one-hot or probability distributions).

    • Binary classification logits shaped (batch, 1) but labels shaped (batch,) (or vice versa), causing broadcasting surprises.

Fix: inspect shapes and choose the correct loss. Ensure batch dimensions align and that you reduce to a scalar loss.

tf.debugging.assert_rank(y, 1)  
tf.debugging.assert_rank(y_pred, 2)  

# Example: make binary labels match (batch, 1) if needed  
y = tf.cast(y, tf.float32)  
y = tf.reshape(y, (-1, 1))

Pitfall 4: Loss reduction and scaling issues

Some loss objects return per-example losses unless configured, and you may accidentally sum when you intended to average (or vice versa). This changes gradient magnitudes and can destabilize training.

  • Fix: use a loss object with a known reduction (default is usually SUM_OVER_BATCH_SIZE) or explicitly reduce.

per_example = tf.keras.losses.sparse_categorical_crossentropy(y, y_pred, from_logits=True)  
data_loss = tf.reduce_mean(per_example)

Pitfall 5: Retracing with tf.function

If input shapes vary a lot (for example, variable-length sequences without padding), @tf.function may retrace frequently, hurting performance.

  • Fix: keep shapes consistent (padding/bucketing) or provide an input signature.

@tf.function(input_signature=[  
    tf.TensorSpec(shape=[None, 32], dtype=tf.float32),  
    tf.TensorSpec(shape=[None], dtype=tf.int32),  
])  
def train_step(x, y):  
    ...

Now answer the exercise about the content:

In a custom TensorFlow training loop using tf.GradientTape, what is a key reason to call the model with training=True during the training step?

You are right! Congratulations, now go to the next page

You missed! Try again.

During the training step, using training=True makes training-specific layers (e.g., Dropout, BatchNorm) behave correctly. Metrics still require manual reset_state(), and retracing is addressed by consistent shapes or an input signature.

Next chapter

Saving and Loading TensorFlow Models as Reusable Artifacts

Arrow Right Icon
Download the app to earn free Certification and listen to the courses in the background, even with the screen off.