Commit 0bef1520 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

introduce new class hierarchy

parent 193a276f
......@@ -284,7 +284,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True):
linmid = op(Linearization.make_var(locmid))
dirder = linmid.jac(dir)
numgrad = (lin2.val-lin.val)
xtol = tol * dirder.norm() / np.sqrt(dirder.size)
xtol = tol * dirder.norm() / np.sqrt(dirder.target.size)
hist.append((numgrad-dirder).norm())
# print(len(hist),hist[-1])
if (abs(numgrad-dirder) <= xtol).s_all():
......
......@@ -19,10 +19,11 @@ from functools import reduce
import numpy as np
from . import utilities
from .operators.operator import Operator
from .domain_tuple import DomainTuple
class Field(object):
class Field(Operator):
"""The discrete representation of a continuous field over multiple spaces.
Stores data arrays and carries all the needed meta-information (i.e. the
......@@ -161,6 +162,26 @@ class Field(object):
"""
return self._val.copy()
@property
def jac(self):
return None
@property
def want_metric(self):
return False
@property
def metric(self):
raise NotImplementedError()
def __call__(self, other):
if (other.target == self.domain):
return self
raise ValueError("domain mismatch")
def __matmul__(self, other):
return self(other)
@property
def dtype(self):
"""type : the data type of the field's entries"""
......@@ -172,14 +193,9 @@ class Field(object):
return self._domain
@property
def shape(self):
"""tuple of int : the concatenated shapes of all sub-domains"""
return self._domain.shape
@property
def size(self):
"""int : total number of pixels in the field"""
return self._domain.size
def target(self):
"""DomainTuple : the field's domain"""
return self._domain
@property
def real(self):
......@@ -255,7 +271,7 @@ class Field(object):
if np.isscalar(wgt):
fct *= wgt
else:
new_shape = np.ones(len(self.shape), dtype=np.int)
new_shape = np.ones(len(self._domain.shape), dtype=np.int)
new_shape[self._domain.axes[ind][0]:
self._domain.axes[ind][-1]+1] = wgt.shape
wgt = wgt.reshape(new_shape)
......
......@@ -17,13 +17,14 @@
import numpy as np
from .operators.operator import Operator
from .field import Field
from .multi_field import MultiField
from .sugar import makeOp
from . import utilities
class Linearization(object):
class Linearization(Operator):
"""Let `A` be an operator and `x` a field. `Linearization` stores the value
of the operator application (i.e. `A(x)`), the local Jacobian
(i.e. `dA(x)/dx`) and, optionally, the local metric.
......@@ -118,6 +119,14 @@ class Linearization(object):
"""
return self._metric
def __call__(self, other):
if (other.target == self.domain):
return self
raise ValueError("domain mismatch")
def __matmul__(self, other):
return self(other)
def __getitem__(self, name):
return self.new(self._val[name], self._jac.ducktape_left(name))
......
......@@ -54,7 +54,7 @@ def _toArray_rw(fld):
def _toField(arr, template):
if isinstance(template, Field):
return Field(template.domain, arr.reshape(template.shape).copy())
return Field(template.domain, arr.reshape(template.domain.shape).copy())
ofs = 0
res = []
for v in template.values():
......
......@@ -18,12 +18,13 @@
import numpy as np
from . import utilities
from .operators.operator import Operator
from .field import Field
from .multi_domain import MultiDomain
from .domain_tuple import DomainTuple
class MultiField(object):
class MultiField(Operator):
def __init__(self, domain, val):
"""The discrete representation of a continuous field over a sum space.
......@@ -82,6 +83,10 @@ class MultiField(object):
def domain(self):
return self._domain
@property
def target(self):
return self._domain
# @property
# def dtype(self):
# return {key: val.dtype for key, val in self._val.items()}
......@@ -144,6 +149,26 @@ class MultiField(object):
return {key: val.val_rw()
for key, val in zip(self._domain.keys(), self._val)}
@property
def jac(self):
return None
@property
def want_metric(self):
return False
@property
def metric(self):
raise NotImplementedError()
def __call__(self, other):
if (other.target == self.domain):
return self
raise ValueError("domain mismatch")
def __matmul__(self, other):
return self(other)
@staticmethod
def from_raw(domain, arr):
return MultiField(
......@@ -179,17 +204,6 @@ class MultiField(object):
"""
return utilities.my_sum(map(lambda v: v.s_sum(), self._val))
@property
def size(self):
"""Computes the overall degrees of freedom.
Returns
-------
size : int
The sum of the size of the individual fields
"""
return utilities.my_sum(map(lambda d: d.size, self._domain.domains()))
def __neg__(self):
return self._transform(lambda x: -x)
......
......@@ -262,7 +262,7 @@ class InverseGammaLikelihood(EnergyOperator):
self._domain = DomainTuple.make(beta.domain)
self._beta = beta
if np.isscalar(alpha):
alpha = Field(beta.domain, np.full(beta.shape, alpha))
alpha = Field(beta.domain, np.full(beta.target.shape, alpha))
elif not isinstance(alpha, Field):
raise TypeError
self._alphap1 = alpha+1
......
......@@ -17,8 +17,6 @@
import numpy as np
from ..field import Field
from ..multi_field import MultiField
from ..utilities import NiftyMeta, indent
......@@ -179,6 +177,8 @@ class Operator(metaclass=NiftyMeta):
return self.apply(x.extract(self.domain))
def _check_input(self, x):
from ..field import Field
from ..multi_field import MultiField
from ..linearization import Linearization
from .scaling_operator import ScalingOperator
if not isinstance(x, (Field, MultiField, Linearization)):
......
......@@ -44,6 +44,6 @@ class OuterProduct(LinearOperator):
return Field(
self._target, np.multiply.outer(
self._field.val, x.val))
axes = len(self._field.shape)
axes = len(self._field.target.shape)
return Field(
self._domain, np.tensordot(self._field.val, x.val, axes))
......@@ -29,8 +29,7 @@ SPACE_COMBINATIONS = [(), SPACES[0], SPACES[1], SPACES]
@pmp('domain', SPACE_COMBINATIONS)
@pmp('attribute_desired_type',
[['domain', ift.DomainTuple], ['val', np.ndarray],
['shape', tuple], ['size', (np.int, np.int64)]])
[['domain', ift.DomainTuple], ['val', np.ndarray]])
def test_return_types(domain, attribute_desired_type):
attribute = attribute_desired_type[0]
desired_type = attribute_desired_type[1]
......@@ -288,18 +287,18 @@ def test_stdfunc():
s = ift.RGSpace((200,))
f = ift.Field.full(s, 27)
assert_equal(f.val, 27)
assert_equal(f.shape, (200,))
assert_equal(f.target.shape, (200,))
assert_equal(f.dtype, np.int)
fx = ift.full(f.domain, 0)
assert_equal(f.dtype, fx.dtype)
assert_equal(f.shape, fx.shape)
assert_equal(f.target.shape, fx.target.shape)
assert_equal(fx.val, 0)
fx = ift.full(f.domain, 1)
assert_equal(f.dtype, fx.dtype)
assert_equal(f.shape, fx.shape)
assert_equal(f.target.shape, fx.target.shape)
assert_equal(fx.val, 1)
fx = ift.full(f.domain, 67.)
assert_equal(f.shape, fx.shape)
assert_equal(f.target.shape, fx.target.shape)
assert_equal(fx.val, 67.)
f = ift.Field.from_random("normal", s)
f2 = ift.Field.from_random("normal", s)
......
......@@ -40,7 +40,7 @@ def test_multifield_field_consistency():
f1 = ift.full(dom, 27)
f2 = ift.makeField(dom['d1'], f1['d1'].val)
assert_equal(f1.s_sum(), f2.s_sum())
assert_equal(f1.size, f2.size)
assert_equal(f1.target.size, f2.target.size)
def test_dataconv():
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment