Object Detection with TensorNets
TensorNets provides a powerful and flexible framework for object detection. You can couple object detection heads like YOLO and FasterRCNN with any of the image classification backbones available in the library.
This guide demonstrates how to use YOLOv2
with a Darknet19
backbone for object detection on the PASCAL VOC dataset.
Step 1: Import Libraries and Define Model
First, import the necessary libraries. Then, define an input placeholder and create the YOLOv2
model. The key idea is to pass the backbone network (nets.Darknet19
) as an argument to the detection model.
import tensorflow as tf
import tensornets as nets
import numpy as np
import matplotlib.pyplot as plt
from tensornets.datasets import voc
# For TensorFlow 2.x, enable 1.x compatibility
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()
inputs = tf.placeholder(tf.float32, [None, 416, 416, 3])
# Couple YOLOv2 with the Darknet19 backbone
model = nets.YOLOv2(inputs, nets.Darknet19)
Step 2: Load Image and Run Inference
Load a sample image and run it through the model. Similar to classification models, you load pre-trained weights and apply model-specific preprocessing.
# Assumes 'cat.png' is in your working directory
img = nets.utils.load_img('cat.png')
with tf.Session() as sess:
# Load pre-trained weights for VOC dataset
sess.run(model.pretrained())
# Preprocess the image and run inference
preds = sess.run(model, {inputs: model.preprocess(img)})
# Get bounding boxes from predictions
boxes = model.get_boxes(preds, img.shape[1:3])
Step 3: Interpret and Visualize Results
The model.get_boxes()
method returns a list of detected objects, organized by class. Each element is a list of bounding boxes for that class, where each box is (x1, y1, x2, y2, score)
.
We can then use this information to draw the boxes on the original image.
# The class index for 'cat' in the PASCAL VOC dataset is 7
cat_class_index = 7
# Print the first detected cat box
if boxes[cat_class_index]:
print("%s: %s" % (voc.classnames[cat_class_index], boxes[cat_class_index][0]))
# Visualize the bounding box
box = boxes[cat_class_index][0]
plt.imshow(img[0].astype(np.uint8))
plt.gca().add_patch(plt.Rectangle(
(box[0], box[1]), box[2] - box[0], box[3] - box[1],
fill=False, edgecolor='r', linewidth=2))
plt.show()
else:
print("No cat detected.")
API Differences for Detection Models
Note that there are slight differences in the APIs for various detection models, primarily in how you get the final predictions from the session run.
- YOLOv3:
preds = sess.run(model.preds, {inputs: img})
- YOLOv2:
preds = sess.run(model, {inputs: img})
- FasterRCNN:
preds = sess.run(model, {inputs: img, model.scales: scale})
For FasterRCNN
, you also need to provide the image scale, as it processes images at multiple scales. See the README
and examples for more details.