Commit e5b55ec6 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'explicit_codomains' into 'NIFTy_5'

allow explicit specification of codomains

See merge request ift/nifty-dev!195
parents 252a5232 46563b88
...@@ -24,7 +24,7 @@ from ..operators.harmonic_operators import HarmonicTransformOperator ...@@ -24,7 +24,7 @@ from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.simple_linear_operators import ducktape from ..operators.simple_linear_operators import ducktape
def CorrelatedField(target, amplitude_operator, name='xi'): def CorrelatedField(target, amplitude_operator, name='xi', codomain=None):
"""Constructs an operator which turns a white Gaussian excitation field """Constructs an operator which turns a white Gaussian excitation field
into a correlated field. into a correlated field.
...@@ -42,16 +42,21 @@ def CorrelatedField(target, amplitude_operator, name='xi'): ...@@ -42,16 +42,21 @@ def CorrelatedField(target, amplitude_operator, name='xi'):
amplitude_operator: Operator amplitude_operator: Operator
name : string name : string
:class:`MultiField` key for the xi-field. :class:`MultiField` key for the xi-field.
codomain : Domain
The codomain for target[0]. If not supplied, it is inferred.
Returns Returns
------- -------
Correlated field : Operator Operator
Correlated field
""" """
tgt = DomainTuple.make(target) tgt = DomainTuple.make(target)
if len(tgt) > 1: if len(tgt) > 1:
raise ValueError raise ValueError
h_space = tgt[0].get_default_codomain() if codomain is None:
ht = HarmonicTransformOperator(h_space, tgt[0]) codomain = tgt[0].get_default_codomain()
h_space = codomain
ht = HarmonicTransformOperator(h_space, target=tgt[0])
p_space = amplitude_operator.target[0] p_space = amplitude_operator.target[0]
power_distributor = PowerDistributor(h_space, p_space) power_distributor = PowerDistributor(h_space, p_space)
A = power_distributor(amplitude_operator) A = power_distributor(amplitude_operator)
...@@ -70,7 +75,7 @@ def MfCorrelatedField(target, amplitudes, name='xi'): ...@@ -70,7 +75,7 @@ def MfCorrelatedField(target, amplitudes, name='xi'):
Parameters Parameters
---------- ----------
target : Domain, DomainTuple or tuple of Domain target : Domain, DomainTuple or tuple of Domain
Target of the operator. Must contain exactly one space. Target of the operator. Must contain exactly two spaces.
amplitudes: iterable of Operator amplitudes: iterable of Operator
List of two amplitude operators. List of two amplitude operators.
name : string name : string
...@@ -78,7 +83,8 @@ def MfCorrelatedField(target, amplitudes, name='xi'): ...@@ -78,7 +83,8 @@ def MfCorrelatedField(target, amplitudes, name='xi'):
Returns Returns
------- -------
Correlated field : Operator Operator
Correlated field
""" """
tgt = DomainTuple.make(target) tgt = DomainTuple.make(target)
if len(tgt) != 2: if len(tgt) != 2:
...@@ -88,7 +94,7 @@ def MfCorrelatedField(target, amplitudes, name='xi'): ...@@ -88,7 +94,7 @@ def MfCorrelatedField(target, amplitudes, name='xi'):
hsp = DomainTuple.make([tt.get_default_codomain() for tt in tgt]) hsp = DomainTuple.make([tt.get_default_codomain() for tt in tgt])
ht1 = HarmonicTransformOperator(hsp, target=tgt[0], space=0) ht1 = HarmonicTransformOperator(hsp, target=tgt[0], space=0)
ht2 = HarmonicTransformOperator(ht1.target, space=1) ht2 = HarmonicTransformOperator(ht1.target, target=tgt[1], space=1)
ht = ht2 @ ht1 ht = ht2 @ ht1
psp = [aa.target[0] for aa in amplitudes] psp = [aa.target[0] for aa in amplitudes]
......
...@@ -43,7 +43,8 @@ def _make_dynamic_operator(target, ...@@ -43,7 +43,8 @@ def _make_dynamic_operator(target,
causal, causal,
minimum_phase, minimum_phase,
sigc=None, sigc=None,
quant=None): quant=None,
codomain=None):
if not isinstance(target, RGSpace): if not isinstance(target, RGSpace):
raise TypeError("RGSpace required") raise TypeError("RGSpace required")
if not target.harmonic: if not target.harmonic:
...@@ -64,7 +65,9 @@ def _make_dynamic_operator(target, ...@@ -64,7 +65,9 @@ def _make_dynamic_operator(target,
if cone and (sigc is None or quant is None): if cone and (sigc is None or quant is None):
raise RuntimeError raise RuntimeError
dom = DomainTuple.make(target.get_default_codomain()) if codomain is None:
codomain = target.get_default_codomain()
dom = DomainTuple.make(codomain)
ops = {} ops = {}
FFT = FFTOperator(dom) FFT = FFTOperator(dom)
Real = Realizer(dom) Real = Realizer(dom)
......
...@@ -37,9 +37,11 @@ class QHTOperator(LinearOperator): ...@@ -37,9 +37,11 @@ class QHTOperator(LinearOperator):
space : int space : int
The index of the domain on which the operator acts. The index of the domain on which the operator acts.
target[space] must be a non-harmonic LogRGSpace. target[space] must be a non-harmonic LogRGSpace.
codomain : Domain
The codomain for target[space]. If not supplied, it is inferred.
""" """
def __init__(self, target, space=0): def __init__(self, target, space=0, codomain=None):
self._target = DomainTuple.make(target) self._target = DomainTuple.make(target)
self._space = infer_space(self._target, space) self._space = infer_space(self._target, space)
...@@ -51,8 +53,9 @@ class QHTOperator(LinearOperator): ...@@ -51,8 +53,9 @@ class QHTOperator(LinearOperator):
raise TypeError("target[space] must be a nonharmonic space") raise TypeError("target[space] must be a nonharmonic space")
self._domain = [dom for dom in self._target] self._domain = [dom for dom in self._target]
self._domain[self._space] = \ if codomain is None:
self._target[self._space].get_default_codomain() codomain = self._target[self._space].get_default_codomain()
self._domain[self._space] = codomain
self._domain = DomainTuple.make(self._domain) self._domain = DomainTuple.make(self._domain)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment