Commit 48d5c64f authored by Ultima's avatar Ultima
Browse files

Implemented fast_copy routines for fields and d2os.

parent 8afcedf6
......@@ -2056,6 +2056,20 @@ class field(object):
new_field.set_val(new_val=copied_val)
return new_field
def _fast_copy_empty(self):
# make an empty field
new_field = EmptyField()
# repair its class
new_field.__class__ = self.__class__
# copy domain, codomain, ishape and val
for key, value in self.__dict__.items():
if key != 'val':
new_field.__dict__[key] = value
else:
new_field.__dict__[key] = \
self.domain.unary_operation(self.val, op='copy_empty')
return new_field
def copy_empty(self, domain=None, codomain=None, ishape=None, **kwargs):
if domain is None:
domain = self.domain
......@@ -2063,6 +2077,13 @@ class field(object):
codomain = self.codomain
if ishape is None:
ishape = self.ishape
if (domain is self.domain and
codomain is self.codomain and
ishape == self.ishape and
kwargs == {}):
new_field = self._fast_copy_empty()
else:
new_field = field(domain=domain, codomain=codomain, ishape=ishape,
**kwargs)
return new_field
......@@ -3054,3 +3075,8 @@ class field(object):
def __gt__(self, other):
return self._binary_helper(other, op='gt')
class EmptyField(field):
def __init__(self):
pass
\ No newline at end of file
......@@ -148,6 +148,19 @@ class distributed_data_object(object):
hermitian=self.hermitian)
return new_d2o
def _fast_copy_empty(self):
# make an empty d2o
new_copy = EmptyD2o()
# repair its class
new_copy.__class__ = self.__class__
# now copy everthing in the __dict__ except for the data array
for key, value in self.__dict__.items():
if key != 'data':
new_copy.__dict__[key] = value
else:
new_copy.__dict__[key] = np.empty_like(value)
return new_copy
def copy(self, dtype=None, distribution_strategy=None, **kwargs):
temp_d2o = self.copy_empty(dtype=dtype,
distribution_strategy=distribution_strategy,
......@@ -178,9 +191,19 @@ class distributed_data_object(object):
local_shape = self.local_shape
if dtype is None:
dtype = self.dtype
else:
dtype = np.dtype(dtype)
if distribution_strategy is None:
distribution_strategy = self.distribution_strategy
# check if all parameters remain the same -> use the _fast_copy_empty
if (global_shape == self.shape and
local_shape == self.local_shape and
dtype == self.dtype and
distribution_strategy == self.distribution_strategy and
kwargs == self.init_kwargs):
return self._fast_copy_empty()
kwargs.update(self.init_kwargs)
temp_d2o = distributed_data_object(
......@@ -325,11 +348,6 @@ class distributed_data_object(object):
return temp_d2o
def _builtin_helper(self, operator, other, inplace=False):
if isinstance(other, distributed_data_object):
other_is_real = other.isreal()
else:
other_is_real = np.isreal(other)
# Case 1: other is not a scalar
if not (np.isscalar(other) or np.shape(other) == (1,)):
try:
......@@ -340,14 +358,29 @@ class distributed_data_object(object):
temp_data = self.distributor.extract_local_data(other)
temp_data = operator(temp_data)
# Case 2: other is a real scalar -> preserve hermitianity
elif other_is_real or (self.dtype not in (np.dtype('complex128'),
np.dtype('complex256'))):
# Case 2: other is a scalar
else:
if isinstance(other, distributed_data_object):
other_is_real = other.isreal()
else:
other_is_real = np.isreal(other)
if other_is_real:
hermitian_Q = self.hermitian
temp_data = operator(other)
# Case 3: other is complex
else:
hermitian_Q = False
# #Case 2.1 self is real
# if (self.dtype not in (np.dtype('complex128'),
# np.dtype('complex256'))):
# hermitian_Q = self.hermitian
# elif
#
# # Case 3: other is complex
# else:
# hermitian_Q = False
# temp_data = operator(other)
#
temp_data = operator(other)
# write the new data into a new distributed_data_object
if inplace is True:
......@@ -356,9 +389,8 @@ class distributed_data_object(object):
# use common datatype for self and other
new_dtype = np.dtype(np.find_common_type((self.dtype,),
(temp_data.dtype,)))
temp_d2o = self.copy_empty(
dtype=new_dtype)
temp_d2o.set_local_data(data=temp_data)
temp_d2o = self.copy_empty(dtype=new_dtype)
temp_d2o.set_local_data(data=temp_data, copy=False)
temp_d2o.hermitian = hermitian_Q
return temp_d2o
......@@ -2871,7 +2903,9 @@ class d2o_slicing_iter(d2o_iter):
self.d2o.data.flatten(),
root=self.active_node)
class EmptyD2o(distributed_data_object):
def __init__(self):
pass
......
......@@ -76,8 +76,12 @@ class los_response(operator):
number_of_dimensions = len(starts)
# if zero_point is None:
# zero_point = [0.] * number_of_dimensions
if zero_point is None:
zero_point = [0.] * number_of_dimensions
phys_middle = (np.array(domain.get_vol(split=True)) *
domain.get_shape()) / 2.
zero_point = phys_middle * domain.paradict['zerocenter']
if np.shape(zero_point) != (number_of_dimensions,):
raise ValueError(about._errors.cstring(
......
......@@ -224,6 +224,7 @@ class conjugate_gradient(object):
"""
self.x = self.b.copy_empty()
self.x.set_val(new_val = 0)
self.x.set_val(new_val = x0)
if self.W is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment