Skip to content
Snippets Groups Projects
Commit 6e6b4f76 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

running, but not correctly

parent 4067aa04
Branches
Tags
No related merge requests found
......@@ -49,6 +49,8 @@ class Field(object):
def __init__(self, domain, val):
if not isinstance(domain, DomainTuple):
raise TypeError("domain must be of type DomainTuple")
if np.isscalar(val):
val = np.full((), val)
if not isinstance(val, dobj.data_object):
raise TypeError("val must be of type dobj.data_object")
if domain.shape != val.shape:
......
......@@ -11,9 +11,10 @@ from .multi.multi_field import MultiField
class Linearization(object):
def __init__(self, val, jac):
def __init__(self, val, jac, metric=None):
self._val = val
self._jac = jac
self._metric = metric
@property
def domain(self):
......@@ -31,17 +32,25 @@ class Linearization(object):
def jac(self):
return self._jac
@property
def metric(self):
return self._metric
def __neg__(self):
return Linearization(-self._val, self._jac*(-1))
return Linearization(-self._val, self._jac*(-1),
None if self._metric is None else self._metric*(-1))
def __add__(self, other):
if isinstance(other, Linearization):
from .operators.relaxed_sum_operator import RelaxedSumOperator
met = None
if self._metric is not None and other._metric is not None:
met = RelaxedSumOperator((self._metric, other._metric))
return Linearization(
self._val.unite(other._val),
RelaxedSumOperator((self._jac, other._jac)))
RelaxedSumOperator((self._jac, other._jac)), met)
if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac)
return Linearization(self._val+other, self._jac, self._metric)
def __radd__(self, other):
return self.__add__(other)
......@@ -76,6 +85,11 @@ class Linearization(object):
d1 = makeOp(other)
return Linearization(self._val*other, d1*self._jac)
def sum(self):
from .sugar import full
from .operators.vdot_operator import VdotOperator
return Linearization(full((),self._val.sum()),
VdotOperator(full(self._jac.target,1))*self._jac)
@staticmethod
def make_var(field):
from .operators.scaling_operator import ScalingOperator
......
......@@ -60,6 +60,6 @@ class RelaxedSumOperator(LinearOperator):
self._check_mode(mode)
res = None
for op in self._ops:
x = op.apply(x.extract(op._dom(mode)), mode)
res = x if res is None else res.unite(x)
tmp = op.apply(x.extract(op._dom(mode)), mode)
res = tmp if res is None else res.unite(tmp)
return res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment