Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
N
NIFTy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
10
Issues
10
List
Boards
Labels
Service Desk
Milestones
Merge Requests
8
Merge Requests
8
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
NIFTy
Commits
2ec3da9a
Commit
2ec3da9a
authored
Jul 06, 2018
by
Martin Reinecke
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
optimize the Field constructor
parent
4ed58632
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
81 deletions
+49
-81
nifty5/domain_tuple.py
nifty5/domain_tuple.py
+2
-2
nifty5/field.py
nifty5/field.py
+25
-56
nifty5/sugar.py
nifty5/sugar.py
+2
-2
test/test_field.py
test/test_field.py
+13
-15
test/test_minimization/test_minimizers.py
test/test_minimization/test_minimizers.py
+7
-6
No files found.
nifty5/domain_tuple.py
View file @
2ec3da9a
...
...
@@ -25,8 +25,8 @@ from .domains.domain import Domain
class
DomainTuple
(
object
):
"""Ordered sequence of Domain objects.
This class holds a
set of :class:`Domain` objects, which together form the
space on which a :class:`Field` is defined.
This class holds a
tuple of :class:`Domain` objects, which together form
the
space on which a :class:`Field` is defined.
Notes
-----
...
...
nifty5/field.py
View file @
2ec3da9a
...
...
@@ -33,45 +33,28 @@ class Field(object):
Parameters
----------
domain : None, DomainTuple, tuple of Domain, or Domain
domain : DomainTuple
the domain of the new Field
val : Field, data_object or scalar
The values the array should contain after init. A scalar input will
fill the whole array with this scalar. If a data_object is provided,
its dimensions must match the domain's.
dtype : type
A numpy.type. Most common are float and complex.
val : data_object
This object's global shape must match the domain shape
After construction, the object will no longer be writeable!
Notes
-----
If possible, do not invoke the constructor directly, but use one of the
many convenience functions for Field con
a
truction!
many convenience functions for Field con
s
truction!
"""
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
):
self
.
_domain
=
self
.
_infer_domain
(
domain
=
domain
,
val
=
val
)
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
)
if
isinstance
(
val
,
Field
):
if
self
.
_domain
!=
val
.
_domain
:
raise
ValueError
(
"Domain mismatch"
)
self
.
_val
=
val
.
_val
elif
(
np
.
isscalar
(
val
)):
self
.
_val
=
dobj
.
full
(
self
.
_domain
.
shape
,
dtype
=
dtype
,
fill_value
=
val
)
elif
isinstance
(
val
,
dobj
.
data_object
):
if
self
.
_domain
.
shape
==
val
.
shape
:
if
dtype
==
val
.
dtype
:
self
.
_val
=
val
else
:
self
.
_val
=
dobj
.
from_object
(
val
,
dtype
,
True
,
True
)
else
:
raise
ValueError
(
"Shape mismatch"
)
else
:
raise
TypeError
(
"unknown source type"
)
def
__init__
(
self
,
domain
,
val
):
if
not
isinstance
(
domain
,
DomainTuple
):
raise
TypeError
(
"domain must be of type DomainTuple"
)
if
not
isinstance
(
val
,
dobj
.
data_object
):
raise
TypeError
(
"val must be of type dobj.data_object"
)
if
domain
.
shape
!=
val
.
shape
:
raise
ValueError
(
"mismatch between the shapes of val and domain"
)
self
.
_domain
=
domain
self
.
_val
=
val
dobj
.
lock
(
self
.
_val
)
# prevent implicit conversion to bool
...
...
@@ -99,7 +82,10 @@ class Field(object):
"""
if
not
np
.
isscalar
(
val
):
raise
TypeError
(
"val must be a scalar"
)
return
Field
(
DomainTuple
.
make
(
domain
),
val
)
if
not
(
np
.
isreal
(
val
)
or
np
.
iscomplex
(
val
)):
raise
TypeError
(
"need arithmetic scalar"
)
domain
=
DomainTuple
.
make
(
domain
)
return
Field
(
domain
,
dobj
.
full
(
domain
.
shape
,
fill_value
=
val
))
@
staticmethod
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
...
...
@@ -118,12 +104,13 @@ class Field(object):
If False, the contens of `arr` are used directly, and must be
identical on all MPI tasks.
"""
return
Field
(
domain
,
dobj
.
from_global_data
(
arr
,
sum_up
))
return
Field
(
DomainTuple
.
make
(
domain
),
dobj
.
from_global_data
(
arr
,
sum_up
))
@
staticmethod
def
from_local_data
(
domain
,
arr
):
domain
=
DomainTuple
.
make
(
domain
)
return
Field
(
domain
,
dobj
.
from_local_data
(
domain
.
shape
,
arr
))
return
Field
(
DomainTuple
.
make
(
domain
),
dobj
.
from_local_data
(
domain
.
shape
,
arr
))
def
to_global_data
(
self
):
"""Returns an array containing the full data of the field.
...
...
@@ -167,25 +154,7 @@ class Field(object):
-----
No copy is made. If needed, use an additional copy() invocation.
"""
return
Field
(
new_domain
,
self
.
_val
)
@
staticmethod
def
_infer_domain
(
domain
,
val
=
None
):
if
domain
is
None
:
if
isinstance
(
val
,
Field
):
return
val
.
_domain
if
np
.
isscalar
(
val
):
return
DomainTuple
.
make
(())
# empty domain tuple
raise
TypeError
(
"could not infer domain from value"
)
return
DomainTuple
.
make
(
domain
)
@
staticmethod
def
_infer_dtype
(
dtype
,
val
):
if
dtype
is
not
None
:
return
dtype
if
val
is
None
:
raise
ValueError
(
"could not infer dtype"
)
return
np
.
result_type
(
val
)
return
Field
(
DomainTuple
.
make
(
new_domain
),
self
.
_val
)
@
staticmethod
def
from_random
(
random_type
,
domain
,
dtype
=
np
.
float64
,
**
kwargs
):
...
...
@@ -444,7 +413,7 @@ class Field(object):
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
i
not
in
spaces
)
return
Field
(
domain
=
return_domain
,
val
=
data
)
return
Field
(
DomainTuple
.
make
(
return_domain
),
data
)
def
sum
(
self
,
spaces
=
None
):
"""Sums up over the sub-domains given by `spaces`.
...
...
nifty5/sugar.py
View file @
2ec3da9a
...
...
@@ -40,7 +40,7 @@ def PS_field(pspace, func):
if
not
isinstance
(
pspace
,
PowerSpace
):
raise
TypeError
data
=
dobj
.
from_global_data
(
func
(
pspace
.
k_lengths
))
return
Field
(
pspace
,
val
=
data
)
return
Field
(
DomainTuple
.
make
(
pspace
),
data
)
def
get_signal_variance
(
spec
,
space
):
...
...
@@ -158,7 +158,7 @@ def _create_power_field(domain, power_spectrum):
if
not
isinstance
(
power_spectrum
.
domain
[
0
],
PowerSpace
):
raise
TypeError
(
"PowerSpace required"
)
power_domain
=
power_spectrum
.
domain
[
0
]
fp
=
Field
(
power_domain
,
val
=
power_spectrum
.
val
)
fp
=
power_spectrum
else
:
power_domain
=
PowerSpace
(
domain
)
fp
=
PS_field
(
power_domain
,
power_spectrum
)
...
...
test/test_field.py
View file @
2ec3da9a
...
...
@@ -139,16 +139,16 @@ class Test_Functionality(unittest.TestCase):
assert_equal
(
d
,
d2
)
def
test_empty_domain
(
self
):
f
=
ift
.
Field
((),
5
)
f
=
ift
.
Field
.
full
((),
5
)
assert_equal
(
f
.
local_data
,
5
)
f
=
ift
.
Field
(
None
,
5
)
f
=
ift
.
Field
.
full
(
None
,
5
)
assert_equal
(
f
.
local_data
,
5
)
def
test_trivialities
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
f1
=
ift
.
Field
(
s1
,
27
)
f1
=
ift
.
Field
.
full
(
s1
,
27
)
assert_equal
(
f1
.
local_data
,
f1
.
real
.
local_data
)
f1
=
ift
.
Field
(
s1
,
27.
+
3j
)
f1
=
ift
.
Field
.
full
(
s1
,
27.
+
3j
)
assert_equal
(
f1
.
real
.
local_data
,
27.
)
assert_equal
(
f1
.
imag
.
local_data
,
3.
)
assert_equal
(
f1
.
local_data
,
+
f1
.
local_data
)
...
...
@@ -160,7 +160,7 @@ class Test_Functionality(unittest.TestCase):
def
test_weight
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
f
=
ift
.
Field
(
s1
,
10.
)
f
=
ift
.
Field
.
full
(
s1
,
10.
)
f2
=
f
.
weight
(
1
)
assert_equal
(
f
.
weight
(
1
).
local_data
,
f2
.
local_data
)
assert_equal
(
f
.
total_volume
(),
1
)
...
...
@@ -170,7 +170,7 @@ class Test_Functionality(unittest.TestCase):
assert_equal
(
f
.
scalar_weight
(
0
),
0.1
)
assert_equal
(
f
.
scalar_weight
((
0
,)),
0.1
)
s1
=
ift
.
GLSpace
(
10
)
f
=
ift
.
Field
(
s1
,
10.
)
f
=
ift
.
Field
.
full
(
s1
,
10.
)
assert_equal
(
f
.
scalar_weight
(),
None
)
assert_equal
(
f
.
scalar_weight
(
0
),
None
)
assert_equal
(
f
.
scalar_weight
((
0
,)),
None
)
...
...
@@ -178,7 +178,7 @@ class Test_Functionality(unittest.TestCase):
@
expand
(
product
([
ift
.
RGSpace
(
10
),
ift
.
GLSpace
(
10
)],
[
np
.
float64
,
np
.
complex128
]))
def
test_reduction
(
self
,
dom
,
dt
):
s1
=
ift
.
Field
(
dom
,
1.
,
dtype
=
dt
)
s1
=
ift
.
Field
.
full
(
dom
,
dt
(
1.
)
)
assert_allclose
(
s1
.
mean
(),
1.
)
assert_allclose
(
s1
.
mean
(
0
),
1.
)
assert_allclose
(
s1
.
var
(),
0.
,
atol
=
1e-14
)
...
...
@@ -189,13 +189,11 @@ class Test_Functionality(unittest.TestCase):
def
test_err
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
s2
=
ift
.
RGSpace
((
11
,))
f1
=
ift
.
Field
(
s1
,
27
)
f1
=
ift
.
Field
.
full
(
s1
,
27
)
with
assert_raises
(
ValueError
):
f2
=
ift
.
Field
(
s2
,
f1
)
with
assert_raises
(
ValueError
):
f2
=
ift
.
Field
(
s2
,
f1
.
val
)
f2
=
ift
.
Field
(
ift
.
DomainTuple
.
make
(
s2
),
f1
.
val
)
with
assert_raises
(
TypeError
):
f2
=
ift
.
Field
(
s2
,
"xyz"
)
f2
=
ift
.
Field
.
full
(
s2
,
"xyz"
)
with
assert_raises
(
TypeError
):
if
f1
:
pass
...
...
@@ -203,20 +201,20 @@ class Test_Functionality(unittest.TestCase):
f1
.
full
((
2
,
4
,
6
))
with
assert_raises
(
TypeError
):
f2
=
ift
.
Field
(
None
,
None
)
with
assert_raises
(
Valu
eError
):
with
assert_raises
(
Typ
eError
):
f2
=
ift
.
Field
(
s1
,
None
)
with
assert_raises
(
ValueError
):
f1
.
imag
with
assert_raises
(
TypeError
):
f1
.
vdot
(
42
)
with
assert_raises
(
ValueError
):
f1
.
vdot
(
ift
.
Field
(
s2
,
1.
))
f1
.
vdot
(
ift
.
Field
.
full
(
s2
,
1.
))
with
assert_raises
(
TypeError
):
ift
.
full
(
s1
,
[
2
,
3
])
def
test_stdfunc
(
self
):
s
=
ift
.
RGSpace
((
200
,))
f
=
ift
.
Field
(
s
,
27
)
f
=
ift
.
Field
.
full
(
s
,
27
)
assert_equal
(
f
.
local_data
,
27
)
assert_equal
(
f
.
shape
,
(
200
,))
assert_equal
(
f
.
dtype
,
np
.
int
)
...
...
test/test_minimization/test_minimizers.py
View file @
2ec3da9a
...
...
@@ -133,7 +133,7 @@ class Test_Minimizers(unittest.TestCase):
@
expand
(
product
(
minimizers
+
slow_minimizers
))
def
test_gauss
(
self
,
minimizer
):
space
=
ift
.
UnstructuredDomain
((
1
,))
starting_point
=
ift
.
Field
(
space
,
val
=
3.
)
starting_point
=
ift
.
Field
.
full
(
space
,
3.
)
class
ExpEnergy
(
ift
.
Energy
):
def
__init__
(
self
,
position
):
...
...
@@ -147,14 +147,15 @@ class Test_Minimizers(unittest.TestCase):
@
property
def
gradient
(
self
):
x
=
self
.
position
.
to_global_data
()[
0
]
return
ift
.
Field
(
self
.
position
.
domain
,
val
=
2
*
x
*
np
.
exp
(
-
(
x
**
2
)))
return
ift
.
Field
.
full
(
self
.
position
.
domain
,
2
*
x
*
np
.
exp
(
-
(
x
**
2
)))
@
property
def
metric
(
self
):
x
=
self
.
position
.
to_global_data
()[
0
]
v
=
(
2
-
4
*
x
*
x
)
*
np
.
exp
(
-
x
**
2
)
return
ift
.
DiagonalOperator
(
ift
.
Field
(
self
.
position
.
domain
,
val
=
v
))
ift
.
Field
.
full
(
self
.
position
.
domain
,
v
))
try
:
minimizer
=
eval
(
minimizer
)
...
...
@@ -171,7 +172,7 @@ class Test_Minimizers(unittest.TestCase):
@
expand
(
product
(
minimizers
+
newton_minimizers
+
slow_minimizers
))
def
test_cosh
(
self
,
minimizer
):
space
=
ift
.
UnstructuredDomain
((
1
,))
starting_point
=
ift
.
Field
(
space
,
val
=
3.
)
starting_point
=
ift
.
Field
.
full
(
space
,
3.
)
class
CoshEnergy
(
ift
.
Energy
):
def
__init__
(
self
,
position
):
...
...
@@ -185,14 +186,14 @@ class Test_Minimizers(unittest.TestCase):
@
property
def
gradient
(
self
):
x
=
self
.
position
.
to_global_data
()[
0
]
return
ift
.
Field
(
self
.
position
.
domain
,
val
=
np
.
sinh
(
x
))
return
ift
.
Field
.
full
(
self
.
position
.
domain
,
np
.
sinh
(
x
))
@
property
def
metric
(
self
):
x
=
self
.
position
.
to_global_data
()[
0
]
v
=
np
.
cosh
(
x
)
return
ift
.
DiagonalOperator
(
ift
.
Field
(
self
.
position
.
domain
,
val
=
v
))
ift
.
Field
.
full
(
self
.
position
.
domain
,
v
))
try
:
minimizer
=
eval
(
minimizer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment