jaxsnn.event.stepping

Modules

jaxsnn.event.stepping.step(dynamics, …)

Find next spike (external or internal), and simulate to that point.

jaxsnn.event.stepping.step_existing_events

jaxsnn.event.stepping.types

Functions

jaxsnn.event.stepping.step(dynamics: Callable, tr_dynamics: Callable, t_max: float, solver: Callable[[jaxsnn.event.types.LIFState, float, float], jaxsnn.event.types.Spike], step_input: Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], *args: int)Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike]

Find next spike (external or internal), and simulate to that point.

Args:

dynamics (Callable): Function describing the continuous neuron dynamics tr_dynamics (Callable): Function describing the transition dynamics t_max (float): Max time until which to run solver (Solver): Parallel root solver which returns the next event state (StepInput): (StepState, weights, int)

Returns:

Tuple[StepInput, Spike]: New state after transition and stored spike

jaxsnn.event.stepping.step_existing(dynamics: Callable, tr_dynamics: Callable, t_max: float, event_stepper: Callable[[jaxsnn.event.types.LIFState, float, float], jaxsnn.event.types.Spike], step_input: Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], *args: int)Tuple[Tuple[jaxsnn.event.types.StepState, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], int], jaxsnn.event.types.EventPropSpike]

Find next spike (external or internal), and simulate to that point.

Parameters
  • dynamics – Function describing the continuous neuron dynamics

  • tr_dynamics – Function describing the transition dynamics

  • t_max – Max time until which to run

  • solver – Parallel root solver which returns the next event

  • state – (StepState, weights, int)

Returns

New state after transition and stored spike