Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
be3d197c
Commit
be3d197c
authored
Sep 04, 2016
by
theos
Browse files
Merged fft_operator/transformation_factory into FFTOperator.
Started to consolidate the code base of LM <-> GL/HP transformations.
parent
007da99a
Changes
10
Hide whitespace changes
Inline
Side-by-side
nifty/field.py
View file @
be3d197c
...
@@ -5,7 +5,7 @@ from d2o import distributed_data_object,\
...
@@ -5,7 +5,7 @@ from d2o import distributed_data_object,\
STRATEGIES
as
DISTRIBUTION_STRATEGIES
STRATEGIES
as
DISTRIBUTION_STRATEGIES
from
nifty.config
import
about
,
\
from
nifty.config
import
about
,
\
nifty_configuration
as
gc
,
\
nifty_configuration
as
gc
from
nifty.field_types
import
FieldType
from
nifty.field_types
import
FieldType
...
...
nifty/operators/fft_operator/fft_operator.py
View file @
be3d197c
from
nifty.config
import
about
from
nifty.config
import
about
import
nifty.nifty_utilities
as
utilities
import
nifty.nifty_utilities
as
utilities
from
nifty.spaces
import
RGSpace
,
\
GLSpace
,
\
HPSpace
,
\
LMSpace
from
nifty.operators.linear_operator
import
LinearOperator
from
nifty.operators.linear_operator
import
LinearOperator
from
transformations
import
TransformationFactory
from
transformations
import
RGRGTransformation
,
\
LMGLTransformation
,
\
LMHPTransformation
,
\
GLLMTransformation
,
\
HPLMTransformation
,
\
TransformationCache
class
FFTOperator
(
LinearOperator
):
class
FFTOperator
(
LinearOperator
):
# ---Class attributes---
default_codomain_dictionary
=
{
RGSpace
:
RGSpace
,
HPSpace
:
LMSpace
,
GLSpace
:
LMSpace
,
LMSpace
:
HPSpace
,
}
transformation_dictionary
=
{(
RGSpace
,
RGSpace
):
RGRGTransformation
,
(
HPSpace
,
LMSpace
):
HPLMTransformation
,
(
GLSpace
,
LMSpace
):
GLLMTransformation
,
(
LMSpace
,
HPSpace
):
LMHPTransformation
,
(
LMSpace
,
GLSpace
):
LMGLTransformation
}
# ---Overwritten properties and methods---
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
target
=
None
):
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
target
=
None
,
module
=
None
):
super
(
FFTOperator
,
self
).
__init__
(
domain
=
domain
,
super
(
FFTOperator
,
self
).
__init__
(
domain
=
domain
,
field_type
=
field_type
)
field_type
=
field_type
)
# Initialize domain and target
if
len
(
self
.
domain
)
!=
1
:
if
len
(
self
.
domain
)
!=
1
:
raise
ValueError
(
about
.
_errors
.
cstring
(
raise
ValueError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator accepts only exactly one '
'ERROR: TransformationOperator accepts only exactly one '
...
@@ -24,17 +51,30 @@ class FFTOperator(LinearOperator):
...
@@ -24,17 +51,30 @@ class FFTOperator(LinearOperator):
))
))
if
target
is
None
:
if
target
is
None
:
target
=
utilities
.
get_default_codomain
(
self
.
domain
[
0
])
target
=
(
self
.
get_default_codomain
(
self
.
domain
[
0
]),
)
self
.
_target
=
self
.
_parse_domain
(
target
)
self
.
_target
=
self
.
_parse_domain
(
target
)
self
.
_forward_transformation
=
TransformationFactory
.
create
(
# Create transformation instances
self
.
domain
[
0
],
self
.
target
[
0
]
try
:
)
forward_class
=
self
.
transformation_dictionary
[
(
self
.
domain
[
0
].
__class__
,
self
.
target
[
0
].
__class__
)]
self
.
_inverse_transformation
=
TransformationFactory
.
create
(
except
KeyError
:
self
.
target
[
0
],
self
.
domain
[
0
]
raise
TypeError
(
about
.
_errors
.
cstring
(
)
"ERROR: No forward transformation for domain-target pair "
"found."
))
try
:
backward_class
=
self
.
transformation_dictionary
[
(
self
.
target
[
0
].
__class__
,
self
.
domain
[
0
].
__class__
)]
except
KeyError
:
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: No backward transformation for domain-target pair "
"found."
))
self
.
_forward_transformation
=
TransformationCache
.
create
(
forward_class
,
self
.
domain
[
0
],
self
.
target
[
0
],
module
=
module
)
self
.
_backward_transformation
=
TransformationCache
.
create
(
backward_class
,
self
.
target
[
0
],
self
.
domain
[
0
],
module
=
module
)
def
_times
(
self
,
x
,
spaces
,
types
):
def
_times
(
self
,
x
,
spaces
,
types
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
...
@@ -69,7 +109,7 @@ class FFTOperator(LinearOperator):
...
@@ -69,7 +109,7 @@ class FFTOperator(LinearOperator):
else
:
else
:
axes
=
x
.
domain_axes
[
spaces
[
0
]]
axes
=
x
.
domain_axes
[
spaces
[
0
]]
new_val
=
self
.
_
inverse
_transformation
.
transform
(
x
.
val
,
axes
=
axes
)
new_val
=
self
.
_
backward
_transformation
.
transform
(
x
.
val
,
axes
=
axes
)
if
spaces
is
None
:
if
spaces
is
None
:
result_domain
=
self
.
domain
result_domain
=
self
.
domain
...
@@ -99,3 +139,22 @@ class FFTOperator(LinearOperator):
...
@@ -99,3 +139,22 @@ class FFTOperator(LinearOperator):
@
property
@
property
def
unitary
(
self
):
def
unitary
(
self
):
return
True
return
True
# ---Added properties and methods---
@
classmethod
def
get_default_codomain
(
cls
,
domain
):
domain_class
=
domain
.
__class__
try
:
codomain_class
=
cls
.
default_codomain_dictionary
[
domain_class
]
except
KeyError
:
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: unknown domain"
))
try
:
transform_class
=
cls
.
transformation_dictionary
[(
domain_class
,
codomain_class
)]
except
KeyError
:
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: No transformation for domain-codomain pair found."
))
return
transform_class
.
get_codomain
(
domain
)
nifty/operators/fft_operator/transformations/__init__.py
View file @
be3d197c
...
@@ -4,4 +4,4 @@ from hplmtransformation import HPLMTransformation
...
@@ -4,4 +4,4 @@ from hplmtransformation import HPLMTransformation
from
lmgltransformation
import
LMGLTransformation
from
lmgltransformation
import
LMGLTransformation
from
lmhptransformation
import
LMHPTransformation
from
lmhptransformation
import
LMHPTransformation
from
transformation_factory
import
TransformationFactory
from
transformation_cache
import
TransformationCache
\ No newline at end of file
\ No newline at end of file
nifty/operators/fft_operator/transformations/gllmtransformation.py
View file @
be3d197c
import
numpy
as
np
import
numpy
as
np
from
transformation
import
Transformation
from
nifty.config
import
dependency_injector
as
gdi
,
\
from
d2o
import
distributed_data_object
about
from
nifty.config
import
dependency_injector
as
gdi
import
nifty.nifty_utilities
as
utilities
from
nifty
import
GLSpace
,
LMSpace
from
nifty
import
GLSpace
,
LMSpace
from
slicing_transformation
import
SlicingTransformation
import
lm_transformation_factory
as
ltf
import
lm_transformation_factory
as
ltf
g
l
=
gdi
.
get
(
'libsharp_wrapper_gl'
)
l
ibsharp
=
gdi
.
get
(
'libsharp_wrapper_gl'
)
class
GLLMTransformation
(
Transformation
):
class
GLLMTransformation
(
SlicingTransformation
):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
if
'libsharp_wrapper_gl'
not
in
gdi
:
if
'libsharp_wrapper_gl'
not
in
gdi
:
raise
ImportError
(
"The module libsharp is needed but not available"
)
raise
ImportError
(
about
.
_errors
.
cstring
(
"The module libsharp is needed but not available."
))
if
codomain
is
None
:
super
(
GLLMTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
self
.
domain
=
domain
self
.
codomain
=
self
.
get_codomain
(
domain
)
# ---Mandatory properties and methods---
elif
self
.
check_codomain
(
domain
,
codomain
):
self
.
domain
=
domain
self
.
codomain
=
codomain
else
:
raise
ValueError
(
"ERROR: Incompatible codomain!"
)
@
staticmethod
@
staticmethod
def
get_codomain
(
domain
):
def
get_codomain
(
domain
):
...
@@ -40,10 +38,12 @@ class GLLMTransformation(Transformation):
...
@@ -40,10 +38,12 @@ class GLLMTransformation(Transformation):
A compatible codomain.
A compatible codomain.
"""
"""
if
domain
is
None
:
if
domain
is
None
:
raise
ValueError
(
'ERROR: cannot generate codomain for None'
)
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: cannot generate codomain for None"
))
if
not
isinstance
(
domain
,
GLSpace
):
if
not
isinstance
(
domain
,
GLSpace
):
raise
TypeError
(
'ERROR: domain needs to be a GLSpace'
)
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: domain needs to be a GLSpace"
))
nlat
=
domain
.
nlat
nlat
=
domain
.
nlat
lmax
=
nlat
-
1
lmax
=
nlat
-
1
...
@@ -53,16 +53,18 @@ class GLLMTransformation(Transformation):
...
@@ -53,16 +53,18 @@ class GLLMTransformation(Transformation):
else
:
else
:
return
LMSpace
(
lmax
=
lmax
,
mmax
=
mmax
,
dtype
=
np
.
complex128
)
return
LMSpace
(
lmax
=
lmax
,
mmax
=
mmax
,
dtype
=
np
.
complex128
)
@
static
method
@
class
method
def
check_codomain
(
domain
,
codomain
):
def
check_codomain
(
cls
,
domain
,
codomain
):
if
not
isinstance
(
domain
,
GLSpace
):
if
not
isinstance
(
domain
,
GLSpace
):
raise
TypeError
(
'ERROR: domain is not a GLSpace'
)
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: domain is not a GLSpace"
))
if
codomain
is
None
:
if
codomain
is
None
:
return
False
return
False
if
not
isinstance
(
codomain
,
LMSpace
):
if
not
isinstance
(
codomain
,
LMSpace
):
raise
TypeError
(
'ERROR: codomain must be a LMSpace.'
)
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: codomain must be a LMSpace."
))
nlat
=
domain
.
nlat
nlat
=
domain
.
nlat
nlon
=
domain
.
nlon
nlon
=
domain
.
nlon
...
@@ -74,74 +76,45 @@ class GLLMTransformation(Transformation):
...
@@ -74,74 +76,45 @@ class GLLMTransformation(Transformation):
return
True
return
True
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
# ---Added properties and methods---
"""
GL -> LM transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
if
self
.
domain
.
discrete
:
val
=
self
.
domain
.
weight
(
val
,
power
=-
0.5
,
axes
=
axes
)
def
_transformation_of_slice
(
self
,
inp
):
# shorthands for transform parameters
# shorthands for transform parameters
nlat
=
self
.
domain
.
nlat
nlat
=
self
.
domain
.
nlat
nlon
=
self
.
domain
.
nlon
nlon
=
self
.
domain
.
nlon
lmax
=
self
.
codomain
.
lmax
lmax
=
self
.
codomain
.
lmax
mmax
=
self
.
codomain
.
mmax
mmax
=
self
.
codomain
.
mmax
if
isinstance
(
val
,
distributed_data_object
):
if
issubclass
(
inp
.
dtype
.
type
,
np
.
complexfloating
):
temp_val
=
val
.
get_full_data
()
else
:
[
resultReal
,
resultImag
]
=
[
self
.
libsharpMap2Alm
(
x
,
temp_val
=
val
nlat
=
nlat
,
nlon
=
nlon
,
return_val
=
None
lmax
=
lmax
,
mmax
=
mmax
)
for
slice_list
in
utilities
.
get_slice_list
(
temp_val
.
shape
,
axes
):
for
x
in
(
inp
.
real
,
inp
.
imag
)]
if
slice_list
==
[
slice
(
None
,
None
)]:
inp
=
temp_val
resultReal
=
ltf
.
buildIdx
(
resultReal
,
lmax
=
lmax
)
else
:
resultImag
=
ltf
.
buildIdx
(
resultImag
,
lmax
=
lmax
)
if
return_val
is
None
:
# construct correct complex dtype
return_val
=
np
.
empty_like
(
temp_val
)
one
=
resultReal
.
dtype
.
type
(
1
)
inp
=
temp_val
[
slice_list
]
result_dtype
=
np
.
dtype
(
type
(
one
+
1j
))
if
inp
.
dtype
>=
np
.
dtype
(
'complex64'
):
result
=
np
.
empty_like
(
resultReal
,
dtype
=
result_dtype
)
inpReal
=
self
.
GlMap2Alm
(
result
.
real
=
resultReal
np
.
real
(
inp
).
astype
(
np
.
float64
,
copy
=
False
),
nlat
=
nlat
,
result
.
imag
=
resultImag
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
)
inpImg
=
self
.
GlMap2Alm
(
np
.
imag
(
inp
).
astype
(
np
.
float64
,
copy
=
False
),
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
)
inpReal
=
ltf
.
buildIdx
(
inpReal
,
lmax
=
lmax
)
inpImg
=
ltf
.
buildIdx
(
inpImg
,
lmax
=
lmax
)
inp
=
inpReal
+
inpImg
*
1j
else
:
inp
=
self
.
GlMap2Alm
(
inp
,
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
)
inp
=
ltf
.
buildIdx
(
inp
,
lmax
=
lmax
)
if
slice_list
==
[
slice
(
None
,
None
)]:
return_val
=
inp
else
:
return_val
[
slice_list
]
=
inp
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
new_val
.
set_full_data
(
return_val
,
copy
=
False
)
else
:
else
:
return_val
=
return_val
.
astype
(
self
.
codomain
.
dtype
,
copy
=
False
)
result
=
self
.
libsharpMap2Alm
(
inp
,
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
)
result
=
ltf
.
buildIdx
(
result
,
lmax
=
lmax
)
return
re
turn_val
return
re
sult
def
G
lMap2Alm
(
self
,
inp
,
**
kwargs
):
def
l
ibsharp
Map2Alm
(
self
,
inp
,
**
kwargs
):
if
inp
.
dtype
==
np
.
dtype
(
'float32'
):
if
inp
.
dtype
==
np
.
dtype
(
'float32'
):
return
gl
.
map2alm_f
(
inp
,
kwargs
)
return
libsharp
.
map2alm_f
(
inp
,
**
kwargs
)
elif
inp
.
dtype
==
np
.
dtype
(
'float64'
):
return
libsharp
.
map2alm
(
inp
,
**
kwargs
)
else
:
else
:
return
gl
.
map
.
alm
(
inp
,
kwargs
)
about
.
warnings
.
cprint
(
"WARNING: performing dtype conversion for "
"libsharp compatibility."
)
nifty/operators/fft_operator/transformations/rgrgtransformation.py
View file @
be3d197c
...
@@ -7,17 +7,13 @@ from nifty import RGSpace, nifty_configuration
...
@@ -7,17 +7,13 @@ from nifty import RGSpace, nifty_configuration
class
RGRGTransformation
(
Transformation
):
class
RGRGTransformation
(
Transformation
):
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
if
codomain
is
None
:
super
(
RGRGTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
codomain
=
self
.
get_codomain
(
domain
)
else
:
if
not
self
.
check_codomain
(
domain
,
codomain
):
raise
ValueError
(
"ERROR: incompatible codomain!"
)
if
module
is
None
:
if
module
is
None
:
if
nifty_configuration
[
'fft_module'
]
==
'pyfftw'
:
if
nifty_configuration
[
'fft_module'
]
==
'pyfftw'
:
self
.
_transform
=
FFTW
(
domain
,
codomain
)
self
.
_transform
=
FFTW
(
domain
,
codomain
)
elif
nifty_configuration
[
'fft_module'
]
==
'gfft'
or
\
elif
(
nifty_configuration
[
'fft_module'
]
==
'gfft'
or
nifty_configuration
[
'fft_module'
]
==
'gfft_dummy'
:
nifty_configuration
[
'fft_module'
]
==
'gfft_dummy'
)
:
self
.
_transform
=
\
self
.
_transform
=
\
GFFT
(
domain
,
GFFT
(
domain
,
codomain
,
codomain
,
...
@@ -73,7 +69,9 @@ class RGRGTransformation(Transformation):
...
@@ -73,7 +69,9 @@ class RGRGTransformation(Transformation):
distances
=
1
/
(
np
.
array
(
domain
.
shape
)
*
distances
=
1
/
(
np
.
array
(
domain
.
shape
)
*
np
.
array
(
domain
.
distances
))
np
.
array
(
domain
.
distances
))
if
dtype
is
None
:
if
dtype
is
None
:
dtype
=
np
.
complex
# create a definitely complex dtype from the dtype of domain
one
=
domain
.
dtype
.
type
(
1
)
dtype
=
np
.
dtype
(
type
(
one
+
1j
))
new_space
=
RGSpace
(
domain
.
shape
,
new_space
=
RGSpace
(
domain
.
shape
,
zerocenter
=
zerocenter
,
zerocenter
=
zerocenter
,
...
@@ -86,7 +84,7 @@ class RGRGTransformation(Transformation):
...
@@ -86,7 +84,7 @@ class RGRGTransformation(Transformation):
@
staticmethod
@
staticmethod
def
check_codomain
(
domain
,
codomain
):
def
check_codomain
(
domain
,
codomain
):
if
not
isinstance
(
domain
,
RGSpace
):
if
not
isinstance
(
domain
,
RGSpace
):
raise
TypeError
(
'ERROR: domain
must be
a RGSpace'
)
raise
TypeError
(
'ERROR: domain
is not
a RGSpace'
)
if
codomain
is
None
:
if
codomain
is
None
:
return
False
return
False
...
@@ -101,6 +99,11 @@ class RGRGTransformation(Transformation):
...
@@ -101,6 +99,11 @@ class RGRGTransformation(Transformation):
if
domain
.
harmonic
==
codomain
.
harmonic
:
if
domain
.
harmonic
==
codomain
.
harmonic
:
return
False
return
False
if
codomain
.
harmonic
and
not
issubclass
(
codomain
.
dtype
.
type
,
np
.
complexfloating
):
about
.
warnings
.
cprint
(
"WARNING: codomain is harmonic but dtype is real."
)
# Check if the distances match, i.e. dist' = 1 / (num * dist)
# Check if the distances match, i.e. dist' = 1 / (num * dist)
if
not
np
.
all
(
if
not
np
.
all
(
np
.
absolute
(
np
.
array
(
domain
.
shape
)
*
np
.
absolute
(
np
.
array
(
domain
.
shape
)
*
...
...
nifty/operators/fft_operator/transformations/slicing_transformation.py
0 → 100644
View file @
be3d197c
# -*- coding: utf-8 -*-
import
abc
import
numpy
as
np
import
nifty.nifty_utilities
as
utilities
from
transformation
import
Transformation
class
SlicingTransformation
(
Transformation
):
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
return_shape
=
np
.
array
(
val
.
shape
)
return_shape
[
list
(
axes
)]
=
self
.
codomain
.
shape
return_shape
=
tuple
(
return_shape
)
return_val
=
None
for
slice_list
in
utilities
.
get_slice_list
(
val
.
shape
,
axes
):
if
return_val
is
None
:
return_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
,
global_shape
=
return_shape
)
data
=
val
[
slice_list
]
data
=
data
.
get_full_data
()
data
=
self
.
_transformation_of_slice
(
data
)
return_val
[
slice_list
]
=
data
return
return_val
@
abc
.
abstractmethod
def
_transformation_of_slice
(
self
,
inp
):
raise
NotImplementedError
nifty/operators/fft_operator/transformations/transformation.py
View file @
be3d197c
import
abc
class
Transformation
(
object
):
class
Transformation
(
object
):
"""
"""
A generic transformation which defines a static check_codomain
A generic transformation which defines a static check_codomain
method for all transforms.
method for all transforms.
"""
"""
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
pass
if
codomain
is
None
:
self
.
domain
=
domain
self
.
codomain
=
self
.
get_codomain
(
domain
)
elif
self
.
check_codomain
(
domain
,
codomain
):
self
.
domain
=
domain
self
.
codomain
=
codomain
else
:
raise
ValueError
(
"ERROR: Incompatible codomain!"
)
@
staticmethod
def
get_codomain
(
domain
,
dtype
=
None
,
zerocenter
=
None
,
**
kwargs
):
raise
NotImplementedError
@
staticmethod
def
check_codomain
(
domain
,
codomain
):
raise
NotImplementedError
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
nifty/operators/fft_operator/transformations/transformation_cache.py
0 → 100644
View file @
be3d197c
class
_TransformationCache
(
object
):
def
__init__
(
self
):
self
.
cache
=
{}
def
create
(
self
,
transformation_class
,
domain
,
codomain
,
module
):
key
=
domain
.
__hash__
()
^
((
codomain
.
__hash__
()
/
111
)
^
(
module
.
__hash__
())
/
179
)
if
key
not
in
self
.
cache
:
self
.
cache
[
key
]
=
transformation_class
(
domain
,
codomain
,
module
)
return
self
.
cache
[
key
]
TransformationCache
=
_TransformationCache
()
nifty/operators/fft_operator/transformations/transformation_factory.py
deleted
100644 → 0
View file @
007da99a
from
nifty.spaces
import
RGSpace
,
GLSpace
,
HPSpace
,
LMSpace
from
rgrgtransformation
import
RGRGTransformation
from
gllmtransformation
import
GLLMTransformation
from
hplmtransformation
import
HPLMTransformation
from
lmgltransformation
import
LMGLTransformation
from
lmhptransformation
import
LMHPTransformation
class
_TransformationFactory
(
object
):
"""
Transform factory which generates transform objects
"""
def
__init__
(
self
):
# cache for storing the transform objects
self
.
cache
=
{}
def
_get_transform
(
self
,
domain
,
codomain
,
module
):
if
isinstance
(
domain
,
RGSpace
):
if
isinstance
(
codomain
,
RGSpace
):
return
RGRGTransformation
(
domain
,
codomain
,
module
)
else
:
raise
ValueError
(
'ERROR: incompatible codomain'
)
elif
isinstance
(
domain
,
GLSpace
):
if
isinstance
(
codomain
,
LMSpace
):
return
GLLMTransformation
(
domain
,
codomain
,
module
)
else
:
raise
ValueError
(
'ERROR: incompatible codomain'
)
elif
isinstance
(
domain
,
HPSpace
):
if
isinstance
(
codomain
,
LMSpace
):
return
HPLMTransformation
(
domain
,
codomain
,
module
)
else
:
raise
ValueError
(
'ERROR: incompatible codomain'
)
elif
isinstance
(
domain
,
LMSpace
):
if
isinstance
(
codomain
,<