hxtorch.spiking.modules.HXModuleWrapper

class hxtorch.spiking.modules.HXModuleWrapper(experiment: Experiment, modules: List[HXModule], func: Optional[Callable])

Bases: hxtorch.spiking.modules.hx_module.HXFunctionalModule

Class to wrap HXModules

__init__(experiment: Experiment, modules: List[HXModule], func: Optional[Callable])None

A module which wraps a number of HXModules defined in modules to which a single PyTorch-differential function func is defined. For instance, this allows to wrap a Synapse and a Neuron to describe recurrence. :param experiment: The experiment to register this wrapper in. :param modules: A list of modules to be represented by this wrapper. :param func: The function describing the unified functionality of all

modules assigned to this wrapper. As for HXModules, this needs to be a PyTorch-differentiable function and can be either an autograd.Function or a function defined by PyTorch operation. The signature of this function is expected as: 1. All positional arguments of each function in modules appended

in the order given in modules.

  1. All keywords arguments of each function in modules. If a keyword occurs multiple times, it is post-fixed _i, where i is an integer incremented with each occurrence.

  2. A keyword argument hw_data if hardware data is expected, which is a tuple holding the data for each module for which data is expected. The order is defined by modules.

The function is expected to output a tensor or a tuple of tensors for each module in modules, that can be assigned to the output handle of the corresponding HXModule.

Methods

__init__(experiment, modules, func)

A module which wraps a number of HXModules defined in modules to which a single PyTorch-differential function func is defined. For instance, this allows to wrap a Synapse and a Neuron to describe recurrence. :param experiment: The experiment to register this wrapper in. :param modules: A list of modules to be represented by this wrapper. :param func: The function describing the unified functionality of all modules assigned to this wrapper. As for HXModules, this needs to be a PyTorch-differentiable function and can be either an autograd.Function or a function defined by PyTorch operation. The signature of this function is expected as: 1. All positional arguments of each function in modules appended in the order given in modules. 2. All keywords arguments of each function in modules. If a keyword occurs multiple times, it is post-fixed _i, where i is an integer incremented with each occurrence. 3. A keyword argument hw_data if hardware data is expected, which is a tuple holding the data for each module for which data is expected. The order is defined by modules. The function is expected to output a tensor or a tuple of tensors for each module in modules, that can be assigned to the output handle of the corresponding HXModule.

contains(modules)

Checks whether a list of modules modules is registered in the wrapper.

exec_forward(input, output, hw_map)

Execute the the forward function of the wrapper. This method assigns each output handle in output their corresponding PyTorch tensors and adds the wrapper’s func to the PyTorch graph. :param input: A tuple of the input handles where each handle corresponds to a certain module. The order is defined by modules. Note, a module can have multiple input handles. :param output: A tuple of output handles, each corresponding to one module. The order is defined by modules. :param hw_map: The hardware data map.

extra_repr()

Add additional information

update(modules[, func])

Update the modules and the function in the wrapper.

update_args(modules)

Gathers the args and kwargs of all modules in modules and renames keyword arguments that occur multiple times.

Attributes

contains(modules: List[hxtorch.spiking.modules.hx_module.HXModule])bool

Checks whether a list of modules modules is registered in the wrapper. :param modules: The modules for which to check if they are registered. :return: Returns a bool indicating whether modules are a subset.

exec_forward(input: Tuple[hxtorch.spiking.handle.TensorHandle], output: Tuple[hxtorch.spiking.handle.TensorHandle], hw_map: Dict[_pygrenade_vx_network.PopulationOnNetwork, Tuple[torch.Tensor]])None

Execute the the forward function of the wrapper. This method assigns each output handle in output their corresponding PyTorch tensors and adds the wrapper’s func to the PyTorch graph. :param input: A tuple of the input handles where each handle

corresponds to a certain module. The order is defined by modules. Note, a module can have multiple input handles.

Parameters
  • output – A tuple of output handles, each corresponding to one module. The order is defined by modules.

  • hw_map – The hardware data map.

extra_args: Tuple[Any]
extra_kwargs: Dict[str, Any]
extra_repr()str

Add additional information

training: bool
update(modules: List[hxtorch.spiking.modules.hx_module.HXModule], func: Optional[Callable] = None)

Update the modules and the function in the wrapper. :param modules: The new modules to assign to the wrapper. :param func: The new function to represent the modules in the wrapper.

update_args(modules: List[hxtorch.spiking.modules.hx_module.HXModule])

Gathers the args and kwargs of all modules in modules and renames keyword arguments that occur multiple times. :param modules: The modules represented by the wrapper.