jaxsnn.discrete.loss

Functions

jaxsnn.discrete.loss.nll_loss(predictions: jax.Array, targets: jax.Array)jax.Array
jaxsnn.discrete.loss.one_hot(x, k, dtype=<class 'jax.numpy.float32'>)

Create a one-hot encoding of x of size k.