Commit eac89700 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'fix_linear_interpolation_boundary' into 'NIFTy_5'

Fix linear interpolation boundary

See merge request !289
parents 2198d139 f65415e7
Pipeline #43342 passed with stages
in 8 minutes and 53 seconds
...@@ -109,7 +109,7 @@ class Linearization(object): ...@@ -109,7 +109,7 @@ class Linearization(object):
def __getitem__(self, name): def __getitem__(self, name):
from .operators.simple_linear_operators import ducktape from .operators.simple_linear_operators import ducktape
return self.new(self._val[name], ducktape(None, self.domain, name)) return self.new(self._val[name], self._jac.ducktape_left(name))
def __neg__(self): def __neg__(self):
return self.new(-self._val, -self._jac, return self.new(-self._val, -self._jac,
......
...@@ -328,8 +328,8 @@ class StandardHamiltonian(EnergyOperator): ...@@ -328,8 +328,8 @@ class StandardHamiltonian(EnergyOperator):
return self._lh(x) + self._prior(x) return self._lh(x) + self._prior(x)
else: else:
lhx, prx = self._lh(x), self._prior(x) lhx, prx = self._lh(x), self._prior(x)
mtr = SamplingEnabler(lhx.metric, prx.metric.inverse, mtr = SamplingEnabler(lhx.metric, prx.metric,
self._ic_samp, prx.metric.inverse) self._ic_samp)
return (lhx + prx).add_metric(mtr) return (lhx + prx).add_metric(mtr)
def __repr__(self): def __repr__(self):
......
...@@ -75,8 +75,8 @@ class LinearInterpolator(LinearOperator): ...@@ -75,8 +75,8 @@ class LinearInterpolator(LinearOperator):
# dimensions. # dimensions.
dist = np.array(dist).reshape(-1, 1) dist = np.array(dist).reshape(-1, 1)
pos = sampling_points/dist pos = sampling_points/dist
excess = pos - pos.astype(np.int64) excess = pos - np.floor(pos)
pos = pos.astype(np.int64) pos = np.floor(pos).astype(np.int64)
max_index = np.array(self.domain.shape).reshape(-1, 1) max_index = np.array(self.domain.shape).reshape(-1, 1)
data = np.zeros((len(mg[0]), N_points)) data = np.zeros((len(mg[0]), N_points))
ii = np.zeros((len(mg[0]), N_points), dtype=np.int64) ii = np.zeros((len(mg[0]), N_points), dtype=np.int64)
......
...@@ -33,7 +33,7 @@ class SamplingEnabler(EndomorphicOperator): ...@@ -33,7 +33,7 @@ class SamplingEnabler(EndomorphicOperator):
likelihood : :class:`EndomorphicOperator` likelihood : :class:`EndomorphicOperator`
Metric of the likelihood Metric of the likelihood
prior : :class:`EndomorphicOperator` prior : :class:`EndomorphicOperator`
Inverse metric of the prior Metric of the prior
iteration_controller : :class:`IterationController` iteration_controller : :class:`IterationController`
The iteration controller to use for the iterative numerical inversion The iteration controller to use for the iterative numerical inversion
done by a :class:`ConjugateGradient` object. done by a :class:`ConjugateGradient` object.
......
...@@ -118,13 +118,47 @@ class FieldAdapter(LinearOperator): ...@@ -118,13 +118,47 @@ class FieldAdapter(LinearOperator):
return MultiField(self._tgt(mode), (x,)) return MultiField(self._tgt(mode), (x,))
def __repr__(self): def __repr__(self):
key = self.domain.keys()[0] return 'FieldAdapter'
return 'FieldAdapter: {}'.format(key)
class _SlowFieldAdapter(LinearOperator):
"""Operator for conversion between Fields and MultiFields.
The operator is built so that the MultiDomain is always the target.
Its domain is `tgt[name]`
Parameters
----------
dom : dict or MultiDomain:
the operator's dom
name : String
The relevant key of the MultiDomain.
"""
def __init__(self, dom, name):
from ..sugar import makeDomain
tmp = makeDomain(dom)
if not isinstance(tmp, MultiDomain):
raise TypeError("MultiDomain expected")
self._name = str(name)
self._domain = tmp
self._target = tmp[name]
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if isinstance(x, MultiField):
return x[self._name]
else:
return MultiField.from_dict(self._tgt(mode), {self._name: x})
def __repr__(self):
return '_SlowFieldAdapter'
def ducktape(left, right, name): def ducktape(left, right, name):
"""Convenience function creating an operator that converts between a """Convenience function creating an operator that converts between a
DomainTuple and a single-entry MultiDomain. DomainTuple and a MultiDomain.
Parameters Parameters
---------- ----------
...@@ -146,36 +180,48 @@ def ducktape(left, right, name): ...@@ -146,36 +180,48 @@ def ducktape(left, right, name):
- `left` and `right` must not be both `None`, but one of them can (and - `left` and `right` must not be both `None`, but one of them can (and
probably should) be `None`. In this case, the missing information is probably should) be `None`. In this case, the missing information is
inferred. inferred.
- the returned operator's domains are
- a `DomainTuple` and
- a `MultiDomain` with exactly one entry called `name` and the same
`DomainTuple`
Which of these is the domain and which is the target depends on the
input.
Returns Returns
------- -------
FieldAdapter : an adapter operator converting between the two (possibly FieldAdapter or _SlowFieldAdapter
partially inferred) domains. an adapter operator converting between the two (possibly
partially inferred) domains.
""" """
from ..sugar import makeDomain from ..sugar import makeDomain
from .operator import Operator from .operator import Operator
if isinstance(right, Operator):
right = right.target
elif right is not None:
right = makeDomain(right)
if isinstance(left, Operator):
left = left.domain
elif left is not None:
left = makeDomain(left)
if left is None: # need to infer left from right if left is None: # need to infer left from right
if isinstance(right, Operator):
right = right.target
elif right is not None:
right = makeDomain(right)
if isinstance(right, MultiDomain): if isinstance(right, MultiDomain):
left = right[name] left = right[name]
else: else:
left = MultiDomain.make({name: right}) left = MultiDomain.make({name: right})
else: elif right is None: # need to infer right from left
if isinstance(left, Operator): if isinstance(left, MultiDomain):
left = left.domain right = left[name]
else:
right = MultiDomain.make({name: left})
lmulti = isinstance(left, MultiDomain)
rmulti = isinstance(right, MultiDomain)
if lmulti+rmulti != 1:
raise ValueError("need exactly one MultiDomain")
if lmulti:
if len(left) == 1:
return FieldAdapter(left, name)
else:
return _SlowFieldAdapter(left, name).adjoint
if rmulti:
if len(right) == 1:
return FieldAdapter(left, name)
else: else:
left = makeDomain(left) return _SlowFieldAdapter(right, name)
return FieldAdapter(left, name) raise ValueError("must not arrive here")
class GeometryRemover(LinearOperator): class GeometryRemover(LinearOperator):
......
...@@ -363,7 +363,7 @@ def makeOp(input): ...@@ -363,7 +363,7 @@ def makeOp(input):
return DiagonalOperator(input) return DiagonalOperator(input)
if isinstance(input, MultiField): if isinstance(input, MultiField):
return BlockDiagonalOperator( return BlockDiagonalOperator(
input.domain, {key: makeOp(val) for key, val in enumerate(input)}) input.domain, {key: makeOp(val) for key, val in input.items()})
raise NotImplementedError raise NotImplementedError
......
...@@ -61,8 +61,8 @@ def testBinary(type1, type2, space, seed): ...@@ -61,8 +61,8 @@ def testBinary(type1, type2, space, seed):
_make_linearization(type2, dom2, seed) _make_linearization(type2, dom2, seed)
dom = ift.MultiDomain.union((dom1, dom2)) dom = ift.MultiDomain.union((dom1, dom2))
select_s1 = ift.ducktape(None, dom, "s1") select_s1 = ift.ducktape(None, dom1, "s1")
select_s2 = ift.ducktape(None, dom, "s2") select_s2 = ift.ducktape(None, dom2, "s2")
model = select_s1*select_s2 model = select_s1*select_s2
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
ift.extra.check_jacobian_consistency(model, pos, ntries=20) ift.extra.check_jacobian_consistency(model, pos, ntries=20)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment