Commit a0d4bb5b authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'NIFTy_4' into new_los

parents 1328c826 484ddb59
Pipeline #28496 passed with stages
in 13 minutes and 35 seconds
image: docker:stable
image: $CONTAINER_TEST_IMAGE
variables:
CONTAINER_TEST_IMAGE: gitlab-registry.mpcdf.mpg.de/ift/nifty:$CI_BUILD_REF_NAME
......@@ -8,7 +8,20 @@ stages:
- test
- release
build_docker:
build_docker_from_scratch:
only:
- schedules
image: docker:stable
stage: build_docker
script:
- docker login -u gitlab-ci-token -p $CI_BUILD_TOKEN gitlab-registry.mpcdf.mpg.de
- docker build -t $CONTAINER_TEST_IMAGE --no-cache .
- docker push $CONTAINER_TEST_IMAGE
build_docker_from_cache:
except:
- schedules
image: docker:stable
stage: build_docker
script:
- docker login -u gitlab-ci-token -p $CI_BUILD_TOKEN gitlab-registry.mpcdf.mpg.de
......@@ -16,7 +29,6 @@ build_docker:
- docker push $CONTAINER_TEST_IMAGE
test_python2_scalar:
image: $CONTAINER_TEST_IMAGE
stage: test
script:
- python setup.py install --user -f
......@@ -25,28 +37,24 @@ test_python2_scalar:
coverage report | grep TOTAL | awk '{ print "TOTAL: "$6; }'
test_python3_scalar:
image: $CONTAINER_TEST_IMAGE
stage: test
script:
- python3 setup.py install --user -f
- nosetests3 -q
test_python2_mpi:
image: $CONTAINER_TEST_IMAGE
stage: test
script:
- python setup.py install --user -f
- OMP_NUM_THREADS=1 mpiexec -n 2 nosetests -q 2> /dev/null
test_python3_mpi:
image: $CONTAINER_TEST_IMAGE
stage: test
script:
- python3 setup.py install --user -f
- OMP_NUM_THREADS=1 mpiexec -n 2 nosetests3 -q 2> /dev/null
pages:
image: $CONTAINER_TEST_IMAGE
stage: release
script:
- python setup.py install --user -f
......
......@@ -24,6 +24,9 @@ from .utilities import memo
from .logger import logger
from .multi import *
__all__ = ["__version__", "dobj", "DomainTuple"] + \
domains.__all__ + operators.__all__ + minimization.__all__ + \
["DomainTuple", "Field", "sqrt", "exp", "log"]
["DomainTuple", "Field", "sqrt", "exp", "log"] + \
multi.__all__
......@@ -66,6 +66,8 @@ class DomainTuple(object):
"""
if isinstance(domain, DomainTuple):
return domain
if isinstance(domain, dict):
return domain
domain = DomainTuple._parse_domain(domain)
obj = DomainTuple._tupleCache.get(domain)
if obj is not None:
......
......@@ -746,20 +746,6 @@ class Field(object):
raise ValueError("domains are incompatible.")
self.local_data[()] = other.local_data[()]
def _binary_helper(self, other, op):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val)
return self if tval is self.val else Field(self._domain, tval)
if np.isscalar(other) or isinstance(other, dobj.data_object):
tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self._domain, tval)
return NotImplemented
def __repr__(self):
return "<nifty4.Field>"
......@@ -778,30 +764,38 @@ for op in ["__add__", "__radd__", "__iadd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
return self._binary_helper(other, op=op)
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val)
return self if tval is self.val else Field(self._domain, tval)
if np.isscalar(other) or isinstance(other, dobj.data_object):
tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self._domain, tval)
return NotImplemented
return func2
setattr(Field, op, func(op))
# Arithmetic functions working on Fields
def _math_helper(x, function, out):
function = getattr(dobj, function)
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
function(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=function(x.val))
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
return _math_helper(x, f, out)
fu = getattr(dobj, f)
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=fu(x.val))
return func2
setattr(_current_module, f, func(f))
from .multi_domain import MultiDomain
from .multi_field import MultiField
__all__ = ["MultiDomain", "MultiField"]
class MultiDomain(dict):
pass
from ..field import Field
import numpy as np
from .multi_domain import MultiDomain
class MultiField(object):
def __init__(self, val):
"""
Parameters
----------
val : dict
"""
self._val = val
def __getitem__(self, key):
return self._val[key]
def keys(self):
return self._val.keys()
def items(self):
return self._val.items()
def values(self):
return self._val.values()
@property
def domain(self):
return MultiDomain({key: val.domain for key, val in self._val.items()})
@property
def dtype(self):
return {key: val.dtype for key, val in self._val.items()}
def _check_domain(self, other):
if other.domain != self.domain:
raise ValueError("domains are incompatible.")
def vdot(self, x):
result = 0.
self._check_domain(x)
for key, sub_field in self.items():
result += sub_field.vdot(x[key])
return result
def lock(self):
for v in self.values():
v.lock()
return self
def copy(self):
return MultiField({key: val.copy() for key, val in self.items()})
@staticmethod
def build_dtype(dtype, domain):
if isinstance(dtype, dict):
return dtype
if dtype is None:
dtype = np.float64
return {key: dtype for key in domain.keys()}
@staticmethod
def zeros(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.zeros(dom, dtype=dtype[key])
for key, dom in domain.items()})
@staticmethod
def ones(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.ones(dom, dtype=dtype[key])
for key, dom in domain.items()})
@staticmethod
def empty(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.empty(dom, dtype=dtype[key])
for key, dom in domain.items()})
def norm(self):
""" Computes the L2-norm of the field values.
Returns
-------
norm : float
The L2-norm of the field values.
"""
return np.sqrt(np.abs(self.vdot(x=self)))
def __neg__(self):
return MultiField({key: -val for key, val in self.items()})
def conjugate(self):
return MultiField({key: sub_field.conjugate()
for key, sub_field in self.items()})
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
"__div__", "__rdiv__", "__idiv__",
"__truediv__", "__rtruediv__", "__itruediv__",
"__floordiv__", "__rfloordiv__", "__ifloordiv__",
"__pow__", "__rpow__", "__ipow__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
self._check_domain(other)
result_val = {key: getattr(sub_field, op)(other[key])
for key, sub_field in self.items()}
else:
result_val = {key: getattr(val, op)(other)
for key, val in self.items()}
return MultiField(result_val)
return func2
setattr(MultiField, op, func(op))
......@@ -18,7 +18,6 @@
from ..minimization.quadratic_energy import QuadraticEnergy
from ..minimization.iteration_controller import IterationController
from ..field import Field
from ..logger import logger
from .endomorphic_operator import EndomorphicOperator
import numpy as np
......@@ -68,7 +67,7 @@ class InversionEnabler(EndomorphicOperator):
if self._op.capability & mode:
return self._op.apply(x, mode)
x0 = Field.zeros(self._tgt(mode), dtype=x.dtype)
x0 = x*0.
invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]]
invop = self._op._flip_modes(self._ilog[invmode])
prec = self._approximation
......@@ -83,6 +82,6 @@ class InversionEnabler(EndomorphicOperator):
def draw_sample(self, from_inverse=False, dtype=np.float64):
try:
return self._op.draw_sample(from_inverse, dtype)
except:
except NotImplementedError:
samp = self._op.draw_sample(not from_inverse, dtype)
return self.inverse_times(samp) if from_inverse else self(samp)
......@@ -271,8 +271,9 @@ class LinearOperator(NiftyMetaBase()):
raise ValueError("requested operator mode is not supported")
def _check_input(self, x, mode):
if not isinstance(x, Field):
raise ValueError("supplied object is not a `Field`.")
# MR FIXME: temporary fix for working with MultiFields
#if not isinstance(x, Field):
# raise ValueError("supplied object is not a `Field`.")
self._check_mode(mode)
if x.domain != self._dom(mode):
......
......@@ -50,6 +50,7 @@ class SandwichOperator(EndomorphicOperator):
def draw_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise ValueError("cannot draw from inverse of this operator")
raise NotImplementedError(
"cannot draw from inverse of this operator")
return self._bun.adjoint_times(
self._cheese.draw_sample(from_inverse, dtype))
......@@ -145,7 +145,8 @@ class SumOperator(LinearOperator):
def draw_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise ValueError("cannot draw from inverse of this operator")
raise NotImplementedError(
"cannot draw from inverse of this operator")
res = self._ops[0].draw_sample(from_inverse, dtype)
for op in self._ops[1:]:
res += op.draw_sample(from_inverse, dtype)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment