diff --git a/nifty5/field.py b/nifty5/field.py index 377e74e7a9be8f8d04bfeae43d03f44a4272f385..1516458f111aa652f69fc3907c4c1d03c2c333cc 100644 --- a/nifty5/field.py +++ b/nifty5/field.py @@ -49,6 +49,8 @@ class Field(object): def __init__(self, domain, val): if not isinstance(domain, DomainTuple): raise TypeError("domain must be of type DomainTuple") + if np.isscalar(val): + val = np.full((), val) if not isinstance(val, dobj.data_object): raise TypeError("val must be of type dobj.data_object") if domain.shape != val.shape: diff --git a/nifty5/operator.py b/nifty5/operator.py index d47d6df1c449929cfb9d7ed08725b2ef88f9f956..c3ff955c3650da5c35c20c698670a524bbea307e 100644 --- a/nifty5/operator.py +++ b/nifty5/operator.py @@ -11,9 +11,10 @@ from .multi.multi_field import MultiField class Linearization(object): - def __init__(self, val, jac): + def __init__(self, val, jac, metric=None): self._val = val self._jac = jac + self._metric = metric @property def domain(self): @@ -31,17 +32,25 @@ class Linearization(object): def jac(self): return self._jac + @property + def metric(self): + return self._metric + def __neg__(self): - return Linearization(-self._val, self._jac*(-1)) + return Linearization(-self._val, self._jac*(-1), + None if self._metric is None else self._metric*(-1)) def __add__(self, other): if isinstance(other, Linearization): from .operators.relaxed_sum_operator import RelaxedSumOperator + met = None + if self._metric is not None and other._metric is not None: + met = RelaxedSumOperator((self._metric, other._metric)) return Linearization( self._val.unite(other._val), - RelaxedSumOperator((self._jac, other._jac))) + RelaxedSumOperator((self._jac, other._jac)), met) if isinstance(other, (int, float, complex, Field, MultiField)): - return Linearization(self._val+other, self._jac) + return Linearization(self._val+other, self._jac, self._metric) def __radd__(self, other): return self.__add__(other) @@ -76,6 +85,11 @@ class Linearization(object): d1 = makeOp(other) return Linearization(self._val*other, d1*self._jac) + def sum(self): + from .sugar import full + from .operators.vdot_operator import VdotOperator + return Linearization(full((),self._val.sum()), + VdotOperator(full(self._jac.target,1))*self._jac) @staticmethod def make_var(field): from .operators.scaling_operator import ScalingOperator diff --git a/nifty5/operators/relaxed_sum_operator.py b/nifty5/operators/relaxed_sum_operator.py index b44bb77b23320a223bc81fb71ad12da7eeaabc57..aeac60ae77742c31664052abc577df4d355da4c9 100644 --- a/nifty5/operators/relaxed_sum_operator.py +++ b/nifty5/operators/relaxed_sum_operator.py @@ -60,6 +60,6 @@ class RelaxedSumOperator(LinearOperator): self._check_mode(mode) res = None for op in self._ops: - x = op.apply(x.extract(op._dom(mode)), mode) - res = x if res is None else res.unite(x) + tmp = op.apply(x.extract(op._dom(mode)), mode) + res = tmp if res is None else res.unite(tmp) return res