jaxsnn.event.training
Classes
|
|
|
|
|
Functions
-
jaxsnn.event.training.
data_loader
(dataset: Any, batch_size: int, rng: Optional[jax.Array, None] = None)
-
jaxsnn.event.training.
epoch
(update_fn: Callable, loss_fn: Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], Tuple[jaxsnn.event.types.EventPropSpike, jax.Array]], Tuple[float, Tuple[jax.Array, List[jaxsnn.event.types.EventPropSpike]]]], trainset, testset, opt_state: jaxsnn.event.types.OptState, i: int)
-
jaxsnn.event.training.
loss_and_acc
(loss_fn: Callable, weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], dataset: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array]) → jaxsnn.event.types.TestResult
-
jaxsnn.event.training.
time_it
(timed_function: Callable, *args) → Tuple[Any, float]
-
jaxsnn.event.training.
update
(optimizer, loss_fn: Callable, params: jaxsnn.base.params.LIFParameters, state: jaxsnn.event.types.OptState, batch: Tuple[jaxsnn.event.types.Spike, jax.Array]) → Tuple[jaxsnn.event.types.OptState, Tuple[jax.Array, jax.Array]]