Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
918324e1
Commit
918324e1
authored
May 31, 2021
by
Philipp Arras
Browse files
ptw -> direct calls
parent
cc38c3de
Changes
1
Hide whitespace changes
Inline
Side-by-side
src/operators/energy_operators.py
View file @
918324e1
...
...
@@ -200,9 +200,9 @@ class VariableCovarianceGaussianEnergy(LikelihoodOperator):
self
.
_check_input
(
x
)
r
,
i
=
x
[
self
.
_kr
],
x
[
self
.
_ki
]
if
self
.
_cplx
:
res
=
0.5
*
r
.
vdot
(
r
*
i
.
real
).
real
-
i
.
ptw
(
"
log
"
).
sum
()
res
=
0.5
*
r
.
vdot
(
r
*
i
.
real
).
real
-
i
.
log
(
).
sum
()
else
:
res
=
0.5
*
(
r
.
vdot
(
r
*
i
)
-
i
.
ptw
(
"
log
"
).
sum
())
res
=
0.5
*
(
r
.
vdot
(
r
*
i
)
-
i
.
log
(
).
sum
())
if
not
x
.
want_metric
:
return
res
if
self
.
_use_full_fisher
:
...
...
@@ -260,16 +260,16 @@ class _SpecialGammaEnergy(LikelihoodOperator):
self
.
_check_input
(
x
)
r
=
self
.
_resi
if
self
.
_cplx
:
res
=
0.5
*
(
r
*
x
.
real
).
vdot
(
r
).
real
-
x
.
ptw
(
"
log
"
).
sum
()
res
=
0.5
*
(
r
*
x
.
real
).
vdot
(
r
).
real
-
x
.
log
(
).
sum
()
else
:
res
=
0.5
*
((
r
*
x
).
vdot
(
r
)
-
x
.
ptw
(
"
log
"
).
sum
())
res
=
0.5
*
((
r
*
x
).
vdot
(
r
)
-
x
.
log
(
).
sum
())
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
def
get_transformation
(
self
):
sc
=
1.
if
self
.
_cplx
else
np
.
sqrt
(
0.5
)
return
self
.
_dt
,
sc
*
ScalingOperator
(
self
.
_domain
,
1.
).
ptw
(
'
log
'
)
return
self
.
_dt
,
sc
*
ScalingOperator
(
self
.
_domain
,
1.
).
log
(
)
class
GaussianEnergy
(
LikelihoodOperator
):
...
...
@@ -399,7 +399,7 @@ class PoissonianEnergy(LikelihoodOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
res
=
x
.
sum
()
-
x
.
ptw
(
"
log
"
).
vdot
(
self
.
_d
)
res
=
x
.
sum
()
-
x
.
log
(
).
vdot
(
self
.
_d
)
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
...
...
@@ -446,14 +446,14 @@ class InverseGammaLikelihood(LikelihoodOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
res
=
x
.
ptw
(
"
log
"
).
vdot
(
self
.
_alphap1
)
+
x
.
ptw
(
"
reciprocal
"
).
vdot
(
self
.
_beta
)
res
=
x
.
log
(
).
vdot
(
self
.
_alphap1
)
+
x
.
reciprocal
(
).
vdot
(
self
.
_beta
)
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
def
get_transformation
(
self
):
fact
=
self
.
_alphap1
.
ptw
(
'
sqrt
'
)
res
=
makeOp
(
fact
)
@
ScalingOperator
(
self
.
_domain
,
1.
).
ptw
(
'
log
'
)
fact
=
self
.
_alphap1
.
sqrt
(
)
res
=
makeOp
(
fact
)
@
ScalingOperator
(
self
.
_domain
,
1.
).
log
(
)
return
self
.
_sampling_dtype
,
res
...
...
@@ -481,7 +481,7 @@ class StudentTEnergy(LikelihoodOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
res
=
(((
self
.
_theta
+
1
)
/
2
)
*
(
x
**
2
/
self
.
_theta
).
ptw
(
"
log1p
"
)).
sum
()
res
=
(((
self
.
_theta
+
1
)
/
2
)
*
(
x
**
2
/
self
.
_theta
).
log1p
(
)).
sum
()
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
...
...
@@ -492,7 +492,7 @@ class StudentTEnergy(LikelihoodOperator):
else
:
from
..extra
import
full
th
=
full
(
self
.
_domain
,
self
.
_theta
)
return
np
.
float64
,
makeOp
(((
th
+
1
)
/
(
th
+
3
)).
ptw
(
'
sqrt
'
))
return
np
.
float64
,
makeOp
(((
th
+
1
)
/
(
th
+
3
)).
sqrt
(
))
class
BernoulliEnergy
(
LikelihoodOperator
):
...
...
@@ -522,16 +522,16 @@ class BernoulliEnergy(LikelihoodOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
res
=
-
x
.
ptw
(
"
log
"
).
vdot
(
self
.
_d
)
+
(
1.
-
x
).
ptw
(
"
log
"
).
vdot
(
self
.
_d
-
1.
)
res
=
-
x
.
log
(
).
vdot
(
self
.
_d
)
+
(
1.
-
x
).
log
(
).
vdot
(
self
.
_d
-
1.
)
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
def
get_transformation
(
self
):
from
..extra
import
full
res
=
Adder
(
full
(
self
.
_domain
,
1.
))
@
ScalingOperator
(
self
.
_domain
,
-
1
)
res
=
res
*
ScalingOperator
(
self
.
_domain
,
1
).
ptw
(
'
reciprocal
'
)
return
np
.
float64
,
-
2.
*
res
.
ptw
(
'
sqrt
'
).
ptw
(
'
arctan
'
)
res
=
Adder
(
full
(
self
.
_domain
,
1.
))
@
ScalingOperator
(
self
.
_domain
,
-
1
)
res
=
res
*
ScalingOperator
(
self
.
_domain
,
1
).
reciprocal
(
)
return
np
.
float64
,
-
2.
*
res
.
sqrt
(
).
arctan
(
)
class
StandardHamiltonian
(
EnergyOperator
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment