jaxsnn.event.utils

Modules

jaxsnn.event.utils.filter

jaxsnn.event.utils.training

Classes

Any(*args, **kwargs)

Special type indicating an unconstrained type.

EventPropSpike(time, idx, current)

Spike(time, idx)

WeightInput(input,)

WeightRecurrent(input, recurrent)

Functions

jaxsnn.event.utils.bump_weights(weights: List[jaxsnn.event.types.WeightInput], recording: List[jaxsnn.event.types.Spike])List[jaxsnn.event.types.WeightInput]
jaxsnn.event.utils.clip_gradient(grads: List[jaxsnn.event.types.WeightInput])List[jaxsnn.event.types.WeightInput]
jaxsnn.event.utils.filter_spikes(input_spikes: jaxsnn.event.types.EventPropSpike, prev_layer_start: int)jaxsnn.event.types.EventPropSpike

Filters the input spikes by ensuring only the spikes from the previous layer are kept.

jaxsnn.event.utils.get_index_trainset(trainset, idx)
jaxsnn.event.utils.load_weights(filenames)List[jaxsnn.event.types.WeightInput]
jaxsnn.event.utils.load_weights_recurrent(folder: str)
jaxsnn.event.utils.save_weights(weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], folder: str)
jaxsnn.event.utils.save_weights_recurrent(weights: jaxsnn.event.types.WeightRecurrent, folder: str)
jaxsnn.event.utils.time_it(timed_function: Callable, *args)Tuple[Any, float]