From 2ec3da9a4bedcd7076bd4792c5776323caf979f4 Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Fri, 6 Jul 2018 11:56:18 +0200 Subject: [PATCH] optimize the Field constructor --- nifty5/domain_tuple.py | 4 +- nifty5/field.py | 81 +++++++---------------- nifty5/sugar.py | 4 +- test/test_field.py | 28 ++++---- test/test_minimization/test_minimizers.py | 13 ++-- 5 files changed, 49 insertions(+), 81 deletions(-) diff --git a/nifty5/domain_tuple.py b/nifty5/domain_tuple.py index c4465395..ea446d8e 100644 --- a/nifty5/domain_tuple.py +++ b/nifty5/domain_tuple.py @@ -25,8 +25,8 @@ from .domains.domain import Domain class DomainTuple(object): """Ordered sequence of Domain objects. - This class holds a set of :class:`Domain` objects, which together form the - space on which a :class:`Field` is defined. + This class holds a tuple of :class:`Domain` objects, which together form + the space on which a :class:`Field` is defined. Notes ----- diff --git a/nifty5/field.py b/nifty5/field.py index 62476263..f827ddce 100644 --- a/nifty5/field.py +++ b/nifty5/field.py @@ -33,45 +33,28 @@ class Field(object): Parameters ---------- - domain : None, DomainTuple, tuple of Domain, or Domain + domain : DomainTuple + the domain of the new Field - val : Field, data_object or scalar - The values the array should contain after init. A scalar input will - fill the whole array with this scalar. If a data_object is provided, - its dimensions must match the domain's. - - dtype : type - A numpy.type. Most common are float and complex. + val : data_object + This object's global shape must match the domain shape + After construction, the object will no longer be writeable! Notes ----- If possible, do not invoke the constructor directly, but use one of the - many convenience functions for Field conatruction! + many convenience functions for Field construction! """ - def __init__(self, domain=None, val=None, dtype=None): - self._domain = self._infer_domain(domain=domain, val=val) - - dtype = self._infer_dtype(dtype=dtype, val=val) - if isinstance(val, Field): - if self._domain != val._domain: - raise ValueError("Domain mismatch") - self._val = val._val - - elif (np.isscalar(val)): - self._val = dobj.full(self._domain.shape, dtype=dtype, - fill_value=val) - elif isinstance(val, dobj.data_object): - if self._domain.shape == val.shape: - if dtype == val.dtype: - self._val = val - else: - self._val = dobj.from_object(val, dtype, True, True) - else: - raise ValueError("Shape mismatch") - else: - raise TypeError("unknown source type") - + def __init__(self, domain, val): + if not isinstance(domain, DomainTuple): + raise TypeError("domain must be of type DomainTuple") + if not isinstance(val, dobj.data_object): + raise TypeError("val must be of type dobj.data_object") + if domain.shape != val.shape: + raise ValueError("mismatch between the shapes of val and domain") + self._domain = domain + self._val = val dobj.lock(self._val) # prevent implicit conversion to bool @@ -99,7 +82,10 @@ class Field(object): """ if not np.isscalar(val): raise TypeError("val must be a scalar") - return Field(DomainTuple.make(domain), val) + if not (np.isreal(val) or np.iscomplex(val)): + raise TypeError("need arithmetic scalar") + domain = DomainTuple.make(domain) + return Field(domain, dobj.full(domain.shape, fill_value=val)) @staticmethod def from_global_data(domain, arr, sum_up=False): @@ -118,12 +104,13 @@ class Field(object): If False, the contens of `arr` are used directly, and must be identical on all MPI tasks. """ - return Field(domain, dobj.from_global_data(arr, sum_up)) + return Field(DomainTuple.make(domain), + dobj.from_global_data(arr, sum_up)) @staticmethod def from_local_data(domain, arr): - domain = DomainTuple.make(domain) - return Field(domain, dobj.from_local_data(domain.shape, arr)) + return Field(DomainTuple.make(domain), + dobj.from_local_data(domain.shape, arr)) def to_global_data(self): """Returns an array containing the full data of the field. @@ -167,25 +154,7 @@ class Field(object): ----- No copy is made. If needed, use an additional copy() invocation. """ - return Field(new_domain, self._val) - - @staticmethod - def _infer_domain(domain, val=None): - if domain is None: - if isinstance(val, Field): - return val._domain - if np.isscalar(val): - return DomainTuple.make(()) # empty domain tuple - raise TypeError("could not infer domain from value") - return DomainTuple.make(domain) - - @staticmethod - def _infer_dtype(dtype, val): - if dtype is not None: - return dtype - if val is None: - raise ValueError("could not infer dtype") - return np.result_type(val) + return Field(DomainTuple.make(new_domain), self._val) @staticmethod def from_random(random_type, domain, dtype=np.float64, **kwargs): @@ -444,7 +413,7 @@ class Field(object): for i, dom in enumerate(self._domain) if i not in spaces) - return Field(domain=return_domain, val=data) + return Field(DomainTuple.make(return_domain), data) def sum(self, spaces=None): """Sums up over the sub-domains given by `spaces`. diff --git a/nifty5/sugar.py b/nifty5/sugar.py index b5069e2b..0dd149ac 100644 --- a/nifty5/sugar.py +++ b/nifty5/sugar.py @@ -40,7 +40,7 @@ def PS_field(pspace, func): if not isinstance(pspace, PowerSpace): raise TypeError data = dobj.from_global_data(func(pspace.k_lengths)) - return Field(pspace, val=data) + return Field(DomainTuple.make(pspace), data) def get_signal_variance(spec, space): @@ -158,7 +158,7 @@ def _create_power_field(domain, power_spectrum): if not isinstance(power_spectrum.domain[0], PowerSpace): raise TypeError("PowerSpace required") power_domain = power_spectrum.domain[0] - fp = Field(power_domain, val=power_spectrum.val) + fp = power_spectrum else: power_domain = PowerSpace(domain) fp = PS_field(power_domain, power_spectrum) diff --git a/test/test_field.py b/test/test_field.py index c5687e2b..42a57342 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -139,16 +139,16 @@ class Test_Functionality(unittest.TestCase): assert_equal(d, d2) def test_empty_domain(self): - f = ift.Field((), 5) + f = ift.Field.full((), 5) assert_equal(f.local_data, 5) - f = ift.Field(None, 5) + f = ift.Field.full(None, 5) assert_equal(f.local_data, 5) def test_trivialities(self): s1 = ift.RGSpace((10,)) - f1 = ift.Field(s1, 27) + f1 = ift.Field.full(s1, 27) assert_equal(f1.local_data, f1.real.local_data) - f1 = ift.Field(s1, 27.+3j) + f1 = ift.Field.full(s1, 27.+3j) assert_equal(f1.real.local_data, 27.) assert_equal(f1.imag.local_data, 3.) assert_equal(f1.local_data, +f1.local_data) @@ -160,7 +160,7 @@ class Test_Functionality(unittest.TestCase): def test_weight(self): s1 = ift.RGSpace((10,)) - f = ift.Field(s1, 10.) + f = ift.Field.full(s1, 10.) f2 = f.weight(1) assert_equal(f.weight(1).local_data, f2.local_data) assert_equal(f.total_volume(), 1) @@ -170,7 +170,7 @@ class Test_Functionality(unittest.TestCase): assert_equal(f.scalar_weight(0), 0.1) assert_equal(f.scalar_weight((0,)), 0.1) s1 = ift.GLSpace(10) - f = ift.Field(s1, 10.) + f = ift.Field.full(s1, 10.) assert_equal(f.scalar_weight(), None) assert_equal(f.scalar_weight(0), None) assert_equal(f.scalar_weight((0,)), None) @@ -178,7 +178,7 @@ class Test_Functionality(unittest.TestCase): @expand(product([ift.RGSpace(10), ift.GLSpace(10)], [np.float64, np.complex128])) def test_reduction(self, dom, dt): - s1 = ift.Field(dom, 1., dtype=dt) + s1 = ift.Field.full(dom, dt(1.)) assert_allclose(s1.mean(), 1.) assert_allclose(s1.mean(0), 1.) assert_allclose(s1.var(), 0., atol=1e-14) @@ -189,13 +189,11 @@ class Test_Functionality(unittest.TestCase): def test_err(self): s1 = ift.RGSpace((10,)) s2 = ift.RGSpace((11,)) - f1 = ift.Field(s1, 27) + f1 = ift.Field.full(s1, 27) with assert_raises(ValueError): - f2 = ift.Field(s2, f1) - with assert_raises(ValueError): - f2 = ift.Field(s2, f1.val) + f2 = ift.Field(ift.DomainTuple.make(s2), f1.val) with assert_raises(TypeError): - f2 = ift.Field(s2, "xyz") + f2 = ift.Field.full(s2, "xyz") with assert_raises(TypeError): if f1: pass @@ -203,20 +201,20 @@ class Test_Functionality(unittest.TestCase): f1.full((2, 4, 6)) with assert_raises(TypeError): f2 = ift.Field(None, None) - with assert_raises(ValueError): + with assert_raises(TypeError): f2 = ift.Field(s1, None) with assert_raises(ValueError): f1.imag with assert_raises(TypeError): f1.vdot(42) with assert_raises(ValueError): - f1.vdot(ift.Field(s2, 1.)) + f1.vdot(ift.Field.full(s2, 1.)) with assert_raises(TypeError): ift.full(s1, [2, 3]) def test_stdfunc(self): s = ift.RGSpace((200,)) - f = ift.Field(s, 27) + f = ift.Field.full(s, 27) assert_equal(f.local_data, 27) assert_equal(f.shape, (200,)) assert_equal(f.dtype, np.int) diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py index 9685b871..8607a8b9 100644 --- a/test/test_minimization/test_minimizers.py +++ b/test/test_minimization/test_minimizers.py @@ -133,7 +133,7 @@ class Test_Minimizers(unittest.TestCase): @expand(product(minimizers+slow_minimizers)) def test_gauss(self, minimizer): space = ift.UnstructuredDomain((1,)) - starting_point = ift.Field(space, val=3.) + starting_point = ift.Field.full(space, 3.) class ExpEnergy(ift.Energy): def __init__(self, position): @@ -147,14 +147,15 @@ class Test_Minimizers(unittest.TestCase): @property def gradient(self): x = self.position.to_global_data()[0] - return ift.Field(self.position.domain, val=2*x*np.exp(-(x**2))) + return ift.Field.full(self.position.domain, + 2*x*np.exp(-(x**2))) @property def metric(self): x = self.position.to_global_data()[0] v = (2 - 4*x*x)*np.exp(-x**2) return ift.DiagonalOperator( - ift.Field(self.position.domain, val=v)) + ift.Field.full(self.position.domain, v)) try: minimizer = eval(minimizer) @@ -171,7 +172,7 @@ class Test_Minimizers(unittest.TestCase): @expand(product(minimizers+newton_minimizers+slow_minimizers)) def test_cosh(self, minimizer): space = ift.UnstructuredDomain((1,)) - starting_point = ift.Field(space, val=3.) + starting_point = ift.Field.full(space, 3.) class CoshEnergy(ift.Energy): def __init__(self, position): @@ -185,14 +186,14 @@ class Test_Minimizers(unittest.TestCase): @property def gradient(self): x = self.position.to_global_data()[0] - return ift.Field(self.position.domain, val=np.sinh(x)) + return ift.Field.full(self.position.domain, np.sinh(x)) @property def metric(self): x = self.position.to_global_data()[0] v = np.cosh(x) return ift.DiagonalOperator( - ift.Field(self.position.domain, val=v)) + ift.Field.full(self.position.domain, v)) try: minimizer = eval(minimizer) -- GitLab