Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
1297585d
Commit
1297585d
authored
Aug 22, 2017
by
Theo Steininger
Browse files
Merge branch 'python3' into 'master'
Add Python3 compatibility See merge request
!156
parents
da29a8f4
a476b875
Pipeline
#17025
passed with stages
in 24 minutes and 49 seconds
Changes
121
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
.gitlab-ci.yml
View file @
1297585d
...
@@ -14,21 +14,22 @@ before_script:
...
@@ -14,21 +14,22 @@ before_script:
-
chmod +x ci/*.sh
-
chmod +x ci/*.sh
-
ci/install_basics.sh
-
ci/install_basics.sh
-
pip install --upgrade -r ci/requirements.txt
-
pip install --upgrade -r ci/requirements.txt
-
pip3 install --upgrade -r ci/requirements.txt
test_min
:
test_min
:
stage
:
test
stage
:
test
script
:
script
:
-
python setup.py build_ext --inplace
-
nosetests -vv
-
nosetests -vv
-
nosetests3 -vv
test_mpi
:
test_mpi
:
stage
:
test
stage
:
test
script
:
script
:
-
ci/install_pyHealpix.sh
-
ci/install_pyHealpix.sh
-
ci/install_mpi4py.sh
-
ci/install_mpi4py.sh
-
python setup.py build_ext --inplace
-
nosetests -vv
-
nosetests -vv
-
nosetests3 -vv
test_mpi_fftw
:
test_mpi_fftw
:
stage
:
test
stage
:
test
...
@@ -36,8 +37,8 @@ test_mpi_fftw:
...
@@ -36,8 +37,8 @@ test_mpi_fftw:
-
ci/install_pyHealpix.sh
-
ci/install_pyHealpix.sh
-
ci/install_mpi4py.sh
-
ci/install_mpi4py.sh
-
ci/install_pyfftw.sh
-
ci/install_pyfftw.sh
-
python setup.py build_ext --inplace
-
nosetests -vv
-
nosetests -vv
-
nosetests3 -vv
test_mpi_fftw_hdf5
:
test_mpi_fftw_hdf5
:
stage
:
test
stage
:
test
...
@@ -46,9 +47,10 @@ test_mpi_fftw_hdf5:
...
@@ -46,9 +47,10 @@ test_mpi_fftw_hdf5:
-
ci/install_mpi4py.sh
-
ci/install_mpi4py.sh
-
ci/install_pyfftw.sh
-
ci/install_pyfftw.sh
-
ci/install_h5py.sh
-
ci/install_h5py.sh
-
python setup.py build_ext --inplace
-
mpiexec --allow-run-as-root -n 2 nosetests -x
-
mpiexec --allow-run-as-root -n 2 nosetests -x
-
mpiexec --allow-run-as-root -n 2 nosetests3 -x
-
mpiexec --allow-run-as-root -n 4 nosetests -x
-
mpiexec --allow-run-as-root -n 4 nosetests -x
-
mpiexec --allow-run-as-root -n 4 nosetests3 -x
-
nosetests -x --with-coverage --cover-package=nifty --cover-branches
-
nosetests -x --with-coverage --cover-package=nifty --cover-branches
-
>
-
>
coverage report | grep TOTAL | awk '{ print "TOTAL: "$6; }'
coverage report | grep TOTAL | awk '{ print "TOTAL: "$6; }'
...
...
ci/install_basics.sh
View file @
1297585d
#!/bin/bash
#!/bin/bash
apt-get
install
-y
build-essential python python-pip python-dev git autoconf libtool gsl-bin libgsl-dev wget
apt-get
install
-y
build-essential python python-pip python-dev git autoconf libtool gsl-bin libgsl-dev wget
python3 python3-pip python3-dev python3-nose
ci/install_mpi4py.sh
View file @
1297585d
...
@@ -2,3 +2,4 @@
...
@@ -2,3 +2,4 @@
apt-get
install
-y
openmpi-bin libopenmpi-dev
apt-get
install
-y
openmpi-bin libopenmpi-dev
pip
install
mpi4py
pip
install
mpi4py
pip3
install
mpi4py
ci/install_pyHealpix.sh
View file @
1297585d
...
@@ -2,4 +2,5 @@
...
@@ -2,4 +2,5 @@
git clone https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
git clone https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git
(
cd
pyHealpix
&&
autoreconf
-i
&&
./configure
--enable-openmp
&&
make
-j4
install
)
(
cd
pyHealpix
&&
autoreconf
-i
&&
./configure
--enable-openmp
&&
make
-j4
install
)
(
cd
pyHealpix
&&
autoreconf
-i
&&
PYTHON
=
python3 ./configure
--enable-openmp
&&
make
-j4
install
)
rm
-rf
pyHealpix
rm
-rf
pyHealpix
ci/install_pyfftw.sh
View file @
1297585d
...
@@ -4,4 +4,5 @@ apt-get install -y libatlas-base-dev libfftw3-bin libfftw3-dev libfftw3-double3
...
@@ -4,4 +4,5 @@ apt-get install -y libatlas-base-dev libfftw3-bin libfftw3-dev libfftw3-double3
git clone
-b
mpi https://github.com/fredros/pyFFTW.git
git clone
-b
mpi https://github.com/fredros/pyFFTW.git
(
cd
pyFFTW
&&
CC
=
mpicc python setup.py build_ext
install
)
(
cd
pyFFTW
&&
CC
=
mpicc python setup.py build_ext
install
)
(
cd
pyFFTW
&&
CC
=
mpicc python3 setup.py build_ext
install
)
rm
-rf
pyFFTW
rm
-rf
pyFFTW
demos/paper_demos/cartesian_wiener_filter.py
View file @
1297585d
...
@@ -7,8 +7,9 @@ from nifty import plotting
...
@@ -7,8 +7,9 @@ from nifty import plotting
from
keepers
import
Repository
from
keepers
import
Repository
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
ift
.
nifty_configuration
[
'default_distribution_strategy'
]
=
'fftw'
signal_to_noise
=
1.5
# The signal to noise ratio
a
signal_to_noise
=
1.5
# The signal to noise ratio
# Setting up parameters |\label{code:wf_parameters}|
# Setting up parameters |\label{code:wf_parameters}|
...
@@ -29,7 +30,7 @@ if __name__ == "__main__":
...
@@ -29,7 +30,7 @@ if __name__ == "__main__":
harmonic_space_1
=
ift
.
FFTOperator
.
get_default_codomain
(
signal_space_1
)
harmonic_space_1
=
ift
.
FFTOperator
.
get_default_codomain
(
signal_space_1
)
fft_1
=
ift
.
FFTOperator
(
harmonic_space_1
,
target
=
signal_space_1
,
fft_1
=
ift
.
FFTOperator
(
harmonic_space_1
,
target
=
signal_space_1
,
domain_dtype
=
np
.
complex
,
target_dtype
=
np
.
complex
)
domain_dtype
=
np
.
complex
,
target_dtype
=
np
.
complex
)
power_space_1
=
ift
.
PowerSpace
(
harmonic_space_1
,
distribution_strategy
=
'fftw'
)
power_space_1
=
ift
.
PowerSpace
(
harmonic_space_1
)
mock_power_1
=
ift
.
Field
(
power_space_1
,
val
=
power_spectrum_1
,
mock_power_1
=
ift
.
Field
(
power_space_1
,
val
=
power_spectrum_1
,
distribution_strategy
=
'not'
)
distribution_strategy
=
'not'
)
...
@@ -67,20 +68,18 @@ if __name__ == "__main__":
...
@@ -67,20 +68,18 @@ if __name__ == "__main__":
distribution_strategy
=
'not'
)
distribution_strategy
=
'not'
)
diagonal
=
mock_power
.
power_synthesize
(
spaces
=
(
0
,
1
),
mean
=
1
,
std
=
0
,
diagonal
=
mock_power
.
power_synthesize
(
spaces
=
(
0
,
1
),
mean
=
1
,
std
=
0
,
real_signal
=
False
,
real_signal
=
False
)
**
2
distribution_strategy
=
'fftw'
)
**
2
S
=
ift
.
DiagonalOperator
(
domain
=
(
harmonic_space_1
,
harmonic_space_2
),
S
=
ift
.
DiagonalOperator
(
domain
=
(
harmonic_space_1
,
harmonic_space_2
),
diagonal
=
diagonal
)
diagonal
=
diagonal
)
np
.
random
.
seed
(
10
)
np
.
random
.
seed
(
10
)
mock_signal
=
fft
(
mock_power
.
power_synthesize
(
real_signal
=
True
,
mock_signal
=
fft
(
mock_power
.
power_synthesize
(
real_signal
=
True
))
distribution_strategy
=
'fftw'
))
# Setting up a exemplary response
# Setting up a exemplary response
N1_10
=
int
(
N_pixels_1
/
10
)
N1_10
=
int
(
N_pixels_1
/
10
)
mask_1
=
ift
.
Field
(
signal_space_1
,
val
=
1.
,
distribution_strategy
=
'fftw'
)
mask_1
=
ift
.
Field
(
signal_space_1
,
val
=
1.
)
mask_1
.
val
[
N1_10
*
7
:
N1_10
*
9
]
=
0.
mask_1
.
val
[
N1_10
*
7
:
N1_10
*
9
]
=
0.
N2_10
=
int
(
N_pixels_2
/
10
)
N2_10
=
int
(
N_pixels_2
/
10
)
...
@@ -95,12 +94,10 @@ if __name__ == "__main__":
...
@@ -95,12 +94,10 @@ if __name__ == "__main__":
# Setting up the noise covariance and drawing a random noise realization
# Setting up the noise covariance and drawing a random noise realization
N
=
ift
.
DiagonalOperator
(
data_domain
,
diagonal
=
mock_signal
.
var
()
/
signal_to_noise
,
N
=
ift
.
DiagonalOperator
(
data_domain
,
diagonal
=
mock_signal
.
var
()
/
signal_to_noise
,
bare
=
True
,
bare
=
True
)
distribution_strategy
=
'fftw'
)
noise
=
ift
.
Field
.
from_random
(
domain
=
data_domain
,
random_type
=
'normal'
,
noise
=
ift
.
Field
.
from_random
(
domain
=
data_domain
,
random_type
=
'normal'
,
std
=
mock_signal
.
std
()
/
np
.
sqrt
(
signal_to_noise
),
std
=
mock_signal
.
std
()
/
np
.
sqrt
(
signal_to_noise
),
mean
=
0
,
mean
=
0
)
distribution_strategy
=
'fftw'
)
data
=
R
(
mock_signal
)
+
noise
#|\label{code:wf_mock_data}|
data
=
R
(
mock_signal
)
+
noise
#|\label{code:wf_mock_data}|
# Wiener filter
# Wiener filter
...
@@ -133,7 +130,7 @@ if __name__ == "__main__":
...
@@ -133,7 +130,7 @@ if __name__ == "__main__":
plotter
.
plot
.
zmin
=
0.
plotter
.
plot
.
zmin
=
0.
plotter
.
plot
.
zmax
=
3.
plotter
.
plot
.
zmax
=
3.
sm
=
ift
.
SmoothingOperator
(
plot_space
,
sigma
=
0.03
)
sm
=
ift
.
SmoothingOperator
.
make
(
plot_space
,
sigma
=
0.03
)
plotter
(
ift
.
log
(
ift
.
sqrt
(
sm
(
ift
.
Field
(
plot_space
,
val
=
variance
.
val
.
real
)))),
path
=
'uncertainty.html'
)
plotter
(
ift
.
log
(
ift
.
sqrt
(
sm
(
ift
.
Field
(
plot_space
,
val
=
variance
.
val
.
real
)))),
path
=
'uncertainty.html'
)
plotter
.
plot
.
zmin
=
np
.
real
(
mock_signal
.
min
());
plotter
.
plot
.
zmin
=
np
.
real
(
mock_signal
.
min
());
...
...
demos/paper_demos/wiener_filter.py
View file @
1297585d
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
proby
=
Proby
(
signal_space
,
probe_count
=
800
)
proby
=
Proby
(
signal_space
,
probe_count
=
800
)
proby
(
lambda
z
:
fft
(
wiener_curvature
.
inverse_times
(
fft
.
inverse_times
(
z
))))
#|\label{code:wf_variance_fft_wrap}|
proby
(
lambda
z
:
fft
(
wiener_curvature
.
inverse_times
(
fft
.
inverse_times
(
z
))))
#|\label{code:wf_variance_fft_wrap}|
sm
=
ift
.
SmoothingOperator
(
signal_space
,
sigma
=
0.03
)
sm
=
ift
.
SmoothingOperator
.
make
(
signal_space
,
sigma
=
0.03
)
variance
=
ift
.
sqrt
(
sm
(
proby
.
diagonal
.
weight
(
-
1
)))
#|\label{code:wf_variance_weighting}|
variance
=
ift
.
sqrt
(
sm
(
proby
.
diagonal
.
weight
(
-
1
)))
#|\label{code:wf_variance_weighting}|
repo
=
Repository
(
'repo_800.h5'
)
repo
=
Repository
(
'repo_800.h5'
)
...
...
demos/probing.py
View file @
1297585d
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from
__future__
import
print_function
from
nifty
import
Field
,
RGSpace
,
DiagonalProberMixin
,
TraceProberMixin
,
\
from
nifty
import
Field
,
RGSpace
,
DiagonalProberMixin
,
TraceProberMixin
,
\
Prober
,
DiagonalOperator
Prober
,
DiagonalOperator
...
@@ -19,12 +20,9 @@ diagOp = DiagonalOperator(domain=x, diagonal=f)
...
@@ -19,12 +20,9 @@ diagOp = DiagonalOperator(domain=x, diagonal=f)
diagProber
=
DiagonalProber
(
domain
=
x
)
diagProber
=
DiagonalProber
(
domain
=
x
)
diagProber
(
diagOp
)
diagProber
(
diagOp
)
print
(
f
-
diagProber
.
diagonal
).
norm
()
print
(
(
f
-
diagProber
.
diagonal
).
norm
()
)
multiProber
=
MultiProber
(
domain
=
x
)
multiProber
=
MultiProber
(
domain
=
x
)
multiProber
(
diagOp
)
multiProber
(
diagOp
)
print
(
f
-
multiProber
.
diagonal
).
norm
()
print
((
f
-
multiProber
.
diagonal
).
norm
())
print
f
.
sum
()
-
multiProber
.
trace
print
(
f
.
sum
()
-
multiProber
.
trace
)
nifty/__init__.py
View file @
1297585d
...
@@ -26,36 +26,34 @@ logger = MPILogger()
...
@@ -26,36 +26,34 @@ logger = MPILogger()
# it is important to import config before d2o such that NIFTy is able to
# it is important to import config before d2o such that NIFTy is able to
# pre-create d2o's configuration object with the corrected path
# pre-create d2o's configuration object with the corrected path
from
config
import
dependency_injector
,
\
from
.
config
import
dependency_injector
,
\
nifty_configuration
,
\
nifty_configuration
,
\
d2o_configuration
d2o_configuration
from
d2o
import
distributed_data_object
,
d2o_librarian
from
d2o
import
distributed_data_object
,
d2o_librarian
from
.field
import
Field
from
.random
import
Random
from
field
import
Field
from
.basic_arithmetics
import
*
from
random
import
Random
from
.nifty_utilities
import
*
from
basic_arithmetic
s
import
*
from
.field_type
s
import
*
from
nifty_utilit
ies
import
*
from
.energ
ies
import
*
from
field_types
import
*
from
.minimization
import
*
from
energi
es
import
*
from
.spac
es
import
*
from
minimization
import
*
from
.operators
import
*
from
spaces
import
*
from
.probing
import
*
from
operators
import
*
from
.sugar
import
*
from
probing
import
*
from
.
import
plotting
from
sugar
import
*
from
.
import
library
import
library
import
plotting
nifty/basic_arithmetics.py
View file @
1297585d
...
@@ -16,9 +16,10 @@
...
@@ -16,9 +16,10 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
division
import
numpy
as
np
import
numpy
as
np
from
d2o
import
distributed_data_object
from
d2o
import
distributed_data_object
from
nifty
.field
import
Field
from
.field
import
Field
__all__
=
[
'cos'
,
'sin'
,
'cosh'
,
'sinh'
,
'tan'
,
'tanh'
,
'arccos'
,
'arcsin'
,
__all__
=
[
'cos'
,
'sin'
,
'cosh'
,
'sinh'
,
'tan'
,
'tanh'
,
'arccos'
,
'arcsin'
,
...
...
nifty/config/__init__.py
View file @
1297585d
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
# and financially supported by the Studienstiftung des deutschen Volkes.
from
nifty_config
import
dependency_injector
,
\
from
.
nifty_config
import
dependency_injector
,
\
nifty_configuration
nifty_configuration
from
d2o_config
import
d2o_configuration
from
.
d2o_config
import
d2o_configuration
nifty/config/nifty_config.py
View file @
1297585d
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
# and financially supported by the Studienstiftung des deutschen Volkes.
import
matplotlib_init
from
.
import
matplotlib_init
import
os
import
os
...
...
nifty/domain_object.py
View file @
1297585d
...
@@ -16,14 +16,17 @@
...
@@ -16,14 +16,17 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
division
import
abc
import
abc
from
nifty
.nifty_meta
import
NiftyMeta
from
.nifty_meta
import
NiftyMeta
from
keepers
import
Loggable
,
\
from
keepers
import
Loggable
,
\
Versionable
Versionable
from
future.utils
import
with_metaclass
class
DomainObject
(
Versionable
,
Loggable
,
object
):
class
DomainObject
(
with_metaclass
(
NiftyMeta
,
type
(
'NewBase'
,
(
Versionable
,
Loggable
,
object
),
{}))):
"""The abstract class that can be used as a domain for a field.
"""The abstract class that can be used as a domain for a field.
This holds all the information and functionality a field needs to know
This holds all the information and functionality a field needs to know
...
@@ -39,8 +42,6 @@ class DomainObject(Versionable, Loggable, object):
...
@@ -39,8 +42,6 @@ class DomainObject(Versionable, Loggable, object):
"""
"""
__metaclass__
=
NiftyMeta
def
__init__
(
self
):
def
__init__
(
self
):
# _global_id is used in the Versioning module from keepers
# _global_id is used in the Versioning module from keepers
self
.
_ignore_for_hash
=
[
'_global_id'
]
self
.
_ignore_for_hash
=
[
'_global_id'
]
...
@@ -56,7 +57,7 @@ class DomainObject(Versionable, Loggable, object):
...
@@ -56,7 +57,7 @@ class DomainObject(Versionable, Loggable, object):
item
=
vars
(
self
)[
key
]
item
=
vars
(
self
)[
key
]
if
key
in
self
.
_ignore_for_hash
or
key
==
'_ignore_for_hash'
:
if
key
in
self
.
_ignore_for_hash
or
key
==
'_ignore_for_hash'
:
continue
continue
result_hash
^=
item
.
__hash__
()
^
int
(
hash
(
key
)
/
117
)
result_hash
^=
item
.
__hash__
()
^
int
(
hash
(
key
)
/
/
117
)
return
result_hash
return
result_hash
def
__eq__
(
self
,
x
):
def
__eq__
(
self
,
x
):
...
@@ -75,7 +76,7 @@ class DomainObject(Versionable, Loggable, object):
...
@@ -75,7 +76,7 @@ class DomainObject(Versionable, Loggable, object):
"""
"""
if
isinstance
(
x
,
type
(
self
)):
if
isinstance
(
x
,
type
(
self
)):
for
key
in
vars
(
self
).
keys
():
for
key
in
list
(
vars
(
self
).
keys
()
)
:
item1
=
vars
(
self
)[
key
]
item1
=
vars
(
self
)[
key
]
if
key
in
self
.
_ignore_for_hash
or
key
==
'_ignore_for_hash'
:
if
key
in
self
.
_ignore_for_hash
or
key
==
'_ignore_for_hash'
:
continue
continue
...
...
nifty/energies/__init__.py
View file @
1297585d
...
@@ -16,6 +16,6 @@
...
@@ -16,6 +16,6 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
# and financially supported by the Studienstiftung des deutschen Volkes.
from
energy
import
Energy
from
.
energy
import
Energy
from
line_energy
import
LineEnergy
from
.
line_energy
import
LineEnergy
from
memoization
import
memo
from
.
memoization
import
memo
nifty/energies/energy.py
View file @
1297585d
...
@@ -19,9 +19,10 @@
...
@@ -19,9 +19,10 @@
from
nifty.nifty_meta
import
NiftyMeta
from
nifty.nifty_meta
import
NiftyMeta
from
keepers
import
Loggable
from
keepers
import
Loggable
from
future.utils
import
with_metaclass
class
Energy
(
Loggable
,
object
):
class
Energy
(
with_metaclass
(
NiftyMeta
,
type
(
'NewBase'
,
(
Loggable
,
object
)
,
{})))
:
""" Provides the functional used by minimization schemes.
""" Provides the functional used by minimization schemes.
The Energy object is an implementation of a scalar function including its
The Energy object is an implementation of a scalar function including its
...
@@ -63,8 +64,6 @@ class Energy(Loggable, object):
...
@@ -63,8 +64,6 @@ class Energy(Loggable, object):
"""
"""
__metaclass__
=
NiftyMeta
def
__init__
(
self
,
position
):
def
__init__
(
self
,
position
):
super
(
Energy
,
self
).
__init__
()
super
(
Energy
,
self
).
__init__
()
self
.
_cache
=
{}
self
.
_cache
=
{}
...
...
nifty/energies/line_energy.py
View file @
1297585d
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
print_function
class
LineEnergy
(
object
):
class
LineEnergy
(
object
):
""" Evaluates an underlying Energy along a certain line direction.
""" Evaluates an underlying Energy along a certain line direction.
...
@@ -114,6 +116,6 @@ class LineEnergy(object):
...
@@ -114,6 +116,6 @@ class LineEnergy(object):
def
directional_derivative
(
self
):
def
directional_derivative
(
self
):
res
=
self
.
energy
.
gradient
.
vdot
(
self
.
line_direction
)
res
=
self
.
energy
.
gradient
.
vdot
(
self
.
line_direction
)
if
abs
(
res
.
imag
)
/
max
(
abs
(
res
.
real
),
1.
)
>
1e-12
:
if
abs
(
res
.
imag
)
/
max
(
abs
(
res
.
real
),
1.
)
>
1e-12
:
print
"directional derivative has non-negligible "
\
print
(
"directional derivative has non-negligible "
"imaginary part:"
,
res
"imaginary part:"
,
res
)
return
res
.
real
return
res
.
real
nifty/field.py
View file @
1297585d
...
@@ -17,6 +17,9 @@
...
@@ -17,6 +17,9 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
division
from
__future__
import
division
from
builtins
import
zip
#from builtins import str
from
builtins
import
range
import
ast
import
ast
import
numpy
as
np
import
numpy
as
np
...
@@ -27,14 +30,15 @@ from keepers import Versionable,\
...
@@ -27,14 +30,15 @@ from keepers import Versionable,\
from
d2o
import
distributed_data_object
,
\
from
d2o
import
distributed_data_object
,
\
STRATEGIES
as
DISTRIBUTION_STRATEGIES
STRATEGIES
as
DISTRIBUTION_STRATEGIES
from
nifty
.config
import
nifty_configuration
as
gc
from
.config
import
nifty_configuration
as
gc
from
nifty
.domain_object
import
DomainObject
from
.domain_object
import
DomainObject
from
nifty
.spaces.power_space
import
PowerSpace
from
.spaces.power_space
import
PowerSpace
import
nifty.nifty_utilities
as
utilities
from
.
import
nifty_utilities
as
utilities
from
nifty.random
import
Random
from
.random
import
Random
from
functools
import
reduce
class
Field
(
Loggable
,
Versionable
,
object
):
class
Field
(
Loggable
,
Versionable
,
object
):
...
@@ -346,7 +350,7 @@ class Field(Loggable, Versionable, object):
...
@@ -346,7 +350,7 @@ class Field(Loggable, Versionable, object):
# check if the `spaces` input is valid
# check if the `spaces` input is valid
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
spaces
is
None
:
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
spaces
=
list
(
range
(
len
(
self
.
domain
))
)
if
len
(
spaces
)
==
0
:
if
len
(
spaces
)
==
0
:
raise
ValueError
(
raise
ValueError
(
...
@@ -526,7 +530,7 @@ class Field(Loggable, Versionable, object):
...
@@ -526,7 +530,7 @@ class Field(Loggable, Versionable, object):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
spaces
is
None
:
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
spaces
=
list
(
range
(
len
(
self
.
domain
))
)
for
power_space_index
in
spaces
:
for
power_space_index
in
spaces
:
power_space
=
self
.
domain
[
power_space_index
]
power_space
=
self
.
domain
[
power_space_index
]
...
@@ -821,7 +825,7 @@ class Field(Loggable, Versionable, object):
...
@@ -821,7 +825,7 @@ class Field(Loggable, Versionable, object):
dim_tuple
=
tuple
(
sp
.
dim
for
sp
in
self
.
domain
)
dim_tuple
=
tuple
(
sp
.
dim
for
sp
in
self
.
domain
)
try
:
try
:
return
reduce
(
lambda
x
,
y
:
x
*
y
,
dim_tuple
)
return
int
(
reduce
(
lambda
x
,
y
:
x
*
y
,
dim_tuple
)
)
except
TypeError
:
except
TypeError
:
return
0
return
0
...
@@ -1010,7 +1014,7 @@ class Field(Loggable, Versionable, object):
...
@@ -1010,7 +1014,7 @@ class Field(Loggable, Versionable, object):
fast_copyable
=
True
fast_copyable
=
True
try
:
try
:
for
i
in
x
range
(
len
(
self
.
domain
)):
for
i
in
range
(
len
(
self
.
domain
)):
if
self
.
domain
[
i
]
is
not
domain
[
i
]:
if
self
.
domain
[
i
]
is
not
domain
[
i
]:
fast_copyable
=
False
fast_copyable
=
False
break
break
...
@@ -1032,7 +1036,7 @@ class Field(Loggable, Versionable, object):
...
@@ -1032,7 +1036,7 @@ class Field(Loggable, Versionable, object):
# repair its class
# repair its class
new_field
.
__class__
=
self
.
__class__
new_field
.
__class__
=
self
.
__class__
# copy domain, codomain and val
# copy domain, codomain and val
for
key
,
value
in
self
.
__dict__
.
items
():
for
key
,
value
in
list
(
self
.
__dict__
.
items
()
)
:
if
key
!=
'_val'
:
if
key
!=
'_val'
:
new_field
.
__dict__
[
key
]
=
value
new_field
.
__dict__
[
key
]
=
value
else
:
else
:
...
@@ -1069,7 +1073,7 @@ class Field(Loggable, Versionable, object):
...
@@ -1069,7 +1073,7 @@ class Field(Loggable, Versionable, object):