Commit c1bba9b3 authored by Daniel Böckenhoff's avatar Daniel Böckenhoff
Browse files

tfields saving and writing

parent d4a7807c
......@@ -15,6 +15,10 @@ import sympy
import scipy as sp
import scipy.spatial # NOQA: F401
import scipy.stats as stats
import os
from six import string_types
import pathlib
import warnings
np.seterr(all='warn', over='raise')
......@@ -156,6 +160,95 @@ class AbstractNdarray(np.ndarray):
index = -(i + 1)
setattr(self, slot, state[index])
def save(self, path, *args, **kwargs):
"""
Saving a tensors object by redirecting to the correct save method depending on path
Args:
path (str or buffer)
*args:
forwarded to extension specific method
**kwargs:
extension (str): only needed if path is buffer
... remaining:forwarded to extension specific method
"""
# get the extension
if isinstance(path, string_types):
extension = pathlib.Path(path).suffix.lstrip('.')
# get the save method
try:
save_method = getattr(self,
'_save_{extension}'.format(**locals()))
except:
raise NotImplementedError("Can not find save method for extension: "
"{extension}.".format(**locals()))
# resolve: relative paths, symlinks and ~
path = os.path.realpath(os.path.abspath(os.path.expanduser(path)))
return save_method(path, **kwargs)
@classmethod
def load(cls, path, *args, **kwargs):
"""
load a file as a tensors object.
Args:
path (str or buffer)
*args:
forwarded to extension specific method
**kwargs:
extension (str): only needed if path is buffer
... remaining:forwarded to extension specific method
"""
extension = kwargs.pop('extension', 'npz')
if isinstance(path, string_types):
path = os.path.realpath(os.path.abspath(os.path.expanduser(path)))
extension = pathlib.Path(path).suffix.lstrip('.')
try:
load_method = getattr(cls, '_load_{e}'.format(e=extension))
except:
raise NotImplementedError("Can not find load method for extension: "
"{extension}.".format(**locals()))
return load_method(path, *args, **kwargs)
def _save_npz(self, path, **kwargs):
"""
Args:
path (open file or str/unicode): destination to save file to.
Examples:
>>> import tfields
>>> from tempfile import NamedTemporaryFile
>>> outFile = NamedTemporaryFile(suffix='.npz')
>>> p = tfields.Points3D([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]])
>>> p.save(outFile.name)
>>> _ = outFile.seek(0)
>>> p1 = tfields.Points3D.load(outFile.name)
>>> assert p.equal(p1)
"""
kwargs = {}
for attr in self._iter_slots():
if not hasattr(self, attr):
# attribute in __slots__ not found.
warnings.warn("When saving instance of class {0} Attribute {1} not set."
"This Attribute is not saved.".format(self.__class__, attr), Warning)
else:
kwargs[attr] = getattr(self, attr)
np.savez(path, self, **kwargs)
@classmethod
def _load_npz(cls, path, **load_kwargs):
"""
Factory method
Given a path to a npz file, construct the object
"""
np_file = np.load(path, **load_kwargs)
keys = np_file.keys()
bulk = np_file['arr_0']
data_kwargs = {key: np_file[key] for key in keys if key not in ['arr_0']}
return cls.__new__(cls, bulk, **data_kwargs)
class Tensors(AbstractNdarray):
"""
......
......@@ -8,11 +8,10 @@ basic threedimensional tensor
"""
import tfields
import numpy as np
import pathlib
import os
import osTools
import ioTools
import pyTools
import mplTools as mpt
import warnings
import loggingTools
logger = loggingTools.Logger(__name__)
......@@ -147,210 +146,6 @@ class Points3D(tfields.Tensors):
kwargs['dim'] = 3
return super(Points3D, cls).__new__(cls, tensors, **kwargs)
def save(self, filePath, extension='npz', **kwargs):
"""
save a tensors object.
filePath (str or buffer)
[extension (str): only needed if filePath is buffer]
"""
if not osTools.isBuffer(filePath):
extension = ioTools.getExtension(filePath)
if extension not in ['npz', 'obj', 'txt', 'ply']:
raise NotImplementedError("Extension {0} not implemented.".format(extension))
# create from buffer, if file is tar file
if osTools.isTarPath(filePath):
raise NotImplementedError("Saving to tar file not intended "
"since we want to make the tar-writing outside.")
if extension == 'npz':
self.savez(filePath)
if extension == 'obj':
return self.saveObj(filePath, **kwargs)
if extension == 'txt':
self.saveTxt(filePath)
if extension == 'ply':
self.saveTxt(filePath)
else:
raise NotImplementedError("Do not understand the format of file {0}".format(filePath))
def savePly(self, filePath):
from plyfile import PlyData, PlyElement
with self.tmp_transform(tfields.bases.CARTESIAN):
vertices = np.array([tuple(x) for x in self],
dtype=[('x', '<f4'), ('y', '<f4'), ('z', '<f4')])
element = PlyElement.describe(vertices, 'vertex')
filePath = osTools.resolve(filePath)
PlyData([element]).write(filePath)
def saveTxt(self, filePath):
"""
Save obj as .txt file
"""
if self.coordSys != tfields.bases.CARTESIAN:
cpy = self.copy()
cpy.transform(tfields.bases.CARTESIAN)
else:
cpy = self
with ioTools.TextFile(filePath, 'w') as f:
f.writeMatrix(self, seperator=' ', lineBreak='\n')
def saveObj(self, filePath, **kwargs):
"""
Save obj as wavefront/.obj file
"""
obj = kwargs.pop('object', None)
group = kwargs.pop('group', None)
filePath = filePath.replace('.obj', '')
fileDir, fileName = os.path.split(filePath)
with open(osTools.resolve(filePath + '.obj'), 'w') as f:
f.write("# File saved with tensors Points3D.saveObj method\n\n")
if obj is not None:
f.write("o {0}\n".format(obj))
if group is not None:
f.write("g {0}\n".format(group))
for vertex in self:
f.write("v {v[0]} {v[1]} {v[2]}\n".format(v=vertex))
def savez(self, filePath):
"""
Args:
filePath (open file or str/unicode): destination to save file to.
Examples:
>>> from tempfile import NamedTemporaryFile
>>> outFile = NamedTemporaryFile(suffix='.npz')
>>> p = tfields.Points3D([[1., 2., 3.], [4., 5., 6.], [1, 2, -6]])
>>> p.savez(outFile.name)
>>> _ = outFile.seek(0)
>>> p1 = tfields.Points3D.createFromFile(outFile.name)
>>> p.equal(p1)
True
"""
kwargs = {}
for attr in self.__slots__:
if attr == '_cache':
# do not save the cache
continue
elif not hasattr(self, attr):
# attribute in __slots__ not found.
warnings.warn("When saving instance of class {0} Attribute {1} not set."
"This Attribute is not saved.".format(self.__class__, attr), Warning)
else:
kwargs[attr] = getattr(self, attr)
# kwargs = dict(zip(self.__slots__, [getattr(self, attr) for attr in self.__slots__]))
# resolve paths. Try except for buffer objects
try:
filePath = osTools.resolve(filePath)
except AttributeError:
pass
except:
raise
np.savez(filePath, self, **kwargs)
@classmethod
def createFromObjFile(cls, filePath, *groupNames):
"""
Factory method
Given a filePath to a obj/wavefront file, construct the object
"""
ioCls = ioTools.ObjFile
with ioCls(filePath, 'r') as f:
f.process()
vertices = f.getVertices(*groupNames)
log = logger.new()
if issubclass(cls, tfields.Triangles3D):
raise NotImplementedError("reading Triangles3D from obj. Use faces"
"attribute and the Mesh3D implementation.")
elif issubclass(cls, Points3D):
pass
else:
raise NotImplementedError("{cls} not implemented in createFromObjFile"
.format(**locals()))
return cls.__new__(cls, vertices)
@classmethod
def createFromTxtFile(cls, filePath):
"""
Factory method
Given a filePath to a txt file, construct the object
"""
ioCls = ioTools.HaukeFile
with ioCls(filePath, 'r') as f:
f.process()
return f.getPoints3D()
@classmethod
def createFromNpzFile(cls, filePath, **kwargs):
"""
Factory method
Given a filePath to a npz file, construct the object
"""
npFile = np.load(filePath)
keys = npFile.keys()
array = npFile['arr_0']
additionalKwargs = {key: npFile[key] for key in keys if key not in ['arr_0', '_cache']}
if additionalKwargs is not None:
kwargs.update(additionalKwargs)
return cls.__new__(cls, array, **kwargs)
@classmethod
def createFromInpFile(cls, filePath):
"""
Factory method
Given a filePath to a inp file, construct the object
"""
import ioTools.transcoding
transcoding = ioTools.transcoding.getTranscoding('inp')
content = transcoding.read(filePath)
part = content['parts'][0]
vertices = np.array([part['x'], part['y'], part['z']]).T / 1000
indices = np.array(part['nodeIndex']) - 1
if not list(indices) == range(len(indices)):
raise ValueError("node index skipped")
return cls(vertices)
@classmethod
def createFromFile(cls, filePath, *args, **kwargs):
"""
load a file as a tensors object.
Args:
filePath (str or buffer)
*args:
forwarded to extension specific method
**kwargs:
extension (str): only needed if filePath is buffer
... remaining:forwarded to extension specific method
"""
extension = kwargs.pop('extension', 'npz')
if not osTools.isBuffer(filePath):
filePath = osTools.resolve(filePath)
extension = ioTools.getExtension(filePath)
# create from buffer, if file is tar file
if osTools.isTarPath(filePath):
with ioTools.TarHolder() as th:
buf = th.read(filePath)
inst = cls.createFromFile(buf, extension=extension, *args, **kwargs)
return inst
try:
fun = getattr(cls, 'createFrom{e}File'.format(e=extension.capitalize()))
except:
raise NotImplementedError("Do not understand the format of file {0}".format(filePath))
return fun(filePath, *args, **kwargs)
def plot(self, **kwargs):
"""
Frowarding to mpt.plotArray
"""
artist = mpt.plotArray(self, **kwargs)
return artist
if __name__ == '__main__':
import doctest
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment