outer_product_operator.py 1.78 KB
Newer Older
 1 2 3 4 5 6 7 8 9 10 11 12 13 ``````# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # `````` Martin Reinecke committed Jan 07, 2019 14 ``````# Copyright(C) 2013-2019 Max-Planck-Society `````` 15 ``````# `````` Martin Reinecke committed Jan 07, 2019 16 ``````# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. `````` 17 18 19 20 21 22 23 24 25 `````` import numpy as np from ..domain_tuple import DomainTuple from ..field import Field from .linear_operator import LinearOperator class OuterProduct(LinearOperator): `````` 26 `````` """Performs the point-wise outer product of two fields. `````` 27 28 29 30 31 32 33 34 35 36 `````` Parameters --------- field: Field, domain: DomainTuple, the domain of the input field --------- """ def __init__(self, field, domain): self._domain = domain self._field = field `````` Martin Reinecke committed Sep 18, 2018 37 38 `````` self._target = DomainTuple.make( tuple(sub_d for sub_d in field.domain._dom + domain._dom)) `````` 39 40 41 42 43 `````` self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: `````` Martin Reinecke committed Sep 18, 2018 44 45 46 `````` return Field.from_global_data( self._target, np.multiply.outer( self._field.to_global_data(), x.to_global_data())) `````` 47 `````` axes = len(self._field.shape) `````` Martin Reinecke committed Sep 18, 2018 48 49 50 `````` return Field.from_global_data( self._domain, np.tensordot( self._field.to_global_data(), x.to_global_data(), axes))``````