Commit b2a5dd5c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

improve Field norm() method

parent c300a904
...@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum", "local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum",
"np_allreduce_min", "np_allreduce_max", "np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data", "distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy", "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw", "lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"] "ensure_not_distributed", "ensure_default_distributed"]
...@@ -553,3 +553,22 @@ def ensure_default_distributed(arr): ...@@ -553,3 +553,22 @@ def ensure_default_distributed(arr):
if arr._distaxis != 0: if arr._distaxis != 0:
arr = redistribute(arr, dist=0) arr = redistribute(arr, dist=0)
return arr return arr
def absmax(arr):
if arr._data.size == 0:
tmp = np.array(0, dtype=arr._data.dtype)
else:
tmp = np.linalg.norm(arr._data, ord=np.inf)
res = np.empty_like(tmp)
_comm.Allreduce(tmp, res, MPI.MAX)
return res[()]
def norm(arr, ord=2):
if ord == np.inf:
return absmax(arr)
tmp = np.linalg.norm(np.atleast_1d(arr._data), ord=ord) ** ord
res = np.empty_like(tmp)
_comm.Allreduce(tmp, res, MPI.SUM)
return res[()] ** (1./ord)
...@@ -31,7 +31,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -31,7 +31,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum", "local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum",
"np_allreduce_min", "np_allreduce_max", "np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data", "distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy", "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw", "lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"] "ensure_not_distributed", "ensure_default_distributed"]
...@@ -141,3 +141,11 @@ def ensure_not_distributed(arr, axes): ...@@ -141,3 +141,11 @@ def ensure_not_distributed(arr, axes):
def ensure_default_distributed(arr): def ensure_default_distributed(arr):
return arr return arr
def absmax(arr):
return np.linalg.norm(arr, ord=np.inf)
def norm(arr, ord=2):
return np.linalg.norm(np.atleast_1d(arr), ord=ord)
...@@ -360,25 +360,20 @@ class Field(object): ...@@ -360,25 +360,20 @@ class Field(object):
# For the moment, do this the explicit, non-optimized way # For the moment, do this the explicit, non-optimized way
return (self.conjugate()*x).sum(spaces=spaces) return (self.conjugate()*x).sum(spaces=spaces)
def norm(self): def norm(self, ord=2):
""" Computes the L2-norm of the field values. """ Computes the L2-norm of the field values.
Returns Parameters
------- ----------
float ord : int, default=2
The L2-norm of the field values. accepted values: 1, 2, ..., np.inf
"""
return np.sqrt(abs(self.vdot(x=self)))
def squared_norm(self):
""" Computes the square of the L2-norm of the field values.
Returns Returns
------- -------
float float
The square of the L2-norm of the field values. The L2-norm of the field values.
""" """
return abs(self.vdot(x=self)) return dobj.norm(self._val, ord)
def conjugate(self): def conjugate(self):
""" Returns the complex conjugate of the field. """ Returns the complex conjugate of the field.
......
...@@ -137,15 +137,24 @@ class MultiField(object): ...@@ -137,15 +137,24 @@ class MultiField(object):
domain, tuple(Field.from_global_data(domain[key], arr[key], sum_up) domain, tuple(Field.from_global_data(domain[key], arr[key], sum_up)
for key in domain.keys())) for key in domain.keys()))
def norm(self): def norm(self, ord=2):
""" Computes the L2-norm of the field values. """ Computes the norm of the field values.
Parameters
----------
ord : int, default=2
accepted values: 1, 2, ..., np.inf
Returns Returns
------- -------
norm : float norm : float
The L2-norm of the field values. The norm of the field values.
""" """
return np.sqrt(np.abs(self.vdot(x=self))) nrm = np.asarray([f.norm(ord) for f in self._val])
if ord == np.inf:
return nrm.max()
return (nrm ** ord).sum() ** (1./ord)
# return np.sqrt(np.abs(self.vdot(x=self)))
def sum(self): def sum(self):
""" Computes the sum all field values. """ Computes the sum all field values.
...@@ -168,16 +177,6 @@ class MultiField(object): ...@@ -168,16 +177,6 @@ class MultiField(object):
""" """
return utilities.my_sum(map(lambda d: d.size, self._domain.domains())) return utilities.my_sum(map(lambda d: d.size, self._domain.domains()))
def squared_norm(self):
""" Computes the square of the L2-norm of the field values.
Returns
-------
float
The square of the L2-norm of the field values.
"""
return abs(self.vdot(x=self))
def __neg__(self): def __neg__(self):
return self._transform(lambda x: -x) return self._transform(lambda x: -x)
......
...@@ -110,6 +110,16 @@ class Test_Functionality(unittest.TestCase): ...@@ -110,6 +110,16 @@ class Test_Functionality(unittest.TestCase):
assert_allclose(sc1.mean.local_data, fp1.local_data, rtol=0.2) assert_allclose(sc1.mean.local_data, fp1.local_data, rtol=0.2)
assert_allclose(sc2.mean.local_data, fp2.local_data, rtol=0.2) assert_allclose(sc2.mean.local_data, fp2.local_data, rtol=0.2)
def test_norm(self):
s = ift.RGSpace((10,))
f = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
gd = f.to_global_data()
assert_allclose(f.norm(), np.linalg.norm(gd))
assert_allclose(f.norm(1), np.linalg.norm(gd, ord=1))
assert_allclose(f.norm(2), np.linalg.norm(gd, ord=2))
assert_allclose(f.norm(3), np.linalg.norm(gd, ord=3))
assert_allclose(f.norm(np.inf), np.linalg.norm(gd, ord=np.inf))
def test_vdot(self): def test_vdot(self):
s = ift.RGSpace((10,)) s = ift.RGSpace((10,))
f1 = ift.Field.from_random("normal", domain=s, dtype=np.complex128) f1 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment