Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
c2482aa1
Commit
c2482aa1
authored
May 10, 2017
by
Theo Steininger
Browse files
Added default_spaces property to Operator classes.
parent
8ad1b902
Pipeline
#12218
passed with stages
in 11 minutes and 12 seconds
Changes
9
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
demos/wiener_filter_hamiltonian.py
View file @
c2482aa1
...
...
@@ -107,7 +107,7 @@ if __name__ == "__main__":
# callback=distance_measure,
# max_history_length=3)
m0
=
Field
(
s_space
,
val
=
1
)
m0
=
Field
(
s_space
,
val
=
1
.
)
energy
=
WienerFilterEnergy
(
position
=
m0
,
D
=
D
,
j
=
j
)
...
...
nifty/nifty_utilities.py
View file @
c2482aa1
...
...
@@ -19,6 +19,7 @@
import
numpy
as
np
from
itertools
import
product
def
get_slice_list
(
shape
,
axes
):
"""
Helper function which generates slice list(s) to traverse over all
...
...
@@ -65,8 +66,7 @@ def get_slice_list(shape, axes):
return
def
cast_axis_to_tuple
(
axis
,
length
):
def
cast_axis_to_tuple
(
axis
,
length
=
None
):
if
axis
is
None
:
return
None
try
:
...
...
@@ -78,16 +78,17 @@ def cast_axis_to_tuple(axis, length):
raise
TypeError
(
"Could not convert axis-input to tuple of ints"
)
# shift negative indices to positive ones
axis
=
tuple
(
item
if
(
item
>=
0
)
else
(
item
+
length
)
for
item
in
axis
)
if
length
is
not
None
:
# shift negative indices to positive ones
axis
=
tuple
(
item
if
(
item
>=
0
)
else
(
item
+
length
)
for
item
in
axis
)
# Deactivated this, in order to allow for the ComposedOperator
# remove duplicate entries
# axis = tuple(set(axis))
# Deactivated this, in order to allow for the ComposedOperator
# remove duplicate entries
# axis = tuple(set(axis))
# assert that all entries are elements in [0, length]
for
elem
in
axis
:
assert
(
0
<=
elem
<
length
)
# assert that all entries are elements in [0, length]
for
elem
in
axis
:
assert
(
0
<=
elem
<
length
)
return
axis
...
...
nifty/operators/diagonal_operator/diagonal_operator.py
View file @
c2482aa1
...
...
@@ -30,9 +30,10 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
diagonal
=
None
,
bare
=
False
,
copy
=
True
,
distribution_strategy
=
None
):
def
__init__
(
self
,
domain
=
(),
diagonal
=
None
,
bare
=
False
,
copy
=
True
,
distribution_strategy
=
None
,
default_spaces
=
None
):
super
(
DiagonalOperator
,
self
).
__init__
(
default_spaces
)
self
.
_domain
=
self
.
_parse_domain
(
domain
)
if
distribution_strategy
is
None
:
...
...
nifty/operators/fft_operator/fft_operator.py
View file @
c2482aa1
...
...
@@ -112,7 +112,8 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
,
target
=
None
,
module
=
None
,
domain_dtype
=
None
,
target_dtype
=
None
):
domain_dtype
=
None
,
target_dtype
=
None
,
default_spaces
=
None
):
super
(
FFTOperator
,
self
).
__init__
(
default_spaces
)
# Initialize domain and target
...
...
nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py
View file @
c2482aa1
...
...
@@ -29,6 +29,7 @@ class InvertibleOperatorMixin(object):
else
:
self
.
__inverter
=
ConjugateGradient
(
preconditioner
=
self
.
__preconditioner
)
super
(
InvertibleOperatorMixin
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
_times
(
self
,
x
,
spaces
,
x0
=
None
):
if
x0
is
None
:
...
...
nifty/operators/linear_operator/linear_operator.py
View file @
c2482aa1
...
...
@@ -27,8 +27,8 @@ import nifty.nifty_utilities as utilities
class
LinearOperator
(
Loggable
,
object
):
__metaclass__
=
NiftyMeta
def
__init__
(
self
):
pas
s
def
__init__
(
self
,
default_spaces
=
None
):
self
.
default_spaces
=
default_space
s
def
_parse_domain
(
self
,
domain
):
return
utilities
.
parse_domain
(
domain
)
...
...
@@ -45,6 +45,14 @@ class LinearOperator(Loggable, object):
def
unitary
(
self
):
raise
NotImplementedError
@
property
def
default_spaces
(
self
):
return
self
.
_default_spaces
@
default_spaces
.
setter
def
default_spaces
(
self
,
spaces
):
self
.
_default_spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
times
(
*
args
,
**
kwargs
)
...
...
@@ -127,6 +135,9 @@ class LinearOperator(Loggable, object):
raise
ValueError
(
"supplied object is not a `nifty.Field`."
)
if
spaces
is
None
:
spaces
=
self
.
default_spaces
# sanitize the `spaces` and `types` input
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
...
...
nifty/operators/projection_operator/projection_operator.py
View file @
c2482aa1
...
...
@@ -27,7 +27,9 @@ class ProjectionOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def
__init__
(
self
,
projection_field
):
def
__init__
(
self
,
projection_field
,
default_spaces
=
None
):
super
(
ProjectionOperator
,
self
).
__init__
(
default_spaces
)
if
not
isinstance
(
projection_field
,
Field
):
raise
TypeError
(
"The projection_field must be a NIFTy-Field"
"instance."
)
...
...
nifty/operators/propagator_operator/propagator_operator.py
View file @
c2482aa1
...
...
@@ -26,7 +26,7 @@ class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator):
# ---Overwritten properties and methods---
def
__init__
(
self
,
S
=
None
,
M
=
None
,
R
=
None
,
N
=
None
,
inverter
=
None
,
preconditioner
=
None
):
preconditioner
=
None
,
default_spaces
=
None
):
"""
Sets the standard operator properties and `codomain`, `_A1`, `_A2`,
and `RN` if required.
...
...
@@ -66,7 +66,8 @@ class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner
=
self
.
_S_times
super
(
PropagatorOperator
,
self
).
__init__
(
inverter
=
inverter
,
preconditioner
=
preconditioner
)
preconditioner
=
preconditioner
,
default_spaces
=
default_spaces
)
# ---Mandatory properties and methods---
...
...
nifty/operators/smoothing_operator/smoothing_operator.py
View file @
c2482aa1
...
...
@@ -27,7 +27,9 @@ from d2o import STRATEGIES
class
SmoothingOperator
(
EndomorphicOperator
):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
sigma
=
0
,
log_distances
=
False
):
def
__init__
(
self
,
domain
=
(),
sigma
=
0
,
log_distances
=
False
,
default_spaces
=
None
):
super
(
SmoothingOperator
,
self
).
__init__
(
default_spaces
)
self
.
_domain
=
self
.
_parse_domain
(
domain
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment