Commit 2ec3da9a authored by Martin Reinecke's avatar Martin Reinecke

optimize the Field constructor

parent 4ed58632
......@@ -25,8 +25,8 @@ from .domains.domain import Domain
class DomainTuple(object):
"""Ordered sequence of Domain objects.
This class holds a set of :class:`Domain` objects, which together form the
space on which a :class:`Field` is defined.
This class holds a tuple of :class:`Domain` objects, which together form
the space on which a :class:`Field` is defined.
Notes
-----
......
......@@ -33,45 +33,28 @@ class Field(object):
Parameters
----------
domain : None, DomainTuple, tuple of Domain, or Domain
domain : DomainTuple
the domain of the new Field
val : Field, data_object or scalar
The values the array should contain after init. A scalar input will
fill the whole array with this scalar. If a data_object is provided,
its dimensions must match the domain's.
dtype : type
A numpy.type. Most common are float and complex.
val : data_object
This object's global shape must match the domain shape
After construction, the object will no longer be writeable!
Notes
-----
If possible, do not invoke the constructor directly, but use one of the
many convenience functions for Field conatruction!
many convenience functions for Field construction!
"""
def __init__(self, domain=None, val=None, dtype=None):
self._domain = self._infer_domain(domain=domain, val=val)
dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field):
if self._domain != val._domain:
raise ValueError("Domain mismatch")
self._val = val._val
elif (np.isscalar(val)):
self._val = dobj.full(self._domain.shape, dtype=dtype,
fill_value=val)
elif isinstance(val, dobj.data_object):
if self._domain.shape == val.shape:
if dtype == val.dtype:
self._val = val
else:
self._val = dobj.from_object(val, dtype, True, True)
else:
raise ValueError("Shape mismatch")
else:
raise TypeError("unknown source type")
def __init__(self, domain, val):
if not isinstance(domain, DomainTuple):
raise TypeError("domain must be of type DomainTuple")
if not isinstance(val, dobj.data_object):
raise TypeError("val must be of type dobj.data_object")
if domain.shape != val.shape:
raise ValueError("mismatch between the shapes of val and domain")
self._domain = domain
self._val = val
dobj.lock(self._val)
# prevent implicit conversion to bool
......@@ -99,7 +82,10 @@ class Field(object):
"""
if not np.isscalar(val):
raise TypeError("val must be a scalar")
return Field(DomainTuple.make(domain), val)
if not (np.isreal(val) or np.iscomplex(val)):
raise TypeError("need arithmetic scalar")
domain = DomainTuple.make(domain)
return Field(domain, dobj.full(domain.shape, fill_value=val))
@staticmethod
def from_global_data(domain, arr, sum_up=False):
......@@ -118,12 +104,13 @@ class Field(object):
If False, the contens of `arr` are used directly, and must be
identical on all MPI tasks.
"""
return Field(domain, dobj.from_global_data(arr, sum_up))
return Field(DomainTuple.make(domain),
dobj.from_global_data(arr, sum_up))
@staticmethod
def from_local_data(domain, arr):
domain = DomainTuple.make(domain)
return Field(domain, dobj.from_local_data(domain.shape, arr))
return Field(DomainTuple.make(domain),
dobj.from_local_data(domain.shape, arr))
def to_global_data(self):
"""Returns an array containing the full data of the field.
......@@ -167,25 +154,7 @@ class Field(object):
-----
No copy is made. If needed, use an additional copy() invocation.
"""
return Field(new_domain, self._val)
@staticmethod
def _infer_domain(domain, val=None):
if domain is None:
if isinstance(val, Field):
return val._domain
if np.isscalar(val):
return DomainTuple.make(()) # empty domain tuple
raise TypeError("could not infer domain from value")
return DomainTuple.make(domain)
@staticmethod
def _infer_dtype(dtype, val):
if dtype is not None:
return dtype
if val is None:
raise ValueError("could not infer dtype")
return np.result_type(val)
return Field(DomainTuple.make(new_domain), self._val)
@staticmethod
def from_random(random_type, domain, dtype=np.float64, **kwargs):
......@@ -444,7 +413,7 @@ class Field(object):
for i, dom in enumerate(self._domain)
if i not in spaces)
return Field(domain=return_domain, val=data)
return Field(DomainTuple.make(return_domain), data)
def sum(self, spaces=None):
"""Sums up over the sub-domains given by `spaces`.
......
......@@ -40,7 +40,7 @@ def PS_field(pspace, func):
if not isinstance(pspace, PowerSpace):
raise TypeError
data = dobj.from_global_data(func(pspace.k_lengths))
return Field(pspace, val=data)
return Field(DomainTuple.make(pspace), data)
def get_signal_variance(spec, space):
......@@ -158,7 +158,7 @@ def _create_power_field(domain, power_spectrum):
if not isinstance(power_spectrum.domain[0], PowerSpace):
raise TypeError("PowerSpace required")
power_domain = power_spectrum.domain[0]
fp = Field(power_domain, val=power_spectrum.val)
fp = power_spectrum
else:
power_domain = PowerSpace(domain)
fp = PS_field(power_domain, power_spectrum)
......
......@@ -139,16 +139,16 @@ class Test_Functionality(unittest.TestCase):
assert_equal(d, d2)
def test_empty_domain(self):
f = ift.Field((), 5)
f = ift.Field.full((), 5)
assert_equal(f.local_data, 5)
f = ift.Field(None, 5)
f = ift.Field.full(None, 5)
assert_equal(f.local_data, 5)
def test_trivialities(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
f1 = ift.Field.full(s1, 27)
assert_equal(f1.local_data, f1.real.local_data)
f1 = ift.Field(s1, 27.+3j)
f1 = ift.Field.full(s1, 27.+3j)
assert_equal(f1.real.local_data, 27.)
assert_equal(f1.imag.local_data, 3.)
assert_equal(f1.local_data, +f1.local_data)
......@@ -160,7 +160,7 @@ class Test_Functionality(unittest.TestCase):
def test_weight(self):
s1 = ift.RGSpace((10,))
f = ift.Field(s1, 10.)
f = ift.Field.full(s1, 10.)
f2 = f.weight(1)
assert_equal(f.weight(1).local_data, f2.local_data)
assert_equal(f.total_volume(), 1)
......@@ -170,7 +170,7 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f.scalar_weight(0), 0.1)
assert_equal(f.scalar_weight((0,)), 0.1)
s1 = ift.GLSpace(10)
f = ift.Field(s1, 10.)
f = ift.Field.full(s1, 10.)
assert_equal(f.scalar_weight(), None)
assert_equal(f.scalar_weight(0), None)
assert_equal(f.scalar_weight((0,)), None)
......@@ -178,7 +178,7 @@ class Test_Functionality(unittest.TestCase):
@expand(product([ift.RGSpace(10), ift.GLSpace(10)],
[np.float64, np.complex128]))
def test_reduction(self, dom, dt):
s1 = ift.Field(dom, 1., dtype=dt)
s1 = ift.Field.full(dom, dt(1.))
assert_allclose(s1.mean(), 1.)
assert_allclose(s1.mean(0), 1.)
assert_allclose(s1.var(), 0., atol=1e-14)
......@@ -189,13 +189,11 @@ class Test_Functionality(unittest.TestCase):
def test_err(self):
s1 = ift.RGSpace((10,))
s2 = ift.RGSpace((11,))
f1 = ift.Field(s1, 27)
f1 = ift.Field.full(s1, 27)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1.val)
f2 = ift.Field(ift.DomainTuple.make(s2), f1.val)
with assert_raises(TypeError):
f2 = ift.Field(s2, "xyz")
f2 = ift.Field.full(s2, "xyz")
with assert_raises(TypeError):
if f1:
pass
......@@ -203,20 +201,20 @@ class Test_Functionality(unittest.TestCase):
f1.full((2, 4, 6))
with assert_raises(TypeError):
f2 = ift.Field(None, None)
with assert_raises(ValueError):
with assert_raises(TypeError):
f2 = ift.Field(s1, None)
with assert_raises(ValueError):
f1.imag
with assert_raises(TypeError):
f1.vdot(42)
with assert_raises(ValueError):
f1.vdot(ift.Field(s2, 1.))
f1.vdot(ift.Field.full(s2, 1.))
with assert_raises(TypeError):
ift.full(s1, [2, 3])
def test_stdfunc(self):
s = ift.RGSpace((200,))
f = ift.Field(s, 27)
f = ift.Field.full(s, 27)
assert_equal(f.local_data, 27)
assert_equal(f.shape, (200,))
assert_equal(f.dtype, np.int)
......
......@@ -133,7 +133,7 @@ class Test_Minimizers(unittest.TestCase):
@expand(product(minimizers+slow_minimizers))
def test_gauss(self, minimizer):
space = ift.UnstructuredDomain((1,))
starting_point = ift.Field(space, val=3.)
starting_point = ift.Field.full(space, 3.)
class ExpEnergy(ift.Energy):
def __init__(self, position):
......@@ -147,14 +147,15 @@ class Test_Minimizers(unittest.TestCase):
@property
def gradient(self):
x = self.position.to_global_data()[0]
return ift.Field(self.position.domain, val=2*x*np.exp(-(x**2)))
return ift.Field.full(self.position.domain,
2*x*np.exp(-(x**2)))
@property
def metric(self):
x = self.position.to_global_data()[0]
v = (2 - 4*x*x)*np.exp(-x**2)
return ift.DiagonalOperator(
ift.Field(self.position.domain, val=v))
ift.Field.full(self.position.domain, v))
try:
minimizer = eval(minimizer)
......@@ -171,7 +172,7 @@ class Test_Minimizers(unittest.TestCase):
@expand(product(minimizers+newton_minimizers+slow_minimizers))
def test_cosh(self, minimizer):
space = ift.UnstructuredDomain((1,))
starting_point = ift.Field(space, val=3.)
starting_point = ift.Field.full(space, 3.)
class CoshEnergy(ift.Energy):
def __init__(self, position):
......@@ -185,14 +186,14 @@ class Test_Minimizers(unittest.TestCase):
@property
def gradient(self):
x = self.position.to_global_data()[0]
return ift.Field(self.position.domain, val=np.sinh(x))
return ift.Field.full(self.position.domain, np.sinh(x))
@property
def metric(self):
x = self.position.to_global_data()[0]
v = np.cosh(x)
return ift.DiagonalOperator(
ift.Field(self.position.domain, val=v))
ift.Field.full(self.position.domain, v))
try:
minimizer = eval(minimizer)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment