diff --git a/demos/bernoulli_demo.py b/demos/bernoulli_demo.py index 5ae269ead480611024001231f488d53aa1d00af3..9449424c9bade9e391960be55b367059268a9e76 100644 --- a/demos/bernoulli_demo.py +++ b/demos/bernoulli_demo.py @@ -59,8 +59,9 @@ if __name__ == '__main__': # Generate mock data p = R(sky) mock_position = ift.from_random('normal', harmonic_space) - data = np.random.binomial(1, p(mock_position).local_data.astype(np.float64)) - data = ift.Field.from_local_data(R.target, data) + tmp = p(mock_position).to_global_data().astype(np.float64) + data = np.random.binomial(1, tmp) + data = ift.Field.from_global_data(R.target, data) # Compute likelihood and Hamiltonian position = ift.from_random('normal', harmonic_space) diff --git a/nifty5/domain_tuple.py b/nifty5/domain_tuple.py index 64d5b53adc9e303cc8f3641355fde20ab10d9e5f..edf923ea3cbc398a635674f7ff9e1694258559cf 100644 --- a/nifty5/domain_tuple.py +++ b/nifty5/domain_tuple.py @@ -141,7 +141,7 @@ class DomainTuple(object): def __eq__(self, x): if self is x: return True - return self is DomainTuple.make(x) + return self._dom == x._dom def __ne__(self, x): return not self.__eq__(x) diff --git a/nifty5/extra/energy_and_model_tests.py b/nifty5/extra/energy_and_model_tests.py index 32641d898548b16b52665167b858c7672cd42b79..40b93a14447d7fef1648f6b1370d3ec0abbadc99 100644 --- a/nifty5/extra/energy_and_model_tests.py +++ b/nifty5/extra/energy_and_model_tests.py @@ -60,13 +60,13 @@ def _check_consistency(op, loc, tol, ntries, do_metric): for i in range(50): locmid = loc + 0.5*dir linmid = op(Linearization.make_var(locmid)) - dirder = linmid.jac(dir)/dirnorm - numgrad = (lin2.val-lin.val)/dirnorm + dirder = linmid.jac(dir) + numgrad = (lin2.val-lin.val) xtol = tol * dirder.norm() / np.sqrt(dirder.size) cond = (abs(numgrad-dirder) <= xtol).all() if do_metric: - dgrad = linmid.metric(dir)/dirnorm - dgrad2 = (lin2.gradient-lin.gradient)/dirnorm + dgrad = linmid.metric(dir) + dgrad2 = (lin2.gradient-lin.gradient) cond = cond and (abs(dgrad-dgrad2) <= xtol).all() if cond: break diff --git a/nifty5/field.py b/nifty5/field.py index c6a8f9852e9f8a0aede3511f9bcd557f4962262d..d809d6ea6ff418684709e2a2d2eb4cd25694af4e 100644 --- a/nifty5/field.py +++ b/nifty5/field.py @@ -348,7 +348,7 @@ class Field(object): raise TypeError("The dot-partner must be an instance of " + "the NIFTy field class") - if x._domain is not self._domain: + if x._domain != self._domain: raise ValueError("Domain mismatch") ndom = len(self._domain) @@ -609,7 +609,7 @@ class Field(object): "\n- val = " + repr(self._val) def extract(self, dom): - if dom is not self._domain: + if dom != self._domain: raise ValueError("domain mismatch") return self @@ -623,13 +623,14 @@ class Field(object): # if other is a field, make sure that the domains match f = getattr(self._val, op) if isinstance(other, Field): - if other._domain is not self._domain: + if other._domain != self._domain: raise ValueError("domains are incompatible.") return Field(self._domain, f(other._val)) if np.isscalar(other): return Field(self._domain, f(other)) return NotImplemented + for op in ["__add__", "__radd__", "__sub__", "__rsub__", "__mul__", "__rmul__", diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index 4c33a7fe473030f5e3c1be740ade21d7eede36a4..37d2c0063d20fef91a16148049c2d57b1a4ca94b 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -26,11 +26,12 @@ from ..operators.domain_distributor import DomainDistributor from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.power_distributor import PowerDistributor from ..operators.operator import Operator +from ..operators.simple_linear_operators import FieldAdapter -class CorrelatedField(Operator): +def CorrelatedField(s_space, amplitude_model): ''' - Class for construction of correlated fields + Function for construction of correlated fields Parameters ---------- @@ -38,17 +39,14 @@ class CorrelatedField(Operator): amplitude_model : model for correlation structure ''' - def __init__(self, s_space, amplitude_model): - h_space = s_space.get_default_codomain() - self._ht = HarmonicTransformOperator(h_space, s_space) - p_space = amplitude_model.target[0] - power_distributor = PowerDistributor(h_space, p_space) - self._A = power_distributor(amplitude_model) - self._domain = MultiDomain.union( - (amplitude_model.domain, MultiDomain.make({"xi": h_space}))) - - def apply(self, x): - return self._ht(self._A(x)*x["xi"]) + h_space = s_space.get_default_codomain() + ht = HarmonicTransformOperator(h_space, s_space) + p_space = amplitude_model.target[0] + power_distributor = PowerDistributor(h_space, p_space) + A = power_distributor(amplitude_model) + domain = MultiDomain.union( + (amplitude_model.domain, MultiDomain.make({"xi": h_space}))) + return ht(A*FieldAdapter(domain, "xi")) # def make_mf_correlated_field(s_space_spatial, s_space_energy, diff --git a/nifty5/linearization.py b/nifty5/linearization.py index 2c4b679743f8bcc3c0e75e561af694da363952b3..0e2529be679f4273821ae2303f7469ab96be8a45 100644 --- a/nifty5/linearization.py +++ b/nifty5/linearization.py @@ -13,6 +13,8 @@ class Linearization(object): def __init__(self, val, jac, metric=None): self._val = val self._jac = jac + if self._val.domain != self._jac.target: + raise ValueError("domain mismatch") self._metric = metric @property @@ -61,13 +63,12 @@ class Linearization(object): def __add__(self, other): if isinstance(other, Linearization): - from .operators.relaxed_sum_operator import RelaxedSumOperator met = None if self._metric is not None and other._metric is not None: - met = RelaxedSumOperator((self._metric, other._metric)) + met = self._metric._myadd(other._metric, False) return Linearization( self._val.unite(other._val), - RelaxedSumOperator((self._jac, other._jac)), met) + self._jac._myadd(other._jac, False), met) if isinstance(other, (int, float, complex, Field, MultiField)): return Linearization(self._val+other, self._jac, self._metric) @@ -83,15 +84,20 @@ class Linearization(object): def __mul__(self, other): from .sugar import makeOp if isinstance(other, Linearization): + if self.target != other.target: + raise ValueError("domain mismatch") return Linearization( self._val*other._val, - makeOp(other._val)(self._jac) + makeOp(self._val)(other._jac)) + (makeOp(other._val)(self._jac))._myadd( + makeOp(self._val)(other._jac), False)) if np.isscalar(other): if other == 1: return self met = None if self._metric is None else self._metric.scale(other) return Linearization(self._val*other, self._jac.scale(other), met) if isinstance(other, (Field, MultiField)): + if self.target != other.domain: + raise ValueError("domain mismatch") return Linearization(self._val*other, makeOp(other)(self._jac)) def __rmul__(self, other): diff --git a/nifty5/multi_domain.py b/nifty5/multi_domain.py index 47b8407896bc44ccb061bd39a500686cfe48c8de..39c8f68a635448b4aa4151d4259f25811074bd26 100644 --- a/nifty5/multi_domain.py +++ b/nifty5/multi_domain.py @@ -95,7 +95,7 @@ class MultiDomain(object): def __eq__(self, x): if self is x: return True - return self is MultiDomain.make(x) + return self.items() == x.items() def __ne__(self, x): return not self.__eq__(x) @@ -115,7 +115,7 @@ class MultiDomain(object): for dom in inp: for key, subdom in zip(dom._keys, dom._domains): if key in res: - if res[key] is not subdom: + if res[key] != subdom: raise ValueError("domain mismatch") else: res[key] = subdom diff --git a/nifty5/multi_field.py b/nifty5/multi_field.py index 25862cb7d7579463929023f4206122f6096c0ebd..02dd163082320112010242cf34c298229f14b284 100644 --- a/nifty5/multi_field.py +++ b/nifty5/multi_field.py @@ -42,7 +42,7 @@ class MultiField(object): raise ValueError("length mismatch") for d, v in zip(domain._domains, val): if isinstance(v, Field): - if v._domain is not d: + if v._domain != d: raise ValueError("domain mismatch") else: raise TypeError("bad entry in val (must be Field)") @@ -103,7 +103,7 @@ class MultiField(object): for dom in domain._domains)) def _check_domain(self, other): - if other._domain is not self._domain: + if other._domain != self._domain: raise ValueError("domains are incompatible.") def vdot(self, x): @@ -216,7 +216,7 @@ class MultiField(object): def _binary_op(self, other, op): f = getattr(Field, op) if isinstance(other, MultiField): - if self._domain is not other._domain: + if self._domain != other._domain: raise ValueError("domain mismatch") val = tuple(f(v1, v2) for v1, v2 in zip(self._val, other._val)) diff --git a/nifty5/operators/block_diagonal_operator.py b/nifty5/operators/block_diagonal_operator.py index eaa0b4a7ad0ddd281c0a64f3e00fe366045bad8f..2006ff11655841000c195b55f07783076d45389d 100644 --- a/nifty5/operators/block_diagonal_operator.py +++ b/nifty5/operators/block_diagonal_operator.py @@ -57,14 +57,14 @@ class BlockDiagonalOperator(EndomorphicOperator): # return MultiField(self._domain, val) def _combine_chain(self, op): - if self._domain is not op._domain: + if self._domain != op._domain: raise ValueError("domain mismatch") res = tuple(v1(v2) for v1, v2 in zip(self._ops, op._ops)) return BlockDiagonalOperator(self._domain, res) def _combine_sum(self, op, selfneg, opneg): from ..operators.sum_operator import SumOperator - if self._domain is not op._domain: + if self._domain != op._domain: raise ValueError("domain mismatch") res = tuple(SumOperator.make([v1, v2], [selfneg, opneg]) for v1, v2 in zip(self._ops, op._ops)) diff --git a/nifty5/operators/chain_operator.py b/nifty5/operators/chain_operator.py index 16365a7b3919482b65c1c037fea5bb97079d6d4a..eb849e7db0103da37b1379e2bd4f7281cbb7be2d 100644 --- a/nifty5/operators/chain_operator.py +++ b/nifty5/operators/chain_operator.py @@ -44,7 +44,7 @@ class ChainOperator(LinearOperator): from .diagonal_operator import DiagonalOperator # Step 1: verify domains for i in range(len(ops)-1): - if ops[i+1].target is not ops[i].domain: + if ops[i+1].target != ops[i].domain: raise ValueError("domain mismatch") # Step 2: unpack ChainOperators opsnew = [] diff --git a/nifty5/operators/diagonal_operator.py b/nifty5/operators/diagonal_operator.py index 5864ae038acead4c04fbd381dcf802ba685bd1cf..fbdba3982813f5379068025fc43b19866cc353d7 100644 --- a/nifty5/operators/diagonal_operator.py +++ b/nifty5/operators/diagonal_operator.py @@ -65,7 +65,7 @@ class DiagonalOperator(EndomorphicOperator): self._domain = DomainTuple.make(domain) if spaces is None: self._spaces = None - if diagonal.domain is not self._domain: + if diagonal.domain != self._domain: raise ValueError("domain mismatch") else: self._spaces = utilities.parse_spaces(spaces, len(self._domain)) diff --git a/nifty5/operators/endomorphic_operator.py b/nifty5/operators/endomorphic_operator.py index b1d5a07be1bb575ddf19c456688ad6f696247e8e..96c3da84d164c59389c57146cdd1fcafbae0b108 100644 --- a/nifty5/operators/endomorphic_operator.py +++ b/nifty5/operators/endomorphic_operator.py @@ -62,5 +62,5 @@ class EndomorphicOperator(LinearOperator): def _check_input(self, x, mode): self._check_mode(mode) - if self.domain is not x.domain: + if self.domain != x.domain: raise ValueError("The operator's and field's domains don't match.") diff --git a/nifty5/operators/energy_operators.py b/nifty5/operators/energy_operators.py index 55d8afec5766e032d95ba327e490ac7e38b2f231..79747869604e8bec175febc2ea46cef88fb70e67 100644 --- a/nifty5/operators/energy_operators.py +++ b/nifty5/operators/energy_operators.py @@ -85,7 +85,7 @@ class GaussianEnergy(EnergyOperator): if self._domain is None: self._domain = newdom else: - if self._domain is not newdom: + if self._domain != newdom: raise ValueError("domain mismatch") def apply(self, x): @@ -157,6 +157,5 @@ class SampledKullbachLeiblerDivergence(EnergyOperator): self._res_samples = tuple(res_samples) def apply(self, x): - res = (utilities.my_sum(map(lambda v: self._h(x+v), self._res_samples)) * - (1./len(self._res_samples))) - return res + mymap = map(lambda v: self._h(x+v), self._res_samples) + return utilities.my_sum(mymap) * (1./len(self._res_samples)) diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py index 4cca76fa39e5994e5d5d1be194acdecd4b35a72a..4e39d0aee517eb78cbcb6657d202b2c14ccae4c1 100644 --- a/nifty5/operators/linear_operator.py +++ b/nifty5/operators/linear_operator.py @@ -116,10 +116,16 @@ class LinearOperator(Operator): return ChainOperator.make([other, self]) return Operator.__rmatmul__(self, other) + def _myadd(self, other, oneg): + if self.domain == other.domain and self.target == other.target: + from .sum_operator import SumOperator + return SumOperator.make((self, other), (False, oneg)) + from .relaxed_sum_operator import RelaxedSumOperator + return RelaxedSumOperator((self, -other if oneg else other)) + def __add__(self, other): if isinstance(other, LinearOperator): - from .sum_operator import SumOperator - return SumOperator.make([self, other], [False, False]) + return self._myadd(other, False) return Operator.__add__(self, other) def __radd__(self, other): @@ -127,14 +133,12 @@ class LinearOperator(Operator): def __sub__(self, other): if isinstance(other, LinearOperator): - from .sum_operator import SumOperator - return SumOperator.make([self, other], [False, True]) + return self._myadd(other, True) return Operator.__sub__(self, other) def __rsub__(self, other): if isinstance(other, LinearOperator): - from .sum_operator import SumOperator - return SumOperator.make([other, self], [False, True]) + return other._myadd(self, True) return Operator.__rsub__(self, other) @property @@ -260,5 +264,5 @@ class LinearOperator(Operator): def _check_input(self, x, mode): self._check_mode(mode) - if self._dom(mode) is not x.domain: + if self._dom(mode) != x.domain: raise ValueError("The operator's and field's domains don't match.") diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 24b5e435758c1c8d391a7668237729687d9a468a..04f91b1ececd7d0838266ba34c28f392ac8d293c 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -50,15 +50,15 @@ class Operator(NiftyMetaBase()): def __mul__(self, x): if not isinstance(x, Operator): return NotImplemented - return _OpProd.make((self, x)) + return _OpProd(self, x) def apply(self, x): raise NotImplementedError def __call__(self, x): - if isinstance(x, Operator): - return _OpChain.make((self, x)) - return self.apply(x) + if isinstance(x, Operator): + return _OpChain.make((self, x)) + return self.apply(x) for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: @@ -108,6 +108,9 @@ class _OpChain(_CombinedOperator): super(_OpChain, self).__init__(ops, _callingfrommake) self._domain = self._ops[-1].domain self._target = self._ops[0].target + for i in range(1, len(self._ops)): + if self._ops[i-1].domain != self._ops[i].target: + raise ValueError("domain mismatch") def apply(self, x): for op in reversed(self._ops): @@ -115,21 +118,44 @@ class _OpChain(_CombinedOperator): return x -class _OpProd(_CombinedOperator): - def __init__(self, ops, _callingfrommake=False): - super(_OpProd, self).__init__(ops, _callingfrommake) - self._domain = self._ops[0].domain - self._target = self._ops[0].target +class _OpProd(Operator): + def __init__(self, op1, op2): + from ..sugar import domain_union + self._domain = domain_union((op1.domain, op2.domain)) + self._target = op1.target + if op1.target != op2.target: + raise ValueError("target mismatch") + self._op1 = op1 + self._op2 = op2 def apply(self, x): - return my_product(map(lambda op: op(x), self._ops)) + from ..linearization import Linearization + from ..sugar import makeOp + lin = isinstance(x, Linearization) + if not lin: + r1 = self._op1(x.extract(self._op1.domain)) + r2 = self._op2(x.extract(self._op2.domain)) + return r1*r2 + lin1 = self._op1( + Linearization.make_var(x._val.extract(self._op1.domain))) + lin2 = self._op2( + Linearization.make_var(x._val.extract(self._op2.domain))) + op = (makeOp(lin1._val)(lin2._jac))._myadd( + makeOp(lin2._val)(lin1._jac), False) + jac = op(x.jac) + return Linearization(lin1._val*lin2._val, jac) class _OpSum(_CombinedOperator): def __init__(self, ops, _callingfrommake=False): + from ..sugar import domain_union super(_OpSum, self).__init__(ops, _callingfrommake) self._domain = domain_union([op.domain for op in self._ops]) self._target = domain_union([op.target for op in self._ops]) def apply(self, x): - raise NotImplementedError + res = None + for op in self._ops: + tmp = op(x.extract(op.domain)) + res = tmp if res is None else res.unite(tmp) + return res diff --git a/nifty5/operators/relaxed_sum_operator.py b/nifty5/operators/relaxed_sum_operator.py index 53530e2c593d580e2340364d62ba02ea1a1b0602..2a9cb68a1769867d8de17817124a44c42ec5d2d6 100644 --- a/nifty5/operators/relaxed_sum_operator.py +++ b/nifty5/operators/relaxed_sum_operator.py @@ -38,12 +38,6 @@ class RelaxedSumOperator(LinearOperator): self._capability = self.TIMES | self.ADJOINT_TIMES for op in ops: self._capability &= op.capability - #self._ops = [] - #for op in ops: - # if isinstance(op, RelaxedSumOperator): - # self._ops += op._ops - # else: - # self._ops += [op] @property def adjoint(self): diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py index b214d1c1c7802ab3b0045913692d20a85c9c3507..05d5f640afa301be1b0d36db6c2ef40a32ac6b1f 100644 --- a/nifty5/operators/simple_linear_operators.py +++ b/nifty5/operators/simple_linear_operators.py @@ -36,7 +36,7 @@ class VdotOperator(LinearOperator): self._field = field self._domain = field.domain self._target = DomainTuple.scalar_domain() - self._capability = self.TIMES | self.ADJOINT_TIMES + self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): self._check_mode(mode) @@ -49,7 +49,7 @@ class SumReductionOperator(LinearOperator): def __init__(self, domain): self._domain = domain self._target = DomainTuple.scalar_domain() - self._capability = self.TIMES | self.ADJOINT_TIMES + self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): self._check_input(x, mode) @@ -61,7 +61,7 @@ class SumReductionOperator(LinearOperator): class ConjugationOperator(EndomorphicOperator): def __init__(self, domain): self._domain = domain - self._capability = self._all_ops + self._capability = self._all_ops def apply(self, x, mode): self._check_input(x, mode) @@ -71,7 +71,7 @@ class ConjugationOperator(EndomorphicOperator): class Realizer(EndomorphicOperator): def __init__(self, domain): self._domain = domain - self._capability = self.TIMES | self.ADJOINT_TIMES + self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): self._check_input(x, mode) @@ -79,20 +79,17 @@ class Realizer(EndomorphicOperator): class FieldAdapter(LinearOperator): - def __init__(self, dom, name_dom): - self._domain = MultiDomain.make(dom) - self._name = name_dom - self._target = dom[name_dom] - self._capability = self.TIMES | self.ADJOINT_TIMES + def __init__(self, dom, name): + self._target = dom[name] + self._domain = MultiDomain.make({name: self._target}) + self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: - return x[self._name] - values = tuple(Field(dom, 0.) if key != self._name else x - for key, dom in self._domain.items()) - return MultiField(self._domain, values) + return x.values()[0] + return MultiField(self._domain, (x,)) class GeometryRemover(LinearOperator): @@ -115,7 +112,7 @@ class GeometryRemover(LinearOperator): self._domain = DomainTuple.make(domain) target_list = [UnstructuredDomain(dom.shape) for dom in self._domain] self._target = DomainTuple.make(target_list) - self._capability = self.TIMES | self.ADJOINT_TIMES + self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): self._check_input(x, mode) @@ -137,7 +134,7 @@ class NullOperator(LinearOperator): from ..sugar import makeDomain self._domain = makeDomain(domain) self._target = makeDomain(target) - self._capability = self.TIMES | self.ADJOINT_TIMES + self._capability = self.TIMES | self.ADJOINT_TIMES @staticmethod def _nullfield(dom): diff --git a/nifty5/operators/symmetrizing_operator.py b/nifty5/operators/symmetrizing_operator.py index b8f5370eed61199b093027e3baedb55acf2e8ee6..0681e3dffc83289e74a3ed996552568ed7497718 100644 --- a/nifty5/operators/symmetrizing_operator.py +++ b/nifty5/operators/symmetrizing_operator.py @@ -30,7 +30,7 @@ from .. import utilities class SymmetrizingOperator(EndomorphicOperator): def __init__(self, domain, space=0): self._domain = DomainTuple.make(domain) - self._capability = self.TIMES | self.ADJOINT_TIMES + self._capability = self.TIMES | self.ADJOINT_TIMES self._space = utilities.infer_space(self._domain, space) dom = self._domain[self._space] if not (isinstance(dom, LogRGSpace) and not dom.harmonic): diff --git a/nifty5/sugar.py b/nifty5/sugar.py index ed2bee0f3be7ab6132158890deef4b11ba278b20..fbe3c00353b4ea8f7018ed2d89ed428e0f7132e6 100644 --- a/nifty5/sugar.py +++ b/nifty5/sugar.py @@ -246,7 +246,7 @@ def makeOp(input): def domain_union(domains): if isinstance(domains[0], DomainTuple): - if any(dom is not domains[0] for dom in domains[1:]): + if any(dom != domains[0] for dom in domains[1:]): raise ValueError("domain mismatch") return domains[0] return MultiDomain.union(domains)