Shortcuts

GANs

Collection of Generative Adversarial Networks


Basic GAN

This is a vanilla GAN. This model can work on any dataset size but results are shown for MNIST. Replace the encoder, decoder or any part of the training loop to build a new method, or simply finetune on your data.

Implemented by:

  • William Falcon

Example outputs:

Basic GAN generated samples

Loss curves:

Basic GAN disc loss Basic GAN gen loss
from pl_bolts.models.gans import GAN
...
gan = GAN()
trainer = Trainer()
trainer.fit(gan)
class pl_bolts.models.gans.GAN(input_channels, input_height, input_width, latent_dim=32, learning_rate=0.0002, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

Vanilla GAN implementation.

Example:

from pl_bolts.models.gan import GAN

m = GAN()
Trainer(gpus=2).fit(m)

Example CLI:

# mnist
python  basic_gan_module.py --gpus 1

# imagenet
python  basic_gan_module.py --gpus 1 --dataset 'imagenet2012'
--data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder
--batch_size 256 --learning_rate 0.0001
Parameters
  • input_channels (int) – number of channels of an image

  • input_height (int) – image height

  • input_width (int) – image width

  • latent_dim (int) – emb dim for encoder

  • learning_rate (float) – the learning rate

forward(z)[source]

Generates an image given input noise z

Example:

z = torch.rand(batch_size, latent_dim)
gan = GAN.load_from_checkpoint(PATH)
img = gan(z)
Read the Docs v: latest
Versions
latest
stable
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
0.1.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.