jaxsnn.discrete.loss

Functions

jaxsnn.discrete.loss.acc_and_loss(snn_apply, weights, batch, decoder)
jaxsnn.discrete.loss.nll_loss(snn_apply, weights, batch, decoder, expected_spikes=0.5, rho=0.0001)Tuple[float, jax.Array]
jaxsnn.discrete.loss.one_hot(x, k, dtype=<class 'jax.numpy.float32'>)

Create a one-hot encoding of x of size k.