Commit a4f5f8f0 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

small enhancements

parent 453fbd4d
Pipeline #26667 failed with stages
in 4 minutes and 50 seconds
......@@ -170,6 +170,11 @@ class Field(object):
"""
return Field(domain, dobj.from_global_data(arr))
@staticmethod
def from_local_data(domain, arr):
domain = DomainTuple.make(domain)
return Field(domain, dobj.from_local_data(domain.shape, arr))
def to_global_data(self):
"""Returns an array containing the full data of the field.
......@@ -802,6 +807,24 @@ class Field(object):
def __ipow__(self, other):
return self._binary_helper(other, op='__ipow__')
def __lt__(self, other):
return self._binary_helper(other, op='__lt__')
def __le__(self, other):
return self._binary_helper(other, op='__le__')
def __ne__(self, other):
return self._binary_helper(other, op='__ne__')
def __eq__(self, other):
return self._binary_helper(other, op='__eq__')
def __ge__(self, other):
return self._binary_helper(other, op='__ge__')
def __gt__(self, other):
return self._binary_helper(other, op='__gt__')
def __repr__(self):
return "<nifty4.Field>"
......
......@@ -40,4 +40,4 @@ def WienerFilterCurvature(R, N, S, inverter):
The minimizer to use during numerical inversion
"""
op = SandwichOperator(R, N.inverse) + S.inverse
return InversionEnabler(op, inverter, S.times)
return InversionEnabler(op, inverter, S)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator
class AdjointOperator(LinearOperator):
"""Adapter class representing the adjoint of a given operator."""
def __init__(self, op):
super(AdjointOperator, self).__init__()
self._op = op
@property
def domain(self):
return self._op.target
@property
def target(self):
return self._op.domain
@property
def capability(self):
return self._adjointCapability[self._op.capability]
@property
def adjoint(self):
return self._op
def apply(self, x, mode):
return self._op.apply(x, self._adjointMode[mode])
......@@ -20,34 +20,51 @@ from .linear_operator import LinearOperator
import numpy as np
class InverseOperator(LinearOperator):
"""Adapter class representing the inverse of a given operator."""
class OperatorAdapter(LinearOperator):
"""Class representing the inverse and/or adjoint of another operator."""
def __init__(self, op):
super(InverseOperator, self).__init__()
def __init__(self, op, mode):
super(OperatorAdapter, self).__init__()
self._op = op
self._mode = int(mode)
if self._mode not in (2, 4, 8):
raise ValueError("invalid mode")
@property
def domain(self):
return self._op.target
return self._op._dom(self._mode)
@property
def target(self):
return self._op.domain
return self._op._tgt(self._mode)
@property
def capability(self):
return self._inverseCapability[self._op.capability]
return self._capTable[self._mode][self._op.capability]
@property
def inverse(self):
return self._op
md = 6 if self._mode == 8 else self._mode
newmode = md ^ self.INVERSE_TIMES
newmode = 8 if md == 6 else md
return self._op if newmode == 0 else OperatorAdapter(self._op, newmode)
@property
def adjoint(self):
md = 6 if self._mode == 8 else self._mode
newmode = md ^ self.ADJOINT_TIMES
newmode = 8 if md == 6 else md
return self._op if newmode == 0 else OperatorAdapter(self._op, newmode)
def apply(self, x, mode):
return self._op.apply(x, self._inverseMode[mode])
return self._op.apply(x, self._modeTable[self._mode][mode])
def draw_sample(self, dtype=np.float64):
return self._op.inverse_draw_sample(dtype)
if self._mode & self.INVERSE_TIMES:
return self._op.inverse_draw_sample(dtype)
return self._op.draw_sample(dtype)
def inverse_draw_sample(self, dtype=np.float64):
return self._op.draw_sample(dtype)
if self._mode & self.INVERSE_TIMES:
return self._op.draw_sample(dtype)
return self._op.inverse_draw_sample(dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment