Commit 9aa6e3cc authored by dboe's avatar dboe
Browse files

changed namespace in rna

parent a9ffe76e
Pipeline #88826 passed with stages
in 1 minute and 5 seconds
......@@ -17,12 +17,10 @@ TODO:
"""
# builtin
import warnings
import pathlib
from contextlib import contextmanager
from collections import Counter
from copy import deepcopy
import logging
from six import string_types
# 3rd party
import numpy as np
......@@ -55,7 +53,7 @@ def dim(tensor):
return tensor.shape[1]
class AbstractObject(object): # pylint: disable=useless-object-inheritance
class AbstractObject(rna.polymorphism.Storable):
"""
Abstract base class for all tfields objects implementing polymorphisms
......@@ -64,65 +62,6 @@ class AbstractObject(object): # pylint: disable=useless-object-inheritance
see https://stackoverflow.com/questions/3570796/why-use-abstract-base-classes-in-python
"""
def save(self, path, *args, **kwargs):
"""
Saving by redirecting to the correct save method depending on path
Args:
path (str | buffer)
*args: joined with path
**kwargs:
extension (str): only needed if path is buffer
... remaining:forwarded to extension specific method
"""
# get the extension
if isinstance(path, (string_types, pathlib.Path)):
extension = pathlib.Path(path).suffix.lstrip(".")
else:
raise ValueError("Wrong path type {0}".format(type(path)))
path = str(path)
# get the save method
try:
save_method = getattr(self, "_save_" + extension)
except AttributeError as err:
raise NotImplementedError(
"Can not find save method for extension: "
"{extension}.".format(**locals())
) from err
path = rna.path.resolve(path, *args)
rna.path.mkdir(path)
return save_method(path, **kwargs)
@classmethod
def load(cls, path, *args, **kwargs):
"""
load a file as a tensors object.
Args:
path (str or buffer)
*args: joined with path
**kwargs:
extension (str): only needed if path is buffer
... remaining:forwarded to extension specific method
"""
if isinstance(path, (string_types, pathlib.Path)):
extension = pathlib.Path(path).suffix.lstrip(".")
path = str(path)
path = rna.path.resolve(path)
else:
extension = kwargs.pop("extension", "npz")
try:
load_method = getattr(cls, "_load_{e}".format(e=extension))
except AttributeError as err:
raise NotImplementedError(
"Can not find load method for extension: "
"{extension}.".format(**locals())
) from err
return load_method(path, *args, **kwargs)
def _save_npz(self, path):
"""
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment