Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
0992e538
Commit
0992e538
authored
Jul 06, 2018
by
Martin Reinecke
Browse files
Merge branch 'tweak_field_construction' into 'NIFTy_5'
optimize the Field constructor See merge request ift/nifty-dev!40
parents
4ed58632
2ec3da9a
Changes
5
Hide whitespace changes
Inline
Side-by-side
nifty5/domain_tuple.py
View file @
0992e538
...
...
@@ -25,8 +25,8 @@ from .domains.domain import Domain
class
DomainTuple
(
object
):
"""Ordered sequence of Domain objects.
This class holds a
set
of :class:`Domain` objects, which together form
the
space on which a :class:`Field` is defined.
This class holds a
tuple
of :class:`Domain` objects, which together form
the
space on which a :class:`Field` is defined.
Notes
-----
...
...
nifty5/field.py
View file @
0992e538
...
...
@@ -33,45 +33,28 @@ class Field(object):
Parameters
----------
domain : None, DomainTuple, tuple of Domain, or Domain
domain : DomainTuple
the domain of the new Field
val : Field, data_object or scalar
The values the array should contain after init. A scalar input will
fill the whole array with this scalar. If a data_object is provided,
its dimensions must match the domain's.
dtype : type
A numpy.type. Most common are float and complex.
val : data_object
This object's global shape must match the domain shape
After construction, the object will no longer be writeable!
Notes
-----
If possible, do not invoke the constructor directly, but use one of the
many convenience functions for Field con
a
truction!
many convenience functions for Field con
s
truction!
"""
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
):
self
.
_domain
=
self
.
_infer_domain
(
domain
=
domain
,
val
=
val
)
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
)
if
isinstance
(
val
,
Field
):
if
self
.
_domain
!=
val
.
_domain
:
raise
ValueError
(
"Domain mismatch"
)
self
.
_val
=
val
.
_val
elif
(
np
.
isscalar
(
val
)):
self
.
_val
=
dobj
.
full
(
self
.
_domain
.
shape
,
dtype
=
dtype
,
fill_value
=
val
)
elif
isinstance
(
val
,
dobj
.
data_object
):
if
self
.
_domain
.
shape
==
val
.
shape
:
if
dtype
==
val
.
dtype
:
self
.
_val
=
val
else
:
self
.
_val
=
dobj
.
from_object
(
val
,
dtype
,
True
,
True
)
else
:
raise
ValueError
(
"Shape mismatch"
)
else
:
raise
TypeError
(
"unknown source type"
)
def
__init__
(
self
,
domain
,
val
):
if
not
isinstance
(
domain
,
DomainTuple
):
raise
TypeError
(
"domain must be of type DomainTuple"
)
if
not
isinstance
(
val
,
dobj
.
data_object
):
raise
TypeError
(
"val must be of type dobj.data_object"
)
if
domain
.
shape
!=
val
.
shape
:
raise
ValueError
(
"mismatch between the shapes of val and domain"
)
self
.
_domain
=
domain
self
.
_val
=
val
dobj
.
lock
(
self
.
_val
)
# prevent implicit conversion to bool
...
...
@@ -99,7 +82,10 @@ class Field(object):
"""
if
not
np
.
isscalar
(
val
):
raise
TypeError
(
"val must be a scalar"
)
return
Field
(
DomainTuple
.
make
(
domain
),
val
)
if
not
(
np
.
isreal
(
val
)
or
np
.
iscomplex
(
val
)):
raise
TypeError
(
"need arithmetic scalar"
)
domain
=
DomainTuple
.
make
(
domain
)
return
Field
(
domain
,
dobj
.
full
(
domain
.
shape
,
fill_value
=
val
))
@
staticmethod
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
...
...
@@ -118,12 +104,13 @@ class Field(object):
If False, the contens of `arr` are used directly, and must be
identical on all MPI tasks.
"""
return
Field
(
domain
,
dobj
.
from_global_data
(
arr
,
sum_up
))
return
Field
(
DomainTuple
.
make
(
domain
),
dobj
.
from_global_data
(
arr
,
sum_up
))
@
staticmethod
def
from_local_data
(
domain
,
arr
):
domain
=
DomainTuple
.
make
(
domain
)
return
Field
(
domain
,
dobj
.
from_local_data
(
domain
.
shape
,
arr
))
return
Field
(
DomainTuple
.
make
(
domain
)
,
dobj
.
from_local_data
(
domain
.
shape
,
arr
))
def
to_global_data
(
self
):
"""Returns an array containing the full data of the field.
...
...
@@ -167,25 +154,7 @@ class Field(object):
-----
No copy is made. If needed, use an additional copy() invocation.
"""
return
Field
(
new_domain
,
self
.
_val
)
@
staticmethod
def
_infer_domain
(
domain
,
val
=
None
):
if
domain
is
None
:
if
isinstance
(
val
,
Field
):
return
val
.
_domain
if
np
.
isscalar
(
val
):
return
DomainTuple
.
make
(())
# empty domain tuple
raise
TypeError
(
"could not infer domain from value"
)
return
DomainTuple
.
make
(
domain
)
@
staticmethod
def
_infer_dtype
(
dtype
,
val
):
if
dtype
is
not
None
:
return
dtype
if
val
is
None
:
raise
ValueError
(
"could not infer dtype"
)
return
np
.
result_type
(
val
)
return
Field
(
DomainTuple
.
make
(
new_domain
),
self
.
_val
)
@
staticmethod
def
from_random
(
random_type
,
domain
,
dtype
=
np
.
float64
,
**
kwargs
):
...
...
@@ -444,7 +413,7 @@ class Field(object):
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
i
not
in
spaces
)
return
Field
(
d
omain
=
return_domain
,
val
=
data
)
return
Field
(
D
omain
Tuple
.
make
(
return_domain
)
,
data
)
def
sum
(
self
,
spaces
=
None
):
"""Sums up over the sub-domains given by `spaces`.
...
...
nifty5/sugar.py
View file @
0992e538
...
...
@@ -40,7 +40,7 @@ def PS_field(pspace, func):
if
not
isinstance
(
pspace
,
PowerSpace
):
raise
TypeError
data
=
dobj
.
from_global_data
(
func
(
pspace
.
k_lengths
))
return
Field
(
pspace
,
val
=
data
)
return
Field
(
DomainTuple
.
make
(
pspace
)
,
data
)
def
get_signal_variance
(
spec
,
space
):
...
...
@@ -158,7 +158,7 @@ def _create_power_field(domain, power_spectrum):
if
not
isinstance
(
power_spectrum
.
domain
[
0
],
PowerSpace
):
raise
TypeError
(
"PowerSpace required"
)
power_domain
=
power_spectrum
.
domain
[
0
]
fp
=
Field
(
power_domain
,
val
=
power_spectrum
.
val
)
fp
=
power_spectrum
else
:
power_domain
=
PowerSpace
(
domain
)
fp
=
PS_field
(
power_domain
,
power_spectrum
)
...
...
test/test_field.py
View file @
0992e538
...
...
@@ -139,16 +139,16 @@ class Test_Functionality(unittest.TestCase):
assert_equal
(
d
,
d2
)
def
test_empty_domain
(
self
):
f
=
ift
.
Field
((),
5
)
f
=
ift
.
Field
.
full
((),
5
)
assert_equal
(
f
.
local_data
,
5
)
f
=
ift
.
Field
(
None
,
5
)
f
=
ift
.
Field
.
full
(
None
,
5
)
assert_equal
(
f
.
local_data
,
5
)
def
test_trivialities
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
f1
=
ift
.
Field
(
s1
,
27
)
f1
=
ift
.
Field
.
full
(
s1
,
27
)
assert_equal
(
f1
.
local_data
,
f1
.
real
.
local_data
)
f1
=
ift
.
Field
(
s1
,
27.
+
3j
)
f1
=
ift
.
Field
.
full
(
s1
,
27.
+
3j
)
assert_equal
(
f1
.
real
.
local_data
,
27.
)
assert_equal
(
f1
.
imag
.
local_data
,
3.
)
assert_equal
(
f1
.
local_data
,
+
f1
.
local_data
)
...
...
@@ -160,7 +160,7 @@ class Test_Functionality(unittest.TestCase):
def
test_weight
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
f
=
ift
.
Field
(
s1
,
10.
)
f
=
ift
.
Field
.
full
(
s1
,
10.
)
f2
=
f
.
weight
(
1
)
assert_equal
(
f
.
weight
(
1
).
local_data
,
f2
.
local_data
)
assert_equal
(
f
.
total_volume
(),
1
)
...
...
@@ -170,7 +170,7 @@ class Test_Functionality(unittest.TestCase):
assert_equal
(
f
.
scalar_weight
(
0
),
0.1
)
assert_equal
(
f
.
scalar_weight
((
0
,)),
0.1
)
s1
=
ift
.
GLSpace
(
10
)
f
=
ift
.
Field
(
s1
,
10.
)
f
=
ift
.
Field
.
full
(
s1
,
10.
)
assert_equal
(
f
.
scalar_weight
(),
None
)
assert_equal
(
f
.
scalar_weight
(
0
),
None
)
assert_equal
(
f
.
scalar_weight
((
0
,)),
None
)
...
...
@@ -178,7 +178,7 @@ class Test_Functionality(unittest.TestCase):
@
expand
(
product
([
ift
.
RGSpace
(
10
),
ift
.
GLSpace
(
10
)],
[
np
.
float64
,
np
.
complex128
]))
def
test_reduction
(
self
,
dom
,
dt
):
s1
=
ift
.
Field
(
dom
,
1.
,
dtype
=
dt
)
s1
=
ift
.
Field
.
full
(
dom
,
dt
(
1.
)
)
assert_allclose
(
s1
.
mean
(),
1.
)
assert_allclose
(
s1
.
mean
(
0
),
1.
)
assert_allclose
(
s1
.
var
(),
0.
,
atol
=
1e-14
)
...
...
@@ -189,13 +189,11 @@ class Test_Functionality(unittest.TestCase):
def
test_err
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
s2
=
ift
.
RGSpace
((
11
,))
f1
=
ift
.
Field
(
s1
,
27
)
f1
=
ift
.
Field
.
full
(
s1
,
27
)
with
assert_raises
(
ValueError
):
f2
=
ift
.
Field
(
s2
,
f1
)
with
assert_raises
(
ValueError
):
f2
=
ift
.
Field
(
s2
,
f1
.
val
)
f2
=
ift
.
Field
(
ift
.
DomainTuple
.
make
(
s2
),
f1
.
val
)
with
assert_raises
(
TypeError
):
f2
=
ift
.
Field
(
s2
,
"xyz"
)
f2
=
ift
.
Field
.
full
(
s2
,
"xyz"
)
with
assert_raises
(
TypeError
):
if
f1
:
pass
...
...
@@ -203,20 +201,20 @@ class Test_Functionality(unittest.TestCase):
f1
.
full
((
2
,
4
,
6
))
with
assert_raises
(
TypeError
):
f2
=
ift
.
Field
(
None
,
None
)
with
assert_raises
(
Valu
eError
):
with
assert_raises
(
Typ
eError
):
f2
=
ift
.
Field
(
s1
,
None
)
with
assert_raises
(
ValueError
):
f1
.
imag
with
assert_raises
(
TypeError
):
f1
.
vdot
(
42
)
with
assert_raises
(
ValueError
):
f1
.
vdot
(
ift
.
Field
(
s2
,
1.
))
f1
.
vdot
(
ift
.
Field
.
full
(
s2
,
1.
))
with
assert_raises
(
TypeError
):
ift
.
full
(
s1
,
[
2
,
3
])
def
test_stdfunc
(
self
):
s
=
ift
.
RGSpace
((
200
,))
f
=
ift
.
Field
(
s
,
27
)
f
=
ift
.
Field
.
full
(
s
,
27
)
assert_equal
(
f
.
local_data
,
27
)
assert_equal
(
f
.
shape
,
(
200
,))
assert_equal
(
f
.
dtype
,
np
.
int
)
...
...
test/test_minimization/test_minimizers.py
View file @
0992e538
...
...
@@ -133,7 +133,7 @@ class Test_Minimizers(unittest.TestCase):
@
expand
(
product
(
minimizers
+
slow_minimizers
))
def
test_gauss
(
self
,
minimizer
):
space
=
ift
.
UnstructuredDomain
((
1
,))
starting_point
=
ift
.
Field
(
space
,
val
=
3.
)
starting_point
=
ift
.
Field
.
full
(
space
,
3.
)
class
ExpEnergy
(
ift
.
Energy
):
def
__init__
(
self
,
position
):
...
...
@@ -147,14 +147,15 @@ class Test_Minimizers(unittest.TestCase):
@
property
def
gradient
(
self
):
x
=
self
.
position
.
to_global_data
()[
0
]
return
ift
.
Field
(
self
.
position
.
domain
,
val
=
2
*
x
*
np
.
exp
(
-
(
x
**
2
)))
return
ift
.
Field
.
full
(
self
.
position
.
domain
,
2
*
x
*
np
.
exp
(
-
(
x
**
2
)))
@
property
def
metric
(
self
):
x
=
self
.
position
.
to_global_data
()[
0
]
v
=
(
2
-
4
*
x
*
x
)
*
np
.
exp
(
-
x
**
2
)
return
ift
.
DiagonalOperator
(
ift
.
Field
(
self
.
position
.
domain
,
val
=
v
))
ift
.
Field
.
full
(
self
.
position
.
domain
,
v
))
try
:
minimizer
=
eval
(
minimizer
)
...
...
@@ -171,7 +172,7 @@ class Test_Minimizers(unittest.TestCase):
@
expand
(
product
(
minimizers
+
newton_minimizers
+
slow_minimizers
))
def
test_cosh
(
self
,
minimizer
):
space
=
ift
.
UnstructuredDomain
((
1
,))
starting_point
=
ift
.
Field
(
space
,
val
=
3.
)
starting_point
=
ift
.
Field
.
full
(
space
,
3.
)
class
CoshEnergy
(
ift
.
Energy
):
def
__init__
(
self
,
position
):
...
...
@@ -185,14 +186,14 @@ class Test_Minimizers(unittest.TestCase):
@
property
def
gradient
(
self
):
x
=
self
.
position
.
to_global_data
()[
0
]
return
ift
.
Field
(
self
.
position
.
domain
,
val
=
np
.
sinh
(
x
))
return
ift
.
Field
.
full
(
self
.
position
.
domain
,
np
.
sinh
(
x
))
@
property
def
metric
(
self
):
x
=
self
.
position
.
to_global_data
()[
0
]
v
=
np
.
cosh
(
x
)
return
ift
.
DiagonalOperator
(
ift
.
Field
(
self
.
position
.
domain
,
val
=
v
))
ift
.
Field
.
full
(
self
.
position
.
domain
,
v
))
try
:
minimizer
=
eval
(
minimizer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment