Commit 6492d74b authored by Martin Reinecke's avatar Martin Reinecke

more polishing

parent b7934d79
Pipeline #23368 passed with stage
in 4 minutes and 33 seconds
This diff is collapsed.
......@@ -59,6 +59,8 @@ class CriticalPowerEnergy(Energy):
self.samples = samples
self.alpha = float(alpha)
self.q = float(q)
self._smoothness_prior = smoothness_prior
self._logarithmic = logarithmic
self.T = SmoothnessOperator(domain=self.position.domain[0],
strength=smoothness_prior,
logarithmic=logarithmic)
......@@ -93,8 +95,9 @@ class CriticalPowerEnergy(Energy):
def at(self, position):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
q=self.q, smoothness_prior=self.smoothness_prior,
logarithmic=self.logarithmic,
q=self.q,
smoothness_prior=self._smoothness_prior,
logarithmic=self._logarithmic,
samples=self.samples, w=self._w,
inverter=self._inverter)
......@@ -111,11 +114,3 @@ class CriticalPowerEnergy(Energy):
def curvature(self):
return CriticalPowerCurvature(theta=self._theta, T=self.T,
inverter=self._inverter)
@property
def logarithmic(self):
return self.T.logarithmic
@property
def smoothness_prior(self):
return self.T.strength
......@@ -47,6 +47,7 @@ class NonlinearPowerEnergy(Energy):
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.Projection = Projection
self._sigma = sigma
self.power = self.Projection.adjoint_times(exp(0.5*self.position))
if sample_list is None:
......@@ -62,7 +63,7 @@ class NonlinearPowerEnergy(Energy):
def at(self, position):
return self.__class__(position, self.d, self.N, self.m, self.D,
self.FFT, self.Instrument, self.nonlinearity,
self.Projection, sigma=self.T.strength,
self.Projection, sigma=self._sigma,
samples=len(self.sample_list),
sample_list=self.sample_list,
inverter=self.inverter)
......
......@@ -68,5 +68,4 @@ class WienerFilterCurvature(EndomorphicOperator):
mock_j = self.R.adjoint_times(self.N.inverse_times(mock_data))
mock_m = self.inverse_times(mock_j)
sample = mock_signal - mock_m
return sample
return mock_signal - mock_m
......@@ -24,23 +24,26 @@ class ChainOperator(LinearOperator):
super(ChainOperator, self).__init__()
if op2.target != op1.domain:
raise ValueError("domain mismatch")
self._op1 = op1
self._op2 = op2
self._capability = op1.capability & op2.capability
op1 = op1._ops if isinstance(op1, ChainOperator) else (op1,)
op2 = op2._ops if isinstance(op2, ChainOperator) else (op2,)
self._ops = op1 + op2
@property
def domain(self):
return self._op2.domain
return self._ops[-1].domain
@property
def target(self):
return self._op1.target
return self._ops[0].target
@property
def capability(self):
return self._op1.capability & self._op2.capability
return self._capability
def apply(self, x, mode):
self._check_mode(mode)
if mode == self.TIMES or mode == self.ADJOINT_INVERSE_TIMES:
return self._op1.apply(self._op2.apply(x, mode), mode)
return self._op2.apply(self._op1.apply(x, mode), mode)
t_ops = self._ops if mode & self._backwards else reversed(self._ops)
for op in t_ops:
x = op.apply(x, mode)
return x
from .endomorphic_operator import EndomorphicOperator
from .scaling_operator import ScalingOperator
from .fft_operator import FFTOperator
from ..utilities import infer_space
from .diagonal_operator import DiagonalOperator
from .. import DomainTuple
class FFTSmoothingOperator(EndomorphicOperator):
def __init__(self, domain, sigma, space=None):
super(FFTSmoothingOperator, self).__init__()
dom = DomainTuple.make(domain)
self._sigma = float(sigma)
self._space = infer_space(dom, space)
self._FFT = FFTOperator(dom, space=self._space)
codomain = self._FFT.domain[self._space].get_default_codomain()
kernel = codomain.get_k_length_array()
smoother = codomain.get_fft_smoothing_kernel_function(self._sigma)
kernel = smoother(kernel)
ddom = list(dom)
ddom[self._space] = codomain
self._diag = DiagonalOperator(kernel, ddom, self._space)
def apply(self, x, mode):
self._check_input(x, mode)
if self._sigma == 0:
return x.copy()
return self._FFT.adjoint_times(self._diag(self._FFT(x)))
@property
def domain(self):
return self._FFT.domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def FFTSmoothingOperator(domain, sigma, space=None):
sigma = float(sigma)
if sigma < 0.:
raise ValueError("sigma must be nonnegative")
if sigma == 0.:
return ScalingOperator(1., domain)
domain = DomainTuple.make(domain)
space = infer_space(domain, space)
FFT = FFTOperator(domain, space=space)
codomain = FFT.domain[space].get_default_codomain()
kernel = codomain.get_k_length_array()
smoother = codomain.get_fft_smoothing_kernel_function(sigma)
kernel = smoother(kernel)
ddom = list(domain)
ddom[space] = codomain
diag = DiagonalOperator(kernel, ddom, space)
return FFT.adjoint*diag*FFT
......@@ -32,6 +32,12 @@ class LinearOperator(with_metaclass(
_adjointMode = (0, 2, 1, 0, 8, 0, 0, 0, 4)
_adjointCapability = (0, 2, 1, 3, 8, 10, 9, 11, 4, 6, 5, 7, 12, 14, 13, 15)
_addInverse = (0, 5, 10, 15, 5, 5, 15, 15, 10, 15, 10, 15, 15, 15, 15, 15)
_backwards = 6
TIMES = 1
ADJOINT_TIMES = 2
INVERSE_TIMES = 4
ADJOINT_INVERSE_TIMES = 8
INVERSE_ADJOINT_TIMES = 8
def _dom(self, mode):
return self.domain if (mode & 9) else self.target
......@@ -62,26 +68,6 @@ class LinearOperator(with_metaclass(
"""
raise NotImplementedError
@property
def TIMES(self):
return 1
@property
def ADJOINT_TIMES(self):
return 2
@property
def INVERSE_TIMES(self):
return 4
@property
def ADJOINT_INVERSE_TIMES(self):
return 8
@property
def INVERSE_ADJOINT_TIMES(self):
return 8
@property
def inverse(self):
from .inverse_operator import InverseOperator
......@@ -127,6 +113,7 @@ class LinearOperator(with_metaclass(
other = self._toOperator(other, self.domain)
return SumOperator(self, other, neg=True)
# MR FIXME: this might be more complicated ...
def __rsub__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
......
......@@ -21,7 +21,6 @@ import numpy as np
from ..field import Field
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
from .. import dobj
class ScalingOperator(EndomorphicOperator):
......@@ -54,6 +53,11 @@ class ScalingOperator(EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if self._factor == 1.:
return x.copy()
if self._factor == 0.:
return Field.zeros_like(x, dtype=x.dtype)
if mode == self.TIMES:
return x*self._factor
elif mode == self.ADJOINT_TIMES:
......@@ -63,6 +67,14 @@ class ScalingOperator(EndomorphicOperator):
else:
return x*(1./np.conj(self._factor))
@property
def inverse(self):
return ScalingOperator(1./self._factor, self._domain)
@property
def adjoint(self):
return ScalingOperator(np.conj(self.factor), self._domain)
@property
def domain(self):
return self._domain
......
from .endomorphic_operator import EndomorphicOperator
from .scaling_operator import ScalingOperator
from .laplace_operator import LaplaceOperator
class SmoothnessOperator(EndomorphicOperator):
def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
"""An operator measuring the smoothness on an irregular grid with respect
to some scale.
......@@ -18,44 +18,15 @@ class SmoothnessOperator(EndomorphicOperator):
Parameters
----------
strength: float,
strength: nonnegative float
Specifies the strength of the SmoothnessOperator
logarithmic : boolean,
logarithmic : boolean
Whether smoothness is calculated on a logarithmic scale or linear scale
default : True
"""
def __init__(self, domain, strength=1., logarithmic=True, space=None):
super(SmoothnessOperator, self).__init__()
self._laplace = LaplaceOperator(domain, logarithmic=logarithmic,
space=space)
if strength < 0:
raise ValueError("ERROR: strength must be >=0.")
self._strength = strength
@property
def domain(self):
return self._laplace._domain
# MR FIXME: shouldn't this operator actually be self-adjoint?
@property
def capability(self):
return self.TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if self._strength == 0.:
return x.zeros_like(x)
result = self._laplace.adjoint_times(self._laplace(x))
result *= self._strength**2
return result
@property
def logarithmic(self):
return self._laplace.logarithmic
@property
def strength(self):
return self._strength
if strength < 0:
raise ValueError("ERROR: strength must be nonnegative.")
if strength == 0.:
return ScalingOperator(0., domain)
laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space)
return (strength**2)*laplace.adjoint*laplace
......@@ -24,25 +24,37 @@ class SumOperator(LinearOperator):
super(SumOperator, self).__init__()
if op1.domain != op2.domain or op1.target != op2.target:
raise ValueError("domain mismatch")
self._op1 = op1
self._op2 = op2
self._neg = bool(neg)
self._capability = (op1.capability & op2.capability &
(self.TIMES | self.ADJOINT_TIMES))
op1 = op1._ops if isinstance(op1, SumOperator) else (op1,)
neg1 = op1._neg if isinstance(op1, SumOperator) else (False,)
op2 = op2._ops if isinstance(op2, SumOperator) else (op2,)
neg2 = op2._neg if isinstance(op2, SumOperator) else (False,)
if neg:
neg2 = tuple(not n for n in neg2)
self._ops = op1 + op2
self._neg = neg1 + neg2
@property
def domain(self):
return self._op1.domain
return self._ops[0].domain
@property
def target(self):
return self._op1.target
return self._ops[0].target
@property
def capability(self):
return (self._op1.capability & self._op2.capability &
(self.TIMES | self.ADJOINT_TIMES))
return self._capability
def apply(self, x, mode):
self._check_mode(mode)
res1 = self._op1.apply(x, mode)
res2 = self._op2.apply(x, mode)
return res1 - res2 if self._neg else res1 + res2
for i, op in enumerate(self._ops):
if i == 0:
res = -op.apply(x, mode) if self._neg[i] else op.apply(x, mode)
else:
if self._neg[i]:
res -= op.apply(x, mode)
else:
res += op.apply(x, mode)
return res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment