Usage Guide & Core Concepts

This guide covers the fundamental concepts and common patterns for using the TensorNets library effectively.

The Functional API: Tensor In, Tensor Out

The core design principle of TensorNets is simplicity. Every network, whether for classification or detection, is a Python function, not a custom class. This function takes a tf.Tensor as input and returns a tf.Tensor as output.

This design makes it incredibly easy to integrate TensorNets models into any existing TensorFlow computational graph.

import tensorflow as tf
import tensornets as nets

inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])

# Create a ResNet50 model
model = nets.ResNet50(inputs)

# The output is a standard TensorFlow tensor
assert isinstance(model, tf.Tensor)

# You can use it like any other tensor
processed_output = tf.nn.relu(model)

Loading Pre-trained Weights

One of the primary features of TensorNets is its collection of pre-trained weights. Loading these weights is a one-line operation. The returned tensor from a model function is augmented with a pretrained() method.

with tf.Session() as sess:
    # Load the pre-trained weights for the model
    sess.run(model.pretrained())

If you have multiple models, you can load their weights simultaneously using nets.pretrained():

models = [
    nets.ResNet50(inputs),
    nets.MobileNet100(inputs)
]

with tf.Session() as sess:
    # Load weights for all models in the list
    nets.pretrained(models)

Image Preprocessing

Different models are trained with different image preprocessing steps (e.g., normalization, color channel ordering). TensorNets handles this automatically via the preprocess() method attached to the model tensor.

img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)

# Apply the correct preprocessing for ResNet50
preprocessed_img = model.preprocess(img)

This is equivalent to calling nets.preprocess(model, img).

Model Introspection

TensorNets provides several useful methods for inspecting the internal structure and state of a model. These are attached to the model tensor itself.

  • model.middles(): Returns a list of tf.Tensor endpoints from representative intermediate layers.
  • model.outputs(): Returns a list of all tf.Tensor endpoints from every layer.
  • model.weights(): Returns a list of all tf.Tensor weight and bias variables.
  • model.summary(): Prints a summary of the model, including the number of layers, weights, and total parameters.

Here's how you can use them:

with tf.Session() as sess:
    img_preprocessed = model.preprocess(img)
    sess.run(model.pretrained())

    # Get the feature maps from intermediate layers
    middles = sess.run(model.middles(), {inputs: img_preprocessed})

    # Get outputs from all layers
    outputs = sess.run(model.outputs(), {inputs: img_preprocessed})

# Print the shapes of the intermediate feature maps
model.print_middles()

# Print a summary of the model architecture
model.summary()

Example output from print_middles() for ResNet50:

Scope: resnet50
conv2/block1/out:0 (?, 56, 56, 256)
conv2/block2/out:0 (?, 56, 56, 256)
...
conv4/block1/out:0 (?, 14, 14, 1024)
...

Example output from summary() for ResNet50:

Scope: resnet50
Total layers: 54
Total weights: 320
Total parameters: 25,636,712

Saving and Loading Custom Weights

If you fine-tune a model or train it from scratch, you can easily save and load the weights using the save() and load() methods.

# After training
with tf.Session() as sess:
    model.init()  # Initialize variables if training from scratch
    # ... your training code ...
    model.save('my_resnet50_weights.npz')

# For deployment or further training
with tf.Session() as sess:
    model.load('my_resnet50_weights.npz')
    # ... your deployment code ...