Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
00c9005f
Commit
00c9005f
authored
Jun 04, 2020
by
Philipp Arras
Browse files
Add nontrivial simplify for constant input
parent
0b26ae98
Pipeline
#76113
passed with stages
in 13 minutes and 43 seconds
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/operators/energy_operators.py
View file @
00c9005f
...
...
@@ -175,6 +175,47 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
**
(
-
2
)})
return
res
.
add_metric
(
SamplingDtypeSetter
(
makeOp
(
met
),
self
.
_dt
))
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
ConstantEnergyOperator
assert
len
(
c_inp
.
keys
())
==
1
key
=
c_inp
.
keys
()[
0
]
assert
key
in
self
.
_domain
.
keys
()
cst
=
c_inp
[
key
]
if
key
==
self
.
_kr
:
res
=
_SpecialGammaEnergy
(
cst
).
ducktape
(
self
.
_ki
)
else
:
dt
=
self
.
_dt
[
self
.
_kr
]
res
=
GaussianEnergy
(
inverse_covariance
=
makeOp
(
cst
),
sampling_dtype
=
dt
).
ducktape
(
self
.
_kr
)
trlog
=
cst
.
log
().
sum
().
val_rw
()
if
not
_iscomplex
(
dt
):
trlog
/=
2
res
=
res
+
ConstantEnergyOperator
(
res
.
domain
,
-
trlog
)
res
=
res
+
ConstantEnergyOperator
(
self
.
_domain
,
0.
)
assert
res
.
domain
is
self
.
domain
assert
res
.
target
is
self
.
target
return
None
,
res
class
_SpecialGammaEnergy
(
EnergyOperator
):
def
__init__
(
self
,
residual
):
self
.
_domain
=
DomainTuple
.
make
(
residual
.
domain
)
self
.
_resi
=
residual
self
.
_cplx
=
_iscomplex
(
self
.
_resi
.
dtype
)
self
.
_scale
=
ScalingOperator
(
self
.
_domain
,
1
if
self
.
_cplx
else
.
5
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
r
=
self
.
_resi
if
self
.
_cplx
:
res
=
0.5
*
(
r
*
x
.
real
).
vdot
(
r
).
real
-
x
.
ptw
(
"log"
).
sum
()
else
:
res
=
0.5
*
((
r
*
x
).
vdot
(
r
)
-
x
.
ptw
(
"log"
).
sum
())
if
not
x
.
want_metric
:
return
res
met
=
makeOp
((
self
.
_scale
(
x
.
val
))
**
(
-
2
))
return
res
.
add_metric
(
SamplingDtypeSetter
(
met
,
self
.
_resi
.
dtype
))
class
GaussianEnergy
(
EnergyOperator
):
"""Computes a negative-log Gaussian.
...
...
src/operators/simple_linear_operators.py
View file @
00c9005f
...
...
@@ -347,6 +347,12 @@ class NullOperator(LinearOperator):
tgt
=
self
.
target
.
keys
()
if
isinstance
(
self
.
target
,
MultiDomain
)
else
'()'
return
f
'
{
tgt
}
<- NullOperator <-
{
dom
}
'
def
draw_sample
(
self
,
from_inverse
=
False
):
if
self
.
_domain
is
not
self
.
_target
:
raise
RuntimeError
from
..sugar
import
full
return
full
(
self
.
_domain
,
0.
)
class
PartialExtractor
(
LinearOperator
):
def
__init__
(
self
,
domain
,
target
):
...
...
src/operators/simplify_for_const.py
View file @
00c9005f
...
...
@@ -109,7 +109,10 @@ class SlowPartialConstantOperator(Operator):
class
ConstantEnergyOperator
(
EnergyOperator
):
def
__init__
(
self
,
dom
,
output
):
from
..sugar
import
makeDomain
from
..field
import
Field
self
.
_domain
=
makeDomain
(
dom
)
if
not
isinstance
(
output
,
Field
):
output
=
Field
.
scalar
(
float
(
output
))
if
self
.
target
is
not
output
.
domain
:
raise
TypeError
self
.
_output
=
output
...
...
test/test_energy_gradients.py
View file @
00c9005f
...
...
@@ -88,6 +88,15 @@ def test_variablecovariancegaussian(field):
energy
(
ift
.
Linearization
.
make_var
(
mf
,
want_metric
=
True
)).
metric
.
draw_sample
()
def
test_specialgamma
(
field
):
if
isinstance
(
field
.
domain
,
ift
.
MultiDomain
):
return
energy
=
ift
.
operators
.
energy_operators
.
_SpecialGammaEnergy
(
field
)
loc
=
ift
.
from_random
(
energy
.
domain
).
exp
()
ift
.
extra
.
check_jacobian_consistency
(
energy
,
loc
,
tol
=
1e-6
,
ntries
=
ntries
)
energy
(
ift
.
Linearization
.
make_var
(
loc
,
want_metric
=
True
)).
metric
.
draw_sample
()
def
test_inverse_gamma
(
field
):
if
isinstance
(
field
.
domain
,
ift
.
MultiDomain
):
return
...
...
test/test_gaussian_energy.py
View file @
00c9005f
...
...
@@ -28,6 +28,7 @@ def _flat_PS(k):
pmp
=
pytest
.
mark
.
parametrize
ntries
=
10
@
pmp
(
'space'
,
[
ift
.
GLSpace
(
5
),
...
...
@@ -70,4 +71,55 @@ def test_gaussian_energy(space, nonlinearity, noise, seed):
energy
=
ift
.
GaussianEnergy
(
d
,
N
)
@
d_model
()
ift
.
extra
.
check_jacobian_consistency
(
energy
,
xi0
,
ntries
=
10
,
tol
=
1e-6
)
energy
,
xi0
,
ntries
=
ntries
,
tol
=
1e-6
)
@
pmp
(
'cplx'
,
[
True
,
False
])
def
testgaussianenergy_compatibility
(
cplx
):
dt
=
np
.
complex128
if
cplx
else
np
.
float64
dom
=
ift
.
UnstructuredDomain
(
3
)
e
=
ift
.
VariableCovarianceGaussianEnergy
(
dom
,
'resi'
,
'icov'
,
dt
)
resi
=
ift
.
from_random
(
dom
)
if
cplx
:
resi
=
resi
+
1j
*
ift
.
from_random
(
dom
)
loc0
=
ift
.
MultiField
.
from_dict
({
'resi'
:
resi
})
loc1
=
ift
.
MultiField
.
from_dict
({
'icov'
:
ift
.
from_random
(
dom
).
exp
()})
loc
=
loc0
.
unite
(
loc1
)
val0
=
e
(
loc
).
val
_
,
e0
=
e
.
simplify_for_constant_input
(
loc0
)
val1
=
e0
(
loc
).
val
val2
=
e0
(
loc
.
unite
(
loc0
)).
val
np
.
testing
.
assert_equal
(
val1
,
val2
)
np
.
testing
.
assert_equal
(
val0
,
val1
)
_
,
e1
=
e
.
simplify_for_constant_input
(
loc1
)
val1
=
e1
(
loc
).
val
val2
=
e1
(
loc
.
unite
(
loc1
)).
val
np
.
testing
.
assert_equal
(
val0
,
val1
)
np
.
testing
.
assert_equal
(
val1
,
val2
)
ift
.
extra
.
check_jacobian_consistency
(
e
,
loc
,
ntries
=
ntries
)
ift
.
extra
.
check_jacobian_consistency
(
e0
,
loc
,
ntries
=
ntries
)
ift
.
extra
.
check_jacobian_consistency
(
e1
,
loc
,
ntries
=
ntries
)
# Test jacobian is zero
lin
=
ift
.
Linearization
.
make_var
(
loc
,
want_metric
=
True
)
grad
=
e
(
lin
).
gradient
.
val
grad0
=
e0
(
lin
).
gradient
.
val
grad1
=
e1
(
lin
).
gradient
.
val
samp
=
e
(
lin
).
metric
.
draw_sample
().
val
samp0
=
e0
(
lin
).
metric
.
draw_sample
().
val
samp1
=
e1
(
lin
).
metric
.
draw_sample
().
val
np
.
testing
.
assert_equal
(
samp0
[
'resi'
],
0.
)
np
.
testing
.
assert_equal
(
samp1
[
'icov'
],
0.
)
np
.
testing
.
assert_equal
(
grad0
[
'resi'
],
0.
)
np
.
testing
.
assert_equal
(
grad1
[
'icov'
],
0.
)
np
.
testing
.
assert_
(
all
(
samp
[
'resi'
]
!=
0
))
np
.
testing
.
assert_
(
all
(
samp
[
'icov'
]
!=
0
))
np
.
testing
.
assert_
(
all
(
samp0
[
'icov'
]
!=
0
))
np
.
testing
.
assert_
(
all
(
samp1
[
'resi'
]
!=
0
))
np
.
testing
.
assert_
(
all
(
grad
[
'resi'
]
!=
0
))
np
.
testing
.
assert_
(
all
(
grad
[
'icov'
]
!=
0
))
np
.
testing
.
assert_
(
all
(
grad0
[
'icov'
]
!=
0
))
np
.
testing
.
assert_
(
all
(
grad1
[
'resi'
]
!=
0
))
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment