Commit d8c42e70 authored by Martin Reinecke's avatar Martin Reinecke

make Fields and MultiFields immutable

parent c3c4a8c4
...@@ -104,6 +104,16 @@ class DomainTuple(object): ...@@ -104,6 +104,16 @@ class DomainTuple(object):
""" """
return self._shape return self._shape
@property
def local_shape(self):
"""tuple of int: number of pixels along each axis on the local task
The shape of the array-like object required to store information
living on part of the domain which is stored on the local MPI task.
"""
from .dobj import local_shape
return local_shape(self._shape)
@property @property
def size(self): def size(self):
"""int : total number of pixels. """int : total number of pixels.
......
...@@ -88,6 +88,16 @@ class Domain(NiftyMetaBase()): ...@@ -88,6 +88,16 @@ class Domain(NiftyMetaBase()):
""" """
raise NotImplementedError raise NotImplementedError
@property
def local_shape(self):
"""tuple of int: number of pixels along each axis on the local task
The shape of the array-like object required to store information
living on part of the domain which is stored on the local MPI task.
"""
from ..dobj import local_shape
return local_shape(self.shape)
@abc.abstractproperty @abc.abstractproperty
def size(self): def size(self):
"""int: total number of pixels. """int: total number of pixels.
......
...@@ -3,7 +3,7 @@ from ..sugar import exp ...@@ -3,7 +3,7 @@ from ..sugar import exp
import numpy as np import numpy as np
from ..dobj import ibegin from .. import dobj
from ..field import Field from ..field import Field
from .structured_domain import StructuredDomain from .structured_domain import StructuredDomain
...@@ -62,26 +62,22 @@ class LogRGSpace(StructuredDomain): ...@@ -62,26 +62,22 @@ class LogRGSpace(StructuredDomain):
np.zeros(len(self.shape)), True) np.zeros(len(self.shape)), True)
def get_k_length_array(self): def get_k_length_array(self):
out = Field(self, dtype=np.float64) ib = dobj.ibegin_from_shape(self._shape)
oloc = out.local_data res = np.arange(self.local_shape[0], dtype=np.float64) + ib[0]
ib = ibegin(out.val)
res = np.arange(oloc.shape[0], dtype=np.float64) + ib[0]
res = np.minimum(res, self.shape[0]-res)*self.bindistances[0] res = np.minimum(res, self.shape[0]-res)*self.bindistances[0]
if len(self.shape) == 1: if len(self.shape) == 1:
oloc[()] = res return Field.from_local_data(self, res)
return out
res *= res res *= res
for i in range(1, len(self.shape)): for i in range(1, len(self.shape)):
tmp = np.arange(oloc.shape[i], dtype=np.float64) + ib[i] tmp = np.arange(self.local_shape[i], dtype=np.float64) + ib[i]
tmp = np.minimum(tmp, self.shape[i]-tmp)*self.bindistances[i] tmp = np.minimum(tmp, self.shape[i]-tmp)*self.bindistances[i]
tmp *= tmp tmp *= tmp
res = np.add.outer(res, tmp) res = np.add.outer(res, tmp)
oloc[()] = np.sqrt(res) return Field.from_local_data(self, np.sqrt(res))
return out
def get_expk_length_array(self): def get_expk_length_array(self):
# FIXME This is a hack! Only for plotting. Seems not to be the final version. # FIXME This is a hack! Only for plotting. Seems not to be the final version.
out = exp(self.get_k_length_array()) out = exp(self.get_k_length_array()).to_global_data().copy()
out.val[1:] = out.val[:-1] out[1:] = out[:-1]
out.val[0] = 0 out[0] = 0
return out return Field.from_global_data(self, out)
...@@ -95,22 +95,18 @@ class RGSpace(StructuredDomain): ...@@ -95,22 +95,18 @@ class RGSpace(StructuredDomain):
def get_k_length_array(self): def get_k_length_array(self):
if (not self.harmonic): if (not self.harmonic):
raise NotImplementedError raise NotImplementedError
out = Field(self, dtype=np.float64) ibegin = dobj.ibegin_from_shape(self._shape)
oloc = out.local_data res = np.arange(self.local_shape[0], dtype=np.float64) + ibegin[0]
ibegin = dobj.ibegin(out.val)
res = np.arange(oloc.shape[0], dtype=np.float64) + ibegin[0]
res = np.minimum(res, self.shape[0]-res)*self.distances[0] res = np.minimum(res, self.shape[0]-res)*self.distances[0]
if len(self.shape) == 1: if len(self.shape) == 1:
oloc[()] = res return Field.from_local_data(self, res)
return out
res *= res res *= res
for i in range(1, len(self.shape)): for i in range(1, len(self.shape)):
tmp = np.arange(oloc.shape[i], dtype=np.float64) + ibegin[i] tmp = np.arange(self.local_shape[i], dtype=np.float64) + ibegin[i]
tmp = np.minimum(tmp, self.shape[i]-tmp)*self.distances[i] tmp = np.minimum(tmp, self.shape[i]-tmp)*self.distances[i]
tmp *= tmp tmp *= tmp
res = np.add.outer(res, tmp) res = np.add.outer(res, tmp)
oloc[()] = np.sqrt(res) return Field.from_local_data(self, np.sqrt(res))
return out
def get_unique_k_lengths(self): def get_unique_k_lengths(self):
if (not self.harmonic): if (not self.harmonic):
...@@ -147,8 +143,7 @@ class RGSpace(StructuredDomain): ...@@ -147,8 +143,7 @@ class RGSpace(StructuredDomain):
from ..sugar import exp from ..sugar import exp
tmp = x*x tmp = x*x
tmp *= -2.*np.pi*np.pi*sigma*sigma tmp *= -2.*np.pi*np.pi*sigma*sigma
exp(tmp, out=tmp) return exp(tmp)
return tmp
def get_fft_smoothing_kernel_function(self, sigma): def get_fft_smoothing_kernel_function(self, sigma):
if (not self.harmonic): if (not self.harmonic):
......
...@@ -35,9 +35,9 @@ def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol): ...@@ -35,9 +35,9 @@ def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.ADJOINT_TIMES needed_cap = op.TIMES | op.ADJOINT_TIMES
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
f1 = from_random("normal", op.domain, dtype=domain_dtype).lock() f1 = from_random("normal", op.domain, dtype=domain_dtype)
f2 = from_random("normal", op.target, dtype=target_dtype).lock() f2 = from_random("normal", op.target, dtype=target_dtype)
res1 = f1.vdot(op.adjoint_times(f2).lock()) res1 = f1.vdot(op.adjoint_times(f2))
res2 = op.times(f1).vdot(f2) res2 = op.times(f1).vdot(f2)
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol) np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
...@@ -46,12 +46,12 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol): ...@@ -46,12 +46,12 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.INVERSE_TIMES needed_cap = op.TIMES | op.INVERSE_TIMES
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
foo = from_random("normal", op.target, dtype=target_dtype).lock() foo = from_random("normal", op.target, dtype=target_dtype)
res = op(op.inverse_times(foo).lock()) res = op(op.inverse_times(foo))
_assert_allclose(res, foo, atol=atol, rtol=rtol) _assert_allclose(res, foo, atol=atol, rtol=rtol)
foo = from_random("normal", op.domain, dtype=domain_dtype).lock() foo = from_random("normal", op.domain, dtype=domain_dtype)
res = op.inverse_times(op(foo).lock()) res = op.inverse_times(op(foo))
_assert_allclose(res, foo, atol=atol, rtol=rtol) _assert_allclose(res, foo, atol=atol, rtol=rtol)
......
This diff is collapsed.
...@@ -66,7 +66,7 @@ class ConjugateGradient(Minimizer): ...@@ -66,7 +66,7 @@ class ConjugateGradient(Minimizer):
return energy, status return energy, status
r = energy.gradient r = energy.gradient
d = r.copy() if preconditioner is None else preconditioner(r) d = r if preconditioner is None else preconditioner(r)
previous_gamma = r.vdot(d).real previous_gamma = r.vdot(d).real
if previous_gamma == 0: if previous_gamma == 0:
......
...@@ -52,7 +52,7 @@ class Energy(NiftyMetaBase()): ...@@ -52,7 +52,7 @@ class Energy(NiftyMetaBase()):
def __init__(self, position): def __init__(self, position):
super(Energy, self).__init__() super(Energy, self).__init__()
self._position = position.lock() self._position = position
def at(self, position): def at(self, position):
""" Returns a new Energy object, initialized at `position`. """ Returns a new Energy object, initialized at `position`.
......
...@@ -63,7 +63,7 @@ class EnergySum(Energy): ...@@ -63,7 +63,7 @@ class EnergySum(Energy):
@memo @memo
def gradient(self): def gradient(self):
return my_lincomb(map(lambda v: v.gradient, self._energies), return my_lincomb(map(lambda v: v.gradient, self._energies),
self._factors).lock() self._factors)
@property @property
@memo @memo
......
...@@ -48,7 +48,7 @@ class LineEnergy(object): ...@@ -48,7 +48,7 @@ class LineEnergy(object):
def __init__(self, line_position, energy, line_direction, offset=0.): def __init__(self, line_position, energy, line_direction, offset=0.):
super(LineEnergy, self).__init__() super(LineEnergy, self).__init__()
self._line_position = float(line_position) self._line_position = float(line_position)
self._line_direction = line_direction.lock() self._line_direction = line_direction
if self._line_position == float(offset): if self._line_position == float(offset):
self._energy = energy self._energy = energy
......
...@@ -35,7 +35,6 @@ class QuadraticEnergy(Energy): ...@@ -35,7 +35,6 @@ class QuadraticEnergy(Energy):
else: else:
Ax = self._A(self.position) Ax = self._A(self.position)
self._grad = Ax if b is None else Ax - b self._grad = Ax if b is None else Ax - b
self._grad.lock()
self._value = 0.5*self.position.vdot(Ax) self._value = 0.5*self.position.vdot(Ax)
if b is not None: if b is not None:
self._value -= b.vdot(self.position) self._value -= b.vdot(self.position)
......
...@@ -33,7 +33,7 @@ def _toFlatNdarray(fld): ...@@ -33,7 +33,7 @@ def _toFlatNdarray(fld):
def _toField(arr, dom): def _toField(arr, dom):
return Field.from_global_data(dom, arr.reshape(dom.shape)) return Field.from_global_data(dom, arr.reshape(dom.shape).copy())
class _MinHelper(object): class _MinHelper(object):
...@@ -44,7 +44,7 @@ class _MinHelper(object): ...@@ -44,7 +44,7 @@ class _MinHelper(object):
def _update(self, x): def _update(self, x):
pos = _toField(x, self._domain) pos = _toField(x, self._domain)
if (pos != self._energy.position).any(): if (pos != self._energy.position).any():
self._energy = self._energy.at(pos.locked_copy()) self._energy = self._energy.at(pos)
def fun(self, x): def fun(self, x):
self._update(x) self._update(x)
......
...@@ -109,8 +109,8 @@ class _InformationStore(object): ...@@ -109,8 +109,8 @@ class _InformationStore(object):
self.max_history_length = max_history_length self.max_history_length = max_history_length
self.s = [None]*max_history_length self.s = [None]*max_history_length
self.y = [None]*max_history_length self.y = [None]*max_history_length
self.last_x = x0.copy() self.last_x = x0
self.last_gradient = gradient.copy() self.last_gradient = gradient
self.k = 0 self.k = 0
mmax = max_history_length mmax = max_history_length
...@@ -233,7 +233,7 @@ class _InformationStore(object): ...@@ -233,7 +233,7 @@ class _InformationStore(object):
self.s[self.k % mmax] = x - self.last_x self.s[self.k % mmax] = x - self.last_x
self.y[self.k % mmax] = gradient - self.last_gradient self.y[self.k % mmax] = gradient - self.last_gradient
self.last_x = x.copy() self.last_x = x
self.last_gradient = gradient.copy() self.last_gradient = gradient
self.k += 1 self.k += 1
...@@ -69,18 +69,6 @@ class MultiField(object): ...@@ -69,18 +69,6 @@ class MultiField(object):
dtype[key], **kwargs) dtype[key], **kwargs)
for key in sorted(domain.keys())}) for key in sorted(domain.keys())})
def fill(self, fill_value):
"""Fill `self` uniformly with `fill_value`
Parameters
----------
fill_value: float or complex or int
The value to fill the field with.
"""
for val in self._val.values():
val.fill(fill_value)
return self
def _check_domain(self, other): def _check_domain(self, other):
if other._domain != self._domain: if other._domain != self._domain:
raise ValueError("domains are incompatible.") raise ValueError("domains are incompatible.")
...@@ -92,27 +80,6 @@ class MultiField(object): ...@@ -92,27 +80,6 @@ class MultiField(object):
result += sub_field.vdot(x[key]) result += sub_field.vdot(x[key])
return result return result
def lock(self):
for v in self.values():
v.lock()
return self
@property
def locked(self):
return all(v.locked for v in self.values())
def copy(self):
return MultiField({key: val.copy() for key, val in self.items()})
def locked_copy(self):
if self.locked:
return self
return MultiField({key: val.locked_copy()
for key, val in self.items()})
def empty_copy(self):
return MultiField({key: val.empty_copy() for key, val in self.items()})
@staticmethod @staticmethod
def build_dtype(dtype, domain): def build_dtype(dtype, domain):
if isinstance(dtype, dict): if isinstance(dtype, dict):
...@@ -121,12 +88,6 @@ class MultiField(object): ...@@ -121,12 +88,6 @@ class MultiField(object):
dtype = np.float64 dtype = np.float64
return {key: dtype for key in domain.keys()} return {key: dtype for key in domain.keys()}
@staticmethod
def empty(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.empty(dom, dtype=dtype[key])
for key, dom in domain.items()})
@staticmethod @staticmethod
def full(domain, val): def full(domain, val):
return MultiField({key: Field.full(dom, val) return MultiField({key: Field.full(dom, val)
...@@ -241,9 +202,9 @@ for op in ["__add__", "__radd__", ...@@ -241,9 +202,9 @@ for op in ["__add__", "__radd__",
result_val[key] = getattr(self[key], op)(other[key]) result_val[key] = getattr(self[key], op)(other[key])
if op in ("__add__", "__radd__"): if op in ("__add__", "__radd__"):
for key in only_self_keys: for key in only_self_keys:
result_val[key] = self[key].copy() result_val[key] = self[key]
for key in only_other_keys: for key in only_other_keys:
result_val[key] = other[key].copy() result_val[key] = other[key]
elif op in ("__mul__", "__rmul__"): elif op in ("__mul__", "__rmul__"):
pass pass
else: else:
......
...@@ -185,7 +185,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -185,7 +185,7 @@ class DiagonalOperator(EndomorphicOperator):
res = Field.from_random(random_type="normal", domain=self._domain, res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype) dtype=dtype)
if from_inverse: if from_inverse:
res.local_data[()] /= np.sqrt(self._ldiag) res /= np.sqrt(self._ldiag)
else: else:
res.local_data[()] *= np.sqrt(self._ldiag) res *= np.sqrt(self._ldiag)
return res return res
...@@ -120,15 +120,14 @@ class DOFDistributor(LinearOperator): ...@@ -120,15 +120,14 @@ class DOFDistributor(LinearOperator):
return res return res
def _times(self, x): def _times(self, x):
res = Field.empty(self._target, dtype=x.dtype)
if dobj.distaxis(x.val) in x.domain.axes[self._space]: if dobj.distaxis(x.val) in x.domain.axes[self._space]:
arr = x.to_global_data() arr = x.to_global_data()
else: else:
arr = x.local_data arr = x.local_data
arr = arr.reshape(self._hshape) arr = arr.reshape(self._hshape)
oarr = arr[(slice(None), self._dofdex, slice(None))]
return Field.from_local_data(self._target, oarr.reshape(self._target.local_shape))
oarr = res.local_data.reshape(self._pshape) oarr = res.local_data.reshape(self._pshape)
oarr[()] = arr[(slice(None), self._dofdex, slice(None))]
return res
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
......
...@@ -23,6 +23,7 @@ from ..minimization.conjugate_gradient import ConjugateGradient ...@@ -23,6 +23,7 @@ from ..minimization.conjugate_gradient import ConjugateGradient
from ..minimization.iteration_controller import IterationController from ..minimization.iteration_controller import IterationController
from ..minimization.quadratic_energy import QuadraticEnergy from ..minimization.quadratic_energy import QuadraticEnergy
from .endomorphic_operator import EndomorphicOperator from .endomorphic_operator import EndomorphicOperator
from ..sugar import full
class InversionEnabler(EndomorphicOperator): class InversionEnabler(EndomorphicOperator):
...@@ -65,7 +66,7 @@ class InversionEnabler(EndomorphicOperator): ...@@ -65,7 +66,7 @@ class InversionEnabler(EndomorphicOperator):
if self._op.capability & mode: if self._op.capability & mode:
return self._op.apply(x, mode) return self._op.apply(x, mode)
x0 = x.empty_copy().fill(0.) x0 = full(x.domain, 0.)
invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]] invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]]
invop = self._op._flip_modes(self._ilog[invmode]) invop = self._op._flip_modes(self._ilog[invmode])
prec = self._approximation prec = self._approximation
......
...@@ -22,6 +22,7 @@ from ..field import Field ...@@ -22,6 +22,7 @@ from ..field import Field
from ..multi.multi_field import MultiField from ..multi.multi_field import MultiField
from .endomorphic_operator import EndomorphicOperator from .endomorphic_operator import EndomorphicOperator
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..sugar import full
class ScalingOperator(EndomorphicOperator): class ScalingOperator(EndomorphicOperator):
...@@ -61,9 +62,9 @@ class ScalingOperator(EndomorphicOperator): ...@@ -61,9 +62,9 @@ class ScalingOperator(EndomorphicOperator):
self._check_input(x, mode) self._check_input(x, mode)
if self._factor == 1.: if self._factor == 1.:
return x.copy() return x
if self._factor == 0.: if self._factor == 0.:
return x.empty_copy().fill(0.) return full(self.domain, 0.)
if mode == self.TIMES: if mode == self.TIMES:
return x*self._factor return x*self._factor
......
...@@ -50,10 +50,9 @@ class SelectionOperator(LinearOperator): ...@@ -50,10 +50,9 @@ class SelectionOperator(LinearOperator):
return self.TIMES | self.ADJOINT_TIMES return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
# FIXME Is the copying necessary?
self._check_input(x, mode) self._check_input(x, mode)
if mode == self.TIMES: if mode == self.TIMES:
return x[self._key].copy() return x[self._key]
else: else:
from ..multi.multi_field import MultiField from ..multi.multi_field import MultiField
return MultiField({self._key: x.copy()}) return MultiField({self._key: x})
...@@ -15,7 +15,7 @@ class SymmetrizingOperator(EndomorphicOperator): ...@@ -15,7 +15,7 @@ class SymmetrizingOperator(EndomorphicOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
tmp = x.copy().val tmp = x.val.copy()
ax = dobj.distaxis(tmp) ax = dobj.distaxis(tmp)
globshape = tmp.shape globshape = tmp.shape
for i in range(self._ndim): for i in range(self._ndim):
......
...@@ -31,7 +31,7 @@ from .logger import logger ...@@ -31,7 +31,7 @@ from .logger import logger
__all__ = ['PS_field', 'power_analyze', 'create_power_operator', __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random', 'create_harmonic_smoothing_operator', 'from_random',
'full', 'empty', 'from_global_data', 'from_local_data', 'full', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate', 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate',
'get_signal_variance', 'makeOp'] 'get_signal_variance', 'makeOp']
...@@ -203,12 +203,6 @@ def full(domain, val): ...@@ -203,12 +203,6 @@ def full(domain, val):
return Field.full(domain, val) return Field.full(domain, val)
def empty(domain, dtype):
if isinstance(domain, (dict, MultiDomain)):
return MultiField.empty(domain, dtype)
return Field.empty(domain, dtype)
def from_random(random_type, domain, dtype=np.float64, **kwargs): def from_random(random_type, domain, dtype=np.float64, **kwargs):
if isinstance(domain, (dict, MultiDomain)): if isinstance(domain, (dict, MultiDomain)):
return MultiField.from_random(random_type, domain, dtype, **kwargs) return MultiField.from_random(random_type, domain, dtype, **kwargs)
...@@ -248,26 +242,13 @@ _current_module = sys.modules[__name__] ...@@ -248,26 +242,13 @@ _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f): def func(f):
def func2(x, out=None): def func2(x):
if isinstance(x, MultiField): if isinstance(x, MultiField):
if out is not None:
if (not isinstance(out, MultiField) or
x._domain != out._domain):
raise ValueError("Bad 'out' argument")
for key, value in x.items():
func2(value, out=out[key])
return out
return MultiField({key: func2(val) for key, val in x.items()}) return MultiField({key: func2(val) for key, val in x.items()})
elif isinstance(x, Field): elif isinstance(x, Field):
fu = getattr(dobj, f) fu = getattr(dobj, f)
if out is not None: return Field(domain=x._domain, val=fu(x.val))
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)