jaxsnn.event.to_nir_data

Translate spikes from a jaxsnn spike representation to NIRGraphData.

Classes

EventData(idx, time, n_neurons, t_max)

Event-based data represented as a list of event indices and their corresponding timestamps.

EventPropSpike(time, idx, current)

NIRGraphData(nodes, …)

Dictionary of NIRNodeData where each entry represents a NIRNode of a corresponding NIRGraph with its observables.

NIRNodeData(observables, …)

Dictionary of EventData or TimeGriddedData where each entry represents an observable (e.g., spikes, voltages) of a corresponding NIRNode

ValuedEventData(idx, time, n_neurons, t_max, …)

Valued event-based data as a list of event indices, event times and event values.

Functions

jaxsnn.event.to_nir_data.to_nir_data(jaxsnn_dict: Dict[str, jaxsnn.event.types.EventPropSpike], jaxsnn_model, observables=('spikes'))nir.data_ir.graph.NIRGraphData

Convert a dict of EventPropSpikes (jax-snn representation) to NIRGraphData.

Parameters
  • jaxsnn_dict – Dictionary of Spike objects where each entry represents the spikes for a corresponding node of the jaxsnn_model. Empty events in jaxsnn are encoded by idx = -1 and time = 2 * t_max.

  • jaxsnn_model – A tuple of (init_fn, apply_fn). For the apply function, the arguments apply.nodes (which holds the layer sizes) and apply.t_max (the simulation time) currently have to be defined manually.

  • observables – Observables to be converted, by default (‘spikes’,)