Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
04c80477
Commit
04c80477
authored
Mar 27, 2018
by
Martin Reinecke
Browse files
Merge branch 'diag_hack' into 'NIFTy_4'
More aggressive combination of diagonal operators See merge request ift/NIFTy!235
parents
6c61cbec
e6b49f93
Pipeline
#26517
passed with stages
in 5 minutes and 30 seconds
Changes
10
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
demos/nonlinear_wiener_filter.py
View file @
04c80477
...
...
@@ -35,7 +35,8 @@ if __name__ == "__main__":
d_space
=
R
.
target
power
=
ift
.
sqrt
(
ift
.
create_power_operator
(
h_space
,
p_spec
).
diagonal
)
p_op
=
ift
.
create_power_operator
(
h_space
,
p_spec
)
power
=
ift
.
sqrt
(
p_op
(
ift
.
Field
.
full
(
h_space
,
1.
)))
# Creating the mock data
true_sky
=
nonlinearity
(
HT
(
power
*
sh
))
...
...
nifty4/library/nonlinear_power_energy.py
View file @
04c80477
...
...
@@ -59,7 +59,7 @@ class NonlinearPowerEnergy(Energy):
self
.
D
=
D
self
.
d
=
d
self
.
N
=
N
self
.
T
=
SmoothnessOperator
(
domain
=
self
.
position
.
domain
[
0
],
self
.
T
=
SmoothnessOperator
(
domain
=
position
.
domain
[
0
],
strength
=
sigma
,
logarithmic
=
True
)
self
.
ht
=
ht
self
.
Instrument
=
Instrument
...
...
@@ -76,19 +76,15 @@ class NonlinearPowerEnergy(Energy):
self
.
inverter
=
inverter
A
=
Distributor
(
exp
(.
5
*
position
))
map_s
=
self
.
ht
(
A
*
xi
)
Tpos
=
self
.
T
(
position
)
self
.
_gradient
=
None
for
xi_sample
in
self
.
xi_sample_list
:
map_s
=
self
.
ht
(
A
*
xi_sample
)
LinR
=
LinearizedPowerResponse
(
self
.
Instrument
,
self
.
nonlinearity
,
self
.
ht
,
self
.
Distributor
,
self
.
position
,
xi_sample
)
map_s
=
ht
(
A
*
xi_sample
)
LinR
=
LinearizedPowerResponse
(
Instrument
,
nonlinearity
,
ht
,
Distributor
,
position
,
xi_sample
)
residual
=
self
.
d
-
\
self
.
Instrument
(
self
.
nonlinearity
(
map_s
))
tmp
=
self
.
N
.
inverse_times
(
residual
)
residual
=
d
-
Instrument
(
nonlinearity
(
map_s
))
tmp
=
N
.
inverse_times
(
residual
)
lh
=
0.5
*
residual
.
vdot
(
tmp
)
grad
=
LinR
.
adjoint_times
(
tmp
)
...
...
@@ -100,7 +96,8 @@ class NonlinearPowerEnergy(Energy):
self
.
_gradient
+=
grad
self
.
_value
*=
1.
/
len
(
self
.
xi_sample_list
)
self
.
_value
+=
0.5
*
self
.
position
.
vdot
(
Tpos
)
Tpos
=
self
.
T
(
position
)
self
.
_value
+=
0.5
*
position
.
vdot
(
Tpos
)
self
.
_gradient
*=
-
1.
/
len
(
self
.
xi_sample_list
)
self
.
_gradient
+=
Tpos
self
.
_gradient
.
lock
()
...
...
nifty4/library/nonlinear_wiener_filter_energy.py
View file @
04c80477
...
...
@@ -31,20 +31,18 @@ class NonlinearWienerFilterEnergy(Energy):
self
.
nonlinearity
=
nonlinearity
self
.
ht
=
ht
self
.
power
=
power
m
=
self
.
ht
(
self
.
power
*
self
.
position
)
self
.
LinearizedResponse
=
LinearizedSignalResponse
(
Instrument
,
nonlinearity
,
ht
,
power
,
m
)
m
=
ht
(
power
*
position
)
residual
=
d
-
Instrument
(
nonlinearity
(
m
))
self
.
N
=
N
self
.
S
=
S
self
.
inverter
=
inverter
t1
=
self
.
S
.
inverse_times
(
self
.
position
)
t2
=
self
.
N
.
inverse_times
(
residual
)
tmp
=
self
.
position
.
vdot
(
t1
)
+
residual
.
vdot
(
t2
)
self
.
_value
=
0.5
*
tmp
.
real
self
.
_gradient
=
t1
-
self
.
LinearizedResponse
.
adjoint_times
(
t2
)
self
.
_gradient
.
lock
()
t1
=
S
.
inverse_times
(
position
)
t2
=
N
.
inverse_times
(
residual
)
self
.
_value
=
0.5
*
(
position
.
vdot
(
t1
)
+
residual
.
vdot
(
t2
)
).
real
self
.
R
=
LinearizedSignalResponse
(
Instrument
,
nonlinearity
,
ht
,
power
,
m
)
self
.
_gradient
=
(
t1
-
self
.
R
.
adjoint_times
(
t2
))
.
lock
()
def
at
(
self
,
position
):
return
self
.
__class__
(
position
,
self
.
d
,
self
.
Instrument
,
...
...
@@ -62,5 +60,4 @@ class NonlinearWienerFilterEnergy(Energy):
@
property
@
memo
def
curvature
(
self
):
return
WienerFilterCurvature
(
R
=
self
.
LinearizedResponse
,
N
=
self
.
N
,
S
=
self
.
S
,
inverter
=
self
.
inverter
)
return
WienerFilterCurvature
(
self
.
R
,
self
.
N
,
self
.
S
,
self
.
inverter
)
nifty4/library/wiener_filter_energy.py
View file @
04c80477
...
...
@@ -51,12 +51,11 @@ class WienerFilterEnergy(Energy):
self
.
_curvature
=
WienerFilterCurvature
(
R
,
N
,
S
,
inverter
)
self
.
_inverter
=
inverter
if
_j
is
None
:
_j
=
self
.
R
.
adjoint_times
(
self
.
N
.
inverse_times
(
d
))
_j
=
R
.
adjoint_times
(
N
.
inverse_times
(
d
))
self
.
_j
=
_j
Dx
=
self
.
_curvature
(
self
.
position
)
self
.
_value
=
0.5
*
self
.
position
.
vdot
(
Dx
)
-
self
.
_j
.
vdot
(
self
.
position
)
self
.
_gradient
=
Dx
-
self
.
_j
self
.
_gradient
.
lock
()
self
.
_value
=
0.5
*
position
.
vdot
(
Dx
)
-
self
.
_j
.
vdot
(
position
)
self
.
_gradient
=
(
Dx
-
self
.
_j
).
lock
()
def
at
(
self
,
position
):
return
self
.
__class__
(
position
=
position
,
d
=
None
,
R
=
self
.
R
,
N
=
self
.
N
,
...
...
nifty4/operators/chain_operator.py
View file @
04c80477
...
...
@@ -61,9 +61,7 @@ class ChainOperator(LinearOperator):
# try to absorb the factor into a DiagonalOperator
for
i
in
range
(
len
(
opsnew
)):
if
isinstance
(
opsnew
[
i
],
DiagonalOperator
):
opsnew
[
i
]
=
DiagonalOperator
(
opsnew
[
i
].
diagonal
*
fct
,
domain
=
opsnew
[
i
].
domain
,
spaces
=
opsnew
[
i
].
_spaces
)
opsnew
[
i
]
=
opsnew
[
i
].
_scale
(
fct
)
fct
=
1.
break
if
fct
!=
1
:
...
...
@@ -75,12 +73,8 @@ class ChainOperator(LinearOperator):
for
op
in
ops
:
if
(
len
(
opsnew
)
>
0
and
isinstance
(
opsnew
[
-
1
],
DiagonalOperator
)
and
isinstance
(
op
,
DiagonalOperator
)
and
op
.
_spaces
==
opsnew
[
-
1
].
_spaces
):
opsnew
[
-
1
]
=
DiagonalOperator
(
opsnew
[
-
1
].
diagonal
*
op
.
diagonal
,
domain
=
opsnew
[
-
1
].
domain
,
spaces
=
opsnew
[
-
1
].
_spaces
)
isinstance
(
op
,
DiagonalOperator
)):
opsnew
[
-
1
]
=
opsnew
[
-
1
].
_combine_prod
(
op
)
else
:
opsnew
.
append
(
op
)
ops
=
opsnew
...
...
@@ -120,9 +114,3 @@ class ChainOperator(LinearOperator):
for
op
in
t_ops
:
x
=
op
.
apply
(
x
,
mode
)
return
x
def
draw_sample
(
self
,
dtype
=
np
.
float64
):
sample
=
self
.
_ops
[
-
1
].
draw_sample
(
dtype
)
for
op
in
reversed
(
self
.
_ops
[:
-
1
]):
sample
=
op
.
process_sample
(
sample
)
return
sample
nifty4/operators/diagonal_operator.py
View file @
04c80477
...
...
@@ -71,32 +71,66 @@ class DiagonalOperator(EndomorphicOperator):
self
.
_spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_domain
))
if
len
(
self
.
_spaces
)
!=
len
(
diagonal
.
domain
):
raise
ValueError
(
"spaces and domain must have the same length"
)
# if nspc==len(self.diagonal.domain),
# we could do some optimization
for
i
,
j
in
enumerate
(
self
.
_spaces
):
if
diagonal
.
domain
[
i
]
!=
self
.
_domain
[
j
]:
raise
ValueError
(
"domain mismatch"
)
if
self
.
_spaces
==
tuple
(
range
(
len
(
self
.
_domain
))):
self
.
_spaces
=
None
# shortcut
self
.
_diagonal
=
diagonal
.
lock
()
if
self
.
_spaces
is
not
None
:
active_axes
=
[]
for
space_index
in
self
.
_spaces
:
active_axes
+=
self
.
_domain
.
axes
[
space_index
]
if
self
.
_spaces
[
0
]
==
0
:
self
.
_ldiag
=
self
.
_
diagonal
.
local_data
self
.
_ldiag
=
diagonal
.
local_data
else
:
self
.
_ldiag
=
self
.
_
diagonal
.
to_global_data
()
self
.
_ldiag
=
diagonal
.
to_global_data
()
locshape
=
dobj
.
local_shape
(
self
.
_domain
.
shape
,
0
)
self
.
_reshaper
=
[
shp
if
i
in
active_axes
else
1
for
i
,
shp
in
enumerate
(
locshape
)]
self
.
_ldiag
=
self
.
_ldiag
.
reshape
(
self
.
_reshaper
)
else
:
self
.
_ldiag
=
self
.
_diagonal
.
local_data
self
.
_ldiag
=
diagonal
.
local_data
self
.
_ldiag
.
flags
.
writeable
=
False
def
_skeleton
(
self
,
spc
):
res
=
DiagonalOperator
.
__new__
(
DiagonalOperator
)
res
.
_domain
=
self
.
_domain
if
self
.
_spaces
is
None
or
spc
is
None
:
res
.
_spaces
=
None
else
:
res
.
_spaces
=
tuple
(
set
(
self
.
_spaces
)
|
set
(
spc
))
return
res
def
_scale
(
self
,
fct
):
if
not
np
.
isscalar
(
fct
):
raise
TypeError
(
"scalar value required"
)
res
=
self
.
_skeleton
(())
res
.
_ldiag
=
self
.
_ldiag
*
fct
return
res
def
_add
(
self
,
sum
):
if
not
np
.
isscalar
(
sum
):
raise
TypeError
(
"scalar value required"
)
res
=
self
.
_skeleton
(())
res
.
_ldiag
=
self
.
_ldiag
+
sum
return
res
def
_combine_prod
(
self
,
op
):
if
not
isinstance
(
op
,
DiagonalOperator
):
raise
TypeError
(
"DiagonalOperator required"
)
res
=
self
.
_skeleton
(
op
.
_spaces
)
res
.
_ldiag
=
self
.
_ldiag
*
op
.
_ldiag
return
res
def
_combine_sum
(
self
,
op
,
selfneg
,
opneg
):
if
not
isinstance
(
op
,
DiagonalOperator
):
raise
TypeError
(
"DiagonalOperator required"
)
res
=
self
.
_skeleton
(
op
.
_spaces
)
res
.
_ldiag
=
(
self
.
_ldiag
*
(
-
1
if
selfneg
else
1
)
+
op
.
_ldiag
*
(
-
1
if
opneg
else
1
))
return
res
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
@@ -116,11 +150,6 @@ class DiagonalOperator(EndomorphicOperator):
else
:
return
Field
(
x
.
domain
,
val
=
x
.
val
/
self
.
_ldiag
.
conj
())
@
property
def
diagonal
(
self
):
""" Returns the diagonal of the Operator."""
return
self
.
_diagonal
@
property
def
domain
(
self
):
return
self
.
_domain
...
...
@@ -131,19 +160,16 @@ class DiagonalOperator(EndomorphicOperator):
@
property
def
inverse
(
self
):
return
DiagonalOperator
(
1.
/
self
.
_diagonal
,
self
.
_domain
,
self
.
_spaces
)
res
=
self
.
_skeleton
(())
res
.
_ldiag
=
1.
/
self
.
_ldiag
return
res
@
property
def
adjoint
(
self
):
return
DiagonalOperator
(
self
.
_diagonal
.
conjugate
(),
self
.
_domain
,
self
.
_spaces
)
def
process_sample
(
self
,
sample
):
if
np
.
issubdtype
(
self
.
_ldiag
.
dtype
,
np
.
complexfloating
):
raise
ValueError
(
"cannot draw sample from complex-valued operator"
)
res
=
Field
.
empty_like
(
sample
)
res
.
local_data
[()]
=
sample
.
local_data
*
np
.
sqrt
(
self
.
_ldiag
)
if
np
.
issubdtype
(
self
.
_ldiag
.
dtype
,
np
.
floating
):
return
self
res
=
self
.
_skeleton
(())
res
.
_ldiag
=
self
.
_ldiag
.
conjugate
()
return
res
def
draw_sample
(
self
,
dtype
=
np
.
float64
):
...
...
nifty4/operators/laplace_operator.py
View file @
04c80477
...
...
@@ -50,10 +50,8 @@ class LaplaceOperator(EndomorphicOperator):
if
not
isinstance
(
self
.
_domain
[
self
.
_space
],
PowerSpace
):
raise
ValueError
(
"Operator must act on a PowerSpace."
)
self
.
_logarithmic
=
bool
(
logarithmic
)
pos
=
self
.
domain
[
self
.
_space
].
k_lengths
.
copy
()
if
self
.
logarithmic
:
if
logarithmic
:
pos
[
1
:]
=
np
.
log
(
pos
[
1
:])
pos
[
0
]
=
pos
[
1
]
-
1.
...
...
@@ -74,10 +72,6 @@ class LaplaceOperator(EndomorphicOperator):
def
capability
(
self
):
return
self
.
TIMES
|
self
.
ADJOINT_TIMES
@
property
def
logarithmic
(
self
):
return
self
.
_logarithmic
def
_times
(
self
,
x
):
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axis
=
axes
[
0
]
...
...
nifty4/operators/scaling_operator.py
View file @
04c80477
...
...
@@ -61,7 +61,7 @@ class ScalingOperator(EndomorphicOperator):
if
self
.
_factor
==
1.
:
return
x
.
copy
()
if
self
.
_factor
==
0.
:
return
Field
.
zeros_like
(
x
,
dtype
=
x
.
dtype
)
return
Field
.
zeros_like
(
x
)
if
mode
==
self
.
TIMES
:
return
x
*
self
.
_factor
...
...
@@ -81,6 +81,8 @@ class ScalingOperator(EndomorphicOperator):
@
property
def
adjoint
(
self
):
if
np
.
issubdtype
(
type
(
self
.
_factor
),
np
.
floating
):
return
self
return
ScalingOperator
(
np
.
conj
(
self
.
_factor
),
self
.
_domain
)
@
property
...
...
@@ -93,11 +95,6 @@ class ScalingOperator(EndomorphicOperator):
return
self
.
TIMES
|
self
.
ADJOINT_TIMES
return
self
.
_all_ops
def
process_sample
(
self
,
sample
):
if
self
.
_factor
.
imag
!=
0.
or
self
.
_factor
.
real
<=
0.
:
raise
ValueError
(
"Operator not positive definite"
)
return
sample
*
np
.
sqrt
(
self
.
_factor
)
def
_sample_helper
(
self
,
fct
,
dtype
):
if
fct
.
imag
!=
0.
or
fct
.
real
<=
0.
:
raise
ValueError
(
"operator not positive definite"
)
...
...
nifty4/operators/sum_operator.py
View file @
04c80477
...
...
@@ -72,9 +72,7 @@ class SumOperator(LinearOperator):
for
i
in
range
(
len
(
opsnew
)):
if
isinstance
(
opsnew
[
i
],
DiagonalOperator
):
sum
*=
(
-
1
if
negnew
[
i
]
else
1
)
opsnew
[
i
]
=
DiagonalOperator
(
opsnew
[
i
].
diagonal
+
sum
,
domain
=
opsnew
[
i
].
domain
,
spaces
=
opsnew
[
i
].
_spaces
)
opsnew
[
i
]
=
opsnew
[
i
].
_add
(
sum
)
sum
=
0.
break
if
sum
!=
0
:
...
...
@@ -90,15 +88,15 @@ class SumOperator(LinearOperator):
for
i
in
range
(
len
(
ops
)):
if
not
processed
[
i
]:
if
isinstance
(
ops
[
i
],
DiagonalOperator
):
diag
=
ops
[
i
].
diagonal
*
(
-
1
if
neg
[
i
]
else
1
)
op
=
ops
[
i
]
opneg
=
neg
[
i
]
for
j
in
range
(
i
+
1
,
len
(
ops
)):
if
(
isinstance
(
ops
[
j
],
DiagonalOperator
)
and
ops
[
i
].
_spaces
==
ops
[
j
].
_spaces
):
diag
+=
ops
[
j
].
diagonal
*
(
-
1
if
neg
[
j
]
else
1
)
if
isinstance
(
ops
[
j
],
DiagonalOperator
)
:
op
=
op
.
_combine_sum
(
ops
[
j
],
opneg
,
neg
[
j
])
opneg
=
False
processed
[
j
]
=
True
opsnew
.
append
(
DiagonalOperator
(
diag
,
ops
[
i
].
domain
,
ops
[
i
].
_spaces
))
negnew
.
append
(
False
)
opsnew
.
append
(
op
)
negnew
.
append
(
opneg
)
else
:
opsnew
.
append
(
ops
[
i
])
negnew
.
append
(
neg
[
i
])
...
...
test/test_operators/test_diagonal_operator.py
View file @
04c80477
...
...
@@ -90,5 +90,5 @@ class DiagonalOperator_Tests(unittest.TestCase):
def
test_diagonal
(
self
,
space
):
diag
=
ift
.
Field
.
from_random
(
'normal'
,
domain
=
space
)
D
=
ift
.
DiagonalOperator
(
diag
)
diag_op
=
D
.
diagonal
diag_op
=
D
(
ift
.
Field
.
full
(
space
,
1.
))
assert_allclose
(
diag
.
to_global_data
(),
diag_op
.
to_global_data
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment