Commit 75f50648 authored by Martin Reinecke's avatar Martin Reinecke

rework imports; tweak domain comparisons

parent 0992e538
from __future__ import absolute_import, division, print_function
from builtins import (ascii, bytes, chr, dict, filter, hex, input,
map, next, oct, open, pow, range, round,
super, zip)
from functools import reduce
......@@ -16,9 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from functools import reduce
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .random import Random
from mpi4py import MPI
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import object
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
......
......@@ -16,9 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from functools import reduce
from __future__ import absolute_import, division, print_function
from .compat import *
from .domains.domain import Domain
......@@ -138,8 +137,7 @@ class DomainTuple(object):
def __eq__(self, x):
if self is x:
return True
x = DomainTuple.make(x)
return self is x
return self is DomainTuple.make(x)
def __ne__(self, x):
return not self.__eq__(x)
......
from functools import reduce
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..sugar import exp
import numpy as np
from .. import dobj
from ..field import Field
from .structured_domain import StructuredDomain
......
......@@ -16,9 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from builtins import range
from functools import reduce
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .structured_domain import StructuredDomain
from ..field import Field
......
......@@ -16,8 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from .domain import Domain
from functools import reduce
class UnstructuredDomain(Domain):
......
from builtins import *
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..minimization.energy import Energy
from ..utilities import memo, my_sum
......
......@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..sugar import from_random
from ..minimization.energy import Energy
......
......@@ -16,12 +16,11 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from __future__ import absolute_import, division, print_function
from .compat import *
import numpy as np
from . import utilities
from .domain_tuple import DomainTuple
from functools import reduce
from . import dobj
......@@ -338,7 +337,7 @@ class Field(object):
raise TypeError("The dot-partner must be an instance of " +
"the NIFTy field class")
if x._domain != self._domain:
if x._domain is not self._domain:
raise ValueError("Domain mismatch")
ndom = len(self._domain)
......@@ -603,7 +602,7 @@ class Field(object):
return True
if not isinstance(other, Field):
return False
if self._domain != other._domain:
if self._domain is not other._domain:
return False
return (self._val == other._val).all()
......@@ -625,7 +624,7 @@ for op in ["__add__", "__radd__",
def func2(self, other):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain != self._domain:
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val)
return Field(self._domain, tval)
......
......@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from __future__ import absolute_import, division, print_function
from .compat import *
def _logger_init():
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import *
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..utilities import memo, my_lincomb_simple, my_lincomb
from .energy import Energy
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import range
from __future__ import absolute_import, division, print_function
from ..compat import *
import abc
from ..utilities import NiftyMetaBase
......
......@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..logger import logger
from .descent_minimizer import DescentMinimizer
from .line_search_strong_wolfe import LineSearchStrongWolfe
......
......@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from builtins import range
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .line_search import LineSearch
from .line_energy import LineEnergy
......
......@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .descent_minimizer import DescentMinimizer
from .line_search_strong_wolfe import LineSearchStrongWolfe
......
......@@ -38,7 +38,7 @@ class ChainOperator(LinearOperator):
from .diagonal_operator import DiagonalOperator
# Step 1: verify domains
for i in range(len(ops)-1):
if ops[i+1].target != ops[i].domain:
if ops[i+1].target is not ops[i].domain:
raise ValueError("domain mismatch")
# Step 2: unpack ChainOperators
opsnew = []
......
......@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..field import Field
from ..domain_tuple import DomainTuple
......@@ -66,7 +66,7 @@ class DiagonalOperator(EndomorphicOperator):
self._domain = DomainTuple.make(domain)
if spaces is None:
self._spaces = None
if diagonal.domain != self._domain:
if diagonal.domain is not self._domain:
raise ValueError("domain mismatch")
else:
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
......
......@@ -73,7 +73,7 @@ class FFTOperator(LinearOperator):
def _apply_cartesian(self, x, mode):
axes = x.domain.axes[self._space]
tdom = self._target if x.domain == self._domain else self._domain
tdom = self._tgt(mode)
oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed
ldat = x.local_data
......
......@@ -121,7 +121,7 @@ class SHTOperator(LinearOperator):
distaxis = dobj.distaxis(tval)
p2h = not x.domain[self._space].harmonic
tdom = self._target if x.domain == self._domain else self._domain
tdom = self._tgt(mode)
func = self._slice_p2h if p2h else self._slice_h2p
idat = dobj.local_data(tval)
odat = np.empty(dobj.local_shape(tdom.shape, distaxis=distaxis),
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import *
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..utilities import my_sum
from .linear_operator import LinearOperator
import numpy as np
......
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import object
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..field import Field
......
......@@ -16,13 +16,12 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import (absolute_import, division, print_function)
from builtins import *
from __future__ import absolute_import, division, print_function
from .compat import *
import numpy as np
from itertools import product
import abc
from future.utils import with_metaclass
from functools import reduce
import collections
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment