Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
a33c8989
Commit
a33c8989
authored
Jun 30, 2017
by
Martin Reinecke
Browse files
first steps
parent
0c8a2610
Changes
106
Hide whitespace changes
Inline
Side-by-side
demos/probing.py
View file @
a33c8989
# -*- coding: utf-8 -*-
from
__future__
import
print_function
from
nifty
import
Field
,
RGSpace
,
DiagonalProberMixin
,
TraceProberMixin
,
\
Prober
,
DiagonalOperator
...
...
@@ -19,12 +20,12 @@ diagOp = DiagonalOperator(domain=x, diagonal=f)
diagProber
=
DiagonalProber
(
domain
=
x
)
diagProber
(
diagOp
)
print
(
f
-
diagProber
.
diagonal
).
norm
()
print
(
(
f
-
diagProber
.
diagonal
).
norm
()
)
multiProber
=
MultiProber
(
domain
=
x
)
multiProber
(
diagOp
)
print
(
f
-
multiProber
.
diagonal
).
norm
()
print
f
.
sum
()
-
multiProber
.
trace
print
(
(
f
-
multiProber
.
diagonal
).
norm
()
)
print
(
f
.
sum
()
-
multiProber
.
trace
)
demos/wiener_filter.py
View file @
a33c8989
from
__future__
import
division
from
nifty
import
*
#import plotly.offline as pl
...
...
@@ -11,7 +12,7 @@ rank = comm.rank
if
__name__
==
"__main__"
:
distribution_strategy
=
'not'
#Setting up physical constants
#total length of Interval or Volume the field lives on, e.g. in meters
L
=
2.
...
...
@@ -21,18 +22,18 @@ if __name__ == "__main__":
field_variance
=
2.
#smoothing length that response (in same unit as L)
response_sigma
=
0.1
#defining resolution (pixels per dimension)
N_pixels
=
512
#Setting up derived constants
k_0
=
1.
/
correlation_length
#note that field_variance**2 = a*k_0/4. for this analytic form of power
#spectrum
a
=
field_variance
**
2
/
k_0
*
4.
pow_spec
=
(
lambda
k
:
a
/
(
1
+
k
/
k_0
)
**
4
)
pow_spec
=
(
lambda
k
:
a
/
(
1
.
+
k
/
k_0
)
**
4
)
pixel_width
=
L
/
N_pixels
# Setting up the geometry
s_space
=
RGSpace
([
N_pixels
,
N_pixels
],
distances
=
pixel_width
)
fft
=
FFTOperator
(
s_space
)
...
...
demos/wiener_filter_hamiltonian.py
View file @
a33c8989
from
__future__
import
division
from
__future__
import
print_function
from
builtins
import
object
from
nifty
import
*
...
...
@@ -62,11 +65,11 @@ if __name__ == "__main__":
p_space
=
PowerSpace
(
h_space
,
distribution_strategy
=
distribution_strategy
)
# create the field instances and power operator
pow_spec
=
(
lambda
k
:
(
42
/
(
k
+
1
)
**
3
))
pow_spec
=
(
lambda
k
:
(
42
.
/
(
k
+
1
)
**
3
))
S
=
create_power_operator
(
h_space
,
power_spectrum
=
pow_spec
,
distribution_strategy
=
distribution_strategy
)
sp
=
Field
(
p_space
,
val
=
lambda
z
:
pow_spec
(
z
)
**
(
1.
/
2
)
,
sp
=
Field
(
p_space
,
val
=
lambda
z
:
pow_spec
(
z
)
**
0.5
,
distribution_strategy
=
distribution_strategy
)
sh
=
sp
.
power_synthesize
(
real_signal
=
True
)
ss
=
fft
.
inverse_times
(
sh
)
...
...
@@ -92,7 +95,7 @@ if __name__ == "__main__":
def
distance_measure
(
energy
,
iteration
):
x
=
energy
.
position
print
(
iteration
,
(
(
x
-
ss
).
norm
()
/
ss
.
norm
()).
real
)
print
(
(
iteration
,
(
x
-
ss
).
norm
()
/
ss
.
norm
()).
real
)
)
# minimizer = SteepestDescent(convergence_tolerance=0,
# iteration_limit=50,
...
...
demos/wiener_filter_harmonic.py
View file @
a33c8989
from
__future__
import
division
from
builtins
import
range
from
nifty
import
*
from
mpi4py
import
MPI
import
plotly.offline
as
py
...
...
@@ -12,7 +14,7 @@ def plot_maps(x, name):
trace
=
[
None
]
*
len
(
x
)
keys
=
x
.
keys
()
keys
=
list
(
x
.
keys
()
)
field
=
x
[
keys
[
0
]]
domain
=
field
.
domain
[
0
]
shape
=
len
(
domain
.
shape
)
...
...
@@ -21,14 +23,14 @@ def plot_maps(x, name):
x_axis
=
np
.
arange
(
0
,
max_n
,
step
)
if
shape
==
1
:
for
ii
in
x
range
(
len
(
x
)):
for
ii
in
range
(
len
(
x
)):
trace
[
ii
]
=
go
.
Scatter
(
x
=
x_axis
,
y
=
x
[
keys
[
ii
]].
val
.
get_full_data
(),
name
=
keys
[
ii
])
fig
=
go
.
Figure
(
data
=
trace
)
py
.
plot
(
fig
,
filename
=
name
)
elif
shape
==
2
:
for
ii
in
x
range
(
len
(
x
)):
for
ii
in
range
(
len
(
x
)):
py
.
plot
([
go
.
Heatmap
(
z
=
x
[
keys
[
ii
]].
val
.
get_full_data
())],
filename
=
keys
[
ii
])
else
:
raise
TypeError
(
"Only 1D and 2D field plots are supported"
)
...
...
@@ -48,12 +50,12 @@ def plot_power(x, name):
trace
=
[
None
]
*
len
(
x
)
keys
=
x
.
keys
()
keys
=
list
(
x
.
keys
()
)
field
=
x
[
keys
[
0
]]
domain
=
field
.
domain
[
0
]
x_axis
=
domain
.
kindex
for
ii
in
x
range
(
len
(
x
)):
for
ii
in
range
(
len
(
x
)):
trace
[
ii
]
=
go
.
Scatter
(
x
=
x_axis
,
y
=
x
[
keys
[
ii
]].
val
.
get_full_data
(),
name
=
keys
[
ii
])
fig
=
go
.
Figure
(
data
=
trace
,
layout
=
layout
)
...
...
demos/wiener_filter_unit.py
View file @
a33c8989
from
__future__
import
division
from
builtins
import
range
from
nifty
import
*
from
mpi4py
import
MPI
import
plotly.offline
as
py
...
...
@@ -12,7 +14,7 @@ def plot_maps(x, name):
trace
=
[
None
]
*
len
(
x
)
keys
=
x
.
keys
()
keys
=
list
(
x
.
keys
()
)
field
=
x
[
keys
[
0
]]
domain
=
field
.
domain
[
0
]
shape
=
len
(
domain
.
shape
)
...
...
@@ -21,14 +23,14 @@ def plot_maps(x, name):
x_axis
=
np
.
arange
(
0
,
max_n
,
step
)
if
shape
==
1
:
for
ii
in
x
range
(
len
(
x
)):
for
ii
in
range
(
len
(
x
)):
trace
[
ii
]
=
go
.
Scatter
(
x
=
x_axis
,
y
=
x
[
keys
[
ii
]].
val
.
get_full_data
(),
name
=
keys
[
ii
])
fig
=
go
.
Figure
(
data
=
trace
)
py
.
plot
(
fig
,
filename
=
name
)
elif
shape
==
2
:
for
ii
in
x
range
(
len
(
x
)):
for
ii
in
range
(
len
(
x
)):
py
.
plot
([
go
.
Heatmap
(
z
=
x
[
keys
[
ii
]].
val
.
get_full_data
())],
filename
=
keys
[
ii
])
else
:
raise
TypeError
(
"Only 1D and 2D field plots are supported"
)
...
...
@@ -48,12 +50,12 @@ def plot_power(x, name):
trace
=
[
None
]
*
len
(
x
)
keys
=
x
.
keys
()
keys
=
list
(
x
.
keys
()
)
field
=
x
[
keys
[
0
]]
domain
=
field
.
domain
[
0
]
x_axis
=
domain
.
kindex
for
ii
in
x
range
(
len
(
x
)):
for
ii
in
range
(
len
(
x
)):
trace
[
ii
]
=
go
.
Scatter
(
x
=
x_axis
,
y
=
x
[
keys
[
ii
]].
val
.
get_full_data
(),
name
=
keys
[
ii
])
fig
=
go
.
Figure
(
data
=
trace
,
layout
=
layout
)
...
...
nifty/__init__.py
View file @
a33c8989
...
...
@@ -26,32 +26,32 @@ logger = MPILogger()
# it is important to import config before d2o such that NIFTy is able to
# pre-create d2o's configuration object with the corrected path
from
config
import
dependency_injector
,
\
from
.
config
import
dependency_injector
,
\
nifty_configuration
,
\
d2o_configuration
from
d2o
import
distributed_data_object
,
d2o_librarian
from
energies
import
*
from
.
energies
import
*
from
field
import
Field
from
.
field
import
Field
from
random
import
Random
from
.
random
import
Random
from
basic_arithmetics
import
*
from
.
basic_arithmetics
import
*
from
nifty_utilities
import
*
from
.
nifty_utilities
import
*
from
field_types
import
*
from
.
field_types
import
*
from
minimization
import
*
from
.
minimization
import
*
from
spaces
import
*
from
.
spaces
import
*
from
operators
import
*
from
.
operators
import
*
from
probing
import
*
from
.
probing
import
*
from
sugar
import
*
from
.
sugar
import
*
import
plotting
from
.
import
plotting
nifty/basic_arithmetics.py
View file @
a33c8989
...
...
@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
division
import
numpy
as
np
from
d2o
import
distributed_data_object
from
nifty.field
import
Field
...
...
nifty/config/__init__.py
View file @
a33c8989
...
...
@@ -17,7 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from
nifty_config
import
dependency_injector
,
\
from
.
nifty_config
import
dependency_injector
,
\
nifty_configuration
from
d2o_config
import
d2o_configuration
from
.
d2o_config
import
d2o_configuration
nifty/config/nifty_config.py
View file @
a33c8989
...
...
@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import
matplotlib_init
from
.
import
matplotlib_init
import
os
...
...
nifty/domain_object.py
View file @
a33c8989
...
...
@@ -16,14 +16,17 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
division
import
abc
from
nifty.nifty_meta
import
NiftyMeta
from
keepers
import
Loggable
,
\
Versionable
from
future.utils
import
with_metaclass
class
DomainObject
(
Versionable
,
Loggable
,
object
):
class
DomainObject
(
with_metaclass
(
NiftyMeta
,
type
(
'NewBase'
,
(
Versionable
,
Loggable
,
object
),
{}))):
"""The abstract class that can be used as a domain for a field.
This holds all the information and functionality a field needs to know
...
...
@@ -39,8 +42,6 @@ class DomainObject(Versionable, Loggable, object):
"""
__metaclass__
=
NiftyMeta
def
__init__
(
self
):
# _global_id is used in the Versioning module from keepers
self
.
_ignore_for_hash
=
[
'_global_id'
]
...
...
@@ -56,7 +57,7 @@ class DomainObject(Versionable, Loggable, object):
item
=
vars
(
self
)[
key
]
if
key
in
self
.
_ignore_for_hash
or
key
==
'_ignore_for_hash'
:
continue
result_hash
^=
item
.
__hash__
()
^
int
(
hash
(
key
)
/
117
)
result_hash
^=
item
.
__hash__
()
^
int
(
hash
(
key
)
/
/
117
)
return
result_hash
def
__eq__
(
self
,
x
):
...
...
@@ -75,7 +76,7 @@ class DomainObject(Versionable, Loggable, object):
"""
if
isinstance
(
x
,
type
(
self
)):
for
key
in
vars
(
self
).
keys
():
for
key
in
list
(
vars
(
self
).
keys
()
)
:
item1
=
vars
(
self
)[
key
]
if
key
in
self
.
_ignore_for_hash
or
key
==
'_ignore_for_hash'
:
continue
...
...
nifty/energies/__init__.py
View file @
a33c8989
...
...
@@ -16,6 +16,6 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
energy
import
Energy
from
line_energy
import
LineEnergy
from
memoization
import
memo
from
.
energy
import
Energy
from
.
line_energy
import
LineEnergy
from
.
memoization
import
memo
nifty/energies/energy.py
View file @
a33c8989
...
...
@@ -19,9 +19,10 @@
from
nifty.nifty_meta
import
NiftyMeta
from
keepers
import
Loggable
from
future.utils
import
with_metaclass
class
Energy
(
Loggable
,
object
):
class
Energy
(
with_metaclass
(
NiftyMeta
,
type
(
'NewBase'
,
(
Loggable
,
object
)
,
{})))
:
""" Provides the functional used by minimization schemes.
The Energy object is an implementation of a scalar function including its
...
...
@@ -63,8 +64,6 @@ class Energy(Loggable, object):
"""
__metaclass__
=
NiftyMeta
def
__init__
(
self
,
position
):
self
.
_cache
=
{}
try
:
...
...
nifty/field.py
View file @
a33c8989
...
...
@@ -17,6 +17,9 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
division
from
builtins
import
zip
from
builtins
import
str
from
builtins
import
range
import
itertools
import
numpy
as
np
...
...
@@ -35,6 +38,7 @@ from nifty.spaces.power_space import PowerSpace
import
nifty.nifty_utilities
as
utilities
from
nifty.random
import
Random
from
functools
import
reduce
class
Field
(
Loggable
,
Versionable
,
object
):
...
...
@@ -337,7 +341,7 @@ class Field(Loggable, Versionable, object):
# check if the `spaces` input is valid
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
spaces
=
list
(
range
(
len
(
self
.
domain
))
)
if
len
(
spaces
)
==
0
:
raise
ValueError
(
...
...
@@ -518,7 +522,7 @@ class Field(Loggable, Versionable, object):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
spaces
=
list
(
range
(
len
(
self
.
domain
))
)
for
power_space_index
in
spaces
:
power_space
=
self
.
domain
[
power_space_index
]
...
...
@@ -602,7 +606,7 @@ class Field(Loggable, Versionable, object):
domain_axes
[
spaces
[
0
]],
preserve_gaussian_variance
=
preserve_gaussian_variance
)
# hermitianize all remaining spaces using the iterative formula
for
space
in
x
range
(
1
,
len
(
spaces
)):
for
space
in
range
(
1
,
len
(
spaces
)):
(
hh
,
ha
)
=
domain
[
space
].
hermitian_decomposition
(
h
,
domain_axes
[
space
],
...
...
@@ -969,7 +973,7 @@ class Field(Loggable, Versionable, object):
fast_copyable
=
True
try
:
for
i
in
x
range
(
len
(
self
.
domain
)):
for
i
in
range
(
len
(
self
.
domain
)):
if
self
.
domain
[
i
]
is
not
domain
[
i
]:
fast_copyable
=
False
break
...
...
@@ -991,7 +995,7 @@ class Field(Loggable, Versionable, object):
# repair its class
new_field
.
__class__
=
self
.
__class__
# copy domain, codomain and val
for
key
,
value
in
self
.
__dict__
.
items
():
for
key
,
value
in
list
(
self
.
__dict__
.
items
()
)
:
if
key
!=
'_val'
:
new_field
.
__dict__
[
key
]
=
value
else
:
...
...
@@ -1028,7 +1032,7 @@ class Field(Loggable, Versionable, object):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
spaces
=
list
(
range
(
len
(
self
.
domain
))
)
for
ind
,
sp
in
enumerate
(
self
.
domain
):
if
ind
in
spaces
:
...
...
@@ -1166,7 +1170,7 @@ class Field(Loggable, Versionable, object):
def
_contraction_helper
(
self
,
op
,
spaces
):
# build a list of all axes
if
spaces
is
None
:
spaces
=
x
range
(
len
(
self
.
domain
))
spaces
=
range
(
len
(
self
.
domain
))
else
:
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
...
...
@@ -1186,7 +1190,7 @@ class Field(Loggable, Versionable, object):
return
data
else
:
return_domain
=
tuple
(
self
.
domain
[
i
]
for
i
in
x
range
(
len
(
self
.
domain
))
for
i
in
range
(
len
(
self
.
domain
))
if
i
not
in
spaces
)
return_field
=
Field
(
domain
=
return_domain
,
...
...
@@ -1234,7 +1238,7 @@ class Field(Loggable, Versionable, object):
if
isinstance
(
other
,
Field
):
try
:
assert
len
(
other
.
domain
)
==
len
(
self
.
domain
)
for
index
in
x
range
(
len
(
self
.
domain
)):
for
index
in
range
(
len
(
self
.
domain
)):
assert
other
.
domain
[
index
]
==
self
.
domain
[
index
]
except
AssertionError
:
raise
ValueError
(
...
...
@@ -1373,6 +1377,17 @@ class Field(Loggable, Versionable, object):
return
self
.
_binary_helper
(
other
,
op
=
'__rdiv__'
)
def
__rtruediv__
(
self
,
other
):
""" x.__rtruediv__(y) <==> y/x
See Also
--------
_builtin_helper
"""
return
self
.
_binary_helper
(
other
,
op
=
'__rtruediv__'
)
def
__idiv__
(
self
,
other
):
""" x.__idiv__(y) <==> x/=y
...
...
nifty/field_types/__init__.py
View file @
a33c8989
...
...
@@ -16,5 +16,5 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
field_type
import
FieldType
from
field_array
import
FieldArray
from
.
field_type
import
FieldType
from
.
field_array
import
FieldArray
nifty/field_types/field_array.py
View file @
a33c8989
...
...
@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
field_type
import
FieldType
from
.field_type
import
FieldType
from
functools
import
reduce
class
FieldArray
(
FieldType
):
...
...
nifty/minimization/__init__.py
View file @
a33c8989
...
...
@@ -16,9 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
line_searching
import
*
from
conjugate_gradient
import
ConjugateGradient
from
descent_minimizer
import
DescentMinimizer
from
steepest_descent
import
SteepestDescent
from
vl_bfgs
import
VL_BFGS
from
relaxed_newton
import
RelaxedNewton
from
.
line_searching
import
*
from
.
conjugate_gradient
import
ConjugateGradient
from
.
descent_minimizer
import
DescentMinimizer
from
.
steepest_descent
import
SteepestDescent
from
.
vl_bfgs
import
VL_BFGS
from
.
relaxed_newton
import
RelaxedNewton
nifty/minimization/descent_minimizer.py
View file @
a33c8989
...
...
@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
division
import
abc
from
nifty.nifty_meta
import
NiftyMeta
...
...
@@ -24,9 +25,10 @@ import numpy as np
from
keepers
import
Loggable
from
.line_searching
import
LineSearchStrongWolfe
from
future.utils
import
with_metaclass
class
DescentMinimizer
(
Loggable
,
object
):
class
DescentMinimizer
(
with_metaclass
(
NiftyMeta
,
type
(
'NewBase'
,
(
Loggable
,
object
)
,
{})))
:
""" A base class used by gradient methods to find a local minimum.
Descent minimization methods are used to find a local minimum of a scalar
...
...
@@ -77,8 +79,6 @@ class DescentMinimizer(Loggable, object):
"""
__metaclass__
=
NiftyMeta
def
__init__
(
self
,
line_searcher
=
LineSearchStrongWolfe
(),
callback
=
None
,
convergence_tolerance
=
1E-4
,
convergence_level
=
3
,
iteration_limit
=
None
):
...
...
nifty/minimization/line_searching/__init__.py
View file @
a33c8989
...
...
@@ -16,5 +16,5 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
line_search
import
LineSearch
from
line_search_strong_wolfe
import
LineSearchStrongWolfe
from
.
line_search
import
LineSearch
from
.
line_search_strong_wolfe
import
LineSearchStrongWolfe
nifty/minimization/line_searching/line_search.py
View file @
a33c8989
...
...
@@ -21,9 +21,10 @@ import abc
from
keepers
import
Loggable
from
nifty
import
LineEnergy
from
future.utils
import
with_metaclass
class
LineSearch
(
Loggable
,
object
):
class
LineSearch
(
with_metaclass
(
abc
.
ABCMeta
,
type
(
'NewBase'
,
(
Loggable
,
object
)
,
{})))
:
"""Class for determining the optimal step size along some descent direction.
Initialize the line search procedure which can be used by a specific line
...
...
@@ -40,8 +41,6 @@ class LineSearch(Loggable, object):
Initial guess for the step length.
"""
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
):
...
...
nifty/minimization/line_searching/line_search_strong_wolfe.py
View file @
a33c8989
...
...
@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
division
from
builtins
import
range
import
numpy
as
np
from
.line_search
import
LineSearch
...
...
@@ -134,7 +136,7 @@ class LineSearchStrongWolfe(LineSearch):
phiprime_alpha0
=
phiprime_0
# start the minimization loop
for
i
in
x
range
(
max_iterations
):
for
i
in
range
(
max_iterations
):
energy_alpha1
=
self
.
line_energy
.
at
(
alpha1
)
phi_alpha1
=
energy_alpha1
.
value
if
alpha1
==
0
:
...
...
@@ -243,7 +245,7 @@ class LineSearchStrongWolfe(LineSearch):
alpha_recent
=
0
phi_recent
=
phi_0
for
i
in
x
range
(
max_iterations
):
for
i
in
range
(
max_iterations
):
delta_alpha
=
alpha_hi
-
alpha_lo
if
delta_alpha
<
0
:
a
,
b
=
alpha_hi
,
alpha_lo
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment