Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
Neel Shah
NIFTy
Commits
71262d82
Commit
71262d82
authored
Oct 17, 2019
by
Philipp Arras
Browse files
Add dtype checks
parent
9a2cc287
Changes
1
Hide whitespace changes
Inline
Side-by-side
nifty5/operator_spectrum.py
View file @
71262d82
...
...
@@ -18,7 +18,9 @@ import scipy.sparse.linalg as ssl
from
.domain_tuple
import
DomainTuple
from
.domains.unstructured_domain
import
UnstructuredDomain
from
.field
import
Field
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.operators.linear_operator
import
LinearOperator
from
.operators.sandwich_operator
import
SandwichOperator
from
.sugar
import
from_global_data
,
makeDomain
...
...
@@ -52,12 +54,14 @@ class _DomRemover(LinearOperator):
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
self
.
_check_float_dtype
(
x
)
x
=
x
.
to_global_data
()
if
isinstance
(
self
.
_domain
,
DomainTuple
):
res
=
x
.
ravel
()
if
mode
==
self
.
TIMES
else
x
.
reshape
(
self
.
_domain
.
shape
)
else
:
res
=
np
.
empty
(
self
.
target
.
shape
,
dtype
=
x
.
dtype
)
if
mode
==
self
.
TIMES
else
{}
res
=
np
.
empty
(
self
.
target
.
shape
)
if
mode
==
self
.
TIMES
else
{}
for
ii
,
(
kk
,
dd
)
in
enumerate
(
self
.
domain
.
items
()):
i0
,
i1
=
self
.
_size_array
[
ii
:
ii
+
2
]
if
mode
==
self
.
TIMES
:
...
...
@@ -66,6 +70,18 @@ class _DomRemover(LinearOperator):
res
[
kk
]
=
x
[
i0
:
i1
].
reshape
(
dd
.
shape
)
return
from_global_data
(
self
.
_tgt
(
mode
),
res
)
@
staticmethod
def
_check_float_dtype
(
fld
):
if
isinstance
(
fld
,
MultiField
):
dts
=
[
ff
.
local_data
.
dtype
for
ff
in
fld
.
values
()]
elif
isinstance
(
fld
,
Field
):
dts
=
[
fld
.
local_data
.
dtype
]
else
:
raise
TypeError
for
dt
in
dts
:
if
not
np
.
issubdtype
(
dt
,
np
.
float64
):
raise
TypeError
(
'Operator supports only floating point dtypes'
)
def
operator_spectrum
(
A
,
k
,
hermitian
,
which
=
'LM'
,
tol
=
0
):
'''
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment