Add a new model
In MoZuMa, a model is usually implemented as a class.
The model implementation details primarily depend on the type
of runner used.
For instance, the TorchInferenceRunner
expects to receive a subclass of TorchModel
.
However, there are a few conventions to follow:
- Model predictions should implement the BatchModelPrediction class, this is required for callbacks to work properly.
- If the model's state needs to be saved,
the model should follow the
ModelWithState
protocol. - If the model predicts labels,
it should follow the
ModelWithLabels
protocol.
State management
A model with internal state (weights) should implement the
ModelWithState
protocol
to be compatible with state stores.
mozuma.models.types.ModelWithState
Protocol of a model with internal state (weights)
It defines two functions set_state
and get_state
.
Attributes:
Name | Type | Description |
---|---|---|
state_type |
StateType |
Type of the model state, see states for more information. |
state_type: StateType
property
readonly
Type of the model state
See states for more information.
get_state(self) -> bytes
Get the model internal state
Returns:
Type | Description |
---|---|
bytes |
Serialised state as bytes |
set_state(self, state: bytes) -> None
Set the model internal state
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
bytes |
Serialised state as bytes |
required |
Labels
When a model returns label scores, it must define a
LabelSet
.
This should be defined by implementing the
ModelWithLabels
protocol.
mozuma.models.types.ModelWithLabels
Model that predicts scores for labels
It defines the get_labels
function
get_labels(self) -> LabelSet
Getter for the model's LabelSet
Returns:
Type | Description |
---|---|
LabelSet |
The label set corresponding to returned label scores |
PyTorch models
PyTorch models should be a subclass of TorchModel
.
Note
PyTorch models already implement the
ModelWithState
protocol
by default.
mozuma.torch.modules.TorchModel
Base torch.nn.Module
for PyTorch models.
A valid subclass of TorchModel
must implement the following method:
And can optionally implement:
Attributes:
Name | Type | Description |
---|---|---|
device |
torch.device |
Mandatory PyTorch device attribute to initialise model. |
is_trainable |
bool |
Flag which indicates if the model is trainable. Default, True. |
Examples:
This would define a simple PyTorch model consisting of fully connected layer.
from mozuma.states import StateType
from mozuma.torch.modules import TorchModel
from torchvision import transforms
class FC(TorchModel[torch.Tensor, torch.Tensor]):
def __init__(self, device: torch.device = torch.device("cpu")):
super().__init__(device=device)
self.fc = nn.Linear(512, 512)
def forward(
self, batch: torch.Tensor
) -> torch.Tensor:
return self.fc(batch)
def to_predictions(
self, forward_output: torch.Tensor
) -> BatchModelPrediction[torch.Tensor]:
return BatchModelPrediction(features=forward_output)
@property
def state_type(self) -> StateType:
return StateType(
backend="pytorch",
architecture="fc512x512",
)
def get_dataset_transforms(self) -> List[Callable]:
return [transforms.ToTensor()]
Note
This is a generic class taking a _BatchType
and _ForwardOutputType
type argument.
This corresponds respectively to the type of data the
forward
will take as argument and return. It is most likely torch.Tensor
Note
By default, MoZuMa models are trainable. Set the is_trainable
parameter to False
when creating a subclass if it shouldn't be trained.
state_type: StateType
property
readonly
Identifier for the current's model state architecture
Important
This property must be implemented in subclasses
Note
PyTorch's model architecture should have the pytorch
backend
Returns:
Type | Description |
---|---|
StateType |
State architecture object |
forward(self, batch: ~_BatchType) -> ~_ForwardOutputType
Forward pass of the model
Important
This method must be implemented in subclasses
Applies the module on a batch and returns all potentially interesting data point (features, labels...)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch |
_BatchType |
the batch of data to process |
required |
Returns:
Type | Description |
---|---|
_ForwardOutputType |
A tensor or a sequence of tensor with relevant information (features, labels, bounding boxes...) |
Note
This method must be implemented in subclasses
get_dataloader_collate_fn(self) -> Optional[Callable[[Any], Any]]
Optionally returns a collate function to be passed to the data loader
Note
This collate function will be wrapped in mozuma.torch.collate.TorchModelCollateFn
.
This means that the first argument batch
will not contain
the indices of the dataset but only the data element.
Returns:
Type | Description |
---|---|
Callable[[Any], Any] | None |
The collate function to be passed to |
get_dataset_transforms(self) -> List[Callable]
Transforms to apply to the input dataset.
Note
By default, this method returns an empty list (meaning no transformation) but in most cases, this will need to be overridden.
Returns:
Type | Description |
---|---|
List[Callable] |
A list of callables that will be used to transform the input data. |
to_predictions(self, forward_output: ~_ForwardOutputType) -> mozuma.predictions.BatchModelPrediction[torch.Tensor]
Modifies the output of the forward pass to create the standard BatchModelPrediction object
Important
This method must be implemented in subclasses
Parameters:
Name | Type | Description | Default |
---|---|---|---|
forward_output |
_ForwardOutputType |
the batch of data to process |
required |
Returns:
Type | Description |
---|---|
BatchModelPrediction[torch.Tensor] |
Prediction object with the keys |