Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
N
NIFTy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
13
Issues
13
List
Boards
Labels
Service Desk
Milestones
Merge Requests
8
Merge Requests
8
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
NIFTy
Commits
d8c42e70
Commit
d8c42e70
authored
Jul 04, 2018
by
Martin Reinecke
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make Fields and MultiFields immutable
parent
c3c4a8c4
Changes
23
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
23 changed files
with
99 additions
and
290 deletions
+99
-290
nifty5/domain_tuple.py
nifty5/domain_tuple.py
+10
-0
nifty5/domains/domain.py
nifty5/domains/domain.py
+10
-0
nifty5/domains/log_rg_space.py
nifty5/domains/log_rg_space.py
+10
-14
nifty5/domains/rg_space.py
nifty5/domains/rg_space.py
+6
-11
nifty5/extra/operator_tests.py
nifty5/extra/operator_tests.py
+7
-7
nifty5/field.py
nifty5/field.py
+27
-128
nifty5/minimization/conjugate_gradient.py
nifty5/minimization/conjugate_gradient.py
+1
-1
nifty5/minimization/energy.py
nifty5/minimization/energy.py
+1
-1
nifty5/minimization/energy_sum.py
nifty5/minimization/energy_sum.py
+1
-1
nifty5/minimization/line_energy.py
nifty5/minimization/line_energy.py
+1
-1
nifty5/minimization/quadratic_energy.py
nifty5/minimization/quadratic_energy.py
+0
-1
nifty5/minimization/scipy_minimizer.py
nifty5/minimization/scipy_minimizer.py
+2
-2
nifty5/minimization/vl_bfgs.py
nifty5/minimization/vl_bfgs.py
+4
-4
nifty5/multi/multi_field.py
nifty5/multi/multi_field.py
+2
-41
nifty5/operators/diagonal_operator.py
nifty5/operators/diagonal_operator.py
+2
-2
nifty5/operators/dof_distributor.py
nifty5/operators/dof_distributor.py
+2
-3
nifty5/operators/inversion_enabler.py
nifty5/operators/inversion_enabler.py
+2
-1
nifty5/operators/scaling_operator.py
nifty5/operators/scaling_operator.py
+3
-2
nifty5/operators/selection_operator.py
nifty5/operators/selection_operator.py
+2
-3
nifty5/operators/symmetrizing_operator.py
nifty5/operators/symmetrizing_operator.py
+1
-1
nifty5/sugar.py
nifty5/sugar.py
+4
-23
test/test_field.py
test/test_field.py
+1
-30
test/test_multi_field.py
test/test_multi_field.py
+0
-13
No files found.
nifty5/domain_tuple.py
View file @
d8c42e70
...
...
@@ -104,6 +104,16 @@ class DomainTuple(object):
"""
return
self
.
_shape
@
property
def
local_shape
(
self
):
"""tuple of int: number of pixels along each axis on the local task
The shape of the array-like object required to store information
living on part of the domain which is stored on the local MPI task.
"""
from
.dobj
import
local_shape
return
local_shape
(
self
.
_shape
)
@
property
def
size
(
self
):
"""int : total number of pixels.
...
...
nifty5/domains/domain.py
View file @
d8c42e70
...
...
@@ -88,6 +88,16 @@ class Domain(NiftyMetaBase()):
"""
raise
NotImplementedError
@
property
def
local_shape
(
self
):
"""tuple of int: number of pixels along each axis on the local task
The shape of the array-like object required to store information
living on part of the domain which is stored on the local MPI task.
"""
from
..dobj
import
local_shape
return
local_shape
(
self
.
shape
)
@
abc
.
abstractproperty
def
size
(
self
):
"""int: total number of pixels.
...
...
nifty5/domains/log_rg_space.py
View file @
d8c42e70
...
...
@@ -3,7 +3,7 @@ from ..sugar import exp
import
numpy
as
np
from
..
dobj
import
ibegin
from
..
import
dobj
from
..field
import
Field
from
.structured_domain
import
StructuredDomain
...
...
@@ -62,26 +62,22 @@ class LogRGSpace(StructuredDomain):
np
.
zeros
(
len
(
self
.
shape
)),
True
)
def
get_k_length_array
(
self
):
out
=
Field
(
self
,
dtype
=
np
.
float64
)
oloc
=
out
.
local_data
ib
=
ibegin
(
out
.
val
)
res
=
np
.
arange
(
oloc
.
shape
[
0
],
dtype
=
np
.
float64
)
+
ib
[
0
]
ib
=
dobj
.
ibegin_from_shape
(
self
.
_shape
)
res
=
np
.
arange
(
self
.
local_shape
[
0
],
dtype
=
np
.
float64
)
+
ib
[
0
]
res
=
np
.
minimum
(
res
,
self
.
shape
[
0
]
-
res
)
*
self
.
bindistances
[
0
]
if
len
(
self
.
shape
)
==
1
:
oloc
[()]
=
res
return
out
return
Field
.
from_local_data
(
self
,
res
)
res
*=
res
for
i
in
range
(
1
,
len
(
self
.
shape
)):
tmp
=
np
.
arange
(
oloc
.
shape
[
i
],
dtype
=
np
.
float64
)
+
ib
[
i
]
tmp
=
np
.
arange
(
self
.
local_
shape
[
i
],
dtype
=
np
.
float64
)
+
ib
[
i
]
tmp
=
np
.
minimum
(
tmp
,
self
.
shape
[
i
]
-
tmp
)
*
self
.
bindistances
[
i
]
tmp
*=
tmp
res
=
np
.
add
.
outer
(
res
,
tmp
)
oloc
[()]
=
np
.
sqrt
(
res
)
return
out
return
Field
.
from_local_data
(
self
,
np
.
sqrt
(
res
))
def
get_expk_length_array
(
self
):
# FIXME This is a hack! Only for plotting. Seems not to be the final version.
out
=
exp
(
self
.
get_k_length_array
())
out
.
val
[
1
:]
=
out
.
val
[:
-
1
]
out
.
val
[
0
]
=
0
return
out
out
=
exp
(
self
.
get_k_length_array
())
.
to_global_data
().
copy
()
out
[
1
:]
=
out
[:
-
1
]
out
[
0
]
=
0
return
Field
.
from_global_data
(
self
,
out
)
nifty5/domains/rg_space.py
View file @
d8c42e70
...
...
@@ -95,22 +95,18 @@ class RGSpace(StructuredDomain):
def
get_k_length_array
(
self
):
if
(
not
self
.
harmonic
):
raise
NotImplementedError
out
=
Field
(
self
,
dtype
=
np
.
float64
)
oloc
=
out
.
local_data
ibegin
=
dobj
.
ibegin
(
out
.
val
)
res
=
np
.
arange
(
oloc
.
shape
[
0
],
dtype
=
np
.
float64
)
+
ibegin
[
0
]
ibegin
=
dobj
.
ibegin_from_shape
(
self
.
_shape
)
res
=
np
.
arange
(
self
.
local_shape
[
0
],
dtype
=
np
.
float64
)
+
ibegin
[
0
]
res
=
np
.
minimum
(
res
,
self
.
shape
[
0
]
-
res
)
*
self
.
distances
[
0
]
if
len
(
self
.
shape
)
==
1
:
oloc
[()]
=
res
return
out
return
Field
.
from_local_data
(
self
,
res
)
res
*=
res
for
i
in
range
(
1
,
len
(
self
.
shape
)):
tmp
=
np
.
arange
(
oloc
.
shape
[
i
],
dtype
=
np
.
float64
)
+
ibegin
[
i
]
tmp
=
np
.
arange
(
self
.
local_
shape
[
i
],
dtype
=
np
.
float64
)
+
ibegin
[
i
]
tmp
=
np
.
minimum
(
tmp
,
self
.
shape
[
i
]
-
tmp
)
*
self
.
distances
[
i
]
tmp
*=
tmp
res
=
np
.
add
.
outer
(
res
,
tmp
)
oloc
[()]
=
np
.
sqrt
(
res
)
return
out
return
Field
.
from_local_data
(
self
,
np
.
sqrt
(
res
))
def
get_unique_k_lengths
(
self
):
if
(
not
self
.
harmonic
):
...
...
@@ -147,8 +143,7 @@ class RGSpace(StructuredDomain):
from
..sugar
import
exp
tmp
=
x
*
x
tmp
*=
-
2.
*
np
.
pi
*
np
.
pi
*
sigma
*
sigma
exp
(
tmp
,
out
=
tmp
)
return
tmp
return
exp
(
tmp
)
def
get_fft_smoothing_kernel_function
(
self
,
sigma
):
if
(
not
self
.
harmonic
):
...
...
nifty5/extra/operator_tests.py
View file @
d8c42e70
...
...
@@ -35,9 +35,9 @@ def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
.
lock
()
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
.
lock
()
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
)
.
lock
()
)
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
))
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
...
...
@@ -46,12 +46,12 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
.
lock
()
res
=
op
(
op
.
inverse_times
(
foo
)
.
lock
()
)
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res
=
op
(
op
.
inverse_times
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
.
lock
()
res
=
op
.
inverse_times
(
op
(
foo
)
.
lock
()
)
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
res
=
op
.
inverse_times
(
op
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
...
...
nifty5/field.py
View file @
d8c42e70
This diff is collapsed.
Click to expand it.
nifty5/minimization/conjugate_gradient.py
View file @
d8c42e70
...
...
@@ -66,7 +66,7 @@ class ConjugateGradient(Minimizer):
return
energy
,
status
r
=
energy
.
gradient
d
=
r
.
copy
()
if
preconditioner
is
None
else
preconditioner
(
r
)
d
=
r
if
preconditioner
is
None
else
preconditioner
(
r
)
previous_gamma
=
r
.
vdot
(
d
).
real
if
previous_gamma
==
0
:
...
...
nifty5/minimization/energy.py
View file @
d8c42e70
...
...
@@ -52,7 +52,7 @@ class Energy(NiftyMetaBase()):
def
__init__
(
self
,
position
):
super
(
Energy
,
self
).
__init__
()
self
.
_position
=
position
.
lock
()
self
.
_position
=
position
def
at
(
self
,
position
):
""" Returns a new Energy object, initialized at `position`.
...
...
nifty5/minimization/energy_sum.py
View file @
d8c42e70
...
...
@@ -63,7 +63,7 @@ class EnergySum(Energy):
@
memo
def
gradient
(
self
):
return
my_lincomb
(
map
(
lambda
v
:
v
.
gradient
,
self
.
_energies
),
self
.
_factors
)
.
lock
()
self
.
_factors
)
@
property
@
memo
...
...
nifty5/minimization/line_energy.py
View file @
d8c42e70
...
...
@@ -48,7 +48,7 @@ class LineEnergy(object):
def
__init__
(
self
,
line_position
,
energy
,
line_direction
,
offset
=
0.
):
super
(
LineEnergy
,
self
).
__init__
()
self
.
_line_position
=
float
(
line_position
)
self
.
_line_direction
=
line_direction
.
lock
()
self
.
_line_direction
=
line_direction
if
self
.
_line_position
==
float
(
offset
):
self
.
_energy
=
energy
...
...
nifty5/minimization/quadratic_energy.py
View file @
d8c42e70
...
...
@@ -35,7 +35,6 @@ class QuadraticEnergy(Energy):
else
:
Ax
=
self
.
_A
(
self
.
position
)
self
.
_grad
=
Ax
if
b
is
None
else
Ax
-
b
self
.
_grad
.
lock
()
self
.
_value
=
0.5
*
self
.
position
.
vdot
(
Ax
)
if
b
is
not
None
:
self
.
_value
-=
b
.
vdot
(
self
.
position
)
...
...
nifty5/minimization/scipy_minimizer.py
View file @
d8c42e70
...
...
@@ -33,7 +33,7 @@ def _toFlatNdarray(fld):
def
_toField
(
arr
,
dom
):
return
Field
.
from_global_data
(
dom
,
arr
.
reshape
(
dom
.
shape
))
return
Field
.
from_global_data
(
dom
,
arr
.
reshape
(
dom
.
shape
)
.
copy
()
)
class
_MinHelper
(
object
):
...
...
@@ -44,7 +44,7 @@ class _MinHelper(object):
def
_update
(
self
,
x
):
pos
=
_toField
(
x
,
self
.
_domain
)
if
(
pos
!=
self
.
_energy
.
position
).
any
():
self
.
_energy
=
self
.
_energy
.
at
(
pos
.
locked_copy
()
)
self
.
_energy
=
self
.
_energy
.
at
(
pos
)
def
fun
(
self
,
x
):
self
.
_update
(
x
)
...
...
nifty5/minimization/vl_bfgs.py
View file @
d8c42e70
...
...
@@ -109,8 +109,8 @@ class _InformationStore(object):
self
.
max_history_length
=
max_history_length
self
.
s
=
[
None
]
*
max_history_length
self
.
y
=
[
None
]
*
max_history_length
self
.
last_x
=
x0
.
copy
()
self
.
last_gradient
=
gradient
.
copy
()
self
.
last_x
=
x0
self
.
last_gradient
=
gradient
self
.
k
=
0
mmax
=
max_history_length
...
...
@@ -233,7 +233,7 @@ class _InformationStore(object):
self
.
s
[
self
.
k
%
mmax
]
=
x
-
self
.
last_x
self
.
y
[
self
.
k
%
mmax
]
=
gradient
-
self
.
last_gradient
self
.
last_x
=
x
.
copy
()
self
.
last_gradient
=
gradient
.
copy
()
self
.
last_x
=
x
self
.
last_gradient
=
gradient
self
.
k
+=
1
nifty5/multi/multi_field.py
View file @
d8c42e70
...
...
@@ -69,18 +69,6 @@ class MultiField(object):
dtype
[
key
],
**
kwargs
)
for
key
in
sorted
(
domain
.
keys
())})
def
fill
(
self
,
fill_value
):
"""Fill `self` uniformly with `fill_value`
Parameters
----------
fill_value: float or complex or int
The value to fill the field with.
"""
for
val
in
self
.
_val
.
values
():
val
.
fill
(
fill_value
)
return
self
def
_check_domain
(
self
,
other
):
if
other
.
_domain
!=
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
...
...
@@ -92,27 +80,6 @@ class MultiField(object):
result
+=
sub_field
.
vdot
(
x
[
key
])
return
result
def
lock
(
self
):
for
v
in
self
.
values
():
v
.
lock
()
return
self
@
property
def
locked
(
self
):
return
all
(
v
.
locked
for
v
in
self
.
values
())
def
copy
(
self
):
return
MultiField
({
key
:
val
.
copy
()
for
key
,
val
in
self
.
items
()})
def
locked_copy
(
self
):
if
self
.
locked
:
return
self
return
MultiField
({
key
:
val
.
locked_copy
()
for
key
,
val
in
self
.
items
()})
def
empty_copy
(
self
):
return
MultiField
({
key
:
val
.
empty_copy
()
for
key
,
val
in
self
.
items
()})
@
staticmethod
def
build_dtype
(
dtype
,
domain
):
if
isinstance
(
dtype
,
dict
):
...
...
@@ -121,12 +88,6 @@ class MultiField(object):
dtype
=
np
.
float64
return
{
key
:
dtype
for
key
in
domain
.
keys
()}
@
staticmethod
def
empty
(
domain
,
dtype
=
None
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
domain
)
return
MultiField
({
key
:
Field
.
empty
(
dom
,
dtype
=
dtype
[
key
])
for
key
,
dom
in
domain
.
items
()})
@
staticmethod
def
full
(
domain
,
val
):
return
MultiField
({
key
:
Field
.
full
(
dom
,
val
)
...
...
@@ -241,9 +202,9 @@ for op in ["__add__", "__radd__",
result_val
[
key
]
=
getattr
(
self
[
key
],
op
)(
other
[
key
])
if
op
in
(
"__add__"
,
"__radd__"
):
for
key
in
only_self_keys
:
result_val
[
key
]
=
self
[
key
]
.
copy
()
result_val
[
key
]
=
self
[
key
]
for
key
in
only_other_keys
:
result_val
[
key
]
=
other
[
key
]
.
copy
()
result_val
[
key
]
=
other
[
key
]
elif
op
in
(
"__mul__"
,
"__rmul__"
):
pass
else
:
...
...
nifty5/operators/diagonal_operator.py
View file @
d8c42e70
...
...
@@ -185,7 +185,7 @@ class DiagonalOperator(EndomorphicOperator):
res
=
Field
.
from_random
(
random_type
=
"normal"
,
domain
=
self
.
_domain
,
dtype
=
dtype
)
if
from_inverse
:
res
.
local_data
[()]
/=
np
.
sqrt
(
self
.
_ldiag
)
res
/=
np
.
sqrt
(
self
.
_ldiag
)
else
:
res
.
local_data
[()]
*=
np
.
sqrt
(
self
.
_ldiag
)
res
*=
np
.
sqrt
(
self
.
_ldiag
)
return
res
nifty5/operators/dof_distributor.py
View file @
d8c42e70
...
...
@@ -120,15 +120,14 @@ class DOFDistributor(LinearOperator):
return
res
def
_times
(
self
,
x
):
res
=
Field
.
empty
(
self
.
_target
,
dtype
=
x
.
dtype
)
if
dobj
.
distaxis
(
x
.
val
)
in
x
.
domain
.
axes
[
self
.
_space
]:
arr
=
x
.
to_global_data
()
else
:
arr
=
x
.
local_data
arr
=
arr
.
reshape
(
self
.
_hshape
)
oarr
=
arr
[(
slice
(
None
),
self
.
_dofdex
,
slice
(
None
))]
return
Field
.
from_local_data
(
self
.
_target
,
oarr
.
reshape
(
self
.
_target
.
local_shape
))
oarr
=
res
.
local_data
.
reshape
(
self
.
_pshape
)
oarr
[()]
=
arr
[(
slice
(
None
),
self
.
_dofdex
,
slice
(
None
))]
return
res
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
nifty5/operators/inversion_enabler.py
View file @
d8c42e70
...
...
@@ -23,6 +23,7 @@ from ..minimization.conjugate_gradient import ConjugateGradient
from
..minimization.iteration_controller
import
IterationController
from
..minimization.quadratic_energy
import
QuadraticEnergy
from
.endomorphic_operator
import
EndomorphicOperator
from
..sugar
import
full
class
InversionEnabler
(
EndomorphicOperator
):
...
...
@@ -65,7 +66,7 @@ class InversionEnabler(EndomorphicOperator):
if
self
.
_op
.
capability
&
mode
:
return
self
.
_op
.
apply
(
x
,
mode
)
x0
=
x
.
empty_copy
().
fill
(
0.
)
x0
=
full
(
x
.
domain
,
0.
)
invmode
=
self
.
_modeTable
[
self
.
INVERSE_BIT
][
self
.
_ilog
[
mode
]]
invop
=
self
.
_op
.
_flip_modes
(
self
.
_ilog
[
invmode
])
prec
=
self
.
_approximation
...
...
nifty5/operators/scaling_operator.py
View file @
d8c42e70
...
...
@@ -22,6 +22,7 @@ from ..field import Field
from
..multi.multi_field
import
MultiField
from
.endomorphic_operator
import
EndomorphicOperator
from
..domain_tuple
import
DomainTuple
from
..sugar
import
full
class
ScalingOperator
(
EndomorphicOperator
):
...
...
@@ -61,9 +62,9 @@ class ScalingOperator(EndomorphicOperator):
self
.
_check_input
(
x
,
mode
)
if
self
.
_factor
==
1.
:
return
x
.
copy
()
return
x
if
self
.
_factor
==
0.
:
return
x
.
empty_copy
().
fill
(
0.
)
return
full
(
self
.
domain
,
0.
)
if
mode
==
self
.
TIMES
:
return
x
*
self
.
_factor
...
...
nifty5/operators/selection_operator.py
View file @
d8c42e70
...
...
@@ -50,10 +50,9 @@ class SelectionOperator(LinearOperator):
return
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
# FIXME Is the copying necessary?
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
x
[
self
.
_key
]
.
copy
()
return
x
[
self
.
_key
]
else
:
from
..multi.multi_field
import
MultiField
return
MultiField
({
self
.
_key
:
x
.
copy
()
})
return
MultiField
({
self
.
_key
:
x
})
nifty5/operators/symmetrizing_operator.py
View file @
d8c42e70
...
...
@@ -15,7 +15,7 @@ class SymmetrizingOperator(EndomorphicOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
tmp
=
x
.
copy
().
val
tmp
=
x
.
val
.
copy
()
ax
=
dobj
.
distaxis
(
tmp
)
globshape
=
tmp
.
shape
for
i
in
range
(
self
.
_ndim
):
...
...
nifty5/sugar.py
View file @
d8c42e70
...
...
@@ -31,7 +31,7 @@ from .logger import logger
__all__
=
[
'PS_field'
,
'power_analyze'
,
'create_power_operator'
,
'create_harmonic_smoothing_operator'
,
'from_random'
,
'full'
,
'
empty'
,
'
from_global_data'
,
'from_local_data'
,
'full'
,
'from_global_data'
,
'from_local_data'
,
'makeDomain'
,
'sqrt'
,
'exp'
,
'log'
,
'tanh'
,
'conjugate'
,
'get_signal_variance'
,
'makeOp'
]
...
...
@@ -203,12 +203,6 @@ def full(domain, val):
return
Field
.
full
(
domain
,
val
)
def
empty
(
domain
,
dtype
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
empty
(
domain
,
dtype
)
return
Field
.
empty
(
domain
,
dtype
)
def
from_random
(
random_type
,
domain
,
dtype
=
np
.
float64
,
**
kwargs
):
if
isinstance
(
domain
,
(
dict
,
MultiDomain
)):
return
MultiField
.
from_random
(
random_type
,
domain
,
dtype
,
**
kwargs
)
...
...
@@ -248,26 +242,13 @@ _current_module = sys.modules[__name__]
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"conjugate"
]:
def
func
(
f
):
def
func2
(
x
,
out
=
None
):
def
func2
(
x
):
if
isinstance
(
x
,
MultiField
):
if
out
is
not
None
:
if
(
not
isinstance
(
out
,
MultiField
)
or
x
.
_domain
!=
out
.
_domain
):
raise
ValueError
(
"Bad 'out' argument"
)
for
key
,
value
in
x
.
items
():
func2
(
value
,
out
=
out
[
key
])
return
out
return
MultiField
({
key
:
func2
(
val
)
for
key
,
val
in
x
.
items
()})
elif
isinstance
(
x
,
Field
):
fu
=
getattr
(
dobj
,
f
)
if
out
is
not
None
:
if
not
isinstance
(
out
,
Field
)
or
x
.
_domain
!=
out
.
_domain
:
raise
ValueError
(
"Bad 'out' argument"
)
fu
(
x
.
val
,
out
=
out
.
val
)
return
out
else
:
return
Field
(
domain
=
x
.
_domain
,
val
=
fu
(
x
.
val
))
return
Field
(
domain
=
x
.
_domain
,
val
=
fu
(
x
.
val
))
else
:
return
getattr
(
np
,
f
)(
x
,
out
)
return
getattr
(
np
,
f
)(
x
)
return
func2
setattr
(
_current_module
,
f
,
func
(
f
))
test/test_field.py
View file @
d8c42e70
...
...
@@ -124,21 +124,6 @@ class Test_Functionality(unittest.TestCase):
res
=
m
.
vdot
(
m
,
spaces
=
1
)
assert_allclose
(
res
.
local_data
,
37.5
)
def
test_lock
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
f1
=
ift
.
Field
(
s1
,
27
)
assert_equal
(
f1
.
locked
,
False
)
f1
.
lock
()
assert_equal
(
f1
.
locked
,
True
)
with
assert_raises
(
ValueError
):
f1
+=
f1
assert_equal
(
f1
.
locked_copy
()
is
f1
,
True
)
def
test_fill
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
f1
=
ift
.
Field
(
s1
,
27
)
assert_equal
(
f1
.
fill
(
10
).
local_data
,
10
)
def
test_dataconv
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
ld
=
np
.
arange
(
ift
.
dobj
.
local_shape
(
s1
.
shape
)[
0
])
...
...
@@ -158,12 +143,6 @@ class Test_Functionality(unittest.TestCase):
assert_equal
(
f
.
local_data
,
5
)
f
=
ift
.
Field
(
None
,
5
)
assert_equal
(
f
.
local_data
,
5
)
assert_equal
(
f
.
empty_copy
().
domain
,
f
.
domain
)
assert_equal
(
f
.
empty_copy
().
dtype
,
f
.
dtype
)
assert_equal
(
f
.
copy
().
domain
,
f
.
domain
)
assert_equal
(
f
.
copy
().
dtype
,
f
.
dtype
)
assert_equal
(
f
.
copy
().
local_data
,
f
.
local_data
)
assert_equal
(
f
.
copy
()
is
f
,
False
)
def
test_trivialities
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
...
...
@@ -182,8 +161,7 @@ class Test_Functionality(unittest.TestCase):
def
test_weight
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
f
=
ift
.
Field
(
s1
,
10.
)
f2
=
f
.
copy
()
f
.
weight
(
1
,
out
=
f2
)
f2
=
f
.
weight
(
1
)
assert_equal
(
f
.
weight
(
1
).
local_data
,
f2
.
local_data
)
assert_equal
(
f
.
total_volume
(),
1
)
assert_equal
(
f
.
total_volume
(
0
),
1
)
...
...
@@ -233,10 +211,6 @@ class Test_Functionality(unittest.TestCase):
f1
.
vdot
(
42
)
with
assert_raises
(
ValueError
):
f1
.
vdot
(
ift
.
Field
(
s2
,
1.
))
with
assert_raises
(
TypeError
):
f1
.
copy_content_from
(
1
)
with
assert_raises
(
ValueError
):
f1
.
copy_content_from
(
ift
.
Field
(
s2
,
1.
))
with
assert_raises
(
TypeError
):
ift
.
full
(
s1
,
[
2
,
3
])
...
...
@@ -246,9 +220,6 @@ class Test_Functionality(unittest.TestCase):
assert_equal
(
f
.
local_data
,
27
)
assert_equal
(
f
.
shape
,
(
200
,))
assert_equal
(
f
.
dtype
,
np
.
int
)
fx
=
ift
.
empty
(
f
.
domain
,
f
.
dtype
)
assert_equal
(
f
.
dtype
,
fx
.
dtype
)
assert_equal
(
f
.
shape
,
fx
.
shape
)
fx
=
ift
.
full
(
f
.
domain
,
0
)
assert_equal
(
f
.
dtype
,
fx
.
dtype
)
assert_equal
(
f
.
shape
,
fx
.
shape
)
...
...
test/test_multi_field.py
View file @
d8c42e70
...
...
@@ -32,19 +32,6 @@ class Test_Functionality(unittest.TestCase):
f2
=
ift
.
from_random
(
"normal"
,
domain
=
dom
,
dtype
=
np
.
complex128
)
assert_allclose
(
f1
.
vdot
(
f2
),
np
.
conj
(
f2
.
vdot
(
f1
)))
def
test_lock
(
self
):
f1
=
ift
.
full
(
dom
,
27
)
assert_equal
(
f1
.
locked
,
False
)
f1
.
lock
()
assert_equal
(
f1
.
locked
,
True
)
assert_equal
(
f1
.
locked_copy
()
is
f1
,
True
)
def
test_fill
(
self
):
f1
=
ift
.
full
(
dom
,
27
)
f1
.
fill
(
10
)
for
val
in
f1
.
values
():
assert_equal
((
val
==
10
).
all
(),
True
)
def
test_dataconv
(
self
):
f1
=
ift
.
full
(
dom
,
27
)
f2
=
ift
.
from_global_data
(
dom
,
f1
.
to_global_data
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment