Shortcuts

Self-supervised Callbacks

Useful callbacks for self-supervised learning models


BYOLMAWeightUpdate

The exponential moving average weight-update rule from Bring Your Own Latent (BYOL).

class pl_bolts.callbacks.self_supervised.BYOLMAWeightUpdate(initial_tau=0.996)[source]

Bases: pytorch_lightning.Callback

Weight update rule from BYOL.

Your model should have a:

  • self.online_network.

  • self.target_network.

Updates the target_network params using an exponential moving average update rule weighted by tau. BYOL claims this keeps the online_network from collapsing.

Note

Automatically increases tau from initial_tau to 1.0 with every training step

Example:

from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate

# model must have 2 attributes
model = Model()
model.online_network = ...
model.target_network = ...

trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
Parameters

initial_tau – starting tau. Auto-updates with every training step


SSLOnlineEvaluator

Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop.

class pl_bolts.callbacks.self_supervised.SSLOnlineEvaluator(drop_p=0.2, hidden_dim=1024, z_dim=None, num_classes=None)[source]

Bases: pytorch_lightning.Callback

Attaches a MLP for finetuning using the standard self-supervised protocol.

Example:

from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator

# your model must have 2 attributes
model = Model()
model.z_dim = ... # the representation dim
model.num_classes = ... # the num of classes in the model
Parameters
  • drop_p (float) – (0.2) dropout probability

  • hidden_dim (int) –

    1. the hidden dimension for the finetune MLP

get_representations(pl_module, x)[source]

Override this to customize for the particular model :param _sphinx_paramlinks_pl_bolts.callbacks.self_supervised.SSLOnlineEvaluator.get_representations.pl_module: :param _sphinx_paramlinks_pl_bolts.callbacks.self_supervised.SSLOnlineEvaluator.get_representations.x:

Read the Docs v: 0.2.3
Versions
latest
stable
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
master_rl
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.