jaxsnn
Modules
Classes
|
Configuration for the conversion from NIR to jaxsnn. |
Functions
-
jaxsnn.from_nir(graph: nir.ir.graph.NIRGraph, config: jaxsnn.event.from_nir.ConversionConfig) Convert NIRGraph to jax-snn representation (init_fn, apply_fn)
Restrictions for NIRGraph: - Only linear feed-forward SNNs are supported - CubaLIF and Linear layers are supported - Affine layers with bias==0 are currently supported - In terms of parameters, only homogeneous layers are supported - The analytical solver is only supported for non-external inputs
Example:
`python nir_graph = nir.NIRGraph(...) cfg = jaxsnn.ConversionConfig(...) init, apply = jaxsnn.from_nir(nir_graph, cfg) `
-
jaxsnn.from_nir_data(nir_graph_data: nir.data_ir.graph.NIRGraphData, jaxsnn_model) → Dict[str, jaxsnn.event.types.EventPropSpike] Convert NIRGraphData to a dict of EventPropSpikes (jax-snn representation)
A linear noise is added on the spike times if the incoming data is time-gridded.
-
jaxsnn.get_logger(name: str)
-
jaxsnn.to_nir_data(jaxsnn_dict: Dict[str, jaxsnn.event.types.EventPropSpike], jaxsnn_model) → nir.data_ir.graph.NIRGraphData - Parameters
jaxsnn_dict – Dictionary of Spike objects where each entry represents the spikes for a corresponding input node of the jaxsnn_model. Empty events in jaxsnn are encoded by idx = -1 and time = np.inf.
jaxsnn_model – A tuple of (init_fn, apply_fn).