Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
1299dd5e
Commit
1299dd5e
authored
Jun 03, 2020
by
Reimar H Leike
Browse files
Merge branch 'fixesinvcov' into 'NIFTy_7'
Fix dtype handling See merge request
!525
parents
682a34b7
bc9b978a
Pipeline
#76028
passed with stages
in 23 minutes and 41 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/operators/energy_operators.py
View file @
1299dd5e
...
...
@@ -154,16 +154,21 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
self
.
_ki
=
str
(
inverse_covariance_key
)
dom
=
DomainTuple
.
make
(
domain
)
self
.
_domain
=
MultiDomain
.
make
({
self
.
_kr
:
dom
,
self
.
_ki
:
dom
})
self
.
_dt
=
sampling_dtype
_check_sampling_dtype
(
self
.
_domain
,
sampling_dtype
)
self
.
_dt
=
{
self
.
_kr
:
sampling_dtype
,
self
.
_ki
:
np
.
float64
}
_check_sampling_dtype
(
self
.
_domain
,
self
.
_dt
)
self
.
_cplx
=
np
.
issubdtype
(
sampling_dtype
,
np
.
complexfloating
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
r
,
i
=
x
[
self
.
_kr
],
x
[
self
.
_ki
]
res
=
0.5
*
(
r
.
vdot
(
r
*
i
.
real
).
real
-
i
.
ptw
(
"log"
).
sum
())
if
self
.
_cplx
:
res
=
0.5
*
r
.
vdot
(
r
*
i
.
real
).
real
-
i
.
ptw
(
"log"
).
sum
()
else
:
res
=
0.5
*
(
r
.
vdot
(
r
*
i
)
-
i
.
ptw
(
"log"
).
sum
())
if
not
x
.
want_metric
:
return
res
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
.
5
*
i
.
val
**
(
-
2
)})
met
=
i
.
val
if
self
.
_cplx
else
0.5
*
i
.
val
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
**
(
-
2
)})
return
res
.
add_metric
(
SamplingDtypeSetter
(
makeOp
(
met
),
self
.
_dt
))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment