jaxsnn.event.utils.training

Classes

Any(*args, **kwargs)

Special type indicating an unconstrained type.

Spike(time, idx)

WeightInput(input,)

WeightRecurrent(input, recurrent)

Functions

jaxsnn.event.utils.training.bump_weights(weights: List[jaxsnn.event.types.WeightInput], recording: List[jaxsnn.event.types.Spike])List[jaxsnn.event.types.WeightInput]
jaxsnn.event.utils.training.clip_gradient(grads: List[jaxsnn.event.types.WeightInput])List[jaxsnn.event.types.WeightInput]
jaxsnn.event.utils.training.get_index_trainset(trainset, idx)
jaxsnn.event.utils.training.load_weights(filenames)List[jaxsnn.event.types.WeightInput]
jaxsnn.event.utils.training.load_weights_recurrent(folder: str)
jaxsnn.event.utils.training.save_weights(weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], folder: str)
jaxsnn.event.utils.training.save_weights_recurrent(weights: jaxsnn.event.types.WeightRecurrent, folder: str)
jaxsnn.event.utils.training.time_it(timed_function: Callable, *args)Tuple[Any, float]