pl_bolts.callbacks.verification.batch_gradient module¶

class
pl_bolts.callbacks.verification.batch_gradient.
BatchGradientVerification
(model)[source]¶ Bases:
pl_bolts.callbacks.verification.base.VerificationBase
Checks if a model mixes data across the batch dimension. This can happen if reshape and/or permutation operations are carried out in the wrong order or on the wrong tensor dimensions.

check
(input_array, input_mapping=None, output_mapping=None, sample_idx=0)[source]¶ Runs the test for data mixing across the batch.
 Parameters
input_array¶ (
Any
) – A dummy input for the model. Can be a tuple or dict in case the model takes multiple positional or named arguments.input_mapping¶ (
Optional
[Callable
]) – An optional input mapping that returns all batched tensors in a input collection. By default, we handle nested collections (tuples, lists, dicts) of tensors and pull them out. If your batch is a custom object, you need to provide this input mapping yourself. Seedefault_input_mapping()
for more information on the default behavior.output_mapping¶ (
Optional
[Callable
]) – An optional output mapping that combines all batched tensors in the output collection into one big batch of shape (B, N), where N is the total number of dimensions that follow the batch dimension in each tensor. By default, we handle nested collections (tuples, lists, dicts) of tensors and combine them automatically. Seedefault_output_mapping()
for more information on the default behavior.sample_idx¶ (
int
) – The index i of the batch sample to run the test for. When computing the gradient of a loss value on the ith output w.r.t. the whole input, we expect the gradient to be nonzero only on the ith input sample and zero gradient on the rest of the batch.
 Return type
 Returns
True
if the data in the batch does not mix during the forward pass, andFalse
otherwise.


class
pl_bolts.callbacks.verification.batch_gradient.
BatchGradientVerificationCallback
(input_mapping=None, output_mapping=None, sample_idx=0, **kwargs)[source]¶ Bases:
pl_bolts.callbacks.verification.base.VerificationCallbackBase
The callback version of the
BatchGradientVerification
test. Verification is performed right before training begins. Parameters
input_mapping¶ (
Optional
[Callable
]) – An optional input mapping that returns all batched tensors in a input collection. SeeBatchGradientVerification.check()
for more information.output_mapping¶ (
Optional
[Callable
]) – An optional output mapping that combines all batched tensors in the output collection into one big batch. SeeBatchGradientVerification.check()
for more information.sample_idx¶ (
int
) – The index of the batch sample to run the test for. SeeBatchGradientVerification.check()
for more information.**kwargs¶ – Additional arguments for the base class
VerificationCallbackBase

pl_bolts.callbacks.verification.batch_gradient.
collect_tensors
(data)[source]¶ Filters all tensors in a collection and returns them in a list.

pl_bolts.callbacks.verification.batch_gradient.
default_input_mapping
(data)[source]¶ Finds all tensors in a (nested) collection that have the same batch size.
 Parameters
data¶ (
Any
) – a tensor or a collection of tensors (tuple, list, dict, etc.). Return type
 Returns
A list of all tensors with the same batch dimensions. If the input was already a tensor, a one element list with the tensor is returned.
>>> data = (torch.zeros(3, 1), "foo", torch.ones(3, 2), torch.rand(2)) >>> result = default_input_mapping(data) >>> len(result) 2 >>> result[0].shape torch.Size([3, 1]) >>> result[1].shape torch.Size([3, 2])

pl_bolts.callbacks.verification.batch_gradient.
default_output_mapping
(data)[source]¶ Pulls out all tensors in a output collection and combines them into one big batch for verification.
 Parameters
data¶ (
Any
) – a tensor or a (nested) collection of tensors (tuple, list, dict, etc.). Return type
 Returns
A float tensor with shape (B, N) where B is the batch size and N is the sum of (flattened) dimensions of all tensors in the collection. If the input was already a tensor, the tensor itself is returned.
Example
>>> data = (torch.rand(3, 5), "foo", torch.rand(3, 2, 4)) >>> result = default_output_mapping(data) >>> result.shape torch.Size([3, 13]) >>> data = {"one": torch.rand(3, 5), "two": torch.rand(3, 2, 1)} >>> result = default_output_mapping(data) >>> result.shape torch.Size([3, 7])