jaxsnn.base.compose

Classes

EventPropSpike(time, idx, current)

Functions

jaxsnn.base.compose.serial(*layers: Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]])Tuple[Callable[[jax.Array, int], List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], jaxsnn.event.types.EventPropSpike], List[jaxsnn.event.types.EventPropSpike]]]

Concatenate multiple layers of init/apply functions

Returns:

InitApply: Init/apply pair