jaxsnn.event.root.ttfs
Analytically find the time of the next spike for a LIF neuron for the special case of $ au_mem = 2 * au_syn$.
When using jax.vmap to do this root solving in parallel, jax.lax.cond is mapped to jax.lax.switch, meaning that both branches are executed. Therefore, special care is taken to ensure that no NaNs occur, which would affect gradient calculation.
Classes
|
Functions
-
jaxsnn.event.root.ttfs.
ttfs_inner
(a_1: jax.Array, a_2: jax.Array, second_term: jax.Array, tau_mem: float, t_max: float)
-
jaxsnn.event.root.ttfs.
ttfs_inner_most
(a_1: jax.Array, denominator: jax.Array, tau_mem: float, t_max: float) → jax.Array
-
jaxsnn.event.root.ttfs.
ttfs_solver
(tau_mem: float, v_th: float, state: jaxsnn.event.types.LIFState, t_max: float) Find the next spike time for special case $ au_mem = 2 * au_syn$
- Args:
tau_mem float (float): Membrane time constant v_th float (float): Treshold Voltage state (LIFState): State of the neuron (voltage, current) t_max (float): maximum time which is to be searched
- Returns:
float: Time of next threshhold crossing or t_max if no crossing