Shortcuts

Vision DataModules

The following are pre-built datamodules for computer-vision.


Supervised learning

These are standard vision datasets with the train, test, val splits pre-generated in DataLoaders with the standard transforms (and Normalization) values

MNIST

class pl_bolts.datamodules.mnist_datamodule.MNISTDataModule(data_dir, val_split=5000, num_workers=16, normalize=False, *args, **kwargs)[source]

Bases: pl_bolts.datamodules.lightning_datamodule.LightningDataModule

Standard MNIST, train, val, test splits and transforms

Transforms:

mnist_transforms = transform_lib.Compose([
    transform_lib.ToTensor()
])

Example:

from pl_bolts.datamodules import MNISTDataModule

dm = MNISTDataModule('.')
model = LitModel(datamodule=dm)
Parameters
  • data_dir (str) – where to save/load the data

  • val_split (int) – how many of the training images to use for the validation split

  • num_workers (int) – how many workers to use for loading data

  • normalize (bool) – If true applies image normalize

prepare_data()[source]

Saves MNIST files to data_dir

test_dataloader(batch_size=32, transforms=None)[source]

MNIST test set uses the test split

Parameters
  • batch_size – size of batch

  • transforms – custom transforms

train_dataloader(batch_size=32, transforms=None)[source]

MNIST train set removes a subset to use for validation

Parameters
  • batch_size – size of batch

  • transforms – custom transforms

val_dataloader(batch_size=32, transforms=None)[source]

MNIST val set uses a subset of the training set for validation

Parameters
  • batch_size – size of batch

  • transforms – custom transforms

property num_classes[source]

Return: 10

FashionMNIST

class pl_bolts.datamodules.fashion_mnist_datamodule.FashionMNISTDataModule(data_dir, val_split=5000, num_workers=16, *args, **kwargs)[source]

Bases: pl_bolts.datamodules.lightning_datamodule.LightningDataModule

Standard FashionMNIST, train, val, test splits and transforms

Transforms:

mnist_transforms = transform_lib.Compose([
    transform_lib.ToTensor()
])

Example:

from pl_bolts.datamodules import FashionMNISTDataModule

dm = FashionMNISTDataModule('.')
model = LitModel(datamodule=dm)
Parameters
  • data_dir (str) – where to save/load the data

  • val_split (int) – how many of the training images to use for the validation split

  • num_workers (int) – how many workers to use for loading data

prepare_data()[source]

Saves FashionMNIST files to data_dir

test_dataloader(batch_size=32, transforms=None)[source]

FashionMNIST test set uses the test split

Parameters
  • batch_size – size of batch

  • transforms – custom transforms

train_dataloader(batch_size=32, transforms=None)[source]

FashionMNIST train set removes a subset to use for validation

Parameters
  • batch_size – size of batch

  • transforms – custom transforms

val_dataloader(batch_size=32, transforms=None)[source]

FashionMNIST val set uses a subset of the training set for validation

Parameters
  • batch_size – size of batch

  • transforms – custom transforms

property num_classes[source]

Return: 10

CIFAR-10

class pl_bolts.datamodules.cifar10_datamodule.CIFAR10DataModule(data_dir, val_split=5000, num_workers=16, *args, **kwargs)[source]

Bases: pl_bolts.datamodules.lightning_datamodule.LightningDataModule

Standard CIFAR10, train, val, test splits and transforms

Transforms:

mnist_transforms = transform_lib.Compose([
    transform_lib.ToTensor(),
    transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
    )
])

Example:

from pl_bolts.datamodules import CIFAR10DataModule

dm = CIFAR10DataModule(PATH)
model = LitModel(datamodule=dm)

Or you can set your own transforms

Example:

dm.train_transforms = ...
dm.test_transforms = ...
dm.val_transforms  = ...
Parameters
  • data_dir – where to save/load the data

  • val_split – how many of the training images to use for the validation split

  • num_workers – how many workers to use for loading data

prepare_data()[source]

Saves CIFAR10 files to data_dir

test_dataloader(batch_size)[source]

CIFAR10 test set uses the test split

Parameters
  • batch_size – size of batch

  • transforms – custom transforms

train_dataloader(batch_size)[source]

CIFAR train set removes a subset to use for validation

Parameters

batch_size – size of batch

val_dataloader(batch_size)[source]

CIFAR10 val set uses a subset of the training set for validation

Parameters

batch_size – size of batch

property num_classes[source]

Return: 10

Imagenet

class pl_bolts.datamodules.imagenet_datamodule.ImagenetDataModule(data_dir, meta_dir=None, num_imgs_per_val_class=50, image_size=224, num_workers=16, *args, **kwargs)[source]

Bases: pl_bolts.datamodules.lightning_datamodule.LightningDataModule

Imagenet train, val and test dataloaders.

The train set is the imagenet train.

The val set is taken from the train set with num_imgs_per_val_class images per class. For example if num_imgs_per_val_class=2 then there will be 2,000 images in the validation set.

The test set is the official imagenet validation set.

Example:

from pl_bolts.datamodules import ImagenetDataModule

datamodule = ImagenetDataModule(IMAGENET_PATH)
Parameters
  • data_dir (str) – path to the imagenet dataset file

  • meta_dir (Optional[str]) – path to meta.bin file

  • num_imgs_per_val_class (int) – how many images per class for the validation set

  • image_size (int) – final image size

  • num_workers (int) – how many data workers

prepare_data()[source]

This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin.

Warning

Please download imagenet on your own first.

test_dataloader(batch_size, num_images_per_class=-1, transforms=None)[source]

Uses the validation split of imagenet2012 for testing

Parameters
  • batch_size – the batch size

  • num_images_per_class – how many images per class to test on

  • transforms – the transforms

train_dataloader(batch_size)[source]

Uses the train split of imagenet2012 and puts away a portion of it for the validation split

Parameters
  • batch_size – the batch size

  • transforms – the transforms

train_transform()[source]

The standard imagenet transforms

transform_lib.Compose([
    transform_lib.RandomResizedCrop(self.image_size),
    transform_lib.RandomHorizontalFlip(),
    transform_lib.ToTensor(),
    transform_lib.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])
val_dataloader(batch_size, transforms=None)[source]

Uses the part of the train split of imagenet2012 that was not used for training via num_imgs_per_val_class

Parameters
  • batch_size – the batch size

  • transforms – the transforms

val_transform()[source]

The standard imagenet transforms for validation

transform_lib.Compose([
    transform_lib.Resize(self.image_size + 32),
    transform_lib.CenterCrop(self.image_size),
    transform_lib.ToTensor(),
    transform_lib.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])
property num_classes[source]

Return:

1000


Semi-supervised learning

The following datasets have support for unlabeled training and semi-supervised learning where only a few examples are labeled.

Imagenet (ssl)

class pl_bolts.datamodules.ssl_imagenet_datamodule.SSLImagenetDataModule(data_dir, meta_dir=None, num_workers=16, *args, **kwargs)[source]

Bases: pl_bolts.datamodules.lightning_datamodule.LightningDataModule

prepare_data()[source]

Use this to download and prepare data. In distributed (GPU, TPU), this will only be called once. This is called before requesting the dataloaders:

Warning

Do not assign anything to the model in this step since this will only be called on 1 GPU.

Pseudocode:

model.prepare_data()
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()

Example:

def prepare_data(self):
    download_imagenet()
    clean_imagenet()
    cache_imagenet()
test_dataloader(batch_size, num_images_per_class, add_normalize=False)[source]

Implement a PyTorch DataLoader for training.

Returns

Single PyTorch DataLoader.

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

You can also return a list of DataLoaders

Example:

def test_dataloader(self):
    dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
    loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
    return loader
train_dataloader(batch_size, num_images_per_class=-1, add_normalize=False)[source]

Implement a PyTorch DataLoader for training.

Returns

Single PyTorch DataLoader.

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example:

def train_dataloader(self):
    dataset = MNIST(root=PATH, train=True, transform=transforms.ToTensor(), download=False)
    loader = torch.utils.data.DataLoader(dataset=dataset)
    return loader
val_dataloader(batch_size, num_images_per_class=50, add_normalize=False)[source]

Implement a PyTorch DataLoader for training.

Returns

Single PyTorch DataLoader.

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

You can also return a list of DataLoaders

Example:

def val_dataloader(self):
    dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
    loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
    return loader

STL-10

class pl_bolts.datamodules.stl10_datamodule.STL10DataModule(data_dir, unlabeled_val_split=5000, train_val_split=500, num_workers=16, *args, **kwargs)[source]

Bases: pl_bolts.datamodules.lightning_datamodule.LightningDataModule

Standard STL-10, train, val, test splits and transforms. STL-10 has support for doing validation splits on the labeled or unlabeled splits

Transforms:

mnist_transforms = transform_lib.Compose([
    transform_lib.ToTensor(),
    transforms.Normalize(
        mean=(0.43, 0.42, 0.39),
        std=(0.27, 0.26, 0.27)
    )
])

Example:

from pl_bolts.datamodules import STL10DataModule

dm = STL10DataModule(PATH)
model = LitModel(datamodule=dm)
Parameters
  • data_dir (str) – where to save/load the data

  • unlabeled_val_split (int) – how many images from the unlabeled training split to use for validation

  • train_val_split (int) – how many images from the labeled training split to use for validation

  • num_workers (int) – how many workers to use for loading data

prepare_data()[source]

Downloads the unlabeled, train and test split

test_dataloader(batch_size)[source]

Loads the test split of STL10

Parameters
  • batch_size – the batch size

  • transforms – the transforms

train_dataloader(batch_size)[source]

Loads the ‘unlabeled’ split minus a portion set aside for validation via unlabeled_val_split.

Parameters

batch_size – the batch size

train_dataloader_mixed(batch_size)[source]

Loads a portion of the ‘unlabeled’ training data and ‘train’ (labeled) data. both portions have a subset removed for validation via unlabeled_val_split and train_val_split

Parameters
  • batch_size – the batch size

  • transforms – a sequence of transforms

val_dataloader(batch_size)[source]

Loads a portion of the ‘unlabeled’ training data set aside for validation The val dataset = (unlabeled - train_val_split)

Parameters
  • batch_size – the batch size

  • transforms – a sequence of transforms

val_dataloader_mixed(batch_size)[source]

Loads a portion of the ‘unlabeled’ training data set aside for validation along with the portion of the ‘train’ dataset to be used for validation

unlabeled_val = (unlabeled - train_val_split)

labeled_val = (train- train_val_split)

full_val = unlabeled_val + labeled_val

Parameters
  • batch_size – the batch size

  • transforms – a sequence of transforms

Read the Docs v: 0.1.0
Versions
latest
stable
0.1.1
0.1.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.