Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
f0ec3ead
Commit
f0ec3ead
authored
Jun 25, 2016
by
theos
Browse files
Fixed domain, codomain handling in nifty.field.
parent
09f8ee21
Changes
1
Hide whitespace changes
Inline
Side-by-side
nifty_field.py
View file @
f0ec3ead
...
...
@@ -187,7 +187,7 @@ class field(object):
field_type
,
datamodel
,
**
kwargs
):
# check domain
self
.
domain
=
self
.
_parse_domain
(
domain
=
domain
)
self
.
domain_axes
_list
=
self
.
_get_axes_
list
(
self
.
domain
)
self
.
domain_axes
=
self
.
_get_axes_
tuple
(
self
.
domain
)
# check codomain
if
codomain
is
None
:
...
...
@@ -196,7 +196,7 @@ class field(object):
self
.
codomain
=
self
.
_parse_codomain
(
codomain
,
self
.
domain
)
self
.
field_type
=
self
.
_parse_field_type
(
field_type
)
self
.
field_type_axes
_list
=
self
.
_get_axes_
list
(
self
.
field_type
)
self
.
field_type_axes
=
self
.
_get_axes_
tuple
(
self
.
field_type
)
if
dtype
is
None
:
dtype
=
self
.
_infer_dtype
(
domain
=
self
.
domain
,
...
...
@@ -237,7 +237,7 @@ class field(object):
dtype
=
reduce
(
lambda
x
,
y
:
np
.
result_type
(
x
,
y
),
dtype_tuple
)
return
dtype
def
_get_axes_
list
(
self
,
things_with_shape
):
def
_get_axes_
tuple
(
self
,
things_with_shape
):
i
=
0
axes_list
=
[]
for
thing
in
things_with_shape
:
...
...
@@ -246,7 +246,7 @@ class field(object):
l
+=
[
i
]
i
+=
1
axes_list
+=
[
tuple
(
l
)]
return
axes_list
return
tuple
(
axes_list
)
def
_parse_comm
(
self
,
comm
):
# check if comm is a string -> the name of comm is given
...
...
@@ -512,11 +512,11 @@ class field(object):
for
ind
,
sp
in
enumerate
(
self
.
domain
):
casted_x
=
sp
.
complement_cast
(
casted_x
,
axis
=
self
.
domain_axes
_list
[
ind
])
axis
=
self
.
domain_axes
[
ind
])
for
ind
,
ft
in
enumerate
(
self
.
field_type
):
casted_x
=
ft
.
complement_cast
(
casted_x
,
axis
=
self
.
field_type_axes
_list
[
ind
])
axis
=
self
.
field_type_axes
[
ind
])
return
casted_x
...
...
@@ -647,7 +647,7 @@ class field(object):
for
ind
,
sp
in
enumerate
(
self
.
domain
):
new_val
=
sp
.
calc_weight
(
new_val
,
power
=
power
,
axes
=
self
.
domain_axes
_list
[
ind
])
axes
=
self
.
domain_axes
[
ind
])
new_field
.
set_val
(
new_val
=
new_val
,
copy
=
False
)
return
new_field
...
...
@@ -718,15 +718,15 @@ class field(object):
dotted
=
x
.
conjugate
()
*
y
for
ind
in
range
(
-
1
,
-
len
(
self
.
field_type_axes
_list
)
-
1
,
-
1
):
for
ind
in
range
(
-
1
,
-
len
(
self
.
field_type_axes
)
-
1
,
-
1
):
dotted
=
self
.
field_type
[
ind
].
dot_contraction
(
dotted
,
axes
=
self
.
field_type_axes
_list
[
ind
])
axes
=
self
.
field_type_axes
[
ind
])
for
ind
in
range
(
-
1
,
-
len
(
self
.
domain_axes
_list
)
-
1
,
-
1
):
for
ind
in
range
(
-
1
,
-
len
(
self
.
domain_axes
)
-
1
,
-
1
):
dotted
=
self
.
domain
[
ind
].
dot_contraction
(
dotted
,
axes
=
self
.
domain_axes
_list
[
ind
])
axes
=
self
.
domain_axes
[
ind
])
return
dotted
def
vdot
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -792,8 +792,7 @@ class field(object):
return
work_field
def
transform
(
self
,
new_domain
=
None
,
new_codomain
=
None
,
spaces
=
None
,
**
kwargs
):
def
transform
(
self
,
spaces
=
None
,
**
kwargs
):
"""
Computes the transform of the field using the appropriate conjugate
transformation.
...
...
@@ -818,37 +817,36 @@ class field(object):
Otherwise, nothing is returned.
"""
if
new_domain
is
None
:
new_domain
=
self
.
codomain
# try to recycle the old domain
if
new_codomain
is
None
:
try
:
new_codomain
=
self
.
_parse_codomain
(
self
.
domain
,
new_domain
)
except
ValueError
:
new_codomain
=
self
.
_build_codomain
(
new_domain
)
else
:
new_codomain
=
self
.
_parse_codomain
(
new_codomain
,
new_domain
)
try
:
spaces_iterator
=
iter
(
spaces
)
iter
(
spaces
)
except
TypeError
:
if
spaces
is
None
:
spaces
_iterator
=
xrange
(
len
(
self
.
shape
))
spaces
=
xrange
(
len
(
self
.
domain_axes
))
else
:
spaces
_iterator
=
(
spaces
,
)
spaces
=
(
spaces
,
)
new_val
=
self
.
get_val
()
for
ind
in
spaces_iterator
:
new_domain
=
()
new_codomain
=
()
for
ind
in
xrange
(
len
(
self
.
domain
)):
if
ind
in
spaces
:
sp
=
self
.
domain
[
ind
]
cosp
=
self
.
codomain
[
ind
]
new_val
=
sp
.
calc_transform
(
new_val
,
codomain
=
new_domain
[
ind
]
,
axes
=
self
.
domain_axes
_list
[
ind
],
codomain
=
cosp
,
axes
=
self
.
domain_axes
[
ind
],
**
kwargs
)
new_domain
+=
(
self
.
codomain
[
ind
],)
new_codomain
+=
(
self
.
domain
[
ind
],)
else
:
new_domain
+=
(
self
.
domain
[
ind
],)
new_codomain
+=
(
self
.
codomain
[
ind
],)
return_field
=
self
.
copy_empty
(
domain
=
new_domain
,
codomain
=
new_codomain
)
return_field
.
set_val
(
new_val
=
new_val
,
copy
=
False
)
return
return_field
def
smooth
(
self
,
sigma
=
0
,
spaces
=
None
,
**
kwargs
):
...
...
@@ -882,7 +880,7 @@ class field(object):
spaces_iterator
=
iter
(
spaces
)
except
TypeError
:
if
spaces
is
None
:
spaces_iterator
=
xrange
(
len
(
self
.
shape
))
spaces_iterator
=
xrange
(
len
(
self
.
domain
))
else
:
spaces_iterator
=
(
spaces
,
)
...
...
@@ -891,7 +889,7 @@ class field(object):
sp
=
self
.
domain
[
ind
]
new_val
=
sp
.
calc_smooth
(
new_val
,
sigma
=
sigma
,
axes
=
self
.
domain_axes
_list
[
ind
],
axes
=
self
.
domain_axes
[
ind
],
**
kwargs
)
new_field
.
set_val
(
new_val
=
new_val
,
copy
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment