From 87e34527c7ea2b48e0b68530dad3abb63a04ae92 Mon Sep 17 00:00:00 2001 From: Philipp Haim Date: Tue, 12 Nov 2019 21:11:38 +0100 Subject: [PATCH] Normal and log-normal operators get correct domain --- nifty5/library/correlated_fields.py | 39 ++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index 2f3272cc..3410f456 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -38,10 +38,9 @@ from ..operators.simple_linear_operators import VdotOperator, ducktape from ..operators.value_inserter import ValueInserter from ..sugar import from_global_data, from_random, full, makeDomain, get_default_codomain -def _reshaper(domain, x, space): - shape = reduce(lambda x,y: x+y, - (domain[i].shape for i in range(len(domain)) if i != space),()) +def _reshaper(x, shape): x = np.array(x) + if x.shape == shape: return np.asfarray(x) elif x.shape in [(), (1,)]: @@ -52,7 +51,7 @@ def _reshaper(domain, x, space): def _lognormal_moment_matching(mean, sig, key, domain = DomainTuple.scalar_domain(), space = 0): domain = makeDomain(domain) - mean, sig = (_reshaper(domain, param, space) for param in (mean, sig)) + mean, sig = (_reshaper(param, domain.shape) for param in (mean, sig)) key = str(key) assert np.all(mean > 0) assert np.all(sig > 0) @@ -64,7 +63,7 @@ def _lognormal_moment_matching(mean, sig, key, def _normal(mean, sig, key, domain = DomainTuple.scalar_domain(), space = 0): domain = makeDomain(domain) - mean, sig = (_reshaper(domain, param, space) for param in (mean, sig)) + mean, sig = (_reshaper(param, domain.shape) for param in (mean, sig)) assert np.all(sig > 0) return Adder(from_global_data(domain, mean)) @ ( sig*ducktape(domain, None, key)) @@ -289,14 +288,22 @@ class CorrelatedFieldMaker: loglogavgslope_stddev, prefix, space = 0): prefix = str(prefix) + parameter_domain = list(makeDomain(target)) + del parameter_domain[space] + if parameter_domain != []: + parameter_domain = makeDomain(parameter_domain) + else: + parameter_domain = DomainTuple.scalar_domain() + fluct = _lognormal_moment_matching(fluctuations_mean, fluctuations_stddev, - prefix + 'fluctuations', space = space) + prefix + 'fluctuations', parameter_domain, space = space) flex = _lognormal_moment_matching(flexibility_mean, flexibility_stddev, - prefix + 'flexibility', space = space) + prefix + 'flexibility', parameter_domain, space = space) asp = _lognormal_moment_matching(asperity_mean, asperity_stddev, - prefix + 'asperity', space = space) + prefix + 'asperity', parameter_domain, space = space) avgsl = _normal(loglogavgslope_mean, loglogavgslope_stddev, - prefix + 'loglogavgslope', space = space) + prefix + 'loglogavgslope', parameter_domain, space = space) + self.add_fluctuations_from_ops(target, fluct, flex, asp, avgsl, prefix + 'spectrum', space) @@ -313,8 +320,7 @@ class CorrelatedFieldMaker: """ prefix = str(prefix) if offset is not None: - offset = float(offset) - + offset = float(offset) hspace = [] zeroind = () for amp, space in zip(self._amplitudes, self._spaces): @@ -325,9 +331,18 @@ class CorrelatedFieldMaker: hspace = makeDomain(hspace) spaces = np.cumsum(self._spaces) + np.arange(len(self._spaces)) + parameter_domain = list(makeDomain(hspace)) + for space in self._spaces: + del parameter_domain[space] + if parameter_domain != []: + parameter_domain = makeDomain(parameter_domain) + else: + parameter_domain = DomainTuple.scalar_domain() + azm = _lognormal_moment_matching(offset_amplitude_mean, offset_amplitude_stddev, - prefix + 'zeromode', space = space) + prefix + 'zeromode', parameter_domain, + space = tuple(self._spaces)) foo = np.ones(hspace.shape) foo[zeroind] = 0 -- GitLab