Commit 15fc7b6d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

remove factory method; create the proper operators directly

parent 788af006
Pipeline #17035 passed with stage
in 20 minutes and 8 seconds
...@@ -130,7 +130,7 @@ if __name__ == "__main__": ...@@ -130,7 +130,7 @@ if __name__ == "__main__":
plotter.plot.zmin = 0. plotter.plot.zmin = 0.
plotter.plot.zmax = 3. plotter.plot.zmax = 3.
sm = ift.SmoothingOperator.make(plot_space, sigma=0.03) sm = ift.FFTSmoothingOperator(plot_space, sigma=0.03)
plotter(ift.log(ift.sqrt(sm(ift.Field(plot_space, val=variance.val.real)))), path='uncertainty.html') plotter(ift.log(ift.sqrt(sm(ift.Field(plot_space, val=variance.val.real)))), path='uncertainty.html')
plotter.plot.zmin = np.real(mock_signal.min()); plotter.plot.zmin = np.real(mock_signal.min());
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
from nifty import Field,\ from nifty import Field,\
FieldArray FieldArray
from nifty.operators.linear_operator import LinearOperator from nifty.operators.linear_operator import LinearOperator
from nifty.operators.smoothing_operator import SmoothingOperator from nifty.operators.smoothing_operator import FFTSmoothingOperator
from nifty.operators.composed_operator import ComposedOperator from nifty.operators.composed_operator import ComposedOperator
from nifty.operators.diagonal_operator import DiagonalOperator from nifty.operators.diagonal_operator import DiagonalOperator
...@@ -81,7 +81,7 @@ class ResponseOperator(LinearOperator): ...@@ -81,7 +81,7 @@ class ResponseOperator(LinearOperator):
"exposure do not match") "exposure do not match")
for ii in range(len(kernel_smoothing)): for ii in range(len(kernel_smoothing)):
kernel_smoothing[ii] = SmoothingOperator.make(self._domain[ii], kernel_smoothing[ii] = FFTSmoothingOperator(self._domain[ii],
sigma=sigma[ii]) sigma=sigma[ii])
kernel_exposure[ii] = DiagonalOperator(self._domain[ii], kernel_exposure[ii] = DiagonalOperator(self._domain[ii],
diagonal=exposure[ii]) diagonal=exposure[ii])
......
...@@ -16,4 +16,5 @@ ...@@ -16,4 +16,5 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .smoothing_operator import SmoothingOperator from .fft_smoothing_operator import FFTSmoothingOperator
from .direct_smoothing_operator import DirectSmoothingOperator
...@@ -83,34 +83,6 @@ class SmoothingOperator(EndomorphicOperator): ...@@ -83,34 +83,6 @@ class SmoothingOperator(EndomorphicOperator):
""" """
@staticmethod
def make(domain, sigma, log_distances=False, default_spaces=None):
_fft_smoothing_spaces = [RGSpace,
GLSpace,
HPSpace]
_direct_smoothing_spaces = [PowerSpace]
domain = SmoothingOperator._parse_domain(domain)
if len(domain) != 1:
raise ValueError("SmoothingOperator only accepts exactly one "
"space as input domain.")
if np.any([isinstance(domain[0], sp)
for sp in _fft_smoothing_spaces]):
from .fft_smoothing_operator import FFTSmoothingOperator
return FFTSmoothingOperator (domain, sigma, default_spaces)
elif np.any([isinstance(domain[0], sp)
for sp in _direct_smoothing_spaces]):
from .direct_smoothing_operator import DirectSmoothingOperator
return DirectSmoothingOperator (domain, sigma, log_distances,\
default_spaces)
else:
raise NotImplementedError("For the given Space smoothing "
" is not available.")
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain, sigma, log_distances=False, def __init__(self, domain, sigma, log_distances=False,
default_spaces=None): default_spaces=None):
......
...@@ -23,7 +23,8 @@ from numpy.testing import assert_equal, assert_allclose ...@@ -23,7 +23,8 @@ from numpy.testing import assert_equal, assert_allclose
from nifty import Field,\ from nifty import Field,\
RGSpace,\ RGSpace,\
PowerSpace,\ PowerSpace,\
SmoothingOperator FFTSmoothingOperator,\
DirectSmoothingOperator
from itertools import product from itertools import product
from test.common import expand from test.common import expand
...@@ -40,7 +41,7 @@ class SmoothingOperator_Tests(unittest.TestCase): ...@@ -40,7 +41,7 @@ class SmoothingOperator_Tests(unittest.TestCase):
@expand(product(spaces, [0., .5, 5.])) @expand(product(spaces, [0., .5, 5.]))
def test_property(self, space, sigma): def test_property(self, space, sigma):
op = SmoothingOperator.make(space, sigma=sigma) op = FFTSmoothingOperator(space, sigma=sigma)
if op.domain[0] != space: if op.domain[0] != space:
raise TypeError raise TypeError
if op.unitary != False: if op.unitary != False:
...@@ -54,7 +55,7 @@ class SmoothingOperator_Tests(unittest.TestCase): ...@@ -54,7 +55,7 @@ class SmoothingOperator_Tests(unittest.TestCase):
@expand(product(spaces, [0., .5, 5.])) @expand(product(spaces, [0., .5, 5.]))
def test_adjoint_times(self, space, sigma): def test_adjoint_times(self, space, sigma):
op = SmoothingOperator.make(space, sigma=sigma) op = FFTSmoothingOperator(space, sigma=sigma)
rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space)
rand2 = Field.from_random('normal', domain=space) rand2 = Field.from_random('normal', domain=space)
tt1 = rand1.vdot(op.times(rand2)) tt1 = rand1.vdot(op.times(rand2))
...@@ -63,7 +64,7 @@ class SmoothingOperator_Tests(unittest.TestCase): ...@@ -63,7 +64,7 @@ class SmoothingOperator_Tests(unittest.TestCase):
@expand(product(spaces, [0., .5, 5.])) @expand(product(spaces, [0., .5, 5.]))
def test_times(self, space, sigma): def test_times(self, space, sigma):
op = SmoothingOperator.make(space, sigma=sigma) op = FFTSmoothingOperator(space, sigma=sigma)
rand1 = Field(space, val=0.) rand1 = Field(space, val=0.)
rand1.val[0] = 1. rand1.val[0] = 1.
tt1 = op.times(rand1) tt1 = op.times(rand1)
...@@ -74,7 +75,7 @@ class SmoothingOperator_Tests(unittest.TestCase): ...@@ -74,7 +75,7 @@ class SmoothingOperator_Tests(unittest.TestCase):
def test_smooth_regular1(self, sz, d, sigma, tp): def test_smooth_regular1(self, sz, d, sigma, tp):
tol = _get_rtol(tp) tol = _get_rtol(tp)
sp = RGSpace(sz, harmonic=True, distances=d) sp = RGSpace(sz, harmonic=True, distances=d)
smo = SmoothingOperator.make(sp, sigma=sigma) smo = FFTSmoothingOperator(sp, sigma=sigma)
inp = Field.from_random(domain=sp, random_type='normal', std=1, mean=4, inp = Field.from_random(domain=sp, random_type='normal', std=1, mean=4,
dtype=tp) dtype=tp)
out = smo(inp) out = smo(inp)
...@@ -86,7 +87,7 @@ class SmoothingOperator_Tests(unittest.TestCase): ...@@ -86,7 +87,7 @@ class SmoothingOperator_Tests(unittest.TestCase):
def test_smooth_regular2(self, sz1, sz2, d1, d2, sigma, tp): def test_smooth_regular2(self, sz1, sz2, d1, d2, sigma, tp):
tol = _get_rtol(tp) tol = _get_rtol(tp)
sp = RGSpace([sz1, sz2], distances=[d1, d2], harmonic=True) sp = RGSpace([sz1, sz2], distances=[d1, d2], harmonic=True)
smo = SmoothingOperator.make(sp, sigma=sigma) smo = FFTSmoothingOperator(sp, sigma=sigma)
inp = Field.from_random(domain=sp, random_type='normal', std=1, mean=4, inp = Field.from_random(domain=sp, random_type='normal', std=1, mean=4,
dtype=tp) dtype=tp)
out = smo(inp) out = smo(inp)
...@@ -99,7 +100,7 @@ class SmoothingOperator_Tests(unittest.TestCase): ...@@ -99,7 +100,7 @@ class SmoothingOperator_Tests(unittest.TestCase):
tol = _get_rtol(tp) tol = _get_rtol(tp)
sp = RGSpace(sz, harmonic=True) sp = RGSpace(sz, harmonic=True)
ps = PowerSpace(sp, nbin=sz, logarithmic=log) ps = PowerSpace(sp, nbin=sz, logarithmic=log)
smo = SmoothingOperator.make(ps, sigma=sigma) smo = DirectSmoothingOperator(ps, sigma=sigma)
inp = Field.from_random(domain=ps, random_type='normal', std=1, mean=4, inp = Field.from_random(domain=ps, random_type='normal', std=1, mean=4,
dtype=tp) dtype=tp)
out = smo(inp) out = smo(inp)
...@@ -112,7 +113,7 @@ class SmoothingOperator_Tests(unittest.TestCase): ...@@ -112,7 +113,7 @@ class SmoothingOperator_Tests(unittest.TestCase):
tol = _get_rtol(tp) tol = _get_rtol(tp)
sp = RGSpace([sz1, sz2], harmonic=True) sp = RGSpace([sz1, sz2], harmonic=True)
ps = PowerSpace(sp, logarithmic=log) ps = PowerSpace(sp, logarithmic=log)
smo = SmoothingOperator.make(ps, sigma=sigma) smo = DirectSmoothingOperator(ps, sigma=sigma)
inp = Field.from_random(domain=ps, random_type='normal', std=1, mean=4, inp = Field.from_random(domain=ps, random_type='normal', std=1, mean=4,
dtype=tp) dtype=tp)
out = smo(inp) out = smo(inp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment