From 6492d74b8eb2f64cfac611a68d9d4617d68ea217 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Fri, 5 Jan 2018 12:08:26 +0100 Subject: [PATCH] more polishing --- nifty/field.py | 77 ++++++++++++----------- nifty/library/critical_power_energy.py | 15 ++--- nifty/library/nonlinear_power_energy.py | 3 +- nifty/library/wiener_filter_curvature.py | 3 +- nifty/operators/chain_operator.py | 19 +++--- nifty/operators/fft_smoothing_operator.py | 51 ++++++--------- nifty/operators/linear_operator.py | 27 +++----- nifty/operators/scaling_operator.py | 14 ++++- nifty/operators/smoothness_operator.py | 49 +++------------ nifty/operators/sum_operator.py | 32 +++++++--- 10 files changed, 130 insertions(+), 160 deletions(-) diff --git a/nifty/field.py b/nifty/field.py index dd4d7b903..4428b2026 100644 --- a/nifty/field.py +++ b/nifty/field.py @@ -32,7 +32,6 @@ class Field(object): In NIFTY, Fields are used to store data arrays and carry all the needed metainformation (i.e. the domain) for operators to be able to work on them. - In addition, Field has methods to work with power spectra. Parameters ---------- @@ -59,23 +58,23 @@ class Field(object): """ def __init__(self, domain=None, val=None, dtype=None, copy=False): - self.domain = self._infer_domain(domain=domain, val=val) + self._domain = self._infer_domain(domain=domain, val=val) dtype = self._infer_dtype(dtype=dtype, val=val) if isinstance(val, Field): - if self.domain != val.domain: + if self._domain != val._domain: raise ValueError("Domain mismatch") self._val = dobj.from_object(val.val, dtype=dtype, copy=copy) elif (np.isscalar(val)): - self._val = dobj.full(self.domain.shape, dtype=dtype, + self._val = dobj.full(self._domain.shape, dtype=dtype, fill_value=val) elif isinstance(val, dobj.data_object): - if self.domain.shape == val.shape: + if self._domain.shape == val.shape: self._val = dobj.from_object(val, dtype=dtype, copy=copy) else: raise ValueError("Shape mismatch") elif val is None: - self._val = dobj.empty(self.domain.shape, dtype=dtype) + self._val = dobj.empty(self._domain.shape, dtype=dtype) else: raise TypeError("unknown source type") @@ -101,7 +100,7 @@ class Field(object): def full_like(field, val, dtype=None): if not isinstance(field, Field): raise TypeError("field must be of Field type") - return Field.full(field.domain, val, dtype) + return Field.full(field._domain, val, dtype) @staticmethod def zeros_like(field, dtype=None): @@ -109,7 +108,7 @@ class Field(object): raise TypeError("field must be of Field type") if dtype is None: dtype = field.dtype - return Field.zeros(field.domain, dtype) + return Field.zeros(field._domain, dtype) @staticmethod def ones_like(field, dtype=None): @@ -117,7 +116,7 @@ class Field(object): raise TypeError("field must be of Field type") if dtype is None: dtype = field.dtype - return Field.ones(field.domain, dtype) + return Field.ones(field._domain, dtype) @staticmethod def empty_like(field, dtype=None): @@ -125,13 +124,13 @@ class Field(object): raise TypeError("field must be of Field type") if dtype is None: dtype = field.dtype - return Field.empty(field.domain, dtype) + return Field.empty(field._domain, dtype) @staticmethod def _infer_domain(domain, val=None): if domain is None: if isinstance(val, Field): - return val.domain + return val._domain if np.isscalar(val): return DomainTuple.make(()) # empty domain tuple raise TypeError("could not infer domain from value") @@ -187,6 +186,10 @@ class Field(object): def dtype(self): return self._val.dtype + @property + def domain(self): + return self._domain + @property def shape(self): """ Returns the total shape of the Field's data array. @@ -195,7 +198,7 @@ class Field(object): ------- Integer tuple containing the dimensions of the spaces in domain. """ - return self.domain.shape + return self._domain.shape @property def dim(self): @@ -208,21 +211,21 @@ class Field(object): out : int The dimension of the Field. """ - return self.domain.dim + return self._domain.dim @property def real(self): """ The real part of the field (data is not copied).""" if not np.issubdtype(self.dtype, np.complexfloating): raise ValueError(".real called on a non-complex Field") - return Field(self.domain, self.val.real) + return Field(self._domain, self.val.real) @property def imag(self): """ The imaginary part of the field (data is not copied).""" if not np.issubdtype(self.dtype, np.complexfloating): raise ValueError(".imag called on a non-complex Field") - return Field(self.domain, self.val.imag) + return Field(self._domain, self.val.imag) def copy(self): """ Returns a full copy of the Field. @@ -238,13 +241,13 @@ class Field(object): def scalar_weight(self, spaces=None): if np.isscalar(spaces): - return self.domain[spaces].scalar_dvol() + return self._domain[spaces].scalar_dvol() if spaces is None: - spaces = range(len(self.domain)) + spaces = range(len(self._domain)) res = 1. for i in spaces: - tmp = self.domain[i].scalar_dvol() + tmp = self._domain[i].scalar_dvol() if tmp is None: return None res *= tmp @@ -277,17 +280,17 @@ class Field(object): if out is not self: out.copy_content_from(self) - spaces = utilities.parse_spaces(spaces, len(self.domain)) + spaces = utilities.parse_spaces(spaces, len(self._domain)) fct = 1. for ind in spaces: - wgt = self.domain[ind].dvol() + wgt = self._domain[ind].dvol() if np.isscalar(wgt): fct *= wgt else: new_shape = np.ones(len(self.shape), dtype=np.int) - new_shape[self.domain.axes[ind][0]: - self.domain.axes[ind][-1]+1] = wgt.shape + new_shape[self._domain.axes[ind][0]: + self._domain.axes[ind][-1]+1] = wgt.shape wgt = wgt.reshape(new_shape) if dobj.distaxis(self._val) >= 0 and ind == 0: # we need to distribute the weights along axis 0 @@ -321,10 +324,10 @@ class Field(object): raise ValueError("The dot-partner must be an instance of " + "the NIFTy field class") - if x.domain != self.domain: + if x._domain != self._domain: raise ValueError("Domain mismatch") - ndom = len(self.domain) + ndom = len(self._domain) spaces = utilities.parse_spaces(spaces, ndom) if len(spaces) == ndom: @@ -359,7 +362,7 @@ class Field(object): ------- The complex conjugated field. """ - return Field(self.domain, self.val.conjugate(), self.dtype) + return Field(self._domain, self.val.conjugate(), self.dtype) # ---General unary/contraction methods--- @@ -367,18 +370,18 @@ class Field(object): return self.copy() def __neg__(self): - return Field(self.domain, -self.val, self.dtype) + return Field(self._domain, -self.val, self.dtype) def __abs__(self): - return Field(self.domain, dobj.abs(self.val), self.dtype) + return Field(self._domain, dobj.abs(self.val), self.dtype) def _contraction_helper(self, op, spaces): if spaces is None: return getattr(self.val, op)() - spaces = utilities.parse_spaces(spaces, len(self.domain)) + spaces = utilities.parse_spaces(spaces, len(self._domain)) - axes_list = tuple(self.domain.axes[sp_index] for sp_index in spaces) + axes_list = tuple(self._domain.axes[sp_index] for sp_index in spaces) if len(axes_list) > 0: axes_list = reduce(lambda x, y: x+y, axes_list) @@ -391,7 +394,7 @@ class Field(object): return data else: return_domain = tuple(dom - for i, dom in enumerate(self.domain) + for i, dom in enumerate(self._domain) if i not in spaces) return Field(domain=return_domain, val=data, copy=False) @@ -435,21 +438,21 @@ class Field(object): def copy_content_from(self, other): if not isinstance(other, Field): raise TypeError("argument must be a Field") - if other.domain != self.domain: + if other._domain != self._domain: raise ValueError("domains are incompatible.") dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()] def _binary_helper(self, other, op): # if other is a field, make sure that the domains match if isinstance(other, Field): - if other.domain != self.domain: + if other._domain != self._domain: raise ValueError("domains are incompatible.") tval = getattr(self.val, op)(other.val) - return self if tval is self.val else Field(self.domain, tval) + return self if tval is self.val else Field(self._domain, tval) if np.isscalar(other) or isinstance(other, dobj.data_object): tval = getattr(self.val, op)(other) - return self if tval is self.val else Field(self.domain, tval) + return self if tval is self.val else Field(self._domain, tval) return NotImplemented @@ -511,7 +514,7 @@ class Field(object): minmax = [self.min(), self.max()] mean = self.mean() return "nifty2go.Field instance\n- domain = " + \ - repr(self.domain) + \ + repr(self._domain) + \ "\n- val = " + repr(self.val) + \ "\n - min.,max. = " + str(minmax) + \ "\n - mean = " + str(mean) @@ -523,12 +526,12 @@ def _math_helper(x, function, out): if not isinstance(x, Field): raise TypeError("This function only accepts Field objects.") if out is not None: - if not isinstance(out, Field) or x.domain != out.domain: + if not isinstance(out, Field) or x._domain != out._domain: raise ValueError("Bad 'out' argument") function(x.val, out=out.val) return out else: - return Field(domain=x.domain, val=function(x.val)) + return Field(domain=x._domain, val=function(x.val)) def sqrt(x, out=None): diff --git a/nifty/library/critical_power_energy.py b/nifty/library/critical_power_energy.py index 3253b12d8..f83374817 100644 --- a/nifty/library/critical_power_energy.py +++ b/nifty/library/critical_power_energy.py @@ -59,6 +59,8 @@ class CriticalPowerEnergy(Energy): self.samples = samples self.alpha = float(alpha) self.q = float(q) + self._smoothness_prior = smoothness_prior + self._logarithmic = logarithmic self.T = SmoothnessOperator(domain=self.position.domain[0], strength=smoothness_prior, logarithmic=logarithmic) @@ -93,8 +95,9 @@ class CriticalPowerEnergy(Energy): def at(self, position): return self.__class__(position, self.m, D=self.D, alpha=self.alpha, - q=self.q, smoothness_prior=self.smoothness_prior, - logarithmic=self.logarithmic, + q=self.q, + smoothness_prior=self._smoothness_prior, + logarithmic=self._logarithmic, samples=self.samples, w=self._w, inverter=self._inverter) @@ -111,11 +114,3 @@ class CriticalPowerEnergy(Energy): def curvature(self): return CriticalPowerCurvature(theta=self._theta, T=self.T, inverter=self._inverter) - - @property - def logarithmic(self): - return self.T.logarithmic - - @property - def smoothness_prior(self): - return self.T.strength diff --git a/nifty/library/nonlinear_power_energy.py b/nifty/library/nonlinear_power_energy.py index 97ec9f786..11962de5e 100644 --- a/nifty/library/nonlinear_power_energy.py +++ b/nifty/library/nonlinear_power_energy.py @@ -47,6 +47,7 @@ class NonlinearPowerEnergy(Energy): self.Instrument = Instrument self.nonlinearity = nonlinearity self.Projection = Projection + self._sigma = sigma self.power = self.Projection.adjoint_times(exp(0.5*self.position)) if sample_list is None: @@ -62,7 +63,7 @@ class NonlinearPowerEnergy(Energy): def at(self, position): return self.__class__(position, self.d, self.N, self.m, self.D, self.FFT, self.Instrument, self.nonlinearity, - self.Projection, sigma=self.T.strength, + self.Projection, sigma=self._sigma, samples=len(self.sample_list), sample_list=self.sample_list, inverter=self.inverter) diff --git a/nifty/library/wiener_filter_curvature.py b/nifty/library/wiener_filter_curvature.py index cabe0ff95..8af4bf5d7 100644 --- a/nifty/library/wiener_filter_curvature.py +++ b/nifty/library/wiener_filter_curvature.py @@ -68,5 +68,4 @@ class WienerFilterCurvature(EndomorphicOperator): mock_j = self.R.adjoint_times(self.N.inverse_times(mock_data)) mock_m = self.inverse_times(mock_j) - sample = mock_signal - mock_m - return sample + return mock_signal - mock_m diff --git a/nifty/operators/chain_operator.py b/nifty/operators/chain_operator.py index b9967ee53..6c1df7261 100644 --- a/nifty/operators/chain_operator.py +++ b/nifty/operators/chain_operator.py @@ -24,23 +24,26 @@ class ChainOperator(LinearOperator): super(ChainOperator, self).__init__() if op2.target != op1.domain: raise ValueError("domain mismatch") - self._op1 = op1 - self._op2 = op2 + self._capability = op1.capability & op2.capability + op1 = op1._ops if isinstance(op1, ChainOperator) else (op1,) + op2 = op2._ops if isinstance(op2, ChainOperator) else (op2,) + self._ops = op1 + op2 @property def domain(self): - return self._op2.domain + return self._ops[-1].domain @property def target(self): - return self._op1.target + return self._ops[0].target @property def capability(self): - return self._op1.capability & self._op2.capability + return self._capability def apply(self, x, mode): self._check_mode(mode) - if mode == self.TIMES or mode == self.ADJOINT_INVERSE_TIMES: - return self._op1.apply(self._op2.apply(x, mode), mode) - return self._op2.apply(self._op1.apply(x, mode), mode) + t_ops = self._ops if mode & self._backwards else reversed(self._ops) + for op in t_ops: + x = op.apply(x, mode) + return x diff --git a/nifty/operators/fft_smoothing_operator.py b/nifty/operators/fft_smoothing_operator.py index e7e790cef..8312b5c91 100644 --- a/nifty/operators/fft_smoothing_operator.py +++ b/nifty/operators/fft_smoothing_operator.py @@ -1,38 +1,25 @@ -from .endomorphic_operator import EndomorphicOperator +from .scaling_operator import ScalingOperator from .fft_operator import FFTOperator from ..utilities import infer_space from .diagonal_operator import DiagonalOperator from .. import DomainTuple -class FFTSmoothingOperator(EndomorphicOperator): - def __init__(self, domain, sigma, space=None): - super(FFTSmoothingOperator, self).__init__() - - dom = DomainTuple.make(domain) - self._sigma = float(sigma) - self._space = infer_space(dom, space) - - self._FFT = FFTOperator(dom, space=self._space) - codomain = self._FFT.domain[self._space].get_default_codomain() - kernel = codomain.get_k_length_array() - smoother = codomain.get_fft_smoothing_kernel_function(self._sigma) - kernel = smoother(kernel) - ddom = list(dom) - ddom[self._space] = codomain - self._diag = DiagonalOperator(kernel, ddom, self._space) - - def apply(self, x, mode): - self._check_input(x, mode) - if self._sigma == 0: - return x.copy() - - return self._FFT.adjoint_times(self._diag(self._FFT(x))) - - @property - def domain(self): - return self._FFT.domain - - @property - def capability(self): - return self.TIMES | self.ADJOINT_TIMES +def FFTSmoothingOperator(domain, sigma, space=None): + sigma = float(sigma) + if sigma < 0.: + raise ValueError("sigma must be nonnegative") + if sigma == 0.: + return ScalingOperator(1., domain) + + domain = DomainTuple.make(domain) + space = infer_space(domain, space) + FFT = FFTOperator(domain, space=space) + codomain = FFT.domain[space].get_default_codomain() + kernel = codomain.get_k_length_array() + smoother = codomain.get_fft_smoothing_kernel_function(sigma) + kernel = smoother(kernel) + ddom = list(domain) + ddom[space] = codomain + diag = DiagonalOperator(kernel, ddom, space) + return FFT.adjoint*diag*FFT diff --git a/nifty/operators/linear_operator.py b/nifty/operators/linear_operator.py index 7d5a9401a..9ea2f89d0 100644 --- a/nifty/operators/linear_operator.py +++ b/nifty/operators/linear_operator.py @@ -32,6 +32,12 @@ class LinearOperator(with_metaclass( _adjointMode = (0, 2, 1, 0, 8, 0, 0, 0, 4) _adjointCapability = (0, 2, 1, 3, 8, 10, 9, 11, 4, 6, 5, 7, 12, 14, 13, 15) _addInverse = (0, 5, 10, 15, 5, 5, 15, 15, 10, 15, 10, 15, 15, 15, 15, 15) + _backwards = 6 + TIMES = 1 + ADJOINT_TIMES = 2 + INVERSE_TIMES = 4 + ADJOINT_INVERSE_TIMES = 8 + INVERSE_ADJOINT_TIMES = 8 def _dom(self, mode): return self.domain if (mode & 9) else self.target @@ -62,26 +68,6 @@ class LinearOperator(with_metaclass( """ raise NotImplementedError - @property - def TIMES(self): - return 1 - - @property - def ADJOINT_TIMES(self): - return 2 - - @property - def INVERSE_TIMES(self): - return 4 - - @property - def ADJOINT_INVERSE_TIMES(self): - return 8 - - @property - def INVERSE_ADJOINT_TIMES(self): - return 8 - @property def inverse(self): from .inverse_operator import InverseOperator @@ -127,6 +113,7 @@ class LinearOperator(with_metaclass( other = self._toOperator(other, self.domain) return SumOperator(self, other, neg=True) + # MR FIXME: this might be more complicated ... def __rsub__(self, other): from .sum_operator import SumOperator other = self._toOperator(other, self.domain) diff --git a/nifty/operators/scaling_operator.py b/nifty/operators/scaling_operator.py index 48b21a3ef..a1c6bd3f1 100644 --- a/nifty/operators/scaling_operator.py +++ b/nifty/operators/scaling_operator.py @@ -21,7 +21,6 @@ import numpy as np from ..field import Field from ..domain_tuple import DomainTuple from .endomorphic_operator import EndomorphicOperator -from .. import dobj class ScalingOperator(EndomorphicOperator): @@ -54,6 +53,11 @@ class ScalingOperator(EndomorphicOperator): def apply(self, x, mode): self._check_input(x, mode) + if self._factor == 1.: + return x.copy() + if self._factor == 0.: + return Field.zeros_like(x, dtype=x.dtype) + if mode == self.TIMES: return x*self._factor elif mode == self.ADJOINT_TIMES: @@ -63,6 +67,14 @@ class ScalingOperator(EndomorphicOperator): else: return x*(1./np.conj(self._factor)) + @property + def inverse(self): + return ScalingOperator(1./self._factor, self._domain) + + @property + def adjoint(self): + return ScalingOperator(np.conj(self.factor), self._domain) + @property def domain(self): return self._domain diff --git a/nifty/operators/smoothness_operator.py b/nifty/operators/smoothness_operator.py index 4f307a784..b669cf353 100644 --- a/nifty/operators/smoothness_operator.py +++ b/nifty/operators/smoothness_operator.py @@ -1,8 +1,8 @@ -from .endomorphic_operator import EndomorphicOperator +from .scaling_operator import ScalingOperator from .laplace_operator import LaplaceOperator -class SmoothnessOperator(EndomorphicOperator): +def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None): """An operator measuring the smoothness on an irregular grid with respect to some scale. @@ -18,44 +18,15 @@ class SmoothnessOperator(EndomorphicOperator): Parameters ---------- - strength: float, + strength: nonnegative float Specifies the strength of the SmoothnessOperator - logarithmic : boolean, + logarithmic : boolean Whether smoothness is calculated on a logarithmic scale or linear scale default : True """ - - def __init__(self, domain, strength=1., logarithmic=True, space=None): - super(SmoothnessOperator, self).__init__() - self._laplace = LaplaceOperator(domain, logarithmic=logarithmic, - space=space) - - if strength < 0: - raise ValueError("ERROR: strength must be >=0.") - self._strength = strength - - @property - def domain(self): - return self._laplace._domain - - # MR FIXME: shouldn't this operator actually be self-adjoint? - @property - def capability(self): - return self.TIMES - - def apply(self, x, mode): - self._check_input(x, mode) - - if self._strength == 0.: - return x.zeros_like(x) - result = self._laplace.adjoint_times(self._laplace(x)) - result *= self._strength**2 - return result - - @property - def logarithmic(self): - return self._laplace.logarithmic - - @property - def strength(self): - return self._strength + if strength < 0: + raise ValueError("ERROR: strength must be nonnegative.") + if strength == 0.: + return ScalingOperator(0., domain) + laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space) + return (strength**2)*laplace.adjoint*laplace diff --git a/nifty/operators/sum_operator.py b/nifty/operators/sum_operator.py index 5ba82197b..7d6b193e7 100644 --- a/nifty/operators/sum_operator.py +++ b/nifty/operators/sum_operator.py @@ -24,25 +24,37 @@ class SumOperator(LinearOperator): super(SumOperator, self).__init__() if op1.domain != op2.domain or op1.target != op2.target: raise ValueError("domain mismatch") - self._op1 = op1 - self._op2 = op2 - self._neg = bool(neg) + self._capability = (op1.capability & op2.capability & + (self.TIMES | self.ADJOINT_TIMES)) + op1 = op1._ops if isinstance(op1, SumOperator) else (op1,) + neg1 = op1._neg if isinstance(op1, SumOperator) else (False,) + op2 = op2._ops if isinstance(op2, SumOperator) else (op2,) + neg2 = op2._neg if isinstance(op2, SumOperator) else (False,) + if neg: + neg2 = tuple(not n for n in neg2) + self._ops = op1 + op2 + self._neg = neg1 + neg2 @property def domain(self): - return self._op1.domain + return self._ops[0].domain @property def target(self): - return self._op1.target + return self._ops[0].target @property def capability(self): - return (self._op1.capability & self._op2.capability & - (self.TIMES | self.ADJOINT_TIMES)) + return self._capability def apply(self, x, mode): self._check_mode(mode) - res1 = self._op1.apply(x, mode) - res2 = self._op2.apply(x, mode) - return res1 - res2 if self._neg else res1 + res2 + for i, op in enumerate(self._ops): + if i == 0: + res = -op.apply(x, mode) if self._neg[i] else op.apply(x, mode) + else: + if self._neg[i]: + res -= op.apply(x, mode) + else: + res += op.apply(x, mode) + return res -- GitLab