jaxsnn.event.training

Classes

LIFParameters(tau_syn, tau_mem, v_th, …)

OptState(opt_state, weights)

Spike(time, idx)

Functions

jaxsnn.event.training.data_loader(dataset: Any, batch_size: int, rng: Optional[jax.Array, None] = None)
jaxsnn.event.training.epoch(update_fn: Callable, loss_fn: Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], Tuple[jaxsnn.event.types.EventPropSpike, jax.Array]], Tuple[float, Tuple[jax.Array, List[jaxsnn.event.types.EventPropSpike]]]], trainset, testset, opt_state: jaxsnn.event.types.OptState, i: int)
jaxsnn.event.training.loss_and_acc(loss_fn: Callable, weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], dataset: Tuple[jaxsnn.event.types.EventPropSpike, jax.Array])jaxsnn.event.types.TestResult
jaxsnn.event.training.time_it(timed_function: Callable, *args)Tuple[Any, float]
jaxsnn.event.training.update(optimizer, loss_fn: Callable, params: jaxsnn.base.params.LIFParameters, state: jaxsnn.event.types.OptState, batch: Tuple[jaxsnn.event.types.Spike, jax.Array])Tuple[jaxsnn.event.types.OptState, Tuple[jax.Array, jax.Array]]