Commit 8afcedf6 authored by Ultima's avatar Ultima

Made a new propagator_operator

parent df8cb230
......@@ -46,7 +46,7 @@ from nifty import *
class problem(object):
def __init__(self, x_space, s2n=12, **kwargs):
def __init__(self, x_space, s2n=6, **kwargs):
"""
Sets up a Wiener filter problem.
......@@ -67,7 +67,7 @@ class problem(object):
#self.k.set_power_indices(**kwargs)
## set some power spectrum
self.power = (lambda k: 42 / (k + 1) ** 3)
self.power = (lambda k: 42 / (k + 1) ** 2)
## define signal covariance
self.S = power_operator(self.k, spec=self.power, bare=True)
......@@ -256,7 +256,7 @@ class problem(object):
##-----------------------------------------------------------------------------
#
if(__name__=="__main__"):
x = rg_space((128), zerocenter=True)
x = rg_space((1280), zerocenter=True)
p = problem(x, log = False)
about.warnings.off()
## pl.close("all")
......
......@@ -37,14 +37,14 @@ import matplotlib as mpl
mpl.use('Agg')
import gc
import imp
#nifty = imp.load_module('nifty', None,
# '/home/steininger/Downloads/nifty', ('','',5))
nifty = imp.load_module('nifty', None,
'/home/steininger/Downloads/nifty', ('','',5))
from nifty import * # version 0.8.0
about.warnings.off()
# some signal space; e.g., a two-dimensional regular grid
shape = [1024,]
shape = [1024]
x_space = rg_space(shape)
#y_space = point_space(1280*1280)
#x_space = hp_space(32)
......@@ -91,25 +91,25 @@ m = D(j, W=S, tol=1E-8, limii=100, note=True)
#temp_result = (D.inverse_times(m)-xi)
#n_power = x_space.enforce_power(s.var()/np.prod(shape))
#s_power = S.get_power()
n_power = x_space.enforce_power(s.var()/np.prod(shape))
s_power = S.get_power()
#s.plot(title="signal", save = 'plot_s.png')
#s.plot(title="signal power", power=True, other=power,
# mono=False, save = 'power_plot_s.png', nbin=1000, log=True,
# vmax = 100, vmin=10e-7)
s.plot(title="signal", save = 'plot_s.png')
s.plot(title="signal power", power=True, other=power,
mono=False, save = 'power_plot_s.png', nbin=1000, log=True,
vmax = 100, vmin=10e-7)
#d_ = field(x_space, val=d.val, target=k_space)
#d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png')
d_ = field(x_space, val=d.val, target=k_space)
d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png')
#n_ = field(x_space, val=n.val, target=k_space)
#n_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_n.png')
n_ = field(x_space, val=n.val, target=k_space)
n_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_n.png')
#
#m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png')
#m.plot(title="reconstructed power", power=True, other=(n_power, s_power),
# save = 'power_plot_m.png', vmin=0.001, vmax=10, mono=False)
#
m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png')
m.plot(title="reconstructed power", power=True, other=(n_power, s_power),
save = 'power_plot_m.png', vmin=0.001, vmax=10, mono=False)
#
......@@ -826,7 +826,7 @@ class point_space(space):
self.comm = self._parse_comm(comm)
self.discrete = True
self.harmonic = False
# self.harmonic = False
self.distances = (np.float(1),)
@property
......@@ -1387,7 +1387,7 @@ class point_space(space):
if not isinstance(codomain, space):
raise TypeError(about._errors.cstring(
"ERROR: invalid input. The given input is no nifty space."))
"ERROR: invalid input. The given input is not a nifty space."))
if codomain == self:
return True
......@@ -1830,7 +1830,6 @@ class point_space(space):
string += 'datamodel: ' + str(self.datamodel) + "\n"
string += 'comm: ' + self.comm.name + "\n"
string += 'discrete: ' + str(self.discrete) + "\n"
string += 'harmonic: ' + str(self.harmonic) + "\n"
string += 'distances: ' + str(self.distances) + "\n"
return string
......@@ -1974,25 +1973,23 @@ class field(object):
else:
codomain = domain.get_codomain()
# Check if the given field lives in the same fourier-space as the
# new domain
if f.domain.harmonic != domain.harmonic:
# check for ishape
if ishape is None:
ishape = f.ishape
# Check if the given field lives in a space which is compatible to the
# given domain
if f.domain != domain:
# Try to transform the given field to the given domain/codomain
f = f.transform(new_domain=domain,
new_codomain=codomain)
# Check if the domain is now really the same.
# This is necessary since iso-fourier-conversion is not implemented
if f.domain == domain:
self._init_from_array(domain=domain,
val=f.val,
codomain=codomain,
ishape=ishape,
copy=copy,
**kwargs)
else:
raise ValueError(about._errors.cstring(
"ERROR: Incompatible domain given."))
self._init_from_array(domain=domain,
val=f.val,
codomain=codomain,
ishape=ishape,
copy=copy,
**kwargs)
def _init_from_array(self, val, domain, codomain, ishape, copy, **kwargs):
# check domain
......@@ -2961,10 +2958,15 @@ class field(object):
return new_field
def _binary_helper(self, other, op='None', inplace=False):
# the other object could be a field/operator. Try to extract its data.
# if other is a field, make sure that the domains match
if isinstance(other, field):
other = field(domain=self.domain,
val=other,
codomain=self.codomain,
copy=False)
try:
other_val = other.get_val()
except(AttributeError):
except AttributeError:
other_val = other
# bring other_val into the right shape
......
......@@ -252,8 +252,7 @@ class conjugate_gradient(object):
convergence = 0
ii = 1
while(True):
from time import sleep
sleep(0.5)
# print ('gamma', gamma)
q = self.A(d)
# print ('q', q.val)
......
......@@ -157,7 +157,7 @@ class operator(object):
# If the operator is symmetric or unitary, we know that the operator
# must be square
if self.sym is True or self.uni is True:
if self.sym or self.uni:
target = self.domain
cotarget = self.codomain
if target is not None:
......@@ -225,49 +225,33 @@ class operator(object):
"ERROR: no generic instance method 'inverse_adjoint_multiply'."))
def _briefing(self, x, domain, codomain, inverse):
# inspect x
if not isinstance(x, field):
y = field(domain, codomain=codomain, val=x)
else:
# check x.domain
if x.domain == domain:
y = x
else:
if x.domain.harmonic != domain.harmonic:
y = x.transform(codomain=domain)
else:
y = x.copy(domain=domain, codomain=codomain)
# make sure, that the result_field of the briefing lives in the
# given domain and codomain
result_field = field(domain=domain, val=x, codomain=codomain,
copy=False)
# weight if ...
# weight if necessary
if (not self.imp) and (not domain.discrete) and (not inverse):
y = y.weight(power=1)
return y
result_field = result_field.weight(power=1)
return result_field
def _debriefing(self, x, y, target, cotarget, inverse):
# > evaluates x and y after `multiply`
if y is None:
return None
else:
# inspect y
if not isinstance(y, field):
y = field(target, codomain=cotarget, val=y)
elif y.domain != target:
raise ValueError(about._errors.cstring(
"ERROR: invalid output domain."))
# weight if ...
if (not self.imp) and (not target.discrete) and inverse:
y = y.weight(power=-1)
# inspect x
if isinstance(x, field):
# repair if the originally field was living in the codomain
# of the operators domain
if self.domain == self.target and\
x.codomain == self.domain and\
x.codomain != x.domain:
y = y.transform(codomain=x.domain)
if x.domain == y.domain and (x.codomain != y.codomain):
y.set_codomain(new_codomain=x.codomain)
return y
# The debriefing takes care that the result field lives in the same
# fourier-type domain as the input field
assert(isinstance(y, field))
# weight if necessary
if (not self.imp) and (not target.discrete) and inverse:
y = y.weight(power=-1)
# if the operators domain as well as the target have the harmonic
# attribute, try to match the result_field to the input_field
if hasattr(self.domain, 'harmonic') and \
hasattr(self.target, 'harmonic'):
if x.domain.harmonic != y.domain.harmonic:
y = y.transform()
return y
def times(self, x, **kwargs):
"""
......@@ -1151,7 +1135,7 @@ class diagonal_operator(operator):
self.target = self.domain
self.cotarget = self.codomain
self.imp = True
self.set_diag(new_diag=diag)
self.set_diag(new_diag=diag, bare=bare)
def set_diag(self, new_diag, bare=False):
"""
......@@ -1605,14 +1589,16 @@ class diagonal_operator(operator):
else:
codomain = domain.get_codomain()
if domain.harmonic != self.domain.harmonic:
temp_field = temp_field.transform(codomain=domain)
return field(domain=domain, val=temp_field, codomain=codomain)
if self.domain == domain and self.codomain == codomain:
return temp_field
else:
return temp_field.copy(domain=domain,
codomain=codomain)
# if domain.harmonic != self.domain.harmonic:
# temp_field = temp_field.transform(new_domain=domain)
#
# if self.domain == domain and self.codomain == codomain:
# return temp_field
# else:
# return temp_field.copy(domain=domain,
# codomain=codomain)
def __repr__(self):
return "<nifty_core.diagonal_operator>"
......@@ -2388,16 +2374,18 @@ class projection_operator(operator):
# check if field is in the same signal/harmonic space as the
# domain of the projection operator
if self.domain != x.domain:
x = x.transform(codomain=self.domain)
x = x.transform(new_domain=self.domain)
vecvec = vecvec_operator(val=x)
return self.pseudo_tr(x=vecvec, axis=axis, **kwargs)
# Case 2: x is an operator
# -> take the diagonal
elif isinstance(x, operator):
working_field = x.diag(bare=False)
working_field = x.diag(bare=False,
domain=self.domain,
codomain=self.codomain)
if self.domain != working_field.domain:
working_field = working_field.transform(codomain=self.domain)
working_field = working_field.transform(new_domain=self.domain)
# Case 3: x is something else
else:
......@@ -2944,49 +2932,53 @@ class response_operator(operator):
codomain=self.codomain)
def _briefing(self, x, domain, codomain, inverse):
# inspect x
if not isinstance(x, field):
y = field(domain, codomain=codomain, val=x)
else:
# check x.domain
if x.domain == domain:
y = x
else:
if x.domain.harmonic != domain.harmonic:
y = x.transform(codomain=domain)
else:
y = x.copy(domain=domain, codomain=codomain)
# make sure, that the result_field of the briefing lives in the
# given domain and codomain
result_field = field(domain=domain, val=x, codomain=codomain,
copy=False)
# weight if ...
# weight if necessary
if (not self.imp) and (not domain.discrete) and (not inverse) and \
self.den:
y = y.weight(power=1)
return y
result_field = result_field.weight(power=1)
return result_field
def _debriefing(self, x, y, target, cotarget, inverse):
# > evaluates x and y after `multiply`
if y is None:
return None
else:
# inspect y
if not isinstance(y, field):
y = field(target, codomain=cotarget, val=y)
elif y.domain != target:
raise ValueError(about._errors.cstring(
"ERROR: invalid output domain."))
# weight if ...
if (not self.imp) and (not target.discrete) and \
(not self.den ^ inverse):
y = y.weight(power=-1)
# inspect x
if isinstance(x, field):
# repair if the originally field was living in the codomain
# of the operators domain
if self.domain == self.target == x.codomain:
y = y.transform(codomain=x.domain)
if x.domain == y.domain and (x.codomain != y.codomain):
y.set_codomain(new_codomain=x.codomain)
return y
# The debriefing takes care that the result field lives in the same
# fourier-type domain as the input field
assert(isinstance(y, field))
# weight if necessary
if (not self.imp) and (not target.discrete) and \
(not self.den ^ inverse):
y = y.weight(power=-1)
return y
#
#
# # > evaluates x and y after `multiply`
# if y is None:
# return None
# else:
# # inspect y
# if not isinstance(y, field):
# y = field(target, codomain=cotarget, val=y)
# elif y.domain != target:
# raise ValueError(about._errors.cstring(
# "ERROR: invalid output domain."))
# # weight if ...
# if (not self.imp) and (not target.discrete) and \
# (not self.den ^ inverse):
# y = y.weight(power=-1)
# # inspect x
# if isinstance(x, field):
# # repair if the originally field was living in the codomain
# # of the operators domain
# if self.domain == self.target == x.codomain:
# y = y.transform(new_domain=x.domain)
# if x.domain == y.domain and (x.codomain != y.codomain):
# y.set_codomain(new_codomain=x.codomain)
# return y
def __repr__(self):
return "<nifty_core.response_operator>"
......@@ -3159,9 +3151,11 @@ class invertible_operator(operator):
if not force or x_ is None:
return None
about.warnings.cprint("WARNING: conjugate gradient failed.")
# weight if ...
if not self.imp: # continiuos domain/target
x_.weight(power=-1, overwrite=True)
# TODO: A weighting here shoud be wrong, as this is done by
# the (de)briefing methods -> Check!
# # weight if ...
# if not self.imp: # continiuos domain/target
# x_.weight(power=-1, overwrite=True)
return x_
def _inverse_multiply(self, x, force=False, W=None, spam=None, reset=None,
......@@ -3230,15 +3224,18 @@ class invertible_operator(operator):
if not force or x_ is None:
return None
about.warnings.cprint("WARNING: conjugate gradient failed.")
# weight if ...
if not self.imp: # continiuos domain/target
x_.weight(power=1, overwrite=True)
# TODO: A weighting here shoud be wrong, as this is done by
# the (de)briefing methods -> Check!
# # weight if ...
# if not self.imp: # continiuos domain/target
# x_.weight(power=1, overwrite=True)
return x_
def __repr__(self):
return "<nifty_tools.invertible_operator>"
class propagator_operator(operator):
"""
.. __
......@@ -3313,6 +3310,186 @@ class propagator_operator(operator):
"""
def __init__(self, S=None, M=None, R=None, N=None):
"""
Sets the standard operator properties and `codomain`, `_A1`, `_A2`,
and `RN` if required.
Parameters
----------
S : operator
Covariance of the signal prior.
M : operator
Likelihood contribution.
R : operator
Response operator translating signal to (noiseless) data.
N : operator
Covariance of the noise prior or the likelihood, respectively.
"""
# parse the signal prior covariance
if not isinstance(S, operator):
raise ValueError(about._errors.cstring(
"ERROR: The given S is not an operator."))
self.S = S
self.S_inverse_times = self.S.inverse_times
# take signal-space domain from S as the domain for D
S_is_harmonic = False
if hasattr(S.domain, 'harmonic'):
if S.domain.harmonic:
S_is_harmonic = True
if S_is_harmonic:
self.domain = S.codomain
self.codomain = S.domain
else:
self.domain = S.domain
self.codomain = S.codomain
self.target = self.domain
self.cotarget = self.codomain
# build up the likelihood contribution
(self.M_times,
M_domain,
M_codomain,
M_target,
M_cotarget) = self._build_likelihood_contribution(M, R, N)
# assert that S and M have matching domains
if not (self.domain == M_domain and
self.codomain == M_codomain and
self.target == M_target and
self.cotarget == M_cotarget):
raise ValueError(about._errors.cstring(
"ERROR: The (co)domains and (co)targets of the prior " +
"signal covariance and the likelihood contribution must be " +
"the same in the sense of '=='."))
self.sym = True
self.uni = False
self.imp = True
def _build_likelihood_contribution(self, M, R, N):
# if a M is given, return its times method and its domains
# supplier and discard R and N
if M is not None:
return (M.times, M.domain, M.codomain, M.target, M.cotarget)
if N is not None:
if R is not None:
return (lambda z: R.adjoint_times(N.inverse_times(R.times(z))),
R.domain, R.codomain, R.domain, R.codomain)
else:
return (N.inverse_times,
N.domain, N.codomain, N.target, N.cotarget)
else:
raise ValueError(about._errors.cstring(
"ERROR: At least M or N must be given."))
def _multiply(self, x, W=None, spam=None, reset=None, note=False,
x0=None, tol=1E-4, clevel=1, limii=None, **kwargs):
if W is None:
W = self.S
(result, convergence) = conjugate_gradient(self._inverse_multiply,
x,
W=W,
spam=spam,
reset=reset,
note=note)(x0=x0,
tol=tol,
clevel=clevel,
limii=limii)
# evaluate
if not convergence:
about.warnings.cprint("WARNING: conjugate gradient failed.")
return result
def _inverse_multiply(self, x, **kwargs):
result = self.S_inverse_times(x)
result += self.M_times(x)
return result
class propagator_operator_old(operator):
"""
.. __
.. / /_
.. _______ _____ ______ ______ ____ __ ____ __ ____ __ / _/ ______ _____
.. / _ / / __/ / _ | / _ | / _ / / _ / / _ / / / / _ | / __/
.. / /_/ / / / / /_/ / / /_/ / / /_/ / / /_/ / / /_/ / / /_ / /_/ / / /
.. / ____/ /__/ \______/ / ____/ \______| \___ / \______| \___/ \______/ /__/ operator class
.. /__/ /__/ /______/
NIFTY subclass for propagator operators (of a certain family)
The propagator operators :math:`D` implemented here have an inverse
formulation like :math:`(S^{-1} + M)`, :math:`(S^{-1} + N^{-1})`, or
:math:`(S^{-1} + R^\dagger N^{-1} R)` as appearing in Wiener filter
theory.
Parameters
----------
S : operator
Covariance of the signal prior.
M : operator
Likelihood contribution.
R : operator
Response operator translating signal to (noiseless) data.
N : operator
Covariance of the noise prior or the likelihood, respectively.
See Also
--------
conjugate_gradient
Notes
-----
The propagator will puzzle the operators `S` and `M` or `R`, `N` or
only `N` together in the predefined from, a domain is set
automatically. The application of the inverse is done by invoking a
conjugate gradient.
Note that changes to `S`, `M`, `R` or `N` auto-update the propagator.
Examples
--------
>>> f = field(rg_space(4), val=[2, 4, 6, 8])
>>> S = power_operator(f.target, spec=1)
>>> N = diagonal_operator(f.domain, diag=1)
>>> D = propagator_operator(S=S, N=N) # D^{-1} = S^{-1} + N^{-1}
>>> D(f).val
array([ 1., 2., 3., 4.])
Attributes
----------
domain : space
A space wherein valid arguments live.
codomain : space
An alternative space wherein valid arguments live; commonly the
codomain of the `domain` attribute.
sym : bool
Indicates that the operator is self-adjoint.
uni : bool
Indicates that the operator is not unitary.
imp : bool
Indicates that volume weights are implemented in the `multiply`
instance methods.