jaxsnn.event.custom_lax

Implement functionality of lax for easier debugging

Functions

jaxsnn.event.custom_lax.cond(pred: bool, true_fun: Callable, false_fun: Callable, *operands)

Call both function to evaluate compiled behaviour of jax.lax.scan

jaxsnn.event.custom_lax.scan(inner_fn: Callable, init, inputs, length: Optional[int, None] = None, reverse: bool = False)
jaxsnn.event.custom_lax.tree_flatten(tree: Any, is_leaf: Optional[Callable[[Any], bool], None] = None)Tuple[List[Any], jaxlib.xla_extension.pytree.PyTreeDef]

Flattens a pytree.

The flattening order (i.e. the order of elements in the output list) is deterministic, corresponding to a left-to-right depth-first tree traversal.

Args:

tree: a pytree to flatten. is_leaf: an optionally specified function that will be called at each

flattening step. It should return a boolean, with true stopping the traversal and the whole subtree being treated as a leaf, and false indicating the flattening should traverse the current object.

Returns:

A pair where the first element is a list of leaf values and the second element is a treedef representing the structure of the flattened tree.

jaxsnn.event.custom_lax.tree_map(f: Callable[[], Any], tree: Any, *rest: Any, is_leaf: Optional[Callable[[Any], bool], None] = None)Any

Maps a multi-input function over pytree args to produce a new pytree.

Args:
f: function that takes 1 + len(rest) arguments, to be applied at the

corresponding leaves of the pytrees.

tree: a pytree to be mapped over, with each leaf providing the first

positional argument to f.

rest: a tuple of pytrees, each of which has the same structure as tree

or has tree as a prefix.

is_leaf: an optionally specified function that will be called at each

flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.

Returns:

A new pytree with the same structure as tree but with the value at each leaf given by f(x, *xs) where x is the value at the corresponding leaf in tree and xs is the tuple of values at corresponding nodes in rest.

Examples:

>>> import jax.tree_util
>>> jax.tree_util.tree_map(lambda x: x + 1, {"x": 7, "y": 42})
{'x': 8, 'y': 43}

If multiple inputs are passed, the structure of the tree is taken from the first input; subsequent inputs need only have tree as a prefix:

>>> jax.tree_util.tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
[[5, 7, 9], [6, 1, 2]]
jaxsnn.event.custom_lax.tree_unflatten(treedef: jaxlib.xla_extension.pytree.PyTreeDef, leaves: Iterable[Any])Any

Reconstructs a pytree from the treedef and the leaves.

The inverse of tree_flatten().

Args:

treedef: the treedef to reconstruct leaves: the iterable of leaves to use for reconstruction. The iterable

must match the leaves of the treedef.

Returns:

The reconstructed pytree, containing the leaves placed in the structure described by treedef.