Commit c51456ea authored by Martin Reinecke's avatar Martin Reinecke

various tweaks, improve coverage

parent 77c29cf3
Pipeline #29511 failed with stages
in 3 minutes and 54 seconds
...@@ -145,20 +145,29 @@ class data_object(object): ...@@ -145,20 +145,29 @@ class data_object(object):
def sum(self, axis=None): def sum(self, axis=None):
return self._contraction_helper("sum", MPI.SUM, axis) return self._contraction_helper("sum", MPI.SUM, axis)
def prod(self, axis=None):
return self._contraction_helper("prod", MPI.PROD, axis)
def min(self, axis=None): def min(self, axis=None):
return self._contraction_helper("min", MPI.MIN, axis) return self._contraction_helper("min", MPI.MIN, axis)
def max(self, axis=None): def max(self, axis=None):
return self._contraction_helper("max", MPI.MAX, axis) return self._contraction_helper("max", MPI.MAX, axis)
def mean(self): def mean(self, axis=None):
return self.sum()/self.size if axis is None:
sz = self.size
else:
sz = reduce(lambda x, y: x*y, [self.shape[i] for i in axis])
return self.sum(axis)/sz
def std(self): def std(self, axis=None):
return np.sqrt(self.var()) return np.sqrt(self.var(axis))
# FIXME: to be improved! # FIXME: to be improved!
def var(self): def var(self, axis=None):
if axis is not None and len(axis) != len(self.shape):
raise ValueError("functionality not yet supported")
return (abs(self-self.mean())**2).mean() return (abs(self-self.mean())**2).mean()
def _binary_helper(self, other, op): def _binary_helper(self, other, op):
......
...@@ -235,6 +235,7 @@ class Field(object): ...@@ -235,6 +235,7 @@ class Field(object):
The value to fill the field with. The value to fill the field with.
""" """
self._val.fill(fill_value) self._val.fill(fill_value)
return self
def lock(self): def lock(self):
"""Write-protect the data content of `self`. """Write-protect the data content of `self`.
...@@ -318,6 +319,17 @@ class Field(object): ...@@ -318,6 +319,17 @@ class Field(object):
""" """
return Field(val=self, copy=True) return Field(val=self, copy=True)
def empty_copy(self):
""" Returns a Field with identical domain and data type, but
uninitialized data.
Returns
-------
Field
A copy of 'self', with uninitialized data.
"""
return Field(self._domain, dtype=self.dtype)
def locked_copy(self): def locked_copy(self):
""" Returns a read-only version of the Field. """ Returns a read-only version of the Field.
...@@ -451,8 +463,8 @@ class Field(object): ...@@ -451,8 +463,8 @@ class Field(object):
or Field (for partial dot products) or Field (for partial dot products)
""" """
if not isinstance(x, Field): if not isinstance(x, Field):
raise ValueError("The dot-partner must be an instance of " + raise TypeError("The dot-partner must be an instance of " +
"the NIFTy field class") "the NIFTy field class")
if x._domain != self._domain: if x._domain != self._domain:
raise ValueError("Domain mismatch") raise ValueError("Domain mismatch")
...@@ -642,7 +654,8 @@ class Field(object): ...@@ -642,7 +654,8 @@ class Field(object):
if self.scalar_weight(spaces) is not None: if self.scalar_weight(spaces) is not None:
return self._contraction_helper('mean', spaces) return self._contraction_helper('mean', spaces)
# MR FIXME: not very efficient # MR FIXME: not very efficient
tmp = self.weight(1) # MR FIXME: do we need "spaces" here?
tmp = self.weight(1, spaces)
return tmp.sum(spaces)*(1./tmp.total_volume(spaces)) return tmp.sum(spaces)*(1./tmp.total_volume(spaces))
def var(self, spaces=None): def var(self, spaces=None):
...@@ -665,12 +678,10 @@ class Field(object): ...@@ -665,12 +678,10 @@ class Field(object):
# MR FIXME: not very efficient or accurate # MR FIXME: not very efficient or accurate
m1 = self.mean(spaces) m1 = self.mean(spaces)
if np.issubdtype(self.dtype, np.complexfloating): if np.issubdtype(self.dtype, np.complexfloating):
sq = abs(self)**2 sq = abs(self-m1)**2
m1 = abs(m1)**2
else: else:
sq = self**2 sq = (self-m1)**2
m1 **= 2 return sq.mean(spaces)
return sq.mean(spaces) - m1
def std(self, spaces=None): def std(self, spaces=None):
"""Determines the standard deviation over the sub-domains given by """Determines the standard deviation over the sub-domains given by
...@@ -690,8 +701,10 @@ class Field(object): ...@@ -690,8 +701,10 @@ class Field(object):
The result of the operation. If it is carried out over the entire The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field. domain, this is a scalar, otherwise a Field.
""" """
from .sugar import sqrt
if self.scalar_weight(spaces) is not None: if self.scalar_weight(spaces) is not None:
return self._contraction_helper('std', spaces) return self._contraction_helper('std', spaces)
print(self.var(spaces))
return sqrt(self.var(spaces)) return sqrt(self.var(spaces))
def copy_content_from(self, other): def copy_content_from(self, other):
......
...@@ -54,7 +54,7 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller): ...@@ -54,7 +54,7 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
""" """
# RL FIXME: make consistent with complex numbers # RL FIXME: make consistent with complex numbers
j = S.draw_sample(from_inverse=True) if j is None else j j = S.draw_sample(from_inverse=True) if j is None else j
energy = QuadraticEnergy(j*0., D_inv, j) energy = QuadraticEnergy(j.empty_copy().fill(0.), D_inv, j)
y = [S.draw_sample() for _ in range(N_samps)] y = [S.draw_sample() for _ in range(N_samps)]
status = controller.start(energy) status = controller.start(energy)
......
...@@ -57,6 +57,18 @@ class MultiField(object): ...@@ -57,6 +57,18 @@ class MultiField(object):
dtype[key], **kwargs) dtype[key], **kwargs)
for key in domain.keys()}) for key in domain.keys()})
def fill(self, fill_value):
"""Fill `self` uniformly with `fill_value`
Parameters
----------
fill_value: float or complex or int
The value to fill the field with.
"""
for val in self._val.values():
val.fill(fill_value)
return self
def _check_domain(self, other): def _check_domain(self, other):
if other.domain != self.domain: if other.domain != self.domain:
raise ValueError("domains are incompatible.") raise ValueError("domains are incompatible.")
...@@ -76,6 +88,9 @@ class MultiField(object): ...@@ -76,6 +88,9 @@ class MultiField(object):
def copy(self): def copy(self):
return MultiField({key: val.copy() for key, val in self.items()}) return MultiField({key: val.copy() for key, val in self.items()})
def empty_copy(self):
return MultiField({key: val.empty_copy() for key, val in self.items()})
@staticmethod @staticmethod
def build_dtype(dtype, domain): def build_dtype(dtype, domain):
if isinstance(dtype, dict): if isinstance(dtype, dict):
......
...@@ -67,7 +67,7 @@ class InversionEnabler(EndomorphicOperator): ...@@ -67,7 +67,7 @@ class InversionEnabler(EndomorphicOperator):
if self._op.capability & mode: if self._op.capability & mode:
return self._op.apply(x, mode) return self._op.apply(x, mode)
x0 = x*0. x0 = x.empty_copy().fill(0.)
invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]] invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]]
invop = self._op._flip_modes(self._ilog[invmode]) invop = self._op._flip_modes(self._ilog[invmode])
prec = self._approximation prec = self._approximation
......
...@@ -62,7 +62,7 @@ class ScalingOperator(EndomorphicOperator): ...@@ -62,7 +62,7 @@ class ScalingOperator(EndomorphicOperator):
if self._factor == 1.: if self._factor == 1.:
return x.copy() return x.copy()
if self._factor == 0.: if self._factor == 0.:
return x*0. return x.empty_copy().fill(0.)
if mode == self.TIMES: if mode == self.TIMES:
return x*self._factor return x*self._factor
......
...@@ -218,16 +218,16 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: ...@@ -218,16 +218,16 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
func2(value, out=out[key]) func2(value, out=out[key])
return out return out
return MultiField({key: func2(val) for key, val in x.items()}) return MultiField({key: func2(val) for key, val in x.items()})
elif isinstance(x, Field):
if not isinstance(x, Field): fu = getattr(dobj, f)
raise TypeError("This function only accepts Field objects.") if out is not None:
fu = getattr(dobj, f) if not isinstance(out, Field) or x._domain != out._domain:
if out is not None: raise ValueError("Bad 'out' argument")
if not isinstance(out, Field) or x._domain != out._domain: fu(x.val, out=out.val)
raise ValueError("Bad 'out' argument") return out
fu(x.val, out=out.val) else:
return out return Field(domain=x._domain, val=fu(x.val))
else: else:
return Field(domain=x._domain, val=fu(x.val)) return getattr(np, f)(x, out)
return func2 return func2
setattr(_current_module, f, func(f)) setattr(_current_module, f, func(f))
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import unittest import unittest
import numpy as np import numpy as np
from numpy.testing import assert_equal, assert_allclose from numpy.testing import assert_equal, assert_allclose, assert_raises
from itertools import product from itertools import product
import nifty4 as ift import nifty4 as ift
from test.common import expand from test.common import expand
...@@ -124,6 +124,122 @@ class Test_Functionality(unittest.TestCase): ...@@ -124,6 +124,122 @@ class Test_Functionality(unittest.TestCase):
res = m.vdot(m, spaces=1) res = m.vdot(m, spaces=1)
assert_allclose(res.local_data, 37.5) assert_allclose(res.local_data, 37.5)
def test_lock(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
assert_equal(f1.locked, False)
f1.lock()
assert_equal(f1.locked, True)
with assert_raises(ValueError):
f1 += f1
assert_equal(f1.locked_copy() is f1, True)
def test_fill(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
assert_equal(f1.fill(10).local_data, 10)
def test_dataconv(self):
s1 = ift.RGSpace((10,))
ld = np.arange(ift.dobj.local_shape(s1.shape)[0])
gd = np.arange(s1.shape[0])
assert_equal(ld, ift.from_local_data(s1, ld).local_data)
assert_equal(gd, ift.from_global_data(s1, gd).to_global_data())
def test_cast_domain(self):
s1 = ift.RGSpace((10,))
s2 = ift.RGSpace((10,), distances=20.)
d = np.arange(s1.shape[0])
d2 = ift.from_global_data(s1, d).cast_domain(s2).to_global_data()
assert_equal(d, d2)
def test_empty_domain(self):
f = ift.Field((), 5)
assert_equal(f.to_global_data(), 5)
f = ift.Field(None, 5)
assert_equal(f.to_global_data(), 5)
assert_equal(f.empty_copy().domain, f.domain)
assert_equal(f.empty_copy().dtype, f.dtype)
assert_equal(f.copy().domain, f.domain)
assert_equal(f.copy().dtype, f.dtype)
assert_equal(f.copy().local_data, f.local_data)
assert_equal(f.copy() is f, False)
def test_trivialities(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
assert_equal(f1.local_data, f1.real.local_data)
f1 = ift.Field(s1, 27.+3j)
assert_equal(f1.real.local_data, 27.)
assert_equal(f1.imag.local_data, 3.)
assert_equal(f1.local_data, +f1.local_data)
assert_equal(f1.sum(), f1.sum(0))
f1 = ift.from_global_data(s1, np.arange(10))
assert_equal(f1.min(), 0)
assert_equal(f1.max(), 9)
assert_equal(f1.prod(), 0)
def test_weight(self):
s1 = ift.RGSpace((10,))
f = ift.Field(s1, 10.)
f2 = f.copy()
f.weight(1, out=f2)
assert_equal(f.weight(1).local_data, f2.local_data)
assert_equal(f.total_volume(), 1)
assert_equal(f.total_volume(0), 1)
assert_equal(f.total_volume((0,)), 1)
assert_equal(f.scalar_weight(), 0.1)
assert_equal(f.scalar_weight(0), 0.1)
assert_equal(f.scalar_weight((0,)), 0.1)
s1 = ift.GLSpace(10)
f = ift.Field(s1, 10.)
assert_equal(f.scalar_weight(), None)
assert_equal(f.scalar_weight(0), None)
assert_equal(f.scalar_weight((0,)), None)
@expand(product([ift.RGSpace(10), ift.GLSpace(10)],
[np.float64, np.complex128]))
def test_reduction(self, dom, dt):
s1 = ift.Field(dom, 1., dtype=dt)
assert_allclose(s1.mean(), 1.)
assert_allclose(s1.mean(0), 1.)
assert_allclose(s1.var(), 0., atol=1e-14)
assert_allclose(s1.var(0), 0., atol=1e-14)
assert_allclose(s1.std(), 0., atol=1e-14)
assert_allclose(s1.std(0), 0., atol=1e-14)
def test_err(self):
s1 = ift.RGSpace((10,))
s2 = ift.RGSpace((11,))
f1 = ift.Field(s1, 27)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1.val)
with assert_raises(TypeError):
f2 = ift.Field(s2, "xyz")
with assert_raises(TypeError):
if f1:
pass
with assert_raises(TypeError):
f1.full((2, 4, 6))
with assert_raises(TypeError):
f2 = ift.Field(None, None)
with assert_raises(ValueError):
f2 = ift.Field(s1, None)
with assert_raises(ValueError):
f1.imag
with assert_raises(TypeError):
f1.vdot(42)
with assert_raises(ValueError):
f1.vdot(ift.Field(s2, 1.))
with assert_raises(TypeError):
f1.copy_content_from(1)
with assert_raises(ValueError):
f1.copy_content_from(ift.Field(s2, 1.))
with assert_raises(TypeError):
ift.full(s1, [2, 3])
def test_stdfunc(self): def test_stdfunc(self):
s = ift.RGSpace((200,)) s = ift.RGSpace((200,))
f = ift.Field(s, 27) f = ift.Field(s, 27)
......
...@@ -30,7 +30,7 @@ spaces = [ift.RGSpace([1024], distances=0.123), ift.HPSpace(32)] ...@@ -30,7 +30,7 @@ spaces = [ift.RGSpace([1024], distances=0.123), ift.HPSpace(32)]
minimizers = ['ift.VL_BFGS(IC)', minimizers = ['ift.VL_BFGS(IC)',
'ift.NonlinearCG(IC, "Polak-Ribiere")', 'ift.NonlinearCG(IC, "Polak-Ribiere")',
#'ift.NonlinearCG(IC, "Hestenes-Stiefel"), # 'ift.NonlinearCG(IC, "Hestenes-Stiefel"),
'ift.NonlinearCG(IC, "Fletcher-Reeves")', 'ift.NonlinearCG(IC, "Fletcher-Reeves")',
'ift.NonlinearCG(IC, "5.49")', 'ift.NonlinearCG(IC, "5.49")',
'ift.NewtonCG(xtol=1e-5, maxiter=1000)', 'ift.NewtonCG(xtol=1e-5, maxiter=1000)',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment