Commit 3c536750 authored by Martin Reinecke's avatar Martin Reinecke

various tweaks

parent f0ed0b7f
......@@ -26,6 +26,8 @@ from .domain_tuple import DomainTuple
class Field(object):
_scalar_dom = DomainTuple.scalar_domain()
""" The discrete representation of a continuous field over multiple spaces.
In NIFTy, Fields are used to store data arrays and carry all the needed
......@@ -55,11 +57,15 @@ class Field(object):
else:
raise TypeError("val must be of type dobj.data_object")
if domain.shape != val.shape:
raise ValueError("mismatch between the shapes of val and domain")
raise ValueError("shape mismatch between val and domain")
self._domain = domain
self._val = val
dobj.lock(self._val)
@staticmethod
def scalar(val):
return Field(Field._scalar_dom, val)
# prevent implicit conversion to bool
def __nonzero__(self):
raise TypeError("Field does not support implicit conversion to bool")
......
......@@ -34,7 +34,7 @@ class Linearization(object):
@property
def gradient(self):
"""Only available if target is a scalar"""
return self._jac.adjoint_times(Field(self._jac.target, 1.))
return self._jac.adjoint_times(Field.scalar(1.))
@property
def metric(self):
......@@ -102,10 +102,10 @@ class Linearization(object):
from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)):
return Linearization(
Field(DomainTuple.scalar_domain(), self._val.vdot(other)),
Field.scalar(self._val.vdot(other)),
VdotOperator(other)(self._jac))
return Linearization(
Field(DomainTuple.scalar_domain(), self._val.vdot(other._val)),
Field.scalar(self._val.vdot(other._val)),
VdotOperator(self._val)(other._jac) +
VdotOperator(other._val)(self._jac))
......@@ -113,7 +113,7 @@ class Linearization(object):
from .operators.simple_linear_operators import SumReductionOperator
from .sugar import full
return Linearization(
Field(DomainTuple.scalar_domain(), self._val.sum()),
Field.scalar(self._val.sum()),
SumReductionOperator(self._jac.target)(self._jac))
def exp(self):
......
......@@ -108,8 +108,10 @@ class MultiDomain(object):
@staticmethod
def union(inp):
inp = set(inp)
if len(inp) == 1: # all domains are identical
return inp.pop()
res = {}
# FIXME speed up!
for dom in inp:
for key, subdom in zip(dom._keys, dom._domains):
if key in res:
......
......@@ -54,7 +54,7 @@ class MultiField(object):
if domain is None:
domain = MultiDomain.make({key: v._domain
for key, v in dict.items()})
res = tuple(dict[key] if key in dict else Field.full(dom, 0)
res = tuple(dict[key] if key in dict else Field(dom, 0)
for key, dom in zip(domain.keys(), domain.domains()))
return MultiField(domain, res)
......@@ -124,7 +124,7 @@ class MultiField(object):
@staticmethod
def full(domain, val):
domain = MultiDomain.make(domain)
return MultiField(domain, tuple(Field.full(dom, val)
return MultiField(domain, tuple(Field(dom, val)
for dom in domain._domains))
def to_global_data(self):
......
......@@ -59,3 +59,8 @@ class EndomorphicOperator(LinearOperator):
A sample from the Gaussian of given covariance.
"""
raise NotImplementedError
def _check_input(self, x, mode):
self._check_mode(mode)
if self.domain is not x.domain:
raise ValueError("The operator's and field's domains don't match.")
......@@ -40,10 +40,10 @@ class SquaredNormOperator(EnergyOperator):
def apply(self, x):
if isinstance(x, Linearization):
val = Field(self._target, x.val.vdot(x.val))
val = Field.scalar(x.val.vdot(x.val))
jac = VdotOperator(2*x.val)(x.jac)
return Linearization(val, jac)
return Field(self._target, x.vdot(x))
return Field.scalar(x.vdot(x))
class QuadraticFormOperator(EnergyOperator):
......@@ -58,9 +58,9 @@ class QuadraticFormOperator(EnergyOperator):
if isinstance(x, Linearization):
t1 = self._op(x.val)
jac = VdotOperator(t1)(x.jac)
val = Field(self._target, 0.5*x.val.vdot(t1))
val = Field.scalar(0.5*x.val.vdot(t1))
return Linearization(val, jac)
return Field(self._target, 0.5*x.vdot(self._op(x)))
return Field.scalar(0.5*x.vdot(self._op(x)))
class GaussianEnergy(EnergyOperator):
......@@ -106,7 +106,7 @@ class PoissonianEnergy(EnergyOperator):
x = self._op(x)
res = x.sum() - x.log().vdot(self._d)
if not isinstance(x, Linearization):
return Field(self._target, res)
return Field.scalar(res)
metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
return res.add_metric(metric)
......@@ -121,7 +121,7 @@ class BernoulliEnergy(EnergyOperator):
x = self._p(x)
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
if not isinstance(x, Linearization):
return Field(self._target, v)
return Field.scalar(v)
met = makeOp(1./(x.val*(1.-x.val)))
met = SandwichOperator.make(x.jac, met)
return v.add_metric(met)
......
......@@ -39,9 +39,9 @@ class VdotOperator(LinearOperator):
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
self._check_mode(mode)
if mode == self.TIMES:
return Field(self._target, self._field.vdot(x))
return Field.scalar(self._field.vdot(x))
return self._field*x.local_data[()]
......@@ -54,7 +54,7 @@ class SumReductionOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field(self._target, x.sum())
return Field.scalar(x.sum())
return full(self._domain, x.local_data[()])
......@@ -90,7 +90,7 @@ class FieldAdapter(LinearOperator):
if mode == self.TIMES:
return x[self._name]
values = tuple(Field.full(dom, 0.) if key != self._name else x
values = tuple(Field(dom, 0.) if key != self._name else x
for key, dom in self._domain.items())
return MultiField(self._domain, values)
......@@ -142,7 +142,7 @@ class NullOperator(LinearOperator):
@staticmethod
def _nullfield(dom):
if isinstance(dom, DomainTuple):
return Field.full(dom, 0)
return Field(dom, 0)
else:
return MultiField.full(dom, 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment