jaxsnn.event.loss
Classes
|
|
alias of |
Functions
-
jaxsnn.event.loss.first_spike(spikes: jaxsnn.event.types.EventPropSpike, size: int, n_outputs: int) → jax.Array
-
jaxsnn.event.loss.max_over_time(output: jaxsnn.base.types.LIState) → jax.Array
-
jaxsnn.event.loss.mse_loss(first_spikes: jax.Array, target: jax.Array, tau_mem: float) → jax.Array
-
jaxsnn.event.loss.nll_loss(output: jax.Array, targets: jax.Array) → jax.Array
-
jaxsnn.event.loss.target_time_loss(first_spikes: jax.Array, target: jax.Array, tau_mem: float) → jax.Array
-
jaxsnn.event.loss.ttfs_loss(first_spikes: jax.Array, target: jax.Array, tau_mem: float) → jax.Array