Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
087530b0
Commit
087530b0
authored
May 23, 2018
by
Martin Reinecke
Browse files
merge NIFTy_4
parents
5429bb64
ec50fcc0
Pipeline
#29620
passed with stages
in 4 minutes and 21 seconds
Changes
40
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty4/library/poisson_energy.py
View file @
087530b0
...
...
@@ -20,7 +20,7 @@ from ..minimization.energy import Energy
from
..operators.diagonal_operator
import
DiagonalOperator
from
..operators.sandwich_operator
import
SandwichOperator
from
..operators.inversion_enabler
import
InversionEnabler
from
..
field
import
log
from
..
sugar
import
log
class
PoissonEnergy
(
Energy
):
...
...
@@ -46,7 +46,7 @@ class PoissonEnergy(Energy):
R1
=
Instrument
*
Rho
*
ht
self
.
_grad
=
(
phipos
+
R1
.
adjoint_times
((
lam
-
d
)
/
(
lam
+
eps
))).
lock
()
self
.
_curv
=
Phi_h
.
inverse
+
SandwichOperator
(
R1
,
W
)
self
.
_curv
=
Phi_h
.
inverse
+
SandwichOperator
.
make
(
R1
,
W
)
def
at
(
self
,
position
):
return
self
.
__class__
(
position
,
self
.
_d
,
self
.
_Instrument
,
...
...
nifty4/library/wiener_filter_curvature.py
View file @
087530b0
...
...
@@ -39,5 +39,5 @@ def WienerFilterCurvature(R, N, S, inverter):
inverter : Minimizer
The minimizer to use during numerical inversion
"""
op
=
SandwichOperator
(
R
,
N
.
inverse
)
+
S
.
inverse
op
=
SandwichOperator
.
make
(
R
,
N
.
inverse
)
+
S
.
inverse
return
InversionEnabler
(
op
,
inverter
,
S
.
inverse
)
nifty4/logger.py
View file @
087530b0
...
...
@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
def
_logger_init
():
import
logging
from
.
import
dobj
...
...
nifty4/multi/__init__.py
View file @
087530b0
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.block_diagonal_operator
import
BlockDiagonalOperator
__all__
=
[
"MultiDomain"
,
"MultiField"
]
__all__
=
[
"MultiDomain"
,
"MultiField"
,
"BlockDiagonalOperator"
]
nifty4/multi/block_diagonal_operator.py
0 → 100644
View file @
087530b0
import
numpy
as
np
from
..operators.endomorphic_operator
import
EndomorphicOperator
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
class
BlockDiagonalOperator
(
EndomorphicOperator
):
def
__init__
(
self
,
operators
):
"""
Parameters
----------
operators : dict
dictionary with operators domain names as keys and
LinearOperators as items
"""
super
(
BlockDiagonalOperator
,
self
).
__init__
()
self
.
_operators
=
operators
self
.
_domain
=
MultiDomain
.
make
(
{
key
:
op
.
domain
for
key
,
op
in
self
.
_operators
.
items
()})
self
.
_cap
=
self
.
_all_ops
for
op
in
self
.
_operators
.
values
():
self
.
_cap
&=
op
.
capability
@
property
def
domain
(
self
):
return
self
.
_domain
@
property
def
capability
(
self
):
return
self
.
_cap
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
return
MultiField
({
key
:
op
.
apply
(
x
[
key
],
mode
=
mode
)
for
key
,
op
in
self
.
_operators
.
items
()})
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
self
.
_domain
)
return
MultiField
({
key
:
op
.
draw_sample
(
from_inverse
,
dtype
[
key
])
for
key
,
op
in
self
.
_operators
.
items
()})
def
_combine_chain
(
self
,
op
):
res
=
{}
for
key
in
self
.
_operators
.
keys
():
res
[
key
]
=
self
.
_operators
[
key
]
*
op
.
_operators
[
key
]
return
BlockDiagonalOperator
(
res
)
def
_combine_sum
(
self
,
op
,
selfneg
,
opneg
):
from
..operators.sum_operator
import
SumOperator
res
=
{}
for
key
in
self
.
_operators
.
keys
():
res
[
key
]
=
SumOperator
.
make
([
self
.
_operators
[
key
],
op
.
_operators
[
key
]],
[
selfneg
,
opneg
])
return
BlockDiagonalOperator
(
res
)
nifty4/multi/multi_domain.py
View file @
087530b0
class
MultiDomain
(
dict
):
pass
import
collections
from
..domain_tuple
import
DomainTuple
__all
=
[
"MultiDomain"
]
class
frozendict
(
collections
.
Mapping
):
"""
An immutable wrapper around dictionaries that implements the complete
:py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
"""
dict_cls
=
dict
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_dict
=
self
.
dict_cls
(
*
args
,
**
kwargs
)
self
.
_hash
=
None
def
__getitem__
(
self
,
key
):
return
self
.
_dict
[
key
]
def
__contains__
(
self
,
key
):
return
key
in
self
.
_dict
def
copy
(
self
,
**
add_or_replace
):
return
self
.
__class__
(
self
,
**
add_or_replace
)
def
__iter__
(
self
):
return
iter
(
self
.
_dict
)
def
__len__
(
self
):
return
len
(
self
.
_dict
)
def
__repr__
(
self
):
return
'<%s %r>'
%
(
self
.
__class__
.
__name__
,
self
.
_dict
)
def
__hash__
(
self
):
if
self
.
_hash
is
None
:
h
=
0
for
key
,
value
in
self
.
_dict
.
items
():
h
^=
hash
((
key
,
value
))
self
.
_hash
=
h
return
self
.
_hash
class
MultiDomain
(
frozendict
):
_domainCache
=
{}
def
__init__
(
self
,
domain
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
super
(
MultiDomain
,
self
).
__init__
(
domain
)
@
staticmethod
def
make
(
domain
):
if
isinstance
(
domain
,
MultiDomain
):
return
domain
if
not
isinstance
(
domain
,
dict
):
raise
TypeError
(
"dict expected"
)
tmp
=
{}
for
key
,
value
in
domain
.
items
():
if
not
isinstance
(
key
,
str
):
raise
TypeError
(
"keys must be strings"
)
tmp
[
key
]
=
DomainTuple
.
make
(
value
)
domain
=
frozendict
(
tmp
)
obj
=
MultiDomain
.
_domainCache
.
get
(
domain
)
if
obj
is
not
None
:
return
obj
obj
=
MultiDomain
(
domain
,
_callingfrommake
=
True
)
MultiDomain
.
_domainCache
[
domain
]
=
obj
return
obj
nifty4/multi/multi_field.py
View file @
087530b0
...
...
@@ -44,7 +44,8 @@ class MultiField(object):
@
property
def
domain
(
self
):
return
MultiDomain
({
key
:
val
.
domain
for
key
,
val
in
self
.
_val
.
items
()})
return
MultiDomain
.
make
(
{
key
:
val
.
domain
for
key
,
val
in
self
.
_val
.
items
()})
@
property
def
dtype
(
self
):
...
...
@@ -57,6 +58,18 @@ class MultiField(object):
dtype
[
key
],
**
kwargs
)
for
key
in
domain
.
keys
()})
def
fill
(
self
,
fill_value
):
"""Fill `self` uniformly with `fill_value`
Parameters
----------
fill_value: float or complex or int
The value to fill the field with.
"""
for
val
in
self
.
_val
.
values
():
val
.
fill
(
fill_value
)
return
self
def
_check_domain
(
self
,
other
):
if
other
.
domain
!=
self
.
domain
:
raise
ValueError
(
"domains are incompatible."
)
...
...
@@ -73,9 +86,22 @@ class MultiField(object):
v
.
lock
()
return
self
@
property
def
locked
(
self
):
return
all
(
v
.
locked
for
v
in
self
.
values
())
def
copy
(
self
):
return
MultiField
({
key
:
val
.
copy
()
for
key
,
val
in
self
.
items
()})
def
locked_copy
(
self
):
if
self
.
locked
:
return
self
return
MultiField
({
key
:
val
.
locked_copy
()
for
key
,
val
in
self
.
items
()})
def
empty_copy
(
self
):
return
MultiField
({
key
:
val
.
empty_copy
()
for
key
,
val
in
self
.
items
()})
@
staticmethod
def
build_dtype
(
dtype
,
domain
):
if
isinstance
(
dtype
,
dict
):
...
...
@@ -85,22 +111,24 @@ class MultiField(object):
return
{
key
:
dtype
for
key
in
domain
.
keys
()}
@
staticmethod
def
zeros
(
domain
,
dtype
=
None
):
def
empty
(
domain
,
dtype
=
None
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
domain
)
return
MultiField
({
key
:
Field
.
zeros
(
dom
,
dtype
=
dtype
[
key
])
return
MultiField
({
key
:
Field
.
empty
(
dom
,
dtype
=
dtype
[
key
])
for
key
,
dom
in
domain
.
items
()})
@
staticmethod
def
ones
(
domain
,
dtype
=
None
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
domain
)
return
MultiField
({
key
:
Field
.
ones
(
dom
,
dtype
=
dtype
[
key
])
def
full
(
domain
,
val
):
return
MultiField
({
key
:
Field
.
full
(
dom
,
val
)
for
key
,
dom
in
domain
.
items
()})
def
to_global_data
(
self
):
return
{
key
:
val
.
to_global_data
()
for
key
,
val
in
self
.
_val
.
items
()}
@
staticmethod
def
empty
(
domain
,
dtype
=
Non
e
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
domain
)
return
MultiField
({
key
:
Field
.
empty
(
dom
,
dtype
=
dtype
[
key
]
)
for
key
,
dom
in
domain
.
items
()})
def
from_global_data
(
domain
,
arr
,
sum_up
=
Fals
e
):
return
MultiField
({
key
:
Field
.
from_global_data
(
domain
[
key
],
val
,
sum_up
)
for
key
,
val
in
arr
.
items
()})
def
norm
(
self
):
""" Computes the L2-norm of the field values.
...
...
nifty4/operators/chain_operator.py
View file @
087530b0
...
...
@@ -64,7 +64,7 @@ class ChainOperator(LinearOperator):
opsnew
[
i
]
=
opsnew
[
i
].
_scale
(
fct
)
fct
=
1.
break
if
fct
!=
1
:
if
fct
!=
1
or
len
(
opsnew
)
==
0
:
# have to add the scaling operator at the end
opsnew
.
append
(
ScalingOperator
(
fct
,
lastdom
))
ops
=
opsnew
...
...
@@ -78,11 +78,24 @@ class ChainOperator(LinearOperator):
else
:
opsnew
.
append
(
op
)
ops
=
opsnew
# Step 5: combine BlockDiagonalOperators where possible
from
..multi.block_diagonal_operator
import
BlockDiagonalOperator
opsnew
=
[]
for
op
in
ops
:
if
(
len
(
opsnew
)
>
0
and
isinstance
(
opsnew
[
-
1
],
BlockDiagonalOperator
)
and
isinstance
(
op
,
BlockDiagonalOperator
)):
opsnew
[
-
1
]
=
opsnew
[
-
1
].
_combine_chain
(
op
)
else
:
opsnew
.
append
(
op
)
ops
=
opsnew
return
ops
@
staticmethod
def
make
(
ops
):
ops
=
tuple
(
ops
)
if
len
(
ops
)
==
0
:
raise
ValueError
(
"ops is empty"
)
ops
=
ChainOperator
.
simplify
(
ops
)
if
len
(
ops
)
==
1
:
return
ops
[
0
]
...
...
nifty4/operators/fft_operator.py
View file @
087530b0
...
...
@@ -61,9 +61,7 @@ class FFTOperator(LinearOperator):
adom
.
check_codomain
(
target
)
target
.
check_codomain
(
adom
)
import
pyfftw
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
utilities
.
fft_prep
()
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
@@ -74,7 +72,6 @@ class FFTOperator(LinearOperator):
return
self
.
_apply_cartesian
(
x
,
mode
)
def
_apply_cartesian
(
self
,
x
,
mode
):
from
pyfftw.interfaces.numpy_fft
import
fftn
axes
=
x
.
domain
.
axes
[
self
.
_space
]
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
oldax
=
dobj
.
distaxis
(
x
.
val
)
...
...
@@ -110,7 +107,7 @@ class FFTOperator(LinearOperator):
tmp
=
dobj
.
from_local_data
(
shp2d
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
)
ldat2
=
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
utilities
.
my_
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
ldat2
.
real
+
ldat2
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
...
...
nifty4/operators/inversion_enabler.py
View file @
087530b0
...
...
@@ -67,7 +67,7 @@ class InversionEnabler(EndomorphicOperator):
if
self
.
_op
.
capability
&
mode
:
return
self
.
_op
.
apply
(
x
,
mode
)
x0
=
x
*
0.
x0
=
x
.
empty_copy
().
fill
(
0.
)
invmode
=
self
.
_modeTable
[
self
.
INVERSE_BIT
][
self
.
_ilog
[
mode
]]
invop
=
self
.
_op
.
_flip_modes
(
self
.
_ilog
[
invmode
])
prec
=
self
.
_approximation
...
...
nifty4/operators/linear_operator.py
View file @
087530b0
...
...
@@ -271,10 +271,6 @@ class LinearOperator(NiftyMetaBase()):
raise
ValueError
(
"requested operator mode is not supported"
)
def
_check_input
(
self
,
x
,
mode
):
# MR FIXME: temporary fix for working with MultiFields
#if not isinstance(x, Field):
# raise ValueError("supplied object is not a `Field`.")
self
.
_check_mode
(
mode
)
if
x
.
domain
!=
self
.
_dom
(
mode
):
raise
ValueError
(
"The operator's and field's domains don't match."
)
nifty4/operators/sandwich_operator.py
View file @
087530b0
...
...
@@ -16,31 +16,46 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
.diagonal_operator
import
DiagonalOperator
from
.endomorphic_operator
import
EndomorphicOperator
from
.scaling_operator
import
ScalingOperator
import
numpy
as
np
class
SandwichOperator
(
EndomorphicOperator
):
"""Operator which is equivalent to the expression `bun.adjoint*cheese*bun`.
Parameters
----------
bun: LinearOperator
the bun part
cheese: EndomorphicOperator
the cheese part
"""
def
__init__
(
self
,
bun
,
cheese
=
None
):
def
__init__
(
self
,
bun
,
cheese
,
op
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
super
(
SandwichOperator
,
self
).
__init__
()
self
.
_bun
=
bun
self
.
_cheese
=
cheese
self
.
_op
=
op
@
staticmethod
def
make
(
bun
,
cheese
=
None
):
"""Build a SandwichOperator (or something simpler if possible)
Parameters
----------
bun: LinearOperator
the bun part
cheese: EndomorphicOperator
the cheese part
"""
if
cheese
is
None
:
self
.
_
cheese
=
ScalingOperator
(
1.
,
bun
.
target
)
self
.
_
op
=
bun
.
adjoint
*
bun
cheese
=
ScalingOperator
(
1.
,
bun
.
target
)
op
=
bun
.
adjoint
*
bun
else
:
self
.
_cheese
=
cheese
self
.
_op
=
bun
.
adjoint
*
cheese
*
bun
op
=
bun
.
adjoint
*
cheese
*
bun
# if our sandwich is diagonal, we can return immediately
if
isinstance
(
op
,
(
ScalingOperator
,
DiagonalOperator
)):
return
op
return
SandwichOperator
(
bun
,
cheese
,
op
,
_callingfrommake
=
True
)
@
property
def
domain
(
self
):
...
...
@@ -54,8 +69,11 @@ class SandwichOperator(EndomorphicOperator):
return
self
.
_op
.
apply
(
x
,
mode
)
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
# Inverse samples from general sandwiches is not possible
if
from_inverse
:
raise
NotImplementedError
(
"cannot draw from inverse of this operator"
)
# Samples from general sandwiches
return
self
.
_bun
.
adjoint_times
(
self
.
_cheese
.
draw_sample
(
from_inverse
,
dtype
))
nifty4/operators/scaling_operator.py
View file @
087530b0
...
...
@@ -20,8 +20,8 @@ from __future__ import division
import
numpy
as
np
from
..field
import
Field
from
..multi.multi_field
import
MultiField
from
..domain_tuple
import
DomainTuple
from
.endomorphic_operator
import
EndomorphicOperator
from
..domain_tuple
import
DomainTuple
class
ScalingOperator
(
EndomorphicOperator
):
...
...
@@ -49,12 +49,13 @@ class ScalingOperator(EndomorphicOperator):
"""
def
__init__
(
self
,
factor
,
domain
):
from
..sugar
import
makeDomain
super
(
ScalingOperator
,
self
).
__init__
()
if
not
np
.
isscalar
(
factor
):
raise
TypeError
(
"Scalar required"
)
self
.
_factor
=
factor
self
.
_domain
=
Domain
Tuple
.
make
(
domain
)
self
.
_domain
=
make
Domain
(
domain
)
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
@@ -62,7 +63,7 @@ class ScalingOperator(EndomorphicOperator):
if
self
.
_factor
==
1.
:
return
x
.
copy
()
if
self
.
_factor
==
0.
:
return
x
.
zeros_like
(
x
)
return
x
.
empty_copy
().
fill
(
0.
)
if
mode
==
self
.
TIMES
:
return
x
*
self
.
_factor
...
...
nifty4/operators/sum_operator.py
View file @
087530b0
...
...
@@ -102,12 +102,36 @@ class SumOperator(LinearOperator):
negnew
.
append
(
neg
[
i
])
ops
=
opsnew
neg
=
negnew
# Step 5: combine BlockDiagonalOperators where possible
from
..multi.block_diagonal_operator
import
BlockDiagonalOperator
processed
=
[
False
]
*
len
(
ops
)
opsnew
=
[]
negnew
=
[]
for
i
in
range
(
len
(
ops
)):
if
not
processed
[
i
]:
if
isinstance
(
ops
[
i
],
BlockDiagonalOperator
):
op
=
ops
[
i
]
opneg
=
neg
[
i
]
for
j
in
range
(
i
+
1
,
len
(
ops
)):
if
isinstance
(
ops
[
j
],
BlockDiagonalOperator
):
op
=
op
.
_combine_sum
(
ops
[
j
],
opneg
,
neg
[
j
])
opneg
=
False
processed
[
j
]
=
True
opsnew
.
append
(
op
)
negnew
.
append
(
opneg
)
else
:
opsnew
.
append
(
ops
[
i
])
negnew
.
append
(
neg
[
i
])
ops
=
opsnew
neg
=
negnew
return
ops
,
neg
@
staticmethod
def
make
(
ops
,
neg
):
ops
=
tuple
(
ops
)
neg
=
tuple
(
neg
)
if
len
(
ops
)
==
0
:
raise
ValueError
(
"ops is empty"
)
if
len
(
ops
)
!=
len
(
neg
):
raise
ValueError
(
"length mismatch between ops and neg"
)
ops
,
neg
=
SumOperator
.
simplify
(
ops
,
neg
)
...
...
nifty4/sugar.py
View file @
087530b0
...
...
@@ -16,23 +16,25 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import
sys
import
numpy
as
np
from
.domains.power_space
import
PowerSpace
from
.field
import
Field
from
.multi.multi_field
import
MultiField
from
.multi.multi_domain
import
MultiDomain
from
.operators.diagonal_operator
import
DiagonalOperator
from
.operators.power_distributor
import
PowerDistributor
from
.domain_tuple
import
DomainTuple
from
.
import
dobj
,
utilities
from
.logger
import
logger
__all__
=
[
'PS_field'
,
'
power_analyze
'
,
'
create_power_operator
'
,
'
create_harmonic_smoothing_operator
'
,
__all__
=
[
'PS_field'
,
'power_analyze'
,
'create_power_operator'
,
'
create_harmonic_smoothing_operator'
,
'from_random
'
,
'
full'
,
'empty'
,
'from_global_data'
,
'from_local_data
'
,
'
makeDomain'
,
'sqrt'
,
'exp'
,
'log'
,
'tanh'
,
'conjugate
'
,
'get_signal_variance'
]
def
PS_field
(
pspace
,
func
):
if
not
isinstance
(
pspace
,
PowerSpace
):
raise
TypeError
...
...
@@ -53,15 +55,16 @@ def get_signal_variance(spec, space):
a method that takes one k-value and returns the power spectrum at that
location
space: PowerSpace or any harmonic Domain
If this function is given a harmonic domain, it creates the naturally
binned
PowerSpace to that domain.
The field, for which the signal variance is then computed, is assumed
to have
this PowerSpace as naturally binned PowerSpace
If this function is given a harmonic domain, it creates the naturally
binned
PowerSpace to that domain.
The field, for which the signal variance is then computed, is assumed
to have
this PowerSpace as naturally binned PowerSpace
"""
if
space
.
harmonic
:
space
=
PowerSpace
(
space
)
if
not
isinstance
(
space
,
PowerSpace
):
raise
ValueError
(
"space must be either a harmonic space or Power space."
)
raise
ValueError
(
"space must be either a harmonic space or Power space."
)
field
=
PS_field
(
space
,
spec
)
dist
=
PowerDistributor
(
space
.
harmonic_partner
,
space
)
k_field
=
dist
(
field
)
...
...
@@ -190,3 +193,70 @@ def create_harmonic_smoothing_operator(domain, space, sigma):
kfunc
=
domain
[
space
].
get_fft_smoothing_kernel_function
(
sigma
)
return
DiagonalOperator
(
kfunc
(
domain
[
space
].
get_k_length_array
()),
domain
,
space
)
def
full
(
domain
,
val
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
full
(
domain
,
val
)
return
Field
.
full
(
domain
,
val
)
def
empty
(
domain
,
dtype
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
empty
(
domain
,
dtype
)
return
Field
.
empty
(
domain
,
dtype
)
def
from_random
(
random_type
,
domain
,
dtype
=
np
.
float64
,
**
kwargs
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
from_random
(
random_type
,
domain
,
dtype
,
**
kwargs
)
return
Field
.
from_random
(
random_type
,
domain
,
dtype
,
**
kwargs
)
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
from_global_data
(
domain
,
arr
,
sum_up
)
return
Field
.
from_global_data
(
domain
,
arr
,
sum_up
)
def
from_local_data
(
domain
,
arr
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
from_local_data
(
domain
,
arr
)
return
Field
.
from_local_data
(
domain
,
arr
)
def
makeDomain
(
domain
):
if
isinstance
(
domain
,
(
MultiDomain
,
dict
)):
return
MultiDomain
.
make
(
domain
)
return
DomainTuple
.
make
(
domain
)
# Arithmetic functions working on Fields
_current_module
=
sys
.
modules
[
__name__
]
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"conjugate"
]:
def
func
(
f
):
def
func2
(
x
,
out
=
None
):
if
isinstance
(
x
,
MultiField
):
if
out
is
not
None
:
if
(
not
isinstance
(
out
,
MultiField
)
or
x
.
_domain
!=
out
.
_domain
):
raise
ValueError
(
"Bad 'out' argument"
)
for
key
,
value
in
x
.
items
():
func2
(
value
,
out
=
out
[
key
])
return
out
return
MultiField
({
key
:
func2
(
val
)
for
key
,
val
in
x
.
items
()})
elif
isinstance
(
x
,
Field
):
fu
=
getattr
(
dobj
,
f
)
if
out
is
not
None
:
if
not
isinstance
(
out
,
Field
)
or
x
.
_domain
!=
out
.
_domain
: