jaxsnn.discrete

Modules

jaxsnn.discrete.decode

jaxsnn.discrete.encode

jaxsnn.discrete.leaky_integrate

jaxsnn.discrete.leaky_integrate_and_fire

jaxsnn.discrete.loss

jaxsnn.discrete.threshold

Functions

jaxsnn.discrete.LI(out_dim, scale_in=0.2)

Layer constructor function for a li (leaky-integrated) layer.

jaxsnn.discrete.LIF(out_dim, method=<jax._src.custom_derivatives.custom_vjp object>, scale_in=0.7, scale_rec=0.2)

Layer constructor function for a lif (leaky-integrated-fire) layer.

jaxsnn.discrete.acc_and_loss(snn_apply, weights, batch, decoder)
jaxsnn.discrete.max_over_time_decode(inputs)
jaxsnn.discrete.nll_loss(snn_apply, weights, batch, decoder, expected_spikes=0.5, rho=0.0001)Tuple[float, jax.Array]
jaxsnn.discrete.spatio_temporal_encode(input_values: jax.Array, seq_length: int, t_late: float, dt: float)

Encodes n-dimensional input coordinates with range [0, 1], and simulates the spikes that occur during a number of timesteps/iterations (seq_length).

Example:
>>> data = np.array([2, 4, 8, 16])
>>> seq_length = 2 # Simulate two iterations
>>> spatio_temporal_encode(data, seq_length)
 # Spikes for each iteration
 DeviceArray([[0., 0., 0., 1.],
         [0., 0., 1., 1.]]))
Parameters:

input_values (torch.Tensor): The input tensor, representing 2d points seq_length (int): The number of iterations to simulate t_early (float): Earliest time at which coordinates may be encoded t_late (float): Latest time at which coordinates may be encoded dt (float): Time delta between simulation steps

Returns:

A tensor with an extra dimension of size seq_length containing spikes (1) or no spikes (0).