In this article, we will present how to interface OpenCV CUDA with NVIDIA TensorRT via the C++ API for fast inference on NVIDIA GPUs.

Deep Learning has revolutionized the field of computer vision by enabling machines to learn and recognize patterns from images and videos. However, training Deep Learning models is a computationally intensive task that requires powerful hardware and large amounts of data. Once a model is trained, it can be used for inference, which involves applying the model to new data to make predictions. Inference is typically faster than training, but it can still be a bottleneck in real-time applications.

Fortunately, there are several tools available that can help accelerate Deep Learning inference on GPUs. One such tool is NVIDIA TensorRT, a high-performance inference engine designed to optimize and deploy Deep Learning models on NVIDIA GPUs. In this article, we will explore how to use TensorRT to accelerate Deep Learning inference when using OpenCV CUDA.

OpenCV CUDA is an extension of the popular OpenCV computer vision library that allows developers to take advantage of NVIDIA GPUs for faster image and video processing. With OpenCV CUDA, developers can use CUDA-enabled functions to perform image and video processing tasks, such as filtering, transformation, and object detection.

OpenCV has a DNN (Deep Neural Networks) module that could be used to do inference. Recent versions of OpenCV DNN has added a CUDA backed. While DNN is very easy to integrate in an OpenCV pipeline, we will show that, with a little effort, interfacing with TensorRT is really simple and will provide a big performance boost.

To use TensorRT with OpenCV CUDA, we first need to prepare our Deep Learning model for inference using TensorRT. This involves converting the model from its original format (such as TensorFlow or PyTorch) into a format that can be used by TensorRT. We will do this process using the trtexec command line tool.

Build a TensorRT engine

For this demonstration, we will use a ResNet50 model and weights from PyTorch torchvision. First, the model is exported to ONNX which is one of the formats (UFF, ONNX, Caffe) supported by trtexec using a Python script like this one:

import torch
from torchvision.models import resnet50, ResNet50_Weights
import torch.onnx as onnx

batch_size = 32
onnx_path = "resnet.onnx"

inputs = torch.randn(batch_size, 3, 224, 224)
model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.eval()

outputs = model(inputs)

onnx.export(model,                     # model being run
            inputs,                    # model input (or a tuple for multiple inputs)
            save_path,                 # where to save the model (can be a file or file-like object)
            export_params=True,        # store the trained parameter weights inside the model file
            opset_version=10,          # the ONNX version to export the model to
            do_constant_folding=True,  # whether to execute constant folding for optimization
            input_names = ['input'],   # the model's input names
            output_names = ['output'], # the model's output names
            dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                          'output' : {0 : 'batch_size'}})

Then, trtexec is used to build and optimize a TensorRT engine targeting fp32+fp16 processing.

trtexec --verbose --onnx=resnet.onnx --saveEngine=resnet16.engine \
    --minShapes="input:1x3x224x224" --optShapes="input:32x3x224x224" \
    --maxShapes="input:256x3x224x224" --fp16

Warning: This will produce an engine optimized for GPU architecture on which trtexec was run. If you need to support other architectures, you will need to repeat the operation for each one.

Then, the engine could be profiled for other batch sizes:

trtexec --loadEngine=resnet16.engine --shapes="input:32x3x224x224"

Use TensorRT C++ API with OpenCV

OpenCV CUDA is a module that allows to do most of the OpenCV operations on the GPU using CUDA. Our goal is to pass the cv::cuda::GpuMat already on GPU to the TensorRT C++ API.

TensorRT C++ API needs some steps to load the engine and create the necessary objects which will later be used to run the inference.

First, the inference runtime is created, the engine loaded and the execution context created.

std::vector<char> buffer = load("resnet16.engine");
std::unique_ptr<IRuntime> runtime{ createInferRuntime(logger) };

auto engine = std::unique_ptr<ICudaEngine>(runtime->deserializeCudaEngine(buffer.data(), buffer.size()));
auto context = std::unique_ptr<IExecutionContext>(m_engine->createExecutionContext());

Then, we need to allocate the processing buffers.

auto idims = m_engine->getTensorShape("input");
auto odims = m_engine->getTensorShape("output");
Dims4 inputDims = { 1, idims.d[1], idims.d[2], idims.d[3] };
Dims4 outputDims = { 1, odims.d[1], odims.d[2], odims.d[3] };
context->setInputShape("input", inputDims);

size_t inputLen = idims.d[1] * idims.d[2] * idims.d[3] * sizeof(float);
size_t outputLen = odims.d[1] * odims.d[2] * odims.d[3] * sizeof(float);

float *input_data, *output_data;
cudaMalloc(&input_data, inputLen);
cudaMalloc(&output_data, outputLen);

context->setTensorAddress("input", input_data);
context->setTensorAddress("output", output_data);

Finally, the inference could be done over and over on various images.

cv::cuda::GpuMat input, output;

toNCHW(input, input_data, stream);
context->enqueueV3(stream);
fromNCHW(output_data, output, stream);

However, OpenCV’s cv::cuda::GpuMat memory model is HWC while TensorRT engine created from ONNX are expecting NCHW (batch N, channels C, height H, width W) format. Two helper functions (toNCHW/fromNCHW) will be needed to transform cv::cuda::GpuMat to/from a buffer accepted by TensorRT.

For simplicity of this example, we use a batch size of 1. In a general case, we should handle an array of GpuMat.

OpenCV’s cv::cuda::PtrStepSz<T> allows to easily access pixel in the kernel.

__global__
void toNCHWKernel(cv::cuda::PtrStepSz<float3> in, float* out)
{
    const int x = blockIdx.x * blockDim.x + threadIdx.x;
    const int y = blockIdx.y * blockDim.y + threadIdx.y;

    if ((x >= in.cols) || (y >= in.rows))
        return;

    float3 v = in(y, x);
    int step = in.cols * in.rows;
    int idx = y * in.cols + x;
    out[idx] = v.x;
    out[idx + step] = v.y;
    out[idx + 2 * step] = v.z;
}

And corresponding code to call the toNCHWKernel:

void toNCHW(const cv::cuda::GpuMat& input, float* output,
            cudaStream_t stream)
{
    const dim3 threads(32, 8);
    const dim3 grid(cv::cuda::device::divUp(input.cols, threads.x),
                    cv::cuda::device::divUp(input.rows, threads.y));

    toNCHWKernel<<<grid,threads,0,stream>>>(input, output);
}

For the ResNet model, the output is a single channel. A more general code will handle the 1 and 3 channels cases.

__global__
void fromNCHWKernel(const float* in, cv::cuda::PtrStepSz<float> out)
{
    const int x = blockIdx.x * blockDim.x + threadIdx.x;
    const int y = blockIdx.y * blockDim.y + threadIdx.y;

    if ((x >= out.cols) || (y >= out.rows))
        return;

    int idx = y * out.cols + x;
    float v = in[idx];
    out(y, x) = v;
}

The code to call it:

void fromNCHW(const float* input, cv::cuda::GpuMat& output,
              cudaStream_t stream)
{
    const dim3 threads(32, 8);
    const dim3 grid(cv::cuda::device::divUp(output.cols, threads.x),
                    cv::cuda::device::divUp(output.rows, threads.y));

    fromNCHWKernel<<<grid,threads,0,stream>>>(input, output);
}

Your input and output cv::cuda::GpuMat should have the sizes expected by the input and output of the network. (244,244,CV_32FC3) and (1000,1,CV_32FC1) for this ResNet50 example.

Conclusion

NVIDIA TensorRT is a SDK for high-performance deep learning inference. Unfortunately, OpenCV DNN has no roadmap to support it.

Here, we showed that interfacing OpenCV CUDA with TensorRT is quite straight forward.

Supporting batch of images and splitting work on multiple GPUs is a matter of some more lines of code. Don’t hesitate to contact me to discuss your project.