Lab 4-01: Keras Functional API
Learning Objectives
- Build models with the Keras Functional API (vs Sequential)
- Create multi-input, multi-output models
- Implement shared layers and branching architectures
- Write custom
tf.keras.layers.Layersubclasses - Use callbacks: EarlyStopping, ModelCheckpoint, TensorBoard
Keras Functional API vs Sequential
# Sequential: only linear stacks
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10, activation='softmax'),
])
# Functional API: full DAG support
inputs = tf.keras.Input(shape=(224, 224, 3))
x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
Multi-Input Model
# Multi-modal: image + metadata
img_input = tf.keras.Input(shape=(128, 128, 1), name="image")
meta_input = tf.keras.Input(shape=(5,), name="metadata")
# Image branch
x = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same')(img_input)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
# Fuse
combined = tf.keras.layers.Concatenate()([x, meta_input])
out = tf.keras.layers.Dense(1, activation='sigmoid', name="output")(combined)
model = tf.keras.Model(inputs=[img_input, meta_input], outputs=out)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
Custom Layer
class ChannelAttention(tf.keras.layers.Layer):
"""Squeeze-and-Excite channel attention."""
def __init__(self, reduction_ratio=4, **kwargs):
super().__init__(**kwargs)
self.reduction_ratio = reduction_ratio
def build(self, input_shape):
C = input_shape[-1]
self.fc1 = tf.keras.layers.Dense(C // self.reduction_ratio, activation='relu')
self.fc2 = tf.keras.layers.Dense(C, activation='sigmoid')
def call(self, x):
# Global average pool → FC → FC → rescale
gap = tf.reduce_mean(x, axis=[1, 2]) # (B, C)
attn = self.fc2(self.fc1(gap)) # (B, C)
return x * attn[:, tf.newaxis, tf.newaxis, :] # broadcast
Interview Questions
Q: When should you use the Functional API instead of Sequential?
A: Whenever you need: (1) multiple inputs/outputs, (2) shared layers, (3) skip connections (ResNet-style), (4) branching (Inception). Sequential only supports linear chains.
Q: What is tf.GradientTape and when do you use it instead of model.fit()?
A: GradientTape records operations for automatic differentiation, enabling a custom training loop with full control. Use it when: custom loss terms, gradient clipping, multiple optimizers (GANs), or logging gradients per-step.
Q: How does model.compile() relate to model.fit()?
A: compile() configures the model: sets optimizer, loss, and metrics. fit() runs the training loop. You must call compile() before fit(). In custom training loops with GradientTape, you bypass both.