Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
10e3efd2
Commit
10e3efd2
authored
Aug 27, 2017
by
Martin Reinecke
Browse files
Merge branch 'master' into byebye_zerocenter
parents
ffd98202
f64657e5
Pipeline
#17419
passed with stage
in 48 minutes and 10 seconds
Changes
26
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/__init__.py
View file @
10e3efd2
...
...
@@ -30,6 +30,8 @@ from .config import dependency_injector,\
nifty_configuration
,
\
d2o_configuration
logger
.
logger
.
setLevel
(
nifty_configuration
[
'loglevel'
])
from
d2o
import
distributed_data_object
,
d2o_librarian
from
.field
import
Field
...
...
nifty/config/nifty_config.py
View file @
10e3efd2
...
...
@@ -76,12 +76,26 @@ variable_harmonic_rg_base = keepers.Variable(
lambda
z
:
z
in
[
'real'
,
'complex'
],
genus
=
'str'
)
variable_threads
=
keepers
.
Variable
(
'threads'
,
[
1
],
lambda
z
:
np
.
int
(
abs
(
z
))
==
z
,
genus
=
'int'
)
variable_loglevel
=
keepers
.
Variable
(
'loglevel'
,
[
10
],
lambda
z
:
np
.
int
(
z
)
==
z
and
0
<=
z
<=
50
,
genus
=
'int'
)
nifty_configuration
=
keepers
.
get_Configuration
(
name
=
'NIFTy'
,
variables
=
[
variable_fft_module
,
variable_default_field_dtype
,
variable_default_distribution_strategy
,
variable_harmonic_rg_base
],
variable_harmonic_rg_base
,
variable_threads
,
variable_loglevel
],
file_name
=
'NIFTy.conf'
,
search_paths
=
[
os
.
path
.
expanduser
(
'~'
)
+
"/.config/nifty/"
,
os
.
path
.
expanduser
(
'~'
)
+
"/.config/"
,
...
...
nifty/domain_object.py
View file @
10e3efd2
...
...
@@ -25,8 +25,8 @@ from keepers import Loggable,\
from
future.utils
import
with_metaclass
class
DomainObject
(
with_metaclass
(
NiftyMeta
,
type
(
'NewBase'
,
(
Versionable
,
Loggable
,
object
),
{}))):
class
DomainObject
(
with_metaclass
(
NiftyMeta
,
type
(
'NewBase'
,
(
Versionable
,
Loggable
,
object
),
{}))):
"""The abstract class that can be used as a domain for a field.
This holds all the information and functionality a field needs to know
...
...
nifty/library/critical_filter/critical_power_curvature.py
View file @
10e3efd2
...
...
@@ -33,6 +33,23 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner
=
preconditioner
,
**
kwargs
)
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
_domain
=
self
.
_domain
if
'theta'
in
kwargs
:
theta
=
kwargs
[
'theta'
]
copy
.
theta
=
DiagonalOperator
(
theta
.
domain
,
diagonal
=
theta
)
else
:
copy
.
theta
=
self
.
theta
.
copy
()
if
'T'
in
kwargs
:
copy
.
T
=
kwargs
[
'T'
]
else
:
copy
.
T
=
self
.
T
copy
=
super
(
CriticalPowerCurvature
,
self
).
_add_attributes_to_copy
(
copy
,
**
kwargs
)
return
copy
def
_times
(
self
,
x
,
spaces
):
return
self
.
T
(
x
)
+
self
.
theta
(
x
)
...
...
nifty/library/critical_filter/critical_power_energy.py
View file @
10e3efd2
...
...
@@ -54,7 +54,8 @@ class CriticalPowerEnergy(Energy):
# ---Overwritten properties and methods---
def
__init__
(
self
,
position
,
m
,
D
=
None
,
alpha
=
1.0
,
q
=
0.
,
smoothness_prior
=
0.
,
logarithmic
=
True
,
samples
=
3
,
w
=
None
):
smoothness_prior
=
0.
,
logarithmic
=
True
,
samples
=
3
,
w
=
None
,
old_curvature
=
None
):
super
(
CriticalPowerEnergy
,
self
).
__init__
(
position
=
position
)
self
.
m
=
m
self
.
D
=
D
...
...
@@ -66,6 +67,8 @@ class CriticalPowerEnergy(Energy):
logarithmic
=
logarithmic
)
self
.
rho
=
self
.
position
.
domain
[
0
].
rho
self
.
_w
=
w
if
w
is
not
None
else
None
self
.
_old_curvature
=
old_curvature
self
.
_curvature
=
None
# ---Mandatory properties and methods---
...
...
@@ -73,9 +76,11 @@ class CriticalPowerEnergy(Energy):
return
self
.
__class__
(
position
,
self
.
m
,
D
=
self
.
D
,
alpha
=
self
.
alpha
,
q
=
self
.
q
,
smoothness_prior
=
self
.
smoothness_prior
,
logarithmic
=
self
.
logarithmic
,
w
=
self
.
w
,
samples
=
self
.
samples
)
w
=
self
.
w
,
samples
=
self
.
samples
,
old_curvature
=
self
.
_curvature
)
@
property
@
memo
def
value
(
self
):
energy
=
self
.
_theta
.
sum
()
energy
+=
self
.
position
.
vdot
(
self
.
_rho_prime
,
bare
=
True
)
...
...
@@ -83,6 +88,7 @@ class CriticalPowerEnergy(Energy):
return
energy
.
real
@
property
@
memo
def
gradient
(
self
):
gradient
=
-
self
.
_theta
.
weight
(
-
1
)
gradient
+=
(
self
.
_rho_prime
).
weight
(
-
1
)
...
...
@@ -92,9 +98,14 @@ class CriticalPowerEnergy(Energy):
@
property
def
curvature
(
self
):
curvature
=
CriticalPowerCurvature
(
theta
=
self
.
_theta
.
weight
(
-
1
),
T
=
self
.
T
)
return
curvature
if
self
.
_curvature
is
None
:
if
self
.
_old_curvature
is
None
:
self
.
_curvature
=
CriticalPowerCurvature
(
theta
=
self
.
_theta
.
weight
(
-
1
),
T
=
self
.
T
)
else
:
self
.
_curvature
=
self
.
_old_curvature
.
copy
(
theta
=
self
.
_theta
.
weight
(
-
1
),
T
=
self
.
T
)
return
self
.
_curvature
# ---Added properties and methods---
...
...
@@ -109,9 +120,11 @@ class CriticalPowerEnergy(Energy):
@
property
def
w
(
self
):
if
self
.
_w
is
None
:
self
.
logger
.
info
(
"Initializing w"
)
w
=
Field
(
domain
=
self
.
position
.
domain
,
val
=
0.
,
dtype
=
self
.
m
.
dtype
)
if
self
.
D
is
not
None
:
for
i
in
range
(
self
.
samples
):
self
.
logger
.
info
(
"Drawing sample %i"
%
i
)
posterior_sample
=
generate_posterior_sample
(
self
.
m
,
self
.
D
)
projected_sample
=
posterior_sample
.
power_analyze
(
...
...
nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py
View file @
10e3efd2
...
...
@@ -48,6 +48,23 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
preconditioner
=
preconditioner
,
**
kwargs
)
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
_cache
=
{}
copy
.
_domain
=
self
.
_domain
copy
.
R
=
self
.
R
.
copy
()
copy
.
N
=
self
.
N
.
copy
()
copy
.
S
=
self
.
S
.
copy
()
copy
.
d
=
self
.
d
.
copy
()
if
'position'
in
kwargs
:
copy
.
position
=
kwargs
[
'position'
]
else
:
copy
.
position
=
self
.
position
.
copy
()
copy
.
_fft
=
self
.
_fft
copy
=
super
(
LogNormalWienerFilterCurvature
,
self
).
_add_attributes_to_copy
(
copy
,
**
kwargs
)
return
copy
@
property
def
domain
(
self
):
return
self
.
_domain
...
...
nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py
View file @
10e3efd2
...
...
@@ -24,7 +24,7 @@ class LogNormalWienerFilterEnergy(Energy):
The prior signal covariance in harmonic space.
"""
def
__init__
(
self
,
position
,
d
,
R
,
N
,
S
,
fft4exp
=
None
):
def
__init__
(
self
,
position
,
d
,
R
,
N
,
S
,
fft4exp
=
None
,
old_curvature
=
None
):
super
(
LogNormalWienerFilterEnergy
,
self
).
__init__
(
position
=
position
)
self
.
d
=
d
self
.
R
=
R
...
...
@@ -37,9 +37,13 @@ class LogNormalWienerFilterEnergy(Energy):
else
:
self
.
_fft
=
fft4exp
self
.
_old_curvature
=
old_curvature
self
.
_curvature
=
None
def
at
(
self
,
position
):
return
self
.
__class__
(
position
=
position
,
d
=
self
.
d
,
R
=
self
.
R
,
N
=
self
.
N
,
S
=
self
.
S
,
fft4exp
=
self
.
_fft
)
S
=
self
.
S
,
fft4exp
=
self
.
_fft
,
old_curvature
=
self
.
_curvature
)
@
property
@
memo
...
...
@@ -53,11 +57,20 @@ class LogNormalWienerFilterEnergy(Energy):
return
self
.
_Sp
+
self
.
_exppRNRexppd
@
property
@
memo
def
curvature
(
self
):
return
LogNormalWienerFilterCurvature
(
R
=
self
.
R
,
N
=
self
.
N
,
S
=
self
.
S
,
d
=
self
.
d
,
position
=
self
.
position
,
fft4exp
=
self
.
_fft
)
if
self
.
_curvature
is
None
:
if
self
.
_old_curvature
is
None
:
self
.
_curvature
=
LogNormalWienerFilterCurvature
(
R
=
self
.
R
,
N
=
self
.
N
,
S
=
self
.
S
,
d
=
self
.
d
,
position
=
self
.
position
,
fft4exp
=
self
.
_fft
)
else
:
self
.
_curvature
=
\
self
.
_old_curvature
.
copy
(
position
=
self
.
position
)
return
self
.
_curvature
@
property
def
_expp
(
self
):
...
...
nifty/library/wiener_filter/wiener_filter_curvature.py
View file @
10e3efd2
...
...
@@ -35,6 +35,15 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner
=
preconditioner
,
**
kwargs
)
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
_domain
=
self
.
_domain
copy
.
R
=
self
.
R
.
copy
()
copy
.
N
=
self
.
N
.
copy
()
copy
.
S
=
self
.
S
.
copy
()
copy
=
super
(
WienerFilterCurvature
,
self
).
_add_attributes_to_copy
(
copy
,
**
kwargs
)
return
copy
@
property
def
domain
(
self
):
return
self
.
_domain
...
...
nifty/library/wiener_filter/wiener_filter_energy.py
View file @
10e3efd2
...
...
@@ -23,16 +23,17 @@ class WienerFilterEnergy(Energy):
The prior signal covariance in harmonic space.
"""
def
__init__
(
self
,
position
,
d
,
R
,
N
,
S
):
def
__init__
(
self
,
position
,
d
,
R
,
N
,
S
,
old_curvature
=
None
):
super
(
WienerFilterEnergy
,
self
).
__init__
(
position
=
position
)
self
.
d
=
d
self
.
R
=
R
self
.
N
=
N
self
.
S
=
S
self
.
_curvature
=
old_curvature
def
at
(
self
,
position
):
return
self
.
__class__
(
position
=
position
,
d
=
self
.
d
,
R
=
self
.
R
,
N
=
self
.
N
,
S
=
self
.
S
)
S
=
self
.
S
,
old_curvature
=
self
.
curvature
)
@
property
@
memo
...
...
@@ -45,9 +46,12 @@ class WienerFilterEnergy(Energy):
return
self
.
_Dx
-
self
.
_j
@
property
@
memo
def
curvature
(
self
):
return
WienerFilterCurvature
(
R
=
self
.
R
,
N
=
self
.
N
,
S
=
self
.
S
)
if
self
.
_curvature
is
None
:
self
.
_curvature
=
WienerFilterCurvature
(
R
=
self
.
R
,
N
=
self
.
N
,
S
=
self
.
S
)
return
self
.
_curvature
@
property
@
memo
...
...
nifty/minimization/conjugate_gradient.py
View file @
10e3efd2
...
...
@@ -132,6 +132,9 @@ class ConjugateGradient(Loggable, object):
iteration_number
=
1
self
.
logger
.
info
(
"Starting conjugate gradient."
)
beta
=
np
.
inf
delta
=
np
.
inf
while
True
:
if
self
.
callback
is
not
None
:
self
.
callback
(
x
,
iteration_number
)
...
...
@@ -140,7 +143,10 @@ class ConjugateGradient(Loggable, object):
alpha
=
previous_gamma
/
d
.
vdot
(
q
).
real
if
not
np
.
isfinite
(
alpha
):
self
.
logger
.
error
(
"Alpha became infinite! Stopping."
)
self
.
logger
.
error
(
"Alpha became infinite! Stopping. Iteration : %08u "
"alpha = %3.1E beta = %3.1E delta = %3.1E"
%
(
iteration_number
,
alpha
,
beta
,
delta
))
return
x0
,
0
x
+=
d
*
alpha
...
...
@@ -174,21 +180,30 @@ class ConjugateGradient(Loggable, object):
if
gamma
==
0
:
convergence
=
self
.
convergence_level
+
1
self
.
logger
.
info
(
"Reached infinite convergence."
)
self
.
logger
.
info
(
"Reached infinite convergence. Iteration : %08u "
"alpha = %3.1E beta = %3.1E delta = %3.1E"
%
(
iteration_number
,
alpha
,
beta
,
delta
))
break
elif
abs
(
delta
)
<
self
.
convergence_tolerance
:
convergence
+=
1
self
.
logger
.
info
(
"Updated convergence level to: %u"
%
convergence
)
if
convergence
==
self
.
convergence_level
:
self
.
logger
.
info
(
"Reached target convergence level."
)
self
.
logger
.
info
(
"Reached target convergence level. Iteration : %08u "
"alpha = %3.1E beta = %3.1E delta = %3.1E"
%
(
iteration_number
,
alpha
,
beta
,
delta
))
break
else
:
convergence
=
max
(
0
,
convergence
-
1
)
if
self
.
iteration_limit
is
not
None
:
if
iteration_number
==
self
.
iteration_limit
:
self
.
logger
.
warn
(
"Reached iteration limit. Stopping."
)
self
.
logger
.
info
(
"Reached iteration limit. Iteration : %08u "
"alpha = %3.1E beta = %3.1E delta = %3.1E"
%
(
iteration_number
,
alpha
,
beta
,
delta
))
break
d
=
s
+
d
*
beta
...
...
nifty/minimization/line_searching/line_search_strong_wolfe.py
View file @
10e3efd2
...
...
@@ -127,10 +127,13 @@ class LineSearchStrongWolfe(LineSearch):
alpha1
=
1.0
/
pk
.
norm
()
# start the minimization loop
for
i
in
range
(
self
.
max_iterations
):
iteration_number
=
0
while
iteration_number
<
self
.
max_iterations
:
iteration_number
+=
1
if
alpha1
==
0
:
self
.
logger
.
warn
(
"Increment size became 0."
)
return
le_0
.
energy
result_energy
=
le_0
.
energy
break
le_alpha1
=
le_0
.
at
(
alpha1
)
phi_alpha1
=
le_alpha1
.
value
...
...
@@ -140,31 +143,37 @@ class LineSearchStrongWolfe(LineSearch):
le_star
=
self
.
_zoom
(
alpha0
,
alpha1
,
phi_0
,
phiprime_0
,
phi_alpha0
,
phiprime_alpha0
,
phi_alpha1
,
le_0
)
return
le_star
.
energy
result_energy
=
le_star
.
energy
break
phiprime_alpha1
=
le_alpha1
.
directional_derivative
if
abs
(
phiprime_alpha1
)
<=
-
self
.
c2
*
phiprime_0
:
return
le_alpha1
.
energy
result_energy
=
le_alpha1
.
energy
break
if
phiprime_alpha1
>=
0
:
le_star
=
self
.
_zoom
(
alpha1
,
alpha0
,
phi_0
,
phiprime_0
,
phi_alpha1
,
phiprime_alpha1
,
phi_alpha0
,
le_0
)
return
le_star
.
energy
result_energy
=
le_star
.
energy
break
# update alphas
alpha0
,
alpha1
=
alpha1
,
min
(
2
*
alpha1
,
self
.
max_step_size
)
if
alpha1
==
self
.
max_step_size
:
print
(
"
r
eached max step size, bailing out"
)
self
.
logger
.
info
(
"
R
eached max step size, bailing out"
)
return
le_alpha1
.
energy
phi_alpha0
=
phi_alpha1
phiprime_alpha0
=
phiprime_alpha1
else
:
# max_iterations was reached
self
.
logger
.
error
(
"The line search algorithm did not converge."
)
return
le_alpha1
.
energy
if
iteration_number
>
1
:
self
.
logger
.
debug
(
"Finished line-search after %08u steps"
%
iteration_number
)
return
result_energy
def
_zoom
(
self
,
alpha_lo
,
alpha_hi
,
phi_0
,
phiprime_0
,
phi_lo
,
phiprime_lo
,
phi_hi
,
le_0
):
...
...
nifty/operators/composed_operator/composed_operator.py
View file @
10e3efd2
...
...
@@ -91,6 +91,14 @@ class ComposedOperator(LinearOperator):
"instances of the LinearOperator-baseclass"
)
self
.
_operator_store
+=
(
op
,)
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
_operator_store
=
()
for
op
in
self
.
_operator_store
:
copy
.
_operator_store
+=
(
op
.
copy
(),)
copy
=
super
(
ComposedOperator
,
self
).
_add_attributes_to_copy
(
copy
,
**
kwargs
)
return
copy
def
_check_input_compatibility
(
self
,
x
,
spaces
,
inverse
=
False
):
"""
The input check must be disabled for the ComposedOperator, since it
...
...
nifty/operators/diagonal_operator/diagonal_operator.py
View file @
10e3efd2
...
...
@@ -117,8 +117,20 @@ class DiagonalOperator(EndomorphicOperator):
distribution_strategy
=
distribution_strategy
,
val
=
diagonal
)
self
.
_self_adjoint
=
None
self
.
_unitary
=
None
self
.
set_diagonal
(
diagonal
=
diagonal
,
bare
=
bare
,
copy
=
copy
)
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
_domain
=
self
.
_domain
copy
.
_distribution_strategy
=
self
.
_distribution_strategy
copy
.
set_diagonal
(
diagonal
=
self
.
diagonal
(
bare
=
True
),
bare
=
True
)
copy
.
_self_adjoint
=
self
.
_self_adjoint
copy
.
_unitary
=
self
.
_unitary
copy
=
super
(
DiagonalOperator
,
self
).
_add_attributes_to_copy
(
copy
,
**
kwargs
)
return
copy
def
_times
(
self
,
x
,
spaces
):
return
self
.
_times_helper
(
x
,
spaces
,
operation
=
lambda
z
:
z
.
__mul__
)
...
...
@@ -127,7 +139,8 @@ class DiagonalOperator(EndomorphicOperator):
operation
=
lambda
z
:
z
.
adjoint
().
__mul__
)
def
_inverse_times
(
self
,
x
,
spaces
):
return
self
.
_times_helper
(
x
,
spaces
,
operation
=
lambda
z
:
z
.
__rtruediv__
)
return
self
.
_times_helper
(
x
,
spaces
,
operation
=
lambda
z
:
z
.
__rtruediv__
)
def
_adjoint_inverse_times
(
self
,
x
,
spaces
):
return
self
.
_times_helper
(
x
,
spaces
,
...
...
nifty/operators/fft_operator/fft_operator.py
View file @
10e3efd2
...
...
@@ -147,6 +147,17 @@ class FFTOperator(LinearOperator):
self
.
target_dtype
=
\
None
if
target_dtype
is
None
else
np
.
dtype
(
target_dtype
)
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
_domain
=
self
.
_domain
copy
.
_target
=
self
.
_target
copy
.
_forward_transformation
=
self
.
_forward_transformation
copy
.
_backward_transformation
=
self
.
_backward_transformation
copy
.
domain_dtype
=
self
.
domain_dtype
copy
.
target_dtype
=
self
.
target_dtype
copy
=
super
(
FFTOperator
,
self
).
_add_attributes_to_copy
(
copy
,
**
kwargs
)
return
copy
def
_times
(
self
,
x
,
spaces
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
if
spaces
is
None
:
...
...
nifty/operators/fft_operator/transformations/rg_transforms.py
View file @
10e3efd2
...
...
@@ -23,6 +23,7 @@ import warnings
import
numpy
as
np
from
d2o
import
distributed_data_object
,
STRATEGIES
from
....config
import
dependency_injector
as
gdi
from
....config
import
nifty_configuration
as
gc
from
....
import
nifty_utilities
as
utilities
from
keepers
import
Loggable
...
...
@@ -337,7 +338,8 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
input_dtype
=
'complex128'
,
output_dtype
=
'complex128'
,
direction
=
'FFTW_FORWARD'
if
codomain
.
harmonic
else
'FFTW_BACKWARD'
,
flags
=
[
"FFTW_ESTIMATE"
],
flags
=
[
'FFTW_ESTIMATE'
],
threads
=
gc
[
'threads'
],
**
kwargs
)
...
...
nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py
View file @
10e3efd2
...
...
@@ -73,6 +73,23 @@ class InvertibleOperatorMixin(object):
self
.
__backward_x0
=
backward_x0
super
(
InvertibleOperatorMixin
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
__preconditioner
=
self
.
__preconditioner
copy
.
__inverter
=
self
.
__inverter
try
:
copy
.
__forward_x0
=
self
.
__forward_x0
.
copy
()
except
AttributeError
:
copy
.
__forward_x0
=
self
.
__forward_x0
try
:
copy
.
__backward_x0
=
self
.
__backward_x0
.
copy
()
except
AttributeError
:
copy
.
__backward_x0
=
self
.
__backward_x0
copy
=
super
(
InvertibleOperatorMixin
,
self
).
_add_attributes_to_copy
(
copy
,
**
kwargs
)
return
copy
def
_times
(
self
,
x
,
spaces
):
if
self
.
__forward_x0
is
not
None
:
x0
=
self
.
__forward_x0
...
...
nifty/operators/laplace_operator/laplace_operator.py
View file @
10e3efd2
...
...
@@ -64,6 +64,15 @@ class LaplaceOperator(EndomorphicOperator):
self
.
_dposc
[
1
:]
+=
self
.
_dpos
self
.
_dposc
*=
0.5
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
_domain
=
self
.
_domain
copy
.
_logarithmic
=
self
.
_logarithmic
copy
.
_dpos
=
self
.
_dpos
copy
.
_dposc
=
self
.
_dposc
copy
=
super
(
LaplaceOperator
,
self
).
_add_attributes_to_copy
(
copy
,
**
kwargs
)
return
copy
@
property
def
target
(
self
):
return
self
.
_domain
...
...
nifty/operators/linear_operator/linear_operator.py
View file @
10e3efd2
...
...
@@ -20,13 +20,15 @@ from builtins import str
import
abc
from
...nifty_meta
import
NiftyMeta
from
keepers
import
Loggable
from
keepers
import
Loggable
,
\
Versionable
from
...field
import
Field
from
...
import
nifty_utilities
as
utilities
from
future.utils
import
with_metaclass
class
LinearOperator
(
with_metaclass
(
NiftyMeta
,
type
(
'NewBase'
,
(
Loggable
,
object
),
{}))):
class
LinearOperator
(
with_metaclass
(
NiftyMeta
,
type
(
'NewBase'
,
(
Versionable
,
Loggable
,
object
),
{}))):
"""NIFTY base class for linear operators.
The base NIFTY operator class is an abstract class from which
...
...
@@ -75,6 +77,20 @@ class LinearOperator(with_metaclass(NiftyMeta, type('NewBase', (Loggable, object
def
__init__
(
self
,
default_spaces
=
None
):
self
.
_default_spaces
=
default_spaces
def
copy
(
self
,
**
kwargs
):
class
EmptyCopy
(
self
.
__class__
):
def
__init__
(
self
):
pass
result
=
EmptyCopy
()
result
.
__class__
=
self
.
__class__
result
=
self
.
_add_attributes_to_copy
(
result
,
**
kwargs
)
return
result
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
_default_spaces
=
self
.
default_spaces
return
copy
@
staticmethod
def
_parse_domain
(
domain
):
return
utilities
.
parse_domain
(
domain
)
...
...
nifty/operators/projection_operator/projection_operator.py
View file @
10e3efd2
...
...
@@ -87,6 +87,13 @@ class ProjectionOperator(EndomorphicOperator):
self
.
_projection_field
=
projection_field
self
.
_unitary
=
None
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
_projection_field
=
self
.
_projection_field
copy
.
_unitary
=
self
.
_unitary
copy
=
super
(
ProjectionOperator
,
self
).
_add_attributes_to_copy
(
copy
,
**
kwargs
)
return
copy
def
_times
(
self
,
x
,
spaces
):
# if the domain matches directly
# -> multiply the fields directly
...
...
nifty/operators/response_operator/response_operator.py
View file @
10e3efd2
...
...
@@ -82,7 +82,7 @@ class ResponseOperator(LinearOperator):
for
ii
in
range
(
len
(
kernel_smoothing
)):
kernel_smoothing
[
ii
]
=
SmoothingOperator
.
make
(
self
.
_domain
[
ii
],
sigma
=
sigma
[
ii
])
sigma
=
sigma
[
ii
])
kernel_exposure
[
ii
]
=
DiagonalOperator
(
self
.
_domain
[
ii
],
diagonal
=
exposure
[
ii
])
...
...
@@ -95,6 +95,15 @@ class ResponseOperator(LinearOperator):
self
.
_target
=
self
.
_parse_domain
(
target_list
)
def
_add_attributes_to_copy
(
self
,
copy
,
**
kwargs
):
copy
.
_domain
=
self
.
_domain
copy
.
_target
=
self
.
_target
copy
.