Commit 3c536750 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

various tweaks

parent f0ed0b7f
...@@ -26,6 +26,8 @@ from .domain_tuple import DomainTuple ...@@ -26,6 +26,8 @@ from .domain_tuple import DomainTuple
class Field(object): class Field(object):
_scalar_dom = DomainTuple.scalar_domain()
""" The discrete representation of a continuous field over multiple spaces. """ The discrete representation of a continuous field over multiple spaces.
In NIFTy, Fields are used to store data arrays and carry all the needed In NIFTy, Fields are used to store data arrays and carry all the needed
...@@ -55,11 +57,15 @@ class Field(object): ...@@ -55,11 +57,15 @@ class Field(object):
else: else:
raise TypeError("val must be of type dobj.data_object") raise TypeError("val must be of type dobj.data_object")
if domain.shape != val.shape: if domain.shape != val.shape:
raise ValueError("mismatch between the shapes of val and domain") raise ValueError("shape mismatch between val and domain")
self._domain = domain self._domain = domain
self._val = val self._val = val
dobj.lock(self._val) dobj.lock(self._val)
@staticmethod
def scalar(val):
return Field(Field._scalar_dom, val)
# prevent implicit conversion to bool # prevent implicit conversion to bool
def __nonzero__(self): def __nonzero__(self):
raise TypeError("Field does not support implicit conversion to bool") raise TypeError("Field does not support implicit conversion to bool")
......
...@@ -34,7 +34,7 @@ class Linearization(object): ...@@ -34,7 +34,7 @@ class Linearization(object):
@property @property
def gradient(self): def gradient(self):
"""Only available if target is a scalar""" """Only available if target is a scalar"""
return self._jac.adjoint_times(Field(self._jac.target, 1.)) return self._jac.adjoint_times(Field.scalar(1.))
@property @property
def metric(self): def metric(self):
...@@ -102,10 +102,10 @@ class Linearization(object): ...@@ -102,10 +102,10 @@ class Linearization(object):
from .operators.simple_linear_operators import VdotOperator from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)): if isinstance(other, (Field, MultiField)):
return Linearization( return Linearization(
Field(DomainTuple.scalar_domain(), self._val.vdot(other)), Field.scalar(self._val.vdot(other)),
VdotOperator(other)(self._jac)) VdotOperator(other)(self._jac))
return Linearization( return Linearization(
Field(DomainTuple.scalar_domain(), self._val.vdot(other._val)), Field.scalar(self._val.vdot(other._val)),
VdotOperator(self._val)(other._jac) + VdotOperator(self._val)(other._jac) +
VdotOperator(other._val)(self._jac)) VdotOperator(other._val)(self._jac))
...@@ -113,7 +113,7 @@ class Linearization(object): ...@@ -113,7 +113,7 @@ class Linearization(object):
from .operators.simple_linear_operators import SumReductionOperator from .operators.simple_linear_operators import SumReductionOperator
from .sugar import full from .sugar import full
return Linearization( return Linearization(
Field(DomainTuple.scalar_domain(), self._val.sum()), Field.scalar(self._val.sum()),
SumReductionOperator(self._jac.target)(self._jac)) SumReductionOperator(self._jac.target)(self._jac))
def exp(self): def exp(self):
......
...@@ -108,8 +108,10 @@ class MultiDomain(object): ...@@ -108,8 +108,10 @@ class MultiDomain(object):
@staticmethod @staticmethod
def union(inp): def union(inp):
inp = set(inp)
if len(inp) == 1: # all domains are identical
return inp.pop()
res = {} res = {}
# FIXME speed up!
for dom in inp: for dom in inp:
for key, subdom in zip(dom._keys, dom._domains): for key, subdom in zip(dom._keys, dom._domains):
if key in res: if key in res:
......
...@@ -54,7 +54,7 @@ class MultiField(object): ...@@ -54,7 +54,7 @@ class MultiField(object):
if domain is None: if domain is None:
domain = MultiDomain.make({key: v._domain domain = MultiDomain.make({key: v._domain
for key, v in dict.items()}) for key, v in dict.items()})
res = tuple(dict[key] if key in dict else Field.full(dom, 0) res = tuple(dict[key] if key in dict else Field(dom, 0)
for key, dom in zip(domain.keys(), domain.domains())) for key, dom in zip(domain.keys(), domain.domains()))
return MultiField(domain, res) return MultiField(domain, res)
...@@ -124,7 +124,7 @@ class MultiField(object): ...@@ -124,7 +124,7 @@ class MultiField(object):
@staticmethod @staticmethod
def full(domain, val): def full(domain, val):
domain = MultiDomain.make(domain) domain = MultiDomain.make(domain)
return MultiField(domain, tuple(Field.full(dom, val) return MultiField(domain, tuple(Field(dom, val)
for dom in domain._domains)) for dom in domain._domains))
def to_global_data(self): def to_global_data(self):
......
...@@ -59,3 +59,8 @@ class EndomorphicOperator(LinearOperator): ...@@ -59,3 +59,8 @@ class EndomorphicOperator(LinearOperator):
A sample from the Gaussian of given covariance. A sample from the Gaussian of given covariance.
""" """
raise NotImplementedError raise NotImplementedError
def _check_input(self, x, mode):
self._check_mode(mode)
if self.domain is not x.domain:
raise ValueError("The operator's and field's domains don't match.")
...@@ -40,10 +40,10 @@ class SquaredNormOperator(EnergyOperator): ...@@ -40,10 +40,10 @@ class SquaredNormOperator(EnergyOperator):
def apply(self, x): def apply(self, x):
if isinstance(x, Linearization): if isinstance(x, Linearization):
val = Field(self._target, x.val.vdot(x.val)) val = Field.scalar(x.val.vdot(x.val))
jac = VdotOperator(2*x.val)(x.jac) jac = VdotOperator(2*x.val)(x.jac)
return Linearization(val, jac) return Linearization(val, jac)
return Field(self._target, x.vdot(x)) return Field.scalar(x.vdot(x))
class QuadraticFormOperator(EnergyOperator): class QuadraticFormOperator(EnergyOperator):
...@@ -58,9 +58,9 @@ class QuadraticFormOperator(EnergyOperator): ...@@ -58,9 +58,9 @@ class QuadraticFormOperator(EnergyOperator):
if isinstance(x, Linearization): if isinstance(x, Linearization):
t1 = self._op(x.val) t1 = self._op(x.val)
jac = VdotOperator(t1)(x.jac) jac = VdotOperator(t1)(x.jac)
val = Field(self._target, 0.5*x.val.vdot(t1)) val = Field.scalar(0.5*x.val.vdot(t1))
return Linearization(val, jac) return Linearization(val, jac)
return Field(self._target, 0.5*x.vdot(self._op(x))) return Field.scalar(0.5*x.vdot(self._op(x)))
class GaussianEnergy(EnergyOperator): class GaussianEnergy(EnergyOperator):
...@@ -106,7 +106,7 @@ class PoissonianEnergy(EnergyOperator): ...@@ -106,7 +106,7 @@ class PoissonianEnergy(EnergyOperator):
x = self._op(x) x = self._op(x)
res = x.sum() - x.log().vdot(self._d) res = x.sum() - x.log().vdot(self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field(self._target, res) return Field.scalar(res)
metric = SandwichOperator.make(x.jac, makeOp(1./x.val)) metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
return res.add_metric(metric) return res.add_metric(metric)
...@@ -121,7 +121,7 @@ class BernoulliEnergy(EnergyOperator): ...@@ -121,7 +121,7 @@ class BernoulliEnergy(EnergyOperator):
x = self._p(x) x = self._p(x)
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d) v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field(self._target, v) return Field.scalar(v)
met = makeOp(1./(x.val*(1.-x.val))) met = makeOp(1./(x.val*(1.-x.val)))
met = SandwichOperator.make(x.jac, met) met = SandwichOperator.make(x.jac, met)
return v.add_metric(met) return v.add_metric(met)
......
...@@ -39,9 +39,9 @@ class VdotOperator(LinearOperator): ...@@ -39,9 +39,9 @@ class VdotOperator(LinearOperator):
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_mode(mode)
if mode == self.TIMES: if mode == self.TIMES:
return Field(self._target, self._field.vdot(x)) return Field.scalar(self._field.vdot(x))
return self._field*x.local_data[()] return self._field*x.local_data[()]
...@@ -54,7 +54,7 @@ class SumReductionOperator(LinearOperator): ...@@ -54,7 +54,7 @@ class SumReductionOperator(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if mode == self.TIMES: if mode == self.TIMES:
return Field(self._target, x.sum()) return Field.scalar(x.sum())
return full(self._domain, x.local_data[()]) return full(self._domain, x.local_data[()])
...@@ -90,7 +90,7 @@ class FieldAdapter(LinearOperator): ...@@ -90,7 +90,7 @@ class FieldAdapter(LinearOperator):
if mode == self.TIMES: if mode == self.TIMES:
return x[self._name] return x[self._name]
values = tuple(Field.full(dom, 0.) if key != self._name else x values = tuple(Field(dom, 0.) if key != self._name else x
for key, dom in self._domain.items()) for key, dom in self._domain.items())
return MultiField(self._domain, values) return MultiField(self._domain, values)
...@@ -142,7 +142,7 @@ class NullOperator(LinearOperator): ...@@ -142,7 +142,7 @@ class NullOperator(LinearOperator):
@staticmethod @staticmethod
def _nullfield(dom): def _nullfield(dom):
if isinstance(dom, DomainTuple): if isinstance(dom, DomainTuple):
return Field.full(dom, 0) return Field(dom, 0)
else: else:
return MultiField.full(dom, 0) return MultiField.full(dom, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment