Commit 79df4e2b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

performance

parent b719e4ae
......@@ -26,6 +26,15 @@ from mpi4py import MPI
from ..compat import *
from .random import Random
__all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"empty", "zeros", "ones", "empty_like", "vdot", "exp",
"log", "tanh", "sqrt", "from_object", "from_random",
"local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum",
"np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy",
"lock", "locked", "uniform_full", "transpose"]
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
rank = _comm.Get_rank()
......@@ -243,6 +252,12 @@ def full(shape, fill_value, dtype=None, distaxis=0):
fill_value, dtype), distaxis)
def uniform_full(shape, fill_value, dtype=None, distaxis=0):
return data_object(
shape, np.broadcast_to(fill_value, local_shape(shape, distaxis)),
distaxis)
def empty(shape, dtype=None, distaxis=0):
return data_object(shape, np.empty(local_shape(shape, distaxis),
dtype), distaxis)
......
......@@ -25,6 +25,15 @@ from numpy import ones, sqrt, tanh, vdot, zeros
from .random import Random
__all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"empty", "zeros", "ones", "empty_like", "vdot", "exp",
"log", "tanh", "sqrt", "from_object", "from_random",
"local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum",
"np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy",
"lock", "locked", "uniform_full"]
ntask = 1
rank = 0
master = True
......@@ -115,3 +124,7 @@ def lock(arr):
def locked(arr):
return not arr.flags.writeable
def uniform_full(shape, fill_value, dtype=None, distaxis=-1):
return np.broadcast_to(fill_value, shape)
......@@ -28,12 +28,3 @@ try:
from .data_objects.distributed_do import *
except ImportError:
from .data_objects.numpy_do import *
__all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"empty", "zeros", "ones", "empty_like", "vdot", "exp",
"log", "tanh", "sqrt", "from_object", "from_random",
"local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum",
"np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy",
"lock", "locked"]
......@@ -47,11 +47,13 @@ class Field(object):
"""
def __init__(self, domain, val):
self._uni = None
if not isinstance(domain, DomainTuple):
raise TypeError("domain must be of type DomainTuple")
if not isinstance(val, dobj.data_object):
if np.isscalar(val):
val = dobj.from_local_data((), np.full((), val))
self._uni = val
val = dobj.uniform_full(domain.shape, val)
else:
raise TypeError("val must be of type dobj.data_object")
if domain.shape != val.shape:
......@@ -88,7 +90,7 @@ class Field(object):
if not (np.isreal(val) or np.iscomplex(val)):
raise TypeError("need arithmetic scalar")
domain = DomainTuple.make(domain)
return Field(domain, dobj.full(domain.shape, fill_value=val))
return Field(domain, val)
@staticmethod
def from_global_data(domain, arr, sum_up=False):
......@@ -391,10 +393,14 @@ class Field(object):
return self
def __neg__(self):
return Field(self._domain, -self._val)
if self._uni is None:
return Field(self._domain, -self._val)
return Field(self._domain, -self._uni)
def __abs__(self):
return Field(self._domain, abs(self._val))
if self._uni is None:
return Field(self._domain, abs(self._val))
return Field(self._domain, abs(self._uni))
def _contraction_helper(self, op, spaces):
if spaces is None:
......@@ -621,7 +627,9 @@ class Field(object):
return self + other
def positive_tanh(self):
return 0.5*(1.+self.tanh())
if self._uni is None:
return 0.5*(1.+self.tanh())
return Field(self._domain, 0.5*(1.+np.tanh(self._uni)))
for op in ["__add__", "__radd__",
......@@ -662,7 +670,11 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
for f in ["sqrt", "exp", "log", "tanh"]:
def func(f):
def func2(self):
fu = getattr(dobj, f)
return Field(domain=self._domain, val=fu(self.val))
if self._uni is None:
fu = getattr(dobj, f)
return Field(domain=self._domain, val=fu(self.val))
else:
fu = getattr(np, f)
return Field(domain=self._domain, val=fu(self._uni))
return func2
setattr(Field, f, func(f))
......@@ -22,6 +22,7 @@ import numpy as np
from ..compat import *
from .linear_operator import LinearOperator
from .null_operator import NullOperator
class ChainOperator(LinearOperator):
......@@ -52,14 +53,16 @@ class ChainOperator(LinearOperator):
else:
opsnew.append(op)
ops = opsnew
# Step 2.5: check for NullOperators
if any(isinstance(op, NullOperator) for op in ops):
ops = (NullOperator(ops[-1].domain, ops[0].target),)
# Step 3: collect ScalingOperators
fct = 1.
opsnew = []
lastdom = ops[-1].domain
for op in ops:
if (isinstance(op, ScalingOperator) and
not np.issubdtype(type(op._factor), np.complexfloating)):
fct *= op._factor
if (isinstance(op, ScalingOperator) and op._factor.imag == 0):
fct *= op._factor.real
else:
opsnew.append(op)
if fct != 1.:
......
......@@ -4,12 +4,12 @@ from ..compat import *
from .linear_operator import LinearOperator
from ..multi.multi_domain import MultiDomain
from ..multi.multi_field import MultiField
from ..field import Field
class FieldAdapter(LinearOperator):
def __init__(self, dom, name_dom):
self._domain = MultiDomain.make(dom)
self._smalldom = MultiDomain.make({name_dom: self._domain[name_dom]})
self._name = name_dom
self._target = dom[name_dom]
......@@ -23,12 +23,13 @@ class FieldAdapter(LinearOperator):
@property
def capability(self):
return self._all_ops
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x[self._name]
tmp = MultiField(self._smalldom, (x,))
return tmp.unite(MultiField.full(self._domain, 0.))
values = tuple(Field.full(dom, 0.) if key != self._name else x
for key, dom in self._domain.items())
return MultiField(self._domain, values)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment