Shortcuts

Self-supervised Learning

This bolts module houses a collection of all self-supervised learning models.

Self-supervised learning extracts representations of an input by solving a pretext task. In this package, we implement many of the current state-of-the-art self-supervised algorithms.

Self-supervised models are trained with unlabeled datasets


Use cases

Here are some use cases for the self-supervised package.

Extracting image features

The models in this module are trained unsupervised and thus can capture better image representations (features).

In this example, we’ll load a resnet 18 which was pretrained on imagenet using CPC as the pretext task.

Example:

from pl_bolts.models.self_supervised import CPCV2

# load resnet18 pretrained using CPC on imagenet
model = CPCV2(pretrained='resnet18')
cpc_resnet18 = model.encoder
cpc_resnet18.freeze()

# it supports any torchvision resnet
model = CPCV2(pretrained='resnet50')

This means you can now extract image representations that were pretrained via unsupervised learning.

Example:

my_dataset = SomeDataset()
for batch in my_dataset:
    x, y = batch
    out = cpc_resnet18(x)

Train with unlabeled data

These models are perfect for training from scratch when you have a huge set of unlabeled images

from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform


train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
val_dataset = MyDataset(transforms=SimCLREvalDataTransform())

# simclr needs a lot of compute!
model = SimCLR()
trainer = Trainer(tpu_cores=128)
trainer.fit(
    model,
    DataLoader(train_dataset),
    DataLoader(val_dataset),
)

Research

Mix and match any part, or subclass to create your own new method

from pl_bolts.models.self_supervised import CPCV2
from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask

amdim_task = FeatureMapContrastiveTask(comparisons='01, 11, 02', bidirectional=True)
model = CPCV2(contrastive_task=amdim_task)

Contrastive Learning Models

Contrastive self-supervised learning (CSL) is a self-supervised learning approach where we generate representations of instances such that similar instances are near each other and far from dissimilar ones. This is often done by comparing triplets of positive, anchor and negative representations.

In this section, we list Lightning implementations of popular contrastive learning approaches.

AMDIM

class pl_bolts.models.self_supervised.AMDIM(datamodule='cifar10', encoder='amdim_encoder', contrastive_task=torch.nn.Module, image_channels=3, image_height=32, encoder_feature_dim=320, embedding_fx_dim=1280, conv_block_depth=10, use_bn=False, tclip=20.0, learning_rate=0.0002, data_dir='', num_classes=10, batch_size=200, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of Augmented Multiscale Deep InfoMax (AMDIM).

Paper authors: Philip Bachman, R Devon Hjelm, William Buchwalter.

Model implemented by: William Falcon

This code is adapted to Lightning using the original author repo (the original repo).

Example

>>> from pl_bolts.models.self_supervised import AMDIM
...
>>> model = AMDIM(encoder='resnet18')

Train:

trainer = Trainer()
trainer.fit(model)
Parameters
  • datamodule (Union[str, LightningDataModule]) – A LightningDatamodule

  • encoder (Union[str, Module, LightningModule]) – an encoder string or model

  • image_channels (int) – 3

  • image_height (int) – pixels

  • encoder_feature_dim (int) – Called ndf in the paper, this is the representation size for the encoder.

  • embedding_fx_dim (int) – Output dim of the embedding function (nrkhs in the paper) (Reproducing Kernel Hilbert Spaces).

  • conv_block_depth (int) – Depth of each encoder block,

  • use_bn (bool) – If true will use batchnorm.

  • tclip (int) – soft clipping non-linearity to the scores after computing the regularization term and before computing the log-softmax. This is the ‘second trick’ used in the paper

  • learning_rate (int) – The learning rate

  • data_dir (str) – Where to store data

  • num_classes (int) – How many classes in the dataset

  • batch_size (int) – The batch size


BYOL

class pl_bolts.models.self_supervised.BYOL(num_classes, learning_rate=0.2, weight_decay=1.5e-06, input_height=32, batch_size=32, num_workers=0, warmup_epochs=10, max_epochs=1000, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of Bootstrap Your Own Latent (BYOL)

Paper authors: Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko.

Model implemented by:

Warning

Work in progress. This implementation is still being verified.

TODOs:
  • verify on CIFAR-10

  • verify on STL-10

  • pre-train on imagenet

Example:

model = BYOL(num_classes=10)

dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

trainer = pl.Trainer()
trainer.fit(model, dm)

Train:

trainer = Trainer()
trainer.fit(model)

CLI command:

# cifar10
python byol_module.py --gpus 1

# imagenet
python byol_module.py
    --gpus 8
    --dataset imagenet2012
    --data_dir /path/to/imagenet/
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32
Parameters
  • datamodule – The datamodule

  • learning_rate (float) – the learning rate

  • weight_decay (float) – optimizer weight decay

  • input_height (int) – image input height

  • batch_size (int) – the batch size

  • num_workers (int) – number of workers

  • warmup_epochs (int) – num of epochs for scheduler warm up

  • max_epochs (int) – max epochs for scheduler


CPC (V2)

PyTorch Lightning implementation of Data-Efficient Image Recognition with Contrastive Predictive Coding

Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord).

Model implemented by:

To Train:

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import CPCV2
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.cpc import (
    CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10)

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = CPCTrainTransformsCIFAR10()
dm.val_transforms = CPCEvalTransformsCIFAR10()

# model
model = CPCV2()

# fit
trainer = pl.Trainer()
trainer.fit(model, dm)

To finetune:

python cpc_finetuner.py
    --ckpt_path path/to/checkpoint.ckpt
    --dataset cifar10
    --gpus 1

CIFAR-10 and STL-10 baselines

CPCv2 does not report baselines on CIFAR-10 and STL-10 datasets. Results in table are reported from the YADIM paper.

CPCv2 implementation results

Dataset

test acc

Encoder

Optimizer

Batch

Epochs

Hardware

LR

CIFAR-10

84.52

CPCresnet101

Adam

64

1000 (upto 24 hours)

1 V100 (32GB)

4e-5

STL-10

78.36

CPCresnet101

Adam

144

1000 (upto 72 hours)

4 V100 (32GB)

1e-4

ImageNet

54.82

CPCresnet101

Adam

3072

1000 (upto 21 days)

64 V100 (32GB)

4e-5


CIFAR-10 pretrained model:

from pl_bolts.models.self_supervised import CPCV2

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/epoch%3D474.ckpt'
cpc_v2 = CPCV2.load_from_checkpoint(weight_path, strict=False)

cpc_v2.freeze()

Pre-training:

pretraining validation loss

Fine-tuning:

online finetuning accuracy

STL-10 pretrained model:

from pl_bolts.models.self_supervised import CPCV2

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/epoch%3D624.ckpt'
cpc_v2 = CPCV2.load_from_checkpoint(weight_path, strict=False)

cpc_v2.freeze()

Pre-training:

pretraining validation loss

Fine-tuning:

online finetuning accuracy

CPCV2 API

class pl_bolts.models.self_supervised.CPCV2(encoder_name='cpc_encoder', patch_size=8, patch_overlap=4, online_ft=True, task='cpc', num_workers=4, num_classes=10, learning_rate=0.0001, pretrained=None, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

Parameters
  • encoder_name (str) – A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder

  • patch_size (int) – How big to make the image patches

  • patch_overlap (int) – How much overlap each patch should have

  • online_ft (bool) – If True, enables a 1024-unit MLP to fine-tune online

  • task (str) – Which self-supervised task to use (‘cpc’, ‘amdim’, etc…)

  • num_workers (int) – number of dataloader workers

  • num_classes (int) – number of classes

  • learning_rate (float) – learning rate

  • pretrained (Optional[str]) – If true, will use the weights pretrained (using CPC) on Imagenet


Moco (V2)

class pl_bolts.models.self_supervised.MocoV2(base_encoder='resnet18', emb_dim=128, num_negatives=65536, encoder_momentum=0.999, softmax_temperature=0.07, learning_rate=0.03, momentum=0.9, weight_decay=0.0001, data_dir='./', batch_size=256, use_mlp=False, num_workers=8, *args, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of Moco

Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He.

Code adapted from facebookresearch/moco to Lightning by:

Example::

from pl_bolts.models.self_supervised import MocoV2 model = MocoV2() trainer = Trainer() trainer.fit(model)

CLI command:

# cifar10
python moco2_module.py --gpus 1

# imagenet
python moco2_module.py
    --gpus 8
    --dataset imagenet2012
    --data_dir /path/to/imagenet/
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32
Parameters
  • base_encoder (Union[str, Module]) – torchvision model name or torch.nn.Module

  • emb_dim (int) – feature dimension (default: 128)

  • num_negatives (int) – queue size; number of negative keys (default: 65536)

  • encoder_momentum (float) – moco momentum of updating key encoder (default: 0.999)

  • softmax_temperature (float) – softmax temperature (default: 0.07)

  • learning_rate (float) – the learning rate

  • momentum (float) – optimizer momentum

  • weight_decay (float) – optimizer weight decay

  • datamodule – the DataModule (train, val, test dataloaders)

  • data_dir (str) – the directory to store data

  • batch_size (int) – batch size

  • use_mlp (bool) – add an mlp to the encoders

  • num_workers (int) – workers for the loaders

forward(img_q, img_k)[source]
Input:

im_q: a batch of query images im_k: a batch of key images

Output:

logits, targets

init_encoders(base_encoder)[source]

Override to add your own encoders


SimCLR

PyTorch Lightning implementation of SimCLR

Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.

Model implemented by:

To Train:

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import (
    SimCLREvalDataTransform, SimCLRTrainDataTransform)

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size, dataset='cifar10')

# fit
trainer = pl.Trainer()
trainer.fit(model, dm)

CIFAR-10 baseline

Cifar-10 implementation results

Implementation

test acc

Encoder

Optimizer

Batch

Epochs

Hardware

LR

Original

~94.00

resnet50

LARS

2048

800

TPUs

1.0/1.5

Ours

85.68

resnet50

LARS-SGD

2048

800 (~4 hours)

8 V100 (16GB)

1.5

Ours

85.68

resnet50

LARS-Adam

2048

800 (~4 hours)

8 V100 (16GB)

1e-3


CIFAR-10 pretrained model:

from pl_bolts.models.self_supervised import SimCLR

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

simclr.freeze()

Pre-training:

pretraining validation loss

Fine-tuning (Single layer MLP, 1024 hidden units):

finetuning validation accuracy
finetuning test accuracy

To reproduce:

# pretrain
python simclr_module.py
    --gpus 1
    --dataset cifar10
    --batch_size 512
    --learning_rate 1e-06
    --num_workers 8

# finetune
python simclr_finetuner.py
    --ckpt_path path/to/epoch=xyz.ckpt
    --gpus 1

SimCLR API

class pl_bolts.models.self_supervised.SimCLR(gpus, num_samples, batch_size, dataset, nodes=1, arch='resnet50', hidden_mlp=2048, feat_dim=128, warmup_epochs=10, max_epochs=100, temperature=0.1, first_conv=True, maxpool1=True, optimizer='adam', lars_wrapper=True, exclude_bn_bias=False, start_lr=0.0, learning_rate=0.001, final_lr=0.0, weight_decay=1e-06, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

Parameters
  • batch_size (int) – the batch size

  • num_samples (int) – num samples in the dataset

  • warmup_epochs (int) – epochs to warmup the lr for

  • lr – the optimizer learning rate

  • opt_weight_decay – the optimizer weight decay

  • loss_temperature – the loss temperature

nt_xent_loss(out_1, out_2, temperature, eps=1e-06)[source]

assume out_1 and out_2 are normalized out_1: [batch_size, dim] out_2: [batch_size, dim]


SwAV

PyTorch Lightning implementation of SwAV Adapted from the official implementation

Paper authors: Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, Armand Joulin.

Implementation adapted by:

To Train:

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SwAV
from pl_bolts.datamodules import STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import (
    SwAVTrainDataTransform, SwAVEvalDataTransform
)
from pl_bolts.transforms.dataset_normalizations import stl10_normalization

# data
batch_size = 128
dm = STL10DataModule(data_dir='.', batch_size=batch_size)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed

dm.train_transforms = SwAVTrainDataTransform(
    normalize=stl10_normalization()
)

dm.val_transforms = SwAVEvalDataTransform(
    normalize=stl10_normalization()
)

# model
model = SwAV(
    gpus=1,
    num_samples=dm.num_unlabeled_samples,
    dataset='stl10',
    batch_size=batch_size
)

# fit
trainer = pl.Trainer(precision=16)
trainer.fit(model)

ImageNet baseline

We have included an option to directly load ImageNet weights provided by FAIR into bolts.

You can load the pretrained model using:

ImageNet pretrained model:

from pl_bolts.models.self_supervised import SwAV

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar'
swav = SwAV.load_from_checkpoint(weight_path, strict=True)

swav.freeze()

STL-10 baseline

The original paper does not provide baselines on STL10.

STL-10 implementation results

Implementation

test acc

Encoder

Optimizer

Batch

Queue used

Epochs

Hardware

LR

Ours

86.72

SwAV resnet50

LARS

128

No

100 (~9 hr)

1 V100 (16GB)

1e-3


STL-10 pretrained model:

from pl_bolts.models.self_supervised import SwAV

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar'
swav = SwAV.load_from_checkpoint(weight_path, strict=False)

swav.freeze()

Pre-training:

pretraining validation loss
online finetuning validation acc

Fine-tuning (Single layer MLP, 1024 hidden units):

finetuning validation accuracy
finetuning validation loss

To reproduce:

# pretrain
python swav_module.py
    --online_ft
    --gpus 1
    --lars_wrapper
    --batch_size 128
    --learning_rate 1e-3
    --gaussian_blur
    --queue_length 0
    --jitter_strength 1.
    --nmb_prototypes 512

# finetune
python swav_finetuner.py
    --ckpt_path path/to/epoch=xyz.ckpt

SwAV API

class pl_bolts.models.self_supervised.SwAV(gpus, num_samples, batch_size, dataset, nodes=1, arch='resnet50', hidden_mlp=2048, feat_dim=128, warmup_epochs=10, max_epochs=100, nmb_prototypes=3000, freeze_prototypes_epochs=1, temperature=0.1, sinkhorn_iterations=3, queue_length=0, queue_path='queue', epoch_queue_starts=15, crops_for_assign=[0, 1], nmb_crops=[2, 6], first_conv=True, maxpool1=True, optimizer='adam', lars_wrapper=True, exclude_bn_bias=False, start_lr=0.0, learning_rate=0.001, final_lr=0.0, weight_decay=1e-06, epsilon=0.05, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

Parameters
  • gpus (int) – number of gpus per node used in training, passed to SwAV module to manage the queue and select distributed sinkhorn

  • nodes (int) – number of nodes to train on

  • num_samples (int) – number of image samples used for training

  • batch_size (int) – batch size per GPU in ddp

  • dataset (str) – dataset being used for train/val

  • arch (str) – encoder architecture used for pre-training

  • hidden_mlp (int) – hidden layer of non-linear projection head, set to 0 to use a linear projection head

  • feat_dim (int) – output dim of the projection head

  • warmup_epochs (int) – apply linear warmup for this many epochs

  • max_epochs (int) – epoch count for pre-training

  • nmb_prototypes (int) – count of prototype vectors

  • freeze_prototypes_epochs (int) – epoch till which gradients of prototype layer are frozen

  • temperature (float) – loss temperature

  • sinkhorn_iterations (int) – iterations for sinkhorn normalization

  • queue_length (int) – set queue when batch size is small, must be divisible by total batch-size (i.e. total_gpus * batch_size), set to 0 to remove the queue

  • queue_path (str) – folder within the logs directory

  • epoch_queue_starts (int) – start uing the queue after this epoch

  • crops_for_assign (list) – list of crop ids for computing assignment

  • nmb_crops (list) – number of global and local crops, ex: [2, 6]

  • first_conv (bool) – keep first conv same as the original resnet architecture, if set to false it is replace by a kernel 3, stride 1 conv (cifar-10)

  • maxpool1 (bool) – keep first maxpool layer same as the original resnet architecture, if set to false, first maxpool is turned off (cifar10, maybe stl10)

  • optimizer (str) – optimizer to use

  • lars_wrapper (bool) – use LARS wrapper over the optimizer

  • exclude_bn_bias (bool) – exclude batchnorm and bias layers from weight decay in optimizers

  • start_lr (float) – starting lr for linear warmup

  • learning_rate (float) – learning rate

  • final_lr (float) – float = final learning rate for cosine weight decay

  • weight_decay (float) – weight decay for optimizer

  • epsilon (float) – epsilon val for swav assignments

Read the Docs v: latest
Versions
latest
stable
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
0.1.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.