jaxsnn.event.types.OptState

class jaxsnn.event.types.OptState(opt_state: optax.OptState, weights: List[Weight], rng: jax.random.KeyArray)

Bases: NamedTuple

Container for the optimizer/training state across steps.

This immutable state groups the underlying optimizer’s internal state, the current set of model weights, and the JAX PRNG key used for stochastic operations (e.g., sampling).

Parameters
  • opt_state – Optimizer-specific internal state to carry across updates (e.g., from optax.init/optax.update).

  • weights – Ordered collection of learnable model parameters to be optimized.

  • rng – JAX PRNG key used for randomized computations; should be split and updated between steps.

__init__()

Initialize self. See help(type(self)) for accurate signature.

Methods

Attributes

opt_state

Alias for field number 0

rng

Alias for field number 2

weights

Alias for field number 1

opt_state: optax.OptState

Alias for field number 0

rng: jax.random.KeyArray

Alias for field number 2

weights: List[Weight]

Alias for field number 1