Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
N
NIFTy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
10
Issues
10
List
Boards
Labels
Service Desk
Milestones
Merge Requests
9
Merge Requests
9
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
NIFTy
Commits
67660e26
Commit
67660e26
authored
May 22, 2018
by
Martin Reinecke
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'static_restructure' into 'NIFTy_4'
Static restructure See merge request ift/NIFTy!259
parents
096f619e
36640cc0
Pipeline
#29556
passed with stages
in 4 minutes and 45 seconds
Changes
34
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
34 changed files
with
554 additions
and
151 deletions
+554
-151
demos/critical_filtering.py
demos/critical_filtering.py
+2
-2
demos/krylov_sampling.py
demos/krylov_sampling.py
+2
-2
demos/nonlinear_critical_filter.py
demos/nonlinear_critical_filter.py
+2
-2
demos/nonlinear_wiener_filter.py
demos/nonlinear_wiener_filter.py
+2
-2
demos/poisson_demo.py
demos/poisson_demo.py
+1
-1
demos/wiener_filter_via_hamiltonian.py
demos/wiener_filter_via_hamiltonian.py
+1
-1
nifty4/__init__.py
nifty4/__init__.py
+1
-1
nifty4/data_objects/distributed_do.py
nifty4/data_objects/distributed_do.py
+15
-5
nifty4/domain_tuple.py
nifty4/domain_tuple.py
+4
-2
nifty4/domains/domain.py
nifty4/domains/domain.py
+8
-4
nifty4/domains/lm_space.py
nifty4/domains/lm_space.py
+3
-1
nifty4/domains/rg_space.py
nifty4/domains/rg_space.py
+2
-1
nifty4/extra/operator_tests.py
nifty4/extra/operator_tests.py
+15
-8
nifty4/field.py
nifty4/field.py
+20
-81
nifty4/library/krylov_sampling.py
nifty4/library/krylov_sampling.py
+1
-1
nifty4/library/noise_energy.py
nifty4/library/noise_energy.py
+2
-1
nifty4/library/nonlinear_power_energy.py
nifty4/library/nonlinear_power_energy.py
+1
-1
nifty4/library/nonlinearities.py
nifty4/library/nonlinearities.py
+3
-3
nifty4/library/poisson_energy.py
nifty4/library/poisson_energy.py
+1
-1
nifty4/logger.py
nifty4/logger.py
+1
-0
nifty4/multi/__init__.py
nifty4/multi/__init__.py
+2
-1
nifty4/multi/block_diagonal_operator.py
nifty4/multi/block_diagonal_operator.py
+55
-0
nifty4/multi/multi_domain.py
nifty4/multi/multi_domain.py
+72
-2
nifty4/multi/multi_field.py
nifty4/multi/multi_field.py
+38
-10
nifty4/operators/chain_operator.py
nifty4/operators/chain_operator.py
+11
-0
nifty4/operators/inversion_enabler.py
nifty4/operators/inversion_enabler.py
+1
-1
nifty4/operators/linear_operator.py
nifty4/operators/linear_operator.py
+0
-4
nifty4/operators/scaling_operator.py
nifty4/operators/scaling_operator.py
+1
-1
nifty4/operators/sum_operator.py
nifty4/operators/sum_operator.py
+22
-0
nifty4/sugar.py
nifty4/sugar.py
+74
-4
test/test_energies/test_power.py
test/test_energies/test_power.py
+1
-1
test/test_field.py
test/test_field.py
+121
-5
test/test_minimization/test_minimizers.py
test/test_minimization/test_minimizers.py
+2
-2
test/test_multi_field.py
test/test_multi_field.py
+67
-0
No files found.
demos/critical_filtering.py
View file @
67660e26
...
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d
=
noiseless_data
+
n
m0
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
Field
.
full
(
p_space
,
-
4.
)
m0
=
ift
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
full
(
p_space
,
-
4.
)
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
plotdict
=
{
"colormap"
:
"Planck-like"
}
...
...
demos/krylov_sampling.py
View file @
67660e26
...
...
@@ -67,8 +67,8 @@ plt.legend()
plt
.
savefig
(
'Krylov_samples_residuals.png'
)
plt
.
close
()
D_hat_old
=
ift
.
Field
.
zeros
(
x_space
).
to_global_data
()
D_hat_new
=
ift
.
Field
.
zeros
(
x_space
).
to_global_data
()
D_hat_old
=
ift
.
full
(
x_space
,
0.
).
to_global_data
()
D_hat_new
=
ift
.
full
(
x_space
,
0.
).
to_global_data
()
for
i
in
range
(
N_samps
):
D_hat_old
+=
sky
(
samps_old
[
i
]).
to_global_data
()
**
2
D_hat_new
+=
sky
(
samps
[
i
]).
to_global_data
()
**
2
...
...
demos/nonlinear_critical_filter.py
View file @
67660e26
...
...
@@ -69,8 +69,8 @@ if __name__ == "__main__":
# Creating the mock data
d
=
noiseless_data
+
n
m0
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
Field
.
full
(
p_space
,
-
4.
)
m0
=
ift
.
full
(
h_space
,
1e-7
)
t0
=
ift
.
full
(
p_space
,
-
4.
)
power0
=
Distributor
.
times
(
ift
.
exp
(
0.5
*
t0
))
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
100
,
...
...
demos/nonlinear_wiener_filter.py
View file @
67660e26
...
...
@@ -36,7 +36,7 @@ if __name__ == "__main__":
d_space
=
R
.
target
p_op
=
ift
.
create_power_operator
(
h_space
,
p_spec
)
power
=
ift
.
sqrt
(
p_op
(
ift
.
Field
.
full
(
h_space
,
1.
)))
power
=
ift
.
sqrt
(
p_op
(
ift
.
full
(
h_space
,
1.
)))
# Creating the mock data
true_sky
=
nonlinearity
(
HT
(
power
*
sh
))
...
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
inverter
=
ift
.
ConjugateGradient
(
controller
=
ICI
)
# initial guess
m
=
ift
.
Field
.
full
(
h_space
,
1e-7
)
m
=
ift
.
full
(
h_space
,
1e-7
)
map_energy
=
ift
.
library
.
NonlinearWienerFilterEnergy
(
m
,
d
,
R
,
nonlinearity
,
HT
,
power
,
N
,
S
,
inverter
=
inverter
)
...
...
demos/poisson_demo.py
View file @
67660e26
...
...
@@ -113,7 +113,7 @@ if __name__ == "__main__":
d_domain
,
np
.
random
.
poisson
(
lam
.
local_data
).
astype
(
np
.
float64
))
# initial guess
psi0
=
ift
.
Field
.
full
(
h_domain
,
1e-7
)
psi0
=
ift
.
full
(
h_domain
,
1e-7
)
energy
=
ift
.
library
.
PoissonEnergy
(
psi0
,
data
,
R0
,
nonlin
,
HT
,
Phi_h
,
inverter
)
IC1
=
ift
.
GradientNormController
(
name
=
"IC1"
,
iteration_limit
=
200
,
...
...
demos/wiener_filter_via_hamiltonian.py
View file @
67660e26
...
...
@@ -50,7 +50,7 @@ if __name__ == "__main__":
inverter
=
ift
.
ConjugateGradient
(
controller
=
ctrl
)
controller
=
ift
.
GradientNormController
(
name
=
"min"
,
tol_abs_gradnorm
=
0.1
)
minimizer
=
ift
.
RelaxedNewton
(
controller
=
controller
)
m0
=
ift
.
Field
.
zeros
(
h_space
)
m0
=
ift
.
full
(
h_space
,
0.
)
# Initialize Wiener filter energy
energy
=
ift
.
library
.
WienerFilterEnergy
(
position
=
m0
,
d
=
d
,
R
=
R
,
N
=
N
,
S
=
S
,
...
...
nifty4/__init__.py
View file @
67660e26
...
...
@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple
from
.operators
import
*
from
.field
import
Field
,
sqrt
,
exp
,
log
from
.field
import
Field
from
.probing.utils
import
probe_with_posterior_samples
,
probe_diagonal
,
\
StatCalculator
...
...
nifty4/data_objects/distributed_do.py
View file @
67660e26
...
...
@@ -20,6 +20,7 @@ import numpy as np
from
.random
import
Random
from
mpi4py
import
MPI
import
sys
from
functools
import
reduce
_comm
=
MPI
.
COMM_WORLD
ntask
=
_comm
.
Get_size
()
...
...
@@ -145,20 +146,29 @@ class data_object(object):
def
sum
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"sum"
,
MPI
.
SUM
,
axis
)
def
prod
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"prod"
,
MPI
.
PROD
,
axis
)
def
min
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"min"
,
MPI
.
MIN
,
axis
)
def
max
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"max"
,
MPI
.
MAX
,
axis
)
def
mean
(
self
):
return
self
.
sum
()
/
self
.
size
def
mean
(
self
,
axis
=
None
):
if
axis
is
None
:
sz
=
self
.
size
else
:
sz
=
reduce
(
lambda
x
,
y
:
x
*
y
,
[
self
.
shape
[
i
]
for
i
in
axis
])
return
self
.
sum
(
axis
)
/
sz
def
std
(
self
):
return
np
.
sqrt
(
self
.
var
())
def
std
(
self
,
axis
=
None
):
return
np
.
sqrt
(
self
.
var
(
axis
))
# FIXME: to be improved!
def
var
(
self
):
def
var
(
self
,
axis
=
None
):
if
axis
is
not
None
and
len
(
axis
)
!=
len
(
self
.
shape
):
raise
ValueError
(
"functionality not yet supported"
)
return
(
abs
(
self
-
self
.
mean
())
**
2
).
mean
()
def
_binary_helper
(
self
,
other
,
op
):
...
...
nifty4/domain_tuple.py
View file @
67660e26
...
...
@@ -34,7 +34,9 @@ class DomainTuple(object):
"""
_tupleCache
=
{}
def
__init__
(
self
,
domain
):
def
__init__
(
self
,
domain
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
self
.
_dom
=
self
.
_parse_domain
(
domain
)
self
.
_axtuple
=
self
.
_get_axes_tuple
()
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
_dom
)
...
...
@@ -72,7 +74,7 @@ class DomainTuple(object):
obj
=
DomainTuple
.
_tupleCache
.
get
(
domain
)
if
obj
is
not
None
:
return
obj
obj
=
DomainTuple
(
domain
)
obj
=
DomainTuple
(
domain
,
_callingfrommake
=
True
)
DomainTuple
.
_tupleCache
[
domain
]
=
obj
return
obj
...
...
nifty4/domains/domain.py
View file @
67660e26
...
...
@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase
class
Domain
(
NiftyMetaBase
()):
"""The abstract class repesenting a (structured or unstructured) domain.
"""
def
__init__
(
self
):
self
.
_hash
=
None
@
abc
.
abstractmethod
def
__repr__
(
self
):
...
...
@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()):
Only members that are explicitly added to
:attr:`._needed_for_hash` will be used for hashing.
"""
result_hash
=
0
for
key
in
self
.
_needed_for_hash
:
result_hash
^=
hash
(
vars
(
self
)[
key
])
return
result_hash
if
self
.
_hash
is
None
:
h
=
0
for
key
in
self
.
_needed_for_hash
:
h
^=
hash
(
vars
(
self
)[
key
])
self
.
_hash
=
h
return
self
.
_hash
def
__eq__
(
self
,
x
):
"""Checks whether two domains are equal.
...
...
nifty4/domains/lm_space.py
View file @
67660e26
...
...
@@ -19,7 +19,7 @@
from
__future__
import
division
import
numpy
as
np
from
.structured_domain
import
StructuredDomain
from
..field
import
Field
,
exp
from
..field
import
Field
class
LMSpace
(
StructuredDomain
):
...
...
@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain):
# cf. "All-sky convolution for polarimetry experiments"
# by Challinor et al.
# http://arxiv.org/abs/astro-ph/0008228
from
..sugar
import
exp
res
=
x
+
1.
res
*=
x
res
*=
-
0.5
*
sigma
*
sigma
...
...
nifty4/domains/rg_space.py
View file @
67660e26
...
...
@@ -21,7 +21,7 @@ from builtins import range
from
functools
import
reduce
import
numpy
as
np
from
.structured_domain
import
StructuredDomain
from
..field
import
Field
,
exp
from
..field
import
Field
from
..
import
dobj
...
...
@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain):
@
staticmethod
def
_kernel
(
x
,
sigma
):
from
..sugar
import
exp
tmp
=
x
*
x
tmp
*=
-
2.
*
np
.
pi
*
np
.
pi
*
sigma
*
sigma
exp
(
tmp
,
out
=
tmp
)
...
...
nifty4/extra/operator_tests.py
View file @
67660e26
...
...
@@ -17,17 +17,26 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..sugar
import
from_random
from
..field
import
Field
__all__
=
[
"consistency_check"
]
def
_assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
local_data
,
f2
.
local_data
,
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
_assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
def
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
f2
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
).
lock
())
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
...
...
@@ -37,15 +46,13 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
).
lock
()
res
=
op
(
op
.
inverse_times
(
foo
).
lock
())
np
.
testing
.
assert_allclose
(
res
.
to_global_data
(),
res
.
to_global_data
(),
atol
=
atol
,
rtol
=
rtol
)
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
).
lock
()
res
=
op
.
inverse_times
(
op
(
foo
).
lock
())
np
.
testing
.
assert_allclose
(
res
.
to_global_data
(),
foo
.
to_global_data
(),
atol
=
atol
,
rtol
=
rtol
)
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
def
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
...
...
nifty4/field.py
View file @
67660e26
...
...
@@ -106,62 +106,10 @@ class Field(object):
raise
TypeError
(
"val must be a scalar"
)
return
Field
(
DomainTuple
.
make
(
domain
),
val
,
dtype
)
@
staticmethod
def
ones
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
1.
,
dtype
)
@
staticmethod
def
zeros
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
0.
,
dtype
)
@
staticmethod
def
empty
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
None
,
dtype
)
@
staticmethod
def
full_like
(
field
,
val
,
dtype
=
None
):
"""Creates a Field from a template, filled with a constant value.
Parameters
----------
field : Field
the template field, from which the domain is inferred
val : float/complex/int scalar
fill value. Data type of the field is inferred from val.
Returns
-------
Field
the newly created field
"""
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
return
Field
.
full
(
field
.
_domain
,
val
,
dtype
)
@
staticmethod
def
zeros_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
zeros
(
field
.
_domain
,
dtype
)
@
staticmethod
def
ones_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
ones
(
field
.
_domain
,
dtype
)
@
staticmethod
def
empty_like
(
field
,
dtype
=
None
):
if
not
isinstance
(
field
,
Field
):
raise
TypeError
(
"field must be of Field type"
)
if
dtype
is
None
:
dtype
=
field
.
dtype
return
Field
.
empty
(
field
.
_domain
,
dtype
)
@
staticmethod
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
"""Returns a Field constructed from `domain` and `arr`.
...
...
@@ -287,6 +235,7 @@ class Field(object):
The value to fill the field with.
"""
self
.
_val
.
fill
(
fill_value
)
return
self
def
lock
(
self
):
"""Write-protect the data content of `self`.
...
...
@@ -370,6 +319,17 @@ class Field(object):
"""
return
Field
(
val
=
self
,
copy
=
True
)
def
empty_copy
(
self
):
""" Returns a Field with identical domain and data type, but
uninitialized data.
Returns
-------
Field
A copy of 'self', with uninitialized data.
"""
return
Field
(
self
.
_domain
,
dtype
=
self
.
dtype
)
def
locked_copy
(
self
):
""" Returns a read-only version of the Field.
...
...
@@ -503,8 +463,8 @@ class Field(object):
or Field (for partial dot products)
"""
if
not
isinstance
(
x
,
Field
):
raise
Valu
eError
(
"The dot-partner must be an instance of "
+
"the NIFTy field class"
)
raise
Typ
eError
(
"The dot-partner must be an instance of "
+
"the NIFTy field class"
)
if
x
.
_domain
!=
self
.
_domain
:
raise
ValueError
(
"Domain mismatch"
)
...
...
@@ -694,7 +654,8 @@ class Field(object):
if
self
.
scalar_weight
(
spaces
)
is
not
None
:
return
self
.
_contraction_helper
(
'mean'
,
spaces
)
# MR FIXME: not very efficient
tmp
=
self
.
weight
(
1
)
# MR FIXME: do we need "spaces" here?
tmp
=
self
.
weight
(
1
,
spaces
)
return
tmp
.
sum
(
spaces
)
*
(
1.
/
tmp
.
total_volume
(
spaces
))
def
var
(
self
,
spaces
=
None
):
...
...
@@ -717,12 +678,10 @@ class Field(object):
# MR FIXME: not very efficient or accurate
m1
=
self
.
mean
(
spaces
)
if
np
.
issubdtype
(
self
.
dtype
,
np
.
complexfloating
):
sq
=
abs
(
self
)
**
2
m1
=
abs
(
m1
)
**
2
sq
=
abs
(
self
-
m1
)
**
2
else
:
sq
=
self
**
2
m1
**=
2
return
sq
.
mean
(
spaces
)
-
m1
sq
=
(
self
-
m1
)
**
2
return
sq
.
mean
(
spaces
)
def
std
(
self
,
spaces
=
None
):
"""Determines the standard deviation over the sub-domains given by
...
...
@@ -742,6 +701,7 @@ class Field(object):
The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
"""
from
.sugar
import
sqrt
if
self
.
scalar_weight
(
spaces
)
is
not
None
:
return
self
.
_contraction_helper
(
'std'
,
spaces
)
return
sqrt
(
self
.
var
(
spaces
))
...
...
@@ -785,24 +745,3 @@ for op in ["__add__", "__radd__", "__iadd__",
return
NotImplemented
return
func2
setattr
(
Field
,
op
,
func
(
op
))
# Arithmetic functions working on Fields
_current_module
=
sys
.
modules
[
__name__
]
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"conjugate"
]:
def
func
(
f
):
def
func2
(
x
,
out
=
None
):
fu
=
getattr
(
dobj
,
f
)
if
not
isinstance
(
x
,
Field
):
raise
TypeError
(
"This function only accepts Field objects."
)
if
out
is
not
None
:
if
not
isinstance
(
out
,
Field
)
or
x
.
_domain
!=
out
.
_domain
:
raise
ValueError
(
"Bad 'out' argument"
)
fu
(
x
.
val
,
out
=
out
.
val
)
return
out
else
:
return
Field
(
domain
=
x
.
_domain
,
val
=
fu
(
x
.
val
))
return
func2
setattr
(
_current_module
,
f
,
func
(
f
))
nifty4/library/krylov_sampling.py
View file @
67660e26
...
...
@@ -54,7 +54,7 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
"""
# RL FIXME: make consistent with complex numbers
j
=
S
.
draw_sample
(
from_inverse
=
True
)
if
j
is
None
else
j
energy
=
QuadraticEnergy
(
j
*
0.
,
D_inv
,
j
)
energy
=
QuadraticEnergy
(
j
.
empty_copy
().
fill
(
0.
)
,
D_inv
,
j
)
y
=
[
S
.
draw_sample
()
for
_
in
range
(
N_samps
)]
status
=
controller
.
start
(
energy
)
...
...
nifty4/library/noise_energy.py
View file @
67660e26
...
...
@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
..field
import
Field
,
exp
from
..field
import
Field
from
..sugar
import
exp
from
..minimization.energy
import
Energy
from
..operators.diagonal_operator
import
DiagonalOperator
import
numpy
as
np
...
...
nifty4/library/nonlinear_power_energy.py
View file @
67660e26
...
...
@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
..
import
exp
from
..
sugar
import
exp
from
..minimization.energy
import
Energy
from
..operators.smoothness_operator
import
SmoothnessOperator
from
..operators.inversion_enabler
import
InversionEnabler
...
...
nifty4/library/nonlinearities.py
View file @
67660e26
...
...
@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
..
field
import
Field
,
exp
,
tanh
from
..
sugar
import
full
,
exp
,
tanh
class
Linear
(
object
):
...
...
@@ -24,10 +24,10 @@ class Linear(object):
return
x
def
derivative
(
self
,
x
):
return
Field
.
ones_like
(
x
)
return
full
(
x
.
domain
,
1.
)
def
hessian
(
self
,
x
):
return
Field
.
zeros_like
(
x
)
return
full
(
x
.
domain
,
0.
)
class
Exponential
(
object
):
...
...
nifty4/library/poisson_energy.py
View file @
67660e26
...
...
@@ -20,7 +20,7 @@ from ..minimization.energy import Energy
from
..operators.diagonal_operator
import
DiagonalOperator
from
..operators.sandwich_operator
import
SandwichOperator
from
..operators.inversion_enabler
import
InversionEnabler
from
..
field
import
log
from
..
sugar
import
log
class
PoissonEnergy
(
Energy
):
...
...
nifty4/logger.py
View file @
67660e26
...
...
@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
def
_logger_init
():
import
logging
from
.
import
dobj
...
...
nifty4/multi/__init__.py
View file @
67660e26
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.block_diagonal_operator
import
BlockDiagonalOperator
__all__
=
[
"MultiDomain"
,
"MultiField"
]
__all__
=
[
"MultiDomain"
,
"MultiField"
,
"BlockDiagonalOperator"
]
nifty4/multi/block_diagonal_operator.py
0 → 100644
View file @
67660e26
import
numpy
as
np
from
..operators.endomorphic_operator
import
EndomorphicOperator
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
class
BlockDiagonalOperator
(
EndomorphicOperator
):
def
__init__
(
self
,
operators
):
"""
Parameters
----------
operators : dict
dictionary with operators domain names as keys and
LinearOperators as items
"""
super
(
BlockDiagonalOperator
,
self
).
__init__
()
self
.
_operators
=
operators
self
.
_domain
=
MultiDomain
.
make
(
{
key
:
op
.
domain
for
key
,
op
in
self
.
_operators
.
items
()})
self
.
_cap
=
self
.
_all_ops
for
op
in
self
.
_operators
.
values
():
self
.
_cap
&=
op
.
capability
@
property
def
domain
(
self
):
return
self
.
_domain
@
property
def
capability
(
self
):
return
self
.
_cap
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
return
MultiField
({
key
:
op
.
apply
(
x
[
key
],
mode
=
mode
)
for
key
,
op
in
self
.
_operators
.
items
()})
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
self
.
_domain
)
return
MultiField
({
key
:
op
.
draw_sample
(
from_inverse
,
dtype
[
key
])
for
key
,
op
in
self
.
_operators
.
items
()})
def
_combine_chain
(
self
,
op
):
res
=
{}
for
key
in
self
.
_operators
.
keys
():
res
[
key
]
=
self
.
_operators
[
key
]
*
op
.
_operators
[
key
]
return
BlockDiagonalOperator
(
res
)
def
_combine_sum
(
self
,
op
,
selfneg
,
opneg
):
from
..operators.sum_operator
import
SumOperator
res
=
{}
for
key
in
self
.
_operators
.
keys
():
res
[
key
]
=
SumOperator
.
make
([
self
.
_operators
[
key
],
op
.
_operators
[
key
]],
[
selfneg
,
opneg
])
return
BlockDiagonalOperator
(
res
)
nifty4/multi/multi_domain.py
View file @
67660e26
class
MultiDomain
(
dict
):
pass
import
collections
from
..domain_tuple
import
DomainTuple
__all
=
[
"MultiDomain"
]
class
frozendict
(
collections
.
Mapping
):
"""
An immutable wrapper around dictionaries that implements the complete
:py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
"""
dict_cls
=
dict
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_dict
=
self
.
dict_cls
(
*
args
,
**
kwargs
)
self
.
_hash
=
None
def
__getitem__
(
self
,
key
):
return
self
.
_dict
[
key
]
def
__contains__
(
self
,
key
):
return
key
in
self
.
_dict
def
copy
(
self
,
**
add_or_replace
):
return
self
.
__class__
(
self
,
**
add_or_replace
)
def
__iter__
(
self
):
return
iter
(
self
.
_dict
)
def
__len__
(
self
):
return
len
(
self
.
_dict
)
def
__repr__
(
self
):
return
'<%s %r>'
%
(
self
.
__class__
.
__name__
,
self
.
_dict
)
def
__hash__
(
self
):
if
self
.
_hash
is
None
:
h
=
0
for
key
,
value
in
self
.
_dict
.
items
():
h
^=
hash
((
key
,
value
))
self
.
_hash
=
h
return
self
.
_hash