jaxsnn.event.root.next_finder
Classes
|
|
|
Functions
-
jaxsnn.event.root.next_finder.
next_event
(solver: Callable, neuron_state: jaxsnn.event.types.LIFState, time: float, t_max: float) → jax.Array Wrapper a root solver to provide a cleaner API for returning next event
- Args:
solver (Callable): The actual root solver neuron_state (LIFState): The state of the neurons time (float): Current time t_max (float): Maximum time of the simulation
- Returns:
Spike: Spike which will occur next
-
jaxsnn.event.root.next_finder.
next_queue
(known_spikes: jaxsnn.event.types.Spike, layer_start: int, neuron_state: jaxsnn.event.types.LIFState, time: float, t_max: float) → jaxsnn.event.types.Spike Return the upcoming spike when training with hardware-in-the-loop.
When working with the BSS-2 system, we have all the spikes in advance and need to find the index and time of the next event. When the hardware spikes are bound to this function with functools.partial, it has the same API as next_event.
- Args:
known_spikes (Spike): All spikes from BSS-2 layer_start (int): Start index of the current layer neuron_state (LIFState): The state of the neurons time (float): Current time t_max (float): max time
- Returns:
Spike: Spike which will occur next in the layer