diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index 5ba81c5216ab0c261b607a052314307ebc10ee8b..2e20f2ad740943d87d0a6e935d2cc27eea0266fe 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -39,15 +39,18 @@ from ..sugar import from_global_data, from_random, full, makeDomain, get_default def _reshaper(x, N): x = np.asfarray(x) if x.shape in [(), (1,)]: - return np.full(N, x) if N != 1 else x.reshape(()) + return np.full(N, x) if N != 0 else x.reshape(()) elif x.shape == (N,): return x else: raise TypeError("Shape of parameters cannot be interpreted") -def _lognormal_moments(mean, sig, N = 1): - mean, sig = (_reshaper(param, N) for param in (mean, sig)) +def _lognormal_moments(mean, sig, N = 0): + if N == 0: + mean, sig = np.asfarray(mean), np.asfarray(sig) + else: + mean, sig = (_reshaper(param, N) for param in (mean, sig)) assert np.all(mean > 0 ) assert np.all(sig > 0) logsig = np.sqrt(np.log((sig/mean)**2 + 1)) @@ -55,12 +58,13 @@ def _lognormal_moments(mean, sig, N = 1): return logmean, logsig -def _normal(mean, sig, key, N = 1): - if N == 1: +def _normal(mean, sig, key, N = 0): + if N == 0: domain = DomainTuple.scalar_domain() + mean, sig = np.asfarray(mean), np.asfarray(sig) else: domain = UnstructuredDomain(N) - mean, sig = (_reshaper(param, N) for param in (mean, sig)) + mean, sig = (_reshaper(param, N) for param in (mean, sig)) return Adder(from_global_data(domain, mean)) @ ( DiagonalOperator(from_global_data(domain,sig)) @ ducktape(domain, None, key)) @@ -280,16 +284,17 @@ class _Amplitude(Operator): assert isinstance(asperity, Operator) assert isinstance(loglogavgslope, Operator) - N_copies = max(dofdex) + 1 - assert N_copies > 0 - if N_copies > 1: + if len(dofdex) > 0: + N_copies = max(dofdex) + 1 space = 1 distributed_tgt = makeDomain((UnstructuredDomain(len(dofdex)), target)) target = makeDomain((UnstructuredDomain(N_copies), target)) Distributor = _Distributor(dofdex, target, distributed_tgt, 0) else: + N_copies = 0 space = 0 - target = makeDomain(target) + distributed_tgt = target = makeDomain(target) + azm_expander = ContractionOperator(distributed_tgt, spaces = space).adjoint assert isinstance(target[space], PowerSpace) twolog = _TwoLogIntegrations(target, space) @@ -337,12 +342,12 @@ class _Amplitude(Operator): sigma = sig_flex*(Adder(shift) @ sig_asp).sqrt() smooth = _SlopeRemover(target, space) @ twolog @ (sigma*xi) op = _Normalization(target, space) @ (slope + smooth) - if space == 1: + if N_copies > 0: op = Distributor @ op sig_fluc = Distributor @ sig_fluc - op = (Distributor @ Adder(vol0)) @ (sig_fluc*(ps_expander @ azm.one_over())*op) + op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op) else: - op = (Adder(vol0)) @ (sig_fluc*(ps_expander @ azm.one_over())*op) + op = (Adder(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op) self.apply = op.apply self._fluc = fluctuations @@ -365,7 +370,7 @@ class CorrelatedFieldMaker: self._total_N = total_N @staticmethod - def make(offset_amplitude_mean, offset_amplitude_stddev, prefix, total_N = 1): + def make(offset_amplitude_mean, offset_amplitude_stddev, prefix, total_N = 0): offset_amplitude_stddev = float(offset_amplitude_stddev) offset_amplitude_mean = float(offset_amplitude_mean) assert offset_amplitude_stddev > 0 @@ -393,15 +398,15 @@ class CorrelatedFieldMaker: dofdex = np.full(self._total_N, 0) else: assert len(dofdex) == self._total_N - N = max(dofdex) - if self._total_N > 1: + if self._total_N > 0: space = 1 - position_space = makeDomain((UnstructuredDomain(self._total_N), position_space)) + N = max(dofdex) + 1 + position_space = makeDomain((UnstructuredDomain(N), position_space)) else: space = 0 + N = 0 position_space = makeDomain(position_space) - N = 1 power_space = PowerSpace(position_space[space].get_default_codomain()) prefix = str(prefix) #assert isinstance(position_space[space], (RGSpace, HPSpace, GLSpace) @@ -410,14 +415,6 @@ class CorrelatedFieldMaker: fluctuations_stddev, prefix + 'fluctuations', N) - - #if copies: - # fluct = fluct*self._azm.one_over() - #else: - # #print(fluct. - # co = ContractionOperator(self._azm.target, None).adjoint - # fluct = (co @ fluct)*self._azm.one_over() - flex = _LognormalMomentMatching(flexibility_mean, flexibility_stddev, prefix + 'flexibility', N) @@ -442,16 +439,19 @@ class CorrelatedFieldMaker: assert isinstance(zeromode, Operator) self._azm = zeromode n_amplitudes = len(self._a) - if self._total_N > 1: + if self._total_N > 0: hspace = makeDomain([UnstructuredDomain(self._total_N)] + [dd[-1].get_default_codomain() for dd in self._position_spaces]) - spaces = tuple(len(dd) for dd in self._position_spaces) - spaces = 1 + np.cumsum(spaces) + spaces = list(1 + np.arange(n_amplitudes)) + #spaces = tuple(len(dd) for dd in self._position_spaces) + #spaces = 1 + np.cumsum(spaces) + zeroind = (slice(None),) + (0,)*(len(hspace.shape)-1) else: hspace = makeDomain( [dd[-1].get_default_codomain() for dd in self._position_spaces]) spaces = tuple(range(n_amplitudes)) - zeroind = (slice(None),)*(1 - 1//self._total_N) + (0,)*(len(hspace.shape)-1+1//self._total_N) + spaces = list(np.arange(n_amplitudes)) + zeroind = (0,)*len(hspace.shape) foo = np.ones(hspace.shape) foo[zeroind] = 0 @@ -460,7 +460,7 @@ class CorrelatedFieldMaker: self._azm.target, zeroind).adjoint azm = Adder(from_global_data(hspace, foo)) @ ZeroModeInserter @ zeromode - spaces = np.array(range(n_amplitudes)) + 1 - 1//self._total_N + #spaces = np.array(range(n_amplitudes)) + 1 - 1//self._total_N ht = HarmonicTransformOperator(hspace, self._position_spaces[0][self._spaces[0]], space=spaces[0]) @@ -475,12 +475,10 @@ class CorrelatedFieldMaker: self._a[i].target[self._spaces[i]], space=spaces[i])) - #breakpoint() - all_spaces = list(range(len(hspace))) a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0] for i in range(1, n_amplitudes): co = ContractionOperator(pd.domain, - all_spaces[:spaces[i]] + all_spaces[spaces[i] + 1:]) + spaces[:i] + spaces[i+1:]) a = a*(co.adjoint @ self._a[i]) return ht(azm*(pd @ a)*ducktape(hspace, None, prefix + 'xi'))