diff --git a/nifty4/multi/multi_field.py b/nifty4/multi/multi_field.py index d76223dac27a5f806d49e83d398061048853f265..3bbfd02a86c071fc933cdcb01b538002652c5efc 100644 --- a/nifty4/multi/multi_field.py +++ b/nifty4/multi/multi_field.py @@ -32,6 +32,13 @@ class MultiField(object): def dtype(self): return {key: val.dtype for key, val in self._val.items()} + @staticmethod + def from_random(random_type, domain, dtype=np.float64, **kwargs): + dtype = self.build_dtype(dtype) + return MultiField({key: Field.from_random(random_type, domain[key], + dtype[key], **kwargs) + for key in domain.keys}) + def _check_domain(self, other): if other.domain != self.domain: raise ValueError("domains are incompatible.") diff --git a/nifty4/operators/scaling_operator.py b/nifty4/operators/scaling_operator.py index b339b934051c68d4a3a15b16f4bb7e22311b2cb9..b9a26bfaf32c569e25f88a0e6fb1e41d8590808d 100644 --- a/nifty4/operators/scaling_operator.py +++ b/nifty4/operators/scaling_operator.py @@ -19,6 +19,7 @@ from __future__ import division import numpy as np from ..field import Field +from ..multi.multi_field import MultiField from ..domain_tuple import DomainTuple from .endomorphic_operator import EndomorphicOperator @@ -61,7 +62,7 @@ class ScalingOperator(EndomorphicOperator): if self._factor == 1.: return x.copy() if self._factor == 0.: - return Field.zeros_like(x) + return x.zeros_like(x) if mode == self.TIMES: return x*self._factor @@ -103,5 +104,6 @@ class ScalingOperator(EndomorphicOperator): if fct.real == 0. and from_inverse: raise ValueError("operator not positive definite") fct = 1./np.sqrt(fct) if from_inverse else np.sqrt(fct) - return Field.from_random( + cls = Field if isinstance(self._domain, DomainTuple) else MultiField + return cls.from_random( random_type="normal", domain=self._domain, std=fct, dtype=dtype)