jaxsnn.base.compose
Classes
|
Functions
-
jaxsnn.base.compose.
serial
(*layers: Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]]) → Tuple[Callable[[jax.Array, int], List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], jaxsnn.event.types.EventPropSpike], List[jaxsnn.event.types.EventPropSpike]]] Concatenate multiple layers of init/apply functions
- Returns:
InitApply: Init/apply pair