Vision Callbacks

Useful callbacks for vision models

Confused Logit

Shows how the input would have to change to move the prediction from one logit to the other

Example outputs:

Example of prediction confused between 5 and 8
class, projection_factor=3, min_logit_value=5.0, logging_batch_interval=20, max_logit_difference=0.1)[source]

Bases: pytorch_lightning.Callback

Takes the logit predictions of a model and when the probabilities of two classes are very close, the model doesn’t have high certainty that it should pick one vs the other class.

This callback shows how the input would have to change to swing the model from one label prediction to the other.

In this case, the network predicts a 5… but gives almost equal probability to an 8. The images show what about the original 5 would have to change to make it more like a 5 or more like an 8.

For each confused logit the confused images are generated by taking the gradient from a logit wrt an input for the top two closest logits.


from import ConfusedLogitCallback
trainer = Trainer(callbacks=[ConfusedLogitCallback()])


whenever called, this model will look for self.last_batch and self.last_logits in the LightningModule


this callback supports tensorboard only right now

  • top_k – How many “offending” images we should plot

  • projection_factor – How much to multiply the input image to make it look more like this logit label

  • min_logit_value – Only consider logit values above this threshold

  • logging_batch_interval – how frequently to inspect/potentially plot something

  • max_logit_difference – when the top 2 logits are within this threshold we consider them confused

Authored by:

  • Alfredo Canziani

Tensorboard Image Generator

Generates images from a generative model and plots to tensorboard


Bases: pytorch_lightning.Callback

Generates images and logs to tensorboard. Your model must implement the forward function for generation


# model must have img_dim arg
model.img_dim = (1, 28, 28)

# model forward must work for sampling
z = torch.rand(batch_size, latent_dim)
img_samples = your_model(z)


from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler

trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()])
Read the Docs v: 0.1.1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.