Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
151b37f4
Commit
151b37f4
authored
Jun 21, 2020
by
Philipp Arras
Browse files
Refactoring and add test
parent
d1ab800e
Pipeline
#77030
passed with stages
in 12 minutes and 22 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/minimization/metric_gaussian_kl.py
View file @
151b37f4
...
...
@@ -51,16 +51,21 @@ def _modify_sample_domain(sample, domain):
"""Takes only keys from sample which are also in domain and inserts zeros
in sample if key is not in domain."""
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
sample
,
MultiField
):
assert
sample
.
domain
is
domain
return
sample
assert
isinstance
(
domain
,
MultiDomain
)
if
sample
.
domain
is
domain
:
from
..field
import
Field
from
..domain_tuple
import
DomainTuple
from
..sugar
import
makeDomain
domain
=
makeDomain
(
domain
)
if
isinstance
(
domain
,
DomainTuple
)
and
isinstance
(
sample
,
Field
):
if
sample
.
domain
is
not
domain
:
raise
TypeError
return
sample
out
=
{
kk
:
vv
for
kk
,
vv
in
sample
.
items
()
if
kk
in
domain
.
keys
()}
out
=
MultiField
.
from_dict
(
out
,
domain
)
assert
domain
is
out
.
domain
return
out
elif
isinstance
(
domain
,
MultiDomain
)
and
isinstance
(
sample
,
MultiField
):
if
sample
.
domain
is
domain
:
return
sample
out
=
{
kk
:
vv
for
kk
,
vv
in
sample
.
items
()
if
kk
in
domain
.
keys
()}
out
=
MultiField
.
from_dict
(
out
,
domain
)
return
out
raise
TypeError
class
MetricGaussianKL
(
Energy
):
...
...
test/test_operators/test_simplification.py
View file @
151b37f4
...
...
@@ -15,7 +15,7 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
numpy.testing
import
assert_
,
assert_allclose
from
numpy.testing
import
assert_
,
assert_allclose
,
assert_raises
import
nifty7
as
ift
from
nifty7.operators.simplify_for_const
import
ConstantOperator
...
...
@@ -41,3 +41,28 @@ def test_simplification():
assert_allclose
(
op
(
f1
)[
"a"
].
val
,
op2
.
force
(
f1
)[
"a"
].
val
)
assert_allclose
(
op
(
f1
)[
"b"
].
val
,
op2
.
force
(
f1
)[
"b"
].
val
)
# FIXME Add test for ChainOperator._simplify_for_constant_input_nontrivial()
def
test_modify_sample_domain
():
func
=
ift
.
minimization
.
metric_gaussian_kl
.
_modify_sample_domain
dom0
=
ift
.
RGSpace
(
1
)
dom1
=
ift
.
RGSpace
(
2
)
field
=
ift
.
full
(
dom0
,
1.
)
ift
.
extra
.
assert_equal
(
func
(
field
,
dom0
),
field
)
mdom0
=
ift
.
makeDomain
({
'a'
:
dom0
,
'b'
:
dom1
})
mdom1
=
ift
.
makeDomain
({
'a'
:
dom0
})
mfield0
=
ift
.
full
(
mdom0
,
1.
)
mfield1
=
ift
.
full
(
mdom1
,
1.
)
mfield01
=
ift
.
MultiField
.
from_dict
({
'a'
:
ift
.
full
(
dom0
,
1.
),
'b'
:
ift
.
full
(
dom1
,
0.
)})
ift
.
extra
.
assert_equal
(
func
(
mfield0
,
mdom0
),
mfield0
)
ift
.
extra
.
assert_equal
(
func
(
mfield0
,
mdom1
),
mfield1
)
ift
.
extra
.
assert_equal
(
func
(
mfield1
,
mdom0
),
mfield01
)
ift
.
extra
.
assert_equal
(
func
(
mfield1
,
mdom1
),
mfield1
)
with
assert_raises
(
TypeError
):
func
(
mfield0
,
dom0
)
with
assert_raises
(
TypeError
):
func
(
field
,
dom1
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment