Commit c1bba9b3 authored by Daniel Böckenhoff's avatar Daniel Böckenhoff
Browse files

tfields saving and writing

parent d4a7807c
...@@ -15,6 +15,10 @@ import sympy ...@@ -15,6 +15,10 @@ import sympy
import scipy as sp import scipy as sp
import scipy.spatial # NOQA: F401 import scipy.spatial # NOQA: F401
import scipy.stats as stats import scipy.stats as stats
import os
from six import string_types
import pathlib
import warnings
np.seterr(all='warn', over='raise') np.seterr(all='warn', over='raise')
...@@ -156,6 +160,95 @@ class AbstractNdarray(np.ndarray): ...@@ -156,6 +160,95 @@ class AbstractNdarray(np.ndarray):
index = -(i + 1) index = -(i + 1)
setattr(self, slot, state[index]) setattr(self, slot, state[index])
def save(self, path, *args, **kwargs):
"""
Saving a tensors object by redirecting to the correct save method depending on path
Args:
path (str or buffer)
*args:
forwarded to extension specific method
**kwargs:
extension (str): only needed if path is buffer
... remaining:forwarded to extension specific method
"""
# get the extension
if isinstance(path, string_types):
extension = pathlib.Path(path).suffix.lstrip('.')
# get the save method
try:
save_method = getattr(self,
'_save_{extension}'.format(**locals()))
except:
raise NotImplementedError("Can not find save method for extension: "
"{extension}.".format(**locals()))
# resolve: relative paths, symlinks and ~
path = os.path.realpath(os.path.abspath(os.path.expanduser(path)))
return save_method(path, **kwargs)
@classmethod
def load(cls, path, *args, **kwargs):
"""
load a file as a tensors object.
Args:
path (str or buffer)
*args:
forwarded to extension specific method
**kwargs:
extension (str): only needed if path is buffer
... remaining:forwarded to extension specific method
"""
extension = kwargs.pop('extension', 'npz')
if isinstance(path, string_types):
path = os.path.realpath(os.path.abspath(os.path.expanduser(path)))
extension = pathlib.Path(path).suffix.lstrip('.')
try:
load_method = getattr(cls, '_load_{e}'.format(e=extension))
except:
raise NotImplementedError("Can not find load method for extension: "
"{extension}.".format(**locals()))
return load_method(path, *args, **kwargs)
def _save_npz(self, path, **kwargs):
"""
Args:
path (open file or str/unicode): destination to save file to.
Examples:
>>> import tfields
>>> from tempfile import NamedTemporaryFile
>>> outFile = NamedTemporaryFile(suffix='.npz')
>>> p = tfields.Points3D([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]])
>>> p.save(outFile.name)
>>> _ = outFile.seek(0)
>>> p1 = tfields.Points3D.load(outFile.name)
>>> assert p.equal(p1)
"""
kwargs = {}
for attr in self._iter_slots():
if not hasattr(self, attr):
# attribute in __slots__ not found.
warnings.warn("When saving instance of class {0} Attribute {1} not set."
"This Attribute is not saved.".format(self.__class__, attr), Warning)
else:
kwargs[attr] = getattr(self, attr)
np.savez(path, self, **kwargs)
@classmethod
def _load_npz(cls, path, **load_kwargs):
"""
Factory method
Given a path to a npz file, construct the object
"""
np_file = np.load(path, **load_kwargs)
keys = np_file.keys()
bulk = np_file['arr_0']
data_kwargs = {key: np_file[key] for key in keys if key not in ['arr_0']}
return cls.__new__(cls, bulk, **data_kwargs)
class Tensors(AbstractNdarray): class Tensors(AbstractNdarray):
""" """
......
...@@ -8,11 +8,10 @@ basic threedimensional tensor ...@@ -8,11 +8,10 @@ basic threedimensional tensor
""" """
import tfields import tfields
import numpy as np import numpy as np
import pathlib
import os import os
import osTools import osTools
import ioTools import ioTools
import pyTools
import mplTools as mpt
import warnings import warnings
import loggingTools import loggingTools
logger = loggingTools.Logger(__name__) logger = loggingTools.Logger(__name__)
...@@ -147,210 +146,6 @@ class Points3D(tfields.Tensors): ...@@ -147,210 +146,6 @@ class Points3D(tfields.Tensors):
kwargs['dim'] = 3 kwargs['dim'] = 3
return super(Points3D, cls).__new__(cls, tensors, **kwargs) return super(Points3D, cls).__new__(cls, tensors, **kwargs)
def save(self, filePath, extension='npz', **kwargs):
"""
save a tensors object.
filePath (str or buffer)
[extension (str): only needed if filePath is buffer]
"""
if not osTools.isBuffer(filePath):
extension = ioTools.getExtension(filePath)
if extension not in ['npz', 'obj', 'txt', 'ply']:
raise NotImplementedError("Extension {0} not implemented.".format(extension))
# create from buffer, if file is tar file
if osTools.isTarPath(filePath):
raise NotImplementedError("Saving to tar file not intended "
"since we want to make the tar-writing outside.")
if extension == 'npz':
self.savez(filePath)
if extension == 'obj':
return self.saveObj(filePath, **kwargs)
if extension == 'txt':
self.saveTxt(filePath)
if extension == 'ply':
self.saveTxt(filePath)
else:
raise NotImplementedError("Do not understand the format of file {0}".format(filePath))
def savePly(self, filePath):
from plyfile import PlyData, PlyElement
with self.tmp_transform(tfields.bases.CARTESIAN):
vertices = np.array([tuple(x) for x in self],
dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')])
element = PlyElement.describe(vertices, 'vertex')
filePath = osTools.resolve(filePath)
PlyData([element]).write(filePath)
def saveTxt(self, filePath):
"""
Save obj as .txt file
"""
if self.coordSys != tfields.bases.CARTESIAN:
cpy = self.copy()
cpy.transform(tfields.bases.CARTESIAN)
else:
cpy = self
with ioTools.TextFile(filePath, 'w') as f:
f.writeMatrix(self, seperator=' ', lineBreak='\n')
def saveObj(self, filePath, **kwargs):
"""
Save obj as wavefront/.obj file
"""
obj = kwargs.pop('object', None)
group = kwargs.pop('group', None)
filePath = filePath.replace('.obj', '')
fileDir, fileName = os.path.split(filePath)
with open(osTools.resolve(filePath + '.obj'), 'w') as f:
f.write("# File saved with tensors Points3D.saveObj method\n\n")
if obj is not None:
f.write("o {0}\n".format(obj))
if group is not None:
f.write("g {0}\n".format(group))
for vertex in self:
f.write("v {v[0]} {v[1]} {v[2]}\n".format(v=vertex))
def savez(self, filePath):
"""
Args:
filePath (open file or str/unicode): destination to save file to.
Examples:
>>> from tempfile import NamedTemporaryFile
>>> outFile = NamedTemporaryFile(suffix='.npz')
>>> p = tfields.Points3D([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]])
>>> p.savez(outFile.name)
>>> _ = outFile.seek(0)
>>> p1 = tfields.Points3D.createFromFile(outFile.name)
>>> p.equal(p1)
True
"""
kwargs = {}
for attr in self.__slots__:
if attr == '_cache':
# do not save the cache
continue
elif not hasattr(self, attr):
# attribute in __slots__ not found.
warnings.warn("When saving instance of class {0} Attribute {1} not set."
"This Attribute is not saved.".format(self.__class__, attr), Warning)
else:
kwargs[attr] = getattr(self, attr)
# kwargs = dict(zip(self.__slots__, [getattr(self, attr) for attr in self.__slots__]))
# resolve paths. Try except for buffer objects
try:
filePath = osTools.resolve(filePath)
except AttributeError:
pass
except:
raise
np.savez(filePath, self, **kwargs)
@classmethod
def createFromObjFile(cls, filePath, *groupNames):
"""
Factory method
Given a filePath to a obj/wavefront file, construct the object
"""
ioCls = ioTools.ObjFile
with ioCls(filePath, 'r') as f:
f.process()
vertices = f.getVertices(*groupNames)
log = logger.new()
if issubclass(cls, tfields.Triangles3D):
raise NotImplementedError("reading Triangles3D from obj. Use faces"
"attribute and the Mesh3D implementation.")
elif issubclass(cls, Points3D):
pass
else:
raise NotImplementedError("{cls} not implemented in createFromObjFile"
.format(**locals()))
return cls.__new__(cls, vertices)
@classmethod
def createFromTxtFile(cls, filePath):
"""
Factory method
Given a filePath to a txt file, construct the object
"""
ioCls = ioTools.HaukeFile
with ioCls(filePath, 'r') as f:
f.process()
return f.getPoints3D()
@classmethod
def createFromNpzFile(cls, filePath, **kwargs):
"""
Factory method
Given a filePath to a npz file, construct the object
"""
npFile = np.load(filePath)
keys = npFile.keys()
array = npFile['arr_0']
additionalKwargs = {key: npFile[key] for key in keys if key not in ['arr_0', '_cache']}
if additionalKwargs is not None:
kwargs.update(additionalKwargs)
return cls.__new__(cls, array, **kwargs)
@classmethod
def createFromInpFile(cls, filePath):
"""
Factory method
Given a filePath to a inp file, construct the object
"""
import ioTools.transcoding
transcoding = ioTools.transcoding.getTranscoding('inp')
content = transcoding.read(filePath)
part = content['parts'][0]
vertices = np.array([part['x'], part['y'], part['z']]).T / 1000
indices = np.array(part['nodeIndex']) - 1
if not list(indices) == range(len(indices)):
raise ValueError("node index skipped")
return cls(vertices)
@classmethod
def createFromFile(cls, filePath, *args, **kwargs):
"""
load a file as a tensors object.
Args:
filePath (str or buffer)
*args:
forwarded to extension specific method
**kwargs:
extension (str): only needed if filePath is buffer
... remaining:forwarded to extension specific method
"""
extension = kwargs.pop('extension', 'npz')
if not osTools.isBuffer(filePath):
filePath = osTools.resolve(filePath)
extension = ioTools.getExtension(filePath)
# create from buffer, if file is tar file
if osTools.isTarPath(filePath):
with ioTools.TarHolder() as th:
buf = th.read(filePath)
inst = cls.createFromFile(buf, extension=extension, *args, **kwargs)
return inst
try:
fun = getattr(cls, 'createFrom{e}File'.format(e=extension.capitalize()))
except:
raise NotImplementedError("Do not understand the format of file {0}".format(filePath))
return fun(filePath, *args, **kwargs)
def plot(self, **kwargs):
"""
Frowarding to mpt.plotArray
"""
artist = mpt.plotArray(self, **kwargs)
return artist
if __name__ == '__main__': if __name__ == '__main__':
import doctest import doctest
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment