Shortcuts

pl_bolts.callbacks.verification.batch_gradient module

class pl_bolts.callbacks.verification.batch_gradient.BatchGradientVerification(model)[source]

Bases: pl_bolts.callbacks.verification.base.VerificationBase

Checks if a model mixes data across the batch dimension. This can happen if reshape- and/or permutation operations are carried out in the wrong order or on the wrong tensor dimensions.

Parameters

model (Module) – The model to run verification for.

check(input_array, input_mapping=None, output_mapping=None, sample_idx=0)[source]

Runs the test for data mixing across the batch.

Parameters
  • input_array (Any) – A dummy input for the model. Can be a tuple or dict in case the model takes multiple positional or named arguments.

  • input_mapping (Optional[Callable]) – An optional input mapping that returns all batched tensors in a input collection. By default, we handle nested collections (tuples, lists, dicts) of tensors and pull them out. If your batch is a custom object, you need to provide this input mapping yourself. See default_input_mapping() for more information on the default behavior.

  • output_mapping (Optional[Callable]) – An optional output mapping that combines all batched tensors in the output collection into one big batch of shape (B, N), where N is the total number of dimensions that follow the batch dimension in each tensor. By default, we handle nested collections (tuples, lists, dicts) of tensors and combine them automatically. See default_output_mapping() for more information on the default behavior.

  • sample_idx (int) – The index i of the batch sample to run the test for. When computing the gradient of a loss value on the i-th output w.r.t. the whole input, we expect the gradient to be non-zero only on the i-th input sample and zero gradient on the rest of the batch.

Return type

bool

Returns

True if the data in the batch does not mix during the forward pass, and False otherwise.

class pl_bolts.callbacks.verification.batch_gradient.BatchGradientVerificationCallback(input_mapping=None, output_mapping=None, sample_idx=0, **kwargs)[source]

Bases: pl_bolts.callbacks.verification.base.VerificationCallbackBase

The callback version of the BatchGradientVerification test. Verification is performed right before training begins.

Parameters
message(*args, **kwargs)[source]

The message to be printed when the model does not pass the verification. If the message for warning and error differ, override the warning_message() and error_message() methods directly.

Parameters
  • *args – Any positional arguments that are needed to construct the message.

  • **kwargs – Any keyword arguments that are needed to construct the message.

Return type

str

Returns

The message as a string.

on_train_start(trainer, pl_module)[source]
Return type

None

pl_bolts.callbacks.verification.batch_gradient.collect_tensors(data)[source]

Filters all tensors in a collection and returns them in a list.

Return type

List[Tensor]

pl_bolts.callbacks.verification.batch_gradient.default_input_mapping(data)[source]

Finds all tensors in a (nested) collection that have the same batch size.

Parameters

data (Any) – a tensor or a collection of tensors (tuple, list, dict, etc.).

Return type

List[Tensor]

Returns

A list of all tensors with the same batch dimensions. If the input was already a tensor, a one- element list with the tensor is returned.

>>> data = (torch.zeros(3, 1), "foo", torch.ones(3, 2), torch.rand(2))
>>> result = default_input_mapping(data)
>>> len(result)
2
>>> result[0].shape
torch.Size([3, 1])
>>> result[1].shape
torch.Size([3, 2])
pl_bolts.callbacks.verification.batch_gradient.default_output_mapping(data)[source]

Pulls out all tensors in a output collection and combines them into one big batch for verification.

Parameters

data (Any) – a tensor or a (nested) collection of tensors (tuple, list, dict, etc.).

Return type

Tensor

Returns

A float tensor with shape (B, N) where B is the batch size and N is the sum of (flattened) dimensions of all tensors in the collection. If the input was already a tensor, the tensor itself is returned.

Example

>>> data = (torch.rand(3, 5), "foo", torch.rand(3, 2, 4))
>>> result = default_output_mapping(data)
>>> result.shape
torch.Size([3, 13])
>>> data = {"one": torch.rand(3, 5), "two": torch.rand(3, 2, 1)}
>>> result = default_output_mapping(data)
>>> result.shape
torch.Size([3, 7])
Read the Docs v: 0.3.0
Versions
latest
stable
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.