Commit fc6e304a authored by Gordian Edenhofer's avatar Gordian Edenhofer Committed by Philipp Arras
Browse files

correlated_field.py: Raise errors for users

Instead of validating the input via `assert`s in user-accessible
functions, raise appropriate errors.
parent 8aaafe29
Pipeline #64966 passed with stages
in 9 minutes and 52 seconds
......@@ -51,8 +51,11 @@ def _lognormal_moments(mean, sig, N=0):
mean, sig = np.asfarray(mean), np.asfarray(sig)
else:
mean, sig = (_reshaper(param, N) for param in (mean, sig))
assert np.all(mean > 0)
assert np.all(sig > 0)
if not np.all(mean > 0):
raise ValueError(f"mean must be greater 0; got {mean!r}")
if not np.all(sig > 0):
raise ValueError(f"sig must be greater 0; got {sig!r}")
logsig = np.sqrt(np.log((sig/mean)**2 + 1))
logmean = np.log(mean) - logsig**2/2
return logmean, logsig
......@@ -334,7 +337,8 @@ class _Amplitude(Operator):
class CorrelatedFieldMaker:
def __init__(self, amplitude_offset, prefix, total_N):
assert isinstance(amplitude_offset, Operator)
if not isinstance(amplitude_offset, Operator):
raise TypeError("amplitude_offset needs to be an operator")
self._a = []
self._position_spaces = []
......@@ -348,8 +352,8 @@ class CorrelatedFieldMaker:
dofdex=None):
if dofdex is None:
dofdex = np.full(total_N, 0)
else:
assert len(dofdex) == total_N
elif len(dofdex) != total_N:
raise ValueError("length of dofdex needs to match total_N")
N = max(dofdex) + 1 if total_N > 0 else 0
zm = _LognormalMomentMatching(offset_amplitude_mean,
offset_amplitude_stddev,
......@@ -381,8 +385,8 @@ class CorrelatedFieldMaker:
if dofdex is None:
dofdex = np.full(self._total_N, 0)
else:
assert len(dofdex) == self._total_N
elif len(dofdex) != self._total_N:
raise ValueError("length of dofdex needs to match total_N")
if self._total_N > 0:
space = 1
......@@ -501,7 +505,9 @@ class CorrelatedFieldMaker:
def moment_slice_to_average(self, fluctuations_slice_mean, nsamples=1000):
fluctuations_slice_mean = float(fluctuations_slice_mean)
assert fluctuations_slice_mean > 0
if not fluctuations_slice_mean > 0:
msg = "fluctuations_slice_mean must be greater zero; got {!r}"
raise ValueError(msg.format(fluctuations_slice_mean))
from ..sugar import from_random
scm = 1.
for a in self._a:
......@@ -547,7 +553,8 @@ class CorrelatedFieldMaker:
"""Returns operator which acts on prior or posterior samples"""
if len(self._a) == 0:
raise NotImplementedError
assert space < len(self._a)
if space >= len(self._a):
raise ValueError(f"invalid space specified; got {space!r}")
if len(self._a) == 1:
return self.average_fluctuation(0)
q = 1.
......@@ -563,7 +570,8 @@ class CorrelatedFieldMaker:
"""Returns operator which acts on prior or posterior samples"""
if len(self._a) == 0:
raise NotImplementedError
assert space < len(self._a)
if space >= len(self._a):
raise ValueError(f"invalid space specified; got {space!r}")
if len(self._a) == 1:
return self._a[0].fluctuation_amplitude
return self._a[space].fluctuation_amplitude
......@@ -584,7 +592,8 @@ class CorrelatedFieldMaker:
"""Computes slice fluctuations from collection of field (defined in signal
space) realizations."""
ldom = len(samples[0].domain)
assert space < ldom
if space >= ldom:
raise ValueError(f"invalid space specified; got {space!r}")
if ldom == 1:
return _total_fluctuation_realized(samples)
res1, res2 = 0., 0.
......@@ -601,7 +610,8 @@ class CorrelatedFieldMaker:
"""Computes average fluctuations from collection of field (defined in signal
space) realizations."""
ldom = len(samples[0].domain)
assert space < ldom
if space >= ldom:
raise ValueError(f"invalid space specified; got {space!r}")
if ldom == 1:
return _total_fluctuation_realized(samples)
spaces = ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment