From 4fe143dd8606371e37469dded9fa0115efe59295 Mon Sep 17 00:00:00 2001 From: Philipp Haim Date: Wed, 13 Nov 2019 23:09:51 +0100 Subject: [PATCH] Working for arbitrary constelation of DomainTuple --- nifty5/library/correlated_fields.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index f85a664c..b218c997 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -231,21 +231,18 @@ class CorrelatedFieldMaker: twolog = _TwoLogIntegrations(target, space) dt = twolog._logvol axis = target.axes[space][0] - sl = (slice(None),)*axis + first = (slice(None),)*axis + (0,) extender_sl = (None,)*axis + (slice(None),) + (None,)*(target.axes[-1][-1] - axis) - first = sl + (0,) - second = sl + (1,) expander = ContractionOperator(twolog.domain, spaces = space).adjoint sqrt_t = np.zeros(twolog.domain[space].shape) - #sqrt_t[first] = sqrt_t[second] = np.sqrt(dt) sqrt_t[0] = sqrt_t[1] = np.sqrt(dt) sqrt_t = from_global_data(twolog.domain[space], sqrt_t) sqrt_t = DiagonalOperator(sqrt_t, twolog.domain, spaces = space) sigmasq = sqrt_t @ expander @ flexibility dist = np.zeros(twolog.domain[space].shape) - dist[first] += 1. + dist[0] += 1. dist = from_global_data(twolog.domain[space], dist) dist = DiagonalOperator(dist, twolog.domain, spaces = space) @@ -258,8 +255,9 @@ class CorrelatedFieldMaker: smoothslope = _make_slope_Operator(smooth,loglogavgslope, space) # move to tests - #assert_allclose( - # smooth(from_random('normal', smooth.domain)).val[0:2], 0) + assert_allclose( + smooth(from_random('normal', smooth.domain)).val[( + slice(None),)*axis + (slice(0,2),)], 0) consistency_check(twolog) check_jacobian_consistency(smooth, from_random('normal', smooth.domain)) @@ -275,7 +273,7 @@ class CorrelatedFieldMaker: expander = DiagonalOperator(from_global_data(target[space], arr) , target, spaces = space) @ expander mask = np.zeros(target.shape) - mask[0] = vol + mask[first] = vol adder = Adder(from_global_data(target, mask)) ampl = adder @ ((expander @ fluctuations)*normal_ampl) -- GitLab