Shortcuts

Self-supervised Learning

This bolts module houses a collection of all self-supervised learning models.

Self-supervised learning extracts representations of an input by solving a pretext task. In this package, we implement many of the current state-of-the-art self-supervised algorithms.

Self-supervised models are trained with unlabeled datasets


Use cases

Here are some use cases for the self-supervised package.

Extracting image features

The models in this module are trained unsupervised and thus can capture better image representations (features).

In this example, we’ll load a resnet 18 which was pretrained on imagenet using CPC as the pretext task.

Example:

from pl_bolts.models.self_supervised import CPCV2

# load resnet18 pretrained using CPC on imagenet
model = CPCV2(pretrained='resnet18')
cpc_resnet18 = model.encoder
cpc_resnet18.freeze()

# it supports any torchvision resnet
model = CPCV2(pretrained='resnet50')

This means you can now extract image representations that were pretrained via unsupervised learning.

Example:

my_dataset = SomeDataset()
for batch in my_dataset:
    x, y = batch
    out = cpc_resnet18(x)

Train with unlabeled data

These models are perfect for training from scratch when you have a huge set of unlabeled images

from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform


train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
val_dataset = MyDataset(transforms=SimCLREvalDataTransform())

# simclr needs a lot of compute!
model = SimCLR()
trainer = Trainer(tpu_cores=128)
trainer.fit(
    model,
    DataLoader(train_dataset),
    DataLoader(val_dataset),
)

Research

Mix and match any part, or subclass to create your own new method

from pl_bolts.models.self_supervised import CPCV2
from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask

amdim_task = FeatureMapContrastiveTask(comparisons='01, 11, 02', bidirectional=True)
model = CPCV2(contrastive_task=amdim_task)

Contrastive Learning Models

Contrastive self-supervised learning (CSL) is a self-supervised learning approach where we generate representations of instances such that similar instances are near each other and far from dissimilar ones. This is often done by comparing triplets of positive, anchor and negative representations.

In this section, we list Lightning implementations of popular contrastive learning approaches.

AMDIM

class pl_bolts.models.self_supervised.AMDIM(datamodule='cifar10', encoder='amdim_encoder', contrastive_task=torch.nn.Module, image_channels=3, image_height=32, encoder_feature_dim=320, embedding_fx_dim=1280, conv_block_depth=10, use_bn=False, tclip=20.0, learning_rate=0.0002, data_dir='', num_classes=10, batch_size=200, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of Augmented Multiscale Deep InfoMax (AMDIM)

Paper authors: Philip Bachman, R Devon Hjelm, William Buchwalter.

Model implemented by: William Falcon

This code is adapted to Lightning using the original author repo (the original repo).

Example

>>> from pl_bolts.models.self_supervised import AMDIM
...
>>> model = AMDIM(encoder='resnet18')

Train:

trainer = Trainer()
trainer.fit(model)
Parameters
  • datamodule (Union[str, LightningDataModule]) – A LightningDatamodule

  • encoder (Union[str, Module, LightningModule]) – an encoder string or model

  • image_channels (int) – 3

  • image_height (int) – pixels

  • encoder_feature_dim (int) – Called ndf in the paper, this is the representation size for the encoder.

  • embedding_fx_dim (int) – Output dim of the embedding function (nrkhs in the paper) (Reproducing Kernel Hilbert Spaces).

  • conv_block_depth (int) – Depth of each encoder block,

  • use_bn (bool) – If true will use batchnorm.

  • tclip (int) – soft clipping non-linearity to the scores after computing the regularization term and before computing the log-softmax. This is the ‘second trick’ used in the paper

  • learning_rate (int) – The learning rate

  • data_dir (str) – Where to store data

  • num_classes (int) – How many classes in the dataset

  • batch_size (int) – The batch size

CPC (V2)

class pl_bolts.models.self_supervised.CPCV2(datamodule=None, encoder='cpc_encoder', patch_size=8, patch_overlap=4, online_ft=True, task='cpc', num_workers=4, learning_rate=0.0001, data_dir='', batch_size=32, pretrained=None, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of Data-Efficient Image Recognition with Contrastive Predictive Coding

Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord).

Model implemented by:

Example

>>> from pl_bolts.models.self_supervised import CPCV2
...
>>> model = CPCV2()

Train:

trainer = Trainer()
trainer.fit(model)

CLI command:

# cifar10
python cpc_module.py --gpus 1

# imagenet
python cpc_module.py
    --gpus 8
    --dataset imagenet2012
    --data_dir /path/to/imagenet/
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32

Some uses:

# load resnet18 pretrained using CPC on imagenet
model = CPCV2(encoder='resnet18', pretrained=True)
resnet18 = model.encoder
renset18.freeze()

# it supportes any torchvision resnet
model = CPCV2(encoder='resnet50', pretrained=True)

# use it as a feature extractor
x = torch.rand(2, 3, 224, 224)
out = model(x)
Parameters
  • datamodule (Optional[LightningDataModule]) – A Datamodule (optional). Otherwise set the dataloaders directly

  • encoder (Union[str, Module, LightningModule]) – A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder

  • patch_size (int) – How big to make the image patches

  • patch_overlap (int) – How much overlap should each patch have.

  • online_ft (int) – Enable a 1024-unit MLP to fine-tune online

  • task (str) – Which self-supervised task to use (‘cpc’, ‘amdim’, etc…)

  • num_workers (int) – num dataloader worksers

  • learning_rate (int) – what learning rate to use

  • data_dir (str) – where to store data

  • batch_size (int) – batch size

  • pretrained (Optional[str]) – If true, will use the weights pretrained (using CPC) on Imagenet

Moco (V2)

class pl_bolts.models.self_supervised.MocoV2(base_encoder='resnet18', emb_dim=128, num_negatives=65536, encoder_momentum=0.999, softmax_temperature=0.07, learning_rate=0.03, momentum=0.9, weight_decay=0.0001, datamodule=None, data_dir='./', batch_size=256, use_mlp=False, num_workers=8, *args, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of Moco

Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He.

Code adapted from facebookresearch/moco to Lightning by:

Example

>>> from pl_bolts.models.self_supervised import MocoV2
...
>>> model = MocoV2()

Train:

trainer = Trainer()
trainer.fit(model)

CLI command:

# cifar10
python moco2_module.py --gpus 1

# imagenet
python moco2_module.py
    --gpus 8
    --dataset imagenet2012
    --data_dir /path/to/imagenet/
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32
Parameters
  • base_encoder (Union[str, Module]) – torchvision model name or torch.nn.Module

  • emb_dim (int) – feature dimension (default: 128)

  • num_negatives (int) – queue size; number of negative keys (default: 65536)

  • encoder_momentum (float) – moco momentum of updating key encoder (default: 0.999)

  • softmax_temperature (float) – softmax temperature (default: 0.07)

  • learning_rate (float) – the learning rate

  • momentum (float) – optimizer momentum

  • weight_decay (float) – optimizer weight decay

  • datamodule (Optional[LightningDataModule]) – the DataModule (train, val, test dataloaders)

  • data_dir (str) – the directory to store data

  • batch_size (int) – batch size

  • use_mlp (bool) – add an mlp to the encoders

  • num_workers (int) – workers for the loaders

_batch_shuffle_ddp(x)[source]

Batch shuffle, for making use of BatchNorm. * Only support DistributedDataParallel (DDP) model. *

_batch_unshuffle_ddp(x, idx_unshuffle)[source]

Undo batch shuffle. * Only support DistributedDataParallel (DDP) model. *

_momentum_update_key_encoder()[source]

Momentum update of the key encoder

forward(img_q, img_k)[source]
Input:

im_q: a batch of query images im_k: a batch of key images

Output:

logits, targets

init_encoders(base_encoder)[source]

Override to add your own encoders

SimCLR

class pl_bolts.models.self_supervised.SimCLR(datamodule=None, data_dir='', learning_rate=6e-05, weight_decay=0.0005, input_height=32, batch_size=128, online_ft=False, num_workers=4, optimizer='lars', lr_sched_step=30.0, lr_sched_gamma=0.5, lars_momentum=0.9, lars_eta=0.001, loss_temperature=0.5, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

PyTorch Lightning implementation of SIMCLR

Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.

Model implemented by:

Example

>>> from pl_bolts.models.self_supervised import SimCLR
...
>>> model = SimCLR()

Train:

trainer = Trainer()
trainer.fit(model)

CLI command:

# cifar10
python simclr_module.py --gpus 1

# imagenet
python simclr_module.py
    --gpus 8
    --dataset imagenet2012
    --data_dir /path/to/imagenet/
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32
Parameters
  • datamodule (Optional[LightningDataModule]) – The datamodule

  • data_dir (str) – directory to store data

  • learning_rate (float) – the learning rate

  • weight_decay (float) – optimizer weight decay

  • input_height (int) – image input height

  • batch_size (int) – the batch size

  • online_ft (bool) – whether to tune online or not

  • num_workers (int) – number of workers

  • optimizer (str) – optimizer name

  • lr_sched_step (float) – step for learning rate scheduler

  • lr_sched_gamma (float) – gamma for learning rate scheduler

  • lars_momentum (float) – the mom param for lars optimizer

  • lars_eta (float) – for lars optimizer

  • loss_temperature (float) – float = 0.

Read the Docs v: 0.1.0
Versions
latest
stable
0.1.1
0.1.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.