Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
9dea1d88
Commit
9dea1d88
authored
Jun 20, 2020
by
Philipp Arras
Browse files
Implement proper constant support 5/n
parent
b4f32295
Changes
5
Hide whitespace changes
Inline
Side-by-side
src/extra.py
View file @
9dea1d88
...
...
@@ -26,6 +26,7 @@ from .field import Field
from
.linearization
import
Linearization
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.operators.energy_operators
import
EnergyOperator
from
.operators.linear_operator
import
LinearOperator
from
.operators.operator
import
Operator
from
.sugar
import
from_random
...
...
@@ -320,14 +321,13 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
return
# FIXME ?
keys
=
op
.
domain
.
keys
()
combis
=
[]
for
ll
in
range
(
0
,
len
(
keys
)):
for
ll
in
range
(
1
,
len
(
keys
)):
combis
.
extend
(
list
(
combinations
(
keys
,
ll
)))
if
len
(
combis
)
>
max_combinations
:
random
.
seed
(
42
)
combis
=
random
.
sample
(
combis
,
int
(
max_combinations
))
for
cstkeys
in
combis
:
varkeys
=
set
(
keys
)
-
set
(
cstkeys
)
print
(
f
'Constant:
{
set
(
cstkeys
)
}
, Variable:
{
varkeys
}
'
)
cstloc
=
loc
.
extract_by_keys
(
cstkeys
)
varloc
=
loc
.
extract_by_keys
(
varkeys
)
...
...
@@ -348,13 +348,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
foo
=
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
cstloc
.
domain
)
assert_equal
(
foo
,
0
*
foo
)
# FIXME
# if isinstance(op, EnergyOperator):
# _allzero(oplin.gradient.extract(cstdom))
# if isinstance(op, EnergyOperator) and metric_sampling:
# samp0 = oplin.metric.draw_sample()
# _allzero(samp0.extract(cstdom))
# _nozero(samp0.extract(vardom))
if
isinstance
(
op
,
EnergyOperator
)
and
metric_sampling
:
oplin
.
metric
.
draw_sample
()
assert
op0
.
domain
is
varloc
.
domain
_jac_vs_finite_differences
(
op0
,
varloc
,
np
.
sqrt
(
tol
),
ntries
,
only_r_differentiable
)
...
...
src/operators/operator.py
View file @
9dea1d88
...
...
@@ -302,11 +302,7 @@ class Operator(metaclass=NiftyMeta):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
InsertionOperator
s
=
(
'SlowPartialConstantOperator used. You might want to consider'
' implementing `_simplify_for_constant_input_nontrivial()` for'
' this operator:'
)
logger
.
warning
(
s
)
logger
.
warning
(
self
.
__repr__
())
logger
.
warning
(
'SlowPartialConstantOperator used.'
)
return
None
,
self
@
InsertionOperator
(
self
.
domain
,
c_inp
)
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
...
...
test/test_kl.py
View file @
9dea1d88
...
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-20
19
Max-Planck-Society
# Copyright(C) 2013-20
20
Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
...
test/test_mpi/test_kl.py
View file @
9dea1d88
...
...
@@ -69,7 +69,7 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
_
,
tmph
=
h
.
simplify_for_constant_input
(
mean0
.
extract_by_keys
(
constants
))
else
:
tmph
=
h
kl1
=
ift
.
MetricGaussianKL
(
mean0
,
tmph
,
2
,
mirror_samples
,
comm
,
locsamp
,
False
,
True
)
kl1
=
ift
.
MetricGaussianKL
(
mean0
.
extract
(
tmph
.
domain
)
,
tmph
,
2
,
mirror_samples
,
comm
,
locsamp
,
False
,
True
)
elif
mode
==
1
:
kl0
=
ift
.
MetricGaussianKL
.
make
(
**
args
)
samples
=
kl0
.
_local_samples
...
...
@@ -80,7 +80,7 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
_
,
tmph
=
h
.
simplify_for_constant_input
(
mean0
.
extract_by_keys
(
constants
))
else
:
tmph
=
h
kl1
=
ift
.
MetricGaussianKL
(
mean0
,
tmph
,
2
,
mirror_samples
,
comm
,
locsamp
,
False
,
True
)
kl1
=
ift
.
MetricGaussianKL
(
mean0
.
extract
(
tmph
.
domain
)
,
tmph
,
2
,
mirror_samples
,
comm
,
locsamp
,
False
,
True
)
# Test number of samples
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
...
...
@@ -92,31 +92,9 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
# Test gradient
if
mf
:
for
kk
in
h
.
domain
.
keys
():
for
kk
in
kl0
.
gradient
.
domain
.
keys
():
res0
=
kl0
.
gradient
[
kk
].
val
if
kk
in
constants
:
res0
=
0
*
res0
res1
=
kl1
.
gradient
[
kk
].
val
assert_equal
(
res0
,
res1
)
else
:
assert_equal
(
kl0
.
gradient
.
val
,
kl1
.
gradient
.
val
)
# Test point_estimates (after drawing samples)
if
mf
:
for
kk
in
point_estimates
:
for
ss
in
kl0
.
samples
:
ss
=
ss
[
kk
].
val
assert_equal
(
ss
,
0
*
ss
)
for
ss
in
kl1
.
samples
:
ss
=
ss
[
kk
].
val
assert_equal
(
ss
,
0
*
ss
)
# Test constants (after some minimization)
if
mf
:
cg
=
ift
.
GradientNormController
(
iteration_limit
=
5
)
minimizer
=
ift
.
NewtonCG
(
cg
)
for
e
in
[
kl0
,
kl1
]:
e
,
_
=
minimizer
(
e
)
diff
=
(
mean0
-
e
.
position
).
to_dict
()
for
kk
in
constants
:
assert_equal
(
diff
[
kk
].
val
,
0
*
diff
[
kk
].
val
)
test/test_operators/test_correlated_fields.py
View file @
9dea1d88
...
...
@@ -108,5 +108,7 @@ def testAmplitudesInvariants(sspace, N):
assert_
(
op
.
target
[
-
1
]
==
fsspace
)
for
ampl
in
fa
.
normalized_amplitudes
:
ift
.
extra
.
check_operator
(
ampl
,
ift
.
from_random
(
ampl
.
domain
),
ntries
=
10
)
ift
.
extra
.
check_operator
(
op
,
ift
.
from_random
(
op
.
domain
),
ntries
=
10
,
max_combinations
=
3
)
ift
.
extra
.
check_operator
(
ampl
,
ift
.
from_random
(
ampl
.
domain
),
ntries
=
3
,
max_combinations
=
5
)
ift
.
extra
.
check_operator
(
op
,
ift
.
from_random
(
op
.
domain
),
ntries
=
3
,
max_combinations
=
5
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment