jaxsnn.event.from_nir
Implement conversion of a NIR graph to jaxsnn-model
Classes
|
Configuration for the conversion from NIR to jaxsnn. |
|
|
|
Functions
-
jaxsnn.event.from_nir.
EventPropLIF
(size: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean=0.5, std=2.0, wrap_only_step: bool = False, duplication: Optional[int, None] = None) → Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]] Feed-forward layer of LIF neurons with EventProp gradient computation.
- Args:
size (int): Number of hidden neurons n_spikes (int): Number of spikes which are simulated in this t_max (float): Maximum simulation time p (LIFParameters): Parameters of the LIF neurons mean (float, optional): Mean of initial weights. Defaults to 0.5. std (float, optional): Standard deviation of initial weights.
Defaults to 2.0.
- wrap_only_step (bool, optional): If custom vjp should be defined
only for the step function or for the entire trajectory. Defaults to False.
- duplication (Optional[int], optional): Factor with which input weights
are duplicated. Defaults to None.
- Returns:
SingleInitApply: Pair of init apply functions.
-
jaxsnn.event.from_nir.
HardwareLIF
(size: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: float, std: float, duplication: Optional[int, None] = None) → Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike, jaxsnn.event.types.Spike], jaxsnn.event.types.EventPropSpike]]
-
jaxsnn.event.from_nir.
LIF
(size: int, n_spikes: int, t_max: float, params: jaxsnn.base.params.LIFParameters, mean: float = 0.5, std: float = 2.0, duplication: Optional[int, None] = None) → Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]] A feed-forward layer of LIF Neurons.
- Args:
size (int): Number of hidden neurons n_spikes (int): Number of spikes which are simulated in this layer t_max (float): Maxium simulation time p (LIFParameters): Parameters of the LIF neurons mean (float, optional): Mean of initial weights. Defaults to 0.5. std (float, optional): Standard deviation of initial weights.
Defaults to 2.0.
- Returns:
SingleInitApply: _description_
-
jaxsnn.event.from_nir.
bias_is_zero
(obj)
-
jaxsnn.event.from_nir.
convert_cuba_lif
(graph, node_key, config: jaxsnn.event.from_nir.ConversionConfig) Convert nir.CubaLIF to jaxsnn representation (init_fn, apply_fn) jaxsnn representation is either a HardwareLIF, EventPropLIF or LIF layer
-
jaxsnn.event.from_nir.
convert_to_number
(param) Convert parameter to number
-
jaxsnn.event.from_nir.
dataclass
(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False) Add dunder methods based on the fields defined in the class.
Examines PEP 526 __annotations__ to determine fields.
If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method is added. If frozen is true, fields may not be assigned to after instance creation. If match_args is true, the __match_args__ tuple is added. If kw_only is true, then by default all fields are keyword-only. If slots is true, a new class with a __slots__ attribute is returned.
-
jaxsnn.event.from_nir.
from_nir
(graph: nir.ir.graph.NIRGraph, config: jaxsnn.event.from_nir.ConversionConfig) Convert NIRGraph to jax-snn representation (init_fn, apply_fn)
Restrictions for NIRGraph: - Only linear feed-forward SNNs are supported - CubaLIF and Linear layers are supported - Affine layers with bias==0 are currently supported - In terms of parameters, only homogeneous layers are supported - The analytical solver is only supported for non-external inputs
Example:
`python nir_graph = nir.NIRGraph(...) cfg = jaxsnn.ConversionConfig(...) init, apply = jaxsnn.from_nir(nir_graph, cfg) `
-
jaxsnn.event.from_nir.
get_edge
(graph, pre_node='', post_node='') Return list of edges either starting with pre_node or ending with post_node
-
jaxsnn.event.from_nir.
get_keys
(graph, node_class) Return array of keys of nodes of node class
-
jaxsnn.event.from_nir.
get_prev_node
(graph, node_key) Return previous node of node
-
jaxsnn.event.from_nir.
is_homogeneous
(arr)
-
jaxsnn.event.from_nir.
serial
(*layers: Tuple[Callable[[jax.Array, int], Tuple[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[int, Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent], jaxsnn.event.types.EventPropSpike], jaxsnn.event.types.EventPropSpike]]) → Tuple[Callable[[jax.Array, int], List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]]], Callable[[List[Union[jaxsnn.event.types.WeightInput, jaxsnn.event.types.WeightRecurrent]], jaxsnn.event.types.EventPropSpike], List[jaxsnn.event.types.EventPropSpike]]] Concatenate multiple layers of init/apply functions
- Returns:
InitApply: Init/apply pair