Commit 0674df96 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

start experiments with data object interface

parent 6d7d5ede
......@@ -47,3 +47,5 @@ from .sugar import *
from . import plotting
from . import library
from .data_objects import numpy_do as dobj
# Data object module for NIFTy that uses simple numpy ndarrays.
import numpy as np
from numpy import ndarray as data_object
from numpy import full, empty, sqrt, ones, zeros, vdot, abs
def from_object(object, dtype=None, copy=True):
return np.array(object, dtype=dtype, copy=copy)
......@@ -24,6 +24,7 @@ from . import nifty_utilities as utilities
from .random import Random
from .domain_tuple import DomainTuple
from functools import reduce
from . import dobj
class Field(object):
......@@ -78,16 +79,16 @@ class Field(object):
if isinstance(val, Field):
if self.domain != val.domain:
raise ValueError("Domain mismatch")
self._val = np.array(val.val, dtype=dtype, copy=copy)
self._val = dobj.from_object(val.val, dtype=dtype, copy=copy)
elif (np.isscalar(val)):
self._val = np.full(self.domain.shape, dtype=dtype, fill_value=val)
elif isinstance(val, np.ndarray):
self._val = dobj.full(self.domain.shape, dtype=dtype, fill_value=val)
elif isinstance(val, dobj.data_object):
if self.domain.shape == val.shape:
self._val = np.array(val, dtype=dtype, copy=copy)
self._val = dobj.from_object(val, dtype=dtype, copy=copy)
else:
raise ValueError("Shape mismatch")
elif val is None:
self._val = np.empty(self.domain.shape, dtype=dtype)
self._val = dobj.empty(self.domain.shape, dtype=dtype)
else:
raise TypeError("unknown source type")
......@@ -253,7 +254,7 @@ class Field(object):
"synthetization.")
result_domain[i] = self.domain[i].harmonic_partner
spec = np.sqrt(self.val)
spec = dobj.sqrt(self.val)
for i in spaces:
power_space = self.domain[i]
local_blow_up = [slice(None)]*len(spec.shape)
......@@ -449,7 +450,7 @@ class Field(object):
The weighted field.
"""
new_field = Field(val=self, copy=not inplace)
new_field = self if inplace else self.copy()
if spaces is None:
spaces = range(len(self.domain))
......@@ -462,7 +463,7 @@ class Field(object):
if np.isscalar(wgt):
fct *= wgt
else:
new_shape = np.ones(len(self.shape), dtype=np.int)
new_shape = dobj.ones(len(self.shape), dtype=np.int)
new_shape[self.domain.axes[ind][0]:
self.domain.axes[ind][-1]+1] = wgt.shape
wgt = wgt.reshape(new_shape)
......@@ -504,7 +505,7 @@ class Field(object):
fct = tmp
if spaces is None:
return fct*np.vdot(y.val.ravel(), x.val.ravel())
return fct*dobj.vdot(y.val.ravel(), x.val.ravel())
else:
# create a diagonal operator which is capable of taking care of the
# axes-matching
......@@ -522,7 +523,7 @@ class Field(object):
The L2-norm of the field values.
"""
return np.sqrt(np.abs(self.vdot(x=self)))
return dobj.sqrt(dobj.abs(self.vdot(x=self)))
def conjugate(self):
""" Returns the complex conjugate of the field.
......@@ -544,7 +545,7 @@ class Field(object):
return Field(self.domain, -self.val, self.dtype)
def __abs__(self):
return Field(self.domain, np.abs(self.val), self.dtype)
return Field(self.domain, dobj.abs(self.val), self.dtype)
def _contraction_helper(self, op, spaces):
if spaces is None:
......@@ -597,6 +598,13 @@ class Field(object):
def std(self, spaces=None):
return self._contraction_helper('std', spaces)
def copy_content_from(self, other):
if not isinstance(other, Field):
raise TypeError("argument must be a Field")
if other.domain != self.domain:
raise ValueError("domains are incompatible.")
self.val[()] = other.val
# ---General binary methods---
def _binary_helper(self, other, op):
......
......@@ -92,9 +92,6 @@ else:
def general_axpy(a,x,y,out):
if x.domain != y.domain or x.domain != out.domain:
raise ValueError ("Incompatible domains")
x = x.val
y = y.val
out = out.val
if out is x:
if a != 1.:
......@@ -106,7 +103,7 @@ else:
else:
out += x
else:
out[()] = y
out.copy_content_from(y)
if a != 1.:
out += a*x
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment