Commit b4016a99 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cosmetics

parents 3124e8ad 06b39731
......@@ -16,7 +16,8 @@ if __name__ == '__main__':
sh = S.draw_sample()
s = FFT(sh)
u = ift.Field(s_space, val = -12)
u = ift.Field(s_space, val = -12.)
u.val[200] = 1
u.val[300] = 3
u.val[500] = 4
......
......@@ -10,16 +10,15 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1 import AxesGrid
np.random.seed(42)
if __name__ == '__main__':
path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits'
data = load_data(path)
alpha = 1.3
myEnergy = build_problem(data, alpha=alpha, BFGS=True)
myEnergy = build_problem(data, alpha=alpha, BFGS=False)
for i in range(10):
myEnergy = problem_iteration(myEnergy, iterations=10, BFGS=True)
myEnergy = problem_iteration(myEnergy, iterations=10, BFGS=False)
A = ift.FFTSmoothingOperator(myEnergy.position.domain, sigma=2.)
plt.magma()
size = 15
......@@ -42,7 +41,7 @@ if __name__ == '__main__':
plt.axis('off')
ax = plt.gca()
im = ax.imshow(A(myEnergy.point_like).val, norm=LogNorm(vmin=vmin, vmax=vmax))
im = ax.imshow(myEnergy.point_like.val, norm=LogNorm(vmin=vmin, vmax=vmax))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
......@@ -67,7 +66,7 @@ if __name__ == '__main__':
plt.figure()
fig, ax = plt.subplots(1, 3, figsize=(6, 3))
fig, ax = plt.subplots(1, 3, figsize=(6, 4))
plt.suptitle('zoomed in section', size=size)
# fig.tight_layout()
......@@ -120,12 +119,3 @@ if __name__ == '__main__':
for cax in grid.cbar_axes:
cax.toggle_label(True)
plt.close()
plt.figure()
power = ift.power_analyze(myEnergy.diffuse)
k_lengths = power.domain.k_lenghts
plt.plot(power.val, k_lengths, 'k-')
plt.yscale('log')
plt.xscale('log')
plt.title('diffuse power')
......@@ -5,7 +5,7 @@ from nifty4.library.nonlinearities import PositiveTanh
class SeparationEnergy(ift.Energy):
def __init__(self, position, data, alpha, correlation, inverter=None):
if (position>9.).any() or (position<-9.).any():
if (position > 9.).any() or (position < -9.).any():
raise ValueError("position outside allowed range")
super(SeparationEnergy, self).__init__(position=position)
......@@ -14,7 +14,10 @@ class SeparationEnergy(ift.Energy):
self._inverter = inverter
self._q = 1e-40
h_space = correlation.domain[0] if correlation is not None else position.domain[0].get_default_codomain()
if correlation is None:
h_space = position.domain[0].get_default_codomain()
else:
h_space = correlation.domain[0]
FFT = ift.FFTOperator(h_space, position.domain[0])
self._ptanh = PositiveTanh()
......@@ -24,8 +27,10 @@ class SeparationEnergy(ift.Energy):
s = ift.log(data * one_m_a)
if correlation is None:
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
correlation = ift.power_analyze(FFT.inverse_times(s), binbounds=binbounds)
binbounds = ift.PowerSpace.useful_binbounds(h_space,
logarithmic=False)
correlation = ift.power_analyze(FFT.inverse_times(s),
binbounds=binbounds)
correlation = ift.create_power_operator(h_space, correlation)
self._correlation = correlation
......@@ -45,21 +50,23 @@ class SeparationEnergy(ift.Energy):
s_p = -a_p/one_m_a
diffuse = Sis * s_p
point = (alpha - 1. - qexpmu) * a_p/a
u_p = a_p/a
point = (alpha - 1. - qexpmu) * u_p
det = position / var_x
det += s_p
self._gradient = (diffuse + point + det).lock()
if inverter is not None: # curvature is needed, remember some values
self._qexpmu = qexpmu
self._s_p = s_p
self._FFT = FFT
self._var_x = var_x
self._point = point
self._u_p = a_p/a
if inverter is not None: # curvature is needed
point = qexpmu * u_p ** 2
R = FFT.inverse * s_p
N = self._correlation
S = ift.DiagonalOperator(1./(point + 1./var_x))
self._curvature = ift.library.WienerFilterCurvature(
R=R, N=N, S=S, inverter=self._inverter)
def at(self, position):
return self.__class__(position, self._data, self._alpha, self._correlation, self._inverter)
return self.__class__(position, self._data, self._alpha,
self._correlation, self._inverter)
@property
def diffuse(self):
......@@ -74,7 +81,8 @@ class SeparationEnergy(ift.Energy):
return SeparationEnergy(position, data, alpha, None, inverter)
def with_new_correlation(self):
return SeparationEnergy(self._position, self._data, self._alpha, None, self._inverter)
return SeparationEnergy(self._position, self._data, self._alpha, None,
self._inverter)
@property
def value(self):
......@@ -86,16 +94,11 @@ class SeparationEnergy(ift.Energy):
@property
def curvature(self):
point = self._qexpmu * self._u_p ** 2
R = self._FFT.inverse * self._s_p
N = self._correlation
S = ift.DiagonalOperator(1/(point + 1/self._var_x))
return ift.library.WienerFilterCurvature(R=R, N=N, S=S,
inverter=self._inverter)
return self._curvature
def longest_step(self, dir):
p = self.position.to_global_data()
d = dir.to_global_data()
lim = np.where(d>0, 9, -9)
p = self.position.local_data
d = dir.local_data
lim = np.where(d > 0, 9, -9)
dist = (lim-p)/d
return np.min(dist)
return ift.Field.from_local_data(self.position.domain, dist).min()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment