Commit 82c170d1 authored by Martin Reinecke's avatar Martin Reinecke

add EnergySum class from D4PO

parent 0c0bd6b3
Pipeline #30074 passed with stages
in 9 minutes and 5 seconds
......@@ -721,6 +721,15 @@ class Field(object):
self._domain.__str__() + \
"\n- val = " + repr(self.val)
def equivalent(self, other):
if self is other:
return True
if not isinstance(other, Field):
return False
if self._domain != other._domain:
return False
return (self._val == other._val).all()
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
......
......@@ -14,10 +14,11 @@ from .scipy_minimizer import ScipyMinimizer, NewtonCG, L_BFGS_B, ScipyCG
from .energy import Energy
from .quadratic_energy import QuadraticEnergy
from .line_energy import LineEnergy
from .energy_sum import EnergySum
__all__ = ["LineSearch", "LineSearchStrongWolfe",
"IterationController", "GradientNormController",
"Minimizer", "ConjugateGradient", "NonlinearCG", "DescentMinimizer",
"SteepestDescent", "VL_BFGS", "RelaxedNewton", "ScipyMinimizer",
"NewtonCG", "L_BFGS_B", "ScipyCG", "Energy", "QuadraticEnergy",
"LineEnergy", "L_BFGS"]
"LineEnergy", "L_BFGS", "EnergySum"]
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .energy import Energy
from ..utilities import memo
class EnergySum(Energy):
def __init__(self, position, energies, minimizer_controller=None,
preconditioner=None, precon_idx=None):
super(EnergySum, self).__init__(position=position)
self._energies = [energy.at(position) for energy in energies]
self._min_controller = minimizer_controller
self._preconditioner = preconditioner
self._precon_idx = precon_idx
def at(self, position):
return self.__class__(position, self._energies, self._min_controller,
self._preconditioner, self._precon_idx)
@property
@memo
def value(self):
res = self._energies[0].value
for e in self._energies[1:]:
res += e.value
return res
@property
@memo
def gradient(self):
res = self._energies[0].gradient.copy()
for e in self._energies[1:]:
res += e.gradient
return res.lock()
@property
@memo
def curvature(self):
res = self._energies[0].curvature
for e in self._energies[1:]:
res = res + e.curvature
if self._min_controller is None:
return res
precon = self._preconditioner
if precon is None and self._precon_idx is not None:
precon = self._energies[self._precon_idx].curvature
from ..operators.inversion_enabler import InversionEnabler
from .conjugate_gradient import ConjugateGradient
return InversionEnabler(
res, ConjugateGradient(self._min_controller), precon)
......@@ -29,6 +29,8 @@ class MultiField(object):
val : dict
"""
self._val = val
self._domain = MultiDomain.make(
{key: val.domain for key, val in self._val.items()})
def __getitem__(self, key):
return self._val[key]
......@@ -44,8 +46,7 @@ class MultiField(object):
@property
def domain(self):
return MultiDomain.make(
{key: val.domain for key, val in self._val.items()})
return self._domain
@property
def dtype(self):
......@@ -71,7 +72,7 @@ class MultiField(object):
return self
def _check_domain(self, other):
if other.domain != self.domain:
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
def vdot(self, x):
......@@ -147,6 +148,17 @@ class MultiField(object):
return MultiField({key: sub_field.conjugate()
for key, sub_field in self.items()})
def equivalent(self, other):
if self is other:
return True
if not isinstance(other, MultiField):
return False
if self._domain != other._domain:
return False
for key, val in self._val.items():
if not val.equivalent(other[key]):
return False
return True
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment