Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
e6de9466
Commit
e6de9466
authored
Nov 18, 2017
by
Martin Reinecke
Browse files
performance tweaks
parent
b6ebd0ad
Pipeline
#21872
passed with stage
in 4 minutes and 16 seconds
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/field.py
View file @
e6de9466
...
...
@@ -417,6 +417,11 @@ class Field(object):
return
self
.
_contraction_helper
(
'sum'
,
spaces
)
def
integrate
(
self
,
spaces
=
None
):
swgt
=
self
.
scalar_weight
(
spaces
)
if
swgt
is
not
None
:
res
=
self
.
sum
(
spaces
)
res
*=
swgt
return
res
tmp
=
self
.
weight
(
1
,
spaces
=
spaces
)
return
tmp
.
sum
(
spaces
)
...
...
nifty/library/critical_power_curvature.py
View file @
e6de9466
...
...
@@ -9,7 +9,6 @@ class CriticalPowerCurvature(EndomorphicOperator):
CriticalPowerEnergy used in some minimization algorithms or
for error estimates of the power spectrum.
Parameters
----------
theta: Field,
...
...
@@ -20,19 +19,19 @@ class CriticalPowerCurvature(EndomorphicOperator):
def
__init__
(
self
,
theta
,
T
):
super
(
CriticalPowerCurvature
,
self
).
__init__
()
self
.
theta
=
DiagonalOperator
(
theta
)
self
.
T
=
T
self
.
_
theta
=
DiagonalOperator
(
theta
)
self
.
_
T
=
T
@
property
def
preconditioner
(
self
):
return
self
.
theta
.
inverse_times
return
self
.
_
theta
.
inverse_times
def
_times
(
self
,
x
):
return
self
.
T
(
x
)
+
self
.
theta
(
x
)
return
self
.
_
T
(
x
)
+
self
.
_
theta
(
x
)
@
property
def
domain
(
self
):
return
self
.
theta
.
domain
return
self
.
_
theta
.
domain
@
property
def
self_adjoint
(
self
):
...
...
nifty/library/critical_power_energy.py
View file @
e6de9466
...
...
@@ -52,8 +52,6 @@ class CriticalPowerEnergy(Energy):
default : None
"""
# ---Overwritten properties and methods---
def
__init__
(
self
,
position
,
m
,
D
=
None
,
alpha
=
1.0
,
q
=
0.
,
smoothness_prior
=
0.
,
logarithmic
=
True
,
samples
=
3
,
w
=
None
,
inverter
=
None
):
...
...
@@ -61,8 +59,8 @@ class CriticalPowerEnergy(Energy):
self
.
m
=
m
self
.
D
=
D
self
.
samples
=
samples
self
.
alpha
=
Field
(
self
.
position
.
domain
,
val
=
alpha
)
self
.
q
=
Field
(
self
.
position
.
domain
,
val
=
q
)
self
.
alpha
=
float
(
alpha
)
self
.
q
=
float
(
q
)
self
.
T
=
SmoothnessOperator
(
domain
=
self
.
position
.
domain
[
0
],
strength
=
smoothness_prior
,
logarithmic
=
logarithmic
)
...
...
@@ -71,8 +69,6 @@ class CriticalPowerEnergy(Energy):
self
.
_w
=
w
self
.
_inverter
=
inverter
# ---Mandatory properties and methods---
def
at
(
self
,
position
):
return
self
.
__class__
(
position
,
self
.
m
,
D
=
self
.
D
,
alpha
=
self
.
alpha
,
q
=
self
.
q
,
smoothness_prior
=
self
.
smoothness_prior
,
...
...
@@ -83,9 +79,9 @@ class CriticalPowerEnergy(Energy):
@
property
@
memo
def
value
(
self
):
energy
=
Field
.
ones_like
(
self
.
position
).
vdot
(
self
.
_theta
)
energy
+=
self
.
position
.
vdot
(
self
.
alpha
-
0.5
)
energy
+=
0.5
*
self
.
position
.
vdot
(
self
.
_Tt
)
energy
=
self
.
_theta
.
integrate
(
)
energy
+=
self
.
position
.
integrate
()
*
(
self
.
alpha
-
0.5
)
energy
+=
0.5
*
self
.
position
.
vdot
(
self
.
_Tt
)
return
energy
.
real
@
property
...
...
@@ -116,15 +112,15 @@ class CriticalPowerEnergy(Energy):
def
w
(
self
):
if
self
.
_w
is
None
:
# self.logger.info("Initializing w")
w
=
Field
(
domain
=
self
.
position
.
domain
,
val
=
0.
,
dtype
=
self
.
m
.
dtype
)
if
self
.
D
is
not
None
:
w
=
Field
.
zeros
(
self
.
position
.
domain
,
dtype
=
self
.
m
.
dtype
)
for
i
in
range
(
self
.
samples
):
# self.logger.info("Drawing sample %i" % i)
posterior_sample
=
generate_posterior_sample
(
self
.
m
,
self
.
D
)
w
+=
self
.
P
(
abs
(
posterior_sample
)
**
2
)
w
+=
self
.
P
(
abs
(
posterior_sample
)
**
2
)
w
/
=
float
(
self
.
samples
)
w
*
=
1.
/
self
.
samples
else
:
w
=
self
.
P
(
abs
(
self
.
m
)
**
2
)
self
.
_w
=
w
...
...
@@ -133,7 +129,7 @@ class CriticalPowerEnergy(Energy):
@
property
@
memo
def
_theta
(
self
):
return
exp
(
-
self
.
position
)
*
(
self
.
q
+
self
.
w
/
2.
)
return
exp
(
-
self
.
position
)
*
(
self
.
q
+
self
.
w
*
0.5
)
@
property
@
memo
...
...
nifty/library/log_normal_wiener_filter_curvature.py
View file @
e6de9466
...
...
@@ -22,20 +22,13 @@ class LogNormalWienerFilterCurvature(EndomorphicOperator):
The prior signal covariance
"""
def
__init__
(
self
,
R
,
N
,
S
,
d
,
position
,
fft4exp
=
None
):
def
__init__
(
self
,
R
,
N
,
S
,
position
,
fft4exp
):
super
(
LogNormalWienerFilterCurvature
,
self
).
__init__
()
self
.
R
=
R
self
.
N
=
N
self
.
S
=
S
self
.
d
=
d
self
.
position
=
position
if
fft4exp
is
None
:
self
.
_fft
=
create_composed_fft_operator
(
self
.
domain
,
all_to
=
'position'
)
else
:
self
.
_fft
=
fft4exp
super
(
LogNormalWienerFilterCurvature
,
self
).
__init__
()
self
.
_fft
=
fft4exp
@
property
def
domain
(
self
):
...
...
@@ -51,33 +44,14 @@ class LogNormalWienerFilterCurvature(EndomorphicOperator):
def
_times
(
self
,
x
):
part1
=
self
.
S
.
inverse_times
(
x
)
# part2 = self._exppRNRexppd * x
part3
=
self
.
_fft
.
adjoint_times
(
self
.
_expp_sspace
*
self
.
_fft
(
x
))
part3
=
self
.
_fft
.
adjoint_times
(
self
.
_expp_sspace
*
self
.
_fft
(
self
.
R
.
adjoint_times
(
self
.
N
.
inverse_times
(
self
.
R
(
part3
)))))
return
part1
+
part3
# + part2
return
part1
+
part3
@
property
@
memo
def
_expp_sspace
(
self
):
return
exp
(
self
.
_fft
(
self
.
position
))
@
property
@
memo
def
_Rexppd
(
self
):
expp
=
self
.
_fft
.
adjoint_times
(
self
.
_expp_sspace
)
return
self
.
R
(
expp
)
-
self
.
d
@
property
@
memo
def
_NRexppd
(
self
):
return
self
.
N
.
inverse_times
(
self
.
_Rexppd
)
@
property
@
memo
def
_exppRNRexppd
(
self
):
return
self
.
_fft
.
adjoint_times
(
self
.
_expp_sspace
*
self
.
_fft
(
self
.
R
.
adjoint_times
(
self
.
_NRexppd
)))
nifty/library/log_normal_wiener_filter_energy.py
View file @
e6de9466
...
...
@@ -48,19 +48,19 @@ class LogNormalWienerFilterEnergy(Energy):
@
memo
def
value
(
self
):
return
0.5
*
(
self
.
position
.
vdot
(
self
.
_Sp
)
+
self
.
curvature
.
op
.
_Rexppd
.
vdot
(
self
.
curvature
.
op
.
_NRexppd
))
self
.
_Rexppd
.
vdot
(
self
.
_NRexppd
))
@
property
@
memo
def
gradient
(
self
):
return
self
.
_Sp
+
self
.
curvature
.
op
.
_exppRNRexppd
return
self
.
_Sp
+
self
.
_exppRNRexppd
@
property
@
memo
def
curvature
(
self
):
return
InversionEnabler
(
LogNormalWienerFilterCurvature
(
R
=
self
.
R
,
N
=
self
.
N
,
S
=
self
.
S
,
d
=
self
.
d
,
position
=
self
.
position
,
position
=
self
.
position
,
fft4exp
=
self
.
_fft
),
inverter
=
self
.
_inverter
)
...
...
@@ -68,3 +68,21 @@ class LogNormalWienerFilterEnergy(Energy):
@
memo
def
_Sp
(
self
):
return
self
.
S
.
inverse_times
(
self
.
position
)
@
property
@
memo
def
_Rexppd
(
self
):
expp
=
self
.
_fft
.
adjoint_times
(
self
.
curvature
.
op
.
_expp_sspace
)
return
self
.
R
(
expp
)
-
self
.
d
@
property
@
memo
def
_NRexppd
(
self
):
return
self
.
N
.
inverse_times
(
self
.
_Rexppd
)
@
property
@
memo
def
_exppRNRexppd
(
self
):
return
self
.
_fft
.
adjoint_times
(
self
.
curvature
.
op
.
_expp_sspace
*
self
.
_fft
(
self
.
R
.
adjoint_times
(
self
.
_NRexppd
)))
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment