From “training loop” intuition to Keras training
When you call model.fit(), Keras runs a training loop for you. Conceptually, each training step does three things: (1) run a forward pass to produce predictions, (2) compute a loss that measures how wrong the predictions are, and (3) update model weights using an optimizer. Metrics are computed alongside the loss to give human-friendly signals (for example, accuracy) but they do not drive weight updates.
The Keras workflow mirrors this mental model in three stages:
- compile: choose the loss, optimizer, and metrics (what to optimize and how to report progress)
- fit: run training for some number of epochs with a batch size and optional callbacks (how long and with what supervision)
- evaluate/predict: measure performance on held-out data and generate outputs (how well it generalizes and what it outputs)
The compile step: loss, optimizer, metrics
Loss: what the model is trying to minimize
The loss function converts model predictions and true targets into a single number. Lower is better. Pick a loss that matches your task and label format.
- Binary classification (two classes):
BinaryCrossentropy - Multi-class classification (one correct class out of N):
SparseCategoricalCrossentropy(integer labels) orCategoricalCrossentropy(one-hot labels) - Regression:
MeanSquaredErrororMeanAbsoluteError
Two common “gotchas” are (1) using the wrong cross-entropy variant for your label encoding, and (2) mismatching the final activation with the loss. Keras losses often have a from_logits argument: set it to True if your model outputs raw scores (no sigmoid/softmax), and False if your model outputs probabilities (with sigmoid/softmax).
import tensorflow as tf
# Example: multi-class classification with integer labels (0..num_classes-1)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="acc")]
)Optimizer: how weights are updated
An optimizer decides how to adjust weights to reduce the loss. For beginners, Adam is a strong default because it adapts learning rates per parameter and often converges quickly. SGD (optionally with momentum) can work very well too, but is more sensitive to learning rate and may need more tuning.
Continue in our app.
You can listen to the audiobook with the screen off, receive a free certificate for this course, and also have access to 5,000 other free online courses.
Or continue reading below...Download the app
- Adam: good default, fast convergence in many problems
- SGD + momentum: can generalize well; often needs careful learning-rate scheduling
The most important optimizer hyperparameter is the learning rate. Too high: training may diverge or bounce around. Too low: training is stable but painfully slow and may look “stuck.”
# Two learning-rate choices to compare
adam_fast = tf.keras.optimizers.Adam(learning_rate=1e-3)
adam_slow = tf.keras.optimizers.Adam(learning_rate=1e-4)Metrics: what you monitor (not what you optimize)
Metrics are computed for reporting and monitoring. They do not affect gradients unless you explicitly build them into the loss. Choose metrics that reflect what you care about in deployment.
- Classification:
Accuracy,Precision,Recall,AUC - Regression:
MAE,MSE
It’s common to optimize cross-entropy loss while monitoring accuracy. Loss is smoother and more sensitive to improvements; accuracy can plateau even when the model is still getting “more confident” in correct predictions.
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-3),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
metrics=[
tf.keras.metrics.BinaryAccuracy(name="acc"),
tf.keras.metrics.AUC(name="auc")
]
)The fit step: epochs, batch size, validation, callbacks
Epochs and batch size: how training is chunked
Batch size is how many examples are used to compute one gradient update. Smaller batches add noise to updates (sometimes helping generalization) but can be slower. Larger batches can be faster on GPUs but may require a different learning rate and can generalize differently.
Epochs is how many full passes over the training data you run. More epochs can improve performance until the model starts to overfit.
# Typical fit call
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=20,
batch_size=32, # ignored if train_ds already batches
verbose=2
)If you are using a tf.data.Dataset that is already batched, the batch_size argument is ignored. In that case, control batch size in the dataset pipeline.
Validation data: your early warning system
Validation metrics estimate generalization during training. You typically want training loss to go down and validation loss to go down as well. When training keeps improving but validation gets worse, you are likely overfitting.
# If you have arrays instead of a dataset
history = model.fit(
x_train, y_train,
validation_split=0.2,
epochs=30,
batch_size=64
)Reading training curves to diagnose underfitting vs. overfitting
Keras returns a History object whose history dict contains per-epoch values (for example, loss, val_loss, acc, val_acc). Plotting these curves is one of the fastest ways to diagnose what to do next.
import matplotlib.pyplot as plt
def plot_history(history, metric="acc"):
h = history.history
epochs = range(1, len(h["loss"]) + 1)
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(epochs, h["loss"], label="train")
if "val_loss" in h:
plt.plot(epochs, h["val_loss"], label="val")
plt.title("Loss")
plt.xlabel("epoch")
plt.legend()
plt.subplot(1,2,2)
if metric in h:
plt.plot(epochs, h[metric], label="train")
val_metric = "val_" + metric
if val_metric in h:
plt.plot(epochs, h[val_metric], label="val")
plt.title(metric)
plt.xlabel("epoch")
plt.legend()
plt.tight_layout()
# plot_history(history, metric="acc")Underfitting (model too simple, not trained enough, or learning rate too low):
- Training loss is high and decreases slowly; training metric is low.
- Validation loss/metric is similarly poor (no big gap).
- Typical fixes: train longer, increase model capacity, reduce regularization, improve features, or increase learning rate slightly.
Overfitting (model memorizes training data):
- Training loss keeps decreasing, training metric keeps improving.
- Validation loss bottoms out then starts rising; validation metric plateaus or drops.
- Typical fixes: add regularization (dropout/weight decay), use data augmentation, reduce model capacity, or stop earlier with callbacks.
How learning rate affects convergence (and what it looks like)
The learning rate controls step size in weight space. You can often identify learning-rate problems just by the loss curve:
- Too high: loss is unstable, may spike or become
NaN; validation is erratic. - Too low: loss decreases very slowly; you may see almost flat curves for many epochs.
- Reasonable: loss decreases steadily; validation follows training until overfitting begins.
A practical workflow is to start with a sensible default (for Adam, often 1e-3), watch the first few epochs, then adjust. If training is unstable, lower the learning rate by 3–10×. If training is stable but slow, increase by 2–3×.
# Re-compile with a different learning rate (optimizer state resets)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
loss=model.loss,
metrics=model.metrics
)Note: recompiling resets optimizer state (such as Adam’s moving averages). If you want to change learning rate mid-training without resetting state, use a callback like ReduceLROnPlateau or assign optimizer.learning_rate (advanced usage).
Essential callbacks: when and how to use them
Callbacks let you inject logic into training: saving checkpoints, stopping early, or adapting the learning rate. They are especially useful because they help you avoid wasting compute and reduce overfitting.
ModelCheckpoint: save the best model as you train
Use ModelCheckpoint when training is long, when you expect overfitting, or when you want the best validation model even if later epochs get worse. A common pattern is to save only the best weights according to val_loss.
ckpt = tf.keras.callbacks.ModelCheckpoint(
filepath="checkpoints/best.weights.h5",
monitor="val_loss",
save_best_only=True,
save_weights_only=True,
mode="min",
verbose=1
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=50,
callbacks=[ckpt]
)
# Restore best weights after training
model.load_weights("checkpoints/best.weights.h5")Guidance:
- Monitor
val_lossfor most tasks; it is more sensitive than accuracy. - Use
save_weights_only=Trueif you can recreate the model code easily; otherwise save the full model (larger files).
EarlyStopping: stop when validation stops improving
Use EarlyStopping to prevent overfitting and to avoid spending time on epochs that no longer improve validation. The patience parameter allows a few “non-improving” epochs before stopping, which is helpful when validation is noisy.
early = tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=5,
mode="min",
restore_best_weights=True,
verbose=1
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=100,
callbacks=[early]
)Guidance:
- Set
restore_best_weights=Trueso you end with the best validation epoch automatically. - Choose
patiencebased on noise: 2–3 for stable curves, 5–10 for noisy validation.
ReduceLROnPlateau: lower learning rate when progress stalls
Use ReduceLROnPlateau when loss decreases at first but then plateaus. Reducing the learning rate can help the optimizer “settle” into a better minimum. This is a simple and effective alternative to manually re-tuning learning rates.
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=2,
min_lr=1e-6,
mode="min",
verbose=1
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=60,
callbacks=[reduce_lr]
)Guidance:
- Start with
factorbetween 0.1 and 0.5. - Use small
patience(1–3) so the learning rate drops soon after a plateau. - Combine with
EarlyStoppingso training ends once even reduced learning rates no longer help.
Putting callbacks together (common recipe)
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
"checkpoints/best.weights.h5",
monitor="val_loss",
save_best_only=True,
save_weights_only=True
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=2,
min_lr=1e-6
),
tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=6,
restore_best_weights=True
)
]
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=100,
callbacks=callbacks
)This combination tends to work well in practice: checkpoints preserve the best model, learning rate reductions help squeeze out extra validation improvements, and early stopping prevents unnecessary over-training.
Evaluate and predict: measuring and using the trained model
Evaluate: final metrics on a test/holdout set
After training decisions are made using validation data, use a separate test set (or final holdout) to estimate real-world performance. model.evaluate() returns the loss and metric values in the same order you compiled.
results = model.evaluate(test_ds, verbose=0)
print(results)
# If you want named results
named = dict(zip(model.metrics_names, results))
print(named)Predict: generate outputs for new inputs
model.predict() runs a forward pass and returns model outputs. For classification, you often convert probabilities to class decisions.
# Predict probabilities (shape depends on your model)
y_prob = model.predict(new_x)
# Example: binary classification thresholding
# y_pred = (y_prob >= 0.5).astype("int32")
# Example: multi-class class id
# y_pred = y_prob.argmax(axis=-1)If your model outputs logits (no sigmoid/softmax), apply the appropriate activation at prediction time (sigmoid for binary, softmax for multi-class) before thresholding or taking argmax.