Example: Transfer Learning

Transfer learning is a powerful technique where you adapt a model pre-trained on a large dataset (like ImageNet) to a new, specific task. This is useful when you have a smaller dataset.

TensorNets makes transfer learning straightforward by allowing you to replace the final classification layer and load weights for the convolutional base.

This example shows how to adapt a DenseNet169 model for a new task with 50 classes.

Step 1: Define the Model for the New Task

Instantiate the model, but this time, specify the number of classes for your new task and set is_training=True to enable layers like Batch Normalization to work correctly during training.

import tensorflow as tf
import tensornets as nets

# For TensorFlow 2.x, enable 1.x compatibility
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()

# Define placeholders for your data and labels
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
your_labels = tf.placeholder(tf.float32, [None, 50]) # One-hot encoded labels

# Instantiate the model with 50 output classes and in training mode
model = nets.DenseNet169(inputs, is_training=True, classes=50)

Step 2: Define Loss and Optimizer

Define a loss function using the model's logits (model.logits). The logits are the raw outputs before the final softmax activation. Then, create an optimizer to minimize this loss.

# Use softmax cross-entropy as the loss function
loss = tf.losses.softmax_cross_entropy(onehot_labels=your_labels, logits=model.logits)

# Use the Adam optimizer to minimize the loss
train_op = tf.train.AdamOptimizer(learning_rate=1e-5).minimize(loss)

Step 3: Load Pre-trained Weights and Train

In a TensorFlow session, first load the pre-trained ImageNet weights. These weights will populate the convolutional layers, while the new final classification layer will have its initial random weights. Then, you can start your training loop.

with tf.Session() as sess:
    # Load the pre-trained weights for the convolutional base
    nets.pretrained(model)

    # Assuming you have a generator `your_numpy_data` that yields batches (x, y)
    # where x is in NHWC format and y is one-hot encoded.
    for epoch in range(num_epochs):
        for (x_batch, y_batch) in your_numpy_data:
            # Preprocess the input images for DenseNet
            x_preprocessed = model.preprocess(x_batch)

            # Run the training operation
            _, current_loss = sess.run([train_op, loss], 
                                       {inputs: x_preprocessed, your_labels: y_batch})

            print(f"Epoch {epoch}, Loss: {current_loss}")