Commit 85ba4336 authored by Ultima's avatar Ultima
Browse files

Made test_nifty_mpi_data.py PEP8 compliant.

Removed trailing whitespaces in nifty_operators.py
parent 803b053a
## NIFTY (Numerical Information Field Theory) has been developed at the
## Max-Planck-Institute for Astrophysics.
##
## Copyright (C) 2015 Max-Planck-Society
##
## Author: Marco Selig
## Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
## See the GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.
# NIFTY (Numerical Information Field Theory) has been developed at the
# Max-Planck-Institute for Astrophysics.
#
# Copyright (C) 2015 Max-Planck-Society
#
# Author: Marco Selig
# Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import division
import numpy as np
from nifty.nifty_about import about
from nifty.nifty_core import space, \
point_space, \
nested_space, \
field
point_space, \
nested_space, \
field
from nifty_minimization import conjugate_gradient
from nifty_probing import trace_prober,\
inverse_trace_prober,\
diagonal_prober,\
inverse_diagonal_prober
import nifty.nifty_utilities as utilities
import nifty_simple_math
inverse_trace_prober,\
diagonal_prober,\
inverse_diagonal_prober
import nifty.nifty_utilities as utilities
import nifty_simple_math
##=============================================================================
# =============================================================================
class operator(object):
"""
......@@ -107,8 +107,9 @@ class operator(object):
This is a freeform list of parameters that derivatives of the
operator class can use. Not used in the base operators.
"""
def __init__(self, domain, codomain = None, sym=False, uni=False,
imp=False, target=None, cotarget=None, bare = False,
def __init__(self, domain, codomain=None, sym=False, uni=False,
imp=False, target=None, cotarget=None, bare=False,
para=None):
"""
Sets the attributes for an operator class instance.
......@@ -138,48 +139,46 @@ class operator(object):
-------
None
"""
## Check if the domain is realy a space
# Check if the domain is realy a space
if not isinstance(domain, space):
raise TypeError(about._errors.cstring(
"ERROR: invalid input. domain is not a space."))
self.domain = domain
## Parse codomain
if self.domain.check_codomain(codomain) == True:
# Parse codomain
if self.domain.check_codomain(codomain) == True:
self.codomain = codomain
else:
self.codomain = self.domain.get_codomain()
## Cast the symmetric and unitary input
# Cast the symmetric and unitary input
self.sym = bool(sym)
self.uni = bool(uni)
self.bare = bool(bare)
## If no target is supplied, we assume that the operator is square
## If the operator is symmetric or unitary, we know that the operator
## must be square
# If no target is supplied, we assume that the operator is square
# If the operator is symmetric or unitary, we know that the operator
# must be square
if self.sym == True or self.uni == True:
target = self.domain
if self.sym is True or self.uni is True:
target = self.domain
cotarget = self.codomain
if target is not None:
about.warnings.cprint("WARNING: Ignoring target.")
elif target is None:
target = self.domain
cotarget = self.codomain
elif isinstance(target, space):
self.target = target
## Parse cotarget
if self.target.check_codomain(cotarget) == True:
self.target = target
# Parse cotarget
if self.target.check_codomain(cotarget) == True:
self.codomain = codomain
else:
self.codomain = self.domain.get_codomain()
else:
else:
raise TypeError(about._errors.cstring(
"ERROR: invalid input. Target is not a space."))
"ERROR: invalid input. Target is not a space."))
if self.domain.discrete and self.target.discrete:
self.imp = True
......@@ -188,16 +187,6 @@ class operator(object):
self.para = para
#
# @property
# def val(self):
# return self._val
#
# @val.setter
# def val(self, x):
# self._val = self.domain.cast(x)
def set_val(self, new_val):
"""
Resets the field values.
......@@ -210,12 +199,11 @@ class operator(object):
"""
self.val = self.domain.cast(new_val)
return self.val
def get_val(self):
return self.val
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def nrow(self):
"""
......@@ -256,7 +244,7 @@ class operator(object):
"""
if axis is None:
return np.array([self.nrow(),self.ncol()])
return np.array([self.nrow(), self.ncol()])
elif axis == 0:
return self.nrow()
elif axis == 1:
......@@ -265,9 +253,9 @@ class operator(object):
raise ValueError(about._errors.cstring(
"ERROR: invalid input axis."))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_para(self,newpara):
def set_para(self, newpara):
"""
Sets the parameters and creates the `para` property if it does
not exist
......@@ -284,77 +272,77 @@ class operator(object):
"""
self.para = newpara
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def _multiply(self, x, **kwargs):
## > applies the operator to a given field
def _multiply(self, x, **kwargs):
# > applies the operator to a given field
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'multiply'."))
def _adjoint_multiply(self, x, **kwargs):
## > applies the adjoint operator to a given field
def _adjoint_multiply(self, x, **kwargs):
# > applies the adjoint operator to a given field
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'adjoint_multiply'."))
def _inverse_multiply(self, x, **kwargs):
## > applies the inverse operator to a given field
def _inverse_multiply(self, x, **kwargs):
# > applies the inverse operator to a given field
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_multiply'."))
def _adjoint_inverse_multiply(self, x, **kwargs):
## > applies the inverse adjoint operator to a given field
def _adjoint_inverse_multiply(self, x, **kwargs):
# > applies the inverse adjoint operator to a given field
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'adjoint_inverse_multiply'."))
def _inverse_adjoint_multiply(self, x, **kwargs):
## > applies the adjoint inverse operator to a given field
def _inverse_adjoint_multiply(self, x, **kwargs):
# > applies the adjoint inverse operator to a given field
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_adjoint_multiply'."))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def _briefing(self, x, domain, codomain, inverse): ## > prepares x for `multiply`
## inspect x
def _briefing(self, x, domain, codomain, inverse):
# inspect x
if not isinstance(x, field):
x_ = field(domain, codomain=codomain, val=x)
else:
## check x.domain
# check x.domain
if domain == x.domain:
x_ = x
## transform
# transform
else:
x_ = x.transform(codomain=domain)
## weight if ...
if self.imp == False and domain.discrete == False and inverse == False:
# weight if ...
if self.imp is False and domain.discrete is False and inverse is False:
x_ = x_.weight(power=1)
return x_
def _debriefing(self, x, x_, target, cotarget, inverse):
## > evaluates x and x_ after `multiply`
def _debriefing(self, x, x_, target, cotarget, inverse):
# > evaluates x and x_ after `multiply`
if x_ is None:
return None
else:
## inspect x_
# inspect x_
if not isinstance(x_, field):
x_ = field(target, codomain=cotarget, val=x_)
elif x_.domain != target:
raise ValueError(about._errors.cstring(
"ERROR: invalid output domain."))
## weight if ...
if self.imp == False and target.discrete == False and\
inverse == True:
# weight if ...
if self.imp is False and target.discrete is False and\
inverse is True:
x_ = x_.weight(power=-1)
## inspect x
# inspect x
if isinstance(x, field):
## repair if the originally field was living in the codomain
## of the operators domain
# repair if the originally field was living in the codomain
# of the operators domain
if self.domain == self.target != x.domain:
x_ = x_.transform(codomain=x.domain)
x_ = x_.transform(codomain=x.domain)
if x_.domain == x.domain and (x_.codomain != x.codomain):
x_.set_codomain(new_codomain = x.codomain)
x_.set_codomain(new_codomain=x.codomain)
return x_
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def times(self, x, **kwargs):
"""
......@@ -371,12 +359,12 @@ class operator(object):
Ox : field
Mapped field on the target domain of the operator.
"""
## prepare
# prepare
x_ = self._briefing(x, self.domain, self.codomain, inverse=False)
## apply operator
# apply operator
x_ = self._multiply(x_, **kwargs)
## evaluate
return self._debriefing(x, x_, self.target, self.cotarget,
# evaluate
return self._debriefing(x, x_, self.target, self.cotarget,
inverse=False)
def __call__(self, x, **kwargs):
......@@ -398,19 +386,19 @@ class operator(object):
Mapped field on the domain of the operator.
"""
## check whether self-adjoint
# check whether self-adjoint
if self.sym == True:
return self.times(x, **kwargs)
## check whether unitary
# check whether unitary
if self.uni == True:
return self.inverse_times(x, **kwargs)
## prepare
# prepare
x_ = self._briefing(x, self.target, self.cotarget, inverse=False)
## apply operator
# apply operator
x_ = self._adjoint_multiply(x_, **kwargs)
## evaluate
return self._debriefing(x, x_, self.domain, self.codomain,
# evaluate
return self._debriefing(x, x_, self.domain, self.codomain,
inverse=False)
def inverse_times(self, x, **kwargs):
......@@ -428,16 +416,16 @@ class operator(object):
OIx : field
Mapped field on the target space of the operator.
"""
## check whether self-inverse
# check whether self-inverse
if self.sym == True and self.uni == True:
return self.times(x,**kwargs)
return self.times(x, **kwargs)
## prepare
# prepare
x_ = self._briefing(x, self.target, self.cotarget, inverse=True)
## apply operator
# apply operator
x_ = self._inverse_multiply(x_, **kwargs)
## evaluate
return self._debriefing(x, x_, self.domain, self.codomain,
# evaluate
return self._debriefing(x, x_, self.domain, self.codomain,
inverse=True)
def adjoint_inverse_times(self, x, **kwargs):
......@@ -456,19 +444,19 @@ class operator(object):
Mapped field on the domain of the operator.
"""
## check whether self-adjoint
# check whether self-adjoint
if self.sym == True:
return self.inverse_times(x, **kwargs)
## check whether unitary
# check whether unitary
if self.uni == True:
return self.times(x, **kwargs)
## prepare
# prepare
x_ = self._briefing(x, self.domain, self.codomain, inverse=True)
## apply operator
# apply operator
x_ = self._adjoint_inverse_multiply(x_, **kwargs)
## evaluate
return self._debriefing(x, x_, self.target, self.cotarget,
# evaluate
return self._debriefing(x, x_, self.target, self.cotarget,
inverse=True)
def inverse_adjoint_times(self, x, **kwargs):
......@@ -487,26 +475,24 @@ class operator(object):
Mapped field on the domain of the operator.
"""
## check whether self-adjoint
# check whether self-adjoint
if self.sym == True:
return self.inverse_times(x, **kwargs)
## check whether unitary
# check whether unitary
if self.uni == True:
return self.times(x, **kwargs)
## prepare
# prepare
x_ = self._briefing(x, self.domain, self.codomain, inverse=True)
## apply operator
# apply operator
x_ = self._inverse_adjoint_multiply(x_, **kwargs)
## evaluate
return self._debriefing(x, x_, self.target, self.cotarget,
# evaluate
return self._debriefing(x, x_, self.target, self.cotarget,
inverse=True)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def tr(self, domain=None, codomain=None, random="pm1", nrun=8,
varQ=False, **kwargs):
varQ=False, **kwargs):
"""
Computes the trace of the operator
......@@ -548,19 +534,18 @@ class operator(object):
--------
probing : The class used to perform probing operations
"""
return trace_prober(operator = self,
domain = domain,
codomain = codomain,
random = random,
nrun = nrun,
varQ = varQ,
return trace_prober(operator=self,
domain=domain,
codomain=codomain,
random=random,
nrun=nrun,
varQ=varQ,
**kwargs
)()
def inverse_tr(self, domain=None, codomain=None, random="pm1", nrun=8,
varQ=False, **kwargs):
varQ=False, **kwargs):
"""
Computes the trace of the inverse operator
......@@ -596,19 +581,19 @@ class operator(object):
--------
probing : The class used to perform probing operations
"""
return inverse_trace_prober(operator = self,
domain = domain,
codomain = codomain,
random = random,
nrun = nrun,
varQ = varQ,
**kwargs
)()
return inverse_trace_prober(operator=self,
domain=domain,
codomain=codomain,
random=random,
nrun=nrun,
varQ=varQ,
**kwargs
)()
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def diag(self, domain=None, codomain=None, random="pm1", nrun=8,
varQ=False, bare=False, **kwargs):
varQ=False, bare=False, **kwargs):
"""
Computes the diagonal of the operator via probing.
......@@ -674,23 +659,23 @@ class operator(object):
"""
diag = diagonal_prober(operator = self,
domain = domain,
codomain = codomain,
random = random,
nrun = nrun,
varQ = varQ,
**kwargs
)()
diag = diagonal_prober(operator=self,
domain=domain,
codomain=codomain,
random=random,
nrun=nrun,
varQ=varQ,
**kwargs
)()
if diag is None:
about.warnings.cprint("WARNING: forwarding 'NoneType'.")
return None
if domain is None:
domain = diag.domain
## weight if ...
# weight if ...
if domain.discrete == False and bare == True:
if(isinstance(diag,tuple)): ## diag == (diag,variance)
if(isinstance(diag, tuple)): # diag == (diag,variance)
return (diag[0].weight(power=-1),
diag[1].weight(power=-1))
else:
......@@ -698,8 +683,8 @@ class operator(object):
else:
return diag
def inverse_diag(self, domain=None, codomain=None, random="pm1",
nrun=8, varQ=False, bare=False, **kwargs):
def inverse_diag(self, domain=None, codomain=None, random="pm1",
nrun=8, varQ=False, bare=False, **kwargs):
"""
Computes the diagonal of the inverse operator via probing.
......@@ -766,31 +751,31 @@ class operator(object):
"""
if(domain is None):
domain = self.target
diag = inverse_diagonal_prober(operator = self,
domain = domain,
codomain = codomain,
random = random,
nrun = nrun,
varQ = varQ,
**kwargs
)()
diag = inverse_diagonal_prober(operator=self,
domain=domain,
codomain=codomain,
random=random,
nrun=nrun,
varQ=varQ,
**kwargs
)()
if(diag is None):
# about.warnings.cprint("WARNING: forwarding 'NoneType'.")
# about.warnings.cprint("WARNING: forwarding 'NoneType'.")
return None
if domain is None:
domain = diag.codomain
## weight if ...
# weight if ...
if domain.discrete == False and bare == True:
if(isinstance(diag,tuple)): ## diag == (diag,variance)
if(isinstance(diag, tuple)): # diag == (diag,variance)
return (diag[0].weight(power=-1),
diag[1].weight(power=-1))
else:
return diag.weight(power=-1)
else:
return diag
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def det(self):
"""
Computes the determinant of the operator.
......@@ -842,7 +827,7 @@ class operator(object):
"""
return self.log_det()
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def hat(self, bare=False, domain=None, codomain=None, **kwargs):
"""
......@@ -906,13 +891,13 @@ class operator(object):
"""
if domain is None:
domain = self.domain
diag = self.diag(bare=bare, domain=domain, codomain=codomain,
diag = self.diag(bare=bare, domain=domain, codomain=codomain,
var=False, **kwargs)
if diag is None:
about.warnings.cprint("WARNING: forwarding 'NoneType'.")
return None
return diag
return diag
def inverse_hat(self, bare=False, domain=None, codomain=None, **kwargs):
"""
Translates the inverse operator's diagonal into a field
......@@ -975,13 +960,13 @@ class operator(object):
"""
if domain is None:
domain = self.target
diag = self.inverse_diag(bare=bare, domain=domain, codomain=codomain,
diag = self.inverse_diag(bare=bare, domain=domain, codomain=codomain,
var=False, **kwargs)
if diag is None:
about.warnings.cprint("WARNING: forwarding 'NoneType'.")
return None
return diag
return diag
def hathat(self, domain=None,