Commit d8c42e70 authored by Martin Reinecke's avatar Martin Reinecke

make Fields and MultiFields immutable

parent c3c4a8c4
...@@ -104,6 +104,16 @@ class DomainTuple(object): ...@@ -104,6 +104,16 @@ class DomainTuple(object):
""" """
return self._shape return self._shape
@property
def local_shape(self):
"""tuple of int: number of pixels along each axis on the local task
The shape of the array-like object required to store information
living on part of the domain which is stored on the local MPI task.
"""
from .dobj import local_shape
return local_shape(self._shape)
@property @property
def size(self): def size(self):
"""int : total number of pixels. """int : total number of pixels.
......
...@@ -88,6 +88,16 @@ class Domain(NiftyMetaBase()): ...@@ -88,6 +88,16 @@ class Domain(NiftyMetaBase()):
""" """
raise NotImplementedError raise NotImplementedError
@property
def local_shape(self):
"""tuple of int: number of pixels along each axis on the local task
The shape of the array-like object required to store information
living on part of the domain which is stored on the local MPI task.
"""
from ..dobj import local_shape
return local_shape(self.shape)
@abc.abstractproperty @abc.abstractproperty
def size(self): def size(self):
"""int: total number of pixels. """int: total number of pixels.
......
...@@ -3,7 +3,7 @@ from ..sugar import exp ...@@ -3,7 +3,7 @@ from ..sugar import exp
import numpy as np import numpy as np
from ..dobj import ibegin from .. import dobj
from ..field import Field from ..field import Field
from .structured_domain import StructuredDomain from .structured_domain import StructuredDomain
...@@ -62,26 +62,22 @@ class LogRGSpace(StructuredDomain): ...@@ -62,26 +62,22 @@ class LogRGSpace(StructuredDomain):
np.zeros(len(self.shape)), True) np.zeros(len(self.shape)), True)
def get_k_length_array(self): def get_k_length_array(self):
out = Field(self, dtype=np.float64) ib = dobj.ibegin_from_shape(self._shape)
oloc = out.local_data res = np.arange(self.local_shape[0], dtype=np.float64) + ib[0]
ib = ibegin(out.val)
res = np.arange(oloc.shape[0], dtype=np.float64) + ib[0]
res = np.minimum(res, self.shape[0]-res)*self.bindistances[0] res = np.minimum(res, self.shape[0]-res)*self.bindistances[0]
if len(self.shape) == 1: if len(self.shape) == 1:
oloc[()] = res return Field.from_local_data(self, res)
return out
res *= res res *= res
for i in range(1, len(self.shape)): for i in range(1, len(self.shape)):
tmp = np.arange(oloc.shape[i], dtype=np.float64) + ib[i] tmp = np.arange(self.local_shape[i], dtype=np.float64) + ib[i]
tmp = np.minimum(tmp, self.shape[i]-tmp)*self.bindistances[i] tmp = np.minimum(tmp, self.shape[i]-tmp)*self.bindistances[i]
tmp *= tmp tmp *= tmp
res = np.add.outer(res, tmp) res = np.add.outer(res, tmp)
oloc[()] = np.sqrt(res) return Field.from_local_data(self, np.sqrt(res))
return out
def get_expk_length_array(self): def get_expk_length_array(self):
# FIXME This is a hack! Only for plotting. Seems not to be the final version. # FIXME This is a hack! Only for plotting. Seems not to be the final version.
out = exp(self.get_k_length_array()) out = exp(self.get_k_length_array()).to_global_data().copy()
out.val[1:] = out.val[:-1] out[1:] = out[:-1]
out.val[0] = 0 out[0] = 0
return out return Field.from_global_data(self, out)
...@@ -95,22 +95,18 @@ class RGSpace(StructuredDomain): ...@@ -95,22 +95,18 @@ class RGSpace(StructuredDomain):
def get_k_length_array(self): def get_k_length_array(self):
if (not self.harmonic): if (not self.harmonic):
raise NotImplementedError raise NotImplementedError
out = Field(self, dtype=np.float64) ibegin = dobj.ibegin_from_shape(self._shape)
oloc = out.local_data res = np.arange(self.local_shape[0], dtype=np.float64) + ibegin[0]
ibegin = dobj.ibegin(out.val)
res = np.arange(oloc.shape[0], dtype=np.float64) + ibegin[0]
res = np.minimum(res, self.shape[0]-res)*self.distances[0] res = np.minimum(res, self.shape[0]-res)*self.distances[0]
if len(self.shape) == 1: if len(self.shape) == 1:
oloc[()] = res return Field.from_local_data(self, res)
return out
res *= res res *= res
for i in range(1, len(self.shape)): for i in range(1, len(self.shape)):
tmp = np.arange(oloc.shape[i], dtype=np.float64) + ibegin[i] tmp = np.arange(self.local_shape[i], dtype=np.float64) + ibegin[i]
tmp = np.minimum(tmp, self.shape[i]-tmp)*self.distances[i] tmp = np.minimum(tmp, self.shape[i]-tmp)*self.distances[i]
tmp *= tmp tmp *= tmp
res = np.add.outer(res, tmp) res = np.add.outer(res, tmp)
oloc[()] = np.sqrt(res) return Field.from_local_data(self, np.sqrt(res))
return out
def get_unique_k_lengths(self): def get_unique_k_lengths(self):
if (not self.harmonic): if (not self.harmonic):
...@@ -147,8 +143,7 @@ class RGSpace(StructuredDomain): ...@@ -147,8 +143,7 @@ class RGSpace(StructuredDomain):
from ..sugar import exp from ..sugar import exp
tmp = x*x tmp = x*x
tmp *= -2.*np.pi*np.pi*sigma*sigma tmp *= -2.*np.pi*np.pi*sigma*sigma
exp(tmp, out=tmp) return exp(tmp)
return tmp
def get_fft_smoothing_kernel_function(self, sigma): def get_fft_smoothing_kernel_function(self, sigma):
if (not self.harmonic): if (not self.harmonic):
......
...@@ -35,9 +35,9 @@ def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol): ...@@ -35,9 +35,9 @@ def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.ADJOINT_TIMES needed_cap = op.TIMES | op.ADJOINT_TIMES
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
f1 = from_random("normal", op.domain, dtype=domain_dtype).lock() f1 = from_random("normal", op.domain, dtype=domain_dtype)
f2 = from_random("normal", op.target, dtype=target_dtype).lock() f2 = from_random("normal", op.target, dtype=target_dtype)
res1 = f1.vdot(op.adjoint_times(f2).lock()) res1 = f1.vdot(op.adjoint_times(f2))
res2 = op.times(f1).vdot(f2) res2 = op.times(f1).vdot(f2)
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol) np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
...@@ -46,12 +46,12 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol): ...@@ -46,12 +46,12 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.INVERSE_TIMES needed_cap = op.TIMES | op.INVERSE_TIMES
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
foo = from_random("normal", op.target, dtype=target_dtype).lock() foo = from_random("normal", op.target, dtype=target_dtype)
res = op(op.inverse_times(foo).lock()) res = op(op.inverse_times(foo))
_assert_allclose(res, foo, atol=atol, rtol=rtol) _assert_allclose(res, foo, atol=atol, rtol=rtol)
foo = from_random("normal", op.domain, dtype=domain_dtype).lock() foo = from_random("normal", op.domain, dtype=domain_dtype)
res = op.inverse_times(op(foo).lock()) res = op.inverse_times(op(foo))
_assert_allclose(res, foo, atol=atol, rtol=rtol) _assert_allclose(res, foo, atol=atol, rtol=rtol)
......
...@@ -35,7 +35,7 @@ class Field(object): ...@@ -35,7 +35,7 @@ class Field(object):
---------- ----------
domain : None, DomainTuple, tuple of Domain, or Domain domain : None, DomainTuple, tuple of Domain, or Domain
val : None, Field, data_object, or scalar val : Field, data_object or scalar
The values the array should contain after init. A scalar input will The values the array should contain after init. A scalar input will
fill the whole array with this scalar. If a data_object is provided, fill the whole array with this scalar. If a data_object is provided,
its dimensions must match the domain's. its dimensions must match the domain's.
...@@ -49,31 +49,29 @@ class Field(object): ...@@ -49,31 +49,29 @@ class Field(object):
many convenience functions for Field conatruction! many convenience functions for Field conatruction!
""" """
def __init__(self, domain=None, val=None, dtype=None, copy=False, def __init__(self, domain=None, val=None, dtype=None):
locked=False):
self._domain = self._infer_domain(domain=domain, val=val) self._domain = self._infer_domain(domain=domain, val=val)
dtype = self._infer_dtype(dtype=dtype, val=val) dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field): if isinstance(val, Field):
if self._domain != val._domain: if self._domain != val._domain:
raise ValueError("Domain mismatch") raise ValueError("Domain mismatch")
self._val = dobj.from_object(val.val, dtype=dtype, copy=copy, self._val = val._val
set_locked=locked)
elif (np.isscalar(val)): elif (np.isscalar(val)):
self._val = dobj.full(self._domain.shape, dtype=dtype, self._val = dobj.full(self._domain.shape, dtype=dtype,
fill_value=val) fill_value=val)
elif isinstance(val, dobj.data_object): elif isinstance(val, dobj.data_object):
if self._domain.shape == val.shape: if self._domain.shape == val.shape:
self._val = dobj.from_object(val, dtype=dtype, copy=copy, if dtype == val.dtype:
set_locked=locked) self._val = val
else:
self._val = dobj.from_object(val, dtype, True, True)
else: else:
raise ValueError("Shape mismatch") raise ValueError("Shape mismatch")
elif val is None:
self._val = dobj.empty(self._domain.shape, dtype=dtype)
else: else:
raise TypeError("unknown source type") raise TypeError("unknown source type")
if locked:
dobj.lock(self._val) dobj.lock(self._val)
# prevent implicit conversion to bool # prevent implicit conversion to bool
...@@ -84,7 +82,7 @@ class Field(object): ...@@ -84,7 +82,7 @@ class Field(object):
raise TypeError("Field does not support implicit conversion to bool") raise TypeError("Field does not support implicit conversion to bool")
@staticmethod @staticmethod
def full(domain, val, dtype=None): def full(domain, val):
"""Creates a Field with a given domain, filled with a constant value. """Creates a Field with a given domain, filled with a constant value.
Parameters Parameters
...@@ -101,11 +99,7 @@ class Field(object): ...@@ -101,11 +99,7 @@ class Field(object):
""" """
if not np.isscalar(val): if not np.isscalar(val):
raise TypeError("val must be a scalar") raise TypeError("val must be a scalar")
return Field(DomainTuple.make(domain), val, dtype) return Field(DomainTuple.make(domain), val)
@staticmethod
def empty(domain, dtype=None):
return Field(DomainTuple.make(domain), None, dtype)
@staticmethod @staticmethod
def from_global_data(domain, arr, sum_up=False): def from_global_data(domain, arr, sum_up=False):
...@@ -152,11 +146,6 @@ class Field(object): ...@@ -152,11 +146,6 @@ class Field(object):
Returns a handle to the part of the array data residing on the local Returns a handle to the part of the array data residing on the local
task (or to the entore array if MPI is not active). task (or to the entore array if MPI is not active).
Notes
-----
If the field is not locked, the array data can be modified.
Use with care!
""" """
return dobj.local_data(self._val) return dobj.local_data(self._val)
...@@ -196,8 +185,6 @@ class Field(object): ...@@ -196,8 +185,6 @@ class Field(object):
return dtype return dtype
if val is None: if val is None:
raise ValueError("could not infer dtype") raise ValueError("could not infer dtype")
if isinstance(val, Field):
return val.dtype
return np.result_type(val) return np.result_type(val)
@staticmethod @staticmethod
...@@ -223,41 +210,6 @@ class Field(object): ...@@ -223,41 +210,6 @@ class Field(object):
val=dobj.from_random(random_type, dtype=dtype, val=dobj.from_random(random_type, dtype=dtype,
shape=domain.shape, **kwargs)) shape=domain.shape, **kwargs))
def fill(self, fill_value):
"""Fill `self` uniformly with `fill_value`
Parameters
----------
fill_value: float or complex or int
The value to fill the field with.
"""
self._val.fill(fill_value)
return self
def lock(self):
"""Write-protect the data content of `self`.
After this call, it will no longer be possible to change the data
entries of `self`. This is convenient if, for example, a
DiagonalOperator wants to ensure that its diagonal cannot be modified
inadvertently, without making copies.
Notes
-----
This will not only prohibit modifications to the entries of `self`, but
also to the entries of any other Field or numpy array pointing to the
same data. If an unlocked instance is needed, use copy().
The fact that there is no `unlock()` method is deliberate.
"""
dobj.lock(self._val)
return self
@property
def locked(self):
"""bool : True iff the field's data content has been locked"""
return dobj.locked(self._val)
@property @property
def val(self): def val(self):
"""dobj.data_object : the data object storing the field's entries """dobj.data_object : the data object storing the field's entries
...@@ -303,43 +255,6 @@ class Field(object): ...@@ -303,43 +255,6 @@ class Field(object):
raise ValueError(".imag called on a non-complex Field") raise ValueError(".imag called on a non-complex Field")
return Field(self._domain, self.val.imag) return Field(self._domain, self.val.imag)
def copy(self):
""" Returns a full copy of the Field.
The returned object will be an identical copy of the original Field.
The copy will be writeable, even if `self` was locked.
Returns
-------
Field
An identical, but unlocked copy of 'self'.
"""
return Field(val=self, copy=True)
def empty_copy(self):
""" Returns a Field with identical domain and data type, but
uninitialized data.
Returns
-------
Field
A copy of 'self', with uninitialized data.
"""
return Field(self._domain, dtype=self.dtype)
def locked_copy(self):
""" Returns a read-only version of the Field.
If `self` is locked, returns `self`. Otherwise returns a locked copy
of `self`.
Returns
-------
Field
A read-only version of `self`.
"""
return self if self.locked else Field(val=self, copy=True, locked=True)
def scalar_weight(self, spaces=None): def scalar_weight(self, spaces=None):
"""Returns the uniform volume element for a sub-domain of `self`. """Returns the uniform volume element for a sub-domain of `self`.
...@@ -392,7 +307,7 @@ class Field(object): ...@@ -392,7 +307,7 @@ class Field(object):
res *= self._domain[i].total_volume res *= self._domain[i].total_volume
return res return res
def weight(self, power=1, spaces=None, out=None): def weight(self, power=1, spaces=None):
""" Weights the pixels of `self` with their invidual pixel-volume. """ Weights the pixels of `self` with their invidual pixel-volume.
Parameters Parameters
...@@ -404,21 +319,12 @@ class Field(object): ...@@ -404,21 +319,12 @@ class Field(object):
Determines on which sub-domain the operation takes place. Determines on which sub-domain the operation takes place.
If None, the entire domain is used. If None, the entire domain is used.
out : Field or None
if not None, the result is returned in a new Field
otherwise the contents of "out" are overwritten with the result.
"out" may be identical to "self"!
Returns Returns
------- -------
Field Field
The weighted field. The weighted field.
""" """
if out is None: aout = self.local_data.copy()
out = self.copy()
else:
if out is not self:
out.copy_content_from(self)
spaces = utilities.parse_spaces(spaces, len(self._domain)) spaces = utilities.parse_spaces(spaces, len(self._domain))
...@@ -435,12 +341,12 @@ class Field(object): ...@@ -435,12 +341,12 @@ class Field(object):
if dobj.distaxis(self._val) >= 0 and ind == 0: if dobj.distaxis(self._val) >= 0 and ind == 0:
# we need to distribute the weights along axis 0 # we need to distribute the weights along axis 0
wgt = dobj.local_data(dobj.from_global_data(wgt)) wgt = dobj.local_data(dobj.from_global_data(wgt))
out.local_data[()] *= wgt**power aout *= wgt**power
fct = fct**power fct = fct**power
if fct != 1.: if fct != 1.:
out *= fct aout *= fct
return out return Field.from_local_data(self._domain, aout)
def vdot(self, x=None, spaces=None): def vdot(self, x=None, spaces=None):
""" Computes the dot product of 'self' with x. """ Computes the dot product of 'self' with x.
...@@ -508,7 +414,7 @@ class Field(object): ...@@ -508,7 +414,7 @@ class Field(object):
# ---General unary/contraction methods--- # ---General unary/contraction methods---
def __pos__(self): def __pos__(self):
return self.copy() return self
def __neg__(self): def __neg__(self):
return Field(self._domain, -self.val) return Field(self._domain, -self.val)
...@@ -538,7 +444,7 @@ class Field(object): ...@@ -538,7 +444,7 @@ class Field(object):
for i, dom in enumerate(self._domain) for i, dom in enumerate(self._domain)
if i not in spaces) if i not in spaces)
return Field(domain=return_domain, val=data, copy=False) return Field(domain=return_domain, val=data)
def sum(self, spaces=None): def sum(self, spaces=None):
"""Sums up over the sub-domains given by `spaces`. """Sums up over the sub-domains given by `spaces`.
...@@ -713,13 +619,6 @@ class Field(object): ...@@ -713,13 +619,6 @@ class Field(object):
return self._contraction_helper('std', spaces) return self._contraction_helper('std', spaces)
return sqrt(self.var(spaces)) return sqrt(self.var(spaces))
def copy_content_from(self, other):
if not isinstance(other, Field):
raise TypeError("argument must be a Field")
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
self.local_data[()] = other.local_data[()]
def __repr__(self): def __repr__(self):
return "<nifty5.Field>" return "<nifty5.Field>"
...@@ -745,13 +644,13 @@ class Field(object): ...@@ -745,13 +644,13 @@ class Field(object):
return self.isEquivalentTo(other) return self.isEquivalentTo(other)
for op in ["__add__", "__radd__", "__iadd__", for op in ["__add__", "__radd__",
"__sub__", "__rsub__", "__isub__", "__sub__", "__rsub__",
"__mul__", "__rmul__", "__imul__", "__mul__", "__rmul__",
"__div__", "__rdiv__", "__idiv__", "__div__", "__rdiv__",
"__truediv__", "__rtruediv__", "__itruediv__", "__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__", "__ifloordiv__", "__floordiv__", "__rfloordiv__",
"__pow__", "__rpow__", "__ipow__", "__pow__", "__rpow__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op): def func(op):
def func2(self, other): def func2(self, other):
...@@ -761,11 +660,11 @@ for op in ["__add__", "__radd__", "__iadd__", ...@@ -761,11 +660,11 @@ for op in ["__add__", "__radd__", "__iadd__",
if other._domain != self._domain: if other._domain != self._domain:
raise ValueError("domains are incompatible.") raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val) tval = getattr(self.val, op)(other.val)
return self if tval is self.val else Field(self._domain, tval) return Field(self._domain, tval)
if np.isscalar(other) or isinstance(other, dobj.data_object): if np.isscalar(other) or isinstance(other, dobj.data_object):
tval = getattr(self.val, op)(other) tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self._domain, tval) return Field(self._domain, tval)
return NotImplemented return NotImplemented
return func2 return func2
......
...@@ -66,7 +66,7 @@ class ConjugateGradient(Minimizer): ...@@ -66,7 +66,7 @@ class ConjugateGradient(Minimizer):
return energy, status return energy, status
r = energy.gradient r = energy.gradient
d = r.copy() if preconditioner is None else preconditioner(r) d = r if preconditioner is None else preconditioner(r)
previous_gamma = r.vdot(d).real previous_gamma = r.vdot(d).real
if previous_gamma == 0: if previous_gamma == 0:
......
...@@ -52,7 +52,7 @@ class Energy(NiftyMetaBase()): ...@@ -52,7 +52,7 @@ class Energy(NiftyMetaBase()):
def __init__(self, position): def __init__(self, position):
super(Energy, self).__init__() super(Energy, self).__init__()
self._position = position.lock() self._position = position
def at(self, position): def at(self, position):
""" Returns a new Energy object, initialized at `position`. """ Returns a new Energy object, initialized at `position`.
......
...@@ -63,7 +63,7 @@ class EnergySum(Energy): ...@@ -63,7 +63,7 @@ class EnergySum(Energy):
@memo @memo
def gradient(self): def gradient(self):
return my_lincomb(map(lambda v: v.gradient, self._energies), return my_lincomb(map(lambda v: v.gradient, self._energies),
self._factors).lock() self._factors)
@property @property
@memo @memo
......
...@@ -48,7 +48,7 @@ class LineEnergy(object): ...@@ -48,7 +48,7 @@ class LineEnergy(object):
def __init__(self, line_position, energy, line_direction, offset=0.): def __init__(self, line_position, energy, line_direction, offset=0.):
super(LineEnergy, self).__init__() super(LineEnergy, self).__init__()
self._line_position = float(line_position)