Commit 093f8b06 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'misc_work' into 'NIFTy_5'

Misc work, includes the remove_apply_data branch

See merge request ift/nifty-dev!44
parents 73532911 abd435e1
......@@ -82,7 +82,7 @@ if __name__ == '__main__':
# Minimize the Hamiltonian
H = ift.Hamiltonian(likelihood, ic_sampling)
H = H.makeInvertible(ic_cg)
H = H.make_invertible(ic_cg)
# minimizer = ift.SteepestDescent(ic_newton)
H, convergence = minimizer(H)
......
......@@ -98,7 +98,7 @@ if __name__ == '__main__':
# Minimize the Hamiltonian
H = ift.Hamiltonian(likelihood)
H = H.makeInvertible(ic_cg)
H = H.make_invertible(ic_cg)
H, convergence = minimizer(H)
# Plot results
......
......@@ -95,7 +95,7 @@ if __name__ == '__main__':
for _ in range(N_samples)]
KL = ift.SampledKullbachLeiblerDivergence(H, samples)
KL = KL.makeInvertible(ic_cg)
KL = KL.make_invertible(ic_cg)
KL, convergence = minimizer(KL)
position = KL.position
......
......@@ -76,13 +76,11 @@ from .sugar import *
from .plotting.plot import plot
from .library.amplitude_model import make_amplitude_model
from .library.apply_data import ApplyData
from .library.gaussian_energy import GaussianEnergy
from .library.los_response import LOSResponse
from .library.point_sources import PointSources
from .library.poissonian_energy import PoissonianEnergy
from .library.wiener_filter_curvature import WienerFilterCurvature
from .library.wiener_filter_energy import WienerFilterEnergy
from .library.correlated_fields import (make_correlated_field,
make_mf_correlated_field)
from .library.bernoulli_energy import BernoulliEnergy
......
......@@ -32,10 +32,10 @@ class Hamiltonian(Energy):
lh: Likelihood (energy object)
prior:
"""
super(Hamiltonian, self).__init__(lh.position)
super(Hamiltonian, self).__init__(lh._position)
self._lh = lh
self._ic_samp = iteration_controller_sampling
self._prior = GaussianEnergy(Variable(self.position))
self._prior = GaussianEnergy(Variable(self._position))
def at(self, position):
return self.__class__(self._lh.at(position), self._ic_samp)
......
......@@ -607,11 +607,6 @@ class Field(object):
return False
return (self._val == other._val).all()
def isSubsetOf(self, other):
"""Identical to `Field.isEquivalentTo()`. This method is provided for
easier interoperability with `MultiField`."""
return self.isEquivalentTo(other)
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......
......@@ -33,7 +33,7 @@ class GaussianEnergy(Energy):
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance
"""
super(GaussianEnergy, self).__init__(inp.position)
super(GaussianEnergy, self).__init__(inp._position)
self._inp = inp
self._mean = mean
self._cov = covariance
......
......@@ -71,16 +71,6 @@ class PointSources(Model):
foo = invgamma.ppf(norm.cdf(field.local_data), alpha, scale=q)
return Field.from_local_data(field.domain, foo)
# MR FIXME: is this function needed?
@staticmethod
def IG_prime(field, alpha, q):
inner = norm.pdf(field.local_data)
outer = invgamma.pdf(invgamma.ppf(norm.cdf(field.local_data), alpha,
scale=q), alpha, scale=q)
# # FIXME
# outer = np.clip(outer, 1e-20, None)
return Field.from_local_data(field.domain, inner/outer)
# MR FIXME: why does this take an np.ndarray instead of a Field?
@staticmethod
def inverseIG(u, alpha, q):
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..minimization.quadratic_energy import QuadraticEnergy
from .wiener_filter_curvature import WienerFilterCurvature
def WienerFilterEnergy(position, d, R, N, S, iteration_controller=None,
iteration_controller_sampling=None):
"""The Energy for the Wiener filter.
It covers the case of linear measurement with
Gaussian noise and Gaussian signal prior with known covariance.
Parameters
----------
position : Field
The current map in harmonic space.
d : Field
the data
R : LinearOperator
The response operator, description of the measurement process. It needs
to map from harmonic signal space to data space.
N : EndomorphicOperator
The noise covariance in data space.
S : EndomorphicOperator
The prior signal covariance in harmonic space.
inverter : Minimizer, optional
the minimization strategy to use for operator inversion
If None, the energy object will not support curvature computation.
sampling_inverter : Minimizer, optional
The minimizer to use during numerical sampling
if None, it is not possible to draw inverse samples
default: None
"""
op = WienerFilterCurvature(R, N, S, iteration_controller,
iteration_controller_sampling)
vec = R.adjoint_times(N.inverse_times(d))
return QuadraticEnergy(position, op, vec)
......@@ -129,7 +129,7 @@ class Energy(NiftyMetaBase()):
"""
return None
def makeInvertible(self, controller, preconditioner=None):
def make_invertible(self, controller, preconditioner=None):
from .iteration_controller import IterationController
if not isinstance(controller, IterationController):
raise TypeError
......@@ -169,8 +169,6 @@ class MetricInversionEnabler(Energy):
self._preconditioner = preconditioner
def at(self, position):
if self._position.isSubsetOf(position):
return self
return MetricInversionEnabler(
self._energy.at(position), self._controller, self._preconditioner)
......
......@@ -36,14 +36,14 @@ class QuadraticEnergy(Energy):
self._grad = _grad
Ax = _grad if b is None else _grad + b
else:
Ax = self._A(self.position)
Ax = self._A(self._position)
self._grad = Ax if b is None else Ax - b
self._value = 0.5*self.position.vdot(Ax)
self._value = 0.5*self._position.vdot(Ax)
if b is not None:
self._value -= b.vdot(self.position)
self._value -= b.vdot(self._position)
def at(self, position):
return QuadraticEnergy(position=position, A=self._A, b=self._b)
return QuadraticEnergy(position, self._A, self._b)
def at_with_grad(self, position, grad):
""" Specialized version of `at`, taking also a gradient.
......@@ -64,8 +64,7 @@ class QuadraticEnergy(Energy):
Energy
Energy object at new position.
"""
return QuadraticEnergy(position=position, A=self._A, b=self._b,
_grad=grad)
return QuadraticEnergy(position, self._A, self._b, grad)
@property
def value(self):
......
......@@ -19,6 +19,7 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..operators.null_operator import NullOperator
from .model import Model
......@@ -44,7 +45,7 @@ class Constant(Model):
self._constant = constant
self._value = self._constant
self._jacobian = 0.
self._jacobian = NullOperator(position.domain, constant.domain)
def at(self, position):
return self.__class__(position, self._constant)
......@@ -36,7 +36,7 @@ class MultiModel(Model):
val = self._model.value
if not isinstance(val.domain, DomainTuple):
raise TypeError
self._value = MultiField({key: val})
self._value = MultiField.from_dict({key: val})
self._jacobian = (MultiAdaptor(self.value.domain) *
self._model.jacobian)
......
......@@ -20,6 +20,7 @@ from __future__ import absolute_import, division, print_function
from ..compat import *
from ..operators.endomorphic_operator import EndomorphicOperator
from .multi_domain import MultiDomain
from .multi_field import MultiField
......@@ -33,8 +34,12 @@ class BlockDiagonalOperator(EndomorphicOperator):
LinearOperators as items
"""
super(BlockDiagonalOperator, self).__init__()
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
if not isinstance(operators, tuple):
raise TypeError("tuple expected")
self._domain = domain
self._ops = tuple(operators[key] for key in self.domain.keys())
self._ops = operators
self._cap = self._all_ops
for op in self._ops:
if op is not None:
......@@ -63,15 +68,13 @@ class BlockDiagonalOperator(EndomorphicOperator):
def _combine_chain(self, op):
if self._domain is not op._domain:
raise ValueError("domain mismatch")
res = {key: v1*v2
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
res = tuple(v1*v2 for v1, v2 in zip(self._ops, op._ops))
return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
if self._domain is not op._domain:
raise ValueError("domain mismatch")
res = {}
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops):
res[key] = SumOperator.make([v1, v2], [selfneg, opneg])
res = tuple(SumOperator.make([v1, v2], [selfneg, opneg])
for v1, v2 in zip(self._ops, op._ops))
return BlockDiagonalOperator(self._domain, res)
......@@ -31,7 +31,7 @@ class MultiField(object):
Parameters
----------
domain: MultiDomain
val: tuple of Fields
val: tuple containing Field or None entries
"""
if not isinstance(domain, MultiDomain):
raise TypeError("domain must be of type MultiDomain")
......@@ -39,12 +39,12 @@ class MultiField(object):
raise TypeError("val must be a tuple")
if len(val) != len(domain):
raise ValueError("length mismatch")
for i, v in enumerate(val):
for d, v in zip(domain._domains, val):
if isinstance(v, Field):
if v._domain is not domain._domains[i]:
if v._domain is not d:
raise ValueError("domain mismatch")
elif v is not None:
raise TypeError("bad entry in val")
raise TypeError("bad entry in val (must be Field or None)")
self._domain = domain
self._val = val
......@@ -192,26 +192,45 @@ class MultiField(object):
return False
return True
def isSubsetOf(self, other):
"""Determines (as quickly as possible) whether `self`'s content is
a subset of `other`'s content."""
if self is other:
return True
if not isinstance(other, MultiField):
return False
if len(set(self._domain.keys()) - set(other._domain.keys())) > 0:
return False
for key in self._domain.keys():
if other._domain[key] is not self._domain[key]:
return False
if not other[key].isSubsetOf(self[key]):
return False
return True
for op in ["__add__", "__radd__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = []
for v1, v2 in zip(self._val, other._val):
if v1 is not None:
val.append(v1 if v2 is None else (v1+v2))
else:
val.append(None if v2 is None else v2)
val = tuple(val)
else:
val = tuple(other if v1 is None else (v1+other)
for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
for op in ["__mul__", "__rmul__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = tuple(None if v1 is None or v2 is None else v1*v2
for v1, v2 in zip(self._val, other._val))
else:
val = tuple(None if v1 is None else (v1*other)
for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__sub__", "__rsub__",
"__div__", "__rdiv__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
......@@ -219,27 +238,18 @@ for op in ["__add__", "__radd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
res = []
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
for v1, v2 in zip(self._val, other._val):
if v1 is not None:
if v2 is None:
res.append(getattr(v1, op)(v1*0))
else:
res.append(getattr(v1, op)(v2))
else:
if v2 is None:
res.append(None)
else:
res.append(getattr(v2*0, op)(v2))
return MultiField(self._domain, tuple(res))
val = tuple(getattr(v1, op)(v2)
for v1, v2 in zip (self._val, other._val))
else:
return self._transform(lambda x: getattr(x, op)(other))
val = tuple(getattr(v1, op)(other) for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"__itruediv__", "__ifloordiv__", "__ipow__"]:
def func(op):
......
......@@ -17,15 +17,50 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field
from ..multi.multi_field import MultiField
from ..operators.linear_operator import LinearOperator
class NullOperator(LinearOperator):
"""Operator corresponding to a matrix of all zeros.
Parameters
----------
domain : DomainTuple or MultiDomain
input domain
target : DomainTuple or MultiDomain
output domain
"""
def __init__(self, domain, target):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
@staticmethod
def _nullfield(dom):
if isinstance (dom, DomainTuple):
return Field.full(dom, 0)
else:
return MultiField(dom, (None,)*len(dom))
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return self._nullfield(self._target)
return self._nullfield(self._domain)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
def ApplyData(data, var, model_data):
# TODO This is rather confusing. Delete that eventually.
from ..operators.diagonal_operator import DiagonalOperator
from ..models.constant import Constant
from ..sugar import sqrt
sqrt_n = DiagonalOperator(sqrt(var))
data = Constant(model_data.position, data)
return sqrt_n.inverse(model_data - data)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
......@@ -50,8 +50,8 @@ class SamplingEnabler(EndomorphicOperator):
def __init__(self, likelihood, prior, iteration_controller,
approximation=None):
self._op = likelihood + prior
super(SamplingEnabler, self).__init__()
self._op = likelihood + prior
self._likelihood = likelihood
self._prior = prior
self._ic = iteration_controller
......@@ -61,9 +61,8 @@ class SamplingEnabler(EndomorphicOperator):
try:
return self._op.draw_sample(from_inverse, dtype)
except NotImplementedError:
# MR FIXME: I think there is a silent assumption that
# from_inverse==True when we arrive here.
# Can we make this explicit?
if not from_inverse:
raise ValueError("from_inverse must be True here")
s = self._prior.draw_sample(from_inverse=True)
sp = self._prior(s)
nj = self._likelihood.draw_sample()
......
......@@ -22,14 +22,19 @@ import numpy as np
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.log_rg_space import LogRGSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from .linear_operator import LinearOperator
class SlopeOperator(LinearOperator):
def __init__(self, domain, target, sigmas):
# MR FIXME: check explicitly for the required domain types etc.
# Maybe compute domain from target automatically?
if not isinstance(target, LogRGSpace):
raise TypeError
if not (isinstance(domain, UnstructuredDomain) and domain.shape == (2,)):
raise TypeError
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(target)
......
......@@ -21,13 +21,15 @@ from __future__ import absolute_import, division, print_function
from .. import dobj
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.log_rg_space import LogRGSpace
from ..field import Field
from .endomorphic_operator import EndomorphicOperator
# MR FIXME: we should make sure that the domain is a harmonic RGSpace, correct?
class SymmetrizingOperator(EndomorphicOperator):
def __init__(self, domain):
if not (isinstance(domain, LogRGSpace) and not domain.harmonic):
raise TypeError
self._domain = DomainTuple.make(domain)
self._ndim = len(self.domain.shape)
......@@ -48,7 +50,7 @@ class SymmetrizingOperator(EndomorphicOperator):
tmp2[lead+(slice(1, None),)] -= tmp2[lead+(slice(None, 0, -1),)]
if i == ax:
tmp = dobj.redistribute(tmp, dist=ax)
return Field(self.target, val=tmp)
return Field(self.target, val=tmp)
@property
def capability(self):
......
......@@ -239,7 +239,7 @@ def makeOp(input):
return DiagonalOperator(input)
if isinstance(input, MultiField):
return BlockDiagonalOperator(
input.domain, {key: makeOp(val) for key, val in input.items()})
input.domain, tuple(makeOp(val) for val in input.values()))
raise NotImplementedError
# Arithmetic functions working on Fields
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment