Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
e9a5b0f1
Commit
e9a5b0f1
authored
Apr 01, 2018
by
Martin Reinecke
Browse files
make inverse_draw_sample() largely obsolete
parent
ad129166
Pipeline
#26707
passed with stage
in 8 minutes and 27 seconds
Changes
7
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty4/operators/diagonal_operator.py
View file @
e9a5b0f1
...
...
@@ -174,21 +174,14 @@ class DiagonalOperator(EndomorphicOperator):
raise
ValueError
(
"bad operator flipping mode"
)
return
res
def
draw_sample
(
self
,
dtype
=
np
.
float64
):
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
if
(
np
.
issubdtype
(
self
.
_ldiag
.
dtype
,
np
.
complexfloating
)
or
(
self
.
_ldiag
<=
0.
).
any
()):
raise
ValueError
(
"operator not positive definite"
)
res
=
Field
.
from_random
(
random_type
=
"normal"
,
domain
=
self
.
_domain
,
dtype
=
dtype
)
res
.
local_data
[()]
*=
np
.
sqrt
(
self
.
_ldiag
)
return
res
def
inverse_draw_sample
(
self
,
dtype
=
np
.
float64
):
if
(
np
.
issubdtype
(
self
.
_ldiag
.
dtype
,
np
.
complexfloating
)
or
(
self
.
_ldiag
<=
0.
).
any
()):
raise
ValueError
(
"operator not positive definite"
)
res
=
Field
.
from_random
(
random_type
=
"normal"
,
domain
=
self
.
_domain
,
dtype
=
dtype
)
res
.
local_data
[()]
/=
np
.
sqrt
(
self
.
_ldiag
)
if
from_inverse
:
res
.
local_data
[()]
/=
np
.
sqrt
(
self
.
_ldiag
)
else
:
res
.
local_data
[()]
*=
np
.
sqrt
(
self
.
_ldiag
)
return
res
nifty4/operators/endomorphic_operator.py
View file @
e9a5b0f1
...
...
@@ -36,12 +36,19 @@ class EndomorphicOperator(LinearOperator):
for endomorphic operators."""
return
self
.
domain
def
draw_sample
(
self
,
dtype
=
np
.
float64
):
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
"""Generate a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
Parameters
----------
from_inverse : bool (default : False)
if True, the sample is drawn from the inverse of the operator
dtype : numpy datatype (default : numpy.float64)
the data type to be used for the sample
Returns
-------
Field
...
...
@@ -59,8 +66,4 @@ class EndomorphicOperator(LinearOperator):
-------
A sample from the Gaussian of given covariance
"""
if
self
.
capability
&
self
.
INVERSE_TIMES
:
x
=
self
.
draw_sample
(
dtype
)
return
self
.
inverse_times
(
x
)
else
:
raise
NotImplementedError
return
self
.
draw_sample
(
True
,
dtype
)
nifty4/operators/inversion_enabler.py
View file @
e9a5b0f1
...
...
@@ -20,11 +20,11 @@ from ..minimization.quadratic_energy import QuadraticEnergy
from
..minimization.iteration_controller
import
IterationController
from
..field
import
Field
from
..logger
import
logger
from
.
linear
_operator
import
Linear
Operator
from
.
endomorphic
_operator
import
Endomorphic
Operator
import
numpy
as
np
class
InversionEnabler
(
Linear
Operator
):
class
InversionEnabler
(
Endomorphic
Operator
):
"""Class which augments the capability of another operator object via
numerical inversion.
...
...
@@ -80,14 +80,9 @@ class InversionEnabler(LinearOperator):
logger
.
warning
(
"Error detected during operator inversion"
)
return
r
.
position
def
draw_sample
(
self
,
dtype
=
np
.
float64
):
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
try
:
return
self
.
_op
.
draw_sample
(
dtype
)
return
self
.
_op
.
draw_sample
(
from_inverse
,
dtype
)
except
:
return
self
(
self
.
_op
.
inverse_draw_sample
(
dtype
))
def
inverse_draw_sample
(
self
,
dtype
=
np
.
float64
):
try
:
return
self
.
_op
.
inverse_draw_sample
(
dtype
)
except
:
return
self
.
inverse_times
(
self
.
_op
.
draw_sample
(
dtype
))
samp
=
self
.
_op
.
draw_sample
(
not
from_inverse
,
dtype
)
return
self
.
inverse_times
(
samp
)
if
from_inverse
else
self
(
samp
)
nifty4/operators/operator_adapter.py
View file @
e9a5b0f1
...
...
@@ -49,12 +49,7 @@ class OperatorAdapter(LinearOperator):
def
apply
(
self
,
x
,
mode
):
return
self
.
_op
.
apply
(
x
,
self
.
_modeTable
[
self
.
_mode
][
self
.
_ilog
[
mode
]])
def
draw_sample
(
self
,
dtype
=
np
.
float64
):
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
if
self
.
_mode
&
self
.
INVERSE_BIT
:
return
self
.
_op
.
inverse_draw_sample
(
dtype
)
return
self
.
_op
.
draw_sample
(
dtype
)
def
inverse_draw_sample
(
self
,
dtype
=
np
.
float64
):
if
self
.
_mode
&
self
.
INVERSE_BIT
:
return
self
.
_op
.
draw_sample
(
dtype
)
return
self
.
_op
.
inverse_draw_sample
(
dtype
)
return
self
.
_op
.
draw_sample
(
not
from_inverse
,
dtype
)
return
self
.
_op
.
draw_sample
(
from_inverse
,
dtype
)
nifty4/operators/sandwich_operator.py
View file @
e9a5b0f1
...
...
@@ -48,5 +48,8 @@ class SandwichOperator(EndomorphicOperator):
def
apply
(
self
,
x
,
mode
):
return
self
.
_op
.
apply
(
x
,
mode
)
def
draw_sample
(
self
,
dtype
=
np
.
float64
):
return
self
.
_bun
.
adjoint_times
(
self
.
_cheese
.
draw_sample
(
dtype
))
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
if
from_inverse
:
raise
ValueError
(
"cannot draw from inverse of this operator"
)
return
self
.
_bun
.
adjoint_times
(
self
.
_cheese
.
draw_sample
(
from_inverse
,
dtype
))
nifty4/operators/scaling_operator.py
View file @
e9a5b0f1
...
...
@@ -93,14 +93,10 @@ class ScalingOperator(EndomorphicOperator):
def
capability
(
self
):
return
self
.
_all_ops
def
_sample_helper
(
self
,
fct
,
dtype
):
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
fct
=
self
.
_factor
if
fct
.
imag
!=
0.
or
fct
.
real
<=
0.
:
raise
ValueError
(
"operator not positive definite"
)
fct
=
1.
/
np
.
sqrt
(
fct
)
if
from_inverse
else
np
.
sqrt
(
fct
)
return
Field
.
from_random
(
random_type
=
"normal"
,
domain
=
self
.
_domain
,
std
=
fct
,
dtype
=
dtype
)
def
draw_sample
(
self
,
dtype
=
np
.
float64
):
return
self
.
_sample_helper
(
np
.
sqrt
(
self
.
_factor
),
dtype
)
def
inverse_draw_sample
(
self
,
dtype
=
np
.
float64
):
return
self
.
_sample_helper
(
1.
/
np
.
sqrt
(
self
.
_factor
),
dtype
)
nifty4/operators/sum_operator.py
View file @
e9a5b0f1
...
...
@@ -143,8 +143,10 @@ class SumOperator(LinearOperator):
res
+=
op
.
apply
(
x
,
mode
)
return
res
def
draw_sample
(
self
,
dtype
=
np
.
float64
):
res
=
self
.
_ops
[
0
].
draw_sample
(
dtype
)
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
if
from_inverse
:
raise
ValueError
(
"cannot draw from inverse of this operator"
)
res
=
self
.
_ops
[
0
].
draw_sample
(
from_inverse
,
dtype
)
for
op
in
self
.
_ops
[
1
:]:
res
+=
op
.
draw_sample
(
dtype
)
res
+=
op
.
draw_sample
(
from_inverse
,
dtype
)
return
res
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment