Quick Start: Image Classification

This guide provides a minimal, step-by-step example of how to use TensorNets for a simple image classification task. We will use the ResNet50 model to classify an image.

Step 1: Import Libraries

First, import TensorFlow and TensorNets. If you are using TensorFlow 2.x, you will need to import the compatibility module and disable v2 behavior, as TensorNets is built on the TensorFlow 1.x computational graph API.

import tensorflow as tf
import tensornets as nets

# For TensorFlow 2.x, enable 1.x compatibility
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()

Step 2: Define the Model

Next, define a placeholder for your input images and instantiate the model. In TensorNets, models are functions that take an input tensor and return an output tensor. The input shape for ResNet50 is [batch_size, 224, 224, 3].

inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
model = nets.ResNet50(inputs)

# The model is a TensorFlow Tensor
assert isinstance(model, tf.Tensor)

Step 3: Load and Preprocess an Image

TensorNets provides a utility function nets.utils.load_img to easily load and prepare an image. We'll use a sample image cat.png. The preprocessing steps will resize and crop the image to the required 224x224 dimensions.

# Assumes 'cat.png' is in your working directory
img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)

# The image is now a NumPy array with shape (1, 224, 224, 3)
assert img.shape == (1, 224, 224, 3)

Step 4: Run Inference

Now, you can run the model within a TensorFlow session. We'll perform two key steps:

  1. Apply the model-specific preprocessing to the image using model.preprocess().
  2. Load the pre-trained ImageNet weights into the model using model.pretrained().
  3. Run the session to get the predictions (preds).
with tf.Session() as sess:
    # Apply model-specific preprocessing
    img_preprocessed = model.preprocess(img)

    # Load pre-trained weights
    sess.run(model.pretrained())

    # Run inference
    preds = sess.run(model, {inputs: img_preprocessed})

Step 5: Decode Predictions

Finally, use the nets.utils.decode_predictions utility to convert the raw model output into human-readable class names.

# Get the top 2 predictions
print(nets.utils.decode_predictions(preds, top=2)[0])

This will produce an output similar to the following, showing the most likely classes for the image:

[(u'n02124075', u'Egyptian_cat', 0.28067636), (u'n02127052', u'lynx', 0.16826575)]

And that's it! You have successfully used a pre-trained ResNet50 model for image classification with just a few lines of code.