jaxsnn.base.dataset.linear

Functions

jaxsnn.base.dataset.linear.linear_dataset(rng: jax.Array, size: int, mirror: bool, bias_spike: Optional[float, None])Tuple[jax.Array, jax.Array]