Commit 201a6a28 authored by Martin Reinecke's avatar Martin Reinecke

more immutability; start importing Python3 builtins everywhere

parent 8924690e
......@@ -16,11 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from functools import reduce
import numpy as np
from .random import Random
from mpi4py import MPI
import sys
from functools import reduce
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
......@@ -62,6 +64,9 @@ class data_object(object):
if local_shape(self._shape, self._distaxis) != self._data.shape:
raise ValueError("shape mismatch")
def copy(self):
return data_object(self._shape, self._data.copy(), self._distaxis)
# def _sanity_checks(self):
# # check whether the distaxis is consistent
# if self._distaxis < -1 or self._distaxis >= len(self._shape):
......
......@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from functools import reduce
from .domains.domain import Domain
......
......@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
import numpy as np
from ..sugar import from_random
from ..minimization.energy import Energy
......
......@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from builtins import range
from __future__ import (absolute_import, division, print_function)
from builtins import *
import numpy as np
from . import utilities
from .domain_tuple import DomainTuple
......
......@@ -16,6 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
def _logger_init():
import logging
......
......@@ -16,10 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from builtins import range
from __future__ import (absolute_import, division, print_function)
from builtins import *
from ..logger import logger
from .descent_minimizer import DescentMinimizer
from .line_search_strong_wolfe import LineSearchStrongWolfe
......
......@@ -16,9 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from builtins import range
from builtins import object
from __future__ import (absolute_import, division, print_function)
from builtins import *
import numpy as np
from .descent_minimizer import DescentMinimizer
from .line_search_strong_wolfe import LineSearchStrongWolfe
......
......@@ -25,7 +25,7 @@ def _joint_position(model1, model2):
a = model1.position._val
b = model2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab = dict(a)
ab.update(b)
return MultiField(ab)
......
import collections
from ..domain_tuple import DomainTuple
__all = ["MultiDomain"]
class frozendict(collections.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete
:py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, self._dict)
def __hash__(self):
if self._hash is None:
h = 0
for key, value in self._dict.items():
h ^= hash((key, value))
self._hash = h
return self._hash
from ..utilities import frozendict
class MultiDomain(frozendict):
......
......@@ -19,6 +19,7 @@
from ..field import Field
import numpy as np
from .multi_domain import MultiDomain
from ..utilities import frozendict
class MultiField(object):
......@@ -28,7 +29,7 @@ class MultiField(object):
----------
val : dict
"""
self._val = val
self._val = frozendict(val)
self._domain = MultiDomain.make(
{key: val.domain for key, val in self._val.items()})
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from __future__ import (absolute_import, division, print_function)
from builtins import *
import numpy as np
from ..field import Field
from ..domain_tuple import DomainTuple
......
......@@ -16,17 +16,19 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
import numpy as np
from itertools import product
import abc
from future.utils import with_metaclass
from functools import reduce
import collections
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
"my_product"]
"my_product", "frozendict"]
def my_sum(terms):
......@@ -290,3 +292,43 @@ def my_fftn_r2c(a, axes=None):
def my_fftn(a, axes=None):
from pyfftw.interfaces.numpy_fft import fftn
return fftn(a, axes=axes, **_fft_extra_args)
class frozendict(collections.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete
:py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, self._dict)
def __hash__(self):
if self._hash is None:
h = 0
for key, value in self._dict.items():
h ^= hash((key, value))
self._hash = h
return self._hash
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment