Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
2ec7955a
Commit
2ec7955a
authored
May 20, 2020
by
Philipp Arras
Browse files
Support multifield and complex output for find_position
parent
f40a1834
Pipeline
#75315
failed with stages
in 39 seconds
Changes
2
Pipelines
1
Show whitespace changes
Inline
Side-by-side
nifty6/sugar.py
View file @
2ec7955a
...
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-20
19
Max-Planck-Society
# Copyright(C) 2013-20
20
Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
...
@@ -20,21 +20,24 @@ from time import time
import
numpy
as
np
from
.logger
import
logger
from
.
import
utilities
from
.
import
pointwise
,
utilities
from
.domain_tuple
import
DomainTuple
from
.domains.power_space
import
PowerSpace
from
.field
import
Field
from
.logger
import
logger
from
.minimization.descent_minimizers
import
NewtonCG
from
.minimization.iteration_controllers
import
GradientNormController
from
.minimization.metric_gaussian_kl
import
MetricGaussianKL
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.operators.block_diagonal_operator
import
BlockDiagonalOperator
from
.operators.diagonal_operator
import
DiagonalOperator
from
.operators.distributors
import
PowerDistributor
from
.operators.energy_operators
import
GaussianEnergy
,
StandardHamiltonian
from
.operators.operator
import
Operator
from
.operators.sampling_enabler
import
SamplingDtypeSetter
from
.operators.scaling_operator
import
ScalingOperator
from
.plot
import
Plot
from
.
import
pointwise
__all__
=
[
'PS_field'
,
'power_analyze'
,
'create_power_operator'
,
'create_harmonic_smoothing_operator'
,
'from_random'
,
...
...
@@ -491,31 +494,28 @@ def exec_time(obj, want_metric=True):
def
calculate_position
(
operator
,
output
):
"""Finds approximate preimage of an operator for a given output."""
from
.minimization.descent_minimizers
import
NewtonCG
from
.minimization.iteration_controllers
import
GradientNormController
from
.minimization.metric_gaussian_kl
import
MetricGaussianKL
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.energy_operators
import
GaussianEnergy
,
StandardHamiltonian
if
not
isinstance
(
operator
,
Operator
):
raise
TypeError
if
output
.
domain
!=
operator
.
target
:
raise
TypeError
if
isinstance
(
output
,
MultiField
):
cov
=
1e-3
*
max
([
vv
.
max
()
for
vv
in
output
.
val
.
values
()])
**
2
cov
=
1e-3
*
max
([
np
.
max
(
np
.
abs
(
vv
)
)
for
vv
in
output
.
val
.
values
()])
**
2
invcov
=
ScalingOperator
(
output
.
domain
,
cov
).
inverse
dtype
=
list
(
set
([
ff
.
dtype
for
ff
in
output
.
values
()]))
if
len
(
dtype
)
!=
1
:
raise
ValueError
(
'Only MultiFields with one dtype supported.'
)
dtype
=
dtype
[
0
]
else
:
cov
=
1e-3
*
output
.
val
.
max
(
)
**
2
cov
=
1e-3
*
np
.
max
(
np
.
abs
(
output
.
val
)
)
**
2
dtype
=
output
.
dtype
invcov
=
ScalingOperator
(
output
.
domain
,
cov
).
inverse
d
=
output
+
invcov
.
draw_sample
(
dtype
,
from_inverse
=
True
)
invcov
=
SamplingDtypeSetter
(
invcov
,
output
.
dtype
)
invcov
=
SamplingDtypeSetter
(
invcov
,
output
.
dtype
)
d
=
output
+
invcov
.
draw_sample
(
from_inverse
=
True
)
lh
=
GaussianEnergy
(
d
,
invcov
)
@
operator
H
=
StandardHamiltonian
(
lh
,
ic_samp
=
GradientNormController
(
iteration_limit
=
200
))
pos
=
0.1
*
from_random
(
'normal'
,
operator
.
domain
)
pos
=
0.1
*
from_random
(
operator
.
domain
)
minimizer
=
NewtonCG
(
GradientNormController
(
iteration_limit
=
10
,
name
=
'findpos'
))
for
ii
in
range
(
3
):
logger
.
info
(
f
'Start iteration
{
ii
+
1
}
/3'
)
...
...
test/test_sugar.py
View file @
2ec7955a
...
...
@@ -52,9 +52,18 @@ def test_exec_time():
ift
.
exec_time
(
oo
,
wm
)
def
test_calc_pos
():
import
pytest
pmp
=
pytest
.
mark
.
parametrize
@
pmp
(
'mf'
,
[
False
,
True
])
@
pmp
(
'cplx'
,
[
False
,
True
])
def
test_calc_pos
(
mf
,
cplx
):
dom
=
ift
.
RGSpace
(
12
,
harmonic
=
True
)
op
=
ift
.
HarmonicTransformOperator
(
dom
).
ptw
(
"exp"
)
if
mf
:
op
=
op
.
ducktape_left
(
'foo'
)
dom
=
ift
.
makeDomain
({
''
:
dom
})
if
cplx
:
op
=
op
+
1j
*
op
fld
=
op
(
0.1
*
ift
.
from_random
(
op
.
domain
,
'normal'
))
pos
=
ift
.
calculate_position
(
op
,
fld
)
ift
.
extra
.
assert_allclose
(
op
(
pos
),
fld
,
1e-1
,
1e-1
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment