jaxsnn
Modules
Classes
|
Configuration for the conversion from NIR to jaxsnn. |
Functions
-
jaxsnn.from_nir(graph: nir.ir.graph.NIRGraph, config: jaxsnn.event.from_nir.ConversionConfig) Convert NIRGraph to jax-snn representation (init_fn, apply_fn)
Restrictions for NIRGraph: - Only linear feed-forward SNNs are supported - CubaLIF and Linear layers are supported - Affine layers with bias==0 are currently supported - In terms of parameters, only homogeneous layers are supported - The analytical solver is only supported for non-external inputs
Example:
`python nir_graph = nir.NIRGraph(...) cfg = jaxsnn.ConversionConfig(...) init, apply = jaxsnn.from_nir(nir_graph, cfg) `
-
jaxsnn.from_nir_data(nir_graph_data: nir.data_ir.graph.NIRGraphData, jaxsnn_model, observables=('spikes')) → Dict[str, jaxsnn.event.types.EventPropSpike] Convert NIRGraphData to a dict of EventPropSpikes (jax-snn representation)
- Parameters
nir_graph_data – NIRGraphData to be converted.
jaxsnn_model – jaxsnn model tuple (init, apply).
observables – Observables to be converted, by default (‘spikes’,)
-
jaxsnn.get_logger(name: str)
-
jaxsnn.to_nir_data(jaxsnn_dict: Dict[str, jaxsnn.event.types.EventPropSpike], jaxsnn_model, observables=('spikes')) → nir.data_ir.graph.NIRGraphData Convert a dict of EventPropSpikes (jax-snn representation) to NIRGraphData.
- Parameters
jaxsnn_dict – Dictionary of Spike objects where each entry represents the spikes for a corresponding node of the jaxsnn_model. Empty events in jaxsnn are encoded by idx = -1 and time = 2 * t_max.
jaxsnn_model – A tuple of (init_fn, apply_fn). For the apply function, the arguments apply.nodes (which holds the layer sizes) and apply.t_max (the simulation time) currently have to be defined manually.
observables – Observables to be converted, by default (‘spikes’,)