jaxsnn.discrete.loss
Functions
-
jaxsnn.discrete.loss.
acc_and_loss
(snn_apply, weights, batch, decoder)
-
jaxsnn.discrete.loss.
nll_loss
(snn_apply, weights, batch, decoder, expected_spikes=0.5, rho=0.0001) → Tuple[float, jax.Array]
-
jaxsnn.discrete.loss.
one_hot
(x, k, dtype=<class 'jax.numpy.float32'>) Create a one-hot encoding of x of size k.