This section houses autoencoders and variational autoencoders.

Basic AE

This is the simplest autoencoder. You can use it like so

from pl_bolts.models.autoencoders import AE

model = AE()
trainer = Trainer()

You can override any part of this AE to build your own variation.

from pl_bolts.models.autoencoders import AE

class MyAEFlavor(AE):

    def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
        encoder = YourSuperFancyEncoder(...)
        return encoder
class pl_bolts.models.autoencoders.AE(datamodule=None, input_channels=1, input_height=28, input_width=28, latent_dim=32, batch_size=32, hidden_dim=128, learning_rate=0.001, num_workers=8, data_dir='.', **kwargs)[source]

Bases: pytorch_lightning.LightningModule


datamodule: the datamodule (train, val, test splits) input_channels: num of image channels input_height: image height input_width: image width latent_dim: emb dim for encoder batch_size: the batch size hidden_dim: the encoder dim learning_rate: the learning rate num_workers: num dataloader workers data_dir: where to store data

Variational Autoencoders

Basic VAE

Use the VAE like so.

from pl_bolts.models.autoencoders import VAE

model = VAE()
trainer = Trainer()

You can override any part of this VAE to build your own variation.

from pl_bolts.models.autoencoders import VAE

class MyVAEFlavor(VAE):

    def get_posterior(self, mu, std):
        # do something other than the default
        # P = self.get_distribution(self.prior, loc=torch.zeros_like(mu), scale=torch.ones_like(std))

        return P
class pl_bolts.models.autoencoders.VAE(hidden_dim=128, latent_dim=32, input_channels=3, input_width=224, input_height=224, batch_size=32, learning_rate=0.001, data_dir='.', datamodule=None, num_workers=8, pretrained=None, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

Standard VAE with Gaussian Prior and approx posterior.

Model is available pretrained on different datasets:


# not pretrained
vae = VAE()

# pretrained on imagenet
vae = VAE(pretrained='imagenet')

# pretrained on cifar10
vae = VAE(pretrained='cifar10')
  • hidden_dim (int) – encoder and decoder hidden dims

  • latent_dim (int) – latenet code dim

  • input_channels (int) – num of channels of the input image.

  • input_width (int) – image input width

  • input_height (int) – image input height

  • batch_size (int) – the batch size

  • the learning rate (learning_rate") –

  • data_dir (str) – the directory to store data

  • datamodule (Optional[LightningDataModule]) – The Lightning DataModule

  • pretrained (Optional[str]) – Load weights pretrained on a dataset

Read the Docs v: 0.1.1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.