Commit 6b54a5ce authored by Martin Reinecke's avatar Martin Reinecke
Browse files


parent 9a1c07ed
......@@ -54,6 +54,7 @@ from .operators.selection_operator import SelectionOperator
from .operators.slope_operator import SlopeOperator
from .operators.smoothness_operator import SmoothnessOperator
from .operators.symmetrizing_operator import SymmetrizingOperator
from .operators.vdot_operator import VdotOperator
from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
......@@ -58,7 +58,7 @@ class Linearization(object):
d1 = makeOp(self._val)
d2 = makeOp(other._val)
return Linearization(self._val*other._val,
self._jac*d2 + d1*other._jac)
d2*self._jac + d1*other._jac)
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2018 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
from ..compat import *
from ..domain_tuple import DomainTuple
from import UnstructuredDomain
from .linear_operator import LinearOperator
from ..sugar import full
class VdotOperator(LinearOperator):
def __init__(self, field):
super(VdotOperator, self).__init__()
self._field = field
self._target = DomainTuple.make(UnstructuredDomain(1))
def domain(self):
return self._field.domain
def target(self):
return self._target
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return full(self._target, self._field.vdot(x))
return self._field*x.to_global_data()[()]
......@@ -49,6 +49,13 @@ class Consistency_Tests(unittest.TestCase):
op = a+b
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testVdotOperator(self, sp, dtype):
op = ift.VdotOperator(ift.Field.from_random("normal", sp,
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
(ift.RGSpace((24, 31), distances=(0.4, 2.34),
harmonic=True), 3, 0),
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment