jaxsnn.event.training

Classes

Any(*args, **kwargs)

Special type indicating an unconstrained type.

LIFParameters(tau_syn, tau_mem, v_th, …)

OptState(opt_state, weights, rng)

Container for the optimizer/training state across steps.

Spike(time, idx)

Functions

jaxsnn.event.training.data_loader(dataset: Tuple[Any, Any], batch_size: int, num_batches: Optional[int, None] = None, rng: Optional[jax.Array, None] = None)
jaxsnn.event.training.epoch(update_fn: Callable, test_fn: Callable[[List[jax.Array], Tuple[jax.Array, jax.Array]], Tuple[Any, str]], trainset, testset, opt_state: jaxsnn.event.types.OptState, i: int)
jaxsnn.event.training.time_it(timed_function: Callable, *args)Tuple[Any, float]
jaxsnn.event.training.update(optimizer, loss_fn: Callable, params: jaxsnn.base.params.LIFParameters, state: jaxsnn.event.types.OptState, batch: Tuple[jaxsnn.event.types.Spike, jax.Array])Tuple[jaxsnn.event.types.OptState, Tuple[jax.Array, jax.Array]]