Shortcuts

Self-supervised learning

Collection of useful functions for self-supervised learning


Identity class

Example:

from pl_bolts.utils import Identity
class pl_bolts.utils.self_supervised.Identity[source]

Bases: torch.nn.Module

An identity class to replace arbitrary layers in pretrained models

Example:

from pl_bolts.utils import Identity

model = resnet18()
model.fc = Identity()

SSL-ready resnets

Torchvision resnets with the fc layers removed and with the ability to return all feature maps instead of just the last one.

Example:

from pl_bolts.utils.self_supervised import torchvision_ssl_encoder

resnet = torchvision_ssl_encoder('resnet18', pretrained=False, return_all_feature_maps=True)
x = torch.rand(3, 3, 32, 32)

feat_maps = resnet(x)
pl_bolts.utils.self_supervised.torchvision_ssl_encoder(name, pretrained=False, return_all_feature_maps=False)[source]

SSL backbone finetuner

class pl_bolts.models.self_supervised.ssl_finetuner.SSLFineTuner(backbone, in_features, num_classes, hidden_dim=1024)[source]

Bases: pytorch_lightning.LightningModule

Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP with 1024 units

Example:

from pl_bolts.utils.self_supervised import SSLFineTuner
from pl_bolts.models.self_supervised import CPCV2
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.cpc.transforms import CPCEvalTransformsCIFAR10,
                                                            CPCTrainTransformsCIFAR10

# pretrained model
backbone = CPCV2.load_from_checkpoint(PATH, strict=False)

# dataset + transforms
dm = CIFAR10DataModule(data_dir='.')
dm.train_transforms = CPCTrainTransformsCIFAR10()
dm.val_transforms = CPCEvalTransformsCIFAR10()

# finetuner
finetuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes)

# train
trainer = pl.Trainer()
trainer.fit(finetuner, dm)

# test
trainer.test(datamodule=dm)
Parameters
  • backbone – a pretrained model

  • in_features – feature dim of backbone outputs

  • num_classes – classes of the dataset

  • hidden_dim – dim of the MLP (1024 default used in self-supervised literature)

Read the Docs v: 0.2.0
Versions
latest
stable
0.2.0
0.1.1
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.