Skip to content
Snippets Groups Projects
Commit 2ec3da9a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

optimize the Field constructor

parent 4ed58632
No related branches found
No related tags found
No related merge requests found
......@@ -25,8 +25,8 @@ from .domains.domain import Domain
class DomainTuple(object):
"""Ordered sequence of Domain objects.
This class holds a set of :class:`Domain` objects, which together form the
space on which a :class:`Field` is defined.
This class holds a tuple of :class:`Domain` objects, which together form
the space on which a :class:`Field` is defined.
Notes
-----
......
......
......@@ -33,45 +33,28 @@ class Field(object):
Parameters
----------
domain : None, DomainTuple, tuple of Domain, or Domain
val : Field, data_object or scalar
The values the array should contain after init. A scalar input will
fill the whole array with this scalar. If a data_object is provided,
its dimensions must match the domain's.
domain : DomainTuple
the domain of the new Field
dtype : type
A numpy.type. Most common are float and complex.
val : data_object
This object's global shape must match the domain shape
After construction, the object will no longer be writeable!
Notes
-----
If possible, do not invoke the constructor directly, but use one of the
many convenience functions for Field conatruction!
many convenience functions for Field construction!
"""
def __init__(self, domain=None, val=None, dtype=None):
self._domain = self._infer_domain(domain=domain, val=val)
dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field):
if self._domain != val._domain:
raise ValueError("Domain mismatch")
self._val = val._val
elif (np.isscalar(val)):
self._val = dobj.full(self._domain.shape, dtype=dtype,
fill_value=val)
elif isinstance(val, dobj.data_object):
if self._domain.shape == val.shape:
if dtype == val.dtype:
def __init__(self, domain, val):
if not isinstance(domain, DomainTuple):
raise TypeError("domain must be of type DomainTuple")
if not isinstance(val, dobj.data_object):
raise TypeError("val must be of type dobj.data_object")
if domain.shape != val.shape:
raise ValueError("mismatch between the shapes of val and domain")
self._domain = domain
self._val = val
else:
self._val = dobj.from_object(val, dtype, True, True)
else:
raise ValueError("Shape mismatch")
else:
raise TypeError("unknown source type")
dobj.lock(self._val)
# prevent implicit conversion to bool
......@@ -99,7 +82,10 @@ class Field(object):
"""
if not np.isscalar(val):
raise TypeError("val must be a scalar")
return Field(DomainTuple.make(domain), val)
if not (np.isreal(val) or np.iscomplex(val)):
raise TypeError("need arithmetic scalar")
domain = DomainTuple.make(domain)
return Field(domain, dobj.full(domain.shape, fill_value=val))
@staticmethod
def from_global_data(domain, arr, sum_up=False):
......@@ -118,12 +104,13 @@ class Field(object):
If False, the contens of `arr` are used directly, and must be
identical on all MPI tasks.
"""
return Field(domain, dobj.from_global_data(arr, sum_up))
return Field(DomainTuple.make(domain),
dobj.from_global_data(arr, sum_up))
@staticmethod
def from_local_data(domain, arr):
domain = DomainTuple.make(domain)
return Field(domain, dobj.from_local_data(domain.shape, arr))
return Field(DomainTuple.make(domain),
dobj.from_local_data(domain.shape, arr))
def to_global_data(self):
"""Returns an array containing the full data of the field.
......@@ -167,25 +154,7 @@ class Field(object):
-----
No copy is made. If needed, use an additional copy() invocation.
"""
return Field(new_domain, self._val)
@staticmethod
def _infer_domain(domain, val=None):
if domain is None:
if isinstance(val, Field):
return val._domain
if np.isscalar(val):
return DomainTuple.make(()) # empty domain tuple
raise TypeError("could not infer domain from value")
return DomainTuple.make(domain)
@staticmethod
def _infer_dtype(dtype, val):
if dtype is not None:
return dtype
if val is None:
raise ValueError("could not infer dtype")
return np.result_type(val)
return Field(DomainTuple.make(new_domain), self._val)
@staticmethod
def from_random(random_type, domain, dtype=np.float64, **kwargs):
......@@ -444,7 +413,7 @@ class Field(object):
for i, dom in enumerate(self._domain)
if i not in spaces)
return Field(domain=return_domain, val=data)
return Field(DomainTuple.make(return_domain), data)
def sum(self, spaces=None):
"""Sums up over the sub-domains given by `spaces`.
......
......
......@@ -40,7 +40,7 @@ def PS_field(pspace, func):
if not isinstance(pspace, PowerSpace):
raise TypeError
data = dobj.from_global_data(func(pspace.k_lengths))
return Field(pspace, val=data)
return Field(DomainTuple.make(pspace), data)
def get_signal_variance(spec, space):
......@@ -158,7 +158,7 @@ def _create_power_field(domain, power_spectrum):
if not isinstance(power_spectrum.domain[0], PowerSpace):
raise TypeError("PowerSpace required")
power_domain = power_spectrum.domain[0]
fp = Field(power_domain, val=power_spectrum.val)
fp = power_spectrum
else:
power_domain = PowerSpace(domain)
fp = PS_field(power_domain, power_spectrum)
......
......
......@@ -139,16 +139,16 @@ class Test_Functionality(unittest.TestCase):
assert_equal(d, d2)
def test_empty_domain(self):
f = ift.Field((), 5)
f = ift.Field.full((), 5)
assert_equal(f.local_data, 5)
f = ift.Field(None, 5)
f = ift.Field.full(None, 5)
assert_equal(f.local_data, 5)
def test_trivialities(self):
s1 = ift.RGSpace((10,))
f1 = ift.Field(s1, 27)
f1 = ift.Field.full(s1, 27)
assert_equal(f1.local_data, f1.real.local_data)
f1 = ift.Field(s1, 27.+3j)
f1 = ift.Field.full(s1, 27.+3j)
assert_equal(f1.real.local_data, 27.)
assert_equal(f1.imag.local_data, 3.)
assert_equal(f1.local_data, +f1.local_data)
......@@ -160,7 +160,7 @@ class Test_Functionality(unittest.TestCase):
def test_weight(self):
s1 = ift.RGSpace((10,))
f = ift.Field(s1, 10.)
f = ift.Field.full(s1, 10.)
f2 = f.weight(1)
assert_equal(f.weight(1).local_data, f2.local_data)
assert_equal(f.total_volume(), 1)
......@@ -170,7 +170,7 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f.scalar_weight(0), 0.1)
assert_equal(f.scalar_weight((0,)), 0.1)
s1 = ift.GLSpace(10)
f = ift.Field(s1, 10.)
f = ift.Field.full(s1, 10.)
assert_equal(f.scalar_weight(), None)
assert_equal(f.scalar_weight(0), None)
assert_equal(f.scalar_weight((0,)), None)
......@@ -178,7 +178,7 @@ class Test_Functionality(unittest.TestCase):
@expand(product([ift.RGSpace(10), ift.GLSpace(10)],
[np.float64, np.complex128]))
def test_reduction(self, dom, dt):
s1 = ift.Field(dom, 1., dtype=dt)
s1 = ift.Field.full(dom, dt(1.))
assert_allclose(s1.mean(), 1.)
assert_allclose(s1.mean(0), 1.)
assert_allclose(s1.var(), 0., atol=1e-14)
......@@ -189,13 +189,11 @@ class Test_Functionality(unittest.TestCase):
def test_err(self):
s1 = ift.RGSpace((10,))
s2 = ift.RGSpace((11,))
f1 = ift.Field(s1, 27)
f1 = ift.Field.full(s1, 27)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1)
with assert_raises(ValueError):
f2 = ift.Field(s2, f1.val)
f2 = ift.Field(ift.DomainTuple.make(s2), f1.val)
with assert_raises(TypeError):
f2 = ift.Field(s2, "xyz")
f2 = ift.Field.full(s2, "xyz")
with assert_raises(TypeError):
if f1:
pass
......@@ -203,20 +201,20 @@ class Test_Functionality(unittest.TestCase):
f1.full((2, 4, 6))
with assert_raises(TypeError):
f2 = ift.Field(None, None)
with assert_raises(ValueError):
with assert_raises(TypeError):
f2 = ift.Field(s1, None)
with assert_raises(ValueError):
f1.imag
with assert_raises(TypeError):
f1.vdot(42)
with assert_raises(ValueError):
f1.vdot(ift.Field(s2, 1.))
f1.vdot(ift.Field.full(s2, 1.))
with assert_raises(TypeError):
ift.full(s1, [2, 3])
def test_stdfunc(self):
s = ift.RGSpace((200,))
f = ift.Field(s, 27)
f = ift.Field.full(s, 27)
assert_equal(f.local_data, 27)
assert_equal(f.shape, (200,))
assert_equal(f.dtype, np.int)
......
......
......@@ -133,7 +133,7 @@ class Test_Minimizers(unittest.TestCase):
@expand(product(minimizers+slow_minimizers))
def test_gauss(self, minimizer):
space = ift.UnstructuredDomain((1,))
starting_point = ift.Field(space, val=3.)
starting_point = ift.Field.full(space, 3.)
class ExpEnergy(ift.Energy):
def __init__(self, position):
......@@ -147,14 +147,15 @@ class Test_Minimizers(unittest.TestCase):
@property
def gradient(self):
x = self.position.to_global_data()[0]
return ift.Field(self.position.domain, val=2*x*np.exp(-(x**2)))
return ift.Field.full(self.position.domain,
2*x*np.exp(-(x**2)))
@property
def metric(self):
x = self.position.to_global_data()[0]
v = (2 - 4*x*x)*np.exp(-x**2)
return ift.DiagonalOperator(
ift.Field(self.position.domain, val=v))
ift.Field.full(self.position.domain, v))
try:
minimizer = eval(minimizer)
......@@ -171,7 +172,7 @@ class Test_Minimizers(unittest.TestCase):
@expand(product(minimizers+newton_minimizers+slow_minimizers))
def test_cosh(self, minimizer):
space = ift.UnstructuredDomain((1,))
starting_point = ift.Field(space, val=3.)
starting_point = ift.Field.full(space, 3.)
class CoshEnergy(ift.Energy):
def __init__(self, position):
......@@ -185,14 +186,14 @@ class Test_Minimizers(unittest.TestCase):
@property
def gradient(self):
x = self.position.to_global_data()[0]
return ift.Field(self.position.domain, val=np.sinh(x))
return ift.Field.full(self.position.domain, np.sinh(x))
@property
def metric(self):
x = self.position.to_global_data()[0]
v = np.cosh(x)
return ift.DiagonalOperator(
ift.Field(self.position.domain, val=v))
ift.Field.full(self.position.domain, v))
try:
minimizer = eval(minimizer)
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment