jaxsnn.event.loss

Classes

EventPropSpike(time, idx, current)

LIFState(V, I)

Spike(time, idx)

TestResult(loss, accuracy, t_first_spike, …)

Functions

jaxsnn.event.loss.first_spike(spikes: jaxsnn.event.types.EventPropSpike, size: int, n_outputs: int)jax.Array
jaxsnn.event.loss.loss_and_acc(loss_fn: Callable, weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], dataset: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array])jaxsnn.event.types.TestResult
jaxsnn.event.loss.loss_and_acc_scan(loss_fn: Callable, weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], dataset: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array])jaxsnn.event.types.TestResult
jaxsnn.event.loss.loss_wrapper(apply_fn: Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], jaxsnn.event.types.EventPropSpike], List[jaxsnn.event.types.EventPropSpike]], loss_fn: Callable[[jax.Array, jax.Array, float], float], tau_mem: float, n_neurons: int, n_outputs: int, weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], batch: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array], vmap: bool = True, external: Optional[List[jaxsnn.event.types.Spike], None] = None, carry: Optional[Any, None] = None)Tuple[float, Tuple[jax.Array, List[jaxsnn.event.types.EventPropSpike]]]
jaxsnn.event.loss.max_over_time(output: jaxsnn.event.types.LIFState)jax.Array
jaxsnn.event.loss.max_over_time_loss(apply_fn: Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], jaxsnn.event.types.EventPropSpike], List[jaxsnn.event.types.EventPropSpike]], weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], batch: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array])Tuple[Union[jax.Array, numpy.ndarray, float], Tuple[float, List[jaxsnn.event.types.EventPropSpike]]]
jaxsnn.event.loss.mse_loss(first_spikes: jax.Array, target: jax.Array, tau_mem: float)float
jaxsnn.event.loss.nll_loss(output: jax.Array, targets: jax.Array)float
jaxsnn.event.loss.target_time_loss(first_spikes: jax.Array, target: jax.Array, tau_mem: float)float
jaxsnn.event.loss.ttfs_loss(first_spikes: jax.Array, target: jax.Array, tau_mem: float)float