jaxsnn.event.functional
Classes
|
|
|
|
|
|
|
Functions
-
jaxsnn.event.functional.
filter_spikes
(input_spikes: jaxsnn.event.types.EventPropSpike, prev_layer_start: int) → jaxsnn.event.types.EventPropSpike Filters the input spikes by ensuring only the spikes from the previous layer are kept.
-
jaxsnn.event.functional.
step
(dynamics: Callable, tr_dynamics: Callable, t_max: float, solver: Callable[[jaxsnn.event.types.LIFState, float, float], jaxsnn.event.types.Spike], step_input: Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], *args: int) → Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike] Find next spike (external or internal), and simulate to that point.
- Args:
dynamics (Callable): Function describing the continuous neuron dynamics tr_dynamics (Callable): Function describing the transition dynamics t_max (float): Max time until which to run solver (Solver): Parallel root solver which returns the next event state (StepInput): (StepState, weights, int)
- Returns:
Tuple[StepInput, Spike]: New state after transition and stored spike
-
jaxsnn.event.functional.
trajectory
(step_fn: Callable[[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], int], Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike]], n_hidden: int, n_spikes: int) → Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike] Evaluate the step_fn until n_spikes have been simulated.
Uses a scan over the step_fn to return an apply function