hxtorch.spiking.transforms.decode.MaxOverTime

class hxtorch.spiking.transforms.decode.MaxOverTime(*args, **kwargs)

Bases: torch.nn.modules.module.Module

Simple max-over-time decoding

__init__(*args, **kwargs)None

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Methods

forward(input)

Translate an input tensor of shape (batch_size, time_length, population_size) into a tensor of shape (batch_size, population_size), where the time dimension is discarded by picking the maximum value along the time.

Attributes

forward(input: torch.Tensor)torch.Tensor

Translate an input tensor of shape (batch_size, time_length, population_size) into a tensor of shape (batch_size, population_size), where the time dimension is discarded by picking the maximum value along the time. Hence this module performs a ‘max-over-time’ operation.

Parameters

input – The input tensor to transform. expected shape: (batch_size, time_length, population_size)

Returns

Returns the tensor holding the max-over-time values.

training: bool