From 8afcedf6ff46af3b40a48b807cc4a221a7ddf0e6 Mon Sep 17 00:00:00 2001 From: Ultima <theo.steininger@ultimanet.de> Date: Thu, 12 Nov 2015 17:03:00 +0100 Subject: [PATCH] Made a new propagator_operator --- demos/demo_excaliwir.py | 6 +- demos/demo_wf1.py | 36 ++-- nifty_core.py | 42 ++-- operators/nifty_minimization.py | 3 +- operators/nifty_operators.py | 371 +++++++++++++++++++++++--------- operators/nifty_probing.py | 43 ++-- test/test_nifty_spaces.py | 10 +- 7 files changed, 340 insertions(+), 171 deletions(-) diff --git a/demos/demo_excaliwir.py b/demos/demo_excaliwir.py index 8d25f1912..670963086 100644 --- a/demos/demo_excaliwir.py +++ b/demos/demo_excaliwir.py @@ -46,7 +46,7 @@ from nifty import * class problem(object): - def __init__(self, x_space, s2n=12, **kwargs): + def __init__(self, x_space, s2n=6, **kwargs): """ Sets up a Wiener filter problem. @@ -67,7 +67,7 @@ class problem(object): #self.k.set_power_indices(**kwargs) ## set some power spectrum - self.power = (lambda k: 42 / (k + 1) ** 3) + self.power = (lambda k: 42 / (k + 1) ** 2) ## define signal covariance self.S = power_operator(self.k, spec=self.power, bare=True) @@ -256,7 +256,7 @@ class problem(object): ##----------------------------------------------------------------------------- # if(__name__=="__main__"): - x = rg_space((128), zerocenter=True) + x = rg_space((1280), zerocenter=True) p = problem(x, log = False) about.warnings.off() ## pl.close("all") diff --git a/demos/demo_wf1.py b/demos/demo_wf1.py index 5543a479f..149634d36 100644 --- a/demos/demo_wf1.py +++ b/demos/demo_wf1.py @@ -37,14 +37,14 @@ import matplotlib as mpl mpl.use('Agg') import gc import imp -#nifty = imp.load_module('nifty', None, -# '/home/steininger/Downloads/nifty', ('','',5)) +nifty = imp.load_module('nifty', None, + '/home/steininger/Downloads/nifty', ('','',5)) from nifty import * # version 0.8.0 about.warnings.off() # some signal space; e.g., a two-dimensional regular grid -shape = [1024,] +shape = [1024] x_space = rg_space(shape) #y_space = point_space(1280*1280) #x_space = hp_space(32) @@ -91,25 +91,25 @@ m = D(j, W=S, tol=1E-8, limii=100, note=True) #temp_result = (D.inverse_times(m)-xi) -#n_power = x_space.enforce_power(s.var()/np.prod(shape)) -#s_power = S.get_power() +n_power = x_space.enforce_power(s.var()/np.prod(shape)) +s_power = S.get_power() -#s.plot(title="signal", save = 'plot_s.png') -#s.plot(title="signal power", power=True, other=power, -# mono=False, save = 'power_plot_s.png', nbin=1000, log=True, -# vmax = 100, vmin=10e-7) +s.plot(title="signal", save = 'plot_s.png') +s.plot(title="signal power", power=True, other=power, + mono=False, save = 'power_plot_s.png', nbin=1000, log=True, + vmax = 100, vmin=10e-7) -#d_ = field(x_space, val=d.val, target=k_space) -#d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png') +d_ = field(x_space, val=d.val, target=k_space) +d_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_d.png') -#n_ = field(x_space, val=n.val, target=k_space) -#n_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_n.png') +n_ = field(x_space, val=n.val, target=k_space) +n_.plot(title="data", vmin=s.min(), vmax=s.max(), save = 'plot_n.png') -# -#m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png') -#m.plot(title="reconstructed power", power=True, other=(n_power, s_power), -# save = 'power_plot_m.png', vmin=0.001, vmax=10, mono=False) -# + +m.plot(title="reconstructed map", vmin=s.min(), vmax=s.max(), save = 'plot_m.png') +m.plot(title="reconstructed power", power=True, other=(n_power, s_power), + save = 'power_plot_m.png', vmin=0.001, vmax=10, mono=False) + # diff --git a/nifty_core.py b/nifty_core.py index 9cd62debf..cea49da93 100644 --- a/nifty_core.py +++ b/nifty_core.py @@ -826,7 +826,7 @@ class point_space(space): self.comm = self._parse_comm(comm) self.discrete = True - self.harmonic = False +# self.harmonic = False self.distances = (np.float(1),) @property @@ -1387,7 +1387,7 @@ class point_space(space): if not isinstance(codomain, space): raise TypeError(about._errors.cstring( - "ERROR: invalid input. The given input is no nifty space.")) + "ERROR: invalid input. The given input is not a nifty space.")) if codomain == self: return True @@ -1830,7 +1830,6 @@ class point_space(space): string += 'datamodel: ' + str(self.datamodel) + "\n" string += 'comm: ' + self.comm.name + "\n" string += 'discrete: ' + str(self.discrete) + "\n" - string += 'harmonic: ' + str(self.harmonic) + "\n" string += 'distances: ' + str(self.distances) + "\n" return string @@ -1974,25 +1973,23 @@ class field(object): else: codomain = domain.get_codomain() - # Check if the given field lives in the same fourier-space as the - # new domain - if f.domain.harmonic != domain.harmonic: + # check for ishape + if ishape is None: + ishape = f.ishape + + # Check if the given field lives in a space which is compatible to the + # given domain + if f.domain != domain: # Try to transform the given field to the given domain/codomain f = f.transform(new_domain=domain, new_codomain=codomain) - # Check if the domain is now really the same. - # This is necessary since iso-fourier-conversion is not implemented - if f.domain == domain: - self._init_from_array(domain=domain, - val=f.val, - codomain=codomain, - ishape=ishape, - copy=copy, - **kwargs) - else: - raise ValueError(about._errors.cstring( - "ERROR: Incompatible domain given.")) + self._init_from_array(domain=domain, + val=f.val, + codomain=codomain, + ishape=ishape, + copy=copy, + **kwargs) def _init_from_array(self, val, domain, codomain, ishape, copy, **kwargs): # check domain @@ -2961,10 +2958,15 @@ class field(object): return new_field def _binary_helper(self, other, op='None', inplace=False): - # the other object could be a field/operator. Try to extract its data. + # if other is a field, make sure that the domains match + if isinstance(other, field): + other = field(domain=self.domain, + val=other, + codomain=self.codomain, + copy=False) try: other_val = other.get_val() - except(AttributeError): + except AttributeError: other_val = other # bring other_val into the right shape diff --git a/operators/nifty_minimization.py b/operators/nifty_minimization.py index 7e321af74..cd8b9b70c 100644 --- a/operators/nifty_minimization.py +++ b/operators/nifty_minimization.py @@ -252,8 +252,7 @@ class conjugate_gradient(object): convergence = 0 ii = 1 while(True): - from time import sleep - sleep(0.5) + # print ('gamma', gamma) q = self.A(d) # print ('q', q.val) diff --git a/operators/nifty_operators.py b/operators/nifty_operators.py index 6b800b170..67a463dc0 100644 --- a/operators/nifty_operators.py +++ b/operators/nifty_operators.py @@ -157,7 +157,7 @@ class operator(object): # If the operator is symmetric or unitary, we know that the operator # must be square - if self.sym is True or self.uni is True: + if self.sym or self.uni: target = self.domain cotarget = self.codomain if target is not None: @@ -225,49 +225,33 @@ class operator(object): "ERROR: no generic instance method 'inverse_adjoint_multiply'.")) def _briefing(self, x, domain, codomain, inverse): - # inspect x - if not isinstance(x, field): - y = field(domain, codomain=codomain, val=x) - else: - # check x.domain - if x.domain == domain: - y = x - else: - if x.domain.harmonic != domain.harmonic: - y = x.transform(codomain=domain) - else: - y = x.copy(domain=domain, codomain=codomain) + # make sure, that the result_field of the briefing lives in the + # given domain and codomain + result_field = field(domain=domain, val=x, codomain=codomain, + copy=False) - # weight if ... + # weight if necessary if (not self.imp) and (not domain.discrete) and (not inverse): - y = y.weight(power=1) - return y + result_field = result_field.weight(power=1) + return result_field def _debriefing(self, x, y, target, cotarget, inverse): - # > evaluates x and y after `multiply` - if y is None: - return None - else: - # inspect y - if not isinstance(y, field): - y = field(target, codomain=cotarget, val=y) - elif y.domain != target: - raise ValueError(about._errors.cstring( - "ERROR: invalid output domain.")) - # weight if ... - if (not self.imp) and (not target.discrete) and inverse: - y = y.weight(power=-1) - # inspect x - if isinstance(x, field): - # repair if the originally field was living in the codomain - # of the operators domain - if self.domain == self.target and\ - x.codomain == self.domain and\ - x.codomain != x.domain: - y = y.transform(codomain=x.domain) - if x.domain == y.domain and (x.codomain != y.codomain): - y.set_codomain(new_codomain=x.codomain) - return y + # The debriefing takes care that the result field lives in the same + # fourier-type domain as the input field + assert(isinstance(y, field)) + + # weight if necessary + if (not self.imp) and (not target.discrete) and inverse: + y = y.weight(power=-1) + + # if the operators domain as well as the target have the harmonic + # attribute, try to match the result_field to the input_field + if hasattr(self.domain, 'harmonic') and \ + hasattr(self.target, 'harmonic'): + if x.domain.harmonic != y.domain.harmonic: + y = y.transform() + + return y def times(self, x, **kwargs): """ @@ -1151,7 +1135,7 @@ class diagonal_operator(operator): self.target = self.domain self.cotarget = self.codomain self.imp = True - self.set_diag(new_diag=diag) + self.set_diag(new_diag=diag, bare=bare) def set_diag(self, new_diag, bare=False): """ @@ -1605,14 +1589,16 @@ class diagonal_operator(operator): else: codomain = domain.get_codomain() - if domain.harmonic != self.domain.harmonic: - temp_field = temp_field.transform(codomain=domain) + return field(domain=domain, val=temp_field, codomain=codomain) - if self.domain == domain and self.codomain == codomain: - return temp_field - else: - return temp_field.copy(domain=domain, - codomain=codomain) +# if domain.harmonic != self.domain.harmonic: +# temp_field = temp_field.transform(new_domain=domain) +# +# if self.domain == domain and self.codomain == codomain: +# return temp_field +# else: +# return temp_field.copy(domain=domain, +# codomain=codomain) def __repr__(self): return "<nifty_core.diagonal_operator>" @@ -2388,16 +2374,18 @@ class projection_operator(operator): # check if field is in the same signal/harmonic space as the # domain of the projection operator if self.domain != x.domain: - x = x.transform(codomain=self.domain) + x = x.transform(new_domain=self.domain) vecvec = vecvec_operator(val=x) return self.pseudo_tr(x=vecvec, axis=axis, **kwargs) # Case 2: x is an operator # -> take the diagonal elif isinstance(x, operator): - working_field = x.diag(bare=False) + working_field = x.diag(bare=False, + domain=self.domain, + codomain=self.codomain) if self.domain != working_field.domain: - working_field = working_field.transform(codomain=self.domain) + working_field = working_field.transform(new_domain=self.domain) # Case 3: x is something else else: @@ -2944,49 +2932,53 @@ class response_operator(operator): codomain=self.codomain) def _briefing(self, x, domain, codomain, inverse): - # inspect x - if not isinstance(x, field): - y = field(domain, codomain=codomain, val=x) - else: - # check x.domain - if x.domain == domain: - y = x - else: - if x.domain.harmonic != domain.harmonic: - y = x.transform(codomain=domain) - else: - y = x.copy(domain=domain, codomain=codomain) + # make sure, that the result_field of the briefing lives in the + # given domain and codomain + result_field = field(domain=domain, val=x, codomain=codomain, + copy=False) - # weight if ... + # weight if necessary if (not self.imp) and (not domain.discrete) and (not inverse) and \ self.den: - y = y.weight(power=1) - return y + result_field = result_field.weight(power=1) + return result_field def _debriefing(self, x, y, target, cotarget, inverse): - # > evaluates x and y after `multiply` - if y is None: - return None - else: - # inspect y - if not isinstance(y, field): - y = field(target, codomain=cotarget, val=y) - elif y.domain != target: - raise ValueError(about._errors.cstring( - "ERROR: invalid output domain.")) - # weight if ... - if (not self.imp) and (not target.discrete) and \ - (not self.den ^ inverse): - y = y.weight(power=-1) - # inspect x - if isinstance(x, field): - # repair if the originally field was living in the codomain - # of the operators domain - if self.domain == self.target == x.codomain: - y = y.transform(codomain=x.domain) - if x.domain == y.domain and (x.codomain != y.codomain): - y.set_codomain(new_codomain=x.codomain) - return y + # The debriefing takes care that the result field lives in the same + # fourier-type domain as the input field + assert(isinstance(y, field)) + + # weight if necessary + if (not self.imp) and (not target.discrete) and \ + (not self.den ^ inverse): + y = y.weight(power=-1) + + return y +# +# +# # > evaluates x and y after `multiply` +# if y is None: +# return None +# else: +# # inspect y +# if not isinstance(y, field): +# y = field(target, codomain=cotarget, val=y) +# elif y.domain != target: +# raise ValueError(about._errors.cstring( +# "ERROR: invalid output domain.")) +# # weight if ... +# if (not self.imp) and (not target.discrete) and \ +# (not self.den ^ inverse): +# y = y.weight(power=-1) +# # inspect x +# if isinstance(x, field): +# # repair if the originally field was living in the codomain +# # of the operators domain +# if self.domain == self.target == x.codomain: +# y = y.transform(new_domain=x.domain) +# if x.domain == y.domain and (x.codomain != y.codomain): +# y.set_codomain(new_codomain=x.codomain) +# return y def __repr__(self): return "<nifty_core.response_operator>" @@ -3159,9 +3151,11 @@ class invertible_operator(operator): if not force or x_ is None: return None about.warnings.cprint("WARNING: conjugate gradient failed.") - # weight if ... - if not self.imp: # continiuos domain/target - x_.weight(power=-1, overwrite=True) + # TODO: A weighting here shoud be wrong, as this is done by + # the (de)briefing methods -> Check! +# # weight if ... +# if not self.imp: # continiuos domain/target +# x_.weight(power=-1, overwrite=True) return x_ def _inverse_multiply(self, x, force=False, W=None, spam=None, reset=None, @@ -3230,15 +3224,18 @@ class invertible_operator(operator): if not force or x_ is None: return None about.warnings.cprint("WARNING: conjugate gradient failed.") - # weight if ... - if not self.imp: # continiuos domain/target - x_.weight(power=1, overwrite=True) + # TODO: A weighting here shoud be wrong, as this is done by + # the (de)briefing methods -> Check! +# # weight if ... +# if not self.imp: # continiuos domain/target +# x_.weight(power=1, overwrite=True) return x_ def __repr__(self): return "<nifty_tools.invertible_operator>" + class propagator_operator(operator): """ .. __ @@ -3313,6 +3310,186 @@ class propagator_operator(operator): """ + def __init__(self, S=None, M=None, R=None, N=None): + """ + Sets the standard operator properties and `codomain`, `_A1`, `_A2`, + and `RN` if required. + + Parameters + ---------- + S : operator + Covariance of the signal prior. + M : operator + Likelihood contribution. + R : operator + Response operator translating signal to (noiseless) data. + N : operator + Covariance of the noise prior or the likelihood, respectively. + + """ + + # parse the signal prior covariance + if not isinstance(S, operator): + raise ValueError(about._errors.cstring( + "ERROR: The given S is not an operator.")) + + self.S = S + self.S_inverse_times = self.S.inverse_times + + # take signal-space domain from S as the domain for D + S_is_harmonic = False + if hasattr(S.domain, 'harmonic'): + if S.domain.harmonic: + S_is_harmonic = True + + if S_is_harmonic: + self.domain = S.codomain + self.codomain = S.domain + else: + self.domain = S.domain + self.codomain = S.codomain + + self.target = self.domain + self.cotarget = self.codomain + + # build up the likelihood contribution + (self.M_times, + M_domain, + M_codomain, + M_target, + M_cotarget) = self._build_likelihood_contribution(M, R, N) + + # assert that S and M have matching domains + if not (self.domain == M_domain and + self.codomain == M_codomain and + self.target == M_target and + self.cotarget == M_cotarget): + raise ValueError(about._errors.cstring( + "ERROR: The (co)domains and (co)targets of the prior " + + "signal covariance and the likelihood contribution must be " + + "the same in the sense of '=='.")) + + self.sym = True + self.uni = False + self.imp = True + + def _build_likelihood_contribution(self, M, R, N): + # if a M is given, return its times method and its domains + # supplier and discard R and N + if M is not None: + return (M.times, M.domain, M.codomain, M.target, M.cotarget) + + if N is not None: + if R is not None: + return (lambda z: R.adjoint_times(N.inverse_times(R.times(z))), + R.domain, R.codomain, R.domain, R.codomain) + else: + return (N.inverse_times, + N.domain, N.codomain, N.target, N.cotarget) + else: + raise ValueError(about._errors.cstring( + "ERROR: At least M or N must be given.")) + + def _multiply(self, x, W=None, spam=None, reset=None, note=False, + x0=None, tol=1E-4, clevel=1, limii=None, **kwargs): + + if W is None: + W = self.S + (result, convergence) = conjugate_gradient(self._inverse_multiply, + x, + W=W, + spam=spam, + reset=reset, + note=note)(x0=x0, + tol=tol, + clevel=clevel, + limii=limii) + # evaluate + if not convergence: + about.warnings.cprint("WARNING: conjugate gradient failed.") + + return result + + def _inverse_multiply(self, x, **kwargs): + result = self.S_inverse_times(x) + result += self.M_times(x) + return result + + +class propagator_operator_old(operator): + """ + .. __ + .. / /_ + .. _______ _____ ______ ______ ____ __ ____ __ ____ __ / _/ ______ _____ + .. / _ / / __/ / _ | / _ | / _ / / _ / / _ / / / / _ | / __/ + .. / /_/ / / / / /_/ / / /_/ / / /_/ / / /_/ / / /_/ / / /_ / /_/ / / / + .. / ____/ /__/ \______/ / ____/ \______| \___ / \______| \___/ \______/ /__/ operator class + .. /__/ /__/ /______/ + + NIFTY subclass for propagator operators (of a certain family) + + The propagator operators :math:`D` implemented here have an inverse + formulation like :math:`(S^{-1} + M)`, :math:`(S^{-1} + N^{-1})`, or + :math:`(S^{-1} + R^\dagger N^{-1} R)` as appearing in Wiener filter + theory. + + Parameters + ---------- + S : operator + Covariance of the signal prior. + M : operator + Likelihood contribution. + R : operator + Response operator translating signal to (noiseless) data. + N : operator + Covariance of the noise prior or the likelihood, respectively. + + See Also + -------- + conjugate_gradient + + Notes + ----- + The propagator will puzzle the operators `S` and `M` or `R`, `N` or + only `N` together in the predefined from, a domain is set + automatically. The application of the inverse is done by invoking a + conjugate gradient. + Note that changes to `S`, `M`, `R` or `N` auto-update the propagator. + + Examples + -------- + >>> f = field(rg_space(4), val=[2, 4, 6, 8]) + >>> S = power_operator(f.target, spec=1) + >>> N = diagonal_operator(f.domain, diag=1) + >>> D = propagator_operator(S=S, N=N) # D^{-1} = S^{-1} + N^{-1} + >>> D(f).val + array([ 1., 2., 3., 4.]) + + Attributes + ---------- + domain : space + A space wherein valid arguments live. + codomain : space + An alternative space wherein valid arguments live; commonly the + codomain of the `domain` attribute. + sym : bool + Indicates that the operator is self-adjoint. + uni : bool + Indicates that the operator is not unitary. + imp : bool + Indicates that volume weights are implemented in the `multiply` + instance methods. + target : space + The space wherein the operator output lives. + _A1 : {operator, function} + Application of :math:`S^{-1}` to a field. + _A2 : {operator, function} + Application of all operations not included in `A1` to a field. + RN : {2-tuple of operators}, *optional* + Contains `R` and `N` if given. + + """ + def __init__(self, S=None, M=None, R=None, N=None): """ Sets the standard operator properties and `codomain`, `_A1`, `_A2`, @@ -3433,7 +3610,7 @@ class propagator_operator(operator): return (x, True) # transform else: - return (x.transform(codomain=self.codomain, + return (x.transform(new_domain=self.codomain, overwrite=False), True) @@ -3444,7 +3621,7 @@ class propagator_operator(operator): elif isinstance(x, field): # repair ... if in_codomain == True and x.domain != self.codomain: - x_ = x_.transform(codomain=x.domain) # ... domain + x_ = x_.transform(new_domain=x.domain) # ... domain if x_.codomain != x.codomain: x_.set_codomain(new_codomain=x.codomain) # ... codomain return x_ diff --git a/operators/nifty_probing.py b/operators/nifty_probing.py index aee9d55e4..ea6279eb5 100644 --- a/operators/nifty_probing.py +++ b/operators/nifty_probing.py @@ -318,9 +318,6 @@ class prober(object): f = self.function(probe, **self.kwargs) return f - - - def finalize(self, sum_of_probes, sum_of_squares, num): """ Evaluates the probing results. @@ -391,7 +388,7 @@ class prober(object): for ii in xrange(self.nrun): print ('running probe ', ii) temp_probe = self.generate_probe() - temp_result = self.evaluate_probe(probe = temp_probe) + temp_result = self.evaluate_probe(probe=temp_probe) if temp_result is not None: sum_of_probes += temp_result @@ -403,7 +400,7 @@ class prober(object): # evaluate return self.finalize(sum_of_probes, sum_of_squares, num) - def __call__(self,loop=False,**kwargs): + def __call__(self, loop=False, **kwargs): """ Starts the probing process. @@ -428,15 +425,10 @@ class prober(object): self.configure(**kwargs) return self.probe() - - def __repr__(self): return "<nifty_core.probing>" - - - class _specialized_prober(object): def __init__(self, operator, domain=None, inverseQ=False, **kwargs): # remove a potentially supplied function keyword argument @@ -448,15 +440,15 @@ class _specialized_prober(object): about.warnings.cprint( "WARNING: Dropped the supplied function keyword-argument!") - if domain is None and inverseQ == False: + if domain is None and not inverseQ: kwargs.update({'domain': operator.domain}) - elif domain is None and inverseQ == True: + elif domain is None and inverseQ: kwargs.update({'domain': operator.target}) else: kwargs.update({'domain': domain}) self.operator = operator - self.prober = prober(function = self._probing_function, + self.prober = prober(function=self._probing_function, **kwargs) def _probing_function(self, probe): @@ -465,48 +457,45 @@ class _specialized_prober(object): def __call__(self, *args, **kwargs): return self.prober(*args, **kwargs) - def __getattr__(self, attr): return getattr(self.prober, attr) - - - class trace_prober(_specialized_prober): def __init__(self, operator, **kwargs): - super(trace_prober, self).__init__(operator = operator, + super(trace_prober, self).__init__(operator=operator, inverseQ=False, **kwargs) + def _probing_function(self, probe): return direct_dot(probe.conjugate(), self.operator.times(probe)) + class inverse_trace_prober(_specialized_prober): def __init__(self, operator, **kwargs): - super(inverse_trace_prober, self).__init__(operator = operator, + super(inverse_trace_prober, self).__init__(operator=operator, inverseQ=True, **kwargs) + def _probing_function(self, probe): return direct_dot(probe.conjugate(), self.operator.inverse_times(probe)) + class diagonal_prober(_specialized_prober): def __init__(self, **kwargs): - super(diagonal_prober, self).__init__(inverseQ = False, + super(diagonal_prober, self).__init__(inverseQ=False, **kwargs) + def _probing_function(self, probe): return probe.conjugate()*self.operator.times(probe) + class inverse_diagonal_prober(_specialized_prober): def __init__(self, operator, **kwargs): - super(inverse_diagonal_prober, self).__init__(operator = operator, + super(inverse_diagonal_prober, self).__init__(operator=operator, inverseQ=True, **kwargs) + def _probing_function(self, probe): return probe.conjugate()*self.operator.inverse_times(probe) - - - - - - diff --git a/test/test_nifty_spaces.py b/test/test_nifty_spaces.py index a6410eb83..b22ebbc31 100644 --- a/test/test_nifty_spaces.py +++ b/test/test_nifty_spaces.py @@ -275,10 +275,11 @@ class Test_Common_Point_Like_Space_Interface(unittest.TestCase): assert(isinstance(s.dtype, np.dtype)) assert(isinstance(s.datamodel, str)) assert(isinstance(s.discrete, bool)) - assert(isinstance(s.harmonic, bool)) +# assert(isinstance(s.harmonic, bool)) assert(isinstance(s.distances, tuple)) - if s.harmonic: - assert(isinstance(s.power_indices, power_indices)) + if hasattr(s, 'harmonic'): + if s.harmonic: + assert(isinstance(s.power_indices, power_indices)) @parameterized.expand(point_like_spaces, testcase_func_name=custom_name_func) @@ -325,7 +326,6 @@ class Test_Point_Space(unittest.TestCase): assert_equal(p.datamodel, datamodel) assert_equal(p.discrete, True) - assert_equal(p.harmonic, False) assert_equal(p.distances, (np.float(1.),)) ############################################################################### @@ -693,6 +693,7 @@ class Test_RG_Space(unittest.TestCase): harmonic=harmonic, fft_module=fft_module, datamodel=datamodel) + assert(isinstance(x.harmonic, bool)) assert_equal(x.get_shape(), shape) assert_equal(x.dtype, np.dtype('float64') if complexity == 0 else @@ -1022,6 +1023,7 @@ class Test_Lm_Space(unittest.TestCase): testcase_func_name=custom_name_func) def test_successfull_init(self, lmax, mmax, dtype, datamodel): l = lm_space(lmax, mmax=mmax, dtype=dtype, datamodel=datamodel) + assert(isinstance(l.harmonic, bool)) assert_equal(l.paradict['lmax'], lmax) if mmax is None or mmax > lmax: assert_equal(l.paradict['mmax'], lmax) -- GitLab