Commit 4fbb6346 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

compactification

parent cb589a45
......@@ -438,7 +438,7 @@ def redistribute(arr, dist=None, nodist=None):
def transpose(arr):
if len(arr.shape) != 2 or arr._distaxis != 0:
raise ValueError("bad input")
raise ValueError("bad input")
ssz0 = arr._data.size//arr.shape[1]
ssz = np.empty(ntask, dtype=np.int)
rszall = arr.size//arr.shape[1]*_shareSize(arr.shape[1], ntask, rank)
......@@ -450,7 +450,7 @@ def transpose(arr):
for i in range(ntask):
lo, hi = _shareRange(arr.shape[1], ntask, i)
ssz[i] = ssz0*(hi-lo)
sbuf[ofs:ofs+ssz[i]] = arr._data[:,lo:hi].flat
sbuf[ofs:ofs+ssz[i]] = arr._data[:, lo:hi].flat
ofs += ssz[i]
rsz[i] = rsz0*_shareSize(arr.shape[0], ntask, i)
ssz *= arr._data.itemsize
......@@ -463,10 +463,11 @@ def transpose(arr):
del sbuf # free memory
arrnew = empty((arr.shape[1], arr.shape[0]), dtype=arr.dtype, distaxis=0)
ofs = 0
sz2 = _shareSize(arr.shape[1], ntask, rank)
for i in range(ntask):
lo, hi = _shareRange(arr.shape[0], ntask, i)
sz = rsz[i]//arr._data.itemsize
arrnew._data[:,lo:hi] = rbuf[ofs:ofs+sz].reshape(hi-lo,-1).T
arrnew._data[:, lo:hi] = rbuf[ofs:ofs+sz].reshape(hi-lo, sz2).T
ofs += sz
return arrnew
......
......@@ -18,12 +18,10 @@ class CriticalPowerCurvature(EndomorphicOperator):
The smoothness prior contribution to the curvature.
"""
# ---Overwritten properties and methods---
def __init__(self, theta, T):
super(CriticalPowerCurvature, self).__init__()
self.theta = DiagonalOperator(theta)
self.T = T
super(CriticalPowerCurvature, self).__init__()
@property
def preconditioner(self):
......@@ -32,8 +30,6 @@ class CriticalPowerCurvature(EndomorphicOperator):
def _times(self, x):
return self.T(x) + self.theta(x)
# ---Mandatory properties and methods---
@property
def domain(self):
return self.theta.domain
......
......@@ -21,7 +21,7 @@ class CriticalPowerEnergy(Energy):
position : Field,
The current position of this energy.
m : Field,
The map whichs power spectrum has to be inferred
The map whose power spectrum has to be inferred
D : EndomorphicOperator,
The curvature of the Gaussian encoding the posterior covariance.
If not specified, the map is assumed to be no reconstruction.
......@@ -100,8 +100,7 @@ class CriticalPowerEnergy(Energy):
@memo
def curvature(self):
curv = CriticalPowerCurvature(theta=self._theta, T=self.T)
return InversionEnabler(curv, inverter=self._inverter,
preconditioner=curv.preconditioner)
return InversionEnabler(curv, inverter=self._inverter)
# ---Added properties and methods---
......
......@@ -91,7 +91,8 @@ class RGRGTransformation(Transformation):
if True:
if oldax != 0:
raise ValueError("bad distribution")
ldat2 = ldat.reshape((ldat.shape[0],-1))
ldat2 = ldat.reshape((ldat.shape[0],
np.prod(ldat.shape[1:])))
shp2d = (x.val.shape[0], np.prod(x.val.shape[1:]))
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
......
......@@ -48,50 +48,27 @@ class InversionEnabler(LinearOperator):
def op(self):
return self._op
def _times(self, x):
def _operation(self, x, o1, o2, tdom):
try:
res = self._op._times(x)
return o1(x)
except NotImplementedError:
x0 = Field.zeros(self.target, dtype=x.dtype)
(result, convergence) = self._inverter(QuadraticEnergy(
A=self._op.inverse_times,
b=x, position=x0),
preconditioner=self._preconditioner)
res = result.position
return res
x0 = Field.zeros(tdom, dtype=x.dtype)
energy = QuadraticEnergy(A=o2, b=x, position=x0)
r = self._inverter(energy, preconditioner=self._preconditioner)[0]
return r.position
def _times(self, x):
return self._operation(x, self._op._times,
self._op.inverse_times, self.target)
def _adjoint_times(self, x):
try:
res = self._op._adjoint_times(x)
except NotImplementedError:
x0 = Field.zeros(self.domain, dtype=x.dtype)
(result, convergence) = self._inverter(QuadraticEnergy(
A=self.adjoint_inverse_times,
b=x, position=x0),
preconditioner=self._preconditioner)
res = result.position
return res
return self._operation(x, self._op._adjoint_times,
self._op._adjoint_inverse_times, self.domain)
def _inverse_times(self, x):
try:
res = self._op._inverse_times(x)
except NotImplementedError:
x0 = Field.zeros(self.domain, dtype=x.dtype)
(result, convergence) = self._inverter(QuadraticEnergy(
A=self.times,
b=x, position=x0),
preconditioner=self._preconditioner)
res = result.position
return res
return self._operation(x, self._op._inverse_times,
self._op._times, self.domain)
def _adjoint_inverse_times(self, x):
try:
res = self._op._adjoint_inverse_times(x)
except NotImplementedError:
x0 = Field.zeros(self.target, dtype=x.dtype)
(result, convergence) = self._inverter(QuadraticEnergy(
A=self.adjoint_times,
b=x, position=x0),
preconditioner=self._preconditioner)
res = result.position
return res
return self._operation(x, self._op._adjoint_inverse_times,
self._op._adjoint_times, self.target)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..minimization.quadratic_energy import QuadraticEnergy
from ..field import Field
from .linear_operator import LinearOperator
class InversionEnabler(LinearOperator):
def __init__(self, op, inverter, preconditioner=None):
self._op = op
self._inverter = inverter
if preconditioner is None and hasattr(op, "preconditioner"):
self._preconditioner = op.preconditioner
else:
self._preconditioner = preconditioner
super(InversionEnabler, self).__init__()
@property
def domain(self):
return self._op.domain
@property
def target(self):
return self._op.target
@property
def unitary(self):
return self._op.unitary
@property
def op(self):
return self._op
def _operation(self, x, o1, o2, tdom):
try:
res = o1(x)
except NotImplementedError:
x0 = Field.zeros(tdom, dtype=x.dtype)
energy = QuadraticEnergy(A=o2, b=x, position=x0)
r = self._inverter(energy, preconditioner=self._preconditioner)[0]
res = r.position
return res
def _times(self, x):
return self._operation(x, self._op._times,
self._op.inverse_times, self.target)
def _adjoint_times(self, x):
return self._operation(x, self._op._adjoint_times,
self._op._adjoint_inverse_times, self.domain)
def _inverse_times(self, x):
return self._operation(x, self._op._inverse_times,
self._op._times, self.domain)
def _adjoint_inverse_times(self, x):
return self._operation(x, self._op._adjoint_inverse_times,
self._op._adjoint_times, self.target)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment