Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
5b713064
Commit
5b713064
authored
Sep 24, 2018
by
Philipp Frank
Browse files
merged master
parents
c1a84f56
ad593f04
Changes
17
Hide whitespace changes
Inline
Side-by-side
nifty5/__init__.py
View file @
5b713064
...
...
@@ -43,8 +43,9 @@ from .operators.slope_operator import SlopeOperator
from
.operators.smoothness_operator
import
SmoothnessOperator
from
.operators.symmetrizing_operator
import
SymmetrizingOperator
from
.operators.block_diagonal_operator
import
BlockDiagonalOperator
from
.operators.outer_product_operator
import
OuterProduct
from
.operators.simple_linear_operators
import
(
VdotOperator
,
SumReductionOperator
,
ConjugationOperator
,
Realizer
,
VdotOperator
,
ConjugationOperator
,
Realizer
,
FieldAdapter
,
GeometryRemover
,
NullOperator
)
from
.operators.energy_operators
import
(
EnergyOperator
,
GaussianEnergy
,
PoissonianEnergy
,
InverseGammaLikelihood
,
...
...
nifty5/fft.py
0 → 100644
View file @
5b713064
from
__future__
import
absolute_import
,
division
,
print_function
from
.utilities
import
iscomplextype
import
numpy
as
np
_use_fftw
=
True
if
_use_fftw
:
import
pyfftw
from
pyfftw.interfaces.numpy_fft
import
fftn
,
rfftn
,
ifftn
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
# Optional extra arguments for the FFT calls
# if exact reproducibility is needed,
# set "planner_effort" to "FFTW_ESTIMATE"
import
os
nthreads
=
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"1"
))
_fft_extra_args
=
dict
(
planner_effort
=
'FFTW_ESTIMATE'
,
threads
=
nthreads
)
else
:
from
numpy.fft
import
fftn
,
rfftn
,
ifftn
_fft_extra_args
=
{}
def
hartley
(
a
,
axes
=
None
):
# Check if the axes provided are valid given the shape
if
axes
is
not
None
and
\
not
all
(
axis
<
len
(
a
.
shape
)
for
axis
in
axes
):
raise
ValueError
(
"Provided axes do not match array shape"
)
if
iscomplextype
(
a
.
dtype
):
raise
TypeError
(
"Hartley transform requires real-valued arrays."
)
tmp
=
rfftn
(
a
,
axes
=
axes
,
**
_fft_extra_args
)
def
_fill_array
(
tmp
,
res
,
axes
):
if
axes
is
None
:
axes
=
tuple
(
range
(
tmp
.
ndim
))
lastaxis
=
axes
[
-
1
]
ntmplast
=
tmp
.
shape
[
lastaxis
]
slice1
=
(
slice
(
None
),)
*
lastaxis
+
(
slice
(
0
,
ntmplast
),)
np
.
add
(
tmp
.
real
,
tmp
.
imag
,
out
=
res
[
slice1
])
def
_fill_upper_half
(
tmp
,
res
,
axes
):
lastaxis
=
axes
[
-
1
]
nlast
=
res
.
shape
[
lastaxis
]
ntmplast
=
tmp
.
shape
[
lastaxis
]
nrem
=
nlast
-
ntmplast
slice1
=
[
slice
(
None
)]
*
lastaxis
+
[
slice
(
ntmplast
,
None
)]
slice2
=
[
slice
(
None
)]
*
lastaxis
+
[
slice
(
nrem
,
0
,
-
1
)]
for
i
in
axes
[:
-
1
]:
slice1
[
i
]
=
slice
(
1
,
None
)
slice2
[
i
]
=
slice
(
None
,
0
,
-
1
)
slice1
=
tuple
(
slice1
)
slice2
=
tuple
(
slice2
)
np
.
subtract
(
tmp
[
slice2
].
real
,
tmp
[
slice2
].
imag
,
out
=
res
[
slice1
])
for
i
,
ax
in
enumerate
(
axes
[:
-
1
]):
dim1
=
(
slice
(
None
),)
*
ax
+
(
slice
(
0
,
1
),)
axes2
=
axes
[:
i
]
+
axes
[
i
+
1
:]
_fill_upper_half
(
tmp
[
dim1
],
res
[
dim1
],
axes2
)
_fill_upper_half
(
tmp
,
res
,
axes
)
return
res
return
_fill_array
(
tmp
,
np
.
empty_like
(
a
),
axes
)
# Do a real-to-complex forward FFT and return the _full_ output array
def
my_fftn_r2c
(
a
,
axes
=
None
):
# Check if the axes provided are valid given the shape
if
axes
is
not
None
and
\
not
all
(
axis
<
len
(
a
.
shape
)
for
axis
in
axes
):
raise
ValueError
(
"Provided axes do not match array shape"
)
if
iscomplextype
(
a
.
dtype
):
raise
TypeError
(
"Transform requires real-valued input arrays."
)
tmp
=
rfftn
(
a
,
axes
=
axes
,
**
_fft_extra_args
)
def
_fill_complex_array
(
tmp
,
res
,
axes
):
if
axes
is
None
:
axes
=
tuple
(
range
(
tmp
.
ndim
))
lastaxis
=
axes
[
-
1
]
ntmplast
=
tmp
.
shape
[
lastaxis
]
slice1
=
[
slice
(
None
)]
*
lastaxis
+
[
slice
(
0
,
ntmplast
)]
res
[
tuple
(
slice1
)]
=
tmp
def
_fill_upper_half_complex
(
tmp
,
res
,
axes
):
lastaxis
=
axes
[
-
1
]
nlast
=
res
.
shape
[
lastaxis
]
ntmplast
=
tmp
.
shape
[
lastaxis
]
nrem
=
nlast
-
ntmplast
slice1
=
[
slice
(
None
)]
*
lastaxis
+
[
slice
(
ntmplast
,
None
)]
slice2
=
[
slice
(
None
)]
*
lastaxis
+
[
slice
(
nrem
,
0
,
-
1
)]
for
i
in
axes
[:
-
1
]:
slice1
[
i
]
=
slice
(
1
,
None
)
slice2
[
i
]
=
slice
(
None
,
0
,
-
1
)
# np.conjugate(tmp[slice2], out=res[slice1])
res
[
tuple
(
slice1
)]
=
np
.
conjugate
(
tmp
[
tuple
(
slice2
)])
for
i
,
ax
in
enumerate
(
axes
[:
-
1
]):
dim1
=
tuple
([
slice
(
None
)]
*
ax
+
[
slice
(
0
,
1
)])
axes2
=
axes
[:
i
]
+
axes
[
i
+
1
:]
_fill_upper_half_complex
(
tmp
[
dim1
],
res
[
dim1
],
axes2
)
_fill_upper_half_complex
(
tmp
,
res
,
axes
)
return
res
return
_fill_complex_array
(
tmp
,
np
.
empty_like
(
a
,
dtype
=
tmp
.
dtype
),
axes
)
def
my_fftn
(
a
,
axes
=
None
):
return
fftn
(
a
,
axes
=
axes
,
**
_fft_extra_args
)
nifty5/field.py
View file @
5b713064
...
...
@@ -327,6 +327,23 @@ class Field(object):
return
Field
.
from_local_data
(
self
.
_domain
,
aout
)
def
outer
(
self
,
x
):
""" Computes the outer product of 'self' with x.
Parameters
----------
x : Field
Returns
----------
Field, lives on the product space of self.domain and x.domain
"""
if
not
isinstance
(
x
,
Field
):
raise
TypeError
(
"The multiplier must be an instance of "
+
"the NIFTy field class"
)
from
.operators.outer_product_operator
import
OuterProduct
return
OuterProduct
(
self
,
x
.
domain
)(
x
)
def
vdot
(
self
,
x
=
None
,
spaces
=
None
):
""" Computes the dot product of 'self' with x.
...
...
nifty5/library/correlated_fields.py
View file @
5b713064
...
...
@@ -28,15 +28,18 @@ from ..operators.simple_linear_operators import FieldAdapter
from
..operators.scaling_operator
import
ScalingOperator
def
CorrelatedField
(
s_space
,
amplitude_model
):
def
CorrelatedField
(
s_space
,
amplitude_model
,
name
=
'xi'
):
'''
Function for construction of correlated fields
Parameters
----------
s_space : Field domain
amplitude_model : model for correlation structure
s_space : Domain
Field domain
amplitude_model: Operator
model for correlation structure
name : string
MultiField component name
'''
h_space
=
s_space
.
get_default_codomain
()
ht
=
HarmonicTransformOperator
(
h_space
,
s_space
)
...
...
@@ -45,11 +48,11 @@ def CorrelatedField(s_space, amplitude_model):
A
=
power_distributor
(
amplitude_model
)
vol
=
h_space
.
scalar_dvol
vol
=
ScalingOperator
(
vol
**
(
-
0.5
),
h_space
)
return
ht
(
vol
(
A
)
*
FieldAdapter
(
MultiDomain
.
make
({
"xi"
:
h_space
}),
"xi"
))
return
ht
(
vol
(
A
)
*
FieldAdapter
(
MultiDomain
.
make
({
name
:
h_space
}),
name
))
def
MfCorrelatedField
(
s_space_spatial
,
s_space_energy
,
amplitude_model_spatial
,
amplitude_model_energy
):
amplitude_model_energy
,
name
=
"xi"
):
'''
Method for construction of correlated multi-frequency fields
'''
...
...
@@ -67,11 +70,11 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
pd_energy
=
PowerDistributor
(
pd_spatial
.
domain
,
p_space_energy
,
1
)
pd
=
pd_spatial
(
pd_energy
)
dom_distr_spatial
=
ContractionOperator
(
pd
.
domain
,
0
).
adjoint
dom_distr_energy
=
ContractionOperator
(
pd
.
domain
,
1
).
adjoint
dom_distr_spatial
=
ContractionOperator
(
pd
.
domain
,
1
).
adjoint
dom_distr_energy
=
ContractionOperator
(
pd
.
domain
,
0
).
adjoint
a_spatial
=
dom_distr_spatial
(
amplitude_model_spatial
)
a_energy
=
dom_distr_energy
(
amplitude_model_energy
)
a
=
a_spatial
*
a_energy
A
=
pd
(
a
)
return
ht
(
A
*
FieldAdapter
(
MultiDomain
.
make
({
"xi"
:
h_space
}),
"xi"
))
return
ht
(
A
*
FieldAdapter
(
MultiDomain
.
make
({
name
:
h_space
}),
name
))
nifty5/linearization.py
View file @
5b713064
...
...
@@ -126,6 +126,19 @@ class Linearization(object):
def
__rmul__
(
self
,
other
):
return
self
.
__mul__
(
other
)
def
outer
(
self
,
other
):
from
.operators.outer_product_operator
import
OuterProduct
if
isinstance
(
other
,
Linearization
):
return
self
.
new
(
OuterProduct
(
self
.
_val
,
other
.
target
)(
other
.
_val
),
OuterProduct
(
self
.
_jac
(
self
.
_val
),
other
.
target
).
_myadd
(
OuterProduct
(
self
.
_val
,
other
.
target
)(
other
.
_jac
),
False
))
if
np
.
isscalar
(
other
):
return
self
.
__mul__
(
other
)
if
isinstance
(
other
,
(
Field
,
MultiField
)):
return
self
.
new
(
OuterProduct
(
self
.
_val
,
other
.
domain
)(
other
),
OuterProduct
(
self
.
_jac
(
self
.
_val
),
other
.
domain
))
def
vdot
(
self
,
other
):
from
.operators.simple_linear_operators
import
VdotOperator
if
isinstance
(
other
,
(
Field
,
MultiField
)):
...
...
@@ -137,11 +150,27 @@ class Linearization(object):
VdotOperator
(
self
.
_val
)(
other
.
_jac
)
+
VdotOperator
(
other
.
_val
)(
self
.
_jac
))
def
sum
(
self
):
from
.operators.simple_linear_operators
import
SumReductionOperator
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
sum
()),
SumReductionOperator
(
self
.
_jac
.
target
)(
self
.
_jac
))
def
sum
(
self
,
spaces
=
None
):
from
.operators.contraction_operator
import
ContractionOperator
if
spaces
is
None
:
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
sum
()),
ContractionOperator
(
self
.
_jac
.
target
,
None
)(
self
.
_jac
))
else
:
return
self
.
new
(
self
.
_val
.
sum
(
spaces
),
ContractionOperator
(
self
.
_jac
.
target
,
spaces
)(
self
.
_jac
))
def
integrate
(
self
,
spaces
=
None
):
from
.operators.contraction_operator
import
ContractionOperator
if
spaces
is
None
:
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
integrate
()),
ContractionOperator
(
self
.
_jac
.
target
,
None
,
1
)(
self
.
_jac
))
else
:
return
self
.
new
(
self
.
_val
.
integrate
(
spaces
),
ContractionOperator
(
self
.
_jac
.
target
,
spaces
,
1
)(
self
.
_jac
))
def
exp
(
self
):
tmp
=
self
.
_val
.
exp
()
...
...
@@ -178,6 +207,14 @@ class Linearization(object):
return
Linearization
(
field
,
NullOperator
(
field
.
domain
,
field
.
domain
),
want_metric
=
want_metric
)
@
staticmethod
def
make_const_empty_input
(
field
,
want_metric
=
False
):
from
.operators.simple_linear_operators
import
NullOperator
from
.multi_domain
import
MultiDomain
return
Linearization
(
field
,
NullOperator
(
MultiDomain
.
make
({}),
field
.
domain
),
want_metric
=
want_metric
)
@
staticmethod
def
make_partial_var
(
field
,
constants
,
want_metric
=
False
):
from
.operators.scaling_operator
import
ScalingOperator
...
...
nifty5/operators/contraction_operator.py
View file @
5b713064
...
...
@@ -37,28 +37,37 @@ class ContractionOperator(LinearOperator):
----------
domain : Domain, tuple of Domain or DomainTuple
spaces : int or tuple of int
The elements of "domain" which are taken as target.
The elements of "domain" which are contracted.
weight : int, default=0
if nonzero, the fields living on self.domain are weighted with the
specified power.
"""
def
__init__
(
self
,
domain
,
spaces
):
def
__init__
(
self
,
domain
,
spaces
,
weight
=
0
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_domain
))
self
.
_target
=
[
dom
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
i
in
self
.
_spaces
dom
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
i
not
in
self
.
_spaces
]
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
self
.
_weight
=
weight
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
ADJOINT_TIMES
:
ldat
=
x
.
lo
c
al_data
if
0
in
self
.
_spaces
else
x
.
to_g
lo
b
al_data
()
ldat
=
x
.
to_g
lo
b
al_data
()
if
0
in
self
.
_spaces
else
x
.
lo
c
al_data
shp
=
[]
for
i
,
dom
in
enumerate
(
self
.
_domain
):
tmp
=
dom
.
shape
if
i
>
0
else
dom
.
local_shape
shp
+=
tmp
if
i
in
self
.
_spaces
else
(
1
,)
*
len
(
dom
.
shape
)
shp
+=
tmp
if
i
not
in
self
.
_spaces
else
(
1
,)
*
len
(
dom
.
shape
)
ldat
=
np
.
broadcast_to
(
ldat
.
reshape
(
shp
),
self
.
_domain
.
local_shape
)
return
Field
.
from_local_data
(
self
.
_domain
,
ldat
)
res
=
Field
.
from_local_data
(
self
.
_domain
,
ldat
)
if
self
.
_weight
!=
0
:
res
=
res
.
weight
(
self
.
_weight
,
spaces
=
self
.
_spaces
)
return
res
else
:
return
x
.
sum
(
[
s
for
s
in
range
(
len
(
x
.
domain
))
if
s
not
in
self
.
_spaces
])
if
self
.
_weight
!=
0
:
x
=
x
.
weight
(
self
.
_weight
,
spaces
=
self
.
_spaces
)
res
=
x
.
sum
(
self
.
_spaces
)
return
res
if
isinstance
(
res
,
Field
)
else
Field
.
scalar
(
res
)
nifty5/operators/harmonic_operators.py
View file @
5b713064
...
...
@@ -20,7 +20,7 @@ from __future__ import absolute_import, division, print_function
import
numpy
as
np
from
..
import
dobj
,
utilities
from
..
import
dobj
,
utilities
,
fft
from
..compat
import
*
from
..domain_tuple
import
DomainTuple
from
..domains.gl_space
import
GLSpace
...
...
@@ -74,8 +74,6 @@ class FFTOperator(LinearOperator):
adom
.
check_codomain
(
target
)
target
.
check_codomain
(
adom
)
utilities
.
fft_prep
()
def
apply
(
self
,
x
,
mode
):
from
pyfftw.interfaces.numpy_fft
import
fftn
,
ifftn
self
.
_check_input
(
x
,
mode
)
...
...
@@ -174,8 +172,6 @@ class HartleyOperator(LinearOperator):
adom
.
check_codomain
(
target
)
target
.
check_codomain
(
adom
)
utilities
.
fft_prep
()
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
utilities
.
iscomplextype
(
x
.
dtype
):
...
...
@@ -190,14 +186,14 @@ class HartleyOperator(LinearOperator):
oldax
=
dobj
.
distaxis
(
x
.
val
)
if
oldax
not
in
axes
:
# straightforward, no redistribution needed
ldat
=
x
.
local_data
ldat
=
utilities
.
hartley
(
ldat
,
axes
=
axes
)
ldat
=
fft
.
hartley
(
ldat
,
axes
=
axes
)
tmp
=
dobj
.
from_local_data
(
x
.
val
.
shape
,
ldat
,
distaxis
=
oldax
)
elif
len
(
axes
)
<
len
(
x
.
shape
)
or
len
(
axes
)
==
1
:
# we can use one Hartley pass in between the redistributions
tmp
=
dobj
.
redistribute
(
x
.
val
,
nodist
=
axes
)
newax
=
dobj
.
distaxis
(
tmp
)
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
utilities
.
hartley
(
ldat
,
axes
=
axes
)
ldat
=
fft
.
hartley
(
ldat
,
axes
=
axes
)
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat
,
distaxis
=
newax
)
tmp
=
dobj
.
redistribute
(
tmp
,
dist
=
oldax
)
else
:
# two separate, full FFTs needed
...
...
@@ -211,7 +207,7 @@ class HartleyOperator(LinearOperator):
rem_axes
=
tuple
(
i
for
i
in
axes
if
i
!=
oldax
)
tmp
=
x
.
val
ldat
=
dobj
.
local_data
(
tmp
)
ldat
=
utilities
.
my_fftn_r2c
(
ldat
,
axes
=
rem_axes
)
ldat
=
fft
.
my_fftn_r2c
(
ldat
,
axes
=
rem_axes
)
if
oldax
!=
0
:
raise
ValueError
(
"bad distribution"
)
ldat2
=
ldat
.
reshape
((
ldat
.
shape
[
0
],
...
...
@@ -220,7 +216,7 @@ class HartleyOperator(LinearOperator):
tmp
=
dobj
.
from_local_data
(
shp2d
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
)
ldat2
=
utilities
.
my_fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
fft
.
my_fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
ldat2
.
real
+
ldat2
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
...
...
@@ -289,6 +285,10 @@ class SHTOperator(LinearOperator):
else
:
self
.
sjob
.
set_Healpix_geometry
(
target
.
nside
)
def
__reduce__
(
self
):
return
(
_unpickleSHTOperator
,
(
self
.
_domain
,
self
.
_target
[
self
.
_space
],
self
.
_space
))
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
utilities
.
iscomplextype
(
x
.
dtype
):
...
...
@@ -337,6 +337,10 @@ class SHTOperator(LinearOperator):
return
Field
(
tdom
,
dobj
.
ensure_default_distributed
(
odat
))
def
_unpickleSHTOperator
(
*
args
):
return
SHTOperator
(
*
args
)
class
HarmonicTransformOperator
(
LinearOperator
):
"""Transforms between a harmonic domain and a position domain counterpart.
...
...
nifty5/operators/operator_adapter.py
View file @
5b713064
...
...
@@ -69,6 +69,6 @@ class OperatorAdapter(LinearOperator):
def
__repr__
(
self
):
from
..utilities
import
indent
mode
=
[
"adjoint"
,
"inverse"
,
"adjoint inverse"
][
self
.
_trafo
]
mode
=
[
"adjoint"
,
"inverse"
,
"adjoint inverse"
][
self
.
_trafo
-
1
]
res
=
"OperatorAdapter: {}
\n
"
.
format
(
mode
)
return
res
+
indent
(
self
.
_op
.
__repr__
())
nifty5/operators/outer_product_operator.py
0 → 100644
View file @
5b713064
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
__future__
import
absolute_import
,
division
,
print_function
import
itertools
import
numpy
as
np
from
..
import
dobj
,
utilities
from
..compat
import
*
from
..domain_tuple
import
DomainTuple
from
..domains.rg_space
import
RGSpace
from
..multi_field
import
MultiField
,
MultiDomain
from
..field
import
Field
from
.linear_operator
import
LinearOperator
import
operator
class
OuterProduct
(
LinearOperator
):
"""Performs the pointwise outer product of two fields.
Parameters
---------
field: Field,
domain: DomainTuple, the domain of the input field
---------
"""
def
__init__
(
self
,
field
,
domain
):
self
.
_domain
=
domain
self
.
_field
=
field
self
.
_target
=
DomainTuple
.
make
(
tuple
(
sub_d
for
sub_d
in
field
.
domain
.
_dom
+
domain
.
_dom
))
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
Field
.
from_global_data
(
self
.
_target
,
np
.
multiply
.
outer
(
self
.
_field
.
to_global_data
(),
x
.
to_global_data
()))
axes
=
len
(
self
.
_field
.
shape
)
return
Field
.
from_global_data
(
self
.
_domain
,
np
.
tensordot
(
self
.
_field
.
to_global_data
(),
x
.
to_global_data
(),
axes
))
nifty5/operators/qht_operator.py
View file @
5b713064
...
...
@@ -20,9 +20,10 @@ from __future__ import absolute_import, division, print_function
from
..
import
dobj
from
..compat
import
*
from
..
import
fft
from
..domain_tuple
import
DomainTuple
from
..field
import
Field
from
..utilities
import
hartley
,
infer_space
from
..utilities
import
infer_space
from
.linear_operator
import
LinearOperator
...
...
@@ -69,5 +70,5 @@ class QHTOperator(LinearOperator):
for
i
in
rng
:
sl
=
(
slice
(
None
),)
*
i
+
(
slice
(
1
,
None
),)
v
,
tmp
=
dobj
.
ensure_not_distributed
(
v
,
(
i
,))
tmp
[
sl
]
=
hartley
(
tmp
[
sl
],
axes
=
(
i
,))
tmp
[
sl
]
=
fft
.
hartley
(
tmp
[
sl
],
axes
=
(
i
,))
return
Field
(
self
.
_tgt
(
mode
),
dobj
.
ensure_default_distributed
(
v
))
nifty5/operators/simple_linear_operators.py
View file @
5b713064
...
...
@@ -43,19 +43,6 @@ class VdotOperator(LinearOperator):
return
self
.
_field
*
x
.
local_data
[()]
class
SumReductionOperator
(
LinearOperator
):
def
__init__
(
self
,
domain
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_target
=
DomainTuple
.
scalar_domain
()
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
Field
.
scalar
(
x
.
sum
())
return
full
(
self
.
_domain
,
x
.
local_data
[()])
class
ConjugationOperator
(
EndomorphicOperator
):
def
__init__
(
self
,
domain
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
...
...
nifty5/operators/symmetrizing_operator.py
View file @
5b713064
...
...
@@ -42,4 +42,5 @@ class SymmetrizingOperator(EndomorphicOperator):
lead
=
(
slice
(
None
),)
*
i
v
,
loc
=
dobj
.
ensure_not_distributed
(
v
,
(
i
,))
loc
[
lead
+
(
slice
(
1
,
None
),)]
-=
loc
[
lead
+
(
slice
(
None
,
0
,
-
1
),)]
loc
/=
2
return
Field
(
self
.
target
,
dobj
.
ensure_default_distributed
(
v
))
nifty5/plot.py
View file @
5b713064
...
...
@@ -75,7 +75,7 @@ def _makeplot(name):
plt
.
close
()
return
extension
=
os
.
path
.
splitext
(
name
)[
1
]
if
extension
in
(
".pdf"
,
".png"
):
if
extension
in
(
".pdf"
,
".png"
,
".svg"
):
plt
.
savefig
(
name
)
plt
.
close
()
else
:
...
...
nifty5/utilities.py
View file @
5b713064
...
...
@@ -22,15 +22,13 @@ import collections
from
itertools
import
product
import
numpy
as
np
import
pyfftw
from
future.utils
import
with_metaclass
from
pyfftw.interfaces.numpy_fft
import
fftn
,
rfftn
from
.compat
import
*
__all__
=
[
"get_slice_list"
,
"safe_cast"
,
"parse_spaces"
,
"infer_space"
,
"memo"
,
"NiftyMetaBase"
,
"
fft_prep"
,
"hartley"
,
"my_fftn_r2c
"
,
"my_fftn"
,
"my_sum"
,
"my_lincomb_simple"
,
"my_lincomb"
,
"indent"
,
"memo"
,
"NiftyMetaBase"
,
"
my_sum"
,
"my_lincomb_simple
"
,
"my_lincomb"
,
"indent"
,
"my_product"
,
"frozendict"
,
"special_add_at"
,
"iscomplextype"
]
...
...
@@ -187,117 +185,6 @@ def NiftyMetaBase():
return
with_metaclass
(
NiftyMeta
,
type
(
'NewBase'
,
(
object
,),
{}))
def
nthreads
():
if
nthreads
.
_val
is
None
:
import
os
nthreads
.
_val
=
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"1"
))
return
nthreads
.
_val
nthreads
.
_val
=
None
# Optional extra arguments for the FFT calls
# _fft_extra_args = {}
# if exact reproducibility is needed, use this:
_fft_extra_args
=
dict
(
planner_effort
=
'FFTW_ESTIMATE'
)
def
fft_prep
():
if
not
fft_prep
.
_initialized
:
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
fft_prep
.
_initialized
=
True
fft_prep
.
_initialized
=
False
def
hartley
(
a
,
axes
=
None
):
# Check if the axes provided are valid given the shape
if
axes
is
not
None
and
\
not
all
(
axis
<
len
(
a
.
shape
)
for
axis
in
axes
):
raise
ValueError
(
"Provided axes do not match array shape"
)
if
iscomplextype
(
a
.
dtype
):
raise
TypeError
(
"Hartley transform requires real-valued arrays."
)
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
(),
**
_fft_extra_args
)
def
_fill_array
(
tmp
,
res
,
axes
):
if
axes
is
None
:
axes
=
tuple
(
range
(
tmp
.
ndim
))
lastaxis
=
axes
[
-
1
]
ntmplast
=
tmp
.
shape
[
lastaxis
]
slice1
=
(
slice
(
None
),)
*
lastaxis
+
(
slice
(
0
,
ntmplast
),)
np
.
add
(
tmp
.
real
,
tmp
.
imag
,
out
=
res
[
slice1
])
def
_fill_upper_half
(
tmp
,
res
,
axes
):
lastaxis
=
axes
[
-
1
]
nlast
=
res
.
shape
[
lastaxis
]
ntmplast
=
tmp
.
shape
[
lastaxis
]
nrem
=
nlast
-
ntmplast
slice1
=
[
slice
(
None
)]
*
lastaxis
+
[
slice
(
ntmplast
,
None
)]