Commit d166c586 authored by Martin Reinecke's avatar Martin Reinecke

merge

parents 5b3640bb 4ca56a07
Pipeline #64971 failed with stages
in 8 minutes and 35 seconds
...@@ -21,15 +21,17 @@ import numpy as np ...@@ -21,15 +21,17 @@ import numpy as np
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..domains.power_space import PowerSpace from ..domains.power_space import PowerSpace
from ..domains.unstructured_domain import UnstructuredDomain from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..multi_field import MultiField
from ..operators.adder import Adder from ..operators.adder import Adder
from ..operators.contraction_operator import ContractionOperator from ..operators.contraction_operator import ContractionOperator
from ..operators.diagonal_operator import DiagonalOperator
from ..operators.distributors import PowerDistributor from ..operators.distributors import PowerDistributor
from ..operators.endomorphic_operator import EndomorphicOperator from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.linear_operator import LinearOperator from ..operators.linear_operator import LinearOperator
from ..operators.diagonal_operator import DiagonalOperator
from ..operators.operator import Operator from ..operators.operator import Operator
from ..operators.simple_linear_operators import VdotOperator, ducktape from ..operators.simple_linear_operators import ducktape
from ..probing import StatCalculator from ..probing import StatCalculator
from ..sugar import from_global_data, full, makeDomain from ..sugar import from_global_data, full, makeDomain
...@@ -49,8 +51,11 @@ def _lognormal_moments(mean, sig, N=0): ...@@ -49,8 +51,11 @@ def _lognormal_moments(mean, sig, N=0):
mean, sig = np.asfarray(mean), np.asfarray(sig) mean, sig = np.asfarray(mean), np.asfarray(sig)
else: else:
mean, sig = (_reshaper(param, N) for param in (mean, sig)) mean, sig = (_reshaper(param, N) for param in (mean, sig))
assert np.all(mean > 0) if not np.all(mean > 0):
assert np.all(sig > 0) raise ValueError(f"mean must be greater 0; got {mean!r}")
if not np.all(sig > 0):
raise ValueError(f"sig must be greater 0; got {sig!r}")
logsig = np.sqrt(np.log((sig/mean)**2 + 1)) logsig = np.sqrt(np.log((sig/mean)**2 + 1))
logmean = np.log(mean) - logsig**2/2 logmean = np.log(mean) - logsig**2/2
return logmean, logsig return logmean, logsig
...@@ -63,9 +68,8 @@ def _normal(mean, sig, key, N=0): ...@@ -63,9 +68,8 @@ def _normal(mean, sig, key, N=0):
else: else:
domain = UnstructuredDomain(N) domain = UnstructuredDomain(N)
mean, sig = (_reshaper(param, N) for param in (mean, sig)) mean, sig = (_reshaper(param, N) for param in (mean, sig))
return Adder(from_global_data(domain, mean)) @ ( return Adder(from_global_data(domain, mean)) @ (DiagonalOperator(
DiagonalOperator(from_global_data(domain, sig)) from_global_data(domain, sig)) @ ducktape(domain, None, key))
@ ducktape(domain, None, key))
def _log_k_lengths(pspace): def _log_k_lengths(pspace):
...@@ -99,13 +103,6 @@ def _total_fluctuation_realized(samples): ...@@ -99,13 +103,6 @@ def _total_fluctuation_realized(samples):
return np.sqrt((res/len(samples)).mean()) return np.sqrt((res/len(samples)).mean())
def _stats(op, samples):
sc = StatCalculator()
for s in samples:
sc.add(op(s.extract(op.domain)))
return sc.mean.val, sc.var.sqrt().val
class _LognormalMomentMatching(Operator): class _LognormalMomentMatching(Operator):
def __init__(self, mean, sig, key, N_copies): def __init__(self, mean, sig, key, N_copies):
key = str(key) key = str(key)
...@@ -146,7 +143,7 @@ class _SlopeRemover(EndomorphicOperator): ...@@ -146,7 +143,7 @@ class _SlopeRemover(EndomorphicOperator):
else: else:
res = x.copy() res = x.copy()
res[self._last] -= (x*self._sc[self._extender]).sum( res[self._last] -= (x*self._sc[self._extender]).sum(
axis=self._space, keepdims=True) axis=self._space, keepdims=True)
return from_global_data(self._tgt(mode), res) return from_global_data(self._tgt(mode), res)
...@@ -199,11 +196,14 @@ class _Normalization(Operator): ...@@ -199,11 +196,14 @@ class _Normalization(Operator):
hspace = list(self._domain) hspace = list(self._domain)
hspace[space] = hspace[space].harmonic_partner hspace[space] = hspace[space].harmonic_partner
hspace = makeDomain(hspace) hspace = makeDomain(hspace)
pd = PowerDistributor(hspace, power_space=self._domain[space], space=space) pd = PowerDistributor(hspace,
power_space=self._domain[space],
space=space)
mode_multiplicity = pd.adjoint(full(pd.target, 1.)).val.copy() mode_multiplicity = pd.adjoint(full(pd.target, 1.)).val.copy()
zero_mode = (slice(None),)*self._domain.axes[space][0] + (0,) zero_mode = (slice(None),)*self._domain.axes[space][0] + (0,)
mode_multiplicity[zero_mode] = 0 mode_multiplicity[zero_mode] = 0
self._mode_multiplicity = from_global_data(self._domain, mode_multiplicity) self._mode_multiplicity = from_global_data(self._domain,
mode_multiplicity)
self._specsum = _SpecialSum(self._domain, space) self._specsum = _SpecialSum(self._domain, space)
def apply(self, x): def apply(self, x):
...@@ -322,8 +322,8 @@ class _Amplitude(Operator): ...@@ -322,8 +322,8 @@ class _Amplitude(Operator):
op = Distributor @ op op = Distributor @ op
sig_fluc = Distributor @ sig_fluc sig_fluc = Distributor @ sig_fluc
op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op) op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op)
self._fluc = (_Distributor(dofdex, fluctuations.target, distributed_tgt[0]) @ self._fluc = (_Distributor(dofdex, fluctuations.target,
fluctuations) distributed_tgt[0]) @ fluctuations)
else: else:
op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.one_over())*op) op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.one_over())*op)
self._fluc = fluctuations self._fluc = fluctuations
...@@ -339,7 +339,8 @@ class _Amplitude(Operator): ...@@ -339,7 +339,8 @@ class _Amplitude(Operator):
class CorrelatedFieldMaker: class CorrelatedFieldMaker:
def __init__(self, amplitude_offset, prefix, total_N): def __init__(self, amplitude_offset, prefix, total_N):
assert isinstance(amplitude_offset, Operator) if not isinstance(amplitude_offset, Operator):
raise TypeError("amplitude_offset needs to be an operator")
self._a = [] self._a = []
self._position_spaces = [] self._position_spaces = []
...@@ -353,8 +354,8 @@ class CorrelatedFieldMaker: ...@@ -353,8 +354,8 @@ class CorrelatedFieldMaker:
dofdex=None): dofdex=None):
if dofdex is None: if dofdex is None:
dofdex = np.full(total_N, 0) dofdex = np.full(total_N, 0)
else: elif len(dofdex) != total_N:
assert len(dofdex) == total_N raise ValueError("length of dofdex needs to match total_N")
N = max(dofdex) + 1 if total_N > 0 else 0 N = max(dofdex) + 1 if total_N > 0 else 0
zm = _LognormalMomentMatching(offset_amplitude_mean, zm = _LognormalMomentMatching(offset_amplitude_mean,
offset_amplitude_stddev, offset_amplitude_stddev,
...@@ -386,8 +387,8 @@ class CorrelatedFieldMaker: ...@@ -386,8 +387,8 @@ class CorrelatedFieldMaker:
if dofdex is None: if dofdex is None:
dofdex = np.full(self._total_N, 0) dofdex = np.full(self._total_N, 0)
else: elif len(dofdex) != self._total_N:
assert len(dofdex) == self._total_N raise ValueError("length of dofdex needs to match total_N")
if self._total_N > 0: if self._total_N > 0:
space = 1 space = 1
...@@ -427,14 +428,14 @@ class CorrelatedFieldMaker: ...@@ -427,14 +428,14 @@ class CorrelatedFieldMaker:
def _finalize_from_op(self): def _finalize_from_op(self):
n_amplitudes = len(self._a) n_amplitudes = len(self._a)
if self._total_N > 0: if self._total_N > 0:
hspace = makeDomain([UnstructuredDomain(self._total_N)] + hspace = makeDomain(
[dd.target[-1].harmonic_partner [UnstructuredDomain(self._total_N)] +
for dd in self._a]) [dd.target[-1].harmonic_partner for dd in self._a])
spaces = tuple(range(1, n_amplitudes + 1)) spaces = tuple(range(1, n_amplitudes + 1))
amp_space = 1 amp_space = 1
else: else:
hspace = makeDomain( hspace = makeDomain(
[dd.target[0].harmonic_partner for dd in self._a]) [dd.target[0].harmonic_partner for dd in self._a])
spaces = tuple(range(n_amplitudes)) spaces = tuple(range(n_amplitudes))
amp_space = 0 amp_space = 0
...@@ -451,14 +452,12 @@ class CorrelatedFieldMaker: ...@@ -451,14 +452,12 @@ class CorrelatedFieldMaker:
pd = PowerDistributor(hspace, self._a[0].target[amp_space], amp_space) pd = PowerDistributor(hspace, self._a[0].target[amp_space], amp_space)
for i in range(1, n_amplitudes): for i in range(1, n_amplitudes):
pd = (pd @ PowerDistributor(pd.domain, pd = (pd @ PowerDistributor(
self._a[i].target[amp_space], pd.domain, self._a[i].target[amp_space], space=spaces[i]))
space=spaces[i]))
a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0] a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0]
for i in range(1, n_amplitudes): for i in range(1, n_amplitudes):
co = ContractionOperator(pd.domain, co = ContractionOperator(pd.domain, spaces[:i] + spaces[i + 1:])
spaces[:i] + spaces[i+1:])
a = a*(co.adjoint @ self._a[i]) a = a*(co.adjoint @ self._a[i])
return ht(azm*(pd @ a)*ducktape(hspace, None, self._prefix + 'xi')) return ht(azm*(pd @ a)*ducktape(hspace, None, self._prefix + 'xi'))
...@@ -467,23 +466,26 @@ class CorrelatedFieldMaker: ...@@ -467,23 +466,26 @@ class CorrelatedFieldMaker:
""" """
offset vs zeromode: volume factor offset vs zeromode: volume factor
""" """
if offset is not None:
raise NotImplementedError
offset = float(offset)
op = self._finalize_from_op() op = self._finalize_from_op()
if prior_info > 0: if offset is not None:
from ..sugar import from_random # Deviations from this offset must not be considered here as they
samps = [ # are learned by the zeromode
from_random('normal', op.domain) for _ in range(prior_info) if isinstance(offset, (Field, MultiField)):
] op = Adder(offset) @ op
self.statistics_summary(samps) else:
offset = float(offset)
op = Adder(full(op.target, offset)) @ op
self.statistics_summary(prior_info)
return op return op
def statistics_summary(self, samples): def statistics_summary(self, prior_info):
from ..sugar import from_random
if prior_info == 0:
return
lst = [('Offset amplitude', self.amplitude_total_offset), lst = [('Offset amplitude', self.amplitude_total_offset),
('Total fluctuation amplitude', self.total_fluctuation)] ('Total fluctuation amplitude', self.total_fluctuation)]
namps = len(self._a) namps = len(self._a)
if namps > 1: if namps > 1:
for ii in range(namps): for ii in range(namps):
...@@ -493,13 +495,19 @@ class CorrelatedFieldMaker: ...@@ -493,13 +495,19 @@ class CorrelatedFieldMaker:
self.average_fluctuation(ii))) self.average_fluctuation(ii)))
for kk, op in lst: for kk, op in lst:
mean, stddev = _stats(op, samples) sc = StatCalculator()
for _ in range(prior_info):
sc.add(op(from_random('normal', op.domain)))
mean = sc.mean.val
stddev = sc.var.sqrt().val
for m, s in zip(mean.flatten(), stddev.flatten()): for m, s in zip(mean.flatten(), stddev.flatten()):
print('{}: {:.02E} ± {:.02E}'.format(kk, m, s)) print('{}: {:.02E} ± {:.02E}'.format(kk, m, s))
def moment_slice_to_average(self, fluctuations_slice_mean, nsamples=1000): def moment_slice_to_average(self, fluctuations_slice_mean, nsamples=1000):
fluctuations_slice_mean = float(fluctuations_slice_mean) fluctuations_slice_mean = float(fluctuations_slice_mean)
assert fluctuations_slice_mean > 0 if not fluctuations_slice_mean > 0:
msg = "fluctuations_slice_mean must be greater zero; got {!r}"
raise ValueError(msg.format(fluctuations_slice_mean))
from ..sugar import from_random from ..sugar import from_random
scm = 1. scm = 1.
for a in self._a: for a in self._a:
...@@ -545,7 +553,8 @@ class CorrelatedFieldMaker: ...@@ -545,7 +553,8 @@ class CorrelatedFieldMaker:
"""Returns operator which acts on prior or posterior samples""" """Returns operator which acts on prior or posterior samples"""
if len(self._a) == 0: if len(self._a) == 0:
raise NotImplementedError raise NotImplementedError
assert space < len(self._a) if space >= len(self._a):
raise ValueError(f"invalid space specified; got {space!r}")
if len(self._a) == 1: if len(self._a) == 1:
return self.average_fluctuation(0) return self.average_fluctuation(0)
q = 1. q = 1.
...@@ -561,7 +570,8 @@ class CorrelatedFieldMaker: ...@@ -561,7 +570,8 @@ class CorrelatedFieldMaker:
"""Returns operator which acts on prior or posterior samples""" """Returns operator which acts on prior or posterior samples"""
if len(self._a) == 0: if len(self._a) == 0:
raise NotImplementedError raise NotImplementedError
assert space < len(self._a) if space >= len(self._a):
raise ValueError(f"invalid space specified; got {space!r}")
if len(self._a) == 1: if len(self._a) == 1:
return self._a[0].fluctuation_amplitude return self._a[0].fluctuation_amplitude
return self._a[space].fluctuation_amplitude return self._a[space].fluctuation_amplitude
...@@ -582,7 +592,8 @@ class CorrelatedFieldMaker: ...@@ -582,7 +592,8 @@ class CorrelatedFieldMaker:
"""Computes slice fluctuations from collection of field (defined in signal """Computes slice fluctuations from collection of field (defined in signal
space) realizations.""" space) realizations."""
ldom = len(samples[0].domain) ldom = len(samples[0].domain)
assert space < ldom if space >= ldom:
raise ValueError(f"invalid space specified; got {space!r}")
if ldom == 1: if ldom == 1:
return _total_fluctuation_realized(samples) return _total_fluctuation_realized(samples)
res1, res2 = 0., 0. res1, res2 = 0., 0.
...@@ -599,7 +610,8 @@ class CorrelatedFieldMaker: ...@@ -599,7 +610,8 @@ class CorrelatedFieldMaker:
"""Computes average fluctuations from collection of field (defined in signal """Computes average fluctuations from collection of field (defined in signal
space) realizations.""" space) realizations."""
ldom = len(samples[0].domain) ldom = len(samples[0].domain)
assert space < ldom if space >= ldom:
raise ValueError(f"invalid space specified; got {space!r}")
if ldom == 1: if ldom == 1:
return _total_fluctuation_realized(samples) return _total_fluctuation_realized(samples)
spaces = () spaces = ()
......
...@@ -132,6 +132,10 @@ class MultiField(object): ...@@ -132,6 +132,10 @@ class MultiField(object):
return MultiField(domain, tuple(Field(dom, val) return MultiField(domain, tuple(Field(dom, val)
for dom in domain._domains)) for dom in domain._domains))
def to_global_data(self):
return {key: val.val
for key, val in zip(self._domain.keys(), self._val)}
@staticmethod @staticmethod
def from_global_data(domain, arr): def from_global_data(domain, arr):
return MultiField( return MultiField(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment