Commit 97974970 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

domain -> target part 2 ... ouch

parent 4f9bda1c
Pipeline #72566 failed with stages
in 19 minutes and 53 seconds
......@@ -27,14 +27,14 @@ class Field(Operand):
"""The discrete representation of a continuous field over multiple spaces.
Stores data arrays and carries all the needed meta-information (i.e. the
domain) for operators to be able to operate on them.
target) for operators to be able to operate on them.
Parameters
----------
domain : DomainTuple
The domain of the new Field.
target : DomainTuple
The target of the new Field.
val : numpy.ndarray
This object's shape must match the domain shape
This object's shape must match the target shape
After construction, the object will no longer be writeable!
Notes
......@@ -45,17 +45,17 @@ class Field(Operand):
_scalar_dom = DomainTuple.scalar_domain()
def __init__(self, domain, val):
if not isinstance(domain, DomainTuple):
raise TypeError("domain must be of type DomainTuple")
def __init__(self, target, val):
if not isinstance(target, DomainTuple):
raise TypeError("target must be of type DomainTuple")
if not isinstance(val, np.ndarray):
if np.isscalar(val):
val = np.full(domain.shape, val)
val = np.full(target.shape, val)
else:
raise TypeError("val must be of type numpy.ndarray")
if domain.shape != val.shape:
raise ValueError("shape mismatch between val and domain")
self._domain = domain
if target.shape != val.shape:
raise ValueError("shape mismatch between val and target")
self._target = target
self._val = val
self._val.flags.writeable = False
......@@ -71,12 +71,12 @@ class Field(Operand):
raise TypeError("Field does not support implicit conversion to bool")
@staticmethod
def full(domain, val):
"""Creates a Field with a given domain, filled with a constant value.
def full(target, val):
"""Creates a Field with a given target, filled with a constant value.
Parameters
----------
domain : Domain, tuple of Domain, or DomainTuple
target : Domain, tuple of Domain, or DomainTuple
Domain of the new Field.
val : float/complex/int scalar
Fill value. Data type of the field is inferred from val.
......@@ -90,49 +90,49 @@ class Field(Operand):
raise TypeError("val must be a scalar")
if not (np.isreal(val) or np.iscomplex(val)):
raise TypeError("need arithmetic scalar")
domain = DomainTuple.make(domain)
return Field(domain, val)
target = DomainTuple.make(target)
return Field(target, val)
@staticmethod
def from_raw(domain, arr):
"""Returns a Field constructed from `domain` and `arr`.
def from_raw(target, arr):
"""Returns a Field constructed from `target` and `arr`.
Parameters
----------
domain : DomainTuple, tuple of Domain, or Domain
The domain of the new Field.
target : DomainTuple, tuple of Domain, or Domain
The target of the new Field.
arr : numpy.ndarray
The data content to be used for the new Field.
Its shape must match the shape of `domain`.
Its shape must match the shape of `target`.
"""
return Field(DomainTuple.make(domain), arr)
return Field(DomainTuple.make(target), arr)
def cast_domain(self, new_domain):
"""Returns a field with the same data, but a different domain
def cast_target(self, new_target):
"""Returns a field with the same data, but a different target
Parameters
----------
new_domain : Domain, tuple of Domain, or DomainTuple
The domain for the returned field. Must be shape-compatible to
new_target : Domain, tuple of Domain, or DomainTuple
The target for the returned field. Must be shape-compatible to
`self`.
Returns
-------
Field
Field defined on `new_domain`, but with the same data as `self`.
Field defined on `new_target`, but with the same data as `self`.
"""
return Field(DomainTuple.make(new_domain), self._val)
return Field(DomainTuple.make(new_target), self._val)
@staticmethod
def from_random(random_type, domain, dtype=np.float64, **kwargs):
def from_random(random_type, target, dtype=np.float64, **kwargs):
"""Draws a random field with the given parameters.
Parameters
----------
random_type : 'pm1', 'normal', or 'uniform'
The random distribution to use.
domain : DomainTuple
The domain of the output random Field.
target : DomainTuple
The target of the output random Field.
dtype : type
The datatype of the output random Field.
......@@ -142,10 +142,10 @@ class Field(Operand):
The newly created Field.
"""
from .random import Random
domain = DomainTuple.make(domain)
target = DomainTuple.make(target)
generator_function = getattr(Random, random_type)
arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs)
return Field(domain, arr)
arr = generator_function(dtype=dtype, shape=target.shape, **kwargs)
return Field(target, arr)
@property
def fld(self):
......@@ -171,32 +171,26 @@ class Field(Operand):
"""type : the data type of the field's entries"""
return self._val.dtype
@property
def domain(self):
"""DomainTuple : the field's domain"""
raise NotImplementedError
return None# self._domain
@property
def target(self):
"""DomainTuple : the field's domain"""
return self._domain
"""DomainTuple : the field's target"""
return self._target
@property
def shape(self):
"""tuple of int : the concatenated shapes of all sub-domains"""
return self._domain.shape
return self._target.shape
@property
def size(self):
"""int : total number of pixels in the field"""
return self._domain.size
return self._target.size
@property
def real(self):
"""Field : The real part of the field"""
if utilities.iscomplextype(self.dtype):
return Field(self._domain, self._val.real)
return Field(self._target, self._val.real)
return self
@property
......@@ -204,7 +198,7 @@ class Field(Operand):
"""Field : The imaginary part of the field"""
if not utilities.iscomplextype(self.dtype):
raise ValueError(".imag called on a non-complex Field")
return Field(self._domain, self._val.imag)
return Field(self._target, self._val.imag)
def scalar_weight(self, spaces=None):
"""Returns the uniform volume element for a sub-domain of `self`.
......@@ -212,8 +206,8 @@ class Field(Operand):
Parameters
----------
spaces : int, tuple of int or None
Indices of the sub-domains of the field's domain to be considered.
If `None`, the entire domain is used.
Indices of the sub-domains of the field's target to be considered.
If `None`, the entire target is used.
Returns
-------
......@@ -221,23 +215,23 @@ class Field(Operand):
If the requested sub-domain has a uniform volume element, it is
returned. Otherwise, `None` is returned.
"""
return self._domain.scalar_weight(spaces)
return self._target.scalar_weight(spaces)
def total_volume(self, spaces=None):
"""Returns the total volume of the field's domain or of a subspace of it.
"""Returns the total volume of the field's target or of a subspace of it.
Parameters
----------
spaces : int, tuple of int or None
Indices of the sub-domains of the field's domain to be considered.
If `None`, the total volume of the whole domain is returned.
Indices of the sub-domains of the field's target to be considered.
If `None`, the total volume of the whole target is returned.
Returns
-------
float
the total volume of the requested (sub-)domain.
"""
return self._domain.total_volume(spaces)
return self._target.total_volume(spaces)
def weight(self, power=1, spaces=None):
"""Weights the pixels of `self` with their invidual pixel volumes.
......@@ -249,7 +243,7 @@ class Field(Operand):
spaces : None, int or tuple of int
Determines on which sub-domain the operation takes place.
If None, the entire domain is used.
If None, the entire target is used.
Returns
-------
......@@ -258,24 +252,24 @@ class Field(Operand):
"""
aout = self.val_rw()
spaces = utilities.parse_spaces(spaces, len(self._domain))
spaces = utilities.parse_spaces(spaces, len(self._target))
fct = 1.
for ind in spaces:
wgt = self._domain[ind].dvol
wgt = self._target[ind].dvol
if np.isscalar(wgt):
fct *= wgt
else:
new_shape = np.ones(len(self.shape), dtype=np.int)
new_shape[self._domain.axes[ind][0]:
self._domain.axes[ind][-1]+1] = wgt.shape
new_shape[self._target.axes[ind][0]:
self._target.axes[ind][-1]+1] = wgt.shape
wgt = wgt.reshape(new_shape)
aout *= wgt**power
fct = fct**power
if fct != 1.:
aout *= fct
return Field(self._domain, aout)
return Field(self._target, aout)
def outer(self, x):
"""Computes the outer product of 'self' with x.
......@@ -301,7 +295,7 @@ class Field(Operand):
Parameters
----------
x : Field
x must be defined on the same domain as `self`.
x must be defined on the same target as `self`.
spaces : None, int or tuple of int
The dot product is only carried out over the sub-domains in this
......@@ -316,10 +310,10 @@ class Field(Operand):
raise TypeError("The dot-partner must be an instance of " +
"the Field class")
if x._domain != self._domain:
if x._target != self._target:
raise ValueError("Domain mismatch")
ndom = len(self._domain)
ndom = len(self._target)
spaces = utilities.parse_spaces(spaces, ndom)
if len(spaces) == ndom:
......@@ -334,7 +328,7 @@ class Field(Operand):
Parameters
----------
x : Field
x must be defined on the same domain as `self`.
x must be defined on the same target as `self`.
Returns
-------
......@@ -345,7 +339,7 @@ class Field(Operand):
raise TypeError("The dot-partner must be an instance of " +
"the Field class")
if x._domain != self._domain:
if x._target != self._target:
raise ValueError("Domain mismatch")
return np.vdot(self._val, x._val)
......@@ -374,7 +368,7 @@ class Field(Operand):
The complex conjugated field.
"""
if utilities.iscomplextype(self._val.dtype):
return Field(self._domain, self._val.conjugate())
return Field(self._target, self._val.conjugate())
return self
# ---General unary/contraction methods---
......@@ -383,18 +377,18 @@ class Field(Operand):
return self
def __neg__(self):
return Field(self._domain, -self._val)
return Field(self._target, -self._val)
def __abs__(self):
return Field(self._domain, abs(self._val))
return Field(self._target, abs(self._val))
def _contraction_helper(self, op, spaces):
if spaces is None:
return Field.scalar(getattr(self._val, op)())
spaces = utilities.parse_spaces(spaces, len(self._domain))
spaces = utilities.parse_spaces(spaces, len(self._target))
axes_list = tuple(self._domain.axes[sp_index] for sp_index in spaces)
axes_list = tuple(self._target.axes[sp_index] for sp_index in spaces)
if len(axes_list) > 0:
axes_list = reduce(lambda x, y: x+y, axes_list)
......@@ -406,11 +400,11 @@ class Field(Operand):
if np.isscalar(data):
return Field.scalar(data)
else:
return_domain = tuple(dom
for i, dom in enumerate(self._domain)
return_target = tuple(dom
for i, dom in enumerate(self._target)
if i not in spaces)
return Field(DomainTuple.make(return_domain), data)
return Field(DomainTuple.make(return_target), data)
def sum(self, spaces=None):
"""Sums up over the sub-domains given by `spaces`.
......@@ -601,7 +595,7 @@ class Field(Operand):
# MR FIXME: not very efficient or accurate
m1 = self.mean(spaces)
from .operators.contraction_operator import ContractionOperator
op = ContractionOperator(self._domain, spaces)
op = ContractionOperator(self._target, spaces)
m1 = op.adjoint_times(m1)
if utilities.iscomplextype(self.dtype):
sq = abs(self-m1)**2
......@@ -667,18 +661,18 @@ class Field(Operand):
return "<nifty6.Field>"
def __str__(self):
return "nifty6.Field instance\n- domain = " + \
self._domain.__str__() + \
return "nifty6.Field instance\n- target = " + \
self._target.__str__() + \
"\n- val = " + repr(self._val)
def extract(self, dom):
if dom != self._domain:
raise ValueError("domain mismatch")
if dom != self._target:
raise ValueError("target mismatch")
return self
def extract_part(self, dom):
if dom != self._domain:
raise ValueError("domain mismatch")
if dom != self._target:
raise ValueError("target mismatch")
return self
def unite(self, other):
......@@ -688,14 +682,14 @@ class Field(Operand):
return self-other if neg else self+other
def _binary_op(self, other, op):
# if other is a field, make sure that the domains match
# if other is a field, make sure that the targets match
f = getattr(self._val, op)
if isinstance(other, Field):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
return Field(self._domain, f(other._val))
if other._target != self._target:
raise ValueError("targets are incompatible.")
return Field(self._target, f(other._val))
if np.isscalar(other):
return Field(self._domain, f(other))
return Field(self._target, f(other))
return NotImplemented
def _prep_args(self, args, kwargs):
......@@ -711,13 +705,13 @@ class Field(Operand):
def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict
argstmp, kwargstmp = self._prep_args(args, kwargs)
return Field(self._domain, ptw_dict[op][0](self._val, *argstmp, **kwargstmp))
return Field(self._target, ptw_dict[op][0](self._val, *argstmp, **kwargstmp))
def ptw_with_deriv(self, op, *args, **kwargs):
from .pointwise import ptw_dict
argstmp, kwargstmp = self._prep_args(args, kwargs)
tmp = ptw_dict[op][1](self._val, *argstmp, **kwargstmp)
return (Field(self._domain, tmp[0]), Field(self._domain, tmp[1]))
return (Field(self._target, tmp[0]), Field(self._target, tmp[1]))
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......
......@@ -54,7 +54,7 @@ def _toArray_rw(fld):
def _toField(arr, template):
if isinstance(template, Field):
return Field(template.domain, arr.reshape(template.shape).copy())
return Field(template.target, arr.reshape(template.shape).copy())
ofs = 0
res = []
for v in template.values():
......@@ -70,7 +70,7 @@ def _toField(arr, template):
class _MinHelper(object):
def __init__(self, energy):
self._energy = energy
self._domain = energy.position.domain
self._domain = energy.position.target
def _update(self, x):
pos = _toField(x, self._energy.position)
......
......@@ -25,77 +25,72 @@ from .operand import Operand
class MultiField(Operand):
def __init__(self, domain, val):
def __init__(self, target, val):
"""The discrete representation of a continuous field over a sum space.
Parameters
----------
domain: MultiDomain
target: MultiDomain
val: tuple containing Field entries
"""
if not isinstance(domain, MultiDomain):
raise TypeError("domain must be of type MultiDomain")
if not isinstance(target, MultiDomain):
raise TypeError("target must be of type MultiDomain")
if not isinstance(val, tuple):
raise TypeError("val must be a tuple")
if len(val) != len(domain):
if len(val) != len(target):
raise ValueError("length mismatch")
for d, v in zip(domain._domains, val):
for d, v in zip(target._domains, val):
if isinstance(v, Field):
if v.target != d:
print(v.target)
print(d)
raise ValueError("domain mismatch")
raise ValueError("target mismatch")
else:
raise TypeError("bad entry in val (must be Field)")
self._domain = domain
self._target = target
self._val = val
@staticmethod
def from_dict(dict, domain=None):
if domain is None:
def from_dict(dict, target=None):
if target is None:
for dd in dict.values():
if not isinstance(dd.target, DomainTuple):
raise TypeError('Values of dictionary need to be Fields '
'defined on DomainTuples.')
domain = MultiDomain.make({key: v._domain
target = MultiDomain.make({key: v._target
for key, v in dict.items()})
res = tuple(dict[key] if key in dict else Field(dom, 0.)
for key, dom in zip(domain.keys(), domain.domains()))
return MultiField(domain, res)
for key, dom in zip(target.keys(), target.domains()))
return MultiField(target, res)
def to_dict(self):
return {key: val for key, val in zip(self._domain.keys(), self._val)}
return {key: val for key, val in zip(self._target.keys(), self._val)}
def __getitem__(self, key):
return self._val[self._domain.idx[key]]
return self._val[self._target.idx[key]]
def __contains__(self, key):
return key in self._domain.idx
return key in self._target.idx
def keys(self):
return self._domain.keys()
return self._target.keys()
def items(self):
return zip(self._domain.keys(), self._val)
return zip(self._target.keys(), self._val)
def values(self):
return self._val
@property
def domain(self):
raise NotImplementedError
return None #self._domain
@property
def target(self):
return self._domain
return self._target
# @property
# def dtype(self):
# return {key: val.dtype for key, val in self._val.items()}
def _transform(self, op):
return MultiField(self._domain, tuple(op(v) for v in self._val))
return MultiField(self._target, tuple(op(v) for v in self._val))
@property
def real(self):
......@@ -108,20 +103,20 @@ class MultiField(Operand):
return self._transform(lambda x: x.imag)
@staticmethod
def from_random(random_type, domain, dtype=np.float64, **kwargs):
domain = MultiDomain.make(domain)
# dtype = MultiField.build_dtype(dtype, domain)
def from_random(random_type, target, dtype=np.float64, **kwargs):
target = MultiDomain.make(target)
# dtype = MultiField.build_dtype(dtype, target)
return MultiField(
domain, tuple(Field.from_random(random_type, dom, dtype, **kwargs)
for dom in domain._domains))
target, tuple(Field.from_random(random_type, dom, dtype, **kwargs)
for dom in target._domains))
def _check_domain(self, other):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
def _check_target(self, other):
if other._target != self._target:
raise ValueError("targets are incompatible.")
def s_vdot(self, x):
result = 0.
self._check_domain(x)
self._check_target(x)
for v1, v2 in zip(self._val, x._val):
result += v1.s_vdot(v2)
return result
......@@ -130,18 +125,18 @@ class MultiField(Operand):
return Field.scalar(self.s_vdot(x))
# @staticmethod
# def build_dtype(dtype, domain):
# def build_dtype(dtype, target):
# if isinstance(dtype, dict):
# return dtype
# if dtype is None:
# dtype = np.float64
# return {key: dtype for key in domain.keys()}
# return {key: dtype for key in target.keys()}
@staticmethod
def full(domain, val):
domain = MultiDomain.make(domain)
return MultiField(domain, tuple(Field(dom, val)
for dom in domain._domains))
def full(target, val):
target = MultiDomain.make(target)
return MultiField(target, tuple(Field(dom, val)
for dom in target._domains))
@property
def fld(self):
......@@ -150,17 +145,17 @@ class MultiField(Operand):
@property
def val(self):
return {key: val.val
for key, val in zip(self._domain.keys(), self._val)}
for key, val in zip(self._target.keys(), self._val)}
def val_rw(self):
return {key: val.val_rw()
for key, val in zip(self._domain.keys(), self._val)}
for key, val in zip(self._target.keys(), self._val)}
@staticmethod
def from_raw(domain, arr):
def from_raw(target, arr):
return MultiField(
domain, tuple(Field(domain[key], arr[key])
for key in domain.keys()))
target, tuple(Field(target[key], arr[key])
for key in target.keys()))
def norm(self, ord=2):
"""Computes the norm of the field values.
......@@ -200,7 +195,7 @@ class MultiField(Operand):
size : int
The sum of the size of the individual fields
"""
return utilities.my_sum(map(lambda d: d.size, self._domain.domains()))
return utilities.my_sum(map(lambda d: d.size, self._target.domains()))
def __neg__(self):
return self._transform(lambda x: -x)
......@@ -227,13 +222,13 @@ class MultiField(Operand):
return False