Commit 48e52d16 authored by Vincent Eberle's avatar Vincent Eberle
Browse files

Work in progress. Scalable Energys by Reimar and Vince

parent e68753f6
Pipeline #67245 failed with stages
in 21 minutes and 7 seconds
......@@ -44,6 +44,30 @@ class EnergyOperator(Operator):
_target = DomainTuple.scalar_domain()
def scale(self, scalar):
if scalar >= 0:
return ScaledEnergy(self._target, scalar)(self)
return super(self, Operator).scale(scalar)
class ScaledEnergy(EnergyOperator):
def __init__(self, domain, scalar):
self._domain = domain
assert scalar >= 0
self._scalar = scalar
def apply(self, x):
res = ScalingOperator(self._target, self._scalar)(x)
if isinstance(x, Linearization):
if x.metric is not None:
met = SandwichOperator.make(ScalingOperator(self._target, np.sqrt(self._scalar)), x.metric)
res = res.add_metric(met)
return res
class Squared2NormOperator(EnergyOperator):
"""Computes the square of the L2-norm of the output of an operator.
......@@ -46,6 +46,16 @@ def test_gaussian(field):
energy = ift.GaussianEnergy(domain=field.domain)
ift.extra.check_jacobian_consistency(energy, field)
def test_ScaledEnergy(field):
energy = ift.GaussianEnergy(domain=field.domain)
ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
#lin = ift.Linearization.make_var(field)
met1 = ift.GaussianEnergy(field)
met2 = ift.GaussianEnergy.scale(0.3)(field)
assert met1 == met2 / 0.3
assert isinstance(met2, ift.SandwichOperator)
def test_studentt(field):
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment