Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
ift
NIFTy
Commits
871c8d12
Commit
871c8d12
authored
Aug 31, 2016
by
Jait Dixit
Browse files
WIP: SmoothOperator for RGSpace
parent
30e5db9f
Changes
4
Hide whitespace changes
Inline
Side-by-side
nifty/operators/__init__.py
View file @
871c8d12
...
...
@@ -27,6 +27,8 @@ from diagonal_operator import DiagonalOperator
from
endomorphic_operator
import
EndomorphicOperator
from
smooth_operator
import
SmoothOperator
from
fft_operator
import
*
from
nifty_operators
import
operator
,
\
...
...
@@ -48,4 +50,4 @@ from nifty_probing import prober,\
inverse_diagonal_prober
from
nifty_minimization
import
conjugate_gradient
,
\
steepest_descent
\ No newline at end of file
steepest_descent
nifty/operators/smooth_operator/__init__.py
View file @
871c8d12
from
smooth_operator
import
SmoothOperator
nifty/operators/smooth_operator/smooth_operator.py
View file @
871c8d12
...
...
@@ -9,52 +9,44 @@ from nifty.operators.fft_operator import FFTOperator
class
SmoothOperator
(
EndomorphicOperator
):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
inplace
=
False
,
sigma
=
None
,
implemented
=
False
):
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
inplace
=
False
,
sigma
=
None
):
super
(
SmoothOperator
,
self
).
__init__
(
domain
=
domain
,
field_type
=
field_type
,
implemented
=
implemented
)
field_type
=
field_type
)
if
len
(
self
.
domain
)
!=
1
:
raise
ValueError
(
about
.
_errors
.
cstring
(
'ERROR: SmoothOperator accepts only exactly one '
'space as input domain.'
)
)
if
self
.
field_type
!=
():
raise
ValueError
(
about
.
_errors
.
cstring
(
'ERROR:
Transformation
Operator field-type must be an '
'ERROR:
Smooth
Operator field-type must be an '
'empty tuple.'
))
self
.
_sigma
=
sigma
self
.
_inplace
=
inplace
self
.
_implemented
=
bool
(
implemented
)
self
.
_inplace
=
bool
(
inplace
)
def
_inverse_times
(
self
,
x
,
spaces
,
types
):
return
self
.
_smooth_helper
(
x
,
spaces
,
types
,
inverse
=
True
)
def
_times
(
self
,
x
,
spaces
,
types
):
if
sigma
==
0
:
return
x
if
self
.
inplace
else
x
.
copy
()
return
self
.
_smooth_helper
(
x
,
spaces
,
types
)
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
# ---Mandatory properties and methods---
@
property
def
implemented
(
self
):
return
True
if
spaces
is
None
:
return
x
if
self
.
inplace
else
x
.
copy
()
@
property
def
symmetric
(
self
):
return
False
for
space
in
spaces
:
axes
=
x
.
domain_axes
[
space
]
for
space_axis
,
val_axis
in
zip
(
range
(
len
(
x
.
domain
[
space
].
shape
)),
axes
):
transform
=
FFTOperator
(
x
.
domain
[
space
])
kernel
=
x
.
domain
[
space
].
get_codomain_mask
(
self
.
sigma
,
space_axis
)
if
isinstance
(
x
.
domain
[
space
],
RGSpace
):
new_shape
=
np
.
ones
(
len
(
x
.
shape
),
dtype
=
np
.
int
)
new_shape
[
val_axis
]
=
len
(
kernel
)
kernel
=
kernel
.
reshape
(
new_shape
)
# transform
transformed_inp
=
transform
(
x
)
transformed_inp
*=
kernel
elif
isinstance
(
x
.
domain
[
space
],
LMSpace
):
pass
else
:
raise
ValueError
(
about
.
_errors
.
cstring
(
'ERROR: SmoothOperator cannot smooth space '
+
str
(
x
.
domain
[
space
]))
@
property
def
unitary
(
self
):
return
False
# ---Added properties and methods---
@
property
...
...
@@ -64,3 +56,53 @@ class SmoothOperator(EndomorphicOperator):
@
property
def
inplace
(
self
):
return
self
.
_inplace
def
_smooth_helper
(
self
,
x
,
spaces
,
types
,
inverse
=
False
):
if
self
.
sigma
==
0
:
return
x
if
self
.
inplace
else
x
.
copy
()
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
if
spaces
is
None
:
return
x
if
self
.
inplace
else
x
.
copy
()
# copy for doing the actual smoothing
smooth_out
=
x
.
copy
()
space_obj
=
x
.
domain
[
spaces
[
0
]]
axes
=
x
.
domain_axes
[
spaces
[
0
]]
for
space_axis
,
val_axis
in
zip
(
range
(
len
(
space_obj
.
shape
)),
axes
):
transform
=
FFTOperator
(
space_obj
)
kernel
=
space_obj
.
get_codomain_smoothing_kernel
(
self
.
sigma
,
space_axis
)
if
isinstance
(
space_obj
,
RGSpace
):
new_shape
=
np
.
ones
(
len
(
x
.
shape
),
dtype
=
np
.
int
)
new_shape
[
val_axis
]
=
len
(
kernel
)
kernel
=
kernel
.
reshape
(
new_shape
)
# transform
smooth_out
=
transform
(
smooth_out
,
spaces
=
spaces
[
0
])
# multiply kernel
if
inverse
:
smooth_out
.
val
/=
kernel
else
:
smooth_out
.
val
*=
kernel
# inverse transform
smooth_out
=
transform
.
inverse_times
(
smooth_out
,
spaces
=
spaces
[
0
])
elif
isinstance
(
space_obj
,
LMSpace
):
pass
else
:
raise
ValueError
(
about
.
_errors
.
cstring
(
'ERROR: SmoothOperator cannot smooth space '
+
str
(
space_obj
)))
if
self
.
inplace
:
x
.
set_val
(
val
=
smooth_out
.
val
)
return
x
else
:
return
smooth_out
nifty/spaces/rg_space/rg_space.py
View file @
871c8d12
...
...
@@ -310,10 +310,14 @@ class RGSpace(Space):
temp
[:]
=
zerocenter
return
tuple
(
temp
)
def
get_codomain_
mask
(
self
,
sigma
,
axis
):
def
get_codomain_
smoothing_kernel
(
self
,
sigma
,
axis
):
if
sigma
is
None
:
sigma
=
np
.
sqrt
(
2
)
*
np
.
max
(
self
.
distances
)
mask
=
np
.
fft
.
fftfreq
(
self
.
shape
[
axis
],
d
=
self
.
distances
[
axis
])
gaussian
=
lambda
x
:
np
.
exp
(
-
2.
*
np
.
pi
**
2
*
x
**
2
*
sigma
**
2
)
k
=
np
.
fft
.
fftfreq
(
self
.
shape
[
axis
],
d
=
self
.
distances
[
axis
])
return
mask
if
self
.
zerocenter
[
axis
]
else
np
.
fft
.
fftshift
(
mask
)
if
self
.
zerocenter
[
axis
]:
k
=
np
.
fft
.
fftshift
(
k
)
return
np
.
array
(
gaussian
(
k
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment