Commit bad181a2 authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

working prototype

parent 240a3623
import matplotlib
matplotlib.use('agg')
from kivy.app import App
from kivy.uix.widget import Widget
from kivy.uix.button import Button
from kivy.uix.label import Label
from kivy.uix.image import Image
from kivy.properties import ObjectProperty, StringProperty, NumericProperty
from kivy.uix.boxlayout import BoxLayout
from kivy.clock import Clock
from kivy.uix.textinput import TextInput
import re
import threading
from point_separation import build_problem, problem_iteration,load_data
# from matplotlib import pyplot as plt
import nifty2go as ift
import matplotlib.pyplot as plt
class FloatInput(TextInput):
pat = re.compile('[^0-9]')
def insert_text(self, substring, from_undo=False):
pat = self.pat
if '.' in self.text:
s = re.sub(pat, '', substring)
else:
s = '.'.join([re.sub(pat, '', s) for s in substring.split('.', 1)])
return super(FloatInput, self).insert_text(s, from_undo=from_undo)
class MyImage(BoxLayout):
source = StringProperty('')
text = StringProperty('')
img = ObjectProperty(None)
def reload(self):
self.img.reload()
class MyAlphaWidget(BoxLayout):
alpha = NumericProperty(None)
pass
class StoreResultsWidget(BoxLayout):
pass
class LoadDataWidget(BoxLayout):
pass
class ImageWidget(BoxLayout):
data_image = ObjectProperty(None)
diffuse_image = ObjectProperty(None)
point_image = ObjectProperty(None)
power_image = ObjectProperty(None)
def reload_images(self):
self.diffuse_image.reload()
self.point_image.reload()
self.power_image.reload()
class MenuWidget(BoxLayout):
pass
class MyWidget(BoxLayout):
image_widget = ObjectProperty(None)
menu_widget = ObjectProperty(None)
data_path = StringProperty(None)
result_path = StringProperty(None)
alpha = NumericProperty(None)
def reload_images(self):
self.image_widget.reload_images()
def set_result_path(self, path):
self.result_path = path
print path
def set_data_path(self, path):
self.data_path = path
data = load_data(self.data_path)
self.myEnergy = build_problem(data, self.alpha)
self.plotting()
print path
def set_alpha(self, alpha):
self.alpha = alpha
print alpha
def plotting(self):
plt.viridis()
plt.imsave(self.result_path+'points0.png', self.myEnergy.u.val)
plt.imsave(self.result_path+'maps0.png', (self.myEnergy.s).val)
plt.imsave(self.result_path+'data0.png', ift.log(self.myEnergy.d).val)
self.reload_images()
def run_separation(self, ini):
if ini:
self.set_data_path(self.data_path)
self.myEnergy = problem_iteration(self.myEnergy)
self.plotting()
class SeparatorApp(App):
def build(self):
self.trigger = 0
self.root = MyWidget()
Clock.schedule_interval(self.update, 1)
return self.root
def update(self,*args):
self.trigger += 1
if self.trigger < 0:
self.root.run_separation(self.trigger==-10)
if __name__ == '__main__':
MyApp().run()
\ No newline at end of file
SeparatorApp().run()
from nifty2go import *
from nifty2go.library.nonlinearities import PositiveTanh
from nifty2go.library.wiener_filter_curvature import WienerFilterCurvature
import numpy as np
from matplotlib import pyplot as plt
from scipy.misc import imresize
# from matplotlib import pyplot as plt
from astropy.io import fits
from separation_energy import SeparationEnergy
from nifty2go.library.nonlinearities import PositiveTanh
# class MyOtherEnergy(Energy):
# def __init__(self, position, d, S, alpha, q, inverter):
# x = position.val.clip(-9,9)
# position = Field(position.domain,val=x)
# self.inverter = inverter
# super(MyOtherEnergy, self).__init__(position=position)
# self.d = d
# self.S = S
# self.alpha = alpha
# self.q = q
# self.a = PositiveTanh(self.position)
# self.a_deriv = PositiveTanh.derivative(self.position)
# self.u = log(self.d*self.a)
# self.s = log(self.d*(1-self.a))
# self.tanh_x = 2 * self.a - 1
# self.deriv_tanh_x = 2 * self.a_deriv
#
# def at(self, position):
# return self.__class__(position, d=self.d, S=self.S, alpha=self.alpha, q=self.q, inverter = self.inverter)
#
# @property
# def value(self):
# diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s))
# point = (1-self.alpha).vdot(self.u) + self.q.vdot(exp(-self.u))
# det = - log(1-self.a).integrate()
# det += log(self.a_deriv).integrate()
#
# return diffuse + point + det
# @property
# def gradient(self):
# diffuse = - self.S.inverse(self.s)/((1-self.a)*self.d)*self.d*self.a_deriv
# point = - self.q * exp(-self.u)/(self.a*self.d)*self.d*self.a_deriv
# point += (1-self.alpha) /(self.a*self.d)*self.d*self.a_deriv
# det = 1./(1-self.a)*self.a_deriv
# det += -1./(self.a_deriv)*self.tanh_x * self.deriv_tanh_x
# return diffuse + point + det
class MyOtherEnergy(Energy):
def __init__(self, position, d, Sh, alpha, q, inverter, FFT):
x = position.val.clip(-9,9)
position = Field(position.domain,val=x)
self.inverter = inverter
super(MyOtherEnergy, self).__init__(position=position)
self.d = d
self.S = FFT.adjoint * Sh * FFT
self.Sh = Sh
self.FFT = FFT
self.alpha = alpha
self.q = q
self.a = PositiveTanh(self.position)
self.a_p = PositiveTanh.derivative(self.position)
self.tanh_x = 2 * self.a - 1
self.tanh_x_p = 2 * self.a_p
self.a_pp = - self.tanh_x *self.tanh_x_p
self.a_ppp = (3 * self.tanh_x ** 2 - 1) * self.tanh_x_p
self.u = log(self.d*self.a)
self.u_p = self.a_p/self.a
self.u_pp = self.a_pp/self.a - self.u_p ** 2
self.s = log(self.d*(1-self.a))
self.s_p = - self.a_p/(1-self.a)
self.s_pp = self.a_pp/(1-self.a) + self.s_p ** 2
def at(self, position):
return self.__class__(position, d=self.d, Sh=self.Sh,
alpha=self.alpha, q=self.q, inverter=self.inverter,
FFT = self.FFT)
@property
def value(self):
diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s))
point = (-1+self.alpha).weight(0).vdot(self.u) + self.q.weight(0).vdot(exp(-self.u))
det = log(1-self.a).integrate()
det += 0.5/9.*self.position.vdot(self.position)
print diffuse + point + det
return diffuse + point + det
@property
def gradient(self):
diffuse = self.S.inverse(self.s) * self.s_p
point = (-1 + self.alpha).weight(0) * self.u_p - self.q.weight(0) * exp(-self.u) * self.u_p
det = self.position/9.
det +=- 1./(1-self.a) * self.a_p
return (diffuse + point + det)
@property
def curvature(self):
diffuse = self.s_p * self.S.inverse * self.s_p #+ self.s * self.S.inverse * self.s_pp
point = self.q.weight(0) * exp(-self.u) * self.u_p ** 2
# point += (1 - self.alpha - self.q * exp(-self.u)) * self.u_pp
# det = self.s_pp**2 #- (self.a_pp / self.a_p)**2# + self.a_ppp/self.a_p
curv = diffuse + (point )#+ det)
curv = InversionEnabler(curv, self.inverter, self.S.times)
R = self.FFT * self.s_p
N = self.Sh
S = DiagonalOperator(1/(point + 1/9. ))
return WienerFilterCurvature(R=R,N=N,S=S, inverter=self.inverter)
if __name__ == '__main__':
PositiveTanh = PositiveTanh()
# dd = plt.imread('IC1396.jpg') + 0.01
# dd = plt.imread('Andromeda.jpg') + 0.01
# dd = plt.imread('Andromeda_large.png')*255 + 0.01
# dd = plt.imread('andromeda.jpeg') + 0.01
# dd = plt.imread('galaxy.jpg') + 0.01
# dd = plt.imread('M52.jpg') + 0.01
# dd = plt.imread('m51_3.jpg')
# dd = plt.imread('M16.jpg') + 0.01
# dd = plt.imread('m31_wide.jpg') + 0.01
# dd = fits.open('hubble_m51/hlsp_legus_hst_acs_ngc5194-sw_f555w_v1_drc.fits')[0].data + 0.1
dd = fits.open('hst_05195_01_wfpc2_f702w_pc_sci.fits')[1].data.clip(min=0.01) #M100
# dd = fits.open('hst_10402_06_wfpc2_f336w_wf_sci.fits')[1].data.clip(min=0.01) #M94
def load_data(path):
# dd = imresize(dd,20)+0.01
# if path[-5:] == '.fits':
data = fits.open(path)[1].data
# else:
# data = plt.imread(path)
dd = np.ndarray.astype(dd, float)
# dd = dd[1000:2000,1000:2000]
d0 = dd
d1 = dd
d2 = dd
# d0 = dd[:,:,0]
# d1 = dd[:,:,1]
# d2 = dd[:,:,2]
s_space = RGSpace(d0.shape,distances=[1,1])
# s_space = RGSpace([256,256])
# s_space = RGSpace([529,660])
# s_space = RGSpace([2000,3000])
# s_space = RGSpace([494,782])
d0 = Field(s_space,val=d0)
d1 = Field(s_space,val=d1)
d2 = Field(s_space,val=d2)
data = data.clip(min=0.001)
data = np.ndarray.astype(data, float)
return data
def build_problem(data, alpha):
s_space = RGSpace(data.shape, distances=[1, 1])
data = Field(s_space,val=data)
FFT = FFTOperator(s_space)
h_space = FFT.target[0]
iFFT = FFTOperator(h_space)
binbounds = PowerSpace.useful_binbounds(h_space, logarithmic = False)
p_space = PowerSpace(h_space, binbounds=binbounds)
k_lengths = p_space.k_lengths
p_d0 = power_analyze(FFT(log(d0)), binbounds=p_space.binbounds)
spectrum = p_d0
# spectrum = 1e-2/(k_lengths**2+1)
# spectrum = Field(p_space,val=spectrum)
Sh = create_power_operator(h_space, spectrum)
sh = power_synthesize(spectrum)
s = FFT.adjoint_times(sh)
u = Field.from_random('normal', s_space)*1
# points = np.zeros(s_space.shape)
# points[128,128] = 4.
# points[64,128] = 3.
# points[133,200] = 2.
#
#
# u = Field(s_space,val=points)
d = exp(s) + exp(u)
# d = plt.imread('m51_3.jpg')[:,:,0]
# d0=d
# d1=d
# d2=d
d0 = Field(s_space,val=d0)
d1 = Field(s_space,val=d1)
d2 = Field(s_space,val=d2)
plt.imsave('data0.png', log(d0).val)
# d = Field(s_space,val=d)
# plt.figure()
# plt.plot(d.val,'kx')
# plt.savefig('data.png')
# plt.imsave('data.png', (d).val)
#
# plt.close()
S = FFT.inverse * Sh * FFT
ICI = GradientNormController(verbose=True, name="ICI",
initial_spectrum = power_analyze(FFT(log(data)), binbounds=p_space.binbounds)
initial_correlation = create_power_operator(h_space, initial_spectrum)
initial_x = Field(s_space, val=-1.)
alpha = Field(s_space, val=alpha)
q = Field(s_space, val=10e-40)
pos_tanh = PositiveTanh()
ICI = GradientNormController(verbose=False, name="ICI",
iteration_limit=500,
tol_abs_gradnorm=1e-5)
inverter = ConjugateGradient(controller=ICI)
parameters = dict(data=data, correlation=initial_correlation,
alpha=alpha, q=q,
inverter=inverter, FFT=FFT, pos_tanh=pos_tanh)
separationEnergy = SeparationEnergy(position=initial_x, parameters=parameters)
return separationEnergy
def problem_iteration(energy):
controller = GradientNormController(verbose=True, tol_abs_gradnorm=0.00000001, iteration_limit=1)
minimizer = RelaxedNewton(controller=controller)
energy, convergence = minimizer(energy)
new_position = energy.position
h_space = energy.correlation.domain[0]
FFT = energy.FFT
binbounds = PowerSpace.useful_binbounds(h_space, logarithmic=False)
new_power = power_analyze(FFT(energy.s), binbounds=binbounds)
new_correlation = create_power_operator(h_space, new_power)
new_parameters = energy.parameters
new_parameters['correlation'] = new_correlation
new_energy = SeparationEnergy(new_position, new_parameters)
return new_energy
controller1 = GradientNormController(verbose=True, tol_abs_gradnorm=0.00000001, iteration_limit=10)
controller2 = GradientNormController(verbose=True, tol_abs_gradnorm=0.00000001, iteration_limit=3)
# wolfe = LineSearchStrongWolfe(max_zoom_iterations=3)
# wolfe.preferred_initial_step_size = 1e-10
minimizer1 = VL_BFGS(controller=controller1,max_history_length=10)#, line_searcher=wolfe)
# minimizer1 = SteepestDescent(controller=controller1)
# minimizer1 = RelaxedNewton(controller=controller1)
# minimizer1 = VL_BFGS(controller=controller1, max_history_length=1)
# x = Field.from_random('normal',s_space)
x = Field(s_space, val=-1.)
alpha = Field(s_space,val=1.3)
q = Field(s_space, val=1e-40)
if __name__ == '__main__':
path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits'
data = load_data(path)
alpha = 1.5
# myEnergy = MyOtherEnergy(x,d,S,alpha,q)
myEnergy0 = MyOtherEnergy(x.copy(),d0,Sh,alpha,q, inverter, FFT)
myEnergy1 = MyOtherEnergy(x.copy(),d1,Sh,alpha,q, inverter, FFT)
myEnergy = build_problem(data, alpha=alpha)
myEnergy2 = MyOtherEnergy(x.copy(),d2,Sh,alpha,q, inverter, FFT)
for i in range( 100):
if i%1 == 0:
minimizer1 = RelaxedNewton(controller=controller2)
else:
minimizer1 = VL_BFGS(controller=controller1, max_history_length=10) # , line_searcher=wolfe)
# minimizer1 = SteepestDescent(controller=controller1)
myEnergy0, convergence = minimizer1(myEnergy0)
my_s0 = FFT(log((1 - myEnergy0.a) * d0))
my_p0 = power_analyze(my_s0,binbounds=p_space.binbounds)
# my_p0.val[-1] = my_p0.val[-2]
# my_p0.val[:] = 1e3/(k_lengths**1+1)
print i
my_S0 = create_power_operator(h_space,my_p0)
my_S0 = Sh
myEnergy0 = MyOtherEnergy(myEnergy0.position,d0,my_S0,alpha,q, inverter, FFT)
# myEnergy1, convergence = minimizer1(myEnergy1)
my_s1 = FFT(log((1 - myEnergy1.a) * d1))
my_p1 = power_analyze(my_s1,binbounds=p_space.binbounds)
# my_p1.val[-1] = my_p1.val[-2]
# my_p1.val[:] = 1e-3/(k_lengths**1+1)
# my_p1=spectrum
my_S1 = create_power_operator(h_space,my_p1)
my_S1 = Sh
myEnergy1 = MyOtherEnergy(myEnergy1.position,d1,my_S1,alpha,q, inverter, FFT)
# myEnergy2, convergence = minimizer1(myEnergy2)
my_s2 = FFT(log((1 - myEnergy2.a) * d2))
# samples = []
# collector = Field.zeros(p_space)
# for i in range(3):
# sample =(myEnergy2.curvature.generate_posterior_sample2())
# sample = Field(sample.domain,val=sample.val.clip(-9, 9))
# sample += myEnergy2.position
#
# collector += power_analyze( FFT(log((1 - PositiveTanh(sample)) * d2)),
# binbounds=binbounds)
# collector /= 3.
# my_p2 =Field(p_space,val=1e1/(k_lengths**4+1))
# my_p2.val[:] = 1e-3/(k_lengths**2+1)
# my_p2.val[-1] = my_p2.val[-2]
my_p2 = power_analyze(my_s2,binbounds=p_space.binbounds)
my_S2 = create_power_operator(h_space,my_p2)
my_S2 = Sh
myEnergy2 = MyOtherEnergy(myEnergy2.position,d2,my_S2,alpha,q, inverter, FFT)
# myEnergy = MyOtherEnergy(x, d, S, alpha, q)
# myEnergy0, convergence = minimizer1(myEnergy0)
# myEnergy1, convergence = minimizer1(myEnergy1)
# myEnergy2, convergence = minimizer1(myEnergy2)
plt.imsave('points0.png',log(myEnergy0.a*d0).val)
plt.imsave('maps0.png',(log((1-myEnergy0.a)*d0)).val)
plt.imsave('data0.png',log(d0).val)
plt.imsave('data1.png',(d1).val)
plt.imsave('points1.png',(myEnergy1.a*d1).val)
plt.imsave('maps1.png',((1-myEnergy1.a)*d1).val)
plt.imsave('data2.png',(d2).val)
plt.imsave('points2.png',(myEnergy2.a*d2).val)
plt.imsave('maps2.png',((1-myEnergy2.a)*d2).val)
# maps = np.zeros_like(dd)
# maps[:,:,0] = ((1-myEnergy0.a)*d0).val
# maps[:,:,1] = ((1-myEnergy1.a)*d1).val
# maps[:,:,2] = ((1-myEnergy2.a)*d2).val
# plt.imsave('maps.png', maps /255.)
#
# points = np.zeros_like(dd)
# points[:,:,0] = ((myEnergy0.a)*d0).val
# points[:,:,1] = ((myEnergy1.a)*d1).val
# points[:,:,2] = ((myEnergy2.a)*d2).val
# plt.imsave('points.png', points / 255.)
plt.figure()
plt.yscale('log')
plt.xscale('log')
plt.plot(spectrum.val, 'k-')
plt.plot(power_analyze(my_s0,binbounds=p_space.binbounds).val, 'r-', alpha = 0.5)
plt.plot(power_analyze(my_s1,binbounds=p_space.binbounds).val, 'g-', alpha = 0.5)
plt.plot(power_analyze(my_s2,binbounds=p_space.binbounds).val, 'b-', alpha = 0.5)
p_d0 = power_analyze(FFT(log(d0)), binbounds=p_space.binbounds).val
plt.plot(p_d0,'y--',alpha=0.5)
p_d1 = power_analyze(FFT(log(d1)), binbounds=p_space.binbounds).val
plt.plot(p_d1,'y--',alpha=0.5)
p_d2 = power_analyze(FFT(log(d2)), binbounds=p_space.binbounds).val
plt.plot(p_d2,'y--',alpha=0.5)
plt.plot(my_p0.val,'r-',alpha=0.5)
plt.plot(my_p1.val,'g-',alpha=0.5)
plt.plot(my_p2.val,'b-',alpha=0.5)
p_u = power_analyze(FFT(u), binbounds=p_space.binbounds)
p_s = power_analyze(sh, binbounds=p_space.binbounds)
plt.plot(p_u.val, 'k-')
plt.plot(p_s.val,'k-')
p_u0 = power_analyze(FFT(log(myEnergy0.a*d0)), binbounds=p_space.binbounds).val
p_u1 = power_analyze(FFT(log(myEnergy1.a*d1)), binbounds=p_space.binbounds).val
p_u2 = power_analyze(FFT(log(myEnergy2.a*d2)), binbounds=p_space.binbounds).val
plt.plot(p_u0,'r:',alpha=0.5)
plt.plot(p_u1,'g:',alpha=0.5)
plt.plot(p_u2,'b:',alpha=0.5)
plt.savefig('power.png')
plt.close()
# myEnergy, convergence = minimizer1(myEnergy)
# plt.imsave('data.png', (d).val)
# my_s = FFT(log((1 - myEnergy.a) * d))
# my_p = power_analyze(my_s,binbounds=p_space.binbounds)
# my_S = FFT.inverse * create_power_operator(h_space,my_p) * FFT
#
# myEnergy = MyOtherEnergy(myEnergy.position,d,my_S,alpha,q)
#
# plt.imsave('points.png', (myEnergy.a * d).val)
# plt.imsave('maps.png', ((1 - myEnergy.a) * d).val)
myEnergy = problem_iteration(myEnergy)
plt.viridis()
plt.imsave('points0.png',myEnergy.u.val)
plt.imsave('maps0.png',(myEnergy.s).val)
plt.imsave('data0.png',myEnergy.d.val)
#
# plt.figure()
#
# plt.plot(exp(s).val)
# plt.plot(((1-myEnergy.a)*d).val, alpha=0.7)
# plt.savefig('maps.png')
# plt.close()
# plt.figure()
#
# plt.plot(exp(u).val)
# plt.plot(((myEnergy.a)*d).val, alpha=0.7)
# plt.savefig('points.png')
# plt.close()
# es = np.
# if len(s_space.shape) == 2:
# plt.imsave('data0.png',(d0).val)
#
# plt.imsave('points0.png',(myEnergy0.a*d0).val)
# plt.imsave('maps0.png',((1-myEnergy0.a)*d0).val)
#
# plt.imsave('data1.png',(d1).val)
#
# plt.imsave('points1.png',(myEnergy1.a*d1).val)
# plt.imsave('maps1.png',((1-myEnergy1.a)*d1).val)
# plt.imsave('data2.png',(d2).val)
#
# plt.imsave('points2.png',(myEnergy2.a*d2).val)
# plt.imsave('maps2.png',((1-myEnergy2.a)*d2).val)
#
# maps = np.zeros_like(dd)
# maps[:,:,0] = ((1-myEnergy0.a)*d0).val
# maps[:,:,1] = ((1-myEnergy1.a)*d1).val
# maps[:,:,2] = ((1-myEnergy2.a)*d2).val
# plt.imsave('maps.png', maps/255.)
from nifty2go import Energy, Field, log, exp, DiagonalOperator
from nifty2go.library import WienerFilterCurvature
class SeparationEnergy(Energy):
def __init__(self, position, parameters):
x = position.val.clip(-9, 9)
position = Field(position.domain, val=x)
super(SeparationEnergy, self).__init__(position=position)
self.parameters = parameters
self.inverter = parameters['inverter']
self.d = parameters['data']
self.FFT = parameters['FFT']
self.correlation = parameters['correlation']
self.alpha = parameters['alpha']
self.q = parameters['q']
pos_tanh = parameters['pos_tanh']
self.S = self.FFT.adjoint * self.correlation * self.FFT
self.a = pos_tanh(self.position)
self.a_p = pos_tanh.derivative(self.position)
self.u = log(self.d * self.a)
self.u_p = self.a_p/self.a
one_m_a = 1 - self.a
self.s = log(self.d * one_m_a)
self.s_p = - self.a_p / one_m_a
self.var_x = 9.
def at(self, position):
return self.__class__(position, parameters=self.parameters)
@property
def value(self):
diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s))
point = (self.alpha-1).vdot(self.u) + self.q.vdot(exp(-self.u))
det = self.s.integrate()
det += 0.5 / self.var_x * self.position.vdot(self.position)
return diffuse + point + det
@property
def gradient(self):
diffuse = self.S.inverse(self.s) * self.s_p
point = (self.alpha - 1) * self.u_p - self.q * exp(-self.u) * self.u_p
det = self.position / self.var_x
det += self.s_p
return diffuse + point + det
@property
def curvature(self):
point = self.q * exp(-self.u) * self.u_p ** 2
R = self.FFT * self.s_p
N = self.correlation
S = DiagonalOperator(1/(point + 1/self.var_x))
return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)
#:kivy 1.0.9
TextInput: