Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
2ec7955a
Commit
2ec7955a
authored
May 20, 2020
by
Philipp Arras
Browse files
Support multifield and complex output for find_position
parent
f40a1834
Pipeline
#75315
failed with stages
in 39 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/sugar.py
View file @
2ec7955a
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#
# Copyright(C) 2013-20
19
Max-Planck-Society
# Copyright(C) 2013-20
20
Max-Planck-Society
#
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
@@ -20,21 +20,24 @@ from time import time
...
@@ -20,21 +20,24 @@ from time import time
import
numpy
as
np
import
numpy
as
np
from
.logger
import
logger
from
.
import
pointwise
,
utilities
from
.
import
utilities
from
.domain_tuple
import
DomainTuple
from
.domain_tuple
import
DomainTuple
from
.domains.power_space
import
PowerSpace
from
.domains.power_space
import
PowerSpace
from
.field
import
Field
from
.field
import
Field
from
.logger
import
logger
from
.minimization.descent_minimizers
import
NewtonCG
from
.minimization.iteration_controllers
import
GradientNormController
from
.minimization.metric_gaussian_kl
import
MetricGaussianKL
from
.multi_domain
import
MultiDomain
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.multi_field
import
MultiField
from
.operators.block_diagonal_operator
import
BlockDiagonalOperator
from
.operators.block_diagonal_operator
import
BlockDiagonalOperator
from
.operators.diagonal_operator
import
DiagonalOperator
from
.operators.diagonal_operator
import
DiagonalOperator
from
.operators.distributors
import
PowerDistributor
from
.operators.distributors
import
PowerDistributor
from
.operators.energy_operators
import
GaussianEnergy
,
StandardHamiltonian
from
.operators.operator
import
Operator
from
.operators.operator
import
Operator
from
.operators.sampling_enabler
import
SamplingDtypeSetter
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.scaling_operator
import
ScalingOperator
from
.plot
import
Plot
from
.plot
import
Plot
from
.
import
pointwise
__all__
=
[
'PS_field'
,
'power_analyze'
,
'create_power_operator'
,
__all__
=
[
'PS_field'
,
'power_analyze'
,
'create_power_operator'
,
'create_harmonic_smoothing_operator'
,
'from_random'
,
'create_harmonic_smoothing_operator'
,
'from_random'
,
...
@@ -491,31 +494,28 @@ def exec_time(obj, want_metric=True):
...
@@ -491,31 +494,28 @@ def exec_time(obj, want_metric=True):
def
calculate_position
(
operator
,
output
):
def
calculate_position
(
operator
,
output
):
"""Finds approximate preimage of an operator for a given output."""
"""Finds approximate preimage of an operator for a given output."""
from
.minimization.descent_minimizers
import
NewtonCG
from
.minimization.iteration_controllers
import
GradientNormController
from
.minimization.metric_gaussian_kl
import
MetricGaussianKL
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.energy_operators
import
GaussianEnergy
,
StandardHamiltonian
if
not
isinstance
(
operator
,
Operator
):
if
not
isinstance
(
operator
,
Operator
):
raise
TypeError
raise
TypeError
if
output
.
domain
!=
operator
.
target
:
if
output
.
domain
!=
operator
.
target
:
raise
TypeError
raise
TypeError
if
isinstance
(
output
,
MultiField
):
if
isinstance
(
output
,
MultiField
):
cov
=
1e-3
*
max
([
vv
.
max
()
for
vv
in
output
.
val
.
values
()])
**
2
cov
=
1e-3
*
max
([
np
.
max
(
np
.
abs
(
vv
)
)
for
vv
in
output
.
val
.
values
()])
**
2
invcov
=
ScalingOperator
(
output
.
domain
,
cov
).
inverse
invcov
=
ScalingOperator
(
output
.
domain
,
cov
).
inverse
dtype
=
list
(
set
([
ff
.
dtype
for
ff
in
output
.
values
()]))
dtype
=
list
(
set
([
ff
.
dtype
for
ff
in
output
.
values
()]))
if
len
(
dtype
)
!=
1
:
if
len
(
dtype
)
!=
1
:
raise
ValueError
(
'Only MultiFields with one dtype supported.'
)
raise
ValueError
(
'Only MultiFields with one dtype supported.'
)
dtype
=
dtype
[
0
]
dtype
=
dtype
[
0
]
else
:
else
:
cov
=
1e-3
*
output
.
val
.
max
(
)
**
2
cov
=
1e-3
*
np
.
max
(
np
.
abs
(
output
.
val
)
)
**
2
dtype
=
output
.
dtype
dtype
=
output
.
dtype
invcov
=
ScalingOperator
(
output
.
domain
,
cov
).
inverse
invcov
=
ScalingOperator
(
output
.
domain
,
cov
).
inverse
d
=
output
+
invcov
.
draw_sample
(
dtype
,
from_inverse
=
True
)
invcov
=
SamplingDtypeSetter
(
invcov
,
output
.
dtype
)
invcov
=
SamplingDtypeSetter
(
invcov
,
output
.
dtype
)
d
=
output
+
invcov
.
draw_sample
(
from_inverse
=
True
)
lh
=
GaussianEnergy
(
d
,
invcov
)
@
operator
lh
=
GaussianEnergy
(
d
,
invcov
)
@
operator
H
=
StandardHamiltonian
(
H
=
StandardHamiltonian
(
lh
,
ic_samp
=
GradientNormController
(
iteration_limit
=
200
))
lh
,
ic_samp
=
GradientNormController
(
iteration_limit
=
200
))
pos
=
0.1
*
from_random
(
'normal'
,
operator
.
domain
)
pos
=
0.1
*
from_random
(
operator
.
domain
)
minimizer
=
NewtonCG
(
GradientNormController
(
iteration_limit
=
10
,
name
=
'findpos'
))
minimizer
=
NewtonCG
(
GradientNormController
(
iteration_limit
=
10
,
name
=
'findpos'
))
for
ii
in
range
(
3
):
for
ii
in
range
(
3
):
logger
.
info
(
f
'Start iteration
{
ii
+
1
}
/3'
)
logger
.
info
(
f
'Start iteration
{
ii
+
1
}
/3'
)
...
...
test/test_sugar.py
View file @
2ec7955a
...
@@ -52,9 +52,18 @@ def test_exec_time():
...
@@ -52,9 +52,18 @@ def test_exec_time():
ift
.
exec_time
(
oo
,
wm
)
ift
.
exec_time
(
oo
,
wm
)
def
test_calc_pos
():
import
pytest
pmp
=
pytest
.
mark
.
parametrize
@
pmp
(
'mf'
,
[
False
,
True
])
@
pmp
(
'cplx'
,
[
False
,
True
])
def
test_calc_pos
(
mf
,
cplx
):
dom
=
ift
.
RGSpace
(
12
,
harmonic
=
True
)
dom
=
ift
.
RGSpace
(
12
,
harmonic
=
True
)
op
=
ift
.
HarmonicTransformOperator
(
dom
).
ptw
(
"exp"
)
op
=
ift
.
HarmonicTransformOperator
(
dom
).
ptw
(
"exp"
)
if
mf
:
op
=
op
.
ducktape_left
(
'foo'
)
dom
=
ift
.
makeDomain
({
''
:
dom
})
if
cplx
:
op
=
op
+
1j
*
op
fld
=
op
(
0.1
*
ift
.
from_random
(
op
.
domain
,
'normal'
))
fld
=
op
(
0.1
*
ift
.
from_random
(
op
.
domain
,
'normal'
))
pos
=
ift
.
calculate_position
(
op
,
fld
)
pos
=
ift
.
calculate_position
(
op
,
fld
)
ift
.
extra
.
assert_allclose
(
op
(
pos
),
fld
,
1e-1
,
1e-1
)
ift
.
extra
.
assert_allclose
(
op
(
pos
),
fld
,
1e-1
,
1e-1
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment