Commit d68a5a82 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

starting with new __repr__ implementation

parent e6c74a63
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2018 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .random import Random
import sys
ntask = 1
rank = 0
master = True
def is_numpy():
return False
def local_shape(shape, distaxis=0):
return shape
class data_object(object):
def __init__(self, shape, data):
self._data = data
def copy(self):
return data_object(self._data.shape, self._data.copy())
def dtype(self):
return self._data.dtype
def shape(self):
return self._data.shape
def size(self):
return self._data.size
def real(self):
return data_object(self._data.shape, self._data.real)
def imag(self):
return data_object(self._data.shape, self._data.imag)
def conj(self):
return data_object(self._data.shape, self._data.conj())
def conjugate(self):
return data_object(self._data.shape, self._data.conjugate())
def _contraction_helper(self, op, axis):
if axis is not None:
if len(axis) == len(self._data.shape):
axis = None
if axis is None:
return getattr(self._data, op)()
res = getattr(self._data, op)(axis=axis)
return data_object(res.shape, res)
def sum(self, axis=None):
return self._contraction_helper("sum", axis)
def prod(self, axis=None):
return self._contraction_helper("prod", axis)
def min(self, axis=None):
return self._contraction_helper("min", axis)
def max(self, axis=None):
return self._contraction_helper("max", axis)
def mean(self, axis=None):
if axis is None:
sz = self.size
sz = reduce(lambda x, y: x*y, [self.shape[i] for i in axis])
return self.sum(axis)/sz
def std(self, axis=None):
return np.sqrt(self.var(axis))
# FIXME: to be improved!
def var(self, axis=None):
if axis is not None and len(axis) != len(self.shape):
raise ValueError("functionality not yet supported")
return (abs(self-self.mean())**2).mean()
def _binary_helper(self, other, op):
a = self
if isinstance(other, data_object):
b = other
if a._data.shape != b._data.shape:
raise ValueError("shapes are incompatible.")
a = a._data
b = b._data
elif np.isscalar(other):
a = a._data
b = other
elif isinstance(other, np.ndarray):
a = a._data
b = other
return NotImplemented
tval = getattr(a, op)(b)
if tval is a:
return self
return data_object(self._data.shape, tval)
def __neg__(self):
return data_object(self._data.shape, -self._data)
def __abs__(self):
return data_object(self._data.shape, abs(self._data))
def all(self):
return self.sum() == self.size
def any(self):
return self.sum() != 0
def fill(self, value):
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
"__div__", "__rdiv__", "__idiv__",
"__truediv__", "__rtruediv__", "__itruediv__",
"__floordiv__", "__rfloordiv__", "__ifloordiv__",
"__pow__", "__rpow__", "__ipow__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
return self._binary_helper(other, op=op)
return func2
setattr(data_object, op, func(op))
def full(shape, fill_value, dtype=None):
return data_object(shape, np.full(shape, fill_value, dtype))
def empty(shape, dtype=None):
return data_object(shape, np.empty(shape, dtype))
def zeros(shape, dtype=None, distaxis=0):
return data_object(shape, np.zeros(shape, dtype))
def ones(shape, dtype=None, distaxis=0):
return data_object(shape, np.ones(shape, dtype))
def empty_like(a, dtype=None):
return data_object(np.empty_like(a._data, dtype))
def vdot(a, b):
return np.vdot(a._data, b._data)
def _math_helper(x, function, out):
function = getattr(np, function)
if out is not None:
function(x._data, out=out._data)
return out
return data_object(x.shape, function(x._data))
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
return _math_helper(x, f, out)
return func2
setattr(_current_module, f, func(f))
def from_object(object, dtype, copy, set_locked):
if dtype is None:
dtype = object.dtype
dtypes_equal = dtype == object.dtype
if set_locked and dtypes_equal and locked(object):
return object
if not dtypes_equal and not copy:
raise ValueError("cannot change data type without copying")
if set_locked and not copy:
raise ValueError("cannot lock object without copying")
data = np.array(object._data, dtype=dtype, copy=copy)
if set_locked:
data.flags.writeable = False
return data_object(object._shape, data, distaxis=object._distaxis)
# This function draws all random numbers on all tasks, to produce the same
# array independent on the number of tasks
# MR FIXME: depending on what is really wanted/needed (i.e. same result
# independent of number of tasks, performance etc.) we need to adjust the
# algorithm.
def from_random(random_type, shape, dtype=np.float64, **kwargs):
generator_function = getattr(Random, random_type)
ldat = generator_function(dtype=dtype, shape=shape, **kwargs)
return from_local_data(shape, ldat)
def local_data(arr):
return arr._data
def ibegin_from_shape(glob_shape, distaxis=0):
return (0,) * len(glob_shape)
def ibegin(arr):
return (0,) * arr._data.ndim
def np_allreduce_sum(arr):
return arr.copy()
def np_allreduce_min(arr):
return arr.copy()
def distaxis(arr):
return -1
def from_local_data(shape, arr, distaxis=-1):
return data_object(shape, arr)
def from_global_data(arr, sum_up=False):
return data_object(arr.shape, arr)
def to_global_data(arr):
return arr._data
def redistribute(arr, dist=None, nodist=None):
return arr.copy()
def default_distaxis():
return -1
def lock(arr):
arr._data.flags.writeable = False
def locked(arr):
return not arr._data.flags.writeable
......@@ -20,6 +20,7 @@ from __future__ import absolute_import, division, print_function
from .compat import *
from .domains.domain import Domain
from . import utilities
class DomainTuple(object):
......@@ -153,3 +154,7 @@ class DomainTuple(object):
if DomainTuple._scalarDomain is None:
DomainTuple._scalarDomain = DomainTuple.make(())
return DomainTuple._scalarDomain
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._dom)
return "DomainTuple:\n"+utilities.indent(subs)
......@@ -20,7 +20,7 @@ from __future__ import absolute_import, division, print_function
from .compat import *
from .domain_tuple import DomainTuple
from .utilities import frozendict
from .utilities import frozendict, indent
class MultiDomain(object):
......@@ -120,3 +120,8 @@ class MultiDomain(object):
res[key] = subdom
return MultiDomain.make(res)
def __repr__(self):
subs = "\n".join("{}:\n {}".format(key, dom.__repr__())
for key, dom in self.items())
return "MultiDomain:\n"+indent(subs)
......@@ -60,6 +60,12 @@ class EndomorphicOperator(LinearOperator):
raise NotImplementedError
def _dom(self, mode):
return self._domain
def _tgt(self, mode):
return self._domain
def _check_input(self, x, mode):
if self.domain != x.domain:
......@@ -22,8 +22,6 @@ import numpy as np
from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field
from ..multi_field import MultiField
from ..sugar import full
from .endomorphic_operator import EndomorphicOperator
......@@ -82,13 +80,17 @@ class ScalingOperator(EndomorphicOperator):
fct = 1./fct
return ScalingOperator(fct, self._domain)
def draw_sample(self, from_inverse=False, dtype=np.float64):
def _get_fct(self, from_inverse):
fct = self._factor
if fct.imag != 0. or fct.real < 0.:
raise ValueError("operator not positive definite")
if fct.real == 0. and from_inverse:
raise ValueError("operator not positive definite")
fct = 1./np.sqrt(fct) if from_inverse else np.sqrt(fct)
cls = Field if isinstance(self._domain, DomainTuple) else MultiField
return cls.from_random(
random_type="normal", domain=self._domain, std=fct, dtype=dtype)
if (fct.imag != 0. or fct.real < 0. or
(fct.real == 0. and from_inverse)):
raise ValueError("operator not positive definite")
return 1./np.sqrt(fct) if from_inverse else np.sqrt(fct)
def process_sample(self, samp, from_inverse):
return samp*self._get_fct(from_inverse)
def draw_sample(self, from_inverse=False, dtype=np.float64):
from ..sugar import from_random
return from_random(random_type="normal", domain=self._domain,
std=self._get_fct(from_inverse), dtype=dtype)
......@@ -30,7 +30,7 @@ from .compat import *
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype"]
......@@ -363,3 +363,7 @@ def special_add_at(a, axis, index, b):
_iscomplex_tpl = (np.complex64, np.complex128)
def iscomplextype(dtype):
return dtype.type in _iscomplex_tpl
def indent(inp):
return "\n".join(((" "+s).rstrip() for s in inp.splitlines()))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment