d2o_iter.py 1.86 KB
Newer Older
theos's avatar
theos committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# -*- coding: utf-8 -*-

import numpy as np


class d2o_iter(object):
    def __init__(self, d2o):
        self.d2o = d2o
        self.i = 0
        self.n = np.prod(self.d2o.shape)
        self.initialize_current_local_data()

    def __iter__(self):
        return self

    def next(self):
        if self.n == 0:
            raise StopIteration()

        self.update_current_local_data()
        if self.i < self.n:
            i = self.i
            self.i += 1
            return self.current_local_data[i]
        else:
            raise StopIteration()

    def initialize_current_local_data(self):
        raise NotImplementedError

    def update_current_local_data(self):
        raise NotImplementedError


class d2o_not_iter(d2o_iter):
    def initialize_current_local_data(self):
        self.current_local_data = self.d2o.data.flatten()

    def update_current_local_data(self):
        pass


class d2o_slicing_iter(d2o_iter):
    def __init__(self, d2o):
        self.d2o = d2o
        self.i = 0
        self.n = np.prod(self.d2o.shape)
        self.local_dim_offset_list = \
            self.d2o.distributor.all_local_slices[:, 4]
        self.active_node = None

        self.initialize_current_local_data()

    def initialize_current_local_data(self):
        self.update_current_local_data()

    def update_current_local_data(self):
        new_active_node = np.searchsorted(self.local_dim_offset_list,
                                          self.i,
                                          side='right')-1
        # new_active_node = min(new_active_node, self.d2o.comm.size-1)
        if self.active_node != new_active_node:
            self.active_node = new_active_node

            self.current_local_data = self.d2o.comm.bcast(
                                        self.d2o.get_local_data().flatten(),
                                        root=self.active_node)