Commit ef4479aa authored by Martin Reinecke's avatar Martin Reinecke

massive reworking

parent a6e1e5e2
......@@ -59,8 +59,9 @@ if __name__ == '__main__':
# Generate mock data
p = R(sky)
mock_position = ift.from_random('normal', harmonic_space)
data = np.random.binomial(1, p(mock_position).local_data.astype(np.float64))
data = ift.Field.from_local_data(R.target, data)
tmp = p(mock_position).to_global_data().astype(np.float64)
data = np.random.binomial(1, tmp)
data = ift.Field.from_global_data(R.target, data)
# Compute likelihood and Hamiltonian
position = ift.from_random('normal', harmonic_space)
......
......@@ -141,7 +141,7 @@ class DomainTuple(object):
def __eq__(self, x):
if self is x:
return True
return self is DomainTuple.make(x)
return self._dom == x._dom
def __ne__(self, x):
return not self.__eq__(x)
......
......@@ -60,13 +60,13 @@ def _check_consistency(op, loc, tol, ntries, do_metric):
for i in range(50):
locmid = loc + 0.5*dir
linmid = op(Linearization.make_var(locmid))
dirder = linmid.jac(dir)/dirnorm
numgrad = (lin2.val-lin.val)/dirnorm
dirder = linmid.jac(dir)
numgrad = (lin2.val-lin.val)
xtol = tol * dirder.norm() / np.sqrt(dirder.size)
cond = (abs(numgrad-dirder) <= xtol).all()
if do_metric:
dgrad = linmid.metric(dir)/dirnorm
dgrad2 = (lin2.gradient-lin.gradient)/dirnorm
dgrad = linmid.metric(dir)
dgrad2 = (lin2.gradient-lin.gradient)
cond = cond and (abs(dgrad-dgrad2) <= xtol).all()
if cond:
break
......
......@@ -348,7 +348,7 @@ class Field(object):
raise TypeError("The dot-partner must be an instance of " +
"the NIFTy field class")
if x._domain is not self._domain:
if x._domain != self._domain:
raise ValueError("Domain mismatch")
ndom = len(self._domain)
......@@ -609,7 +609,7 @@ class Field(object):
"\n- val = " + repr(self._val)
def extract(self, dom):
if dom is not self._domain:
if dom != self._domain:
raise ValueError("domain mismatch")
return self
......@@ -623,13 +623,14 @@ class Field(object):
# if other is a field, make sure that the domains match
f = getattr(self._val, op)
if isinstance(other, Field):
if other._domain is not self._domain:
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
return Field(self._domain, f(other._val))
if np.isscalar(other):
return Field(self._domain, f(other))
return NotImplemented
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
......
......@@ -26,11 +26,12 @@ from ..operators.domain_distributor import DomainDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.power_distributor import PowerDistributor
from ..operators.operator import Operator
from ..operators.simple_linear_operators import FieldAdapter
class CorrelatedField(Operator):
def CorrelatedField(s_space, amplitude_model):
'''
Class for construction of correlated fields
Function for construction of correlated fields
Parameters
----------
......@@ -38,17 +39,14 @@ class CorrelatedField(Operator):
amplitude_model : model for correlation structure
'''
def __init__(self, s_space, amplitude_model):
h_space = s_space.get_default_codomain()
self._ht = HarmonicTransformOperator(h_space, s_space)
p_space = amplitude_model.target[0]
power_distributor = PowerDistributor(h_space, p_space)
self._A = power_distributor(amplitude_model)
self._domain = MultiDomain.union(
(amplitude_model.domain, MultiDomain.make({"xi": h_space})))
def apply(self, x):
return self._ht(self._A(x)*x["xi"])
h_space = s_space.get_default_codomain()
ht = HarmonicTransformOperator(h_space, s_space)
p_space = amplitude_model.target[0]
power_distributor = PowerDistributor(h_space, p_space)
A = power_distributor(amplitude_model)
domain = MultiDomain.union(
(amplitude_model.domain, MultiDomain.make({"xi": h_space})))
return ht(A*FieldAdapter(domain, "xi"))
# def make_mf_correlated_field(s_space_spatial, s_space_energy,
......
......@@ -13,6 +13,8 @@ class Linearization(object):
def __init__(self, val, jac, metric=None):
self._val = val
self._jac = jac
if self._val.domain != self._jac.target:
raise ValueError("domain mismatch")
self._metric = metric
@property
......@@ -61,13 +63,12 @@ class Linearization(object):
def __add__(self, other):
if isinstance(other, Linearization):
from .operators.relaxed_sum_operator import RelaxedSumOperator
met = None
if self._metric is not None and other._metric is not None:
met = RelaxedSumOperator((self._metric, other._metric))
met = self._metric._myadd(other._metric, False)
return Linearization(
self._val.unite(other._val),
RelaxedSumOperator((self._jac, other._jac)), met)
self._jac._myadd(other._jac, False), met)
if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac, self._metric)
......@@ -83,15 +84,20 @@ class Linearization(object):
def __mul__(self, other):
from .sugar import makeOp
if isinstance(other, Linearization):
if self.target != other.target:
raise ValueError("domain mismatch")
return Linearization(
self._val*other._val,
makeOp(other._val)(self._jac) + makeOp(self._val)(other._jac))
(makeOp(other._val)(self._jac))._myadd(
makeOp(self._val)(other._jac), False))
if np.isscalar(other):
if other == 1:
return self
met = None if self._metric is None else self._metric.scale(other)
return Linearization(self._val*other, self._jac.scale(other), met)
if isinstance(other, (Field, MultiField)):
if self.target != other.domain:
raise ValueError("domain mismatch")
return Linearization(self._val*other, makeOp(other)(self._jac))
def __rmul__(self, other):
......
......@@ -95,7 +95,7 @@ class MultiDomain(object):
def __eq__(self, x):
if self is x:
return True
return self is MultiDomain.make(x)
return self.items() == x.items()
def __ne__(self, x):
return not self.__eq__(x)
......@@ -115,7 +115,7 @@ class MultiDomain(object):
for dom in inp:
for key, subdom in zip(dom._keys, dom._domains):
if key in res:
if res[key] is not subdom:
if res[key] != subdom:
raise ValueError("domain mismatch")
else:
res[key] = subdom
......
......@@ -42,7 +42,7 @@ class MultiField(object):
raise ValueError("length mismatch")
for d, v in zip(domain._domains, val):
if isinstance(v, Field):
if v._domain is not d:
if v._domain != d:
raise ValueError("domain mismatch")
else:
raise TypeError("bad entry in val (must be Field)")
......@@ -103,7 +103,7 @@ class MultiField(object):
for dom in domain._domains))
def _check_domain(self, other):
if other._domain is not self._domain:
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
def vdot(self, x):
......@@ -216,7 +216,7 @@ class MultiField(object):
def _binary_op(self, other, op):
f = getattr(Field, op)
if isinstance(other, MultiField):
if self._domain is not other._domain:
if self._domain != other._domain:
raise ValueError("domain mismatch")
val = tuple(f(v1, v2)
for v1, v2 in zip(self._val, other._val))
......
......@@ -57,14 +57,14 @@ class BlockDiagonalOperator(EndomorphicOperator):
# return MultiField(self._domain, val)
def _combine_chain(self, op):
if self._domain is not op._domain:
if self._domain != op._domain:
raise ValueError("domain mismatch")
res = tuple(v1(v2) for v1, v2 in zip(self._ops, op._ops))
return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
if self._domain is not op._domain:
if self._domain != op._domain:
raise ValueError("domain mismatch")
res = tuple(SumOperator.make([v1, v2], [selfneg, opneg])
for v1, v2 in zip(self._ops, op._ops))
......
......@@ -44,7 +44,7 @@ class ChainOperator(LinearOperator):
from .diagonal_operator import DiagonalOperator
# Step 1: verify domains
for i in range(len(ops)-1):
if ops[i+1].target is not ops[i].domain:
if ops[i+1].target != ops[i].domain:
raise ValueError("domain mismatch")
# Step 2: unpack ChainOperators
opsnew = []
......
......@@ -65,7 +65,7 @@ class DiagonalOperator(EndomorphicOperator):
self._domain = DomainTuple.make(domain)
if spaces is None:
self._spaces = None
if diagonal.domain is not self._domain:
if diagonal.domain != self._domain:
raise ValueError("domain mismatch")
else:
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
......
......@@ -62,5 +62,5 @@ class EndomorphicOperator(LinearOperator):
def _check_input(self, x, mode):
self._check_mode(mode)
if self.domain is not x.domain:
if self.domain != x.domain:
raise ValueError("The operator's and field's domains don't match.")
......@@ -85,7 +85,7 @@ class GaussianEnergy(EnergyOperator):
if self._domain is None:
self._domain = newdom
else:
if self._domain is not newdom:
if self._domain != newdom:
raise ValueError("domain mismatch")
def apply(self, x):
......@@ -157,6 +157,5 @@ class SampledKullbachLeiblerDivergence(EnergyOperator):
self._res_samples = tuple(res_samples)
def apply(self, x):
res = (utilities.my_sum(map(lambda v: self._h(x+v), self._res_samples)) *
(1./len(self._res_samples)))
return res
mymap = map(lambda v: self._h(x+v), self._res_samples)
return utilities.my_sum(mymap) * (1./len(self._res_samples))
......@@ -116,10 +116,16 @@ class LinearOperator(Operator):
return ChainOperator.make([other, self])
return Operator.__rmatmul__(self, other)
def _myadd(self, other, oneg):
if self.domain == other.domain and self.target == other.target:
from .sum_operator import SumOperator
return SumOperator.make((self, other), (False, oneg))
from .relaxed_sum_operator import RelaxedSumOperator
return RelaxedSumOperator((self, -other if oneg else other))
def __add__(self, other):
if isinstance(other, LinearOperator):
from .sum_operator import SumOperator
return SumOperator.make([self, other], [False, False])
return self._myadd(other, False)
return Operator.__add__(self, other)
def __radd__(self, other):
......@@ -127,14 +133,12 @@ class LinearOperator(Operator):
def __sub__(self, other):
if isinstance(other, LinearOperator):
from .sum_operator import SumOperator
return SumOperator.make([self, other], [False, True])
return self._myadd(other, True)
return Operator.__sub__(self, other)
def __rsub__(self, other):
if isinstance(other, LinearOperator):
from .sum_operator import SumOperator
return SumOperator.make([other, self], [False, True])
return other._myadd(self, True)
return Operator.__rsub__(self, other)
@property
......@@ -260,5 +264,5 @@ class LinearOperator(Operator):
def _check_input(self, x, mode):
self._check_mode(mode)
if self._dom(mode) is not x.domain:
if self._dom(mode) != x.domain:
raise ValueError("The operator's and field's domains don't match.")
......@@ -50,15 +50,15 @@ class Operator(NiftyMetaBase()):
def __mul__(self, x):
if not isinstance(x, Operator):
return NotImplemented
return _OpProd.make((self, x))
return _OpProd(self, x)
def apply(self, x):
raise NotImplementedError
def __call__(self, x):
if isinstance(x, Operator):
return _OpChain.make((self, x))
return self.apply(x)
if isinstance(x, Operator):
return _OpChain.make((self, x))
return self.apply(x)
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
......@@ -108,6 +108,9 @@ class _OpChain(_CombinedOperator):
super(_OpChain, self).__init__(ops, _callingfrommake)
self._domain = self._ops[-1].domain
self._target = self._ops[0].target
for i in range(1, len(self._ops)):
if self._ops[i-1].domain != self._ops[i].target:
raise ValueError("domain mismatch")
def apply(self, x):
for op in reversed(self._ops):
......@@ -115,21 +118,44 @@ class _OpChain(_CombinedOperator):
return x
class _OpProd(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False):
super(_OpProd, self).__init__(ops, _callingfrommake)
self._domain = self._ops[0].domain
self._target = self._ops[0].target
class _OpProd(Operator):
def __init__(self, op1, op2):
from ..sugar import domain_union
self._domain = domain_union((op1.domain, op2.domain))
self._target = op1.target
if op1.target != op2.target:
raise ValueError("target mismatch")
self._op1 = op1
self._op2 = op2
def apply(self, x):
return my_product(map(lambda op: op(x), self._ops))
from ..linearization import Linearization
from ..sugar import makeOp
lin = isinstance(x, Linearization)
if not lin:
r1 = self._op1(x.extract(self._op1.domain))
r2 = self._op2(x.extract(self._op2.domain))
return r1*r2
lin1 = self._op1(
Linearization.make_var(x._val.extract(self._op1.domain)))
lin2 = self._op2(
Linearization.make_var(x._val.extract(self._op2.domain)))
op = (makeOp(lin1._val)(lin2._jac))._myadd(
makeOp(lin2._val)(lin1._jac), False)
jac = op(x.jac)
return Linearization(lin1._val*lin2._val, jac)
class _OpSum(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False):
from ..sugar import domain_union
super(_OpSum, self).__init__(ops, _callingfrommake)
self._domain = domain_union([op.domain for op in self._ops])
self._target = domain_union([op.target for op in self._ops])
def apply(self, x):
raise NotImplementedError
res = None
for op in self._ops:
tmp = op(x.extract(op.domain))
res = tmp if res is None else res.unite(tmp)
return res
......@@ -38,12 +38,6 @@ class RelaxedSumOperator(LinearOperator):
self._capability = self.TIMES | self.ADJOINT_TIMES
for op in ops:
self._capability &= op.capability
#self._ops = []
#for op in ops:
# if isinstance(op, RelaxedSumOperator):
# self._ops += op._ops
# else:
# self._ops += [op]
@property
def adjoint(self):
......
......@@ -36,7 +36,7 @@ class VdotOperator(LinearOperator):
self._field = field
self._domain = field.domain
self._target = DomainTuple.scalar_domain()
self._capability = self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_mode(mode)
......@@ -49,7 +49,7 @@ class SumReductionOperator(LinearOperator):
def __init__(self, domain):
self._domain = domain
self._target = DomainTuple.scalar_domain()
self._capability = self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -61,7 +61,7 @@ class SumReductionOperator(LinearOperator):
class ConjugationOperator(EndomorphicOperator):
def __init__(self, domain):
self._domain = domain
self._capability = self._all_ops
self._capability = self._all_ops
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -71,7 +71,7 @@ class ConjugationOperator(EndomorphicOperator):
class Realizer(EndomorphicOperator):
def __init__(self, domain):
self._domain = domain
self._capability = self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -79,20 +79,17 @@ class Realizer(EndomorphicOperator):
class FieldAdapter(LinearOperator):
def __init__(self, dom, name_dom):
self._domain = MultiDomain.make(dom)
self._name = name_dom
self._target = dom[name_dom]
self._capability = self.TIMES | self.ADJOINT_TIMES
def __init__(self, dom, name):
self._target = dom[name]
self._domain = MultiDomain.make({name: self._target})
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x[self._name]
values = tuple(Field(dom, 0.) if key != self._name else x
for key, dom in self._domain.items())
return MultiField(self._domain, values)
return x.values()[0]
return MultiField(self._domain, (x,))
class GeometryRemover(LinearOperator):
......@@ -115,7 +112,7 @@ class GeometryRemover(LinearOperator):
self._domain = DomainTuple.make(domain)
target_list = [UnstructuredDomain(dom.shape) for dom in self._domain]
self._target = DomainTuple.make(target_list)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -137,7 +134,7 @@ class NullOperator(LinearOperator):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
@staticmethod
def _nullfield(dom):
......
......@@ -30,7 +30,7 @@ from .. import utilities
class SymmetrizingOperator(EndomorphicOperator):
def __init__(self, domain, space=0):
self._domain = DomainTuple.make(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
if not (isinstance(dom, LogRGSpace) and not dom.harmonic):
......
......@@ -246,7 +246,7 @@ def makeOp(input):
def domain_union(domains):
if isinstance(domains[0], DomainTuple):
if any(dom is not domains[0] for dom in domains[1:]):
if any(dom != domains[0] for dom in domains[1:]):
raise ValueError("domain mismatch")
return domains[0]
return MultiDomain.union(domains)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment