jaxsnn.event.root.ttfs
Analytically find the time of the next spike for a LIF neuron for the special cases of $ au_mem = au_syn$ and $ au_mem = 2 * au_syn$.
When using jax.vmap to do this root solving in parallel, jax.lax.cond is mapped to jax.lax.switch, meaning that both branches are executed. Therefore, special care is taken to ensure that no NaNs occur, which would affect gradient calculation.
Classes
|
|
partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords. |
Functions
-
jaxsnn.event.root.ttfs.
lambertw
(x, k=0, max_steps=5)
-
jaxsnn.event.root.ttfs.
ttfs_ratio1_inner
(a_1: jax.Array, b: jax.Array, w_arg: jax.Array, tau_mem: float, t_max: float)
-
jaxsnn.event.root.ttfs.
ttfs_ratio1_inner_most
(a_1: jax.Array, b: jax.Array, w_arg: jax.Array, tau_mem: float, t_max: float)
-
jaxsnn.event.root.ttfs.
ttfs_ratio2_inner
(a_1: jax.Array, a_2: jax.Array, second_term: jax.Array, tau_mem: float, t_max: float)
-
jaxsnn.event.root.ttfs.
ttfs_ratio2_inner_most
(a_1: jax.Array, denominator: jax.Array, tau_mem: float, t_max: float) → jax.Array
-
jaxsnn.event.root.ttfs.
ttfs_solver
(tau_mem: float, tau_syn: float, v_th: float, state: jaxsnn.event.types.LIFState, t_max: float) Find the next spike time for special cases $ au_mem = au_syn$ and $ au_mem = 2 * au_syn$
- Args:
tau_mem (float): Membrane time constant tau_syn (float): Synaptic time constant v_th (float): Threshold Voltage state (LIFState): State of the neuron (voltage, current) t_max (float): Maximum time which is to be searched
- Returns:
float: Time of next threshold crossing or t_max if no crossing