Quick Start: Image Classification
This guide provides a minimal, step-by-step example of how to use TensorNets for a simple image classification task. We will use the ResNet50
model to classify an image.
Step 1: Import Libraries
First, import TensorFlow and TensorNets. If you are using TensorFlow 2.x, you will need to import the compatibility module and disable v2 behavior, as TensorNets is built on the TensorFlow 1.x computational graph API.
import tensorflow as tf
import tensornets as nets
# For TensorFlow 2.x, enable 1.x compatibility
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()
Step 2: Define the Model
Next, define a placeholder for your input images and instantiate the model. In TensorNets, models are functions that take an input tensor and return an output tensor. The input shape for ResNet50
is [batch_size, 224, 224, 3]
.
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
model = nets.ResNet50(inputs)
# The model is a TensorFlow Tensor
assert isinstance(model, tf.Tensor)
Step 3: Load and Preprocess an Image
TensorNets provides a utility function nets.utils.load_img
to easily load and prepare an image. We'll use a sample image cat.png
. The preprocessing steps will resize and crop the image to the required 224x224
dimensions.
# Assumes 'cat.png' is in your working directory
img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)
# The image is now a NumPy array with shape (1, 224, 224, 3)
assert img.shape == (1, 224, 224, 3)
Step 4: Run Inference
Now, you can run the model within a TensorFlow session. We'll perform two key steps:
- Apply the model-specific preprocessing to the image using
model.preprocess()
. - Load the pre-trained ImageNet weights into the model using
model.pretrained()
. - Run the session to get the predictions (
preds
).
with tf.Session() as sess:
# Apply model-specific preprocessing
img_preprocessed = model.preprocess(img)
# Load pre-trained weights
sess.run(model.pretrained())
# Run inference
preds = sess.run(model, {inputs: img_preprocessed})
Step 5: Decode Predictions
Finally, use the nets.utils.decode_predictions
utility to convert the raw model output into human-readable class names.
# Get the top 2 predictions
print(nets.utils.decode_predictions(preds, top=2)[0])
This will produce an output similar to the following, showing the most likely classes for the image:
[(u'n02124075', u'Egyptian_cat', 0.28067636), (u'n02127052', u'lynx', 0.16826575)]
And that's it! You have successfully used a pre-trained ResNet50
model for image classification with just a few lines of code.