jaxsnn.event.modules.leaky_integrate_and_fire
Implement different LIF layers, which can be concatenated
Each layer returns a paif or two functions, the init function and the apply function. These functions can be concatenated with jaxsnn.event.compose.serial, which also returns and init/apply pair, consisting of multiple layers. The init function is used to initalize the weights of the network. The apply function does the inference and is equivalent to the forward function is in PyTorch. It receives the input spikes and weights of the network and returns the hidden spikes.
The layers in this module differ in the topology they offer (feed-forward / recurrent) and in the way the gradients are computed (analytical via jax.grad or with an adjoint system (EventProp).
HardwareLIF and HardwareRecurrentLIF allow the execution of the forward pass on the neuromorphic BSS-2 system. They forward pass is executed on the neuromorphic system and the spikes are retrived. Because the spike data from BSS-2 is missing information about the synaptic current at spike time (which is needed for the EventProp algorithm), a second forward pass in software is executed. The spike times from the hardware are used as solution for the root solving. The adjoint system of the EventProp algorithm is added as a custom Vector-Jacobian-Product (VJP).
Classes
|
|
partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords. |
Functions
-
jaxsnn.event.modules.leaky_integrate_and_fire.
EventPropLIF
(n_hidden: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean=0.5, std=2.0, wrap_only_step: bool = False, duplication: Optional[int, None] = None) → Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]] Feed-forward layer of LIF neurons with EventProp gradient computation.
- Args:
n_hidden (int): Number of hidden neurons n_spikes (int): Number of spikes which are simulated in this t_max (float): Maximum simulation time p (LIFParameters): Parameters of the LIF neurons mean (float, optional): Mean of initial weights. Defaults to 0.5. std (float, optional): Standard deviation of initial weights.
Defaults to 2.0.
- wrap_only_step (bool, optional): If custom vjp should be defined
only for the step function or for the entire trajectory. Defaults to False.
- duplication (Optional[int], optional): Factor with which input weights
are duplicated. Defaults to None.
- Returns:
SingleInitApply: Pair of init apply functions.
-
jaxsnn.event.modules.leaky_integrate_and_fire.
HardwareLIF
(n_hidden: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: float, std: float, duplication: Optional[int, None] = None) → Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike, jaxsnn.event.types.Spike], jaxsnn.event.types.EventPropSpike]]
-
jaxsnn.event.modules.leaky_integrate_and_fire.
HardwareRecurrentLIF
(layers: List[int], n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: List[float], std: List[float], duplication: Optional[int, None] = None)
-
jaxsnn.event.modules.leaky_integrate_and_fire.
LIF
(n_hidden: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: float = 0.5, std: float = 2.0, duplication: Optional[int, None] = None) → Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]] A feed-forward layer of LIF Neurons.
- Args:
n_hidden (int): Number of hidden neurons n_spikes (int): Number of spikes which are simulated in this layer t_max (float): Maxium simulation time p (LIFParameters): Parameters of the LIF neurons mean (float, optional): Mean of initial weights. Defaults to 0.5. std (float, optional): Standard deviation of initial weights.
Defaults to 2.0.
- Returns:
SingleInitApply: _description_
-
jaxsnn.event.modules.leaky_integrate_and_fire.
RecurrentEventPropLIF
(layers: List[int], n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: List[float], std: List[float], wrap_only_step: bool = False, duplication: Optional[int, None] = None) → Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]] Use quadrants of the recurrent weight matrix to set up a multi-layer feed-forward LIF in one recurrent layer.
When simulating multiple layers, the first layer needs to be fully simulated before the resulting spikes are passed to the next layer. When viewing multiple feed-forward layers as one recurrent layer with the only rectangular parts of the weight matrix initialized with non-zero entries, multiple feed-forward layers can be simulated together.
- Args:
layers (List[int]): Number of neurons in each feed-forward layer n_spikes (int): Number of spikes which are simulated in this t_max (float): Maximum simulation time p (LIFParameters): Parameters of the LIF neurons mean (float): Mean of initial weights. std (float): Standard deviation of initial weights. wrap_only_step (bool, optional): If custom vjp should be defined only
for the step function or for the entire trajectory. Defaults to False.
- duplication (Optional[int], optional): Factor with which input weights
are duplicated. Defaults to None.
- Returns:
SingleInitApply: Pair of init apply functions.
-
jaxsnn.event.modules.leaky_integrate_and_fire.
RecurrentLIF
(layers: List[int], n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: List[float], std: List[float], duplication: Optional[int, None] = None) → Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]]
-
jaxsnn.event.modules.leaky_integrate_and_fire.
adjoint_lif_exponential_flow
(params: jaxsnn.base.params.LIFParameters)
-
jaxsnn.event.modules.leaky_integrate_and_fire.
adjoint_transition_with_recurrence
(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
-
jaxsnn.event.modules.leaky_integrate_and_fire.
adjoint_transition_without_recurrence
(params: jaxsnn.base.params.LIFParameters, adjoint_state: jaxsnn.event.types.StepState, spike: jaxsnn.event.types.EventPropSpike, layer_start: int, adjoint_spike: jaxsnn.event.types.EventPropSpike, grads: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], input_queue_head: int)
-
jaxsnn.event.modules.leaky_integrate_and_fire.
construct_adjoint_apply_fn
(step_fn, step_fn_bwd, n_hidden, n_spikes, wrap_only_step=False)
-
jaxsnn.event.modules.leaky_integrate_and_fire.
construct_init_fn
(n_hidden: int, mean: float, std: float, duplication: Optional[int, None] = None) → Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]]
-
jaxsnn.event.modules.leaky_integrate_and_fire.
construct_recurrent_init_fn
(layers: List[int], mean: List[float], std: List[float], duplication: Optional[float, None] = None) → Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]]
-
jaxsnn.event.modules.leaky_integrate_and_fire.
lif_exponential_flow
(params: jaxsnn.base.params.LIFParameters)
-
jaxsnn.event.modules.leaky_integrate_and_fire.
next_event
(solver: Callable, neuron_state: jaxsnn.event.types.LIFState, time: float, t_max: float) → jax.Array Wrapper a root solver to provide a cleaner API for returning next event
- Args:
solver (Callable): The actual root solver neuron_state (LIFState): The state of the neurons time (float): Current time t_max (float): Maximum time of the simulation
- Returns:
Spike: Spike which will occur next
-
jaxsnn.event.modules.leaky_integrate_and_fire.
step
(dynamics: Callable, tr_dynamics: Callable, t_max: float, solver: Callable[[jaxsnn.event.types.LIFState, float, float], jaxsnn.event.types.Spike], step_input: Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], *args: int) → Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike] Find next spike (external or internal), and simulate to that point.
- Args:
dynamics (Callable): Function describing the continuous neuron dynamics tr_dynamics (Callable): Function describing the transition dynamics t_max (float): Max time until which to run solver (Solver): Parallel root solver which returns the next event state (StepInput): (StepState, weights, int)
- Returns:
Tuple[StepInput, Spike]: New state after transition and stored spike
-
jaxsnn.event.modules.leaky_integrate_and_fire.
step_bwd
(adjoint_dynamics: Callable, adjoint_tr_dynamics: Callable, t_max: float, res, g)
-
jaxsnn.event.modules.leaky_integrate_and_fire.
trajectory
(step_fn: Callable[[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], int], Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike]], n_hidden: int, n_spikes: int) → Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike] Evaluate the step_fn until n_spikes have been simulated.
Uses a scan over the step_fn to return an apply function
-
jaxsnn.event.modules.leaky_integrate_and_fire.
transition_with_recurrence
(params: jaxsnn.base.params.LIFParameters, state: jaxsnn.event.types.StepState, weights: Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], spike_mask: jax.Array, recurrent_spike: bool, prev_layer_start: int) → jaxsnn.event.types.StepState
-
jaxsnn.event.modules.leaky_integrate_and_fire.
transition_without_recurrence
(params: jaxsnn.base.params.LIFParameters, state: jaxsnn.event.types.StepState, weights: jaxsnn.event.types.WeightInput, spike_mask: jax.Array, recurrent_spike: bool, prev_layer_start: int) → jaxsnn.event.types.StepState
-
jaxsnn.event.modules.leaky_integrate_and_fire.
ttfs_solver
(tau_mem: float, v_th: float, state: jaxsnn.event.types.LIFState, t_max: float) Find the next spike time for special case $ au_mem = 2 * au_syn$
- Args:
tau_mem float (float): Membrane time constant v_th float (float): Treshold Voltage state (LIFState): State of the neuron (voltage, current) t_max (float): maximum time which is to be searched
- Returns:
float: Time of next threshhold crossing or t_max if no crossing