Commit 0992e538 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'tweak_field_construction' into 'NIFTy_5'

optimize the Field constructor

See merge request ift/nifty-dev!40
parents 4ed58632 2ec3da9a
...@@ -25,8 +25,8 @@ from .domains.domain import Domain ...@@ -25,8 +25,8 @@ from .domains.domain import Domain
class DomainTuple(object): class DomainTuple(object):
"""Ordered sequence of Domain objects. """Ordered sequence of Domain objects.
This class holds a set of :class:`Domain` objects, which together form the This class holds a tuple of :class:`Domain` objects, which together form
space on which a :class:`Field` is defined. the space on which a :class:`Field` is defined.
Notes Notes
----- -----
......
...@@ -33,45 +33,28 @@ class Field(object): ...@@ -33,45 +33,28 @@ class Field(object):
Parameters Parameters
---------- ----------
domain : None, DomainTuple, tuple of Domain, or Domain domain : DomainTuple
the domain of the new Field
val : Field, data_object or scalar
The values the array should contain after init. A scalar input will
fill the whole array with this scalar. If a data_object is provided,
its dimensions must match the domain's.
dtype : type val : data_object
A numpy.type. Most common are float and complex. This object's global shape must match the domain shape
After construction, the object will no longer be writeable!
Notes Notes
----- -----
If possible, do not invoke the constructor directly, but use one of the If possible, do not invoke the constructor directly, but use one of the
many convenience functions for Field conatruction! many convenience functions for Field construction!
""" """
def __init__(self, domain=None, val=None, dtype=None): def __init__(self, domain, val):
self._domain = self._infer_domain(domain=domain, val=val) if not isinstance(domain, DomainTuple):
raise TypeError("domain must be of type DomainTuple")
dtype = self._infer_dtype(dtype=dtype, val=val) if not isinstance(val, dobj.data_object):
if isinstance(val, Field): raise TypeError("val must be of type dobj.data_object")
if self._domain != val._domain: if domain.shape != val.shape:
raise ValueError("Domain mismatch") raise ValueError("mismatch between the shapes of val and domain")
self._val = val._val self._domain = domain
elif (np.isscalar(val)):
self._val = dobj.full(self._domain.shape, dtype=dtype,
fill_value=val)
elif isinstance(val, dobj.data_object):
if self._domain.shape == val.shape:
if dtype == val.dtype:
self._val = val self._val = val
else:
self._val = dobj.from_object(val, dtype, True, True)
else:
raise ValueError("Shape mismatch")
else:
raise TypeError("unknown source type")
dobj.lock(self._val) dobj.lock(self._val)
# prevent implicit conversion to bool # prevent implicit conversion to bool
...@@ -99,7 +82,10 @@ class Field(object): ...@@ -99,7 +82,10 @@ class Field(object):
""" """
if not np.isscalar(val): if not np.isscalar(val):
raise TypeError("val must be a scalar") raise TypeError("val must be a scalar")
return Field(DomainTuple.make(domain), val) if not (np.isreal(val) or np.iscomplex(val)):
raise TypeError("need arithmetic scalar")
domain = DomainTuple.make(domain)
return Field(domain, dobj.full(domain.shape, fill_value=val))
@staticmethod @staticmethod
def from_global_data(domain, arr, sum_up=False): def from_global_data(domain, arr, sum_up=False):
...@@ -118,12 +104,13 @@ class Field(object): ...@@ -118,12 +104,13 @@ class Field(object):
If False, the contens of `arr` are used directly, and must be If False, the contens of `arr` are used directly, and must be
identical on all MPI tasks. identical on all MPI tasks.
""" """
return Field(domain, dobj.from_global_data(arr, sum_up)) return Field(DomainTuple.make(domain),
dobj.from_global_data(arr, sum_up))
@staticmethod @staticmethod
def from_local_data(domain, arr): def from_local_data(domain, arr):
domain = DomainTuple.make(domain) return Field(DomainTuple.make(domain),
return Field(domain, dobj.from_local_data(domain.shape, arr)) dobj.from_local_data(domain.shape, arr))
def to_global_data(self): def to_global_data(self):
"""Returns an array containing the full data of the field. """Returns an array containing the full data of the field.
...@@ -167,25 +154,7 @@ class Field(object): ...@@ -167,25 +154,7 @@ class Field(object):
----- -----
No copy is made. If needed, use an additional copy() invocation. No copy is made. If needed, use an additional copy() invocation.
""" """
return Field(new_domain, self._val) return Field(DomainTuple.make(new_domain), self._val)
@staticmethod
def _infer_domain(domain, val=None):
if domain is None:
if isinstance(val, Field):
return val._domain
if np.isscalar(val):
return DomainTuple.make(()) # empty domain tuple
raise TypeError("could not infer domain from value")
return DomainTuple.make(domain)
@staticmethod
def _infer_dtype(dtype, val):
if dtype is not None:
return dtype
if val is None:
raise ValueError("could not infer dtype")
return np.result_type(val)
@staticmethod @staticmethod
def from_random(random_type, domain, dtype=np.float64, **kwargs): def from_random(random_type, domain, dtype=np.float64, **kwargs):
...@@ -444,7 +413,7 @@ class Field(object): ...@@ -444,7 +413,7 @@ class Field(object):
for i, dom in enumerate(self._domain) for i, dom in enumerate(self._domain)
if i not in spaces) if i not in spaces)
return Field(domain=return_domain, val=data) return Field(DomainTuple.make(return_domain), data)
def sum(self, spaces=None): def sum(self, spaces=None):
"""Sums up over the sub-domains given by `spaces`. """Sums up over the sub-domains given by `spaces`.
......
...@@ -40,7 +40,7 @@ def PS_field(pspace, func): ...@@ -40,7 +40,7 @@ def PS_field(pspace, func):
if not isinstance(pspace, PowerSpace): if not isinstance(pspace, PowerSpace):
raise TypeError raise TypeError
data = dobj.from_global_data(func(pspace.k_lengths)) data = dobj.from_global_data(func(pspace.k_lengths))
return Field(pspace, val=data) return Field(DomainTuple.make(pspace), data)
def get_signal_variance(spec, space): def get_signal_variance(spec, space):
...@@ -158,7 +158,7 @@ def _create_power_field(domain, power_spectrum): ...@@ -158,7 +158,7 @@ def _create_power_field(domain, power_spectrum):
if not isinstance(power_spectrum.domain[0], PowerSpace): if not isinstance(power_spectrum.domain[0], PowerSpace):
raise TypeError("PowerSpace required") raise TypeError("PowerSpace required")
power_domain = power_spectrum.domain[0] power_domain = power_spectrum.domain[0]
fp = Field(power_domain, val=power_spectrum.val) fp = power_spectrum
else: else:
power_domain = PowerSpace(domain) power_domain = PowerSpace(domain)
fp = PS_field(power_domain, power_spectrum) fp = PS_field(power_domain, power_spectrum)
......
...@@ -139,16 +139,16 @@ class Test_Functionality(unittest.TestCase): ...@@ -139,16 +139,16 @@ class Test_Functionality(unittest.TestCase):
assert_equal(d, d2) assert_equal(d, d2)
def test_empty_domain(self): def test_empty_domain(self):
f = ift.Field((), 5) f = ift.Field.full((), 5)
assert_equal(f.local_data, 5) assert_equal(f.local_data, 5)
f = ift.Field(None, 5) f = ift.Field.full(None, 5)
assert_equal(f.local_data, 5) assert_equal(f.local_data, 5)
def test_trivialities(self): def test_trivialities(self):
s1 = ift.RGSpace((10,)) s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27) f1 = ift.Field.full(s1, 27)
assert_equal(f1.local_data, f1.real.local_data) assert_equal(f1.local_data, f1.real.local_data)
f1 = ift.Field(s1, 27.+3j) f1 = ift.Field.full(s1, 27.+3j)
assert_equal(f1.real.local_data, 27.) assert_equal(f1.real.local_data, 27.)
assert_equal(f1.imag.local_data, 3.) assert_equal(f1.imag.local_data, 3.)
assert_equal(f1.local_data, +f1.local_data) assert_equal(f1.local_data, +f1.local_data)
...@@ -160,7 +160,7 @@ class Test_Functionality(unittest.TestCase): ...@@ -160,7 +160,7 @@ class Test_Functionality(unittest.TestCase):
def test_weight(self): def test_weight(self):
s1 = ift.RGSpace((10,)) s1 = ift.RGSpace((10,))
f = ift.Field(s1, 10.) f = ift.Field.full(s1, 10.)
f2 = f.weight(1) f2 = f.weight(1)
assert_equal(f.weight(1).local_data, f2.local_data) assert_equal(f.weight(1).local_data, f2.local_data)
assert_equal(f.total_volume(), 1) assert_equal(f.total_volume(), 1)
...@@ -170,7 +170,7 @@ class Test_Functionality(unittest.TestCase): ...@@ -170,7 +170,7 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f.scalar_weight(0), 0.1) assert_equal(f.scalar_weight(0), 0.1)
assert_equal(f.scalar_weight((0,)), 0.1) assert_equal(f.scalar_weight((0,)), 0.1)
s1 = ift.GLSpace(10) s1 = ift.GLSpace(10)
f = ift.Field(s1, 10.) f = ift.Field.full(s1, 10.)
assert_equal(f.scalar_weight(), None) assert_equal(f.scalar_weight(), None)
assert_equal(f.scalar_weight(0), None) assert_equal(f.scalar_weight(0), None)
assert_equal(f.scalar_weight((0,)), None) assert_equal(f.scalar_weight((0,)), None)
...@@ -178,7 +178,7 @@ class Test_Functionality(unittest.TestCase): ...@@ -178,7 +178,7 @@ class Test_Functionality(unittest.TestCase):
@expand(product([ift.RGSpace(10), ift.GLSpace(10)], @expand(product([ift.RGSpace(10), ift.GLSpace(10)],
[np.float64, np.complex128])) [np.float64, np.complex128]))
def test_reduction(self, dom, dt): def test_reduction(self, dom, dt):
s1 = ift.Field(dom, 1., dtype=dt) s1 = ift.Field.full(dom, dt(1.))
assert_allclose(s1.mean(), 1.) assert_allclose(s1.mean(), 1.)
assert_allclose(s1.mean(0), 1.) assert_allclose(s1.mean(0), 1.)
assert_allclose(s1.var(), 0., atol=1e-14) assert_allclose(s1.var(), 0., atol=1e-14)
...@@ -189,13 +189,11 @@ class Test_Functionality(unittest.TestCase): ...@@ -189,13 +189,11 @@ class Test_Functionality(unittest.TestCase):
def test_err(self): def test_err(self):
s1 = ift.RGSpace((10,)) s1 = ift.RGSpace((10,))
s2 = ift.RGSpace((11,)) s2 = ift.RGSpace((11,))
f1 = ift.Field(s1, 27) f1 = ift.Field.full(s1, 27)
with assert_raises(ValueError): with assert_raises(ValueError):
f2 = ift.Field(s2, f1) f2 = ift.Field(ift.DomainTuple.make(s2), f1.val)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1.val)
with assert_raises(TypeError): with assert_raises(TypeError):
f2 = ift.Field(s2, "xyz") f2 = ift.Field.full(s2, "xyz")
with assert_raises(TypeError): with assert_raises(TypeError):
if f1: if f1:
pass pass
...@@ -203,20 +201,20 @@ class Test_Functionality(unittest.TestCase): ...@@ -203,20 +201,20 @@ class Test_Functionality(unittest.TestCase):
f1.full((2, 4, 6)) f1.full((2, 4, 6))
with assert_raises(TypeError): with assert_raises(TypeError):
f2 = ift.Field(None, None) f2 = ift.Field(None, None)
with assert_raises(ValueError): with assert_raises(TypeError):
f2 = ift.Field(s1, None) f2 = ift.Field(s1, None)
with assert_raises(ValueError): with assert_raises(ValueError):
f1.imag f1.imag
with assert_raises(TypeError): with assert_raises(TypeError):
f1.vdot(42) f1.vdot(42)
with assert_raises(ValueError): with assert_raises(ValueError):
f1.vdot(ift.Field(s2, 1.)) f1.vdot(ift.Field.full(s2, 1.))
with assert_raises(TypeError): with assert_raises(TypeError):
ift.full(s1, [2, 3]) ift.full(s1, [2, 3])
def test_stdfunc(self): def test_stdfunc(self):
s = ift.RGSpace((200,)) s = ift.RGSpace((200,))
f = ift.Field(s, 27) f = ift.Field.full(s, 27)
assert_equal(f.local_data, 27) assert_equal(f.local_data, 27)
assert_equal(f.shape, (200,)) assert_equal(f.shape, (200,))
assert_equal(f.dtype, np.int) assert_equal(f.dtype, np.int)
......
...@@ -133,7 +133,7 @@ class Test_Minimizers(unittest.TestCase): ...@@ -133,7 +133,7 @@ class Test_Minimizers(unittest.TestCase):
@expand(product(minimizers+slow_minimizers)) @expand(product(minimizers+slow_minimizers))
def test_gauss(self, minimizer): def test_gauss(self, minimizer):
space = ift.UnstructuredDomain((1,)) space = ift.UnstructuredDomain((1,))
starting_point = ift.Field(space, val=3.) starting_point = ift.Field.full(space, 3.)
class ExpEnergy(ift.Energy): class ExpEnergy(ift.Energy):
def __init__(self, position): def __init__(self, position):
...@@ -147,14 +147,15 @@ class Test_Minimizers(unittest.TestCase): ...@@ -147,14 +147,15 @@ class Test_Minimizers(unittest.TestCase):
@property @property
def gradient(self): def gradient(self):
x = self.position.to_global_data()[0] x = self.position.to_global_data()[0]
return ift.Field(self.position.domain, val=2*x*np.exp(-(x**2))) return ift.Field.full(self.position.domain,
2*x*np.exp(-(x**2)))
@property @property
def metric(self): def metric(self):
x = self.position.to_global_data()[0] x = self.position.to_global_data()[0]
v = (2 - 4*x*x)*np.exp(-x**2) v = (2 - 4*x*x)*np.exp(-x**2)
return ift.DiagonalOperator( return ift.DiagonalOperator(
ift.Field(self.position.domain, val=v)) ift.Field.full(self.position.domain, v))
try: try:
minimizer = eval(minimizer) minimizer = eval(minimizer)
...@@ -171,7 +172,7 @@ class Test_Minimizers(unittest.TestCase): ...@@ -171,7 +172,7 @@ class Test_Minimizers(unittest.TestCase):
@expand(product(minimizers+newton_minimizers+slow_minimizers)) @expand(product(minimizers+newton_minimizers+slow_minimizers))
def test_cosh(self, minimizer): def test_cosh(self, minimizer):
space = ift.UnstructuredDomain((1,)) space = ift.UnstructuredDomain((1,))
starting_point = ift.Field(space, val=3.) starting_point = ift.Field.full(space, 3.)
class CoshEnergy(ift.Energy): class CoshEnergy(ift.Energy):
def __init__(self, position): def __init__(self, position):
...@@ -185,14 +186,14 @@ class Test_Minimizers(unittest.TestCase): ...@@ -185,14 +186,14 @@ class Test_Minimizers(unittest.TestCase):
@property @property
def gradient(self): def gradient(self):
x = self.position.to_global_data()[0] x = self.position.to_global_data()[0]
return ift.Field(self.position.domain, val=np.sinh(x)) return ift.Field.full(self.position.domain, np.sinh(x))
@property @property
def metric(self): def metric(self):
x = self.position.to_global_data()[0] x = self.position.to_global_data()[0]
v = np.cosh(x) v = np.cosh(x)
return ift.DiagonalOperator( return ift.DiagonalOperator(
ift.Field(self.position.domain, val=v)) ift.Field.full(self.position.domain, v))
try: try:
minimizer = eval(minimizer) minimizer = eval(minimizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment