Commit a09c6ed0 authored by Theo Steininger's avatar Theo Steininger

Fixed spaces handling, especially in ComposedOperator. Added...

Fixed spaces handling, especially in ComposedOperator. Added create_composed_fft_operator to sugar.py. Fixed small bug in plotting.
parent 0808fda1
......@@ -23,7 +23,7 @@ class WienerFilterEnergy(Energy):
The prior signal covariance in harmonic space.
"""
def __init__(self, position, d, R, N, S, inverter=None):
def __init__(self, position, d, R, N, S):
super(WienerFilterEnergy, self).__init__(position=position)
self.d = d
self.R = R
......@@ -32,7 +32,7 @@ class WienerFilterEnergy(Energy):
def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S, inverter=self.inverter)
S=self.S)
@property
@memo
......@@ -49,6 +49,7 @@ class WienerFilterEnergy(Energy):
def curvature(self):
return WienerFilterCurvature(R=self.R, N=self.N, S=self.S)
@property
@memo
def _Dx(self):
return self.curvature(self.position)
......
......@@ -147,10 +147,11 @@ class ComposedOperator(LinearOperator):
def _inverse_times_helper(self, x, spaces, func):
space_index = 0
if spaces is None:
spaces = range(len(self.target))[::-1]
spaces = range(len(self.target))
rev_spaces = spaces[::-1]
for op in reversed(self._operator_store):
active_spaces = spaces[space_index:space_index+len(op.target)]
active_spaces = rev_spaces[space_index:space_index+len(op.target)]
space_index += len(op.target)
x = getattr(op, func)(x, spaces=active_spaces)
x = getattr(op, func)(x, spaces=active_spaces[::-1])
return x
......@@ -270,8 +270,11 @@ class LinearOperator(Loggable, object):
raise ValueError(
"supplied object is not a `Field`.")
if spaces is None:
spaces = self.default_spaces
if spaces is None and self.default_spaces is not None:
if not inverse:
spaces = self.default_spaces
else:
spaces = self.default_spaces[::-1]
# sanitize the `spaces` and `types` input
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
......
......@@ -66,21 +66,12 @@ class ResponseOperator(LinearOperator):
"""
def __init__(self, domain,
sigma=[1.], exposure=[1.],
def __init__(self, domain, sigma=[1.], exposure=[1.],
default_spaces=None):
super(ResponseOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain)
shapes = len(self._domain)*[None]
shape_target = []
for ii in xrange(len(shapes)):
shapes[ii] = self._domain[ii].shape
shape_target = np.append(shape_target, self._domain[ii].shape)
self._target = self._parse_domain(FieldArray(shape_target))
kernel_smoothing = len(self._domain)*[None]
kernel_exposure = len(self._domain)*[None]
......@@ -97,6 +88,12 @@ class ResponseOperator(LinearOperator):
self._composed_kernel = ComposedOperator(kernel_smoothing)
self._composed_exposure = ComposedOperator(kernel_exposure)
target_list = []
for space in self.domain:
target_list += [FieldArray(space.shape)]
self._target = self._parse_domain(target_list)
@property
def domain(self):
return self._domain
......@@ -119,8 +116,6 @@ class ResponseOperator(LinearOperator):
def _adjoint_times(self, x, spaces):
# setting correct spaces
res = Field(self.domain, val=x.val)
if spaces is not None:
spaces = range(spaces[0], spaces[0]+len(res.domain))
res = self._composed_exposure.adjoint_times(res, spaces)
res = res.weight(power=-1)
res = self._composed_kernel.adjoint_times(res, spaces)
......
......@@ -6,7 +6,7 @@ from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Axis(PlotlyWrapper):
def __init__(self, label=None, font='Balto', color='', log=False,
font_size=22, show_grid=True, visible=True):
self.label = str(label)
self.label = str(label) if label is not None else None
self.font = font
self.color = color
self.log = log
......@@ -32,6 +32,7 @@ class Axis(PlotlyWrapper):
ply_object['visible'] = self.visible
ply_object['tickfont'] = {'size': self.font_size,
'family': self.font}
ply_object['exponentformat'] = 'power'
# ply_object['domain'] = {'0': '0.04',
# '1': '1'}
return ply_object
......@@ -16,14 +16,10 @@ class Figure2D(FigureFromPlot):
if isinstance(plots[0], Heatmap) and width is None and \
height is None:
(x, y) = plots[0].data.shape
(y, x) = plots[0].data.shape
if x > y:
width = 500
height = int(500*y/x)
else:
height = 500
width = int(500 * y / x)
width = 500
height = int(500*y/x)
if isinstance(plots[0], GLMollweide) or \
isinstance(plots[0], HPMollweide):
......
......@@ -52,7 +52,8 @@ class Heatmap(PlotlyWrapper):
plotly_object['showscale'] = True
plotly_object['colorbar'] = {'tickfont': {'size': self._font_size,
'family': self._font_family}}
'family': self._font_family},
'exponentformat': 'power'}
if self.color_map:
plotly_object['colorscale'] = self.color_map.to_plotly()
if self.webgl:
......@@ -67,7 +68,7 @@ class Heatmap(PlotlyWrapper):
return 700
def default_height(self):
(x, y) = self.data.shape
(y, x) = self.data.shape
return int(700 * y / x)
def default_axes(self):
......
......@@ -16,10 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from nifty.sugar import create_composed_fft_operator
class DiagonalProberMixin(object):
def __init__(self, *args, **kwargs):
self.reset()
self.__evaluate_probe_in_signal_space = True
super(DiagonalProberMixin, self).__init__(*args, **kwargs)
def reset(self):
......@@ -30,7 +33,11 @@ class DiagonalProberMixin(object):
super(DiagonalProberMixin, self).reset()
def finish_probe(self, probe, pre_result):
result = probe[1].conjugate()*pre_result
if self.__evaluate_probe_in_signal_space:
fft = create_composed_fft_operator(self._domain, all_to='position')
result = fft(probe[1]).conjugate()*fft(pre_result)
else:
result = probe[1].conjugate()*pre_result
self.__sum_of_probings += result
if self.compute_variance:
self.__sum_of_squares += result.conjugate() * result
......
......@@ -16,10 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from nifty.sugar import create_composed_fft_operator
class TraceProberMixin(object):
def __init__(self, *args, **kwargs):
self.reset()
self.__evaluate_probe_in_signal_space = True
super(TraceProberMixin, self).__init__(*args, **kwargs)
def reset(self):
......@@ -30,7 +33,12 @@ class TraceProberMixin(object):
super(TraceProberMixin, self).reset()
def finish_probe(self, probe, pre_result):
result = probe[1].vdot(pre_result, bare=True)
if self.__evaluate_probe_in_signal_space:
fft = create_composed_fft_operator(self._domain, all_to='position')
result = fft(probe[1]).vdot(fft(pre_result), bare=True)
else:
result = probe[1].vdot(pre_result, bare=True)
self.__sum_of_probings += result
if self.compute_variance:
self.__sum_of_squares += result.conjugate() * result
......
......@@ -47,6 +47,7 @@ class Prober(object):
self._random_type = self._parse_random_type(random_type)
self.compute_variance = bool(compute_variance)
self.probe_dtype = np.dtype(probe_dtype)
self._uid_counter = 0
# ---Properties---
......@@ -108,7 +109,8 @@ class Prober(object):
domain=self.domain,
dtype=self.probe_dtype,
distribution_strategy=self.distribution_strategy)
uid = np.random.randint(1e18)
uid = self._uid_counter
self._uid_counter += 1
return (uid, f)
def process_probe(self, callee, probe, index):
......
......@@ -16,13 +16,18 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from nifty import PowerSpace,\
from nifty import Space,\
PowerSpace,\
Field,\
ComposedOperator,\
DiagonalOperator,\
FFTOperator,\
sqrt,\
nifty_configuration
__all__ = ['create_power_operator', 'generate_posterior_sample']
__all__ = ['create_power_operator',
'generate_posterior_sample',
'create_composed_fft_operator']
def create_power_operator(domain, power_spectrum, dtype=None,
......@@ -110,3 +115,23 @@ def generate_posterior_sample(mean, covariance):
mock_m = covariance.inverse_times(mock_j)
sample = mock_signal - mock_m + mean
return sample
def create_composed_fft_operator(domain, codomain=None, all_to='other'):
fft_op_list = []
space_index_list = []
if codomain is None:
codomain = [None]*len(domain)
for i in range(len(domain)):
space = domain[i]
cospace = codomain[i]
if not isinstance(space, Space):
continue
if (all_to == 'other' or
(all_to == 'position' and space.harmonic) or
(all_to == 'harmonic' and not space.harmonic)):
fft_op_list += [FFTOperator(domain=space, target=cospace)]
space_index_list += [i]
result = ComposedOperator(fft_op_list, default_spaces=space_index_list)
return result
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment