Shortcuts

Self-supervised Learning

This bolts module houses a collection of all self-supervised learning models.

Self-supervised learning extracts representations of an input by solving a pretext task. In this package, we implement many of the current state-of-the-art self-supervised algorithms.

Self-supervised models are trained with unlabeled datasets


Use cases

Here are some use cases for the self-supervised package.

Extracting image features

The models in this module are trained unsupervised and thus can capture better image representations (features).

In this example, we’ll load a resnet 18 which was pretrained on imagenet using CPC as the pretext task.

Example:

from pl_bolts.models.self_supervised import CPCV2

# load resnet18 pretrained using CPC on imagenet
model = CPCV2(pretrained='resnet18')
cpc_resnet18 = model.encoder
cpc_resnet18.freeze()

# it supports any torchvision resnet
model = CPCV2(pretrained='resnet50')

This means you can now extract image representations that were pretrained via unsupervised learning.

Example:

my_dataset = SomeDataset()
for batch in my_dataset:
    x, y = batch
    out = cpc_resnet18(x)

Train with unlabeled data

These models are perfect for training from scratch when you have a huge set of unlabeled images

from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform


train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
val_dataset = MyDataset(transforms=SimCLREvalDataTransform())

# simclr needs a lot of compute!
model = SimCLR()
trainer = Trainer(tpu_cores=128)
trainer.fit(
    model,
    DataLoader(train_dataset),
    DataLoader(val_dataset),
)

Research

Mix and match any part, or subclass to create your own new method

from pl_bolts.models.self_supervised import CPCV2
from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask

amdim_task = FeatureMapContrastiveTask(comparisons='01, 11, 02', bidirectional=True)
model = CPCV2(contrastive_task=amdim_task)

Contrastive Learning Models

Contrastive self-supervised learning (CSL) is a self-supervised learning approach where we generate representations of instances such that similar instances are near each other and far from dissimilar ones. This is often done by comparing triplets of positive, anchor and negative representations.

In this section, we list Lightning implementations of popular contrastive learning approaches.

AMDIM

class pl_bolts.models.self_supervised.AMDIM(datamodule='cifar10', encoder='amdim_encoder', contrastive_task=torch.nn.Module, image_channels=3, image_height=32, encoder_feature_dim=320, embedding_fx_dim=1280, conv_block_depth=10, use_bn=False, tclip=20.0, learning_rate=0.0002, data_dir='', num_classes=10, batch_size=200, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of Augmented Multiscale Deep InfoMax (AMDIM)

Paper authors: Philip Bachman, R Devon Hjelm, William Buchwalter.

Model implemented by: William Falcon

This code is adapted to Lightning using the original author repo (the original repo).

Example

>>> from pl_bolts.models.self_supervised import AMDIM
...
>>> model = AMDIM(encoder='resnet18')

Train:

trainer = Trainer()
trainer.fit(model)
Parameters
  • datamodule (Union[str, LightningDataModule]) – A LightningDatamodule

  • encoder (Union[str, Module, LightningModule]) – an encoder string or model

  • image_channels (int) – 3

  • image_height (int) – pixels

  • encoder_feature_dim (int) – Called ndf in the paper, this is the representation size for the encoder.

  • embedding_fx_dim (int) – Output dim of the embedding function (nrkhs in the paper) (Reproducing Kernel Hilbert Spaces).

  • conv_block_depth (int) – Depth of each encoder block,

  • use_bn (bool) – If true will use batchnorm.

  • tclip (int) – soft clipping non-linearity to the scores after computing the regularization term and before computing the log-softmax. This is the ‘second trick’ used in the paper

  • learning_rate (int) – The learning rate

  • data_dir (str) – Where to store data

  • num_classes (int) – How many classes in the dataset

  • batch_size (int) – The batch size


BYOL

class pl_bolts.models.self_supervised.BYOL(num_classes, learning_rate=0.2, weight_decay=1.5e-05, input_height=32, batch_size=32, num_workers=0, warmup_epochs=10, max_epochs=1000, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of Bring Your Own Latent (BYOL)

Paper authors: Jean-Bastien Grill ,Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko.

Model implemented by:

Warning

Work in progress. This implementation is still being verified.

TODOs:
  • verify on CIFAR-10

  • verify on STL-10

  • pre-train on imagenet

Example:

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import BYOL
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.simclr_transforms import (
    SimCLREvalDataTransform, SimCLRTrainDataTransform)

# model
model = BYOL(num_classes=10)

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

trainer = pl.Trainer()
trainer.fit(model, dm)

Train:

trainer = Trainer()
trainer.fit(model)

CLI command:

# cifar10
python byol_module.py --gpus 1

# imagenet
python byol_module.py
    --gpus 8
    --dataset imagenet2012
    --data_dir /path/to/imagenet/
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32
Parameters
  • datamodule – The datamodule

  • learning_rate (float) – the learning rate

  • weight_decay (float) – optimizer weight decay

  • input_height (int) – image input height

  • batch_size (int) – the batch size

  • num_workers (int) – number of workers

  • warmup_epochs (int) – num of epochs for scheduler warm up

  • max_epochs (int) – max epochs for scheduler


CPC (V2)

class pl_bolts.models.self_supervised.CPCV2(datamodule=None, encoder='cpc_encoder', patch_size=8, patch_overlap=4, online_ft=True, task='cpc', num_workers=4, learning_rate=0.0001, data_dir='', batch_size=32, pretrained=None, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of Data-Efficient Image Recognition with Contrastive Predictive Coding

Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord).

Model implemented by:

Example

>>> from pl_bolts.models.self_supervised import CPCV2
...
>>> model = CPCV2()

Train:

trainer = Trainer()
trainer.fit(model)

CLI command:

# cifar10
python cpc_module.py --gpus 1

# imagenet
python cpc_module.py
    --gpus 8
    --dataset imagenet2012
    --data_dir /path/to/imagenet/
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32

To Finetune:

python cpc_finetuner.py --ckpt_path path/to/checkpoint.ckpt --dataset cifar10 --gpus x

Some uses:

# load resnet18 pretrained using CPC on imagenet
model = CPCV2(encoder='resnet18', pretrained=True)
resnet18 = model.encoder
renset18.freeze()

# it supportes any torchvision resnet
model = CPCV2(encoder='resnet50', pretrained=True)

# use it as a feature extractor
x = torch.rand(2, 3, 224, 224)
out = model(x)
Parameters
  • datamodule (Optional[LightningDataModule]) – A Datamodule (optional). Otherwise set the dataloaders directly

  • encoder (Union[str, Module, LightningModule]) – A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder

  • patch_size (int) – How big to make the image patches

  • patch_overlap (int) – How much overlap should each patch have.

  • online_ft (int) – Enable a 1024-unit MLP to fine-tune online

  • task (str) – Which self-supervised task to use (‘cpc’, ‘amdim’, etc…)

  • num_workers (int) – num dataloader worksers

  • learning_rate (int) – what learning rate to use

  • data_dir (str) – where to store data

  • batch_size (int) – batch size

  • pretrained (Optional[str]) – If true, will use the weights pretrained (using CPC) on Imagenet


Moco (V2)

class pl_bolts.models.self_supervised.MocoV2(base_encoder='resnet18', emb_dim=128, num_negatives=65536, encoder_momentum=0.999, softmax_temperature=0.07, learning_rate=0.03, momentum=0.9, weight_decay=0.0001, datamodule=None, data_dir='./', batch_size=256, use_mlp=False, num_workers=8, *args, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of Moco

Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He.

Code adapted from facebookresearch/moco to Lightning by:

Example

>>> from pl_bolts.models.self_supervised import MocoV2
...
>>> model = MocoV2()

Train:

trainer = Trainer()
trainer.fit(model)

CLI command:

# cifar10
python moco2_module.py --gpus 1

# imagenet
python moco2_module.py
    --gpus 8
    --dataset imagenet2012
    --data_dir /path/to/imagenet/
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32
Parameters
  • base_encoder (Union[str, Module]) – torchvision model name or torch.nn.Module

  • emb_dim (int) – feature dimension (default: 128)

  • num_negatives (int) – queue size; number of negative keys (default: 65536)

  • encoder_momentum (float) – moco momentum of updating key encoder (default: 0.999)

  • softmax_temperature (float) – softmax temperature (default: 0.07)

  • learning_rate (float) – the learning rate

  • momentum (float) – optimizer momentum

  • weight_decay (float) – optimizer weight decay

  • datamodule (Optional[LightningDataModule]) – the DataModule (train, val, test dataloaders)

  • data_dir (str) – the directory to store data

  • batch_size (int) – batch size

  • use_mlp (bool) – add an mlp to the encoders

  • num_workers (int) – workers for the loaders

_batch_shuffle_ddp(x)[source]

Batch shuffle, for making use of BatchNorm. * Only support DistributedDataParallel (DDP) model. *

_batch_unshuffle_ddp(x, idx_unshuffle)[source]

Undo batch shuffle. * Only support DistributedDataParallel (DDP) model. *

_momentum_update_key_encoder()[source]

Momentum update of the key encoder

forward(img_q, img_k)[source]
Input:

im_q: a batch of query images im_k: a batch of key images

Output:

logits, targets

init_encoders(base_encoder)[source]

Override to add your own encoders


SimCLR

PyTorch Lightning implementation of SimCLR

Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.

Model implemented by:

To Train:

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.simclr_transforms import (
    SimCLREvalDataTransform, SimCLRTrainDataTransform)

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size)

# fit
trainer = pl.Trainer()
trainer.fit(model, dm)

CIFAR-10 baseline

Cifar-10 implementation results

Implementation

test acc

Encoder

Optimizer

Batch

Epochs

Hardware

LR

Original

92.00?

resnet50

LARS

512

1000

1 V100 (32GB)

1.0

Ours

85.68

resnet50

LARS

512

960 (12 hr)

1 V100 (32GB)

1e-6


CIFAR-10 pretrained model:

from pl_bolts.models.self_supervised import SimCLR

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

simclr.freeze()

Pre-training:

pretraining validation loss

Fine-tuning (Single layer MLP, 1024 hidden units):

finetuning validation accuracy
finetuning test accuracy

To reproduce:

# pretrain
python simclr_module.py
    --gpus 1
    --dataset cifar10
    --batch_size 512
    --learning_rate 1e-06
    --num_workers 8

# finetune
python simclr_finetuner.py
    --ckpt_path path/to/epoch=xyz.ckpt
    --gpus 1

SimCLR API

class pl_bolts.models.self_supervised.SimCLR(batch_size, num_samples, warmup_epochs=10, lr=0.0001, opt_weight_decay=1e-06, loss_temperature=0.5, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

Parameters
  • batch_size – the batch size

  • num_samples – num samples in the dataset

  • warmup_epochs – epochs to warmup the lr for

  • lr – the optimizer learning rate

  • opt_weight_decay – the optimizer weight decay

  • loss_temperature – the loss temperature

Read the Docs v: 0.2.0
Versions
latest
stable
0.2.0
0.1.1
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.