jaxsnn.discrete.loss
Functions
-
jaxsnn.discrete.loss.acc_and_loss(snn_apply, weights, batch, decoder)
-
jaxsnn.discrete.loss.nll_loss(snn_apply, weights, batch, decoder, expected_spikes=0.5, rho=0.0001) → Tuple[float, jax.Array]
-
jaxsnn.discrete.loss.one_hot(x, k, dtype=<class 'jax.numpy.float32'>) Create a one-hot encoding of x of size k.