jaxsnn.event.training
Classes
|
|
|
Container for the optimizer/training state across steps. |
|
Functions
-
jaxsnn.event.training.data_loader(dataset: Tuple[Any, Any], batch_size: int, num_batches: Optional[int, None] = None, rng: Optional[jax.Array, None] = None)
-
jaxsnn.event.training.epoch(update_fn: Callable, loss_fn: Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], Tuple[jaxsnn.event.types.EventPropSpike, jax.Array]], Tuple[float, Tuple[jax.Array, List[jaxsnn.event.types.EventPropSpike]]]], trainset, testset, opt_state: jaxsnn.event.types.OptState, i: int)
-
jaxsnn.event.training.loss_and_acc(loss_fn: Callable, weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], dataset: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array]) → jaxsnn.event.types.TestResult
-
jaxsnn.event.training.time_it(timed_function: Callable, *args) → Tuple[Any, float]
-
jaxsnn.event.training.update(optimizer, loss_fn: Callable, params: jaxsnn.base.params.LIFParameters, state: jaxsnn.event.types.OptState, batch: Tuple[jaxsnn.event.types.Spike, jax.Array]) → Tuple[jaxsnn.event.types.OptState, Tuple[jax.Array, jax.Array]]