Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
fd208746
Commit
fd208746
authored
May 20, 2020
by
Martin Reinecke
Browse files
Merge branch 'find_pos_merge' into 'NIFTy_6'
Find pos merge See merge request
!485
parents
abf56d58
0921954d
Pipeline
#75326
passed with stages
in 8 minutes and 52 seconds
Changes
7
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/multi_field.py
View file @
fd208746
...
...
@@ -83,9 +83,9 @@ class MultiField(Operator):
def
domain
(
self
):
return
self
.
_domain
#
@property
#
def dtype(self):
#
return {key: val.dtype for key, val in self.
_val.
items()}
@
property
def
dtype
(
self
):
return
{
key
:
val
.
dtype
for
key
,
val
in
self
.
items
()}
def
_transform
(
self
,
op
):
return
MultiField
(
self
.
_domain
,
tuple
(
op
(
v
)
for
v
in
self
.
_val
))
...
...
nifty6/operators/sampling_enabler.py
View file @
fd208746
...
...
@@ -124,6 +124,10 @@ class SamplingDtypeSetter(EndomorphicOperator):
need to conincide the with keys of the `MultiDomain`.
"""
def
__init__
(
self
,
op
,
dtype
):
if
isinstance
(
op
,
SamplingDtypeSetter
):
if
op
.
_dtype
!=
dtype
:
raise
ValueError
(
'Dtype for sampling already set to another dtype.'
)
op
=
op
.
_op
if
not
isinstance
(
op
,
EndomorphicOperator
):
raise
TypeError
if
not
hasattr
(
op
,
'draw_sample_with_dtype'
):
...
...
nifty6/pointwise.py
View file @
fd208746
...
...
@@ -25,9 +25,13 @@ def _sqrt_helper(v):
def
_sinc_helper
(
v
):
tmp
=
np
.
sinc
(
v
)
tmp2
=
(
np
.
cos
(
np
.
pi
*
v
)
-
tmp
)
/
v
return
(
tmp
,
np
.
where
(
v
==
0.
,
0
,
tmp2
))
fv
=
np
.
sinc
(
v
)
df
=
np
.
empty
(
v
.
shape
,
dtype
=
v
.
dtype
)
sel
=
v
!=
0.
v
=
v
[
sel
]
df
[
sel
]
=
(
np
.
cos
(
np
.
pi
*
v
)
-
fv
[
sel
])
/
v
df
[
~
sel
]
=
0
return
(
fv
,
df
)
def
_expm1_helper
(
v
):
...
...
@@ -54,13 +58,13 @@ def _reciprocal_helper(v):
def
_abs_helper
(
v
):
if
np
.
issubdtype
(
v
.
dtype
,
np
.
complexfloating
):
raise
TypeError
(
"Argument must not be complex"
)
return
(
np
.
abs
(
v
),
np
.
where
(
v
==
0
,
np
.
nan
,
np
.
sign
(
v
)))
return
(
np
.
abs
(
v
),
np
.
where
(
v
==
0
,
np
.
nan
,
np
.
sign
(
v
)))
def
_sign_helper
(
v
):
if
np
.
issubdtype
(
v
.
dtype
,
np
.
complexfloating
):
raise
TypeError
(
"Argument must not be complex"
)
return
(
np
.
sign
(
v
),
np
.
where
(
v
==
0
,
np
.
nan
,
0
))
return
(
np
.
sign
(
v
),
np
.
where
(
v
==
0
,
np
.
nan
,
0
))
def
_power_helper
(
v
,
expo
):
...
...
@@ -73,21 +77,21 @@ def _clip_helper(v, a_min, a_max):
tmp
=
np
.
clip
(
v
,
a_min
,
a_max
)
tmp2
=
np
.
ones
(
v
.
shape
)
if
a_min
is
not
None
:
tmp2
=
np
.
where
(
tmp
==
a_min
,
0.
,
tmp2
)
tmp2
=
np
.
where
(
tmp
==
a_min
,
0.
,
tmp2
)
if
a_max
is
not
None
:
tmp2
=
np
.
where
(
tmp
==
a_max
,
0.
,
tmp2
)
tmp2
=
np
.
where
(
tmp
==
a_max
,
0.
,
tmp2
)
return
(
tmp
,
tmp2
)
ptw_dict
=
{
"sqrt"
:
(
np
.
sqrt
,
_sqrt_helper
),
"sin"
:
(
np
.
sin
,
lambda
v
:
(
np
.
sin
(
v
),
np
.
cos
(
v
))),
"cos"
:
(
np
.
cos
,
lambda
v
:
(
np
.
cos
(
v
),
-
np
.
sin
(
v
))),
"tan"
:
(
np
.
tan
,
lambda
v
:
(
np
.
tan
(
v
),
1.
/
np
.
cos
(
v
)
**
2
)),
"sin"
:
(
np
.
sin
,
lambda
v
:
(
np
.
sin
(
v
),
np
.
cos
(
v
))),
"cos"
:
(
np
.
cos
,
lambda
v
:
(
np
.
cos
(
v
),
-
np
.
sin
(
v
))),
"tan"
:
(
np
.
tan
,
lambda
v
:
(
np
.
tan
(
v
),
1.
/
np
.
cos
(
v
)
**
2
)),
"sinc"
:
(
np
.
sinc
,
_sinc_helper
),
"exp"
:
(
np
.
exp
,
lambda
v
:
(
2
*
(
np
.
exp
(
v
),))),
"expm1"
:
(
np
.
expm1
,
_expm1_helper
),
"log"
:
(
np
.
log
,
lambda
v
:
(
np
.
log
(
v
),
1.
/
v
)),
"exp"
:
(
np
.
exp
,
lambda
v
:
(
2
*
(
np
.
exp
(
v
),))),
"expm1"
:
(
np
.
expm1
,
_expm1_helper
),
"log"
:
(
np
.
log
,
lambda
v
:
(
np
.
log
(
v
),
1.
/
v
)),
"log10"
:
(
np
.
log10
,
lambda
v
:
(
np
.
log10
(
v
),
(
1.
/
np
.
log
(
10.
))
/
v
)),
"log1p"
:
(
np
.
log1p
,
lambda
v
:
(
np
.
log1p
(
v
),
1.
/
(
1.
+
v
))),
"sinh"
:
(
np
.
sinh
,
lambda
v
:
(
np
.
sinh
(
v
),
np
.
cosh
(
v
))),
...
...
nifty6/sugar.py
View file @
fd208746
...
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-20
19
Max-Planck-Society
# Copyright(C) 2013-20
20
Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
...
@@ -20,21 +20,20 @@ from time import time
import
numpy
as
np
from
.logger
import
logger
from
.
import
utilities
from
.
import
pointwise
,
utilities
from
.domain_tuple
import
DomainTuple
from
.domains.power_space
import
PowerSpace
from
.field
import
Field
from
.logger
import
logger
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.operators.block_diagonal_operator
import
BlockDiagonalOperator
from
.operators.diagonal_operator
import
DiagonalOperator
from
.operators.distributors
import
PowerDistributor
from
.operators.operator
import
Operator
from
.operators.sampling_enabler
import
SamplingDtypeSetter
from
.operators.scaling_operator
import
ScalingOperator
from
.plot
import
Plot
from
.
import
pointwise
__all__
=
[
'PS_field'
,
'power_analyze'
,
'create_power_operator'
,
'create_harmonic_smoothing_operator'
,
'from_random'
,
...
...
@@ -501,17 +500,26 @@ def calculate_position(operator, output):
if
output
.
domain
!=
operator
.
target
:
raise
TypeError
if
isinstance
(
output
,
MultiField
):
cov
=
1e-3
*
max
([
vv
.
max
()
for
vv
in
output
.
val
.
values
()])
**
2
cov
=
1e-3
*
max
([
np
.
max
(
np
.
abs
(
vv
))
for
vv
in
output
.
val
.
values
()])
**
2
invcov
=
ScalingOperator
(
output
.
domain
,
cov
).
inverse
dtype
=
list
(
set
([
ff
.
dtype
for
ff
in
output
.
values
()]))
if
len
(
dtype
)
!=
1
:
raise
ValueError
(
'Only MultiFields with one dtype supported.'
)
dtype
=
dtype
[
0
]
else
:
cov
=
1e-3
*
output
.
val
.
max
()
**
2
cov
=
1e-3
*
np
.
max
(
np
.
abs
(
output
.
val
))
**
2
dtype
=
output
.
dtype
invcov
=
ScalingOperator
(
output
.
domain
,
cov
).
inverse
d
=
output
+
invcov
.
draw_sample_with_dtype
(
dtype
=
output
.
dtype
,
from_inverse
=
True
)
invcov
=
SamplingDtypeSetter
(
invcov
,
output
.
dtype
)
invcov
=
SamplingDtypeSetter
(
invcov
,
output
.
dtype
)
d
=
output
+
invcov
.
draw_sample
(
from_inverse
=
True
)
lh
=
GaussianEnergy
(
d
,
invcov
)
@
operator
H
=
StandardHamiltonian
(
lh
,
ic_samp
=
GradientNormController
(
iteration_limit
=
200
))
pos
=
0.1
*
from_random
(
operator
.
domain
,
'normal'
)
minimizer
=
NewtonCG
(
GradientNormController
(
iteration_limit
=
10
))
pos
=
0.1
*
from_random
(
operator
.
domain
)
minimizer
=
NewtonCG
(
GradientNormController
(
iteration_limit
=
10
,
name
=
'findpos'
))
for
ii
in
range
(
3
):
logger
.
info
(
f
'Start iteration
{
ii
+
1
}
/3'
)
kl
=
MetricGaussianKL
(
pos
,
H
,
3
,
mirror_samples
=
True
)
kl
,
_
=
minimizer
(
kl
)
pos
=
kl
.
position
...
...
test/test_linearization.py
View file @
fd208746
...
...
@@ -57,14 +57,28 @@ def test_special_gradients():
'log'
,
'exp'
,
'sqrt'
,
'sin'
,
'cos'
,
'tan'
,
'sinc'
,
'sinh'
,
'cosh'
,
'tanh'
,
'absolute'
,
'reciprocal'
,
'sigmoid'
,
'log10'
,
'log1p'
,
"expm1"
])
def
test_actual_gradients
(
f
):
@
pmp
(
'cplxpos'
,
[
True
,
False
])
@
pmp
(
'cplxdir'
,
[
True
,
False
])
@
pmp
(
'holomorphic'
,
[
True
,
False
])
def
test_actual_gradients
(
f
,
cplxpos
,
cplxdir
,
holomorphic
):
if
(
cplxpos
or
cplxdir
)
and
f
in
[
'absolute'
]:
return
if
holomorphic
and
f
in
[
'absolute'
]:
# These function are not holomorphic
return
dom
=
ift
.
UnstructuredDomain
((
1
,))
fld
=
ift
.
full
(
dom
,
2.4
)
eps
=
1e-8
if
cplxpos
:
fld
=
fld
+
0.21j
eps
=
1e-7
if
cplxdir
:
eps
*=
1j
if
holomorphic
:
eps
*=
(
1
+
0.78j
)
var0
=
ift
.
Linearization
.
make_var
(
fld
)
var1
=
ift
.
Linearization
.
make_var
(
fld
+
eps
)
f0
=
var0
.
ptw
(
f
).
val
.
val
f1
=
var1
.
ptw
(
f
).
val
.
val
df0
=
(
f1
-
f0
)
/
eps
df1
=
_lin2grad
(
var0
.
ptw
(
f
))
assert_allclose
(
df0
,
df1
,
rtol
=
100
*
eps
)
assert_allclose
(
df0
,
df1
,
rtol
=
100
*
np
.
abs
(
eps
)
)
test/test_plot.py
View file @
fd208746
...
...
@@ -27,7 +27,6 @@ name = (f'plot{nr}.png' for nr in count())
def
test_plots
():
# FIXME Write to temporary folder?
rg_space1
=
ift
.
makeDomain
(
ift
.
RGSpace
((
10
,)))
rg_space2
=
ift
.
makeDomain
(
ift
.
RGSpace
((
8
,
6
),
distances
=
1
))
hp_space
=
ift
.
makeDomain
(
ift
.
HPSpace
(
5
))
...
...
@@ -75,4 +74,5 @@ def test_mf_plot():
plot
=
ift
.
Plot
()
plot
.
add
(
f1
,
block
=
False
,
title
=
'f_space_idx = 1'
)
plot
.
add
(
f2
,
freq_space_idx
=
0
,
title
=
'f_space_idx = 0'
)
plot
.
output
(
nx
=
2
,
ny
=
1
,
title
=
'MF-Plots, should look identical'
,
name
=
next
(
name
))
plot
.
output
(
nx
=
2
,
ny
=
1
,
title
=
'MF-Plots, should look identical'
,
name
=
next
(
name
))
test/test_sugar.py
View file @
fd208746
...
...
@@ -52,9 +52,18 @@ def test_exec_time():
ift
.
exec_time
(
oo
,
wm
)
def
test_calc_pos
():
import
pytest
pmp
=
pytest
.
mark
.
parametrize
@
pmp
(
'mf'
,
[
False
,
True
])
@
pmp
(
'cplx'
,
[
False
,
True
])
def
test_calc_pos
(
mf
,
cplx
):
dom
=
ift
.
RGSpace
(
12
,
harmonic
=
True
)
op
=
ift
.
HarmonicTransformOperator
(
dom
).
ptw
(
"exp"
)
if
mf
:
op
=
op
.
ducktape_left
(
'foo'
)
dom
=
ift
.
makeDomain
({
''
:
dom
})
if
cplx
:
op
=
op
+
1j
*
op
fld
=
op
(
0.1
*
ift
.
from_random
(
op
.
domain
,
'normal'
))
pos
=
ift
.
calculate_position
(
op
,
fld
)
ift
.
extra
.
assert_allclose
(
op
(
pos
),
fld
,
1e-1
,
1e-1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment