Shortcuts

Autoencoders

This section houses autoencoders and variational autoencoders.


Basic AE

This is the simplest autoencoder. You can use it like so

from pl_bolts.models.autoencoders import AE

model = AE()
trainer = Trainer()
trainer.fit(model)

You can override any part of this AE to build your own variation.

from pl_bolts.models.autoencoders import AE

class MyAEFlavor(AE):

    def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
        encoder = YourSuperFancyEncoder(...)
        return encoder

You can use the pretrained models present in bolts.

CIFAR-10 pretrained model:

from pl_bolts.models.autoencoders import AE

ae = AE(input_height=32)
print(AE.pretrained_weights_available())
ae = ae.from_pretrained('cifar10-resnet18')

ae.freeze()

Training:

loss

Reconstructions:

Both input and generated images are normalized versions as the training was done with such images.

input
recon

class pl_bolts.models.autoencoders.AE(input_height, enc_type='resnet18', first_conv=False, maxpool1=False, enc_out_dim=512, latent_dim=256, lr=0.0001, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

Standard AE

Model is available pretrained on different datasets:

Example:

# not pretrained
ae = AE()

# pretrained on cifar10
ae = AE.from_pretrained('cifar10-resnet18')
Parameters
  • input_height (int) – height of the images

  • enc_type (str) – option between resnet18 or resnet50

  • first_conv (bool) – use standard kernel_size 7, stride 2 at start or replace it with kernel_size 3, stride 1 conv

  • maxpool1 (bool) – use standard maxpool to reduce spatial dim of feat by a factor of 2

  • enc_out_dim (int) – set according to the out_channel count of encoder used (512 for resnet18, 2048 for resnet50)

  • latent_dim (int) – dim of latent space

  • lr (float) – learning rate for Adam


Variational Autoencoders

Basic VAE

Use the VAE like so.

from pl_bolts.models.autoencoders import VAE

model = VAE()
trainer = Trainer()
trainer.fit(model)

You can override any part of this VAE to build your own variation.

from pl_bolts.models.autoencoders import VAE

class MyVAEFlavor(VAE):

    def get_posterior(self, mu, std):
        # do something other than the default
        # P = self.get_distribution(self.prior, loc=torch.zeros_like(mu), scale=torch.ones_like(std))

        return P

You can use the pretrained models present in bolts.

CIFAR-10 pretrained model:

from pl_bolts.models.autoencoders import VAE

vae = VAE(input_height=32)
print(VAE.pretrained_weights_available())
vae = vae.from_pretrained('cifar10-resnet18')

vae.freeze()

Training:

reconstruction loss
kl

Reconstructions:

Both input and generated images are normalized versions as the training was done with such images.

input
recon

STL-10 pretrained model:

from pl_bolts.models.autoencoders import VAE

vae = VAE(input_height=96, first_conv=True)
print(VAE.pretrained_weights_available())
vae = vae.from_pretrained('cifar10-resnet18')

vae.freeze()

Training:

reconstruction loss
kl

class pl_bolts.models.autoencoders.VAE(input_height, enc_type='resnet18', first_conv=False, maxpool1=False, enc_out_dim=512, kl_coeff=0.1, latent_dim=256, lr=0.0001, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

Standard VAE with Gaussian Prior and approx posterior.

Model is available pretrained on different datasets:

Example:

# not pretrained
vae = VAE()

# pretrained on cifar10
vae = VAE.from_pretrained('cifar10-resnet18')

# pretrained on stl10
vae = VAE.from_pretrained('stl10-resnet18')
Parameters
  • input_height (int) – height of the images

  • enc_type (str) – option between resnet18 or resnet50

  • first_conv (bool) – use standard kernel_size 7, stride 2 at start or replace it with kernel_size 3, stride 1 conv

  • maxpool1 (bool) – use standard maxpool to reduce spatial dim of feat by a factor of 2

  • enc_out_dim (int) – set according to the out_channel count of encoder used (512 for resnet18, 2048 for resnet50)

  • kl_coeff (float) – coefficient for kl term of the loss

  • latent_dim (int) – dim of latent space

  • lr (float) – learning rate for Adam

Read the Docs v: latest
Versions
latest
stable
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
0.1.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.