Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
68b71aca
Commit
68b71aca
authored
Oct 29, 2017
by
Martin Reinecke
Browse files
locate all places where adjustments are needed for distributed fields
parent
5bbff04b
Pipeline
#20800
passed with stage
in 4 minutes and 8 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/data_objects/my_own_do.py
View file @
68b71aca
...
...
@@ -56,8 +56,8 @@ class data_object(object):
a
=
self
.
_data
if
isinstance
(
other
,
data_object
):
b
=
other
.
_data
#
if a.shape != b.shape:
#
print
("shapes are incompatible.")
if
a
.
shape
!=
b
.
shape
:
raise
ValueError
(
"shapes are incompatible."
)
else
:
b
=
other
...
...
nifty/dobj.py
View file @
68b71aca
from
.data_objects.my_own_do
import
*
#from .data_objects.numpy_do import *
nifty/operators/diagonal_operator.py
View file @
68b71aca
...
...
@@ -22,6 +22,7 @@ from ..field import Field
from
..domain_tuple
import
DomainTuple
from
.endomorphic_operator
import
EndomorphicOperator
from
..nifty_utilities
import
cast_iseq_to_tuple
from
..dobj
import
to_ndarray
as
to_np
class
DiagonalOperator
(
EndomorphicOperator
):
...
...
@@ -97,16 +98,16 @@ class DiagonalOperator(EndomorphicOperator):
self
.
_unitary
=
None
def
_times
(
self
,
x
):
return
self
.
_times_helper
(
x
,
lambda
z
:
z
.
__mul__
)
return
self
.
_times_helper
(
x
,
self
.
_diagonal
)
def
_adjoint_times
(
self
,
x
):
return
self
.
_times_helper
(
x
,
lambda
z
:
z
.
conjugate
().
__mul__
)
return
self
.
_times_helper
(
x
,
self
.
_diagonal
.
conj
()
)
def
_inverse_times
(
self
,
x
):
return
self
.
_times_helper
(
x
,
lambda
z
:
z
.
__rtruediv__
)
return
self
.
_times_helper
(
x
,
1.
/
self
.
_diagonal
)
def
_adjoint_inverse_times
(
self
,
x
):
return
self
.
_times_helper
(
x
,
lambda
z
:
z
.
conjugate
().
__rtruediv__
)
return
self
.
_times_helper
(
x
,
1.
/
self
.
_diagonal
.
conj
()
)
def
diagonal
(
self
):
""" Returns the diagonal of the Operator.
...
...
@@ -137,9 +138,9 @@ class DiagonalOperator(EndomorphicOperator):
self
.
_unitary
=
(
abs
(
self
.
_diagonal
.
val
)
==
1.
).
all
()
return
self
.
_unitary
def
_times_helper
(
self
,
x
,
operation
):
def
_times_helper
(
self
,
x
,
diag
):
if
self
.
_spaces
is
None
:
return
operation
(
self
.
_diagonal
)(
x
)
return
diag
*
x
active_axes
=
[]
for
space_index
in
self
.
_spaces
:
...
...
@@ -147,7 +148,7 @@ class DiagonalOperator(EndomorphicOperator):
reshaper
=
[
shp
if
i
in
active_axes
else
1
for
i
,
shp
in
enumerate
(
x
.
shape
)]
reshaped_local_diagonal
=
self
.
_diagonal
.
val
.
reshape
(
reshaper
)
reshaped_local_diagonal
=
to_np
(
diag
.
val
)
.
reshape
(
reshaper
)
# here the actual multiplication takes place
return
Field
(
x
.
domain
,
val
=
operation
(
reshaped_local_diagonal
)
(
x
.
val
))
return
Field
(
x
.
domain
,
val
=
x
.
val
*
reshaped_local_diagonal
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment