Commit 71e38722 authored by Lukas Platz's avatar Lukas Platz

concise CorrelatedField offset parametrization

parent ff7dfaa8
...@@ -352,18 +352,19 @@ class _Amplitude(Operator): ...@@ -352,18 +352,19 @@ class _Amplitude(Operator):
class CorrelatedFieldMaker: class CorrelatedFieldMaker:
def __init__(self, amplitude_offset, prefix, total_N): def __init__(self, offset_mean, offset_fluctuations_op, prefix, total_N):
if not isinstance(amplitude_offset, Operator): if not isinstance(offset_fluctuations_op, Operator):
raise TypeError("amplitude_offset needs to be an operator") raise TypeError("offset_fluctuations_op needs to be an operator")
self._a = [] self._a = []
self._position_spaces = [] self._position_spaces = []
self._azm = amplitude_offset self._offset_mean = offset_mean
self._azm = offset_fluctuations_op
self._prefix = prefix self._prefix = prefix
self._total_N = total_N self._total_N = total_N
@staticmethod @staticmethod
def make(offset_amplitude_mean, offset_amplitude_stddev, prefix, def make(offset_mean, offset_variation_mean, offset_variation_stddev, prefix,
total_N=0, total_N=0,
dofdex=None): dofdex=None):
if dofdex is None: if dofdex is None:
...@@ -371,13 +372,13 @@ class CorrelatedFieldMaker: ...@@ -371,13 +372,13 @@ class CorrelatedFieldMaker:
elif len(dofdex) != total_N: elif len(dofdex) != total_N:
raise ValueError("length of dofdex needs to match total_N") raise ValueError("length of dofdex needs to match total_N")
N = max(dofdex) + 1 if total_N > 0 else 0 N = max(dofdex) + 1 if total_N > 0 else 0
zm = _LognormalMomentMatching(offset_amplitude_mean, zm = _LognormalMomentMatching(offset_variation_mean,
offset_amplitude_stddev, offset_variation_stddev,
prefix + 'zeromode', prefix + 'zeromode',
N) N)
if total_N > 0: if total_N > 0:
zm = _Distributor(dofdex, zm.target, UnstructuredDomain(total_N)) @ zm zm = _Distributor(dofdex, zm.target, UnstructuredDomain(total_N)) @ zm
return CorrelatedFieldMaker(zm, prefix, total_N) return CorrelatedFieldMaker(offset_mean, zm, prefix, total_N)
def add_fluctuations(self, def add_fluctuations(self,
position_space, position_space,
...@@ -470,12 +471,13 @@ class CorrelatedFieldMaker: ...@@ -470,12 +471,13 @@ class CorrelatedFieldMaker:
corr = reduce(mul, a) corr = reduce(mul, a)
return ht(azm*corr*ducktape(hspace, None, self._prefix + 'xi')) return ht(azm*corr*ducktape(hspace, None, self._prefix + 'xi'))
def finalize(self, offset=None, prior_info=100): def finalize(self, prior_info=100):
""" """
offset vs zeromode: volume factor offset vs zeromode: volume factor
""" """
op = self._finalize_from_op() op = self._finalize_from_op()
if offset is not None: if self._offset_mean is not None:
offset = self._offset_mean
# Deviations from this offset must not be considered here as they # Deviations from this offset must not be considered here as they
# are learned by the zeromode # are learned by the zeromode
if isinstance(offset, (Field, MultiField)): if isinstance(offset, (Field, MultiField)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment