Face emotion recognition training
This notebooks shows how to train a face emotion recognition model on top of ArcFace face features
Import MTCNN and ArcFace modules from mozuma
from mozuma.models.arcface.pretrained import torch_arcface_insightface
from mozuma.models.mtcnn.pretrained import torch_mtcnn
from mozuma.torch.options import TorchRunnerOptions
from mozuma.torch.runners import TorchInferenceRunner
from mozuma.callbacks.memory import (
CollectBoundingBoxesInMemory,
CollectFeaturesInMemory,
)
from mozuma.torch.datasets import (
ImageBoundingBoxDataset,
ListDataset,
ListDatasetIndexed,
)
from torchvision.datasets import FER2013
import os
Enable logging inside notebook
import logging
import sys
logging.basicConfig(
format="%(asctime)s | %(levelname)s : %(message)s",
level=logging.INFO,
stream=sys.stdout,
)
Load a dataset containing images of faces annotated with emotion labels
We should first download FER2013 dataset from Kaggle and unzip train.csv
and test.csv
files.
path_to_fer2013 = os.path.join(os.environ["HOME"], "torchvision-datasets")
train_set = FER2013(root=path_to_fer2013, split="train")
# Training images
labels_dict = {
0: "Angry",
1: "Disgust",
2: "Fear",
3: "Happy",
4: "Sad",
5: "Surprise",
6: "Neutral",
}
train_samples = [(img.convert("RGB"), labels_dict[label]) for img, label in train_set]
train_images, train_labels = zip(*train_samples)
Run face detection on the images with TorchMTCNNModule
torch_device = "cuda"
model = torch_mtcnn(device=torch_device)
# Callbacks
bb = CollectBoundingBoxesInMemory()
# Runner
runner = TorchInferenceRunner(
model=model,
dataset=ListDataset(train_images),
callbacks=[bb],
options=TorchRunnerOptions(
data_loader_options={"batch_size": 32}, device=torch_device, tqdm_enabled=True
),
)
runner.run()
Extract face features with TorchArcFaceModule
arcface = torch_arcface_insightface(device=torch_device)
# Dataset
dataset = ImageBoundingBoxDataset(
image_dataset=ListDatasetIndexed(indices=bb.indices, objects=train_images),
bounding_boxes=bb.bounding_boxes,
)
# Callbacks
ff = CollectFeaturesInMemory()
# Runner
runner = TorchInferenceRunner(
model=arcface,
dataset=dataset,
callbacks=[ff],
options=TorchRunnerOptions(
data_loader_options={"batch_size": 32}, device=torch_device, tqdm_enabled=True
),
)
runner.run()
Training of a linear classifier on top of the face features
Import the module for training
from mozuma.models.classification import LinearClassifierTorchModule
from mozuma.torch.datasets import TorchTrainingDataset
from mozuma.torch.runners import TorchTrainingRunner
from mozuma.torch.options import TorchTrainingOptions
from mozuma.labels.base import LabelSet
import torch
import torch.nn.functional as F
import torch.optim as optim
Define the training dataset
Define the labels
labels = list(labels_dict.values())
label_set = LabelSet(label_set_unique_id="emotion", label_list=labels)
# split samples into train and valid sets
train_indices, valid_indices = torch.split(
torch.randperm(len(ff.indices)), int(len(ff.indices) * 0.9)
)
# define training set
train_dset = TorchTrainingDataset(
dataset=ListDatasetIndexed(train_indices, ff.features[train_indices]),
targets=label_set.get_label_ids([train_labels[idx] for idx in train_indices]),
)
# define valid set
valid_dset = TorchTrainingDataset(
dataset=ListDatasetIndexed(valid_indices, ff.features[valid_indices]),
targets=label_set.get_label_ids([train_labels[idx] for idx in valid_indices]),
)
Define the linear classifier
in_features = len(ff.features[0])
classifier = LinearClassifierTorchModule(in_features=in_features, label_set=label_set)
Define the trainer
from ignite.metrics import Precision, Recall, Loss, Accuracy
precision = Precision(average=False)
recall = Recall(average=False)
F1 = (precision * recall * 2 / (precision + recall)).mean()
loss_fn = F.cross_entropy
trainer = TorchTrainingRunner(
model=classifier,
dataset=(train_dset, valid_dset),
callbacks=[],
options=TorchTrainingOptions(
data_loader_options={"batch_size": 32},
criterion=loss_fn,
optimizer=optim.Adam(classifier.parameters(), lr=1e-2),
metrics={
"pre": precision,
"recall": recall,
"f1": F1,
"acc": Accuracy(),
"ce_loss": Loss(loss_fn),
},
validate_every=1,
num_epoch=5,
tqdm_enabled=True,
),
)
trainer.run()