Skip to content

Object Detection with VinVL

See the source code Open In Colab

Import mozuma modules

from mozuma.torch.options import TorchRunnerOptions
from mozuma.torch.runners import TorchInferenceRunner
from mozuma.callbacks.memory import (
    CollectBoundingBoxesInMemory,
)
from mozuma.helpers.files import list_files_in_dir
from mozuma.torch.datasets import LocalBinaryFilesDataset, ImageDataset
from mozuma.models.vinvl.pretrained import torch_vinvl_detector

import torch
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from PIL import Image
import os

%matplotlib inline

Load images

base_path = os.path.join("../../tests", "fixtures", "objects")
file_names = list_files_in_dir(base_path, allowed_extensions=("jpg",))[:50]
dataset = ImageDataset(LocalBinaryFilesDataset(file_names))

Run object detection with torch_vinvl_detector

# Load VinVL model (it might take a few minutes.)
torch_device = torch.device("cuda")
vinvl = torch_vinvl_detector(device=torch_device, score_threshold=0.5)

bb = CollectBoundingBoxesInMemory()

# Runner
runner = TorchInferenceRunner(
    model=vinvl,
    dataset=dataset,
    callbacks=[bb],
    options=TorchRunnerOptions(
        device=torch_device, data_loader_options={"batch_size": 10}, tqdm_enabled=True
    ),
)
runner.run()

Visualise the detected objects

First get labels and attributes

for i, img_path in enumerate(bb.indices):
    print(f"Object detected for {img_path}")
    img = Image.open(img_path).convert("RGB")
    plt.figure()
    plt.imshow(img)
    bboxes = bb.bounding_boxes[i].bounding_boxes
    scores = bb.bounding_boxes[i].scores
    for k, bbox in enumerate(bboxes):
        bbox0, bbox1, bbox2, bbox3 = bbox
        plt.gca().add_patch(
            Rectangle(
                (bbox0, bbox1),
                bbox2 - bbox0,
                bbox3 - bbox1,
                fill=False,
                edgecolor="red",
                linewidth=2,
                alpha=0.5,
            )
        )
        plt.text(bbox0, bbox1, f"{scores[k]*100:.1f}%", color="blue", fontsize=12)