Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
bae4a32b
Commit
bae4a32b
authored
Jun 28, 2018
by
Martin Reinecke
Browse files
improve implementation of EnergySum and allow for linear combination of energies
parent
ae802b86
Changes
2
Hide whitespace changes
Inline
Side-by-side
nifty5/minimization/energy.py
View file @
bae4a32b
...
...
@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..field
import
Field
from
..multi
import
MultiField
from
..utilities
import
NiftyMetaBase
,
memo
...
...
@@ -128,30 +129,18 @@ class Energy(NiftyMetaBase()):
"""
return
None
def
__add__
(
self
,
other
):
if
not
isinstance
(
other
,
Energy
):
raise
TypeError
return
Add
(
self
,
other
)
def
__mul__
(
self
,
factor
):
from
.energy_sum
import
EnergySum
if
not
isinstance
(
factor
,
(
float
,
int
)):
raise
TypeError
(
"Factor must be a real-valued scalar"
)
return
EnergySum
.
make
([
self
],
[
factor
])
def
__sub__
(
self
,
other
):
def
__rmul__
(
self
,
factor
):
return
self
.
__mul__
(
factor
)
def
__add__
(
self
,
other
):
from
.energy_sum
import
EnergySum
from
..sugar
import
full
if
not
isinstance
(
other
,
Energy
):
raise
TypeError
return
Add
(
self
,
(
-
1
)
*
other
)
def
Add
(
energy1
,
energy2
):
if
(
isinstance
(
energy1
.
position
,
MultiField
)
and
isinstance
(
energy2
.
position
,
MultiField
)):
a
=
energy1
.
position
.
_val
b
=
energy2
.
position
.
_val
# Note: In python >3.5 one could do {**a, **b}
ab
=
a
.
copy
()
ab
.
update
(
b
)
position
=
MultiField
(
ab
)
elif
(
isinstance
(
energy1
.
position
,
Field
)
and
isinstance
(
energy2
.
position
,
Field
)):
position
=
energy1
.
position
else
:
raise
TypeError
from
.energy_sum
import
EnergySum
return
EnergySum
(
position
,
[
energy1
,
energy2
])
raise
TypeError
(
"can only add Energies to Energies"
)
return
EnergySum
.
make
([
self
,
other
])
nifty5/minimization/energy_sum.py
View file @
bae4a32b
...
...
@@ -21,44 +21,60 @@ from ..utilities import memo
class
EnergySum
(
Energy
):
def
__init__
(
self
,
position
,
energies
,
minimizer_controller
=
None
,
preconditioner
=
None
,
precon_idx
=
None
):
def
__init__
(
self
,
position
,
energies
,
factors
):
super
(
EnergySum
,
self
).
__init__
(
position
=
position
)
self
.
_energies
=
[
energy
.
at
(
position
)
for
energy
in
energies
]
self
.
_min_controller
=
minimizer_controller
self
.
_preconditioner
=
preconditioner
self
.
_precon_idx
=
precon_idx
self
.
_energies
=
tuple
(
e
.
at
(
position
)
for
e
in
energies
)
self
.
_factors
=
tuple
(
factors
)
@
staticmethod
def
make
(
energies
,
factors
=
None
):
if
factors
is
None
:
factors
=
(
1
,)
*
len
(
energies
)
# unpack energies
eout
=
[]
fout
=
[]
EnergySum
.
_unpackEnergies
(
energies
,
factors
,
1.
,
eout
,
fout
)
for
e
in
eout
[
1
:]:
if
not
e
.
position
.
isEquivalentTo
(
eout
[
0
].
position
):
raise
ValueError
(
"position mismatch"
)
return
EnergySum
(
eout
[
0
].
position
,
eout
,
fout
)
@
staticmethod
def
_unpackEnergies
(
e_in
,
f_in
,
prefactor
,
e_out
,
f_out
):
for
e
,
f
in
zip
(
e_in
,
f_in
):
if
isinstance
(
e
,
EnergySum
):
EnergySum
.
_unpackEnergies
(
e
.
_energies
,
e
.
_factors
,
prefactor
*
f
,
e_out
,
f_out
)
else
:
e_out
.
append
(
e
)
f_out
.
append
(
prefactor
*
f
)
def
at
(
self
,
position
):
return
self
.
__class__
(
position
,
self
.
_energies
,
self
.
_min_controller
,
self
.
_preconditioner
,
self
.
_precon_idx
)
return
self
.
__class__
(
position
,
self
.
_energies
,
self
.
_factors
)
@
property
@
memo
def
value
(
self
):
res
=
self
.
_energies
[
0
].
value
for
e
in
self
.
_energies
[
1
:]:
res
+=
e
.
value
res
=
self
.
_energies
[
0
].
value
*
self
.
_factors
[
0
]
for
e
,
f
in
zip
(
self
.
_energies
[
1
:]
,
self
.
_factors
[
1
:])
:
res
+=
e
.
value
*
f
return
res
@
property
@
memo
def
gradient
(
self
):
res
=
self
.
_energies
[
0
].
gradient
.
copy
()
for
e
in
self
.
_energies
[
1
:]:
res
+=
e
.
gradient
res
=
self
.
_energies
[
0
].
gradient
.
copy
()
if
self
.
_factors
[
0
]
==
1.
\
else
self
.
_energies
[
0
].
gradient
*
self
.
_factors
[
0
]
for
e
,
f
in
zip
(
self
.
_energies
[
1
:],
self
.
_factors
[
1
:]):
res
+=
e
.
gradient
if
f
==
1.
else
f
*
e
.
gradient
return
res
.
lock
()
@
property
@
memo
def
curvature
(
self
):
res
=
self
.
_energies
[
0
].
curvature
for
e
in
self
.
_energies
[
1
:]:
res
=
res
+
e
.
curvature
if
self
.
_min_controller
is
None
:
return
res
precon
=
self
.
_preconditioner
if
precon
is
None
and
self
.
_precon_idx
is
not
None
:
precon
=
self
.
_energies
[
self
.
_precon_idx
].
curvature
from
..operators.inversion_enabler
import
InversionEnabler
return
InversionEnabler
(
res
,
self
.
_min_controller
,
precon
)
res
=
self
.
_energies
[
0
].
curvature
if
self
.
_factors
[
0
]
==
1.
\
else
self
.
_energies
[
0
].
curvature
*
self
.
_factors
[
0
]
for
e
,
f
in
zip
(
self
.
_energies
[
1
:],
self
.
_factors
[
1
:]):
res
=
res
+
(
e
.
curvature
if
f
==
1.
else
e
.
curvature
*
f
)
return
res
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment