Example: Multi-GPU Usage

TensorNets models can be easily deployed across multiple GPUs for parallel inference or training. This is achieved by using TensorFlow's tf.device context manager.

This example demonstrates how to place two different models on two separate GPUs for parallel inference.

Step 1: Define Models on Different Devices

Use tf.device() to assign each model to a specific GPU. In this example, ResNeXt50 is placed on gpu:0 and DenseNet201 is placed on gpu:1.

import tensorflow as tf
import tensornets as nets

# For TensorFlow 2.x, enable 1.x compatibility
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()

inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
models = []

# Place the first model on GPU 0
with tf.device('/gpu:0'):
    models.append(nets.ResNeXt50(inputs))

# Place the second model on GPU 1
with tf.device('/gpu:1'):
    models.append(nets.DenseNet201(inputs))

Step 2: Prepare Input and Preprocess

Load and preprocess the image. Since we are running multiple models, we can use nets.preprocess which will return a list of preprocessed images, one for each model in the models list.

img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)

# Preprocess the image for all models at once
# Note: ResNeXt and DenseNet use the same preprocessing (fb_preprocess)
preprocessed_imgs = nets.preprocess(models, img)

Note: If the models had different preprocessing requirements, nets.preprocess would return a list of differently processed NumPy arrays. Since these models share the same preprocessing, the returned list will contain identical arrays. For inference, you only need to feed one of them.

Step 3: Run Parallel Inference

Within a tf.Session, load the pre-trained weights for all models and then run the inference. TensorFlow will automatically execute the model computations on their assigned GPUs in parallel.

with tf.Session() as sess:
    # Load pre-trained weights for all models
    nets.pretrained(models)

    # Run inference on both models. TensorFlow handles the parallel execution.
    # We feed the same preprocessed image to the input placeholder.
    # `preprocessed_imgs[0]` is used, but any from the list would work here.
    preds = sess.run(models, {inputs: preprocessed_imgs[0]})

    # Decode and print the predictions from each model
    for pred in preds:
        print(nets.utils.decode_predictions(pred, top=2)[0])