Skip to content
Snippets Groups Projects
Commit fa44604e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

reduce direct usage of Field constructor

parent c0d51e88
Branches
Tags
No related merge requests found
Pipeline #
......@@ -29,7 +29,7 @@ if __name__ == "__main__":
# Choose the measurement instrument
# Instrument = SmoothingOperator(s_space, sigma=0.01)
Instrument = ift.DiagonalOperator(ift.Field(s_space, 1.))
Instrument = ift.ScalingOperator(1., s_space)
# Instrument._diagonal.val[200:400, 200:400] = 0
# Instrument._diagonal.val[64:512-64, 64:512-64] = 0
......@@ -74,7 +74,7 @@ if __name__ == "__main__":
# Set starting position
flat_power = ift.Field.full(p_space, 1e-8)
m0 = ift.power_synthesize(flat_power, real_signal=True)
t0 = ift.Field(p_space, val=-7.)
t0 = ift.Field.full(p_space, -7.)
for i in range(500):
S0 = ift.create_power_operator(h_space, power_spectrum=ift.exp(t0))
......
......
......@@ -5,14 +5,14 @@ np.random.seed(42)
def adjust_zero_mode(m0, t0):
mtmp = ift.dobj.to_global_data(m0.val)
mtmp = m0.to_global_data()
zero_position = len(m0.shape)*(0,)
zero_mode = mtmp[zero_position]
mtmp[zero_position] = zero_mode / abs(zero_mode)
ttmp = ift.dobj.to_global_data(t0.val)
ttmp = t0.to_global_data()
ttmp[0] += 2 * np.log(abs(zero_mode))
return (ift.Field(m0.domain, ift.dobj.from_global_data(mtmp)),
ift.Field(t0.domain, ift.dobj.from_global_data(ttmp)))
return (ift.Field.from_global_data(m0.domain, mtmp),
ift.Field.from_global_data(t0.domain, ttmp))
if __name__ == "__main__":
......@@ -48,7 +48,7 @@ if __name__ == "__main__":
# Instrument = SmoothingOperator(s_space, sigma=0.01)
mask = np.ones(s_space.shape)
mask[6000:8000] = 0.
mask = ift.Field(s_space, val=ift.dobj.from_global_data(mask))
mask = ift.Field.from_global_data(s_space, mask)
MaskOperator = ift.DiagonalOperator(mask)
R = ift.GeometryRemover(s_space)
......
......
......@@ -46,11 +46,10 @@ if __name__ == "__main__":
ht = ht_2*ht_1
mock_power = ift.Field(
mock_power = ift.Field.from_global_data(
(power_space_1, power_space_2),
val=ift.dobj.from_global_data(
np.outer(ift.dobj.to_global_data(mock_power_1.val),
ift.dobj.to_global_data(mock_power_2.val))))
np.outer(mock_power_1.to_global_data(),
mock_power_2.to_global_data()))
diagonal = ift.power_synthesize_nonrandom(mock_power, spaces=(0, 1))**2
......@@ -63,12 +62,12 @@ if __name__ == "__main__":
N1_10 = int(N_pixels_1/10)
mask_1 = np.ones(signal_space_1.shape)
mask_1[N1_10*7:N1_10*9] = 0.
mask_1 = ift.Field(signal_space_1, ift.dobj.from_global_data(mask_1))
mask_1 = ift.Field.from_global_data(signal_space_1, mask_1)
N2_10 = int(N_pixels_2/10)
mask_2 = np.ones(signal_space_2.shape)
mask_2[N2_10*7:N2_10*9] = 0.
mask_2 = ift.Field(signal_space_2, ift.dobj.from_global_data(mask_2))
mask_2 = ift.Field.from_global_data(signal_space_2, mask_2)
R = ift.GeometryRemover(signal_domain)
R = R*ift.DiagonalOperator(mask_1, signal_domain, spaces=0)
......
......
......@@ -34,7 +34,7 @@ if __name__ == "__main__":
mask = np.ones(signal_space.shape)
N10 = int(N_pixels/10)
mask[N10*5:N10*9, N10*5:N10*9] = 0.
mask = ift.Field(signal_space, ift.dobj.from_global_data(mask))
mask = ift.Field.from_global_data(signal_space, mask).lock()
R = ift.GeometryRemover(signal_space)
R = R*ift.DiagonalOperator(mask)
R = R*ht
......
......
......@@ -34,7 +34,7 @@ if __name__ == "__main__":
else:
raise NotImplementedError
diag = ift.Field(s_space, ift.dobj.from_global_data(diag))
diag = ift.Field.from_global_data(s_space, diag).lock()
Instrument = ift.DiagonalOperator(diag)
# Add harmonic transformation to the instrument
......
......
......@@ -91,7 +91,7 @@ class LMSpace(StructuredDomain):
for m in range(1, mmax+1):
ldist[idx:idx+2*(lmax+1-m)] = tmp[2*m:]
idx += 2*(lmax+1-m)
return Field((self,), dobj.from_global_data(ldist))
return Field.from_global_data(self, ldist)
def get_unique_k_lengths(self):
return np.arange(self.lmax+1, dtype=np.float64)
......
......
......@@ -123,6 +123,13 @@ class Field(object):
dtype = field.dtype
return Field.empty(field._domain, dtype)
@staticmethod
def from_global_data(domain, dobject):
return Field(domain, dobj.from_global_data(dobject))
def to_global_data(self):
return dobj.to_global_data(self._val)
@staticmethod
def _infer_domain(domain, val=None):
if domain is None:
......@@ -174,6 +181,7 @@ class Field(object):
def lock(self):
dobj.lock(self._val)
return self
@property
def locked(self):
......
......
......@@ -91,7 +91,7 @@ class DOFDistributor(LinearOperator):
np.add.at(oarr, (slice(None), self._dofdex, slice(None)), arr)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
oarr = dobj.np_allreduce_sum(oarr).reshape(self._domain.shape)
res = Field(self._domain, dobj.from_global_data(oarr))
res = Field.from_global_data(self._domain, oarr)
else:
oarr = oarr.reshape(dobj.local_shape(self._domain.shape,
dobj.distaxis(x.val)))
......@@ -103,7 +103,7 @@ class DOFDistributor(LinearOperator):
def _times(self, x):
res = Field.empty(self._target, dtype=x.dtype)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
arr = dobj.to_global_data(x.val)
arr = x.to_global_data()
else:
arr = dobj.local_data(x.val)
arr = arr.reshape(self._hshape)
......
......
......@@ -59,14 +59,14 @@ class Test_Functionality(unittest.TestCase):
p1 = ift.PowerSpace(space1)
p1val = _spec1(p1.k_lengths)
fp1 = ift.Field(p1, val=ift.dobj.from_global_data(p1val))
fp1 = ift.Field.from_global_data(p1, p1val)
p2 = ift.PowerSpace(space2)
p2val = _spec2(p2.k_lengths)
fp2 = ift.Field(p2, val=ift.dobj.from_global_data(p2val))
fp2 = ift.Field.from_global_data(p2, p2val)
outer = ift.dobj.from_global_data(np.outer(p1val, p2val))
fp = ift.Field((p1, p2), val=outer)
outer = np.outer(p1val, p2val)
fp = ift.Field.from_global_data((p1, p2), outer)
samples = 500
ps1 = 0.
......@@ -79,10 +79,10 @@ class Test_Functionality(unittest.TestCase):
ps1 += sp.sum(spaces=1)/fp2.sum()
ps2 += sp.sum(spaces=0)/fp1.sum()
assert_allclose(ift.dobj.to_global_data(ps1.val/samples),
ift.dobj.to_global_data(fp1.val), rtol=0.2)
assert_allclose(ift.dobj.to_global_data(ps2.val/samples),
ift.dobj.to_global_data(fp2.val), rtol=0.2)
assert_allclose((ps1/samples).to_global_data(),
fp1.to_global_data(), rtol=0.2)
assert_allclose((ps2/samples).to_global_data(),
fp2.to_global_data(), rtol=0.2)
@expand(product([ift.RGSpace((8,), harmonic=True),
ift.RGSpace((8, 8), harmonic=True, distances=0.123)],
......@@ -95,11 +95,11 @@ class Test_Functionality(unittest.TestCase):
p1 = ift.PowerSpace(space1)
p1val = _spec1(p1.k_lengths)
fp1 = ift.Field(p1, val=ift.dobj.from_global_data(p1val))
fp1 = ift.Field.from_global_data(p1, p1val)
p2 = ift.PowerSpace(space2)
p2val = _spec2(p2.k_lengths)
fp2 = ift.Field(p2, val=ift.dobj.from_global_data(p2val))
fp2 = ift.Field.from_global_data(p2, p2val)
S_1 = ift.create_power_field(space1, lambda x: np.sqrt(_spec1(x)))
S_1 = ift.DiagonalOperator(S_1, domain=fulldomain, spaces=0)
......@@ -118,10 +118,10 @@ class Test_Functionality(unittest.TestCase):
ps1 += sp.sum(spaces=1)/fp2.sum()
ps2 += sp.sum(spaces=0)/fp1.sum()
assert_allclose(ift.dobj.to_global_data(ps1.val/samples),
ift.dobj.to_global_data(fp1.val), rtol=0.2)
assert_allclose(ift.dobj.to_global_data(ps2.val/samples),
ift.dobj.to_global_data(fp2.val), rtol=0.2)
assert_allclose((ps1/samples).to_global_data(),
fp1.to_global_data(), rtol=0.2)
assert_allclose((ps2/samples).to_global_data(),
fp2.to_global_data(), rtol=0.2)
def test_vdot(self):
s = ift.RGSpace((10,))
......@@ -135,4 +135,4 @@ class Test_Functionality(unittest.TestCase):
x2 = ift.RGSpace((150,))
m = ift.Field((x1, x2), val=.5)
res = m.vdot(m, spaces=1)
assert_allclose(ift.dobj.to_global_data(res.val), 37.5)
assert_allclose(res.to_global_data(), 37.5)
......@@ -33,6 +33,7 @@ _p_spaces = _p_RG_spaces + [ift.HPSpace(17), ift.GLSpace(8, 13)]
_pow_spaces = [ift.PowerSpace(ift.RGSpace((17, 38), harmonic=True))]
class Consistency_Tests(unittest.TestCase):
@expand(product(_h_spaces, [np.float64, np.complex128]))
def testPPO(self, sp, dtype):
......
......
......@@ -54,7 +54,7 @@ class SmoothingOperator_Tests(unittest.TestCase):
op = ift.FFTSmoothingOperator(space, sigma=sigma)
fld = np.zeros(space.shape, dtype=np.float64)
fld[0] = 1.
rand1 = ift.Field(space, ift.dobj.from_global_data(fld))
rand1 = ift.Field.from_global_data(space, fld)
tt1 = op.times(rand1)
assert_allclose(1, tt1.sum())
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment