jaxsnn.event.custom_lax
Implement functionality of lax for easier debugging
Functions
-
jaxsnn.event.custom_lax.
cond
(pred: bool, true_fun: Callable, false_fun: Callable, *operands) Call both function to evaluate compiled behaviour of jax.lax.scan
-
jaxsnn.event.custom_lax.
scan
(inner_fn: Callable, init, inputs, length: Optional[int, None] = None, reverse: bool = False)
-
jaxsnn.event.custom_lax.
tree_flatten
(tree: Any, is_leaf: Optional[Callable[[Any], bool], None] = None) → Tuple[List[Any], jaxlib.xla_extension.pytree.PyTreeDef] Flattens a pytree.
The flattening order (i.e. the order of elements in the output list) is deterministic, corresponding to a left-to-right depth-first tree traversal.
- Args:
tree: a pytree to flatten. is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, with true stopping the traversal and the whole subtree being treated as a leaf, and false indicating the flattening should traverse the current object.
- Returns:
A pair where the first element is a list of leaf values and the second element is a treedef representing the structure of the flattened tree.
-
jaxsnn.event.custom_lax.
tree_map
(f: Callable[[…], Any], tree: Any, *rest: Any, is_leaf: Optional[Callable[[Any], bool], None] = None) → Any Maps a multi-input function over pytree args to produce a new pytree.
- Args:
- f: function that takes
1 + len(rest)
arguments, to be applied at the corresponding leaves of the pytrees.
- tree: a pytree to be mapped over, with each leaf providing the first
positional argument to
f
.- rest: a tuple of pytrees, each of which has the same structure as
tree
or has
tree
as a prefix.- is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
- f: function that takes
- Returns:
A new pytree with the same structure as
tree
but with the value at each leaf given byf(x, *xs)
wherex
is the value at the corresponding leaf intree
andxs
is the tuple of values at corresponding nodes inrest
.
Examples:
>>> import jax.tree_util >>> jax.tree_util.tree_map(lambda x: x + 1, {"x": 7, "y": 42}) {'x': 8, 'y': 43}
If multiple inputs are passed, the structure of the tree is taken from the first input; subsequent inputs need only have
tree
as a prefix:>>> jax.tree_util.tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) [[5, 7, 9], [6, 1, 2]]
-
jaxsnn.event.custom_lax.
tree_unflatten
(treedef: jaxlib.xla_extension.pytree.PyTreeDef, leaves: Iterable[Any]) → Any Reconstructs a pytree from the treedef and the leaves.
The inverse of
tree_flatten()
.- Args:
treedef: the treedef to reconstruct leaves: the iterable of leaves to use for reconstruction. The iterable
must match the leaves of the treedef.
- Returns:
The reconstructed pytree, containing the
leaves
placed in the structure described bytreedef
.