jaxsnn.discrete.encode

Classes

LIFParameters(tau_syn, tau_mem, v_th, …)

partial

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

Functions

jaxsnn.discrete.encode.constant_current_lif_encode(input_current: jax.Array, seq_length: int)

Encodes input currents as fixed (constant) voltage currents, and simulates the spikes that occur during a number of timesteps/iterations (seq_length).

Example:
>>> data = np.array([2, 4, 8, 16])
>>> seq_length = 2 # Simulate two iterations
>>> constant_current_lif_encode(data, seq_length)
 # State in terms of membrane voltage
(DeviceArray([[0.2000, 0.4000, 0.8000, 0.0000],
         [0.3800, 0.7600, 0.0000, 0.0000]]),
 # Spikes for each iteration
 DeviceArray([[0., 0., 0., 1.],
         [0., 0., 1., 1.]]))
Parameters:

input_current (jax.Array): The input array, representing LIF current seq_length (int): The number of iterations to simulate

Returns:

An array with an extra dimension of size seq_length containing spikes (1) or no spikes (0).

jaxsnn.discrete.encode.lif_current_encoder(voltage, input_current: jax.Array, params: jaxsnn.base.params.LIFParameters = LIFParameters(tau_syn=0.005, tau_mem=0.01, v_th=0.6, v_leak=0.0, v_reset=0.0), dt: float = 0.001)Tuple[jax.Array, jax.Array]

Computes a single euler-integration step of a leaky integrator. More specifically it implements one integration step of the following ODE

\[\begin{split}\begin{align*} \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\ \dot{i} &= -1/\tau_{\text{syn}} i \end{align*}\end{split}\]
Parameters:

input (jax.Array): the input current at the current time step voltage (jax.Array): current state of the LIF neuron p (LIFParameters): parameters of a leaky integrate and fire neuron dt (float): Integration timestep to use

jaxsnn.discrete.encode.one_hot(x, k, dtype=<class 'jax.numpy.float32'>)

Create a one-hot encoding of x of size k.

jaxsnn.discrete.encode.spatio_temporal_encode(input_values: jax.Array, seq_length: int, t_late: float, dt: float)

Encodes n-dimensional input coordinates with range [0, 1], and simulates the spikes that occur during a number of timesteps/iterations (seq_length).

Example:
>>> data = np.array([2, 4, 8, 16])
>>> seq_length = 2 # Simulate two iterations
>>> spatio_temporal_encode(data, seq_length)
 # Spikes for each iteration
 DeviceArray([[0., 0., 0., 1.],
         [0., 0., 1., 1.]]))
Parameters:

input_values (torch.Tensor): The input tensor, representing 2d points seq_length (int): The number of iterations to simulate t_early (float): Earliest time at which coordinates may be encoded t_late (float): Latest time at which coordinates may be encoded dt (float): Time delta between simulation steps

Returns:

A tensor with an extra dimension of size seq_length containing spikes (1) or no spikes (0).