Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
9161a4c6
Commit
9161a4c6
authored
Apr 02, 2020
by
Martin Reinecke
Browse files
more
parent
a8303513
Changes
12
Hide whitespace changes
Inline
Side-by-side
demos/getting_started_1.py
View file @
9161a4c6
...
...
@@ -69,7 +69,7 @@ if __name__ == '__main__':
harmonic_space
=
position_space
.
get_default_codomain
()
# Harmonic transform from harmonic space to position space
HT
=
ift
.
M_
HarmonicTransformOperator
(
harmonic_space
,
target
=
position_space
)
HT
=
ift
.
HarmonicTransformOperator
(
harmonic_space
,
target
=
position_space
)
# Set prior correlation covariance with a power spectrum leading to
# homogeneous and isotropic statistics
...
...
@@ -86,15 +86,15 @@ if __name__ == '__main__':
prior_correlation_structure
=
PD
(
ift
.
PS_field
(
power_space
,
power_spectrum
))
# Insert the result into the diagonal of an harmonic space operator
S
=
ift
.
M_
DiagonalOperator
(
prior_correlation_structure
)
S
=
ift
.
DiagonalOperator
(
prior_correlation_structure
)
# S is the prior field covariance
# Build instrument response consisting of a discretization, mask
# and harmonic transformaion
# Masking operator to model that parts of the field have not been observed
mask
=
ift
.
Multi
Field
.
from_raw
(
position_space
,
mask
)
Mask
=
ift
.
SingleLinearAdapter
(
ift
.
MaskOperator
(
mask
.
sing
)
)
mask
=
ift
.
Field
.
from_raw
(
position_space
,
mask
)
Mask
=
ift
.
MaskOperator
(
mask
)
# The response operator consists of
# - a harmonic transform (to get to image space)
...
...
@@ -134,11 +134,11 @@ if __name__ == '__main__':
filename
=
"getting_started_1_mode_{}.png"
.
format
(
mode
)
if
rg
and
len
(
position_space
.
shape
)
==
1
:
plot
.
add
(
[
HT
(
MOCK_SIGNAL
)
.
sing
,
Mask
.
adjoint
(
data
)
.
sing
,
HT
(
m
)
.
sing
],
[
HT
(
MOCK_SIGNAL
),
Mask
.
adjoint
(
data
),
HT
(
m
)],
label
=
[
'Mock signal'
,
'Data'
,
'Reconstruction'
],
alpha
=
[
1
,
.
3
,
1
])
plot
.
add
(
Mask
.
adjoint
(
Mask
(
HT
(
m
-
MOCK_SIGNAL
)))
.
sing
,
title
=
'Residuals'
)
plot
.
add
(
Mask
.
adjoint
(
Mask
(
HT
(
m
-
MOCK_SIGNAL
))),
title
=
'Residuals'
)
plot
.
output
(
nx
=
2
,
ny
=
1
,
xsize
=
10
,
ysize
=
4
,
name
=
filename
)
else
:
plot
.
add
(
HT
(
MOCK_SIGNAL
),
title
=
'Mock Signal'
)
...
...
nifty6/domain_tuple.py
View file @
9161a4c6
...
...
@@ -191,12 +191,7 @@ class DomainTuple(object):
return
self
.
_dom
.
__hash__
()
def
__eq__
(
self
,
x
):
if
self
is
x
:
return
True
from
.multi_domain
import
MultiDomain
if
isinstance
(
x
,
MultiDomain
):
return
self
.
mult
==
x
return
self
.
_dom
==
x
.
_dom
return
(
self
is
x
)
or
(
self
.
_dom
==
x
.
_dom
)
def
__ne__
(
self
,
x
):
return
not
self
.
__eq__
(
x
)
...
...
nifty6/field.py
View file @
9161a4c6
...
...
@@ -596,7 +596,7 @@ class Field(object):
m1
=
self
.
mean
(
spaces
)
from
.operators.contraction_operator
import
ContractionOperator
op
=
ContractionOperator
(
self
.
_domain
,
spaces
)
m1
=
op
.
adjoint_times
(
m1
)
m1
=
op
.
adjoint_times
(
m1
.
mult
).
sing
if
utilities
.
iscomplextype
(
self
.
dtype
):
sq
=
abs
(
self
-
m1
)
**
2
else
:
...
...
nifty6/linearization.py
View file @
9161a4c6
...
...
@@ -333,7 +333,7 @@ class Linearization(Operator):
ind
=
self
.
_val
.
val
==
0
loc
=
tmp2
.
val_rw
()
loc
[
ind
]
=
0
tmp2
=
Field
(
tmp
.
domain
,
loc
)
tmp2
=
make
Field
(
tmp
.
domain
,
loc
)
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
log
(
self
):
...
...
nifty6/multi_domain.py
View file @
9161a4c6
...
...
@@ -105,8 +105,6 @@ class MultiDomain(object):
def
__eq__
(
self
,
x
):
if
self
is
x
:
return
True
if
isinstance
(
x
,
DomainTuple
):
x
=
x
.
mult
return
list
(
self
.
items
())
==
list
(
x
.
items
())
def
__ne__
(
self
,
x
):
...
...
@@ -129,7 +127,6 @@ class MultiDomain(object):
return
inp
.
pop
()
res
=
{}
for
dom
in
inp
:
dom
=
dom
.
mult
for
key
,
subdom
in
zip
(
dom
.
_keys
,
dom
.
_domains
):
if
key
in
res
:
if
res
[
key
]
!=
subdom
:
...
...
nifty6/operators/block_diagonal_operator.py
View file @
9161a4c6
...
...
@@ -44,7 +44,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
val
=
tuple
(
op
.
apply
(
v
,
mode
=
mode
)
if
op
is
not
None
else
v
val
=
tuple
(
op
.
apply
(
v
.
mult
,
mode
=
mode
)
.
sing
if
op
is
not
None
else
v
for
op
,
v
in
zip
(
self
.
_ops
,
x
.
values
()))
return
MultiField
(
self
.
_domain
,
val
)
...
...
nifty6/operators/diagonal_operator.py
View file @
9161a4c6
...
...
@@ -128,7 +128,7 @@ class DiagonalOperator(EndomorphicOperator_s):
return
self
.
_from_ldiag
(
op
.
_spaces
,
tdiag
)
def
_apply_s
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
self
.
_check_input
_s
(
x
,
mode
)
# shortcut for most common cases
if
mode
==
1
or
(
not
self
.
_complex
and
mode
==
2
):
return
Field
(
x
.
domain
,
x
.
val
*
self
.
_ldiag
)
...
...
nifty6/operators/distributors.py
View file @
9161a4c6
...
...
@@ -22,10 +22,10 @@ from ..domains.dof_space import DOFSpace
from
..domains.power_space
import
PowerSpace
from
..field
import
Field
from
..utilities
import
infer_space
,
special_add_at
from
.linear_operator
import
LinearOperator
from
.linear_operator
import
LinearOperator
_s
class
DOFDistributor
(
LinearOperator
):
class
DOFDistributor
(
LinearOperator
_s
):
"""Operator which distributes actual degrees of freedom (dof) according to
some distribution scheme into a higher dimensional space. This distribution
scheme is defined by the dofdex, a degree of freedom index, which
...
...
@@ -50,17 +50,17 @@ class DOFDistributor(LinearOperator):
"""
def
__init__
(
self
,
dofdex
,
target
=
None
,
space
=
None
):
if
target
is
None
:
target
=
dofdex
.
domain
self
.
_target
=
DomainTuple
.
make
(
target
)
space
=
infer_space
(
self
.
_target
,
space
)
partner
=
self
.
_target
[
space
]
if
not
isinstance
(
dofdex
,
Field
):
raise
TypeError
(
"dofdex must be a Field"
)
if
not
len
(
dofdex
.
domain
)
==
1
:
raise
ValueError
(
"dofdex must be defined on exactly one Space"
)
if
not
np
.
issubdtype
(
dofdex
.
dtype
,
np
.
integer
):
raise
TypeError
(
"dofdex must contain integer numbers"
)
if
target
is
None
:
target
=
dofdex
.
domain
self
.
_target_s
=
DomainTuple
.
make
(
target
)
space
=
infer_space
(
self
.
_target_s
,
space
)
partner
=
self
.
_target_s
[
space
]
if
partner
!=
dofdex
.
domain
[
0
]:
raise
ValueError
(
"incorrect dofdex domain"
)
...
...
@@ -88,18 +88,18 @@ class DOFDistributor(LinearOperator):
def
_init2
(
self
,
dofdex
,
space
,
other_space
):
self
.
_space
=
space
dom
=
list
(
self
.
_target
)
dom
=
list
(
self
.
_target
_s
)
dom
[
self
.
_space
]
=
other_space
self
.
_domain
=
DomainTuple
.
make
(
dom
)
self
.
_domain
_s
=
DomainTuple
.
make
(
dom
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_dofdex
=
dofdex
.
ravel
()
firstaxis
=
self
.
_target
.
axes
[
self
.
_space
][
0
]
lastaxis
=
self
.
_target
.
axes
[
self
.
_space
][
-
1
]
arrshape
=
self
.
_target
.
shape
firstaxis
=
self
.
_target
_s
.
axes
[
self
.
_space
][
0
]
lastaxis
=
self
.
_target
_s
.
axes
[
self
.
_space
][
-
1
]
arrshape
=
self
.
_target
_s
.
shape
presize
=
np
.
prod
(
arrshape
[
0
:
firstaxis
],
dtype
=
np
.
int
)
postsize
=
np
.
prod
(
arrshape
[
lastaxis
+
1
:],
dtype
=
np
.
int
)
self
.
_hshape
=
(
presize
,
self
.
_domain
[
self
.
_space
].
shape
[
0
],
postsize
)
self
.
_hshape
=
(
presize
,
self
.
_domain
_s
[
self
.
_space
].
shape
[
0
],
postsize
)
self
.
_pshape
=
(
presize
,
self
.
_dofdex
.
size
,
postsize
)
def
_adjoint_times
(
self
,
x
):
...
...
@@ -107,8 +107,8 @@ class DOFDistributor(LinearOperator):
arr
=
arr
.
reshape
(
self
.
_pshape
)
oarr
=
np
.
zeros
(
self
.
_hshape
,
dtype
=
x
.
dtype
)
oarr
=
special_add_at
(
oarr
,
1
,
self
.
_dofdex
,
arr
)
oarr
=
oarr
.
reshape
(
self
.
_domain
.
shape
)
res
=
Field
.
from_raw
(
self
.
_domain
,
oarr
)
oarr
=
oarr
.
reshape
(
self
.
_domain
_s
.
shape
)
res
=
Field
.
from_raw
(
self
.
_domain
_s
,
oarr
)
return
res
def
_times
(
self
,
x
):
...
...
@@ -116,10 +116,10 @@ class DOFDistributor(LinearOperator):
arr
=
arr
.
reshape
(
self
.
_hshape
)
oarr
=
np
.
empty
(
self
.
_pshape
,
dtype
=
x
.
dtype
)
oarr
[()]
=
arr
[(
slice
(
None
),
self
.
_dofdex
,
slice
(
None
))]
return
Field
(
self
.
_target
,
oarr
.
reshape
(
self
.
_target
.
shape
))
return
Field
(
self
.
_target
_s
,
oarr
.
reshape
(
self
.
_target
_s
.
shape
))
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
def
_
apply
_s
(
self
,
x
,
mode
):
self
.
_check_input
_s
(
x
,
mode
)
return
self
.
_times
(
x
)
if
mode
==
self
.
TIMES
else
self
.
_adjoint_times
(
x
)
...
...
@@ -141,9 +141,9 @@ class PowerDistributor(DOFDistributor):
def
__init__
(
self
,
target
,
power_space
=
None
,
space
=
None
):
# Initialize domain and target
self
.
_target
=
DomainTuple
.
make
(
target
)
self
.
_space
=
infer_space
(
self
.
_target
,
space
)
hspace
=
self
.
_target
[
self
.
_space
]
self
.
_target
_s
=
DomainTuple
.
make
(
target
)
self
.
_space
=
infer_space
(
self
.
_target
_s
,
space
)
hspace
=
self
.
_target
_s
[
self
.
_space
]
if
not
hspace
.
harmonic
:
raise
ValueError
(
"Operator requires harmonic target space"
)
if
power_space
is
None
:
...
...
nifty6/operators/energy_operators.py
View file @
9161a4c6
...
...
@@ -42,7 +42,7 @@ class EnergyOperator(Operator):
- Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
divergence.
"""
_target
=
DomainTuple
.
scalar_domain
()
_target
=
DomainTuple
.
scalar_domain
()
.
mult
class
Squared2NormOperator
(
EnergyOperator
):
...
...
@@ -60,9 +60,8 @@ class Squared2NormOperator(EnergyOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
if
not
isinstance
(
x
,
Linearization
):
res
=
x
.
vdot
(
x
)
return
res
res
=
x
.
val
.
vdot
(
x
.
val
)
return
x
.
vdot
(
x
).
mult
res
=
x
.
val
.
vdot
(
x
.
val
).
mult
return
x
.
new
(
res
,
VdotOperator
(
2
*
x
.
val
))
...
...
@@ -89,8 +88,8 @@ class QuadraticFormOperator(EnergyOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
if
not
isinstance
(
x
,
Linearization
):
return
0.5
*
x
.
vdot
(
self
.
_op
(
x
))
res
=
0.5
*
x
.
val
.
vdot
(
self
.
_op
(
x
.
val
))
return
0.5
*
x
.
vdot
(
self
.
_op
(
x
))
.
mult
res
=
0.5
*
x
.
val
.
vdot
(
self
.
_op
(
x
.
val
))
.
mult
return
x
.
new
(
res
,
VdotOperator
(
self
.
_op
(
x
.
val
)))
...
...
@@ -127,7 +126,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
res
=
0.5
*
(
x
[
self
.
_r
].
vdot
(
x
[
self
.
_r
]
*
x
[
self
.
_icov
]).
real
-
x
[
self
.
_icov
].
log
().
sum
())
res
=
0.5
*
(
x
[
self
.
_r
].
vdot
(
x
[
self
.
_r
]
*
x
[
self
.
_icov
]).
real
-
x
[
self
.
_icov
].
log
().
sum
())
.
mult
if
not
isinstance
(
x
,
Linearization
)
or
not
x
.
want_metric
:
return
res
mf
=
{
self
.
_r
:
x
.
val
[
self
.
_icov
],
self
.
_icov
:
.
5
*
x
.
val
[
self
.
_icov
]
**
(
-
2
)}
...
...
nifty6/operators/simple_linear_operators.py
View file @
9161a4c6
...
...
@@ -23,6 +23,7 @@ from ..multi_field import MultiField
from
.linear_operator
import
LinearOperator
from
.endomorphic_operator
import
EndomorphicOperator
from
..
import
utilities
from
..sugar
import
makeDomain
import
numpy
as
np
...
...
@@ -37,14 +38,14 @@ class VdotOperator(LinearOperator):
def
__init__
(
self
,
field
):
self
.
_field
=
field
self
.
_domain
=
field
.
domain
self
.
_target
=
DomainTuple
.
scalar_domain
()
self
.
_target
=
DomainTuple
.
scalar_domain
()
.
mult
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
if
mode
==
self
.
TIMES
:
return
self
.
_field
.
vdot
(
x
)
return
self
.
_field
*
x
.
val
[()]
return
self
.
_field
.
vdot
(
x
)
.
mult
return
self
.
_field
*
x
.
values
()[
0
].
val
[()]
class
ConjugationOperator
(
EndomorphicOperator
):
...
...
@@ -104,7 +105,7 @@ class Realizer(EndomorphicOperator):
"""
def
__init__
(
self
,
domain
):
self
.
_domain
=
Domain
Tuple
.
make
(
domain
)
self
.
_domain
=
make
Domain
(
domain
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
...
...
test/test_gaussian_energy.py
View file @
9161a4c6
...
...
@@ -45,13 +45,13 @@ def test_gaussian_energy(space, nonlinearity, noise, seed):
binbounds
=
ift
.
PowerSpace
.
useful_binbounds
(
hspace
,
logarithmic
=
False
)
pspace
=
ift
.
PowerSpace
(
hspace
,
binbounds
=
binbounds
)
Dist
=
ift
.
PowerDistributor
(
target
=
hspace
,
power_space
=
pspace
)
xi0
=
ift
.
Field
.
from_random
(
domain
=
hspace
,
random_type
=
'normal'
)
xi0
=
ift
.
Field
.
from_random
(
domain
=
hspace
,
random_type
=
'normal'
)
.
mult
def
pspec
(
k
):
return
1
/
(
1
+
k
**
2
)
**
dim
pspec
=
ift
.
PS_field
(
pspace
,
pspec
)
A
=
Dist
(
ift
.
sqrt
(
pspec
))
A
=
Dist
(
ift
.
sqrt
(
pspec
))
.
mult
N
=
ift
.
ScalingOperator
(
space
,
noise
)
n
=
N
.
draw_sample
()
R
=
ift
.
ScalingOperator
(
space
,
10.
)
...
...
test/test_linearization.py
View file @
9161a4c6
...
...
@@ -26,7 +26,7 @@ pmp = pytest.mark.parametrize
def
_lin2grad
(
lin
):
return
lin
.
jac
(
ift
.
full
(
lin
.
domain
,
1.
)).
val
return
lin
.
jac
(
ift
.
full
(
lin
.
domain
,
1.
)).
sing
.
val
def
jt
(
lin
,
check
):
...
...
@@ -35,9 +35,9 @@ def jt(lin, check):
def
test_special_gradients
():
dom
=
ift
.
UnstructuredDomain
((
1
,))
f
=
ift
.
full
(
dom
,
2.4
)
f
=
ift
.
full
(
dom
,
2.4
)
.
mult
var
=
ift
.
Linearization
.
make_var
(
f
)
s
=
f
.
val
s
=
f
.
sing
.
val
jt
(
var
.
clip
(
0
,
10
),
np
.
ones_like
(
s
))
jt
(
var
.
clip
(
-
1
,
0
),
np
.
zeros_like
(
s
))
...
...
@@ -59,12 +59,12 @@ def test_special_gradients():
])
def
test_actual_gradients
(
f
):
dom
=
ift
.
UnstructuredDomain
((
1
,))
fld
=
ift
.
full
(
dom
,
2.4
)
fld
=
ift
.
full
(
dom
,
2.4
)
.
mult
eps
=
1e-8
var0
=
ift
.
Linearization
.
make_var
(
fld
)
var1
=
ift
.
Linearization
.
make_var
(
fld
+
eps
)
f0
=
getattr
(
var0
,
f
)().
val
.
val
f1
=
getattr
(
var1
,
f
)().
val
.
val
f0
=
getattr
(
var0
,
f
)().
val
.
sing
.
val
f1
=
getattr
(
var1
,
f
)().
val
.
sing
.
val
df0
=
(
f1
-
f0
)
/
eps
df1
=
_lin2grad
(
getattr
(
var0
,
f
)())
assert_allclose
(
df0
,
df1
,
rtol
=
100
*
eps
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment