Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
ef4479aa
Commit
ef4479aa
authored
Aug 10, 2018
by
Martin Reinecke
Browse files
massive reworking
parent
a6e1e5e2
Changes
19
Hide whitespace changes
Inline
Side-by-side
demos/bernoulli_demo.py
View file @
ef4479aa
...
...
@@ -59,8 +59,9 @@ if __name__ == '__main__':
# Generate mock data
p
=
R
(
sky
)
mock_position
=
ift
.
from_random
(
'normal'
,
harmonic_space
)
data
=
np
.
random
.
binomial
(
1
,
p
(
mock_position
).
local_data
.
astype
(
np
.
float64
))
data
=
ift
.
Field
.
from_local_data
(
R
.
target
,
data
)
tmp
=
p
(
mock_position
).
to_global_data
().
astype
(
np
.
float64
)
data
=
np
.
random
.
binomial
(
1
,
tmp
)
data
=
ift
.
Field
.
from_global_data
(
R
.
target
,
data
)
# Compute likelihood and Hamiltonian
position
=
ift
.
from_random
(
'normal'
,
harmonic_space
)
...
...
nifty5/domain_tuple.py
View file @
ef4479aa
...
...
@@ -141,7 +141,7 @@ class DomainTuple(object):
def
__eq__
(
self
,
x
):
if
self
is
x
:
return
True
return
self
is
DomainTuple
.
make
(
x
)
return
self
.
_dom
==
x
.
_dom
def
__ne__
(
self
,
x
):
return
not
self
.
__eq__
(
x
)
...
...
nifty5/extra/energy_and_model_tests.py
View file @
ef4479aa
...
...
@@ -60,13 +60,13 @@ def _check_consistency(op, loc, tol, ntries, do_metric):
for
i
in
range
(
50
):
locmid
=
loc
+
0.5
*
dir
linmid
=
op
(
Linearization
.
make_var
(
locmid
))
dirder
=
linmid
.
jac
(
dir
)
/
dirnorm
numgrad
=
(
lin2
.
val
-
lin
.
val
)
/
dirnorm
dirder
=
linmid
.
jac
(
dir
)
numgrad
=
(
lin2
.
val
-
lin
.
val
)
xtol
=
tol
*
dirder
.
norm
()
/
np
.
sqrt
(
dirder
.
size
)
cond
=
(
abs
(
numgrad
-
dirder
)
<=
xtol
).
all
()
if
do_metric
:
dgrad
=
linmid
.
metric
(
dir
)
/
dirnorm
dgrad2
=
(
lin2
.
gradient
-
lin
.
gradient
)
/
dirnorm
dgrad
=
linmid
.
metric
(
dir
)
dgrad2
=
(
lin2
.
gradient
-
lin
.
gradient
)
cond
=
cond
and
(
abs
(
dgrad
-
dgrad2
)
<=
xtol
).
all
()
if
cond
:
break
...
...
nifty5/field.py
View file @
ef4479aa
...
...
@@ -348,7 +348,7 @@ class Field(object):
raise
TypeError
(
"The dot-partner must be an instance of "
+
"the NIFTy field class"
)
if
x
.
_domain
is
not
self
.
_domain
:
if
x
.
_domain
!=
self
.
_domain
:
raise
ValueError
(
"Domain mismatch"
)
ndom
=
len
(
self
.
_domain
)
...
...
@@ -609,7 +609,7 @@ class Field(object):
"
\n
- val = "
+
repr
(
self
.
_val
)
def
extract
(
self
,
dom
):
if
dom
is
not
self
.
_domain
:
if
dom
!=
self
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
return
self
...
...
@@ -623,13 +623,14 @@ class Field(object):
# if other is a field, make sure that the domains match
f
=
getattr
(
self
.
_val
,
op
)
if
isinstance
(
other
,
Field
):
if
other
.
_domain
is
not
self
.
_domain
:
if
other
.
_domain
!=
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
return
Field
(
self
.
_domain
,
f
(
other
.
_val
))
if
np
.
isscalar
(
other
):
return
Field
(
self
.
_domain
,
f
(
other
))
return
NotImplemented
for
op
in
[
"__add__"
,
"__radd__"
,
"__sub__"
,
"__rsub__"
,
"__mul__"
,
"__rmul__"
,
...
...
nifty5/library/correlated_fields.py
View file @
ef4479aa
...
...
@@ -26,11 +26,12 @@ from ..operators.domain_distributor import DomainDistributor
from
..operators.harmonic_operators
import
HarmonicTransformOperator
from
..operators.power_distributor
import
PowerDistributor
from
..operators.operator
import
Operator
from
..operators.simple_linear_operators
import
FieldAdapter
class
CorrelatedField
(
Operator
):
def
CorrelatedField
(
s_space
,
amplitude_model
):
'''
Class
for construction of correlated fields
Function
for construction of correlated fields
Parameters
----------
...
...
@@ -38,17 +39,14 @@ class CorrelatedField(Operator):
amplitude_model : model for correlation structure
'''
def
__init__
(
self
,
s_space
,
amplitude_model
):
h_space
=
s_space
.
get_default_codomain
()
self
.
_ht
=
HarmonicTransformOperator
(
h_space
,
s_space
)
p_space
=
amplitude_model
.
target
[
0
]
power_distributor
=
PowerDistributor
(
h_space
,
p_space
)
self
.
_A
=
power_distributor
(
amplitude_model
)
self
.
_domain
=
MultiDomain
.
union
(
(
amplitude_model
.
domain
,
MultiDomain
.
make
({
"xi"
:
h_space
})))
def
apply
(
self
,
x
):
return
self
.
_ht
(
self
.
_A
(
x
)
*
x
[
"xi"
])
h_space
=
s_space
.
get_default_codomain
()
ht
=
HarmonicTransformOperator
(
h_space
,
s_space
)
p_space
=
amplitude_model
.
target
[
0
]
power_distributor
=
PowerDistributor
(
h_space
,
p_space
)
A
=
power_distributor
(
amplitude_model
)
domain
=
MultiDomain
.
union
(
(
amplitude_model
.
domain
,
MultiDomain
.
make
({
"xi"
:
h_space
})))
return
ht
(
A
*
FieldAdapter
(
domain
,
"xi"
))
# def make_mf_correlated_field(s_space_spatial, s_space_energy,
...
...
nifty5/linearization.py
View file @
ef4479aa
...
...
@@ -13,6 +13,8 @@ class Linearization(object):
def
__init__
(
self
,
val
,
jac
,
metric
=
None
):
self
.
_val
=
val
self
.
_jac
=
jac
if
self
.
_val
.
domain
!=
self
.
_jac
.
target
:
raise
ValueError
(
"domain mismatch"
)
self
.
_metric
=
metric
@
property
...
...
@@ -61,13 +63,12 @@ class Linearization(object):
def
__add__
(
self
,
other
):
if
isinstance
(
other
,
Linearization
):
from
.operators.relaxed_sum_operator
import
RelaxedSumOperator
met
=
None
if
self
.
_metric
is
not
None
and
other
.
_metric
is
not
None
:
met
=
RelaxedSumOperator
((
self
.
_metric
,
other
.
_metric
)
)
met
=
self
.
_metric
.
_myadd
(
other
.
_metric
,
False
)
return
Linearization
(
self
.
_val
.
unite
(
other
.
_val
),
RelaxedSumOperator
((
self
.
_jac
,
other
.
_jac
)
),
met
)
self
.
_jac
.
_myadd
(
other
.
_jac
,
False
),
met
)
if
isinstance
(
other
,
(
int
,
float
,
complex
,
Field
,
MultiField
)):
return
Linearization
(
self
.
_val
+
other
,
self
.
_jac
,
self
.
_metric
)
...
...
@@ -83,15 +84,20 @@ class Linearization(object):
def
__mul__
(
self
,
other
):
from
.sugar
import
makeOp
if
isinstance
(
other
,
Linearization
):
if
self
.
target
!=
other
.
target
:
raise
ValueError
(
"domain mismatch"
)
return
Linearization
(
self
.
_val
*
other
.
_val
,
makeOp
(
other
.
_val
)(
self
.
_jac
)
+
makeOp
(
self
.
_val
)(
other
.
_jac
))
(
makeOp
(
other
.
_val
)(
self
.
_jac
)).
_myadd
(
makeOp
(
self
.
_val
)(
other
.
_jac
),
False
))
if
np
.
isscalar
(
other
):
if
other
==
1
:
return
self
met
=
None
if
self
.
_metric
is
None
else
self
.
_metric
.
scale
(
other
)
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
.
scale
(
other
),
met
)
if
isinstance
(
other
,
(
Field
,
MultiField
)):
if
self
.
target
!=
other
.
domain
:
raise
ValueError
(
"domain mismatch"
)
return
Linearization
(
self
.
_val
*
other
,
makeOp
(
other
)(
self
.
_jac
))
def
__rmul__
(
self
,
other
):
...
...
nifty5/multi_domain.py
View file @
ef4479aa
...
...
@@ -95,7 +95,7 @@ class MultiDomain(object):
def
__eq__
(
self
,
x
):
if
self
is
x
:
return
True
return
self
is
MultiDomain
.
make
(
x
)
return
self
.
items
()
==
x
.
items
(
)
def
__ne__
(
self
,
x
):
return
not
self
.
__eq__
(
x
)
...
...
@@ -115,7 +115,7 @@ class MultiDomain(object):
for
dom
in
inp
:
for
key
,
subdom
in
zip
(
dom
.
_keys
,
dom
.
_domains
):
if
key
in
res
:
if
res
[
key
]
is
not
subdom
:
if
res
[
key
]
!=
subdom
:
raise
ValueError
(
"domain mismatch"
)
else
:
res
[
key
]
=
subdom
...
...
nifty5/multi_field.py
View file @
ef4479aa
...
...
@@ -42,7 +42,7 @@ class MultiField(object):
raise
ValueError
(
"length mismatch"
)
for
d
,
v
in
zip
(
domain
.
_domains
,
val
):
if
isinstance
(
v
,
Field
):
if
v
.
_domain
is
not
d
:
if
v
.
_domain
!=
d
:
raise
ValueError
(
"domain mismatch"
)
else
:
raise
TypeError
(
"bad entry in val (must be Field)"
)
...
...
@@ -103,7 +103,7 @@ class MultiField(object):
for
dom
in
domain
.
_domains
))
def
_check_domain
(
self
,
other
):
if
other
.
_domain
is
not
self
.
_domain
:
if
other
.
_domain
!=
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
def
vdot
(
self
,
x
):
...
...
@@ -216,7 +216,7 @@ class MultiField(object):
def
_binary_op
(
self
,
other
,
op
):
f
=
getattr
(
Field
,
op
)
if
isinstance
(
other
,
MultiField
):
if
self
.
_domain
is
not
other
.
_domain
:
if
self
.
_domain
!=
other
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
val
=
tuple
(
f
(
v1
,
v2
)
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
))
...
...
nifty5/operators/block_diagonal_operator.py
View file @
ef4479aa
...
...
@@ -57,14 +57,14 @@ class BlockDiagonalOperator(EndomorphicOperator):
# return MultiField(self._domain, val)
def
_combine_chain
(
self
,
op
):
if
self
.
_domain
is
not
op
.
_domain
:
if
self
.
_domain
!=
op
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
res
=
tuple
(
v1
(
v2
)
for
v1
,
v2
in
zip
(
self
.
_ops
,
op
.
_ops
))
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
def
_combine_sum
(
self
,
op
,
selfneg
,
opneg
):
from
..operators.sum_operator
import
SumOperator
if
self
.
_domain
is
not
op
.
_domain
:
if
self
.
_domain
!=
op
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
res
=
tuple
(
SumOperator
.
make
([
v1
,
v2
],
[
selfneg
,
opneg
])
for
v1
,
v2
in
zip
(
self
.
_ops
,
op
.
_ops
))
...
...
nifty5/operators/chain_operator.py
View file @
ef4479aa
...
...
@@ -44,7 +44,7 @@ class ChainOperator(LinearOperator):
from
.diagonal_operator
import
DiagonalOperator
# Step 1: verify domains
for
i
in
range
(
len
(
ops
)
-
1
):
if
ops
[
i
+
1
].
target
is
not
ops
[
i
].
domain
:
if
ops
[
i
+
1
].
target
!=
ops
[
i
].
domain
:
raise
ValueError
(
"domain mismatch"
)
# Step 2: unpack ChainOperators
opsnew
=
[]
...
...
nifty5/operators/diagonal_operator.py
View file @
ef4479aa
...
...
@@ -65,7 +65,7 @@ class DiagonalOperator(EndomorphicOperator):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
if
spaces
is
None
:
self
.
_spaces
=
None
if
diagonal
.
domain
is
not
self
.
_domain
:
if
diagonal
.
domain
!=
self
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
else
:
self
.
_spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_domain
))
...
...
nifty5/operators/endomorphic_operator.py
View file @
ef4479aa
...
...
@@ -62,5 +62,5 @@ class EndomorphicOperator(LinearOperator):
def
_check_input
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
if
self
.
domain
is
not
x
.
domain
:
if
self
.
domain
!=
x
.
domain
:
raise
ValueError
(
"The operator's and field's domains don't match."
)
nifty5/operators/energy_operators.py
View file @
ef4479aa
...
...
@@ -85,7 +85,7 @@ class GaussianEnergy(EnergyOperator):
if
self
.
_domain
is
None
:
self
.
_domain
=
newdom
else
:
if
self
.
_domain
is
not
newdom
:
if
self
.
_domain
!=
newdom
:
raise
ValueError
(
"domain mismatch"
)
def
apply
(
self
,
x
):
...
...
@@ -157,6 +157,5 @@ class SampledKullbachLeiblerDivergence(EnergyOperator):
self
.
_res_samples
=
tuple
(
res_samples
)
def
apply
(
self
,
x
):
res
=
(
utilities
.
my_sum
(
map
(
lambda
v
:
self
.
_h
(
x
+
v
),
self
.
_res_samples
))
*
(
1.
/
len
(
self
.
_res_samples
)))
return
res
mymap
=
map
(
lambda
v
:
self
.
_h
(
x
+
v
),
self
.
_res_samples
)
return
utilities
.
my_sum
(
mymap
)
*
(
1.
/
len
(
self
.
_res_samples
))
nifty5/operators/linear_operator.py
View file @
ef4479aa
...
...
@@ -116,10 +116,16 @@ class LinearOperator(Operator):
return
ChainOperator
.
make
([
other
,
self
])
return
Operator
.
__rmatmul__
(
self
,
other
)
def
_myadd
(
self
,
other
,
oneg
):
if
self
.
domain
==
other
.
domain
and
self
.
target
==
other
.
target
:
from
.sum_operator
import
SumOperator
return
SumOperator
.
make
((
self
,
other
),
(
False
,
oneg
))
from
.relaxed_sum_operator
import
RelaxedSumOperator
return
RelaxedSumOperator
((
self
,
-
other
if
oneg
else
other
))
def
__add__
(
self
,
other
):
if
isinstance
(
other
,
LinearOperator
):
from
.sum_operator
import
SumOperator
return
SumOperator
.
make
([
self
,
other
],
[
False
,
False
])
return
self
.
_myadd
(
other
,
False
)
return
Operator
.
__add__
(
self
,
other
)
def
__radd__
(
self
,
other
):
...
...
@@ -127,14 +133,12 @@ class LinearOperator(Operator):
def
__sub__
(
self
,
other
):
if
isinstance
(
other
,
LinearOperator
):
from
.sum_operator
import
SumOperator
return
SumOperator
.
make
([
self
,
other
],
[
False
,
True
])
return
self
.
_myadd
(
other
,
True
)
return
Operator
.
__sub__
(
self
,
other
)
def
__rsub__
(
self
,
other
):
if
isinstance
(
other
,
LinearOperator
):
from
.sum_operator
import
SumOperator
return
SumOperator
.
make
([
other
,
self
],
[
False
,
True
])
return
other
.
_myadd
(
self
,
True
)
return
Operator
.
__rsub__
(
self
,
other
)
@
property
...
...
@@ -260,5 +264,5 @@ class LinearOperator(Operator):
def
_check_input
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
if
self
.
_dom
(
mode
)
is
not
x
.
domain
:
if
self
.
_dom
(
mode
)
!=
x
.
domain
:
raise
ValueError
(
"The operator's and field's domains don't match."
)
nifty5/operators/operator.py
View file @
ef4479aa
...
...
@@ -50,15 +50,15 @@ class Operator(NiftyMetaBase()):
def
__mul__
(
self
,
x
):
if
not
isinstance
(
x
,
Operator
):
return
NotImplemented
return
_OpProd
.
make
(
(
self
,
x
)
)
return
_OpProd
(
self
,
x
)
def
apply
(
self
,
x
):
raise
NotImplementedError
def
__call__
(
self
,
x
):
if
isinstance
(
x
,
Operator
):
return
_OpChain
.
make
((
self
,
x
))
return
self
.
apply
(
x
)
if
isinstance
(
x
,
Operator
):
return
_OpChain
.
make
((
self
,
x
))
return
self
.
apply
(
x
)
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"positive_tanh"
]:
...
...
@@ -108,6 +108,9 @@ class _OpChain(_CombinedOperator):
super
(
_OpChain
,
self
).
__init__
(
ops
,
_callingfrommake
)
self
.
_domain
=
self
.
_ops
[
-
1
].
domain
self
.
_target
=
self
.
_ops
[
0
].
target
for
i
in
range
(
1
,
len
(
self
.
_ops
)):
if
self
.
_ops
[
i
-
1
].
domain
!=
self
.
_ops
[
i
].
target
:
raise
ValueError
(
"domain mismatch"
)
def
apply
(
self
,
x
):
for
op
in
reversed
(
self
.
_ops
):
...
...
@@ -115,21 +118,44 @@ class _OpChain(_CombinedOperator):
return
x
class
_OpProd
(
_CombinedOperator
):
def
__init__
(
self
,
ops
,
_callingfrommake
=
False
):
super
(
_OpProd
,
self
).
__init__
(
ops
,
_callingfrommake
)
self
.
_domain
=
self
.
_ops
[
0
].
domain
self
.
_target
=
self
.
_ops
[
0
].
target
class
_OpProd
(
Operator
):
def
__init__
(
self
,
op1
,
op2
):
from
..sugar
import
domain_union
self
.
_domain
=
domain_union
((
op1
.
domain
,
op2
.
domain
))
self
.
_target
=
op1
.
target
if
op1
.
target
!=
op2
.
target
:
raise
ValueError
(
"target mismatch"
)
self
.
_op1
=
op1
self
.
_op2
=
op2
def
apply
(
self
,
x
):
return
my_product
(
map
(
lambda
op
:
op
(
x
),
self
.
_ops
))
from
..linearization
import
Linearization
from
..sugar
import
makeOp
lin
=
isinstance
(
x
,
Linearization
)
if
not
lin
:
r1
=
self
.
_op1
(
x
.
extract
(
self
.
_op1
.
domain
))
r2
=
self
.
_op2
(
x
.
extract
(
self
.
_op2
.
domain
))
return
r1
*
r2
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
x
.
_val
.
extract
(
self
.
_op1
.
domain
)))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
x
.
_val
.
extract
(
self
.
_op2
.
domain
)))
op
=
(
makeOp
(
lin1
.
_val
)(
lin2
.
_jac
)).
_myadd
(
makeOp
(
lin2
.
_val
)(
lin1
.
_jac
),
False
)
jac
=
op
(
x
.
jac
)
return
Linearization
(
lin1
.
_val
*
lin2
.
_val
,
jac
)
class
_OpSum
(
_CombinedOperator
):
def
__init__
(
self
,
ops
,
_callingfrommake
=
False
):
from
..sugar
import
domain_union
super
(
_OpSum
,
self
).
__init__
(
ops
,
_callingfrommake
)
self
.
_domain
=
domain_union
([
op
.
domain
for
op
in
self
.
_ops
])
self
.
_target
=
domain_union
([
op
.
target
for
op
in
self
.
_ops
])
def
apply
(
self
,
x
):
raise
NotImplementedError
res
=
None
for
op
in
self
.
_ops
:
tmp
=
op
(
x
.
extract
(
op
.
domain
))
res
=
tmp
if
res
is
None
else
res
.
unite
(
tmp
)
return
res
nifty5/operators/relaxed_sum_operator.py
View file @
ef4479aa
...
...
@@ -38,12 +38,6 @@ class RelaxedSumOperator(LinearOperator):
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
for
op
in
ops
:
self
.
_capability
&=
op
.
capability
#self._ops = []
#for op in ops:
# if isinstance(op, RelaxedSumOperator):
# self._ops += op._ops
# else:
# self._ops += [op]
@
property
def
adjoint
(
self
):
...
...
nifty5/operators/simple_linear_operators.py
View file @
ef4479aa
...
...
@@ -36,7 +36,7 @@ class VdotOperator(LinearOperator):
self
.
_field
=
field
self
.
_domain
=
field
.
domain
self
.
_target
=
DomainTuple
.
scalar_domain
()
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
...
...
@@ -49,7 +49,7 @@ class SumReductionOperator(LinearOperator):
def
__init__
(
self
,
domain
):
self
.
_domain
=
domain
self
.
_target
=
DomainTuple
.
scalar_domain
()
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
@@ -61,7 +61,7 @@ class SumReductionOperator(LinearOperator):
class
ConjugationOperator
(
EndomorphicOperator
):
def
__init__
(
self
,
domain
):
self
.
_domain
=
domain
self
.
_capability
=
self
.
_all_ops
self
.
_capability
=
self
.
_all_ops
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
@@ -71,7 +71,7 @@ class ConjugationOperator(EndomorphicOperator):
class
Realizer
(
EndomorphicOperator
):
def
__init__
(
self
,
domain
):
self
.
_domain
=
domain
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
@@ -79,20 +79,17 @@ class Realizer(EndomorphicOperator):
class
FieldAdapter
(
LinearOperator
):
def
__init__
(
self
,
dom
,
name_dom
):
self
.
_domain
=
MultiDomain
.
make
(
dom
)
self
.
_name
=
name_dom
self
.
_target
=
dom
[
name_dom
]
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
__init__
(
self
,
dom
,
name
):
self
.
_target
=
dom
[
name
]
self
.
_domain
=
MultiDomain
.
make
({
name
:
self
.
_target
})
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
x
[
self
.
_name
]
values
=
tuple
(
Field
(
dom
,
0.
)
if
key
!=
self
.
_name
else
x
for
key
,
dom
in
self
.
_domain
.
items
())
return
MultiField
(
self
.
_domain
,
values
)
return
x
.
values
()[
0
]
return
MultiField
(
self
.
_domain
,
(
x
,))
class
GeometryRemover
(
LinearOperator
):
...
...
@@ -115,7 +112,7 @@ class GeometryRemover(LinearOperator):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
target_list
=
[
UnstructuredDomain
(
dom
.
shape
)
for
dom
in
self
.
_domain
]
self
.
_target
=
DomainTuple
.
make
(
target_list
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
@@ -137,7 +134,7 @@ class NullOperator(LinearOperator):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
self
.
_target
=
makeDomain
(
target
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
@
staticmethod
def
_nullfield
(
dom
):
...
...
nifty5/operators/symmetrizing_operator.py
View file @
ef4479aa
...
...
@@ -30,7 +30,7 @@ from .. import utilities
class
SymmetrizingOperator
(
EndomorphicOperator
):
def
__init__
(
self
,
domain
,
space
=
0
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_space
=
utilities
.
infer_space
(
self
.
_domain
,
space
)
dom
=
self
.
_domain
[
self
.
_space
]
if
not
(
isinstance
(
dom
,
LogRGSpace
)
and
not
dom
.
harmonic
):
...
...
nifty5/sugar.py
View file @
ef4479aa
...
...
@@ -246,7 +246,7 @@ def makeOp(input):
def
domain_union
(
domains
):
if
isinstance
(
domains
[
0
],
DomainTuple
):
if
any
(
dom
is
not
domains
[
0
]
for
dom
in
domains
[
1
:]):
if
any
(
dom
!=
domains
[
0
]
for
dom
in
domains
[
1
:]):
raise
ValueError
(
"domain mismatch"
)
return
domains
[
0
]
return
MultiDomain
.
union
(
domains
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment