Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
88d89ed6
Commit
88d89ed6
authored
Aug 20, 2016
by
theos
Browse files
Finalized LinearOperator base class. Added a first version of a square operator.
parent
0448204d
Changes
9
Hide whitespace changes
Inline
Side-by-side
nifty/nifty_utilities.py
View file @
88d89ed6
...
...
@@ -236,6 +236,14 @@ def cast_axis_to_tuple(axis, length):
# shift negative indices to positive ones
axis
=
tuple
(
item
if
(
item
>=
0
)
else
(
item
+
length
)
for
item
in
axis
)
# remove duplicate entries
axis
=
tuple
(
set
(
axis
))
# assert that all entries are elements in [0, length]
for
elem
in
axis
:
assert
(
0
<=
elem
<
length
)
return
axis
...
...
nifty/operators/__init__.py
View file @
88d89ed6
...
...
@@ -20,6 +20,13 @@
## along with this program. If not, see <http://www.gnu.org/licenses/>.
from
__future__
import
division
from
linear_operator
import
LinearOperator
,
\
LinearOperatorParadict
from
square_operator
import
SquareOperator
,
\
SquareOperatorParadict
from
nifty_operators
import
operator
,
\
diagonal_operator
,
\
power_operator
,
\
...
...
nifty/operators/linear_operator/__init__.py
0 → 100644
View file @
88d89ed6
# -*- coding: utf-8 -*-
from
linear_operator
import
LinearOperator
from
linear_operator_paradict
import
LinearOperatorParadict
nifty/operators/linear_operator/linear_operator.py
0 → 100644
View file @
88d89ed6
# -*- coding: utf-8 -*-
from
nifty.config
import
about
from
nifty.field
import
Field
from
nifty.spaces
import
Space
from
nifty.field_types
import
FieldType
import
nifty.nifty_utilities
as
utilities
from
linear_operator_paradict
import
LinearOperatorParadict
class
LinearOperator
(
object
):
def
__init__
(
self
,
domain
=
None
,
target
=
None
,
field_type
=
None
,
field_type_target
=
None
,
implemented
=
False
,
symmetric
=
False
,
unitary
=
False
):
self
.
paradict
=
LinearOperatorParadict
()
self
.
_implemented
=
bool
(
implemented
)
self
.
domain
=
self
.
_parse_domain
(
domain
)
self
.
target
=
self
.
_parse_domain
(
target
)
self
.
field_type
=
self
.
_parse_field_type
(
field_type
)
self
.
field_type_target
=
self
.
_parse_field_type
(
field_type_target
)
def
_parse_domain
(
self
,
domain
):
if
domain
is
None
:
domain
=
()
elif
not
isinstance
(
domain
,
tuple
):
domain
=
(
domain
,)
for
d
in
domain
:
if
not
isinstance
(
d
,
Space
):
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: Given object contains something that is not a "
"nifty.space."
))
return
domain
def
_parse_field_type
(
self
,
field_type
):
if
field_type
is
None
:
field_type
=
()
elif
not
isinstance
(
field_type
,
tuple
):
field_type
=
(
field_type
,)
for
ft
in
field_type
:
if
not
isinstance
(
ft
,
FieldType
):
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: Given object is not a nifty.FieldType."
))
return
field_type
@
property
def
implemented
(
self
):
return
self
.
_implemented
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
times
(
*
args
,
**
kwargs
)
def
times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
spaces
,
types
=
self
.
_check_input_compatibility
(
x
,
spaces
,
types
)
if
not
self
.
implemented
:
x
=
x
.
weight
(
spaces
=
spaces
)
y
=
self
.
_times
(
x
,
spaces
,
types
)
return
y
def
inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
spaces
,
types
=
self
.
_check_input_compatibility
(
x
,
spaces
,
types
)
y
=
self
.
_inverse_times
(
x
,
spaces
,
types
)
if
not
self
.
implemented
:
y
=
y
.
weight
(
power
=-
1
,
spaces
=
spaces
)
return
y
def
adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
spaces
,
types
=
self
.
_check_input_compatibility
(
x
,
spaces
,
types
)
if
not
self
.
implemented
:
x
=
x
.
weight
(
spaces
=
spaces
)
y
=
self
.
_adjoint_times
(
x
,
spaces
,
types
)
return
y
def
adjoint_inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
spaces
,
types
=
self
.
_check_input_compatibility
(
x
,
spaces
,
types
)
y
=
self
.
_adjoint_inverse_times
(
x
,
spaces
,
types
)
if
not
self
.
implemented
:
y
=
y
.
weight
(
power
=-
1
,
spaces
=
spaces
)
return
y
def
inverse_adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
spaces
,
types
=
self
.
_check_input_compatibility
(
x
,
spaces
,
types
)
y
=
self
.
_inverse_adjoint_times
(
x
,
spaces
,
types
)
if
not
self
.
implemented
:
y
=
y
.
weight
(
power
=-
1
,
spaces
=
spaces
)
return
y
def
_times
(
self
,
x
,
spaces
,
types
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'times'."
))
def
_adjoint_times
(
self
,
x
,
spaces
,
types
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'adjoint_times'."
))
def
_inverse_times
(
self
,
x
,
spaces
,
types
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'inverse_times'."
))
def
_adjoint_inverse_times
(
self
,
x
,
spaces
,
types
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'adjoint_inverse_times'."
))
def
_inverse_adjoint_times
(
self
,
x
,
spaces
,
types
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'inverse_adjoint_times'."
))
def
_check_input_compatibility
(
self
,
x
,
spaces
,
types
):
if
not
isinstance
(
x
,
Field
):
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: supplied object is not a `nifty.Field`."
))
# sanitize the `spaces` and `types` input
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
types
=
utilities
.
cast_axis_to_tuple
(
types
,
len
(
x
.
field_type
))
# if the operator's domain is set to something, there are two valid
# cases:
# 1. Case:
# The user specifies with `spaces` that the operators domain should
# be applied to a certain domain in the domain-tuple of x. This is
# only valid if len(self.domain)==1.
# 2. Case:
# The domains of self and x match completely.
if
spaces
is
None
:
if
self
.
domain
!=
()
and
self
.
domain
!=
x
.
domain
:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: The operator's and and field's domains don't "
"match."
))
else
:
if
len
(
self
.
domain
)
>
1
:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: Specifying `spaces` for operators with multiple "
"domain spaces is not valid."
))
elif
len
(
spaces
)
!=
len
(
self
.
domain
):
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: Length of `spaces` does not match the number of "
"spaces in the operator's domain."
))
elif
len
(
spaces
)
==
1
:
if
x
.
domain
[
spaces
[
0
]]
!=
self
.
domain
[
0
]:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: The operator's and and field's domains don't "
"match."
))
if
types
is
None
:
if
self
.
field_type
!=
()
and
self
.
field_type
!=
x
.
field_type
:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: The operator's and and field's field_types don't "
"match."
))
else
:
if
len
(
self
.
field_type
)
>
1
:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: Specifying `types` for operators with multiple "
"field-types is not valid."
))
elif
len
(
types
)
!=
len
(
self
.
field_type
):
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: Length of `types` does not match the number of "
"the operator's field-types."
))
elif
len
(
types
)
==
1
:
if
x
.
field_type
[
types
[
0
]]
!=
self
.
field_type
[
0
]:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: The operator's and and field's field_type "
"don't match."
))
return
(
spaces
,
types
)
def
__repr__
(
self
):
return
str
(
self
.
__class__
)
nifty/operators/operator/operator_paradict.py
→
nifty/operators/
linear_
operator/
linear_
operator_paradict.py
View file @
88d89ed6
...
...
@@ -3,5 +3,5 @@
from
nifty.paradict
import
Paradict
class
OperatorParadict
(
Paradict
):
class
Linear
OperatorParadict
(
Paradict
):
pass
nifty/operators/operator/operator.py
deleted
100644 → 0
View file @
0448204d
# -*- coding: utf-8 -*-
from
nifty.config
import
about
from
operator_paradict
import
OperatorParadict
class
LinearOperator
(
object
):
def
__init__
(
self
,
domain
=
None
,
target
=
None
,
field_type
=
None
,
field_type_target
=
None
,
implemented
=
False
,
symmetric
=
False
,
unitary
=
False
,
**
kwargs
):
self
.
paradict
=
OperatorParadict
(
**
kwargs
)
self
.
implemented
=
implemented
self
.
symmetric
=
symmetric
self
.
unitary
=
unitary
@
property
def
implemented
(
self
):
return
self
.
_implemented
@
implemented
.
setter
def
implemented
(
self
,
b
):
self
.
_implemented
=
bool
(
b
)
@
property
def
symmetric
(
self
):
return
self
.
_symmetric
@
symmetric
.
setter
def
symmetric
(
self
,
b
):
self
.
_symmetric
=
bool
(
b
)
@
property
def
unitary
(
self
):
return
self
.
_unitary
@
unitary
.
setter
def
unitary
(
self
,
b
):
self
.
_unitary
=
bool
(
b
)
def
times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
raise
NotImplementedError
def
adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
raise
NotImplementedError
def
inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
raise
NotImplementedError
def
adjoint_inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
raise
NotImplementedError
def
inverse_adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
raise
NotImplementedError
def
_times
(
self
,
x
,
**
kwargs
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'times'."
))
def
_adjoint_times
(
self
,
x
,
**
kwargs
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'adjoint_times'."
))
def
_inverse_times
(
self
,
x
,
**
kwargs
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'inverse_times'."
))
def
_adjoint_inverse_times
(
self
,
x
,
**
kwargs
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'adjoint_inverse_times'."
))
def
_inverse_adjoint_times
(
self
,
x
,
**
kwargs
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'inverse_adjoint_times'."
))
def
_check_input_compatibility
(
self
,
x
,
spaces
,
types
):
# assert: x is a field
# if spaces is None -> assert f.domain == self.domain
# -> same for field_type
# else: check if self.domain/self.field_type == one entry.
#
nifty/operators/square_operator/__init__.py
0 → 100644
View file @
88d89ed6
# -*- coding: utf-8 -*-
from
square_operator
import
SquareOperator
from
square_operator_paradict
import
SquareOperatorParadict
nifty/operators/square_operator/square_operator.py
0 → 100644
View file @
88d89ed6
# -*- coding: utf-8 -*-
from
nifty.config
import
about
from
nifty.operators.linear_operator
import
LinearOperator
from
square_operator_paradict
import
SquareOperatorParadict
class
SquareOperator
(
LinearOperator
):
def
__init__
(
self
,
domain
=
None
,
target
=
None
,
field_type
=
None
,
field_type_target
=
None
,
implemented
=
False
,
symmetric
=
False
,
unitary
=
False
):
if
target
is
not
None
:
about
.
warnings
.
cprint
(
"WARNING: Discarding given target for SquareOperator."
)
target
=
domain
if
field_type_target
is
not
None
:
about
.
warnings
.
cprint
(
"WARNING: Discarding given field_type_target for "
"SquareOperator."
)
field_type_target
=
field_type
LinearOperator
.
__init__
(
self
,
domain
=
domain
,
target
=
target
,
field_type
=
field_type
,
field_type_target
=
field_type_target
,
implemented
=
implemented
)
self
.
paradict
=
SquareOperatorParadict
(
symmetric
=
symmetric
,
unitary
=
unitary
)
def
inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
if
self
.
paradict
[
'symmetric'
]
and
self
.
paradict
[
'unitary'
]:
return
self
.
times
(
x
,
spaces
,
types
)
else
:
return
LinearOperator
.
inverse_times
(
self
,
x
=
x
,
spaces
=
spaces
,
types
=
types
)
def
adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
if
self
.
paradict
[
'symmetric'
]:
return
self
.
times
(
x
,
spaces
,
types
)
elif
self
.
paradict
[
'unitary'
]:
return
self
.
inverse_times
(
x
,
spaces
,
types
)
else
:
return
LinearOperator
.
adjoint_times
(
self
,
x
=
x
,
spaces
=
spaces
,
types
=
types
)
def
adjoint_inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
if
self
.
paradict
[
'symmetric'
]:
return
self
.
inverse_times
(
x
,
spaces
,
types
)
elif
self
.
paradict
[
'unitary'
]:
return
self
.
times
(
x
,
spaces
,
types
)
else
:
return
LinearOperator
.
adjoint_inverse_times
(
self
,
x
=
x
,
spaces
=
spaces
,
types
=
types
)
def
inverse_adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
if
self
.
paradict
[
'symmetric'
]:
return
self
.
inverse_times
(
x
,
spaces
,
types
)
elif
self
.
paradict
[
'unitary'
]:
return
self
.
times
(
x
,
spaces
,
types
)
else
:
return
LinearOperator
.
inverse_adjoint_times
(
self
,
x
=
x
,
spaces
=
spaces
,
types
=
types
)
def
trace
(
self
):
pass
def
inverse_trace
(
self
):
pass
def
diagonal
(
self
):
pass
def
inverse_diagonal
(
self
):
pass
def
determinant
(
self
):
pass
def
inverse_determinant
(
self
):
pass
def
log_determinant
(
self
):
pass
def
trace_log
(
self
):
pass
nifty/operators/square_operator/square_operator_paradict.py
0 → 100644
View file @
88d89ed6
# -*- coding: utf-8 -*-
from
nifty.config
import
about
from
nifty.operators.linear_operator
import
LinearOperatorParadict
class
SquareOperatorParadict
(
LinearOperatorParadict
):
def
__init__
(
self
,
symmetric
,
unitary
):
LinearOperatorParadict
.
__init__
(
self
,
symmetric
=
symmetric
,
unitary
=
unitary
)
def
__setitem__
(
self
,
key
,
arg
):
if
key
not
in
[
'symmetric'
,
'unitary'
]:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: Unsupported SquareOperator parameter: "
+
key
))
if
key
==
'symmetric'
:
temp
=
bool
(
arg
)
elif
key
==
'unitary'
:
temp
=
bool
(
arg
)
self
.
parameters
.
__setitem__
(
key
,
temp
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment