Text-to-image retrieval with CLIP
This is an example of a text-to-Image retrieval engine based on OpenAI CLIP model.
Import mozuma
modules for the task
from mozuma.torch.runners import TorchInferenceRunner
from mozuma.torch.options import TorchRunnerOptions
from mozuma.callbacks.memory import (
CollectFeaturesInMemory,
)
from mozuma.torch.datasets import (
ImageDataset,
ListDataset,
LocalBinaryFilesDataset,
)
from mozuma.helpers.files import list_files_in_dir
from mozuma.models.clip.text import CLIPTextModule
from mozuma.models.clip.image import CLIPImageModule
from mozuma.states import StateKey
from mozuma.stores import Store
import torch
Load CLIP Image Encoder
image_encoder = CLIPImageModule(clip_model_name="ViT-B/32", device=torch.device("cuda"))
store = Store()
store.load(image_encoder, StateKey(image_encoder.state_type, "clip"))
Extract CLIP image features of FlickR30k dataset
It might take a few minutes for extracting the features...
path_to_flickr30k_images = "/mnt/storage01/datasets/flickr30k/full/images"
file_names = list_files_in_dir(path_to_flickr30k_images, allowed_extensions=("jpg",))
dataset = ImageDataset(LocalBinaryFilesDataset(file_names))
image_features = CollectFeaturesInMemory()
runner = TorchInferenceRunner(
dataset=dataset,
model=image_encoder,
callbacks=[image_features],
options=TorchRunnerOptions(
data_loader_options={"batch_size": 128},
device=image_encoder.device,
tqdm_enabled=True,
),
)
runner.run()
Load CLIP Text Encoder
text_encoder = CLIPTextModule(image_encoder.clip_model_name, device=torch.device("cpu"))
store.load(text_encoder, StateKey(text_encoder.state_type, "clip"))
Extract CLIP text features of a given query
text_queries = [
"Workers look down from up above on a piece of equipment .",
"Ballet dancers in a studio practice jumping with wonderful form .",
]
dataset = ListDataset(text_queries)
text_features = CollectFeaturesInMemory()
runner = TorchInferenceRunner(
dataset=dataset,
model=text_encoder,
callbacks=[text_features],
options=TorchRunnerOptions(
data_loader_options={"batch_size": 1},
device=text_encoder.device,
tqdm_enabled=True,
),
)
runner.run()
Text-to-image retrieval engine
Pick the top 5 most similar images for the text query
img_feat = torch.tensor(image_features.features).type(torch.float32)
img_feat /= img_feat.norm(dim=-1, keepdim=True)
txt_feat = torch.tensor(text_features.features)
txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
similarity = (100.0 * txt_feat @ img_feat.T).softmax(dim=-1)
values, indices = similarity.topk(5)
Display the results