jaxsnn

Modules

jaxsnn.base

jaxsnn.discrete

jaxsnn.event

Classes

ConversionConfig(t_max, n_spikes, int], …)

Configuration for the conversion from NIR to jaxsnn.

Functions

jaxsnn.from_nir(graph: nir.ir.graph.NIRGraph, config: jaxsnn.event.from_nir.ConversionConfig)

Convert NIRGraph to jax-snn representation (init_fn, apply_fn)

Restrictions for NIRGraph: - Only linear feed-forward SNNs are supported - CubaLIF and Linear layers are supported - Affine layers with bias==0 are currently supported - In terms of parameters, only homogeneous layers are supported - The analytical solver is only supported for non-external inputs

Example: `python nir_graph = nir.NIRGraph(...) cfg = jaxsnn.ConversionConfig(...) init, apply = jaxsnn.from_nir(nir_graph, cfg) `

jaxsnn.from_nir_data(nir_graph_data: nir.data_ir.graph.NIRGraphData, jaxsnn_model, observables=('spikes'))Dict[str, jaxsnn.event.types.EventPropSpike]

Convert NIRGraphData to a dict of EventPropSpikes (jax-snn representation)

Parameters
  • nir_graph_data – NIRGraphData to be converted.

  • jaxsnn_model – jaxsnn model tuple (init, apply).

  • observables – Observables to be converted, by default (‘spikes’,)

jaxsnn.get_logger(name: str)
jaxsnn.to_nir_data(jaxsnn_dict: Dict[str, jaxsnn.event.types.EventPropSpike], jaxsnn_model, observables=('spikes'))nir.data_ir.graph.NIRGraphData

Convert a dict of EventPropSpikes (jax-snn representation) to NIRGraphData.

Parameters
  • jaxsnn_dict – Dictionary of Spike objects where each entry represents the spikes for a corresponding node of the jaxsnn_model. Empty events in jaxsnn are encoded by idx = -1 and time = 2 * t_max.

  • jaxsnn_model – A tuple of (init_fn, apply_fn). For the apply function, the arguments apply.nodes (which holds the layer sizes) and apply.t_max (the simulation time) currently have to be defined manually.

  • observables – Observables to be converted, by default (‘spikes’,)