Self-supervised Callbacks¶
Useful callbacks for self-supervised learning models.
Note
We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!
BYOLMAWeightUpdate¶
The exponential moving average weight-update rule from Bootstrap Your Own Latent (BYOL).
- class pl_bolts.callbacks.byol_updates.BYOLMAWeightUpdate(initial_tau=0.996)[source]
Bases:
pytorch_lightning.callbacks.callback.Callback
Weight update rule from Bootstrap Your Own Latent (BYOL).
Updates the target_network params using an exponential moving average update rule weighted by tau. BYOL claims this keeps the online_network from collapsing.
The PyTorch Lightning module being trained should have:
self.online_network
self.target_network
Note
Automatically increases tau from
initial_tau
to 1.0 with every training stepExample:
# model must have 2 attributes model = Model() model.online_network = ... model.target_network = ... trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]
Called when the train batch ends.
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.- Return type
SSLOnlineEvaluator¶
Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop.
- class pl_bolts.callbacks.ssl_online.SSLOnlineEvaluator(z_dim, drop_p=0.2, hidden_dim=None, num_classes=None, dataset=None)[source]
Bases:
pytorch_lightning.callbacks.callback.Callback
Warning
The feature SSLOnlineEvaluator is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
Attaches a MLP for fine-tuning using the standard self-supervised protocol.
Example:
# your datamodule must have 2 attributes dm = DataModule() dm.num_classes = ... # the num of classes in the datamodule dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10) # your model must have 1 attribute model = Model() model.z_dim = ... # the representation dim online_eval = SSLOnlineEvaluator( z_dim=model.z_dim )
- Parameters
- load_state_dict(state_dict)[source]
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]
Called when the train batch ends.
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.- Return type
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]
Called when the validation batch ends.
- Return type
- setup(trainer, pl_module, stage=None)[source]
Called when fit, validate, test, predict, or tune begins.
- Return type