Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Daniel Boeckenhoff
tfields
Commits
f17a7d97
Commit
f17a7d97
authored
May 08, 2020
by
dboe
Browse files
sorted test_core
parent
3c0ecff7
Changes
4
Hide whitespace changes
Inline
Side-by-side
test/test_core.py
View file @
f17a7d97
...
...
@@ -7,35 +7,70 @@ import tfields
ATOL
=
1e-8
class
Base_Check
(
object
):
class
AbstractNdarray_Check
(
object
):
def
demand_equal
(
self
,
other
):
raise
NotImplementedError
(
self
.
__class__
)
def
test_pickle
(
self
):
with
NamedTemporaryFile
(
suffix
=
'.pickle'
)
as
out_file
:
pickle
.
dump
(
self
.
_inst
,
out_file
)
out_file
.
flush
()
out_file
.
seek
(
0
)
reloaded
=
pickle
.
load
(
out_file
)
self
.
demand_equal
(
reloaded
)
def
test_save_npz
(
self
):
out_file
=
NamedTemporaryFile
(
suffix
=
'.npz'
)
self
.
_inst
.
save
(
out_file
.
name
)
_
=
out_file
.
seek
(
0
)
# this is only necessary in the test
load_inst
=
self
.
_inst
.
__class__
.
load
(
out_file
.
name
)
# allow_pickle=True) ?
self
.
demand_equal
(
load_inst
)
def
tearDown
(
self
):
del
self
.
_inst
class
Tensors_Check
(
AbstractNdarray_Check
):
"""
Testing derivatives of Tensors
"""
_inst
=
None
def
demand_equal
(
self
,
other
,
atol
=
False
,
transformed
=
False
):
if
atol
:
self
.
assertTrue
(
self
.
_inst
.
equal
(
other
,
atol
=
ATOL
))
else
:
self
.
assertTrue
(
self
.
_inst
.
equal
(
other
))
if
not
transformed
:
self
.
assertEqual
(
self
.
_inst
.
coord_sys
,
other
.
coord_sys
)
self
.
assertEqual
(
self
.
_inst
.
name
,
other
.
name
)
def
test_self_equality
(
self
):
# Test equality
self
.
assertTrue
(
self
.
_inst
.
equal
(
self
.
_inst
))
self
.
demand_equal
(
self
.
_inst
)
transformer
=
self
.
_inst
.
copy
()
transformer
.
transform
(
tfields
.
bases
.
CYLINDER
)
self
.
demand_equal
(
transformer
,
atol
=
True
,
transformed
=
True
)
def
test_cylinderTrafo
(
self
):
# Test coordinate transformations in circle
transformer
=
self
.
_inst
.
copy
()
transformer
.
transform
(
tfields
.
bases
.
CYLINDER
)
self
.
assertTrue
(
tfields
.
Tensors
(
self
.
_inst
).
equal
(
transformer
,
atol
=
ATOL
))
self
.
assertTrue
(
self
.
_inst
.
equal
(
transformer
,
atol
=
ATOL
))
if
len
(
self
.
_inst
)
>
0
:
self
.
assertFalse
(
np
.
array_equal
(
self
.
_inst
,
transformer
))
transformer
.
transform
(
tfields
.
bases
.
CARTESIAN
)
self
.
assertTrue
(
self
.
_inst
.
equal
(
transformer
,
atol
=
ATOL
)
)
self
.
demand_
equal
(
transformer
,
atol
=
True
,
transformed
=
True
)
def
test_spericalTrafo
(
self
):
# Test coordinate transformations in circle
transformer
=
self
.
_inst
.
copy
()
transformer
.
transform
(
tfields
.
bases
.
SPHERICAL
)
transformer
.
transform
(
tfields
.
bases
.
CARTESIAN
)
self
.
assertTrue
(
self
.
_inst
.
equal
(
transformer
,
atol
=
ATOL
)
)
self
.
demand_
equal
(
transformer
,
atol
=
True
,
transformed
=
True
)
def
test_basic_merge
(
self
):
# create 3 copies with different coord_sys
...
...
@@ -65,22 +100,8 @@ class Base_Check(object):
atol
=
ATOL
)
self
.
assertTrue
(
value
)
def
test_pickle
(
self
):
with
NamedTemporaryFile
(
suffix
=
'.pickle'
)
as
out_file
:
pickle
.
dump
(
self
.
_inst
,
out_file
)
out_file
.
flush
()
out_file
.
seek
(
0
)
reloaded
=
pickle
.
load
(
out_file
)
self
.
assertTrue
(
self
.
_inst
.
equal
(
reloaded
))
def
tearDown
(
self
):
del
self
.
_inst
class
Tensor_Fields_Check
(
object
):
class
TensorFields_Check
(
Tensors_Check
):
def
test_fields
(
self
):
# field is of type list
self
.
assertTrue
(
isinstance
(
self
.
_inst
.
fields
,
list
))
...
...
@@ -92,32 +113,34 @@ class Tensor_Fields_Check(object):
self
.
assertFalse
(
field
is
target_field
)
class
TensorMaps_Check
(
TensorFields_Check
):
def
test_maps
(
self
):
self
.
assertIsNotNone
(
self
.
_inst
.
maps
)
"""
EMPTY TESTS
"""
class
Tensors_Empty_Test
(
Base
_Check
,
unittest
.
TestCase
):
class
Tensors_Empty_Test
(
Tensors
_Check
,
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_inst
=
tfields
.
Tensors
([],
dim
=
3
)
class
TensorFields_Empty_Test
(
Tensor
s_Empty_Test
,
Tensor_Fields_Check
):
class
TensorFields_Empty_Test
(
Tensor
Fields_Check
,
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_fields
=
[]
self
.
_inst
=
tfields
.
TensorFields
([],
dim
=
3
)
class
TensorMaps_Empty_Test
(
Tensor
Fields_Empty_Test
):
class
TensorMaps_Empty_Test
(
Tensor
Maps_Check
,
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_fields
=
[]
self
.
_inst
=
tfields
.
TensorMaps
([],
dim
=
3
)
self
.
_maps
=
[]
self
.
_maps_fields
=
[]
def
test_maps
(
self
):
self
.
assertIsNotNone
(
self
.
_inst
.
maps
)
class
TensorFields_Copy_Test
(
TensorFields_Empty_Test
):
def
setUp
(
self
):
...
...
@@ -153,5 +176,38 @@ class TensorMaps_Copy_Test(TensorMaps_Empty_Test):
maps
=
self
.
_maps
)
class
Container_Check
(
AbstractNdarray_Check
):
def
demand_equal
(
self
,
other
):
raise
NotImplementedError
(
self
.
__class__
)
def
test_item
(
self
):
if
len
(
self
.
_inst
.
items
)
>
0
:
self
.
assertEqual
(
len
(
self
.
_inst
),
len
(
self
.
_inst
))
self
.
assertEqual
(
type
(
self
.
_inst
),
type
(
self
.
_inst
))
class
Container_Test
(
Container_Check
,
unittest
.
TestCase
):
def
setUp
(
self
):
sphere
=
tfields
.
Mesh3D
.
grid
(
(
1
,
1
,
1
),
(
-
np
.
pi
,
np
.
pi
,
3
),
(
-
np
.
pi
/
2
,
np
.
pi
/
2
,
3
),
coord_sys
=
'spherical'
)
sphere2
=
sphere
.
copy
()
*
3
self
.
_inst
=
tfields
.
Container
([
sphere
,
sphere2
])
class
ContainerFolded_Test
(
Container_Check
,
unittest
.
TestCase
):
def
setUp
(
self
):
sphere
=
tfields
.
Mesh3D
.
grid
(
(
1
,
1
,
1
),
(
-
np
.
pi
,
np
.
pi
,
3
),
(
-
np
.
pi
/
2
,
np
.
pi
/
2
,
3
),
coord_sys
=
'spherical'
)
sphere2
=
sphere
.
copy
()
*
3
self
.
_container
=
tfields
.
Container
([
sphere
,
sphere2
])
self
.
_inst
=
tfields
.
Container
(
self
.
_container
)
if
__name__
==
'__main__'
:
unittest
.
main
()
test/test_mesh3D.py
View file @
f17a7d97
...
...
@@ -4,13 +4,13 @@ import unittest
import
sympy
# NOQA: F401
import
os
import
sys
from
.test_core
import
Base
_Check
from
.test_core
import
Tensors
_Check
THIS_DIR
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
getcwd
(),
os
.
path
.
expanduser
(
__file__
))))
sys
.
path
.
append
(
os
.
path
.
normpath
(
os
.
path
.
join
(
THIS_DIR
)))
class
Sphere_Test
(
Base
_Check
,
unittest
.
TestCase
):
class
Sphere_Test
(
Tensors
_Check
,
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_inst
=
tfields
.
Mesh3D
.
grid
(
(
1
,
1
,
1
),
...
...
test/test_templating.py
View file @
f17a7d97
...
...
@@ -12,48 +12,53 @@ class Base_Check(object):
self
.
assertTrue
(
len
(
templates
),
len
(
self
.
_instances
))
for
template
,
inst
in
zip
(
templates
,
self
.
_instances
):
merged_cut
=
merged
.
cut
(
template
)
self
.
assertEqual
(
len
(
inst
.
maps
),
len
(
merged_cut
.
maps
))
self
.
assertEqual
(
len
(
merged_cut
.
maps
),
len
(
template
.
maps
))
for
i
,
mp
in
enumerate
(
inst
.
maps
):
self
.
assertEqual
(
len
(
mp
),
len
(
merged_cut
.
maps
[
i
]))
self
.
assertEqual
(
tfields
.
core
.
dim
(
mp
),
tfields
.
core
.
dim
(
merged_cut
.
maps
[
i
]))
self
.
assertEqual
(
tfields
.
core
.
dim
(
template
.
maps
[
i
]),
tfields
.
core
.
dim
(
merged_cut
.
maps
[
i
]))
self
.
assertTrue
(
tfields
.
TensorFields
(
inst
).
equal
(
tfields
.
TensorFields
(
merged_cut
)))
self
.
assertTrue
(
inst
.
equal
(
merged_cut
))
# class Tensors_Empty_Test(Base_Check, unittest.TestCase):
# def setUp(self):
# self._instances = [tfields.Tensors([], dim=3) for i in range(3)]
#
#
# class TensorFields_Empty_Test(Base_Check, unittest.TestCase):
# def setUp(self):
# self._fields = []
# self._instances = [tfields.TensorFields([], dim=3) for i in range(3)]
# class TensorMaps_Empty_Test(Base_Check, unittest.TestCase):
# def setUp(self):
# self._instances = [tfields.TensorMaps([], dim=3) for i in range(3)]
#class TensorFields_Test(TensorFields_Empty_Test):
# def setUp(self):
# base = [(-5, 5, 7)] * 3
# self._fields = [tfields.Tensors.grid(*base, coord_sys='cylinder'),
# tfields.Tensors(range(7**3))]
# tensors = tfields.Tensors.grid(*base)
# self._instances = [tfields.TensorFields(tensors, *self._fields)
# for i in range(3)]
# class TensorMaps_Test(TensorMaps_Empty_Test):
self
.
_check_maps
(
inst
,
template
,
merged_cut
)
def
_check_maps
(
self
,
inst
,
template
,
merged_cut
):
pass
class
Tensors_Empty_Test
(
Base_Check
,
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_instances
=
[
tfields
.
Tensors
([],
dim
=
3
)
for
i
in
range
(
3
)]
class
TensorFields_Empty_Test
(
Base_Check
,
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_fields
=
[]
self
.
_instances
=
[
tfields
.
TensorFields
([],
dim
=
3
)
for
i
in
range
(
3
)]
class
TensorMaps_Empty_Test
(
Base_Check
,
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_instances
=
[
tfields
.
TensorMaps
([],
dim
=
3
)
for
i
in
range
(
3
)]
def
_check_maps
(
self
,
inst
,
template
,
merged_cut
):
self
.
assertEqual
(
len
(
inst
.
maps
),
len
(
merged_cut
.
maps
))
self
.
assertEqual
(
len
(
merged_cut
.
maps
),
len
(
template
.
maps
))
for
i
,
mp
in
enumerate
(
inst
.
maps
):
self
.
assertEqual
(
len
(
mp
),
len
(
merged_cut
.
maps
[
i
]))
self
.
assertEqual
(
tfields
.
core
.
dim
(
mp
),
tfields
.
core
.
dim
(
merged_cut
.
maps
[
i
]))
self
.
assertEqual
(
tfields
.
core
.
dim
(
template
.
maps
[
i
]),
tfields
.
core
.
dim
(
merged_cut
.
maps
[
i
]))
self
.
assertTrue
(
tfields
.
TensorFields
(
inst
).
equal
(
tfields
.
TensorFields
(
merged_cut
)))
self
.
assertTrue
(
inst
.
equal
(
merged_cut
))
class
TensorFields_Test
(
TensorFields_Empty_Test
):
def
setUp
(
self
):
base
=
[(
-
5
,
5
,
7
)]
*
3
self
.
_fields
=
[
tfields
.
Tensors
.
grid
(
*
base
,
coord_sys
=
'cylinder'
),
tfields
.
Tensors
(
range
(
7
**
3
))]
tensors
=
tfields
.
Tensors
.
grid
(
*
base
)
self
.
_instances
=
[
tfields
.
TensorFields
(
tensors
,
*
self
.
_fields
)
for
i
in
range
(
3
)]
class
TensorMaps_Test
(
TensorMaps_Empty_Test
):
def
setUp
(
self
):
base
=
[(
-
1
,
1
,
3
)]
*
3
tensors
=
tfields
.
Tensors
.
grid
(
*
base
)
...
...
tfields/core.py
View file @
f17a7d97
...
...
@@ -220,7 +220,9 @@ class AbstractNdarray(np.ndarray):
@
classmethod
@
contextmanager
def
_bypass_setters
(
cls
,
*
slots
,
empty_means_all
=
True
):
def
_bypass_setters
(
cls
,
*
slots
,
empty_means_all
=
True
,
demand_existence
=
False
):
"""
Temporarily remove the setter in __slot_setters__ corresponding to slot
position in __slot__. You should know what you do, when using this.
...
...
@@ -229,13 +231,22 @@ class AbstractNdarray(np.ndarray):
*slots (str): attribute names in __slots__
empty_means_all (bool): defines behaviour when slots is empty.
When True: if slots is empty mute all slots in __slots__
demand_existence (bool): if false do not check the existence of the
slot in __slots__ - do nothing for that slot. Handle with care!
"""
if
not
slots
and
empty_means_all
:
slots
=
cls
.
__slots__
slot_indices
=
[]
setters
=
[]
for
slot
in
slots
:
slot_index
=
cls
.
__slots__
.
index
(
slot
)
slot_index
=
cls
.
__slots__
.
index
(
slot
)
\
if
slot
in
cls
.
__slots__
else
None
if
slot_index
is
None
:
# slot not in cls.__slots__.
if
demand_existence
:
raise
ValueError
(
"Slot {slot} not existing"
.
format
(
**
locals
()))
continue
if
len
(
cls
.
__slot_setters__
)
<
slot_index
+
1
:
# no setter to be found
continue
...
...
@@ -466,7 +477,7 @@ class AbstractNdarray(np.ndarray):
bulk_type
=
getattr
(
tfields
,
bulk_type
)
list_dict
[
key
].
append
(
bulk_type
.
_from_dict
(
**
sub_dict
[
index
]))
with
cls
.
_bypass_setters
(
'fields'
):
with
cls
.
_bypass_setters
(
'fields'
,
demand_existence
=
False
):
'''
Build the normal way
'''
...
...
@@ -1722,11 +1733,12 @@ class TensorFields(Tensors):
index
=
index
[
0
]
if
item
.
fields
:
# circumvent the setter here.
with
self
.
_bypass_setters
(
'fields'
):
with
self
.
_bypass_setters
(
'fields'
,
demand_existence
=
False
):
item
.
fields
=
[
field
.
__getitem__
(
index
)
for
field
in
item
.
fields
]
except
IndexError
as
err
:
except
IndexError
as
err
:
# noqa: F841
warnings
.
warn
(
"Index error occured for field.__getitem__. Error "
"message: {err}"
.
format
(
**
locals
())
...
...
@@ -1953,10 +1965,9 @@ class Maps(Container):
A Maps object is a container for TensorFields sorted by dimension.
"""
def
__new__
(
cls
,
maps
,
**
kwargs
):
if
not
issubclass
(
type
(
maps
),
Maps
):
dims
=
[
dim
(
obj
)
for
obj
in
maps
]
dims
,
maps
=
tfields
.
lib
.
util
.
multi_sort
(
dims
,
maps
)
kwargs
[
'labels'
]
=
dims
if
issubclass
(
type
(
maps
),
Maps
):
kwargs
[
'labels'
]
=
maps
.
labels
maps
=
maps
.
items
maps_cp
=
[]
for
mp
in
maps
:
...
...
@@ -1969,8 +1980,16 @@ class Maps(Container):
maps
=
maps_cp
obj
=
super
().
__new__
(
cls
,
maps
,
**
kwargs
)
# obj._update()
return
obj
def
_update
(
self
):
maps
=
self
.
items
dims
=
[
dim
(
obj
)
for
obj
in
maps
]
dims
,
maps
=
tfields
.
lib
.
util
.
multi_sort
(
dims
,
maps
)
self
.
items
=
maps
self
.
labels
=
dims
@
property
def
dims
(
self
):
return
self
.
labels
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment