Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
83b0eccc
Commit
83b0eccc
authored
Aug 02, 2018
by
Martin Reinecke
Browse files
move Operator
parent
73a85004
Changes
11
Hide whitespace changes
Inline
Side-by-side
nifty5/__init__.py
View file @
83b0eccc
...
...
@@ -16,6 +16,7 @@ from .domains.log_rg_space import LogRGSpace
from
.domain_tuple
import
DomainTuple
from
.field
import
Field
from
.operators.operator
import
Operator
from
.operators.central_zero_padder
import
CentralZeroPadder
from
.operators.diagonal_operator
import
DiagonalOperator
from
.operators.dof_distributor
import
DOFDistributor
...
...
@@ -92,7 +93,6 @@ from .energies.kl import SampledKullbachLeiblerDivergence
from
.energies.hamiltonian
import
Hamiltonian
from
.energies.energy_adapter
import
EnergyAdapter
from
.operator
import
Operator
from
.linearization
import
Linearization
# We deliberately don't set __all__ here, because we don't want people to do a
...
...
nifty5/energies/hamiltonian.py
View file @
83b0eccc
...
...
@@ -19,7 +19,7 @@
from
__future__
import
absolute_import
,
division
,
print_function
from
..compat
import
*
from
..operator
import
Operator
from
..operator
s.operator
import
Operator
from
..library.gaussian_energy
import
GaussianEnergy
from
..operators.sampling_enabler
import
SamplingEnabler
...
...
nifty5/energies/kl.py
View file @
83b0eccc
...
...
@@ -19,7 +19,7 @@
from
__future__
import
absolute_import
,
division
,
print_function
from
..compat
import
*
from
..operator
import
Operator
from
..operator
s.operator
import
Operator
from
..utilities
import
my_sum
...
...
nifty5/library/amplitude_model.py
View file @
83b0eccc
...
...
@@ -27,7 +27,7 @@ from ..field import Field
from
..multi.multi_field
import
MultiField
from
..multi.multi_domain
import
MultiDomain
from
..sugar
import
makeOp
,
sqrt
from
..operator
import
Operator
from
..operator
s.operator
import
Operator
def
_ceps_kernel
(
dof_space
,
k
,
a
,
k0
):
...
...
nifty5/library/bernoulli_energy.py
View file @
83b0eccc
...
...
@@ -19,7 +19,7 @@
from
__future__
import
absolute_import
,
division
,
print_function
from
..compat
import
*
from
..operator
import
Operator
from
..operator
s.operator
import
Operator
from
..operators.sandwich_operator
import
SandwichOperator
from
..sugar
import
makeOp
...
...
nifty5/library/correlated_fields.py
View file @
83b0eccc
...
...
@@ -25,7 +25,7 @@ from ..multi.multi_domain import MultiDomain
from
..operators.domain_distributor
import
DomainDistributor
from
..operators.harmonic_transform_operator
import
HarmonicTransformOperator
from
..operators.power_distributor
import
PowerDistributor
from
..operator
import
Operator
from
..operator
s.operator
import
Operator
class
CorrelatedField
(
Operator
):
...
...
nifty5/library/gaussian_energy.py
View file @
83b0eccc
...
...
@@ -19,8 +19,9 @@
from
__future__
import
absolute_import
,
division
,
print_function
from
..compat
import
*
from
..operator
import
Operator
from
..operator
s.operator
import
Operator
from
..operators.sandwich_operator
import
SandwichOperator
from
..domain_tuple
import
DomainTuple
class
GaussianEnergy
(
Operator
):
...
...
@@ -28,6 +29,7 @@ class GaussianEnergy(Operator):
super
(
GaussianEnergy
,
self
).
__init__
()
self
.
_mean
=
mean
self
.
_icov
=
None
if
covariance
is
None
else
covariance
.
inverse
self
.
_target
=
DomainTuple
.
scalar_domain
()
def
__call__
(
self
,
x
):
residual
=
x
if
self
.
_mean
is
None
else
x
-
self
.
_mean
...
...
nifty5/library/poissonian_energy.py
View file @
83b0eccc
...
...
@@ -21,7 +21,7 @@ from __future__ import absolute_import, division, print_function
from
numpy
import
inf
,
isnan
from
..compat
import
*
from
..operator
import
Operator
from
..operator
s.operator
import
Operator
from
..operators.sandwich_operator
import
SandwichOperator
from
..sugar
import
makeOp
...
...
nifty5/operator.py
deleted
100644 → 0
View file @
73a85004
from
__future__
import
absolute_import
,
division
,
print_function
from
.compat
import
*
from
.utilities
import
NiftyMetaBase
class
Operator
(
NiftyMetaBase
()):
"""Transforms values living on one domain into values living on another
domain, and can also provide the Jacobian.
"""
def
chain
(
self
,
x
):
if
not
callable
(
x
):
raise
TypeError
(
"callable needed"
)
ops1
=
self
.
_ops
if
isinstance
(
self
,
OpChain
)
else
(
self
,)
ops2
=
x
.
_ops
if
isinstance
(
x
,
OpChain
)
else
(
x
,)
return
OpChain
(
ops1
+
ops2
)
def
__call__
(
self
,
x
):
"""Returns transformed x
Parameters
----------
x : Linearization
input
Returns
-------
Linearization
output
"""
raise
NotImplementedError
class
OpChain
(
Operator
):
def
__init__
(
self
,
ops
):
self
.
_ops
=
tuple
(
ops
)
def
__call__
(
self
,
x
):
for
op
in
reversed
(
self
.
_ops
):
x
=
op
(
x
)
return
x
nifty5/operators/linear_operator.py
View file @
83b0eccc
...
...
@@ -23,7 +23,7 @@ import abc
import
numpy
as
np
from
..compat
import
*
from
.
.operator
import
Operator
from
.operator
import
Operator
class
LinearOperator
(
Operator
):
...
...
@@ -86,21 +86,6 @@ class LinearOperator(Operator):
def
__init__
(
self
):
pass
@
abc
.
abstractproperty
def
domain
(
self
):
# FIXME Adopt documentation to MultiDomains
"""DomainTuple : the operator's input domain
The domain on which the Operator's input Field lives."""
raise
NotImplementedError
@
abc
.
abstractproperty
def
target
(
self
):
"""DomainTuple : the operator's output domain
The domain on which the Operator's output Field lives."""
raise
NotImplementedError
def
_flip_modes
(
self
,
trafo
):
from
.operator_adapter
import
OperatorAdapter
return
self
if
trafo
==
0
else
OperatorAdapter
(
self
,
trafo
)
...
...
nifty5/operators/operator.py
0 → 100644
View file @
83b0eccc
from
__future__
import
absolute_import
,
division
,
print_function
import
abc
from
..compat
import
*
from
..utilities
import
NiftyMetaBase
class
Operator
(
NiftyMetaBase
()):
"""Transforms values living on one domain into values living on another
domain, and can also provide the Jacobian.
"""
def
domain
(
self
):
"""DomainTuple or MultiDomain : the operator's input domain
The domain on which the Operator's input Field lives."""
return
self
.
_domain
def
target
(
self
):
"""DomainTuple or MultiDomain : the operator's output domain
The domain on which the Operator's output Field lives."""
return
self
.
_target
def
__matmul__
(
self
,
x
):
if
not
isinstance
(
x
,
Operator
):
return
NotImplemented
return
OpChain
.
make
((
self
,
x
))
ops1
=
self
.
_ops
if
isinstance
(
self
,
OpChain
)
else
(
self
,)
ops2
=
x
.
_ops
if
isinstance
(
x
,
OpChain
)
else
(
x
,)
return
OpChain
(
ops1
+
ops2
)
def
chain
(
self
,
x
):
res
=
self
.
__matmul__
(
x
)
if
res
==
NotImplemented
:
raise
TypeError
(
"operator expected"
)
return
res
def
__call__
(
self
,
x
):
"""Returns transformed x
Parameters
----------
x : Linearization
input
Returns
-------
Linearization
output
"""
raise
NotImplementedError
class
_CombinedOperator
(
Operator
):
def
__init__
(
self
,
ops
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
self
.
_ops
=
tuple
(
ops
)
@
classmethod
def
unpack
(
cls
,
ops
,
res
):
for
op
in
ops
:
if
isinstance
(
op
,
cls
):
res
=
cls
.
unpack
(
op
,
res
)
else
:
res
=
res
+
[
op
]
return
res
@
classmethod
def
make
(
cls
,
ops
):
res
=
cls
.
unpack
(
ops
,
[])
if
len
(
res
)
==
1
:
return
res
[
0
]
return
cls
(
res
,
_callingfrommake
=
True
)
class
_OpChain
(
_CombinedOperator
):
def
__init__
(
self
,
ops
,
_callingfrommake
=
False
):
super
(
_OpChain
,
self
).
__init__
(
ops
,
_callingfrommake
)
self
.
_domain
=
self
.
_ops
[
-
1
].
domain
self
.
_target
=
self
.
_ops
[
0
].
target
def
__call__
(
self
,
x
):
for
op
in
reversed
(
self
.
_ops
):
x
=
op
(
x
)
return
x
class
_OpProd
(
_CombinedOperator
):
def
__init__
(
self
,
ops
,
_callingfrommake
=
False
):
super
(
_OpProd
,
self
).
__init__
(
ops
,
_callingfrommake
)
self
.
_domain
=
self
.
_ops
[
0
].
domain
self
.
_target
=
self
.
_ops
[
0
].
target
def
__call__
(
self
,
x
):
return
my_prod
(
map
(
lambda
op
:
op
(
x
)
for
op
in
self
.
_ops
))
class
_OpSum
(
_CombinedOperator
):
def
__init__
(
self
,
ops
,
_callingfrommake
=
False
):
super
(
_OpSum
,
self
).
__init__
(
ops
,
_callingfrommake
)
self
.
_domain
=
domain_union
([
op
.
domain
for
op
in
self
.
_ops
])
self
.
_target
=
domain_union
([
op
.
target
for
op
in
self
.
_ops
])
def
__call__
(
self
,
x
):
raise
NotImplementedError
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment