jaxsnn.event.loss
Classes
|
|
|
|
|
|
|
Functions
-
jaxsnn.event.loss.
first_spike
(spikes: jaxsnn.event.types.EventPropSpike, size: int, n_outputs: int) → jax.Array
-
jaxsnn.event.loss.
loss_and_acc
(loss_fn: Callable, weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], dataset: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array]) → jaxsnn.event.types.TestResult
-
jaxsnn.event.loss.
loss_and_acc_scan
(loss_fn: Callable, weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], dataset: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array]) → jaxsnn.event.types.TestResult
-
jaxsnn.event.loss.
loss_wrapper
(apply_fn: Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], jaxsnn.event.types.EventPropSpike], List[jaxsnn.event.types.EventPropSpike]], loss_fn: Callable[[jax.Array, jax.Array, float], float], tau_mem: float, n_neurons: int, n_outputs: int, weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], batch: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array], vmap: bool = True, external: Optional[List[jaxsnn.event.types.Spike], None] = None, carry: Optional[Any, None] = None) → Tuple[float, Tuple[jax.Array, List[jaxsnn.event.types.EventPropSpike]]]
-
jaxsnn.event.loss.
max_over_time
(output: jaxsnn.event.types.LIFState) → jax.Array
-
jaxsnn.event.loss.
max_over_time_loss
(apply_fn: Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], jaxsnn.event.types.EventPropSpike], List[jaxsnn.event.types.EventPropSpike]], weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], batch: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array]) → Tuple[Union[jax.Array, numpy.ndarray, float], Tuple[float, List[jaxsnn.event.types.EventPropSpike]]]
-
jaxsnn.event.loss.
mse_loss
(first_spikes: jax.Array, target: jax.Array, tau_mem: float) → float
-
jaxsnn.event.loss.
nll_loss
(output: jax.Array, targets: jax.Array) → float
-
jaxsnn.event.loss.
target_time_loss
(first_spikes: jax.Array, target: jax.Array, tau_mem: float) → float
-
jaxsnn.event.loss.
ttfs_loss
(first_spikes: jax.Array, target: jax.Array, tau_mem: float) → float