Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
8103fbbd
Commit
8103fbbd
authored
May 28, 2021
by
Philipp Arras
Browse files
Cosmetics
parent
cc077788
Changes
6
Show whitespace changes
Inline
Side-by-side
ChangeLog.md
View file @
8103fbbd
demos/getting_started_3.py
View file @
8103fbbd
...
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-202
0
Max-Planck-Society
# Copyright(C) 2013-202
1
Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
...
@@ -97,13 +97,13 @@ def main():
data
=
signal_response
(
mock_position
)
+
N
.
draw_sample_with_dtype
(
dtype
=
np
.
float64
)
# Minimization parameters
ic_sampling
=
ift
.
AbsDeltaEnergyController
(
ic_sampling
=
ift
.
AbsDeltaEnergyController
(
name
=
"Sampling (linear)"
,
deltaE
=
0.05
,
iteration_limit
=
100
)
ic_newton
=
ift
.
AbsDeltaEnergyController
(
name
=
'Newton'
,
deltaE
=
0.5
,
iteration_limit
=
35
)
ic_newton
=
ift
.
AbsDeltaEnergyController
(
name
=
'Newton'
,
deltaE
=
0.5
,
convergence_level
=
2
,
iteration_limit
=
35
)
minimizer
=
ift
.
NewtonCG
(
ic_newton
)
ic_sampling_nl
=
ift
.
AbsDeltaEnergyController
(
name
=
'Sampling'
,
deltaE
=
0.5
,
iteration_limit
=
15
,
convergence_level
=
2
)
ic_sampling_nl
=
ift
.
AbsDeltaEnergyController
(
name
=
'Sampling (nonlin)'
,
deltaE
=
0.5
,
iteration_limit
=
15
,
convergence_level
=
2
)
minimizer_sampling
=
ift
.
NewtonCG
(
ic_sampling_nl
)
# Set up likelihood and information Hamiltonian
...
...
docs/source/ift.rst
View file @
8103fbbd
src/minimization/kl_energies.py
View file @
8103fbbd
...
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-202
0
Max-Planck-Society
# Copyright(C) 2013-202
1
Max-Planck-Society
# Authors: Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
...
@@ -21,6 +21,7 @@ from functools import reduce
from
..
import
random
from
..
import
utilities
from
..domain_tuple
import
DomainTuple
from
..linearization
import
Linearization
from
..multi_field
import
MultiField
from
..operators.inversion_enabler
import
InversionEnabler
...
...
@@ -38,17 +39,6 @@ from ..utilities import myassert
from
.energy
import
Energy
from
.descent_minimizers
import
DescentMinimizer
,
ConjugateGradient
def
_is_prior_dtype_float
(
H
):
real
=
True
dts
=
H
.
_prior
.
_met
.
_dtype
if
isinstance
(
dts
,
dict
):
for
k
in
dts
.
keys
():
if
not
np
.
issubdtype
(
dts
[
k
],
np
.
float
):
real
=
False
else
:
real
=
np
.
issubdtype
(
dts
,
np
.
float
)
return
real
def
_get_lo_hi
(
comm
,
n_samples
):
ntask
,
rank
,
_
=
utilities
.
get_MPI_params_from_comm
(
comm
)
...
...
@@ -187,8 +177,17 @@ class _GeoMetricSampler:
n_samples
,
mirror_samples
,
napprox
=
0
,
want_error
=
False
):
if
not
isinstance
(
H
,
StandardHamiltonian
):
raise
NotImplementedError
if
not
_is_prior_dtype_float
(
H
):
# Check domain dtype
dts
=
H
.
_prior
.
_met
.
_dtype
if
isinstance
(
H
.
domain
,
DomainTuple
):
real
=
np
.
issubdtype
(
dts
,
np
.
float
)
else
:
real
=
all
([
np
.
issubdtype
(
dts
[
kk
],
np
.
float
)
for
kk
in
dts
.
keys
()])
if
not
real
:
raise
ValueError
(
"_GeoMetricSampler only supports real valued latent DOFs."
)
# /Check domain dtype
if
isinstance
(
position
,
MultiField
):
self
.
_position
=
position
.
extract
(
H
.
domain
)
else
:
...
...
@@ -206,12 +205,11 @@ class _GeoMetricSampler:
scale
=
SamplingDtypeSetter
(
scale
,
dtype
)
if
sampling
else
scale
fl
=
f_lh
(
Linearization
.
make_var
(
self
.
_position
))
self
.
_g
=
(
Adder
(
-
self
.
_position
)
+
fl
.
jac
.
adjoint
@
Adder
(
-
fl
.
val
)
@
f_lh
)
self
.
_g
=
(
Adder
(
-
self
.
_position
)
+
fl
.
jac
.
adjoint
@
Adder
(
-
fl
.
val
)
@
f_lh
)
self
.
_likelihood
=
SandwichOperator
.
make
(
fl
.
jac
,
scale
)
self
.
_prior
=
SamplingDtypeSetter
(
ScalingOperator
(
fl
.
domain
,
1.
),
np
.
float64
)
self
.
_met
=
self
.
_likelihood
+
self
.
_prior
if
napprox
>=
1
:
if
napprox
>=
1
:
self
.
_approximation
=
makeOp
(
approximation2endo
(
self
.
_met
,
napprox
)).
inverse
else
:
self
.
_approximation
=
None
...
...
src/operators/energy_operators.py
View file @
8103fbbd
...
...
@@ -93,10 +93,11 @@ class LikelihoodOperator(EnergyOperator):
:func:`~nifty7.operators.operator.Operator.get_transformation`.
"""
dtp
,
f
=
self
.
get_transformation
()
ch
=
ScalingOperator
(
f
.
target
,
1.
)
ch
=
None
if
dtp
is
not
None
:
ch
=
SamplingDtypeSetter
(
ch
,
dtp
)
return
SandwichOperator
.
make
(
f
(
Linearization
.
make_var
(
x
)).
jac
,
ch
)
ch
=
SamplingDtypeSetter
(
ScalingOperator
(
f
.
target
,
1.
),
dtp
)
bun
=
f
(
Linearization
.
make_var
(
x
)).
jac
return
SandwichOperator
.
make
(
bun
,
ch
)
class
Squared2NormOperator
(
EnergyOperator
):
...
...
@@ -181,8 +182,8 @@ class VariableCovarianceGaussianEnergy(LikelihoodOperator):
Default is True
"""
def
__init__
(
self
,
domain
,
residual_key
,
inverse_covariance_key
,
sampling_dtype
,
use_full_fisher
=
True
):
def
__init__
(
self
,
domain
,
residual_key
,
inverse_covariance_key
,
sampling_dtype
,
use_full_fisher
=
True
):
self
.
_kr
=
str
(
residual_key
)
self
.
_ki
=
str
(
inverse_covariance_key
)
dom
=
DomainTuple
.
make
(
domain
)
...
...
@@ -190,7 +191,7 @@ class VariableCovarianceGaussianEnergy(LikelihoodOperator):
self
.
_dt
=
{
self
.
_kr
:
sampling_dtype
,
self
.
_ki
:
np
.
float64
}
_check_sampling_dtype
(
self
.
_domain
,
self
.
_dt
)
self
.
_cplx
=
_iscomplex
(
sampling_dtype
)
self
.
_use_fisher
=
use_full_fisher
self
.
_use_
full_
fisher
=
use_full_fisher
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
...
...
@@ -201,7 +202,7 @@ class VariableCovarianceGaussianEnergy(LikelihoodOperator):
res
=
0.5
*
(
r
.
vdot
(
r
*
i
)
-
i
.
ptw
(
"log"
).
sum
())
if
not
x
.
want_metric
:
return
res
if
self
.
_use_fisher
:
if
self
.
_use_
full_
fisher
:
met
=
1.
if
self
.
_cplx
else
0.5
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
*
i
.
val
**
(
-
2
)},
domain
=
self
.
_domain
)
...
...
@@ -616,6 +617,3 @@ class AveragedEnergy(EnergyOperator):
dtp
,
trafo
=
self
.
_h
.
get_transformation
()
mymap
=
map
(
lambda
v
:
trafo
@
Adder
(
v
),
self
.
_res_samples
)
return
dtp
,
utilities
.
my_sum
(
mymap
)
/
np
.
sqrt
(
len
(
self
.
_res_samples
))
src/operators/operator.py
View file @
8103fbbd
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment