Commit fc6e304a authored by Gordian Edenhofer's avatar Gordian Edenhofer Committed by Philipp Arras

correlated_field.py: Raise errors for users

Instead of validating the input via `assert`s in user-accessible
functions, raise appropriate errors.
parent 8aaafe29
Pipeline #64966 passed with stages
in 9 minutes and 52 seconds
...@@ -51,8 +51,11 @@ def _lognormal_moments(mean, sig, N=0): ...@@ -51,8 +51,11 @@ def _lognormal_moments(mean, sig, N=0):
mean, sig = np.asfarray(mean), np.asfarray(sig) mean, sig = np.asfarray(mean), np.asfarray(sig)
else: else:
mean, sig = (_reshaper(param, N) for param in (mean, sig)) mean, sig = (_reshaper(param, N) for param in (mean, sig))
assert np.all(mean > 0) if not np.all(mean > 0):
assert np.all(sig > 0) raise ValueError(f"mean must be greater 0; got {mean!r}")
if not np.all(sig > 0):
raise ValueError(f"sig must be greater 0; got {sig!r}")
logsig = np.sqrt(np.log((sig/mean)**2 + 1)) logsig = np.sqrt(np.log((sig/mean)**2 + 1))
logmean = np.log(mean) - logsig**2/2 logmean = np.log(mean) - logsig**2/2
return logmean, logsig return logmean, logsig
...@@ -334,7 +337,8 @@ class _Amplitude(Operator): ...@@ -334,7 +337,8 @@ class _Amplitude(Operator):
class CorrelatedFieldMaker: class CorrelatedFieldMaker:
def __init__(self, amplitude_offset, prefix, total_N): def __init__(self, amplitude_offset, prefix, total_N):
assert isinstance(amplitude_offset, Operator) if not isinstance(amplitude_offset, Operator):
raise TypeError("amplitude_offset needs to be an operator")
self._a = [] self._a = []
self._position_spaces = [] self._position_spaces = []
...@@ -348,8 +352,8 @@ class CorrelatedFieldMaker: ...@@ -348,8 +352,8 @@ class CorrelatedFieldMaker:
dofdex=None): dofdex=None):
if dofdex is None: if dofdex is None:
dofdex = np.full(total_N, 0) dofdex = np.full(total_N, 0)
else: elif len(dofdex) != total_N:
assert len(dofdex) == total_N raise ValueError("length of dofdex needs to match total_N")
N = max(dofdex) + 1 if total_N > 0 else 0 N = max(dofdex) + 1 if total_N > 0 else 0
zm = _LognormalMomentMatching(offset_amplitude_mean, zm = _LognormalMomentMatching(offset_amplitude_mean,
offset_amplitude_stddev, offset_amplitude_stddev,
...@@ -381,8 +385,8 @@ class CorrelatedFieldMaker: ...@@ -381,8 +385,8 @@ class CorrelatedFieldMaker:
if dofdex is None: if dofdex is None:
dofdex = np.full(self._total_N, 0) dofdex = np.full(self._total_N, 0)
else: elif len(dofdex) != self._total_N:
assert len(dofdex) == self._total_N raise ValueError("length of dofdex needs to match total_N")
if self._total_N > 0: if self._total_N > 0:
space = 1 space = 1
...@@ -501,7 +505,9 @@ class CorrelatedFieldMaker: ...@@ -501,7 +505,9 @@ class CorrelatedFieldMaker:
def moment_slice_to_average(self, fluctuations_slice_mean, nsamples=1000): def moment_slice_to_average(self, fluctuations_slice_mean, nsamples=1000):
fluctuations_slice_mean = float(fluctuations_slice_mean) fluctuations_slice_mean = float(fluctuations_slice_mean)
assert fluctuations_slice_mean > 0 if not fluctuations_slice_mean > 0:
msg = "fluctuations_slice_mean must be greater zero; got {!r}"
raise ValueError(msg.format(fluctuations_slice_mean))
from ..sugar import from_random from ..sugar import from_random
scm = 1. scm = 1.
for a in self._a: for a in self._a:
...@@ -547,7 +553,8 @@ class CorrelatedFieldMaker: ...@@ -547,7 +553,8 @@ class CorrelatedFieldMaker:
"""Returns operator which acts on prior or posterior samples""" """Returns operator which acts on prior or posterior samples"""
if len(self._a) == 0: if len(self._a) == 0:
raise NotImplementedError raise NotImplementedError
assert space < len(self._a) if space >= len(self._a):
raise ValueError(f"invalid space specified; got {space!r}")
if len(self._a) == 1: if len(self._a) == 1:
return self.average_fluctuation(0) return self.average_fluctuation(0)
q = 1. q = 1.
...@@ -563,7 +570,8 @@ class CorrelatedFieldMaker: ...@@ -563,7 +570,8 @@ class CorrelatedFieldMaker:
"""Returns operator which acts on prior or posterior samples""" """Returns operator which acts on prior or posterior samples"""
if len(self._a) == 0: if len(self._a) == 0:
raise NotImplementedError raise NotImplementedError
assert space < len(self._a) if space >= len(self._a):
raise ValueError(f"invalid space specified; got {space!r}")
if len(self._a) == 1: if len(self._a) == 1:
return self._a[0].fluctuation_amplitude return self._a[0].fluctuation_amplitude
return self._a[space].fluctuation_amplitude return self._a[space].fluctuation_amplitude
...@@ -584,7 +592,8 @@ class CorrelatedFieldMaker: ...@@ -584,7 +592,8 @@ class CorrelatedFieldMaker:
"""Computes slice fluctuations from collection of field (defined in signal """Computes slice fluctuations from collection of field (defined in signal
space) realizations.""" space) realizations."""
ldom = len(samples[0].domain) ldom = len(samples[0].domain)
assert space < ldom if space >= ldom:
raise ValueError(f"invalid space specified; got {space!r}")
if ldom == 1: if ldom == 1:
return _total_fluctuation_realized(samples) return _total_fluctuation_realized(samples)
res1, res2 = 0., 0. res1, res2 = 0., 0.
...@@ -601,7 +610,8 @@ class CorrelatedFieldMaker: ...@@ -601,7 +610,8 @@ class CorrelatedFieldMaker:
"""Computes average fluctuations from collection of field (defined in signal """Computes average fluctuations from collection of field (defined in signal
space) realizations.""" space) realizations."""
ldom = len(samples[0].domain) ldom = len(samples[0].domain)
assert space < ldom if space >= ldom:
raise ValueError(f"invalid space specified; got {space!r}")
if ldom == 1: if ldom == 1:
return _total_fluctuation_realized(samples) return _total_fluctuation_realized(samples)
spaces = () spaces = ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment