Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
af46206a
Commit
af46206a
authored
Apr 08, 2020
by
Martin Reinecke
Browse files
domain -> target, round 1
parent
d19a916d
Pipeline
#72562
failed with stages
in 44 seconds
Changes
27
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/extra.py
View file @
af46206a
...
...
@@ -92,22 +92,22 @@ def _actual_domain_check_linear(op, domain_dtype=None, inp=None):
inp
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
elif
inp
is
None
:
raise
ValueError
(
'Need to specify either dtype or inp'
)
assert_
(
inp
.
domain
is
op
.
domain
)
assert_
(
op
(
inp
).
domain
is
op
.
target
)
assert_
(
inp
.
target
is
op
.
domain
)
assert_
(
op
(
inp
).
target
is
op
.
target
)
def
_actual_domain_check_nonlinear
(
op
,
loc
):
assert
isinstance
(
loc
,
(
Field
,
MultiField
))
assert_
(
loc
.
domain
is
op
.
domain
)
assert_
(
loc
.
target
is
op
.
domain
)
for
wm
in
[
False
,
True
]:
lin
=
Linearization
.
make_var
(
loc
,
wm
)
reslin
=
op
(
lin
)
assert_
(
lin
.
domain
is
op
.
domain
)
assert_
(
lin
.
target
is
op
.
domain
)
assert_
(
lin
.
fld
.
domain
is
lin
.
domain
)
assert_
(
lin
.
fld
.
target
is
lin
.
target
)
assert_
(
reslin
.
domain
is
op
.
domain
)
assert_
(
reslin
.
target
is
op
.
target
)
assert_
(
reslin
.
fld
.
domain
is
reslin
.
target
)
assert_
(
reslin
.
fld
.
target
is
reslin
.
target
)
assert_
(
reslin
.
target
is
op
.
target
)
assert_
(
reslin
.
jac
.
domain
is
reslin
.
domain
)
assert_
(
reslin
.
jac
.
target
is
reslin
.
target
)
...
...
@@ -123,7 +123,7 @@ def _domain_check(op):
for
dd
in
[
op
.
domain
,
op
.
target
]:
if
not
isinstance
(
dd
,
(
DomainTuple
,
MultiDomain
)):
raise
TypeError
(
'The domain and the target of an operator need to'
,
'The domain and the target of an operator need to
'
'be instances of either DomainTuple or MultiDomain.'
)
...
...
@@ -220,7 +220,7 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
def
_get_acceptable_location
(
op
,
loc
,
lin
):
if
not
np
.
isfinite
(
lin
.
fld
.
s_sum
()):
raise
ValueError
(
'Initial value must be finite'
)
dir
=
from_random
(
"normal"
,
loc
.
domain
)
dir
=
from_random
(
"normal"
,
loc
.
target
)
dirder
=
lin
.
jac
(
dir
)
if
dirder
.
norm
()
==
0
:
dir
=
dir
*
(
lin
.
fld
.
norm
()
*
1e-5
)
...
...
nifty6/field.py
View file @
af46206a
...
...
@@ -174,7 +174,8 @@ class Field(Operand):
@
property
def
domain
(
self
):
"""DomainTuple : the field's domain"""
return
self
.
_domain
raise
NotImplementedError
return
None
# self._domain
@
property
def
target
(
self
):
...
...
@@ -286,13 +287,13 @@ class Field(Operand):
Returns
-------
Field
Defined on the product space of self.
domain
and x.
domain
.
Defined on the product space of self.
target
and x.
target
.
"""
if
not
isinstance
(
x
,
Field
):
raise
TypeError
(
"The multiplier must be an instance of "
+
"the Field class"
)
from
.operators.outer_product_operator
import
OuterProduct
return
OuterProduct
(
self
,
x
.
domain
)(
x
)
return
OuterProduct
(
self
,
x
.
target
)(
x
)
def
vdot
(
self
,
x
,
spaces
=
None
):
"""Computes the dot product of 'self' with x.
...
...
nifty6/library/correlated_fields.py
View file @
af46206a
...
...
@@ -110,8 +110,8 @@ def _structured_spaces(domain):
def
_total_fluctuation_realized
(
samples
):
spaces
=
_structured_spaces
(
samples
[
0
].
domain
)
co
=
ContractionOperator
(
samples
[
0
].
domain
,
spaces
)
spaces
=
_structured_spaces
(
samples
[
0
].
target
)
co
=
ContractionOperator
(
samples
[
0
].
target
,
spaces
)
size
=
co
.
domain
.
size
/
co
.
target
.
size
res
=
0.
for
s
in
samples
:
...
...
@@ -606,7 +606,7 @@ class CorrelatedFieldMaker:
@
staticmethod
def
offset_amplitude_realized
(
samples
):
spaces
=
_structured_spaces
(
samples
[
0
].
domain
)
spaces
=
_structured_spaces
(
samples
[
0
].
target
)
res
=
0.
for
s
in
samples
:
res
=
res
+
s
.
mean
(
spaces
)
**
2
...
...
@@ -621,7 +621,7 @@ class CorrelatedFieldMaker:
def
slice_fluctuation_realized
(
samples
,
space
):
"""Computes slice fluctuations from collection of field (defined in signal
space) realizations."""
spaces
=
_structured_spaces
(
samples
[
0
].
domain
)
spaces
=
_structured_spaces
(
samples
[
0
].
target
)
if
space
>=
len
(
spaces
):
raise
ValueError
(
"invalid space specified; got {!r}"
.
format
(
space
))
if
len
(
spaces
)
==
1
:
...
...
@@ -640,7 +640,7 @@ class CorrelatedFieldMaker:
def
average_fluctuation_realized
(
samples
,
space
):
"""Computes average fluctuations from collection of field (defined in signal
space) realizations."""
spaces
=
_structured_spaces
(
samples
[
0
].
domain
)
spaces
=
_structured_spaces
(
samples
[
0
].
target
)
if
space
>=
len
(
spaces
):
raise
ValueError
(
"invalid space specified; got {!r}"
.
format
(
space
))
if
len
(
spaces
)
==
1
:
...
...
@@ -649,7 +649,7 @@ class CorrelatedFieldMaker:
sub_spaces
=
set
(
spaces
)
sub_spaces
.
remove
(
space
)
# Domain containing domain[space] and domain[0] iff total_N>0
sub_dom
=
makeDomain
([
samples
[
0
].
domain
[
ind
]
sub_dom
=
makeDomain
([
samples
[
0
].
target
[
ind
]
for
ind
in
(
set
([
0
])
-
set
(
spaces
))
|
set
([
space
])])
co
=
ContractionOperator
(
sub_dom
,
len
(
sub_dom
)
-
1
)
size
=
co
.
domain
.
size
/
co
.
target
.
size
...
...
nifty6/linearization.py
View file @
af46206a
...
...
@@ -42,7 +42,7 @@ class Linearization(Operand):
def
__init__
(
self
,
fld
,
jac
,
metric
=
None
,
want_metric
=
False
):
self
.
_fld
=
fld
self
.
_jac
=
jac
if
self
.
_fld
.
domain
!=
self
.
_jac
.
target
:
if
self
.
_fld
.
target
!=
self
.
_jac
.
target
:
raise
ValueError
(
"domain mismatch"
)
self
.
_want_metric
=
want_metric
self
.
_metric
=
metric
...
...
@@ -217,8 +217,8 @@ class Linearization(Operand):
return
self
.
__mul__
(
other
)
from
.operators.outer_product_operator
import
OuterProduct
if
other
.
jac
is
None
:
return
self
.
new
(
OuterProduct
(
self
.
_fld
,
other
.
domain
)(
other
),
OuterProduct
(
self
.
_jac
(
self
.
_fld
),
other
.
domain
))
return
self
.
new
(
OuterProduct
(
self
.
_fld
,
other
.
target
)(
other
),
OuterProduct
(
self
.
_jac
(
self
.
_fld
),
other
.
target
))
return
self
.
new
(
OuterProduct
(
self
.
_fld
,
other
.
target
)(
other
.
_fld
),
OuterProduct
(
self
.
_jac
(
self
.
_fld
),
other
.
target
).
_myadd
(
...
...
@@ -318,7 +318,7 @@ class Linearization(Operand):
the requested Linearization
"""
from
.operators.scaling_operator
import
ScalingOperator
return
Linearization
(
field
,
ScalingOperator
(
field
.
domain
,
1.
),
return
Linearization
(
field
,
ScalingOperator
(
field
.
target
,
1.
),
want_metric
=
want_metric
)
@
staticmethod
...
...
@@ -343,7 +343,7 @@ class Linearization(Operand):
The Jacobian is square and contains only zeroes.
"""
from
.operators.simple_linear_operators
import
NullOperator
return
Linearization
(
field
,
NullOperator
(
field
.
domain
,
field
.
domain
),
return
Linearization
(
field
,
NullOperator
(
field
.
target
,
field
.
target
),
want_metric
=
want_metric
)
@
staticmethod
...
...
@@ -371,7 +371,7 @@ class Linearization(Operand):
from
.operators.simple_linear_operators
import
NullOperator
from
.multi_domain
import
MultiDomain
return
Linearization
(
field
,
NullOperator
(
MultiDomain
.
make
({}),
field
.
domain
),
field
,
NullOperator
(
MultiDomain
.
make
({}),
field
.
target
),
want_metric
=
want_metric
)
@
staticmethod
...
...
@@ -405,6 +405,6 @@ class Linearization(Operand):
return
Linearization
.
make_var
(
field
,
want_metric
)
else
:
ops
=
{
key
:
ScalingOperator
(
dom
,
0.
if
key
in
constants
else
1.
)
for
key
,
dom
in
field
.
domain
.
items
()}
bdop
=
BlockDiagonalOperator
(
field
.
domain
,
ops
)
for
key
,
dom
in
field
.
target
.
items
()}
bdop
=
BlockDiagonalOperator
(
field
.
target
,
ops
)
return
Linearization
(
field
,
bdop
,
want_metric
=
want_metric
)
nifty6/minimization/metric_gaussian_kl.py
View file @
af46206a
...
...
@@ -50,18 +50,18 @@ def _allreduce_sum_field(comm, fld):
if
comm
is
None
:
return
fld
if
isinstance
(
fld
,
Field
):
return
Field
(
fld
.
domain
,
_np_allreduce_sum
(
fld
.
val
))
return
Field
(
fld
.
target
,
_np_allreduce_sum
(
fld
.
val
))
res
=
tuple
(
Field
(
f
.
domain
,
_np_allreduce_sum
(
comm
,
f
.
val
))
Field
(
f
.
target
,
_np_allreduce_sum
(
comm
,
f
.
val
))
for
f
in
fld
.
values
())
return
MultiField
(
fld
.
domain
,
res
)
return
MultiField
(
fld
.
target
,
res
)
class
_KLMetric
(
EndomorphicOperator
):
def
__init__
(
self
,
KL
):
self
.
_KL
=
KL
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_domain
=
KL
.
position
.
domain
self
.
_domain
=
KL
.
position
.
target
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
@@ -144,7 +144,7 @@ class MetricGaussianKL(Energy):
if
not
isinstance
(
hamiltonian
,
StandardHamiltonian
):
raise
TypeError
if
hamiltonian
.
domain
is
not
mean
.
domain
:
if
hamiltonian
.
domain
is
not
mean
.
target
:
raise
ValueError
if
not
isinstance
(
n_samples
,
int
):
raise
TypeError
...
...
nifty6/multi_field.py
View file @
af46206a
...
...
@@ -41,7 +41,9 @@ class MultiField(Operand):
raise
ValueError
(
"length mismatch"
)
for
d
,
v
in
zip
(
domain
.
_domains
,
val
):
if
isinstance
(
v
,
Field
):
if
v
.
_domain
!=
d
:
if
v
.
target
!=
d
:
print
(
v
.
target
)
print
(
d
)
raise
ValueError
(
"domain mismatch"
)
else
:
raise
TypeError
(
"bad entry in val (must be Field)"
)
...
...
@@ -52,7 +54,7 @@ class MultiField(Operand):
def
from_dict
(
dict
,
domain
=
None
):
if
domain
is
None
:
for
dd
in
dict
.
values
():
if
not
isinstance
(
dd
.
domain
,
DomainTuple
):
if
not
isinstance
(
dd
.
target
,
DomainTuple
):
raise
TypeError
(
'Values of dictionary need to be Fields '
'defined on DomainTuples.'
)
domain
=
MultiDomain
.
make
({
key
:
v
.
_domain
...
...
@@ -81,7 +83,8 @@ class MultiField(Operand):
@
property
def
domain
(
self
):
return
self
.
_domain
raise
NotImplementedError
return
None
#self._domain
@
property
def
target
(
self
):
...
...
@@ -329,15 +332,15 @@ class MultiField(Operand):
for
i
in
range
(
len
(
self
.
_val
)):
argstmp
,
kwargstmp
=
self
.
_prep_args
(
args
,
kwargs
,
i
)
tmp
.
append
(
self
.
_val
[
i
].
ptw
(
op
,
*
argstmp
,
**
kwargstmp
))
return
MultiField
(
self
.
domain
,
tuple
(
tmp
))
return
MultiField
(
self
.
target
,
tuple
(
tmp
))
def
ptw_with_deriv
(
self
,
op
,
*
args
,
**
kwargs
):
tmp
=
[]
for
i
in
range
(
len
(
self
.
_val
)):
argstmp
,
kwargstmp
=
self
.
_prep_args
(
args
,
kwargs
,
i
)
tmp
.
append
(
self
.
_val
[
i
].
ptw_with_deriv
(
op
,
*
argstmp
,
**
kwargstmp
))
return
(
MultiField
(
self
.
domain
,
tuple
(
v
[
0
]
for
v
in
tmp
)),
MultiField
(
self
.
domain
,
tuple
(
v
[
1
]
for
v
in
tmp
)))
return
(
MultiField
(
self
.
target
,
tuple
(
v
[
0
]
for
v
in
tmp
)),
MultiField
(
self
.
target
,
tuple
(
v
[
1
]
for
v
in
tmp
)))
def
_binary_op
(
self
,
other
,
op
):
f
=
getattr
(
Field
,
op
)
...
...
nifty6/operators/adder.py
View file @
af46206a
...
...
@@ -34,7 +34,7 @@ class Adder(Operator):
def
__init__
(
self
,
a
,
neg
=
False
,
domain
=
None
):
self
.
_a
=
a
if
isinstance
(
a
,
(
Field
,
MultiField
)):
dom
=
a
.
domain
dom
=
a
.
target
elif
np
.
isscalar
(
a
):
dom
=
makeDomain
(
domain
)
else
:
...
...
nifty6/operators/convolution_operators.py
View file @
af46206a
...
...
@@ -62,14 +62,14 @@ def FuncConvolutionOperator(domain, func, space=None):
def
_ConvolutionOperator
(
domain
,
kernel
,
space
=
None
):
domain
=
DomainTuple
.
make
(
domain
)
space
=
utilities
.
infer_space
(
domain
,
space
)
if
len
(
kernel
.
domain
)
!=
1
:
if
len
(
kernel
.
target
)
!=
1
:
raise
ValueError
(
"kernel needs exactly one domain"
)
if
not
isinstance
(
domain
[
space
],
(
HPSpace
,
GLSpace
,
RGSpace
)):
raise
TypeError
(
"need RGSpace, HPSpace, or GLSpace"
)
lm
=
[
d
for
d
in
domain
]
lm
[
space
]
=
lm
[
space
].
get_default_codomain
()
lm
=
DomainTuple
.
make
(
lm
)
if
lm
[
space
]
!=
kernel
.
domain
[
0
]:
if
lm
[
space
]
!=
kernel
.
target
[
0
]:
raise
ValueError
(
"Input domain and kernel are incompatible"
)
HT
=
HarmonicTransformOperator
(
lm
,
domain
[
space
],
space
)
diag
=
DiagonalOperator
(
kernel
*
domain
[
space
].
total_volume
,
lm
,
(
space
,))
...
...
nifty6/operators/diagonal_operator.py
View file @
af46206a
...
...
@@ -56,19 +56,19 @@ class DiagonalOperator(EndomorphicOperator):
if
not
isinstance
(
diagonal
,
Field
):
raise
TypeError
(
"Field object required"
)
if
domain
is
None
:
self
.
_domain
=
diagonal
.
domain
self
.
_domain
=
diagonal
.
target
else
:
self
.
_domain
=
DomainTuple
.
make
(
domain
)
if
spaces
is
None
:
self
.
_spaces
=
None
if
diagonal
.
domain
!=
self
.
_domain
:
if
diagonal
.
target
!=
self
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
else
:
self
.
_spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_domain
))
if
len
(
self
.
_spaces
)
!=
len
(
diagonal
.
domain
):
if
len
(
self
.
_spaces
)
!=
len
(
diagonal
.
target
):
raise
ValueError
(
"spaces and domain must have the same length"
)
for
i
,
j
in
enumerate
(
self
.
_spaces
):
if
diagonal
.
domain
[
i
]
!=
self
.
_domain
[
j
]:
if
diagonal
.
target
[
i
]
!=
self
.
_domain
[
j
]:
raise
ValueError
(
"domain mismatch"
)
if
self
.
_spaces
==
tuple
(
range
(
len
(
self
.
_domain
))):
self
.
_spaces
=
None
# shortcut
...
...
@@ -130,15 +130,15 @@ class DiagonalOperator(EndomorphicOperator):
self
.
_check_input
(
x
,
mode
)
# shortcut for most common cases
if
mode
==
1
or
(
not
self
.
_complex
and
mode
==
2
):
return
Field
(
x
.
domain
,
x
.
val
*
self
.
_ldiag
)
return
Field
(
x
.
target
,
x
.
val
*
self
.
_ldiag
)
xdiag
=
self
.
_ldiag
if
self
.
_complex
and
(
mode
&
10
):
# adjoint or inverse adjoint
xdiag
=
xdiag
.
conj
()
if
mode
&
3
:
return
Field
(
x
.
domain
,
x
.
val
*
xdiag
)
return
Field
(
x
.
domain
,
x
.
val
/
xdiag
)
return
Field
(
x
.
target
,
x
.
val
*
xdiag
)
return
Field
(
x
.
target
,
x
.
val
/
xdiag
)
def
_flip_modes
(
self
,
trafo
):
if
trafo
==
self
.
ADJOINT_BIT
and
not
self
.
_complex
:
# shortcut
...
...
nifty6/operators/distributors.py
View file @
af46206a
...
...
@@ -51,17 +51,17 @@ class DOFDistributor(LinearOperator):
def
__init__
(
self
,
dofdex
,
target
=
None
,
space
=
None
):
if
target
is
None
:
target
=
dofdex
.
domain
target
=
dofdex
.
target
self
.
_target
=
DomainTuple
.
make
(
target
)
space
=
infer_space
(
self
.
_target
,
space
)
partner
=
self
.
_target
[
space
]
if
not
isinstance
(
dofdex
,
Field
):
raise
TypeError
(
"dofdex must be a Field"
)
if
not
len
(
dofdex
.
domain
)
==
1
:
if
not
len
(
dofdex
.
target
)
==
1
:
raise
ValueError
(
"dofdex must be defined on exactly one Space"
)
if
not
np
.
issubdtype
(
dofdex
.
dtype
,
np
.
integer
):
raise
TypeError
(
"dofdex must contain integer numbers"
)
if
partner
!=
dofdex
.
domain
[
0
]:
if
partner
!=
dofdex
.
target
[
0
]:
raise
ValueError
(
"incorrect dofdex domain"
)
ldat
=
dofdex
.
val
...
...
nifty6/operators/endomorphic_operator.py
View file @
af46206a
...
...
@@ -60,5 +60,5 @@ class EndomorphicOperator(LinearOperator):
def
_check_input
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
if
self
.
domain
!=
x
.
domain
:
if
self
.
domain
!=
x
.
target
:
raise
ValueError
(
"The operator's and field's domains don't match."
)
nifty6/operators/energy_operators.py
View file @
af46206a
...
...
@@ -166,7 +166,7 @@ class GaussianEnergy(EnergyOperator):
self
.
_domain
=
None
if
mean
is
not
None
:
self
.
_checkEquivalence
(
mean
.
domain
)
self
.
_checkEquivalence
(
mean
.
target
)
if
inverse_covariance
is
not
None
:
self
.
_checkEquivalence
(
inverse_covariance
.
domain
)
if
domain
is
not
None
:
...
...
@@ -223,7 +223,7 @@ class PoissonianEnergy(EnergyOperator):
if
np
.
any
(
d
.
val
<
0
):
raise
ValueError
self
.
_d
=
d
self
.
_domain
=
DomainTuple
.
make
(
d
.
domain
)
self
.
_domain
=
DomainTuple
.
make
(
d
.
target
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
...
...
@@ -257,10 +257,10 @@ class InverseGammaLikelihood(EnergyOperator):
def
__init__
(
self
,
beta
,
alpha
=-
0.5
):
if
not
isinstance
(
beta
,
Field
):
raise
TypeError
self
.
_domain
=
DomainTuple
.
make
(
beta
.
domain
)
self
.
_domain
=
DomainTuple
.
make
(
beta
.
target
)
self
.
_beta
=
beta
if
np
.
isscalar
(
alpha
):
alpha
=
Field
(
beta
.
domain
,
np
.
full
(
beta
.
shape
,
alpha
))
alpha
=
Field
(
beta
.
target
,
np
.
full
(
beta
.
shape
,
alpha
))
elif
not
isinstance
(
alpha
,
Field
):
raise
TypeError
self
.
_alphap1
=
alpha
+
1
...
...
@@ -311,7 +311,7 @@ class BernoulliEnergy(EnergyOperator):
E(f) = -
\\
log
\\
text{Bernoulli}(d|f)
= -d^
\\
dagger
\\
log f - (1-d)^
\\
dagger
\\
log(1-f),
where f is a field defined on `d.
domain
` with the expected
where f is a field defined on `d.
target
` with the expected
frequencies of events.
Parameters
...
...
@@ -326,7 +326,7 @@ class BernoulliEnergy(EnergyOperator):
if
not
np
.
all
(
np
.
logical_or
(
d
.
val
==
0
,
d
.
val
==
1
)):
raise
ValueError
self
.
_d
=
d
self
.
_domain
=
DomainTuple
.
make
(
d
.
domain
)
self
.
_domain
=
DomainTuple
.
make
(
d
.
target
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
...
...
nifty6/operators/harmonic_operators.py
View file @
af46206a
...
...
@@ -72,14 +72,14 @@ class FFTOperator(LinearOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
ncells
=
x
.
domain
[
self
.
_space
].
size
if
x
.
domain
[
self
.
_space
].
harmonic
:
# harmonic -> position
ncells
=
x
.
target
[
self
.
_space
].
size
if
x
.
target
[
self
.
_space
].
harmonic
:
# harmonic -> position
func
=
fft
.
fftn
fct
=
1.
else
:
func
=
fft
.
ifftn
fct
=
ncells
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axes
=
x
.
target
.
axes
[
self
.
_space
]
tdom
=
self
.
_tgt
(
mode
)
tmp
=
func
(
x
.
val
,
axes
=
axes
)
Tval
=
Field
(
tdom
,
tmp
)
...
...
@@ -146,7 +146,7 @@ class HartleyOperator(LinearOperator):
return
self
.
_apply_cartesian
(
x
,
mode
)
def
_apply_cartesian
(
self
,
x
,
mode
):
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axes
=
x
.
target
.
axes
[
self
.
_space
]
tdom
=
self
.
_tgt
(
mode
)
tmp
=
fft
.
hartley
(
x
.
val
,
axes
=
axes
)
Tval
=
Field
(
tdom
,
tmp
)
...
...
@@ -247,10 +247,10 @@ class SHTOperator(LinearOperator):
return
res
/
np
.
sqrt
(
np
.
pi
*
4
)
def
_apply_spherical
(
self
,
x
,
mode
):
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axes
=
x
.
target
.
axes
[
self
.
_space
]
v
=
x
.
val
p2h
=
not
x
.
domain
[
self
.
_space
].
harmonic
p2h
=
not
x
.
target
[
self
.
_space
].
harmonic
tdom
=
self
.
_tgt
(
mode
)
func
=
self
.
_slice_p2h
if
p2h
else
self
.
_slice_h2p
odat
=
np
.
empty
(
tdom
.
shape
,
dtype
=
x
.
dtype
)
...
...
nifty6/operators/inversion_enabler.py
View file @
af46206a
...
...
@@ -65,7 +65,7 @@ class InversionEnabler(EndomorphicOperator):
if
self
.
_op
.
capability
&
mode
:
return
self
.
_op
.
apply
(
x
,
mode
)
x0
=
full
(
x
.
domain
,
0.
)
x0
=
full
(
x
.
target
,
0.
)
invmode
=
self
.
_modeTable
[
self
.
INVERSE_BIT
][
self
.
_ilog
[
mode
]]
invop
=
self
.
_op
.
_flip_modes
(
self
.
_ilog
[
invmode
])
prec
=
self
.
_approximation
...
...
nifty6/operators/linear_operator.py
View file @
af46206a
...
...
@@ -255,4 +255,4 @@ class LinearOperator(Operator):
def
_check_input
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
self
.
_check_domain_equality
(
self
.
_dom
(
mode
),
x
.
domain
)
self
.
_check_domain_equality
(
self
.
_dom
(
mode
),
x
.
target
)
nifty6/operators/mask_operator.py
View file @
af46206a
...
...
@@ -37,7 +37,7 @@ class MaskOperator(LinearOperator):
def
__init__
(
self
,
flags
):
if
not
isinstance
(
flags
,
Field
):
raise
TypeError
self
.
_domain
=
DomainTuple
.
make
(
flags
.
domain
)
self
.
_domain
=
DomainTuple
.
make
(
flags
.
target
)
self
.
_flags
=
np
.
logical_not
(
flags
.
val
)
self
.
_target
=
DomainTuple
.
make
(
UnstructuredDomain
(
self
.
_flags
.
sum
()))
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
...
...
nifty6/operators/operator.py
View file @
af46206a
...
...
@@ -179,7 +179,7 @@ class Operator(metaclass=NiftyMeta):
raise
ValueError
if
x
.
jac
.
_factor
!=
1
:
raise
ValueError
self
.
_check_domain_equality
(
self
.
_domain
,
x
.
domain
)
self
.
_check_domain_equality
(
self
.
_domain
,
x
.
target
)
def
__call__
(
self
,
x
):
if
isinstance
(
x
,
Operator
):
...
...
@@ -205,7 +205,7 @@ class Operator(metaclass=NiftyMeta):
def
simplify_for_constant_input
(
self
,
c_inp
):
if
c_inp
is
None
:
return
None
,
self
if
c_inp
.
domain
==
self
.
domain
:
if
c_inp
.
target
==
self
.
domain
:
op
=
_ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
return
op
(
c_inp
),
op
return
self
.
_simplify_for_constant_input_nontrivial
(
c_inp
)
...
...
@@ -270,7 +270,7 @@ class _ConstantOperator(Operator):
def
__init__
(
self
,
dom
,
output
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
dom
)
self
.
_target
=
output
.
domain
self
.
_target
=
output
.
target
self
.
_output
=
output
def
apply
(
self
,
x
):
...
...
nifty6/operators/outer_product_operator.py
View file @
af46206a
...
...
@@ -35,7 +35,7 @@ class OuterProduct(LinearOperator):
self
.
_domain
=
domain
self
.
_field
=
field
self
.
_target
=
DomainTuple
.
make
(
tuple
(
sub_d
for
sub_d
in
field
.
domain
.
_dom
+
domain
.
_dom
))
tuple
(
sub_d
for
sub_d
in
field
.
target
.
_dom
+
domain
.
_dom
))
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
...
...
nifty6/operators/scaling_operator.py
View file @
af46206a
...
...
@@ -66,7 +66,7 @@ class ScalingOperator(EndomorphicOperator):
if
fct
==
1.
:
return
x
if
fct
==
0.
:
return
full
(
x
.
domain
,
0.
)
return
full
(
x
.
target
,
0.
)
MODES_WITH_ADJOINT
=
self
.
ADJOINT_TIMES
|
self
.
ADJOINT_INVERSE_TIMES
MODES_WITH_INVERSE
=
self
.
INVERSE_TIMES
|
self
.
ADJOINT_INVERSE_TIMES
...
...
nifty6/operators/simple_linear_operators.py
View file @
af46206a
...
...
@@ -36,7 +36,7 @@ class VdotOperator(LinearOperator):
"""
def
__init__
(
self
,
field
):
self
.
_field
=
field
self
.
_domain
=
field
.
domain
self
.
_domain
=
field
.
target
self
.
_target
=
DomainTuple
.
scalar_domain
()
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
...
...
@@ -346,7 +346,7 @@ class PartialExtractor(LinearOperator):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
x
.
extract
(
self
.
_target
)
res0
=
MultiField
.
from_dict
({
key
:
x
[
key
]
for
key
in
x
.
domain
.
keys
()})
res0
=
MultiField
.
from_dict
({
key
:
x
[
key
]
for
key
in
x
.
target
.
keys
()})
res1
=
MultiField
.
full
(
self
.
_compldomain
,
0.
)
return
res0
.
unite
(
res1
)
...
...
Prev
1
2
Next
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment