Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
71262d82
Commit
71262d82
authored
Oct 17, 2019
by
Philipp Arras
Browse files
Add dtype checks
parent
9a2cc287
Pipeline
#62089
passed with stages
in 9 minutes and 6 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty5/operator_spectrum.py
View file @
71262d82
...
...
@@ -18,7 +18,9 @@ import scipy.sparse.linalg as ssl
from
.domain_tuple
import
DomainTuple
from
.domains.unstructured_domain
import
UnstructuredDomain
from
.field
import
Field
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.operators.linear_operator
import
LinearOperator
from
.operators.sandwich_operator
import
SandwichOperator
from
.sugar
import
from_global_data
,
makeDomain
...
...
@@ -52,12 +54,14 @@ class _DomRemover(LinearOperator):
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
self
.
_check_float_dtype
(
x
)
x
=
x
.
to_global_data
()
if
isinstance
(
self
.
_domain
,
DomainTuple
):
res
=
x
.
ravel
()
if
mode
==
self
.
TIMES
else
x
.
reshape
(
self
.
_domain
.
shape
)
else
:
res
=
np
.
empty
(
self
.
target
.
shape
,
dtype
=
x
.
dtype
)
if
mode
==
self
.
TIMES
else
{}
res
=
np
.
empty
(
self
.
target
.
shape
)
if
mode
==
self
.
TIMES
else
{}
for
ii
,
(
kk
,
dd
)
in
enumerate
(
self
.
domain
.
items
()):
i0
,
i1
=
self
.
_size_array
[
ii
:
ii
+
2
]
if
mode
==
self
.
TIMES
:
...
...
@@ -66,6 +70,18 @@ class _DomRemover(LinearOperator):
res
[
kk
]
=
x
[
i0
:
i1
].
reshape
(
dd
.
shape
)
return
from_global_data
(
self
.
_tgt
(
mode
),
res
)
@
staticmethod
def
_check_float_dtype
(
fld
):
if
isinstance
(
fld
,
MultiField
):
dts
=
[
ff
.
local_data
.
dtype
for
ff
in
fld
.
values
()]
elif
isinstance
(
fld
,
Field
):
dts
=
[
fld
.
local_data
.
dtype
]
else
:
raise
TypeError
for
dt
in
dts
:
if
not
np
.
issubdtype
(
dt
,
np
.
float64
):
raise
TypeError
(
'Operator supports only floating point dtypes'
)
def
operator_spectrum
(
A
,
k
,
hermitian
,
which
=
'LM'
,
tol
=
0
):
'''
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment