Object Detection with TensorNets

TensorNets provides a powerful and flexible framework for object detection. You can couple object detection heads like YOLO and FasterRCNN with any of the image classification backbones available in the library.

This guide demonstrates how to use YOLOv2 with a Darknet19 backbone for object detection on the PASCAL VOC dataset.

Step 1: Import Libraries and Define Model

First, import the necessary libraries. Then, define an input placeholder and create the YOLOv2 model. The key idea is to pass the backbone network (nets.Darknet19) as an argument to the detection model.

import tensorflow as tf
import tensornets as nets
import numpy as np
import matplotlib.pyplot as plt
from tensornets.datasets import voc

# For TensorFlow 2.x, enable 1.x compatibility
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()

inputs = tf.placeholder(tf.float32, [None, 416, 416, 3])
# Couple YOLOv2 with the Darknet19 backbone
model = nets.YOLOv2(inputs, nets.Darknet19)

Step 2: Load Image and Run Inference

Load a sample image and run it through the model. Similar to classification models, you load pre-trained weights and apply model-specific preprocessing.

# Assumes 'cat.png' is in your working directory
img = nets.utils.load_img('cat.png')

with tf.Session() as sess:
    # Load pre-trained weights for VOC dataset
    sess.run(model.pretrained())

    # Preprocess the image and run inference
    preds = sess.run(model, {inputs: model.preprocess(img)})

    # Get bounding boxes from predictions
    boxes = model.get_boxes(preds, img.shape[1:3])

Step 3: Interpret and Visualize Results

The model.get_boxes() method returns a list of detected objects, organized by class. Each element is a list of bounding boxes for that class, where each box is (x1, y1, x2, y2, score).

We can then use this information to draw the boxes on the original image.

# The class index for 'cat' in the PASCAL VOC dataset is 7
cat_class_index = 7

# Print the first detected cat box
if boxes[cat_class_index]:
    print("%s: %s" % (voc.classnames[cat_class_index], boxes[cat_class_index][0]))

    # Visualize the bounding box
    box = boxes[cat_class_index][0]
    plt.imshow(img[0].astype(np.uint8))
    plt.gca().add_patch(plt.Rectangle(
        (box[0], box[1]), box[2] - box[0], box[3] - box[1],
        fill=False, edgecolor='r', linewidth=2))
    plt.show()
else:
    print("No cat detected.")

API Differences for Detection Models

Note that there are slight differences in the APIs for various detection models, primarily in how you get the final predictions from the session run.

  • YOLOv3: preds = sess.run(model.preds, {inputs: img})
  • YOLOv2: preds = sess.run(model, {inputs: img})
  • FasterRCNN: preds = sess.run(model, {inputs: img, model.scales: scale})

For FasterRCNN, you also need to provide the image scale, as it processes images at multiple scales. See the README and examples for more details.