Commit 3c104777 authored by Philipp Arras's avatar Philipp Arras

Cosmetics

parent fc6e304a
Pipeline #64967 passed with stages
in 9 minutes and 49 seconds
...@@ -21,19 +21,19 @@ import numpy as np ...@@ -21,19 +21,19 @@ import numpy as np
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..domains.power_space import PowerSpace from ..domains.power_space import PowerSpace
from ..domains.unstructured_domain import UnstructuredDomain from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..multi_field import MultiField
from ..operators.adder import Adder from ..operators.adder import Adder
from ..operators.contraction_operator import ContractionOperator from ..operators.contraction_operator import ContractionOperator
from ..operators.diagonal_operator import DiagonalOperator
from ..operators.distributors import PowerDistributor from ..operators.distributors import PowerDistributor
from ..operators.endomorphic_operator import EndomorphicOperator from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.linear_operator import LinearOperator from ..operators.linear_operator import LinearOperator
from ..operators.diagonal_operator import DiagonalOperator
from ..operators.operator import Operator from ..operators.operator import Operator
from ..operators.simple_linear_operators import VdotOperator, ducktape from ..operators.simple_linear_operators import ducktape
from ..probing import StatCalculator from ..probing import StatCalculator
from ..sugar import from_global_data, full, makeDomain from ..sugar import from_global_data, full, makeDomain
from ..field import Field
from ..multi_field import MultiField
def _reshaper(x, N): def _reshaper(x, N):
...@@ -68,9 +68,8 @@ def _normal(mean, sig, key, N=0): ...@@ -68,9 +68,8 @@ def _normal(mean, sig, key, N=0):
else: else:
domain = UnstructuredDomain(N) domain = UnstructuredDomain(N)
mean, sig = (_reshaper(param, N) for param in (mean, sig)) mean, sig = (_reshaper(param, N) for param in (mean, sig))
return Adder(from_global_data(domain, mean)) @ ( return Adder(from_global_data(domain, mean)) @ (DiagonalOperator(
DiagonalOperator(from_global_data(domain, sig)) from_global_data(domain, sig)) @ ducktape(domain, None, key))
@ ducktape(domain, None, key))
def _log_k_lengths(pspace): def _log_k_lengths(pspace):
...@@ -144,7 +143,7 @@ class _SlopeRemover(EndomorphicOperator): ...@@ -144,7 +143,7 @@ class _SlopeRemover(EndomorphicOperator):
else: else:
res = x.copy() res = x.copy()
res[self._last] -= (x*self._sc[self._extender]).sum( res[self._last] -= (x*self._sc[self._extender]).sum(
axis=self._space, keepdims=True) axis=self._space, keepdims=True)
return from_global_data(self._tgt(mode), res) return from_global_data(self._tgt(mode), res)
...@@ -197,11 +196,14 @@ class _Normalization(Operator): ...@@ -197,11 +196,14 @@ class _Normalization(Operator):
hspace = list(self._domain) hspace = list(self._domain)
hspace[space] = hspace[space].harmonic_partner hspace[space] = hspace[space].harmonic_partner
hspace = makeDomain(hspace) hspace = makeDomain(hspace)
pd = PowerDistributor(hspace, power_space=self._domain[space], space=space) pd = PowerDistributor(hspace,
power_space=self._domain[space],
space=space)
mode_multiplicity = pd.adjoint(full(pd.target, 1.)).to_global_data_rw() mode_multiplicity = pd.adjoint(full(pd.target, 1.)).to_global_data_rw()
zero_mode = (slice(None),)*self._domain.axes[space][0] + (0,) zero_mode = (slice(None),)*self._domain.axes[space][0] + (0,)
mode_multiplicity[zero_mode] = 0 mode_multiplicity[zero_mode] = 0
self._mode_multiplicity = from_global_data(self._domain, mode_multiplicity) self._mode_multiplicity = from_global_data(self._domain,
mode_multiplicity)
self._specsum = _SpecialSum(self._domain, space) self._specsum = _SpecialSum(self._domain, space)
def apply(self, x): def apply(self, x):
...@@ -320,8 +322,8 @@ class _Amplitude(Operator): ...@@ -320,8 +322,8 @@ class _Amplitude(Operator):
op = Distributor @ op op = Distributor @ op
sig_fluc = Distributor @ sig_fluc sig_fluc = Distributor @ sig_fluc
op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op) op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op)
self._fluc = (_Distributor(dofdex, fluctuations.target, distributed_tgt[0]) @ self._fluc = (_Distributor(dofdex, fluctuations.target,
fluctuations) distributed_tgt[0]) @ fluctuations)
else: else:
op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.one_over())*op) op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.one_over())*op)
self._fluc = fluctuations self._fluc = fluctuations
...@@ -426,14 +428,14 @@ class CorrelatedFieldMaker: ...@@ -426,14 +428,14 @@ class CorrelatedFieldMaker:
def _finalize_from_op(self): def _finalize_from_op(self):
n_amplitudes = len(self._a) n_amplitudes = len(self._a)
if self._total_N > 0: if self._total_N > 0:
hspace = makeDomain([UnstructuredDomain(self._total_N)] + hspace = makeDomain(
[dd.target[-1].harmonic_partner [UnstructuredDomain(self._total_N)] +
for dd in self._a]) [dd.target[-1].harmonic_partner for dd in self._a])
spaces = tuple(range(1, n_amplitudes + 1)) spaces = tuple(range(1, n_amplitudes + 1))
amp_space = 1 amp_space = 1
else: else:
hspace = makeDomain( hspace = makeDomain(
[dd.target[0].harmonic_partner for dd in self._a]) [dd.target[0].harmonic_partner for dd in self._a])
spaces = tuple(range(n_amplitudes)) spaces = tuple(range(n_amplitudes))
amp_space = 0 amp_space = 0
...@@ -450,14 +452,12 @@ class CorrelatedFieldMaker: ...@@ -450,14 +452,12 @@ class CorrelatedFieldMaker:
pd = PowerDistributor(hspace, self._a[0].target[amp_space], amp_space) pd = PowerDistributor(hspace, self._a[0].target[amp_space], amp_space)
for i in range(1, n_amplitudes): for i in range(1, n_amplitudes):
pd = (pd @ PowerDistributor(pd.domain, pd = (pd @ PowerDistributor(
self._a[i].target[amp_space], pd.domain, self._a[i].target[amp_space], space=spaces[i]))
space=spaces[i]))
a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0] a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0]
for i in range(1, n_amplitudes): for i in range(1, n_amplitudes):
co = ContractionOperator(pd.domain, co = ContractionOperator(pd.domain, spaces[:i] + spaces[i + 1:])
spaces[:i] + spaces[i+1:])
a = a*(co.adjoint @ self._a[i]) a = a*(co.adjoint @ self._a[i])
return ht(azm*(pd @ a)*ducktape(hspace, None, self._prefix + 'xi')) return ht(azm*(pd @ a)*ducktape(hspace, None, self._prefix + 'xi'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment