Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
7db6d4f9
Commit
7db6d4f9
authored
Aug 08, 2018
by
Martin Reinecke
Browse files
compatification
parent
12c44bb4
Changes
3
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/field_zero_padder.py
View file @
7db6d4f9
...
...
@@ -25,9 +25,8 @@ class FieldZeroPadder(LinearOperator):
raise
ValueError
(
"Shape mismatch"
)
if
any
([
a
<
b
for
a
,
b
in
zip
(
new_shape
,
dom
.
shape
)]):
raise
ValueError
(
"New shape must be larger than old shape"
)
tgt
=
RGSpace
(
new_shape
,
dom
.
distances
)
self
.
_target
=
list
(
self
.
_domain
)
self
.
_target
[
self
.
_space
]
=
tgt
self
.
_target
[
self
.
_space
]
=
RGSpace
(
new_shape
,
dom
.
distances
)
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
...
...
nifty5/operators/laplace_operator.py
View file @
7db6d4f9
...
...
@@ -68,60 +68,38 @@ class LaplaceOperator(EndomorphicOperator):
self
.
_dposc
[
1
:]
+=
self
.
_dpos
self
.
_dposc
*=
0.5
def
_times
(
self
,
x
):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axis
=
axes
[
0
]
locval
=
x
.
val
if
axis
==
dobj
.
distaxis
(
locval
):
locval
=
dobj
.
redistribute
(
locval
,
nodist
=
(
axis
,))
val
=
dobj
.
local_data
(
locval
)
nval
=
len
(
self
.
_dposc
)
prefix
=
(
slice
(
None
),)
*
axis
sl_l
=
prefix
+
(
slice
(
None
,
-
1
),)
# "left" slice
sl_r
=
prefix
+
(
slice
(
1
,
None
),)
# "right" slice
dpos
=
self
.
_dpos
.
reshape
((
1
,)
*
axis
+
(
nval
-
1
,))
dposc
=
self
.
_dposc
.
reshape
((
1
,)
*
axis
+
(
nval
,))
deriv
=
(
val
[
sl_r
]
-
val
[
sl_l
])
/
dpos
# defined between points
locval
=
x
.
val
if
axis
==
dobj
.
distaxis
(
locval
):
locval
=
dobj
.
redistribute
(
locval
,
nodist
=
(
axis
,))
val
=
dobj
.
local_data
(
locval
)
ret
=
np
.
empty_like
(
val
)
ret
[
sl_l
]
=
deriv
ret
[
prefix
+
(
-
1
,)]
=
0.
ret
[
sl_r
]
-=
deriv
ret
/=
dposc
ret
[
prefix
+
(
slice
(
None
,
2
),)]
=
0.
ret
[
prefix
+
(
-
1
,)]
=
0.
if
mode
==
self
.
TIMES
:
deriv
=
(
val
[
sl_r
]
-
val
[
sl_l
])
/
dpos
# defined between points
ret
[
sl_l
]
=
deriv
ret
[
prefix
+
(
-
1
,)]
=
0.
ret
[
sl_r
]
-=
deriv
ret
/=
dposc
ret
[
prefix
+
(
slice
(
None
,
2
),)]
=
0.
ret
[
prefix
+
(
-
1
,)]
=
0.
else
:
val
=
val
/
dposc
val
[
prefix
+
(
slice
(
None
,
2
),)]
=
0.
val
[
prefix
+
(
-
1
,)]
=
0.
deriv
=
(
val
[
sl_r
]
-
val
[
sl_l
])
/
dpos
# defined between points
ret
[
sl_l
]
=
deriv
ret
[
prefix
+
(
-
1
,)]
=
0.
ret
[
sl_r
]
-=
deriv
ret
=
dobj
.
from_local_data
(
locval
.
shape
,
ret
,
dobj
.
distaxis
(
locval
))
if
dobj
.
distaxis
(
locval
)
!=
dobj
.
distaxis
(
x
.
val
):
ret
=
dobj
.
redistribute
(
ret
,
dist
=
dobj
.
distaxis
(
x
.
val
))
return
Field
(
self
.
domain
,
val
=
ret
)
def
_adjoint_times
(
self
,
x
):
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axis
=
axes
[
0
]
nval
=
len
(
self
.
_dposc
)
prefix
=
(
slice
(
None
),)
*
axis
sl_l
=
prefix
+
(
slice
(
None
,
-
1
),)
# "left" slice
sl_r
=
prefix
+
(
slice
(
1
,
None
),)
# "right" slice
dpos
=
self
.
_dpos
.
reshape
((
1
,)
*
axis
+
(
nval
-
1
,))
dposc
=
self
.
_dposc
.
reshape
((
1
,)
*
axis
+
(
nval
,))
yf
=
x
.
val
if
axis
==
dobj
.
distaxis
(
yf
):
yf
=
dobj
.
redistribute
(
yf
,
nodist
=
(
axis
,))
y
=
dobj
.
local_data
(
yf
)
y
=
y
/
dposc
y
[
prefix
+
(
slice
(
None
,
2
),)]
=
0.
y
[
prefix
+
(
-
1
,)]
=
0.
deriv
=
(
y
[
sl_r
]
-
y
[
sl_l
])
/
dpos
# defined between points
ret
=
np
.
empty_like
(
y
)
ret
[
sl_l
]
=
deriv
ret
[
prefix
+
(
-
1
,)]
=
0.
ret
[
sl_r
]
-=
deriv
ret
=
dobj
.
from_local_data
(
x
.
shape
,
ret
,
dobj
.
distaxis
(
yf
))
if
dobj
.
distaxis
(
yf
)
!=
dobj
.
distaxis
(
x
.
val
):
ret
=
dobj
.
redistribute
(
ret
,
dist
=
dobj
.
distaxis
(
x
.
val
))
return
Field
(
self
.
domain
,
val
=
ret
)
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
self
.
_times
(
x
)
return
self
.
_adjoint_times
(
x
)
nifty5/operators/simple_linear_operators.py
View file @
7db6d4f9
...
...
@@ -119,9 +119,7 @@ class GeometryRemover(LinearOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
x
.
cast_domain
(
self
.
_target
)
return
x
.
cast_domain
(
self
.
_domain
)
return
x
.
cast_domain
(
self
.
_tgt
(
mode
))
class
NullOperator
(
LinearOperator
):
...
...
@@ -150,7 +148,4 @@ class NullOperator(LinearOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
self
.
_nullfield
(
self
.
_target
)
return
self
.
_nullfield
(
self
.
_domain
)
return
self
.
_nullfield
(
self
.
_tgt
(
mode
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment