jaxsnn.event.adjoint_lif
Classes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords. |
Functions
-
jaxsnn.event.adjoint_lif.
adjoint_lif_dynamic
(params: jaxsnn.base.params.LIFParameters, lambda_0: jax.Array, t: float)
-
jaxsnn.event.adjoint_lif.
adjoint_lif_exponential_flow
(params: jaxsnn.base.params.LIFParameters)
-
jaxsnn.event.adjoint_lif.
adjoint_transition_with_recurrence
(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
-
jaxsnn.event.adjoint_lif.
adjoint_transition_without_recurrence
(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
-
jaxsnn.event.adjoint_lif.
construct_adjoint_apply_fn
(step_fn, step_fn_bwd, n_hidden, n_spikes, wrap_only_step=False)
-
jaxsnn.event.adjoint_lif.
exponential_flow
(kernel: jax.Array)
-
jaxsnn.event.adjoint_lif.
filter_spikes
(input_spikes: jaxsnn.event.types.EventPropSpike, prev_layer_start: int) → jaxsnn.event.types.EventPropSpike Filters the input spikes by ensuring only the spikes from the previous layer are kept.
-
jaxsnn.event.adjoint_lif.
next_queue
(known_spikes: jaxsnn.event.types.Spike, layer_start: int, neuron_state: jaxsnn.event.types.LIFState, time: float, t_max: float) → jaxsnn.event.types.Spike Return the upcoming spike when training with hardware-in-the-loop.
When working with the BSS-2 system, we have all the spikes in advance and need to find the index and time of the next event. When the hardware spikes are bound to this function with functools.partial, it has the same API as next_event.
- Args:
known_spikes (Spike): All spikes from BSS-2 layer_start (int): Start index of the current layer neuron_state (LIFState): The state of the neurons time (float): Current time t_max (float): max time
- Returns:
Spike: Spike which will occur next in the layer
-
jaxsnn.event.adjoint_lif.
step_bwd
(adjoint_dynamics: Callable, adjoint_tr_dynamics: Callable, t_max: float, res, g)
-
jaxsnn.event.adjoint_lif.
trajectory
(step_fn: Callable[[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], int], Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike]], n_hidden: int, n_spikes: int) → Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike] Evaluate the step_fn until n_spikes have been simulated.
Uses a scan over the step_fn to return an apply function