Commit 8563ad22 authored by Jakob Knollmueller's avatar Jakob Knollmueller

Merge branch 'nifty4' into 'master'

# Conflicts:
#   1d_separation.py
#   hubble_separation.py
parents 06b39731 a2eabd51
from point_separation import build_problem, problem_iteration
from nifty2go import *
import nifty4 as ift
import numpy as np
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Palatino']})
......@@ -8,18 +8,15 @@ from matplotlib import pyplot as plt
np.random.seed(42)
if __name__ == '__main__':
s_space = RGSpace([1024])
FFT = FFTOperator(s_space)
h_space = FFT.target[0]
p_space = PowerSpace(h_space)
sp = Field(p_space, val=1./(1+p_space.k_lengths)**2.5 )
sh = power_synthesize(sp)
s = FFT.adjoint_times(sh)
# u = np.random.exponential(10,1024)
# u = log(Field(s_space, val = u))
u = Field(s_space, val = -12.)
s_space = ift.RGSpace([1024])
h_space = s_space.get_default_codomain()
FFT = ift.FFTOperator(h_space)
p_spec = lambda k: (1./(1+k)**2.5)
S = ift.create_power_operator(h_space, power_spectrum=p_spec)
sh = S.draw_sample()
s = FFT(sh)
u = ift.Field(s_space, val = -12)
u.val[200] = 1
u.val[300] = 3
u.val[500] = 4
......@@ -30,7 +27,7 @@ if __name__ == '__main__':
u.val[652] = 1
u.val[1002] = 2.5
d = exp(s) + exp(u)
d = ift.exp(s) + ift.exp(u)
data = d.val
energy1 = build_problem(data,1.25)
......@@ -48,17 +45,17 @@ if __name__ == '__main__':
f, (ax0, ax1,ax2) = plt.subplots(3, sharex=True, sharey=True)
plt.suptitle('diffuse components', size=size)
ax0.plot(exp(energy1.s).val, 'k-')
ax0.plot(ift.exp(energy1.s).val, 'k-')
ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'$\alpha = 1.25$', size=size)
ax0.set_ylim(1e-1,1e3)
ax0.set_yscale("log")
ax1.plot(exp(energy2.s).val, 'k-')
ax1.plot(ift.exp(energy2.s).val, 'k-')
ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
ax2.plot(exp(energy3.s).val, 'k-')
ax2.plot(ift.exp(energy3.s).val, 'k-')
ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'$\alpha = 1.75$', size=size)
......@@ -69,17 +66,17 @@ if __name__ == '__main__':
plt.suptitle('point-like components', size=size)
ax0.plot(exp(energy1.u).val, 'k-')
ax0.plot(ift.exp(energy1.u).val, 'k-')
ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'$\alpha = 1.25$', size=size)
ax0.set_ylim(1e-1,1e3)
ax0.set_yscale("log")
ax1.plot(exp(energy2.u).val, 'k-')
ax1.plot(ift.exp(energy2.u).val, 'k-')
ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
ax2.plot(exp(energy3.u).val, 'k-')
ax2.plot(ift.exp(energy3.u).val, 'k-')
ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'$\alpha = 1.75$', size=size)
......@@ -100,10 +97,10 @@ if __name__ == '__main__':
ax0.set_ylabel(r'data', size=size)
ax1.plot(exp(s).val, 'k-')
ax1.plot(ift.exp(s).val, 'k-')
ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'diffuse', size=size)
ax2.plot(exp(u).val, 'k-')
ax2.plot(ift.exp(u).val, 'k-')
ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'point-like', size=size)
......
import matplotlib
# matplotlib.use('agg')
matplotlib.use('module://kivy.garden.matplotlib.backend_kivy')
matplotlib.use('agg')
# matplotlib.use('module://kivy.garden.matplotlib.backend_kivy')
from point_separation import build_problem, problem_iteration,load_data
from point_separation import build_multi_problem, multi_problem_iteration,load_data
from kivy.app import App
from kivy.uix.widget import Widget
......@@ -12,27 +12,41 @@ from kivy.uix.image import Image
from kivy.properties import ObjectProperty, StringProperty, NumericProperty
from kivy.uix.boxlayout import BoxLayout
from kivy.clock import Clock, mainthread
from os.path import sep, expanduser, isdir, dirname, join
from kivy.garden.filebrowser import FileBrowser
from kivy.utils import platform
from kivy.uix.textinput import TextInput
from kivy.uix.screenmanager import ScreenManager, Screen, NoTransition
import numpy as np
import re
import threading
# from matplotlib import pyplot as plt
from matplotlib import pyplot as plt
import nifty2go as ift
import matplotlib.pyplot as plt
import time
from kivy.uix.progressbar import ProgressBar
class FloatInput(TextInput):
pat = re.compile('[^0-9]')
def insert_text(self, substring, from_undo=False):
print substring
pat = self.pat
if '.' in self.text:
s = re.sub(pat, '', substring)
else:
s = '.'.join([re.sub(pat, '', s) for s in substring.split('.', 1)])
return super(FloatInput, self).insert_text(s, from_undo=from_undo)
class IntInput(TextInput):
pat = re.compile('[^0-9]')
def insert_text(self, substring, from_undo=False):
pat = self.pat
s = re.sub(pat, '', substring)
return super(IntInput, self).insert_text(s, from_undo=from_undo)
class MyImage(BoxLayout):
......@@ -45,6 +59,9 @@ class MyImage(BoxLayout):
class MyAlphaWidget(BoxLayout):
alpha = NumericProperty(None)
pass
class IterationWidget(BoxLayout):
iteration = NumericProperty(None)
pass
class ResultsPathWidget(BoxLayout):
pass
......@@ -52,10 +69,13 @@ class ResultsPathWidget(BoxLayout):
class DataPathWidget(BoxLayout):
pass
class ImageWidget(ScreenManager):
class DisplayWidget(ScreenManager):
def reload(self):
for child in self.children:
child.reload()
class ImageWidget(BoxLayout):
def reload(self):
self.image_widget.reload()
class MenuWidget(BoxLayout):
pass
......@@ -80,8 +100,19 @@ class ActionWidget(BoxLayout):
pass
class DisplayChoiceWidget(BoxLayout):
pass
class DisplayOptionWidget(BoxLayout):
class GlobalScreenManager(ScreenManager):
pass
class MainScreen(Screen):
pass
class FileScreen(Screen):
pass
class PathScreen(Screen):
def is_dir(self, directory, filename):
return isdir(join(directory, filename))
class MyWidget(BoxLayout):
image_widget = ObjectProperty(None)
......@@ -91,6 +122,15 @@ class MyWidget(BoxLayout):
result_path = StringProperty(None)
alpha = NumericProperty(None)
class MyPathBrowser(FileBrowser):
pass
class MyFileBrowser(FileBrowser):
filters = ['*.fits', '*.png', '*.jpg']
pass
class SeparatorApp(App):
stop = threading.Event()
......@@ -103,17 +143,23 @@ class SeparatorApp(App):
power_image = StringProperty(None)
vmin = None
vmax = None
iterations = 5
myEnergy = None
iterations = 3
user_path = ''
reconstructing = False
data_loaded = False
def build(self):
self.set_default()
self.trigger = 0
self.root = MyWidget()
self.root.image_widget.transition = NoTransition()
self.root = GlobalScreenManager()
self.image_widget = self.root.main.image_widget.image_widget.image_widget
self.root.transition = NoTransition()
self.image_widget.transition = NoTransition()
return self.root
def set_default(self):
self.data_path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits'
self.data_path = ''
self.result_path = ''
self.alpha = 1.5
self.path = ''
......@@ -121,68 +167,115 @@ class SeparatorApp(App):
self.diffuse_image = self.path + 'placeholder.png'
self.points_image = self.path + 'placeholder.png'
self.power_image = self.path + 'placeholder.png'
if platform == 'win':
self.user_path = dirname(expanduser('~')) + sep + 'Documents'
else:
self.user_path = expanduser('~') + sep + 'Documents'
def load_data(self):
def load_data(self, selection):
print selection
self.data_path = selection[0]
threading.Thread(target=self.load_data_thread).start()
self.root.current = 'main'
def load_data_thread(self):
self.data = load_data(self.data_path)
self.vmin = np.log(self.data.min())
self.max = np.log(self.data.max())
self.plot_array(np.log(self.data), 'data.png')
self.plot_data()
self.set_data_loaded(True)
self.set_data_image()
self.update_plots()
def run_separation(self):
threading.Thread(target= self.run_separation_thread).start()
if not self.reconstructing and self.data_loaded:
threading.Thread(target= self.run_separation_thread).start()
def run_separation_thread(self):
self.myEnergy = build_problem(self.data, self.alpha)
self.plot_array(self.myEnergy.u.val, 'points.png')
self.plot_array(self.myEnergy.s.val, 'diffuse.png')
self.set_reconstructing(True)
self.myEnergy = build_multi_problem(self.data, self.alpha)
self.plot_components(self.path)
self.set_image_paths()
self.update_plots()
for i in range(self.iterations):
self.myEnergy = problem_iteration(self.myEnergy)
self.plot_array(self.myEnergy.u.val, 'points.png')
self.plot_array(self.myEnergy.s.val, 'diffuse.png')
self.myEnergy = multi_problem_iteration(self.myEnergy)
self.plot_components(self.path)
self.update_plots()
self.set_reconstructing(False)
def save_results(self):
pass
if self.myEnergy is not None:
self.root.current = 'path'
def select_path(self, path):
self.result_path = path[0]
threading.Thread(target=self.save_data_thread).start()
self.root.current = 'main'
def save_data_thread(self):
np.savetxt(self.result_path + '/points.csv', np.exp(self.myEnergy.u.val))
np.savetxt(self.result_path + '/diffus.csv', np.exp(self.myEnergy.s.val))
self.plot_components(self.result_path)
def set_result_path(self, path):
self.result_path = path
print path
@mainthread
def update_plots(self):
self.root.image_widget.reload()
self.image_widget.reload()
@mainthread
def set_reconstructing(self, reconstructing):
self.reconstructing = reconstructing
@mainthread
def set_data_loaded(self, loaded):
self.data_loaded = loaded
def set_data_path(self, path):
self.data_path = path
@mainthread
def set_data_image(self):
self.data_image = self.result_path + 'data.png'
self.data_image = self.path + 'data.png'
@mainthread
def set_image_paths(self):
self.points_image = self.result_path + 'points.png'
self.diffuse_image = self.result_path + 'diffuse.png'
self.points_image = self.path + 'points.png'
self.diffuse_image = self.path + 'diffuse.png'
def plot_array(self, array, path):
plt.imsave(path, array, vmin=self.vmin, vmax=self.vmax)
def plot_data(self):
if self.data.shape[0] == 1:
plt.imsave(self.path+'data.png', self.data[0], vmin=self.vmin, vmax=self.vmax)
else:
plt.imsave(self.path+ 'data.png', self.data/255.)
def plot_components(self, path):
diffuse = np.empty_like(self.data)
points = np.empty_like(self.data)
for i in range(len(self.myEnergy)):
diffuse[...,i] = np.exp(self.myEnergy[i].s.val)
points[...,i] = np.exp(self.myEnergy[i].u.val)
if len(self.myEnergy) == 1:
plt.imsave(path+'diffuse.png', diffuse[...,0], vmin=self.vmin, vmax=self.vmax)
plt.imsave(path+'points.png', points[...,0], vmin=self.vmin, vmax=self.vmax)
else:
plt.imsave(self.path+ 'diffuse.png', diffuse/255.)
plt.imsave(self.path+ 'points.png', points/255.)
def set_alpha(self, alpha):
self.alpha = alpha
print alpha
if alpha == '':
pass
else:
self.alpha = alpha
def set_iterations(self,iterations):
if iterations == '':
pass
else:
self.iterations = int(iterations)
def on_stop(self):
self.stop.set()
if __name__ == '__main__':
plt.viridis()
plt.gray()
SeparatorApp().run()
from point_separation import build_problem, problem_iteration, load_data
from nifty2go import *
from nifty4 import *
import numpy as np
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Palatino']})
......@@ -10,6 +10,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1 import AxesGrid
np.random.seed(42)
if __name__ == '__main__':
path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits'
......@@ -21,7 +22,8 @@ if __name__ == '__main__':
myEnergy = build_problem(data, alpha=alpha)
for i in range(10):
myEnergy = problem_iteration(myEnergy)
A = FFTSmoothingOperator(myEnergy.s.domain, sigma=2.)
plt.magma()
size = 15
vmin = data.min()+0.01
vmax = 0.01*data.max()
......@@ -42,7 +44,7 @@ if __name__ == '__main__':
plt.axis('off')
ax = plt.gca()
im = ax.imshow(np.exp(myEnergy.u.val), norm=LogNorm(vmin=vmin, vmax=vmax))
im = ax.imshow(A(exp(myEnergy.u)).val, norm=LogNorm(vmin=vmin, vmax=vmax))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
......@@ -55,7 +57,8 @@ if __name__ == '__main__':
plt.title('data', size=size)
plt.axis('off')
ax = plt.gca()
im = ax.imshow(data, norm=LogNorm(vmin=vmin, vmax=vmax))
dat = Field(myEnergy.s.domain,val=data)
im = ax.imshow((dat).val, norm=LogNorm(vmin=vmin, vmax=vmax))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
......@@ -66,29 +69,29 @@ if __name__ == '__main__':
plt.figure()
fig, ax = plt.subplots(1, 3, figsize=(6, 4))
fig, ax = plt.subplots(1, 3, figsize=(6, 3))
plt.suptitle('zoomed in section', size=size)
# fig.tight_layout()
vmin = data.min() + 0.0001
vmax = 0.001*data.max()
im = ax[0].imshow(data[600:700,650:750],norm=LogNorm(vmin=vmin, vmax=vmax))
ax[0].set_title('data',size = 15)
vmax = 0.001 * data.max()
im = ax[0].imshow(data[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
ax[0].set_title('data', size=15)
ax[0].axis('off')
ax[1].imshow(exp(myEnergy.s).val[600:700, 650:750],norm=LogNorm(vmin=vmin, vmax=vmax))
ax[1].set_title('diffuse',size = 15)
ax[1].imshow(exp(myEnergy.s).val[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
ax[1].set_title('diffuse', size=15)
ax[1].axis('off')
ax[2].imshow(exp(myEnergy.u).val[600:700, 650:750],norm=LogNorm(vmin=vmin, vmax=vmax))
ax[2].set_title('point-like',size = 15)
ax[2].imshow(exp(myEnergy.u).val[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
ax[2].set_title('point-like', size=15)
ax[2].axis('off')
# cax = fig.add_axes([0., 0.9, 0.03, 0.8])
cb = fig.colorbar(im, ax=ax.ravel().tolist(), orientation='horizontal', pad = 0.01)
cb.set_label('flux', size = 15)
fig.subplots_adjust(left=None, bottom=None, right=None, top=None,
wspace=0.01, hspace=None)
cb = fig.colorbar(im, ax=ax.ravel().tolist(), orientation='horizontal', pad=0.01)
cb.set_label('flux', size=15)
fig.subplots_adjust(left=None, bottom=0.25, right=None, top=None,
wspace=0.01, hspace=None)
plt.savefig('hubble_zoom.pdf')
......@@ -103,16 +106,28 @@ if __name__ == '__main__':
cbar_pad=0.1
)
im = grid[0].imshow(data[600:700,650:750],norm=LogNorm(vmin=vmin, vmax=vmax))#, extent=extent, interpolation="not")
im = grid[1].imshow(exp(myEnergy.s).val[600:700, 650:750],norm=LogNorm(vmin=vmin, vmax=vmax))#, extent=extent, interpolation="not")
im = grid[2].imshow(exp(myEnergy.u).val[600:700, 650:750],norm=LogNorm(vmin=vmin, vmax=vmax))#, extent=extent, interpolation="not")
im = grid[0].imshow(data[600:700, 650:750],
norm=LogNorm(vmin=vmin, vmax=vmax)) # , extent=extent, interpolation="not")
im = grid[1].imshow(exp(myEnergy.s).val[600:700, 650:750],
norm=LogNorm(vmin=vmin, vmax=vmax)) # , extent=extent, interpolation="not")
im = grid[2].imshow(exp(myEnergy.u).val[600:700, 650:750],
norm=LogNorm(vmin=vmin, vmax=vmax)) # , extent=extent, interpolation="not")
grid[0].axis('off')
grid[1].axis('off')
grid[2].axis('off')
grid[0].set_label('data')
#plt.colorbar(im, cax = grid.cbar_axes[0])
# plt.colorbar(im, cax = grid.cbar_axes[0])
cb = grid.cbar_axes[0].colorbar(im)
for cax in grid.cbar_axes:
cax.toggle_label(True)
plt.close()
plt.figure()
power = power_analyze(exp(myEnergy.s))
k_lengths = power.domain.k_lenghts
plt.plot(power.val, k_lengths, 'k-')
plt.yscale('log')
plt.xscale('log')
plt.title('diffuse power')
from nifty2go import *
import nifty4 as ift
import numpy as np
# from matplotlib import pyplot as plt
from matplotlib import pyplot as plt
from astropy.io import fits
from separation_energy import SeparationEnergy
from nifty2go.library.nonlinearities import PositiveTanh
from nifty4.library.nonlinearities import PositiveTanh
def load_data(path):
# if path[-5:] == '.fits':
data = fits.open(path)[1].data
# else:
# data = plt.imread(path)
if path[-5:] == '.fits':
data = fits.open(path)[1].data
else:
data = plt.imread(path)[:,:,0]
data = data.clip(min=0.001)
data = np.ndarray.astype(data, float)
return data
def build_problem(data, alpha):
s_space = RGSpace(data.shape, distances=len(data.shape) * [1])
data = Field(s_space,val=data)
FFT = FFTOperator(s_space)
h_space = FFT.target[0]
binbounds = PowerSpace.useful_binbounds(h_space, logarithmic = False)
p_space = PowerSpace(h_space, binbounds=binbounds)
initial_spectrum = power_analyze(FFT(log(data)), binbounds=p_space.binbounds)
initial_correlation = create_power_operator(h_space, initial_spectrum)
initial_x = Field(s_space, val=-1.)
alpha = Field(s_space, val=alpha)
q = Field(s_space, val=10e-40)
s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1])
h_space = s_space.get_default_codomain()
data = ift.Field(s_space,val=data)
FFT = ift.FFTOperator(h_space)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False)
p_space = ift.PowerSpace(h_space, binbounds=binbounds)
initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)), binbounds=p_space.binbounds)
initial_correlation = ift.create_power_operator(h_space, initial_spectrum)
initial_x = ift.Field(s_space, val=-1.)
alpha = ift.Field(s_space, val=alpha)
q = ift.Field(s_space, val=10e-40)
pos_tanh = PositiveTanh()
ICI = GradientNormController(verbose=False, name="ICI",
iteration_limit=500,
ICI = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-5)
inverter = ConjugateGradient(controller=ICI)
inverter = ift.ConjugateGradient(controller=ICI)
parameters = dict(data=data, correlation=initial_correlation,
alpha=alpha, q=q,
......@@ -41,15 +40,15 @@ def build_problem(data, alpha):
return separationEnergy
def problem_iteration(energy, iterations=3):
controller = GradientNormController(verbose=True, tol_abs_gradnorm=0.00000001, iteration_limit=iterations)
minimizer = RelaxedNewton(controller=controller)
controller = ift.GradientNormController(name="test1", tol_abs_gradnorm=0.00000001, iteration_limit=iterations)
minimizer = ift.RelaxedNewton(controller=controller)
energy, convergence = minimizer(energy)
new_position = energy.position
h_space = energy.correlation.domain[0]
FFT = energy.FFT
binbounds = PowerSpace.useful_binbounds(h_space, logarithmic=False)
new_power = power_analyze(FFT(energy.s), binbounds=binbounds)
new_correlation = create_power_operator(h_space, new_power)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds)
new_correlation = ift.create_power_operator(h_space, new_power)
new_parameters = energy.parameters
new_parameters['correlation'] = new_correlation
new_energy = SeparationEnergy(new_position, new_parameters)
......
from nifty2go import Energy, Field, log, exp, DiagonalOperator
from nifty2go.library import WienerFilterCurvature
from nifty4 import Energy, Field, log, exp, DiagonalOperator
from nifty4.library import WienerFilterCurvature
class SeparationEnergy(Energy):
......@@ -19,7 +19,7 @@ class SeparationEnergy(Energy):
self.q = parameters['q']
pos_tanh = parameters['pos_tanh']
self.S = self.FFT.adjoint * self.correlation * self.FFT
self.S = self.FFT * self.correlation * self.FFT.adjoint
self.a = pos_tanh(self.position)
self.a_p = pos_tanh.derivative(self.position)
......@@ -52,7 +52,7 @@ class SeparationEnergy(Energy):
@property
def curvature(self):
point = self.q * exp(-self.u) * self.u_p ** 2
R = self.FFT * self.s_p
R = self.FFT.inverse * self.s_p
N = self.correlation
S = DiagonalOperator(1/(point + 1/self.var_x))
return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)
......@@ -27,30 +27,19 @@ TextInput:
text: unichr(945)
FloatInput
text: '1.5'
on_text: app.set_alpha(float(self.text))
<ResultsPathWidget>:
orientation:'horizontal'
Label:
text:'path to results'
TextInput:
text: ''
multiline: False
on_text: app.set_result_path(self.text)
<DataPathWidget>:
orientation:'horizontal'
on_text: app.set_alpha(self.text)
<IterationWidget>:
orientation: 'horizontal'
Label:
text: 'location of the data'
TextInput:
text: 'hst_05195_01_wfpc2_f702w_pc_sci.fits'
multiline: False
on_text: app.set_data_path(self.text)
text: 'iterations'
IntInput
text: '3'
on_text: app.set_iterations(self.text)
<FloatInput>:
multiline: False
<IntInput>:
multiline: False
<MyImage>:
text: self.text
......@@ -94,7 +83,7 @@ TextInput:
id: power
source: app.power_image
text: 'power'
<ImageWidget>
<DisplayWidget>
all: all
points: points
diffuse: diffuse
......@@ -154,13 +143,9 @@ TextInput:
orientation: 'vertical'
MyAlphaWidget:
size_hint: 1,0.1
DataPathWidget:
size_hint: 1,0.1
text: 'location of the data'
ResultsPathWidget:
IterationWidget:
size_hint: 1,0.1
text: 'path to results'
DisplayOptionWidget
ActionWidget
......@@ -168,39 +153,84 @@ TextInput:
orientation: 'horizontal'
Button:
text: 'load data'
on_press: app.load_data()
on_press: app.root.current = 'file'
Button:
text: 'run separation'
on_press: app.run_separation()
Button:
text: 'save results'
on_press: app.save_results()