hxtorch.spiking.HXModuleWrapper
-
class
hxtorch.spiking.
HXModuleWrapper
(experiment: Experiment, modules: List[HXModule], func: Optional[Callable]) Bases:
hxtorch.spiking.modules.hx_module.HXFunctionalModule
Class to wrap HXModules
-
__init__
(experiment: Experiment, modules: List[HXModule], func: Optional[Callable]) → None A module which wraps a number of HXModules defined in modules to which a single PyTorch-differential function func is defined. For instance, this allows to wrap a Synapse and a Neuron to describe recurrence. :param experiment: The experiment to register this wrapper in. :param modules: A list of modules to be represented by this wrapper. :param func: The function describing the unified functionality of all
modules assigned to this wrapper. As for HXModules, this needs to be a PyTorch-differentiable function and can be either an autograd.Function or a function defined by PyTorch operation. The signature of this function is expected as: 1. All positional arguments of each function in modules appended
in the order given in modules.
All keywords arguments of each function in modules. If a keyword occurs multiple times, it is post-fixed _i, where i is an integer incremented with each occurrence.
A keyword argument hw_data if hardware data is expected, which is a tuple holding the data for each module for which data is expected. The order is defined by modules.
The function is expected to output a tensor or a tuple of tensors for each module in modules, that can be assigned to the output handle of the corresponding HXModule.
Methods
__init__
(experiment, modules, func)A module which wraps a number of HXModules defined in modules to which a single PyTorch-differential function func is defined. For instance, this allows to wrap a Synapse and a Neuron to describe recurrence. :param experiment: The experiment to register this wrapper in. :param modules: A list of modules to be represented by this wrapper. :param func: The function describing the unified functionality of all modules assigned to this wrapper. As for HXModules, this needs to be a PyTorch-differentiable function and can be either an autograd.Function or a function defined by PyTorch operation. The signature of this function is expected as: 1. All positional arguments of each function in modules appended in the order given in modules. 2. All keywords arguments of each function in modules. If a keyword occurs multiple times, it is post-fixed _i, where i is an integer incremented with each occurrence. 3. A keyword argument hw_data if hardware data is expected, which is a tuple holding the data for each module for which data is expected. The order is defined by modules. The function is expected to output a tensor or a tuple of tensors for each module in modules, that can be assigned to the output handle of the corresponding HXModule.
contains
(modules)Checks whether a list of modules modules is registered in the wrapper.
exec_forward
(input, output, hw_map)Execute the the forward function of the wrapper. This method assigns each output handle in output their corresponding PyTorch tensors and adds the wrapper’s func to the PyTorch graph. :param input: A tuple of the input handles where each handle corresponds to a certain module. The order is defined by modules. Note, a module can have multiple input handles. :param output: A tuple of output handles, each corresponding to one module. The order is defined by modules. :param hw_map: The hardware data map.
Add additional information
update
(modules[, func])Update the modules and the function in the wrapper.
update_args
(modules)Gathers the args and kwargs of all modules in modules and renames keyword arguments that occur multiple times.
Attributes
-
contains
(modules: List[hxtorch.spiking.modules.hx_module.HXModule]) → bool Checks whether a list of modules modules is registered in the wrapper. :param modules: The modules for which to check if they are registered. :return: Returns a bool indicating whether modules are a subset.
-
exec_forward
(input: Tuple[hxtorch.spiking.handle.TensorHandle], output: Tuple[hxtorch.spiking.handle.TensorHandle], hw_map: Dict[_pygrenade_vx_network.PopulationOnNetwork, Tuple[torch.Tensor]]) → None Execute the the forward function of the wrapper. This method assigns each output handle in output their corresponding PyTorch tensors and adds the wrapper’s func to the PyTorch graph. :param input: A tuple of the input handles where each handle
corresponds to a certain module. The order is defined by modules. Note, a module can have multiple input handles.
- Parameters
output – A tuple of output handles, each corresponding to one module. The order is defined by modules.
hw_map – The hardware data map.
-
extra_repr
() → str Add additional information
-
training
: bool
-
update
(modules: List[hxtorch.spiking.modules.hx_module.HXModule], func: Optional[Callable] = None) Update the modules and the function in the wrapper. :param modules: The new modules to assign to the wrapper. :param func: The new function to represent the modules in the wrapper.
-
update_args
(modules: List[hxtorch.spiking.modules.hx_module.HXModule]) Gathers the args and kwargs of all modules in modules and renames keyword arguments that occur multiple times. :param modules: The modules represented by the wrapper.
-