hxtorch.spiking.functional.lif

Leaky-integrate and fire neurons

Classes

Unterjubel(*args, **kwargs)

Unterjubel hardware observables to allow correct gradient flow

Functions

hxtorch.spiking.functional.lif.cuba_lif_integration(input: torch.Tensor, *, leak: torch.Tensor, reset: torch.Tensor, threshold: torch.Tensor, tau_syn: torch.Tensor, tau_mem: torch.Tensor, method: torch.Tensor, alpha: torch.Tensor, hw_data: Optional[torch.Tensor] = None, dt: float = 1e-06)Tuple[torch.Tensor, ]

Leaky-integrate and fire neuron integration for realization of simple spiking neurons with exponential synapses. Integrates according to:

i^{t+1} = i^t * (1 - dt / au_{syn}) + x^t v^{t+1} = dt / au_{men} * (v_l - v^t + i^t) + v^t z^{t+1} = 1 if v^{t+1} > params.threshold v^{t+1} = params.reset if z^{t+1} == 1

Assumes i^0, v^0 = 0, v_leak :note: One dt synaptic delay between input and output

TODO: Issue 3992

Parameters
  • input – Tensor holding ‘graded_spikes’ in shape (batch, time, neurons).

  • leak – The leak voltage as torch.Tensor.

  • reset – The reset voltage as torch.Tensor.

  • threshold – The threshold voltage as torch.Tensor.

  • tau_syn – The synaptic time constant as torch.Tensor.

  • tau_mem – The membrane time constant as torch.Tensor.

  • method – The method used for the surrogate gradient, e.g., ‘superspike’.

  • alpha – The slope of the surrogate gradient in case of ‘superspike’.

  • hw_data – An optional tuple holding optional hardware observables in the order (spikes, membrane_cadc, membrane_madc).

  • dt – Integration step width.

Returns

Returns tuple holding tensors with membrane traces, spikes and synaptic current. Tensors are of shape (batch, time, neurons).

hxtorch.spiking.functional.lif.cuba_lif_step(z: torch.Tensor, v: torch.Tensor, i: torch.Tensor, input: torch.Tensor, spikes_hw: torch.Tensor, membrane_hw: torch.Tensor, *, leak: torch.Tensor, reset: torch.Tensor, threshold: torch.Tensor, tau_syn: torch.Tensor, tau_mem: torch.Tensor, method: torch.Tensor, alpha: torch.Tensor, dt: float = 1e-06)Tuple[torch.Tensor, ]

Integrate the membrane of a neurons one time step further according to the Leaky-integrate and fire dynamics.

Parameters
  • z – The spike tensor at time step t.

  • v – The membrane tensor at time step t.

  • i – The current tensor at time step t.

  • input – The input tensor at time step t (graded spikes).

  • spikes_hw – The hardware spikes corresponding to the current time step. In case this is None, no HW spikes will be injected.

  • membrane_hw – The hardware CADC traces corresponding to the current time step. In case this is None, no HW CADC values will be injected.

  • leak – The leak voltage as torch.Tensor.

  • reset – The reset voltage as torch.Tensor.

  • threshold – The threshold voltage as torch.Tensor.

  • tau_syn – The synaptic time constant as torch.Tensor.

  • tau_mem – The membrane time constant as torch.Tensor.

  • method – The method used for the surrogate gradient, e.g., ‘superspike’.

  • alpha – The slope of the surrogate gradient in case of ‘superspike’.

  • dt – Integration step width.

Returns

Returns a tuple (z, v, i) holding the tensors of time step t + 1.

hxtorch.spiking.functional.lif.cuba_refractory_lif_integration(input: torch.Tensor, *, leak: torch.Tensor, reset: torch.Tensor, threshold: torch.Tensor, tau_syn: torch.Tensor, tau_mem: torch.Tensor, refractory_time: torch.Tensor, method: torch.Tensor, alpha: torch.Tensor, hw_data: Optional[torch.Tensor] = None, dt: float = 1e-06)Tuple[torch.Tensor, ]

Leaky-integrate and fire neuron integration for realization of simple spiking neurons with exponential synapses and refractory period.

Integrates according to:

i^{t+1} = i^t * (1 - dt / au_{syn}) + x^t v^{t+1} = dt / au_{men} * (v_l - v^t + i^{t+1}) + v^t z^{t+1} = 1 if v^{t+1} > params.v_th v^{t+1} = params.v_reset if z^{t+1} == 1 or ref^{t+1} > 0 ref^{t+1} = params.tau_ref ref^{t+1} -= 1

Assumes i^0, v^0 = 0.

Parameters
  • input – Tensor holding ‘graded_spikes’ in shape (batch, time, neurons).

  • leak – The leak voltage as torch.Tensor.

  • reset – The reset voltage as torch.Tensor.

  • threshold – The threshold voltage as torch.Tensor.

  • tau_syn – The synaptic time constant as torch.Tensor.

  • tau_mem – The membrane time constant as torch.Tensor.

  • refractory_time – The refractory time constant as torch.Tensor.

  • method – The method used for the surrogate gradient, e.g., ‘superspike’.

  • alpha – The slope of the surrogate gradient in case of ‘superspike’.

  • hw_data – An optional tuple holding optional hardware observables in the order (spikes, membrane_cadc, membrane_madc).

  • dt – Integration step width.

Returns

Returns tuple holding tensors with membrane traces, spikes and synaptic current. Tensors are of shape (batch, time, neurons).

hxtorch.spiking.functional.lif.refractory_update(z: torch.Tensor, v: torch.Tensor, ref_state: torch._VariableFunctionsClass.tensor, spikes_hw: torch.Tensor, membrane_hw: torch.Tensor, *, reset: torch.Tensor, refractory_time: torch.Tensor, dt: float)Tuple[torch.Tensor, ]

Update neuron membrane and spikes to account for refractory period. This implemention is widly adopted from: https://github.com/norse/norse/blob/main/norse/torch/functional/lif_refrac.py

Parameters
  • z – The spike tensor at time step t.

  • v – The membrane tensor at time step t.

  • ref_state – The refractory state holding the number of time steps the neurons has to remain in the refractory period.

  • spikes_hw – The hardware spikes corresponding to the current time step. In case this is None, no HW spikes will be injected.

  • membrnae_hw – The hardware CADC traces corresponding to the current time step. In case this is None, no HW CADC values will be injected.

  • reset – The reset voltage as torch.Tensor.

  • refractory_time – The refractory time constant as torch.Tensor.

  • dt – Integration step width.

Returns

Returns a tuple (z, v, ref_state) holding the tensors of time step t.

hxtorch.spiking.functional.lif.spiking_threshold(input: torch.Tensor, method: str, alpha: float)torch.Tensor

Selection of the used threshold function. :param input: Input tensor to threshold function. :param method: The string indicator of the the threshold function.

Currently supported: ‘super_spike’.

Parameters

alpha – Parameter controlling the slope of the surrogate derivative in case of ‘superspike’.

Returns

Returns the tensor of the threshold function.