jaxsnn.event.modules.leaky_integrate_and_fire

Implement different LIF layers, which can be concatenated

Each layer returns a paif or two functions, the init function and the apply function. These functions can be concatenated with jaxsnn.event.compose.serial, which also returns and init/apply pair, consisting of multiple layers. The init function is used to initalize the weights of the network. The apply function does the inference and is equivalent to the forward function is in PyTorch. It receives the input spikes and weights of the network and returns the hidden spikes.

The layers in this module differ in the topology they offer (feed-forward / recurrent) and in the way the gradients are computed (analytical via jax.grad or with an adjoint system (EventProp).

HardwareLIF and HardwareRecurrentLIF allow the execution of the forward pass on the neuromorphic BSS-2 system. They forward pass is executed on the neuromorphic system and the spikes are retrived. Because the spike data from BSS-2 is missing information about the synaptic current at spike time (which is needed for the EventProp algorithm), a second forward pass in software is executed. The spike times from the hardware are used as solution for the root solving. The adjoint system of the EventProp algorithm is added as a custom Vector-Jacobian-Product (VJP).

Classes

LIFParameters(tau_syn, tau_mem, v_th, …)

partial

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

Functions

jaxsnn.event.modules.leaky_integrate_and_fire.EventPropLIF(n_hidden: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean=0.5, std=2.0, wrap_only_step: bool = False, duplication: Optional[int, None] = None)Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]]

Feed-forward layer of LIF neurons with EventProp gradient computation.

Args:

n_hidden (int): Number of hidden neurons n_spikes (int): Number of spikes which are simulated in this t_max (float): Maximum simulation time p (LIFParameters): Parameters of the LIF neurons mean (float, optional): Mean of initial weights. Defaults to 0.5. std (float, optional): Standard deviation of initial weights.

Defaults to 2.0.

wrap_only_step (bool, optional): If custom vjp should be defined

only for the step function or for the entire trajectory. Defaults to False.

duplication (Optional[int], optional): Factor with which input weights

are duplicated. Defaults to None.

Returns:

SingleInitApply: Pair of init apply functions.

jaxsnn.event.modules.leaky_integrate_and_fire.HardwareLIF(n_hidden: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: float, std: float, duplication: Optional[int, None] = None)Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike, jaxsnn.event.types.Spike], jaxsnn.event.types.EventPropSpike]]
jaxsnn.event.modules.leaky_integrate_and_fire.HardwareRecurrentLIF(layers: List[int], n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: List[float], std: List[float], duplication: Optional[int, None] = None)
jaxsnn.event.modules.leaky_integrate_and_fire.LIF(n_hidden: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: float = 0.5, std: float = 2.0, duplication: Optional[int, None] = None)Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]]

A feed-forward layer of LIF Neurons.

Args:

n_hidden (int): Number of hidden neurons n_spikes (int): Number of spikes which are simulated in this layer t_max (float): Maxium simulation time p (LIFParameters): Parameters of the LIF neurons mean (float, optional): Mean of initial weights. Defaults to 0.5. std (float, optional): Standard deviation of initial weights.

Defaults to 2.0.

Returns:

SingleInitApply: _description_

jaxsnn.event.modules.leaky_integrate_and_fire.RecurrentEventPropLIF(layers: List[int], n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: List[float], std: List[float], wrap_only_step: bool = False, duplication: Optional[int, None] = None)Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]]

Use quadrants of the recurrent weight matrix to set up a multi-layer feed-forward LIF in one recurrent layer.

When simulating multiple layers, the first layer needs to be fully simulated before the resulting spikes are passed to the next layer. When viewing multiple feed-forward layers as one recurrent layer with the only rectangular parts of the weight matrix initialized with non-zero entries, multiple feed-forward layers can be simulated together.

Args:

layers (List[int]): Number of neurons in each feed-forward layer n_spikes (int): Number of spikes which are simulated in this t_max (float): Maximum simulation time p (LIFParameters): Parameters of the LIF neurons mean (float): Mean of initial weights. std (float): Standard deviation of initial weights. wrap_only_step (bool, optional): If custom vjp should be defined only

for the step function or for the entire trajectory. Defaults to False.

duplication (Optional[int], optional): Factor with which input weights

are duplicated. Defaults to None.

Returns:

SingleInitApply: Pair of init apply functions.

jaxsnn.event.modules.leaky_integrate_and_fire.RecurrentLIF(layers: List[int], n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: List[float], std: List[float], duplication: Optional[int, None] = None)Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]]
jaxsnn.event.modules.leaky_integrate_and_fire.adjoint_lif_exponential_flow(params: jaxsnn.base.params.LIFParameters)
jaxsnn.event.modules.leaky_integrate_and_fire.adjoint_transition_with_recurrence(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
jaxsnn.event.modules.leaky_integrate_and_fire.adjoint_transition_without_recurrence(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
jaxsnn.event.modules.leaky_integrate_and_fire.construct_adjoint_apply_fn(step_fn, step_fn_bwd, n_hidden, n_spikes, wrap_only_step=False)
jaxsnn.event.modules.leaky_integrate_and_fire.construct_init_fn(n_hidden: int, mean: float, std: float, duplication: Optional[int, None] = None)Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]]
jaxsnn.event.modules.leaky_integrate_and_fire.construct_recurrent_init_fn(layers: List[int], mean: List[float], std: List[float], duplication: Optional[float, None] = None)Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]]
jaxsnn.event.modules.leaky_integrate_and_fire.lif_exponential_flow(params: jaxsnn.base.params.LIFParameters)
jaxsnn.event.modules.leaky_integrate_and_fire.next_event(solver: Callable, neuron_state: jaxsnn.event.types.LIFState, time: float, t_max: float)jax.Array

Wrapper a root solver to provide a cleaner API for returning next event

Args:

solver (Callable): The actual root solver neuron_state (LIFState): The state of the neurons time (float): Current time t_max (float): Maximum time of the simulation

Returns:

Spike: Spike which will occur next

jaxsnn.event.modules.leaky_integrate_and_fire.step(dynamics: Callable, tr_dynamics: Callable, t_max: float, solver: Callable[[jaxsnn.event.types.LIFState, float, float], jaxsnn.event.types.Spike], step_input: Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], *args: int)Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike]

Find next spike (external or internal), and simulate to that point.

Args:

dynamics (Callable): Function describing the continuous neuron dynamics tr_dynamics (Callable): Function describing the transition dynamics t_max (float): Max time until which to run solver (Solver): Parallel root solver which returns the next event state (StepInput): (StepState, weights, int)

Returns:

Tuple[StepInput, Spike]: New state after transition and stored spike

jaxsnn.event.modules.leaky_integrate_and_fire.step_bwd(adjoint_dynamics: Callable, adjoint_tr_dynamics: Callable, t_max: float, res, g)
jaxsnn.event.modules.leaky_integrate_and_fire.trajectory(step_fn: Callable[[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], int], Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike]], n_hidden: int, n_spikes: int)Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]

Evaluate the step_fn until n_spikes have been simulated.

Uses a scan over the step_fn to return an apply function

jaxsnn.event.modules.leaky_integrate_and_fire.transition_with_recurrence(params: jaxsnn.base.params.LIFParameters, state: jaxsnn.event.types.StepState, weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], spike_mask: jax.Array, recurrent_spike: bool, prev_layer_start: int)jaxsnn.event.types.StepState
jaxsnn.event.modules.leaky_integrate_and_fire.transition_without_recurrence(params: jaxsnn.base.params.LIFParameters, state: jaxsnn.event.types.StepState, weights: jaxsnn.event.types.WeightInput, spike_mask: jax.Array, recurrent_spike: bool, prev_layer_start: int)jaxsnn.event.types.StepState
jaxsnn.event.modules.leaky_integrate_and_fire.ttfs_solver(tau_mem: float, v_th: float, state: jaxsnn.event.types.LIFState, t_max: float)

Find the next spike time for special case $ au_mem = 2 * au_syn$

Args:

tau_mem float (float): Membrane time constant v_th float (float): Treshold Voltage state (LIFState): State of the neuron (voltage, current) t_max (float): maximum time which is to be searched

Returns:

float: Time of next threshhold crossing or t_max if no crossing