diff --git a/nifty/library/wiener_filter/wiener_filter_energy.py b/nifty/library/wiener_filter/wiener_filter_energy.py index 8eec64c51746c6121008f8890b79fcf1da25654d..61a0d891e2a8a20175485a8a3db2fbfc78653c47 100644 --- a/nifty/library/wiener_filter/wiener_filter_energy.py +++ b/nifty/library/wiener_filter/wiener_filter_energy.py @@ -23,7 +23,7 @@ class WienerFilterEnergy(Energy): The prior signal covariance in harmonic space. """ - def __init__(self, position, d, R, N, S, inverter=None): + def __init__(self, position, d, R, N, S): super(WienerFilterEnergy, self).__init__(position=position) self.d = d self.R = R @@ -32,7 +32,7 @@ class WienerFilterEnergy(Energy): def at(self, position): return self.__class__(position=position, d=self.d, R=self.R, N=self.N, - S=self.S, inverter=self.inverter) + S=self.S) @property @memo @@ -49,6 +49,7 @@ class WienerFilterEnergy(Energy): def curvature(self): return WienerFilterCurvature(R=self.R, N=self.N, S=self.S) + @property @memo def _Dx(self): return self.curvature(self.position) diff --git a/nifty/operators/composed_operator/composed_operator.py b/nifty/operators/composed_operator/composed_operator.py index 3b925d65736848b26d539b772ba064b00f23b6e4..993385df25949dda1c89db37132538ce9626915b 100644 --- a/nifty/operators/composed_operator/composed_operator.py +++ b/nifty/operators/composed_operator/composed_operator.py @@ -147,10 +147,11 @@ class ComposedOperator(LinearOperator): def _inverse_times_helper(self, x, spaces, func): space_index = 0 if spaces is None: - spaces = range(len(self.target))[::-1] + spaces = range(len(self.target)) + rev_spaces = spaces[::-1] for op in reversed(self._operator_store): - active_spaces = spaces[space_index:space_index+len(op.target)] + active_spaces = rev_spaces[space_index:space_index+len(op.target)] space_index += len(op.target) - x = getattr(op, func)(x, spaces=active_spaces) + x = getattr(op, func)(x, spaces=active_spaces[::-1]) return x diff --git a/nifty/operators/linear_operator/linear_operator.py b/nifty/operators/linear_operator/linear_operator.py index 38b52f341861758597c664fc917a39b4d67493b8..c0274e297ef6bc745ded9f99a9a067f72ce49b05 100644 --- a/nifty/operators/linear_operator/linear_operator.py +++ b/nifty/operators/linear_operator/linear_operator.py @@ -270,8 +270,11 @@ class LinearOperator(Loggable, object): raise ValueError( "supplied object is not a `Field`.") - if spaces is None: - spaces = self.default_spaces + if spaces is None and self.default_spaces is not None: + if not inverse: + spaces = self.default_spaces + else: + spaces = self.default_spaces[::-1] # sanitize the `spaces` and `types` input spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) diff --git a/nifty/operators/response_operator/response_operator.py b/nifty/operators/response_operator/response_operator.py index 17ff7f88d1d83af1c5a0d46ad207465ac03db4aa..c36b7d1b13a7dbb57a48970dc68fb177002a9dd4 100644 --- a/nifty/operators/response_operator/response_operator.py +++ b/nifty/operators/response_operator/response_operator.py @@ -66,21 +66,12 @@ class ResponseOperator(LinearOperator): """ - def __init__(self, domain, - sigma=[1.], exposure=[1.], + def __init__(self, domain, sigma=[1.], exposure=[1.], default_spaces=None): super(ResponseOperator, self).__init__(default_spaces) self._domain = self._parse_domain(domain) - shapes = len(self._domain)*[None] - shape_target = [] - for ii in xrange(len(shapes)): - shapes[ii] = self._domain[ii].shape - shape_target = np.append(shape_target, self._domain[ii].shape) - - self._target = self._parse_domain(FieldArray(shape_target)) - kernel_smoothing = len(self._domain)*[None] kernel_exposure = len(self._domain)*[None] @@ -97,6 +88,12 @@ class ResponseOperator(LinearOperator): self._composed_kernel = ComposedOperator(kernel_smoothing) self._composed_exposure = ComposedOperator(kernel_exposure) + target_list = [] + for space in self.domain: + target_list += [FieldArray(space.shape)] + + self._target = self._parse_domain(target_list) + @property def domain(self): return self._domain @@ -119,8 +116,6 @@ class ResponseOperator(LinearOperator): def _adjoint_times(self, x, spaces): # setting correct spaces res = Field(self.domain, val=x.val) - if spaces is not None: - spaces = range(spaces[0], spaces[0]+len(res.domain)) res = self._composed_exposure.adjoint_times(res, spaces) res = res.weight(power=-1) res = self._composed_kernel.adjoint_times(res, spaces) diff --git a/nifty/plotting/descriptors/axis.py b/nifty/plotting/descriptors/axis.py index e6e9dd6906d2f48e8a74b03baaac774ceed60358..c30b67b6a729e3a5b0dfb39e81e7952d954db939 100644 --- a/nifty/plotting/descriptors/axis.py +++ b/nifty/plotting/descriptors/axis.py @@ -6,7 +6,7 @@ from nifty.plotting.plotly_wrapper import PlotlyWrapper class Axis(PlotlyWrapper): def __init__(self, label=None, font='Balto', color='', log=False, font_size=22, show_grid=True, visible=True): - self.label = str(label) + self.label = str(label) if label is not None else None self.font = font self.color = color self.log = log @@ -32,6 +32,7 @@ class Axis(PlotlyWrapper): ply_object['visible'] = self.visible ply_object['tickfont'] = {'size': self.font_size, 'family': self.font} + ply_object['exponentformat'] = 'power' # ply_object['domain'] = {'0': '0.04', # '1': '1'} return ply_object diff --git a/nifty/plotting/figures/figure_2D.py b/nifty/plotting/figures/figure_2D.py index 9165e4bdca2b47cef348a68841d537e112fd6d70..569472f33c9631b58ac44ec2aee2ecbb955c4f1b 100644 --- a/nifty/plotting/figures/figure_2D.py +++ b/nifty/plotting/figures/figure_2D.py @@ -16,14 +16,10 @@ class Figure2D(FigureFromPlot): if isinstance(plots[0], Heatmap) and width is None and \ height is None: - (x, y) = plots[0].data.shape + (y, x) = plots[0].data.shape - if x > y: - width = 500 - height = int(500*y/x) - else: - height = 500 - width = int(500 * y / x) + width = 500 + height = int(500*y/x) if isinstance(plots[0], GLMollweide) or \ isinstance(plots[0], HPMollweide): diff --git a/nifty/plotting/plots/heatmaps/heatmap.py b/nifty/plotting/plots/heatmaps/heatmap.py index 497396f47a4b58a69d7de8639e125b762bba44c1..98a8b3f444231a7a8ae1a0ec558efdd4f481c869 100644 --- a/nifty/plotting/plots/heatmaps/heatmap.py +++ b/nifty/plotting/plots/heatmaps/heatmap.py @@ -52,7 +52,8 @@ class Heatmap(PlotlyWrapper): plotly_object['showscale'] = True plotly_object['colorbar'] = {'tickfont': {'size': self._font_size, - 'family': self._font_family}} + 'family': self._font_family}, + 'exponentformat': 'power'} if self.color_map: plotly_object['colorscale'] = self.color_map.to_plotly() if self.webgl: @@ -67,7 +68,7 @@ class Heatmap(PlotlyWrapper): return 700 def default_height(self): - (x, y) = self.data.shape + (y, x) = self.data.shape return int(700 * y / x) def default_axes(self): diff --git a/nifty/probing/mixin_classes/diagonal_prober_mixin.py b/nifty/probing/mixin_classes/diagonal_prober_mixin.py index 0432050dbe815e3b7b68b9fab15b96083081ab81..e854febf80599e85179409585755088308fa6f7f 100644 --- a/nifty/probing/mixin_classes/diagonal_prober_mixin.py +++ b/nifty/probing/mixin_classes/diagonal_prober_mixin.py @@ -16,10 +16,13 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. +from nifty.sugar import create_composed_fft_operator + class DiagonalProberMixin(object): def __init__(self, *args, **kwargs): self.reset() + self.__evaluate_probe_in_signal_space = True super(DiagonalProberMixin, self).__init__(*args, **kwargs) def reset(self): @@ -30,7 +33,11 @@ class DiagonalProberMixin(object): super(DiagonalProberMixin, self).reset() def finish_probe(self, probe, pre_result): - result = probe[1].conjugate()*pre_result + if self.__evaluate_probe_in_signal_space: + fft = create_composed_fft_operator(self._domain, all_to='position') + result = fft(probe[1]).conjugate()*fft(pre_result) + else: + result = probe[1].conjugate()*pre_result self.__sum_of_probings += result if self.compute_variance: self.__sum_of_squares += result.conjugate() * result diff --git a/nifty/probing/mixin_classes/trace_prober_mixin.py b/nifty/probing/mixin_classes/trace_prober_mixin.py index d2054dd6fbbab3eb22c15298d9efd5179441e282..243e8238fa6d1ee672473685dfcbcde2926bab3c 100644 --- a/nifty/probing/mixin_classes/trace_prober_mixin.py +++ b/nifty/probing/mixin_classes/trace_prober_mixin.py @@ -16,10 +16,13 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. +from nifty.sugar import create_composed_fft_operator + class TraceProberMixin(object): def __init__(self, *args, **kwargs): self.reset() + self.__evaluate_probe_in_signal_space = True super(TraceProberMixin, self).__init__(*args, **kwargs) def reset(self): @@ -30,7 +33,12 @@ class TraceProberMixin(object): super(TraceProberMixin, self).reset() def finish_probe(self, probe, pre_result): - result = probe[1].vdot(pre_result, bare=True) + if self.__evaluate_probe_in_signal_space: + fft = create_composed_fft_operator(self._domain, all_to='position') + result = fft(probe[1]).vdot(fft(pre_result), bare=True) + else: + result = probe[1].vdot(pre_result, bare=True) + self.__sum_of_probings += result if self.compute_variance: self.__sum_of_squares += result.conjugate() * result diff --git a/nifty/probing/prober/prober.py b/nifty/probing/prober/prober.py index f292c495f047e4e376bd77d5c37e164ce289c3f5..65eae6811eee9f12d7e2f0ed355a6e907badd766 100644 --- a/nifty/probing/prober/prober.py +++ b/nifty/probing/prober/prober.py @@ -47,6 +47,7 @@ class Prober(object): self._random_type = self._parse_random_type(random_type) self.compute_variance = bool(compute_variance) self.probe_dtype = np.dtype(probe_dtype) + self._uid_counter = 0 # ---Properties--- @@ -108,7 +109,8 @@ class Prober(object): domain=self.domain, dtype=self.probe_dtype, distribution_strategy=self.distribution_strategy) - uid = np.random.randint(1e18) + uid = self._uid_counter + self._uid_counter += 1 return (uid, f) def process_probe(self, callee, probe, index): diff --git a/nifty/sugar.py b/nifty/sugar.py index 0db2425bcd399c9edf80220d2be1f198343d9ea8..7c1941ef13f26b53cccc23e99059a29caa46fd44 100644 --- a/nifty/sugar.py +++ b/nifty/sugar.py @@ -16,13 +16,18 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from nifty import PowerSpace,\ +from nifty import Space,\ + PowerSpace,\ Field,\ + ComposedOperator,\ DiagonalOperator,\ + FFTOperator,\ sqrt,\ nifty_configuration -__all__ = ['create_power_operator', 'generate_posterior_sample'] +__all__ = ['create_power_operator', + 'generate_posterior_sample', + 'create_composed_fft_operator'] def create_power_operator(domain, power_spectrum, dtype=None, @@ -110,3 +115,23 @@ def generate_posterior_sample(mean, covariance): mock_m = covariance.inverse_times(mock_j) sample = mock_signal - mock_m + mean return sample + + +def create_composed_fft_operator(domain, codomain=None, all_to='other'): + fft_op_list = [] + space_index_list = [] + + if codomain is None: + codomain = [None]*len(domain) + for i in range(len(domain)): + space = domain[i] + cospace = codomain[i] + if not isinstance(space, Space): + continue + if (all_to == 'other' or + (all_to == 'position' and space.harmonic) or + (all_to == 'harmonic' and not space.harmonic)): + fft_op_list += [FFTOperator(domain=space, target=cospace)] + space_index_list += [i] + result = ComposedOperator(fft_op_list, default_spaces=space_index_list) + return result