jaxsnn.event.adjoint_lif

Classes

EventPropSpike(time, idx, current)

InputQueue(spikes, head)

LIFParameters(tau_syn, tau_mem, v_th, …)

LIFState(V, I)

Spike(time, idx)

StepState(neuron_state, spike_times, …)

WeightInput(input,)

WeightRecurrent(input, recurrent)

partial

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

Functions

jaxsnn.event.adjoint_lif.adjoint_lif_dynamic(params: jaxsnn.base.params.LIFParameters, lambda_0: jax.Array, t: float)
jaxsnn.event.adjoint_lif.adjoint_lif_exponential_flow(params: jaxsnn.base.params.LIFParameters)
jaxsnn.event.adjoint_lif.adjoint_transition_with_recurrence(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
jaxsnn.event.adjoint_lif.adjoint_transition_without_recurrence(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
jaxsnn.event.adjoint_lif.construct_adjoint_apply_fn(step_fn, step_fn_bwd, size, n_spikes, wrap_only_step=False)
jaxsnn.event.adjoint_lif.exponential_flow(kernel: jax.Array)
jaxsnn.event.adjoint_lif.filter_spikes(input_spikes: jaxsnn.event.types.EventPropSpike, prev_layer_start: int)jaxsnn.event.types.EventPropSpike

Filters the input spikes by ensuring only the spikes from the previous layer are kept.

jaxsnn.event.adjoint_lif.step_bwd(adjoint_dynamics: Callable, adjoint_tr_dynamics: Callable, t_max: float, res, g)
jaxsnn.event.adjoint_lif.trajectory(step_fn: Callable[[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], int], Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike]], size: int, n_spikes: int)Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]

Evaluate the step_fn until n_spikes have been simulated.

Uses a scan over the step_fn to return an apply function