Commit f8796588 authored by Julian Ruestig's avatar Julian Ruestig 📡

Merge branch 'opsumdomains' of https://gitlab.mpcdf.mpg.de/ift/nifty into mf_plus_add

parents d1301f82 3d353eca
Pipeline #45253 passed with stages
in 9 minutes and 4 seconds
......@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \
# Testing dependencies
python3-pytest-cov jupyter \
# Optional NIFTy dependencies
libfftw3-dev python3-mpi4py python3-matplotlib \
libfftw3-dev python3-mpi4py python3-matplotlib python3-pynfft \
# more optional NIFTy dependencies
&& pip3 install pyfftw \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \
......
......@@ -86,6 +86,7 @@ from .library.correlated_fields import (CorrelatedField, MfCorrelatedField,
MfPartiallyCorrelatedField)
from .library.adjust_variances import (make_adjust_variances_hamiltonian,
do_adjust_variances)
from .library.nfft import NFFT
from . import extra
......
......@@ -235,7 +235,6 @@ class data_object(object):
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
"__div__", "__rdiv__", "__idiv__",
"__truediv__", "__rtruediv__", "__itruediv__",
"__floordiv__", "__rfloordiv__", "__ifloordiv__",
"__pow__", "__rpow__", "__ipow__",
......
......@@ -626,6 +626,11 @@ class Field(object):
raise ValueError("domain mismatch")
return self
def extract_part(self, dom):
if dom != self._domain:
raise ValueError("domain mismatch")
return self
def unite(self, other):
return self+other
......@@ -658,7 +663,6 @@ class Field(object):
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
"__div__", "__rdiv__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
"__pow__", "__rpow__",
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2018-2019 Max-Planck-Society
#
# Resolve is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import nifty5 as ift
class NFFT(ift.LinearOperator):
"""Performs a non-equidistant Fourier transform, i.e. a Fourier transform
followed by a degridding operation.
Parameters
----------
domain : RGSpace
Domain of the operator. It has to be two-dimensional and have shape
`(2N, 2N)`. The coordinates of the lower left pixel of the dirty image
are `(-N,-N)`, and of the upper right pixel `(N-1,N-1)`.
uv : numpy.ndarray
2D numpy array of type float64 and shape (M,2), where M is the number
of measurements. uv[i,0] and uv[i,1] contain the u and v coordinates
of measurement #i, respectively. All coordinates must lie in the range
`[-0.5; 0,5[`.
"""
def __init__(self, domain, uv):
from pynfft.nfft import NFFT
npix = domain.shape[0]
assert npix == domain.shape[1]
assert len(domain.shape) == 2
assert type(npix) == int, "npix must be integer"
assert npix > 0 and (
npix % 2) == 0, "npix must be an even, positive integer"
assert isinstance(uv, np.ndarray), "uv must be a Numpy array"
assert uv.dtype == np.float64, "uv must be an array of float64"
assert uv.ndim == 2, "uv must be a 2D array"
assert uv.shape[0] > 0, "at least one point needed"
assert uv.shape[1] == 2, "the second dimension of uv must be 2"
assert np.all(uv >= -0.5) and np.all(uv <= 0.5),\
"all coordinates must lie between -0.5 and 0.5"
self._domain = ift.DomainTuple.make(domain)
self._target = ift.DomainTuple.make(
ift.UnstructuredDomain(uv.shape[0]))
self._capability = self.TIMES | self.ADJOINT_TIMES
self.npt = uv.shape[0]
self.plan = NFFT(self.domain.shape, self.npt, m=6)
self.plan.x = uv
self.plan.precompute()
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
self.plan.f_hat = x.to_global_data()
res = self.plan.trafo().copy()
else:
self.plan.f = x.to_global_data()
res = self.plan.adjoint().copy()
return ift.Field.from_global_data(self._tgt(mode), res)
......@@ -217,6 +217,12 @@ class MultiField(object):
return MultiField(subset,
tuple(self[key] for key in subset.keys()))
def extract_part(self, subset):
if subset is self._domain:
return self
return MultiField.from_dict({key: self[key] for key in subset.keys()
if key in self})
def unite(self, other):
"""Merges two MultiFields on potentially different MultiDomains.
......@@ -311,7 +317,6 @@ class MultiField(object):
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
"__div__", "__rdiv__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
"__pow__", "__rpow__",
......
......@@ -138,6 +138,17 @@ class ChainOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
# def draw_sample(self, from_inverse=False, dtype=np.float64):
# from ..sugar import from_random
# if len(self._ops) == 1:
......
......@@ -146,6 +146,17 @@ class Operator(metaclass=NiftyMeta):
def __repr__(self):
return self.__class__.__name__
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
return self._simplify_for_constant_input_nontrivial(c_inp)
def _simplify_for_constant_input_nontrivial(self, c_inp):
return None, self
for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan',
'sinh', 'cosh', 'absolute', 'sinc', 'one_over']:
......@@ -157,6 +168,72 @@ for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan',
setattr(Operator, f, func(f))
class _ConstCollector(object):
def __init__(self):
self._const = None
self._nc = set()
def mult(self, const, fulldom):
if const is None:
self._nc |= set(fulldom)
else:
self._nc |= set(fulldom) - set(const)
if self._const is None:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: const[key] for key in const if key not in self._nc})
else:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: self._const[key]*const[key]
for key in const if key not in self._nc})
def add(self, const, fulldom):
if const is None:
self._nc |= set(fulldom.keys())
else:
from ..multi_field import MultiField
self._nc |= set(fulldom.keys()) - set(const.keys())
if self._const is None:
self._const = MultiField.from_dict(
{key: const[key]
for key in const.keys() if key not in self._nc})
else:
self._const = self._const.unite(const)
self._const = MultiField.from_dict(
{key: self._const[key]
for key in self._const if key not in self._nc})
@property
def constfield(self):
return self._const
class _ConstantOperator(Operator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
self._target = output.domain
self._output = output
def apply(self, x):
from ..linearization import Linearization
from .simple_linear_operators import NullOperator
from ..domain_tuple import DomainTuple
self._check_input(x)
if not isinstance(x, Linearization):
return self._output
if x.want_metric and self._target is DomainTuple.scalar_domain():
met = NullOperator(self._domain, self._domain)
else:
met = None
return x.new(self._output, NullOperator(self._domain, self._target),
met)
def __repr__(self):
return 'ConstantOperator <- {}'.format(self.domain.keys())
class _FunctionApplier(Operator):
def __init__(self, domain, funcname):
from ..sugar import makeDomain
......@@ -229,6 +306,17 @@ class _OpChain(_CombinedOperator):
x = op(x)
return x
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "_OpChain:\n" + indent(subs)
......@@ -261,6 +349,21 @@ class _OpProd(Operator):
makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, op(x.jac))
def _simplify_for_constant_input_nontrivial(self, c_inp):
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpProd(o1, o2)
cc = _ConstCollector()
cc.mult(f1, o1.target)
cc.mult(f2, o2.target)
return cc.constfield, _OpProd(o1, o2)
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
return "_OpProd:\n"+indent(subs)
......@@ -281,7 +384,6 @@ class _OpSum(Operator):
v = x._val if lin else x
v1 = v.extract(self._op1.domain)
v2 = v.extract(self._op2.domain)
res = None
if not lin:
return self._op1(v1).unite(self._op2(v2))
wm = x.want_metric
......@@ -290,9 +392,24 @@ class _OpSum(Operator):
op = lin1._jac._myadd(lin2._jac, False)
res = lin1.new(lin1._val.unite(lin2._val), op(x.jac))
if lin1._metric is not None and lin2._metric is not None:
res = res.add_metric(lin1._metric + lin2._metric)
res = res.add_metric(self._op1(x)._metric + self._op2(x)._metric)
return res
def _simplify_for_constant_input_nontrivial(self, c_inp):
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpSum(o1, o2)
cc = _ConstCollector()
cc.add(f1, o1.target)
cc.add(f2, o2.target)
return cc.constfield, _OpSum(o1, o2)
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
return "_OpSum:\n"+indent(subs)
......@@ -42,7 +42,12 @@ class ScalingOperator(EndomorphicOperator):
only in appropriate ways (e.g. call inverse_times only if `factor` is
nonzero).
This shortcoming will hopefully be fixed in the future.
Along with this behaviour comes the feature that it is possible to draw an
inverse sample from a :class:`ScalingOperator` (which is a zero-field).
This occurs if one draws an inverse sample of a positive definite sum of
two operators each of which are only positive semi-definite. However, it
is unclear whether this beviour does not lead to unwanted effects
somewhere else.
"""
def __init__(self, factor, domain):
......
......@@ -315,3 +315,23 @@ class NullOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
return self._nullfield(self._tgt(mode))
class _PartialExtractor(LinearOperator):
def __init__(self, domain, target):
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
if not isinstance(target, MultiDomain):
raise TypeError("MultiDomain expected")
self._domain = domain
self._target = target
for key in self._target.keys():
if not (self._domain[key] is not self._target[key]):
raise ValueError("domain mismatch")
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x.extract(self._target)
return MultiField.from_dict({key: x[key] for key in x.domain.keys()})
......@@ -23,6 +23,7 @@ from ..sugar import domain_union
from ..utilities import indent
from .block_diagonal_operator import BlockDiagonalOperator
from .linear_operator import LinearOperator
from .simple_linear_operators import NullOperator
class SumOperator(LinearOperator):
......@@ -59,6 +60,9 @@ class SumOperator(LinearOperator):
negnew += [not n for n in op._neg]
else:
negnew += list(op._neg)
# FIXME: this needs some more work to keep the domain and target unchanged!
# elif isinstance(op, NullOperator):
# pass
else:
opsnew.append(op)
negnew.append(ng)
......@@ -193,6 +197,9 @@ class SumOperator(LinearOperator):
"cannot draw from inverse of this operator")
res = None
for op in self._ops:
from .simple_linear_operators import NullOperator
if isinstance(op, NullOperator):
continue
tmp = op.draw_sample(from_inverse, dtype)
res = tmp if res is None else res.unite(tmp)
return res
......@@ -200,3 +207,29 @@ class SumOperator(LinearOperator):
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "SumOperator:\n"+indent(subs)
def _simplify_for_constant_input_nontrivial(self, c_inp):
f = []
o = []
for op in self._ops:
tf, to = op.simplify_for_constant_input(
c_inp.extract_part(op.domain))
f.append(tf)
o.append(to)
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
fullop = None
for to, n in zip(o, self._neg):
op = to if not n else -to
fullop = op if fullop is None else fullop + op
return None, fullop
from .operator import _ConstCollector
cc = _ConstCollector()
fullop = None
for tf, to, n in zip(f, o, self._neg):
cc.add(tf, to.target)
op = to if not n else -to
fullop = op if fullop is None else fullop + op
return cc.constfield, fullop
......@@ -279,3 +279,10 @@ def testValueInserter(sp, seed):
ind.append(np.random.randint(0, ss-1))
op = ift.ValueInserter(sp, ind)
ift.extra.consistency_check(op)
def testNFFT():
dom = ift.RGSpace(2*(16,))
uv = np.array([[.2, .4], [-.22, .452]])
op = ift.NFFT(dom, uv)
ift.extra.consistency_check(op)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from numpy.testing import assert_allclose, assert_equal
import nifty5 as ift
def test_simplification():
from nifty5.operators.operator import _ConstantOperator
f1 = ift.Field.full(ift.RGSpace(10),2.)
op = ift.FFTOperator(f1.domain)
_, op2 = op.simplify_for_constant_input(f1)
assert_equal(isinstance(op2, _ConstantOperator), True)
assert_allclose(op(f1).local_data, op2(f1).local_data)
dom = {"a": ift.RGSpace(10)}
f1 = ift.full(dom,2.)
op = ift.FFTOperator(f1.domain["a"]).ducktape("a")
_, op2 = op.simplify_for_constant_input(f1)
assert_equal(isinstance(op2, _ConstantOperator), True)
assert_allclose(op(f1).local_data, op2(f1).local_data)
dom = {"a": ift.RGSpace(10), "b": ift.RGSpace(5)}
f1 = ift.full(dom,2.)
pdom = {"a": ift.RGSpace(10)}
f2 = ift.full(pdom,2.)
o1 = ift.FFTOperator(f1.domain["a"])
o2 = ift.FFTOperator(f1.domain["b"])
op = (o1.ducktape("a").ducktape_left("a") +
o2.ducktape("b").ducktape_left("b"))
_, op2 = op.simplify_for_constant_input(f2)
assert_equal(isinstance(op2._op1, _ConstantOperator), True)
assert_allclose(op(f1)["a"].local_data, op2(f1)["a"].local_data)
assert_allclose(op(f1)["b"].local_data, op2(f1)["b"].local_data)
lin = ift.Linearization.make_var(ift.MultiField.full(op2.domain, 2.), True)
assert_allclose(op(lin).val["a"].local_data,
op2(lin).val["a"].local_data)
assert_allclose(op(lin).val["b"].local_data,
op2(lin).val["b"].local_data)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment