Commit 6492d74b authored by Martin Reinecke's avatar Martin Reinecke

more polishing

parent b7934d79
Pipeline #23368 passed with stage
in 4 minutes and 33 seconds
......@@ -32,7 +32,6 @@ class Field(object):
In NIFTY, Fields are used to store data arrays and carry all the needed
metainformation (i.e. the domain) for operators to be able to work on them.
In addition, Field has methods to work with power spectra.
Parameters
----------
......@@ -59,23 +58,23 @@ class Field(object):
"""
def __init__(self, domain=None, val=None, dtype=None, copy=False):
self.domain = self._infer_domain(domain=domain, val=val)
self._domain = self._infer_domain(domain=domain, val=val)
dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field):
if self.domain != val.domain:
if self._domain != val._domain:
raise ValueError("Domain mismatch")
self._val = dobj.from_object(val.val, dtype=dtype, copy=copy)
elif (np.isscalar(val)):
self._val = dobj.full(self.domain.shape, dtype=dtype,
self._val = dobj.full(self._domain.shape, dtype=dtype,
fill_value=val)
elif isinstance(val, dobj.data_object):
if self.domain.shape == val.shape:
if self._domain.shape == val.shape:
self._val = dobj.from_object(val, dtype=dtype, copy=copy)
else:
raise ValueError("Shape mismatch")
elif val is None:
self._val = dobj.empty(self.domain.shape, dtype=dtype)
self._val = dobj.empty(self._domain.shape, dtype=dtype)
else:
raise TypeError("unknown source type")
......@@ -101,7 +100,7 @@ class Field(object):
def full_like(field, val, dtype=None):
if not isinstance(field, Field):
raise TypeError("field must be of Field type")
return Field.full(field.domain, val, dtype)
return Field.full(field._domain, val, dtype)
@staticmethod
def zeros_like(field, dtype=None):
......@@ -109,7 +108,7 @@ class Field(object):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.zeros(field.domain, dtype)
return Field.zeros(field._domain, dtype)
@staticmethod
def ones_like(field, dtype=None):
......@@ -117,7 +116,7 @@ class Field(object):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.ones(field.domain, dtype)
return Field.ones(field._domain, dtype)
@staticmethod
def empty_like(field, dtype=None):
......@@ -125,13 +124,13 @@ class Field(object):
raise TypeError("field must be of Field type")
if dtype is None:
dtype = field.dtype
return Field.empty(field.domain, dtype)
return Field.empty(field._domain, dtype)
@staticmethod
def _infer_domain(domain, val=None):
if domain is None:
if isinstance(val, Field):
return val.domain
return val._domain
if np.isscalar(val):
return DomainTuple.make(()) # empty domain tuple
raise TypeError("could not infer domain from value")
......@@ -187,6 +186,10 @@ class Field(object):
def dtype(self):
return self._val.dtype
@property
def domain(self):
return self._domain
@property
def shape(self):
""" Returns the total shape of the Field's data array.
......@@ -195,7 +198,7 @@ class Field(object):
-------
Integer tuple containing the dimensions of the spaces in domain.
"""
return self.domain.shape
return self._domain.shape
@property
def dim(self):
......@@ -208,21 +211,21 @@ class Field(object):
out : int
The dimension of the Field.
"""
return self.domain.dim
return self._domain.dim
@property
def real(self):
""" The real part of the field (data is not copied)."""
if not np.issubdtype(self.dtype, np.complexfloating):
raise ValueError(".real called on a non-complex Field")
return Field(self.domain, self.val.real)
return Field(self._domain, self.val.real)
@property
def imag(self):
""" The imaginary part of the field (data is not copied)."""
if not np.issubdtype(self.dtype, np.complexfloating):
raise ValueError(".imag called on a non-complex Field")
return Field(self.domain, self.val.imag)
return Field(self._domain, self.val.imag)
def copy(self):
""" Returns a full copy of the Field.
......@@ -238,13 +241,13 @@ class Field(object):
def scalar_weight(self, spaces=None):
if np.isscalar(spaces):
return self.domain[spaces].scalar_dvol()
return self._domain[spaces].scalar_dvol()
if spaces is None:
spaces = range(len(self.domain))
spaces = range(len(self._domain))
res = 1.
for i in spaces:
tmp = self.domain[i].scalar_dvol()
tmp = self._domain[i].scalar_dvol()
if tmp is None:
return None
res *= tmp
......@@ -277,17 +280,17 @@ class Field(object):
if out is not self:
out.copy_content_from(self)
spaces = utilities.parse_spaces(spaces, len(self.domain))
spaces = utilities.parse_spaces(spaces, len(self._domain))
fct = 1.
for ind in spaces:
wgt = self.domain[ind].dvol()
wgt = self._domain[ind].dvol()
if np.isscalar(wgt):
fct *= wgt
else:
new_shape = np.ones(len(self.shape), dtype=np.int)
new_shape[self.domain.axes[ind][0]:
self.domain.axes[ind][-1]+1] = wgt.shape
new_shape[self._domain.axes[ind][0]:
self._domain.axes[ind][-1]+1] = wgt.shape
wgt = wgt.reshape(new_shape)
if dobj.distaxis(self._val) >= 0 and ind == 0:
# we need to distribute the weights along axis 0
......@@ -321,10 +324,10 @@ class Field(object):
raise ValueError("The dot-partner must be an instance of " +
"the NIFTy field class")
if x.domain != self.domain:
if x._domain != self._domain:
raise ValueError("Domain mismatch")
ndom = len(self.domain)
ndom = len(self._domain)
spaces = utilities.parse_spaces(spaces, ndom)
if len(spaces) == ndom:
......@@ -359,7 +362,7 @@ class Field(object):
-------
The complex conjugated field.
"""
return Field(self.domain, self.val.conjugate(), self.dtype)
return Field(self._domain, self.val.conjugate(), self.dtype)
# ---General unary/contraction methods---
......@@ -367,18 +370,18 @@ class Field(object):
return self.copy()
def __neg__(self):
return Field(self.domain, -self.val, self.dtype)
return Field(self._domain, -self.val, self.dtype)
def __abs__(self):
return Field(self.domain, dobj.abs(self.val), self.dtype)
return Field(self._domain, dobj.abs(self.val), self.dtype)
def _contraction_helper(self, op, spaces):
if spaces is None:
return getattr(self.val, op)()
spaces = utilities.parse_spaces(spaces, len(self.domain))
spaces = utilities.parse_spaces(spaces, len(self._domain))
axes_list = tuple(self.domain.axes[sp_index] for sp_index in spaces)
axes_list = tuple(self._domain.axes[sp_index] for sp_index in spaces)
if len(axes_list) > 0:
axes_list = reduce(lambda x, y: x+y, axes_list)
......@@ -391,7 +394,7 @@ class Field(object):
return data
else:
return_domain = tuple(dom
for i, dom in enumerate(self.domain)
for i, dom in enumerate(self._domain)
if i not in spaces)
return Field(domain=return_domain, val=data, copy=False)
......@@ -435,21 +438,21 @@ class Field(object):
def copy_content_from(self, other):
if not isinstance(other, Field):
raise TypeError("argument must be a Field")
if other.domain != self.domain:
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()]
def _binary_helper(self, other, op):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other.domain != self.domain:
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val)
return self if tval is self.val else Field(self.domain, tval)
return self if tval is self.val else Field(self._domain, tval)
if np.isscalar(other) or isinstance(other, dobj.data_object):
tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self.domain, tval)
return self if tval is self.val else Field(self._domain, tval)
return NotImplemented
......@@ -511,7 +514,7 @@ class Field(object):
minmax = [self.min(), self.max()]
mean = self.mean()
return "nifty2go.Field instance\n- domain = " + \
repr(self.domain) + \
repr(self._domain) + \
"\n- val = " + repr(self.val) + \
"\n - min.,max. = " + str(minmax) + \
"\n - mean = " + str(mean)
......@@ -523,12 +526,12 @@ def _math_helper(x, function, out):
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x.domain != out.domain:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
function(x.val, out=out.val)
return out
else:
return Field(domain=x.domain, val=function(x.val))
return Field(domain=x._domain, val=function(x.val))
def sqrt(x, out=None):
......
......@@ -59,6 +59,8 @@ class CriticalPowerEnergy(Energy):
self.samples = samples
self.alpha = float(alpha)
self.q = float(q)
self._smoothness_prior = smoothness_prior
self._logarithmic = logarithmic
self.T = SmoothnessOperator(domain=self.position.domain[0],
strength=smoothness_prior,
logarithmic=logarithmic)
......@@ -93,8 +95,9 @@ class CriticalPowerEnergy(Energy):
def at(self, position):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
q=self.q, smoothness_prior=self.smoothness_prior,
logarithmic=self.logarithmic,
q=self.q,
smoothness_prior=self._smoothness_prior,
logarithmic=self._logarithmic,
samples=self.samples, w=self._w,
inverter=self._inverter)
......@@ -111,11 +114,3 @@ class CriticalPowerEnergy(Energy):
def curvature(self):
return CriticalPowerCurvature(theta=self._theta, T=self.T,
inverter=self._inverter)
@property
def logarithmic(self):
return self.T.logarithmic
@property
def smoothness_prior(self):
return self.T.strength
......@@ -47,6 +47,7 @@ class NonlinearPowerEnergy(Energy):
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.Projection = Projection
self._sigma = sigma
self.power = self.Projection.adjoint_times(exp(0.5*self.position))
if sample_list is None:
......@@ -62,7 +63,7 @@ class NonlinearPowerEnergy(Energy):
def at(self, position):
return self.__class__(position, self.d, self.N, self.m, self.D,
self.FFT, self.Instrument, self.nonlinearity,
self.Projection, sigma=self.T.strength,
self.Projection, sigma=self._sigma,
samples=len(self.sample_list),
sample_list=self.sample_list,
inverter=self.inverter)
......
......@@ -68,5 +68,4 @@ class WienerFilterCurvature(EndomorphicOperator):
mock_j = self.R.adjoint_times(self.N.inverse_times(mock_data))
mock_m = self.inverse_times(mock_j)
sample = mock_signal - mock_m
return sample
return mock_signal - mock_m
......@@ -24,23 +24,26 @@ class ChainOperator(LinearOperator):
super(ChainOperator, self).__init__()
if op2.target != op1.domain:
raise ValueError("domain mismatch")
self._op1 = op1
self._op2 = op2
self._capability = op1.capability & op2.capability
op1 = op1._ops if isinstance(op1, ChainOperator) else (op1,)
op2 = op2._ops if isinstance(op2, ChainOperator) else (op2,)
self._ops = op1 + op2
@property
def domain(self):
return self._op2.domain
return self._ops[-1].domain
@property
def target(self):
return self._op1.target
return self._ops[0].target
@property
def capability(self):
return self._op1.capability & self._op2.capability
return self._capability
def apply(self, x, mode):
self._check_mode(mode)
if mode == self.TIMES or mode == self.ADJOINT_INVERSE_TIMES:
return self._op1.apply(self._op2.apply(x, mode), mode)
return self._op2.apply(self._op1.apply(x, mode), mode)
t_ops = self._ops if mode & self._backwards else reversed(self._ops)
for op in t_ops:
x = op.apply(x, mode)
return x
from .endomorphic_operator import EndomorphicOperator
from .scaling_operator import ScalingOperator
from .fft_operator import FFTOperator
from ..utilities import infer_space
from .diagonal_operator import DiagonalOperator
from .. import DomainTuple
class FFTSmoothingOperator(EndomorphicOperator):
def __init__(self, domain, sigma, space=None):
super(FFTSmoothingOperator, self).__init__()
dom = DomainTuple.make(domain)
self._sigma = float(sigma)
self._space = infer_space(dom, space)
self._FFT = FFTOperator(dom, space=self._space)
codomain = self._FFT.domain[self._space].get_default_codomain()
kernel = codomain.get_k_length_array()
smoother = codomain.get_fft_smoothing_kernel_function(self._sigma)
kernel = smoother(kernel)
ddom = list(dom)
ddom[self._space] = codomain
self._diag = DiagonalOperator(kernel, ddom, self._space)
def apply(self, x, mode):
self._check_input(x, mode)
if self._sigma == 0:
return x.copy()
return self._FFT.adjoint_times(self._diag(self._FFT(x)))
@property
def domain(self):
return self._FFT.domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def FFTSmoothingOperator(domain, sigma, space=None):
sigma = float(sigma)
if sigma < 0.:
raise ValueError("sigma must be nonnegative")
if sigma == 0.:
return ScalingOperator(1., domain)
domain = DomainTuple.make(domain)
space = infer_space(domain, space)
FFT = FFTOperator(domain, space=space)
codomain = FFT.domain[space].get_default_codomain()
kernel = codomain.get_k_length_array()
smoother = codomain.get_fft_smoothing_kernel_function(sigma)
kernel = smoother(kernel)
ddom = list(domain)
ddom[space] = codomain
diag = DiagonalOperator(kernel, ddom, space)
return FFT.adjoint*diag*FFT
......@@ -32,6 +32,12 @@ class LinearOperator(with_metaclass(
_adjointMode = (0, 2, 1, 0, 8, 0, 0, 0, 4)
_adjointCapability = (0, 2, 1, 3, 8, 10, 9, 11, 4, 6, 5, 7, 12, 14, 13, 15)
_addInverse = (0, 5, 10, 15, 5, 5, 15, 15, 10, 15, 10, 15, 15, 15, 15, 15)
_backwards = 6
TIMES = 1
ADJOINT_TIMES = 2
INVERSE_TIMES = 4
ADJOINT_INVERSE_TIMES = 8
INVERSE_ADJOINT_TIMES = 8
def _dom(self, mode):
return self.domain if (mode & 9) else self.target
......@@ -62,26 +68,6 @@ class LinearOperator(with_metaclass(
"""
raise NotImplementedError
@property
def TIMES(self):
return 1
@property
def ADJOINT_TIMES(self):
return 2
@property
def INVERSE_TIMES(self):
return 4
@property
def ADJOINT_INVERSE_TIMES(self):
return 8
@property
def INVERSE_ADJOINT_TIMES(self):
return 8
@property
def inverse(self):
from .inverse_operator import InverseOperator
......@@ -127,6 +113,7 @@ class LinearOperator(with_metaclass(
other = self._toOperator(other, self.domain)
return SumOperator(self, other, neg=True)
# MR FIXME: this might be more complicated ...
def __rsub__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
......
......@@ -21,7 +21,6 @@ import numpy as np
from ..field import Field
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
from .. import dobj
class ScalingOperator(EndomorphicOperator):
......@@ -54,6 +53,11 @@ class ScalingOperator(EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if self._factor == 1.:
return x.copy()
if self._factor == 0.:
return Field.zeros_like(x, dtype=x.dtype)
if mode == self.TIMES:
return x*self._factor
elif mode == self.ADJOINT_TIMES:
......@@ -63,6 +67,14 @@ class ScalingOperator(EndomorphicOperator):
else:
return x*(1./np.conj(self._factor))
@property
def inverse(self):
return ScalingOperator(1./self._factor, self._domain)
@property
def adjoint(self):
return ScalingOperator(np.conj(self.factor), self._domain)
@property
def domain(self):
return self._domain
......
from .endomorphic_operator import EndomorphicOperator
from .scaling_operator import ScalingOperator
from .laplace_operator import LaplaceOperator
class SmoothnessOperator(EndomorphicOperator):
def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
"""An operator measuring the smoothness on an irregular grid with respect
to some scale.
......@@ -18,44 +18,15 @@ class SmoothnessOperator(EndomorphicOperator):
Parameters
----------
strength: float,
strength: nonnegative float
Specifies the strength of the SmoothnessOperator
logarithmic : boolean,
logarithmic : boolean
Whether smoothness is calculated on a logarithmic scale or linear scale
default : True
"""
def __init__(self, domain, strength=1., logarithmic=True, space=None):
super(SmoothnessOperator, self).__init__()
self._laplace = LaplaceOperator(domain, logarithmic=logarithmic,
space=space)
if strength < 0:
raise ValueError("ERROR: strength must be >=0.")
self._strength = strength
@property
def domain(self):
return self._laplace._domain
# MR FIXME: shouldn't this operator actually be self-adjoint?
@property
def capability(self):
return self.TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if self._strength == 0.:
return x.zeros_like(x)
result = self._laplace.adjoint_times(self._laplace(x))
result *= self._strength**2
return result
@property
def logarithmic(self):
return self._laplace.logarithmic
@property
def strength(self):
return self._strength
if strength < 0:
raise ValueError("ERROR: strength must be nonnegative.")
if strength == 0.:
return ScalingOperator(0., domain)
laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space)
return (strength**2)*laplace.adjoint*laplace
......@@ -24,25 +24,37 @@ class SumOperator(LinearOperator):
super(SumOperator, self).__init__()
if op1.domain != op2.domain or op1.target != op2.target:
raise ValueError("domain mismatch")
self._op1 = op1
self._op2 = op2
self._neg = bool(neg)
self._capability = (op1.capability & op2.capability &
(self.TIMES | self.ADJOINT_TIMES))
op1 = op1._ops if isinstance(op1, SumOperator) else (op1,)
neg1 = op1._neg if isinstance(op1, SumOperator) else (False,)
op2 = op2._ops if isinstance(op2, SumOperator) else (op2,)
neg2 = op2._neg if isinstance(op2, SumOperator) else (False,)
if neg:
neg2 = tuple(not n for n in neg2)
self._ops = op1 + op2
self._neg = neg1 + neg2
@property
def domain(self):
return self._op1.domain
return self._ops[0].domain
@property
def target(self):
return self._op1.target
return self._ops[0].target
@property
def capability(self):
return (self._op1.capability & self._op2.capability &
(self.TIMES | self.ADJOINT_TIMES))
return self._capability
def apply(self, x, mode):
self._check_mode(mode)
res1 = self._op1.apply(x, mode)
res2 = self._op2.apply(x, mode)
return res1 - res2 if self._neg else res1 + res2
for i, op in enumerate(self._ops):
if i == 0:
res = -op.apply(x, mode) if self._neg[i] else op.apply(x, mode)
else:
if self._neg[i]:
res -= op.apply(x, mode)
else:
res += op.apply(x, mode)