projection_operator.py 3.51 KB
Newer Older
Theo Steininger's avatar
Theo Steininger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# -*- coding: utf-8 -*-
import numpy as np

from nifty.field import Field

from nifty.operators.endomorphic_operator import EndomorphicOperator


class ProjectionOperator(EndomorphicOperator):

    # ---Overwritten properties and methods---

    def __init__(self, projection_field):
        if not isinstance(projection_field, Field):
            raise TypeError("The projection_field must be a NIFTy-Field"
                            "instance.")
        self._projection_field = projection_field
        self._unitary = None

    def _times(self, x, spaces):
        # if the domain matches directly
        # -> multiply the fields directly
        if x.domain == self.domain:
            # here the actual multiplication takes place
            dotted = (self._projection_field * x).sum()
            return self._projection_field * dotted

        # if the distribution_strategy of self is sub-slice compatible to
        # the one of x, reshape the local data of self and apply it directly
        active_axes = []
        if spaces is None:
            active_axes = range(len(x.shape))
        else:
            for space_index in spaces:
                active_axes += x.domain_axes[space_index]

        axes_local_distribution_strategy = \
            x.val.get_axes_local_distribution_strategy(active_axes)
        if axes_local_distribution_strategy == \
           self._projection_field.distribution_strategy:
            local_projection_vector = \
                self._projection_field.val.get_local_data(copy=False)
        else:
            # create an array that is sub-slice compatible
            self.logger.warn("The input field is not sub-slice compatible to "
                             "the distribution strategy of the operator. "
                             "Performing an probably expensive "
                             "redistribution.")
            redistr_projection_val = self._projection_field.val.copy(
                distribution_strategy=axes_local_distribution_strategy)
            local_projection_vector = \
                redistr_projection_val.get_local_data(copy=False)

        local_x = x.val.get_local_data(copy=False)

        l = len(local_projection_vector.shape)
        sublist_projector = range(l)
        sublist_x = np.arange(len(local_x.shape)) + l

        for i in xrange(l):
            a = active_axes[i]
            sublist_x[a] = i

        dotted = np.einsum(local_projection_vector, sublist_projector,
                           local_x, sublist_x)

        # get those elements from sublist_x that haven't got contracted
        sublist_dotted = sublist_x[sublist_x >= l]

        remultiplied = np.einsum(local_projection_vector, sublist_projector,
                                 dotted, sublist_dotted,
                                 sublist_x)
        result_field = x.copy_empty(dtype=remultiplied.dtype)
        result_field.val.set_local_data(remultiplied, copy=False)
        return result_field

    def _inverse_times(self, x, spaces):
        raise NotImplementedError("The ProjectionOperator is a singular "
                                  "operator and therefore has no inverse.")

    # ---Mandatory properties and methods---

    @property
    def domain(self):
        return self._projection_field.domain

    @property
    def implemented(self):
        return True

    @property
    def unitary(self):
        if self._unitary is None:
            self._unitary = (self._projection_field.val == 1).all()
        return self._unitary

    @property
    def symmetric(self):
        return True