Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
087530b0
Commit
087530b0
authored
May 23, 2018
by
Martin Reinecke
Browse files
merge NIFTy_4
parents
5429bb64
ec50fcc0
Pipeline
#29620
passed with stages
in 4 minutes and 21 seconds
Changes
40
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
demos/critical_filtering.py
View file @
087530b0
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
# Creating the mock data
d
=
noiseless_data
+
n
d
=
noiseless_data
+
n
m0
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
m0
=
ift
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
Field
.
full
(
p_space
,
-
4.
)
t0
=
ift
.
full
(
p_space
,
-
4.
)
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
plotdict
=
{
"colormap"
:
"Planck-like"
}
plotdict
=
{
"colormap"
:
"Planck-like"
}
...
...
demos/krylov_sampling.py
View file @
087530b0
...
@@ -31,7 +31,7 @@ d = R(s_x) + n
...
@@ -31,7 +31,7 @@ d = R(s_x) + n
R_p
=
R
*
FFT
*
A
R_p
=
R
*
FFT
*
A
j
=
R_p
.
adjoint
(
N
.
inverse
(
d
))
j
=
R_p
.
adjoint
(
N
.
inverse
(
d
))
D_inv
=
ift
.
SandwichOperator
(
R_p
,
N
.
inverse
)
+
S
.
inverse
D_inv
=
ift
.
SandwichOperator
.
make
(
R_p
,
N
.
inverse
)
+
S
.
inverse
N_samps
=
200
N_samps
=
200
...
@@ -67,8 +67,8 @@ plt.legend()
...
@@ -67,8 +67,8 @@ plt.legend()
plt
.
savefig
(
'Krylov_samples_residuals.png'
)
plt
.
savefig
(
'Krylov_samples_residuals.png'
)
plt
.
close
()
plt
.
close
()
D_hat_old
=
ift
.
Field
.
zeros
(
x_space
).
to_global_data
()
D_hat_old
=
ift
.
full
(
x_space
,
0.
).
to_global_data
()
D_hat_new
=
ift
.
Field
.
zeros
(
x_space
).
to_global_data
()
D_hat_new
=
ift
.
full
(
x_space
,
0.
).
to_global_data
()
for
i
in
range
(
N_samps
):
for
i
in
range
(
N_samps
):
D_hat_old
+=
sky
(
samps_old
[
i
]).
to_global_data
()
**
2
D_hat_old
+=
sky
(
samps_old
[
i
]).
to_global_data
()
**
2
D_hat_new
+=
sky
(
samps
[
i
]).
to_global_data
()
**
2
D_hat_new
+=
sky
(
samps
[
i
]).
to_global_data
()
**
2
...
...
demos/nonlinear_critical_filter.py
View file @
087530b0
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
# Creating the mock data
d
=
noiseless_data
+
n
d
=
noiseless_data
+
n
m0
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
m0
=
ift
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
Field
.
full
(
p_space
,
-
4.
)
t0
=
ift
.
full
(
p_space
,
-
4.
)
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
100
,
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
100
,
...
...
demos/nonlinear_wiener_filter.py
View file @
087530b0
...
@@ -36,7 +36,7 @@ if __name__ == "__main__":
...
@@ -36,7 +36,7 @@ if __name__ == "__main__":
d_space
=
R
.
target
d_space
=
R
.
target
p_op
=
ift
.
create_power_operator
(
h_space
,
p_spec
)
p_op
=
ift
.
create_power_operator
(
h_space
,
p_spec
)
power
=
ift
.
sqrt
(
p_op
(
ift
.
Field
.
full
(
h_space
,
1.
)))
power
=
ift
.
sqrt
(
p_op
(
ift
.
full
(
h_space
,
1.
)))
# Creating the mock data
# Creating the mock data
true_sky
=
nonlinearity
(
HT
(
power
*
sh
))
true_sky
=
nonlinearity
(
HT
(
power
*
sh
))
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
inverter
=
ift
.
ConjugateGradient
(
controller
=
ICI
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
ICI
)
# initial guess
# initial guess
m
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
m
=
ift
.
full
(
h_space
,
1e-7
)
map_energy
=
ift
.
library
.
NonlinearWienerFilterEnergy
(
map_energy
=
ift
.
library
.
NonlinearWienerFilterEnergy
(
m
,
d
,
R
,
nonlinearity
,
HT
,
power
,
N
,
S
,
inverter
=
inverter
)
m
,
d
,
R
,
nonlinearity
,
HT
,
power
,
N
,
S
,
inverter
=
inverter
)
...
...
demos/poisson_demo.py
View file @
087530b0
...
@@ -80,12 +80,12 @@ if __name__ == "__main__":
...
@@ -80,12 +80,12 @@ if __name__ == "__main__":
IC
=
ift
.
GradientNormController
(
name
=
"inverter"
,
iteration_limit
=
500
,
IC
=
ift
.
GradientNormController
(
name
=
"inverter"
,
iteration_limit
=
500
,
tol_abs_gradnorm
=
1e-3
)
tol_abs_gradnorm
=
1e-3
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
IC
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
IC
)
D
=
(
ift
.
SandwichOperator
(
R
,
N
.
inverse
)
+
Phi_h
.
inverse
).
inverse
D
=
(
ift
.
SandwichOperator
.
make
(
R
,
N
.
inverse
)
+
Phi_h
.
inverse
).
inverse
D
=
ift
.
InversionEnabler
(
D
,
inverter
,
approximation
=
Phi_h
)
D
=
ift
.
InversionEnabler
(
D
,
inverter
,
approximation
=
Phi_h
)
m
=
HT
(
D
(
j
))
m
=
HT
(
D
(
j
))
# Uncertainty
# Uncertainty
D
=
ift
.
SandwichOperator
(
aHT
,
D
)
# real space propagator
D
=
ift
.
SandwichOperator
.
make
(
aHT
,
D
)
# real space propagator
Dhat
=
ift
.
probe_with_posterior_samples
(
D
.
inverse
,
None
,
Dhat
=
ift
.
probe_with_posterior_samples
(
D
.
inverse
,
None
,
nprobes
=
nprobes
)[
1
]
nprobes
=
nprobes
)[
1
]
sig
=
ift
.
sqrt
(
Dhat
)
sig
=
ift
.
sqrt
(
Dhat
)
...
@@ -113,7 +113,7 @@ if __name__ == "__main__":
...
@@ -113,7 +113,7 @@ if __name__ == "__main__":
d_domain
,
np
.
random
.
poisson
(
lam
.
local_data
).
astype
(
np
.
float64
))
d_domain
,
np
.
random
.
poisson
(
lam
.
local_data
).
astype
(
np
.
float64
))
# initial guess
# initial guess
psi0
=
ift
.
Field
.
full
(
h_domain
,
1e-7
)
psi0
=
ift
.
full
(
h_domain
,
1e-7
)
energy
=
ift
.
library
.
PoissonEnergy
(
psi0
,
data
,
R0
,
nonlin
,
HT
,
Phi_h
,
energy
=
ift
.
library
.
PoissonEnergy
(
psi0
,
data
,
R0
,
nonlin
,
HT
,
Phi_h
,
inverter
)
inverter
)
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
200
,
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
200
,
...
...
demos/wiener_filter_data_space_noiseless.py
View file @
087530b0
...
@@ -13,19 +13,20 @@ class CustomResponse(ift.LinearOperator):
...
@@ -13,19 +13,20 @@ class CustomResponse(ift.LinearOperator):
def
__init__
(
self
,
domain
,
data_points
):
def
__init__
(
self
,
domain
,
data_points
):
self
.
_domain
=
ift
.
DomainTuple
.
make
(
domain
)
self
.
_domain
=
ift
.
DomainTuple
.
make
(
domain
)
self
.
_points
=
data_points
self
.
_points
=
data_points
data_shape
=
ift
.
Field
.
zeros
(
domain
).
val
[
data_points
].
shape
data_shape
=
ift
.
Field
.
full
(
domain
,
0.
).
to_global_data
()[
data_points
]
\
.
shape
self
.
_target
=
ift
.
DomainTuple
.
make
(
ift
.
UnstructuredDomain
(
data_shape
))
self
.
_target
=
ift
.
DomainTuple
.
make
(
ift
.
UnstructuredDomain
(
data_shape
))
def
_times
(
self
,
x
):
def
_times
(
self
,
x
):
d
=
ift
.
Field
.
zeros
(
self
.
_target
)
d
=
np
.
zeros
(
self
.
_target
.
shape
,
dtype
=
np
.
float64
)
d
.
val
[()]
+=
x
.
val
[
self
.
_points
]
d
+=
x
.
to_global_data
()
[
self
.
_points
]
return
d
return
ift
.
from_global_data
(
self
.
_target
,
d
)
def
_adjoint_times
(
self
,
d
):
def
_adjoint_times
(
self
,
d
):
x
=
ift
.
Field
.
zeros
(
self
.
_domain
)
x
=
np
.
zeros
(
self
.
_domain
.
shape
,
dtype
=
np
.
float64
)
x
.
val
[
self
.
_points
]
+=
d
.
val
[
()
]
x
[
self
.
_points
]
+=
d
.
to_global_data
()
return
x
return
ift
.
from_global_data
(
self
.
_domain
,
x
)
@
property
@
property
def
domain
(
self
):
def
domain
(
self
):
return
self
.
_domain
return
self
.
_domain
...
@@ -61,8 +62,8 @@ if __name__ == "__main__":
...
@@ -61,8 +62,8 @@ if __name__ == "__main__":
# Set up derived constants
# Set up derived constants
k_0
=
1.
/
correlation_length
k_0
=
1.
/
correlation_length
#defining a power spectrum with the right correlation length
#
defining a power spectrum with the right correlation length
#we later set the field variance to the desired value
#
we later set the field variance to the desired value
unscaled_pow_spec
=
(
lambda
k
:
1.
/
(
1
+
k
/
k_0
)
**
4
)
unscaled_pow_spec
=
(
lambda
k
:
1.
/
(
1
+
k
/
k_0
)
**
4
)
pixel_width
=
L
/
N_pixels
pixel_width
=
L
/
N_pixels
...
@@ -71,7 +72,7 @@ if __name__ == "__main__":
...
@@ -71,7 +72,7 @@ if __name__ == "__main__":
h_space
=
s_space
.
get_default_codomain
()
h_space
=
s_space
.
get_default_codomain
()
s_var
=
ift
.
get_signal_variance
(
unscaled_pow_spec
,
h_space
)
s_var
=
ift
.
get_signal_variance
(
unscaled_pow_spec
,
h_space
)
pow_spec
=
(
lambda
k
:
unscaled_pow_spec
(
k
)
/
s_var
*
field_variance
**
2
)
pow_spec
=
(
lambda
k
:
unscaled_pow_spec
(
k
)
/
s_var
*
field_variance
**
2
)
HT
=
ift
.
HarmonicTransformOperator
(
h_space
,
s_space
)
HT
=
ift
.
HarmonicTransformOperator
(
h_space
,
s_space
)
# Create mock data
# Create mock data
...
@@ -79,7 +80,8 @@ if __name__ == "__main__":
...
@@ -79,7 +80,8 @@ if __name__ == "__main__":
Sh
=
ift
.
create_power_operator
(
h_space
,
power_spectrum
=
pow_spec
)
Sh
=
ift
.
create_power_operator
(
h_space
,
power_spectrum
=
pow_spec
)
sh
=
Sh
.
draw_sample
()
sh
=
Sh
.
draw_sample
()
Rx
=
CustomResponse
(
s_space
,
[
np
.
arange
(
0
,
N_pixels
,
5
)[:,
np
.
newaxis
],
np
.
arange
(
0
,
N_pixels
,
2
)[
np
.
newaxis
,:]])
Rx
=
CustomResponse
(
s_space
,
[
np
.
arange
(
0
,
N_pixels
,
5
)[:,
np
.
newaxis
],
np
.
arange
(
0
,
N_pixels
,
2
)[
np
.
newaxis
,
:]])
ift
.
extra
.
consistency_check
(
Rx
)
ift
.
extra
.
consistency_check
(
Rx
)
a
=
ift
.
Field
.
from_random
(
'normal'
,
s_space
)
a
=
ift
.
Field
.
from_random
(
'normal'
,
s_space
)
b
=
ift
.
Field
.
from_random
(
'normal'
,
Rx
.
target
)
b
=
ift
.
Field
.
from_random
(
'normal'
,
Rx
.
target
)
...
@@ -97,16 +99,17 @@ if __name__ == "__main__":
...
@@ -97,16 +99,17 @@ if __name__ == "__main__":
tol_abs_gradnorm
=
0.0001
)
tol_abs_gradnorm
=
0.0001
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
IC
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
IC
)
# setting up measurement precision matrix M
# setting up measurement precision matrix M
M
=
(
ift
.
SandwichOperator
(
R
.
adjoint
,
Sh
)
+
N
)
M
=
(
ift
.
SandwichOperator
.
make
(
R
.
adjoint
,
Sh
)
+
N
)
M
=
ift
.
InversionEnabler
(
M
,
inverter
)
M
=
ift
.
InversionEnabler
(
M
,
inverter
)
m
=
Sh
(
R
.
adjoint
(
M
.
inverse_times
(
d
)))
m
=
Sh
(
R
.
adjoint
(
M
.
inverse_times
(
d
)))
# Plotting
# Plotting
backprojection
=
Rx
.
adjoint
(
d
)
backprojection
=
Rx
.
adjoint
(
d
)
reweighted_backprojection
=
backprojection
/
backprojection
.
max
()
*
HT
(
sh
).
max
()
reweighted_backprojection
=
(
backprojection
/
backprojection
.
max
()
*
HT
(
sh
).
max
())
zmax
=
max
(
HT
(
sh
).
max
(),
reweighted_backprojection
.
max
(),
HT
(
m
).
max
())
zmax
=
max
(
HT
(
sh
).
max
(),
reweighted_backprojection
.
max
(),
HT
(
m
).
max
())
zmin
=
min
(
HT
(
sh
).
min
(),
reweighted_backprojection
.
min
(),
HT
(
m
).
min
())
zmin
=
min
(
HT
(
sh
).
min
(),
reweighted_backprojection
.
min
(),
HT
(
m
).
min
())
plotdict
=
{
"colormap"
:
"Planck-like"
,
"zmax"
:
zmax
,
"zmin"
:
zmin
}
plotdict
=
{
"colormap"
:
"Planck-like"
,
"zmax"
:
zmax
,
"zmin"
:
zmin
}
ift
.
plot
(
HT
(
sh
),
name
=
"mock_signal.png"
,
**
plotdict
)
ift
.
plot
(
HT
(
sh
),
name
=
"mock_signal.png"
,
**
plotdict
)
ift
.
plot
(
backprojection
,
name
=
"backprojected_data.png"
,
**
plotdict
)
ift
.
plot
(
backprojection
,
name
=
"backprojected_data.png"
,
**
plotdict
)
ift
.
plot
(
HT
(
m
),
name
=
"reconstruction.png"
,
**
plotdict
)
ift
.
plot
(
HT
(
m
),
name
=
"reconstruction.png"
,
**
plotdict
)
demos/wiener_filter_easy.py
View file @
087530b0
...
@@ -31,7 +31,7 @@ if __name__ == "__main__":
...
@@ -31,7 +31,7 @@ if __name__ == "__main__":
h_space
=
s_space
.
get_default_codomain
()
h_space
=
s_space
.
get_default_codomain
()
s_var
=
ift
.
get_signal_variance
(
unscaled_pow_spec
,
h_space
)
s_var
=
ift
.
get_signal_variance
(
unscaled_pow_spec
,
h_space
)
pow_spec
=
(
lambda
k
:
unscaled_pow_spec
(
k
)
/
s_var
*
field_variance
**
2
)
pow_spec
=
(
lambda
k
:
unscaled_pow_spec
(
k
)
/
s_var
*
field_variance
**
2
)
HT
=
ift
.
HarmonicTransformOperator
(
h_space
,
s_space
)
HT
=
ift
.
HarmonicTransformOperator
(
h_space
,
s_space
)
# Create mock data
# Create mock data
...
@@ -53,7 +53,7 @@ if __name__ == "__main__":
...
@@ -53,7 +53,7 @@ if __name__ == "__main__":
IC
=
ift
.
GradientNormController
(
name
=
"inverter"
,
iteration_limit
=
500
,
IC
=
ift
.
GradientNormController
(
name
=
"inverter"
,
iteration_limit
=
500
,
tol_abs_gradnorm
=
0.1
)
tol_abs_gradnorm
=
0.1
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
IC
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
IC
)
D
=
(
ift
.
SandwichOperator
(
R
,
N
.
inverse
)
+
Sh
.
inverse
).
inverse
D
=
(
ift
.
SandwichOperator
.
make
(
R
,
N
.
inverse
)
+
Sh
.
inverse
).
inverse
D
=
ift
.
InversionEnabler
(
D
,
inverter
,
approximation
=
Sh
)
D
=
ift
.
InversionEnabler
(
D
,
inverter
,
approximation
=
Sh
)
m
=
D
(
j
)
m
=
D
(
j
)
...
...
demos/wiener_filter_via_hamiltonian.py
View file @
087530b0
...
@@ -50,7 +50,7 @@ if __name__ == "__main__":
...
@@ -50,7 +50,7 @@ if __name__ == "__main__":
inverter
=
ift
.
ConjugateGradient
(
controller
=
ctrl
)
inverter
=
ift
.
ConjugateGradient
(
controller
=
ctrl
)
controller
=
ift
.
GradientNormController
(
name
=
"min"
,
tol_abs_gradnorm
=
0.1
)
controller
=
ift
.
GradientNormController
(
name
=
"min"
,
tol_abs_gradnorm
=
0.1
)
minimizer
=
ift
.
RelaxedNewton
(
controller
=
controller
)
minimizer
=
ift
.
RelaxedNewton
(
controller
=
controller
)
m0
=
ift
.
Field
.
zeros
(
h_space
)
m0
=
ift
.
full
(
h_space
,
0.
)
# Initialize Wiener filter energy
# Initialize Wiener filter energy
energy
=
ift
.
library
.
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S
,
energy
=
ift
.
library
.
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S
,
...
...
nifty4/__init__.py
View file @
087530b0
...
@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple
...
@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple
from
.operators
import
*
from
.operators
import
*
from
.field
import
Field
,
sqrt
,
exp
,
log
from
.field
import
Field
from
.probing.utils
import
probe_with_posterior_samples
,
probe_diagonal
,
\
from
.probing.utils
import
probe_with_posterior_samples
,
probe_diagonal
,
\
StatCalculator
StatCalculator
...
...
nifty4/data_objects/distributed_do.py
View file @
087530b0
...
@@ -20,6 +20,7 @@ import numpy as np
...
@@ -20,6 +20,7 @@ import numpy as np
from
.random
import
Random
from
.random
import
Random
from
mpi4py
import
MPI
from
mpi4py
import
MPI
import
sys
import
sys
from
functools
import
reduce
_comm
=
MPI
.
COMM_WORLD
_comm
=
MPI
.
COMM_WORLD
ntask
=
_comm
.
Get_size
()
ntask
=
_comm
.
Get_size
()
...
@@ -145,20 +146,29 @@ class data_object(object):
...
@@ -145,20 +146,29 @@ class data_object(object):
def
sum
(
self
,
axis
=
None
):
def
sum
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"sum"
,
MPI
.
SUM
,
axis
)
return
self
.
_contraction_helper
(
"sum"
,
MPI
.
SUM
,
axis
)
def
prod
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"prod"
,
MPI
.
PROD
,
axis
)
def
min
(
self
,
axis
=
None
):
def
min
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"min"
,
MPI
.
MIN
,
axis
)
return
self
.
_contraction_helper
(
"min"
,
MPI
.
MIN
,
axis
)
def
max
(
self
,
axis
=
None
):
def
max
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"max"
,
MPI
.
MAX
,
axis
)
return
self
.
_contraction_helper
(
"max"
,
MPI
.
MAX
,
axis
)
def
mean
(
self
):
def
mean
(
self
,
axis
=
None
):
return
self
.
sum
()
/
self
.
size
if
axis
is
None
:
sz
=
self
.
size
else
:
sz
=
reduce
(
lambda
x
,
y
:
x
*
y
,
[
self
.
shape
[
i
]
for
i
in
axis
])
return
self
.
sum
(
axis
)
/
sz
def
std
(
self
):
def
std
(
self
,
axis
=
None
):
return
np
.
sqrt
(
self
.
var
())
return
np
.
sqrt
(
self
.
var
(
axis
))
# FIXME: to be improved!
# FIXME: to be improved!
def
var
(
self
):
def
var
(
self
,
axis
=
None
):
if
axis
is
not
None
and
len
(
axis
)
!=
len
(
self
.
shape
):
raise
ValueError
(
"functionality not yet supported"
)
return
(
abs
(
self
-
self
.
mean
())
**
2
).
mean
()
return
(
abs
(
self
-
self
.
mean
())
**
2
).
mean
()
def
_binary_helper
(
self
,
other
,
op
):
def
_binary_helper
(
self
,
other
,
op
):
...
...
nifty4/domain_tuple.py
View file @
087530b0
...
@@ -34,7 +34,9 @@ class DomainTuple(object):
...
@@ -34,7 +34,9 @@ class DomainTuple(object):
"""
"""
_tupleCache
=
{}
_tupleCache
=
{}
def
__init__
(
self
,
domain
):
def
__init__
(
self
,
domain
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
self
.
_dom
=
self
.
_parse_domain
(
domain
)
self
.
_dom
=
self
.
_parse_domain
(
domain
)
self
.
_axtuple
=
self
.
_get_axes_tuple
()
self
.
_axtuple
=
self
.
_get_axes_tuple
()
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
_dom
)
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
_dom
)
...
@@ -72,7 +74,7 @@ class DomainTuple(object):
...
@@ -72,7 +74,7 @@ class DomainTuple(object):
obj
=
DomainTuple
.
_tupleCache
.
get
(
domain
)
obj
=
DomainTuple
.
_tupleCache
.
get
(
domain
)
if
obj
is
not
None
:
if
obj
is
not
None
:
return
obj
return
obj
obj
=
DomainTuple
(
domain
)
obj
=
DomainTuple
(
domain
,
_callingfrommake
=
True
)
DomainTuple
.
_tupleCache
[
domain
]
=
obj
DomainTuple
.
_tupleCache
[
domain
]
=
obj
return
obj
return
obj
...
...
nifty4/domains/domain.py
View file @
087530b0
...
@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase
...
@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase
class
Domain
(
NiftyMetaBase
()):
class
Domain
(
NiftyMetaBase
()):
"""The abstract class repesenting a (structured or unstructured) domain.
"""The abstract class repesenting a (structured or unstructured) domain.
"""
"""
def
__init__
(
self
):
self
.
_hash
=
None
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
__repr__
(
self
):
def
__repr__
(
self
):
...
@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()):
...
@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()):
Only members that are explicitly added to
Only members that are explicitly added to
:attr:`._needed_for_hash` will be used for hashing.
:attr:`._needed_for_hash` will be used for hashing.
"""
"""
result_hash
=
0
if
self
.
_hash
is
None
:
for
key
in
self
.
_needed_for_hash
:
h
=
0
result_hash
^=
hash
(
vars
(
self
)[
key
])
for
key
in
self
.
_needed_for_hash
:
return
result_hash
h
^=
hash
(
vars
(
self
)[
key
])
self
.
_hash
=
h
return
self
.
_hash
def
__eq__
(
self
,
x
):
def
__eq__
(
self
,
x
):
"""Checks whether two domains are equal.
"""Checks whether two domains are equal.
...
...
nifty4/domains/lm_space.py
View file @
087530b0
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
from
__future__
import
division
from
__future__
import
division
import
numpy
as
np
import
numpy
as
np
from
.structured_domain
import
StructuredDomain
from
.structured_domain
import
StructuredDomain
from
..field
import
Field
,
exp
from
..field
import
Field
class
LMSpace
(
StructuredDomain
):
class
LMSpace
(
StructuredDomain
):
...
@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain):
...
@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain):
# cf. "All-sky convolution for polarimetry experiments"
# cf. "All-sky convolution for polarimetry experiments"
# by Challinor et al.
# by Challinor et al.
# http://arxiv.org/abs/astro-ph/0008228
# http://arxiv.org/abs/astro-ph/0008228
from
..sugar
import
exp
res
=
x
+
1.
res
=
x
+
1.
res
*=
x
res
*=
x
res
*=
-
0.5
*
sigma
*
sigma
res
*=
-
0.5
*
sigma
*
sigma
...
...
nifty4/domains/rg_space.py
View file @
087530b0
...
@@ -21,7 +21,7 @@ from builtins import range
...
@@ -21,7 +21,7 @@ from builtins import range
from
functools
import
reduce
from
functools
import
reduce
import
numpy
as
np
import
numpy
as
np
from
.structured_domain
import
StructuredDomain
from
.structured_domain
import
StructuredDomain
from
..field
import
Field
,
exp
from
..field
import
Field
from
..
import
dobj
from
..
import
dobj
...
@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain):
...
@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain):
@
staticmethod
@
staticmethod
def
_kernel
(
x
,
sigma
):
def
_kernel
(
x
,
sigma
):
from
..sugar
import
exp
tmp
=
x
*
x
tmp
=
x
*
x
tmp
*=
-
2.
*
np
.
pi
*
np
.
pi
*
sigma
*
sigma
tmp
*=
-
2.
*
np
.
pi
*
np
.
pi
*
sigma
*
sigma
exp
(
tmp
,
out
=
tmp
)
exp
(
tmp
,
out
=
tmp
)
...
...
nifty4/extra/operator_tests.py
View file @
087530b0
...
@@ -17,17 +17,26 @@
...
@@ -17,17 +17,26 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
import
numpy
as
np
from
..sugar
import
from_random
from
..field
import
Field
from
..field
import
Field
__all__
=
[
"consistency_check"
]
__all__
=
[
"consistency_check"
]
def
_assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
local_data
,
f2
.
local_data
,
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
_assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
def
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
def
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
return
f1
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
f2
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
).
lock
())
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
).
lock
())
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
...
@@ -37,15 +46,13 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
...
@@ -37,15 +46,13 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
return
foo
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
res
=
op
(
op
.
inverse_times
(
foo
).
lock
())
res
=
op
(
op
.
inverse_times
(
foo
).
lock
())
np
.
testing
.
assert_allclose
(
res
.
to_global_data
(),
res
.
to_global_data
(),
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
atol
=
atol
,
rtol
=
rtol
)
foo
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
res
=
op
.
inverse_times
(
op
(
foo
).
lock
())
res
=
op
.
inverse_times
(
op
(
foo
).
lock
())
np
.
testing
.
assert_allclose
(
res
.
to_global_data
(),
foo
.
to_global_data
(),
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
atol
=
atol
,
rtol
=
rtol
)
def
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
def
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
...
...
nifty4/field.py
View file @
087530b0
...
@@ -106,62 +106,10 @@ class Field(object):
...
@@ -106,62 +106,10 @@ class Field(object):
raise
TypeError
(
"val must be a scalar"
)
raise
TypeError
(
"val must be a scalar"
)
return
Field
(
DomainTuple
.
make
(
domain
),
val
,
dtype
)
return
Field
(
DomainTuple
.
make
(
domain
),
val
,
dtype
)
@
staticmethod
def
ones
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
1.
,
dtype
)
@
staticmethod
def
zeros
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
0.
,
dtype
)
@
staticmethod
@
staticmethod
def
empty
(
domain
,
dtype
=
None
):
def
empty
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
None
,
dtype
)
return
Field
(
DomainTuple
.
make
(
domain
),
None
,
dtype
)
@
staticmethod
def
full_like
(
field
,
val
,
dtype
=
None
):
"""Creates a Field from a template, filled with a constant value.
Parameters
----------
field : Field
the template field, from which the domain is inferred
val : float/complex/int scalar
fill value. Data type of the field is inferred from val.
Returns
-------
Field
the newly created field
"""
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
return
Field
.
full
(
field
.
_domain
,
val
,
dtype
)
@
staticmethod
def
zeros_like
(
field
,
dtype
=
None