Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
37582f42
Commit
37582f42
authored
Feb 07, 2017
by
Theo Steininger
Browse files
Unified spaces and field_types into single domain object.
parent
e38eae92
Pipeline
#9997
passed with stage
in 33 minutes and 35 seconds
Changes
17
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/domain_object.py
0 → 100644
View file @
37582f42
# -*- coding: utf-8 -*-
import
abc
import
numpy
as
np
from
keepers
import
Loggable
,
\
Versionable
class
DomainObject
(
Versionable
,
Loggable
,
object
):
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
dtype
):
self
.
_dtype
=
np
.
dtype
(
dtype
)
self
.
_ignore_for_hash
=
[]
def
__hash__
(
self
):
# Extract the identifying parts from the vars(self) dict.
result_hash
=
0
for
key
in
sorted
(
vars
(
self
).
keys
()):
item
=
vars
(
self
)[
key
]
if
key
in
self
.
_ignore_for_hash
or
key
==
'_ignore_for_hash'
:
continue
result_hash
^=
item
.
__hash__
()
^
int
(
hash
(
key
)
/
117
)
return
result_hash
def
__eq__
(
self
,
x
):
if
isinstance
(
x
,
type
(
self
)):
return
hash
(
self
)
==
hash
(
x
)
else
:
return
False
def
__ne__
(
self
,
x
):
return
not
self
.
__eq__
(
x
)
@
property
def
dtype
(
self
):
return
self
.
_dtype
@
abc
.
abstractproperty
def
shape
(
self
):
raise
NotImplementedError
(
"There is no generic shape for DomainObject."
)
@
abc
.
abstractproperty
def
dim
(
self
):
raise
NotImplementedError
(
"There is no generic dim for DomainObject."
)
def
pre_cast
(
self
,
x
,
axes
=
None
):
return
x
def
post_cast
(
self
,
x
,
axes
=
None
):
return
x
# ---Serialization---
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
repository
):
result
=
cls
(
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
]))
return
result
nifty/field.py
View file @
37582f42
...
...
@@ -9,9 +9,8 @@ from d2o import distributed_data_object,\
from
nifty.config
import
nifty_configuration
as
gc
from
nifty.
field_types
import
FieldType
from
nifty.
domain_object
import
DomainObject
from
nifty.spaces.space
import
Space
from
nifty.spaces.power_space
import
PowerSpace
import
nifty.nifty_utilities
as
utilities
...
...
@@ -21,25 +20,15 @@ from nifty.random import Random
class
Field
(
Loggable
,
Versionable
,
object
):
# ---Initialization methods---
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
field_type
=
None
,
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
distribution_strategy
=
None
,
copy
=
False
):
self
.
domain
=
self
.
_parse_domain
(
domain
=
domain
,
val
=
val
)
self
.
domain_axes
=
self
.
_get_axes_tuple
(
self
.
domain
)
self
.
field_type
=
self
.
_parse_field_type
(
field_type
,
val
=
val
)
try
:
start
=
len
(
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
domain_axes
))
except
TypeError
:
start
=
0
self
.
field_type_axes
=
self
.
_get_axes_tuple
(
self
.
field_type
,
start
=
start
)
self
.
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
,
domain
=
self
.
domain
,
field_type
=
self
.
field_type
)
domain
=
self
.
domain
)
self
.
distribution_strategy
=
self
.
_parse_distribution_strategy
(
distribution_strategy
=
distribution_strategy
,
...
...
@@ -53,34 +42,18 @@ class Field(Loggable, Versionable, object):
domain
=
val
.
domain
else
:
domain
=
()
elif
isinstance
(
domain
,
Space
):
elif
isinstance
(
domain
,
DomainObject
):
domain
=
(
domain
,)
elif
not
isinstance
(
domain
,
tuple
):
domain
=
tuple
(
domain
)
for
d
in
domain
:
if
not
isinstance
(
d
,
Space
):
if
not
isinstance
(
d
,
DomainObject
):
raise
TypeError
(
"Given domain contains something that is not a "
"
nifty.spa
ce."
)
"
DomainObject instan
ce."
)
return
domain
def
_parse_field_type
(
self
,
field_type
,
val
=
None
):
if
field_type
is
None
:
if
isinstance
(
val
,
Field
):
field_type
=
val
.
field_type
else
:
field_type
=
()
elif
isinstance
(
field_type
,
FieldType
):
field_type
=
(
field_type
,)
elif
not
isinstance
(
field_type
,
tuple
):
field_type
=
tuple
(
field_type
)
for
ft
in
field_type
:
if
not
isinstance
(
ft
,
FieldType
):
raise
TypeError
(
"Given object is not a nifty.FieldType."
)
return
field_type
def
_get_axes_tuple
(
self
,
things_with_shape
,
start
=
0
):
i
=
start
axes_list
=
[]
...
...
@@ -92,7 +65,7 @@ class Field(Loggable, Versionable, object):
axes_list
+=
[
tuple
(
l
)]
return
tuple
(
axes_list
)
def
_infer_dtype
(
self
,
dtype
,
val
,
domain
,
field_type
):
def
_infer_dtype
(
self
,
dtype
,
val
,
domain
):
if
dtype
is
None
:
if
isinstance
(
val
,
Field
)
or
\
isinstance
(
val
,
distributed_data_object
):
...
...
@@ -102,8 +75,6 @@ class Field(Loggable, Versionable, object):
dtype_tuple
=
(
np
.
dtype
(
dtype
),)
if
domain
is
not
None
:
dtype_tuple
+=
tuple
(
np
.
dtype
(
sp
.
dtype
)
for
sp
in
domain
)
if
field_type
is
not
None
:
dtype_tuple
+=
tuple
(
np
.
dtype
(
ft
.
dtype
)
for
ft
in
field_type
)
dtype
=
reduce
(
lambda
x
,
y
:
np
.
result_type
(
x
,
y
),
dtype_tuple
)
...
...
@@ -127,10 +98,10 @@ class Field(Loggable, Versionable, object):
# ---Factory methods---
@
classmethod
def
from_random
(
cls
,
random_type
,
domain
=
None
,
dtype
=
None
,
field_type
=
None
,
def
from_random
(
cls
,
random_type
,
domain
=
None
,
dtype
=
None
,
distribution_strategy
=
None
,
**
kwargs
):
# create a initially empty field
f
=
cls
(
domain
=
domain
,
dtype
=
dtype
,
field_type
=
field_type
,
f
=
cls
(
domain
=
domain
,
dtype
=
dtype
,
distribution_strategy
=
distribution_strategy
)
# now use the processed input in terms of f in order to parse the
...
...
@@ -363,7 +334,6 @@ class Field(Loggable, Versionable, object):
std
=
std
,
domain
=
result_domain
,
dtype
=
harmonic_domain
.
dtype
,
field_type
=
self
.
field_type
,
distribution_strategy
=
self
.
distribution_strategy
)
for
x
in
result_list
]
...
...
@@ -451,9 +421,7 @@ class Field(Loggable, Versionable, object):
@
property
def
shape
(
self
):
shape_tuple
=
()
shape_tuple
+=
tuple
(
sp
.
shape
for
sp
in
self
.
domain
)
shape_tuple
+=
tuple
(
ft
.
shape
for
ft
in
self
.
field_type
)
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
domain
)
try
:
global_shape
=
reduce
(
lambda
x
,
y
:
x
+
y
,
shape_tuple
)
except
TypeError
:
...
...
@@ -463,9 +431,7 @@ class Field(Loggable, Versionable, object):
@
property
def
dim
(
self
):
dim_tuple
=
()
dim_tuple
+=
tuple
(
sp
.
dim
for
sp
in
self
.
domain
)
dim_tuple
+=
tuple
(
ft
.
dim
for
ft
in
self
.
field_type
)
dim_tuple
=
tuple
(
sp
.
dim
for
sp
in
self
.
domain
)
try
:
return
reduce
(
lambda
x
,
y
:
x
*
y
,
dim_tuple
)
except
TypeError
:
...
...
@@ -500,20 +466,12 @@ class Field(Loggable, Versionable, object):
casted_x
=
sp
.
pre_cast
(
casted_x
,
axes
=
self
.
domain_axes
[
ind
])
for
ind
,
ft
in
enumerate
(
self
.
field_type
):
casted_x
=
ft
.
pre_cast
(
casted_x
,
axes
=
self
.
field_type_axes
[
ind
])
casted_x
=
self
.
_actual_cast
(
casted_x
,
dtype
=
dtype
)
for
ind
,
sp
in
enumerate
(
self
.
domain
):
casted_x
=
sp
.
post_cast
(
casted_x
,
axes
=
self
.
domain_axes
[
ind
])
for
ind
,
ft
in
enumerate
(
self
.
field_type
):
casted_x
=
ft
.
post_cast
(
casted_x
,
axes
=
self
.
field_type_axes
[
ind
])
return
casted_x
def
_actual_cast
(
self
,
x
,
dtype
=
None
):
...
...
@@ -530,19 +488,16 @@ class Field(Loggable, Versionable, object):
return_x
.
set_full_data
(
x
,
copy
=
False
)
return
return_x
def
copy
(
self
,
domain
=
None
,
dtype
=
None
,
field_type
=
None
,
distribution_strategy
=
None
):
def
copy
(
self
,
domain
=
None
,
dtype
=
None
,
distribution_strategy
=
None
):
copied_val
=
self
.
get_val
(
copy
=
True
)
new_field
=
self
.
copy_empty
(
domain
=
domain
,
dtype
=
dtype
,
field_type
=
field_type
,
distribution_strategy
=
distribution_strategy
)
new_field
.
set_val
(
new_val
=
copied_val
,
copy
=
False
)
return
new_field
def
copy_empty
(
self
,
domain
=
None
,
dtype
=
None
,
field_type
=
None
,
distribution_strategy
=
None
):
def
copy_empty
(
self
,
domain
=
None
,
dtype
=
None
,
distribution_strategy
=
None
):
if
domain
is
None
:
domain
=
self
.
domain
else
:
...
...
@@ -553,11 +508,6 @@ class Field(Loggable, Versionable, object):
else
:
dtype
=
np
.
dtype
(
dtype
)
if
field_type
is
None
:
field_type
=
self
.
field_type
else
:
field_type
=
self
.
_parse_field_type
(
field_type
)
if
distribution_strategy
is
None
:
distribution_strategy
=
self
.
distribution_strategy
...
...
@@ -567,10 +517,6 @@ class Field(Loggable, Versionable, object):
if
self
.
domain
[
i
]
is
not
domain
[
i
]:
fast_copyable
=
False
break
for
i
in
xrange
(
len
(
self
.
field_type
)):
if
self
.
field_type
[
i
]
is
not
field_type
[
i
]:
fast_copyable
=
False
break
except
IndexError
:
fast_copyable
=
False
...
...
@@ -580,7 +526,6 @@ class Field(Loggable, Versionable, object):
else
:
new_field
=
Field
(
domain
=
domain
,
dtype
=
dtype
,
field_type
=
field_type
,
distribution_strategy
=
distribution_strategy
)
return
new_field
...
...
@@ -626,8 +571,6 @@ class Field(Loggable, Versionable, object):
assert
len
(
x
.
domain
)
==
len
(
self
.
domain
)
for
index
in
xrange
(
len
(
self
.
domain
)):
assert
x
.
domain
[
index
]
==
self
.
domain
[
index
]
for
index
in
xrange
(
len
(
self
.
field_type
)):
assert
x
.
field_type
[
index
]
==
self
.
field_type
[
index
]
except
AssertionError
:
raise
ValueError
(
"domains are incompatible."
)
...
...
@@ -707,22 +650,15 @@ class Field(Loggable, Versionable, object):
return_field
.
set_val
(
new_val
,
copy
=
False
)
return
return_field
def
_contraction_helper
(
self
,
op
,
spaces
,
types
):
def
_contraction_helper
(
self
,
op
,
spaces
):
# build a list of all axes
if
spaces
is
None
:
spaces
=
xrange
(
len
(
self
.
domain
))
else
:
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
types
is
None
:
types
=
xrange
(
len
(
self
.
field_type
))
else
:
types
=
utilities
.
cast_axis_to_tuple
(
types
,
len
(
self
.
field_type
))
axes_list
=
tuple
(
self
.
domain_axes
[
sp_index
]
for
sp_index
in
spaces
)
axes_list
=
()
axes_list
+=
tuple
(
self
.
domain_axes
[
sp_index
]
for
sp_index
in
spaces
)
axes_list
+=
tuple
(
self
.
field_type_axes
[
ft_index
]
for
ft_index
in
types
)
try
:
axes_list
=
reduce
(
lambda
x
,
y
:
x
+
y
,
axes_list
)
except
TypeError
:
...
...
@@ -739,47 +675,44 @@ class Field(Loggable, Versionable, object):
return_domain
=
tuple
(
self
.
domain
[
i
]
for
i
in
xrange
(
len
(
self
.
domain
))
if
i
not
in
spaces
)
return_field_type
=
tuple
(
self
.
field_type
[
i
]
for
i
in
xrange
(
len
(
self
.
field_type
))
if
i
not
in
types
)
return_field
=
Field
(
domain
=
return_domain
,
val
=
data
,
field_type
=
return_field_type
,
copy
=
False
)
return
return_field
def
sum
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'sum'
,
spaces
,
types
)
def
sum
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'sum'
,
spaces
)
def
prod
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'prod'
,
spaces
,
types
)
def
prod
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'prod'
,
spaces
)
def
all
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'all'
,
spaces
,
types
)
def
all
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'all'
,
spaces
)
def
any
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'any'
,
spaces
,
types
)
def
any
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'any'
,
spaces
)
def
min
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'min'
,
spaces
,
types
)
def
min
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'min'
,
spaces
)
def
nanmin
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'nanmin'
,
spaces
,
types
)
def
nanmin
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'nanmin'
,
spaces
)
def
max
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'max'
,
spaces
,
types
)
def
max
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'max'
,
spaces
)
def
nanmax
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'nanmax'
,
spaces
,
types
)
def
nanmax
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'nanmax'
,
spaces
)
def
mean
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'mean'
,
spaces
,
types
)
def
mean
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'mean'
,
spaces
)
def
var
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'var'
,
spaces
,
types
)
def
var
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'var'
,
spaces
)
def
std
(
self
,
spaces
=
None
,
types
=
None
):
return
self
.
_contraction_helper
(
'std'
,
spaces
,
types
)
def
std
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'std'
,
spaces
)
# ---General binary methods---
...
...
@@ -790,9 +723,6 @@ class Field(Loggable, Versionable, object):
assert
len
(
other
.
domain
)
==
len
(
self
.
domain
)
for
index
in
xrange
(
len
(
self
.
domain
)):
assert
other
.
domain
[
index
]
==
self
.
domain
[
index
]
assert
len
(
other
.
field_type
)
==
len
(
self
.
field_type
)
for
index
in
xrange
(
len
(
self
.
field_type
)):
assert
other
.
field_type
[
index
]
==
self
.
field_type
[
index
]
except
AssertionError
:
raise
ValueError
(
"domains are incompatible."
)
...
...
@@ -895,19 +825,14 @@ class Field(Loggable, Versionable, object):
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
hdf5_group
.
attrs
[
'distribution_strategy'
]
=
self
.
distribution_strategy
hdf5_group
.
attrs
[
'field_type_axes'
]
=
str
(
self
.
field_type_axes
)
hdf5_group
.
attrs
[
'domain_axes'
]
=
str
(
self
.
domain_axes
)
hdf5_group
[
'num_domain'
]
=
len
(
self
.
domain
)
hdf5_group
[
'num_ft'
]
=
len
(
self
.
field_type
)
ret_dict
=
{
'val'
:
self
.
val
}
for
i
in
range
(
len
(
self
.
domain
)):
ret_dict
[
's_'
+
str
(
i
)]
=
self
.
domain
[
i
]
for
i
in
range
(
len
(
self
.
field_type
)):
ret_dict
[
'ft_'
+
str
(
i
)]
=
self
.
field_type
[
i
]
return
ret_dict
@
classmethod
...
...
@@ -922,14 +847,7 @@ class Field(Loggable, Versionable, object):
temp_domain
.
append
(
repository
.
get
(
's_'
+
str
(
i
),
hdf5_group
))
new_field
.
domain
=
tuple
(
temp_domain
)
temp_ft
=
[]
for
i
in
range
(
hdf5_group
[
'num_ft'
][()]):
temp_domain
.
append
(
repository
.
get
(
'ft_'
+
str
(
i
),
hdf5_group
))
new_field
.
field_type
=
tuple
(
temp_ft
)
exec
(
'new_field.domain_axes = '
+
hdf5_group
.
attrs
[
'domain_axes'
])
exec
(
'new_field.field_type_axes = '
+
hdf5_group
.
attrs
[
'field_type_axes'
])
new_field
.
_val
=
repository
.
get
(
'val'
,
hdf5_group
)
new_field
.
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
new_field
.
distribution_strategy
=
\
...
...
nifty/field_types/field_array.py
View file @
37582f42
# -*- coding: utf-8 -*-
import
pickle
from
field_type
import
FieldType
class
FieldArray
(
FieldType
):
def
__init__
(
self
,
dtype
,
shape
):
try
:
new_shape
=
tuple
([
int
(
i
)
for
i
in
shape
])
except
TypeError
:
new_shape
=
(
int
(
shape
),
)
self
.
_shape
=
new_shape
super
(
FieldArray
,
self
).
__init__
(
dtype
=
dtype
)
@
property
def
shape
(
self
):
return
self
.
_shape
@
property
def
dim
(
self
):
return
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
shape
)
# ---Serialization---
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
[
'shape'
]
=
self
.
shape
hdf5_group
[
'dtype'
]
=
pickle
.
dumps
(
self
.
dtype
)
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
loopback_get
):
result
=
cls
(
hdf5_group
[
'shape'
][:],
pickle
.
loads
(
hdf5_group
[
'dtype'
][()])
)
return
result
nifty/field_types/field_type.py
View file @
37582f42
# -*- coding: utf-8 -*-
import
pickle
import
numpy
as
np
from
keepers
import
Versionable
from
nifty.domain_object
import
DomainObject
class
FieldType
(
Versionable
,
object
):
def
__init__
(
self
,
shape
,
dtype
):
try
:
new_shape
=
tuple
([
int
(
i
)
for
i
in
shape
])
except
TypeError
:
new_shape
=
(
int
(
shape
),
)
self
.
_shape
=
new_shape
self
.
_dtype
=
np
.
dtype
(
dtype
)
def
__hash__
(
self
):
# Extract the identifying parts from the vars(self) dict.
result_hash
=
0
for
(
key
,
item
)
in
vars
(
self
).
items
():
result_hash
^=
item
.
__hash__
()
^
int
(
hash
(
key
)
/
117
)
return
result_hash
def
__eq__
(
self
,
x
):
if
isinstance
(
x
,
type
(
self
)):
return
hash
(
self
)
==
hash
(
x
)
else
:
return
False
@
property
def
shape
(
self
):
return
self
.
_shape
@
property
def
dtype
(
self
):
return
self
.
_dtype
@
property
def
dim
(
self
):
raise
NotImplementedError
class
FieldType
(
DomainObject
):
def
process
(
self
,
method_name
,
array
,
inplace
=
True
,
**
kwargs
):
try
:
...
...
@@ -52,25 +17,3 @@ class FieldType(Versionable, object):
result_array
=
array
.
copy
()
return
result_array
def
pre_cast
(
self
,
x
,
axes
=
None
):
return
x
def
post_cast
(
self
,
x
,
axes
=
None
):
return
x
# ---Serialization---
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
[
'shape'
]
=
self
.
shape
hdf5_group
[
'dtype'
]
=
pickle
.
dumps
(
self
.
dtype
)
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
loopback_get
):
result
=
cls
(
hdf5_group
[
'shape'
][:],
pickle
.
loads
(
hdf5_group
[
'dtype'
][()])
)
return
result
nifty/nifty_utilities.py
View file @
37582f42
...
...
@@ -281,33 +281,17 @@ def get_default_codomain(domain):
def
parse_domain
(
domain
):
from
nifty.
spaces.space
import
Space
from
nifty.
domain_object
import
DomainObject
if
domain
is
None
:
domain
=
()
elif
isinstance
(
domain
,
Space
):
elif
isinstance
(
domain
,
DomainObject
):
domain
=
(
domain
,)
elif
not
isinstance
(
domain
,
tuple
):
domain
=
tuple
(
domain
)
for
d
in
domain
:
if
not
isinstance
(
d
,
Space
):
if
not
isinstance
(
d
,
DomainObject
):
raise
TypeError
(
"Given object contains something that is not a "
"
nifty.space
."
)
"
instance of DomainObject-class
."
)
return
domain
def
parse_field_type
(
field_type
):
from
nifty.field_types
import
FieldType
if
field_type
is
None
:
field_type
=
()
elif
isinstance
(
field_type
,
FieldType
):
field_type
=
(
field_type
,)
elif
not
isinstance
(
field_type
,
tuple
):
field_type
=
tuple
(
field_type
)
for
ft
in
field_type
:
if
not
isinstance
(
ft
,
FieldType
):
raise
TypeError
(
"Given object is not a nifty.FieldType."
)
return
field_type
nifty/operators/composed_operator/composed_operator.py
View file @
37582f42
...
...
@@ -13,13 +13,13 @@ class ComposedOperator(LinearOperator):
"instances of the LinearOperator-baseclass"
)
self
.
_operator_store
+=
(
op
,)
def
_check_input_compatibility
(
self
,
x
,
spaces
,
types
,
inverse
=
False
):
def
_check_input_compatibility
(
self
,
x
,
spaces
,
inverse
=
False
):
"""
The input check must be disabled for the ComposedOperator, since it
is not easily forecasteable what the output of an operator-call
will look like.
"""
return
(
spaces
,
types
)
return
spaces
# ---Mandatory properties and methods---
@
property
...
...
@@ -38,22 +38,6 @@ class ComposedOperator(LinearOperator):
self
.
_target
+=
op
.
target
return
self
.
_target
@
property
def
field_type
(
self
):
if
not
hasattr
(
self
,
'_field_type'
):
self
.
_field_type
=
()
for
op
in
self
.
_operator_store
:
self
.
_field_type
+=
op
.
field_type
return
self
.
_field_type
@
property
def
field_type_target
(
self
):
if
not
hasattr
(
self
,
'_field_type_target'
):
self
.
_field_type_target
=
()
for
op
in
self
.
_operator_store
:
self
.
_field_type_target
+=
op
.
field_type_target
return
self
.
_field_type_target
@
property