Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
0e8e4be1
Commit
0e8e4be1
authored
May 29, 2018
by
Martin Reinecke
Browse files
Merge branch 'NIFTy_4' into yango_minimizer
parents
cf24f350
0df979f2
Changes
42
Hide whitespace changes
Inline
Side-by-side
demos/critical_filtering.py
View file @
0e8e4be1
...
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d
=
noiseless_data
+
n
m0
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
Field
.
full
(
p_space
,
-
4.
)
m0
=
ift
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
full
(
p_space
,
-
4.
)
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
plotdict
=
{
"colormap"
:
"Planck-like"
}
...
...
demos/krylov_sampling.py
View file @
0e8e4be1
...
...
@@ -31,7 +31,7 @@ d = R(s_x) + n
R_p
=
R
*
FFT
*
A
j
=
R_p
.
adjoint
(
N
.
inverse
(
d
))
D_inv
=
ift
.
SandwichOperator
(
R_p
,
N
.
inverse
)
+
S
.
inverse
D_inv
=
ift
.
SandwichOperator
.
make
(
R_p
,
N
.
inverse
)
+
S
.
inverse
N_samps
=
200
...
...
@@ -67,8 +67,8 @@ plt.legend()
plt
.
savefig
(
'Krylov_samples_residuals.png'
)
plt
.
close
()
D_hat_old
=
ift
.
Field
.
zeros
(
x_space
).
to_global_data
()
D_hat_new
=
ift
.
Field
.
zeros
(
x_space
).
to_global_data
()
D_hat_old
=
ift
.
full
(
x_space
,
0.
).
to_global_data
()
D_hat_new
=
ift
.
full
(
x_space
,
0.
).
to_global_data
()
for
i
in
range
(
N_samps
):
D_hat_old
+=
sky
(
samps_old
[
i
]).
to_global_data
()
**
2
D_hat_new
+=
sky
(
samps
[
i
]).
to_global_data
()
**
2
...
...
demos/nonlinear_critical_filter.py
View file @
0e8e4be1
...
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d
=
noiseless_data
+
n
m0
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
Field
.
full
(
p_space
,
-
4.
)
m0
=
ift
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
full
(
p_space
,
-
4.
)
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
100
,
...
...
demos/nonlinear_wiener_filter.py
View file @
0e8e4be1
...
...
@@ -36,7 +36,7 @@ if __name__ == "__main__":
d_space
=
R
.
target
p_op
=
ift
.
create_power_operator
(
h_space
,
p_spec
)
power
=
ift
.
sqrt
(
p_op
(
ift
.
Field
.
full
(
h_space
,
1.
)))
power
=
ift
.
sqrt
(
p_op
(
ift
.
full
(
h_space
,
1.
)))
# Creating the mock data
true_sky
=
nonlinearity
(
HT
(
power
*
sh
))
...
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
inverter
=
ift
.
ConjugateGradient
(
controller
=
ICI
)
# initial guess
m
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
m
=
ift
.
full
(
h_space
,
1e-7
)
map_energy
=
ift
.
library
.
NonlinearWienerFilterEnergy
(
m
,
d
,
R
,
nonlinearity
,
HT
,
power
,
N
,
S
,
inverter
=
inverter
)
...
...
demos/poisson_demo.py
View file @
0e8e4be1
...
...
@@ -80,12 +80,12 @@ if __name__ == "__main__":
IC
=
ift
.
GradientNormController
(
name
=
"inverter"
,
iteration_limit
=
500
,
tol_abs_gradnorm
=
1e-3
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
IC
)
D
=
(
ift
.
SandwichOperator
(
R
,
N
.
inverse
)
+
Phi_h
.
inverse
).
inverse
D
=
(
ift
.
SandwichOperator
.
make
(
R
,
N
.
inverse
)
+
Phi_h
.
inverse
).
inverse
D
=
ift
.
InversionEnabler
(
D
,
inverter
,
approximation
=
Phi_h
)
m
=
HT
(
D
(
j
))
# Uncertainty
D
=
ift
.
SandwichOperator
(
aHT
,
D
)
# real space propagator
D
=
ift
.
SandwichOperator
.
make
(
aHT
,
D
)
# real space propagator
Dhat
=
ift
.
probe_with_posterior_samples
(
D
.
inverse
,
None
,
nprobes
=
nprobes
)[
1
]
sig
=
ift
.
sqrt
(
Dhat
)
...
...
@@ -113,7 +113,7 @@ if __name__ == "__main__":
d_domain
,
np
.
random
.
poisson
(
lam
.
local_data
).
astype
(
np
.
float64
))
# initial guess
psi0
=
ift
.
Field
.
full
(
h_domain
,
1e-7
)
psi0
=
ift
.
full
(
h_domain
,
1e-7
)
energy
=
ift
.
library
.
PoissonEnergy
(
psi0
,
data
,
R0
,
nonlin
,
HT
,
Phi_h
,
inverter
)
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
200
,
...
...
demos/wiener_filter_data_space_noiseless.py
0 → 100644
View file @
0e8e4be1
import
numpy
as
np
import
nifty4
as
ift
# TODO: MAKE RESPONSE MPI COMPATIBLE OR USE LOS RESPONSE INSTEAD
class
CustomResponse
(
ift
.
LinearOperator
):
"""
A custom operator that measures a specific points`
An operator that is a delta measurement at certain points
"""
def
__init__
(
self
,
domain
,
data_points
):
self
.
_domain
=
ift
.
DomainTuple
.
make
(
domain
)
self
.
_points
=
data_points
data_shape
=
ift
.
Field
.
full
(
domain
,
0.
).
to_global_data
()[
data_points
]
\
.
shape
self
.
_target
=
ift
.
DomainTuple
.
make
(
ift
.
UnstructuredDomain
(
data_shape
))
def
_times
(
self
,
x
):
d
=
np
.
zeros
(
self
.
_target
.
shape
,
dtype
=
np
.
float64
)
d
+=
x
.
to_global_data
()[
self
.
_points
]
return
ift
.
from_global_data
(
self
.
_target
,
d
)
def
_adjoint_times
(
self
,
d
):
x
=
np
.
zeros
(
self
.
_domain
.
shape
,
dtype
=
np
.
float64
)
x
[
self
.
_points
]
+=
d
.
to_global_data
()
return
ift
.
from_global_data
(
self
.
_domain
,
x
)
@
property
def
domain
(
self
):
return
self
.
_domain
@
property
def
target
(
self
):
return
self
.
_target
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
return
self
.
_times
(
x
)
if
mode
==
self
.
TIMES
else
self
.
_adjoint_times
(
x
)
@
property
def
capability
(
self
):
return
self
.
TIMES
|
self
.
ADJOINT_TIMES
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
43
)
# Set up physical constants
# Total length of interval or volume the field lives on, e.g. in meters
L
=
2.
# Typical distance over which the field is correlated (in same unit as L)
correlation_length
=
0.3
# Variance of field in position space sqrt(<|s_x|^2>) (in same unit as s)
field_variance
=
2.
# Smoothing length of response (in same unit as L)
response_sigma
=
0.01
# typical noise amplitude of the measurement
noise_level
=
0.
# Define resolution (pixels per dimension)
N_pixels
=
256
# Set up derived constants
k_0
=
1.
/
correlation_length
# defining a power spectrum with the right correlation length
# we later set the field variance to the desired value
unscaled_pow_spec
=
(
lambda
k
:
1.
/
(
1
+
k
/
k_0
)
**
4
)
pixel_width
=
L
/
N_pixels
# Set up the geometry
s_space
=
ift
.
RGSpace
([
N_pixels
,
N_pixels
],
distances
=
pixel_width
)
h_space
=
s_space
.
get_default_codomain
()
s_var
=
ift
.
get_signal_variance
(
unscaled_pow_spec
,
h_space
)
pow_spec
=
(
lambda
k
:
unscaled_pow_spec
(
k
)
/
s_var
*
field_variance
**
2
)
HT
=
ift
.
HarmonicTransformOperator
(
h_space
,
s_space
)
# Create mock data
Sh
=
ift
.
create_power_operator
(
h_space
,
power_spectrum
=
pow_spec
)
sh
=
Sh
.
draw_sample
()
Rx
=
CustomResponse
(
s_space
,
[
np
.
arange
(
0
,
N_pixels
,
5
)[:,
np
.
newaxis
],
np
.
arange
(
0
,
N_pixels
,
2
)[
np
.
newaxis
,
:]])
ift
.
extra
.
consistency_check
(
Rx
)
a
=
ift
.
Field
.
from_random
(
'normal'
,
s_space
)
b
=
ift
.
Field
.
from_random
(
'normal'
,
Rx
.
target
)
R
=
Rx
*
HT
noiseless_data
=
R
(
sh
)
N
=
ift
.
ScalingOperator
(
noise_level
**
2
,
R
.
target
)
n
=
N
.
draw_sample
()
d
=
noiseless_data
+
n
# Wiener filter
IC
=
ift
.
GradientNormController
(
name
=
"inverter"
,
iteration_limit
=
1000
,
tol_abs_gradnorm
=
0.0001
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
IC
)
# setting up measurement precision matrix M
M
=
(
ift
.
SandwichOperator
.
make
(
R
.
adjoint
,
Sh
)
+
N
)
M
=
ift
.
InversionEnabler
(
M
,
inverter
)
m
=
Sh
(
R
.
adjoint
(
M
.
inverse_times
(
d
)))
# Plotting
backprojection
=
Rx
.
adjoint
(
d
)
reweighted_backprojection
=
(
backprojection
/
backprojection
.
max
()
*
HT
(
sh
).
max
())
zmax
=
max
(
HT
(
sh
).
max
(),
reweighted_backprojection
.
max
(),
HT
(
m
).
max
())
zmin
=
min
(
HT
(
sh
).
min
(),
reweighted_backprojection
.
min
(),
HT
(
m
).
min
())
plotdict
=
{
"colormap"
:
"Planck-like"
,
"zmax"
:
zmax
,
"zmin"
:
zmin
}
ift
.
plot
(
HT
(
sh
),
name
=
"mock_signal.png"
,
**
plotdict
)
ift
.
plot
(
backprojection
,
name
=
"backprojected_data.png"
,
**
plotdict
)
ift
.
plot
(
HT
(
m
),
name
=
"reconstruction.png"
,
**
plotdict
)
demos/wiener_filter_easy.py
View file @
0e8e4be1
import
numpy
as
np
import
nifty4
as
ift
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
43
)
# Set up physical constants
...
...
@@ -12,21 +13,25 @@ if __name__ == "__main__":
field_variance
=
2.
# Smoothing length of response (in same unit as L)
response_sigma
=
0.01
# typical noise amplitude of the measurement
noise_level
=
1.
# Define resolution (pixels per dimension)
N_pixels
=
256
# Set up derived constants
k_0
=
1.
/
correlation_length
# Note that field_variance**2 = a*k_0/4. for this analytic form of power
# spectrum
a
=
field_variance
**
2
/
k_0
*
4.
pow_spec
=
(
lambda
k
:
a
/
(
1
+
k
/
k_0
)
**
4
)
#defining a power spectrum with the right correlation length
#we later set the field variance to the desired value
unscaled_pow_spec
=
(
lambda
k
:
1.
/
(
1
+
k
/
k_0
)
**
4
)
pixel_width
=
L
/
N_pixels
# Set up the geometry
s_space
=
ift
.
RGSpace
([
N_pixels
,
N_pixels
],
distances
=
pixel_width
)
h_space
=
s_space
.
get_default_codomain
()
s_var
=
ift
.
get_signal_variance
(
unscaled_pow_spec
,
h_space
)
pow_spec
=
(
lambda
k
:
unscaled_pow_spec
(
k
)
/
s_var
*
field_variance
**
2
)
HT
=
ift
.
HarmonicTransformOperator
(
h_space
,
s_space
)
# Create mock data
...
...
@@ -36,11 +41,8 @@ if __name__ == "__main__":
R
=
HT
*
ift
.
create_harmonic_smoothing_operator
((
h_space
,),
0
,
response_sigma
)
noiseless_data
=
R
(
sh
)
signal_to_noise
=
1.
noise_amplitude
=
noiseless_data
.
val
.
std
()
/
signal_to_noise
N
=
ift
.
ScalingOperator
(
noise_amplitude
**
2
,
s_space
)
N
=
ift
.
ScalingOperator
(
noise_level
**
2
,
s_space
)
n
=
N
.
draw_sample
()
d
=
noiseless_data
+
n
...
...
@@ -51,7 +53,7 @@ if __name__ == "__main__":
IC
=
ift
.
GradientNormController
(
name
=
"inverter"
,
iteration_limit
=
500
,
tol_abs_gradnorm
=
0.1
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
IC
)
D
=
(
ift
.
SandwichOperator
(
R
,
N
.
inverse
)
+
Sh
.
inverse
).
inverse
D
=
(
ift
.
SandwichOperator
.
make
(
R
,
N
.
inverse
)
+
Sh
.
inverse
).
inverse
D
=
ift
.
InversionEnabler
(
D
,
inverter
,
approximation
=
Sh
)
m
=
D
(
j
)
...
...
demos/wiener_filter_via_hamiltonian.py
View file @
0e8e4be1
...
...
@@ -50,7 +50,7 @@ if __name__ == "__main__":
inverter
=
ift
.
ConjugateGradient
(
controller
=
ctrl
)
controller
=
ift
.
GradientNormController
(
name
=
"min"
,
tol_abs_gradnorm
=
0.1
)
minimizer
=
ift
.
RelaxedNewton
(
controller
=
controller
)
m0
=
ift
.
Field
.
zeros
(
h_space
)
m0
=
ift
.
full
(
h_space
,
0.
)
# Initialize Wiener filter energy
energy
=
ift
.
library
.
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S
,
...
...
nifty4/__init__.py
View file @
0e8e4be1
...
...
@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple
from
.operators
import
*
from
.field
import
Field
,
sqrt
,
exp
,
log
from
.field
import
Field
from
.probing.utils
import
probe_with_posterior_samples
,
probe_diagonal
,
\
StatCalculator
...
...
nifty4/data_objects/distributed_do.py
View file @
0e8e4be1
...
...
@@ -20,6 +20,7 @@ import numpy as np
from
.random
import
Random
from
mpi4py
import
MPI
import
sys
from
functools
import
reduce
_comm
=
MPI
.
COMM_WORLD
ntask
=
_comm
.
Get_size
()
...
...
@@ -145,20 +146,29 @@ class data_object(object):
def
sum
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"sum"
,
MPI
.
SUM
,
axis
)
def
prod
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"prod"
,
MPI
.
PROD
,
axis
)
def
min
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"min"
,
MPI
.
MIN
,
axis
)
def
max
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"max"
,
MPI
.
MAX
,
axis
)
def
mean
(
self
):
return
self
.
sum
()
/
self
.
size
def
mean
(
self
,
axis
=
None
):
if
axis
is
None
:
sz
=
self
.
size
else
:
sz
=
reduce
(
lambda
x
,
y
:
x
*
y
,
[
self
.
shape
[
i
]
for
i
in
axis
])
return
self
.
sum
(
axis
)
/
sz
def
std
(
self
):
return
np
.
sqrt
(
self
.
var
())
def
std
(
self
,
axis
=
None
):
return
np
.
sqrt
(
self
.
var
(
axis
))
# FIXME: to be improved!
def
var
(
self
):
def
var
(
self
,
axis
=
None
):
if
axis
is
not
None
and
len
(
axis
)
!=
len
(
self
.
shape
):
raise
ValueError
(
"functionality not yet supported"
)
return
(
abs
(
self
-
self
.
mean
())
**
2
).
mean
()
def
_binary_helper
(
self
,
other
,
op
):
...
...
nifty4/domain_tuple.py
View file @
0e8e4be1
...
...
@@ -34,7 +34,9 @@ class DomainTuple(object):
"""
_tupleCache
=
{}
def
__init__
(
self
,
domain
):
def
__init__
(
self
,
domain
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
self
.
_dom
=
self
.
_parse_domain
(
domain
)
self
.
_axtuple
=
self
.
_get_axes_tuple
()
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
_dom
)
...
...
@@ -72,7 +74,7 @@ class DomainTuple(object):
obj
=
DomainTuple
.
_tupleCache
.
get
(
domain
)
if
obj
is
not
None
:
return
obj
obj
=
DomainTuple
(
domain
)
obj
=
DomainTuple
(
domain
,
_callingfrommake
=
True
)
DomainTuple
.
_tupleCache
[
domain
]
=
obj
return
obj
...
...
nifty4/domains/domain.py
View file @
0e8e4be1
...
...
@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase
class
Domain
(
NiftyMetaBase
()):
"""The abstract class repesenting a (structured or unstructured) domain.
"""
def
__init__
(
self
):
self
.
_hash
=
None
@
abc
.
abstractmethod
def
__repr__
(
self
):
...
...
@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()):
Only members that are explicitly added to
:attr:`._needed_for_hash` will be used for hashing.
"""
result_hash
=
0
for
key
in
self
.
_needed_for_hash
:
result_hash
^=
hash
(
vars
(
self
)[
key
])
return
result_hash
if
self
.
_hash
is
None
:
h
=
0
for
key
in
self
.
_needed_for_hash
:
h
^=
hash
(
vars
(
self
)[
key
])
self
.
_hash
=
h
return
self
.
_hash
def
__eq__
(
self
,
x
):
"""Checks whether two domains are equal.
...
...
nifty4/domains/lm_space.py
View file @
0e8e4be1
...
...
@@ -19,7 +19,7 @@
from
__future__
import
division
import
numpy
as
np
from
.structured_domain
import
StructuredDomain
from
..field
import
Field
,
exp
from
..field
import
Field
class
LMSpace
(
StructuredDomain
):
...
...
@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain):
# cf. "All-sky convolution for polarimetry experiments"
# by Challinor et al.
# http://arxiv.org/abs/astro-ph/0008228
from
..sugar
import
exp
res
=
x
+
1.
res
*=
x
res
*=
-
0.5
*
sigma
*
sigma
...
...
nifty4/domains/rg_space.py
View file @
0e8e4be1
...
...
@@ -21,7 +21,7 @@ from builtins import range
from
functools
import
reduce
import
numpy
as
np
from
.structured_domain
import
StructuredDomain
from
..field
import
Field
,
exp
from
..field
import
Field
from
..
import
dobj
...
...
@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain):
@
staticmethod
def
_kernel
(
x
,
sigma
):
from
..sugar
import
exp
tmp
=
x
*
x
tmp
*=
-
2.
*
np
.
pi
*
np
.
pi
*
sigma
*
sigma
exp
(
tmp
,
out
=
tmp
)
...
...
nifty4/extra/energy_tests.py
View file @
0e8e4be1
...
...
@@ -18,15 +18,19 @@
import
numpy
as
np
from
..field
import
Field
from
..sugar
import
from_random
__all__
=
[
"check_value_gradient_consistency"
,
"check_value_gradient_curvature_consistency"
]
def
_get_acceptable_energy
(
E
):
if
not
np
.
isfinite
(
E
.
value
):
val
=
E
.
value
if
not
np
.
isfinite
(
val
):
raise
ValueError
dir
=
Field
.
from_random
(
"normal"
,
E
.
position
.
domain
)
dir
=
from_random
(
"normal"
,
E
.
position
.
domain
)
dirder
=
E
.
gradient
.
vdot
(
dir
)
dir
*=
np
.
abs
(
val
)
/
np
.
abs
(
dirder
)
*
1e-5
# find a step length that leads to a "reasonable" energy
for
i
in
range
(
50
):
try
:
...
...
@@ -44,12 +48,13 @@ def _get_acceptable_energy(E):
def
check_value_gradient_consistency
(
E
,
tol
=
1e-6
,
ntries
=
100
):
for
_
in
range
(
ntries
):
E2
=
_get_acceptable_energy
(
E
)
val
=
E
.
value
dir
=
E2
.
position
-
E
.
position
Enext
=
E2
dirnorm
=
dir
.
norm
()
dirder
=
E
.
gradient
.
vdot
(
dir
)
/
dirnorm
for
i
in
range
(
50
):
if
abs
((
E2
.
value
-
E
.
val
ue
)
/
dirnorm
-
dirder
)
<
tol
:
if
abs
((
E2
.
value
-
val
)
/
dirnorm
-
dirder
)
<
tol
:
break
dir
*=
0.5
dirnorm
*=
0.5
...
...
nifty4/extra/operator_tests.py
View file @
0e8e4be1
...
...
@@ -17,17 +17,26 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..sugar
import
from_random
from
..field
import
Field
__all__
=
[
"consistency_check"
]
def
_assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
local_data
,
f2
.
local_data
,
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
_assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
def
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
f2
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
).
lock
())
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
...
...
@@ -37,15 +46,13 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
res
=
op
(
op
.
inverse_times
(
foo
).
lock
())
np
.
testing
.
assert_allclose
(
res
.
to_global_data
(),
res
.
to_global_data
(),
atol
=
atol
,
rtol
=
rtol
)
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
res
=
op
.
inverse_times
(
op
(
foo
).
lock
())
np
.
testing
.
assert_allclose
(
res
.
to_global_data
(),
foo
.
to_global_data
(),
atol
=
atol
,
rtol
=
rtol
)
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
def
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
...
...
nifty4/field.py
View file @
0e8e4be1
...
...
@@ -106,62 +106,10 @@ class Field(object):
raise
TypeError
(
"val must be a scalar"
)
return
Field
(
DomainTuple
.
make
(
domain
),
val
,
dtype
)
@
staticmethod
def
ones
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
1.
,
dtype
)
@
staticmethod
def
zeros
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
0.
,
dtype
)
@
staticmethod
def
empty
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
None
,
dtype
)
@
staticmethod
def
full_like
(
field
,
val
,
dtype
=
None
):
"""Creates a Field from a template, filled with a constant value.
Parameters
----------
field : Field
the template field, from which the domain is inferred
val : float/complex/int scalar
fill value. Data type of the field is inferred from val.
Returns
-------
Field
the newly created field
"""
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
return
Field
.
full
(
field
.
_domain
,
val
,
dtype
)
@
staticmethod
def
zeros_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
zeros
(
field
.
_domain
,
dtype
)
@
staticmethod
def
ones_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
ones
(
field
.
_domain
,
dtype
)
@
staticmethod
def
empty_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
empty
(
field
.
_domain
,
dtype
)
@
staticmethod
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
"""Returns a Field constructed from `domain` and `arr`.
...
...
@@ -287,6 +235,7 @@ class Field(object):
The value to fill the field with.
"""
self
.
_val
.
fill
(
fill_value
)
return
self
def
lock
(
self
):
"""Write-protect the data content of `self`.
...
...
@@ -370,6 +319,17 @@ class Field(object):
"""
return
Field
(
val
=
self
,
copy
=
True
)
def
empty_copy
(
self
):
""" Returns a Field with identical domain and data type, but
uninitialized data.
Returns
-------
Field
A copy of 'self', with uninitialized data.
"""
return
Field
(
self
.
_domain
,
dtype
=
self
.
dtype
)
def
locked_copy
(
self
):
""" Returns a read-only version of the Field.
...
...
@@ -503,8 +463,8 @@ class Field(object):
or Field (for partial dot products)
"""
if
not
isinstance
(
x
,
Field
):
raise
Valu
eError
(
"The dot-partner must be an instance of "
+
"the NIFTy field class"
)
raise
Typ
eError
(
"The dot-partner must be an instance of "
+
"the NIFTy field class"
)
if
x
.
_domain
!=
self
.
_domain
:
raise
ValueError
(
"Domain mismatch"
)
...
...
@@ -694,7 +654,8 @@ class Field(object):
if
self
.
scalar_weight
(
spaces
)
is
not
None
:
return
self
.
_contraction_helper
(
'mean'
,
spaces
)
# MR FIXME: not very efficient
tmp
=
self
.
weight
(
1
)