Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
N
NIFTy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
13
Issues
13
List
Boards
Labels
Service Desk
Milestones
Merge Requests
8
Merge Requests
8
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
NIFTy
Commits
c3ed466f
Commit
c3ed466f
authored
Aug 05, 2018
by
Martin Reinecke
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
no more chains
parent
369c6e7c
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
27 changed files
with
82 additions
and
94 deletions
+82
-94
demos/Wiener_Filter.ipynb
demos/Wiener_Filter.ipynb
+2
-2
demos/bernoulli_demo.py
demos/bernoulli_demo.py
+2
-2
demos/getting_started_1.py
demos/getting_started_1.py
+3
-3
demos/getting_started_2.py
demos/getting_started_2.py
+3
-3
demos/getting_started_3.py
demos/getting_started_3.py
+4
-4
demos/polynomial_fit.py
demos/polynomial_fit.py
+1
-1
nifty5/energies/hamiltonian.py
nifty5/energies/hamiltonian.py
+1
-1
nifty5/energies/kl.py
nifty5/energies/kl.py
+1
-1
nifty5/library/amplitude_model.py
nifty5/library/amplitude_model.py
+2
-2
nifty5/library/bernoulli_energy.py
nifty5/library/bernoulli_energy.py
+1
-1
nifty5/library/correlated_fields.py
nifty5/library/correlated_fields.py
+1
-1
nifty5/library/gaussian_energy.py
nifty5/library/gaussian_energy.py
+1
-1
nifty5/library/poissonian_energy.py
nifty5/library/poissonian_energy.py
+1
-1
nifty5/linearization.py
nifty5/linearization.py
+16
-16
nifty5/multi/block_diagonal_operator.py
nifty5/multi/block_diagonal_operator.py
+1
-1
nifty5/operators/harmonic_smoothing_operator.py
nifty5/operators/harmonic_smoothing_operator.py
+1
-1
nifty5/operators/linear_operator.py
nifty5/operators/linear_operator.py
+6
-5
nifty5/operators/operator.py
nifty5/operators/operator.py
+10
-23
nifty5/operators/sandwich_operator.py
nifty5/operators/sandwich_operator.py
+2
-2
nifty5/operators/smoothness_operator.py
nifty5/operators/smoothness_operator.py
+1
-1
nifty5/utilities.py
nifty5/utilities.py
+7
-6
test/test_energies/test_map.py
test/test_energies/test_map.py
+4
-4
test/test_field.py
test/test_field.py
+2
-2
test/test_models/test_model_gradients.py
test/test_models/test_model_gradients.py
+3
-3
test/test_multi_field.py
test/test_multi_field.py
+1
-1
test/test_operators/test_adjoint.py
test/test_operators/test_adjoint.py
+1
-1
test/test_operators/test_composed_operator.py
test/test_operators/test_composed_operator.py
+4
-5
No files found.
demos/Wiener_Filter.ipynb
View file @
c3ed466f
...
...
@@ -429,7 +429,7 @@
"mask[l:h] = 0\n",
"mask = ift.Field.from_global_data(s_space, mask)\n",
"\n",
"R = ift.DiagonalOperator(mask)
.chain
(HT)\n",
"R = ift.DiagonalOperator(mask)(HT)\n",
"n = n.to_global_data_rw()\n",
"n[l:h] = 0\n",
"n = ift.Field.from_global_data(s_space, n)\n",
...
...
@@ -585,7 +585,7 @@
"mask[l:h,l:h] = 0.\n",
"mask = ift.Field.from_global_data(s_space, mask)\n",
"\n",
"R = ift.DiagonalOperator(mask)
.chain
(HT)\n",
"R = ift.DiagonalOperator(mask)(HT)\n",
"n = n.to_global_data_rw()\n",
"n[l:h, l:h] = 0\n",
"n = ift.Field.from_global_data(s_space, n)\n",
...
...
demos/bernoulli_demo.py
View file @
c3ed466f
...
...
@@ -53,7 +53,7 @@ if __name__ == '__main__':
A
=
pd
(
a
)
# Set up a sky model
sky
=
HT
.
chain
(
ift
.
makeOp
(
A
)).
positive_tanh
()
sky
=
HT
(
ift
.
makeOp
(
A
)).
positive_tanh
()
GR
=
ift
.
GeometryRemover
(
position_space
)
# Set up instrumental response
...
...
@@ -61,7 +61,7 @@ if __name__ == '__main__':
# Generate mock data
d_space
=
R
.
target
[
0
]
p
=
R
.
chain
(
sky
)
p
=
R
(
sky
)
mock_position
=
ift
.
from_random
(
'normal'
,
harmonic_space
)
pp
=
p
(
mock_position
)
data
=
np
.
random
.
binomial
(
1
,
pp
.
to_global_data
().
astype
(
np
.
float64
))
...
...
demos/getting_started_1.py
View file @
c3ed466f
...
...
@@ -78,7 +78,7 @@ if __name__ == '__main__':
GR
=
ift
.
GeometryRemover
(
position_space
)
mask
=
ift
.
Field
.
from_global_data
(
position_space
,
mask
)
Mask
=
ift
.
DiagonalOperator
(
mask
)
R
=
GR
.
chain
(
Mask
).
chain
(
HT
)
R
=
GR
(
Mask
(
HT
)
)
data_space
=
GR
.
target
...
...
@@ -93,7 +93,7 @@ if __name__ == '__main__':
# Build propagator D and information source j
j
=
R
.
adjoint_times
(
N
.
inverse_times
(
data
))
D_inv
=
R
.
adjoint
.
chain
(
N
.
inverse
).
chain
(
R
)
+
S
.
inverse
D_inv
=
R
.
adjoint
(
N
.
inverse
(
R
)
)
+
S
.
inverse
# Make it invertible
IC
=
ift
.
GradientNormController
(
iteration_limit
=
500
,
tol_abs_gradnorm
=
1e-3
)
D
=
ift
.
InversionEnabler
(
D_inv
,
IC
,
approximation
=
S
.
inverse
).
inverse
...
...
@@ -112,7 +112,7 @@ if __name__ == '__main__':
title
=
"getting_started_1"
)
else
:
ift
.
plot
(
HT
(
MOCK_SIGNAL
),
title
=
'Mock Signal'
)
ift
.
plot
(
mask_to_nan
(
mask
,
(
GR
.
chain
(
Mask
)).
adjoint
(
data
)),
ift
.
plot
(
mask_to_nan
(
mask
,
(
GR
(
Mask
)).
adjoint
(
data
)),
title
=
'Data'
)
ift
.
plot
(
HT
(
m
),
title
=
'Reconstruction'
)
ift
.
plot
(
mask_to_nan
(
mask
,
HT
(
m
-
MOCK_SIGNAL
)),
title
=
'Residuals'
)
...
...
demos/getting_started_2.py
View file @
c3ed466f
...
...
@@ -70,16 +70,16 @@ if __name__ == '__main__':
A
=
pd
(
a
)
# Set up a sky model
sky
=
ift
.
exp
(
HT
.
chain
(
ift
.
makeOp
(
A
)))
sky
=
ift
.
exp
(
HT
(
ift
.
makeOp
(
A
)))
M
=
ift
.
DiagonalOperator
(
exposure
)
GR
=
ift
.
GeometryRemover
(
position_space
)
# Set up instrumental response
R
=
GR
.
chain
(
M
)
R
=
GR
(
M
)
# Generate mock data
d_space
=
R
.
target
[
0
]
lamb
=
R
.
chain
(
sky
)
lamb
=
R
(
sky
)
mock_position
=
ift
.
from_random
(
'normal'
,
domain
)
data
=
lamb
(
mock_position
)
data
=
np
.
random
.
poisson
(
data
.
to_global_data
().
astype
(
np
.
float64
))
...
...
demos/getting_started_3.py
View file @
c3ed466f
...
...
@@ -44,8 +44,8 @@ if __name__ == '__main__':
domain
=
ift
.
MultiDomain
.
union
(
(
A
.
domain
,
ift
.
MultiDomain
.
make
({
'xi'
:
harmonic_space
})))
correlated_field
=
ht
.
chain
(
power_distributor
.
chain
(
A
)
*
ift
.
FieldAdapter
(
domain
,
"xi"
))
correlated_field
=
ht
(
power_distributor
(
A
)
*
ift
.
FieldAdapter
(
domain
,
"xi"
))
# alternatively to the block above one can do:
# correlated_field = ift.CorrelatedField(position_space, A)
...
...
@@ -57,7 +57,7 @@ if __name__ == '__main__':
R
=
ift
.
LOSResponse
(
position_space
,
starts
=
LOS_starts
,
ends
=
LOS_ends
)
# build signal response model and model likelihood
signal_response
=
R
.
chain
(
signal
)
signal_response
=
R
(
signal
)
# specify noise
data_space
=
R
.
target
noise
=
.
001
...
...
@@ -69,7 +69,7 @@ if __name__ == '__main__':
# set up model likelihood
likelihood
=
ift
.
GaussianEnergy
(
mean
=
data
,
covariance
=
N
)
.
chain
(
signal_response
)
mean
=
data
,
covariance
=
N
)(
signal_response
)
# set up minimization and inversion schemes
ic_cg
=
ift
.
GradientNormController
(
iteration_limit
=
10
)
...
...
demos/polynomial_fit.py
View file @
c3ed466f
...
...
@@ -97,7 +97,7 @@ d = ift.from_global_data(d_space, y)
N
=
ift
.
DiagonalOperator
(
ift
.
from_global_data
(
d_space
,
var
))
IC
=
ift
.
GradientNormController
(
tol_abs_gradnorm
=
1e-8
)
likelihood
=
ift
.
GaussianEnergy
(
d
,
N
)
.
chain
(
R
)
likelihood
=
ift
.
GaussianEnergy
(
d
,
N
)(
R
)
H
=
ift
.
Hamiltonian
(
likelihood
,
IC
)
H
=
ift
.
EnergyAdapter
(
params
,
H
)
H
=
H
.
make_invertible
(
IC
)
...
...
nifty5/energies/hamiltonian.py
View file @
c3ed466f
...
...
@@ -40,7 +40,7 @@ class Hamiltonian(Operator):
def
target
(
self
):
return
DomainTuple
.
scalar_domain
()
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
if
self
.
_ic_samp
is
None
or
not
isinstance
(
x
,
Linearization
):
return
self
.
_lh
(
x
)
+
self
.
_prior
(
x
)
else
:
...
...
nifty5/energies/kl.py
View file @
c3ed466f
...
...
@@ -42,6 +42,6 @@ class SampledKullbachLeiblerDivergence(Operator):
def
target
(
self
):
return
DomainTuple
.
scalar_domain
()
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
return
(
my_sum
(
map
(
lambda
v
:
self
.
_h
(
x
+
v
),
self
.
_res_samples
))
*
(
1.
/
len
(
self
.
_res_samples
)))
nifty5/library/amplitude_model.py
View file @
c3ed466f
...
...
@@ -130,7 +130,7 @@ class AmplitudeModel(Operator):
cepstrum
=
create_cepstrum_amplitude_field
(
dof_space
,
kern
)
ceps
=
makeOp
(
sqrt
(
cepstrum
))
self
.
_smooth_op
=
sym
.
chain
(
qht
).
chain
(
ceps
)
self
.
_smooth_op
=
sym
(
qht
(
ceps
)
)
self
.
_keys
=
tuple
(
keys
)
@
property
...
...
@@ -141,7 +141,7 @@ class AmplitudeModel(Operator):
def
target
(
self
):
return
self
.
_target
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
smooth_spec
=
self
.
_smooth_op
(
x
[
self
.
_keys
[
0
]])
phi
=
x
[
self
.
_keys
[
1
]]
+
self
.
_norm_phi_mean
linear_spec
=
self
.
_slope
(
phi
)
...
...
nifty5/library/bernoulli_energy.py
View file @
c3ed466f
...
...
@@ -39,7 +39,7 @@ class BernoulliEnergy(Operator):
def
target
(
self
):
return
DomainTuple
.
scalar_domain
()
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
x
=
self
.
_p
(
x
)
v
=
x
.
log
().
vdot
(
-
self
.
_d
)
-
(
1.
-
x
).
log
().
vdot
(
1.
-
self
.
_d
)
if
not
isinstance
(
x
,
Linearization
):
...
...
nifty5/library/correlated_fields.py
View file @
c3ed466f
...
...
@@ -58,7 +58,7 @@ class CorrelatedField(Operator):
def
target
(
self
):
return
self
.
_ht
.
target
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
A
=
self
.
_power_distributor
(
self
.
_amplitude_model
(
x
))
correlated_field_h
=
A
*
x
[
"xi"
]
correlated_field
=
self
.
_ht
(
correlated_field_h
)
...
...
nifty5/library/gaussian_energy.py
View file @
c3ed466f
...
...
@@ -55,7 +55,7 @@ class GaussianEnergy(Operator):
def
target
(
self
):
return
DomainTuple
.
scalar_domain
()
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
residual
=
x
if
self
.
_mean
is
None
else
x
-
self
.
_mean
icovres
=
residual
if
self
.
_icov
is
None
else
self
.
_icov
(
residual
)
res
=
.
5
*
residual
.
vdot
(
icovres
)
...
...
nifty5/library/poissonian_energy.py
View file @
c3ed466f
...
...
@@ -41,7 +41,7 @@ class PoissonianEnergy(Operator):
def
target
(
self
):
return
DomainTuple
.
scalar_domain
()
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
x
=
self
.
_op
(
x
)
res
=
x
.
sum
()
-
x
.
log
().
vdot
(
self
.
_d
)
if
not
isinstance
(
x
,
Linearization
):
...
...
nifty5/linearization.py
View file @
c3ed466f
...
...
@@ -46,8 +46,8 @@ class Linearization(object):
def
__neg__
(
self
):
return
Linearization
(
-
self
.
_val
,
self
.
_jac
.
chain
(
-
1
),
None
if
self
.
_metric
is
None
else
self
.
_metric
.
chain
(
-
1
))
-
self
.
_val
,
self
.
_jac
*
(
-
1
),
None
if
self
.
_metric
is
None
else
self
.
_metric
*
(
-
1
))
def
__add__
(
self
,
other
):
if
isinstance
(
other
,
Linearization
):
...
...
@@ -77,24 +77,24 @@ class Linearization(object):
d2
=
makeOp
(
other
.
_val
)
return
Linearization
(
self
.
_val
*
other
.
_val
,
d2
.
chain
(
self
.
_jac
)
+
d1
.
chain
(
other
.
_jac
))
d2
(
self
.
_jac
)
+
d1
(
other
.
_jac
))
if
isinstance
(
other
,
(
int
,
float
,
complex
)):
# if other == 0:
# return ...
met
=
None
if
self
.
_metric
is
None
else
self
.
_metric
.
chain
(
other
)
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
.
chain
(
other
),
met
)
met
=
None
if
self
.
_metric
is
None
else
self
.
_metric
(
other
)
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
(
other
),
met
)
if
isinstance
(
other
,
(
Field
,
MultiField
)):
d2
=
makeOp
(
other
)
return
Linearization
(
self
.
_val
*
other
,
d2
.
chain
(
self
.
_jac
))
return
Linearization
(
self
.
_val
*
other
,
d2
(
self
.
_jac
))
raise
TypeError
def
__rmul__
(
self
,
other
):
from
.sugar
import
makeOp
if
isinstance
(
other
,
(
int
,
float
,
complex
)):
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
.
chain
(
other
))
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
(
other
))
if
isinstance
(
other
,
(
Field
,
MultiField
)):
d1
=
makeOp
(
other
)
return
Linearization
(
self
.
_val
*
other
,
d1
.
chain
(
self
.
_jac
))
return
Linearization
(
self
.
_val
*
other
,
d1
(
self
.
_jac
))
def
vdot
(
self
,
other
):
from
.domain_tuple
import
DomainTuple
...
...
@@ -102,11 +102,11 @@ class Linearization(object):
if
isinstance
(
other
,
(
Field
,
MultiField
)):
return
Linearization
(
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
vdot
(
other
)),
VdotOperator
(
other
)
.
chain
(
self
.
_jac
))
VdotOperator
(
other
)(
self
.
_jac
))
return
Linearization
(
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
vdot
(
other
.
_val
)),
VdotOperator
(
self
.
_val
)
.
chain
(
other
.
_jac
)
+
VdotOperator
(
other
.
_val
)
.
chain
(
self
.
_jac
))
VdotOperator
(
self
.
_val
)(
other
.
_jac
)
+
VdotOperator
(
other
.
_val
)(
self
.
_jac
))
def
sum
(
self
):
from
.domain_tuple
import
DomainTuple
...
...
@@ -114,24 +114,24 @@ class Linearization(object):
from
.sugar
import
full
return
Linearization
(
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
sum
()),
SumReductionOperator
(
self
.
_jac
.
target
)
.
chain
(
self
.
_jac
))
SumReductionOperator
(
self
.
_jac
.
target
)(
self
.
_jac
))
def
exp
(
self
):
tmp
=
self
.
_val
.
exp
()
return
Linearization
(
tmp
,
makeOp
(
tmp
)
.
chain
(
self
.
_jac
))
return
Linearization
(
tmp
,
makeOp
(
tmp
)(
self
.
_jac
))
def
log
(
self
):
tmp
=
self
.
_val
.
log
()
return
Linearization
(
tmp
,
makeOp
(
1.
/
self
.
_val
)
.
chain
(
self
.
_jac
))
return
Linearization
(
tmp
,
makeOp
(
1.
/
self
.
_val
)(
self
.
_jac
))
def
tanh
(
self
):
tmp
=
self
.
_val
.
tanh
()
return
Linearization
(
tmp
,
makeOp
(
1.
-
tmp
**
2
)
.
chain
(
self
.
_jac
))
return
Linearization
(
tmp
,
makeOp
(
1.
-
tmp
**
2
)(
self
.
_jac
))
def
positive_tanh
(
self
):
tmp
=
self
.
_val
.
tanh
()
tmp2
=
0.5
*
(
1.
+
tmp
)
return
Linearization
(
tmp2
,
makeOp
(
0.5
*
(
1.
-
tmp
**
2
))
.
chain
(
self
.
_jac
))
return
Linearization
(
tmp2
,
makeOp
(
0.5
*
(
1.
-
tmp
**
2
))(
self
.
_jac
))
def
add_metric
(
self
,
metric
):
return
Linearization
(
self
.
_val
,
self
.
_jac
,
metric
)
...
...
nifty5/multi/block_diagonal_operator.py
View file @
c3ed466f
...
...
@@ -68,7 +68,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
def
_combine_chain
(
self
,
op
):
if
self
.
_domain
is
not
op
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
res
=
tuple
(
v1
.
chain
(
v2
)
for
v1
,
v2
in
zip
(
self
.
_ops
,
op
.
_ops
))
res
=
tuple
(
v1
(
v2
)
for
v1
,
v2
in
zip
(
self
.
_ops
,
op
.
_ops
))
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
def
_combine_sum
(
self
,
op
,
selfneg
,
opneg
):
...
...
nifty5/operators/harmonic_smoothing_operator.py
View file @
c3ed466f
...
...
@@ -67,4 +67,4 @@ def HarmonicSmoothingOperator(domain, sigma, space=None):
ddom
=
list
(
domain
)
ddom
[
space
]
=
codomain
diag
=
DiagonalOperator
(
kernel
,
ddom
,
space
)
return
Hartley
.
inverse
.
chain
(
diag
).
chain
(
Hartley
)
return
Hartley
.
inverse
(
diag
(
Hartley
)
)
nifty5/operators/linear_operator.py
View file @
c3ed466f
...
...
@@ -142,9 +142,6 @@ class LinearOperator(Operator):
from
.chain_operator
import
ChainOperator
return
ChainOperator
.
make
([
self
,
other2
])
def
chain
(
self
,
other
):
return
self
.
__matmul__
(
other
)
def
__rmatmul__
(
self
,
other
):
if
np
.
isscalar
(
other
)
and
other
==
1.
:
return
self
...
...
@@ -213,10 +210,14 @@ class LinearOperator(Operator):
def
__call__
(
self
,
x
):
"""Same as :meth:`times`"""
from
..field
import
Field
from
..multi.multi_field
import
MultiField
if
isinstance
(
x
,
(
Field
,
MultiField
)):
return
self
.
apply
(
x
,
self
.
TIMES
)
from
..linearization
import
Linearization
if
isinstance
(
x
,
Linearization
):
return
Linearization
(
self
(
x
.
_val
),
self
.
chain
(
x
.
_jac
))
return
self
.
apply
(
x
,
self
.
TIMES
)
return
Linearization
(
self
(
x
.
_val
),
self
(
x
.
_jac
))
return
self
.
__matmul__
(
x
)
def
times
(
self
,
x
):
""" Applies the Operator to a given Field.
...
...
nifty5/operators/operator.py
View file @
c3ed466f
...
...
@@ -33,26 +33,13 @@ class Operator(NiftyMetaBase()):
return
NotImplemented
return
_OpProd
.
make
((
self
,
x
))
def
chain
(
self
,
x
):
res
=
self
.
__matmul__
(
x
)
if
res
==
NotImplemented
:
raise
TypeError
(
"operator expected"
)
return
res
def
apply
(
self
,
x
):
raise
NotImplementedError
def
__call__
(
self
,
x
):
"""Returns transformed x
Parameters
----------
x : Linearization
input
Returns
-------
Linearization
output
"""
raise
NotImplementedError
if
isinstance
(
x
,
Operator
):
return
_OpChain
.
make
((
self
,
x
))
return
self
.
apply
(
x
)
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"positive_tanh"
]:
...
...
@@ -78,7 +65,7 @@ class _FunctionApplier(Operator):
def
target
(
self
):
return
self
.
_domain
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
return
getattr
(
x
,
self
.
_funcname
)()
...
...
@@ -117,7 +104,7 @@ class _OpChain(_CombinedOperator):
def
target
(
self
):
return
self
.
_ops
[
0
].
target
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
for
op
in
reversed
(
self
.
_ops
):
x
=
op
(
x
)
return
x
...
...
@@ -135,7 +122,7 @@ class _OpProd(_CombinedOperator):
def
target
(
self
):
return
self
.
_ops
[
0
].
target
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
from
..utilities
import
my_product
return
my_product
(
map
(
lambda
op
:
op
(
x
),
self
.
_ops
))
...
...
@@ -154,7 +141,7 @@ class _OpSum(_CombinedOperator):
def
target
(
self
):
return
self
.
_target
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
raise
NotImplementedError
...
...
@@ -193,7 +180,7 @@ class QuadraticFormOperator(Operator):
def
target
(
self
):
return
self
.
_target
def
__call__
(
self
,
x
):
def
apply
(
self
,
x
):
if
isinstance
(
x
,
Linearization
):
jac
=
self
.
_op
(
x
)
val
=
Field
(
self
.
_target
,
0.5
*
x
.
vdot
(
jac
))
...
...
nifty5/operators/sandwich_operator.py
View file @
c3ed466f
...
...
@@ -56,9 +56,9 @@ class SandwichOperator(EndomorphicOperator):
raise
TypeError
(
"cheese must be a linear operator"
)
if
cheese
is
None
:
cheese
=
ScalingOperator
(
1.
,
bun
.
target
)
op
=
bun
.
adjoint
.
chain
(
bun
)
op
=
bun
.
adjoint
(
bun
)
else
:
op
=
bun
.
adjoint
.
chain
(
cheese
).
chain
(
bun
)
op
=
bun
.
adjoint
(
cheese
(
bun
)
)
# if our sandwich is diagonal, we can return immediately
if
isinstance
(
op
,
(
ScalingOperator
,
DiagonalOperator
)):
...
...
nifty5/operators/smoothness_operator.py
View file @
c3ed466f
...
...
@@ -54,4 +54,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
if
strength
==
0.
:
return
ScalingOperator
(
0.
,
domain
)
laplace
=
LaplaceOperator
(
domain
,
logarithmic
=
logarithmic
,
space
=
space
)
return
(
strength
**
2
)
*
laplace
.
adjoint
.
chain
(
laplace
)
return
(
strength
**
2
)
*
laplace
.
adjoint
(
laplace
)
nifty5/utilities.py
View file @
c3ed466f
...
...
@@ -23,6 +23,8 @@ from itertools import product
import
numpy
as
np
from
future.utils
import
with_metaclass
import
pyfftw
from
pyfftw.interfaces.numpy_fft
import
rfftn
,
fftn
from
.compat
import
*
...
...
@@ -201,9 +203,11 @@ _fft_extra_args = dict(planner_effort='FFTW_ESTIMATE')
def
fft_prep
():
import
pyfftw
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
if
not
fft_prep
.
_initialized
:
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
fft_prep
.
_initialized
=
True
fft_prep
.
_initialized
=
False
def
hartley
(
a
,
axes
=
None
):
...
...
@@ -214,7 +218,6 @@ def hartley(a, axes=None):
if
iscomplextype
(
a
.
dtype
):
raise
TypeError
(
"Hartley transform requires real-valued arrays."
)
from
pyfftw.interfaces.numpy_fft
import
rfftn
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
(),
**
_fft_extra_args
)
def
_fill_array
(
tmp
,
res
,
axes
):
...
...
@@ -258,7 +261,6 @@ def my_fftn_r2c(a, axes=None):
if
iscomplextype
(
a
.
dtype
):
raise
TypeError
(
"Transform requires real-valued input arrays."
)
from
pyfftw.interfaces.numpy_fft
import
rfftn
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
(),
**
_fft_extra_args
)
def
_fill_complex_array
(
tmp
,
res
,
axes
):
...
...
@@ -293,7 +295,6 @@ def my_fftn_r2c(a, axes=None):
def
my_fftn
(
a
,
axes
=
None
):
from
pyfftw.interfaces.numpy_fft
import
fftn
return
fftn
(
a
,
axes
=
axes
,
**
_fft_extra_args
)
...
...
test/test_energies/test_map.py
View file @
c3ed466f
...
...
@@ -56,18 +56,18 @@ class Energy_Tests(unittest.TestCase):
def
d_model
():
if
nonlinearity
==
""
:
return
R
.
chain
(
ht
.
chain
(
ift
.
makeOp
(
A
)))
return
R
(
ht
(
ift
.
makeOp
(
A
)))
else
:
tmp
=
ht
.
chain
(
ift
.
makeOp
(
A
))
tmp
=
ht
(
ift
.
makeOp
(
A
))
nonlin
=
getattr
(
tmp
,
nonlinearity
)()
return
R
.
chain
(
nonlin
)
return
R
(
nonlin
)
d
=
d_model
()(
xi0
)
+
n
if
noise
==
1
:
N
=
None
energy
=
ift
.
GaussianEnergy
(
d
,
N
)
.
chain
(
d_model
())
energy
=
ift
.
GaussianEnergy
(
d
,
N
)(
d_model
())
if
nonlinearity
==
""
:
ift
.
extra
.
check_value_gradient_metric_consistency
(
energy
,
xi0
,
ntries
=
10
)
...
...
test/test_field.py
View file @
c3ed466f
...
...
@@ -66,7 +66,7 @@ class Test_Functionality(unittest.TestCase):
op1
=
ift
.
create_power_operator
((
space1
,
space2
),
_spec1
,
0
)
op2
=
ift
.
create_power_operator
((
space1
,
space2
),
_spec2
,
1
)
opfull
=
op2
.
chain
(
op1
)
opfull
=
op2
(
op1
)
samples
=
500
sc1
=
ift
.
StatCalculator
()
...
...
@@ -94,7 +94,7 @@ class Test_Functionality(unittest.TestCase):
S_1
=
ift
.
create_power_operator
((
space1
,
space2
),
_spec1
,
0
)
S_2
=
ift
.
create_power_operator
((
space1
,
space2
),
_spec2
,
1
)
S_full
=
S_2
.
chain
(
S_1
)
S_full
=
S_2
(
S_1
)
samples
=
500
sc1
=
ift
.
StatCalculator
()
...
...
test/test_models/test_model_gradients.py
View file @
c3ed466f
...
...
@@ -71,16 +71,16 @@ class Model_Tests(unittest.TestCase):
model
=
ift
.
FieldAdapter
(
dom
,
"s1"
)
*
3.
pos
=
ift
.
from_random
(
"normal"
,
dom
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
model
=
ift
.
ScalingOperator
(
2.456
,
space
)
.
chain
(
model
=
ift
.
ScalingOperator
(
2.456
,
space
)(
ift
.
FieldAdapter
(
dom
,
"s1"
)
*
ift
.
FieldAdapter
(
dom
,
"s2"
))
pos
=
ift
.
from_random
(
"normal"
,
dom
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
model
=
ift
.
positive_tanh
(
ift
.
ScalingOperator
(
2.456
,
space
)
.
chain
(
model
=
ift
.
positive_tanh
(
ift
.
ScalingOperator
(
2.456
,
space
)(
ift
.
FieldAdapter
(
dom
,
"s1"
)
*
ift
.
FieldAdapter
(
dom
,
"s2"
)))
pos
=
ift
.
from_random
(
"normal"
,
dom
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
if
isinstance
(
space
,
ift
.
RGSpace
):
model
=
ift
.
FFTOperator
(
space
)
.
chain
(
model
=
ift
.
FFTOperator
(
space
)(
ift
.
FieldAdapter
(
dom
,
"s1"
)
*
ift
.
FieldAdapter
(
dom
,
"s2"
))
pos
=
ift
.
from_random
(
"normal"
,
dom
)
ift
.
extra
.
check_value_gradient_consistency
(
model
,
pos
)
...
...
test/test_multi_field.py
View file @
c3ed466f
...
...
@@ -40,7 +40,7 @@ class Test_Functionality(unittest.TestCase):
def
test_blockdiagonal
(
self
):
op
=
ift
.
BlockDiagonalOperator
(
dom
,
(
ift
.
ScalingOperator
(
20.
,
dom
[
"d1"
]),))
op2
=
op
.
chain
(
op
)
op2
=
op
(
op
)
ift
.
extra
.
consistency_check
(
op2
)
assert_equal
(
type
(
op2
),
ift
.
BlockDiagonalOperator
)
f1
=
op2
(
ift
.
full
(
dom
,
1
))
...
...
test/test_operators/test_adjoint.py
View file @
c3ed466f
...
...
@@ -53,7 +53,7 @@ class Consistency_Tests(unittest.TestCase):
dtype
=
dtype
))
op
=
ift
.
SandwichOperator
.
make
(
a
,
b
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
op
=
a
.
chain
(
b
)
op
=
a
(
b
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
op
=
a
+
b
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
...
...
test/test_operators/test_composed_operator.py
View file @
c3ed466f
...
...
@@ -37,7 +37,7 @@ class ComposedOperator_Tests(unittest.TestCase):
op1
=
ift
.
DiagonalOperator
(
diag1
,
cspace
,
spaces
=
(
0
,))
op2
=
ift
.
DiagonalOperator
(
diag2
,
cspace
,
spaces
=
(
1
,))
op
=
op2
.
chain
(
op1
)
op
=
op2
(
op1
)
rand1
=
ift
.
Field
.
from_random
(
'normal'
,
domain
=
(
space1
,
space2
))
rand2
=
ift
.
Field
.
from_random
(
'normal'
,
domain
=
(
space1
,
space2
))
...
...
@@ -54,7 +54,7 @@ class ComposedOperator_Tests(unittest.TestCase):
op1
=
ift
.
DiagonalOperator
(
diag1
,
cspace
,
spaces
=
(
0
,))
op2
=
ift
.
DiagonalOperator
(
diag2
,
cspace
,
spaces
=
(
1
,))
op
=
op2
.
chain
(
op1
)
op
=
op2
(
op1
)
rand1
=
ift
.
Field
.
from_random
(
'normal'
,
domain
=
(
space1
,
space2
))
tt1
=
op
.
inverse_times
(
op
.
times
(
rand1
))
...
...
@@ -75,8 +75,7 @@ class ComposedOperator_Tests(unittest.TestCase):
def
test_chain
(
self
,
space
):
op1
=
ift
.
makeOp
(
ift
.
Field
.
full
(
space
,
2.
))
op2
=
3.
full_op
=
(
op1
.
chain
(
op2
).
chain
(
op2
).
chain
(
op1
).
chain
(
op1
).
chain
(
op1
).
chain
(
op2
))
full_op
=
op1
(
op2
)(
op2
)(
op1
)(
op1
)(
op1
)(
op2
)
x
=
ift
.
Field
.
full
(
space
,
1.
)
res
=
full_op
(
x
)
assert_equal
(
isinstance
(
full_op
,
ift
.
DiagonalOperator
),
True
)
...
...
@@ -86,7 +85,7 @@ class ComposedOperator_Tests(unittest.TestCase):
def
test_mix
(
self
,
space
):
op1
=
ift
.
makeOp
(
ift
.
Field
.
full
(
space
,
2.
))
op2
=
3.
full_op
=
op1
.
chain
(
op2
+
op2
).
chain
(
op1
).
chain
(
op1
)
-
op1
.
chain
(
op2
)
full_op
=
op1
(
op2
+
op2
)(
op1
)(
op1
)
-
op1
(
op2
)
x
=
ift
.
Field
.
full
(
space
,
1.
)
res
=
full_op
(
x
)
assert_equal
(
isinstance
(
full_op
,
ift
.
DiagonalOperator
),
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment