jaxsnn.event.hardware.utils
Classes
|
|
|
|
|
|
partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords. |
Functions
-
jaxsnn.event.hardware.utils.
add_linear_noise
(spike: jaxsnn.event.types.Spike) → jaxsnn.event.types.Spike
-
jaxsnn.event.hardware.utils.
add_noise_batch
(spikes: jaxsnn.event.types.Spike, rng: jax._src.random.PRNGKey, std: float = 1e-07, bias: float = 1e-07) → jaxsnn.event.types.Spike
-
jaxsnn.event.hardware.utils.
cut_spikes
(spikes: jaxsnn.event.types.Spike, count)
-
jaxsnn.event.hardware.utils.
cut_spikes_batch
(spikes: jaxsnn.event.types.Spike, count)
-
jaxsnn.event.hardware.utils.
filter_spikes
(spikes: jaxsnn.event.types.Spike, layer_start: int, layer_end: Optional[int, None] = None) Only return spikes of neurons after layer start
Other spikes are encoded with time=np.inf and index=-1
-
jaxsnn.event.hardware.utils.
filter_spikes_batch
(spikes: jaxsnn.event.types.Spike, layer_start: int, layer_end: Optional[int, None] = None) Only return spikes of neurons after layer start
Other spikes are encoded with time=np.inf and index=-1
-
jaxsnn.event.hardware.utils.
first_spike
(spikes: jaxsnn.event.types.Spike, start: int, stop: int) → jax.Array
-
jaxsnn.event.hardware.utils.
first_spike_batch
(spikes: jaxsnn.event.types.Spike, start: int, stop: int) → jax.Array Vectorized version of first_spike. Takes similar arguments as first_spike but with additional array axes over which first_spike is mapped.
-
jaxsnn.event.hardware.utils.
linear_saturating
(weight: jax.Array, scale: float, min_weight: float = - 63.0, max_weight: float = 63.0, as_int: bool = True) → jax.Array Scale all weights according to:
w <- clip(scale * w, min_weight, max_weight)
- Parameters
weight – The weight array to be transformed.
scale – A constant the weight array is scaled with.
min_weight – The minimum value, smaller values are clipped to after scaling.
max_weight – The maximum value, bigger values are clipped to after scaling.
as_int – Round to nearest int and return as int type.
- Returns
The transformed weight tensor.
-
jaxsnn.event.hardware.utils.
simulate_hw_weights
(weights: List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], scale: float, as_int: bool = False) → List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]
-
jaxsnn.event.hardware.utils.
simulate_madc
(tau_mem: float, tau_syn: float, inputs: jaxsnn.event.types.Spike, weight: float, ts: jax.Array)
-
jaxsnn.event.hardware.utils.
sort_batch
(spikes: jaxsnn.event.types.Spike) → jaxsnn.event.types.Spike
-
jaxsnn.event.hardware.utils.
spike_similarity_batch
(spike1: jaxsnn.event.types.Spike, spike2: jaxsnn.event.types.Spike)
-
jaxsnn.event.hardware.utils.
spike_to_grenade_input
(spike: jaxsnn.event.types.Spike, input_neurons: int) Convert jaxsnn spike representation to grenade spike representation
We represent spikes as tuple of index and time. An instance of grenade.InputGenerator expects a python list with shape [badge, neuron_idx, spike_time]
TODO move to hxtorch / grenade