Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
3c5d3287
Commit
3c5d3287
authored
Sep 18, 2018
by
Martin Reinecke
Browse files
use ContractionOperator for most of the work
parent
d61dd5f4
Changes
7
Hide whitespace changes
Inline
Side-by-side
nifty5/__init__.py
View file @
3c5d3287
...
...
@@ -45,7 +45,7 @@ from .operators.symmetrizing_operator import SymmetrizingOperator
from
.operators.block_diagonal_operator
import
BlockDiagonalOperator
from
.operators.outer_product_operator
import
OuterProduct
from
.operators.simple_linear_operators
import
(
VdotOperator
,
SumReductionOperator
,
IntegralReductionOperator
,
ConjugationOperator
,
Realizer
,
VdotOperator
,
ConjugationOperator
,
Realizer
,
FieldAdapter
,
GeometryRemover
,
NullOperator
)
from
.operators.energy_operators
import
(
EnergyOperator
,
GaussianEnergy
,
PoissonianEnergy
,
InverseGammaLikelihood
,
...
...
nifty5/library/correlated_fields.py
View file @
3c5d3287
...
...
@@ -64,8 +64,8 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
pd_energy
=
PowerDistributor
(
pd_spatial
.
domain
,
p_space_energy
,
1
)
pd
=
pd_spatial
(
pd_energy
)
dom_distr_spatial
=
ContractionOperator
(
pd
.
domain
,
0
).
adjoint
dom_distr_energy
=
ContractionOperator
(
pd
.
domain
,
1
).
adjoint
dom_distr_spatial
=
ContractionOperator
(
pd
.
domain
,
1
).
adjoint
dom_distr_energy
=
ContractionOperator
(
pd
.
domain
,
0
).
adjoint
a_spatial
=
dom_distr_spatial
(
amplitude_model_spatial
)
a_energy
=
dom_distr_energy
(
amplitude_model_energy
)
...
...
nifty5/linearization.py
View file @
3c5d3287
...
...
@@ -132,12 +132,14 @@ class Linearization(object):
return
self
.
new
(
OuterProduct
(
self
.
_val
,
other
.
_val
.
domain
)(
other
.
_val
),
OuterProduct
(
other
.
_val
,
self
.
_jac
.
domain
)(
self
.
_jac
).
_myadd
(
OuterProduct
(
self
.
_val
,
other
.
_jac
.
domain
)(
other
.
_jac
),
False
))
OuterProduct
(
self
.
_val
,
other
.
_jac
.
domain
)(
other
.
_jac
),
False
))
if
np
.
isscalar
(
other
):
return
self
.
__mul__
(
other
)
if
isinstance
(
other
,
(
Field
,
MultiField
)):
return
self
.
new
(
OuterProduct
(
self
.
_val
,
other
.
_val
.
domain
)(
other
.
_val
),
OuterProduct
(
other
.
_val
,
self
.
_jac
.
domain
)(
self
.
_jac
))
return
self
.
new
(
OuterProduct
(
self
.
_val
,
other
.
_val
.
domain
)(
other
.
_val
),
OuterProduct
(
other
.
_val
,
self
.
_jac
.
domain
)(
self
.
_jac
))
def
vdot
(
self
,
other
):
from
.operators.simple_linear_operators
import
VdotOperator
...
...
@@ -151,26 +153,26 @@ class Linearization(object):
VdotOperator
(
other
.
_val
)(
self
.
_jac
))
def
sum
(
self
,
spaces
=
None
):
from
.operators.
simple_linear
_operator
s
import
SumRedu
ctionOperator
from
.operators.
contraction
_operator
import
Contra
ctionOperator
if
spaces
is
None
:
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
sum
()),
SumRedu
ctionOperator
(
self
.
_jac
.
target
,
None
)(
self
.
_jac
))
Contra
ctionOperator
(
self
.
_jac
.
target
,
None
)(
self
.
_jac
))
else
:
return
self
.
new
(
self
.
_val
.
sum
(
spaces
),
SumRedu
ctionOperator
(
self
.
_jac
.
target
,
spaces
)(
self
.
_jac
))
Contra
ctionOperator
(
self
.
_jac
.
target
,
spaces
)(
self
.
_jac
))
def
integrate
(
self
,
spaces
=
None
):
from
.operators.
simple_linear
_operator
s
import
IntegralRedu
ctionOperator
from
.operators.
contraction
_operator
import
Contra
ctionOperator
if
spaces
is
None
:
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
integrate
()),
IntegralRedu
ctionOperator
(
self
.
_jac
.
target
,
None
)(
self
.
_jac
))
Contra
ctionOperator
(
self
.
_jac
.
target
,
None
,
1
)(
self
.
_jac
))
else
:
return
self
.
new
(
self
.
_val
.
integrate
(
spaces
),
IntegralRedu
ctionOperator
(
self
.
_jac
.
target
,
spaces
)(
self
.
_jac
))
Contra
ctionOperator
(
self
.
_jac
.
target
,
spaces
,
1
)(
self
.
_jac
))
def
exp
(
self
):
tmp
=
self
.
_val
.
exp
()
...
...
nifty5/operators/contraction_operator.py
View file @
3c5d3287
...
...
@@ -37,28 +37,37 @@ class ContractionOperator(LinearOperator):
----------
domain : Domain, tuple of Domain or DomainTuple
spaces : int or tuple of int
The elements of "domain" which are taken as target.
The elements of "domain" which are contracted.
weight : int, default=0
if nonzero, the fields living on self.domain are weighted with the
specified power.
"""
def
__init__
(
self
,
domain
,
spaces
):
def
__init__
(
self
,
domain
,
spaces
,
weight
=
0
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_domain
))
self
.
_target
=
[
dom
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
i
in
self
.
_spaces
dom
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
i
not
in
self
.
_spaces
]
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
self
.
_weight
=
weight
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
ADJOINT_TIMES
:
ldat
=
x
.
lo
c
al_data
if
0
in
self
.
_spaces
else
x
.
to_g
lo
b
al_data
()
ldat
=
x
.
to_g
lo
b
al_data
()
if
0
in
self
.
_spaces
else
x
.
lo
c
al_data
shp
=
[]
for
i
,
dom
in
enumerate
(
self
.
_domain
):
tmp
=
dom
.
shape
if
i
>
0
else
dom
.
local_shape
shp
+=
tmp
if
i
in
self
.
_spaces
else
(
1
,)
*
len
(
dom
.
shape
)
shp
+=
tmp
if
i
not
in
self
.
_spaces
else
(
1
,)
*
len
(
dom
.
shape
)
ldat
=
np
.
broadcast_to
(
ldat
.
reshape
(
shp
),
self
.
_domain
.
local_shape
)
return
Field
.
from_local_data
(
self
.
_domain
,
ldat
)
res
=
Field
.
from_local_data
(
self
.
_domain
,
ldat
)
if
self
.
_weight
!=
0
:
res
=
res
.
weight
(
self
.
_weight
,
spaces
=
self
.
_spaces
)
return
res
else
:
return
x
.
sum
(
[
s
for
s
in
range
(
len
(
x
.
domain
))
if
s
not
in
self
.
_spaces
])
if
self
.
_weight
!=
0
:
x
=
x
.
weight
(
self
.
_weight
,
spaces
=
self
.
_spaces
)
res
=
x
.
sum
(
self
.
_spaces
)
return
res
if
isinstance
(
res
,
Field
)
else
Field
.
scalar
(
res
)
nifty5/operators/outer_product_operator.py
View file @
3c5d3287
...
...
@@ -46,14 +46,19 @@ class OuterProduct(LinearOperator):
self
.
_domain
=
domain
self
.
_field
=
field
self
.
_target
=
DomainTuple
.
make
(
tuple
(
sub_d
for
sub_d
in
field
.
domain
.
_dom
+
domain
.
_dom
))
self
.
_target
=
DomainTuple
.
make
(
tuple
(
sub_d
for
sub_d
in
field
.
domain
.
_dom
+
domain
.
_dom
))
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
Field
.
from_global_data
(
self
.
_target
,
np
.
multiply
.
outer
(
self
.
_field
.
to_global_data
(),
x
.
to_global_data
()))
return
Field
.
from_global_data
(
self
.
_target
,
np
.
multiply
.
outer
(
self
.
_field
.
to_global_data
(),
x
.
to_global_data
()))
axes
=
len
(
self
.
_field
.
shape
)
return
Field
.
from_global_data
(
self
.
_domain
,
np
.
tensordot
(
self
.
_field
.
to_global_data
(),
x
.
to_global_data
(),
axes
))
return
Field
.
from_global_data
(
self
.
_domain
,
np
.
tensordot
(
self
.
_field
.
to_global_data
(),
x
.
to_global_data
(),
axes
))
nifty5/operators/simple_linear_operators.py
View file @
3c5d3287
...
...
@@ -18,19 +18,15 @@
from
__future__
import
absolute_import
,
division
,
print_function
import
numpy
as
np
from
..compat
import
*
from
..domain_tuple
import
DomainTuple
from
..domains.unstructured_domain
import
UnstructuredDomain
from
..field
import
Field
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
from
..sugar
import
full
,
makeDomain
from
..sugar
import
full
from
.endomorphic_operator
import
EndomorphicOperator
from
.linear_operator
import
LinearOperator
from
.domain_tuple_field_inserter
import
DomainTupleFieldInserter
from
..
import
utilities
class
VdotOperator
(
LinearOperator
):
...
...
@@ -47,81 +43,6 @@ class VdotOperator(LinearOperator):
return
self
.
_field
*
x
.
local_data
[()]
class
SumReductionOperator
(
LinearOperator
):
def
__init__
(
self
,
domain
,
spaces
=
None
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_domain
))
if
len
(
self
.
_spaces
)
==
len
(
self
.
_domain
):
self
.
_spaces
=
None
if
self
.
_spaces
is
None
:
self
.
_target
=
DomainTuple
.
scalar_domain
()
else
:
self
.
_target
=
makeDomain
(
tuple
(
dom
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
not
(
i
in
self
.
_spaces
)))
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
if
self
.
_spaces
is
None
:
return
Field
.
scalar
(
x
.
sum
())
else
:
return
x
.
sum
(
self
.
_spaces
)
if
self
.
_spaces
is
None
:
return
full
(
self
.
_domain
,
x
.
local_data
[()])
else
:
one
=
np
.
ones
(
self
.
_domain
.
shape
)
slice_list
=
[
slice
(
None
),
]
*
len
(
self
.
_domain
.
shape
)
p
=
0
for
i
in
range
(
len
(
self
.
_domain
)):
l
=
len
(
self
.
_domain
[
i
].
shape
)
if
i
in
self
.
_spaces
:
slice_list
[
slice
(
p
,
p
+
l
)]
=
(
np
.
newaxis
,)
*
l
p
=
p
+
l
return
Field
.
from_global_data
(
self
.
_domain
,
x
.
to_global_data
()[
tuple
(
slice_list
)]
*
one
)
class
IntegralReductionOperator
(
LinearOperator
):
def
__init__
(
self
,
domain
,
spaces
=
None
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_domain
))
if
len
(
self
.
_spaces
)
==
len
(
self
.
_domain
):
self
.
_spaces
=
None
if
self
.
_spaces
is
None
:
self
.
_target
=
DomainTuple
.
scalar_domain
()
else
:
self
.
_target
=
makeDomain
(
tuple
(
dom
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
not
(
i
in
self
.
_spaces
)))
self
.
_marg_space
=
makeDomain
(
tuple
(
dom
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
(
i
in
self
.
_spaces
)))
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
vol
=
1.
if
mode
==
self
.
TIMES
:
if
self
.
_spaces
is
None
:
return
Field
.
scalar
(
x
.
integrate
())
else
:
return
x
.
integrate
(
self
.
_spaces
)
if
self
.
_spaces
is
None
:
for
d
in
self
.
_domain
.
_dom
:
for
dis
in
d
.
distances
:
vol
*=
dis
return
full
(
self
.
_domain
,
x
.
local_data
[()]
*
vol
)
else
:
for
d
in
self
.
_marg_space
.
_dom
:
for
dis
in
d
.
distances
:
vol
*=
dis
one
=
np
.
ones
(
self
.
_domain
.
shape
)
slice_list
=
[
slice
(
None
),
]
*
len
(
self
.
_domain
.
shape
)
p
=
0
for
i
in
range
(
len
(
self
.
_domain
)):
l
=
len
(
self
.
_domain
[
i
].
shape
)
if
i
in
self
.
_spaces
:
slice_list
[
slice
(
p
,
p
+
l
)]
=
(
np
.
newaxis
,)
*
l
p
=
p
+
l
return
Field
.
from_global_data
(
self
.
_domain
,
x
.
to_global_data
()[
tuple
(
slice_list
)]
*
one
*
vol
)
class
ConjugationOperator
(
EndomorphicOperator
):
def
__init__
(
self
,
domain
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
...
...
test/test_operators/test_adjoint.py
View file @
3c5d3287
...
...
@@ -65,12 +65,6 @@ class Consistency_Tests(unittest.TestCase):
dtype
=
dtype
))
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
(
_h_spaces
+
_p_spaces
+
_pow_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testSumReductionOperator
(
self
,
sp
,
dtype
):
op
=
ift
.
SumReductionOperator
(
sp
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
([(
ift
.
RGSpace
(
10
,
harmonic
=
True
),
4
,
0
),
(
ift
.
RGSpace
((
24
,
31
),
distances
=
(
0.4
,
2.34
),
harmonic
=
True
),
3
,
0
),
...
...
@@ -193,11 +187,11 @@ class Consistency_Tests(unittest.TestCase):
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
([
0
,
1
,
2
,
3
,
(
0
,
1
),
(
0
,
2
),
(
0
,
1
,
2
),
(
0
,
2
,
3
),
(
1
,
3
)],
[
np
.
float64
,
np
.
complex128
]))
def
testContractionOperator
(
self
,
spaces
,
dtype
):
dom
=
(
ift
.
RGSpace
(
10
),
ift
.
UnstructuredDomain
(
13
),
ift
.
GLSpace
(
5
),
[
0
,
1
,
2
,
-
1
],
[
np
.
float64
,
np
.
complex128
]))
def
testContractionOperator
(
self
,
spaces
,
wgt
,
dtype
):
dom
=
(
ift
.
RGSpace
(
10
),
ift
.
RGSpace
(
13
),
ift
.
GLSpace
(
5
),
ift
.
HPSpace
(
4
))
op
=
ift
.
ContractionOperator
(
dom
,
spaces
)
op
=
ift
.
ContractionOperator
(
dom
,
spaces
,
wgt
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
def
testDomainTupleFieldInserter
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment