Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
31932e59
Commit
31932e59
authored
May 19, 2020
by
Philipp Arras
Browse files
Merge remote-tracking branch 'origin/NIFTy_6' into integration_operator
parents
83a7a4fc
4d8c1460
Pipeline
#75200
passed with stages
in 8 minutes and 23 seconds
Changes
23
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
ChangeLog
View file @
31932e59
Changes since NIFTy 5:
Minimum Python version increased to 3.6
=======================================
New operators
=============
In addition to the below changes, the following operators were introduced:
* UniformOperator: Transforms a Gaussian into a uniform distribution
* VariableCovarianceGaussianEnergy: Energy operator for inferring covariances
* MultiLinearEinsum: Multi-linear version of numpy's einsum with derivates
* LinearEinsum: Linear version of numpy's einsum with one free field
* PartialConjugate: Conjugates parts of a multi-field
* SliceOperator: Geometry preserving mask operator
* SplitOperator: Splits a single field into a multi-field
FFT convention adjusted
=======================
...
...
README.md
View file @
31932e59
...
...
@@ -45,7 +45,7 @@ Installation
### Requirements
-
[
Python 3
](
https://www.python.org/
)
(
3.
5
.x
or later)
-
[
Python 3
](
https://www.python.org/
)
(
3.
6
.x
or later)
-
[
SciPy
](
https://www.scipy.org/
)
Optional dependencies:
...
...
nifty6/__init__.py
View file @
31932e59
...
...
@@ -25,6 +25,7 @@ from .operators.adder import Adder
from
.operators.diagonal_operator
import
DiagonalOperator
from
.operators.distributors
import
DOFDistributor
,
PowerDistributor
from
.operators.domain_tuple_field_inserter
import
DomainTupleFieldInserter
from
.operators.einsum
import
LinearEinsum
,
MultiLinearEinsum
from
.operators.contraction_operator
import
ContractionOperator
,
IntegrationOperator
from
.operators.linear_interpolation
import
LinearInterpolator
from
.operators.endomorphic_operator
import
EndomorphicOperator
...
...
@@ -38,12 +39,13 @@ from .operators.regridding_operator import RegriddingOperator
from
.operators.sampling_enabler
import
SamplingEnabler
,
SamplingDtypeSetter
from
.operators.sandwich_operator
import
SandwichOperator
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.selection_operators
import
SliceOperator
,
SplitOperator
from
.operators.block_diagonal_operator
import
BlockDiagonalOperator
from
.operators.outer_product_operator
import
OuterProduct
from
.operators.simple_linear_operators
import
(
VdotOperator
,
ConjugationOperator
,
Realizer
,
FieldAdapter
,
ducktape
,
GeometryRemover
,
NullOperator
,
M
atrix
P
roduct
O
perator
,
PartialExtractor
,
SwitchSpaces
Operator
)
VdotOperator
,
ConjugationOperator
,
Realizer
,
FieldAdapter
,
ducktape
,
GeometryRemover
,
NullOperator
,
PartialExtractor
)
from
.operators.m
atrix
_p
roduct
_o
perator
import
MatrixProduct
Operator
from
.operators.value_inserter
import
ValueInserter
from
.operators.energy_operators
import
(
EnergyOperator
,
GaussianEnergy
,
PoissonianEnergy
,
InverseGammaLikelihood
,
...
...
nifty6/library/correlated_fields.py
View file @
31932e59
...
...
@@ -61,7 +61,7 @@ def _lognormal_moments(mean, sig, N=0):
if
not
np
.
all
(
sig
>
0
):
raise
ValueError
(
"sig must be greater 0; got {!r}"
.
format
(
sig
))
logsig
=
np
.
sqrt
(
np
.
log
((
sig
/
mean
)
**
2
+
1
))
logsig
=
np
.
sqrt
(
np
.
log
1p
((
sig
/
mean
)
**
2
))
logmean
=
np
.
log
(
mean
)
-
logsig
**
2
/
2
return
logmean
,
logsig
...
...
nifty6/multi_domain.py
View file @
31932e59
...
...
@@ -22,18 +22,27 @@ from .utilities import frozendict, indent
class
MultiDomain
(
object
):
"""A tuple of domains corresponding to a direct sum.
This class is the domain of the direct sum of fields defined
on (possibly different) domains. To make an instance
of this class, call `MultiDomain.make(inp)`.
This class is the domain of the direct sum of fields defined on (possibly
different) domains. To make an instance of this class, call
`MultiDomain.make(inp)`.
Notes
-----
For consistency and to be independent of the order of insertion, the keys
within a multi-domain are sorted. Hence, renaming a domain may result in it
being placed at a different index within a multi-domain. This is especially
important if a sequence of, e.g., random numbers is distributed sequentially
over a multi-domain. In this example, ordering keys differently will change
the resulting :class:`MultiField`.
"""
_domainCache
=
{}
def
__init__
(
self
,
d
i
ct
,
_callingfrommake
=
False
):
def
__init__
(
self
,
dct
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
(
'To create a MultiDomain call `MultiDomain.make()`.'
)
self
.
_keys
=
tuple
(
sorted
(
d
i
ct
.
keys
()))
self
.
_domains
=
tuple
(
d
i
ct
[
key
]
for
key
in
self
.
_keys
)
self
.
_keys
=
tuple
(
sorted
(
dct
.
keys
()))
self
.
_domains
=
tuple
(
dct
[
key
]
for
key
in
self
.
_keys
)
self
.
_idx
=
frozendict
({
key
:
i
for
i
,
key
in
enumerate
(
self
.
_keys
)})
@
staticmethod
...
...
nifty6/multi_field.py
View file @
31932e59
...
...
@@ -102,6 +102,29 @@ class MultiField(Operator):
@
staticmethod
def
from_random
(
random_type
,
domain
,
dtype
=
np
.
float64
,
**
kwargs
):
"""Draws a random multi-field with the given parameters.
Parameters
----------
random_type : 'pm1', 'normal', or 'uniform'
The random distribution to use.
domain : DomainTuple
The domain of the output random Field.
dtype : type
The datatype of the output random Field.
Returns
-------
MultiField
The newly created :class:`MultiField`.
Notes
-----
The individual fields within this multi-field will be drawn in alphabetical
order of the multi-field's domain keys. As a consequence, renaming these
keys may cause the multi-field to be filled with different random numbers,
even for the same initial RNG state.
"""
domain
=
MultiDomain
.
make
(
domain
)
if
isinstance
(
dtype
,
dict
):
dtype
=
{
kk
:
np
.
dtype
(
dt
)
for
kk
,
dt
in
dtype
.
items
()}
...
...
nifty6/operators/einsum.py
0 → 100644
View file @
31932e59
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Authors: Gordian Edenhofer, Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
string
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
from
..field
import
Field
from
..linearization
import
Linearization
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
from
.linear_operator
import
LinearOperator
from
.operator
import
Operator
class
MultiLinearEinsum
(
Operator
):
"""Multi-linear Einsum operator with corresponding derivates
Parameters
----------
domain : MultiDomain or dict{name: DomainTuple}
The operator's input domain.
subscripts : str
The subscripts which is passed to einsum.
key_order: tuple of str, optional
The order of the keys in the multi-field. If not specified, defaults to
the order of the keys in the multi-field.
static_mf: MultiField or dict{name: Field}, optional
A dictionary like type from which Fields are to be taken if the key from
`key_order` is not part of the `domain`. Fields in this object are
supposed to be static as they will not appear as FieldAdapter in the
Linearization.
optimize: bool, String or List, optional
Parameter passed on to einsum_path.
Notes
-----
By convention :class:`MultiLinearEinsum` only performs operations with
lower indices. Therefore no complex conjugation is performed on complex
inputs. To achieve operations with upper/lower indices use
:class:`PartialConjugate` before applying this operator.
"""
def
__init__
(
self
,
domain
,
subscripts
,
key_order
=
None
,
static_mf
=
None
,
optimize
=
'optimal'
):
self
.
_domain
=
MultiDomain
.
make
(
domain
)
if
key_order
is
None
:
self
.
_key_order
=
tuple
(
self
.
_domain
.
keys
())
else
:
self
.
_key_order
=
key_order
if
static_mf
is
not
None
and
key_order
is
None
:
ve
=
"`key_order` mus be specified if additional fields are munged"
raise
ValueError
(
ve
)
self
.
_stat_mf
=
static_mf
iss
,
oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
len_consist
=
len
(
self
.
_key_order
)
==
len
(
iss_spl
)
sscr_consist
=
all
(
o
in
iss
for
o
in
oss
)
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
shapes
,
numpy_subscripts
,
subscriptmap
=
{},
''
,
{}
alphabet
=
list
(
string
.
ascii_lowercase
)[::
-
1
]
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
):
dom
=
self
.
_domain
[
k
]
if
k
in
self
.
_domain
.
keys
(
)
else
self
.
_stat_mf
[
k
].
domain
if
len
(
dom
)
!=
len
(
ss
):
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
[
alphabet
.
pop
()
for
_
in
range
(
len
(
dom
[
i
].
shape
))]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
','
shapes
[
k
]
=
dom
.
shape
numpy_subscripts
=
numpy_subscripts
[:
-
1
]
+
'->'
dom_sscr
=
dict
(
zip
(
self
.
_key_order
,
iss_spl
))
tgt
=
[]
for
o
in
oss
:
k_hit
=
tuple
(
k
for
k
,
sscr
in
dom_sscr
.
items
()
if
o
in
sscr
)[
0
]
dom_k_idx
=
dom_sscr
[
k_hit
].
index
(
o
)
if
k_hit
in
self
.
_domain
.
keys
():
tgt
+=
[
self
.
_domain
[
k_hit
][
dom_k_idx
]]
else
:
if
k_hit
not
in
self
.
_stat_mf
.
keys
():
ve
=
f
"
{
k_hit
}
is not in domain nor in static_mf"
raise
ValueError
(
ve
)
tgt
+=
[
self
.
_stat_mf
[
k_hit
].
domain
[
dom_k_idx
]]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
o
])
self
.
_target
=
DomainTuple
.
make
(
tgt
)
numpy_iss
,
numpy_oss
,
*
_
=
numpy_subscripts
.
split
(
"->"
)
numpy_iss_spl
=
numpy_iss
.
split
(
","
)
self
.
_sscr_endswith
=
dict
()
self
.
_linpaths
=
dict
()
for
k
,
(
i
,
ss
)
in
zip
(
self
.
_key_order
,
enumerate
(
numpy_iss_spl
)):
left_ss_spl
=
(
*
numpy_iss_spl
[:
i
],
*
numpy_iss_spl
[
i
+
1
:],
ss
)
linpath
=
'->'
.
join
((
','
.
join
(
left_ss_spl
),
numpy_oss
))
plc
=
tuple
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
q
])
for
q
in
shapes
if
q
!=
k
)
plc
+=
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
k
]),)
self
.
_sscr_endswith
[
k
]
=
linpath
self
.
_linpaths
[
k
]
=
np
.
einsum_path
(
linpath
,
*
plc
,
optimize
=
optimize
)[
0
]
if
isinstance
(
optimize
,
list
):
path
=
optimize
else
:
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
k
])
for
k
in
shapes
)
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_sscr
=
numpy_subscripts
self
.
_ein_kw
=
{
"optimize"
:
path
}
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
if
isinstance
(
x
,
Linearization
):
val
=
x
.
val
.
val
else
:
val
=
x
.
val
v
=
(
val
[
k
]
if
k
in
val
else
self
.
_stat_mf
[
k
].
val
for
k
in
self
.
_key_order
)
res
=
np
.
einsum
(
self
.
_sscr
,
*
v
,
**
self
.
_ein_kw
)
if
isinstance
(
x
,
Linearization
):
jac
=
None
for
wrt
in
self
.
domain
.
keys
():
plc
=
{
k
:
x
.
val
[
k
]
if
k
in
x
.
val
else
self
.
_stat_mf
[
k
]
for
k
in
self
.
_key_order
if
k
!=
wrt
}
mf_wo_k
=
MultiField
.
from_dict
(
plc
)
ss
=
self
.
_sscr_endswith
[
wrt
]
# Use the fact that the insertion order in a dictionary is the
# ordering of keys as to pass on `key_order`
jac_k
=
LinearEinsum
(
self
.
domain
[
wrt
],
mf_wo_k
,
ss
,
key_order
=
tuple
(
plc
.
keys
()),
optimize
=
self
.
_linpaths
[
wrt
],
_target
=
self
.
_target
,
_calling_as_lin
=
True
).
ducktape
(
wrt
)
jac
=
jac
+
jac_k
if
jac
is
not
None
else
jac_k
return
x
.
new
(
Field
.
from_raw
(
self
.
target
,
res
),
jac
)
return
Field
.
from_raw
(
self
.
target
,
res
)
class
LinearEinsum
(
LinearOperator
):
"""Linear Einsum operator with exactly one freely varying field
Parameters
----------
domain : Domain, DomainTuple or tuple of Domain
The operator's input domain.
mf : MultiField
The first part of the left-hand side of the einsum.
subscripts : str
The subscripts which is passed to einsum. Everything before the very
last scripts before the '->' is treated as part of the fixed mulfi-
field while the last scripts are taken to correspond to the freely
varying field.
key_order: tuple of str, optional
The order of the keys in the multi-field. If not specified, defaults to
the order of the keys in the multi-field.
optimize: bool, String or List, optional
Parameter passed on to einsum_path.
Notes
-----
By convention :class:`LinearEinsum` only performs operations with
lower indices. Therefore no complex conjugation is performed on complex
inputs or mf. To achieve operations with upper/lower indices use
:class:`PartialConjugate` before applying this operator.
"""
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
'optimal'
,
_target
=
None
,
_calling_as_lin
=
False
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
if
_calling_as_lin
:
self
.
_init_wo_preproc
(
mf
,
subscripts
,
key_order
,
optimize
,
_target
)
else
:
self
.
_mf
=
mf
if
key_order
is
None
:
_key_order
=
tuple
(
self
.
_mf
.
domain
.
keys
())
else
:
_key_order
=
key_order
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
iss
,
oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
sscr_consist
=
all
(
o
in
iss
for
o
in
oss
)
len_consist
=
len
(
_key_order
)
==
len
(
iss_spl
[:
-
1
])
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
_key_order
}
for subscripts
{
subscripts
}
"
shapes
,
numpy_subscripts
,
subscriptmap
=
(),
''
,
{}
alphabet
=
list
(
string
.
ascii_lowercase
)
for
k
,
ss
in
zip
(
_key_order
,
iss_spl
[:
-
1
]):
dom
=
self
.
_mf
[
k
].
domain
if
len
(
dom
)
!=
len
(
ss
):
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
[
alphabet
.
pop
()
for
_
in
range
(
len
(
dom
[
i
].
shape
))]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
','
shapes
+=
(
dom
.
shape
,)
if
len
(
self
.
_domain
)
!=
len
(
iss_spl
[
-
1
]):
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
iss_spl
[
-
1
])):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
[
alphabet
.
pop
()
for
_
in
range
(
len
(
self
.
_domain
[
i
].
shape
))]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
shapes
+=
(
self
.
_domain
.
shape
,)
numpy_subscripts
+=
'->'
dom_sscr
=
dict
(
zip
(
_key_order
,
iss_spl
[:
-
1
]))
dom_sscr
[
id
(
self
)]
=
iss_spl
[
-
1
]
tgt
=
[]
for
o
in
oss
:
k_hit
=
tuple
(
k
for
k
,
sscr
in
dom_sscr
.
items
()
if
o
in
sscr
)[
0
]
dom_k_idx
=
dom_sscr
[
k_hit
].
index
(
o
)
if
k_hit
in
_key_order
:
tgt
+=
[
self
.
_mf
.
domain
[
k_hit
][
dom_k_idx
]]
else
:
assert
k_hit
==
id
(
self
)
tgt
+=
[
self
.
_domain
[
dom_k_idx
]]
numpy_subscripts
+=
""
.
join
(
subscriptmap
[
o
])
_target
=
DomainTuple
.
make
(
tgt
)
self
.
_sscr
=
numpy_subscripts
if
isinstance
(
optimize
,
list
):
path
=
optimize
else
:
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_init_wo_preproc
(
mf
,
numpy_subscripts
,
_key_order
,
path
,
_target
)
def
_init_wo_preproc
(
self
,
mf
,
subscripts
,
keyorder
,
optimize
,
target
):
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
self
.
_mf
=
mf
self
.
_sscr
=
subscripts
self
.
_key_order
=
keyorder
self
.
_target
=
target
iss
,
oss
,
*
_
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
adj_iss
=
","
.
join
((
","
.
join
(
iss_spl
[:
-
1
]),
oss
))
self
.
_adj_sscr
=
"->"
.
join
((
adj_iss
,
iss_spl
[
-
1
]))
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
dom
,
ss
,
mf
=
self
.
target
,
self
.
_sscr
,
self
.
_mf
else
:
dom
,
ss
,
mf
=
self
.
domain
,
self
.
_adj_sscr
,
self
.
_mf
.
conjugate
()
res
=
np
.
einsum
(
ss
,
*
(
mf
[
k
].
val
for
k
in
self
.
_key_order
),
x
.
val
,
**
self
.
_ein_kw
)
return
Field
.
from_raw
(
dom
,
res
)
nifty6/operators/matrix_product_operator.py
0 → 100644
View file @
31932e59
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
..domain_tuple
import
DomainTuple
from
..field
import
Field
from
.endomorphic_operator
import
EndomorphicOperator
from
..
import
utilities
import
numpy
as
np
class
MatrixProductOperator
(
EndomorphicOperator
):
"""Endomorphic matrix multiplication with input field.
This operator supports scipy.sparse matrices and numpy arrays
as the matrix to be applied.
For numpy array matrices, can apply the matrix over a subspace
of the input.
If the input arrays have more than one dimension, for
scipy.sparse matrices the `flatten` keyword argument must be
set to true. This means that the input field will be flattened
before applying the matrix and reshaped to its original shape
afterwards.
Matrices are tested regarding their compatibility with the
called for application method.
Flattening and subspace application are mutually exclusive.
Parameters
----------
domain: :class:`Domain` or :class:`DomainTuple`
Domain of the operator.
If :class:`DomainTuple` it is assumed to have only one entry.
matrix: scipy.sparse matrix or numpy array
Quadratic matrix of shape `(domain.shape, domain.shape)`
(if `not flatten`) that supports `matrix.transpose()`.
If it is not a numpy array, needs to be applicable to the val
array of input fields by `matrix.dot()`.
spaces: int or tuple of int, optional
The subdomain(s) of "domain" which the operator acts on.
If None, it acts on all elements.
Only possible for numpy array matrices.
If `len(domain) > 1` and `flatten=False`, this parameter is
mandatory.
flatten: boolean, optional
Whether the input value array should be flattened before
applying the matrix and reshaped to its original shape
afterwards.
Needed for scipy.sparse matrices if `len(domain) > 1`.
"""
def
__init__
(
self
,
domain
,
matrix
,
spaces
=
None
,
flatten
=
False
):
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_domain
=
DomainTuple
.
make
(
domain
)
mat_dim
=
len
(
matrix
.
shape
)
if
mat_dim
%
2
!=
0
or
\
matrix
.
shape
!=
(
matrix
.
shape
[:
mat_dim
//
2
]
+
matrix
.
shape
[:
mat_dim
//
2
]):
raise
ValueError
(
"Matrix must be quadratic."
)
appl_dim
=
mat_dim
//
2
# matrix application space dimension
# take shortcut for trivial case
if
spaces
is
not
None
:
if
len
(
self
.
_domain
.
shape
)
==
1
and
spaces
==
(
0
,
):
spaces
=
None
if
spaces
is
None
:
self
.
_spaces
=
None
self
.
_active_axes
=
utilities
.
my_sum
(
self
.
_domain
.
axes
)
appl_space_shape
=
self
.
_domain
.
shape
if
flatten
:
appl_space_shape
=
(
utilities
.
my_product
(
appl_space_shape
),
)
else
:
if
flatten
:
raise
ValueError
(
"Cannot flatten input AND apply to a subspace"
)
if
not
isinstance
(
matrix
,
np
.
ndarray
):
raise
ValueError
(
"Application to subspaces only supported for numpy array matrices."
)
self
.
_spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_domain
))
appl_space_shape
=
[]
active_axes
=
[]
for
space_idx
in
spaces
:
appl_space_shape
+=
self
.
_domain
[
space_idx
].
shape
active_axes
+=
self
.
_domain
.
axes
[
space_idx
]
appl_space_shape
=
tuple
(
appl_space_shape
)
self
.
_active_axes
=
tuple
(
active_axes
)
self
.
_mat_last_n
=
tuple
([
-
appl_dim
+
i
for
i
in
range
(
appl_dim
)])
self
.
_mat_first_n
=
np
.
arange
(
appl_dim
)
# Test if the matrix and the array it will be applied to fit
if
matrix
.
shape
[:
appl_dim
]
!=
appl_space_shape
:
raise
ValueError
(
"Matrix and domain shapes are incompatible under the requested "
+
"application scheme.
\n
"
+
f
"Matrix appl shape:
{
matrix
.
shape
[
:
appl_dim
]
}
, "
+
f
"appl_space_shape:
{
appl_space_shape
}
."
)
self
.
_mat
=
matrix
self
.
_mat_tr
=
matrix
.
transpose
().
conjugate
()
self
.
_flatten
=
flatten
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
times
=
(
mode
==
self
.
TIMES
)
m
=
self
.
_mat
if
times
else
self
.
_mat_tr
if
self
.
_spaces
is
None
:
if
not
self
.
_flatten
:
res
=
m
.
dot
(
x
.
val
)
else
:
res
=
m
.
dot
(
x
.
val
.
flatten
()).
reshape
(
self
.
_domain
.
shape
)
return
Field
(
self
.
_domain
,
res
)
mat_axes
=
self
.
_mat_last_n
if
times
else
np
.
flip
(
self
.
_mat_last_n
)
move_axes
=
self
.
_mat_first_n
if
times
else
np
.
flip
(
self
.
_mat_first_n
)
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
(
mat_axes
,
self
.
_active_axes
))
res
=
np
.
moveaxis
(
res
,
move_axes
,
self
.
_active_axes
)
return
Field
(
self
.
_domain
,
res
)
nifty6/operators/partial_conjugate.py
0 → 100644
View file @
31932e59
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
.endomorphic_operator
import
EndomorphicOperator
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
class
PartialConjugate
(
EndomorphicOperator
):
"""Perform partial conjugation of a :class:`MultiField`
Parameters
----------
domain : MultiDomain
The operator's input domain and output target
conjugation_keys : iterable of string
The keys of the :class:`MultiField` for which complex conjugation
should be performed.
"""
def
__init__
(
self
,
domain
,
conjugation_keys
):
if
not
isinstance
(
domain
,
MultiDomain
):
raise
ValueError
(
"MultiDomain expected!"
)
indom
=
(
key
in
domain
.
keys
()
for
key
in
conjugation_keys
)
if
sum
(
indom
)
!=
len
(
conjugation_keys
):
raise
ValueError
(
"conjugation_keys not in domain!"
)
self
.
_domain
=
domain
self
.
_conjugation_keys
=
conjugation_keys
self
.
_capabilities
=
self
.
_all_ops
def
apply
(
self
,
x
,
mode
):