From fc6e304ab8d570efa345f3a496bff9764e199232 Mon Sep 17 00:00:00 2001 From: Gordian Edenhofer Date: Wed, 4 Dec 2019 16:49:34 +0100 Subject: [PATCH] correlated_field.py: Raise errors for users Instead of validating the input via `assert`s in user-accessible functions, raise appropriate errors. --- nifty6/library/correlated_fields.py | 34 +++++++++++++++++++---------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/nifty6/library/correlated_fields.py b/nifty6/library/correlated_fields.py index d02a8061..ff538dd1 100644 --- a/nifty6/library/correlated_fields.py +++ b/nifty6/library/correlated_fields.py @@ -51,8 +51,11 @@ def _lognormal_moments(mean, sig, N=0): mean, sig = np.asfarray(mean), np.asfarray(sig) else: mean, sig = (_reshaper(param, N) for param in (mean, sig)) - assert np.all(mean > 0) - assert np.all(sig > 0) + if not np.all(mean > 0): + raise ValueError(f"mean must be greater 0; got {mean!r}") + if not np.all(sig > 0): + raise ValueError(f"sig must be greater 0; got {sig!r}") + logsig = np.sqrt(np.log((sig/mean)**2 + 1)) logmean = np.log(mean) - logsig**2/2 return logmean, logsig @@ -334,7 +337,8 @@ class _Amplitude(Operator): class CorrelatedFieldMaker: def __init__(self, amplitude_offset, prefix, total_N): - assert isinstance(amplitude_offset, Operator) + if not isinstance(amplitude_offset, Operator): + raise TypeError("amplitude_offset needs to be an operator") self._a = [] self._position_spaces = [] @@ -348,8 +352,8 @@ class CorrelatedFieldMaker: dofdex=None): if dofdex is None: dofdex = np.full(total_N, 0) - else: - assert len(dofdex) == total_N + elif len(dofdex) != total_N: + raise ValueError("length of dofdex needs to match total_N") N = max(dofdex) + 1 if total_N > 0 else 0 zm = _LognormalMomentMatching(offset_amplitude_mean, offset_amplitude_stddev, @@ -381,8 +385,8 @@ class CorrelatedFieldMaker: if dofdex is None: dofdex = np.full(self._total_N, 0) - else: - assert len(dofdex) == self._total_N + elif len(dofdex) != self._total_N: + raise ValueError("length of dofdex needs to match total_N") if self._total_N > 0: space = 1 @@ -501,7 +505,9 @@ class CorrelatedFieldMaker: def moment_slice_to_average(self, fluctuations_slice_mean, nsamples=1000): fluctuations_slice_mean = float(fluctuations_slice_mean) - assert fluctuations_slice_mean > 0 + if not fluctuations_slice_mean > 0: + msg = "fluctuations_slice_mean must be greater zero; got {!r}" + raise ValueError(msg.format(fluctuations_slice_mean)) from ..sugar import from_random scm = 1. for a in self._a: @@ -547,7 +553,8 @@ class CorrelatedFieldMaker: """Returns operator which acts on prior or posterior samples""" if len(self._a) == 0: raise NotImplementedError - assert space < len(self._a) + if space >= len(self._a): + raise ValueError(f"invalid space specified; got {space!r}") if len(self._a) == 1: return self.average_fluctuation(0) q = 1. @@ -563,7 +570,8 @@ class CorrelatedFieldMaker: """Returns operator which acts on prior or posterior samples""" if len(self._a) == 0: raise NotImplementedError - assert space < len(self._a) + if space >= len(self._a): + raise ValueError(f"invalid space specified; got {space!r}") if len(self._a) == 1: return self._a[0].fluctuation_amplitude return self._a[space].fluctuation_amplitude @@ -584,7 +592,8 @@ class CorrelatedFieldMaker: """Computes slice fluctuations from collection of field (defined in signal space) realizations.""" ldom = len(samples[0].domain) - assert space < ldom + if space >= ldom: + raise ValueError(f"invalid space specified; got {space!r}") if ldom == 1: return _total_fluctuation_realized(samples) res1, res2 = 0., 0. @@ -601,7 +610,8 @@ class CorrelatedFieldMaker: """Computes average fluctuations from collection of field (defined in signal space) realizations.""" ldom = len(samples[0].domain) - assert space < ldom + if space >= ldom: + raise ValueError(f"invalid space specified; got {space!r}") if ldom == 1: return _total_fluctuation_realized(samples) spaces = () -- GitLab