Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
0674df96
Commit
0674df96
authored
Sep 24, 2017
by
Martin Reinecke
Browse files
start experiments with data object interface
parent
6d7d5ede
Changes
5
Hide whitespace changes
Inline
Side-by-side
nifty/__init__.py
View file @
0674df96
...
...
@@ -47,3 +47,5 @@ from .sugar import *
from
.
import
plotting
from
.
import
library
from
.data_objects
import
numpy_do
as
dobj
nifty/data_objects/__init__.py
0 → 100644
View file @
0674df96
nifty/data_objects/numpy_do.py
0 → 100644
View file @
0674df96
# Data object module for NIFTy that uses simple numpy ndarrays.
import
numpy
as
np
from
numpy
import
ndarray
as
data_object
from
numpy
import
full
,
empty
,
sqrt
,
ones
,
zeros
,
vdot
,
abs
def
from_object
(
object
,
dtype
=
None
,
copy
=
True
):
return
np
.
array
(
object
,
dtype
=
dtype
,
copy
=
copy
)
nifty/field.py
View file @
0674df96
...
...
@@ -24,6 +24,7 @@ from . import nifty_utilities as utilities
from
.random
import
Random
from
.domain_tuple
import
DomainTuple
from
functools
import
reduce
from
.
import
dobj
class
Field
(
object
):
...
...
@@ -78,16 +79,16 @@ class Field(object):
if
isinstance
(
val
,
Field
):
if
self
.
domain
!=
val
.
domain
:
raise
ValueError
(
"Domain mismatch"
)
self
.
_val
=
np
.
array
(
val
.
val
,
dtype
=
dtype
,
copy
=
copy
)
self
.
_val
=
dobj
.
from_object
(
val
.
val
,
dtype
=
dtype
,
copy
=
copy
)
elif
(
np
.
isscalar
(
val
)):
self
.
_val
=
np
.
full
(
self
.
domain
.
shape
,
dtype
=
dtype
,
fill_value
=
val
)
elif
isinstance
(
val
,
np
.
ndarray
):
self
.
_val
=
dobj
.
full
(
self
.
domain
.
shape
,
dtype
=
dtype
,
fill_value
=
val
)
elif
isinstance
(
val
,
dobj
.
data_object
):
if
self
.
domain
.
shape
==
val
.
shape
:
self
.
_val
=
np
.
array
(
val
,
dtype
=
dtype
,
copy
=
copy
)
self
.
_val
=
dobj
.
from_object
(
val
,
dtype
=
dtype
,
copy
=
copy
)
else
:
raise
ValueError
(
"Shape mismatch"
)
elif
val
is
None
:
self
.
_val
=
np
.
empty
(
self
.
domain
.
shape
,
dtype
=
dtype
)
self
.
_val
=
dobj
.
empty
(
self
.
domain
.
shape
,
dtype
=
dtype
)
else
:
raise
TypeError
(
"unknown source type"
)
...
...
@@ -253,7 +254,7 @@ class Field(object):
"synthetization."
)
result_domain
[
i
]
=
self
.
domain
[
i
].
harmonic_partner
spec
=
np
.
sqrt
(
self
.
val
)
spec
=
dobj
.
sqrt
(
self
.
val
)
for
i
in
spaces
:
power_space
=
self
.
domain
[
i
]
local_blow_up
=
[
slice
(
None
)]
*
len
(
spec
.
shape
)
...
...
@@ -449,7 +450,7 @@ class Field(object):
The weighted field.
"""
new_field
=
Field
(
val
=
self
,
copy
=
not
inplace
)
new_field
=
self
if
inplace
else
self
.
copy
(
)
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
...
...
@@ -462,7 +463,7 @@ class Field(object):
if
np
.
isscalar
(
wgt
):
fct
*=
wgt
else
:
new_shape
=
np
.
ones
(
len
(
self
.
shape
),
dtype
=
np
.
int
)
new_shape
=
dobj
.
ones
(
len
(
self
.
shape
),
dtype
=
np
.
int
)
new_shape
[
self
.
domain
.
axes
[
ind
][
0
]:
self
.
domain
.
axes
[
ind
][
-
1
]
+
1
]
=
wgt
.
shape
wgt
=
wgt
.
reshape
(
new_shape
)
...
...
@@ -504,7 +505,7 @@ class Field(object):
fct
=
tmp
if
spaces
is
None
:
return
fct
*
np
.
vdot
(
y
.
val
.
ravel
(),
x
.
val
.
ravel
())
return
fct
*
dobj
.
vdot
(
y
.
val
.
ravel
(),
x
.
val
.
ravel
())
else
:
# create a diagonal operator which is capable of taking care of the
# axes-matching
...
...
@@ -522,7 +523,7 @@ class Field(object):
The L2-norm of the field values.
"""
return
np
.
sqrt
(
np
.
abs
(
self
.
vdot
(
x
=
self
)))
return
dobj
.
sqrt
(
dobj
.
abs
(
self
.
vdot
(
x
=
self
)))
def
conjugate
(
self
):
""" Returns the complex conjugate of the field.
...
...
@@ -544,7 +545,7 @@ class Field(object):
return
Field
(
self
.
domain
,
-
self
.
val
,
self
.
dtype
)
def
__abs__
(
self
):
return
Field
(
self
.
domain
,
np
.
abs
(
self
.
val
),
self
.
dtype
)
return
Field
(
self
.
domain
,
dobj
.
abs
(
self
.
val
),
self
.
dtype
)
def
_contraction_helper
(
self
,
op
,
spaces
):
if
spaces
is
None
:
...
...
@@ -597,6 +598,13 @@ class Field(object):
def
std
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'std'
,
spaces
)
def
copy_content_from
(
self
,
other
):
if
not
isinstance
(
other
,
Field
):
raise
TypeError
(
"argument must be a Field"
)
if
other
.
domain
!=
self
.
domain
:
raise
ValueError
(
"domains are incompatible."
)
self
.
val
[()]
=
other
.
val
# ---General binary methods---
def
_binary_helper
(
self
,
other
,
op
):
...
...
nifty/low_level_library.py
View file @
0674df96
...
...
@@ -92,9 +92,6 @@ else:
def
general_axpy
(
a
,
x
,
y
,
out
):
if
x
.
domain
!=
y
.
domain
or
x
.
domain
!=
out
.
domain
:
raise
ValueError
(
"Incompatible domains"
)
x
=
x
.
val
y
=
y
.
val
out
=
out
.
val
if
out
is
x
:
if
a
!=
1.
:
...
...
@@ -106,7 +103,7 @@ else:
else
:
out
+=
x
else
:
out
[()]
=
y
out
.
copy_content_from
(
y
)
if
a
!=
1.
:
out
+=
a
*
x
else
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment