Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
Neel Shah
NIFTy
Commits
a7bc5e41
Commit
a7bc5e41
authored
May 19, 2018
by
Martin Reinecke
Browse files
step 1
parent
6fb90ba4
Changes
19
Hide whitespace changes
Inline
Side-by-side
demos/critical_filtering.py
View file @
a7bc5e41
...
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d
=
noiseless_data
+
n
m0
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
Field
.
full
(
p_space
,
-
4.
)
m0
=
ift
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
full
(
p_space
,
-
4.
)
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
plotdict
=
{
"colormap"
:
"Planck-like"
}
...
...
demos/krylov_sampling.py
View file @
a7bc5e41
...
...
@@ -67,8 +67,8 @@ plt.legend()
plt
.
savefig
(
'Krylov_samples_residuals.png'
)
plt
.
close
()
D_hat_old
=
ift
.
Field
.
zeros
(
x_space
).
to_global_data
()
D_hat_new
=
ift
.
Field
.
zeros
(
x_space
).
to_global_data
()
D_hat_old
=
ift
.
full
(
x_space
,
0.
).
to_global_data
()
D_hat_new
=
ift
.
full
(
x_space
,
0.
).
to_global_data
()
for
i
in
range
(
N_samps
):
D_hat_old
+=
sky
(
samps_old
[
i
]).
to_global_data
()
**
2
D_hat_new
+=
sky
(
samps
[
i
]).
to_global_data
()
**
2
...
...
demos/nonlinear_critical_filter.py
View file @
a7bc5e41
...
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d
=
noiseless_data
+
n
m0
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
Field
.
full
(
p_space
,
-
4.
)
m0
=
ift
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
full
(
p_space
,
-
4.
)
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
100
,
...
...
demos/nonlinear_wiener_filter.py
View file @
a7bc5e41
...
...
@@ -36,7 +36,7 @@ if __name__ == "__main__":
d_space
=
R
.
target
p_op
=
ift
.
create_power_operator
(
h_space
,
p_spec
)
power
=
ift
.
sqrt
(
p_op
(
ift
.
Field
.
full
(
h_space
,
1.
)))
power
=
ift
.
sqrt
(
p_op
(
ift
.
full
(
h_space
,
1.
)))
# Creating the mock data
true_sky
=
nonlinearity
(
HT
(
power
*
sh
))
...
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
inverter
=
ift
.
ConjugateGradient
(
controller
=
ICI
)
# initial guess
m
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
m
=
ift
.
full
(
h_space
,
1e-7
)
map_energy
=
ift
.
library
.
NonlinearWienerFilterEnergy
(
m
,
d
,
R
,
nonlinearity
,
HT
,
power
,
N
,
S
,
inverter
=
inverter
)
...
...
demos/poisson_demo.py
View file @
a7bc5e41
...
...
@@ -113,7 +113,7 @@ if __name__ == "__main__":
d_domain
,
np
.
random
.
poisson
(
lam
.
local_data
).
astype
(
np
.
float64
))
# initial guess
psi0
=
ift
.
Field
.
full
(
h_domain
,
1e-7
)
psi0
=
ift
.
full
(
h_domain
,
1e-7
)
energy
=
ift
.
library
.
PoissonEnergy
(
psi0
,
data
,
R0
,
nonlin
,
HT
,
Phi_h
,
inverter
)
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
200
,
...
...
demos/wiener_filter_via_hamiltonian.py
View file @
a7bc5e41
...
...
@@ -50,7 +50,7 @@ if __name__ == "__main__":
inverter
=
ift
.
ConjugateGradient
(
controller
=
ctrl
)
controller
=
ift
.
GradientNormController
(
name
=
"min"
,
tol_abs_gradnorm
=
0.1
)
minimizer
=
ift
.
RelaxedNewton
(
controller
=
controller
)
m0
=
ift
.
Field
.
zeros
(
h_space
)
m0
=
ift
.
full
(
h_space
,
0.
)
# Initialize Wiener filter energy
energy
=
ift
.
library
.
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S
,
...
...
nifty4/domain_tuple.py
View file @
a7bc5e41
...
...
@@ -34,7 +34,9 @@ class DomainTuple(object):
"""
_tupleCache
=
{}
def
__init__
(
self
,
domain
):
def
__init__
(
self
,
domain
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
self
.
_dom
=
self
.
_parse_domain
(
domain
)
self
.
_axtuple
=
self
.
_get_axes_tuple
()
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
_dom
)
...
...
@@ -72,7 +74,7 @@ class DomainTuple(object):
obj
=
DomainTuple
.
_tupleCache
.
get
(
domain
)
if
obj
is
not
None
:
return
obj
obj
=
DomainTuple
(
domain
)
obj
=
DomainTuple
(
domain
,
_callingfrommake
=
True
)
DomainTuple
.
_tupleCache
[
domain
]
=
obj
return
obj
...
...
nifty4/domains/domain.py
View file @
a7bc5e41
...
...
@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase
class
Domain
(
NiftyMetaBase
()):
"""The abstract class repesenting a (structured or unstructured) domain.
"""
def
__init__
(
self
):
self
.
_hash
=
None
@
abc
.
abstractmethod
def
__repr__
(
self
):
...
...
@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()):
Only members that are explicitly added to
:attr:`._needed_for_hash` will be used for hashing.
"""
result_hash
=
0
for
key
in
self
.
_needed_for_hash
:
result_hash
^=
hash
(
vars
(
self
)[
key
])
return
result_hash
if
self
.
_hash
is
None
:
h
=
0
for
key
in
self
.
_needed_for_hash
:
h
^=
hash
(
vars
(
self
)[
key
])
self
.
_hash
=
h
return
self
.
_hash
def
__eq__
(
self
,
x
):
"""Checks whether two domains are equal.
...
...
nifty4/extra/operator_tests.py
View file @
a7bc5e41
...
...
@@ -17,7 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..
field
import
Field
from
..
sugar
import
from_random
__all__
=
[
"consistency_check"
]
...
...
@@ -26,8 +26,8 @@ def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
f2
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
).
lock
())
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
...
...
@@ -37,12 +37,12 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
res
=
op
(
op
.
inverse_times
(
foo
).
lock
())
np
.
testing
.
assert_allclose
(
res
.
to_global_data
(),
res
.
to_global_data
(),
atol
=
atol
,
rtol
=
rtol
)
foo
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
res
=
op
.
inverse_times
(
op
(
foo
).
lock
())
np
.
testing
.
assert_allclose
(
res
.
to_global_data
(),
foo
.
to_global_data
(),
atol
=
atol
,
rtol
=
rtol
)
...
...
nifty4/field.py
View file @
a7bc5e41
...
...
@@ -106,62 +106,10 @@ class Field(object):
raise
TypeError
(
"val must be a scalar"
)
return
Field
(
DomainTuple
.
make
(
domain
),
val
,
dtype
)
@
staticmethod
def
ones
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
1.
,
dtype
)
@
staticmethod
def
zeros
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
0.
,
dtype
)
@
staticmethod
def
empty
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
None
,
dtype
)
@
staticmethod
def
full_like
(
field
,
val
,
dtype
=
None
):
"""Creates a Field from a template, filled with a constant value.
Parameters
----------
field : Field
the template field, from which the domain is inferred
val : float/complex/int scalar
fill value. Data type of the field is inferred from val.
Returns
-------
Field
the newly created field
"""
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
return
Field
.
full
(
field
.
_domain
,
val
,
dtype
)
@
staticmethod
def
zeros_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
zeros
(
field
.
_domain
,
dtype
)
@
staticmethod
def
ones_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
ones
(
field
.
_domain
,
dtype
)
@
staticmethod
def
empty_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
empty
(
field
.
_domain
,
dtype
)
@
staticmethod
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
"""Returns a Field constructed from `domain` and `arr`.
...
...
nifty4/library/nonlinearities.py
View file @
a7bc5e41
...
...
@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
..field
import
Field
,
exp
,
tanh
from
..field
import
exp
,
tanh
from
..sugar
import
full
class
Linear
(
object
):
...
...
@@ -24,10 +25,10 @@ class Linear(object):
return
x
def
derivative
(
self
,
x
):
return
Field
.
ones_like
(
x
)
return
full
(
x
.
domain
,
1.
)
def
hessian
(
self
,
x
):
return
Field
.
zeros_like
(
x
)
return
full
(
x
.
domain
,
0.
)
class
Exponential
(
object
):
...
...
nifty4/multi/multi_domain.py
View file @
a7bc5e41
class
MultiDomain
(
dict
):
pass
import
collections
from
..domain_tuple
import
DomainTuple
__all
=
[
"MultiDomain"
]
class
frozendict
(
collections
.
Mapping
):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping`
interface. It can be used as a drop-in replacement for dictionaries where immutability is desired.
"""
dict_cls
=
dict
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_dict
=
self
.
dict_cls
(
*
args
,
**
kwargs
)
self
.
_hash
=
None
def
__getitem__
(
self
,
key
):
return
self
.
_dict
[
key
]
def
__contains__
(
self
,
key
):
return
key
in
self
.
_dict
def
copy
(
self
,
**
add_or_replace
):
return
self
.
__class__
(
self
,
**
add_or_replace
)
def
__iter__
(
self
):
return
iter
(
self
.
_dict
)
def
__len__
(
self
):
return
len
(
self
.
_dict
)
def
__repr__
(
self
):
return
'<%s %r>'
%
(
self
.
__class__
.
__name__
,
self
.
_dict
)
def
__hash__
(
self
):
if
self
.
_hash
is
None
:
h
=
0
for
key
,
value
in
self
.
_dict
.
items
():
h
^=
hash
((
key
,
value
))
self
.
_hash
=
h
return
self
.
_hash
class
MultiDomain
(
frozendict
):
_domainCache
=
{}
def
__init__
(
domain
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
super
(
MultiDomain
,
self
).
__init__
(
domain
)
@
staticmethod
def
make
(
domain
):
if
isinstance
(
domain
,
MultiDomain
):
return
domain
print
type
(
domain
)
if
not
isinstance
(
domain
,
dict
):
raise
TypeError
(
"dict expected"
)
tmp
=
{}
for
key
,
value
in
domain
.
items
():
if
not
isinstance
(
key
,
str
):
raise
TypeError
(
"keys must be strings"
)
tmp
[
key
]
=
DomainTuple
.
make
(
value
)
domain
=
frozendict
(
tmp
)
print
tmp
obj
=
MultiDomain
.
_domainCache
.
get
(
domain
)
if
obj
is
not
None
:
return
obj
obj
=
MultiDomain
(
domain
,
_callingfrommake
=
True
)
MultiDomain
.
_domainCache
[
domain
]
=
obj
return
obj
nifty4/multi/multi_field.py
View file @
a7bc5e41
...
...
@@ -85,21 +85,14 @@ class MultiField(object):
return
{
key
:
dtype
for
key
in
domain
.
keys
()}
@
staticmethod
def
zeros
(
domain
,
dtype
=
None
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
domain
)
return
MultiField
({
key
:
Field
.
zeros
(
dom
,
dtype
=
dtype
[
key
])
for
key
,
dom
in
domain
.
items
()})
@
staticmethod
def
ones
(
domain
,
dtype
=
None
):
def
empty
(
domain
,
dtype
=
None
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
domain
)
return
MultiField
({
key
:
Field
.
ones
(
dom
,
dtype
=
dtype
[
key
])
return
MultiField
({
key
:
Field
.
empty
(
dom
,
dtype
=
dtype
[
key
])
for
key
,
dom
in
domain
.
items
()})
@
staticmethod
def
empty
(
domain
,
dtype
=
None
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
domain
)
return
MultiField
({
key
:
Field
.
empty
(
dom
,
dtype
=
dtype
[
key
])
def
full
(
domain
,
val
):
return
MultiField
({
key
:
Field
.
full
(
dom
,
val
)
for
key
,
dom
in
domain
.
items
()})
def
norm
(
self
):
...
...
nifty4/operators/linear_operator.py
View file @
a7bc5e41
...
...
@@ -271,10 +271,6 @@ class LinearOperator(NiftyMetaBase()):
raise
ValueError
(
"requested operator mode is not supported"
)
def
_check_input
(
self
,
x
,
mode
):
# MR FIXME: temporary fix for working with MultiFields
#if not isinstance(x, Field):
# raise ValueError("supplied object is not a `Field`.")
self
.
_check_mode
(
mode
)
if
x
.
domain
!=
self
.
_dom
(
mode
):
raise
ValueError
(
"The operator's and field's domains don't match."
)
nifty4/operators/scaling_operator.py
View file @
a7bc5e41
...
...
@@ -62,7 +62,7 @@ class ScalingOperator(EndomorphicOperator):
if
self
.
_factor
==
1.
:
return
x
.
copy
()
if
self
.
_factor
==
0.
:
return
x
.
zeros_like
(
x
)
return
x
*
0.
if
mode
==
self
.
TIMES
:
return
x
*
self
.
_factor
...
...
nifty4/sugar.py
View file @
a7bc5e41
...
...
@@ -19,16 +19,18 @@
import
numpy
as
np
from
.domains.power_space
import
PowerSpace
from
.field
import
Field
from
multi.multi_field
import
MultiField
from
multi.multi_domain
import
MultiDomain
from
.operators.diagonal_operator
import
DiagonalOperator
from
.operators.power_distributor
import
PowerDistributor
from
.domain_tuple
import
DomainTuple
from
.
import
dobj
,
utilities
from
.logger
import
logger
__all__
=
[
'PS_field'
,
'
power_analyze
'
,
'
create_power_operator
'
,
'
create_harmonic_smoothing_operator
'
]
__all__
=
[
'PS_field'
,
'power_analyze'
,
'create_power_operator'
,
'
create_harmonic_smoothing_operator'
,
'from_random
'
,
'
full'
,
'empty'
,
'from_global_data'
,
'from_local_data
'
,
'
makeDomain
'
]
def
PS_field
(
pspace
,
func
):
...
...
@@ -161,3 +163,39 @@ def create_harmonic_smoothing_operator(domain, space, sigma):
kfunc
=
domain
[
space
].
get_fft_smoothing_kernel_function
(
sigma
)
return
DiagonalOperator
(
kfunc
(
domain
[
space
].
get_k_length_array
()),
domain
,
space
)
def
full
(
domain
,
val
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
full
(
domain
,
val
)
return
Field
.
full
(
domain
,
val
)
def
empty
(
domain
,
dtype
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
empty
(
domain
,
dtype
)
return
Field
.
empty
(
domain
,
dtype
)
def
from_random
(
random_type
,
domain
,
dtype
=
np
.
float64
,
**
kwargs
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
from_random
(
random_type
,
domain
,
dtype
,
**
kwargs
)
return
Field
.
from_random
(
random_type
,
domain
,
dtype
,
**
kwargs
)
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
from_global_data
(
domain
,
arr
,
sum_up
)
return
Field
.
from_global_data
(
domain
,
arr
,
sum_up
)
def
from_local_data
(
domain
,
arr
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
from_local_data
(
domain
,
arr
)
return
Field
.
from_local_data
(
domain
,
arr
)
def
makeDomain
(
domain
):
if
isinstance
(
domain
,
dict
):
return
MultiDomain
.
make
(
domain
)
return
DomainTuple
.
make
(
domain
)
test/test_energies/test_power.py
View file @
a7bc5e41
...
...
@@ -50,7 +50,7 @@ class Energy_Tests(unittest.TestCase):
n
=
ift
.
Field
.
from_random
(
domain
=
space
,
random_type
=
'normal'
)
s
=
ht
(
xi
*
A
)
R
=
ift
.
ScalingOperator
(
10.
,
space
)
diag
=
ift
.
Field
.
ones
(
space
)
diag
=
ift
.
full
(
space
,
1.
)
N
=
ift
.
DiagonalOperator
(
diag
)
d
=
R
(
f
(
s
))
+
n
...
...
test/test_field.py
View file @
a7bc5e41
...
...
@@ -130,18 +130,18 @@ class Test_Functionality(unittest.TestCase):
assert_equal
(
f
.
local_data
,
27
)
assert_equal
(
f
.
shape
,
(
200
,))
assert_equal
(
f
.
dtype
,
np
.
int
)
fx
=
ift
.
Field
.
empty_like
(
f
)
fx
=
ift
.
empty
(
f
.
domain
,
f
.
dtype
)
assert_equal
(
f
.
dtype
,
fx
.
dtype
)
assert_equal
(
f
.
shape
,
fx
.
shape
)
fx
=
ift
.
Field
.
zeros_like
(
f
)
fx
=
ift
.
full
(
f
.
domain
,
0
)
assert_equal
(
f
.
dtype
,
fx
.
dtype
)
assert_equal
(
f
.
shape
,
fx
.
shape
)
assert_equal
(
fx
.
local_data
,
0
)
fx
=
ift
.
Field
.
ones_like
(
f
)
fx
=
ift
.
full
(
f
.
domain
,
1
)
assert_equal
(
f
.
dtype
,
fx
.
dtype
)
assert_equal
(
f
.
shape
,
fx
.
shape
)
assert_equal
(
fx
.
local_data
,
1
)
fx
=
ift
.
Field
.
full_like
(
f
,
67.
)
fx
=
ift
.
full
(
f
.
domain
,
67.
)
assert_equal
(
f
.
shape
,
fx
.
shape
)
assert_equal
(
fx
.
local_data
,
67.
)
f
=
ift
.
Field
.
from_random
(
"normal"
,
s
)
...
...
test/test_minimization/test_minimizers.py
View file @
a7bc5e41
...
...
@@ -53,7 +53,7 @@ class Test_Minimizers(unittest.TestCase):
covariance_diagonal
=
ift
.
Field
.
from_random
(
'uniform'
,
domain
=
space
)
+
0.5
covariance
=
ift
.
DiagonalOperator
(
covariance_diagonal
)
required_result
=
ift
.
Field
.
ones
(
space
,
dtype
=
np
.
float64
)
required_result
=
ift
.
full
(
space
,
1.
)
try
:
minimizer
=
eval
(
minimizer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment