Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
cb713171
Commit
cb713171
authored
Aug 24, 2016
by
theos
Browse files
Fixed the return type of the FFTOperator.
Fixed violation of LSP for implemented keyword/property.
parent
28b305d3
Changes
5
Hide whitespace changes
Inline
Side-by-side
nifty/field.py
View file @
cb713171
...
...
@@ -49,14 +49,17 @@ class Field(object):
self
.
set_val
(
new_val
=
val
,
copy
=
copy
)
def
_parse_domain
(
self
,
domain
,
val
):
def
_parse_domain
(
self
,
domain
,
val
=
None
):
if
domain
is
None
:
if
isinstance
(
val
,
Field
):
domain
=
val
.
domain
else
:
domain
=
()
elif
not
isinstance
(
domain
,
tupl
e
):
elif
isinstance
(
domain
,
Spac
e
):
domain
=
(
domain
,)
elif
not
isinstance
(
domain
,
tuple
):
domain
=
tuple
(
domain
)
for
d
in
domain
:
if
not
isinstance
(
d
,
Space
):
raise
TypeError
(
about
.
_errors
.
cstring
(
...
...
@@ -64,14 +67,16 @@ class Field(object):
"nifty.space."
))
return
domain
def
_parse_field_type
(
self
,
field_type
,
val
):
def
_parse_field_type
(
self
,
field_type
,
val
=
None
):
if
field_type
is
None
:
if
isinstance
(
val
,
Field
):
field_type
=
val
.
field_type
else
:
field_type
=
()
elif
not
isinstance
(
field_type
,
tupl
e
):
elif
isinstance
(
field_type
,
FieldTyp
e
):
field_type
=
(
field_type
,)
elif
not
isinstance
(
field_type
,
tuple
):
field_type
=
tuple
(
field_type
)
for
ft
in
field_type
:
if
not
isinstance
(
ft
,
FieldType
):
raise
TypeError
(
about
.
_errors
.
cstring
(
...
...
nifty/operators/diagonal_operator/diagonal_operator.py
View file @
cb713171
...
...
@@ -16,11 +16,13 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
implemented
=
False
,
diagonal
=
None
,
bare
=
False
,
datamodel
=
None
,
copy
=
True
):
diagonal
=
None
,
bare
=
False
,
copy
=
True
,
datamodel
=
None
):
super
(
DiagonalOperator
,
self
).
__init__
(
domain
=
domain
,
field_type
=
field_type
,
implemented
=
implemented
)
self
.
_implemented
=
bool
(
implemented
)
if
datamodel
is
None
:
if
isinstance
(
diagonal
,
distributed_data_object
):
datamodel
=
diagonal
.
distribution_strategy
...
...
@@ -80,6 +82,10 @@ class DiagonalOperator(EndomorphicOperator):
# ---Mandatory properties and methods---
@
property
def
implemented
(
self
):
return
self
.
_implemented
@
property
def
symmetric
(
self
):
return
self
.
_symmetric
...
...
@@ -116,8 +122,16 @@ class DiagonalOperator(EndomorphicOperator):
datamodel
=
self
.
datamodel
,
copy
=
copy
)
# weight if the given values were `bare`
f
.
weight
(
inplace
=
True
)
# weight if the given values were `bare` and `implemented` is True
# do inverse weightening if the other way around
if
bare
and
self
.
implemented
:
# If `copy` is True, we won't change external data by weightening
# Otherwise, inplace weightening would change the external field
f
.
weight
(
inplace
=
copy
)
elif
not
bare
and
not
self
.
implemented
:
# If `copy` is True, we won't change external data by weightening
# Otherwise, inplace weightening would change the external field
f
.
weight
(
inplace
=
copy
,
power
=-
1
)
# check if the operator is symmetric:
self
.
_symmetric
=
(
f
.
val
.
imag
==
0
).
all
()
...
...
@@ -127,4 +141,3 @@ class DiagonalOperator(EndomorphicOperator):
# store the diagonal-field
self
.
_diagonal
=
f
nifty/operators/fft_operator/__init__.py
View file @
cb713171
from
transformations
import
*
from
fft_operator
import
FFTOperator
\ No newline at end of file
from
fft_operator
import
FFTOperator
nifty/operators/fft_operator/fft_operator.py
View file @
cb713171
...
...
@@ -8,41 +8,26 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
target
=
(),
field_type_target
=
(),
implemented
=
True
):
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
target
=
None
):
super
(
FFTOperator
,
self
).
__init__
(
domain
=
domain
,
field_type
=
field_type
,
implemented
=
implemented
)
field_type
=
field_type
)
if
self
.
domain
==
():
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator needs a single space as '
'input domain.'
))
else
:
if
len
(
self
.
domain
)
>
1
:
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator accepts only a single '
'space as input domain.'
))
if
len
(
self
.
domain
)
!=
1
:
raise
ValueError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator accepts only exactly one '
'space as input domain.'
))
if
self
.
field_type
!=
():
raise
Typ
eError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator field-type
has to
be an '
raise
Valu
eError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator field-type
must
be an '
'empty tuple.'
))
# currently not sanitizing the target
self
.
_target
=
self
.
_parse_domain
(
utilities
.
get_default_codomain
(
self
.
domain
[
0
])
)
self
.
_field_type_target
=
self
.
_parse_field_type
(
field_type_target
)
if
target
is
None
:
target
=
utilities
.
get_default_codomain
(
self
.
domain
[
0
])
if
self
.
field_type_target
!=
():
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator target field-type has to be an '
'empty tuple.'
))
self
.
_target
=
self
.
_parse_domain
(
utilities
.
get_default_codomain
(
self
.
domain
[
0
]))
self
.
_forward_transformation
=
TransformationFactory
.
create
(
self
.
domain
[
0
],
self
.
target
[
0
]
...
...
@@ -52,24 +37,37 @@ class FFTOperator(LinearOperator):
self
.
target
[
0
],
self
.
domain
[
0
]
)
def
adjoint
_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
return
self
.
inverse_times
(
x
,
spaces
,
types
)
def
_times
(
self
,
x
,
spaces
,
types
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
)
)
def
adjoint_inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
return
self
.
times
(
x
,
spaces
,
types
)
new_val
=
self
.
_forward_transformation
.
transform
(
x
.
val
,
axes
=
spaces
)
def
inverse_adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
return
self
.
times
(
x
,
spaces
,
types
)
if
spaces
is
None
:
result_domain
=
self
.
target
else
:
result_domain
=
list
(
x
.
domain
)
result_domain
[
spaces
[
0
]]
=
self
.
target
[
0
]
def
_times
(
self
,
x
,
spaces
,
types
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
)
)
result_field
=
x
.
copy_empty
(
domain
=
result_domain
)
result_field
.
set_val
(
new_val
=
new_val
)
return
self
.
_forward_transformation
.
transform
(
x
.
val
,
axes
=
spaces
)
return
result_field
def
_inverse_times
(
self
,
x
,
spaces
,
types
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
return
self
.
_inverse_transformation
.
transform
(
x
.
val
,
axes
=
spaces
)
new_val
=
self
.
_inverse_transformation
.
transform
(
x
.
val
,
axes
=
spaces
)
if
spaces
is
None
:
result_domain
=
self
.
domain
else
:
result_domain
=
list
(
x
.
domain
)
result_domain
[
spaces
[
0
]]
=
self
.
domain
[
0
]
result_field
=
x
.
copy_empty
(
domain
=
result_domain
)
result_field
.
set_val
(
new_val
=
new_val
)
return
result_field
# ---Mandatory properties and methods---
...
...
@@ -79,5 +77,12 @@ class FFTOperator(LinearOperator):
@
property
def
field_type_target
(
self
):
return
self
.
_field_type_target
return
self
.
field_type
@
property
def
implemented
(
self
):
return
True
@
property
def
unitary
(
self
):
return
True
nifty/operators/linear_operator/linear_operator.py
View file @
cb713171
...
...
@@ -12,10 +12,9 @@ import nifty.nifty_utilities as utilities
class
LinearOperator
(
object
):
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
domain
=
(),
field_type
=
()
,
implemented
=
False
):
def
__init__
(
self
,
domain
=
(),
field_type
=
()):
self
.
_domain
=
self
.
_parse_domain
(
domain
)
self
.
_field_type
=
self
.
_parse_field_type
(
field_type
)
self
.
_implemented
=
bool
(
implemented
)
@
property
def
domain
(
self
):
...
...
@@ -36,8 +35,11 @@ class LinearOperator(object):
def
_parse_domain
(
self
,
domain
):
if
domain
is
None
:
domain
=
()
elif
not
isinstance
(
domain
,
tupl
e
):
elif
isinstance
(
domain
,
Spac
e
):
domain
=
(
domain
,)
elif
not
isinstance
(
domain
,
tuple
):
domain
=
tuple
(
domain
)
for
d
in
domain
:
if
not
isinstance
(
d
,
Space
):
raise
TypeError
(
about
.
_errors
.
cstring
(
...
...
@@ -48,17 +50,20 @@ class LinearOperator(object):
def
_parse_field_type
(
self
,
field_type
):
if
field_type
is
None
:
field_type
=
()
elif
not
isinstance
(
field_type
,
tupl
e
):
elif
isinstance
(
field_type
,
FieldTyp
e
):
field_type
=
(
field_type
,)
elif
not
isinstance
(
field_type
,
tuple
):
field_type
=
tuple
(
field_type
)
for
ft
in
field_type
:
if
not
isinstance
(
ft
,
FieldType
):
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: Given object is not a nifty.FieldType."
))
return
field_type
@
property
@
abc
.
abstract
property
def
implemented
(
self
):
r
eturn
self
.
_i
mplemented
r
aise
NotI
mplemented
Error
@
abc
.
abstractproperty
def
unitary
(
self
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment