Skip to content

Add a new model

In MoZuMa, a model is usually implemented as a class. The model implementation details primarily depend on the type of runner used. For instance, the TorchInferenceRunner expects to receive a subclass of TorchModel.

However, there are a few conventions to follow:

State management

A model with internal state (weights) should implement the ModelWithState protocol to be compatible with state stores.

mozuma.models.types.ModelWithState

Protocol of a model with internal state (weights)

It defines two functions set_state and get_state.

Attributes:

Name Type Description
state_type StateType

Type of the model state, see states for more information.

state_type: StateType property readonly

Type of the model state

See states for more information.

get_state(self) -> bytes

Get the model internal state

Returns:

Type Description
bytes

Serialised state as bytes

set_state(self, state: bytes) -> None

Set the model internal state

Parameters:

Name Type Description Default
state bytes

Serialised state as bytes

required

Labels

When a model returns label scores, it must define a LabelSet. This should be defined by implementing the ModelWithLabels protocol.

mozuma.models.types.ModelWithLabels

Model that predicts scores for labels

It defines the get_labels function

get_labels(self) -> LabelSet

Getter for the model's LabelSet

Returns:

Type Description
LabelSet

The label set corresponding to returned label scores

PyTorch models

PyTorch models should be a subclass of TorchModel.

Note

PyTorch models already implement the ModelWithState protocol by default.

mozuma.torch.modules.TorchModel

Base torch.nn.Module for PyTorch models.

A valid subclass of TorchModel must implement the following method:

And can optionally implement:

Attributes:

Name Type Description
device torch.device

Mandatory PyTorch device attribute to initialise model.

is_trainable bool

Flag which indicates if the model is trainable. Default, True.

Examples:

This would define a simple PyTorch model consisting of fully connected layer.

from mozuma.states import StateType
from mozuma.torch.modules import TorchModel
from torchvision import transforms


class FC(TorchModel[torch.Tensor, torch.Tensor]):

    def __init__(self, device: torch.device = torch.device("cpu")):
        super().__init__(device=device)
        self.fc = nn.Linear(512, 512)

    def forward(
        self, batch: torch.Tensor
    ) -> torch.Tensor:
        return self.fc(batch)

    def to_predictions(
        self, forward_output: torch.Tensor
    ) -> BatchModelPrediction[torch.Tensor]:
        return BatchModelPrediction(features=forward_output)

    @property
    def state_type(self) -> StateType:
        return StateType(
            backend="pytorch",
            architecture="fc512x512",
        )

    def get_dataset_transforms(self) -> List[Callable]:
        return [transforms.ToTensor()]

Note

This is a generic class taking a _BatchType and _ForwardOutputType type argument. This corresponds respectively to the type of data the forward will take as argument and return. It is most likely torch.Tensor

Note

By default, MoZuMa models are trainable. Set the is_trainable parameter to False when creating a subclass if it shouldn't be trained.

state_type: StateType property readonly

Identifier for the current's model state architecture

Important

This property must be implemented in subclasses

Note

PyTorch's model architecture should have the pytorch backend

Returns:

Type Description
StateType

State architecture object

forward(self, batch: ~_BatchType) -> ~_ForwardOutputType

Forward pass of the model

Important

This method must be implemented in subclasses

Applies the module on a batch and returns all potentially interesting data point (features, labels...)

Parameters:

Name Type Description Default
batch _BatchType

the batch of data to process

required

Returns:

Type Description
_ForwardOutputType

A tensor or a sequence of tensor with relevant information (features, labels, bounding boxes...)

Note

This method must be implemented in subclasses

get_dataloader_collate_fn(self) -> Optional[Callable[[Any], Any]]

Optionally returns a collate function to be passed to the data loader

Note

This collate function will be wrapped in mozuma.torch.collate.TorchModelCollateFn. This means that the first argument batch will not contain the indices of the dataset but only the data element.

Returns:

Type Description
Callable[[Any], Any] | None

The collate function to be passed to TorchModelCollateFn.

get_dataset_transforms(self) -> List[Callable]

Transforms to apply to the input dataset.

Note

By default, this method returns an empty list (meaning no transformation) but in most cases, this will need to be overridden.

Returns:

Type Description
List[Callable]

A list of callables that will be used to transform the input data.

to_predictions(self, forward_output: ~_ForwardOutputType) -> mozuma.predictions.BatchModelPrediction[torch.Tensor]

Modifies the output of the forward pass to create the standard BatchModelPrediction object

Important

This method must be implemented in subclasses

Parameters:

Name Type Description Default
forward_output _ForwardOutputType

the batch of data to process

required

Returns:

Type Description
BatchModelPrediction[torch.Tensor]

Prediction object with the keys features, label_scores...