Commit e3819ae7 authored by Philipp Arras's avatar Philipp Arras
Browse files

Cleanup

parent 4ec90e66
......@@ -29,16 +29,18 @@ from ..operators.harmonic_operators import FFTOperator
from ..operators.scaling_operator import ScalingOperator
from ..operators.simple_linear_operators import FieldAdapter, Realizer
from ..sugar import makeOp
from .light_cone_operator import LightConeOperator, field_from_function
from .light_cone_operator import LightConeOperator
def make_dynamic_operator(FFT,harmonic_padding,sm_s0,sm_x0,
keys=['f', 'c'],
causal=True,
cone=True,
minimum_phase=False,
sigc=3.,
quant=5.):
def _field_from_function(domain, func, absolute=False):
domain = DomainTuple.make(domain)
k_array = _make_coords(domain, absolute=absolute)
return Field.from_global_data(domain, func(k_array))
def make_dynamic_operator(FFT, harmonic_padding, sm_s0, sm_x0, keys=['f', 'c'],
causal=True, cone=True, minimum_phase=False, sigc=3.,
quant=5.):
'''
Constructs an operator for a dynamic field prior.
......@@ -81,12 +83,12 @@ def make_dynamic_operator(FFT,harmonic_padding,sm_s0,sm_x0,
'''
ops = {}
if harmonic_padding is None:
CentralPadd = ScalingOperator(1.,FFT.target)
CentralPadd = ScalingOperator(1., FFT.target)
else:
shp = ()
for i in range(len(FFT.target.shape)):
shp += (FFT.target.shape[i] + harmonic_padding[i],)
CentralPadd = FieldZeroPadder(FFT.target,shp,central=True)
CentralPadd = FieldZeroPadder(FFT.target, shp, central=True)
ops['CentralPadd'] = CentralPadd
sdom = CentralPadd.target[0].get_default_codomain()
FFTB = FFTOperator(sdom)(Realizer(sdom))
......@@ -94,19 +96,20 @@ def make_dynamic_operator(FFT,harmonic_padding,sm_s0,sm_x0,
m = FieldAdapter(sdom, keys[0])
dists = m.target[0].distances
if isinstance(sm_x0,float):
if isinstance(sm_x0, float):
sm_x0 = list((sm_x0,)*len(dists))
elif len(sm_x0) != len(dists):
raise(ValueError,"Shape mismatch!")
raise (ValueError, "Shape mismatch!")
def smoothness_prior_func(x):
res = 1.
for i in range(len(dists)):
res = res + (x[i]/sm_x0[i]/dists[i])**2
return sm_s0/res
Sm = makeOp(field_from_function(m.target, smoothness_prior_func))
Sm = makeOp(_field_from_function(m.target, smoothness_prior_func))
m = CentralPadd.adjoint(FFTB(Sm(m)))
ops[keys[0]+'_k'] = m
ops[keys[0] + '_k'] = m
m = -m.log()
if not minimum_phase:
......@@ -114,22 +117,23 @@ def make_dynamic_operator(FFT,harmonic_padding,sm_s0,sm_x0,
ops['Gncc'] = m
if causal:
m = FFT.inverse(Realizer(FFT.target).adjoint(m))
kernel = makeOp(field_from_function(FFT.domain, (lambda x: 1.+np.sign(x[0]))))
kernel = makeOp(
_field_from_function(FFT.domain, (lambda x: 1. + np.sign(x[0]))))
m = kernel(m)
elif minimum_phase:
raise(ValueError,"minimum phase and not causal not possible!")
raise (ValueError, "minimum phase and not causal not possible!")
if cone and len(m.target.shape) > 1:
if isinstance(sigc,float):
sigc = list((sigc,)*(len(m.target.shape)-1))
elif len(sigc) != len(m.target.shape)-1:
raise(ValueError,"Shape mismatch!")
if isinstance(sigc, float):
sigc = list((sigc,)*(len(m.target.shape) - 1))
elif len(sigc) != len(m.target.shape) - 1:
raise (ValueError, "Shape mismatch!")
c = FieldAdapter(UnstructuredDomain(len(sigc)), keys[1])
c = makeOp(Field.from_global_data(c.target, np.array(sigc)))(c)
lightspeed = ScalingOperator(-0.5,c.target)(c).exp()
lightspeed = ScalingOperator(-0.5, c.target)(c).exp()
scaling = np.array(m.target[0].distances[1:])/m.target[0].distances[0]
scaling = DiagonalOperator(Field.from_global_data(c.target,scaling))
scaling = DiagonalOperator(Field.from_global_data(c.target, scaling))
ops['lightspeed'] = scaling(lightspeed)
c = LightConeOperator(c.target, m.target, quant)(c.exp())
......
......@@ -28,7 +28,7 @@ from ..operators.linear_operator import LinearOperator
from ..operators.operator import Operator
def make_coords(domain, absolute=False):
def _make_coords(domain, absolute=False):
domain = DomainTuple.make(domain)
dim = len(domain.shape)
dist = domain[0].distances
......@@ -44,10 +44,6 @@ def make_coords(domain, absolute=False):
k_array[i] += ks.reshape(fst_dims + (shape[i],) + lst_dims)
return k_array
def field_from_function(domain, func, absolute=False):
domain = DomainTuple.make(domain)
k_array = make_coords(domain, absolute=absolute)
return Field.from_global_data(domain, func(k_array))
class LightConeDerivative(LinearOperator):
def __init__(self, domain, target, derivatives):
......@@ -68,9 +64,9 @@ class LightConeDerivative(LinearOperator):
res[i] = np.sum(self._derivatives[i]*x)
return Field.from_global_data(self._tgt(mode), res)
def cone_arrays(c, domain, sigx,want_gradient):
x = make_coords(domain)
def cone_arrays(c, domain, sigx, want_gradient):
x = _make_coords(domain)
a = np.zeros(domain.shape, dtype=np.complex)
if want_gradient:
derivs = np.zeros((c.size,) + domain.shape, dtype=np.complex)
......@@ -96,6 +92,7 @@ def cone_arrays(c, domain, sigx,want_gradient):
derivs = a*derivs.real
return a, derivs
class LightConeOperator(Operator):
def __init__(self, domain, target, sigx):
self._domain = domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment