Stores
Stores are used to load and save models state.
They define 3 methods: get_state_keys
, load
and save
.
The get_state_keys
method allows to list available pre-trained states for a given type of state.
It usually called from a model
.
# List available model states in store for the model state type
state_keys = store.get_state_keys(model.state_type)
print(state_keys)
# Prints for a ResNet pre-trained on ImageNet:
# [
# StateKey(
# state_type=StateType(
# backend='pytorch',
# architecture='resnet18',
# extra=('cls1000',)
# ),
# training_id='imagenet'
# )
# ]
See states documentation for more information on how pre-trained states are identified.
Then, a state can be loaded into the model
.
# Getting the first state key
resnet_imagenet_state_key = state_keys[0]
# Loading it into the model
store.load(model, state_key=resnet_imagenet_state_key)
A model can be saved by specifying a training_id
which should uniquely identify the training activity that yielded this model's state.
See AbstractStateStore for more details on these methods.
MoZuMa pre-trained models
MoZuMa provides model weights for all defined models
through the MoZuMa Store
.
mozuma.stores.Store() -> GitHUBReleaseStore
Alternative stores
These stores can be used if you want to store model states locally or on S3 storage.
mozuma.stores.local.LocalStateStore
dataclass
Local filebased store
Attributes:
Name | Type | Description |
---|---|---|
folder |
str |
Path to the folder to save model's state |
mozuma.stores.s3.S3StateStore
dataclass
State store on top of S3 object storage
Given the state keys, states are organised in in folders with the following structure:
base_path/
├─ {backend}/
│ ├─ {architecture}.{extra1}.{extra2}.{training_id}.pt
├─ pytorch/ # e.g. for torch models
│ ├─ resnet18.cls1000.imagenet.pt
│ ├─ clip-image-rn50.clip.pt
Attributes:
Name | Type | Description |
---|---|---|
bucket |
str |
Bucket to use to store states |
session_kwargs |
dict |
Arguments passed to initialise |
s3_endpoint_url |
str |
To connect to S3 compatible storage |
base_path |
str |
The base path to store states |
mozuma.stores.github.GitHUBReleaseStore
dataclass
Store implementation leveraging GitHUB releases
Model weights are stored as assets in a release.
We recommend setting up GitHUB authentication to use the
get_state_keys
and the save
methods.
These methods are calling the releases API
and are limited to 60 requests by hour unauthenticated.
This can be done with:
-
Personal access token
:
The PAT and username need to be set in the environment variable
GH_API_BASIC_AUTH={username}:{personal_access_token}
-
GitHUB Token
:
Used in GitHUB Actions, needs to be set in a
GH_TOKEN
environment variable.
Model states are organised in releases with the following convention:
- Release name and tags are constructed as
{release_name_prefix}.{state_type.backend}.{state_type.architecture}
- Asset names within a release are constructed as
{state_type.extra1}.{state_type.extra2}.{training_id}.state.gzip
Attributes:
Name | Type | Description |
---|---|---|
repository_owner |
str |
The owner of the GitHUB repository |
repository_name |
str |
The name of the repository to use as a store |
branch_name |
str |
The branch used to create new releases holding model state.
By default we recommend using an
orphan branch
named |
release_name_prefix |
str |
The prefix to identify releases containing model weights.
Defaults to |
Write your own store
A store should inherit AbstractStateStore
and implement the save
, load
and get_state_keys
methods.
mozuma.stores.abstract.AbstractStateStore
Interface to handle model state loading and saving
See states reference for more information on state management.
exists(self, state_key: StateKey) -> bool
Tests whether the state key exists in the current store
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state_key |
StateKey |
The state key to test |
required |
Returns:
Type | Description |
---|---|
bool |
|
get_state_keys(self, state_type: StateType) -> List[mozuma.states.StateKey]
Lists the available states that are compatible with the given state type.
Attributes:
Name | Type | Description |
---|---|---|
state_type |
StateType |
Used to filter the compatible state keys |
Examples:
This is used to list the pretrained weights for a given model.
The following code gives all available state keys in store
for the model
.
load(self, model: ~_ModelType, state_key: StateKey) -> None
Loads the models weights from the store
Attributes:
Name | Type | Description |
---|---|---|
model |
ModelWithState |
Model to update |
state_key |
StateKey |
The identifier for the state to load |
save(self, model: ~_ModelType, training_id: str) -> StateKey
Saves the model state to the store
Attributes:
Name | Type | Description |
---|---|---|
model |
ModelWithState |
Model to save |
training_id |
str |
Identifier for the training activity |
Returns:
Type | Description |
---|---|
StateKey |
The identifier for the state that has been created |
For stores used to download states of a single model, it can be useful to subclass the
AbstractListStateStore
directly.
This makes it easier to define a store from a fix set of states as it is often the case when
integrating the weights from external sources (pre-trained states for a paper, hugging face...).
See SBERTDistiluseBaseMultilingualCasedV2Store
for an example.
mozuma.stores.list.AbstractListStateStore
Helper to define a store from a fixed list of state keys.
The subclasses should implement the following:
available_state_keys: List[mozuma.states.StateKey]
property
readonly
List of available state keys for this store
Returns:
Type | Description |
---|---|
list(StateKey) |
All available state keys in the store |
get_state_keys(self, state_type: StateType) -> List[mozuma.states.StateKey]
Lists the available states that are compatible with the given state type.
Attributes:
Name | Type | Description |
---|---|---|
state_type |
StateType |
Used to filter the compatible state keys |
Examples:
This is used to list the pretrained weights for a given model.
The following code gives all available state keys in store
for the model
.
load(self, model: ~_ModelType, state_key: StateKey) -> None
Loads the models weights from the store
Attributes:
Name | Type | Description |
---|---|---|
model |
ModelWithState |
Model to update |
state_key |
StateKey |
The identifier for the state to load |
save(self, model: ~_ModelType, training_id: str) -> NoReturn
Saves the model state to the store
Attributes:
Name | Type | Description |
---|---|---|
model |
ModelWithState |
Model to save |
training_id |
str |
Identifier for the training activity |
Returns:
Type | Description |
---|---|
StateKey |
The identifier for the state that has been created |
state_downloader(self, model: ~_ModelType, state_key: StateKey) -> None
Downloads and applies a state to a model
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
_ModelType |
The model that will be used to load the state |
required |
state_key |
StateKey |
The state key identifier |
required |