Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
6492d74b
Commit
6492d74b
authored
Jan 05, 2018
by
Martin Reinecke
Browse files
more polishing
parent
b7934d79
Pipeline
#23368
passed with stage
in 4 minutes and 33 seconds
Changes
10
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/field.py
View file @
6492d74b
...
...
@@ -32,7 +32,6 @@ class Field(object):
In NIFTY, Fields are used to store data arrays and carry all the needed
metainformation (i.e. the domain) for operators to be able to work on them.
In addition, Field has methods to work with power spectra.
Parameters
----------
...
...
@@ -59,23 +58,23 @@ class Field(object):
"""
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
copy
=
False
):
self
.
domain
=
self
.
_infer_domain
(
domain
=
domain
,
val
=
val
)
self
.
_
domain
=
self
.
_infer_domain
(
domain
=
domain
,
val
=
val
)
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
)
if
isinstance
(
val
,
Field
):
if
self
.
domain
!=
val
.
domain
:
if
self
.
_
domain
!=
val
.
_
domain
:
raise
ValueError
(
"Domain mismatch"
)
self
.
_val
=
dobj
.
from_object
(
val
.
val
,
dtype
=
dtype
,
copy
=
copy
)
elif
(
np
.
isscalar
(
val
)):
self
.
_val
=
dobj
.
full
(
self
.
domain
.
shape
,
dtype
=
dtype
,
self
.
_val
=
dobj
.
full
(
self
.
_
domain
.
shape
,
dtype
=
dtype
,
fill_value
=
val
)
elif
isinstance
(
val
,
dobj
.
data_object
):
if
self
.
domain
.
shape
==
val
.
shape
:
if
self
.
_
domain
.
shape
==
val
.
shape
:
self
.
_val
=
dobj
.
from_object
(
val
,
dtype
=
dtype
,
copy
=
copy
)
else
:
raise
ValueError
(
"Shape mismatch"
)
elif
val
is
None
:
self
.
_val
=
dobj
.
empty
(
self
.
domain
.
shape
,
dtype
=
dtype
)
self
.
_val
=
dobj
.
empty
(
self
.
_
domain
.
shape
,
dtype
=
dtype
)
else
:
raise
TypeError
(
"unknown source type"
)
...
...
@@ -101,7 +100,7 @@ class Field(object):
def
full_like
(
field
,
val
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
return
Field
.
full
(
field
.
domain
,
val
,
dtype
)
return
Field
.
full
(
field
.
_
domain
,
val
,
dtype
)
@
staticmethod
def
zeros_like
(
field
,
dtype
=
None
):
...
...
@@ -109,7 +108,7 @@ class Field(object):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
zeros
(
field
.
domain
,
dtype
)
return
Field
.
zeros
(
field
.
_
domain
,
dtype
)
@
staticmethod
def
ones_like
(
field
,
dtype
=
None
):
...
...
@@ -117,7 +116,7 @@ class Field(object):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
ones
(
field
.
domain
,
dtype
)
return
Field
.
ones
(
field
.
_
domain
,
dtype
)
@
staticmethod
def
empty_like
(
field
,
dtype
=
None
):
...
...
@@ -125,13 +124,13 @@ class Field(object):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
empty
(
field
.
domain
,
dtype
)
return
Field
.
empty
(
field
.
_
domain
,
dtype
)
@
staticmethod
def
_infer_domain
(
domain
,
val
=
None
):
if
domain
is
None
:
if
isinstance
(
val
,
Field
):
return
val
.
domain
return
val
.
_
domain
if
np
.
isscalar
(
val
):
return
DomainTuple
.
make
(())
# empty domain tuple
raise
TypeError
(
"could not infer domain from value"
)
...
...
@@ -187,6 +186,10 @@ class Field(object):
def
dtype
(
self
):
return
self
.
_val
.
dtype
@
property
def
domain
(
self
):
return
self
.
_domain
@
property
def
shape
(
self
):
""" Returns the total shape of the Field's data array.
...
...
@@ -195,7 +198,7 @@ class Field(object):
-------
Integer tuple containing the dimensions of the spaces in domain.
"""
return
self
.
domain
.
shape
return
self
.
_
domain
.
shape
@
property
def
dim
(
self
):
...
...
@@ -208,21 +211,21 @@ class Field(object):
out : int
The dimension of the Field.
"""
return
self
.
domain
.
dim
return
self
.
_
domain
.
dim
@
property
def
real
(
self
):
""" The real part of the field (data is not copied)."""
if
not
np
.
issubdtype
(
self
.
dtype
,
np
.
complexfloating
):
raise
ValueError
(
".real called on a non-complex Field"
)
return
Field
(
self
.
domain
,
self
.
val
.
real
)
return
Field
(
self
.
_
domain
,
self
.
val
.
real
)
@
property
def
imag
(
self
):
""" The imaginary part of the field (data is not copied)."""
if
not
np
.
issubdtype
(
self
.
dtype
,
np
.
complexfloating
):
raise
ValueError
(
".imag called on a non-complex Field"
)
return
Field
(
self
.
domain
,
self
.
val
.
imag
)
return
Field
(
self
.
_
domain
,
self
.
val
.
imag
)
def
copy
(
self
):
""" Returns a full copy of the Field.
...
...
@@ -238,13 +241,13 @@ class Field(object):
def
scalar_weight
(
self
,
spaces
=
None
):
if
np
.
isscalar
(
spaces
):
return
self
.
domain
[
spaces
].
scalar_dvol
()
return
self
.
_
domain
[
spaces
].
scalar_dvol
()
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
spaces
=
range
(
len
(
self
.
_
domain
))
res
=
1.
for
i
in
spaces
:
tmp
=
self
.
domain
[
i
].
scalar_dvol
()
tmp
=
self
.
_
domain
[
i
].
scalar_dvol
()
if
tmp
is
None
:
return
None
res
*=
tmp
...
...
@@ -277,17 +280,17 @@ class Field(object):
if
out
is
not
self
:
out
.
copy_content_from
(
self
)
spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
domain
))
spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_
domain
))
fct
=
1.
for
ind
in
spaces
:
wgt
=
self
.
domain
[
ind
].
dvol
()
wgt
=
self
.
_
domain
[
ind
].
dvol
()
if
np
.
isscalar
(
wgt
):
fct
*=
wgt
else
:
new_shape
=
np
.
ones
(
len
(
self
.
shape
),
dtype
=
np
.
int
)
new_shape
[
self
.
domain
.
axes
[
ind
][
0
]:
self
.
domain
.
axes
[
ind
][
-
1
]
+
1
]
=
wgt
.
shape
new_shape
[
self
.
_
domain
.
axes
[
ind
][
0
]:
self
.
_
domain
.
axes
[
ind
][
-
1
]
+
1
]
=
wgt
.
shape
wgt
=
wgt
.
reshape
(
new_shape
)
if
dobj
.
distaxis
(
self
.
_val
)
>=
0
and
ind
==
0
:
# we need to distribute the weights along axis 0
...
...
@@ -321,10 +324,10 @@ class Field(object):
raise
ValueError
(
"The dot-partner must be an instance of "
+
"the NIFTy field class"
)
if
x
.
domain
!=
self
.
domain
:
if
x
.
_
domain
!=
self
.
_
domain
:
raise
ValueError
(
"Domain mismatch"
)
ndom
=
len
(
self
.
domain
)
ndom
=
len
(
self
.
_
domain
)
spaces
=
utilities
.
parse_spaces
(
spaces
,
ndom
)
if
len
(
spaces
)
==
ndom
:
...
...
@@ -359,7 +362,7 @@ class Field(object):
-------
The complex conjugated field.
"""
return
Field
(
self
.
domain
,
self
.
val
.
conjugate
(),
self
.
dtype
)
return
Field
(
self
.
_
domain
,
self
.
val
.
conjugate
(),
self
.
dtype
)
# ---General unary/contraction methods---
...
...
@@ -367,18 +370,18 @@ class Field(object):
return
self
.
copy
()
def
__neg__
(
self
):
return
Field
(
self
.
domain
,
-
self
.
val
,
self
.
dtype
)
return
Field
(
self
.
_
domain
,
-
self
.
val
,
self
.
dtype
)
def
__abs__
(
self
):
return
Field
(
self
.
domain
,
dobj
.
abs
(
self
.
val
),
self
.
dtype
)
return
Field
(
self
.
_
domain
,
dobj
.
abs
(
self
.
val
),
self
.
dtype
)
def
_contraction_helper
(
self
,
op
,
spaces
):
if
spaces
is
None
:
return
getattr
(
self
.
val
,
op
)()
spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
domain
))
spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_
domain
))
axes_list
=
tuple
(
self
.
domain
.
axes
[
sp_index
]
for
sp_index
in
spaces
)
axes_list
=
tuple
(
self
.
_
domain
.
axes
[
sp_index
]
for
sp_index
in
spaces
)
if
len
(
axes_list
)
>
0
:
axes_list
=
reduce
(
lambda
x
,
y
:
x
+
y
,
axes_list
)
...
...
@@ -391,7 +394,7 @@ class Field(object):
return
data
else
:
return_domain
=
tuple
(
dom
for
i
,
dom
in
enumerate
(
self
.
domain
)
for
i
,
dom
in
enumerate
(
self
.
_
domain
)
if
i
not
in
spaces
)
return
Field
(
domain
=
return_domain
,
val
=
data
,
copy
=
False
)
...
...
@@ -435,21 +438,21 @@ class Field(object):
def
copy_content_from
(
self
,
other
):
if
not
isinstance
(
other
,
Field
):
raise
TypeError
(
"argument must be a Field"
)
if
other
.
domain
!=
self
.
domain
:
if
other
.
_
domain
!=
self
.
_
domain
:
raise
ValueError
(
"domains are incompatible."
)
dobj
.
local_data
(
self
.
val
)[()]
=
dobj
.
local_data
(
other
.
val
)[()]
def
_binary_helper
(
self
,
other
,
op
):
# if other is a field, make sure that the domains match
if
isinstance
(
other
,
Field
):
if
other
.
domain
!=
self
.
domain
:
if
other
.
_
domain
!=
self
.
_
domain
:
raise
ValueError
(
"domains are incompatible."
)
tval
=
getattr
(
self
.
val
,
op
)(
other
.
val
)
return
self
if
tval
is
self
.
val
else
Field
(
self
.
domain
,
tval
)
return
self
if
tval
is
self
.
val
else
Field
(
self
.
_
domain
,
tval
)
if
np
.
isscalar
(
other
)
or
isinstance
(
other
,
dobj
.
data_object
):
tval
=
getattr
(
self
.
val
,
op
)(
other
)
return
self
if
tval
is
self
.
val
else
Field
(
self
.
domain
,
tval
)
return
self
if
tval
is
self
.
val
else
Field
(
self
.
_
domain
,
tval
)
return
NotImplemented
...
...
@@ -511,7 +514,7 @@ class Field(object):
minmax
=
[
self
.
min
(),
self
.
max
()]
mean
=
self
.
mean
()
return
"nifty2go.Field instance
\n
- domain = "
+
\
repr
(
self
.
domain
)
+
\
repr
(
self
.
_
domain
)
+
\
"
\n
- val = "
+
repr
(
self
.
val
)
+
\
"
\n
- min.,max. = "
+
str
(
minmax
)
+
\
"
\n
- mean = "
+
str
(
mean
)
...
...
@@ -523,12 +526,12 @@ def _math_helper(x, function, out):
if
not
isinstance
(
x
,
Field
):
raise
TypeError
(
"This function only accepts Field objects."
)
if
out
is
not
None
:
if
not
isinstance
(
out
,
Field
)
or
x
.
domain
!=
out
.
domain
:
if
not
isinstance
(
out
,
Field
)
or
x
.
_
domain
!=
out
.
_
domain
:
raise
ValueError
(
"Bad 'out' argument"
)
function
(
x
.
val
,
out
=
out
.
val
)
return
out
else
:
return
Field
(
domain
=
x
.
domain
,
val
=
function
(
x
.
val
))
return
Field
(
domain
=
x
.
_
domain
,
val
=
function
(
x
.
val
))
def
sqrt
(
x
,
out
=
None
):
...
...
nifty/library/critical_power_energy.py
View file @
6492d74b
...
...
@@ -59,6 +59,8 @@ class CriticalPowerEnergy(Energy):
self
.
samples
=
samples
self
.
alpha
=
float
(
alpha
)
self
.
q
=
float
(
q
)
self
.
_smoothness_prior
=
smoothness_prior
self
.
_logarithmic
=
logarithmic
self
.
T
=
SmoothnessOperator
(
domain
=
self
.
position
.
domain
[
0
],
strength
=
smoothness_prior
,
logarithmic
=
logarithmic
)
...
...
@@ -93,8 +95,9 @@ class CriticalPowerEnergy(Energy):
def
at
(
self
,
position
):
return
self
.
__class__
(
position
,
self
.
m
,
D
=
self
.
D
,
alpha
=
self
.
alpha
,
q
=
self
.
q
,
smoothness_prior
=
self
.
smoothness_prior
,
logarithmic
=
self
.
logarithmic
,
q
=
self
.
q
,
smoothness_prior
=
self
.
_smoothness_prior
,
logarithmic
=
self
.
_logarithmic
,
samples
=
self
.
samples
,
w
=
self
.
_w
,
inverter
=
self
.
_inverter
)
...
...
@@ -111,11 +114,3 @@ class CriticalPowerEnergy(Energy):
def
curvature
(
self
):
return
CriticalPowerCurvature
(
theta
=
self
.
_theta
,
T
=
self
.
T
,
inverter
=
self
.
_inverter
)
@
property
def
logarithmic
(
self
):
return
self
.
T
.
logarithmic
@
property
def
smoothness_prior
(
self
):
return
self
.
T
.
strength
nifty/library/nonlinear_power_energy.py
View file @
6492d74b
...
...
@@ -47,6 +47,7 @@ class NonlinearPowerEnergy(Energy):
self
.
Instrument
=
Instrument
self
.
nonlinearity
=
nonlinearity
self
.
Projection
=
Projection
self
.
_sigma
=
sigma
self
.
power
=
self
.
Projection
.
adjoint_times
(
exp
(
0.5
*
self
.
position
))
if
sample_list
is
None
:
...
...
@@ -62,7 +63,7 @@ class NonlinearPowerEnergy(Energy):
def
at
(
self
,
position
):
return
self
.
__class__
(
position
,
self
.
d
,
self
.
N
,
self
.
m
,
self
.
D
,
self
.
FFT
,
self
.
Instrument
,
self
.
nonlinearity
,
self
.
Projection
,
sigma
=
self
.
T
.
strength
,
self
.
Projection
,
sigma
=
self
.
_sigma
,
samples
=
len
(
self
.
sample_list
),
sample_list
=
self
.
sample_list
,
inverter
=
self
.
inverter
)
...
...
nifty/library/wiener_filter_curvature.py
View file @
6492d74b
...
...
@@ -68,5 +68,4 @@ class WienerFilterCurvature(EndomorphicOperator):
mock_j
=
self
.
R
.
adjoint_times
(
self
.
N
.
inverse_times
(
mock_data
))
mock_m
=
self
.
inverse_times
(
mock_j
)
sample
=
mock_signal
-
mock_m
return
sample
return
mock_signal
-
mock_m
nifty/operators/chain_operator.py
View file @
6492d74b
...
...
@@ -24,23 +24,26 @@ class ChainOperator(LinearOperator):
super
(
ChainOperator
,
self
).
__init__
()
if
op2
.
target
!=
op1
.
domain
:
raise
ValueError
(
"domain mismatch"
)
self
.
_op1
=
op1
self
.
_op2
=
op2
self
.
_capability
=
op1
.
capability
&
op2
.
capability
op1
=
op1
.
_ops
if
isinstance
(
op1
,
ChainOperator
)
else
(
op1
,)
op2
=
op2
.
_ops
if
isinstance
(
op2
,
ChainOperator
)
else
(
op2
,)
self
.
_ops
=
op1
+
op2
@
property
def
domain
(
self
):
return
self
.
_op
2
.
domain
return
self
.
_op
s
[
-
1
]
.
domain
@
property
def
target
(
self
):
return
self
.
_op
1
.
target
return
self
.
_op
s
[
0
]
.
target
@
property
def
capability
(
self
):
return
self
.
_
op1
.
capability
&
self
.
_op2
.
capability
return
self
.
_capability
def
apply
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
if
mode
==
self
.
TIMES
or
mode
==
self
.
ADJOINT_INVERSE_TIMES
:
return
self
.
_op1
.
apply
(
self
.
_op2
.
apply
(
x
,
mode
),
mode
)
return
self
.
_op2
.
apply
(
self
.
_op1
.
apply
(
x
,
mode
),
mode
)
t_ops
=
self
.
_ops
if
mode
&
self
.
_backwards
else
reversed
(
self
.
_ops
)
for
op
in
t_ops
:
x
=
op
.
apply
(
x
,
mode
)
return
x
nifty/operators/fft_smoothing_operator.py
View file @
6492d74b
from
.
endomorphic
_operator
import
Endomorphic
Operator
from
.
scaling
_operator
import
Scaling
Operator
from
.fft_operator
import
FFTOperator
from
..utilities
import
infer_space
from
.diagonal_operator
import
DiagonalOperator
from
..
import
DomainTuple
class
FFTSmoothingOperator
(
EndomorphicOperator
):
def
__init__
(
self
,
domain
,
sigma
,
space
=
None
):
super
(
FFTSmoothingOperator
,
self
).
__init__
()
dom
=
DomainTuple
.
make
(
domain
)
self
.
_sigma
=
float
(
sigma
)
self
.
_space
=
infer_space
(
dom
,
space
)
self
.
_FFT
=
FFTOperator
(
dom
,
space
=
self
.
_space
)
codomain
=
self
.
_FFT
.
domain
[
self
.
_space
].
get_default_codomain
()
kernel
=
codomain
.
get_k_length_array
()
smoother
=
codomain
.
get_fft_smoothing_kernel_function
(
self
.
_sigma
)
kernel
=
smoother
(
kernel
)
ddom
=
list
(
dom
)
ddom
[
self
.
_space
]
=
codomain
self
.
_diag
=
DiagonalOperator
(
kernel
,
ddom
,
self
.
_space
)
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
self
.
_sigma
==
0
:
return
x
.
copy
()
return
self
.
_FFT
.
adjoint_times
(
self
.
_diag
(
self
.
_FFT
(
x
)))
@
property
def
domain
(
self
):
return
self
.
_FFT
.
domain
@
property
def
capability
(
self
):
return
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
FFTSmoothingOperator
(
domain
,
sigma
,
space
=
None
):
sigma
=
float
(
sigma
)
if
sigma
<
0.
:
raise
ValueError
(
"sigma must be nonnegative"
)
if
sigma
==
0.
:
return
ScalingOperator
(
1.
,
domain
)
domain
=
DomainTuple
.
make
(
domain
)
space
=
infer_space
(
domain
,
space
)
FFT
=
FFTOperator
(
domain
,
space
=
space
)
codomain
=
FFT
.
domain
[
space
].
get_default_codomain
()
kernel
=
codomain
.
get_k_length_array
()
smoother
=
codomain
.
get_fft_smoothing_kernel_function
(
sigma
)
kernel
=
smoother
(
kernel
)
ddom
=
list
(
domain
)
ddom
[
space
]
=
codomain
diag
=
DiagonalOperator
(
kernel
,
ddom
,
space
)
return
FFT
.
adjoint
*
diag
*
FFT
nifty/operators/linear_operator.py
View file @
6492d74b
...
...
@@ -32,6 +32,12 @@ class LinearOperator(with_metaclass(
_adjointMode
=
(
0
,
2
,
1
,
0
,
8
,
0
,
0
,
0
,
4
)
_adjointCapability
=
(
0
,
2
,
1
,
3
,
8
,
10
,
9
,
11
,
4
,
6
,
5
,
7
,
12
,
14
,
13
,
15
)
_addInverse
=
(
0
,
5
,
10
,
15
,
5
,
5
,
15
,
15
,
10
,
15
,
10
,
15
,
15
,
15
,
15
,
15
)
_backwards
=
6
TIMES
=
1
ADJOINT_TIMES
=
2
INVERSE_TIMES
=
4
ADJOINT_INVERSE_TIMES
=
8
INVERSE_ADJOINT_TIMES
=
8
def
_dom
(
self
,
mode
):
return
self
.
domain
if
(
mode
&
9
)
else
self
.
target
...
...
@@ -62,26 +68,6 @@ class LinearOperator(with_metaclass(
"""
raise
NotImplementedError
@
property
def
TIMES
(
self
):
return
1
@
property
def
ADJOINT_TIMES
(
self
):
return
2
@
property
def
INVERSE_TIMES
(
self
):
return
4
@
property
def
ADJOINT_INVERSE_TIMES
(
self
):
return
8
@
property
def
INVERSE_ADJOINT_TIMES
(
self
):
return
8
@
property
def
inverse
(
self
):
from
.inverse_operator
import
InverseOperator
...
...
@@ -127,6 +113,7 @@ class LinearOperator(with_metaclass(
other
=
self
.
_toOperator
(
other
,
self
.
domain
)
return
SumOperator
(
self
,
other
,
neg
=
True
)
# MR FIXME: this might be more complicated ...
def
__rsub__
(
self
,
other
):
from
.sum_operator
import
SumOperator
other
=
self
.
_toOperator
(
other
,
self
.
domain
)
...
...
nifty/operators/scaling_operator.py
View file @
6492d74b
...
...
@@ -21,7 +21,6 @@ import numpy as np
from
..field
import
Field
from
..domain_tuple
import
DomainTuple
from
.endomorphic_operator
import
EndomorphicOperator
from
..
import
dobj
class
ScalingOperator
(
EndomorphicOperator
):
...
...
@@ -54,6 +53,11 @@ class ScalingOperator(EndomorphicOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
self
.
_factor
==
1.
:
return
x
.
copy
()
if
self
.
_factor
==
0.
:
return
Field
.
zeros_like
(
x
,
dtype
=
x
.
dtype
)
if
mode
==
self
.
TIMES
:
return
x
*
self
.
_factor
elif
mode
==
self
.
ADJOINT_TIMES
:
...
...
@@ -63,6 +67,14 @@ class ScalingOperator(EndomorphicOperator):
else
:
return
x
*
(
1.
/
np
.
conj
(
self
.
_factor
))
@
property
def
inverse
(
self
):
return
ScalingOperator
(
1.
/
self
.
_factor
,
self
.
_domain
)
@
property
def
adjoint
(
self
):
return
ScalingOperator
(
np
.
conj
(
self
.
factor
),
self
.
_domain
)
@
property
def
domain
(
self
):
return
self
.
_domain
...
...
nifty/operators/smoothness_operator.py
View file @
6492d74b
from
.
endomorphic
_operator
import
Endomorphic
Operator
from
.
scaling
_operator
import
Scaling
Operator
from
.laplace_operator
import
LaplaceOperator
class
SmoothnessOperator
(
En
dom
orphicOperator
):
def
SmoothnessOperator
(
dom
ain
,
strength
=
1.
,
logarithmic
=
True
,
space
=
None
):
"""An operator measuring the smoothness on an irregular grid with respect
to some scale.
...
...
@@ -18,44 +18,15 @@ class SmoothnessOperator(EndomorphicOperator):
Parameters
----------
strength: float
,
strength:
nonnegative
float
Specifies the strength of the SmoothnessOperator
logarithmic : boolean
,
logarithmic : boolean
Whether smoothness is calculated on a logarithmic scale or linear scale
default : True
"""
def
__init__
(
self
,
domain
,
strength
=
1.
,
logarithmic
=
True
,
space
=
None
):
super
(
SmoothnessOperator
,
self
).
__init__
()
self
.
_laplace
=
LaplaceOperator
(
domain
,
logarithmic
=
logarithmic
,
space
=
space
)
if
strength
<
0
:
raise
ValueError
(
"ERROR: strength must be >=0."
)
self
.
_strength
=
strength
@
property
def
domain
(
self
):
return
self
.
_laplace
.
_domain
# MR FIXME: shouldn't this operator actually be self-adjoint?
@
property
def
capability
(
self
):
return
self
.
TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
self
.
_strength
==
0.
:
return
x
.
zeros_like
(
x
)
result
=
self
.
_laplace
.
adjoint_times
(
self
.
_laplace
(
x
))
result
*=
self
.
_strength
**
2
return
result
@
property
def
logarithmic
(
self
):
return
self
.
_laplace
.
logarithmic
@
property
def
strength
(
self
):
return
self
.
_strength
if
strength
<
0
:
raise
ValueError
(
"ERROR: strength must be nonnegative."
)
if
strength
==
0.
:
return
ScalingOperator
(
0.
,
domain
)
laplace
=
LaplaceOperator
(
domain
,
logarithmic
=
logarithmic
,
space
=
space
)
return
(
strength
**
2
)
*
laplace
.
adjoint
*
laplace
nifty/operators/sum_operator.py
View file @
6492d74b
...
...
@@ -24,25 +24,37 @@ class SumOperator(LinearOperator):
super
(
SumOperator
,
self
).
__init__
()
if
op1
.
domain
!=
op2
.
domain
or
op1
.
target
!=
op2
.
target
:
raise
ValueError
(
"domain mismatch"
)
self
.
_op1
=
op1
self
.
_op2
=
op2
self
.
_neg
=
bool
(
neg
)
self
.
_capability
=
(
op1
.
capability
&
op2
.
capability
&
(
self
.
TIMES
|
self
.
ADJOINT_TIMES
))
op1
=
op1
.
_ops
if
isinstance
(
op1
,
SumOperator
)
else
(
op1
,)
neg1
=
op1
.
_neg
if
isinstance
(
op1
,
SumOperator
)
else
(
False
,)
op2
=
op2
.
_ops
if
isinstance
(
op2
,
SumOperator
)
else
(
op2
,)
neg2
=
op2
.
_neg
if
isinstance
(
op2
,
SumOperator
)
else
(
False
,)
if
neg
:
neg2
=
tuple
(
not
n
for
n
in
neg2
)
self
.
_ops
=
op1
+
op2
self
.
_neg
=
neg1
+
neg2
@
property
def
domain
(
self
):
return
self
.
_op
1
.
domain
return
self
.
_op
s
[
0
]
.
domain
@
property
def
target
(
self
):
return
self
.
_op
1
.
target
return
self
.
_op
s
[
0
]
.
target
@
property
def
capability
(
self
):
return
(
self
.
_op1
.
capability
&
self
.
_op2
.
capability
&
(
self
.
TIMES
|
self
.
ADJOINT_TIMES
))
return
self
.
_capability
def
apply
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
res1
=
self
.
_op1
.
apply
(
x
,
mode
)
res2
=
self
.
_op2
.
apply
(
x
,
mode
)
return
res1
-
res2
if
self
.
_neg
else
res1
+
res2
for
i
,
op
in
enumerate
(
self
.
_ops
):
if
i
==
0
:
res
=
-
op
.
apply
(
x
,
mode
)
if
self
.
_neg
[
i
]
else
op
.
apply
(
x
,
mode
)
else
:
if
self
.
_neg
[
i
]: