Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
a535e0e8
Commit
a535e0e8
authored
Jul 07, 2018
by
Martin Reinecke
Browse files
cleanups and fixes
parent
69a50640
Changes
10
Hide whitespace changes
Inline
Side-by-side
demos/getting_started_3.py
View file @
a535e0e8
...
...
@@ -24,7 +24,7 @@ if __name__ == '__main__':
power_distributor
=
ift
.
PowerDistributor
(
harmonic_space
,
power_space
)
position
=
{}
position
[
'xi'
]
=
ift
.
Field
.
from_random
(
'normal'
,
harmonic_space
)
position
=
ift
.
MultiField
(
position
)
position
=
ift
.
MultiField
.
from_dict
(
position
)
xi
=
ift
.
Variable
(
position
)[
'xi'
]
Amp
=
power_distributor
(
A
)
...
...
nifty5/domain_tuple.py
View file @
a535e0e8
...
...
@@ -142,20 +142,6 @@ class DomainTuple(object):
def
__ne__
(
self
,
x
):
return
not
self
.
__eq__
(
x
)
def
compatibleTo
(
self
,
x
):
return
self
.
__eq__
(
x
)
def
subsetOf
(
self
,
x
):
return
self
.
__eq__
(
x
)
def
unitedWith
(
self
,
x
):
if
self
is
x
:
return
self
x
=
DomainTuple
.
make
(
x
)
if
self
is
not
x
:
raise
ValueError
(
"domain mismatch"
)
return
self
def
__str__
(
self
):
res
=
"DomainTuple, len: "
+
str
(
len
(
self
))
for
i
in
self
:
...
...
nifty5/field.py
View file @
a535e0e8
...
...
@@ -109,7 +109,7 @@ class Field(object):
@
staticmethod
def
from_local_data
(
domain
,
arr
):
return
Field
(
DomainTuple
.
make
(
domain
),
dobj
.
from_local_data
(
domain
.
shape
,
arr
))
dobj
.
from_local_data
(
domain
.
shape
,
arr
))
def
to_global_data
(
self
):
"""Returns an array containing the full data of the field.
...
...
nifty5/library/amplitude_model.py
View file @
a535e0e8
...
...
@@ -58,7 +58,7 @@ def make_amplitude_model(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv,
fields
=
{
keys
[
0
]:
Field
.
from_random
(
'normal'
,
dof_space
),
keys
[
1
]:
Field
.
from_random
(
'normal'
,
param_space
)}
position
=
MultiField
(
fields
)
position
=
MultiField
.
from_dict
(
fields
)
dof_space
=
position
[
keys
[
0
]].
domain
[
0
]
kern
=
lambda
k
:
_ceps_kernel
(
dof_space
,
k
,
ceps_a
,
ceps_k
)
...
...
nifty5/multi/block_diagonal_operator.py
View file @
a535e0e8
...
...
@@ -31,12 +31,15 @@ class BlockDiagonalOperator(EndomorphicOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
return
MultiField
(
x
.
domain
,
tuple
(
self
.
_operators
[
key
].
apply
(
x
.
_val
[
i
],
mode
=
mode
)
for
i
,
key
in
enumerate
(
x
.
keys
())))
val
=
tuple
(
self
.
_operators
[
key
].
apply
(
x
.
_val
[
i
],
mode
=
mode
)
for
i
,
key
in
enumerate
(
x
.
keys
()))
return
MultiField
(
self
.
_domain
,
val
)
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
self
.
_domain
)
return
MultiField
.
from_dict
({
key
:
op
.
draw_sample
(
from_inverse
,
dtype
[
key
])
for
key
,
op
in
self
.
_operators
.
items
()})
val
=
tuple
(
self
.
_operators
[
key
].
draw_sample
(
from_inverse
,
dtype
[
key
])
for
key
in
self
.
_domain
.
_keys
)
return
MultiField
(
self
.
_domain
,
val
)
def
_combine_chain
(
self
,
op
):
res
=
{}
...
...
nifty5/multi/multi_domain.py
View file @
a535e0e8
...
...
@@ -61,52 +61,3 @@ class MultiDomain(object):
def
__ne__
(
self
,
x
):
return
not
self
.
__eq__
(
x
)
def
compatibleTo
(
self
,
x
):
if
self
is
x
:
return
True
x
=
MultiDomain
.
make
(
x
)
if
self
is
x
:
return
True
if
(
self
,
x
)
in
MultiDomain
.
_compatCache
:
return
True
commonKeys
=
set
(
self
.
keys
())
&
set
(
x
.
keys
())
for
key
in
commonKeys
:
if
self
[
key
]
is
not
x
[
key
]:
return
False
MultiDomain
.
_compatCache
.
add
((
self
,
x
))
MultiDomain
.
_compatCache
.
add
((
x
,
self
))
return
True
def
subsetOf
(
self
,
x
):
if
self
is
x
:
return
True
x
=
MultiDomain
.
make
(
x
)
if
self
is
x
:
return
True
if
len
(
x
)
==
0
:
return
True
if
(
self
,
x
)
in
MultiDomain
.
_subsetCache
:
return
True
for
key
in
self
.
keys
():
if
key
not
in
x
:
return
False
if
self
[
key
]
is
not
x
[
key
]:
return
False
MultiDomain
.
_subsetCache
.
add
((
self
,
x
))
return
True
def
unitedWith
(
self
,
x
):
if
self
is
x
:
return
self
x
=
MultiDomain
.
make
(
x
)
if
self
is
x
:
return
self
if
not
self
.
compatibleTo
(
x
):
raise
ValueError
(
"domain mismatch"
)
res
=
{}
for
key
,
val
in
self
.
items
():
res
[
key
]
=
val
for
key
,
val
in
x
.
items
():
res
[
key
]
=
val
return
MultiDomain
.
make
(
res
)
nifty5/multi/multi_field.py
View file @
a535e0e8
...
...
@@ -103,7 +103,7 @@ class MultiField(object):
# dtype = MultiField.build_dtype(dtype, domain)
return
MultiField
(
domain
,
tuple
(
Field
.
from_random
(
random_type
,
dom
,
dtype
,
**
kwargs
)
for
dom
in
domain
.
_domains
))
for
dom
in
domain
.
_domains
))
def
_check_domain
(
self
,
other
):
if
other
.
_domain
is
not
self
.
_domain
:
...
...
@@ -131,13 +131,14 @@ class MultiField(object):
for
dom
in
domain
.
_domains
))
def
to_global_data
(
self
):
return
{
key
:
val
.
to_global_data
()
for
key
,
val
in
zip
(
self
.
_domain
.
keys
(),
self
.
_val
)}
return
{
key
:
val
.
to_global_data
()
for
key
,
val
in
zip
(
self
.
_domain
.
keys
(),
self
.
_val
)}
@
staticmethod
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
return
MultiField
(
domain
,
tuple
(
Field
.
from_global_data
(
domain
[
key
],
arr
[
key
],
sum_up
)
for
key
in
domain
.
keys
()))
arr
[
key
],
sum_up
)
for
key
in
domain
.
keys
()))
def
norm
(
self
):
""" Computes the L2-norm of the field values.
...
...
nifty5/operators/linear_operator.py
View file @
a535e0e8
...
...
@@ -282,5 +282,5 @@ class LinearOperator(NiftyMetaBase()):
def
_check_input
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
if
not
self
.
_dom
(
mode
)
.
subsetOf
(
x
.
domain
)
:
if
self
.
_dom
(
mode
)
is
not
x
.
domain
:
raise
ValueError
(
"The operator's and field's domains don't match."
)
nifty5/operators/selection_operator.py
View file @
a535e0e8
...
...
@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from
.linear_operator
import
LinearOperator
from
..multi.multi_domain
import
MultiDomain
class
SelectionOperator
(
LinearOperator
):
...
...
@@ -31,10 +32,7 @@ class SelectionOperator(LinearOperator):
String identifier of the wanted subdomain
"""
def
__init__
(
self
,
domain
,
key
):
from
..multi.multi_domain
import
MultiDomain
if
not
isinstance
(
domain
,
MultiDomain
):
raise
TypeError
(
"Domain must be a MultiDomain"
)
self
.
_domain
=
domain
self
.
_domain
=
MultiDomain
.
make
(
domain
)
self
.
_key
=
key
@
property
...
...
@@ -55,4 +53,6 @@ class SelectionOperator(LinearOperator):
return
x
[
self
.
_key
]
else
:
from
..multi.multi_field
import
MultiField
return
MultiField
.
from_dict
({
self
.
_key
:
x
})
rval
=
[
None
]
*
len
(
self
.
_domain
)
rval
[
self
.
_domain
.
_dict
[
self
.
_key
]]
=
x
return
MultiField
(
self
.
_domain
,
tuple
(
rval
))
nifty5/operators/sum_operator.py
View file @
a535e0e8
...
...
@@ -46,8 +46,8 @@ class SumOperator(LinearOperator):
dom
=
ops
[
0
].
domain
tgt
=
ops
[
0
].
target
for
op
in
ops
[
1
:]:
dom
=
dom
.
unitedWith
(
op
.
domain
)
tgt
=
tgt
.
unitedWith
(
op
.
target
)
if
dom
is
not
op
.
domain
or
tgt
is
not
op
.
target
:
raise
ValueError
(
"Domain mismatch"
)
# Step 2: unpack SumOperators
opsnew
=
[]
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment