jaxsnn.event.adjoint_lif
Classes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords. |
Functions
-
jaxsnn.event.adjoint_lif.
adjoint_lif_dynamic
(params: jaxsnn.base.params.LIFParameters, lambda_0: jax.Array, t: float)
-
jaxsnn.event.adjoint_lif.
adjoint_lif_exponential_flow
(params: jaxsnn.base.params.LIFParameters)
-
jaxsnn.event.adjoint_lif.
adjoint_transition_with_recurrence
(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
-
jaxsnn.event.adjoint_lif.
adjoint_transition_without_recurrence
(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
-
jaxsnn.event.adjoint_lif.
construct_adjoint_apply_fn
(step_fn, step_fn_bwd, size, n_spikes, wrap_only_step=False)
-
jaxsnn.event.adjoint_lif.
exponential_flow
(kernel: jax.Array)
-
jaxsnn.event.adjoint_lif.
filter_spikes
(input_spikes: jaxsnn.event.types.EventPropSpike, prev_layer_start: int) → jaxsnn.event.types.EventPropSpike Filters the input spikes by ensuring only the spikes from the previous layer are kept.
-
jaxsnn.event.adjoint_lif.
step_bwd
(adjoint_dynamics: Callable, adjoint_tr_dynamics: Callable, t_max: float, res, g)
-
jaxsnn.event.adjoint_lif.
trajectory
(step_fn: Callable[[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], int], Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike]], size: int, n_spikes: int) → Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike] Evaluate the step_fn until n_spikes have been simulated.
Uses a scan over the step_fn to return an apply function