Skip to content

Stores

Stores are used to load and save models state. They define 3 methods: get_state_keys, load and save.

The get_state_keys method allows to list available pre-trained states for a given type of state. It usually called from a model.

# List available model states in store for the model state type
state_keys = store.get_state_keys(model.state_type)

print(state_keys)
# Prints for a ResNet pre-trained on ImageNet:
# [
#   StateKey(
#       state_type=StateType(
#           backend='pytorch',
#           architecture='resnet18',
#           extra=('cls1000',)
#       ),
#       training_id='imagenet'
#   )
# ]

See states documentation for more information on how pre-trained states are identified.

Then, a state can be loaded into the model.

# Getting the first state key
resnet_imagenet_state_key = state_keys[0]
# Loading it into the model
store.load(model, state_key=resnet_imagenet_state_key)

A model can be saved by specifying a training_id which should uniquely identify the training activity that yielded this model's state.

store.save(model, training_id="2022-01-01-finetuned-imagenet")

See AbstractStateStore for more details on these methods.

MoZuMa pre-trained models

MoZuMa provides model weights for all defined models through the MoZuMa Store.

mozuma.stores.Store() -> GitHUBReleaseStore

MoZuMa model state store.

Examples:

The store can be used to list available pre-trained states for a model

store = Store()
states = store.get_state_keys(model.state_type)

And load a given state to a model

store.load(model, state_key=states[0])

Alternative stores

These stores can be used if you want to store model states locally or on S3 storage.

mozuma.stores.local.LocalStateStore dataclass

Local filebased store

Attributes:

Name Type Description
folder str

Path to the folder to save model's state

mozuma.stores.s3.S3StateStore dataclass

State store on top of S3 object storage

Given the state keys, states are organised in in folders with the following structure:

base_path/
├─ {backend}/
│  ├─ {architecture}.{extra1}.{extra2}.{training_id}.pt
├─ pytorch/     # e.g. for torch models
│  ├─ resnet18.cls1000.imagenet.pt
│  ├─ clip-image-rn50.clip.pt

Attributes:

Name Type Description
bucket str

Bucket to use to store states

session_kwargs dict

Arguments passed to initialise boto3.session.Session

s3_endpoint_url str

To connect to S3 compatible storage

base_path str

The base path to store states

mozuma.stores.github.GitHUBReleaseStore dataclass

Store implementation leveraging GitHUB releases

Model weights are stored as assets in a release.

We recommend setting up GitHUB authentication to use the get_state_keys and the save methods. These methods are calling the releases API and are limited to 60 requests by hour unauthenticated. This can be done with:

  • Personal access token : The PAT and username need to be set in the environment variable GH_API_BASIC_AUTH={username}:{personal_access_token}
  • GitHUB Token : Used in GitHUB Actions, needs to be set in a GH_TOKEN environment variable.

Model states are organised in releases with the following convention:

  • Release name and tags are constructed as {release_name_prefix}.{state_type.backend}.{state_type.architecture}
  • Asset names within a release are constructed as {state_type.extra1}.{state_type.extra2}.{training_id}.state.gzip

Attributes:

Name Type Description
repository_owner str

The owner of the GitHUB repository

repository_name str

The name of the repository to use as a store

branch_name str

The branch used to create new releases holding model state. By default we recommend using an orphan branch named model-store.

release_name_prefix str

The prefix to identify releases containing model weights. Defaults to state, should not contain a dot.

Write your own store

A store should inherit AbstractStateStore and implement the save, load and get_state_keys methods.

mozuma.stores.abstract.AbstractStateStore

Interface to handle model state loading and saving

See states reference for more information on state management.

exists(self, state_key: StateKey) -> bool

Tests whether the state key exists in the current store

Parameters:

Name Type Description Default
state_key StateKey

The state key to test

required

Returns:

Type Description
bool

True if state key exists or False otherwise

get_state_keys(self, state_type: StateType) -> List[mozuma.states.StateKey]

Lists the available states that are compatible with the given state type.

Attributes:

Name Type Description
state_type StateType

Used to filter the compatible state keys

Examples:

This is used to list the pretrained weights for a given model. The following code gives all available state keys in store for the model.

keys = store.get_state_keys(model.state_type)

load(self, model: ~_ModelType, state_key: StateKey) -> None

Loads the models weights from the store

Attributes:

Name Type Description
model ModelWithState

Model to update

state_key StateKey

The identifier for the state to load

save(self, model: ~_ModelType, training_id: str) -> StateKey

Saves the model state to the store

Attributes:

Name Type Description
model ModelWithState

Model to save

training_id str

Identifier for the training activity

Returns:

Type Description
StateKey

The identifier for the state that has been created

For stores used to download states of a single model, it can be useful to subclass the AbstractListStateStore directly. This makes it easier to define a store from a fix set of states as it is often the case when integrating the weights from external sources (pre-trained states for a paper, hugging face...). See SBERTDistiluseBaseMultilingualCasedV2Store for an example.

mozuma.stores.list.AbstractListStateStore

Helper to define a store from a fixed list of state keys.

The subclasses should implement the following:

available_state_keys: List[mozuma.states.StateKey] property readonly

List of available state keys for this store

Returns:

Type Description
list(StateKey)

All available state keys in the store

get_state_keys(self, state_type: StateType) -> List[mozuma.states.StateKey]

Lists the available states that are compatible with the given state type.

Attributes:

Name Type Description
state_type StateType

Used to filter the compatible state keys

Examples:

This is used to list the pretrained weights for a given model. The following code gives all available state keys in store for the model.

keys = store.get_state_keys(model.state_type)

load(self, model: ~_ModelType, state_key: StateKey) -> None

Loads the models weights from the store

Attributes:

Name Type Description
model ModelWithState

Model to update

state_key StateKey

The identifier for the state to load

save(self, model: ~_ModelType, training_id: str) -> NoReturn

Saves the model state to the store

Attributes:

Name Type Description
model ModelWithState

Model to save

training_id str

Identifier for the training activity

Returns:

Type Description
StateKey

The identifier for the state that has been created

state_downloader(self, model: ~_ModelType, state_key: StateKey) -> None

Downloads and applies a state to a model

Parameters:

Name Type Description Default
model _ModelType

The model that will be used to load the state

required
state_key StateKey

The state key identifier

required