Shortcuts

Autoencoders

This section houses autoencoders and variational autoencoders.


Basic AE

This is the simplest autoencoder. You can use it like so

from pl_bolts.models.autoencoders import AE

model = AE()
trainer = Trainer()
trainer.fit(model)

You can override any part of this AE to build your own variation.

from pl_bolts.models.autoencoders import AE

class MyAEFlavor(AE):

    def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
        encoder = YourSuperFancyEncoder(...)
        return encoder

You can use the pretrained models present in bolts.

CIFAR-10 pretrained model:

from pl_bolts.models.autoencoders import AE

ae = AE(input_height=32)
print(AE.pretrained_weights_available())
ae = ae.from_pretrained('cifar10-resnet18')

ae.freeze()

Training:

loss

Reconstructions:

Both input and generated images are normalized versions as the training was done with such images.

input
recon

class pl_bolts.models.autoencoders.AE(input_height, enc_type='resnet18', first_conv=False, maxpool1=False, enc_out_dim=512, latent_dim=256, lr=0.0001, **kwargs)[source]

Bases: pytorch_lightning.

Standard AE

Model is available pretrained on different datasets:

Example:

# not pretrained
ae = AE()

# pretrained on cifar10
ae = AE(input_height=32).from_pretrained('cifar10-resnet18')
Parameters
  • input_height (int) – height of the images

  • enc_type (str) – option between resnet18 or resnet50

  • first_conv (bool) – use standard kernel_size 7, stride 2 at start or replace it with kernel_size 3, stride 1 conv

  • maxpool1 (bool) – use standard maxpool to reduce spatial dim of feat by a factor of 2

  • enc_out_dim (int) – set according to the out_channel count of encoder used (512 for resnet18, 2048 for resnet50)

  • latent_dim (int) – dim of latent space

  • lr (float) – learning rate for Adam


Variational Autoencoders

Basic VAE

Use the VAE like so.

from pl_bolts.models.autoencoders import VAE

model = VAE()
trainer = Trainer()
trainer.fit(model)

You can override any part of this VAE to build your own variation.

from pl_bolts.models.autoencoders import VAE

class MyVAEFlavor(VAE):

    def get_posterior(self, mu, std):
        # do something other than the default
        # P = self.get_distribution(self.prior, loc=torch.zeros_like(mu), scale=torch.ones_like(std))

        return P

You can use the pretrained models present in bolts.

CIFAR-10 pretrained model:

from pl_bolts.models.autoencoders import VAE

vae = VAE(input_height=32)
print(VAE.pretrained_weights_available())
vae = vae.from_pretrained('cifar10-resnet18')

vae.freeze()

Training:

reconstruction loss
kl

Reconstructions:

Both input and generated images are normalized versions as the training was done with such images.

input
recon

STL-10 pretrained model:

from pl_bolts.models.autoencoders import VAE

vae = VAE(input_height=96, first_conv=True)
print(VAE.pretrained_weights_available())
vae = vae.from_pretrained('cifar10-resnet18')

vae.freeze()

Training:

reconstruction loss
kl

class pl_bolts.models.autoencoders.VAE(input_height, enc_type='resnet18', first_conv=False, maxpool1=False, enc_out_dim=512, kl_coeff=0.1, latent_dim=256, lr=0.0001, **kwargs)[source]

Bases: pytorch_lightning.

Standard VAE with Gaussian Prior and approx posterior.

Model is available pretrained on different datasets:

Example:

# not pretrained
vae = VAE()

# pretrained on cifar10
vae = VAE(input_height=32).from_pretrained('cifar10-resnet18')

# pretrained on stl10
vae = VAE(input_height=32).from_pretrained('stl10-resnet18')
Parameters
  • input_height (int) – height of the images

  • enc_type (str) – option between resnet18 or resnet50

  • first_conv (bool) – use standard kernel_size 7, stride 2 at start or replace it with kernel_size 3, stride 1 conv

  • maxpool1 (bool) – use standard maxpool to reduce spatial dim of feat by a factor of 2

  • enc_out_dim (int) – set according to the out_channel count of encoder used (512 for resnet18, 2048 for resnet50)

  • kl_coeff (float) – coefficient for kl term of the loss

  • latent_dim (int) – dim of latent space

  • lr (float) – learning rate for Adam

Read the Docs v: 0.3.2
Versions
latest
stable
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.