Commit b61562e0 authored by Philipp Arras's avatar Philipp Arras

Simplifications for addition of energies

parent f8c5c895
Pipeline #31244 passed with stages
in 1 minute and 24 seconds
......@@ -137,43 +137,17 @@ class Energy(NiftyMetaBase()):
return Add(self, (-1) * other)
class Add(Energy):
"""
Please not: If you add two operators which share some keys in the position
but have different values there, it is not guaranteed which value will be
used for the sum. You shouldn't do that anyways.
"""
def __init__(self, op1, op2):
if isinstance(op1.position, MultiField) and isinstance(op2.position, MultiField):
a = op1.position._val
b = op2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
position = MultiField(ab)
elif isinstance(op1.position, Field) and isinstance(op2.position, Field):
position = op1.position
else:
raise TypeError
super(Add, self).__init__(position)
self._op1 = op1.at(position)
self._op2 = op2.at(position)
def at(self, position):
return self.__class__(self._op1.at(position), self._op2.at(position))
@property
@memo
def value(self):
return self._op1.value + self._op2.value
@property
@memo
def gradient(self):
return self._op1.gradient + self._op2.gradient
@property
@memo
def curvature(self):
return self._op1.curvature + self._op2.curvature
def Add(energy1, energy2):
if isinstance(energy1.position, MultiField) and isinstance(energy2.position, MultiField):
a = energy1.position._val
b = energy2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
position = MultiField(ab)
elif isinstance(energy1.position, Field) and isinstance(energy2.position, Field):
position = energy1.position
else:
raise TypeError
from .energy_sum import EnergySum
return EnergySum(position, [energy1, energy2])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment