Commit 6492d74b authored by Martin Reinecke's avatar Martin Reinecke

more polishing

parent b7934d79
Pipeline #23368 passed with stage
in 4 minutes and 33 seconds
...@@ -32,7 +32,6 @@ class Field(object): ...@@ -32,7 +32,6 @@ class Field(object):
In NIFTY, Fields are used to store data arrays and carry all the needed In NIFTY, Fields are used to store data arrays and carry all the needed
metainformation (i.e. the domain) for operators to be able to work on them. metainformation (i.e. the domain) for operators to be able to work on them.
In addition, Field has methods to work with power spectra.
Parameters Parameters
---------- ----------
...@@ -59,23 +58,23 @@ class Field(object): ...@@ -59,23 +58,23 @@ class Field(object):
""" """
def __init__(self, domain=None, val=None, dtype=None, copy=False): def __init__(self, domain=None, val=None, dtype=None, copy=False):
self.domain = self._infer_domain(domain=domain, val=val) self._domain = self._infer_domain(domain=domain, val=val)
dtype = self._infer_dtype(dtype=dtype, val=val) dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field): if isinstance(val, Field):
if self.domain != val.domain: if self._domain != val._domain:
raise ValueError("Domain mismatch") raise ValueError("Domain mismatch")
self._val = dobj.from_object(val.val, dtype=dtype, copy=copy) self._val = dobj.from_object(val.val, dtype=dtype, copy=copy)
elif (np.isscalar(val)): elif (np.isscalar(val)):
self._val = dobj.full(self.domain.shape, dtype=dtype, self._val = dobj.full(self._domain.shape, dtype=dtype,
fill_value=val) fill_value=val)
elif isinstance(val, dobj.data_object): elif isinstance(val, dobj.data_object):
if self.domain.shape == val.shape: if self._domain.shape == val.shape:
self._val = dobj.from_object(val, dtype=dtype, copy=copy) self._val = dobj.from_object(val, dtype=dtype, copy=copy)
else: else:
raise ValueError("Shape mismatch") raise ValueError("Shape mismatch")
elif val is None: elif val is None:
self._val = dobj.empty(self.domain.shape, dtype=dtype) self._val = dobj.empty(self._domain.shape, dtype=dtype)
else: else:
raise TypeError("unknown source type") raise TypeError("unknown source type")
...@@ -101,7 +100,7 @@ class Field(object): ...@@ -101,7 +100,7 @@ class Field(object):
def full_like(field, val, dtype=None): def full_like(field, val, dtype=None):
if not isinstance(field, Field): if not isinstance(field, Field):
raise TypeError("field must be of Field type") raise TypeError("field must be of Field type")
return Field.full(field.domain, val, dtype) return Field.full(field._domain, val, dtype)
@staticmethod @staticmethod
def zeros_like(field, dtype=None): def zeros_like(field, dtype=None):
...@@ -109,7 +108,7 @@ class Field(object): ...@@ -109,7 +108,7 @@ class Field(object):
raise TypeError("field must be of Field type") raise TypeError("field must be of Field type")
if dtype is None: if dtype is None:
dtype = field.dtype dtype = field.dtype
return Field.zeros(field.domain, dtype) return Field.zeros(field._domain, dtype)
@staticmethod @staticmethod
def ones_like(field, dtype=None): def ones_like(field, dtype=None):
...@@ -117,7 +116,7 @@ class Field(object): ...@@ -117,7 +116,7 @@ class Field(object):
raise TypeError("field must be of Field type") raise TypeError("field must be of Field type")
if dtype is None: if dtype is None:
dtype = field.dtype dtype = field.dtype
return Field.ones(field.domain, dtype) return Field.ones(field._domain, dtype)
@staticmethod @staticmethod
def empty_like(field, dtype=None): def empty_like(field, dtype=None):
...@@ -125,13 +124,13 @@ class Field(object): ...@@ -125,13 +124,13 @@ class Field(object):
raise TypeError("field must be of Field type") raise TypeError("field must be of Field type")
if dtype is None: if dtype is None:
dtype = field.dtype dtype = field.dtype
return Field.empty(field.domain, dtype) return Field.empty(field._domain, dtype)
@staticmethod @staticmethod
def _infer_domain(domain, val=None): def _infer_domain(domain, val=None):
if domain is None: if domain is None:
if isinstance(val, Field): if isinstance(val, Field):
return val.domain return val._domain
if np.isscalar(val): if np.isscalar(val):
return DomainTuple.make(()) # empty domain tuple return DomainTuple.make(()) # empty domain tuple
raise TypeError("could not infer domain from value") raise TypeError("could not infer domain from value")
...@@ -187,6 +186,10 @@ class Field(object): ...@@ -187,6 +186,10 @@ class Field(object):
def dtype(self): def dtype(self):
return self._val.dtype return self._val.dtype
@property
def domain(self):
return self._domain
@property @property
def shape(self): def shape(self):
""" Returns the total shape of the Field's data array. """ Returns the total shape of the Field's data array.
...@@ -195,7 +198,7 @@ class Field(object): ...@@ -195,7 +198,7 @@ class Field(object):
------- -------
Integer tuple containing the dimensions of the spaces in domain. Integer tuple containing the dimensions of the spaces in domain.
""" """
return self.domain.shape return self._domain.shape
@property @property
def dim(self): def dim(self):
...@@ -208,21 +211,21 @@ class Field(object): ...@@ -208,21 +211,21 @@ class Field(object):
out : int out : int
The dimension of the Field. The dimension of the Field.
""" """
return self.domain.dim return self._domain.dim
@property @property
def real(self): def real(self):
""" The real part of the field (data is not copied).""" """ The real part of the field (data is not copied)."""
if not np.issubdtype(self.dtype, np.complexfloating): if not np.issubdtype(self.dtype, np.complexfloating):
raise ValueError(".real called on a non-complex Field") raise ValueError(".real called on a non-complex Field")
return Field(self.domain, self.val.real) return Field(self._domain, self.val.real)
@property @property
def imag(self): def imag(self):
""" The imaginary part of the field (data is not copied).""" """ The imaginary part of the field (data is not copied)."""
if not np.issubdtype(self.dtype, np.complexfloating): if not np.issubdtype(self.dtype, np.complexfloating):
raise ValueError(".imag called on a non-complex Field") raise ValueError(".imag called on a non-complex Field")
return Field(self.domain, self.val.imag) return Field(self._domain, self.val.imag)
def copy(self): def copy(self):
""" Returns a full copy of the Field. """ Returns a full copy of the Field.
...@@ -238,13 +241,13 @@ class Field(object): ...@@ -238,13 +241,13 @@ class Field(object):
def scalar_weight(self, spaces=None): def scalar_weight(self, spaces=None):
if np.isscalar(spaces): if np.isscalar(spaces):
return self.domain[spaces].scalar_dvol() return self._domain[spaces].scalar_dvol()
if spaces is None: if spaces is None:
spaces = range(len(self.domain)) spaces = range(len(self._domain))
res = 1. res = 1.
for i in spaces: for i in spaces:
tmp = self.domain[i].scalar_dvol() tmp = self._domain[i].scalar_dvol()
if tmp is None: if tmp is None:
return None return None
res *= tmp res *= tmp
...@@ -277,17 +280,17 @@ class Field(object): ...@@ -277,17 +280,17 @@ class Field(object):
if out is not self: if out is not self:
out.copy_content_from(self) out.copy_content_from(self)
spaces = utilities.parse_spaces(spaces, len(self.domain)) spaces = utilities.parse_spaces(spaces, len(self._domain))
fct = 1. fct = 1.
for ind in spaces: for ind in spaces:
wgt = self.domain[ind].dvol() wgt = self._domain[ind].dvol()
if np.isscalar(wgt): if np.isscalar(wgt):
fct *= wgt fct *= wgt
else: else:
new_shape = np.ones(len(self.shape), dtype=np.int) new_shape = np.ones(len(self.shape), dtype=np.int)
new_shape[self.domain.axes[ind][0]: new_shape[self._domain.axes[ind][0]:
self.domain.axes[ind][-1]+1] = wgt.shape self._domain.axes[ind][-1]+1] = wgt.shape
wgt = wgt.reshape(new_shape) wgt = wgt.reshape(new_shape)
if dobj.distaxis(self._val) >= 0 and ind == 0: if dobj.distaxis(self._val) >= 0 and ind == 0:
# we need to distribute the weights along axis 0 # we need to distribute the weights along axis 0
...@@ -321,10 +324,10 @@ class Field(object): ...@@ -321,10 +324,10 @@ class Field(object):
raise ValueError("The dot-partner must be an instance of " + raise ValueError("The dot-partner must be an instance of " +
"the NIFTy field class") "the NIFTy field class")
if x.domain != self.domain: if x._domain != self._domain:
raise ValueError("Domain mismatch") raise ValueError("Domain mismatch")
ndom = len(self.domain) ndom = len(self._domain)
spaces = utilities.parse_spaces(spaces, ndom) spaces = utilities.parse_spaces(spaces, ndom)
if len(spaces) == ndom: if len(spaces) == ndom:
...@@ -359,7 +362,7 @@ class Field(object): ...@@ -359,7 +362,7 @@ class Field(object):
------- -------
The complex conjugated field. The complex conjugated field.
""" """
return Field(self.domain, self.val.conjugate(), self.dtype) return Field(self._domain, self.val.conjugate(), self.dtype)
# ---General unary/contraction methods--- # ---General unary/contraction methods---
...@@ -367,18 +370,18 @@ class Field(object): ...@@ -367,18 +370,18 @@ class Field(object):
return self.copy() return self.copy()
def __neg__(self): def __neg__(self):
return Field(self.domain, -self.val, self.dtype) return Field(self._domain, -self.val, self.dtype)
def __abs__(self): def __abs__(self):
return Field(self.domain, dobj.abs(self.val), self.dtype) return Field(self._domain, dobj.abs(self.val), self.dtype)
def _contraction_helper(self, op, spaces): def _contraction_helper(self, op, spaces):
if spaces is None: if spaces is None:
return getattr(self.val, op)() return getattr(self.val, op)()
spaces = utilities.parse_spaces(spaces, len(self.domain)) spaces = utilities.parse_spaces(spaces, len(self._domain))
axes_list = tuple(self.domain.axes[sp_index] for sp_index in spaces) axes_list = tuple(self._domain.axes[sp_index] for sp_index in spaces)
if len(axes_list) > 0: if len(axes_list) > 0:
axes_list = reduce(lambda x, y: x+y, axes_list) axes_list = reduce(lambda x, y: x+y, axes_list)
...@@ -391,7 +394,7 @@ class Field(object): ...@@ -391,7 +394,7 @@ class Field(object):
return data return data
else: else:
return_domain = tuple(dom return_domain = tuple(dom
for i, dom in enumerate(self.domain) for i, dom in enumerate(self._domain)
if i not in spaces) if i not in spaces)
return Field(domain=return_domain, val=data, copy=False) return Field(domain=return_domain, val=data, copy=False)
...@@ -435,21 +438,21 @@ class Field(object): ...@@ -435,21 +438,21 @@ class Field(object):
def copy_content_from(self, other): def copy_content_from(self, other):
if not isinstance(other, Field): if not isinstance(other, Field):
raise TypeError("argument must be a Field") raise TypeError("argument must be a Field")
if other.domain != self.domain: if other._domain != self._domain:
raise ValueError("domains are incompatible.") raise ValueError("domains are incompatible.")
dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()] dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()]
def _binary_helper(self, other, op): def _binary_helper(self, other, op):
# if other is a field, make sure that the domains match # if other is a field, make sure that the domains match
if isinstance(other, Field): if isinstance(other, Field):
if other.domain != self.domain: if other._domain != self._domain:
raise ValueError("domains are incompatible.") raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val) tval = getattr(self.val, op)(other.val)
return self if tval is self.val else Field(self.domain, tval) return self if tval is self.val else Field(self._domain, tval)
if np.isscalar(other) or isinstance(other, dobj.data_object): if np.isscalar(other) or isinstance(other, dobj.data_object):
tval = getattr(self.val, op)(other) tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self.domain, tval) return self if tval is self.val else Field(self._domain, tval)
return NotImplemented return NotImplemented
...@@ -511,7 +514,7 @@ class Field(object): ...@@ -511,7 +514,7 @@ class Field(object):
minmax = [self.min(), self.max()] minmax = [self.min(), self.max()]
mean = self.mean() mean = self.mean()
return "nifty2go.Field instance\n- domain = " + \ return "nifty2go.Field instance\n- domain = " + \
repr(self.domain) + \ repr(self._domain) + \
"\n- val = " + repr(self.val) + \ "\n- val = " + repr(self.val) + \
"\n - min.,max. = " + str(minmax) + \ "\n - min.,max. = " + str(minmax) + \
"\n - mean = " + str(mean) "\n - mean = " + str(mean)
...@@ -523,12 +526,12 @@ def _math_helper(x, function, out): ...@@ -523,12 +526,12 @@ def _math_helper(x, function, out):
if not isinstance(x, Field): if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.") raise TypeError("This function only accepts Field objects.")
if out is not None: if out is not None:
if not isinstance(out, Field) or x.domain != out.domain: if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument") raise ValueError("Bad 'out' argument")
function(x.val, out=out.val) function(x.val, out=out.val)
return out return out
else: else:
return Field(domain=x.domain, val=function(x.val)) return Field(domain=x._domain, val=function(x.val))
def sqrt(x, out=None): def sqrt(x, out=None):
......
...@@ -59,6 +59,8 @@ class CriticalPowerEnergy(Energy): ...@@ -59,6 +59,8 @@ class CriticalPowerEnergy(Energy):
self.samples = samples self.samples = samples
self.alpha = float(alpha) self.alpha = float(alpha)
self.q = float(q) self.q = float(q)
self._smoothness_prior = smoothness_prior
self._logarithmic = logarithmic
self.T = SmoothnessOperator(domain=self.position.domain[0], self.T = SmoothnessOperator(domain=self.position.domain[0],
strength=smoothness_prior, strength=smoothness_prior,
logarithmic=logarithmic) logarithmic=logarithmic)
...@@ -93,8 +95,9 @@ class CriticalPowerEnergy(Energy): ...@@ -93,8 +95,9 @@ class CriticalPowerEnergy(Energy):
def at(self, position): def at(self, position):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha, return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
q=self.q, smoothness_prior=self.smoothness_prior, q=self.q,
logarithmic=self.logarithmic, smoothness_prior=self._smoothness_prior,
logarithmic=self._logarithmic,
samples=self.samples, w=self._w, samples=self.samples, w=self._w,
inverter=self._inverter) inverter=self._inverter)
...@@ -111,11 +114,3 @@ class CriticalPowerEnergy(Energy): ...@@ -111,11 +114,3 @@ class CriticalPowerEnergy(Energy):
def curvature(self): def curvature(self):
return CriticalPowerCurvature(theta=self._theta, T=self.T, return CriticalPowerCurvature(theta=self._theta, T=self.T,
inverter=self._inverter) inverter=self._inverter)
@property
def logarithmic(self):
return self.T.logarithmic
@property
def smoothness_prior(self):
return self.T.strength
...@@ -47,6 +47,7 @@ class NonlinearPowerEnergy(Energy): ...@@ -47,6 +47,7 @@ class NonlinearPowerEnergy(Energy):
self.Instrument = Instrument self.Instrument = Instrument
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
self.Projection = Projection self.Projection = Projection
self._sigma = sigma
self.power = self.Projection.adjoint_times(exp(0.5*self.position)) self.power = self.Projection.adjoint_times(exp(0.5*self.position))
if sample_list is None: if sample_list is None:
...@@ -62,7 +63,7 @@ class NonlinearPowerEnergy(Energy): ...@@ -62,7 +63,7 @@ class NonlinearPowerEnergy(Energy):
def at(self, position): def at(self, position):
return self.__class__(position, self.d, self.N, self.m, self.D, return self.__class__(position, self.d, self.N, self.m, self.D,
self.FFT, self.Instrument, self.nonlinearity, self.FFT, self.Instrument, self.nonlinearity,
self.Projection, sigma=self.T.strength, self.Projection, sigma=self._sigma,
samples=len(self.sample_list), samples=len(self.sample_list),
sample_list=self.sample_list, sample_list=self.sample_list,
inverter=self.inverter) inverter=self.inverter)
......
...@@ -68,5 +68,4 @@ class WienerFilterCurvature(EndomorphicOperator): ...@@ -68,5 +68,4 @@ class WienerFilterCurvature(EndomorphicOperator):
mock_j = self.R.adjoint_times(self.N.inverse_times(mock_data)) mock_j = self.R.adjoint_times(self.N.inverse_times(mock_data))
mock_m = self.inverse_times(mock_j) mock_m = self.inverse_times(mock_j)
sample = mock_signal - mock_m return mock_signal - mock_m
return sample
...@@ -24,23 +24,26 @@ class ChainOperator(LinearOperator): ...@@ -24,23 +24,26 @@ class ChainOperator(LinearOperator):
super(ChainOperator, self).__init__() super(ChainOperator, self).__init__()
if op2.target != op1.domain: if op2.target != op1.domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
self._op1 = op1 self._capability = op1.capability & op2.capability
self._op2 = op2 op1 = op1._ops if isinstance(op1, ChainOperator) else (op1,)
op2 = op2._ops if isinstance(op2, ChainOperator) else (op2,)
self._ops = op1 + op2
@property @property
def domain(self): def domain(self):
return self._op2.domain return self._ops[-1].domain
@property @property
def target(self): def target(self):
return self._op1.target return self._ops[0].target
@property @property
def capability(self): def capability(self):
return self._op1.capability & self._op2.capability return self._capability
def apply(self, x, mode): def apply(self, x, mode):
self._check_mode(mode) self._check_mode(mode)
if mode == self.TIMES or mode == self.ADJOINT_INVERSE_TIMES: t_ops = self._ops if mode & self._backwards else reversed(self._ops)
return self._op1.apply(self._op2.apply(x, mode), mode) for op in t_ops:
return self._op2.apply(self._op1.apply(x, mode), mode) x = op.apply(x, mode)
return x
from .endomorphic_operator import EndomorphicOperator from .scaling_operator import ScalingOperator
from .fft_operator import FFTOperator from .fft_operator import FFTOperator
from ..utilities import infer_space from ..utilities import infer_space
from .diagonal_operator import DiagonalOperator from .diagonal_operator import DiagonalOperator
from .. import DomainTuple from .. import DomainTuple
class FFTSmoothingOperator(EndomorphicOperator): def FFTSmoothingOperator(domain, sigma, space=None):
def __init__(self, domain, sigma, space=None): sigma = float(sigma)
super(FFTSmoothingOperator, self).__init__() if sigma < 0.:
raise ValueError("sigma must be nonnegative")
dom = DomainTuple.make(domain) if sigma == 0.:
self._sigma = float(sigma) return ScalingOperator(1., domain)
self._space = infer_space(dom, space)
domain = DomainTuple.make(domain)
self._FFT = FFTOperator(dom, space=self._space) space = infer_space(domain, space)
codomain = self._FFT.domain[self._space].get_default_codomain() FFT = FFTOperator(domain, space=space)
kernel = codomain.get_k_length_array() codomain = FFT.domain[space].get_default_codomain()
smoother = codomain.get_fft_smoothing_kernel_function(self._sigma) kernel = codomain.get_k_length_array()
kernel = smoother(kernel) smoother = codomain.get_fft_smoothing_kernel_function(sigma)
ddom = list(dom) kernel = smoother(kernel)
ddom[self._space] = codomain ddom = list(domain)
self._diag = DiagonalOperator(kernel, ddom, self._space) ddom[space] = codomain
diag = DiagonalOperator(kernel, ddom, space)
def apply(self, x, mode): return FFT.adjoint*diag*FFT
self._check_input(x, mode)
if self._sigma == 0:
return x.copy()
return self._FFT.adjoint_times(self._diag(self._FFT(x)))
@property
def domain(self):
return self._FFT.domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
...@@ -32,6 +32,12 @@ class LinearOperator(with_metaclass( ...@@ -32,6 +32,12 @@ class LinearOperator(with_metaclass(
_adjointMode = (0, 2, 1, 0, 8, 0, 0, 0, 4) _adjointMode = (0, 2, 1, 0, 8, 0, 0, 0, 4)
_adjointCapability = (0, 2, 1, 3, 8, 10, 9, 11, 4, 6, 5, 7, 12, 14, 13, 15) _adjointCapability = (0, 2, 1, 3, 8, 10, 9, 11, 4, 6, 5, 7, 12, 14, 13, 15)
_addInverse = (0, 5, 10, 15, 5, 5, 15, 15, 10, 15, 10, 15, 15, 15, 15, 15) _addInverse = (0, 5, 10, 15, 5, 5, 15, 15, 10, 15, 10, 15, 15, 15, 15, 15)
_backwards = 6
TIMES = 1
ADJOINT_TIMES = 2
INVERSE_TIMES = 4
ADJOINT_INVERSE_TIMES = 8
INVERSE_ADJOINT_TIMES = 8
def _dom(self, mode): def _dom(self, mode):
return self.domain if (mode & 9) else self.target return self.domain if (mode & 9) else self.target
...@@ -62,26 +68,6 @@ class LinearOperator(with_metaclass( ...@@ -62,26 +68,6 @@ class LinearOperator(with_metaclass(
"""