diff --git a/tfields/core.py b/tfields/core.py index bc3ee167ba3566c349c77c1cc1be9cd70dac8965..490ac9a89b0473c45d48b708e977b11ba5ac8dfc 100644 --- a/tfields/core.py +++ b/tfields/core.py @@ -17,12 +17,10 @@ TODO: """ # builtin import warnings -import pathlib from contextlib import contextmanager from collections import Counter from copy import deepcopy import logging -from six import string_types # 3rd party import numpy as np @@ -55,7 +53,7 @@ def dim(tensor): return tensor.shape[1] -class AbstractObject(object): # pylint: disable=useless-object-inheritance +class AbstractObject(rna.polymorphism.Storable): """ Abstract base class for all tfields objects implementing polymorphisms @@ -64,65 +62,6 @@ class AbstractObject(object): # pylint: disable=useless-object-inheritance see https://stackoverflow.com/questions/3570796/why-use-abstract-base-classes-in-python """ - def save(self, path, *args, **kwargs): - """ - Saving by redirecting to the correct save method depending on path - - Args: - path (str | buffer) - *args: joined with path - **kwargs: - extension (str): only needed if path is buffer - ... remaining:forwarded to extension specific method - """ - # get the extension - if isinstance(path, (string_types, pathlib.Path)): - extension = pathlib.Path(path).suffix.lstrip(".") - else: - raise ValueError("Wrong path type {0}".format(type(path))) - path = str(path) - - # get the save method - try: - save_method = getattr(self, "_save_" + extension) - except AttributeError as err: - raise NotImplementedError( - "Can not find save method for extension: " - "{extension}.".format(**locals()) - ) from err - - path = rna.path.resolve(path, *args) - rna.path.mkdir(path) - return save_method(path, **kwargs) - - @classmethod - def load(cls, path, *args, **kwargs): - """ - load a file as a tensors object. - - Args: - path (str or buffer) - *args: joined with path - **kwargs: - extension (str): only needed if path is buffer - ... remaining:forwarded to extension specific method - """ - if isinstance(path, (string_types, pathlib.Path)): - extension = pathlib.Path(path).suffix.lstrip(".") - path = str(path) - path = rna.path.resolve(path) - else: - extension = kwargs.pop("extension", "npz") - - try: - load_method = getattr(cls, "_load_{e}".format(e=extension)) - except AttributeError as err: - raise NotImplementedError( - "Can not find load method for extension: " - "{extension}.".format(**locals()) - ) from err - return load_method(path, *args, **kwargs) - def _save_npz(self, path): """ Args: