Commit 4ced6735 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

switch to simpler probing

parent 65d6f135
Pipeline #24792 passed with stage
in 5 minutes and 58 seconds
import nifty4 as ift
import numpy as np
np.random.seed(42)
class DiagonalProber(ift.DiagonalProberMixin, ift.Prober):
pass
class MultiProber(ift.DiagonalProberMixin, ift.TraceProberMixin, ift.Prober):
pass
np.random.seed(42)
x = ift.RGSpace((8, 8))
f = ift.Field.from_random(domain=x, random_type='normal')
diagOp = ift.DiagonalOperator(f)
diagProber = DiagonalProber(domain=x)
diagProber(diagOp)
ift.dobj.mprint((f - diagProber.diagonal).norm())
multiProber = MultiProber(domain=x)
multiProber(diagOp)
ift.dobj.mprint((f - multiProber.diagonal).norm())
ift.dobj.mprint(f.sum() - multiProber.trace)
diag = ift.probe_diagonal(diagOp, 1000)
ift.dobj.mprint((f - diag).norm())
......@@ -29,9 +29,6 @@ from .operators.inversion_enabler import InversionEnabler
from .field import Field, sqrt, exp, log
from .probing.prober import Prober
from .probing.diagonal_prober_mixin import DiagonalProberMixin
from .probing.trace_prober_mixin import TraceProberMixin
from .probing.utils import probe_with_posterior_samples, probe_diagonal
from .minimization.line_search import LineSearch
......@@ -62,5 +59,4 @@ __all__ = ["Domain", "UnstructuredDomain", "StructuredDomain",
"DiagonalOperator", "FFTOperator", "FFTSmoothingOperator",
"DirectSmoothingOperator", "LaplaceOperator",
"PowerProjectionOperator", "InversionEnabler",
"Field", "sqrt", "exp", "log",
"Prober", "DiagonalProberMixin", "TraceProberMixin"]
"Field", "sqrt", "exp", "log"]
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from builtins import object
from ..sugar import create_composed_ht_operator
class DiagonalProberMixin(object):
def __init__(self, *args, **kwargs):
self.reset()
self.__evaluate_probe_in_signal_space = False
super(DiagonalProberMixin, self).__init__(*args, **kwargs)
def reset(self):
self.__sum_of_probings = 0
self.__sum_of_squares = 0
self.__diagonal = None
self.__diagonal_variance = None
super(DiagonalProberMixin, self).reset()
def finish_probe(self, probe, pre_result):
if self.__evaluate_probe_in_signal_space:
ht = create_composed_ht_operator(self._domain)
result = ht(probe[1]).conjugate()*ht(pre_result)
else:
result = probe[1].conjugate()*pre_result
self.__sum_of_probings += result
if self.compute_variance:
self.__sum_of_squares += result.conjugate() * result
super(DiagonalProberMixin, self).finish_probe(probe, pre_result)
@property
def diagonal(self):
if self.__diagonal is None:
self.__diagonal = self.__sum_of_probings/self.probe_count
return self.__diagonal
@property
def diagonal_variance(self):
if not self.compute_variance:
raise AttributeError("self.compute_variance is set to False")
if self.__diagonal_variance is None:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
n = self.probe_count
sum_pr = self.__sum_of_probings
mean = self.diagonal
sum_sq = self.__sum_of_squares
self.__diagonal_variance = ((sum_sq - sum_pr*mean) / (n-1))
return self.__diagonal_variance
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import str
from builtins import range
from builtins import object
import numpy as np
from ..field import Field, DomainTuple
class Prober(object):
"""
See the following webpages for the principles behind the usage of
mixin-classes
https://www.python.org/download/releases/2.2.3/descrintro/#cooperation
https://rhettinger.wordpress.com/2011/05/26/super-considered-super/
"""
def __init__(self, domain=None, probe_count=8,
random_type='pm1', probe_dtype=np.float,
compute_variance=False, ncpu=1):
self._domain = DomainTuple.make(domain)
self._probe_count = int(probe_count)
self._ncpu = int(ncpu)
if random_type not in ["pm1", "normal"]:
raise ValueError(
"unsupported random type: '" + str(random_type) + "'.")
self._random_type = random_type
self.compute_variance = bool(compute_variance)
self.probe_dtype = np.dtype(probe_dtype)
self._uid_counter = 0
@property
def domain(self):
return self._domain
@property
def probe_count(self):
return self._probe_count
@property
def random_type(self):
return self._random_type
def gen_parallel_probe(self, callee):
for i in range(self.probe_count):
yield (callee, self.get_probe(i))
def probing_run(self, callee):
""" controls the generation, evaluation and finalization of probes """
self.reset()
if self._ncpu == 1:
for index in range(self.probe_count):
current_probe = self.get_probe(index)
pre_result = self.process_probe(callee, current_probe, index)
self.finish_probe(current_probe, pre_result)
else:
from multiprocess import Pool
pool = Pool(self._ncpu)
for i in pool.imap_unordered(self.evaluate_probe_parallel,
self.gen_parallel_probe(callee)):
self.finish_probe(i[0], i[1])
def evaluate_probe_parallel(self, argtuple):
callee = argtuple[0]
probe = argtuple[1]
return (probe, callee(probe[1]))
def reset(self):
pass
def get_probe(self, index):
""" layer of abstraction for potential probe-caching """
return self.generate_probe()
def generate_probe(self):
""" a random-probe generator """
f = Field.from_random(random_type=self.random_type,
domain=self.domain, dtype=self.probe_dtype)
uid = self._uid_counter
self._uid_counter += 1
return (uid, f)
def process_probe(self, callee, probe, index):
""" layer of abstraction for potential result-caching/recycling """
return self.evaluate_probe(callee, probe[1])
def evaluate_probe(self, callee, probe, **kwargs):
""" processes a probe """
return callee(probe, **kwargs)
def finish_probe(self, probe, pre_result):
pass
def __call__(self, callee):
return self.probing_run(callee)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from builtins import object
from ..sugar import create_composed_ht_operator
class TraceProberMixin(object):
def __init__(self, *args, **kwargs):
self.reset()
self.__evaluate_probe_in_signal_space = False
super(TraceProberMixin, self).__init__(*args, **kwargs)
def reset(self):
self.__sum_of_probings = 0
self.__sum_of_squares = 0
self.__trace = None
self.__trace_variance = None
super(TraceProberMixin, self).reset()
def finish_probe(self, probe, pre_result):
if self.__evaluate_probe_in_signal_space:
ht = create_composed_ht_operator(self._domain)
result = ht(probe[1]).vdot(ht(pre_result))
else:
result = probe[1].vdot(pre_result)
self.__sum_of_probings += result
if self.compute_variance:
self.__sum_of_squares += result.conjugate() * result
super(TraceProberMixin, self).finish_probe(probe, pre_result)
@property
def trace(self):
if self.__trace is None:
self.__trace = self.__sum_of_probings/self.probe_count
return self.__trace
@property
def trace_variance(self):
if not self.compute_variance:
raise AttributeError("self.compute_variance is set to False")
if self.__trace_variance is None:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
n = self.probe_count
sum_pr = self.__sum_of_probings
mean = self.trace
sum_sq = self.__sum_of_squares
self.__trace_variance = (sum_sq - sum_pr*mean) / (n-1)
return self.__trace_variance
......@@ -58,7 +58,7 @@ def probe_with_posterior_samples(op, post_op, nprobes):
return sc.mean, sc.var
def probe_diagonal(op, nprobes, random_type="normal"):
def probe_diagonal(op, nprobes, random_type="pm1"):
sc = StatCalculator()
for i in range(nprobes):
input = Field.from_random(random_type, op.domain)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment