jaxsnn.discrete
Modules
Functions
-
jaxsnn.discrete.
LI
(out_dim, scale_in=0.2) Layer constructor function for a li (leaky-integrated) layer.
-
jaxsnn.discrete.
LIF
(out_dim, method=<jax._src.custom_derivatives.custom_vjp object>, scale_in=0.7, scale_rec=0.2) Layer constructor function for a lif (leaky-integrated-fire) layer.
-
jaxsnn.discrete.
acc_and_loss
(snn_apply, weights, batch, decoder)
-
jaxsnn.discrete.
max_over_time_decode
(inputs)
-
jaxsnn.discrete.
nll_loss
(snn_apply, weights, batch, decoder, expected_spikes=0.5, rho=0.0001) → Tuple[float, jax.Array]
-
jaxsnn.discrete.
spatio_temporal_encode
(input_values: jax.Array, seq_length: int, t_late: float, dt: float) Encodes n-dimensional input coordinates with range [0, 1], and simulates the spikes that occur during a number of timesteps/iterations (seq_length).
- Example:
>>> data = np.array([2, 4, 8, 16]) >>> seq_length = 2 # Simulate two iterations >>> spatio_temporal_encode(data, seq_length) # Spikes for each iteration DeviceArray([[0., 0., 0., 1.], [0., 0., 1., 1.]]))
- Parameters:
input_values (torch.Tensor): The input tensor, representing 2d points seq_length (int): The number of iterations to simulate t_early (float): Earliest time at which coordinates may be encoded t_late (float): Latest time at which coordinates may be encoded dt (float): Time delta between simulation steps
- Returns:
A tensor with an extra dimension of size seq_length containing spikes (1) or no spikes (0).