Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
345c6ca5
Commit
345c6ca5
authored
Oct 17, 2019
by
Martin Reinecke
Browse files
Merge remote-tracking branch 'origin/NIFTy_5' into constant_operators
parents
ac61e3f5
156c9d79
Changes
6
Hide whitespace changes
Inline
Side-by-side
nifty5/__init__.py
View file @
345c6ca5
...
...
@@ -46,15 +46,16 @@ from .operators.outer_product_operator import OuterProduct
from
.operators.simple_linear_operators
import
(
VdotOperator
,
ConjugationOperator
,
Realizer
,
FieldAdapter
,
ducktape
,
GeometryRemover
,
NullOperator
,
MatrixProductOperator
)
MatrixProductOperator
,
PartialExtractor
)
from
.operators.value_inserter
import
ValueInserter
from
.operators.energy_operators
import
(
EnergyOperator
,
GaussianEnergy
,
PoissonianEnergy
,
InverseGammaLikelihood
,
BernoulliEnergy
,
StandardHamiltonian
,
AveragedEnergy
)
BernoulliEnergy
,
StandardHamiltonian
,
AveragedEnergy
,
QuadraticFormOperator
,
Squared2NormOperator
)
from
.operators.convolution_operators
import
FuncConvolutionOperator
from
.probing
import
probe_with_posterior_samples
,
probe_diagonal
,
\
StatCalculator
StatCalculator
,
approximation2endo
from
.minimization.line_search
import
LineSearch
from
.minimization.iteration_controllers
import
(
...
...
@@ -97,6 +98,8 @@ from .logger import logger
from
.linearization
import
Linearization
from
.operator_spectrum
import
operator_spectrum
from
.
import
internal_config
_scheme
=
internal_config
.
parallelization_scheme
()
if
_scheme
==
"Samples"
:
...
...
nifty5/operator_spectrum.py
0 → 100644
View file @
345c6ca5
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
import
numpy
as
np
import
scipy.sparse.linalg
as
ssl
from
.domain_tuple
import
DomainTuple
from
.domains.unstructured_domain
import
UnstructuredDomain
from
.field
import
Field
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.operators.linear_operator
import
LinearOperator
from
.operators.sandwich_operator
import
SandwichOperator
from
.sugar
import
from_global_data
,
makeDomain
class
_DomRemover
(
LinearOperator
):
"""Operator which transforms between a structured MultiDomain
and an unstructured domain.
Parameters
----------
domain: MultiDomain
the full input domain of the operator.
Notes
-----
The operator converts the full domain of its input domain to an
UnstructuredDomain
"""
def
__init__
(
self
,
domain
):
self
.
_domain
=
makeDomain
(
domain
)
if
isinstance
(
self
.
_domain
,
MultiDomain
):
self
.
_size_array
=
np
.
array
([
0
]
+
[
d
.
size
for
d
in
domain
.
values
()])
else
:
self
.
_size_array
=
np
.
array
([
0
,
domain
.
size
])
np
.
cumsum
(
self
.
_size_array
,
out
=
self
.
_size_array
)
target
=
UnstructuredDomain
(
self
.
_size_array
[
-
1
])
self
.
_target
=
makeDomain
(
target
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
self
.
_check_float_dtype
(
x
)
x
=
x
.
to_global_data
()
if
isinstance
(
self
.
_domain
,
DomainTuple
):
res
=
x
.
ravel
()
if
mode
==
self
.
TIMES
else
x
.
reshape
(
self
.
_domain
.
shape
)
else
:
res
=
np
.
empty
(
self
.
target
.
shape
)
if
mode
==
self
.
TIMES
else
{}
for
ii
,
(
kk
,
dd
)
in
enumerate
(
self
.
domain
.
items
()):
i0
,
i1
=
self
.
_size_array
[
ii
:
ii
+
2
]
if
mode
==
self
.
TIMES
:
res
[
i0
:
i1
]
=
x
[
kk
].
ravel
()
else
:
res
[
kk
]
=
x
[
i0
:
i1
].
reshape
(
dd
.
shape
)
return
from_global_data
(
self
.
_tgt
(
mode
),
res
)
@
staticmethod
def
_check_float_dtype
(
fld
):
if
isinstance
(
fld
,
MultiField
):
dts
=
[
ff
.
local_data
.
dtype
for
ff
in
fld
.
values
()]
elif
isinstance
(
fld
,
Field
):
dts
=
[
fld
.
local_data
.
dtype
]
else
:
raise
TypeError
for
dt
in
dts
:
if
not
np
.
issubdtype
(
dt
,
np
.
float64
):
raise
TypeError
(
'Operator supports only floating point dtypes'
)
def
operator_spectrum
(
A
,
k
,
hermitian
,
which
=
'LM'
,
tol
=
0
):
'''
Find k eigenvalues and eigenvectors of the endomorphism A.
Parameters
----------
A : LinearOperator
Operator of which eigenvalues shall be computed.
k : int
The number of eigenvalues and eigenvectors desired. `k` must be
smaller than N-1. It is not possible to compute all eigenvectors of a
matrix.
hermitian: bool
Specifies whether A is hermitian or not.
which : str, ['LM' | 'SM' | 'LR' | 'SR' | 'LI' | 'SI'], optional
Which `k` eigenvectors and eigenvalues to find:
'LM' : largest magnitude
'SM' : smallest magnitude
'LR' : largest real part
'SR' : smallest real part
'LI' : largest imaginary part
'SI' : smallest imaginary part
tol : float, optional
Relative accuracy for eigenvalues (stopping criterion)
The default value of 0 implies machine precision.
Returns
-------
w : ndarray
Array of k eigenvalues.
Raises
------
ArpackNoConvergence
When the requested convergence is not obtained.
The currently converged eigenvalues and eigenvectors can be found
as ``eigenvalues`` and ``eigenvectors`` attributes of the exception
object.
'''
if
not
isinstance
(
A
,
LinearOperator
):
raise
TypeError
(
'Operator needs to be linear.'
)
if
A
.
domain
is
not
A
.
target
:
raise
TypeError
(
'Operator needs to be endomorphism.'
)
size
=
A
.
domain
.
size
Ar
=
SandwichOperator
.
make
(
_DomRemover
(
A
.
domain
).
adjoint
,
A
)
M
=
ssl
.
LinearOperator
(
shape
=
2
*
(
size
,),
matvec
=
lambda
x
:
Ar
(
from_global_data
(
Ar
.
domain
,
x
)).
to_global_data
())
f
=
ssl
.
eigsh
if
hermitian
else
ssl
.
eigs
eigs
=
f
(
M
,
k
=
k
,
tol
=
tol
,
return_eigenvectors
=
False
,
which
=
which
)
return
np
.
flip
(
np
.
sort
(
eigs
),
axis
=
0
)
nifty5/operators/energy_operators.py
View file @
345c6ca5
...
...
@@ -352,9 +352,11 @@ class StandardHamiltonian(EnergyOperator):
`<https://arxiv.org/abs/1812.04403>`_
"""
def
__init__
(
self
,
lh
,
ic_samp
=
None
):
def
__init__
(
self
,
lh
,
ic_samp
=
None
,
_c_inp
=
None
):
self
.
_lh
=
lh
self
.
_prior
=
GaussianEnergy
(
domain
=
lh
.
domain
)
if
_c_inp
is
not
None
:
_
,
self
.
_prior
=
self
.
_prior
.
simplify_for_constant_input
(
_c_inp
)
self
.
_ic_samp
=
ic_samp
self
.
_domain
=
lh
.
domain
...
...
@@ -371,9 +373,13 @@ class StandardHamiltonian(EnergyOperator):
def
__repr__
(
self
):
subs
=
'Likelihood:
\n
{}'
.
format
(
utilities
.
indent
(
self
.
_lh
.
__repr__
()))
subs
+=
'
\n
Prior:
Quadratic
{}'
.
format
(
self
.
_
lh
.
domain
.
keys
()
)
subs
+=
'
\n
Prior:
\n
{}'
.
format
(
self
.
_
prior
)
return
'StandardHamiltonian:
\n
'
+
utilities
.
indent
(
subs
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
out
,
lh1
=
self
.
_lh
.
simplify_for_constant_input
(
c_inp
)
return
out
,
StandardHamiltonian
(
lh1
,
self
.
_ic_samp
,
_c_inp
=
c_inp
)
class
AveragedEnergy
(
EnergyOperator
):
"""Averages an energy over samples.
...
...
nifty5/operators/inversion_enabler.py
View file @
345c6ca5
...
...
@@ -23,6 +23,7 @@ from ..minimization.iteration_controllers import IterationController
from
..minimization.quadratic_energy
import
QuadraticEnergy
from
..sugar
import
full
from
.endomorphic_operator
import
EndomorphicOperator
from
.linear_operator
import
LinearOperator
class
InversionEnabler
(
EndomorphicOperator
):
...
...
@@ -47,6 +48,12 @@ class InversionEnabler(EndomorphicOperator):
"""
def
__init__
(
self
,
op
,
iteration_controller
,
approximation
=
None
):
# isinstance(op, EndomorphicOperator) does not suffice since op can be
# a ChainOperator
if
not
isinstance
(
op
,
LinearOperator
):
raise
TypeError
(
'Operator needs to be linear.'
)
if
op
.
domain
is
not
op
.
target
:
raise
TypeError
(
'Operator needs to be endomorphic.'
)
self
.
_op
=
op
self
.
_ic
=
iteration_controller
self
.
_approximation
=
approximation
...
...
nifty5/operators/sampling_enabler.py
View file @
345c6ca5
...
...
@@ -42,15 +42,20 @@ class SamplingEnabler(EndomorphicOperator):
operator, which supports the operation modes that the operator doesn't
have. It is used as a preconditioner during the iterative inversion,
to accelerate convergence.
start_from_zero : boolean
If true, the conjugate gradient algorithm starts from a field filled
with zeros. Otherwise, it starts from a prior samples. Default is
False.
"""
def
__init__
(
self
,
likelihood
,
prior
,
iteration_controller
,
approximation
=
None
):
self
.
_op
=
likelihood
+
prior
approximation
=
None
,
start_from_zero
=
False
):
self
.
_likelihood
=
likelihood
self
.
_prior
=
prior
self
.
_ic
=
iteration_controller
self
.
_approximation
=
approximation
self
.
_start_from_zero
=
bool
(
start_from_zero
)
self
.
_op
=
likelihood
+
prior
self
.
_domain
=
self
.
_op
.
domain
self
.
_capability
=
self
.
_op
.
capability
...
...
@@ -60,11 +65,15 @@ class SamplingEnabler(EndomorphicOperator):
except
NotImplementedError
:
if
not
from_inverse
:
raise
ValueError
(
"from_inverse must be True here"
)
s
=
self
.
_prior
.
draw_sample
(
from_inverse
=
True
)
sp
=
self
.
_prior
(
s
)
nj
=
self
.
_likelihood
.
draw_sample
()
energy
=
QuadraticEnergy
(
s
,
self
.
_op
,
sp
+
nj
,
_grad
=
self
.
_likelihood
(
s
)
-
nj
)
if
self
.
_start_from_zero
:
b
=
self
.
_op
.
draw_sample
()
energy
=
QuadraticEnergy
(
0
*
b
,
self
.
_op
,
b
)
else
:
s
=
self
.
_prior
.
draw_sample
(
from_inverse
=
True
)
sp
=
self
.
_prior
(
s
)
nj
=
self
.
_likelihood
.
draw_sample
()
energy
=
QuadraticEnergy
(
s
,
self
.
_op
,
sp
+
nj
,
_grad
=
self
.
_likelihood
(
s
)
-
nj
)
inverter
=
ConjugateGradient
(
self
.
_ic
)
if
self
.
_approximation
is
not
None
:
energy
,
convergence
=
inverter
(
...
...
nifty5/plot.py
View file @
345c6ca5
...
...
@@ -347,8 +347,12 @@ def _plot2D(f, ax, **kwargs):
if
len
(
dom
)
==
2
:
if
(
not
isinstance
(
dom
[
1
],
RGSpace
))
or
len
(
dom
[
1
].
shape
)
!=
1
:
raise
TypeError
(
"need 1D RGSpace as second domain"
)
rgb
=
_rgb_data
(
f
.
to_global_data
())
have_rgb
=
True
if
dom
[
1
].
shape
[
0
]
==
1
:
from
.sugar
import
from_global_data
f
=
from_global_data
(
f
.
domain
[
0
],
f
.
to_global_data
()[...,
0
])
else
:
rgb
=
_rgb_data
(
f
.
to_global_data
())
have_rgb
=
True
foo
=
kwargs
.
pop
(
"norm"
,
None
)
norm
=
{}
if
foo
is
None
else
{
'norm'
:
foo
}
...
...
@@ -477,6 +481,17 @@ class Plot(object):
alpha: float or list of floats
Transparency value.
"""
from
.multi_field
import
MultiField
if
isinstance
(
f
,
MultiField
):
for
kk
in
f
.
domain
.
keys
():
self
.
_plots
.
append
(
f
[
kk
])
mykwargs
=
kwargs
.
copy
()
if
'title'
in
kwargs
:
mykwargs
[
'title'
]
=
"{} {}"
.
format
(
kk
,
kwargs
[
'title'
])
else
:
mykwargs
[
'title'
]
=
"{}"
.
format
(
kk
)
self
.
_kwargs
.
append
(
mykwargs
)
return
self
.
_plots
.
append
(
f
)
self
.
_kwargs
.
append
(
kwargs
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment