Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
00cc61c9
Commit
00cc61c9
authored
May 13, 2020
by
Philipp Arras
Browse files
Rename to SamplingDtypeSetter and add docstrings
parent
58c2740f
Pipeline
#74862
passed with stages
in 20 minutes and 55 seconds
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/__init__.py
View file @
00cc61c9
...
...
@@ -35,7 +35,7 @@ from .operators.field_zero_padder import FieldZeroPadder
from
.operators.inversion_enabler
import
InversionEnabler
from
.operators.mask_operator
import
MaskOperator
from
.operators.regridding_operator
import
RegriddingOperator
from
.operators.sampling_enabler
import
SamplingEnabler
from
.operators.sampling_enabler
import
SamplingEnabler
,
SamplingDtypeSetter
from
.operators.sandwich_operator
import
SandwichOperator
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.block_diagonal_operator
import
BlockDiagonalOperator
...
...
nifty6/library/wiener_filter_curvature.py
View file @
00cc61c9
...
...
@@ -11,12 +11,12 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-20
19
Max-Planck-Society
# Copyright(C) 2013-20
20
Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
..operators.inversion_enabler
import
InversionEnabler
from
..operators.sampling_enabler
import
SamplingEnabler
from
..operators.sampling_enabler
import
SamplingDtypeSetter
,
SamplingEnabler
from
..operators.sandwich_operator
import
SandwichOperator
...
...
@@ -48,11 +48,9 @@ def WienerFilterCurvature(R, N, S, iteration_controller=None,
Ninv
=
N
.
inverse
Sinv
=
S
.
inverse
if
data_sampling_dtype
is
not
None
:
from
..operators.energy_operators
import
SamplingDtypeEnabler
Ninv
=
SamplingDtypeEnabler
(
Ninv
,
data_sampling_dtype
)
Ninv
=
SamplingDtypeSetter
(
Ninv
,
data_sampling_dtype
)
if
prior_sampling_dtype
is
not
None
:
from
..operators.energy_operators
import
SamplingDtypeEnabler
Sinv
=
SamplingDtypeEnabler
(
Sinv
,
data_sampling_dtype
)
Sinv
=
SamplingDtypeSetter
(
Sinv
,
data_sampling_dtype
)
M
=
SandwichOperator
.
make
(
R
,
Ninv
)
if
iteration_controller_sampling
is
not
None
:
op
=
SamplingEnabler
(
M
,
Sinv
,
iteration_controller_sampling
,
...
...
nifty6/operators/endomorphic_operator.py
View file @
00cc61c9
...
...
@@ -30,23 +30,47 @@ class EndomorphicOperator(LinearOperator):
for endomorphic operators."""
return
self
.
_domain
def
draw_sample
(
self
,
from_inverse
=
False
):
"""Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
May or may not be implemented. Only optional.
Parameters
----------
from_inverse : bool (default : False)
if True, the sample is drawn from the inverse of the operator
Returns
-------
Field or MultiField
A sample from the Gaussian of given covariance.
"""
raise
NotImplementedError
def
draw_sample_with_dtype
(
self
,
dtype
,
from_inverse
=
False
):
"""Generate a
zero-mean sample
FIXME
"""Generate
s
a
sample from a Gaussian distribution with zero mean,
covariance given by the operator and specified data type.
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
This method is implemented only for operators which actually draw
samples (e.g. `DiagonalOperator`). Operators which process the sample
(like `SandwichOperator`) implement only `draw_sample()`.
May or may not be implemented. Only optional.
Parameters
----------
dtype : numpy datatype FIXME
the data type to be used for the sample
dtype : numpy.dtype or dict of numpy.dtype
Dtype used for sampling from this operator. If the domain of `op`
is a `MultiDomain`, the dtype can either be specified as one value
for all components of the `MultiDomain` or in form of a dictionary
whose keys need to conincide the with keys of the `MultiDomain`.
from_inverse : bool (default : False)
if True, the sample is drawn from the inverse of the operator
Returns
-------
Field
Field
or MultiField
A sample from the Gaussian of given covariance.
"""
raise
NotImplementedError
...
...
nifty6/operators/energy_operators.py
View file @
00cc61c9
...
...
@@ -25,10 +25,9 @@ from ..multi_field import MultiField
from
..sugar
import
makeDomain
,
makeOp
from
.linear_operator
import
LinearOperator
from
.operator
import
Operator
from
.sampling_enabler
import
SamplingEnabler
from
.sampling_enabler
import
SamplingDtypeSetter
,
SamplingEnabler
from
.scaling_operator
import
ScalingOperator
from
.simple_linear_operators
import
VdotOperator
from
.endomorphic_operator
import
EndomorphicOperator
def
_check_sampling_dtype
(
domain
,
dtypes
):
...
...
@@ -60,43 +59,6 @@ def _field_to_dtype(field):
return
dt
class
SamplingDtypeEnabler
(
EndomorphicOperator
):
def
__init__
(
self
,
endomorphic_operator
,
dtype
):
if
not
isinstance
(
endomorphic_operator
,
EndomorphicOperator
):
raise
TypeError
if
not
hasattr
(
endomorphic_operator
,
'draw_sample_with_dtype'
):
raise
TypeError
dom
=
endomorphic_operator
.
domain
if
isinstance
(
dom
,
MultiDomain
):
if
dtype
in
[
np
.
float64
,
np
.
complex128
]:
dtype
=
{
kk
:
dtype
for
kk
in
dom
.
keys
()}
if
set
(
dtype
.
keys
())
!=
set
(
dom
.
keys
()):
raise
TypeError
self
.
_dtype
=
dtype
self
.
_domain
=
dom
self
.
_capability
=
endomorphic_operator
.
_capability
self
.
apply
=
endomorphic_operator
.
apply
self
.
_op
=
endomorphic_operator
def
draw_sample
(
self
,
from_inverse
=
False
):
"""Generate a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
Parameters
----------
from_inverse : bool (default : False)
if True, the sample is drawn from the inverse of the operator
Returns
-------
Field
A sample from the Gaussian of given covariance.
"""
return
self
.
_op
.
draw_sample_with_dtype
(
self
.
_dtype
,
from_inverse
=
from_inverse
)
class
EnergyOperator
(
Operator
):
"""Operator which has a scalar domain as target domain.
...
...
@@ -199,7 +161,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
return
res
mf
=
{
self
.
_r
:
x
.
val
[
self
.
_icov
],
self
.
_icov
:
.
5
*
x
.
val
[
self
.
_icov
]
**
(
-
2
)}
met
=
makeOp
(
MultiField
.
from_dict
(
mf
))
return
res
.
add_metric
(
SamplingDtype
Enabl
er
(
met
,
self
.
_sampling_dtype
))
return
res
.
add_metric
(
SamplingDtype
Sett
er
(
met
,
self
.
_sampling_dtype
))
class
GaussianEnergy
(
EnergyOperator
):
...
...
@@ -262,7 +224,7 @@ class GaussianEnergy(EnergyOperator):
self
.
_op
=
QuadraticFormOperator
(
inverse_covariance
)
self
.
_met
=
inverse_covariance
if
sampling_dtype
is
not
None
:
self
.
_met
=
SamplingDtype
Enabl
er
(
self
.
_met
,
sampling_dtype
)
self
.
_met
=
SamplingDtype
Sett
er
(
self
.
_met
,
sampling_dtype
)
def
_checkEquivalence
(
self
,
newdom
):
newdom
=
makeDomain
(
newdom
)
...
...
@@ -313,7 +275,7 @@ class PoissonianEnergy(EnergyOperator):
res
=
x
.
sum
()
-
x
.
ptw
(
"log"
).
vdot
(
self
.
_d
)
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
SamplingDtype
Enabl
er
(
makeOp
(
1.
/
x
.
val
),
np
.
float64
))
return
res
.
add_metric
(
SamplingDtype
Sett
er
(
makeOp
(
1.
/
x
.
val
),
np
.
float64
))
class
InverseGammaLikelihood
(
EnergyOperator
):
...
...
@@ -359,7 +321,7 @@ class InverseGammaLikelihood(EnergyOperator):
return
res
met
=
makeOp
(
self
.
_alphap1
/
(
x
.
val
**
2
))
if
self
.
_sampling_dtype
is
not
None
:
met
=
SamplingDtype
Enabl
er
(
met
,
self
.
_sampling_dtype
)
met
=
SamplingDtype
Sett
er
(
met
,
self
.
_sampling_dtype
)
return
res
.
add_metric
(
met
)
...
...
@@ -394,7 +356,7 @@ class StudentTEnergy(EnergyOperator):
return
res
met
=
makeOp
((
self
.
_theta
+
1
)
/
(
self
.
_theta
+
3
),
self
.
domain
)
if
self
.
_sampling_dtype
is
not
None
:
met
=
SamplingDtype
Enabl
er
(
met
,
self
.
_sampling_dtype
)
met
=
SamplingDtype
Sett
er
(
met
,
self
.
_sampling_dtype
)
return
res
.
add_metric
(
met
)
...
...
@@ -429,7 +391,7 @@ class BernoulliEnergy(EnergyOperator):
if
not
x
.
want_metric
:
return
res
met
=
makeOp
(
1.
/
(
x
.
val
*
(
1.
-
x
.
val
)))
return
res
.
add_metric
(
SamplingDtype
Enabl
er
(
met
,
np
.
float64
))
return
res
.
add_metric
(
SamplingDtype
Sett
er
(
met
,
np
.
float64
))
class
StandardHamiltonian
(
EnergyOperator
):
...
...
nifty6/operators/sampling_enabler.py
View file @
00cc61c9
...
...
@@ -19,6 +19,7 @@ import numpy as np
from
..minimization.conjugate_gradient
import
ConjugateGradient
from
..minimization.quadratic_energy
import
QuadraticEnergy
from
..multi_domain
import
MultiDomain
from
.endomorphic_operator
import
EndomorphicOperator
from
.operator
import
Operator
...
...
@@ -96,3 +97,52 @@ class SamplingEnabler(EndomorphicOperator):
indent
(
"
\n
"
.
join
((
"Likelihood:"
,
self
.
_likelihood
.
__repr__
(),
"Prior:"
,
self
.
_prior
.
__repr__
())))))
class
SamplingDtypeSetter
(
EndomorphicOperator
):
"""Class that adds the information whether the operator at hand is the
covariance of a real-valued Gaussian or a complex-valued Gaussian
probability distribution.
This wrapper class shall address the following ambiguity which arises when
drawing a sampling from a Gaussian distribution with zero mean and given
covariance. E.g. a `ScalingOperator` with `1.` on its diagonal can be
viewed as the covariance operator of both a real-valued and complex-valued
Gaussian distribution. `SamplingDtypeSetter` specifies this data type.
Parameters
----------
op : EndomorphicOperator
Operator which shall be supplemented with a dtype for sampling. Needs
to be positive definite, hermitian and needs to implement the method
`draw_sample_with_dtype()`. Note that these three properties are not
checked in the constructor.
dtype : numpy.dtype or dict of numpy.dtype
Dtype used for sampling from this operator. If the domain of `op` is a
`MultiDomain`, the dtype can either be specified as one value for all
components of the `MultiDomain` or in form of a dictionary whose keys
need to conincide the with keys of the `MultiDomain`.
"""
def
__init__
(
self
,
op
,
dtype
):
if
not
isinstance
(
op
,
EndomorphicOperator
):
raise
TypeError
if
not
hasattr
(
op
,
'draw_sample_with_dtype'
):
raise
TypeError
if
isinstance
(
dtype
,
dict
):
dtype
=
{
kk
:
np
.
dtype
(
vv
)
for
kk
,
vv
in
dtype
.
items
()}
else
:
dtype
=
np
.
dtype
(
dtype
)
if
isinstance
(
op
.
domain
,
MultiDomain
):
if
isinstance
(
dtype
,
np
.
dtype
):
dtype
=
{
kk
:
dtype
for
kk
in
op
.
domain
.
keys
()}
if
set
(
dtype
.
keys
())
!=
set
(
op
.
domain
.
keys
()):
raise
TypeError
self
.
_dtype
=
dtype
self
.
_domain
=
op
.
domain
self
.
_capability
=
op
.
capability
self
.
apply
=
op
.
apply
self
.
_op
=
op
def
draw_sample
(
self
,
from_inverse
=
False
):
return
self
.
_op
.
draw_sample_with_dtype
(
self
.
_dtype
,
from_inverse
=
from_inverse
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment