Commit fd208746 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'find_pos_merge' into 'NIFTy_6'

Find pos merge

See merge request !485
parents abf56d58 0921954d
Pipeline #75326 passed with stages
in 8 minutes and 52 seconds
......@@ -83,9 +83,9 @@ class MultiField(Operator):
def domain(self):
return self._domain
# @property
# def dtype(self):
# return {key: val.dtype for key, val in self._val.items()}
@property
def dtype(self):
return {key: val.dtype for key, val in self.items()}
def _transform(self, op):
return MultiField(self._domain, tuple(op(v) for v in self._val))
......
......@@ -124,6 +124,10 @@ class SamplingDtypeSetter(EndomorphicOperator):
need to conincide the with keys of the `MultiDomain`.
"""
def __init__(self, op, dtype):
if isinstance(op, SamplingDtypeSetter):
if op._dtype != dtype:
raise ValueError('Dtype for sampling already set to another dtype.')
op = op._op
if not isinstance(op, EndomorphicOperator):
raise TypeError
if not hasattr(op, 'draw_sample_with_dtype'):
......
......@@ -25,9 +25,13 @@ def _sqrt_helper(v):
def _sinc_helper(v):
tmp = np.sinc(v)
tmp2 = (np.cos(np.pi*v)-tmp)/v
return (tmp, np.where(v==0., 0, tmp2))
fv = np.sinc(v)
df = np.empty(v.shape, dtype=v.dtype)
sel = v != 0.
v = v[sel]
df[sel] = (np.cos(np.pi*v)-fv[sel])/v
df[~sel] = 0
return (fv, df)
def _expm1_helper(v):
......@@ -54,13 +58,13 @@ def _reciprocal_helper(v):
def _abs_helper(v):
if np.issubdtype(v.dtype, np.complexfloating):
raise TypeError("Argument must not be complex")
return (np.abs(v), np.where(v==0, np.nan, np.sign(v)))
return (np.abs(v), np.where(v == 0, np.nan, np.sign(v)))
def _sign_helper(v):
if np.issubdtype(v.dtype, np.complexfloating):
raise TypeError("Argument must not be complex")
return (np.sign(v), np.where(v==0, np.nan, 0))
return (np.sign(v), np.where(v == 0, np.nan, 0))
def _power_helper(v, expo):
......@@ -73,21 +77,21 @@ def _clip_helper(v, a_min, a_max):
tmp = np.clip(v, a_min, a_max)
tmp2 = np.ones(v.shape)
if a_min is not None:
tmp2 = np.where(tmp==a_min, 0., tmp2)
tmp2 = np.where(tmp == a_min, 0., tmp2)
if a_max is not None:
tmp2 = np.where(tmp==a_max, 0., tmp2)
tmp2 = np.where(tmp == a_max, 0., tmp2)
return (tmp, tmp2)
ptw_dict = {
"sqrt": (np.sqrt, _sqrt_helper),
"sin" : (np.sin, lambda v: (np.sin(v), np.cos(v))),
"cos" : (np.cos, lambda v: (np.cos(v), -np.sin(v))),
"tan" : (np.tan, lambda v: (np.tan(v), 1./np.cos(v)**2)),
"sin": (np.sin, lambda v: (np.sin(v), np.cos(v))),
"cos": (np.cos, lambda v: (np.cos(v), -np.sin(v))),
"tan": (np.tan, lambda v: (np.tan(v), 1./np.cos(v)**2)),
"sinc": (np.sinc, _sinc_helper),
"exp" : (np.exp, lambda v: (2*(np.exp(v),))),
"expm1" : (np.expm1, _expm1_helper),
"log" : (np.log, lambda v: (np.log(v), 1./v)),
"exp": (np.exp, lambda v: (2*(np.exp(v),))),
"expm1": (np.expm1, _expm1_helper),
"log": (np.log, lambda v: (np.log(v), 1./v)),
"log10": (np.log10, lambda v: (np.log10(v), (1./np.log(10.))/v)),
"log1p": (np.log1p, lambda v: (np.log1p(v), 1./(1.+v))),
"sinh": (np.sinh, lambda v: (np.sinh(v), np.cosh(v))),
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -20,21 +20,20 @@ from time import time
import numpy as np
from .logger import logger
from . import utilities
from . import pointwise, utilities
from .domain_tuple import DomainTuple
from .domains.power_space import PowerSpace
from .field import Field
from .logger import logger
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import PowerDistributor
from .operators.operator import Operator
from .operators.sampling_enabler import SamplingDtypeSetter
from .operators.scaling_operator import ScalingOperator
from .plot import Plot
from . import pointwise
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
......@@ -501,17 +500,26 @@ def calculate_position(operator, output):
if output.domain != operator.target:
raise TypeError
if isinstance(output, MultiField):
cov = 1e-3*max([vv.max() for vv in output.val.values()])**2
cov = 1e-3*max([np.max(np.abs(vv)) for vv in output.val.values()])**2
invcov = ScalingOperator(output.domain, cov).inverse
dtype = list(set([ff.dtype for ff in output.values()]))
if len(dtype) != 1:
raise ValueError('Only MultiFields with one dtype supported.')
dtype = dtype[0]
else:
cov = 1e-3*output.val.max()**2
cov = 1e-3*np.max(np.abs(output.val))**2
dtype = output.dtype
invcov = ScalingOperator(output.domain, cov).inverse
d = output + invcov.draw_sample_with_dtype(dtype=output.dtype, from_inverse=True)
invcov = SamplingDtypeSetter(invcov, output.dtype)
invcov = SamplingDtypeSetter(invcov, output.dtype)
d = output + invcov.draw_sample(from_inverse=True)
lh = GaussianEnergy(d, invcov) @ operator
H = StandardHamiltonian(
lh, ic_samp=GradientNormController(iteration_limit=200))
pos = 0.1 * from_random(operator.domain, 'normal')
minimizer = NewtonCG(GradientNormController(iteration_limit=10))
pos = 0.1*from_random(operator.domain)
minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos'))
for ii in range(3):
logger.info(f'Start iteration {ii+1}/3')
kl = MetricGaussianKL(pos, H, 3, mirror_samples=True)
kl, _ = minimizer(kl)
pos = kl.position
......
......@@ -57,14 +57,28 @@ def test_special_gradients():
'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh',
'absolute', 'reciprocal', 'sigmoid', 'log10', 'log1p', "expm1"
])
def test_actual_gradients(f):
@pmp('cplxpos', [True, False])
@pmp('cplxdir', [True, False])
@pmp('holomorphic', [True, False])
def test_actual_gradients(f, cplxpos, cplxdir, holomorphic):
if (cplxpos or cplxdir) and f in ['absolute']:
return
if holomorphic and f in ['absolute']:
# These function are not holomorphic
return
dom = ift.UnstructuredDomain((1,))
fld = ift.full(dom, 2.4)
eps = 1e-8
if cplxpos:
fld = fld + 0.21j
eps = 1e-7
if cplxdir:
eps *= 1j
if holomorphic:
eps *= (1+0.78j)
var0 = ift.Linearization.make_var(fld)
var1 = ift.Linearization.make_var(fld + eps)
f0 = var0.ptw(f).val.val
f1 = var1.ptw(f).val.val
df0 = (f1 - f0)/eps
df1 = _lin2grad(var0.ptw(f))
assert_allclose(df0, df1, rtol=100*eps)
assert_allclose(df0, df1, rtol=100*np.abs(eps))
......@@ -27,7 +27,6 @@ name = (f'plot{nr}.png' for nr in count())
def test_plots():
# FIXME Write to temporary folder?
rg_space1 = ift.makeDomain(ift.RGSpace((10,)))
rg_space2 = ift.makeDomain(ift.RGSpace((8, 6), distances=1))
hp_space = ift.makeDomain(ift.HPSpace(5))
......@@ -75,4 +74,5 @@ def test_mf_plot():
plot = ift.Plot()
plot.add(f1, block=False, title='f_space_idx = 1')
plot.add(f2, freq_space_idx=0, title='f_space_idx = 0')
plot.output(nx=2, ny=1, title='MF-Plots, should look identical', name=next(name))
plot.output(nx=2, ny=1, title='MF-Plots, should look identical',
name=next(name))
......@@ -52,9 +52,18 @@ def test_exec_time():
ift.exec_time(oo, wm)
def test_calc_pos():
import pytest
pmp = pytest.mark.parametrize
@pmp('mf', [False, True])
@pmp('cplx', [False, True])
def test_calc_pos(mf, cplx):
dom = ift.RGSpace(12, harmonic=True)
op = ift.HarmonicTransformOperator(dom).ptw("exp")
if mf:
op = op.ducktape_left('foo')
dom = ift.makeDomain({'': dom})
if cplx:
op = op + 1j*op
fld = op(0.1 * ift.from_random(op.domain, 'normal'))
pos = ift.calculate_position(op, fld)
ift.extra.assert_allclose(op(pos), fld, 1e-1, 1e-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment