jaxsnn.event.root.next_finder

Classes

LIFState(V, I)

Spike(time, idx)

Functions

jaxsnn.event.root.next_finder.next_event(solver: Callable, neuron_state: jaxsnn.event.types.LIFState, time: float, t_max: float)jax.Array

Wrapper a root solver to provide a cleaner API for returning next event

Args:

solver (Callable): The actual root solver neuron_state (LIFState): The state of the neurons time (float): Current time t_max (float): Maximum time of the simulation

Returns:

Spike: Spike which will occur next

jaxsnn.event.root.next_finder.next_queue(known_spikes: jaxsnn.event.types.Spike, layer_start: int, neuron_state: jaxsnn.event.types.LIFState, time: float, t_max: float)jaxsnn.event.types.Spike

Return the upcoming spike when training with hardware-in-the-loop.

When working with the BSS-2 system, we have all the spikes in advance and need to find the index and time of the next event. When the hardware spikes are bound to this function with functools.partial, it has the same API as next_event.

Args:

known_spikes (Spike): All spikes from BSS-2 layer_start (int): Start index of the current layer neuron_state (LIFState): The state of the neurons time (float): Current time t_max (float): max time

Returns:

Spike: Spike which will occur next in the layer