outer_product_operator.py 1.7 KB
 1 2 3 4 5 6 7 8 9 10 11 12 13 ``````# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # `````` Martin Reinecke committed Jan 07, 2019 14 ``````# Copyright(C) 2013-2019 Max-Planck-Society `````` 15 ``````# `````` Martin Reinecke committed Jan 07, 2019 16 ``````# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. `````` 17 18 19 20 21 22 23 24 25 `````` import numpy as np from ..domain_tuple import DomainTuple from ..field import Field from .linear_operator import LinearOperator class OuterProduct(LinearOperator): `````` 26 `````` """Performs the point-wise outer product of two fields. `````` 27 28 29 30 `````` Parameters --------- domain: DomainTuple, the domain of the input field `````` Rouven Lemmerz committed May 19, 2020 31 `````` field: Field `````` 32 33 `````` --------- """ `````` Rouven Lemmerz committed May 19, 2020 34 35 `````` def __init__(self, domain, field): self._domain = DomainTuple.make(domain) `````` 36 `````` self._field = field `````` Martin Reinecke committed Sep 18, 2018 37 38 `````` self._target = DomainTuple.make( tuple(sub_d for sub_d in field.domain._dom + domain._dom)) `````` 39 40 41 42 43 `````` self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: `````` Martin Reinecke committed Dec 04, 2019 44 `````` return Field( `````` Martin Reinecke committed Sep 18, 2018 45 `````` self._target, np.multiply.outer( `````` Martin Reinecke committed Dec 04, 2019 46 `````` self._field.val, x.val)) `````` 47 `````` axes = len(self._field.shape) `````` Martin Reinecke committed Dec 04, 2019 48 49 `````` return Field( self._domain, np.tensordot(self._field.val, x.val, axes))``````