jaxsnn.event.types.OptState
-
class
jaxsnn.event.types.OptState(opt_state: optax.OptState, weights: List[Weight], rng: jax.random.KeyArray) Bases:
NamedTupleContainer for the optimizer/training state across steps.
This immutable state groups the underlying optimizer’s internal state, the current set of model weights, and the JAX PRNG key used for stochastic operations (e.g., sampling).
- Parameters
opt_state – Optimizer-specific internal state to carry across updates (e.g., from optax.init/optax.update).
weights – Ordered collection of learnable model parameters to be optimized.
rng – JAX PRNG key used for randomized computations; should be split and updated between steps.
-
__init__() Initialize self. See help(type(self)) for accurate signature.
Methods
Attributes
Alias for field number 0
Alias for field number 2
Alias for field number 1
-
opt_state: optax.OptState Alias for field number 0
-
rng: jax.random.KeyArray Alias for field number 2
-
weights: List[Weight] Alias for field number 1