Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
1cb1c94e
Commit
1cb1c94e
authored
Jul 14, 2018
by
Martin Reinecke
Browse files
generalize DomainDistributor; some more operator tests
parent
6ea15ac4
Changes
6
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/domain_distributor.py
View file @
1cb1c94e
...
...
@@ -25,26 +25,16 @@ from ..compat import *
from
..domain_tuple
import
DomainTuple
from
..field
import
Field
from
.linear_operator
import
LinearOperator
from
..
import
utilities
# MR FIXME: this needs to be rewritten in a generic fashion
class
DomainDistributor
(
LinearOperator
):
def
__init__
(
self
,
target
,
axis
):
if
dobj
.
ntask
>
1
:
raise
NotImplementedError
(
'UpProj class does not support MPI.'
)
assert
len
(
target
)
==
2
assert
axis
in
[
0
,
1
]
if
axis
==
0
:
domain
=
target
[
1
]
self
.
_size
=
target
[
0
].
size
else
:
domain
=
target
[
0
]
self
.
_size
=
target
[
1
].
size
self
.
_axis
=
axis
self
.
_domain
=
DomainTuple
.
make
(
domain
)
def
__init__
(
self
,
target
,
spaces
):
self
.
_target
=
DomainTuple
.
make
(
target
)
self
.
_spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_target
))
self
.
_domain
=
[
tgt
for
i
,
tgt
in
enumerate
(
self
.
_target
)
if
i
in
self
.
_spaces
]
self
.
_domain
=
DomainTuple
.
make
(
self
.
_domain
)
@
property
def
domain
(
self
):
...
...
@@ -57,23 +47,16 @@ class DomainDistributor(LinearOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
x
=
x
.
local_data
otherDirection
=
np
.
ones
(
self
.
_size
)
if
self
.
_axis
==
0
:
res
=
np
.
outer
(
otherDirection
,
x
)
else
:
res
=
np
.
outer
(
x
,
otherDirection
)
res
=
res
.
reshape
(
dobj
.
local_shape
(
self
.
target
.
shape
))
return
Field
.
from_local_data
(
self
.
target
,
res
)
ldat
=
x
.
local_data
if
0
in
self
.
_spaces
else
x
.
to_global_data
()
shp
=
[]
for
i
,
tgt
in
enumerate
(
self
.
_target
):
tmp
=
tgt
.
shape
if
i
>
0
else
tgt
.
local_shape
shp
+=
tmp
if
i
in
self
.
_spaces
else
(
1
,)
*
len
(
tgt
.
shape
)
ldat
=
np
.
broadcast_to
(
ldat
.
reshape
(
shp
),
self
.
_target
.
local_shape
)
return
Field
.
from_local_data
(
self
.
_target
,
ldat
)
else
:
if
self
.
_axis
==
0
:
x
=
x
.
local_data
.
reshape
(
self
.
_size
,
-
1
)
res
=
np
.
sum
(
x
,
axis
=
0
)
else
:
x
=
x
.
local_data
.
reshape
(
-
1
,
self
.
_size
)
res
=
np
.
sum
(
x
,
axis
=
1
)
res
=
res
.
reshape
(
dobj
.
local_shape
(
self
.
domain
.
shape
))
return
Field
.
from_local_data
(
self
.
domain
,
res
)
return
x
.
sum
([
s
for
s
in
range
(
len
(
x
.
domain
))
if
s
not
in
self
.
_spaces
])
@
property
def
capability
(
self
):
...
...
nifty5/operators/exp_transform.py
View file @
1cb1c94e
...
...
@@ -27,12 +27,13 @@ from ..domains.power_space import PowerSpace
from
..domains.rg_space
import
RGSpace
from
..field
import
Field
from
.linear_operator
import
LinearOperator
from
..
import
utilities
class
ExpTransform
(
LinearOperator
):
def
__init__
(
self
,
target
,
dof
,
space
=
0
):
self
.
_target
=
DomainTuple
.
make
(
target
)
self
.
_space
=
int
(
space
)
self
.
_space
=
utilities
.
infer_space
(
self
.
_target
,
space
)
tgt
=
self
.
_target
[
self
.
_space
]
if
not
((
isinstance
(
tgt
,
RGSpace
)
and
tgt
.
harmonic
)
or
isinstance
(
tgt
,
PowerSpace
)):
...
...
nifty5/operators/field_zero_padder.py
View file @
1cb1c94e
...
...
@@ -8,13 +8,14 @@ from ..domain_tuple import DomainTuple
from
..domains.rg_space
import
RGSpace
from
..field
import
Field
from
.linear_operator
import
LinearOperator
from
..
import
utilities
class
FieldZeroPadder
(
LinearOperator
):
def
__init__
(
self
,
domain
,
factor
,
space
=
0
):
super
(
FieldZeroPadder
,
self
).
__init__
()
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_space
=
int
(
space
)
self
.
_space
=
utilities
.
infer_space
(
self
.
_domain
,
space
)
dom
=
self
.
_domain
[
self
.
_space
]
if
not
isinstance
(
dom
,
RGSpace
):
raise
TypeError
(
"RGSpace required"
)
...
...
@@ -52,11 +53,11 @@ class FieldZeroPadder(LinearOperator):
curax
=
dobj
.
distaxis
(
x
)
if
mode
==
self
.
ADJOINT_TIMES
:
newarr
=
np
.
empty
(
dobj
.
local_shape
(
shp_out
),
dtype
=
x
.
dtype
)
newarr
=
np
.
empty
(
dobj
.
local_shape
(
shp_out
,
curax
),
dtype
=
x
.
dtype
)
newarr
[()]
=
dobj
.
local_data
(
x
)[(
slice
(
None
),)
*
ax
+
(
slice
(
0
,
shp_out
[
ax
]),)]
else
:
newarr
=
np
.
zeros
(
dobj
.
local_shape
(
shp_out
),
dtype
=
x
.
dtype
)
newarr
=
np
.
zeros
(
dobj
.
local_shape
(
shp_out
,
curax
),
dtype
=
x
.
dtype
)
newarr
[(
slice
(
None
),)
*
ax
+
(
slice
(
0
,
shp_in
[
ax
]),)]
=
dobj
.
local_data
(
x
)
newarr
=
dobj
.
from_local_data
(
shp_out
,
newarr
,
distaxis
=
curax
)
...
...
nifty5/operators/qht_operator.py
View file @
1cb1c94e
...
...
@@ -22,7 +22,7 @@ from .. import dobj
from
..compat
import
*
from
..domain_tuple
import
DomainTuple
from
..field
import
Field
from
..utilities
import
hartley
from
..utilities
import
hartley
,
infer_space
from
.linear_operator
import
LinearOperator
...
...
@@ -47,7 +47,7 @@ class QHTOperator(LinearOperator):
"""
def
__init__
(
self
,
domain
,
target
,
space
=
0
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_space
=
in
t
(
space
)
self
.
_space
=
in
fer_space
(
self
.
_domain
,
space
)
from
..domains.log_rg_space
import
LogRGSpace
if
not
isinstance
(
self
.
_domain
[
self
.
_space
],
LogRGSpace
):
...
...
nifty5/operators/symmetrizing_operator.py
View file @
1cb1c94e
...
...
@@ -24,15 +24,16 @@ from ..domain_tuple import DomainTuple
from
..domains.log_rg_space
import
LogRGSpace
from
..field
import
Field
from
.endomorphic_operator
import
EndomorphicOperator
from
..
import
utilities
class
SymmetrizingOperator
(
EndomorphicOperator
):
def
__init__
(
self
,
domain
,
space
=
0
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_space
=
int
(
space
)
self
.
_space
=
utilities
.
infer_space
(
self
.
_domain
,
space
)
dom
=
self
.
_domain
[
self
.
_space
]
if
not
(
isinstance
(
dom
,
LogRGSpace
)
and
not
dom
.
harmonic
):
raise
TypeError
raise
TypeError
(
"nonharmonic LogRGSpace needed"
)
@
property
def
domain
(
self
):
...
...
test/test_operators/test_adjoint.py
View file @
1cb1c94e
...
...
@@ -101,3 +101,25 @@ class Consistency_Tests(unittest.TestCase):
def
testGeometryRemover
(
self
,
sp
,
dtype
):
op
=
ift
.
GeometryRemover
(
sp
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
([
0
,
1
,
2
,
3
,
(
0
,
1
),
(
0
,
2
),
(
0
,
1
,
2
),
(
0
,
2
,
3
),
(
1
,
3
)],
[
np
.
float64
,
np
.
complex128
]))
def
testDomainDistributor
(
self
,
spaces
,
dtype
):
dom
=
(
ift
.
RGSpace
(
10
),
ift
.
UnstructuredDomain
(
13
),
ift
.
GLSpace
(
5
),
ift
.
HPSpace
(
4
))
op
=
ift
.
DomainDistributor
(
dom
,
spaces
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
([
0
,
2
],
[
np
.
float64
,
np
.
complex128
]))
def
testSymmetrizingOperator
(
self
,
space
,
dtype
):
dom
=
(
ift
.
LogRGSpace
(
10
,
[
2.
],
[
1.
]),
ift
.
UnstructuredDomain
(
13
),
ift
.
LogRGSpace
((
5
,
27
),
[
1.
,
2.7
],
[
0.
,
4.
]),
ift
.
HPSpace
(
4
))
op
=
ift
.
SymmetrizingOperator
(
dom
,
space
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
([
0
,
2
],
[
2
,
2.7
],
[
np
.
float64
,
np
.
complex128
]))
def
testZeroPadder
(
self
,
space
,
factor
,
dtype
):
dom
=
(
ift
.
RGSpace
(
10
),
ift
.
UnstructuredDomain
(
13
),
ift
.
RGSpace
(
7
),
ift
.
HPSpace
(
4
))
op
=
ift
.
FieldZeroPadder
(
dom
,
factor
,
space
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment