diff --git a/demos/probing.py b/demos/probing.py index 95732a8a59495dd5fa3250e48dee5ec7eecf2b7d..ae93129e291deb003c9bf035caa87dabc6d14545 100644 --- a/demos/probing.py +++ b/demos/probing.py @@ -1,27 +1,12 @@ import nifty4 as ift import numpy as np -np.random.seed(42) - - -class DiagonalProber(ift.DiagonalProberMixin, ift.Prober): - pass - - -class MultiProber(ift.DiagonalProberMixin, ift.TraceProberMixin, ift.Prober): - pass - +np.random.seed(42) x = ift.RGSpace((8, 8)) f = ift.Field.from_random(domain=x, random_type='normal') diagOp = ift.DiagonalOperator(f) -diagProber = DiagonalProber(domain=x) -diagProber(diagOp) -ift.dobj.mprint((f - diagProber.diagonal).norm()) - -multiProber = MultiProber(domain=x) -multiProber(diagOp) -ift.dobj.mprint((f - multiProber.diagonal).norm()) -ift.dobj.mprint(f.sum() - multiProber.trace) +diag = ift.probe_diagonal(diagOp, 1000) +ift.dobj.mprint((f - diag).norm()) diff --git a/nifty4/__init__.py b/nifty4/__init__.py index e16f03d324edfd8e55c39fe1e6b6c97a49750593..2ef9778b756291784fdee9405024b9d13951e171 100644 --- a/nifty4/__init__.py +++ b/nifty4/__init__.py @@ -29,9 +29,6 @@ from .operators.inversion_enabler import InversionEnabler from .field import Field, sqrt, exp, log -from .probing.prober import Prober -from .probing.diagonal_prober_mixin import DiagonalProberMixin -from .probing.trace_prober_mixin import TraceProberMixin from .probing.utils import probe_with_posterior_samples, probe_diagonal from .minimization.line_search import LineSearch @@ -62,5 +59,4 @@ __all__ = ["Domain", "UnstructuredDomain", "StructuredDomain", "DiagonalOperator", "FFTOperator", "FFTSmoothingOperator", "DirectSmoothingOperator", "LaplaceOperator", "PowerProjectionOperator", "InversionEnabler", - "Field", "sqrt", "exp", "log", - "Prober", "DiagonalProberMixin", "TraceProberMixin"] + "Field", "sqrt", "exp", "log"] diff --git a/nifty4/probing/diagonal_prober_mixin.py b/nifty4/probing/diagonal_prober_mixin.py deleted file mode 100644 index 2ed1c2c0e9a9d83ac2fe26e478756d0532e30525..0000000000000000000000000000000000000000 --- a/nifty4/probing/diagonal_prober_mixin.py +++ /dev/null @@ -1,66 +0,0 @@ -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. -# -# Copyright(C) 2013-2017 Max-Planck-Society -# -# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik -# and financially supported by the Studienstiftung des deutschen Volkes. - -from __future__ import division -from builtins import object -from ..sugar import create_composed_ht_operator - - -class DiagonalProberMixin(object): - def __init__(self, *args, **kwargs): - self.reset() - self.__evaluate_probe_in_signal_space = False - super(DiagonalProberMixin, self).__init__(*args, **kwargs) - - def reset(self): - self.__sum_of_probings = 0 - self.__sum_of_squares = 0 - self.__diagonal = None - self.__diagonal_variance = None - super(DiagonalProberMixin, self).reset() - - def finish_probe(self, probe, pre_result): - if self.__evaluate_probe_in_signal_space: - ht = create_composed_ht_operator(self._domain) - result = ht(probe[1]).conjugate()*ht(pre_result) - else: - result = probe[1].conjugate()*pre_result - self.__sum_of_probings += result - if self.compute_variance: - self.__sum_of_squares += result.conjugate() * result - super(DiagonalProberMixin, self).finish_probe(probe, pre_result) - - @property - def diagonal(self): - if self.__diagonal is None: - self.__diagonal = self.__sum_of_probings/self.probe_count - return self.__diagonal - - @property - def diagonal_variance(self): - if not self.compute_variance: - raise AttributeError("self.compute_variance is set to False") - if self.__diagonal_variance is None: - # variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2) - n = self.probe_count - sum_pr = self.__sum_of_probings - mean = self.diagonal - sum_sq = self.__sum_of_squares - - self.__diagonal_variance = ((sum_sq - sum_pr*mean) / (n-1)) - return self.__diagonal_variance diff --git a/nifty4/probing/prober.py b/nifty4/probing/prober.py deleted file mode 100644 index 9ad804e6c2f95449e300306296ba0bd2f86e666a..0000000000000000000000000000000000000000 --- a/nifty4/probing/prober.py +++ /dev/null @@ -1,113 +0,0 @@ -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. -# -# Copyright(C) 2013-2017 Max-Planck-Society -# -# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik -# and financially supported by the Studienstiftung des deutschen Volkes. - -from builtins import str -from builtins import range -from builtins import object -import numpy as np -from ..field import Field, DomainTuple - - -class Prober(object): - """ - See the following webpages for the principles behind the usage of - mixin-classes - - https://www.python.org/download/releases/2.2.3/descrintro/#cooperation - https://rhettinger.wordpress.com/2011/05/26/super-considered-super/ - """ - - def __init__(self, domain=None, probe_count=8, - random_type='pm1', probe_dtype=np.float, - compute_variance=False, ncpu=1): - - self._domain = DomainTuple.make(domain) - self._probe_count = int(probe_count) - self._ncpu = int(ncpu) - if random_type not in ["pm1", "normal"]: - raise ValueError( - "unsupported random type: '" + str(random_type) + "'.") - self._random_type = random_type - self.compute_variance = bool(compute_variance) - self.probe_dtype = np.dtype(probe_dtype) - self._uid_counter = 0 - - @property - def domain(self): - return self._domain - - @property - def probe_count(self): - return self._probe_count - - @property - def random_type(self): - return self._random_type - - def gen_parallel_probe(self, callee): - for i in range(self.probe_count): - yield (callee, self.get_probe(i)) - - def probing_run(self, callee): - """ controls the generation, evaluation and finalization of probes """ - self.reset() - if self._ncpu == 1: - for index in range(self.probe_count): - current_probe = self.get_probe(index) - pre_result = self.process_probe(callee, current_probe, index) - self.finish_probe(current_probe, pre_result) - else: - from multiprocess import Pool - pool = Pool(self._ncpu) - for i in pool.imap_unordered(self.evaluate_probe_parallel, - self.gen_parallel_probe(callee)): - self.finish_probe(i[0], i[1]) - - def evaluate_probe_parallel(self, argtuple): - callee = argtuple[0] - probe = argtuple[1] - return (probe, callee(probe[1])) - - def reset(self): - pass - - def get_probe(self, index): - """ layer of abstraction for potential probe-caching """ - return self.generate_probe() - - def generate_probe(self): - """ a random-probe generator """ - f = Field.from_random(random_type=self.random_type, - domain=self.domain, dtype=self.probe_dtype) - uid = self._uid_counter - self._uid_counter += 1 - return (uid, f) - - def process_probe(self, callee, probe, index): - """ layer of abstraction for potential result-caching/recycling """ - return self.evaluate_probe(callee, probe[1]) - - def evaluate_probe(self, callee, probe, **kwargs): - """ processes a probe """ - return callee(probe, **kwargs) - - def finish_probe(self, probe, pre_result): - pass - - def __call__(self, callee): - return self.probing_run(callee) diff --git a/nifty4/probing/trace_prober_mixin.py b/nifty4/probing/trace_prober_mixin.py deleted file mode 100644 index 2bc3c955dae629edb848091bd144640f681ac2e2..0000000000000000000000000000000000000000 --- a/nifty4/probing/trace_prober_mixin.py +++ /dev/null @@ -1,67 +0,0 @@ -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. -# -# Copyright(C) 2013-2017 Max-Planck-Society -# -# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik -# and financially supported by the Studienstiftung des deutschen Volkes. - -from __future__ import division -from builtins import object -from ..sugar import create_composed_ht_operator - - -class TraceProberMixin(object): - def __init__(self, *args, **kwargs): - self.reset() - self.__evaluate_probe_in_signal_space = False - super(TraceProberMixin, self).__init__(*args, **kwargs) - - def reset(self): - self.__sum_of_probings = 0 - self.__sum_of_squares = 0 - self.__trace = None - self.__trace_variance = None - super(TraceProberMixin, self).reset() - - def finish_probe(self, probe, pre_result): - if self.__evaluate_probe_in_signal_space: - ht = create_composed_ht_operator(self._domain) - result = ht(probe[1]).vdot(ht(pre_result)) - else: - result = probe[1].vdot(pre_result) - - self.__sum_of_probings += result - if self.compute_variance: - self.__sum_of_squares += result.conjugate() * result - super(TraceProberMixin, self).finish_probe(probe, pre_result) - - @property - def trace(self): - if self.__trace is None: - self.__trace = self.__sum_of_probings/self.probe_count - return self.__trace - - @property - def trace_variance(self): - if not self.compute_variance: - raise AttributeError("self.compute_variance is set to False") - if self.__trace_variance is None: - # variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2) - n = self.probe_count - sum_pr = self.__sum_of_probings - mean = self.trace - sum_sq = self.__sum_of_squares - - self.__trace_variance = (sum_sq - sum_pr*mean) / (n-1) - return self.__trace_variance diff --git a/nifty4/probing/utils.py b/nifty4/probing/utils.py index 7ed73667ff6226e15620d1c412555abbfb26af90..0eaa712078451039238c64588e8335d2e2c52314 100644 --- a/nifty4/probing/utils.py +++ b/nifty4/probing/utils.py @@ -58,7 +58,7 @@ def probe_with_posterior_samples(op, post_op, nprobes): return sc.mean, sc.var -def probe_diagonal(op, nprobes, random_type="normal"): +def probe_diagonal(op, nprobes, random_type="pm1"): sc = StatCalculator() for i in range(nprobes): input = Field.from_random(random_type, op.domain)