Why tf.data.Dataset Matters
Training with in-memory NumPy arrays works for small datasets, but it becomes limiting when data grows, when preprocessing is expensive, or when you need consistent, repeatable input behavior. tf.data.Dataset provides a scalable way to stream data, apply transformations (preprocessing, shuffling, batching), and overlap CPU input work with GPU/TPU training.
A dataset pipeline is typically a chain of transformations that turns “raw examples” into “model-ready batches”. The goal is to keep the model fed efficiently while keeping preprocessing correct, reproducible, and maintainable.
Create Datasets from Tensors (In-Memory to Dataset)
The simplest migration path is to wrap existing arrays/tensors with from_tensor_slices. This gives you a dataset where each element is one example (or one pair of features/label).
import tensorflow as tf
# Example: features and labels already in memory
x = tf.random.uniform([1000, 20]) # 1000 examples, 20 features
y = tf.random.uniform([1000], maxval=3, dtype=tf.int32) # 3 classes
ds = tf.data.Dataset.from_tensor_slices((x, y))
for features, label in ds.take(1):
print(features.shape, label)Batching and Shuffling
Most models train on batches. Shuffling is important for SGD-style training so batches represent a good mix of examples.
BATCH_SIZE = 32
SHUFFLE_BUFFER = 1000
SEED = 123
train_ds = (ds
.shuffle(SHUFFLE_BUFFER, seed=SEED, reshuffle_each_iteration=True)
.batch(BATCH_SIZE))shuffle(buffer_size) uses a buffer to randomize order. Larger buffers approximate a more uniform shuffle but use more memory.
Continue in our app.
You can listen to the audiobook with the screen off, receive a free certificate for this course, and also have access to 5,000 other free online courses.
Or continue reading below...Download the app
Create Datasets from Files (Conceptual Patterns)
Real projects usually read from files rather than holding everything in memory. The tf.data API supports many sources (text lines, TFRecord, images on disk). The key idea is the same: create a dataset of file paths, then read and parse each file with map.
Pattern: Dataset of File Paths + map(parse)
# Conceptual example: list of file paths (strings)
file_paths = tf.constant([
"/path/to/example_001.tfrecord",
"/path/to/example_002.tfrecord",
])
files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
def read_file(path):
# Example placeholder: in practice use TFRecordDataset, read_file, etc.
raw = tf.io.read_file(path)
return raw
raw_ds = files_ds.map(read_file, num_parallel_calls=tf.data.AUTOTUNE)In practice, you typically replace read_file with a format-specific reader (for example, tf.data.TFRecordDataset for TFRecords, or tf.io.read_file + decode ops for images). After reading, you parse raw bytes into tensors and then preprocess.
Conceptual TFRecord Parsing Example
TFRecords store serialized tf.train.Example records. Parsing uses a feature specification that describes each field.
# Conceptual: parsing a serialized Example
feature_spec = {
"features": tf.io.FixedLenFeature([20], tf.float32),
"label": tf.io.FixedLenFeature([], tf.int64),
}
def parse_example(serialized):
ex = tf.io.parse_single_example(serialized, feature_spec)
x = ex["features"]
y = tf.cast(ex["label"], tf.int32)
return x, y
# raw_records_ds = tf.data.TFRecordDataset(["/path/to/data.tfrecord"]) # typical
# parsed_ds = raw_records_ds.map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)Core Transformations: map, batch, shuffle, cache, prefetch
map: Apply Parsing and Preprocessing
map transforms each element. It’s where you parse raw records, normalize numeric features, one-hot encode labels, or apply augmentation.
NUM_CLASSES = 3
def preprocess(features, label):
# Normalization example (assumes features are numeric)
features = tf.cast(features, tf.float32)
features = (features - tf.reduce_mean(features)) / (tf.math.reduce_std(features) + 1e-6)
# One-hot encode labels
label = tf.one_hot(label, depth=NUM_CLASSES)
return features, label
train_ds = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)Use num_parallel_calls=tf.data.AUTOTUNE to parallelize preprocessing across CPU cores.
batch: Create Mini-batches
batch groups elements. If your dataset size isn’t divisible by the batch size, you can keep the remainder batch or drop it.
train_ds = train_ds.batch(BATCH_SIZE, drop_remainder=False)shuffle: Randomize Example Order
Shuffling should usually happen before batching so each batch is a mix of examples. For deterministic shuffling, provide a seed and control reshuffling behavior.
train_ds = ds.shuffle(1000, seed=SEED, reshuffle_each_iteration=True)cache: Avoid Recomputing Expensive Steps
cache stores elements after the transformations before it. This is useful when reading/parsing is expensive and the dataset fits in memory. You can also cache to a file path.
# Cache in memory
train_ds = train_ds.cache()
# Or cache to disk (path is a filename)
# train_ds = train_ds.cache("/tmp/train_cache")Place cache after deterministic preprocessing (like parsing and normalization) but typically before random augmentation (if you want different augmentations each epoch).
prefetch: Overlap Input Work with Training
prefetch overlaps the CPU pipeline with model execution so the next batch is prepared while the current batch trains.
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)Recommended Ordering (Common Default)
A practical default pattern for training is:
- Read/parse
- Shuffle
- Map preprocessing
- Batch
- Prefetch
Then add cache where it makes sense based on dataset size and whether you want augmentation to vary each epoch.
Preprocessing Examples Inside the Dataset Pipeline
Normalization (Feature Scaling)
Normalization can be done per-example (simple but sometimes less stable) or using dataset-wide statistics (more consistent). Per-example normalization is easy to demonstrate in a pipeline.
def normalize_per_example(x, y):
x = tf.cast(x, tf.float32)
mean = tf.reduce_mean(x)
std = tf.math.reduce_std(x)
x = (x - mean) / (std + 1e-6)
return x, y
train_ds = train_ds.map(normalize_per_example, num_parallel_calls=tf.data.AUTOTUNE)If you have fixed feature-wise means/stds computed offline, apply them directly for consistent scaling.
# Example: feature-wise normalization with known stats
feature_mean = tf.constant([0.1] * 20, dtype=tf.float32)
feature_std = tf.constant([0.5] * 20, dtype=tf.float32)
def normalize_with_stats(x, y):
x = tf.cast(x, tf.float32)
x = (x - feature_mean) / (feature_std + 1e-6)
return x, yOne-Hot Encoding Labels
Many classification models expect one-hot labels when using categorical cross-entropy. You can convert integer labels in the pipeline.
def one_hot_labels(x, y):
y = tf.one_hot(tf.cast(y, tf.int32), depth=NUM_CLASSES)
return x, ySimple Data Augmentation as Tensor Ops
Augmentation is most common for images, but the idea applies broadly: apply random, label-preserving transformations during training only. Below is a simple image-like example using TensorFlow ops.
# Conceptual: x is an image tensor [H, W, C], y is label
def augment_image(x, y):
x = tf.image.random_flip_left_right(x)
x = tf.image.random_brightness(x, max_delta=0.1)
x = tf.clip_by_value(x, 0.0, 1.0)
return x, y
# Apply augmentation only on training dataset
# train_ds = train_ds.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)Keep augmentation out of validation/test pipelines so evaluation reflects real, unmodified data.
Where Should Preprocessing Live: Dataset Pipeline vs Model Layers?
Preprocessing in the Dataset Pipeline
Best for: parsing files, decoding, heavy CPU preprocessing, deterministic transformations, and augmentations that you only want during training.
Pros: can parallelize with
map, can cache, keeps the model simpler.Cons: if you export the model for serving, you must reproduce the same preprocessing in the serving system unless you also embed it in the model.
Preprocessing in Model Layers
Some preprocessing is better inside the model so it is saved and reused consistently at inference time. Examples include Keras preprocessing layers like normalization, rescaling, and text/vectorization layers (depending on your use case).
Best for: transformations that must be identical in training and serving (for example, rescaling images, normalization with learned/adapted statistics, vocabulary lookups).
Pros: exported model is self-contained; fewer chances of train/serve skew.
Cons: some heavy preprocessing may be slower inside the model; caching benefits may be reduced.
Practical Rule of Thumb
Put parsing/decoding and training-only augmentation in the dataset pipeline.
Put serving-critical normalization/encoding in the model (or ensure the exact same logic is applied at serving).
Repeatable Train/Validation Splitting Patterns
Splitting should be repeatable and should avoid leakage. If you already have separate files for train/validation, prefer that. If you need to split from a single dataset, use a deterministic method.
Pattern A: Split by Taking a Fixed Prefix (Deterministic)
This approach is deterministic if the dataset order is deterministic. It works well when the dataset is already in a stable order (or you apply a deterministic shuffle once).
def split_take_skip(ds, train_fraction=0.8, dataset_size=1000):
train_size = int(train_fraction * dataset_size)
train_ds = ds.take(train_size)
val_ds = ds.skip(train_size)
return train_ds, val_ds
# Example usage
full_ds = tf.data.Dataset.from_tensor_slices((x, y))
train_ds, val_ds = split_take_skip(full_ds, train_fraction=0.8, dataset_size=1000)If you want a randomized but repeatable split, shuffle once with a seed, then take/skip.
full_ds = full_ds.shuffle(1000, seed=SEED, reshuffle_each_iteration=False)
train_ds, val_ds = split_take_skip(full_ds, 0.8, 1000)Pattern B: Split by Hashing an ID (Stable Across Runs)
If each example has a stable identifier (like a filename or unique key), you can split by hashing so the split remains stable even if you change pipeline parallelism or read order.
# Conceptual: ds yields (id_str, features, label)
def split_by_hash(id_str, features, label, val_percent=20):
bucket = tf.strings.to_hash_bucket_fast(id_str, 100)
is_val = bucket < val_percent
return is_val
# Example filtering pattern
# val_ds = ds.filter(lambda id_str, x, y: split_by_hash(id_str, x, y, 20))
# train_ds = ds.filter(lambda id_str, x, y: tf.logical_not(split_by_hash(id_str, x, y, 20)))This is a strong pattern for file-based datasets where the filename can act as the ID.
Deterministic Behavior and Seeds
Reproducibility requires controlling randomness in multiple places: TensorFlow random ops, dataset shuffling, and sometimes parallelism. Determinism can reduce throughput, so enable it when you need repeatable experiments.
Set Global Seeds
SEED = 123
tf.keras.utils.set_random_seed(SEED)set_random_seed sets seeds for TensorFlow (and also Python/NumPy when available in your environment), helping keep runs repeatable.
Control Dataset Shuffling
ds = ds.shuffle(1000, seed=SEED, reshuffle_each_iteration=True)For a fixed split or fixed order across epochs, set reshuffle_each_iteration=False.
Control Determinism in the Input Pipeline
Parallel mapping can change the order elements are produced unless determinism is enforced. You can request deterministic behavior via dataset options.
options = tf.data.Options()
options.deterministic = True
ds = ds.with_options(options)When options.deterministic is True, transformations like parallel map aim to preserve determinism. If you prioritize throughput over exact reproducibility, you can set it to False.
Putting It Together: A Practical, Scalable Pipeline Template
The following template shows a repeatable structure you can adapt. It starts from tensors for simplicity, but the same structure applies after file parsing.
import tensorflow as tf
SEED = 123
BATCH_SIZE = 32
SHUFFLE_BUFFER = 1000
NUM_CLASSES = 3
tf.keras.utils.set_random_seed(SEED)
x = tf.random.uniform([1000, 20])
y = tf.random.uniform([1000], maxval=NUM_CLASSES, dtype=tf.int32)
full_ds = tf.data.Dataset.from_tensor_slices((x, y))
# Deterministic split: shuffle once, then take/skip
full_ds = full_ds.shuffle(SHUFFLE_BUFFER, seed=SEED, reshuffle_each_iteration=False)
train_size = int(0.8 * 1000)
train_ds = full_ds.take(train_size)
val_ds = full_ds.skip(train_size)
def preprocess(x, y):
x = tf.cast(x, tf.float32)
x = (x - tf.reduce_mean(x)) / (tf.math.reduce_std(x) + 1e-6)
y = tf.one_hot(y, depth=NUM_CLASSES)
return x, y
options = tf.data.Options()
options.deterministic = True
train_ds = (train_ds
.with_options(options)
.shuffle(SHUFFLE_BUFFER, seed=SEED, reshuffle_each_iteration=True)
.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(tf.data.AUTOTUNE))
val_ds = (val_ds
.with_options(options)
.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(tf.data.AUTOTUNE))This pattern gives you: (1) a stable train/validation split, (2) training-time reshuffling each epoch, (3) parallel preprocessing, and (4) prefetching to keep training efficient.