Usage Guide & Core Concepts
This guide covers the fundamental concepts and common patterns for using the TensorNets library effectively.
The Functional API: Tensor In, Tensor Out
The core design principle of TensorNets is simplicity. Every network, whether for classification or detection, is a Python function, not a custom class. This function takes a tf.Tensor
as input and returns a tf.Tensor
as output.
This design makes it incredibly easy to integrate TensorNets models into any existing TensorFlow computational graph.
import tensorflow as tf
import tensornets as nets
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
# Create a ResNet50 model
model = nets.ResNet50(inputs)
# The output is a standard TensorFlow tensor
assert isinstance(model, tf.Tensor)
# You can use it like any other tensor
processed_output = tf.nn.relu(model)
Loading Pre-trained Weights
One of the primary features of TensorNets is its collection of pre-trained weights. Loading these weights is a one-line operation. The returned tensor from a model function is augmented with a pretrained()
method.
with tf.Session() as sess:
# Load the pre-trained weights for the model
sess.run(model.pretrained())
If you have multiple models, you can load their weights simultaneously using nets.pretrained()
:
models = [
nets.ResNet50(inputs),
nets.MobileNet100(inputs)
]
with tf.Session() as sess:
# Load weights for all models in the list
nets.pretrained(models)
Image Preprocessing
Different models are trained with different image preprocessing steps (e.g., normalization, color channel ordering). TensorNets handles this automatically via the preprocess()
method attached to the model tensor.
img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)
# Apply the correct preprocessing for ResNet50
preprocessed_img = model.preprocess(img)
This is equivalent to calling nets.preprocess(model, img)
.
Model Introspection
TensorNets provides several useful methods for inspecting the internal structure and state of a model. These are attached to the model tensor itself.
model.middles()
: Returns a list oftf.Tensor
endpoints from representative intermediate layers.model.outputs()
: Returns a list of alltf.Tensor
endpoints from every layer.model.weights()
: Returns a list of alltf.Tensor
weight and bias variables.model.summary()
: Prints a summary of the model, including the number of layers, weights, and total parameters.
Here's how you can use them:
with tf.Session() as sess:
img_preprocessed = model.preprocess(img)
sess.run(model.pretrained())
# Get the feature maps from intermediate layers
middles = sess.run(model.middles(), {inputs: img_preprocessed})
# Get outputs from all layers
outputs = sess.run(model.outputs(), {inputs: img_preprocessed})
# Print the shapes of the intermediate feature maps
model.print_middles()
# Print a summary of the model architecture
model.summary()
Example output from print_middles()
for ResNet50
:
Scope: resnet50
conv2/block1/out:0 (?, 56, 56, 256)
conv2/block2/out:0 (?, 56, 56, 256)
...
conv4/block1/out:0 (?, 14, 14, 1024)
...
Example output from summary()
for ResNet50
:
Scope: resnet50
Total layers: 54
Total weights: 320
Total parameters: 25,636,712
Saving and Loading Custom Weights
If you fine-tune a model or train it from scratch, you can easily save and load the weights using the save()
and load()
methods.
# After training
with tf.Session() as sess:
model.init() # Initialize variables if training from scratch
# ... your training code ...
model.save('my_resnet50_weights.npz')
# For deployment or further training
with tf.Session() as sess:
model.load('my_resnet50_weights.npz')
# ... your deployment code ...