Commit d57b2f02 authored by theos's avatar theos

Restructured new operator classes -> Introduced endomorphic_operator.py in...

Restructured new operator classes -> Introduced endomorphic_operator.py in order to avoid violation of Liskov's SP for square- and diagonal operator.
Removed paradict from operators.
Added from_random as static method to Field.
parent 88d89ed6
......@@ -45,7 +45,7 @@ from paradict import Paradict
# TODO: Remove this once the transition to field types is done.
from spaces.space import Space as point_space
from nifty_random import random
from random import Random
from nifty_simple_math import *
from nifty_utilities import *
......
......@@ -2,12 +2,12 @@ from __future__ import division
import numpy as np
import pylab as pl
from d2o import distributed_data_object, \
from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.config import about, \
nifty_configuration as gc, \
dependency_injector as gdi
from nifty.config import about,\
nifty_configuration as gc,\
dependency_injector as gdi
from nifty.field_types import FieldType,\
FieldArray
......@@ -15,6 +15,8 @@ from nifty.field_types import FieldType,\
from nifty.spaces.space import Space
import nifty.nifty_utilities as utilities
from nifty.random import Random
POINT_DISTRIBUTION_STRATEGIES = DISTRIBUTION_STRATEGIES['global']
COMM = getattr(gdi[gc['mpi_module']], gc['default_comm'])
......@@ -24,20 +26,11 @@ class Field(object):
# ---Initialization methods---
def __init__(self, domain=None, val=None, dtype=None, field_type=None,
datamodel=None, copy=False):
if isinstance(val, Field):
if domain is None:
domain = val.domain
if dtype is None:
dtype = val.dtype
if field_type is None:
field_type = val.field_type
if datamodel is None:
datamodel = val.datamodel
self.domain = self._parse_domain(domain=domain)
self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain)
self.field_type = self._parse_field_type(field_type)
self.field_type = self._parse_field_type(field_type, val=val)
try:
start = len(reduce(lambda x, y: x+y, self.domain_axes))
......@@ -55,9 +48,12 @@ class Field(object):
self.set_val(new_val=val, copy=copy)
def _parse_domain(self, domain):
def _parse_domain(self, domain, val):
if domain is None:
domain = ()
if isinstance(val, Field):
domain = val.domain
else:
domain = ()
elif not isinstance(domain, tuple):
domain = (domain,)
for d in domain:
......@@ -67,9 +63,12 @@ class Field(object):
"nifty.space."))
return domain
def _parse_field_type(self, field_type):
def _parse_field_type(self, field_type, val):
if field_type is None:
field_type = ()
if isinstance(val, Field):
field_type = val.field_type
else:
field_type = ()
elif not isinstance(field_type, tuple):
field_type = (field_type,)
for ft in field_type:
......@@ -89,8 +88,11 @@ class Field(object):
axes_list += [tuple(l)]
return tuple(axes_list)
def _infer_dtype(self, dtype=None, domain=None, field_type=None):
def _infer_dtype(self, dtype, val, domain, field_type):
if dtype is None:
if isinstance(val, Field) or \
isinstance(val, distributed_data_object):
dtype = val.dtype
dtype_tuple = (np.dtype(gc['default_field_dtype']),)
else:
dtype_tuple = (np.dtype(dtype),)
......@@ -100,17 +102,74 @@ class Field(object):
dtype_tuple += tuple(np.dtype(ft.dtype) for ft in field_type)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
return dtype
def _parse_datamodel(self, datamodel, val):
if datamodel in DISTRIBUTION_STRATEGIES['all']:
pass
elif isinstance(val, distributed_data_object):
datamodel = val.distribution_strategy
if datamodel is None:
if isinstance(val, distributed_data_object):
datamodel = val.distribution_strategy
elif isinstance(val, Field):
datamodel = val.datamodel
else:
about.warnings.cprint("WARNING: Datamodel set to default!")
datamodel = gc['default_datamodel']
elif datamodel not in DISTRIBUTION_STRATEGIES['all']:
raise ValueError(about._errors.cstring(
"ERROR: Invalid datamodel!"))
return datamodel
# ---Factory methods---
@classmethod
def from_random(cls, random_type, domain=None, dtype=None, field_type=None,
datamodel=None, **kwargs):
# create a initially empty field
f = cls(domain=domain, dtype=dtype, field_type=field_type,
datamodel=datamodel)
# now use the processed input in terms of f in order to parse the
# random arguments
random_arguments = cls._parse_random_arguments(random_type=random_type,
f=f,
**kwargs)
# extract the distributed_dato_object from f and apply the appropriate
# random number generator to it
sample = f.get_val(copy=False)
generator_function = getattr(Random, random_type)
sample.apply_generator(
lambda shape: generator_function(dtype=f.dtype,
shape=shape,
**random_arguments))
return f
@staticmethod
def _parse_random_arguments(random_type, f, **kwargs):
if random_type == "pm1":
random_arguments = {}
elif random_type == "normal":
mean = kwargs.get('mean', 0)
std = kwargs.get('std', 1)
random_arguments = {'mean': mean,
'std': std}
elif random_type == "uniform":
low = kwargs.get('low', 0)
high = kwargs.get('high', 1)
random_arguments = {'low': low,
'high': high}
# elif random_type == 'syn':
# pass
else:
datamodel = gc['default_datamodel']
raise KeyError(about._errors.cstring(
"ERROR: unsupported random key '" + str(random_type) + "'."))
return datamodel
return random_arguments
# ---Properties---
def set_val(self, new_val=None, copy=False):
......@@ -462,6 +521,7 @@ class Field(object):
assert len(other.domain) == len(self.domain)
for index in xrange(len(self.domain)):
assert other.domain[index] == self.domain[index]
assert len(other.field_type) == len(self.field_type)
for index in xrange(len(self.field_type)):
assert other.field_type[index] == self.field_type[index]
except AssertionError:
......
......@@ -21,11 +21,9 @@
from __future__ import division
from linear_operator import LinearOperator,\
LinearOperatorParadict
from linear_operator import LinearOperator
from square_operator import SquareOperator,\
SquareOperatorParadict
from endomorphic_operator import EndomorphicOperator
from nifty_operators import operator,\
diagonal_operator,\
......
# -*- coding: utf-8 -*-
from diagonal_operator import DiagonalOperator
# -*- coding: utf-8 -*-
import numpy as np
from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.config import about,\
nifty_configuration as gc
from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator
class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), implemented=False,
diagonal=None, bare=False, datamodel=None, copy=True):
super(DiagonalOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
if datamodel is None:
if isinstance(diagonal, distributed_data_object):
datamodel = diagonal.distribution_strategy
elif isinstance(diagonal, Field):
datamodel = diagonal.datamodel
self.datamodel = self._parse_datamodel(datamodel=datamodel,
val=diagonal)
self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)
def _times(self, x, spaces, types):
pass
def _adjoint_times(self, x, spaces, types):
pass
def _inverse_times(self, x, spaces, types):
pass
def _adjoint_inverse_times(self, x, spaces, types):
pass
def _inverse_adjoint_times(self, x, spaces, types):
pass
def diagonal(self, bare=False, copy=True):
if bare:
diagonal = self._diagonal.weight(power=-1)
elif copy:
diagonal = self._diagonal.copy()
else:
diagonal = self._diagonal
return diagonal
def inverse_diagonal(self, bare=False):
return 1/self.diagonal(bare=bare, copy=False)
def trace(self, bare=False):
return self.diagonal(bare=bare, copy=False).sum()
def inverse_trace(self, bare=False):
return self.inverse_diagonal(bare=bare, copy=False).sum()
def trace_log(self):
log_diagonal = self.diagonal(copy=False).apply_scalar_function(np.log)
return log_diagonal.sum()
def determinant(self):
return self.diagonal(copy=False).val.prod()
def inverse_determinant(self):
return 1/self.determinant()
def log_determinant(self):
return np.log(self.determinant())
# ---Mandatory properties and methods---
@property
def symmetric(self):
return self._symmetric
@property
def unitary(self):
return self._unitary
# ---Added properties and methods---
@property
def datamodel(self):
return self._datamodel
def _parse_datamodel(self, datamodel, val):
if datamodel is None:
if isinstance(val, distributed_data_object):
datamodel = val.distribution_strategy
elif isinstance(val, Field):
datamodel = val.datamodel
else:
about.warnings.cprint("WARNING: Datamodel set to default!")
datamodel = gc['default_datamodel']
elif datamodel not in DISTRIBUTION_STRATEGIES['all']:
raise ValueError(about._errors.cstring(
"ERROR: Invalid datamodel!"))
return datamodel
def set_diagonal(self, diagonal, bare=False, copy=True):
# use the casting functionality from Field to process `diagonal`
f = Field(domain=self.domain,
val=diagonal,
field_type=self.field_type,
datamodel=self.datamodel,
copy=copy)
# weight if the given values were `bare`
f.weight(inplace=True)
# check if the operator is symmetric:
self._symmetric = (f.val.imag == 0).all()
# check if the operator is unitary:
self._unitary = (f.val * f.val.conjugate() == 1).all()
# store the diagonal-field
self._diagonal = f
# -*- coding: utf-8 -*-
from endmorphic_operator import EndomorphicOperator
# -*- coding: utf-8 -*-
from nifty.config import about
from nifty.operators.linear_operator import LinearOperator
from square_operator_paradict import SquareOperatorParadict
class SquareOperator(LinearOperator):
import abc
def __init__(self, domain=None, target=None,
field_type=None, field_type_target=None,
implemented=False, symmetric=False, unitary=False):
if target is not None:
about.warnings.cprint(
"WARNING: Discarding given target for SquareOperator.")
target = domain
from nifty.operators.linear_operator import LinearOperator
if field_type_target is not None:
about.warnings.cprint(
"WARNING: Discarding given field_type_target for "
"SquareOperator.")
field_type_target = field_type
LinearOperator.__init__(self,
domain=domain,
target=target,
field_type=field_type,
field_type_target=field_type_target,
implemented=implemented)
class EndomorphicOperator(LinearOperator):
__metaclass__ = abc.ABCMeta
self.paradict = SquareOperatorParadict(symmetric=symmetric,
unitary=unitary)
# ---Overwritten properties and methods---
def inverse_times(self, x, spaces=None, types=None):
if self.paradict['symmetric'] and self.paradict['unitary']:
return self.times(x, spaces, types)
else:
return LinearOperator.inverse_times(self,
x=x,
spaces=spaces,
types=types)
return super(EndomorphicOperator, self).inverse_times(
x=x,
spaces=spaces,
types=types)
def adjoint_times(self, x, spaces=None, types=None):
if self.paradict['symmetric']:
......@@ -47,10 +25,10 @@ class SquareOperator(LinearOperator):
elif self.paradict['unitary']:
return self.inverse_times(x, spaces, types)
else:
return LinearOperator.adjoint_times(self,
x=x,
spaces=spaces,
types=types)
return super(EndomorphicOperator, self).adjoint_times(
x=x,
spaces=spaces,
types=types)
def adjoint_inverse_times(self, x, spaces=None, types=None):
if self.paradict['symmetric']:
......@@ -58,10 +36,10 @@ class SquareOperator(LinearOperator):
elif self.paradict['unitary']:
return self.times(x, spaces, types)
else:
return LinearOperator.adjoint_inverse_times(self,
x=x,
spaces=spaces,
types=types)
return super(EndomorphicOperator, self).adjoint_inverse_times(
x=x,
spaces=spaces,
types=types)
def inverse_adjoint_times(self, x, spaces=None, types=None):
if self.paradict['symmetric']:
......@@ -69,10 +47,30 @@ class SquareOperator(LinearOperator):
elif self.paradict['unitary']:
return self.times(x, spaces, types)
else:
return LinearOperator.inverse_adjoint_times(self,
x=x,
spaces=spaces,
types=types)
return super(EndomorphicOperator, self).inverse_adjoint_times(
x=x,
spaces=spaces,
types=types)
# ---Mandatory properties and methods---
@property
def target(self):
return self.domain
@property
def field_type_target(self):
return self.field_type
# ---Added properties and methods---
@abc.abstractproperty
def symmetric(self):
raise NotImplementedError
@abc.abstractproperty
def unitary(self):
raise NotImplementedError
def trace(self):
pass
......
# -*- coding: utf-8 -*-
import abc
from nifty.config import about
from nifty.field import Field
from nifty.spaces import Space
from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities
from linear_operator_paradict import LinearOperatorParadict
class LinearOperator(object):
__metaclass__ = abc.ABCMeta
def __init__(self, domain=None, target=None,
field_type=None, field_type_target=None,
implemented=False, symmetric=False, unitary=False):
self.paradict = LinearOperatorParadict()
def __init__(self, domain=(), field_type=(), implemented=False):
self._domain = self._parse_domain(domain)
self._field_type = self._parse_field_type(field_type)
self._implemented = bool(implemented)
self.domain = self._parse_domain(domain)
self.target = self._parse_domain(target)
@property
def domain(self):
return self._domain
@abc.abstractproperty
def target(self):
raise NotImplementedError
@property
def field_type(self):
return self._field_type
self.field_type = self._parse_field_type(field_type)
self.field_type_target = self._parse_field_type(field_type_target)
@abc.abstractproperty
def field_type_target(self):
raise NotImplementedError
def _parse_domain(self, domain):
if domain is None:
......
# -*- coding: utf-8 -*-
from nifty.paradict import Paradict
class LinearOperatorParadict(Paradict):
pass
# -*- coding: utf-8 -*-
from square_operator import SquareOperator
from square_operator_paradict import SquareOperatorParadict
# -*- coding: utf-8 -*-
from nifty.config import about
from nifty.operators.linear_operator import LinearOperatorParadict
class SquareOperatorParadict(LinearOperatorParadict):
def __init__(self, symmetric, unitary):
LinearOperatorParadict.__init__(self,
symmetric=symmetric,
unitary=unitary)
def __setitem__(self, key, arg):
if key not in ['symmetric', 'unitary']:
raise ValueError(about._errors.cstring(
"ERROR: Unsupported SquareOperator parameter: " + key))
if key == 'symmetric':
temp = bool(arg)
elif key == 'unitary':
temp = bool(arg)
self.parameters.__setitem__(key, temp)
# -*- coding: utf-8 -*-
import numpy as np
class Random(object):
@staticmethod
def pm1(dtype=np.dtype('int'), shape=1):
size = int(np.prod(shape))
if issubclass(dtype.type, np.complexfloating):
x = np.array([1 + 0j, 0 + 1j, -1 + 0j, 0 - 1j], dtype=dtype)
x = x[np.random.randint(4, high=None, size=size)]
else:
x = 2 * np.random.randint(2, high=None, size=size) - 1
return x.astype(dtype).reshape(shape)
@staticmethod
def normal(dtype=np.dtype('float64'), shape=(1,), mean=None, std=None):
size = int(np.prod(shape))
if issubclass(dtype.type, np.complexfloating):
x = np.empty(size, dtype=dtype)
x.real = np.random.normal(loc=0, scale=np.sqrt(0.5), size=size)
x.imag = np.random.normal(loc=0, scale=np.sqrt(0.5), size=size)
else:
x = np.random.normal(loc=0, scale=1, size=size)
x = x.astype(dtype, copy=False)
x = x.reshape(shape)
if std is not None:
x *= dtype.type(std)
if mean is not None:
x += dtype.type(mean)
return x
@staticmethod
def uniform(dtype=np.dtype('float64'), shape=1, low=0, high=1):
size = int(np.prod(shape))
if issubclass(dtype.type, np.complexfloating):
x = np.empty(size, dtype=dtype)
x.real = (high - low) * np.random.random(size=size) + low
x.imag = (high - low) * np.random.random(size=size) + low
elif dtype in [np.dtype('int8'), np.dtype('int16'), np.dtype('int32'),
np.dtype('int64')]:
x = np.random.random_integers(min(low, high),
high=max(low, high),