Shortcuts

Self-supervised Learning Contrastive tasks

This section implements popular contrastive learning tasks used in self-supervised learning.


FeatureMapContrastiveTask

This task compares sets of feature maps.

In general the feature map comparison pretext task uses triplets of features. Here are the abstract steps of comparison.

Generate multiple views of the same image

x1_view_1 = data_augmentation(x1)
x1_view_2 = data_augmentation(x1)

Use a different example to generate additional views (usually within the same batch or a pool of candidates)

x2_view_1 = data_augmentation(x2)
x2_view_2 = data_augmentation(x2)

Pick 3 views to compare, these are the anchor, positive and negative features

anchor = x1_view_1
positive = x1_view_2
negative = x2_view_1

Generate feature maps for each view

(a0, a1, a2) = encoder(anchor)
(p0, p1, p2) = encoder(positive)

Make a comparison for a set of feature maps

phi = some_score_function()

# the '01' comparison
score = phi(a0, p1)

# and can be bidirectional
score = phi(p0, a1)

In practice the contrastive task creates a BxB matrix where B is the batch size. The diagonals for set 1 of feature maps are the anchors, the diagonals of set 2 of the feature maps are the positives, the non-diagonals of set 1 are the negatives.

class pl_bolts.losses.self_supervised_learning.FeatureMapContrastiveTask(comparisons='00, 11', tclip=10.0, bidirectional=True)[source]

Bases: torch.nn.Module

Performs an anchor, positive negative pair comparison for each each tuple of feature maps passed.

# extract feature maps
pos_0, pos_1, pos_2 = encoder(x_pos)
anc_0, anc_1, anc_2 = encoder(x_anchor)

# compare only the 0th feature maps
task = FeatureMapContrastiveTask('00')
loss, regularizer = task((pos_0), (anc_0))

# compare (pos_0 to anc_1) and (pos_0, anc_2)
task = FeatureMapContrastiveTask('01, 02')
losses, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2))
loss = losses.sum()

# compare (pos_1 vs a anc_random)
task = FeatureMapContrastiveTask('0r')
loss, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2))
Parameters
  • comparisons (str) – groupings of feature map indices to compare (zero indexed, ‘r’ means random) ex: ‘00, 1r’

  • tclip (float) – stability clipping value

  • bidirectional (bool) – if true, does the comparison both ways

# with bidirectional the comparisons are done both ways
task = FeatureMapContrastiveTask('01, 02')

# will compare the following:
# 01: (pos_0, anc_1), (anc_0, pos_1)
# 02: (pos_0, anc_2), (anc_0, pos_2)
forward(anchor_maps, positive_maps)[source]

Takes in a set of tuples, each tuple has two feature maps with all matching dimensions

Example

>>> import torch
>>> from pytorch_lightning import seed_everything
>>> seed_everything(0)
0
>>> a1 = torch.rand(3, 5, 2, 2)
>>> a2 = torch.rand(3, 5, 2, 2)
>>> b1 = torch.rand(3, 5, 2, 2)
>>> b2 = torch.rand(3, 5, 2, 2)
...
>>> task = FeatureMapContrastiveTask('01, 11')
...
>>> losses, regularizer = task((a1, a2), (b1, b2))
>>> losses
tensor([2.2351, 2.1902])
>>> regularizer
tensor(0.0324)
static parse_map_indexes(comparisons)[source]

Example:

>>> FeatureMapContrastiveTask.parse_map_indexes('11')
[(1, 1)]
>>> FeatureMapContrastiveTask.parse_map_indexes('11,59')
[(1, 1), (5, 9)]
>>> FeatureMapContrastiveTask.parse_map_indexes('11,59, 2r')
[(1, 1), (5, 9), (2, -1)]

Context prediction tasks

The following tasks aim to predict a target using a context representation.

CPCContrastiveTask

This is the predictive task from CPC (v2).

task = CPCTask(num_input_channels=32)

# (batch, channels, rows, cols)
# this should be thought of as 49 feature vectors, each with 32 dims
Z = torch.random.rand(3, 32, 7, 7)

loss = task(Z)
class pl_bolts.losses.self_supervised_learning.CPCTask(num_input_channels, target_dim=64, embed_scale=0.1)[source]

Bases: torch.nn.Module

Loss used in CPC

Read the Docs v: 0.1.1
Versions
latest
stable
0.1.1
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.