jaxsnn.event.hardware.utils

Classes

Spike(time, idx)

WeightInput(input,)

WeightRecurrent(input, recurrent)

partial

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

Functions

jaxsnn.event.hardware.utils.add_linear_noise(spike: jaxsnn.event.types.Spike)jaxsnn.event.types.Spike
jaxsnn.event.hardware.utils.add_noise_batch(spikes: jaxsnn.event.types.Spike, rng: jax._src.random.PRNGKey, std: float = 1e-07, bias: float = 1e-07)jaxsnn.event.types.Spike
jaxsnn.event.hardware.utils.cut_spikes(spikes: jaxsnn.event.types.Spike, count)
jaxsnn.event.hardware.utils.cut_spikes_batch(spikes: jaxsnn.event.types.Spike, count)
jaxsnn.event.hardware.utils.filter_spikes(spikes: jaxsnn.event.types.Spike, layer_start: int, layer_end: Optional[int, None] = None)

Only return spikes of neurons after layer start

Other spikes are encoded with time=np.inf and index=-1

jaxsnn.event.hardware.utils.filter_spikes_batch(spikes: jaxsnn.event.types.Spike, layer_start: int, layer_end: Optional[int, None] = None)

Only return spikes of neurons after layer start

Other spikes are encoded with time=np.inf and index=-1

jaxsnn.event.hardware.utils.first_spike(spikes: jaxsnn.event.types.Spike, start: int, stop: int)jax.Array
jaxsnn.event.hardware.utils.first_spike_batch(spikes: jaxsnn.event.types.Spike, start: int, stop: int)jax.Array

Vectorized version of first_spike. Takes similar arguments as first_spike but with additional array axes over which first_spike is mapped.

jaxsnn.event.hardware.utils.linear_saturating(weight: jax.Array, scale: float, min_weight: float = - 63.0, max_weight: float = 63.0, as_int: bool = True)jax.Array

Scale all weights according to:

w <- clip(scale * w, min_weight, max_weight)

Parameters
  • weight – The weight array to be transformed.

  • scale – A constant the weight array is scaled with.

  • min_weight – The minimum value, smaller values are clipped to after scaling.

  • max_weight – The maximum value, bigger values are clipped to after scaling.

  • as_int – Round to nearest int and return as int type.

Returns

The transformed weight tensor.

jaxsnn.event.hardware.utils.simulate_hw_weights(weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], scale: float, as_int: bool = False)List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]
jaxsnn.event.hardware.utils.simulate_madc(tau_mem: float, tau_syn: float, inputs: jaxsnn.event.types.Spike, weight: float, ts: jax.Array)
jaxsnn.event.hardware.utils.sort_batch(spikes: jaxsnn.event.types.Spike)jaxsnn.event.types.Spike
jaxsnn.event.hardware.utils.spike_similarity_batch(spike1: jaxsnn.event.types.Spike, spike2: jaxsnn.event.types.Spike)
jaxsnn.event.hardware.utils.spike_to_grenade_input(spike: jaxsnn.event.types.Spike, input_neurons: int)

Convert jaxsnn spike representation to grenade spike representation

We represent spikes as tuple of index and time. An instance of grenade.InputGenerator expects a python list with shape [badge, neuron_idx, spike_time]

TODO move to hxtorch / grenade