Commit eac89700 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'fix_linear_interpolation_boundary' into 'NIFTy_5'

Fix linear interpolation boundary

See merge request !289
parents 2198d139 f65415e7
Pipeline #43342 passed with stages
in 8 minutes and 53 seconds
......@@ -109,7 +109,7 @@ class Linearization(object):
def __getitem__(self, name):
from .operators.simple_linear_operators import ducktape
return self.new(self._val[name], ducktape(None, self.domain, name))
return self.new(self._val[name], self._jac.ducktape_left(name))
def __neg__(self):
return self.new(-self._val, -self._jac,
......
......@@ -328,8 +328,8 @@ class StandardHamiltonian(EnergyOperator):
return self._lh(x) + self._prior(x)
else:
lhx, prx = self._lh(x), self._prior(x)
mtr = SamplingEnabler(lhx.metric, prx.metric.inverse,
self._ic_samp, prx.metric.inverse)
mtr = SamplingEnabler(lhx.metric, prx.metric,
self._ic_samp)
return (lhx + prx).add_metric(mtr)
def __repr__(self):
......
......@@ -75,8 +75,8 @@ class LinearInterpolator(LinearOperator):
# dimensions.
dist = np.array(dist).reshape(-1, 1)
pos = sampling_points/dist
excess = pos - pos.astype(np.int64)
pos = pos.astype(np.int64)
excess = pos - np.floor(pos)
pos = np.floor(pos).astype(np.int64)
max_index = np.array(self.domain.shape).reshape(-1, 1)
data = np.zeros((len(mg[0]), N_points))
ii = np.zeros((len(mg[0]), N_points), dtype=np.int64)
......
......@@ -33,7 +33,7 @@ class SamplingEnabler(EndomorphicOperator):
likelihood : :class:`EndomorphicOperator`
Metric of the likelihood
prior : :class:`EndomorphicOperator`
Inverse metric of the prior
Metric of the prior
iteration_controller : :class:`IterationController`
The iteration controller to use for the iterative numerical inversion
done by a :class:`ConjugateGradient` object.
......
......@@ -118,13 +118,47 @@ class FieldAdapter(LinearOperator):
return MultiField(self._tgt(mode), (x,))
def __repr__(self):
key = self.domain.keys()[0]
return 'FieldAdapter: {}'.format(key)
return 'FieldAdapter'
class _SlowFieldAdapter(LinearOperator):
"""Operator for conversion between Fields and MultiFields.
The operator is built so that the MultiDomain is always the target.
Its domain is `tgt[name]`
Parameters
----------
dom : dict or MultiDomain:
the operator's dom
name : String
The relevant key of the MultiDomain.
"""
def __init__(self, dom, name):
from ..sugar import makeDomain
tmp = makeDomain(dom)
if not isinstance(tmp, MultiDomain):
raise TypeError("MultiDomain expected")
self._name = str(name)
self._domain = tmp
self._target = tmp[name]
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if isinstance(x, MultiField):
return x[self._name]
else:
return MultiField.from_dict(self._tgt(mode), {self._name: x})
def __repr__(self):
return '_SlowFieldAdapter'
def ducktape(left, right, name):
"""Convenience function creating an operator that converts between a
DomainTuple and a single-entry MultiDomain.
DomainTuple and a MultiDomain.
Parameters
----------
......@@ -146,36 +180,48 @@ def ducktape(left, right, name):
- `left` and `right` must not be both `None`, but one of them can (and
probably should) be `None`. In this case, the missing information is
inferred.
- the returned operator's domains are
- a `DomainTuple` and
- a `MultiDomain` with exactly one entry called `name` and the same
`DomainTuple`
Which of these is the domain and which is the target depends on the
input.
Returns
-------
FieldAdapter : an adapter operator converting between the two (possibly
partially inferred) domains.
FieldAdapter or _SlowFieldAdapter
an adapter operator converting between the two (possibly
partially inferred) domains.
"""
from ..sugar import makeDomain
from .operator import Operator
if isinstance(right, Operator):
right = right.target
elif right is not None:
right = makeDomain(right)
if isinstance(left, Operator):
left = left.domain
elif left is not None:
left = makeDomain(left)
if left is None: # need to infer left from right
if isinstance(right, Operator):
right = right.target
elif right is not None:
right = makeDomain(right)
if isinstance(right, MultiDomain):
left = right[name]
else:
left = MultiDomain.make({name: right})
else:
if isinstance(left, Operator):
left = left.domain
elif right is None: # need to infer right from left
if isinstance(left, MultiDomain):
right = left[name]
else:
right = MultiDomain.make({name: left})
lmulti = isinstance(left, MultiDomain)
rmulti = isinstance(right, MultiDomain)
if lmulti+rmulti != 1:
raise ValueError("need exactly one MultiDomain")
if lmulti:
if len(left) == 1:
return FieldAdapter(left, name)
else:
return _SlowFieldAdapter(left, name).adjoint
if rmulti:
if len(right) == 1:
return FieldAdapter(left, name)
else:
left = makeDomain(left)
return FieldAdapter(left, name)
return _SlowFieldAdapter(right, name)
raise ValueError("must not arrive here")
class GeometryRemover(LinearOperator):
......
......@@ -363,7 +363,7 @@ def makeOp(input):
return DiagonalOperator(input)
if isinstance(input, MultiField):
return BlockDiagonalOperator(
input.domain, {key: makeOp(val) for key, val in enumerate(input)})
input.domain, {key: makeOp(val) for key, val in input.items()})
raise NotImplementedError
......
......@@ -61,8 +61,8 @@ def testBinary(type1, type2, space, seed):
_make_linearization(type2, dom2, seed)
dom = ift.MultiDomain.union((dom1, dom2))
select_s1 = ift.ducktape(None, dom, "s1")
select_s2 = ift.ducktape(None, dom, "s2")
select_s1 = ift.ducktape(None, dom1, "s1")
select_s2 = ift.ducktape(None, dom2, "s2")
model = select_s1*select_s2
pos = ift.from_random("normal", dom)
ift.extra.check_jacobian_consistency(model, pos, ntries=20)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment