Commit 8563ad22 authored by Jakob Knollmueller's avatar Jakob Knollmueller

Merge branch 'nifty4' into 'master'

# Conflicts:
#   1d_separation.py
#   hubble_separation.py
parents 06b39731 a2eabd51
from point_separation import build_problem, problem_iteration from point_separation import build_problem, problem_iteration
from nifty2go import * import nifty4 as ift
import numpy as np import numpy as np
from matplotlib import rc from matplotlib import rc
rc('font',**{'family':'serif','serif':['Palatino']}) rc('font',**{'family':'serif','serif':['Palatino']})
...@@ -8,18 +8,15 @@ from matplotlib import pyplot as plt ...@@ -8,18 +8,15 @@ from matplotlib import pyplot as plt
np.random.seed(42) np.random.seed(42)
if __name__ == '__main__': if __name__ == '__main__':
s_space = RGSpace([1024]) s_space = ift.RGSpace([1024])
FFT = FFTOperator(s_space) h_space = s_space.get_default_codomain()
h_space = FFT.target[0] FFT = ift.FFTOperator(h_space)
p_space = PowerSpace(h_space) p_spec = lambda k: (1./(1+k)**2.5)
sp = Field(p_space, val=1./(1+p_space.k_lengths)**2.5 ) S = ift.create_power_operator(h_space, power_spectrum=p_spec)
sh = power_synthesize(sp) sh = S.draw_sample()
s = FFT.adjoint_times(sh) s = FFT(sh)
# u = np.random.exponential(10,1024) u = ift.Field(s_space, val = -12)
# u = log(Field(s_space, val = u))
u = Field(s_space, val = -12.)
u.val[200] = 1 u.val[200] = 1
u.val[300] = 3 u.val[300] = 3
u.val[500] = 4 u.val[500] = 4
...@@ -30,7 +27,7 @@ if __name__ == '__main__': ...@@ -30,7 +27,7 @@ if __name__ == '__main__':
u.val[652] = 1 u.val[652] = 1
u.val[1002] = 2.5 u.val[1002] = 2.5
d = exp(s) + exp(u) d = ift.exp(s) + ift.exp(u)
data = d.val data = d.val
energy1 = build_problem(data,1.25) energy1 = build_problem(data,1.25)
...@@ -48,17 +45,17 @@ if __name__ == '__main__': ...@@ -48,17 +45,17 @@ if __name__ == '__main__':
f, (ax0, ax1,ax2) = plt.subplots(3, sharex=True, sharey=True) f, (ax0, ax1,ax2) = plt.subplots(3, sharex=True, sharey=True)
plt.suptitle('diffuse components', size=size) plt.suptitle('diffuse components', size=size)
ax0.plot(exp(energy1.s).val, 'k-') ax0.plot(ift.exp(energy1.s).val, 'k-')
ax0.yaxis.set_label_position("right") ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'$\alpha = 1.25$', size=size) ax0.set_ylabel(r'$\alpha = 1.25$', size=size)
ax0.set_ylim(1e-1,1e3) ax0.set_ylim(1e-1,1e3)
ax0.set_yscale("log") ax0.set_yscale("log")
ax1.plot(exp(energy2.s).val, 'k-') ax1.plot(ift.exp(energy2.s).val, 'k-')
ax1.yaxis.set_label_position("right") ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'$\alpha = 1.5$', size=size) ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
ax2.plot(exp(energy3.s).val, 'k-') ax2.plot(ift.exp(energy3.s).val, 'k-')
ax2.yaxis.set_label_position("right") ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'$\alpha = 1.75$', size=size) ax2.set_ylabel(r'$\alpha = 1.75$', size=size)
...@@ -69,17 +66,17 @@ if __name__ == '__main__': ...@@ -69,17 +66,17 @@ if __name__ == '__main__':
plt.suptitle('point-like components', size=size) plt.suptitle('point-like components', size=size)
ax0.plot(exp(energy1.u).val, 'k-') ax0.plot(ift.exp(energy1.u).val, 'k-')
ax0.yaxis.set_label_position("right") ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'$\alpha = 1.25$', size=size) ax0.set_ylabel(r'$\alpha = 1.25$', size=size)
ax0.set_ylim(1e-1,1e3) ax0.set_ylim(1e-1,1e3)
ax0.set_yscale("log") ax0.set_yscale("log")
ax1.plot(exp(energy2.u).val, 'k-') ax1.plot(ift.exp(energy2.u).val, 'k-')
ax1.yaxis.set_label_position("right") ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'$\alpha = 1.5$', size=size) ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
ax2.plot(exp(energy3.u).val, 'k-') ax2.plot(ift.exp(energy3.u).val, 'k-')
ax2.yaxis.set_label_position("right") ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'$\alpha = 1.75$', size=size) ax2.set_ylabel(r'$\alpha = 1.75$', size=size)
...@@ -100,10 +97,10 @@ if __name__ == '__main__': ...@@ -100,10 +97,10 @@ if __name__ == '__main__':
ax0.set_ylabel(r'data', size=size) ax0.set_ylabel(r'data', size=size)
ax1.plot(exp(s).val, 'k-') ax1.plot(ift.exp(s).val, 'k-')
ax1.yaxis.set_label_position("right") ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'diffuse', size=size) ax1.set_ylabel(r'diffuse', size=size)
ax2.plot(exp(u).val, 'k-') ax2.plot(ift.exp(u).val, 'k-')
ax2.yaxis.set_label_position("right") ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'point-like', size=size) ax2.set_ylabel(r'point-like', size=size)
......
import matplotlib import matplotlib
# matplotlib.use('agg') matplotlib.use('agg')
matplotlib.use('module://kivy.garden.matplotlib.backend_kivy') # matplotlib.use('module://kivy.garden.matplotlib.backend_kivy')
from point_separation import build_problem, problem_iteration,load_data from point_separation import build_multi_problem, multi_problem_iteration,load_data
from kivy.app import App from kivy.app import App
from kivy.uix.widget import Widget from kivy.uix.widget import Widget
...@@ -12,27 +12,41 @@ from kivy.uix.image import Image ...@@ -12,27 +12,41 @@ from kivy.uix.image import Image
from kivy.properties import ObjectProperty, StringProperty, NumericProperty from kivy.properties import ObjectProperty, StringProperty, NumericProperty
from kivy.uix.boxlayout import BoxLayout from kivy.uix.boxlayout import BoxLayout
from kivy.clock import Clock, mainthread from kivy.clock import Clock, mainthread
from os.path import sep, expanduser, isdir, dirname, join
from kivy.garden.filebrowser import FileBrowser
from kivy.utils import platform
from kivy.uix.textinput import TextInput from kivy.uix.textinput import TextInput
from kivy.uix.screenmanager import ScreenManager, Screen, NoTransition from kivy.uix.screenmanager import ScreenManager, Screen, NoTransition
import numpy as np import numpy as np
import re import re
import threading import threading
# from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import nifty2go as ift import nifty2go as ift
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import time import time
from kivy.uix.progressbar import ProgressBar
class FloatInput(TextInput): class FloatInput(TextInput):
pat = re.compile('[^0-9]') pat = re.compile('[^0-9]')
def insert_text(self, substring, from_undo=False): def insert_text(self, substring, from_undo=False):
print substring
pat = self.pat pat = self.pat
if '.' in self.text: if '.' in self.text:
s = re.sub(pat, '', substring) s = re.sub(pat, '', substring)
else: else:
s = '.'.join([re.sub(pat, '', s) for s in substring.split('.', 1)]) s = '.'.join([re.sub(pat, '', s) for s in substring.split('.', 1)])
return super(FloatInput, self).insert_text(s, from_undo=from_undo) return super(FloatInput, self).insert_text(s, from_undo=from_undo)
class IntInput(TextInput):
pat = re.compile('[^0-9]')
def insert_text(self, substring, from_undo=False):
pat = self.pat
s = re.sub(pat, '', substring)
return super(IntInput, self).insert_text(s, from_undo=from_undo)
class MyImage(BoxLayout): class MyImage(BoxLayout):
...@@ -45,6 +59,9 @@ class MyImage(BoxLayout): ...@@ -45,6 +59,9 @@ class MyImage(BoxLayout):
class MyAlphaWidget(BoxLayout): class MyAlphaWidget(BoxLayout):
alpha = NumericProperty(None) alpha = NumericProperty(None)
pass pass
class IterationWidget(BoxLayout):
iteration = NumericProperty(None)
pass
class ResultsPathWidget(BoxLayout): class ResultsPathWidget(BoxLayout):
pass pass
...@@ -52,10 +69,13 @@ class ResultsPathWidget(BoxLayout): ...@@ -52,10 +69,13 @@ class ResultsPathWidget(BoxLayout):
class DataPathWidget(BoxLayout): class DataPathWidget(BoxLayout):
pass pass
class ImageWidget(ScreenManager): class DisplayWidget(ScreenManager):
def reload(self): def reload(self):
for child in self.children: for child in self.children:
child.reload() child.reload()
class ImageWidget(BoxLayout):
def reload(self):
self.image_widget.reload()
class MenuWidget(BoxLayout): class MenuWidget(BoxLayout):
pass pass
...@@ -80,8 +100,19 @@ class ActionWidget(BoxLayout): ...@@ -80,8 +100,19 @@ class ActionWidget(BoxLayout):
pass pass
class DisplayChoiceWidget(BoxLayout): class DisplayChoiceWidget(BoxLayout):
pass pass
class DisplayOptionWidget(BoxLayout):
class GlobalScreenManager(ScreenManager):
pass
class MainScreen(Screen):
pass
class FileScreen(Screen):
pass pass
class PathScreen(Screen):
def is_dir(self, directory, filename):
return isdir(join(directory, filename))
class MyWidget(BoxLayout): class MyWidget(BoxLayout):
image_widget = ObjectProperty(None) image_widget = ObjectProperty(None)
...@@ -91,6 +122,15 @@ class MyWidget(BoxLayout): ...@@ -91,6 +122,15 @@ class MyWidget(BoxLayout):
result_path = StringProperty(None) result_path = StringProperty(None)
alpha = NumericProperty(None) alpha = NumericProperty(None)
class MyPathBrowser(FileBrowser):
pass
class MyFileBrowser(FileBrowser):
filters = ['*.fits', '*.png', '*.jpg']
pass
class SeparatorApp(App): class SeparatorApp(App):
stop = threading.Event() stop = threading.Event()
...@@ -103,17 +143,23 @@ class SeparatorApp(App): ...@@ -103,17 +143,23 @@ class SeparatorApp(App):
power_image = StringProperty(None) power_image = StringProperty(None)
vmin = None vmin = None
vmax = None vmax = None
iterations = 5 myEnergy = None
iterations = 3
user_path = ''
reconstructing = False
data_loaded = False
def build(self): def build(self):
self.set_default() self.set_default()
self.trigger = 0 self.trigger = 0
self.root = MyWidget() self.root = GlobalScreenManager()
self.root.image_widget.transition = NoTransition() self.image_widget = self.root.main.image_widget.image_widget.image_widget
self.root.transition = NoTransition()
self.image_widget.transition = NoTransition()
return self.root return self.root
def set_default(self): def set_default(self):
self.data_path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits' self.data_path = ''
self.result_path = '' self.result_path = ''
self.alpha = 1.5 self.alpha = 1.5
self.path = '' self.path = ''
...@@ -121,68 +167,115 @@ class SeparatorApp(App): ...@@ -121,68 +167,115 @@ class SeparatorApp(App):
self.diffuse_image = self.path + 'placeholder.png' self.diffuse_image = self.path + 'placeholder.png'
self.points_image = self.path + 'placeholder.png' self.points_image = self.path + 'placeholder.png'
self.power_image = self.path + 'placeholder.png' self.power_image = self.path + 'placeholder.png'
if platform == 'win':
self.user_path = dirname(expanduser('~')) + sep + 'Documents'
else:
self.user_path = expanduser('~') + sep + 'Documents'
def load_data(self, selection):
def load_data(self): print selection
self.data_path = selection[0]
threading.Thread(target=self.load_data_thread).start() threading.Thread(target=self.load_data_thread).start()
self.root.current = 'main'
def load_data_thread(self): def load_data_thread(self):
self.data = load_data(self.data_path) self.data = load_data(self.data_path)
self.vmin = np.log(self.data.min()) self.vmin = np.log(self.data.min())
self.max = np.log(self.data.max()) self.max = np.log(self.data.max())
self.plot_array(np.log(self.data), 'data.png') self.plot_data()
self.set_data_loaded(True)
self.set_data_image() self.set_data_image()
self.update_plots() self.update_plots()
def run_separation(self): def run_separation(self):
threading.Thread(target= self.run_separation_thread).start() if not self.reconstructing and self.data_loaded:
threading.Thread(target= self.run_separation_thread).start()
def run_separation_thread(self): def run_separation_thread(self):
self.myEnergy = build_problem(self.data, self.alpha) self.set_reconstructing(True)
self.plot_array(self.myEnergy.u.val, 'points.png') self.myEnergy = build_multi_problem(self.data, self.alpha)
self.plot_array(self.myEnergy.s.val, 'diffuse.png') self.plot_components(self.path)
self.set_image_paths() self.set_image_paths()
self.update_plots() self.update_plots()
for i in range(self.iterations): for i in range(self.iterations):
self.myEnergy = problem_iteration(self.myEnergy) self.myEnergy = multi_problem_iteration(self.myEnergy)
self.plot_array(self.myEnergy.u.val, 'points.png') self.plot_components(self.path)
self.plot_array(self.myEnergy.s.val, 'diffuse.png')
self.update_plots() self.update_plots()
self.set_reconstructing(False)
def save_results(self): def save_results(self):
pass if self.myEnergy is not None:
self.root.current = 'path'
def select_path(self, path):
self.result_path = path[0]
threading.Thread(target=self.save_data_thread).start()
self.root.current = 'main'
def save_data_thread(self):
np.savetxt(self.result_path + '/points.csv', np.exp(self.myEnergy.u.val))
np.savetxt(self.result_path + '/diffus.csv', np.exp(self.myEnergy.s.val))
self.plot_components(self.result_path)
def set_result_path(self, path): def set_result_path(self, path):
self.result_path = path self.result_path = path
print path
@mainthread @mainthread
def update_plots(self): def update_plots(self):
self.root.image_widget.reload() self.image_widget.reload()
@mainthread
def set_reconstructing(self, reconstructing):
self.reconstructing = reconstructing
@mainthread
def set_data_loaded(self, loaded):
self.data_loaded = loaded
def set_data_path(self, path): def set_data_path(self, path):
self.data_path = path self.data_path = path
@mainthread @mainthread
def set_data_image(self): def set_data_image(self):
self.data_image = self.result_path + 'data.png' self.data_image = self.path + 'data.png'
@mainthread @mainthread
def set_image_paths(self): def set_image_paths(self):
self.points_image = self.result_path + 'points.png' self.points_image = self.path + 'points.png'
self.diffuse_image = self.result_path + 'diffuse.png' self.diffuse_image = self.path + 'diffuse.png'
def plot_array(self, array, path): def plot_data(self):
plt.imsave(path, array, vmin=self.vmin, vmax=self.vmax) if self.data.shape[0] == 1:
plt.imsave(self.path+'data.png', self.data[0], vmin=self.vmin, vmax=self.vmax)
else:
plt.imsave(self.path+ 'data.png', self.data/255.)
def plot_components(self, path):
diffuse = np.empty_like(self.data)
points = np.empty_like(self.data)
for i in range(len(self.myEnergy)):
diffuse[...,i] = np.exp(self.myEnergy[i].s.val)
points[...,i] = np.exp(self.myEnergy[i].u.val)
if len(self.myEnergy) == 1:
plt.imsave(path+'diffuse.png', diffuse[...,0], vmin=self.vmin, vmax=self.vmax)
plt.imsave(path+'points.png', points[...,0], vmin=self.vmin, vmax=self.vmax)
else:
plt.imsave(self.path+ 'diffuse.png', diffuse/255.)
plt.imsave(self.path+ 'points.png', points/255.)
def set_alpha(self, alpha): def set_alpha(self, alpha):
self.alpha = alpha if alpha == '':
print alpha pass
else:
self.alpha = alpha
def set_iterations(self,iterations):
if iterations == '':
pass
else:
self.iterations = int(iterations)
def on_stop(self): def on_stop(self):
self.stop.set() self.stop.set()
if __name__ == '__main__': if __name__ == '__main__':
plt.viridis() plt.gray()
SeparatorApp().run() SeparatorApp().run()
from point_separation import build_problem, problem_iteration, load_data from point_separation import build_problem, problem_iteration, load_data
from nifty2go import * from nifty4 import *
import numpy as np import numpy as np
from matplotlib import rc from matplotlib import rc
rc('font',**{'family':'serif','serif':['Palatino']}) rc('font',**{'family':'serif','serif':['Palatino']})
...@@ -10,6 +10,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable ...@@ -10,6 +10,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1 import AxesGrid from mpl_toolkits.axes_grid1 import AxesGrid
np.random.seed(42) np.random.seed(42)
if __name__ == '__main__': if __name__ == '__main__':
path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits' path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits'
...@@ -21,7 +22,8 @@ if __name__ == '__main__': ...@@ -21,7 +22,8 @@ if __name__ == '__main__':
myEnergy = build_problem(data, alpha=alpha) myEnergy = build_problem(data, alpha=alpha)
for i in range(10): for i in range(10):
myEnergy = problem_iteration(myEnergy) myEnergy = problem_iteration(myEnergy)
A = FFTSmoothingOperator(myEnergy.s.domain, sigma=2.)
plt.magma()
size = 15 size = 15
vmin = data.min()+0.01 vmin = data.min()+0.01
vmax = 0.01*data.max() vmax = 0.01*data.max()
...@@ -42,7 +44,7 @@ if __name__ == '__main__': ...@@ -42,7 +44,7 @@ if __name__ == '__main__':
plt.axis('off') plt.axis('off')
ax = plt.gca() ax = plt.gca()
im = ax.imshow(np.exp(myEnergy.u.val), norm=LogNorm(vmin=vmin, vmax=vmax)) im = ax.imshow(A(exp(myEnergy.u)).val, norm=LogNorm(vmin=vmin, vmax=vmax))
divider = make_axes_locatable(ax) divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05) cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax) cbar = plt.colorbar(im, cax=cax)
...@@ -55,7 +57,8 @@ if __name__ == '__main__': ...@@ -55,7 +57,8 @@ if __name__ == '__main__':
plt.title('data', size=size) plt.title('data', size=size)
plt.axis('off') plt.axis('off')
ax = plt.gca() ax = plt.gca()
im = ax.imshow(data, norm=LogNorm(vmin=vmin, vmax=vmax)) dat = Field(myEnergy.s.domain,val=data)
im = ax.imshow((dat).val, norm=LogNorm(vmin=vmin, vmax=vmax))
divider = make_axes_locatable(ax) divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05) cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax) cbar = plt.colorbar(im, cax=cax)
...@@ -66,29 +69,29 @@ if __name__ == '__main__': ...@@ -66,29 +69,29 @@ if __name__ == '__main__':
plt.figure() plt.figure()
fig, ax = plt.subplots(1, 3, figsize=(6, 4)) fig, ax = plt.subplots(1, 3, figsize=(6, 3))
plt.suptitle('zoomed in section', size=size) plt.suptitle('zoomed in section', size=size)
# fig.tight_layout() # fig.tight_layout()
vmin = data.min() + 0.0001 vmin = data.min() + 0.0001
vmax = 0.001*data.max() vmax = 0.001 * data.max()
im = ax[0].imshow(data[600:700,650:750],norm=LogNorm(vmin=vmin, vmax=vmax)) im = ax[0].imshow(data[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
ax[0].set_title('data',size = 15) ax[0].set_title('data', size=15)
ax[0].axis('off') ax[0].axis('off')
ax[1].imshow(exp(myEnergy.s).val[600:700, 650:750],norm=LogNorm(vmin=vmin, vmax=vmax)) ax[1].imshow(exp(myEnergy.s).val[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
ax[1].set_title('diffuse',size = 15) ax[1].set_title('diffuse', size=15)
ax[1].axis('off') ax[1].axis('off')
ax[2].imshow(exp(myEnergy.u).val[600:700, 650:750],norm=LogNorm(vmin=vmin, vmax=vmax)) ax[2].imshow(exp(myEnergy.u).val[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
ax[2].set_title('point-like',size = 15) ax[2].set_title('point-like', size=15)
ax[2].axis('off') ax[2].axis('off')
# cax = fig.add_axes([0., 0.9, 0.03, 0.8]) # cax = fig.add_axes([0., 0.9, 0.03, 0.8])
cb = fig.colorbar(im, ax=ax.ravel().tolist(), orientation='horizontal', pad = 0.01) cb = fig.colorbar(im, ax=ax.ravel().tolist(), orientation='horizontal', pad=0.01)
cb.set_label('flux', size = 15) cb.set_label('flux', size=15)
fig.subplots_adjust(left=None, bottom=None, right=None, top=None, fig.subplots_adjust(left=None, bottom=0.25, right=None, top=None,
wspace=0.01, hspace=None) wspace=0.01, hspace=None)
plt.savefig('hubble_zoom.pdf') plt.savefig('hubble_zoom.pdf')
...@@ -103,16 +106,28 @@ if __name__ == '__main__': ...@@ -103,16 +106,28 @@ if __name__ == '__main__':
cbar_pad=0.1 cbar_pad=0.1
) )
im = grid[0].imshow(data[600:700,650:750],norm=LogNorm(vmin=vmin, vmax=vmax))#, extent=extent, interpolation="not") im = grid[0].imshow(data[600:700, 650:750],
im = grid[1].imshow(exp(myEnergy.s).val[600:700, 650:750],norm=LogNorm(vmin=vmin, vmax=vmax))#, extent=extent, interpolation="not") norm=LogNorm(vmin=vmin, vmax=vmax)) # , extent=extent, interpolation="not")
im = grid[2].imshow(exp(myEnergy.u).val[600:700, 650:750],norm=LogNorm(vmin=vmin, vmax=vmax))#, extent=extent, interpolation="not") im = grid[1].imshow(exp(myEnergy.s).val[600