Commit a4f5f8f0 authored by Martin Reinecke's avatar Martin Reinecke

small enhancements

parent 453fbd4d
Pipeline #26667 failed with stages
in 4 minutes and 50 seconds
...@@ -170,6 +170,11 @@ class Field(object): ...@@ -170,6 +170,11 @@ class Field(object):
""" """
return Field(domain, dobj.from_global_data(arr)) return Field(domain, dobj.from_global_data(arr))
@staticmethod
def from_local_data(domain, arr):
domain = DomainTuple.make(domain)
return Field(domain, dobj.from_local_data(domain.shape, arr))
def to_global_data(self): def to_global_data(self):
"""Returns an array containing the full data of the field. """Returns an array containing the full data of the field.
...@@ -802,6 +807,24 @@ class Field(object): ...@@ -802,6 +807,24 @@ class Field(object):
def __ipow__(self, other): def __ipow__(self, other):
return self._binary_helper(other, op='__ipow__') return self._binary_helper(other, op='__ipow__')
def __lt__(self, other):
return self._binary_helper(other, op='__lt__')
def __le__(self, other):
return self._binary_helper(other, op='__le__')
def __ne__(self, other):
return self._binary_helper(other, op='__ne__')
def __eq__(self, other):
return self._binary_helper(other, op='__eq__')
def __ge__(self, other):
return self._binary_helper(other, op='__ge__')
def __gt__(self, other):
return self._binary_helper(other, op='__gt__')
def __repr__(self): def __repr__(self):
return "<nifty4.Field>" return "<nifty4.Field>"
......
...@@ -40,4 +40,4 @@ def WienerFilterCurvature(R, N, S, inverter): ...@@ -40,4 +40,4 @@ def WienerFilterCurvature(R, N, S, inverter):
The minimizer to use during numerical inversion The minimizer to use during numerical inversion
""" """
op = SandwichOperator(R, N.inverse) + S.inverse op = SandwichOperator(R, N.inverse) + S.inverse
return InversionEnabler(op, inverter, S.times) return InversionEnabler(op, inverter, S)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator
class AdjointOperator(LinearOperator):
"""Adapter class representing the adjoint of a given operator."""
def __init__(self, op):
super(AdjointOperator, self).__init__()
self._op = op
@property
def domain(self):
return self._op.target
@property
def target(self):
return self._op.domain
@property
def capability(self):
return self._adjointCapability[self._op.capability]
@property
def adjoint(self):
return self._op
def apply(self, x, mode):
return self._op.apply(x, self._adjointMode[mode])
...@@ -20,34 +20,51 @@ from .linear_operator import LinearOperator ...@@ -20,34 +20,51 @@ from .linear_operator import LinearOperator
import numpy as np import numpy as np
class InverseOperator(LinearOperator): class OperatorAdapter(LinearOperator):
"""Adapter class representing the inverse of a given operator.""" """Class representing the inverse and/or adjoint of another operator."""
def __init__(self, op): def __init__(self, op, mode):
super(InverseOperator, self).__init__() super(OperatorAdapter, self).__init__()
self._op = op self._op = op
self._mode = int(mode)
if self._mode not in (2, 4, 8):
raise ValueError("invalid mode")
@property @property
def domain(self): def domain(self):
return self._op.target return self._op._dom(self._mode)
@property @property
def target(self): def target(self):
return self._op.domain return self._op._tgt(self._mode)
@property @property
def capability(self): def capability(self):
return self._inverseCapability[self._op.capability] return self._capTable[self._mode][self._op.capability]
@property @property
def inverse(self): def inverse(self):
return self._op md = 6 if self._mode == 8 else self._mode
newmode = md ^ self.INVERSE_TIMES
newmode = 8 if md == 6 else md
return self._op if newmode == 0 else OperatorAdapter(self._op, newmode)
@property
def adjoint(self):
md = 6 if self._mode == 8 else self._mode
newmode = md ^ self.ADJOINT_TIMES
newmode = 8 if md == 6 else md
return self._op if newmode == 0 else OperatorAdapter(self._op, newmode)
def apply(self, x, mode): def apply(self, x, mode):
return self._op.apply(x, self._inverseMode[mode]) return self._op.apply(x, self._modeTable[self._mode][mode])
def draw_sample(self, dtype=np.float64): def draw_sample(self, dtype=np.float64):
return self._op.inverse_draw_sample(dtype) if self._mode & self.INVERSE_TIMES:
return self._op.inverse_draw_sample(dtype)
return self._op.draw_sample(dtype)
def inverse_draw_sample(self, dtype=np.float64): def inverse_draw_sample(self, dtype=np.float64):
return self._op.draw_sample(dtype) if self._mode & self.INVERSE_TIMES:
return self._op.draw_sample(dtype)
return self._op.inverse_draw_sample(dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment