Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
65d6f135
Commit
65d6f135
authored
Feb 12, 2018
by
Martin Reinecke
Browse files
tweak diagonal sampling/probing
parent
6bd3a182
Pipeline
#24787
passed with stage
in 6 minutes and 4 seconds
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
demos/paper_demos/cartesian_wiener_filter.py
View file @
65d6f135
...
...
@@ -106,6 +106,6 @@ if __name__ == "__main__":
ift
.
plot
(
ift
.
Field
(
plot_space
,
val
=
data
.
val
),
name
=
'data.png'
,
**
plotdict
)
ift
.
plot
(
ift
.
Field
(
plot_space
,
val
=
m
.
val
),
name
=
'map.png'
,
**
plotdict
)
# sampling the uncertainty map
mean
,
variance
=
ift
.
probe_with_posterior_samples
(
wiener_curvature
,
m_k
,
ht
,
10
)
mean
,
variance
=
ift
.
probe_with_posterior_samples
(
wiener_curvature
,
ht
,
10
)
ift
.
plot
(
ift
.
Field
(
plot_space
,
val
=
ift
.
sqrt
(
variance
).
val
),
name
=
"uncertainty.png"
,
**
plotdict
)
ift
.
plot
(
ift
.
Field
(
plot_space
,
val
=
mean
.
val
),
name
=
"posterior_mean.png"
,
**
plotdict
)
ift
.
plot
(
ift
.
Field
(
plot_space
,
val
=
(
mean
+
m
)
.
val
),
name
=
"posterior_mean.png"
,
**
plotdict
)
demos/paper_demos/wiener_filter.py
View file @
65d6f135
...
...
@@ -66,6 +66,6 @@ if __name__ == "__main__":
ift
.
plot
(
m
,
name
=
"map.png"
,
**
plotdict
)
# sampling the uncertainty map
mean
,
variance
=
ift
.
probe_with_posterior_samples
(
wiener_curvature
,
m_k
,
ht
,
5
)
mean
,
variance
=
ift
.
probe_with_posterior_samples
(
wiener_curvature
,
ht
,
5
)
ift
.
plot
(
ift
.
sqrt
(
variance
),
name
=
"uncertainty.png"
,
**
plotdict
)
ift
.
plot
(
mean
,
name
=
"posterior_mean.png"
,
**
plotdict
)
ift
.
plot
(
mean
+
m
,
name
=
"posterior_mean.png"
,
**
plotdict
)
demos/wiener_filter_via_hamiltonian.py
View file @
65d6f135
...
...
@@ -78,6 +78,11 @@ if __name__ == "__main__":
sample_variance
=
ift
.
Field
.
zeros
(
s_space
)
sample_mean
=
ift
.
Field
.
zeros
(
s_space
)
mean
,
variance
=
ift
.
probe_with_posterior_samples
(
curv
,
m
,
ht
,
50
)
mean
,
variance
=
ift
.
probe_with_posterior_samples
(
curv
,
ht
,
50
)
ift
.
plot
(
variance
,
name
=
"posterior_variance.png"
,
**
plotdict
)
ift
.
plot
(
mean
,
name
=
"posterior_mean.png"
,
**
plotdict
)
ift
.
plot
(
mean
+
ht
(
m
),
name
=
"posterior_mean.png"
,
**
plotdict
)
# try to do the same with diagonal probing
variance
=
ift
.
probe_diagonal
(
ht
*
curv
.
inverse
*
ht
.
adjoint
,
100
)
#sm = ift.FFTSmoothingOperator(s_space, sigma=0.015)
ift
.
plot
(
variance
,
name
=
"posterior_variance2.png"
,
**
plotdict
)
nifty4/__init__.py
View file @
65d6f135
...
...
@@ -32,7 +32,7 @@ from .field import Field, sqrt, exp, log
from
.probing.prober
import
Prober
from
.probing.diagonal_prober_mixin
import
DiagonalProberMixin
from
.probing.trace_prober_mixin
import
TraceProberMixin
from
.probing.utils
import
probe_with_posterior_samples
from
.probing.utils
import
probe_with_posterior_samples
,
probe_diagonal
from
.minimization.line_search
import
LineSearch
from
.minimization.line_search_strong_wolfe
import
LineSearchStrongWolfe
...
...
nifty4/probing/utils.py
View file @
65d6f135
...
...
@@ -11,13 +11,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-201
7
Max-Planck-Society
# Copyright(C) 2013-201
8
Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
builtins
import
object
from
..field
import
Field
class
StatCalculator
(
object
):
def
__init__
(
self
):
...
...
@@ -47,12 +47,21 @@ class StatCalculator(object):
return
self
.
_M2
*
(
1.
/
(
self
.
_count
-
1
))
def
probe_with_posterior_samples
(
op
,
m
,
post_op
,
nprobes
):
def
probe_with_posterior_samples
(
op
,
post_op
,
nprobes
):
sc
=
StatCalculator
()
for
i
in
range
(
nprobes
):
sample
=
post_op
(
op
.
draw_sample
()
+
m
)
sample
=
post_op
(
op
.
draw_sample
())
sc
.
add
(
sample
)
if
nprobes
==
1
:
return
sc
.
mean
,
None
return
sc
.
mean
,
sc
.
var
def
probe_diagonal
(
op
,
nprobes
,
random_type
=
"normal"
):
sc
=
StatCalculator
()
for
i
in
range
(
nprobes
):
input
=
Field
.
from_random
(
random_type
,
op
.
domain
)
output
=
op
(
input
)
sc
.
add
(
output
.
conjugate
()
*
input
)
return
sc
.
mean
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment