Commit 8bdaa000 authored by Jakob Knollmueller's avatar Jakob Knollmueller

MAP

parent b0e621a6
...@@ -26,29 +26,55 @@ if __name__ == '__main__': ...@@ -26,29 +26,55 @@ if __name__ == '__main__':
#specifying location of the input file: #specifying location of the input file:
path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits' path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
path = 'data/frame-i-004874-3-0692.fits' path = 'data/frame-i-004874-3-0692.fits'
path = 'data/frame-i-007812-6-0100.fits'
path = 'data/frame-g-002821-6-0141.fits'
path = 'data/frame-g-002821-6-0141.fits'
path = 'data/frame-i-000752-1-0432.fits'
# path = 'data/frame-i-004858-1-0480.fits'
# data = fits.open(path)[1].data # data = fits.open(path)[1].data
data = fits.open(path)[0].data[1000:15000,1250:1750] # data = fits.open(path)[0].data#[750:,1000:]
xx = 250
yy = int(xx /0.75)
data = fits.open(path)[0].data[500:500+xx,1200:1200+yy]
data_unmod = fits.open(path)[0].data[500:500+xx,1200:1200+yy]
sex = fits.open('data/check.fits')[0].data[500:500+xx,1200:1200+yy]
data -= data.min() - 0.001 data -= data.min() - 0.001
# data = 1.-plt.imread('data/sdss.png').T[0] # data = 1.-plt.imread('data/sdss.png').T[0]
# data = fits.open(path)[1].data # data = fits.open(path)[1].data
data = data.clip(min=0.0001) data = data.clip(min=0.0001)
hdu = fits.PrimaryHDU(data)
hdul = fits.HDUList([hdu])
hdul.writeto('new1.fits')
data = np.ndarray.astype(data, float) data = np.ndarray.astype(data, float)
vmin = np.log(data.min()+0.2) vmin = np.log(data.min()+0.2)
vmax = np.log(data.max()) vmax = np.log(data.max())*0.3
plt.imsave('data.png', np.log(data),vmin=vmin,vmax=vmax) plt.gray()
lin_max = 2.
lin_min=0.1
plt.imsave('log_data.png', np.log(data),vmin=vmin,vmax=vmax)
plt.imsave('data.png', (data), vmax = lin_max,vmin =lin_min)
plt.imsave('sex.png', sex)#, vmax = lin_max, vmin=lin_min)
plt.imsave('log_sex.png', np.log(sex),)#, vmax = lin_max, vmin=lin_min)
plt.imsave('point_sex.png', (data_unmod - sex), vmax = lin_max,vmin =lin_min)
alpha = 1.25 alpha = 1.4
Starblade = sb.build_starblade(data, alpha=alpha) Starblade = sb.build_starblade(data, alpha=alpha,cg_iterations=100,newton_iterations=10)
for i in range(10): for i in range(1000):
Starblade = sb.starblade_iteration(Starblade, samples=i) Starblade = sb.starblade_iteration(Starblade, samples=5)
#plotting on logarithmic scale #plotting on logarithmic scale
plt.imsave('diffuse_component.png', Starblade.s.val, vmin=vmin, vmax=vmax) plt.imsave('log_diffuse_component.png', Starblade.s.val, vmin=vmin, vmax=vmax)
plt.imsave('pointlike_component.png', Starblade.u.val, vmin=vmin, vmax=vmax) plt.imsave('diffuse_component.png', np.exp(Starblade.s.val), vmin=lin_min, vmax=lin_max)
plt.imsave('log_pointlike_component.png', Starblade.u.val, vmin=vmin, vmax=vmax)
plt.imsave('pointlike_component.png', np.exp(Starblade.u.val), vmin=lin_min, vmax=lin_max)
plt.figure() plt.figure()
k_lenghts = Starblade.power_spectrum.domain[0].k_lengths k_lenghts = Starblade.power_spectrum.domain[0].k_lengths
plt.plot(k_lenghts, Starblade.power_spectrum.val) plt.plot(k_lenghts, Starblade.power_spectrum.val)
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2017-2018 Max-Planck-Society
# Author: Jakob Knollmueller
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
import numpy as np
from astropy.io import fits
from matplotlib import pyplot as plt
import scipy.cluster.vq as sp
import starblade as sb
if __name__ == '__main__':
#specifying location of the input file:
path = '../data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
path = '../data/frame-i-004874-3-0692.fits'
# path = '../data/frame-i-007812-6-0100.fits'
# path = '../data/frame-g-002821-6-0141.fits'
path = '../data/frame-i-000752-1-0432.fits'
# path = '../data/frame-u-000752-1-0432.fits'
# path = '../data/frame-i-006174-2-0094.fits'
# path = 'data/frame-i-004858-1-0480.fits'
# data = fits.open(path)[1].data
# data = fits.open(path)[0].data#[750:,1000:]
##docker run --rm -i -t --name sex -v ~/Projects/starblade/demos/sextractor:/work chbrandt/sextractor my_data.fits -c default.se
xx = 450
yy = int(xx /0.75)
x0 = 400#1050#770#450
y0 = 1000#1550#200#1150
data = fits.open(path)[0].data[x0:x0+xx,y0:y0+yy]
data_unmod = data.copy()
# sex = fits.open('check.fits')[0].data[500:500+xx,1200:1200+yy]
data -= data.min() - 0.001
# data = 1.-plt.imread('data/sdss.png').T[0]
# data = fits.open(path)[1].data
data = data.clip(min=0.0001)
sex = fits.open('check.fits')[0].data
hdu = fits.PrimaryHDU(data)
hdul = fits.HDUList([hdu])
hdul.writeto('my_data.fits',overwrite=True)
data = np.ndarray.astype(data, float)
vmin = np.log(data.min()+0.2)
vmax = np.log(data.max())*0.4
plt.gray()
lin_max = data.max()*0.01
lin_min=0.01
plt.imsave('log_data.png', np.log(data),vmin=vmin,vmax=vmax)
plt.imsave('data.png', (data), vmax = lin_max,vmin =lin_min)
plt.imsave('sex.png', sex, vmax = lin_max, vmin=lin_min)
plt.imsave('log_sex.png', np.log(sex), vmax = vmax, vmin=vmin)
plt.imsave('point_sex.png', (data_unmod - sex), vmax = lin_max,vmin =lin_min)
plt.imsave('log_point_sex.png', np.log((data_unmod - sex).clip(min=0.0001)), vmax = vmax,vmin =vmin)
alpha = 1.4
Starblade = sb.build_starblade(data, alpha=alpha,cg_iterations=100,newton_iterations=3)
for i in range(10):
Starblade = sb.starblade_iteration(Starblade, samples=5)
#plotting on logarithmic scale
plt.imsave('log_diffuse_component.png', Starblade.s.val, vmin=vmin, vmax=vmax)
plt.imsave('diffuse_component.png', np.exp(Starblade.s.val), vmin=lin_min, vmax=lin_max)
plt.imsave('log_pointlike_component.png', Starblade.u.val, vmin=vmin, vmax=vmax)
plt.imsave('pointlike_component.png', np.exp(Starblade.u.val), vmin=lin_min, vmax=lin_max)
plt.figure()
k_lenghts = Starblade.power_spectrum.domain[0].k_lengths
plt.plot(k_lenghts, Starblade.power_spectrum.val)
plt.title('power spectrum')
plt.yscale('log')
plt.xscale('log')
plt.ylabel('power')
plt.xlabel('harmonic mode')
plt.savefig('power_spectrum.png')
plt.close('all')
hdu = fits.PrimaryHDU(np.exp(Starblade.u.val))
hdul = fits.HDUList([hdu])
hdul.writeto('my_points.fits',overwrite=True)
hdu = fits.PrimaryHDU(np.exp(Starblade.s.val))
hdul = fits.HDUList([hdu])
hdul.writeto('my_diffuse.fits',overwrite=True)
star_sex = np.log((data_unmod - sex).clip(min=0.0000000001))
star_blade = Starblade.u.val
diffuse_sex = np.log((sex).clip(min=0.0000000001))
diffuse_blade = Starblade.s.val
stars = np.empty((2,len(star_blade.flatten())))
# stars = np.concatenate((star_blade.flatten(),star_sex.flatten()))
stars[0] = star_blade.flatten()
stars[1] = star_sex.flatten()
# stars = sp.whiten(stars)
sp.kmeans2(stars.T,2,iter=30)
...@@ -250,7 +250,7 @@ class StarbladeApp(App): ...@@ -250,7 +250,7 @@ class StarbladeApp(App):
@mainthread @mainthread
def set_data_image(self): def set_data_image(self):
self.data_image = self.path + 'data.png' self.data_image = self.path + 'log_data.png'
@mainthread @mainthread
def set_image_paths(self): def set_image_paths(self):
self.points_image = self.path + 'points.png' self.points_image = self.path + 'points.png'
...@@ -258,9 +258,9 @@ class StarbladeApp(App): ...@@ -258,9 +258,9 @@ class StarbladeApp(App):
def plot_data(self): def plot_data(self):
if self.data.shape[0] == 1: if self.data.shape[0] == 1:
plt.imsave(self.path+'data.png', self.data[0], vmin=self.vmin, vmax=self.vmax) plt.imsave(self.path+'log_data.png', self.data[0], vmin=self.vmin, vmax=self.vmax)
else: else:
plt.imsave(self.path+ 'data.png', self.data/255.) plt.imsave(self.path+ 'log_data.png', self.data/255.)
def plot_components(self, path): def plot_components(self, path):
diffuse = np.empty_like(self.data) diffuse = np.empty_like(self.data)
......
...@@ -6,57 +6,85 @@ rc('text', usetex=True) ...@@ -6,57 +6,85 @@ rc('text', usetex=True)
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import starblade as sb import starblade as sb
from scipy.stats import invgamma
np.random.seed(42) np.random.seed(42)
if __name__ == '__main__': if __name__ == '__main__':
s_space = ift.RGSpace([1024]) s_space = ift.RGSpace([128,128])
h_space = s_space.get_default_codomain() h_space = s_space.get_default_codomain()
FFT = ift.FFTOperator(h_space) FFT = ift.FFTOperator(h_space)
p_spec = lambda k: (1./(1+k)**2.5) p_spec = lambda k: (1./(1+k)**4)
mod_p_spec = lambda k: p_spec(k*1024)*1024**4
binbounds = ift.PowerSpace.useful_binbounds(h_space,logarithmic=True)#, nbin=100)
p_space = ift.PowerSpace(h_space,binbounds=binbounds)
k_lengths=p_space.k_lengths
# p_spec = ift.Field(p_space,val=p_spec(p_space.k_lengths))
S = ift.create_power_operator(h_space, power_spectrum=p_spec) S = ift.create_power_operator(h_space, power_spectrum=p_spec)
sh = S.draw_sample() sh = S.draw_sample()
s = FFT(sh) s = FFT(sh)
# k_lengths = sh.domain[0].k_lengths
u = ift.Field(s_space, val = -12)
u.val[200] = 1 u = ift.Field(s_space, val = -12.)
u.val[300] = 3 # u = 3*(ift.Field.from_random('normal',s_space)-1)
u.val[500] = 4 # u.val[20,20] = 3
u.val[700] = 5 # u.val[15,96] = 4
u.val[900] = 2 # u.val[128,128] = 5
u.val[154] = 0.5 # u.val[65,33] = 6
u.val[421] = 0.25 # u.val[156,119] = 4.5
u.val[652] = 1 # u.val[16,125] = 4.5
u.val[1002] = 2.5 # u.val[156,51] = 2
# u.val[235,62] = 3.5
d = ift.exp(s) + ift.exp(u) # u.val[54,125] = 1.3
x=np.random.randint(0,s_space.shape[0],(s_space.shape[0]/2,1))
y=np.random.randint(0,s_space.shape[0],(s_space.shape[0]/2,1))
brightness = np.random.uniform(-2,5,(s_space.shape[0]/2,1))
u.val[x,y] = brightness
brightness =ift.log(ift.Field(s_space,val=invgamma(0.5).rvs(s_space.shape))/1000)
# u.val[200] = 1
# u.val[300] = 3
# u.val[500] = 4
# u.val[700] = 5
# u.val[900] = 2
# u.val[154] = 0.5
# u.val[421] = 0.25
# u.val[652] = 1
# u.val[1002] = 2.5
# R = ift.FFTSmoothingOperator(s_space,sigma=0.001)
# R(u)
u=brightness
d = ift.exp(s) +ift.exp(u)
data = d.val data = d.val
energy1 = sb.build_starblade(data,1.25) energy1 = sb.build_starblade(data,1.25, newton_iterations=5, cg_iterations=500, q=1e-30)#, manual_power_spectrum= mod_p_spec)
energy2 = sb.build_starblade(data,1.5) energy2 = sb.build_starblade(data,1.5, newton_iterations=5, cg_iterations=500, q=1e-30)#, manual_power_spectrum= mod_p_spec)
energy3 = sb.build_starblade(data,1.75) energy3 = sb.build_starblade(data,1.75, newton_iterations=5, cg_iterations=500, q=1e-30)#, manual_power_spectrum=mod_p_spec)
for i in range(20): for i in range(10):
energy1 = sb.starblade_iteration(energy1) energy1 = sb.starblade_iteration(energy1, samples=0)
energy2 = sb.starblade_iteration(energy2) energy2 = sb.starblade_iteration(energy2, samples=0)
energy3 = sb.starblade_iteration(energy3) energy3 = sb.starblade_iteration(energy3, samples=0)
plt.imsave("2d_data.png",np.log(data))
size = 15 size = 15
plt.figure() plt.figure()
# plt.plot(data, 'k-') # plt.plot(data, 'k-')
f, (ax0, ax1,ax2) = plt.subplots(3, sharex=True, sharey=True) f, (ax0, ax1,ax2) = plt.subplots(3, sharex=True, sharey=True)
plt.suptitle('diffuse components', size=size) plt.suptitle('diffuse components', size=size)
ax0.plot(ift.exp(energy1.s).val, 'k-') ax0.plot(ift.exp(energy1.s).val, 'k-',alpha=0.1)
ax0.yaxis.set_label_position("right") ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'$\alpha = 1.25$', size=size) ax0.set_ylabel(r'$\alpha = 1.25$', size=size)
ax0.set_ylim(1e-1,1e3) ax0.set_ylim(1e-1,1e3)
ax0.set_yscale("log") ax0.set_yscale("log")
ax1.plot(ift.exp(energy2.s).val, 'k-') ax1.plot(ift.exp(energy2.s).val, 'k-',alpha=0.1)
ax1.yaxis.set_label_position("right") ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'$\alpha = 1.5$', size=size) ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
ax2.plot(ift.exp(energy3.s).val, 'k-') ax2.plot(ift.exp(energy3.s).val, 'k-',alpha=0.1)
ax2.yaxis.set_label_position("right") ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'$\alpha = 1.75$', size=size) ax2.set_ylabel(r'$\alpha = 1.75$', size=size)
...@@ -67,17 +95,17 @@ if __name__ == '__main__': ...@@ -67,17 +95,17 @@ if __name__ == '__main__':
plt.suptitle('point-like components', size=size) plt.suptitle('point-like components', size=size)
ax0.plot(ift.exp(energy1.u).val, 'k-') ax0.plot(ift.exp(energy1.u).val, 'k-',alpha=0.1)
ax0.yaxis.set_label_position("right") ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'$\alpha = 1.25$', size=size) ax0.set_ylabel(r'$\alpha = 1.25$', size=size)
ax0.set_ylim(1e-1,1e3) ax0.set_ylim(1e-1,1e3)
ax0.set_yscale("log") ax0.set_yscale("log")
ax1.plot(ift.exp(energy2.u).val, 'k-') ax1.plot(ift.exp(energy2.u).val, 'k-',alpha=0.1)
ax1.yaxis.set_label_position("right") ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'$\alpha = 1.5$', size=size) ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
ax2.plot(ift.exp(energy3.u).val, 'k-') ax2.plot(ift.exp(energy3.u).val, 'k-',alpha=0.1)
ax2.yaxis.set_label_position("right") ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'$\alpha = 1.75$', size=size) ax2.set_ylabel(r'$\alpha = 1.75$', size=size)
...@@ -91,17 +119,17 @@ if __name__ == '__main__': ...@@ -91,17 +119,17 @@ if __name__ == '__main__':
f, (ax0, ax1,ax2) = plt.subplots(3, sharex=True, sharey=True) f, (ax0, ax1,ax2) = plt.subplots(3, sharex=True, sharey=True)
plt.suptitle('data and true components', size=size) plt.suptitle('data and true components', size=size)
ax0.plot(data, 'k-') ax0.plot(data, 'k-',alpha=0.1)
ax0.set_yscale("log") ax0.set_yscale("log")
ax0.set_ylim(1e-1,1e3) ax0.set_ylim(1e-1,1e3)
ax0.yaxis.set_label_position("right") ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'data', size=size) ax0.set_ylabel(r'data', size=size)
ax1.plot(ift.exp(s).val, 'k-') ax1.plot(ift.exp(s).val, 'k-',alpha=0.1)
ax1.yaxis.set_label_position("right") ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'diffuse', size=size) ax1.set_ylabel(r'diffuse', size=size)
ax2.plot(ift.exp(u).val, 'k-') ax2.plot(ift.exp(u).val, 'k-',alpha=0.1)
ax2.yaxis.set_label_position("right") ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'point-like', size=size) ax2.set_ylabel(r'point-like', size=size)
...@@ -109,21 +137,25 @@ if __name__ == '__main__': ...@@ -109,21 +137,25 @@ if __name__ == '__main__':
plt.savefig('1d_data.pdf') plt.savefig('1d_data.pdf')
plt.figure() plt.figure()
binbounds = ift.PowerSpace.useful_binbounds(energy2.FFT.domain[0],logarithmic=False, nbin=100) binbounds = ift.PowerSpace.useful_binbounds(energy2.FFT.domain[0],logarithmic=True)#, nbin=100)
power1 = ift.power_analyze(energy2.FFT.inverse((energy1.s)),binbounds=binbounds) power1 = ift.power_analyze(energy2.FFT.inverse((energy1.s)),binbounds=binbounds)
power2 = ift.power_analyze(energy2.FFT.inverse((energy2.s)),binbounds=binbounds) power2 = ift.power_analyze(energy2.FFT.inverse((energy2.s)),binbounds=binbounds)
power3 = ift.power_analyze(energy2.FFT.inverse((energy3.s)),binbounds=binbounds) power3 = ift.power_analyze(energy2.FFT.inverse((energy3.s)),binbounds=binbounds)
pp1 = energy1.power_spectrum
pp2 = energy2.power_spectrum
pp3 = energy3.power_spectrum
real_power = ift.power_analyze(sh) real_power = ift.power_analyze(sh)
power_d = ift.power_analyze(energy2.FFT.inverse(ift.log(energy2.d)),binbounds=binbounds) power_d = ift.power_analyze(energy2.FFT.inverse(ift.log(energy2.d)),binbounds=binbounds)
# power_u = ift.power_analyze(energy2.FFT.inverse(ift.exp(energy2.u)),binbounds=binbounds) # power_u = ift.power_analyze(energy2.FFT.inverse(ift.exp(energy2.u)),binbounds=binbounds)
k_lengths=power1.domain[0].k_lengths
k_lengths = power1.domain[0].k_lengths plt.plot(k_lengths, p_spec(k_lengths*1024.)*1024**4, 'k-', label='theoretical')
plt.plot(k_lengths, p_spec(k_lengths*1024.)*1024**2, 'k-', label='theoretical')
plt.plot(k_lengths, power_d.val, 'k:', label='data') plt.plot(k_lengths, power_d.val, 'k:', label='data')
plt.plot(k_lengths, power1.val, 'k-', label=(r'$\alpha = 1.25$'), alpha=0.6) plt.plot(k_lengths, power1.val, 'k-', label=(r'$\alpha = 1.25$'), alpha=0.6)
plt.plot(k_lengths, power2.val, 'k-', label=(r'$\alpha = 1.5$'),alpha=0.3) plt.plot(k_lengths, power2.val, 'k-', label=(r'$\alpha = 1.5$'),alpha=0.3)
plt.plot(k_lengths, power3.val, 'k-', label=(r'$\alpha = 1.75$'), alpha=0.15) plt.plot(k_lengths, power3.val, 'k-', label=(r'$\alpha = 1.75$'), alpha=0.15)
plt.plot(k_lengths, pp1.val, 'r-', label=(r'$\alpha = 1.25$'), alpha=0.6)
plt.plot(k_lengths, pp2.val, 'r-', label=(r'$\alpha = 1.5$'),alpha=0.3)
plt.plot(k_lengths, pp3.val, 'r-', label=(r'$\alpha = 1.75$'), alpha=0.15)
# plt.plot(k_lengths, power_u.val, 'k:',label='point-like') # plt.plot(k_lengths, power_u.val, 'k:',label='point-like')
plt.legend() plt.legend()
...@@ -133,4 +165,5 @@ if __name__ == '__main__': ...@@ -133,4 +165,5 @@ if __name__ == '__main__':
plt.ylabel('power',size=15) plt.ylabel('power',size=15)
plt.xlabel('harmonic mode',size=15) plt.xlabel('harmonic mode',size=15)
plt.savefig('1d_power.pdf') plt.savefig('1d_power.pdf')
plt.close('all')
...@@ -47,15 +47,17 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iteratio ...@@ -47,15 +47,17 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iteratio
h_space = s_space.get_default_codomain() h_space = s_space.get_default_codomain()
data = ift.Field(s_space,val=data) data = ift.Field(s_space,val=data)
FFT = ift.FFTOperator(h_space, target=s_space) FFT = ift.FFTOperator(h_space, target=s_space)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False) binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = True)
p_space = ift.PowerSpace(h_space, binbounds=binbounds) p_space = ift.PowerSpace(h_space, binbounds=binbounds)
if manual_power_spectrum is None: if manual_power_spectrum is None:
initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)), initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)),
binbounds=p_space.binbounds) binbounds=p_space.binbounds)
initial_spectrum /= (p_space.k_lengths+1.)**2 initial_spectrum /= (p_space.k_lengths+1.)**4
update_power = True update_power = True
else: else:
initial_spectrum = manual_power_spectrum initial_spectrum = manual_power_spectrum
update_power = False update_power = False
initial_x = ift.Field(s_space, val=-1.) initial_x = ift.Field(s_space, val=-1.)
alpha = ift.Field(s_space, val=alpha) alpha = ift.Field(s_space, val=alpha)
...@@ -87,15 +89,20 @@ def starblade_iteration(starblade, samples=3): ...@@ -87,15 +89,20 @@ def starblade_iteration(starblade, samples=3):
tol_abs_gradnorm=1e-8, tol_abs_gradnorm=1e-8,
iteration_limit=starblade.newton_iterations) iteration_limit=starblade.newton_iterations)
minimizer = ift.RelaxedNewton(controller=controller) minimizer = ift.RelaxedNewton(controller=controller)
# if len(sample_list)>0:
# energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters)
# else:
energy = starblade
energy, convergence = minimizer(energy)
sample_list = [] sample_list = []
for i in range(samples): for i in range(samples):
sample = starblade.curvature.inverse.draw_sample() sample = energy.curvature.inverse.draw_sample()
sample_list.append(sample) sample_list.append(sample)
if len(sample_list)>0: if len(sample_list) == 0:
energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters) sample_list.append(energy.position)
else: # energy = StarbladeKL(energy.position, samples=sample_list, parameters=energy.parameters)
energy = starblade
energy, convergence = minimizer(energy)
new_position = energy.position new_position = energy.position
new_parameters = energy.parameters new_parameters = energy.parameters
if energy.parameters['update_power']: if energy.parameters['update_power']:
...@@ -164,9 +171,9 @@ def update_power(energy): ...@@ -164,9 +171,9 @@ def update_power(energy):
if isinstance(energy, StarbladeKL): if isinstance(energy, StarbladeKL):
power = 0. power = 0.
for en in energy.energy_list: for en in energy.energy_list:
power += ift.power_analyze(energy.parameters['FFT'].inverse_times(en.s), power = ift.power_analyze(energy.parameters['FFT'].inverse_times(en.s),
binbounds=en.parameters['power_spectrum'].domain[0].binbounds) binbounds=en.parameters['power_spectrum'].domain[0].binbounds)
power /= len(energy.energy_list) # power /= len(energy.energy_list)
else: else:
power = ift.power_analyze(energy.FFT.inverse_times(energy.s), power = ift.power_analyze(energy.FFT.inverse_times(energy.s),
binbounds=energy.parameters['power_spectrum'].domain[0].binbounds) binbounds=energy.parameters['power_spectrum'].domain[0].binbounds)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment