jaxsnn.event.construct

Classes

WeightInput(input,)

WeightRecurrent(input, recurrent)

Functions

jaxsnn.event.construct.construct_init_fn(n_hidden: int, mean: float, std: float, duplication: Optional[int, None] = None)Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]]
jaxsnn.event.construct.construct_recurrent_init_fn(layers: List[int], mean: List[float], std: List[float], duplication: Optional[float, None] = None)Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]]