Commit ef955b69 authored by Ultima's avatar Ultima
Browse files

Some Bugfixes.

propagator_operator still not working to full extend.
parent b2d496b2
......@@ -46,7 +46,7 @@ from nifty import *
class problem(object):
def __init__(self, x_space, s2n=0.5, **kwargs):
def __init__(self, x_space, s2n=12, **kwargs):
"""
Sets up a Wiener filter problem.
......@@ -67,7 +67,7 @@ class problem(object):
#self.k.set_power_indices(**kwargs)
## set some power spectrum
self.power = (lambda k: 42 / (k + 1) ** 5)
self.power = (lambda k: 42 / (k + 1) ** 3)
## define signal covariance
self.S = power_operator(self.k, spec=self.power, bare=True)
......@@ -82,7 +82,8 @@ class problem(object):
d_space = self.R.target
## define noise covariance
self.N = diagonal_operator(d_space, diag=abs(s2n) * self.s.var(), bare=True)
#self.N = diagonal_operator(d_space, diag=abs(s2n) * self.s.var(), bare=True)
self.N = diagonal_operator(d_space, diag=abs(s2n), bare=True)
## define (plain) projector
self.Nj = projection_operator(d_space)
## generate noise
......@@ -95,7 +96,9 @@ class problem(object):
#self.j = self.R.adjoint_times(self.N.inverse_times(self.d), target=self.k)
self.j = self.R.adjoint_times(self.N.inverse_times(self.d))
## define information propagator
self.D = propagator_operator(S=self.S, N=self.N, R=self.R)
self.D = propagator_operator(S=self.S,
N=self.N,
R=self.R)
## reserve map
self.m = None
......@@ -186,7 +189,7 @@ class problem(object):
print ('power', power)
print 'Calculated power'
power = np.clip(power, 0.00000001, np.max(power))
#power = np.clip(power, 0.00000001, np.max(power))
self.store += [{'tr_B1': tr_B1,
'tr_B2': tr_B2,
'num': numerator,
......@@ -253,7 +256,7 @@ class problem(object):
##-----------------------------------------------------------------------------
#
if(__name__=="__main__"):
x = rg_space((128,128), zerocenter=True)
x = rg_space((128), zerocenter=True)
p = problem(x, log = False)
about.warnings.off()
## pl.close("all")
......
......@@ -8,12 +8,21 @@ shape = (256, 256)
x_space = rg_space(shape)
k_space = x_space.get_codomain()
power = lambda k: 42/((1+k*shape[0])**2)
power = lambda k: 42/((1+k*shape[0])**3)
S = power_operator(k_space, codomain=x_space, spec=power)
s = S.get_random_field(domain=x_space)
#n_points = 360.
#starts = [[(np.cos(i/n_points*np.pi)+1)*shape[0]/2.,
# (np.sin(i/n_points*np.pi)+1)*shape[0]/2.] for i in xrange(int(n_points))]
#starts = list(np.array(starts).T)
#
#ends = [[(np.cos(i/n_points*np.pi + np.pi)+1)*shape[0]/2.,
# (np.sin(i/n_points*np.pi + np.pi)+1)*shape[0]/2.] for i in xrange(int(n_points))]
#ends = list(np.array(ends).T)
def make_los(n=10, angle=0, d=1):
starts_list = []
ends_list = []
......@@ -29,9 +38,9 @@ def make_los(n=10, angle=0, d=1):
ends_list = rot_matrix.dot(ends_list.T-0.5*d).T+0.5*d
return (starts_list, ends_list)
temp_coords = (np.empty((0,2)), np.empty((0,2)))
n = 256
m = 256
temp_coords = (np.empty((0, 2)), np.empty((0, 2)))
n = 250
m = 250
for alpha in [np.pi/n*j for j in xrange(n)]:
temp = make_los(n=m, angle=alpha)
temp_coords = np.concatenate([temp_coords, temp], axis=1)
......@@ -39,27 +48,19 @@ for alpha in [np.pi/n*j for j in xrange(n)]:
starts = list(temp_coords[0].T)
ends = list(temp_coords[1].T)
#n_points = 360.
#starts = [[(np.cos(i/n_points*np.pi)+1)*shape[0]/2.,
# (np.sin(i/n_points*np.pi)+1)*shape[0]/2.] for i in xrange(int(n_points))]
#starts = list(np.array(starts).T)
#
#ends = [[(np.cos(i/n_points*np.pi + np.pi)+1)*shape[0]/2.,
# (np.sin(i/n_points*np.pi + np.pi)+1)*shape[0]/2.] for i in xrange(int(n_points))]
#ends = list(np.array(ends).T)
R = los_response(x_space, starts=starts, ends=ends, sigmas_up=0.1, sigmas_low=0.1)
R = los_response(x_space, starts=starts, ends=ends,
sigmas_up=0.1, sigmas_low=0.1)
d_space = R.target
N = diagonal_operator(d_space, diag=s.var(), bare=True)
N = diagonal_operator(d_space, diag=s.var()/100000, bare=True)
n = N.get_random_field(domain=d_space)
d = R(s) + n
j = R.adjoint_times(N.inverse_times(d))
D = propagator_operator(S=S, N=N, R=R)
m = D(j, W=S, tol=1E-14, limii=100, note=True)
m = D(j, W=S, tol=1E-14, limii=50, note=True)
s.plot(title="signal", save='1_plot_s.png')
s.plot(save='plot_s_power.png', power=True, other=power)
......
......@@ -44,25 +44,28 @@ from nifty import * # version 0.8.0
about.warnings.off()
# some signal space; e.g., a two-dimensional regular grid
x_space = rg_space([128, 128]) # define signal space
shape = [1024,]
x_space = rg_space(shape)
#y_space = point_space(1280*1280)
#x_space = hp_space(32)
#x_space = gl_space(800)
k_space = x_space.get_codomain() # get conjugate space
# some power spectrum
power = (lambda k: 42 / (k + 1) ** 3)
power = (lambda k: 42 / (k + 1) ** 4)
S = power_operator(k_space, codomain=x_space, spec=power) # define signal covariance
s = S.get_random_field(domain=x_space) # generate signal
#my_mask = x_space.cast(1)
#my_mask[400:900,400:900] = 0
#stretch = 0.6
#my_mask[shape[0]/2*stretch:shape[0]/2/stretch, shape[1]/2*stretch:shape[1]/2/stretch] = 0
my_mask = 1
R = response_operator(x_space, sigma=0.01, mask=my_mask, assign=None) # define response
R = response_operator(x_space, assign=None) #
#R = identity_operator(x_space)
d_space = R.target # get data space
# some noise variance; e.g., signal-to-noise ratio of 1
......@@ -79,14 +82,34 @@ D = propagator_operator(S=S, N=N, R=R) # define inform
#m = D(j, tol=1E-8, limii=20, note=True, force=True)
ident = identity(x_space)
xi = field(x_space, random='gau', target=k_space)
m = D(xi, W=ident, tol=1E-8, limii=10, note=True, force=True)
temp_result = (D.inverse_times(m)-xi)
print (temp_result.dot(temp_result))
print (temp_result.val)
#xi = field(x_space, random='gau', target=k_space)
m = D(j, W=S, tol=1E-8, limii=100, note=True)
#temp_result = (D.inverse_times(m)-xi)
s.plot(title="signal", save = 'plot_s.png') # plot signal
d_ = field(x_space, val=d.val, target=k_space)
d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png') # plot data
m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png') # plot map
#n_power = x_space.enforce_power(s.var()/np.prod(shape))
#s_power = S.get_power()
#s.plot(title="signal", save = 'plot_s.png')
#s.plot(title="signal power", power=True, other=power,
# mono=False, save = 'power_plot_s.png', nbin=1000, log=True,
# vmax = 100, vmin=10e-7)
#d_ = field(x_space, val=d.val, target=k_space)
#d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png')
#n_ = field(x_space, val=n.val, target=k_space)
#n_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_n.png')
#
#m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png')
#m.plot(title="reconstructed power", power=True, other=(n_power, s_power),
# save = 'power_plot_m.png', vmin=0.001, vmax=10, mono=False)
#
#
......@@ -1921,7 +1921,8 @@ class field(object):
"""
def __init__(self, domain, val=None, codomain=None, ishape=None, **kwargs):
def __init__(self, domain=None, val=None, codomain=None, ishape=None,
**kwargs):
"""
Sets the attributes for a field class instance.
......@@ -1944,6 +1945,9 @@ class field(object):
Nothing
"""
# If the given val was a field, try to cast it accordingly to the given
# domain and codomain, etc...
# check domain
if not isinstance(domain, space):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
......@@ -2359,7 +2363,7 @@ class field(object):
# whether the domain matches exactly or not:
# extract the data from x and try to dot with this
return self.dot(x=x.get_val(), axis=axis)
return self.dot(x=x.get_val(), axis=axis, bare=bare)
# Case 3: x is something else
else:
......
......@@ -631,6 +631,7 @@ class distributed_data_object(object):
temp_d2o = self.copy_empty()
temp_data = np.conj(self.get_local_data())
temp_d2o.set_local_data(temp_data)
temp_d2o.hermitian = self.hermitian
return temp_d2o
def conj(self):
......
......@@ -53,7 +53,9 @@ class power_indices(object):
# Initialize the dictonary which stores all individual index-dicts
self.global_dict = {}
# Set self.default_parameters
self.set_default(log=log, nbin=nbin, binbounds=binbounds)
self.set_default(config_dict={'log': log,
'nbin': nbin,
'binbounds': binbounds})
# Redirect the direct calls approaching a power_index instance to the
# default_indices dict
......@@ -217,7 +219,10 @@ class power_indices(object):
# indices, bin them, compute the pundex and then return everything.
else:
# Get the unbinned indices
temp_unbinned_indices = self.get_index_dict(store=False)
temp_unbinned_indices = self.get_index_dict(nbin=None,
binbounds=None,
log=False,
store=False)
# Bin them
(temp_pindex, temp_kindex, temp_rho, temp_pundex) = \
self._bin_power_indices(
......
......@@ -2,23 +2,24 @@
import numpy as np
def hermitianize(x):
## make the point inversions
def hermitianize_gaussian(x):
# make the point inversions
flipped_x = _hermitianize_inverter(x)
flipped_x = flipped_x.conjugate()
## check if x was already hermitian
# check if x was already hermitian
if (x == flipped_x).all():
return x
## average x and flipped_x.
## Correct the variance by multiplying sqrt(0.5)
# average x and flipped_x.
# Correct the variance by multiplying sqrt(0.5)
x = (x + flipped_x) * np.sqrt(0.5)
## The fixed points of the point inversion must not be avaraged.
## Hence one must multiply them again with sqrt(0.5)
## -> Get the middle index of the array
# The fixed points of the point inversion must not be avaraged.
# Hence one must multiply them again with sqrt(0.5)
# -> Get the middle index of the array
mid_index = np.array(x.shape, dtype=np.int)//2
dimensions = mid_index.size
## Use ndindex to iterate over all combinations of zeros and the
## mid_index in order to correct all fixed points.
# Use ndindex to iterate over all combinations of zeros and the
# mid_index in order to correct all fixed points.
for i in np.ndindex((2,)*dimensions):
temp_index = tuple(i*mid_index)
x[temp_index] *= np.sqrt(0.5)
......@@ -30,18 +31,35 @@ def hermitianize(x):
return x
def hermitianize(x):
# make the point inversions
flipped_x = _hermitianize_inverter(x)
flipped_x = flipped_x.conjugate()
# check if x was already hermitian
if (x == flipped_x).all():
return x
# average x and flipped_x.
# Correct the variance by multiplying sqrt(0.5)
x = (x + flipped_x) / 2.
try:
x.hermitian = True
except(AttributeError):
pass
return x
def _hermitianize_inverter(x):
## calculate the number of dimensions the input array has
# calculate the number of dimensions the input array has
dimensions = len(x.shape)
## prepare the slicing object which will be used for mirroring
slice_primitive = [slice(None),]*dimensions
## copy the input data
# prepare the slicing object which will be used for mirroring
slice_primitive = [slice(None), ]*dimensions
# copy the input data
y = x.copy()
## flip in every direction
# flip in every direction
for i in xrange(dimensions):
slice_picker = slice_primitive[:]
slice_picker[i] = slice(1, None,None)
slice_picker[i] = slice(1, None, None)
slice_inverter = slice_primitive[:]
slice_inverter[i] = slice(None, 0, -1)
......@@ -54,7 +72,7 @@ def _hermitianize_inverter(x):
def direct_dot(x, y):
## the input could be fields. Try to extract the data
# the input could be fields. Try to extract the data
try:
x = x.get_val()
except(AttributeError):
......@@ -63,7 +81,7 @@ def direct_dot(x, y):
y = y.get_val()
except(AttributeError):
pass
## try to make a direct vdot
# try to make a direct vdot
try:
return x.vdot(y)
except(AttributeError):
......@@ -74,16 +92,16 @@ def direct_dot(x, y):
except(AttributeError):
pass
## fallback to numpy
# fallback to numpy
return np.vdot(x, y)
def convert_nested_list_to_object_array(x):
## if x is a nested_list full of ndarrays all having the same size,
## np.shape returns the shape of the ndarrays, too, i.e. too many
## dimensions
# if x is a nested_list full of ndarrays all having the same size,
# np.shape returns the shape of the ndarrays, too, i.e. too many
# dimensions
possible_shape = np.shape(x)
## Check if possible_shape goes too deep.
# Check if possible_shape goes too deep.
dimension_counter = 0
current_extract = x
for i in xrange(len(possible_shape)):
......@@ -93,11 +111,11 @@ def convert_nested_list_to_object_array(x):
current_extract = current_extract[0]
dimension_counter += 1
real_shape = possible_shape[:dimension_counter]
## if the numpy array was not encapsulated at all, return x directly
# if the numpy array was not encapsulated at all, return x directly
if real_shape == ():
return x
## Prepare the carrier-object
carrier = np.empty(real_shape, dtype = np.object)
# Prepare the carrier-object
carrier = np.empty(real_shape, dtype=np.object)
for i in xrange(np.prod(real_shape)):
ii = np.unravel_index(i, real_shape)
try:
......@@ -121,10 +139,10 @@ def field_map(ishape, function, *args):
result[ii] = function()
return result
else:
## define a helper function in order to clip the get-indices
## to be suitable for the foreign arrays in args.
## This allows you to do operations, like adding to fields
## with ishape (3,4,3) and (3,4,1)
# define a helper function in order to clip the get-indices
# to be suitable for the foreign arrays in args.
# This allows you to do operations, like adding to fields
# with ishape (3,4,3) and (3,4,1)
def get_clipped(w, ind):
w_shape = np.array(np.shape(w))
get_tuple = tuple(np.clip(ind, 0, w_shape-1))
......@@ -135,5 +153,5 @@ def field_map(ishape, function, *args):
result[ii] = function(*map(
lambda z: get_clipped(z, ii), args)
)
#result[ii] = function(*map(lambda z: z[ii], args))
return result
\ No newline at end of file
# result[ii] = function(*map(lambda z: z[ii], args))
return result
......@@ -2292,7 +2292,7 @@ class projection_operator(operator):
raise TypeError(about._errors.cstring("ERROR: Invalid bands."))
if bands_was_scalar:
new_field = x * (self.assign == bands[0])
new_field = fx * (self.assign == bands[0])
else:
# build up the projection results
# prepare the projector-carrier
......@@ -2392,15 +2392,18 @@ class projection_operator(operator):
vecvec = vecvec_operator(val=x)
return self.pseudo_tr(x=vecvec, axis=axis, **kwargs)
# Case 2: x is not an operator
elif isinstance(x, operator) == False:
# Case 2: x is an operator
# -> take the diagonal
elif isinstance(x, operator):
working_field = x.diag(bare=False)
if self.domain != working_field.domain:
working_field = working_field.transform(codomain=self.domain)
# Case 3: x is something else
else:
raise TypeError(about._errors.cstring(
"ERROR: x must be a field or an operator."))
# Case 3: x is an operator
# -> take the diagonal
working_field = x.diag()
# Check for hidden degrees of freedom and compensate the trace
# accordingly
if self.domain.get_dim() != self.domain.get_dof():
......
## NIFTY (Numerical Information Field Theory) has been developed at the
## Max-Planck-Institute for Astrophysics.
##
## Copyright (C) 2015 Max-Planck-Society
##
## Author: Theo Steininger
## Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
## See the GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.
# NIFTY (Numerical Information Field Theory) has been developed at the
# Max-Planck-Institute for Astrophysics.
#
# Copyright (C) 2015 Max-Planck-Society
#
# Author: Theo Steininger
# Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import division
......@@ -27,7 +27,7 @@ from nifty.nifty_core import space, \
from nifty.nifty_utilities import direct_dot
##=============================================================================
class prober(object):
"""
......@@ -177,59 +177,61 @@ class prober(object):
zeroth entry and the variance in the first entry. (default: False)
"""
## Case 1: no operator given. Check function and domain for general
## sanity
# Case 1: no operator given. Check function and domain for general
# sanity
if operator is None:
## check whether the given function callable
if function is None or hasattr(function, "__call__") == False:
# check whether the given function callable
if function is None or not hasattr(function, "__call__"):
raise ValueError(about._errors.cstring(
"ERROR: invalid input: No function given or not callable."))
## check given domain
if domain is None or isinstance(domain, space) == False:
"ERROR: invalid input: No function given or not callable."))
# check given domain
if domain is None or not isinstance(domain, space):
raise ValueError(about._errors.cstring(
"ERROR: invalid input: given domain is not a nifty space"))
"ERROR: invalid input: given domain is not a nifty space"))
## Case 2: An operator is given. Take domain and function from that
## if not given explicitly
# Case 2: An operator is given. Take domain and function from that
# if not given explicitly
else:
## Case 2.1 extract function
## explicit function overrides operator function
if function is None or hasattr(function,"__call__") == False:
# Check 2.1 extract function
# explicit function overrides operator function
if function is None or not hasattr(function, "__call__"):
try:
function = operator.times
except(AttributeError):
raise ValueError(about._errors.cstring(
"ERROR: no explicit function given and given operator has no times method!"))
## check whether the given function is correctly bound to the
## operator
"ERROR: no explicit function given and given " +
"operator has no times method!"))
# Check 2.2 check whether the given function is correctly bound to
# the operator
if operator != function.im_self:
raise ValueError(about._errors.cstring(
"ERROR: the given function is not a bound function of the operator!"))
## Case 2.2 extract domain
if domain is None or isinstance(domain, space):
if (function in [operator.inverse_times,
operator.adjoint_times]):
try:
domain = operator.target
except(AttributeError):
raise ValueError(about._errors.cstring(
"ERROR: no explicit domain given and given operator has no target!"))
else:
try:
domain = operator.domain
except(AttributeError):
raise ValueError(about._errors.cstring(
"ERROR: no explicit domain given and given operator has no domain!"))
"ERROR: the given function is not a bound function " +
"of the operator!"))
# Check 2.3 extract domain
if domain is None or not isinstance(domain, space):
if (function in [operator.inverse_times,
operator.adjoint_times]):
try:
domain = operator.target
except(AttributeError):
raise ValueError(about._errors.cstring(
"ERROR: no explicit domain given and given " +
"operator has no target!"))
else:
try:
domain = operator.domain
except(AttributeError):
raise ValueError(about._errors.cstring(
"ERROR: no explicit domain given and given " +
"operator has no domain!"))
self.function = function
self.domain = domain
## Check the given target
# Check the given codomain
if codomain is None:
codomain = self.domain.get_codomain()
else:
......@@ -237,71 +239,16 @@ class prober(object):
self.codomain = codomain
if(random not in ["pm1","gau"]):
if random not in ["pm1", "gau"]:
raise ValueError(about._errors.cstring(
"ERROR: unsupported random key '"+str(random)+"'."))
"ERROR: unsupported random key '" + str(random) + "'."))
self.random = random
## Parse the remaining arguments
# Parse the remaining arguments
self.nrun = int(nrun)
self.varQ = bool(varQ)
self.kwargs = kwargs
"""
from nifty_operators import operator
if(not isinstance(op,operator)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))