Skip to content
Snippets Groups Projects

forest_util.py norm: Remove usage of list

Merged Gordian Edenhofer requested to merge fix_jax_array_norm into NIFTy_8
1 file
+ 15
14
Compare changes
  • Side-by-side
  • Inline
+ 15
14
@@ -96,7 +96,9 @@ class ShapeWithDtype():
This class may not be transparent to JAX as it shall not be flattened
itself. If used in a tree-like structure. It should only be used as leave.
"""
def __init__(self, shape: Union[Tuple[int], List[int], int], dtype=None):
def __init__(
self, shape: Union[Tuple[()], Tuple[int], List[int], int], dtype=None
):
"""Instantiates a storage unit for shape and dtype.
Parameters
@@ -235,23 +237,22 @@ def zeros_like(a, dtype=None, shape=None):
return tree_map(partial(_zeros_like, dtype=dtype, shape=shape), a)
def norm(tree, ord, *, ravel: bool):
from jax.numpy.linalg import norm
if ravel:
def el_norm(x):
return jnp.abs(x) if jnp.ndim(x) == 0 else norm(x.ravel(), ord=ord)
else:
def _ravel(x):
return x.ravel() if hasattr(x, "ravel") else jnp.ravel(x)
def el_norm(x):
return jnp.abs(x) if jnp.ndim(x) == 0 else norm(x, ord=ord)
return norm(tree_leaves(tree_map(el_norm, tree)), ord=ord)
def norm(tree, ord, *, ravel: bool):
from jax.numpy.linalg import norm
def el_norm(x):
if jnp.ndim(x) == 0:
return jnp.abs(x)
elif ravel:
return norm(_ravel(x), ord=ord)
else:
return norm(x, ord=ord)
def _ravel(x):
return x.ravel() if hasattr(x, "ravel") else jnp.ravel(x)
return norm(jnp.array(tree_leaves(tree_map(el_norm, tree))), ord=ord)
def dot(a, b, *, precision=None):
Loading