Shortcuts

Self-supervised Callbacks

Useful callbacks for self-supervised learning models


BYOLMAWeightUpdate

The exponential moving average weight-update rule from Bootstrap Your Own Latent (BYOL).

class pl_bolts.callbacks.byol_updates.BYOLMAWeightUpdate(initial_tau=0.996)[source]

Bases: pytorch_lightning.Callback

Weight update rule from BYOL.

Your model should have:

  • self.online_network

  • self.target_network

Updates the target_network params using an exponential moving average update rule weighted by tau. BYOL claims this keeps the online_network from collapsing.

Note

Automatically increases tau from initial_tau to 1.0 with every training step

Example:

# model must have 2 attributes
model = Model()
model.online_network = ...
model.target_network = ...

trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
Parameters

initial_tau (float) – starting tau. Auto-updates with every training step


SSLOnlineEvaluator

Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop.

class pl_bolts.callbacks.ssl_online.SSLOnlineEvaluator(dataset, drop_p=0.2, hidden_dim=None, z_dim=None, num_classes=None)[source]

Bases: pytorch_lightning.Callback

Attaches a MLP for fine-tuning using the standard self-supervised protocol.

Example:

# your model must have 2 attributes
model = Model()
model.z_dim = ... # the representation dim
model.num_classes = ... # the num of classes in the model

online_eval = SSLOnlineEvaluator(
    z_dim=model.z_dim,
    num_classes=model.num_classes,
    dataset='imagenet'
)
Parameters
  • dataset (str) – if stl10, need to get the labeled batch

  • drop_p (float) – Dropout probability

  • hidden_dim (Optional[int]) – Hidden dimension for the fine-tune MLP

  • z_dim (Optional[int]) – Representation dimension

  • num_classes (Optional[int]) – Number of classes

Read the Docs v: latest
Versions
latest
stable
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
0.1.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.