# D2O # Copyright (C) 2016 Theo Steininger # # Author: Theo Steininger # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . from builtins import object import numpy as np class d2o_iter(object): def __init__(self, d2o): self.d2o = d2o self.i = 0 self.n = np.prod(self.d2o.shape) self.initialize_current_local_data() def __iter__(self): return self def __next__(self): if self.n == 0: raise StopIteration() self.update_current_local_data() if self.i < self.n: i = self.i self.i += 1 return self.current_local_data[i] else: raise StopIteration() def initialize_current_local_data(self): raise NotImplementedError def update_current_local_data(self): raise NotImplementedError class d2o_not_iter(d2o_iter): def initialize_current_local_data(self): self.current_local_data = self.d2o.data.flatten() def update_current_local_data(self): pass class d2o_slicing_iter(d2o_iter): def __init__(self, d2o): self.d2o = d2o self.i = 0 self.n = np.prod(self.d2o.shape) self.local_dim_offset_list = \ self.d2o.distributor.all_local_slices[:, 4] self.active_node = None self.initialize_current_local_data() def initialize_current_local_data(self): self.update_current_local_data() def update_current_local_data(self): new_active_node = np.searchsorted(self.local_dim_offset_list, self.i, side='right')-1 # new_active_node = min(new_active_node, self.d2o.comm.size-1) if self.active_node != new_active_node: self.active_node = new_active_node self.current_local_data = self.d2o.comm.bcast( self.d2o.get_local_data().flatten(), root=self.active_node)