diff --git a/d2o/distributed_data_object.py b/d2o/distributed_data_object.py index 4d28701d244dc3e41158827000860c5bac4f366d..06e6ce844042604017f2fb733dc763665fb7f6be 100644 --- a/d2o/distributed_data_object.py +++ b/d2o/distributed_data_object.py @@ -17,7 +17,7 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. import numpy as np - +from keepers import Versionable from d2o.config import configuration as gc,\ dependency_injector as gdi @@ -35,7 +35,7 @@ about_warnings_cprint = lambda z: stdout.write(z + "\n"); stdout.flush() about_infos_cprint = lambda z: stdout.write(z + "\n"); stdout.flush() -class distributed_data_object(object): +class distributed_data_object(Versionable, object): """A multidimensional array with modular MPI-based distribution schemes. The purpose of a distributed_data_object (d2o) is to provide the user @@ -1880,6 +1880,30 @@ class distributed_data_object(object): self.data = self.distributor.load_data(alias, path) + def _to_hdf5(self, hdf5_group): + if self.distribution_strategy not in STRATEGIES['global']: + raise ValueError( + "Only global-type distributed_data_objects can be versioned.") + + if self.dtype is np.dtype(np.complex256): + raise AttributeError( + "Datatype complex256 is not supported by hdf5.") + + hdf5_group.attrs['distribution_strategy'] = self.distribution_strategy + hdf5_dataset = hdf5_group.create_dataset('data', + shape=self.shape, + dtype=self.dtype) + self.distributor._data_to_hdf5(hdf5_dataset, self.data) + + @classmethod + def _from_hdf5(cls, hdf5_group, repository): + distribution_strategy = hdf5_group.attrs['distribution_strategy'] + dataset = hdf5_group['data'] + result_d2o = distributed_data_object( + dataset, + distribution_strategy=distribution_strategy) + return result_d2o + class EmptyD2o(distributed_data_object): def __init__(self): diff --git a/d2o/distributor_factory.py b/d2o/distributor_factory.py index 9d8b3b81bc206e246b8b056ad2db28a9f90bdafc..a413add232b86f1f71f6691e442483314497f3f0 100644 --- a/d2o/distributor_factory.py +++ b/d2o/distributor_factory.py @@ -1980,6 +1980,10 @@ class _slicing_distributor(distributor): # close the file f.close() return data + + def _data_to_hdf5(self, hdf5_dataset, data): + hdf5_dataset[self.local_start:self.local_end] = data + else: def save_data(self, *args, **kwargs): raise ImportError(about_cstring( @@ -1989,6 +1993,10 @@ class _slicing_distributor(distributor): raise ImportError(about_cstring( "ERROR: h5py is not available")) + def _data_to_hdf5(self, *args, **kwargs): + raise ImportError(about_cstring( + "ERROR: h5py is not available")) + def get_iter(self, d2o): return d2o_slicing_iter(d2o) @@ -2311,7 +2319,8 @@ class _not_distributor(distributor): shape=self.global_shape, dtype=self.dtype) # write the data - dset[:] = data + if comm.rank == 0: + dset[:] = data # close the file f.close() @@ -2340,6 +2349,11 @@ class _not_distributor(distributor): # close the file f.close() return data + + def _data_to_hdf5(self, hdf5_dataset, data): + if self.comm.rank == 0: + hdf5_dataset[:] = data + else: def save_data(self, *args, **kwargs): raise ImportError(about_cstring( @@ -2349,6 +2363,10 @@ class _not_distributor(distributor): raise ImportError(about_cstring( "ERROR: h5py is not available")) + def _data_to_hdf5(self, *args, **kwargs): + raise ImportError(about_cstring( + "ERROR: h5py is not available")) + def get_iter(self, d2o): return d2o_not_iter(d2o)