jaxsnn.event.adjoint_lif

Classes

EventPropSpike(time, idx, current)

InputQueue(spikes, head)

LIFParameters(tau_syn, tau_mem, v_th, …)

LIFState(V, I)

StepState(neuron_state, spike_times, …)

WeightInput(input,)

WeightRecurrent(input, recurrent)

partial

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

Functions

jaxsnn.event.adjoint_lif.adjoint_lif_dynamic(params: jaxsnn.base.params.LIFParameters, lambda_0: jax.Array, t: float)
jaxsnn.event.adjoint_lif.adjoint_lif_exponential_flow(params: jaxsnn.base.params.LIFParameters)
jaxsnn.event.adjoint_lif.adjoint_transition_with_recurrence(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
jaxsnn.event.adjoint_lif.adjoint_transition_without_recurrence(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
jaxsnn.event.adjoint_lif.construct_adjoint_apply_fn(step_fn, step_fn_bwd, n_hidden, n_spikes, wrap_only_step=False)
jaxsnn.event.adjoint_lif.exponential_flow(kernel: jax.Array)
jaxsnn.event.adjoint_lif.filter_spikes(input_spikes: jaxsnn.event.types.EventPropSpike, prev_layer_start: int)jaxsnn.event.types.EventPropSpike

Filters the input spikes by ensuring only the spikes from the previous layer are kept.

jaxsnn.event.adjoint_lif.next_queue(known_spikes: jaxsnn.event.types.Spike, layer_start: int, neuron_state: jaxsnn.event.types.LIFState, time: float, t_max: float)jaxsnn.event.types.Spike

Return the upcoming spike when training with hardware-in-the-loop.

When working with the BSS-2 system, we have all the spikes in advance and need to find the index and time of the next event. When the hardware spikes are bound to this function with functools.partial, it has the same API as next_event.

Args:

known_spikes (Spike): All spikes from BSS-2 layer_start (int): Start index of the current layer neuron_state (LIFState): The state of the neurons time (float): Current time t_max (float): max time

Returns:

Spike: Spike which will occur next in the layer

jaxsnn.event.adjoint_lif.step_bwd(adjoint_dynamics: Callable, adjoint_tr_dynamics: Callable, t_max: float, res, g)
jaxsnn.event.adjoint_lif.trajectory(step_fn: Callable[[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], int], Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike]], n_hidden: int, n_spikes: int)Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]

Evaluate the step_fn until n_spikes have been simulated.

Uses a scan over the step_fn to return an apply function