Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
c91ee857
Commit
c91ee857
authored
Sep 21, 2018
by
Martin Reinecke
Browse files
Merge branch 'outer_product' into 'NIFTy_5'
Outer product See merge request ift/nifty-dev!106
parents
71cc7162
9986da5c
Changes
10
Hide whitespace changes
Inline
Side-by-side
nifty5/__init__.py
View file @
c91ee857
...
...
@@ -43,8 +43,9 @@ from .operators.slope_operator import SlopeOperator
from
.operators.smoothness_operator
import
SmoothnessOperator
from
.operators.symmetrizing_operator
import
SymmetrizingOperator
from
.operators.block_diagonal_operator
import
BlockDiagonalOperator
from
.operators.outer_product_operator
import
OuterProduct
from
.operators.simple_linear_operators
import
(
VdotOperator
,
SumReductionOperator
,
ConjugationOperator
,
Realizer
,
VdotOperator
,
ConjugationOperator
,
Realizer
,
FieldAdapter
,
GeometryRemover
,
NullOperator
)
from
.operators.energy_operators
import
(
EnergyOperator
,
GaussianEnergy
,
PoissonianEnergy
,
InverseGammaLikelihood
,
...
...
nifty5/field.py
View file @
c91ee857
...
...
@@ -327,6 +327,23 @@ class Field(object):
return
Field
.
from_local_data
(
self
.
_domain
,
aout
)
def
outer
(
self
,
x
):
""" Computes the outer product of 'self' with x.
Parameters
----------
x : Field
Returns
----------
Field, lives on the product space of self.domain and x.domain
"""
if
not
isinstance
(
x
,
Field
):
raise
TypeError
(
"The multiplier must be an instance of "
+
"the NIFTy field class"
)
from
.operators.outer_product_operator
import
OuterProduct
return
OuterProduct
(
self
,
x
.
domain
)(
x
)
def
vdot
(
self
,
x
=
None
,
spaces
=
None
):
""" Computes the dot product of 'self' with x.
...
...
nifty5/library/correlated_fields.py
View file @
c91ee857
...
...
@@ -67,8 +67,8 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
pd_energy
=
PowerDistributor
(
pd_spatial
.
domain
,
p_space_energy
,
1
)
pd
=
pd_spatial
(
pd_energy
)
dom_distr_spatial
=
ContractionOperator
(
pd
.
domain
,
0
).
adjoint
dom_distr_energy
=
ContractionOperator
(
pd
.
domain
,
1
).
adjoint
dom_distr_spatial
=
ContractionOperator
(
pd
.
domain
,
1
).
adjoint
dom_distr_energy
=
ContractionOperator
(
pd
.
domain
,
0
).
adjoint
a_spatial
=
dom_distr_spatial
(
amplitude_model_spatial
)
a_energy
=
dom_distr_energy
(
amplitude_model_energy
)
...
...
nifty5/linearization.py
View file @
c91ee857
...
...
@@ -126,6 +126,19 @@ class Linearization(object):
def
__rmul__
(
self
,
other
):
return
self
.
__mul__
(
other
)
def
outer
(
self
,
other
):
from
.operators.outer_product_operator
import
OuterProduct
if
isinstance
(
other
,
Linearization
):
return
self
.
new
(
OuterProduct
(
self
.
_val
,
other
.
target
)(
other
.
_val
),
OuterProduct
(
self
.
_jac
(
self
.
_val
),
other
.
target
).
_myadd
(
OuterProduct
(
self
.
_val
,
other
.
target
)(
other
.
_jac
),
False
))
if
np
.
isscalar
(
other
):
return
self
.
__mul__
(
other
)
if
isinstance
(
other
,
(
Field
,
MultiField
)):
return
self
.
new
(
OuterProduct
(
self
.
_val
,
other
.
domain
)(
other
),
OuterProduct
(
self
.
_jac
(
self
.
_val
),
other
.
domain
))
def
vdot
(
self
,
other
):
from
.operators.simple_linear_operators
import
VdotOperator
if
isinstance
(
other
,
(
Field
,
MultiField
)):
...
...
@@ -137,11 +150,27 @@ class Linearization(object):
VdotOperator
(
self
.
_val
)(
other
.
_jac
)
+
VdotOperator
(
other
.
_val
)(
self
.
_jac
))
def
sum
(
self
):
from
.operators.simple_linear_operators
import
SumReductionOperator
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
sum
()),
SumReductionOperator
(
self
.
_jac
.
target
)(
self
.
_jac
))
def
sum
(
self
,
spaces
=
None
):
from
.operators.contraction_operator
import
ContractionOperator
if
spaces
is
None
:
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
sum
()),
ContractionOperator
(
self
.
_jac
.
target
,
None
)(
self
.
_jac
))
else
:
return
self
.
new
(
self
.
_val
.
sum
(
spaces
),
ContractionOperator
(
self
.
_jac
.
target
,
spaces
)(
self
.
_jac
))
def
integrate
(
self
,
spaces
=
None
):
from
.operators.contraction_operator
import
ContractionOperator
if
spaces
is
None
:
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
integrate
()),
ContractionOperator
(
self
.
_jac
.
target
,
None
,
1
)(
self
.
_jac
))
else
:
return
self
.
new
(
self
.
_val
.
integrate
(
spaces
),
ContractionOperator
(
self
.
_jac
.
target
,
spaces
,
1
)(
self
.
_jac
))
def
exp
(
self
):
tmp
=
self
.
_val
.
exp
()
...
...
nifty5/operators/contraction_operator.py
View file @
c91ee857
...
...
@@ -37,28 +37,37 @@ class ContractionOperator(LinearOperator):
----------
domain : Domain, tuple of Domain or DomainTuple
spaces : int or tuple of int
The elements of "domain" which are taken as target.
The elements of "domain" which are contracted.
weight : int, default=0
if nonzero, the fields living on self.domain are weighted with the
specified power.
"""
def
__init__
(
self
,
domain
,
spaces
):
def
__init__
(
self
,
domain
,
spaces
,
weight
=
0
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_domain
))
self
.
_target
=
[
dom
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
i
in
self
.
_spaces
dom
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
i
not
in
self
.
_spaces
]
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
self
.
_weight
=
weight
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
ADJOINT_TIMES
:
ldat
=
x
.
lo
c
al_data
if
0
in
self
.
_spaces
else
x
.
to_g
lo
b
al_data
()
ldat
=
x
.
to_g
lo
b
al_data
()
if
0
in
self
.
_spaces
else
x
.
lo
c
al_data
shp
=
[]
for
i
,
dom
in
enumerate
(
self
.
_domain
):
tmp
=
dom
.
shape
if
i
>
0
else
dom
.
local_shape
shp
+=
tmp
if
i
in
self
.
_spaces
else
(
1
,)
*
len
(
dom
.
shape
)
shp
+=
tmp
if
i
not
in
self
.
_spaces
else
(
1
,)
*
len
(
dom
.
shape
)
ldat
=
np
.
broadcast_to
(
ldat
.
reshape
(
shp
),
self
.
_domain
.
local_shape
)
return
Field
.
from_local_data
(
self
.
_domain
,
ldat
)
res
=
Field
.
from_local_data
(
self
.
_domain
,
ldat
)
if
self
.
_weight
!=
0
:
res
=
res
.
weight
(
self
.
_weight
,
spaces
=
self
.
_spaces
)
return
res
else
:
return
x
.
sum
(
[
s
for
s
in
range
(
len
(
x
.
domain
))
if
s
not
in
self
.
_spaces
])
if
self
.
_weight
!=
0
:
x
=
x
.
weight
(
self
.
_weight
,
spaces
=
self
.
_spaces
)
res
=
x
.
sum
(
self
.
_spaces
)
return
res
if
isinstance
(
res
,
Field
)
else
Field
.
scalar
(
res
)
nifty5/operators/outer_product_operator.py
0 → 100644
View file @
c91ee857
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
absolute_import
,
division
,
print_function
import
itertools
import
numpy
as
np
from
..
import
dobj
,
utilities
from
..compat
import
*
from
..domain_tuple
import
DomainTuple
from
..domains.rg_space
import
RGSpace
from
..multi_field
import
MultiField
,
MultiDomain
from
..field
import
Field
from
.linear_operator
import
LinearOperator
import
operator
class
OuterProduct
(
LinearOperator
):
"""Performs the pointwise outer product of two fields.
Parameters
---------
field: Field,
domain: DomainTuple, the domain of the input field
---------
"""
def
__init__
(
self
,
field
,
domain
):
self
.
_domain
=
domain
self
.
_field
=
field
self
.
_target
=
DomainTuple
.
make
(
tuple
(
sub_d
for
sub_d
in
field
.
domain
.
_dom
+
domain
.
_dom
))
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
Field
.
from_global_data
(
self
.
_target
,
np
.
multiply
.
outer
(
self
.
_field
.
to_global_data
(),
x
.
to_global_data
()))
axes
=
len
(
self
.
_field
.
shape
)
return
Field
.
from_global_data
(
self
.
_domain
,
np
.
tensordot
(
self
.
_field
.
to_global_data
(),
x
.
to_global_data
(),
axes
))
nifty5/operators/simple_linear_operators.py
View file @
c91ee857
...
...
@@ -43,19 +43,6 @@ class VdotOperator(LinearOperator):
return
self
.
_field
*
x
.
local_data
[()]
class
SumReductionOperator
(
LinearOperator
):
def
__init__
(
self
,
domain
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_target
=
DomainTuple
.
scalar_domain
()
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
Field
.
scalar
(
x
.
sum
())
return
full
(
self
.
_domain
,
x
.
local_data
[()])
class
ConjugationOperator
(
EndomorphicOperator
):
def
__init__
(
self
,
domain
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
...
...
test/test_field.py
View file @
c91ee857
...
...
@@ -136,6 +136,34 @@ class Test_Functionality(unittest.TestCase):
res
=
m
.
vdot
(
m
,
spaces
=
1
)
assert_allclose
(
res
.
local_data
,
37.5
)
def
test_outer
(
self
):
x1
=
ift
.
RGSpace
((
9
,))
x2
=
ift
.
RGSpace
((
3
,))
m1
=
ift
.
Field
.
full
(
x1
,
.
5
)
m2
=
ift
.
Field
.
full
(
x2
,
3.
)
res
=
m1
.
outer
(
m2
)
assert_allclose
(
res
.
to_global_data
(),
np
.
full
((
9
,
3
,),
1.5
))
def
test_sum
(
self
):
x1
=
ift
.
RGSpace
((
9
,),
distances
=
2.
)
x2
=
ift
.
RGSpace
((
2
,
12
,),
distances
=
(
0.3
,))
m1
=
ift
.
Field
.
from_global_data
(
ift
.
makeDomain
(
x1
),
np
.
arange
(
9
))
m2
=
ift
.
Field
.
full
(
ift
.
makeDomain
((
x1
,
x2
)),
0.45
)
res1
=
m1
.
sum
()
res2
=
m2
.
sum
(
spaces
=
1
)
assert_allclose
(
res1
,
36
)
assert_allclose
(
res2
.
to_global_data
(),
np
.
full
(
9
,
2
*
12
*
0.45
))
def
test_integrate
(
self
):
x1
=
ift
.
RGSpace
((
9
,),
distances
=
2.
)
x2
=
ift
.
RGSpace
((
2
,
12
,),
distances
=
(
0.3
,))
m1
=
ift
.
Field
.
from_global_data
(
ift
.
makeDomain
(
x1
),
np
.
arange
(
9
))
m2
=
ift
.
Field
.
full
(
ift
.
makeDomain
((
x1
,
x2
)),
0.45
)
res1
=
m1
.
integrate
()
res2
=
m2
.
integrate
(
spaces
=
1
)
assert_allclose
(
res1
,
36
*
2
)
assert_allclose
(
res2
.
to_global_data
(),
np
.
full
(
9
,
2
*
12
*
0.45
*
0.3
**
2
))
def
test_dataconv
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
ld
=
np
.
arange
(
ift
.
dobj
.
local_shape
(
s1
.
shape
)[
0
])
...
...
test/test_models/test_model_gradients.py
View file @
c91ee857
...
...
@@ -64,26 +64,29 @@ class Model_Tests(unittest.TestCase):
dom
=
ift
.
MultiDomain
.
union
((
dom1
,
dom2
))
model
=
ift
.
FieldAdapter
(
dom
,
"s1"
)
*
ift
.
FieldAdapter
(
dom
,
"s2"
)
pos
=
ift
.
from_random
(
"normal"
,
dom
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
,
ntries
=
20
)
model
=
ift
.
FieldAdapter
(
dom
,
"s1"
)
+
ift
.
FieldAdapter
(
dom
,
"s2"
)
pos
=
ift
.
from_random
(
"normal"
,
dom
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
,
ntries
=
20
)
model
=
ift
.
FieldAdapter
(
dom
,
"s1"
).
scale
(
3.
)
pos
=
ift
.
from_random
(
"normal"
,
dom1
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
,
ntries
=
20
)
model
=
ift
.
ScalingOperator
(
2.456
,
space
)(
ift
.
FieldAdapter
(
dom
,
"s1"
)
*
ift
.
FieldAdapter
(
dom
,
"s2"
))
pos
=
ift
.
from_random
(
"normal"
,
dom
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
,
ntries
=
20
)
model
=
ift
.
positive_tanh
(
ift
.
ScalingOperator
(
2.456
,
space
)(
ift
.
FieldAdapter
(
dom
,
"s1"
)
*
ift
.
FieldAdapter
(
dom
,
"s2"
)))
pos
=
ift
.
from_random
(
"normal"
,
dom
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
,
ntries
=
20
)
pos
=
ift
.
from_random
(
"normal"
,
dom
)
model
=
ift
.
OuterProduct
(
pos
[
's1'
],
ift
.
makeDomain
(
space
))
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
[
's2'
],
ntries
=
20
)
if
isinstance
(
space
,
ift
.
RGSpace
):
model
=
ift
.
FFTOperator
(
space
)(
ift
.
FieldAdapter
(
dom
,
"s1"
)
*
ift
.
FieldAdapter
(
dom
,
"s2"
))
pos
=
ift
.
from_random
(
"normal"
,
dom
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
,
ntries
=
20
)
@
expand
(
product
(
[
ift
.
GLSpace
(
15
),
...
...
@@ -106,12 +109,12 @@ class Model_Tests(unittest.TestCase):
sv
,
im
,
iv
)
S
=
ift
.
ScalingOperator
(
1.
,
model
.
domain
)
pos
=
S
.
draw_sample
()
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
,
ntries
=
20
)
model2
=
ift
.
CorrelatedField
(
space
,
model
)
S
=
ift
.
ScalingOperator
(
1.
,
model2
.
domain
)
pos
=
S
.
draw_sample
()
ift
.
extra
.
check_value_gradient_consistency
(
model2
,
pos
)
ift
.
extra
.
check_value_gradient_consistency
(
model2
,
pos
,
ntries
=
20
)
@
expand
(
product
(
[
ift
.
GLSpace
(
15
),
...
...
@@ -125,7 +128,8 @@ class Model_Tests(unittest.TestCase):
q
=
0.73
model
=
ift
.
InverseGammaModel
(
space
,
alpha
,
q
)
# FIXME All those cdfs and ppfs are not very accurate
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
,
tol
=
1e-2
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
,
tol
=
1e-2
,
ntries
=
20
)
# @expand(product(
# ['Variable', 'Constant'],
...
...
test/test_operators/test_adjoint.py
View file @
c91ee857
...
...
@@ -65,12 +65,6 @@ class Consistency_Tests(unittest.TestCase):
dtype
=
dtype
))
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
(
_h_spaces
+
_p_spaces
+
_pow_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testSumReductionOperator
(
self
,
sp
,
dtype
):
op
=
ift
.
SumReductionOperator
(
sp
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
([(
ift
.
RGSpace
(
10
,
harmonic
=
True
),
4
,
0
),
(
ift
.
RGSpace
((
24
,
31
),
distances
=
(
0.4
,
2.34
),
harmonic
=
True
),
3
,
0
),
...
...
@@ -193,11 +187,11 @@ class Consistency_Tests(unittest.TestCase):
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
([
0
,
1
,
2
,
3
,
(
0
,
1
),
(
0
,
2
),
(
0
,
1
,
2
),
(
0
,
2
,
3
),
(
1
,
3
)],
[
np
.
float64
,
np
.
complex128
]))
def
testContractionOperator
(
self
,
spaces
,
dtype
):
dom
=
(
ift
.
RGSpace
(
10
),
ift
.
UnstructuredDomain
(
13
),
ift
.
GLSpace
(
5
),
[
0
,
1
,
2
,
-
1
],
[
np
.
float64
,
np
.
complex128
]))
def
testContractionOperator
(
self
,
spaces
,
wgt
,
dtype
):
dom
=
(
ift
.
RGSpace
(
10
),
ift
.
RGSpace
(
13
),
ift
.
GLSpace
(
5
),
ift
.
HPSpace
(
4
))
op
=
ift
.
ContractionOperator
(
dom
,
spaces
)
op
=
ift
.
ContractionOperator
(
dom
,
spaces
,
wgt
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
def
testDomainTupleFieldInserter
(
self
):
...
...
@@ -263,3 +257,17 @@ class Consistency_Tests(unittest.TestCase):
def
testRegridding
(
self
,
domain
,
shape
,
space
):
op
=
ift
.
RegriddingOperator
(
domain
,
shape
,
space
)
ift
.
extra
.
consistency_check
(
op
)
@
expand
(
product
([
ift
.
DomainTuple
.
make
((
ift
.
RGSpace
((
3
,
5
,
4
)),
ift
.
RGSpace
((
16
,),
distances
=
(
7.
,))),),
ift
.
DomainTuple
.
make
(
ift
.
HPSpace
(
12
),)],
[
ift
.
DomainTuple
.
make
((
ift
.
RGSpace
((
2
,)),
ift
.
GLSpace
(
10
)),),
ift
.
DomainTuple
.
make
(
ift
.
RGSpace
((
10
,
12
),
distances
=
(
0.1
,
1.
)),)]
))
def
testOuter
(
self
,
fdomain
,
domain
):
f
=
ift
.
from_random
(
'normal'
,
fdomain
)
op
=
ift
.
OuterProduct
(
f
,
domain
)
ift
.
extra
.
consistency_check
(
op
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment