Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
N
NIFTy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
13
Issues
13
List
Boards
Labels
Service Desk
Milestones
Merge Requests
8
Merge Requests
8
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
NIFTy
Commits
d8c42e70
Commit
d8c42e70
authored
Jul 04, 2018
by
Martin Reinecke
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make Fields and MultiFields immutable
parent
c3c4a8c4
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
23 changed files
with
99 additions
and
290 deletions
+99
-290
nifty5/domain_tuple.py
nifty5/domain_tuple.py
+10
-0
nifty5/domains/domain.py
nifty5/domains/domain.py
+10
-0
nifty5/domains/log_rg_space.py
nifty5/domains/log_rg_space.py
+10
-14
nifty5/domains/rg_space.py
nifty5/domains/rg_space.py
+6
-11
nifty5/extra/operator_tests.py
nifty5/extra/operator_tests.py
+7
-7
nifty5/field.py
nifty5/field.py
+27
-128
nifty5/minimization/conjugate_gradient.py
nifty5/minimization/conjugate_gradient.py
+1
-1
nifty5/minimization/energy.py
nifty5/minimization/energy.py
+1
-1
nifty5/minimization/energy_sum.py
nifty5/minimization/energy_sum.py
+1
-1
nifty5/minimization/line_energy.py
nifty5/minimization/line_energy.py
+1
-1
nifty5/minimization/quadratic_energy.py
nifty5/minimization/quadratic_energy.py
+0
-1
nifty5/minimization/scipy_minimizer.py
nifty5/minimization/scipy_minimizer.py
+2
-2
nifty5/minimization/vl_bfgs.py
nifty5/minimization/vl_bfgs.py
+4
-4
nifty5/multi/multi_field.py
nifty5/multi/multi_field.py
+2
-41
nifty5/operators/diagonal_operator.py
nifty5/operators/diagonal_operator.py
+2
-2
nifty5/operators/dof_distributor.py
nifty5/operators/dof_distributor.py
+2
-3
nifty5/operators/inversion_enabler.py
nifty5/operators/inversion_enabler.py
+2
-1
nifty5/operators/scaling_operator.py
nifty5/operators/scaling_operator.py
+3
-2
nifty5/operators/selection_operator.py
nifty5/operators/selection_operator.py
+2
-3
nifty5/operators/symmetrizing_operator.py
nifty5/operators/symmetrizing_operator.py
+1
-1
nifty5/sugar.py
nifty5/sugar.py
+4
-23
test/test_field.py
test/test_field.py
+1
-30
test/test_multi_field.py
test/test_multi_field.py
+0
-13
No files found.
nifty5/domain_tuple.py
View file @
d8c42e70
...
...
@@ -104,6 +104,16 @@ class DomainTuple(object):
"""
return
self
.
_shape
@
property
def
local_shape
(
self
):
"""tuple of int: number of pixels along each axis on the local task
The shape of the array-like object required to store information
living on part of the domain which is stored on the local MPI task.
"""
from
.dobj
import
local_shape
return
local_shape
(
self
.
_shape
)
@
property
def
size
(
self
):
"""int : total number of pixels.
...
...
nifty5/domains/domain.py
View file @
d8c42e70
...
...
@@ -88,6 +88,16 @@ class Domain(NiftyMetaBase()):
"""
raise
NotImplementedError
@
property
def
local_shape
(
self
):
"""tuple of int: number of pixels along each axis on the local task
The shape of the array-like object required to store information
living on part of the domain which is stored on the local MPI task.
"""
from
..dobj
import
local_shape
return
local_shape
(
self
.
shape
)
@
abc
.
abstractproperty
def
size
(
self
):
"""int: total number of pixels.
...
...
nifty5/domains/log_rg_space.py
View file @
d8c42e70
...
...
@@ -3,7 +3,7 @@ from ..sugar import exp
import
numpy
as
np
from
..
dobj
import
ibegin
from
..
import
dobj
from
..field
import
Field
from
.structured_domain
import
StructuredDomain
...
...
@@ -62,26 +62,22 @@ class LogRGSpace(StructuredDomain):
np
.
zeros
(
len
(
self
.
shape
)),
True
)
def
get_k_length_array
(
self
):
out
=
Field
(
self
,
dtype
=
np
.
float64
)
oloc
=
out
.
local_data
ib
=
ibegin
(
out
.
val
)
res
=
np
.
arange
(
oloc
.
shape
[
0
],
dtype
=
np
.
float64
)
+
ib
[
0
]
ib
=
dobj
.
ibegin_from_shape
(
self
.
_shape
)
res
=
np
.
arange
(
self
.
local_shape
[
0
],
dtype
=
np
.
float64
)
+
ib
[
0
]
res
=
np
.
minimum
(
res
,
self
.
shape
[
0
]
-
res
)
*
self
.
bindistances
[
0
]
if
len
(
self
.
shape
)
==
1
:
oloc
[()]
=
res
return
out
return
Field
.
from_local_data
(
self
,
res
)
res
*=
res
for
i
in
range
(
1
,
len
(
self
.
shape
)):
tmp
=
np
.
arange
(
oloc
.
shape
[
i
],
dtype
=
np
.
float64
)
+
ib
[
i
]
tmp
=
np
.
arange
(
self
.
local_
shape
[
i
],
dtype
=
np
.
float64
)
+
ib
[
i
]
tmp
=
np
.
minimum
(
tmp
,
self
.
shape
[
i
]
-
tmp
)
*
self
.
bindistances
[
i
]
tmp
*=
tmp
res
=
np
.
add
.
outer
(
res
,
tmp
)
oloc
[()]
=
np
.
sqrt
(
res
)
return
out
return
Field
.
from_local_data
(
self
,
np
.
sqrt
(
res
))
def
get_expk_length_array
(
self
):
# FIXME This is a hack! Only for plotting. Seems not to be the final version.
out
=
exp
(
self
.
get_k_length_array
())
out
.
val
[
1
:]
=
out
.
val
[:
-
1
]
out
.
val
[
0
]
=
0
return
out
out
=
exp
(
self
.
get_k_length_array
())
.
to_global_data
().
copy
()
out
[
1
:]
=
out
[:
-
1
]
out
[
0
]
=
0
return
Field
.
from_global_data
(
self
,
out
)
nifty5/domains/rg_space.py
View file @
d8c42e70
...
...
@@ -95,22 +95,18 @@ class RGSpace(StructuredDomain):
def
get_k_length_array
(
self
):
if
(
not
self
.
harmonic
):
raise
NotImplementedError
out
=
Field
(
self
,
dtype
=
np
.
float64
)
oloc
=
out
.
local_data
ibegin
=
dobj
.
ibegin
(
out
.
val
)
res
=
np
.
arange
(
oloc
.
shape
[
0
],
dtype
=
np
.
float64
)
+
ibegin
[
0
]
ibegin
=
dobj
.
ibegin_from_shape
(
self
.
_shape
)
res
=
np
.
arange
(
self
.
local_shape
[
0
],
dtype
=
np
.
float64
)
+
ibegin
[
0
]
res
=
np
.
minimum
(
res
,
self
.
shape
[
0
]
-
res
)
*
self
.
distances
[
0
]
if
len
(
self
.
shape
)
==
1
:
oloc
[()]
=
res
return
out
return
Field
.
from_local_data
(
self
,
res
)
res
*=
res
for
i
in
range
(
1
,
len
(
self
.
shape
)):
tmp
=
np
.
arange
(
oloc
.
shape
[
i
],
dtype
=
np
.
float64
)
+
ibegin
[
i
]
tmp
=
np
.
arange
(
self
.
local_
shape
[
i
],
dtype
=
np
.
float64
)
+
ibegin
[
i
]
tmp
=
np
.
minimum
(
tmp
,
self
.
shape
[
i
]
-
tmp
)
*
self
.
distances
[
i
]
tmp
*=
tmp
res
=
np
.
add
.
outer
(
res
,
tmp
)
oloc
[()]
=
np
.
sqrt
(
res
)
return
out
return
Field
.
from_local_data
(
self
,
np
.
sqrt
(
res
))
def
get_unique_k_lengths
(
self
):
if
(
not
self
.
harmonic
):
...
...
@@ -147,8 +143,7 @@ class RGSpace(StructuredDomain):
from
..sugar
import
exp
tmp
=
x
*
x
tmp
*=
-
2.
*
np
.
pi
*
np
.
pi
*
sigma
*
sigma
exp
(
tmp
,
out
=
tmp
)
return
tmp
return
exp
(
tmp
)
def
get_fft_smoothing_kernel_function
(
self
,
sigma
):
if
(
not
self
.
harmonic
):
...
...
nifty5/extra/operator_tests.py
View file @
d8c42e70
...
...
@@ -35,9 +35,9 @@ def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
.
lock
()
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
.
lock
()
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
)
.
lock
()
)
f1
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
f2
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
))
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
...
...
@@ -46,12 +46,12 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
.
lock
()
res
=
op
(
op
.
inverse_times
(
foo
)
.
lock
()
)
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res
=
op
(
op
.
inverse_times
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
.
lock
()
res
=
op
.
inverse_times
(
op
(
foo
)
.
lock
()
)
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
res
=
op
.
inverse_times
(
op
(
foo
))
_assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
...
...
nifty5/field.py
View file @
d8c42e70
...
...
@@ -35,7 +35,7 @@ class Field(object):
----------
domain : None, DomainTuple, tuple of Domain, or Domain
val :
None, Field, data_object,
or scalar
val :
Field, data_object
or scalar
The values the array should contain after init. A scalar input will
fill the whole array with this scalar. If a data_object is provided,
its dimensions must match the domain's.
...
...
@@ -49,32 +49,30 @@ class Field(object):
many convenience functions for Field conatruction!
"""
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
copy
=
False
,
locked
=
False
):
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
):
self
.
_domain
=
self
.
_infer_domain
(
domain
=
domain
,
val
=
val
)
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
)
if
isinstance
(
val
,
Field
):
if
self
.
_domain
!=
val
.
_domain
:
raise
ValueError
(
"Domain mismatch"
)
self
.
_val
=
dobj
.
from_object
(
val
.
val
,
dtype
=
dtype
,
copy
=
copy
,
set_locked
=
locked
)
self
.
_val
=
val
.
_val
elif
(
np
.
isscalar
(
val
)):
self
.
_val
=
dobj
.
full
(
self
.
_domain
.
shape
,
dtype
=
dtype
,
fill_value
=
val
)
elif
isinstance
(
val
,
dobj
.
data_object
):
if
self
.
_domain
.
shape
==
val
.
shape
:
self
.
_val
=
dobj
.
from_object
(
val
,
dtype
=
dtype
,
copy
=
copy
,
set_locked
=
locked
)
if
dtype
==
val
.
dtype
:
self
.
_val
=
val
else
:
self
.
_val
=
dobj
.
from_object
(
val
,
dtype
,
True
,
True
)
else
:
raise
ValueError
(
"Shape mismatch"
)
elif
val
is
None
:
self
.
_val
=
dobj
.
empty
(
self
.
_domain
.
shape
,
dtype
=
dtype
)
else
:
raise
TypeError
(
"unknown source type"
)
if
locked
:
dobj
.
lock
(
self
.
_val
)
dobj
.
lock
(
self
.
_val
)
# prevent implicit conversion to bool
def
__nonzero__
(
self
):
...
...
@@ -84,7 +82,7 @@ class Field(object):
raise
TypeError
(
"Field does not support implicit conversion to bool"
)
@
staticmethod
def
full
(
domain
,
val
,
dtype
=
None
):
def
full
(
domain
,
val
):
"""Creates a Field with a given domain, filled with a constant value.
Parameters
...
...
@@ -101,11 +99,7 @@ class Field(object):
"""
if
not
np
.
isscalar
(
val
):
raise
TypeError
(
"val must be a scalar"
)
return
Field
(
DomainTuple
.
make
(
domain
),
val
,
dtype
)
@
staticmethod
def
empty
(
domain
,
dtype
=
None
):
return
Field
(
DomainTuple
.
make
(
domain
),
None
,
dtype
)
return
Field
(
DomainTuple
.
make
(
domain
),
val
)
@
staticmethod
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
...
...
@@ -152,11 +146,6 @@ class Field(object):
Returns a handle to the part of the array data residing on the local
task (or to the entore array if MPI is not active).
Notes
-----
If the field is not locked, the array data can be modified.
Use with care!
"""
return
dobj
.
local_data
(
self
.
_val
)
...
...
@@ -196,8 +185,6 @@ class Field(object):
return
dtype
if
val
is
None
:
raise
ValueError
(
"could not infer dtype"
)
if
isinstance
(
val
,
Field
):
return
val
.
dtype
return
np
.
result_type
(
val
)
@
staticmethod
...
...
@@ -223,41 +210,6 @@ class Field(object):
val
=
dobj
.
from_random
(
random_type
,
dtype
=
dtype
,
shape
=
domain
.
shape
,
**
kwargs
))
def
fill
(
self
,
fill_value
):
"""Fill `self` uniformly with `fill_value`
Parameters
----------
fill_value: float or complex or int
The value to fill the field with.
"""
self
.
_val
.
fill
(
fill_value
)
return
self
def
lock
(
self
):
"""Write-protect the data content of `self`.
After this call, it will no longer be possible to change the data
entries of `self`. This is convenient if, for example, a
DiagonalOperator wants to ensure that its diagonal cannot be modified
inadvertently, without making copies.
Notes
-----
This will not only prohibit modifications to the entries of `self`, but
also to the entries of any other Field or numpy array pointing to the
same data. If an unlocked instance is needed, use copy().
The fact that there is no `unlock()` method is deliberate.
"""
dobj
.
lock
(
self
.
_val
)
return
self
@
property
def
locked
(
self
):
"""bool : True iff the field's data content has been locked"""
return
dobj
.
locked
(
self
.
_val
)
@
property
def
val
(
self
):
"""dobj.data_object : the data object storing the field's entries
...
...
@@ -303,43 +255,6 @@ class Field(object):
raise
ValueError
(
".imag called on a non-complex Field"
)
return
Field
(
self
.
_domain
,
self
.
val
.
imag
)
def
copy
(
self
):
""" Returns a full copy of the Field.
The returned object will be an identical copy of the original Field.
The copy will be writeable, even if `self` was locked.
Returns
-------
Field
An identical, but unlocked copy of 'self'.
"""
return
Field
(
val
=
self
,
copy
=
True
)
def
empty_copy
(
self
):
""" Returns a Field with identical domain and data type, but
uninitialized data.
Returns
-------
Field
A copy of 'self', with uninitialized data.
"""
return
Field
(
self
.
_domain
,
dtype
=
self
.
dtype
)
def
locked_copy
(
self
):
""" Returns a read-only version of the Field.
If `self` is locked, returns `self`. Otherwise returns a locked copy
of `self`.
Returns
-------
Field
A read-only version of `self`.
"""
return
self
if
self
.
locked
else
Field
(
val
=
self
,
copy
=
True
,
locked
=
True
)
def
scalar_weight
(
self
,
spaces
=
None
):
"""Returns the uniform volume element for a sub-domain of `self`.
...
...
@@ -392,7 +307,7 @@ class Field(object):
res
*=
self
.
_domain
[
i
].
total_volume
return
res
def
weight
(
self
,
power
=
1
,
spaces
=
None
,
out
=
None
):
def
weight
(
self
,
power
=
1
,
spaces
=
None
):
""" Weights the pixels of `self` with their invidual pixel-volume.
Parameters
...
...
@@ -404,21 +319,12 @@ class Field(object):
Determines on which sub-domain the operation takes place.
If None, the entire domain is used.
out : Field or None
if not None, the result is returned in a new Field
otherwise the contents of "out" are overwritten with the result.
"out" may be identical to "self"!
Returns
-------
Field
The weighted field.
"""
if
out
is
None
:
out
=
self
.
copy
()
else
:
if
out
is
not
self
:
out
.
copy_content_from
(
self
)
aout
=
self
.
local_data
.
copy
()
spaces
=
utilities
.
parse_spaces
(
spaces
,
len
(
self
.
_domain
))
...
...
@@ -435,12 +341,12 @@ class Field(object):
if
dobj
.
distaxis
(
self
.
_val
)
>=
0
and
ind
==
0
:
# we need to distribute the weights along axis 0
wgt
=
dobj
.
local_data
(
dobj
.
from_global_data
(
wgt
))
out
.
local_data
[()]
*=
wgt
**
power
aout
*=
wgt
**
power
fct
=
fct
**
power
if
fct
!=
1.
:
out
*=
fct
a
out
*=
fct
return
out
return
Field
.
from_local_data
(
self
.
_domain
,
aout
)
def
vdot
(
self
,
x
=
None
,
spaces
=
None
):
""" Computes the dot product of 'self' with x.
...
...
@@ -508,7 +414,7 @@ class Field(object):
# ---General unary/contraction methods---
def
__pos__
(
self
):
return
self
.
copy
()
return
self
def
__neg__
(
self
):
return
Field
(
self
.
_domain
,
-
self
.
val
)
...
...
@@ -538,7 +444,7 @@ class Field(object):
for
i
,
dom
in
enumerate
(
self
.
_domain
)
if
i
not
in
spaces
)
return
Field
(
domain
=
return_domain
,
val
=
data
,
copy
=
False
)
return
Field
(
domain
=
return_domain
,
val
=
data
)
def
sum
(
self
,
spaces
=
None
):
"""Sums up over the sub-domains given by `spaces`.
...
...
@@ -713,13 +619,6 @@ class Field(object):
return
self
.
_contraction_helper
(
'std'
,
spaces
)
return
sqrt
(
self
.
var
(
spaces
))
def
copy_content_from
(
self
,
other
):
if
not
isinstance
(
other
,
Field
):
raise
TypeError
(
"argument must be a Field"
)
if
other
.
_domain
!=
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
self
.
local_data
[()]
=
other
.
local_data
[()]
def
__repr__
(
self
):
return
"<nifty5.Field>"
...
...
@@ -745,13 +644,13 @@ class Field(object):
return
self
.
isEquivalentTo
(
other
)
for
op
in
[
"__add__"
,
"__radd__"
,
"__iadd__"
,
"__sub__"
,
"__rsub__"
,
"__isub__"
,
"__mul__"
,
"__rmul__"
,
"__imul__"
,
"__div__"
,
"__rdiv__"
,
"__idiv__"
,
"__truediv__"
,
"__rtruediv__"
,
"__itruediv__"
,
"__floordiv__"
,
"__rfloordiv__"
,
"__ifloordiv__"
,
"__pow__"
,
"__rpow__"
,
"__ipow__"
,
for
op
in
[
"__add__"
,
"__radd__"
,
"__sub__"
,
"__rsub__"
,
"__mul__"
,
"__rmul__"
,
"__div__"
,
"__rdiv__"
,
"__truediv__"
,
"__rtruediv__"
,
"__floordiv__"
,
"__rfloordiv__"
,
"__pow__"
,
"__rpow__"
,
"__lt__"
,
"__le__"
,
"__gt__"
,
"__ge__"
,
"__eq__"
,
"__ne__"
]:
def
func
(
op
):
def
func2
(
self
,
other
):
...
...
@@ -761,11 +660,11 @@ for op in ["__add__", "__radd__", "__iadd__",
if
other
.
_domain
!=
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
tval
=
getattr
(
self
.
val
,
op
)(
other
.
val
)
return
self
if
tval
is
self
.
val
else
Field
(
self
.
_domain
,
tval
)
return
Field
(
self
.
_domain
,
tval
)
if
np
.
isscalar
(
other
)
or
isinstance
(
other
,
dobj
.
data_object
):
tval
=
getattr
(
self
.
val
,
op
)(
other
)
return
self
if
tval
is
self
.
val
else
Field
(
self
.
_domain
,
tval
)
return
Field
(
self
.
_domain
,
tval
)
return
NotImplemented
return
func2
...
...
nifty5/minimization/conjugate_gradient.py
View file @
d8c42e70
...
...
@@ -66,7 +66,7 @@ class ConjugateGradient(Minimizer):
return
energy
,
status
r
=
energy
.
gradient
d
=
r
.
copy
()
if
preconditioner
is
None
else
preconditioner
(
r
)
d
=
r
if
preconditioner
is
None
else
preconditioner
(
r
)
previous_gamma
=
r
.
vdot
(
d
).
real
if
previous_gamma
==
0
:
...
...
nifty5/minimization/energy.py
View file @
d8c42e70
...
...
@@ -52,7 +52,7 @@ class Energy(NiftyMetaBase()):
def
__init__
(
self
,
position
):
super
(
Energy
,
self
).
__init__
()
self
.
_position
=
position
.
lock
()
self
.
_position
=
position
def
at
(
self
,
position
):
""" Returns a new Energy object, initialized at `position`.
...
...
nifty5/minimization/energy_sum.py
View file @
d8c42e70
...
...
@@ -63,7 +63,7 @@ class EnergySum(Energy):
@
memo
def
gradient
(
self
):
return
my_lincomb
(
map
(
lambda
v
:
v
.
gradient
,
self
.
_energies
),
self
.
_factors
)
.
lock
()
self
.
_factors
)
@
property
@
memo
...
...
nifty5/minimization/line_energy.py
View file @
d8c42e70
...
...
@@ -48,7 +48,7 @@ class LineEnergy(object):
def
__init__
(
self
,
line_position
,
energy
,
line_direction
,
offset
=
0.
):
super
(
LineEnergy
,
self
).
__init__
()
self
.
_line_position
=
float
(
line_position
)
self
.
_line_direction
=
line_direction
.
lock
()
self
.
_line_direction
=
line_direction
if
self
.
_line_position
==
float
(
offset
):
self
.
_energy
=
energy
...
...
nifty5/minimization/quadratic_energy.py
View file @
d8c42e70
...
...
@@ -35,7 +35,6 @@ class QuadraticEnergy(Energy):
else
:
Ax
=
self
.
_A
(
self
.
position
)
self
.
_grad
=
Ax
if
b
is
None
else
Ax
-
b
self
.
_grad
.
lock
()
self
.
_value
=
0.5
*
self
.
position
.
vdot
(
Ax
)
if
b
is
not
None
:
self
.
_value
-=
b
.
vdot
(
self
.
position
)
...
...
nifty5/minimization/scipy_minimizer.py
View file @
d8c42e70
...
...
@@ -33,7 +33,7 @@ def _toFlatNdarray(fld):
def
_toField
(
arr
,
dom
):
return
Field
.
from_global_data
(
dom
,
arr
.
reshape
(
dom
.
shape
))
return
Field
.
from_global_data
(
dom
,
arr
.
reshape
(
dom
.
shape
)
.
copy
()
)
class
_MinHelper
(
object
):
...
...
@@ -44,7 +44,7 @@ class _MinHelper(object):
def
_update
(
self
,
x
):
pos
=
_toField
(
x
,
self
.
_domain
)
if
(
pos
!=
self
.
_energy
.
position
).
any
():
self
.
_energy
=
self
.
_energy
.
at
(
pos
.
locked_copy
()
)
self
.
_energy
=
self
.
_energy
.
at
(
pos
)
def
fun
(
self
,
x
):
self
.
_update
(
x
)
...
...
nifty5/minimization/vl_bfgs.py
View file @
d8c42e70
...
...
@@ -109,8 +109,8 @@ class _InformationStore(object):
self
.
max_history_length
=
max_history_length
self
.
s
=
[
None
]
*
max_history_length
self
.
y
=
[
None
]
*
max_history_length
self
.
last_x
=
x0
.
copy
()
self
.
last_gradient
=
gradient
.
copy
()
self
.
last_x
=
x0
self
.
last_gradient
=
gradient
self
.
k
=
0
mmax
=
max_history_length
...
...
@@ -233,7 +233,7 @@ class _InformationStore(object):
self
.
s
[
self
.
k
%
mmax
]
=
x
-
self
.
last_x
self
.
y
[
self
.
k
%
mmax
]
=
gradient
-
self
.
last_gradient
self
.
last_x
=
x
.
copy
()
self
.
last_gradient
=
gradient
.
copy
()
self
.
last_x
=
x
self
.
last_gradient
=
gradient
self
.
k
+=
1
nifty5/multi/multi_field.py
View file @
d8c42e70
...
...
@@ -69,18 +69,6 @@ class MultiField(object):
dtype
[
key
],
**
kwargs
)
for
key
in
sorted
(
domain
.
keys
())})
def
fill
(
self
,
fill_value
):
"""Fill `self` uniformly with `fill_value`
Parameters
----------
fill_value: float or complex or int
The value to fill the field with.
"""
for
val
in
self
.
_val
.
values
():
val
.
fill
(
fill_value
)
return
self
def
_check_domain
(
self
,
other
):
if
other
.
_domain
!=
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
...
...
@@ -92,27 +80,6 @@ class MultiField(object):
result
+=
sub_field
.
vdot
(
x
[
key
])
return
result
def
lock
(
self
):
for
v
in
self
.
values
():
v
.
lock
()
return
self
@
property
def
locked
(
self
):
return
all
(
v
.
locked
for
v
in
self
.
values
())
def
copy
(
self
):
return
MultiField
({
key
:
val
.
copy
()
for
key
,
val
in
self
.
items
()})
def
locked_copy
(
self
):
if
self
.
locked
:
return
self
return
MultiField
({
key
:
val
.
locked_copy
()
for
key
,
val
in
self
.
items
()})
def
empty_copy
(
self
):
return
MultiField
({
key
:
val
.
empty_copy
()
for
key
,
val
in
self
.
items
()})
@
staticmethod
def
build_dtype
(
dtype
,
domain
):
if
isinstance
(
dtype
,
dict
):
...
...
@@ -121,12 +88,6 @@ class MultiField(object):
dtype
=
np
.
float64
return
{
key
:
dtype
for
key
in
domain
.
keys
()}
@
staticmethod
def
empty
(
domain
,
dtype
=
None
):
dtype
=
MultiField
.
build_dtype
(
dtype
,
domain
)
return
MultiField
({
key
:
Field
.
empty
(
dom
,
dtype
=
dtype
[
key
])
for
key
,
dom
in
domain
.
items
()})
@
staticmethod
def
full
(
domain
,
val
):
return
MultiField
({
key
:
Field
.
full
(
dom
,
val
)
...
...
@@ -241,9 +202,9 @@ for op in ["__add__", "__radd__",
result_val
[
key
]
=
getattr
(
self
[
key
],
op
)(
other
[
key
])
if
op
in
(
"__add__"
,
"__radd__"
):
for
key
in
only_self_keys
:
result_val
[
key
]
=
self
[
key
]
.
copy
()
result_val
[
key
]
=
self
[
key
]
for
key
in
only_other_keys
:
result_val
[
key
]
=
other
[
key
]
.
copy
()
result_val
[
key
]
=
other
[
key
]
elif
op
in
(
"__mul__"
,
"__rmul__"
):
pass
else
:
...
...
nifty5/operators/diagonal_operator.py
View file @
d8c42e70
...
...
@@ -185,7 +185,7 @@ class DiagonalOperator(EndomorphicOperator):
res
=
Field
.
from_random
(
random_type
=
"normal"
,
domain
=
self
.
_domain
,
dtype
=
dtype
)
if
from_inverse
:
res
.
local_data
[()]
/=
np
.
sqrt
(
self
.
_ldiag
)
res
/=
np
.
sqrt
(
self
.
_ldiag
)
else
:
res
.
local_data
[()]
*=
np
.
sqrt
(
self
.
_ldiag
)
res
*=
np
.
sqrt
(
self
.
_ldiag
)
return
res
nifty5/operators/dof_distributor.py
View file @
d8c42e70
...
...
@@ -120,15 +120,14 @@ class DOFDistributor(LinearOperator):
return
res
def
_times
(
self
,
x
):
res
=
Field
.
empty
(
self
.
_target
,
dtype
=
x
.
dtype
)
if
dobj
.
distaxis
(
x
.
val
)
in
x
.
domain
.
axes
[
self
.
_space
]:
arr
=
x
.
to_global_data
()
else
:
arr
=
x
.
local_data
arr
=
arr
.
reshape
(
self
.
_hshape
)
oarr
=
arr
[(
slice
(
None
),
self
.
_dofdex
,
slice
(
None
))]
return
Field
.
from_local_data
(
self
.
_target
,
oarr
.
reshape
(
self
.
_target
.
local_shape
))
oarr
=
res
.
local_data
.
reshape
(
self
.
_pshape
)
oarr
[()]
=
arr
[(
slice
(
None
),
self
.
_dofdex
,
slice
(
None
))]
return
res
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
nifty5/operators/inversion_enabler.py
View file @
d8c42e70
...
...
@@ -23,6 +23,7 @@ from ..minimization.conjugate_gradient import ConjugateGradient
from
..minimization.iteration_controller
import
IterationController
from
..minimization.quadratic_energy
import
QuadraticEnergy
from
.endomorphic_operator
import
EndomorphicOperator
from
..sugar
import
full