Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Daniel Boeckenhoff
tfields
Commits
716a1b8b
Commit
716a1b8b
authored
May 09, 2020
by
dboe
Browse files
first iteration maps
parent
d23dec96
Changes
4
Hide whitespace changes
Inline
Side-by-side
test/test_core.py
View file @
716a1b8b
...
...
@@ -7,7 +7,7 @@ import tfields
ATOL
=
1e-8
class
AbstractNdarray
_Check
(
object
):
class
Base
_Check
(
object
):
def
demand_equal
(
self
,
other
):
raise
NotImplementedError
(
self
.
__class__
)
...
...
@@ -21,6 +21,11 @@ class AbstractNdarray_Check(object):
self
.
demand_equal
(
reloaded
)
def
test_copy
(
self
):
copy
=
type
(
self
.
_inst
)(
self
.
_inst
)
self
.
demand_equal
(
copy
)
self
.
assertIsNot
(
self
.
_inst
,
copy
)
def
test_save_npz
(
self
):
out_file
=
NamedTemporaryFile
(
suffix
=
'.npz'
)
self
.
_inst
.
save
(
out_file
.
name
)
...
...
@@ -39,6 +44,10 @@ class AbstractNdarray_Check(object):
del
self
.
_inst
class
AbstractNdarray_Check
(
Base_Check
):
pass
class
Tensors_Check
(
AbstractNdarray_Check
):
"""
Testing derivatives of Tensors
...
...
@@ -60,6 +69,7 @@ class Tensors_Check(AbstractNdarray_Check):
transformer
=
self
.
_inst
.
copy
()
transformer
.
transform
(
tfields
.
bases
.
CYLINDER
)
self
.
demand_equal
(
transformer
,
atol
=
True
,
transformed
=
True
)
self
.
assertIs
(
self
.
_inst
,
np
.
asarray
(
self
.
_inst
))
def
test_cylinderTrafo
(
self
):
# Test coordinate transformations in circle
...
...
@@ -181,6 +191,14 @@ class TensorMaps_Copy_Test(TensorMaps_Empty_Test):
maps
=
self
.
_maps
)
class
Maps_Test
(
Base_Check
,
unittest
.
TestCase
):
def
demand_equal
(
self
,
other
):
self
.
_inst
.
equal
(
other
)
def
setUp
(
self
):
self
.
_inst
=
tfields
.
Maps
([[[
1
,
2
,
3
]]])
class
Container_Check
(
AbstractNdarray_Check
):
def
demand_equal
(
self
,
other
):
raise
NotImplementedError
(
self
.
__class__
)
...
...
@@ -202,24 +220,24 @@ class Container_Test(Container_Check, unittest.TestCase):
self
.
_inst
=
tfields
.
Container
([
sphere
,
sphere2
])
class
ContainerNoList_Test
(
Container_Check
,
unittest
.
TestCase
):
def
demand_equal_tensors
(
self
,
one
,
other
,
atol
=
False
,
transformed
=
False
):
if
atol
:
self
.
assertTrue
(
one
.
equal
(
other
,
atol
=
ATOL
))
else
:
self
.
assertTrue
(
one
.
equal
(
other
))
if
not
transformed
:
self
.
assertEqual
(
one
.
coord_sys
,
other
.
coord_sys
)
self
.
assertEqual
(
one
.
name
,
other
.
name
)
def
demand_equal
(
self
,
other
):
self
.
demand_equal_tensors
(
self
.
_inst
.
items
.
items
[
0
],
other
)
self
.
demand_equal_tensors
(
self
.
_inst
.
items
.
items
[
1
],
other
)
def
setUp
(
self
):
t
=
tfields
.
TensorFields
([[
1
,
2
,
3
]])
self
.
_inst
=
tfields
.
Container
(
t
)
print
(
self
.
_inst
.
items
)
#
class ContainerNoList_Test(Container_Check, unittest.TestCase):
#
def demand_equal_tensors(self, one, other, atol=False, transformed=False):
#
if atol:
#
self.assertTrue(one.equal(other, atol=ATOL))
#
else:
#
self.assertTrue(one.equal(other))
#
if not transformed:
#
self.assertEqual(one.coord_sys, other.coord_sys)
#
self.assertEqual(one.name, other.name)
#
#
def demand_equal(self, other):
#
self.demand_equal_tensors(self._inst.items.items[0], other)
#
self.demand_equal_tensors(self._inst.items.items[1], other)
#
#
def setUp(self):
#
t = tfields.TensorFields([[1, 2, 3]])
#
self._inst = tfields.Container(t)
#
print(self._inst.items)
if
__name__
==
'__main__'
:
...
...
tfields/__about__.py
View file @
716a1b8b
...
...
@@ -31,6 +31,7 @@ __dependencies__ = [
'pathlib2;python_version<"3.0"'
,
'pathlib;python_version>="3.0"'
,
'rna'
,
'sortedcontainers'
]
__classifiers__
=
[
# find the full list of possible classifiers at https://pypi.org/classifiers/
...
...
tfields/__init__.py
View file @
716a1b8b
...
...
@@ -6,7 +6,7 @@ from . import lib
from
.lib
import
*
# __all__ = ['core', 'points3D']
from
.core
import
Tensors
,
TensorFields
,
TensorMaps
,
Container
from
.core
import
Tensors
,
TensorFields
,
TensorMaps
,
Container
,
Maps
from
.points3D
import
Points3D
from
.mask
import
evalf
...
...
tfields/core.py
View file @
716a1b8b
...
...
@@ -12,16 +12,20 @@ Notes:
<https://docs.scipy.org/doc/numpy-1.15.1/reference/generated/...
... numpy.lib.mixins.NDArrayOperatorsMixin.html>`_
"""
# builtin
import
warnings
import
pathlib
from
six
import
string_types
from
contextlib
import
contextmanager
from
collections
import
Counter
# 3rd party
import
numpy
as
np
import
sympy
import
scipy
as
sp
import
sortedcontainers
import
rna
import
tfields.bases
np
.
seterr
(
all
=
"warn"
,
over
=
"raise"
)
...
...
@@ -31,6 +35,7 @@ def rank(tensor):
"""
Tensor rank
"""
tensor
=
np
.
asarray
(
tensor
)
return
len
(
tensor
.
shape
)
-
1
...
...
@@ -38,6 +43,7 @@ def dim(tensor):
"""
Manifold dimension
"""
tensor
=
np
.
asarray
(
tensor
)
if
rank
(
tensor
)
==
0
:
return
1
return
tensor
.
shape
[
1
]
...
...
@@ -266,8 +272,7 @@ class AbstractNdarray(np.ndarray):
>>> import tfields
>>> m = tfields.TensorMaps(
... [[1,2,3], [3,3,3], [0,0,0], [5,6,7]],
... maps=[tfields.TensorFields([[0, 1, 2], [1, 2, 3]],
... [1, 2])])
... maps=[[[0, 1, 2], [1, 2, 3]], [1, 2])])
>>> mc = m.copy()
>>> mc is m
False
...
...
@@ -484,8 +489,6 @@ class AbstractNdarray(np.ndarray):
'''
bulk
=
kwargs
.
pop
(
'bulk'
)
bulk_type
=
kwargs
.
pop
(
'bulk_type'
)
print
(
"-"
*
100
)
print
(
bulk
,
bulk_type
,
kwargs
)
obj
=
cls
.
__new__
(
cls
,
bulk
,
**
kwargs
)
'''
...
...
@@ -1953,8 +1956,8 @@ class Container(AbstractNdarray):
def
__getitem__
(
self
,
index
):
return
self
.
items
[
index
]
def
__setitem__
(
self
,
index
,
value
):
self
.
items
[
index
]
=
value
def
__setitem__
(
self
,
index
,
item
):
self
.
items
[
index
]
=
item
def
__iter__
(
self
):
return
iter
(
self
.
items
)
...
...
@@ -1963,39 +1966,71 @@ class Container(AbstractNdarray):
return
len
(
self
.
items
)
class
Maps
(
C
ontainer
):
class
Maps
(
sortedc
ontainer
s
.
SortedDict
):
"""
A Maps object is a container for TensorFields sorted by dimension.
Indexing by dimension
"""
def
__new__
(
cls
,
maps
,
**
kwargs
):
if
issubclass
(
type
(
maps
),
Maps
):
kwargs
[
'labels'
]
=
maps
.
labels
maps
=
maps
.
items
maps_cp
=
[]
for
mp
in
maps
:
mp
=
TensorFields
(
mp
,
dtype
=
int
)
if
not
mp
.
rank
==
1
:
raise
ValueError
(
"Incorrect map rank {mp.rank}"
.
format
(
**
locals
())
)
maps_cp
.
append
(
mp
)
maps
=
maps_cp
def
__init__
(
self
,
maps
):
if
issubclass
(
type
(
maps
),
list
):
maps
=
{
dim
(
mp
):
mp
for
mp
in
maps
}
if
not
issubclass
(
type
(
maps
),
dict
):
raise
TypeError
(
"Could not interprete input {}"
.
format
(
maps
))
maps
=
{
d
:
self
.
to_map
(
maps
[
d
])
for
d
in
maps
.
keys
()}
super
().
__init__
(
maps
)
@
staticmethod
def
to_map
(
mp
):
mp
=
TensorFields
(
mp
,
dtype
=
int
)
if
not
mp
.
rank
==
1
:
raise
ValueError
(
"Incorrect map rank {mp.rank}"
.
format
(
**
locals
())
)
return
mp
def
__setitem__
(
self
,
dim
,
mp
):
mp
=
self
.
to_map
(
mp
)
if
not
dim
==
mp
.
dim
:
raise
KeyError
(
"Incorrect map dimension {mp.dim} for index {dim}"
.
format
(
**
locals
())
)
if
dim
==
0
:
warnings
.
warn
(
"Using map dimension {dim}"
.
format
(
**
locals
())
)
obj
=
super
().
__new__
(
cls
,
maps
,
**
kwargs
)
# obj._update()
return
obj
super
().
__setitem__
(
dim
,
mp
)
def
_update
(
self
):
maps
=
self
.
items
dims
=
[
dim
(
obj
)
for
obj
in
maps
]
dims
,
maps
=
tfields
.
lib
.
util
.
multi_sort
(
dims
,
maps
)
self
.
items
=
maps
self
.
labels
=
dims
def
__getitem__
(
self
,
dim
):
if
dim
==
0
:
warnings
.
warn
(
"Using map dimension {dim}"
.
format
(
**
locals
())
)
return
super
().
__getitem__
(
dim
)
# def __iter__(self):
# warnings.warn(
# "Deprecated: will bList like iteration"
# .format(**locals())
# )
# return iter([self[k] for k in super().__iter__()])
def
equal
(
self
,
other
,
**
kwargs
):
"""
Test equality with other object.
Args:
**kwargs: passed to each item on equality check
"""
if
not
self
.
keys
()
==
other
.
keys
():
return
False
for
dim
in
self
.
keys
():
if
not
self
[
dim
].
equal
(
other
[
dim
],
**
kwargs
):
return
False
return
True
@
property
def
dims
(
self
):
return
self
.
labels
class
TensorMaps
(
TensorFields
):
...
...
@@ -2108,7 +2143,7 @@ class TensorMaps(TensorFields):
if
isinstance
(
index
,
tuple
):
index
=
index
[
0
]
if
len
(
item
.
maps
)
==
0
:
item
.
maps
=
[
mp
.
copy
()
for
mp
in
item
.
maps
]
item
.
maps
=
Maps
(
item
.
maps
)
indices
=
np
.
array
(
range
(
len
(
self
)))
keep_indices
=
indices
[
index
]
if
isinstance
(
keep_indices
,
(
int
,
np
.
int64
,
np
.
int32
)):
...
...
@@ -2116,12 +2151,12 @@ class TensorMaps(TensorFields):
delete_indices
=
set
(
indices
).
difference
(
set
(
keep_indices
))
# correct all maps that contain deleted indices
for
mp_
idx
in
range
(
len
(
self
.
maps
)
):
for
m
a
p_
dim
in
self
.
maps
.
keys
(
):
# build mask, where the map should be deleted
map_delete_mask
=
np
.
full
(
(
len
(
self
.
maps
[
mp_
idx
]),),
False
,
dtype
=
bool
(
len
(
self
.
maps
[
m
a
p_
dim
]),),
False
,
dtype
=
bool
)
for
i
,
mp
in
enumerate
(
self
.
maps
[
mp_
idx
]):
for
i
,
mp
in
enumerate
(
self
.
maps
[
m
a
p_
dim
]):
for
index
in
mp
:
if
index
in
delete_indices
:
map_delete_mask
[
i
]
=
True
...
...
@@ -2130,13 +2165,13 @@ class TensorMaps(TensorFields):
# build the correction counters
move_up_counter
=
np
.
zeros
(
self
.
maps
[
mp_
idx
].
shape
,
dtype
=
int
self
.
maps
[
m
a
p_
dim
].
shape
,
dtype
=
int
)
for
p
in
delete_indices
:
move_up_counter
[
self
.
maps
[
mp_
idx
]
>
p
]
-=
1
move_up_counter
[
self
.
maps
[
m
a
p_
dim
]
>
p
]
-=
1
item
.
maps
[
mp_
idx
]
=
(
self
.
maps
[
mp_
idx
]
+
move_up_counter
item
.
maps
[
m
a
p_
dim
]
=
(
self
.
maps
[
m
a
p_
dim
]
+
move_up_counter
)[
map_mask
]
except
IndexError
as
err
:
warnings
.
warn
(
...
...
@@ -2166,23 +2201,18 @@ class TensorMaps(TensorFields):
else
:
inst
,
templates
=
(
return_value
,
None
)
# save map_index in order to be able to recover the exact same
# order in the template later
dim_maps_dict
=
{}
# {dim: {obj_index(i): (map_index(j), maps_field)}}
dim_maps_dict
=
{}
# {dim: {i: mp}
for
i
,
obj
in
enumerate
(
objects
):
for
j
,
map_field
in
enumerate
(
obj
.
maps
):
map_field
=
map_field
+
cum_tensor_lengths
[
i
]
if
map_field
.
dim
not
in
dim_maps_dict
:
dim_maps_dict
[
map_field
.
dim
]
=
{}
dim_maps_dict
[
map_field
.
dim
][
i
]
=
(
j
,
map_field
)
for
dimension
,
mp
in
obj
.
maps
:
mp
=
mp
+
cum_tensor_lengths
[
i
]
dim_maps_dict
[
mp
.
dim
][
i
]
=
mp
maps
=
[]
template_maps_list
=
[[]
for
i
in
range
(
len
(
objects
))]
for
dimension
in
sorted
(
dim_maps_dict
.
keys
()
):
for
dimension
in
sorted
(
dim_maps_dict
):
# sort by object index
obj_indices
=
sorted
(
dim_maps_dict
[
dimension
].
keys
())
map_indices
=
[
dim_maps_dict
[
dimension
][
i
][
0
]
for
i
in
obj_indices
]
dim_maps
=
[
dim_maps_dict
[
dimension
][
i
][
1
]
for
i
in
obj_indices
]
dim_maps
=
[
dim_maps_dict
[
dimension
][
i
]
for
i
in
obj_indices
]
return_value
=
TensorFields
.
merged
(
*
dim_maps
,
...
...
@@ -2191,11 +2221,7 @@ class TensorMaps(TensorFields):
if
return_templates
:
mp
,
dimension_map_templates
=
return_value
for
i
in
obj_indices
:
j
=
map_indices
[
i
]
template_maps_list
[
i
].
append
(
(
j
,
dimension_map_templates
[
i
])
)
template_maps_list
[
i
].
append
(
dimension_map_templates
[
i
])
else
:
mp
=
return_value
maps
.
append
(
mp
)
...
...
@@ -2205,8 +2231,7 @@ class TensorMaps(TensorFields):
for
i
,
template_maps
in
enumerate
(
template_maps_list
):
templates
[
i
]
=
tfields
.
TensorMaps
(
templates
[
i
],
maps
=
[
val
[
1
]
for
val
in
sorted
(
template_maps
,
key
=
lambda
x
:
x
[
0
])])
maps
=
template_maps
)
return
inst
,
templates
else
:
return
inst
...
...
@@ -2242,11 +2267,7 @@ class TensorMaps(TensorFields):
with
other
.
tmp_transform
(
self
.
coord_sys
):
mask
=
super
(
TensorMaps
,
self
).
equal
(
other
,
**
kwargs
)
if
issubclass
(
type
(
other
),
TensorMaps
):
if
len
(
self
.
maps
)
!=
len
(
other
.
maps
):
mask
&=
False
else
:
for
i
,
mp
in
enumerate
(
self
.
maps
):
mask
&=
mp
.
equal
(
other
.
maps
[
i
],
**
kwargs
)
mask
&=
self
.
maps
.
equal
(
other
.
maps
,
**
kwargs
)
return
mask
def
stale
(
self
):
...
...
@@ -2256,14 +2277,16 @@ class TensorMaps(TensorFields):
Examples:
>>> import tfields
>>> vectors = tfields.Tensors([[0, 0, 0], [0, 0, 1], [0, -1, 0], [4, 4, 4]])
>>> tm = tfields.TensorMaps(vectors, maps=[[[0, 1, 2], [0, 1, 2]],
... [[1, 1], [2, 2]]])
>>> vectors = tfields.Tensors(
... [[0, 0, 0], [0, 0, 1], [0, -1, 0], [4, 4, 4]])
>>> tm = tfields.TensorMaps(
... vectors,
... maps=[[[0, 1, 2], [0, 1, 2]], [[1, 1], [2, 2]]])
>>> assert np.array_equal(tm.stale(), [False, False, False, True])
"""
staleMask
=
np
.
full
(
self
.
shape
[
0
],
False
,
dtype
=
bool
)
used
=
set
([
ind
for
mp
in
self
.
maps
for
ind
in
mp
.
flatten
()])
used
=
set
([
ind
for
mp
in
self
.
maps
.
items
()
for
ind
in
mp
.
flatten
()])
for
i
in
range
(
self
.
shape
[
0
]):
if
i
not
in
used
:
staleMask
[
i
]
=
True
...
...
@@ -2314,12 +2337,12 @@ class TensorMaps(TensorFields):
if
duplicate_index
!=
tensor_index
:
stale_mask
[
tensor_index
]
=
True
# redirect maps
for
mp_
idx
in
range
(
len
(
self
.
maps
))
:
for
f
in
range
(
len
(
self
.
maps
[
mp_
idx
])):
mp
=
np
.
array
(
self
.
maps
[
mp_
idx
],
dtype
=
int
)
for
m
a
p_
dim
in
self
.
maps
:
for
f
in
range
(
len
(
self
.
maps
[
m
a
p_
dim
])):
# face index
mp
=
np
.
array
(
self
.
maps
[
m
a
p_
dim
],
dtype
=
int
)
if
tensor_index
in
mp
[
f
]:
index
=
tfields
.
index
(
mp
[
f
],
tensor_index
)
inst
.
maps
[
mp_
idx
][
f
][
index
]
=
duplicate_index
inst
.
maps
[
m
a
p_
dim
][
f
][
index
]
=
duplicate_index
return
inst
.
removed
(
stale_mask
)
...
...
@@ -2354,14 +2377,14 @@ class TensorMaps(TensorFields):
# delete_indices = np.arange(self.shape[0])[remove_condition]
# face_keep_masks = self.to_maps_masks(~remove_condition)
# for mp_
idx
, face_keep_mask in enumerate(face_keep_masks):
# move_up_counter = np.zeros(self.maps[mp_
idx
].shape, dtype=int)
# for m
a
p_
dim
, face_keep_mask in enumerate(face_keep_masks):
# move_up_counter = np.zeros(self.maps[m
a
p_
dim
].shape, dtype=int)
# # correct map:
# for p in delete_indices:
# move_up_counter[self.maps[mp_
idx
] > p] -= 1
# move_up_counter[self.maps[m
a
p_
dim
] > p] -= 1
# inst.maps[mp_
idx
] = (self.maps[mp_
idx
] + move_up_counter)[face_keep_mask]
# inst.maps[m
a
p_
dim
] = (self.maps[m
a
p_
dim
] + move_up_counter)[face_keep_mask]
# return inst
return
self
[
~
remove_condition
]
...
...
@@ -2418,11 +2441,11 @@ class TensorMaps(TensorFields):
delete_indices
=
set
(
indices
.
flat
).
difference
(
set
(
keep_indices
.
flat
))
masks
=
[]
for
mp_
idx
in
range
(
len
(
self
.
maps
))
:
for
m
a
p_
dim
in
self
.
maps
:
map_delete_mask
=
np
.
full
(
(
len
(
self
.
maps
[
mp_
idx
]),),
False
,
dtype
=
bool
(
len
(
self
.
maps
[
m
a
p_
dim
]),),
False
,
dtype
=
bool
)
for
i
,
mp
in
enumerate
(
self
.
maps
[
mp_
idx
]):
for
i
,
mp
in
enumerate
(
self
.
maps
[
m
a
p_
dim
]):
for
index
in
mp
:
if
index
in
delete_indices
:
map_delete_mask
[
i
]
=
True
...
...
@@ -2434,8 +2457,8 @@ class TensorMaps(TensorFields):
"""
Args:
*map_descriptions (Tuple(int, List(List(int)))): tuples of
map_
pos_idx
(int): reference to map position
used like: self.maps[map_
pos_idx
]
map_
dim
(int): reference to map position
used like: self.maps[map_
dim
]
map_indices_list (List(List(int))): each int refers
to index in a map.
...
...
@@ -2446,30 +2469,30 @@ class TensorMaps(TensorFields):
# raise ValueError(map_descriptions)
parts
=
[]
for
map_description
in
map_descriptions
:
map_
pos_idx
,
map_indices_list
=
map_description
map_
dim
,
map_indices_list
=
map_description
for
map_indices
in
map_indices_list
:
obj
=
self
.
copy
()
map_indices
=
set
(
map_indices
)
# for speed up
map_delete_mask
=
np
.
array
(
[
True
if
i
not
in
map_indices
else
False
for
i
in
range
(
len
(
self
.
maps
[
map_
pos_idx
]))
for
i
in
range
(
len
(
self
.
maps
[
map_
dim
]))
]
)
obj
.
maps
[
map_
pos_idx
]
=
obj
.
maps
[
map_
pos_idx
][
~
map_delete_mask
]
obj
.
maps
[
map_
dim
]
=
obj
.
maps
[
map_
dim
][
~
map_delete_mask
]
obj
=
obj
.
cleaned
(
duplicates
=
False
)
parts
.
append
(
obj
)
return
parts
def
disjoint_map
(
self
,
mp_
idx
):
def
disjoint_map
(
self
,
m
a
p_
dim
):
"""
Find the disjoint sets of map = self.maps[mp_
idx
]
Find the disjoint sets of map = self.maps[m
a
p_
dim
]
As an example, this method is interesting for splitting a mesh
consisting of seperate parts
Args:
mp_
idx
(int): reference to map position
used like: self.maps[mp_
idx
]
m
a
p_
dim
(int): reference to map position
used like: self.maps[m
a
p_
dim
]
Returns:
Tuple(int, List(List(int))): map description(tuple): see self.parts
...
...
@@ -2490,10 +2513,10 @@ class TensorMaps(TensorFields):
>>> assert ba.equal(b)
"""
maps_list
=
tfields
.
lib
.
sets
.
disjoint_group_indices
(
self
.
maps
[
mp_
idx
])
return
(
mp_
idx
,
maps_list
)
maps_list
=
tfields
.
lib
.
sets
.
disjoint_group_indices
(
self
.
maps
[
m
a
p_
dim
])
return
(
m
a
p_
dim
,
maps_list
)
def
paths
(
self
,
mp_
idx
):
def
paths
(
self
,
m
a
p_
dim
):
"""
Find the minimal amount of graphs building the original graph with
maximum of two links per node i.e.
...
...
@@ -2510,6 +2533,7 @@ class TensorMaps(TensorFields):
o o o o
where 8 is a duplicated node (one has two links and one has only one.)
Examples:
>>> import tfields
>>> import numpy as np
...
...
@@ -2540,7 +2564,7 @@ class TensorMaps(TensorFields):
"""
obj
=
self
.
cleaned
()
flat_map
=
np
.
array
(
obj
.
maps
[
mp_
idx
].
flat
)
flat_map
=
np
.
array
(
obj
.
maps
[
m
a
p_
dim
].
flat
)
values
,
counts
=
np
.
unique
(
flat_map
,
return_counts
=
True
)
counts
=
{
v
:
n
for
v
,
n
in
zip
(
values
,
counts
)}
...
...
@@ -2571,22 +2595,22 @@ class TensorMaps(TensorFields):
if
duplicat_indices
:
duplicates
=
obj
[
duplicat_indices
]
obj
=
type
(
obj
).
merged
(
obj
,
duplicates
)
obj
.
maps
=
[
tfields
.
Tensors
(
flat_map
.
reshape
(
-
1
,
*
obj
.
maps
[
mp_
idx
].
shape
[
1
:])
)
]
paths
=
obj
.
parts
(
obj
.
disjoint_map
(
0
))
obj
.
maps
=
[
flat_map
.
reshape
(
-
1
,
*
obj
.
maps
[
m
a
p_
dim
].
shape
[
1
:])]
paths
=
obj
.
parts
(
obj
.
disjoint_map
(
map_dim
))
# paths = [paths[2]] # this path did not work previously - debugging # with vessel
# remove duplicate map entries and sort
sorted_paths
=
[]
for
path
in
paths
:
# find start index
values
,
counts
=
np
.
unique
(
path
.
maps
[
0
].
flat
,
return_counts
=
True
)
values
,
counts
=
np
.
unique
(
path
.
maps
[
map_dim
].
flat
,
return_counts
=
True
)
first_node
=
None
for
v
,
c
in
zip
(
values
,
counts
):
if
c
==
1
:
first_node
=
v
break
edges
=
[
list
(
edge
)
for
edge
in
path
.
maps
[
0
]]
edges
=
[
list
(
edge
)
for
edge
in
path
.
maps
[
map_dim
]]
if
first_node
is
None
:
first_node
=
0
# edges[0][0]
path
=
path
[
list
(
range
(
len
(
path
)))
+
[
0
]]
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment