Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
67660e26
Commit
67660e26
authored
May 22, 2018
by
Martin Reinecke
Browse files
Merge branch 'static_restructure' into 'NIFTy_4'
Static restructure See merge request ift/NIFTy!259
parents
096f619e
36640cc0
Pipeline
#29556
passed with stages
in 4 minutes and 45 seconds
Changes
34
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
demos/critical_filtering.py
View file @
67660e26
...
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d
=
noiseless_data
+
n
m0
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
Field
.
full
(
p_space
,
-
4.
)
m0
=
ift
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
full
(
p_space
,
-
4.
)
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
plotdict
=
{
"colormap"
:
"Planck-like"
}
...
...
demos/krylov_sampling.py
View file @
67660e26
...
...
@@ -67,8 +67,8 @@ plt.legend()
plt
.
savefig
(
'Krylov_samples_residuals.png'
)
plt
.
close
()
D_hat_old
=
ift
.
Field
.
zeros
(
x_space
).
to_global_data
()
D_hat_new
=
ift
.
Field
.
zeros
(
x_space
).
to_global_data
()
D_hat_old
=
ift
.
full
(
x_space
,
0.
).
to_global_data
()
D_hat_new
=
ift
.
full
(
x_space
,
0.
).
to_global_data
()
for
i
in
range
(
N_samps
):
D_hat_old
+=
sky
(
samps_old
[
i
]).
to_global_data
()
**
2
D_hat_new
+=
sky
(
samps
[
i
]).
to_global_data
()
**
2
...
...
demos/nonlinear_critical_filter.py
View file @
67660e26
...
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d
=
noiseless_data
+
n
m0
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
Field
.
full
(
p_space
,
-
4.
)
m0
=
ift
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
full
(
p_space
,
-
4.
)
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
100
,
...
...
demos/nonlinear_wiener_filter.py
View file @
67660e26
...
...
@@ -36,7 +36,7 @@ if __name__ == "__main__":
d_space
=
R
.
target
p_op
=
ift
.
create_power_operator
(
h_space
,
p_spec
)
power
=
ift
.
sqrt
(
p_op
(
ift
.
Field
.
full
(
h_space
,
1.
)))
power
=
ift
.
sqrt
(
p_op
(
ift
.
full
(
h_space
,
1.
)))
# Creating the mock data
true_sky
=
nonlinearity
(
HT
(
power
*
sh
))
...
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
inverter
=
ift
.
ConjugateGradient
(
controller
=
ICI
)
# initial guess
m
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
m
=
ift
.
full
(
h_space
,
1e-7
)
map_energy
=
ift
.
library
.
NonlinearWienerFilterEnergy
(
m
,
d
,
R
,
nonlinearity
,
HT
,
power
,
N
,
S
,
inverter
=
inverter
)
...
...
demos/poisson_demo.py
View file @
67660e26
...
...
@@ -113,7 +113,7 @@ if __name__ == "__main__":
d_domain
,
np
.
random
.
poisson
(
lam
.
local_data
).
astype
(
np
.
float64
))
# initial guess
psi0
=
ift
.
Field
.
full
(
h_domain
,
1e-7
)
psi0
=
ift
.
full
(
h_domain
,
1e-7
)
energy
=
ift
.
library
.
PoissonEnergy
(
psi0
,
data
,
R0
,
nonlin
,
HT
,
Phi_h
,
inverter
)
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
200
,
...
...
demos/wiener_filter_via_hamiltonian.py
View file @
67660e26
...
...
@@ -50,7 +50,7 @@ if __name__ == "__main__":
inverter
=
ift
.
ConjugateGradient
(
controller
=
ctrl
)
controller
=
ift
.
GradientNormController
(
name
=
"min"
,
tol_abs_gradnorm
=
0.1
)
minimizer
=
ift
.
RelaxedNewton
(
controller
=
controller
)
m0
=
ift
.
Field
.
zeros
(
h_space
)
m0
=
ift
.
full
(
h_space
,
0.
)
# Initialize Wiener filter energy
energy
=
ift
.
library
.
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S
,
...
...
nifty4/__init__.py
View file @
67660e26
...
...
@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple
from
.operators
import
*
from
.field
import
Field
,
sqrt
,
exp
,
log
from
.field
import
Field
from
.probing.utils
import
probe_with_posterior_samples
,
probe_diagonal
,
\
StatCalculator
...
...
nifty4/data_objects/distributed_do.py
View file @
67660e26
...
...
@@ -20,6 +20,7 @@ import numpy as np
from
.random
import
Random
from
mpi4py
import
MPI
import
sys
from
functools
import
reduce
_comm
=
MPI
.
COMM_WORLD
ntask
=
_comm
.
Get_size
()
...
...
@@ -145,20 +146,29 @@ class data_object(object):
def
sum
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"sum"
,
MPI
.
SUM
,
axis
)
def
prod
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"prod"
,
MPI
.
PROD
,
axis
)
def
min
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"min"
,
MPI
.
MIN
,
axis
)
def
max
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"max"
,
MPI
.
MAX
,
axis
)
def
mean
(
self
):
return
self
.
sum
()
/
self
.
size
def
mean
(
self
,
axis
=
None
):
if
axis
is
None
:
sz
=
self
.
size
else
:
sz
=
reduce
(
lambda
x
,
y
:
x
*
y
,
[
self
.
shape
[
i
]
for
i
in
axis
])
return
self
.
sum
(
axis
)
/
sz
def
std
(
self
):
return
np
.
sqrt
(
self
.
var
())
def
std
(
self
,
axis
=
None
):
return
np
.
sqrt
(
self
.
var
(
axis
))
# FIXME: to be improved!
def
var
(
self
):
def
var
(
self
,
axis
=
None
):
if
axis
is
not
None
and
len
(
axis
)
!=
len
(
self
.
shape
):
raise
ValueError
(
"functionality not yet supported"
)
return
(
abs
(
self
-
self
.
mean
())
**
2
).
mean
()
def
_binary_helper
(
self
,
other
,
op
):
...
...
nifty4/domain_tuple.py
View file @
67660e26
...
...
@@ -34,7 +34,9 @@ class DomainTuple(object):
"""
_tupleCache
=
{}
def
__init__
(
self
,
domain
):
def
__init__
(
self
,
domain
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
self
.
_dom
=
self
.
_parse_domain
(
domain
)
self
.
_axtuple
=
self
.
_get_axes_tuple
()
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
_dom
)
...
...
@@ -72,7 +74,7 @@ class DomainTuple(object):
obj
=
DomainTuple
.
_tupleCache
.
get
(
domain
)
if
obj
is
not
None
:
return
obj
obj
=
DomainTuple
(
domain
)
obj
=
DomainTuple
(
domain
,
_callingfrommake
=
True
)
DomainTuple
.
_tupleCache
[
domain
]
=
obj
return
obj
...
...
nifty4/domains/domain.py
View file @
67660e26
...
...
@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase
class
Domain
(
NiftyMetaBase
()):
"""The abstract class repesenting a (structured or unstructured) domain.
"""
def
__init__
(
self
):
self
.
_hash
=
None
@
abc
.
abstractmethod
def
__repr__
(
self
):
...
...
@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()):
Only members that are explicitly added to
:attr:`._needed_for_hash` will be used for hashing.
"""
result_hash
=
0
for
key
in
self
.
_needed_for_hash
:
result_hash
^=
hash
(
vars
(
self
)[
key
])
return
result_hash
if
self
.
_hash
is
None
:
h
=
0
for
key
in
self
.
_needed_for_hash
:
h
^=
hash
(
vars
(
self
)[
key
])
self
.
_hash
=
h
return
self
.
_hash
def
__eq__
(
self
,
x
):
"""Checks whether two domains are equal.
...
...
nifty4/domains/lm_space.py
View file @
67660e26
...
...
@@ -19,7 +19,7 @@
from
__future__
import
division
import
numpy
as
np
from
.structured_domain
import
StructuredDomain
from
..field
import
Field
,
exp
from
..field
import
Field
class
LMSpace
(
StructuredDomain
):
...
...
@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain):
# cf. "All-sky convolution for polarimetry experiments"
# by Challinor et al.
# http://arxiv.org/abs/astro-ph/0008228
from
..sugar
import
exp
res
=
x
+
1.
res
*=
x
res
*=
-
0.5
*
sigma
*
sigma
...
...
nifty4/domains/rg_space.py
View file @
67660e26
...
...
@@ -21,7 +21,7 @@ from builtins import range
from
functools
import
reduce
import
numpy
as
np
from
.structured_domain
import
StructuredDomain
from
..field
import
Field
,
exp
from
..field
import
Field
from
..
import
dobj
...
...
@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain):
@
staticmethod
def
_kernel
(
x
,
sigma
):
from
..sugar
import
exp
tmp
=
x
*
x
tmp
*=
-
2.
*
np
.
pi
*
np
.
pi
*
sigma
*
sigma
exp
(
tmp
,
out
=
tmp
)
...
...
nifty4/extra/operator_tests.py
View file @
67660e26
...
...
@@ -17,17 +17,26 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..sugar
import
from_random
from
..field
import
Field
__all__
=
[
"consistency_check"
]
def
_assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
local_data
,
f2
.
local_data
,
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
_assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
def
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
f2
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
).
lock
())
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
...
...
@@ -37,15 +46,13 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
res
=
op
(
op
.
inverse_times
(
foo
).
lock
())
np
.
testing
.
assert_allclose
(
res
.
to_global_data
(),
res
.
to_global_data
(),
atol
=
atol
,
rtol
=
rtol
)
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
res
=
op
.
inverse_times
(
op
(
foo
).
lock
())
np
.
testing
.
assert_allclose
(
res
.
to_global_data
(),
foo
.
to_global_data
(),
atol
=
atol
,
rtol
=
rtol
)
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
def
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
...
...
nifty4/field.py
View file @
67660e26
...
...
@@ -106,62 +106,10 @@ class Field(object):
raise
TypeError
(
"val must be a scalar"
)
return
Field
(
DomainTuple
.
make
(
domain
),
val
,
dtype
)
@
staticmethod
def
ones
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
1.
,
dtype
)
@
staticmethod
def
zeros
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
0.
,
dtype
)
@
staticmethod
def
empty
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
None
,
dtype
)
@
staticmethod
def
full_like
(
field
,
val
,
dtype
=
None
):
"""Creates a Field from a template, filled with a constant value.
Parameters
----------
field : Field
the template field, from which the domain is inferred
val : float/complex/int scalar
fill value. Data type of the field is inferred from val.
Returns
-------
Field
the newly created field
"""
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
return
Field
.
full
(
field
.
_domain
,
val
,
dtype
)
@
staticmethod
def
zeros_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
zeros
(
field
.
_domain
,
dtype
)
@
staticmethod
def
ones_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
ones
(
field
.
_domain
,
dtype
)
@
staticmethod
def
empty_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
empty
(
field
.
_domain
,
dtype
)
@
staticmethod
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
"""Returns a Field constructed from `domain` and `arr`.
...
...
@@ -287,6 +235,7 @@ class Field(object):
The value to fill the field with.
"""
self
.
_val
.
fill
(
fill_value
)
return
self
def
lock
(
self
):
"""Write-protect the data content of `self`.
...
...
@@ -370,6 +319,17 @@ class Field(object):
"""
return
Field
(
val
=
self
,
copy
=
True
)
def
empty_copy
(
self
):
""" Returns a Field with identical domain and data type, but
uninitialized data.
Returns
-------
Field
A copy of 'self', with uninitialized data.
"""
return
Field
(
self
.
_domain
,
dtype
=
self
.
dtype
)
def
locked_copy
(
self
):
""" Returns a read-only version of the Field.
...
...
@@ -503,8 +463,8 @@ class Field(object):
or Field (for partial dot products)
"""
if
not
isinstance
(
x
,
Field
):
raise
Valu
eError
(
"The dot-partner must be an instance of "
+
"the NIFTy field class"
)
raise
Typ
eError
(
"The dot-partner must be an instance of "
+
"the NIFTy field class"
)
if
x
.
_domain
!=
self
.
_domain
:
raise
ValueError
(
"Domain mismatch"
)
...
...
@@ -694,7 +654,8 @@ class Field(object):
if
self
.
scalar_weight
(
spaces
)
is
not
None
:
return
self
.
_contraction_helper
(
'mean'
,
spaces
)
# MR FIXME: not very efficient
tmp
=
self
.
weight
(
1
)
# MR FIXME: do we need "spaces" here?
tmp
=
self
.
weight
(
1
,
spaces
)
return
tmp
.
sum
(
spaces
)
*
(
1.
/
tmp
.
total_volume
(
spaces
))
def
var
(
self
,
spaces
=
None
):
...
...
@@ -717,12 +678,10 @@ class Field(object):
# MR FIXME: not very efficient or accurate
m1
=
self
.
mean
(
spaces
)
if
np
.
issubdtype
(
self
.
dtype
,
np
.
complexfloating
):
sq
=
abs
(
self
)
**
2
m1
=
abs
(
m1
)
**
2
sq
=
abs
(
self
-
m1
)
**
2
else
:
sq
=
self
**
2
m1
**=
2
return
sq
.
mean
(
spaces
)
-
m1
sq
=
(
self
-
m1
)
**
2
return
sq
.
mean
(
spaces
)
def
std
(
self
,
spaces
=
None
):
"""Determines the standard deviation over the sub-domains given by
...
...
@@ -742,6 +701,7 @@ class Field(object):
The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
"""
from
.sugar
import
sqrt
if
self
.
scalar_weight
(
spaces
)
is
not
None
:
return
self
.
_contraction_helper
(
'std'
,
spaces
)
return
sqrt
(
self
.
var
(
spaces
))
...
...
@@ -785,24 +745,3 @@ for op in ["__add__", "__radd__", "__iadd__",
return
NotImplemented
return
func2
setattr
(
Field
,
op
,
func
(
op
))
# Arithmetic functions working on Fields
_current_module
=
sys
.
modules
[
__name__
]
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"conjugate"
]:
def
func
(
f
):
def
func2
(
x
,
out
=
None
):
fu
=
getattr
(
dobj
,
f
)
if
not
isinstance
(
x
,
Field
):
raise
TypeError
(
"This function only accepts Field objects."
)
if
out
is
not
None
:
if
not
isinstance
(
out
,
Field
)
or
x
.
_domain
!=
out
.
_domain
:
raise
ValueError
(
"Bad 'out' argument"
)
fu
(
x
.
val
,
out
=
out
.
val
)
return
out
else
:
return
Field
(
domain
=
x
.
_domain
,
val
=
fu
(
x
.
val
))
return
func2
setattr
(
_current_module
,
f
,
func
(
f
))
nifty4/library/krylov_sampling.py
View file @
67660e26
...
...
@@ -54,7 +54,7 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
"""
# RL FIXME: make consistent with complex numbers
j
=
S
.
draw_sample
(
from_inverse
=
True
)
if
j
is
None
else
j
energy
=
QuadraticEnergy
(
j
*
0.
,
D_inv
,
j
)
energy
=
QuadraticEnergy
(
j
.
empty_copy
().
fill
(
0.
)
,
D_inv
,
j
)
y
=
[
S
.
draw_sample
()
for
_
in
range
(
N_samps
)]
status
=
controller
.
start
(
energy
)
...
...
nifty4/library/noise_energy.py
View file @
67660e26
...
...
@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
..field
import
Field
,
exp
from
..field
import
Field
from
..sugar
import
exp
from
..minimization.energy
import
Energy
from
..operators.diagonal_operator
import
DiagonalOperator
import
numpy
as
np
...
...
nifty4/library/nonlinear_power_energy.py
View file @
67660e26
...
...
@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
..
import
exp
from
..
sugar
import
exp
from
..minimization.energy
import
Energy
from
..operators.smoothness_operator
import
SmoothnessOperator
from
..operators.inversion_enabler
import
InversionEnabler
...
...
nifty4/library/nonlinearities.py
View file @
67660e26
...
...
@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
..
field
import
Field
,
exp
,
tanh
from
..
sugar
import
full
,
exp
,
tanh
class
Linear
(
object
):
...
...
@@ -24,10 +24,10 @@ class Linear(object):
return
x
def
derivative
(
self
,
x
):
return
Field
.
ones_like
(
x
)
return
full
(
x
.
domain
,
1.
)
def
hessian
(
self
,
x
):
return
Field
.
zeros_like
(
x
)
return
full
(
x
.
domain
,
0.
)
class
Exponential
(
object
):
...
...
nifty4/library/poisson_energy.py
View file @
67660e26
...
...
@@ -20,7 +20,7 @@ from ..minimization.energy import Energy
from
..operators.diagonal_operator
import
DiagonalOperator
from
..operators.sandwich_operator
import
SandwichOperator
from
..operators.inversion_enabler
import
InversionEnabler
from
..
field
import
log
from
..
sugar
import
log
class
PoissonEnergy
(
Energy
):
...
...
nifty4/logger.py
View file @
67660e26
...
...
@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
def
_logger_init
():
import
logging
from
.
import
dobj
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment