jaxsnn.event.from_nir

Implement conversion of a NIR graph to jaxsnn-model

Classes

ConversionConfig(t_max, n_spikes, int], …)

Configuration for the conversion from NIR to jaxsnn.

LIFParameters(tau_syn, tau_mem, v_th, …)

WeightInput(input,)

Functions

jaxsnn.event.from_nir.EventPropLIF(size: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean=0.5, std=2.0, wrap_only_step: bool = False, duplication: Optional[int, None] = None)Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]]

Feed-forward layer of LIF neurons with EventProp gradient computation.

Args:

size (int): Number of hidden neurons n_spikes (int): Number of spikes which are simulated in this t_max (float): Maximum simulation time p (LIFParameters): Parameters of the LIF neurons mean (float, optional): Mean of initial weights. Defaults to 0.5. std (float, optional): Standard deviation of initial weights.

Defaults to 2.0.

wrap_only_step (bool, optional): If custom vjp should be defined

only for the step function or for the entire trajectory. Defaults to False.

duplication (Optional[int], optional): Factor with which input weights

are duplicated. Defaults to None.

Returns:

SingleInitApply: Pair of init apply functions.

jaxsnn.event.from_nir.HardwareLIF(size: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: float, std: float, duplication: Optional[int, None] = None)Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike, jaxsnn.event.types.Spike], jaxsnn.event.types.EventPropSpike]]
jaxsnn.event.from_nir.LIF(size: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: float = 0.5, std: float = 2.0, duplication: Optional[int, None] = None)Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]]

A feed-forward layer of LIF Neurons.

Args:

size (int): Number of hidden neurons n_spikes (int): Number of spikes which are simulated in this layer t_max (float): Maxium simulation time p (LIFParameters): Parameters of the LIF neurons mean (float, optional): Mean of initial weights. Defaults to 0.5. std (float, optional): Standard deviation of initial weights.

Defaults to 2.0.

Returns:

SingleInitApply: _description_

jaxsnn.event.from_nir.bias_is_zero(obj)
jaxsnn.event.from_nir.convert_cuba_lif(graph, node_key, config: jaxsnn.event.from_nir.ConversionConfig)

Convert nir.CubaLIF to jaxsnn representation (init_fn, apply_fn) jaxsnn representation is either a HardwareLIF, EventPropLIF or LIF layer

jaxsnn.event.from_nir.convert_to_number(param)

Convert parameter to number

jaxsnn.event.from_nir.dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)

Add dunder methods based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method is added. If frozen is true, fields may not be assigned to after instance creation. If match_args is true, the __match_args__ tuple is added. If kw_only is true, then by default all fields are keyword-only. If slots is true, a new class with a __slots__ attribute is returned.

jaxsnn.event.from_nir.from_nir(graph: nir.ir.graph.NIRGraph, config: jaxsnn.event.from_nir.ConversionConfig)

Convert NIRGraph to jax-snn representation (init_fn, apply_fn)

Restrictions for NIRGraph: - Only linear feed-forward SNNs are supported - CubaLIF and Linear layers are supported - Affine layers with bias==0 are currently supported - In terms of parameters, only homogeneous layers are supported - The analytical solver is only supported for non-external inputs

Example: `python nir_graph = nir.NIRGraph(...) cfg = jaxsnn.ConversionConfig(...) init, apply = jaxsnn.from_nir(nir_graph, cfg) `

jaxsnn.event.from_nir.get_edge(graph, pre_node='', post_node='')

Return list of edges either starting with pre_node or ending with post_node

jaxsnn.event.from_nir.get_keys(graph, node_class)

Return array of keys of nodes of node class

jaxsnn.event.from_nir.get_prev_node(graph, node_key)

Return previous node of node

jaxsnn.event.from_nir.is_homogeneous(arr)
jaxsnn.event.from_nir.serial(*layers: Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]])Tuple[Callable[[jax.Array, int], List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], jaxsnn.event.types.EventPropSpike], List[jaxsnn.event.types.EventPropSpike]]]

Concatenate multiple layers of init/apply functions

Returns:

InitApply: Init/apply pair