Commit 3541d7e0 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge remote-tracking branch 'origin/mr_suggestions' into global_newton

parents 9ba7ee05 70f7f62f
__pycache__
# never store the git version file
git_version.py
# custom
setup.cfg
.idea
.DS_Store
*.pyc
*.html
.document
.svn/
*.csv
# from https://github.com/github/gitignore/blob/master/Python.gitignore
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.c
*.o
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/build/
docs/source/mod
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2017-2018 Max-Planck-Society
# Author: Jakob Knollmueller
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
import nifty4 as ift
import numpy as np
class StarbladeEnergy(ift.Energy):
"""The Energy for the starblade problem.
It implements the Information Hamiltonian of the separation of d
Parameters
----------
position : Field
The current position of the separation.
data : Field
The image data.
alpha : float
Slope parameter of the point-source prior
q : float
Cutoff parameter of the point-source prior
correlation : Field
A field in the Fourier space which encodes the diagonal of the prior
correlation structure of the diffuse component
inverter : ConjugateGradient (optional)
the minimization strategy to use for operator inversion
If None, the energy will not be able to compute curvatures
"""
def __init__(self, position, data, alpha, q, correlation, inverter=None):
if (position > 9.).any() or (position < -9.).any():
raise ValueError("position outside allowed range")
......@@ -22,8 +61,8 @@ class StarbladeEnergy(ift.Energy):
self._ptanh = ift.library.PositiveTanh()
a = self._ptanh(position)
a_p = self._ptanh.derivative(position)
one_m_a = 1. - a
s = ift.log(data * one_m_a)
one_m_a = 1.-a
s = ift.log(data*one_m_a)
if correlation is None:
binbounds = ift.PowerSpace.useful_binbounds(h_space,
......@@ -33,13 +72,13 @@ class StarbladeEnergy(ift.Energy):
correlation = ift.create_power_operator(h_space, correlation)
self._correlation = correlation
S = FFT * correlation * FFT.adjoint
S = FFT*correlation*FFT.adjoint
usum = ift.log(data * a).sum()
usum = ift.log(data*a).sum()
var_x = 9.
Sis = S.inverse(s)
qexpmu = self._q / (data*a)
qexpmu = self._q/(data*a)
diffuse = 0.5 * s.vdot(Sis)
point = (alpha-1)*usum + qexpmu.sum()
......@@ -56,7 +95,7 @@ class StarbladeEnergy(ift.Energy):
self._gradient = (diffuse + point + det).lock()
if inverter is not None: # curvature is needed
point = qexpmu * u_p ** 2
point = qexpmu * u_p**2
R = FFT.inverse * s_p
N = self._correlation
S = ift.DiagonalOperator(1./(point + 1./var_x))
......
......@@ -38,12 +38,19 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1])
data = ift.Field.from_global_data(s_space, data)
h_space = s_space.get_default_codomain()
FFT = ift.FFTOperator(h_space, s_space)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
correlation = ift.power_analyze(FFT.inverse_times(ift.log(data)),
binbounds=binbounds)
correlation /= (ift.PowerSpace(h_space, binbounds).k_lengths+1.)**2
correlation = ift.create_power_operator(h_space, correlation)
#if BFGS:
# return StarbladeEnergy.make(ift.Field.full(s_space, -1.), data, alpha)
ICI = ift.GradientNormController(iteration_limit=cg_iterations,
tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(controller=ICI)
return StarbladeEnergy.make(ift.Field.full(s_space, -1.), data, alpha, q,
return StarbladeEnergy(ift.Field.full(s_space, -1.), data, alpha, q, correlation,
inverter)
......@@ -64,7 +71,9 @@ def starblade_iteration(starblade, iterations=3):
#else:
minimizer = ift.RelaxedNewton(controller=controller)
starblade, convergence = minimizer(starblade)
return starblade.with_new_correlation()
# FIXME: this is not final yet!
return starblade
#return starblade.with_new_correlation()
def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
......@@ -82,7 +91,8 @@ def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
Maximum number of conjugate gradient iterations for numerical operator
inversion (default: 500).
"""
return [build_starblade(data[..., i].copy(), alpha, q, cg_iterations) for i in range(data.shape[-1])]
return [build_starblade(data[..., i].copy(), alpha, q, cg_iterations)
for i in range(data.shape[-1])]
def multi_starblade_iteration(MultiStarblade, multiprocessing=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment