Commit 201a6a28 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more immutability; start importing Python3 builtins everywhere

parent 8924690e
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from functools import reduce
import numpy as np import numpy as np
from .random import Random from .random import Random
from mpi4py import MPI from mpi4py import MPI
import sys import sys
from functools import reduce
_comm = MPI.COMM_WORLD _comm = MPI.COMM_WORLD
ntask = _comm.Get_size() ntask = _comm.Get_size()
...@@ -62,6 +64,9 @@ class data_object(object): ...@@ -62,6 +64,9 @@ class data_object(object):
if local_shape(self._shape, self._distaxis) != self._data.shape: if local_shape(self._shape, self._distaxis) != self._data.shape:
raise ValueError("shape mismatch") raise ValueError("shape mismatch")
def copy(self):
return data_object(self._shape, self._data.copy(), self._distaxis)
# def _sanity_checks(self): # def _sanity_checks(self):
# # check whether the distaxis is consistent # # check whether the distaxis is consistent
# if self._distaxis < -1 or self._distaxis >= len(self._shape): # if self._distaxis < -1 or self._distaxis >= len(self._shape):
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from functools import reduce from functools import reduce
from .domains.domain import Domain from .domains.domain import Domain
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
import numpy as np import numpy as np
from ..sugar import from_random from ..sugar import from_random
from ..minimization.energy import Energy from ..minimization.energy import Energy
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division from __future__ import (absolute_import, division, print_function)
from builtins import range from builtins import *
import numpy as np import numpy as np
from . import utilities from . import utilities
from .domain_tuple import DomainTuple from .domain_tuple import DomainTuple
......
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
def _logger_init(): def _logger_init():
import logging import logging
......
...@@ -16,10 +16,8 @@ ...@@ -16,10 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division from __future__ import (absolute_import, division, print_function)
from builtins import *
from builtins import range
from ..logger import logger from ..logger import logger
from .descent_minimizer import DescentMinimizer from .descent_minimizer import DescentMinimizer
from .line_search_strong_wolfe import LineSearchStrongWolfe from .line_search_strong_wolfe import LineSearchStrongWolfe
......
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division from __future__ import (absolute_import, division, print_function)
from builtins import range from builtins import *
from builtins import object
import numpy as np import numpy as np
from .descent_minimizer import DescentMinimizer from .descent_minimizer import DescentMinimizer
from .line_search_strong_wolfe import LineSearchStrongWolfe from .line_search_strong_wolfe import LineSearchStrongWolfe
......
...@@ -25,7 +25,7 @@ def _joint_position(model1, model2): ...@@ -25,7 +25,7 @@ def _joint_position(model1, model2):
a = model1.position._val a = model1.position._val
b = model2.position._val b = model2.position._val
# Note: In python >3.5 one could do {**a, **b} # Note: In python >3.5 one could do {**a, **b}
ab = a.copy() ab = dict(a)
ab.update(b) ab.update(b)
return MultiField(ab) return MultiField(ab)
......
import collections
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..utilities import frozendict
__all = ["MultiDomain"]
class frozendict(collections.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete
:py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, self._dict)
def __hash__(self):
if self._hash is None:
h = 0
for key, value in self._dict.items():
h ^= hash((key, value))
self._hash = h
return self._hash
class MultiDomain(frozendict): class MultiDomain(frozendict):
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from ..field import Field from ..field import Field
import numpy as np import numpy as np
from .multi_domain import MultiDomain from .multi_domain import MultiDomain
from ..utilities import frozendict
class MultiField(object): class MultiField(object):
...@@ -28,7 +29,7 @@ class MultiField(object): ...@@ -28,7 +29,7 @@ class MultiField(object):
---------- ----------
val : dict val : dict
""" """
self._val = val self._val = frozendict(val)
self._domain = MultiDomain.make( self._domain = MultiDomain.make(
{key: val.domain for key, val in self._val.items()}) {key: val.domain for key, val in self._val.items()})
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division from __future__ import (absolute_import, division, print_function)
from builtins import *
import numpy as np import numpy as np
from ..field import Field from ..field import Field
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
......
...@@ -16,17 +16,19 @@ ...@@ -16,17 +16,19 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import * from builtins import *
import numpy as np import numpy as np
from itertools import product from itertools import product
import abc import abc
from future.utils import with_metaclass from future.utils import with_metaclass
from functools import reduce from functools import reduce
import collections
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space", __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c", "memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb", "my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
"my_product"] "my_product", "frozendict"]
def my_sum(terms): def my_sum(terms):
...@@ -290,3 +292,43 @@ def my_fftn_r2c(a, axes=None): ...@@ -290,3 +292,43 @@ def my_fftn_r2c(a, axes=None):
def my_fftn(a, axes=None): def my_fftn(a, axes=None):
from pyfftw.interfaces.numpy_fft import fftn from pyfftw.interfaces.numpy_fft import fftn
return fftn(a, axes=axes, **_fft_extra_args) return fftn(a, axes=axes, **_fft_extra_args)
class frozendict(collections.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete
:py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, self._dict)
def __hash__(self):
if self._hash is None:
h = 0
for key, value in self._dict.items():
h ^= hash((key, value))
self._hash = h
return self._hash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment