Example: Multi-GPU Usage
TensorNets models can be easily deployed across multiple GPUs for parallel inference or training. This is achieved by using TensorFlow's tf.device
context manager.
This example demonstrates how to place two different models on two separate GPUs for parallel inference.
Step 1: Define Models on Different Devices
Use tf.device()
to assign each model to a specific GPU. In this example, ResNeXt50
is placed on gpu:0
and DenseNet201
is placed on gpu:1
.
import tensorflow as tf
import tensornets as nets
# For TensorFlow 2.x, enable 1.x compatibility
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
models = []
# Place the first model on GPU 0
with tf.device('/gpu:0'):
models.append(nets.ResNeXt50(inputs))
# Place the second model on GPU 1
with tf.device('/gpu:1'):
models.append(nets.DenseNet201(inputs))
Step 2: Prepare Input and Preprocess
Load and preprocess the image. Since we are running multiple models, we can use nets.preprocess
which will return a list of preprocessed images, one for each model in the models
list.
img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)
# Preprocess the image for all models at once
# Note: ResNeXt and DenseNet use the same preprocessing (fb_preprocess)
preprocessed_imgs = nets.preprocess(models, img)
Note: If the models had different preprocessing requirements, nets.preprocess
would return a list of differently processed NumPy arrays. Since these models share the same preprocessing, the returned list will contain identical arrays. For inference, you only need to feed one of them.
Step 3: Run Parallel Inference
Within a tf.Session
, load the pre-trained weights for all models and then run the inference. TensorFlow will automatically execute the model computations on their assigned GPUs in parallel.
with tf.Session() as sess:
# Load pre-trained weights for all models
nets.pretrained(models)
# Run inference on both models. TensorFlow handles the parallel execution.
# We feed the same preprocessed image to the input placeholder.
# `preprocessed_imgs[0]` is used, but any from the list would work here.
preds = sess.run(models, {inputs: preprocessed_imgs[0]})
# Decode and print the predictions from each model
for pred in preds:
print(nets.utils.decode_predictions(pred, top=2)[0])