diff --git a/nifty6/library/correlated_fields.py b/nifty6/library/correlated_fields.py index ff538dd104e0e9d30e4e9d286f22877f40bd5aa8..6e3245b9859f9273c6ecf252d9a23a35708b7c0d 100644 --- a/nifty6/library/correlated_fields.py +++ b/nifty6/library/correlated_fields.py @@ -21,19 +21,19 @@ import numpy as np from ..domain_tuple import DomainTuple from ..domains.power_space import PowerSpace from ..domains.unstructured_domain import UnstructuredDomain +from ..field import Field +from ..multi_field import MultiField from ..operators.adder import Adder from ..operators.contraction_operator import ContractionOperator +from ..operators.diagonal_operator import DiagonalOperator from ..operators.distributors import PowerDistributor from ..operators.endomorphic_operator import EndomorphicOperator from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.linear_operator import LinearOperator -from ..operators.diagonal_operator import DiagonalOperator from ..operators.operator import Operator -from ..operators.simple_linear_operators import VdotOperator, ducktape +from ..operators.simple_linear_operators import ducktape from ..probing import StatCalculator from ..sugar import from_global_data, full, makeDomain -from ..field import Field -from ..multi_field import MultiField def _reshaper(x, N): @@ -68,9 +68,8 @@ def _normal(mean, sig, key, N=0): else: domain = UnstructuredDomain(N) mean, sig = (_reshaper(param, N) for param in (mean, sig)) - return Adder(from_global_data(domain, mean)) @ ( - DiagonalOperator(from_global_data(domain, sig)) - @ ducktape(domain, None, key)) + return Adder(from_global_data(domain, mean)) @ (DiagonalOperator( + from_global_data(domain, sig)) @ ducktape(domain, None, key)) def _log_k_lengths(pspace): @@ -144,7 +143,7 @@ class _SlopeRemover(EndomorphicOperator): else: res = x.copy() res[self._last] -= (x*self._sc[self._extender]).sum( - axis=self._space, keepdims=True) + axis=self._space, keepdims=True) return from_global_data(self._tgt(mode), res) @@ -197,11 +196,14 @@ class _Normalization(Operator): hspace = list(self._domain) hspace[space] = hspace[space].harmonic_partner hspace = makeDomain(hspace) - pd = PowerDistributor(hspace, power_space=self._domain[space], space=space) + pd = PowerDistributor(hspace, + power_space=self._domain[space], + space=space) mode_multiplicity = pd.adjoint(full(pd.target, 1.)).to_global_data_rw() zero_mode = (slice(None),)*self._domain.axes[space][0] + (0,) mode_multiplicity[zero_mode] = 0 - self._mode_multiplicity = from_global_data(self._domain, mode_multiplicity) + self._mode_multiplicity = from_global_data(self._domain, + mode_multiplicity) self._specsum = _SpecialSum(self._domain, space) def apply(self, x): @@ -320,8 +322,8 @@ class _Amplitude(Operator): op = Distributor @ op sig_fluc = Distributor @ sig_fluc op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op) - self._fluc = (_Distributor(dofdex, fluctuations.target, distributed_tgt[0]) @ - fluctuations) + self._fluc = (_Distributor(dofdex, fluctuations.target, + distributed_tgt[0]) @ fluctuations) else: op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.one_over())*op) self._fluc = fluctuations @@ -426,14 +428,14 @@ class CorrelatedFieldMaker: def _finalize_from_op(self): n_amplitudes = len(self._a) if self._total_N > 0: - hspace = makeDomain([UnstructuredDomain(self._total_N)] + - [dd.target[-1].harmonic_partner - for dd in self._a]) + hspace = makeDomain( + [UnstructuredDomain(self._total_N)] + + [dd.target[-1].harmonic_partner for dd in self._a]) spaces = tuple(range(1, n_amplitudes + 1)) amp_space = 1 else: hspace = makeDomain( - [dd.target[0].harmonic_partner for dd in self._a]) + [dd.target[0].harmonic_partner for dd in self._a]) spaces = tuple(range(n_amplitudes)) amp_space = 0 @@ -450,14 +452,12 @@ class CorrelatedFieldMaker: pd = PowerDistributor(hspace, self._a[0].target[amp_space], amp_space) for i in range(1, n_amplitudes): - pd = (pd @ PowerDistributor(pd.domain, - self._a[i].target[amp_space], - space=spaces[i])) + pd = (pd @ PowerDistributor( + pd.domain, self._a[i].target[amp_space], space=spaces[i])) a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0] for i in range(1, n_amplitudes): - co = ContractionOperator(pd.domain, - spaces[:i] + spaces[i+1:]) + co = ContractionOperator(pd.domain, spaces[:i] + spaces[i + 1:]) a = a*(co.adjoint @ self._a[i]) return ht(azm*(pd @ a)*ducktape(hspace, None, self._prefix + 'xi'))