Commit 04c80477 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'diag_hack' into 'NIFTy_4'

More aggressive combination of diagonal operators

See merge request ift/NIFTy!235
parents 6c61cbec e6b49f93
Pipeline #26517 passed with stages
in 5 minutes and 30 seconds
......@@ -35,7 +35,8 @@ if __name__ == "__main__":
d_space = R.target
power = ift.sqrt(ift.create_power_operator(h_space, p_spec).diagonal)
p_op = ift.create_power_operator(h_space, p_spec)
power = ift.sqrt(p_op(ift.Field.full(h_space, 1.)))
# Creating the mock data
true_sky = nonlinearity(HT(power*sh))
......
......@@ -59,7 +59,7 @@ class NonlinearPowerEnergy(Energy):
self.D = D
self.d = d
self.N = N
self.T = SmoothnessOperator(domain=self.position.domain[0],
self.T = SmoothnessOperator(domain=position.domain[0],
strength=sigma, logarithmic=True)
self.ht = ht
self.Instrument = Instrument
......@@ -76,19 +76,15 @@ class NonlinearPowerEnergy(Energy):
self.inverter = inverter
A = Distributor(exp(.5 * position))
map_s = self.ht(A * xi)
Tpos = self.T(position)
self._gradient = None
for xi_sample in self.xi_sample_list:
map_s = self.ht(A * xi_sample)
LinR = LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.ht, self.Distributor,
self.position, xi_sample)
map_s = ht(A*xi_sample)
LinR = LinearizedPowerResponse(Instrument, nonlinearity, ht,
Distributor, position, xi_sample)
residual = self.d - \
self.Instrument(self.nonlinearity(map_s))
tmp = self.N.inverse_times(residual)
residual = d - Instrument(nonlinearity(map_s))
tmp = N.inverse_times(residual)
lh = 0.5 * residual.vdot(tmp)
grad = LinR.adjoint_times(tmp)
......@@ -100,7 +96,8 @@ class NonlinearPowerEnergy(Energy):
self._gradient += grad
self._value *= 1. / len(self.xi_sample_list)
self._value += 0.5 * self.position.vdot(Tpos)
Tpos = self.T(position)
self._value += 0.5 * position.vdot(Tpos)
self._gradient *= -1. / len(self.xi_sample_list)
self._gradient += Tpos
self._gradient.lock()
......
......@@ -31,20 +31,18 @@ class NonlinearWienerFilterEnergy(Energy):
self.nonlinearity = nonlinearity
self.ht = ht
self.power = power
m = self.ht(self.power*self.position)
self.LinearizedResponse = LinearizedSignalResponse(
Instrument, nonlinearity, ht, power, m)
m = ht(power*position)
residual = d - Instrument(nonlinearity(m))
self.N = N
self.S = S
self.inverter = inverter
t1 = self.S.inverse_times(self.position)
t2 = self.N.inverse_times(residual)
tmp = self.position.vdot(t1) + residual.vdot(t2)
self._value = 0.5 * tmp.real
self._gradient = t1 - self.LinearizedResponse.adjoint_times(t2)
self._gradient.lock()
t1 = S.inverse_times(position)
t2 = N.inverse_times(residual)
self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real
self.R = LinearizedSignalResponse(Instrument, nonlinearity, ht, power,
m)
self._gradient = (t1 - self.R.adjoint_times(t2)).lock()
def at(self, position):
return self.__class__(position, self.d, self.Instrument,
......@@ -62,5 +60,4 @@ class NonlinearWienerFilterEnergy(Energy):
@property
@memo
def curvature(self):
return WienerFilterCurvature(R=self.LinearizedResponse, N=self.N,
S=self.S, inverter=self.inverter)
return WienerFilterCurvature(self.R, self.N, self.S, self.inverter)
......@@ -51,12 +51,11 @@ class WienerFilterEnergy(Energy):
self._curvature = WienerFilterCurvature(R, N, S, inverter)
self._inverter = inverter
if _j is None:
_j = self.R.adjoint_times(self.N.inverse_times(d))
_j = R.adjoint_times(N.inverse_times(d))
self._j = _j
Dx = self._curvature(self.position)
self._value = 0.5*self.position.vdot(Dx) - self._j.vdot(self.position)
self._gradient = Dx - self._j
self._gradient.lock()
self._value = 0.5*position.vdot(Dx) - self._j.vdot(position)
self._gradient = (Dx - self._j).lock()
def at(self, position):
return self.__class__(position=position, d=None, R=self.R, N=self.N,
......
......@@ -61,9 +61,7 @@ class ChainOperator(LinearOperator):
# try to absorb the factor into a DiagonalOperator
for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator):
opsnew[i] = DiagonalOperator(opsnew[i].diagonal*fct,
domain=opsnew[i].domain,
spaces=opsnew[i]._spaces)
opsnew[i] = opsnew[i]._scale(fct)
fct = 1.
break
if fct != 1:
......@@ -75,12 +73,8 @@ class ChainOperator(LinearOperator):
for op in ops:
if (len(opsnew) > 0 and
isinstance(opsnew[-1], DiagonalOperator) and
isinstance(op, DiagonalOperator) and
op._spaces == opsnew[-1]._spaces):
opsnew[-1] = DiagonalOperator(opsnew[-1].diagonal *
op.diagonal,
domain=opsnew[-1].domain,
spaces=opsnew[-1]._spaces)
isinstance(op, DiagonalOperator)):
opsnew[-1] = opsnew[-1]._combine_prod(op)
else:
opsnew.append(op)
ops = opsnew
......@@ -120,9 +114,3 @@ class ChainOperator(LinearOperator):
for op in t_ops:
x = op.apply(x, mode)
return x
def draw_sample(self, dtype=np.float64):
sample = self._ops[-1].draw_sample(dtype)
for op in reversed(self._ops[:-1]):
sample = op.process_sample(sample)
return sample
......@@ -71,32 +71,66 @@ class DiagonalOperator(EndomorphicOperator):
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
if len(self._spaces) != len(diagonal.domain):
raise ValueError("spaces and domain must have the same length")
# if nspc==len(self.diagonal.domain),
# we could do some optimization
for i, j in enumerate(self._spaces):
if diagonal.domain[i] != self._domain[j]:
raise ValueError("domain mismatch")
if self._spaces == tuple(range(len(self._domain))):
self._spaces = None # shortcut
self._diagonal = diagonal.lock()
if self._spaces is not None:
active_axes = []
for space_index in self._spaces:
active_axes += self._domain.axes[space_index]
if self._spaces[0] == 0:
self._ldiag = self._diagonal.local_data
self._ldiag = diagonal.local_data
else:
self._ldiag = self._diagonal.to_global_data()
self._ldiag = diagonal.to_global_data()
locshape = dobj.local_shape(self._domain.shape, 0)
self._reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(locshape)]
self._ldiag = self._ldiag.reshape(self._reshaper)
else:
self._ldiag = self._diagonal.local_data
self._ldiag = diagonal.local_data
self._ldiag.flags.writeable = False
def _skeleton(self, spc):
res = DiagonalOperator.__new__(DiagonalOperator)
res._domain = self._domain
if self._spaces is None or spc is None:
res._spaces = None
else:
res._spaces = tuple(set(self._spaces) | set(spc))
return res
def _scale(self, fct):
if not np.isscalar(fct):
raise TypeError("scalar value required")
res = self._skeleton(())
res._ldiag = self._ldiag*fct
return res
def _add(self, sum):
if not np.isscalar(sum):
raise TypeError("scalar value required")
res = self._skeleton(())
res._ldiag = self._ldiag + sum
return res
def _combine_prod(self, op):
if not isinstance(op, DiagonalOperator):
raise TypeError("DiagonalOperator required")
res = self._skeleton(op._spaces)
res._ldiag = self._ldiag*op._ldiag
return res
def _combine_sum(self, op, selfneg, opneg):
if not isinstance(op, DiagonalOperator):
raise TypeError("DiagonalOperator required")
res = self._skeleton(op._spaces)
res._ldiag = (self._ldiag * (-1 if selfneg else 1) +
op._ldiag * (-1 if opneg else 1))
return res
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -116,11 +150,6 @@ class DiagonalOperator(EndomorphicOperator):
else:
return Field(x.domain, val=x.val/self._ldiag.conj())
@property
def diagonal(self):
""" Returns the diagonal of the Operator."""
return self._diagonal
@property
def domain(self):
return self._domain
......@@ -131,19 +160,16 @@ class DiagonalOperator(EndomorphicOperator):
@property
def inverse(self):
return DiagonalOperator(1./self._diagonal, self._domain, self._spaces)
res = self._skeleton(())
res._ldiag = 1./self._ldiag
return res
@property
def adjoint(self):
return DiagonalOperator(self._diagonal.conjugate(), self._domain,
self._spaces)
def process_sample(self, sample):
if np.issubdtype(self._ldiag.dtype, np.complexfloating):
raise ValueError("cannot draw sample from complex-valued operator")
res = Field.empty_like(sample)
res.local_data[()] = sample.local_data * np.sqrt(self._ldiag)
if np.issubdtype(self._ldiag.dtype, np.floating):
return self
res = self._skeleton(())
res._ldiag = self._ldiag.conjugate()
return res
def draw_sample(self, dtype=np.float64):
......
......@@ -50,10 +50,8 @@ class LaplaceOperator(EndomorphicOperator):
if not isinstance(self._domain[self._space], PowerSpace):
raise ValueError("Operator must act on a PowerSpace.")
self._logarithmic = bool(logarithmic)
pos = self.domain[self._space].k_lengths.copy()
if self.logarithmic:
if logarithmic:
pos[1:] = np.log(pos[1:])
pos[0] = pos[1]-1.
......@@ -74,10 +72,6 @@ class LaplaceOperator(EndomorphicOperator):
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
@property
def logarithmic(self):
return self._logarithmic
def _times(self, x):
axes = x.domain.axes[self._space]
axis = axes[0]
......
......@@ -61,7 +61,7 @@ class ScalingOperator(EndomorphicOperator):
if self._factor == 1.:
return x.copy()
if self._factor == 0.:
return Field.zeros_like(x, dtype=x.dtype)
return Field.zeros_like(x)
if mode == self.TIMES:
return x*self._factor
......@@ -81,6 +81,8 @@ class ScalingOperator(EndomorphicOperator):
@property
def adjoint(self):
if np.issubdtype(type(self._factor), np.floating):
return self
return ScalingOperator(np.conj(self._factor), self._domain)
@property
......@@ -93,11 +95,6 @@ class ScalingOperator(EndomorphicOperator):
return self.TIMES | self.ADJOINT_TIMES
return self._all_ops
def process_sample(self, sample):
if self._factor.imag != 0. or self._factor.real <= 0.:
raise ValueError("Operator not positive definite")
return sample * np.sqrt(self._factor)
def _sample_helper(self, fct, dtype):
if fct.imag != 0. or fct.real <= 0.:
raise ValueError("operator not positive definite")
......
......@@ -72,9 +72,7 @@ class SumOperator(LinearOperator):
for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator):
sum *= (-1 if negnew[i] else 1)
opsnew[i] = DiagonalOperator(opsnew[i].diagonal+sum,
domain=opsnew[i].domain,
spaces=opsnew[i]._spaces)
opsnew[i] = opsnew[i]._add(sum)
sum = 0.
break
if sum != 0:
......@@ -90,15 +88,15 @@ class SumOperator(LinearOperator):
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], DiagonalOperator):
diag = ops[i].diagonal*(-1 if neg[i] else 1)
op = ops[i]
opneg = neg[i]
for j in range(i+1, len(ops)):
if (isinstance(ops[j], DiagonalOperator) and
ops[i]._spaces == ops[j]._spaces):
diag += ops[j].diagonal*(-1 if neg[j] else 1)
if isinstance(ops[j], DiagonalOperator):
op = op._combine_sum(ops[j], opneg, neg[j])
opneg = False
processed[j] = True
opsnew.append(DiagonalOperator(diag, ops[i].domain,
ops[i]._spaces))
negnew.append(False)
opsnew.append(op)
negnew.append(opneg)
else:
opsnew.append(ops[i])
negnew.append(neg[i])
......
......@@ -90,5 +90,5 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_diagonal(self, space):
diag = ift.Field.from_random('normal', domain=space)
D = ift.DiagonalOperator(diag)
diag_op = D.diagonal
diag_op = D(ift.Field.full(space, 1.))
assert_allclose(diag.to_global_data(), diag_op.to_global_data())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment