Commit d57b2f02 authored by theos's avatar theos
Browse files

Restructured new operator classes -> Introduced endomorphic_operator.py in...

Restructured new operator classes -> Introduced endomorphic_operator.py in order to avoid violation of Liskov's SP for square- and diagonal operator.
Removed paradict from operators.
Added from_random as static method to Field.
parent 88d89ed6
...@@ -45,7 +45,7 @@ from paradict import Paradict ...@@ -45,7 +45,7 @@ from paradict import Paradict
# TODO: Remove this once the transition to field types is done. # TODO: Remove this once the transition to field types is done.
from spaces.space import Space as point_space from spaces.space import Space as point_space
from nifty_random import random from random import Random
from nifty_simple_math import * from nifty_simple_math import *
from nifty_utilities import * from nifty_utilities import *
......
...@@ -2,11 +2,11 @@ from __future__ import division ...@@ -2,11 +2,11 @@ from __future__ import division
import numpy as np import numpy as np
import pylab as pl import pylab as pl
from d2o import distributed_data_object, \ from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.config import about, \ from nifty.config import about,\
nifty_configuration as gc, \ nifty_configuration as gc,\
dependency_injector as gdi dependency_injector as gdi
from nifty.field_types import FieldType,\ from nifty.field_types import FieldType,\
...@@ -15,6 +15,8 @@ from nifty.field_types import FieldType,\ ...@@ -15,6 +15,8 @@ from nifty.field_types import FieldType,\
from nifty.spaces.space import Space from nifty.spaces.space import Space
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
from nifty.random import Random
POINT_DISTRIBUTION_STRATEGIES = DISTRIBUTION_STRATEGIES['global'] POINT_DISTRIBUTION_STRATEGIES = DISTRIBUTION_STRATEGIES['global']
COMM = getattr(gdi[gc['mpi_module']], gc['default_comm']) COMM = getattr(gdi[gc['mpi_module']], gc['default_comm'])
...@@ -24,20 +26,11 @@ class Field(object): ...@@ -24,20 +26,11 @@ class Field(object):
# ---Initialization methods--- # ---Initialization methods---
def __init__(self, domain=None, val=None, dtype=None, field_type=None, def __init__(self, domain=None, val=None, dtype=None, field_type=None,
datamodel=None, copy=False): datamodel=None, copy=False):
if isinstance(val, Field):
if domain is None:
domain = val.domain
if dtype is None:
dtype = val.dtype
if field_type is None:
field_type = val.field_type
if datamodel is None:
datamodel = val.datamodel
self.domain = self._parse_domain(domain=domain) self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain) self.domain_axes = self._get_axes_tuple(self.domain)
self.field_type = self._parse_field_type(field_type) self.field_type = self._parse_field_type(field_type, val=val)
try: try:
start = len(reduce(lambda x, y: x+y, self.domain_axes)) start = len(reduce(lambda x, y: x+y, self.domain_axes))
...@@ -55,8 +48,11 @@ class Field(object): ...@@ -55,8 +48,11 @@ class Field(object):
self.set_val(new_val=val, copy=copy) self.set_val(new_val=val, copy=copy)
def _parse_domain(self, domain): def _parse_domain(self, domain, val):
if domain is None: if domain is None:
if isinstance(val, Field):
domain = val.domain
else:
domain = () domain = ()
elif not isinstance(domain, tuple): elif not isinstance(domain, tuple):
domain = (domain,) domain = (domain,)
...@@ -67,8 +63,11 @@ class Field(object): ...@@ -67,8 +63,11 @@ class Field(object):
"nifty.space.")) "nifty.space."))
return domain return domain
def _parse_field_type(self, field_type): def _parse_field_type(self, field_type, val):
if field_type is None: if field_type is None:
if isinstance(val, Field):
field_type = val.field_type
else:
field_type = () field_type = ()
elif not isinstance(field_type, tuple): elif not isinstance(field_type, tuple):
field_type = (field_type,) field_type = (field_type,)
...@@ -89,8 +88,11 @@ class Field(object): ...@@ -89,8 +88,11 @@ class Field(object):
axes_list += [tuple(l)] axes_list += [tuple(l)]
return tuple(axes_list) return tuple(axes_list)
def _infer_dtype(self, dtype=None, domain=None, field_type=None): def _infer_dtype(self, dtype, val, domain, field_type):
if dtype is None: if dtype is None:
if isinstance(val, Field) or \
isinstance(val, distributed_data_object):
dtype = val.dtype
dtype_tuple = (np.dtype(gc['default_field_dtype']),) dtype_tuple = (np.dtype(gc['default_field_dtype']),)
else: else:
dtype_tuple = (np.dtype(dtype),) dtype_tuple = (np.dtype(dtype),)
...@@ -100,18 +102,75 @@ class Field(object): ...@@ -100,18 +102,75 @@ class Field(object):
dtype_tuple += tuple(np.dtype(ft.dtype) for ft in field_type) dtype_tuple += tuple(np.dtype(ft.dtype) for ft in field_type)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple) dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
return dtype return dtype
def _parse_datamodel(self, datamodel, val): def _parse_datamodel(self, datamodel, val):
if datamodel in DISTRIBUTION_STRATEGIES['all']: if datamodel is None:
pass if isinstance(val, distributed_data_object):
elif isinstance(val, distributed_data_object):
datamodel = val.distribution_strategy datamodel = val.distribution_strategy
elif isinstance(val, Field):
datamodel = val.datamodel
else: else:
about.warnings.cprint("WARNING: Datamodel set to default!")
datamodel = gc['default_datamodel'] datamodel = gc['default_datamodel']
elif datamodel not in DISTRIBUTION_STRATEGIES['all']:
raise ValueError(about._errors.cstring(
"ERROR: Invalid datamodel!"))
return datamodel return datamodel
# ---Factory methods---
@classmethod
def from_random(cls, random_type, domain=None, dtype=None, field_type=None,
datamodel=None, **kwargs):
# create a initially empty field
f = cls(domain=domain, dtype=dtype, field_type=field_type,
datamodel=datamodel)
# now use the processed input in terms of f in order to parse the
# random arguments
random_arguments = cls._parse_random_arguments(random_type=random_type,
f=f,
**kwargs)
# extract the distributed_dato_object from f and apply the appropriate
# random number generator to it
sample = f.get_val(copy=False)
generator_function = getattr(Random, random_type)
sample.apply_generator(
lambda shape: generator_function(dtype=f.dtype,
shape=shape,
**random_arguments))
return f
@staticmethod
def _parse_random_arguments(random_type, f, **kwargs):
if random_type == "pm1":
random_arguments = {}
elif random_type == "normal":
mean = kwargs.get('mean', 0)
std = kwargs.get('std', 1)
random_arguments = {'mean': mean,
'std': std}
elif random_type == "uniform":
low = kwargs.get('low', 0)
high = kwargs.get('high', 1)
random_arguments = {'low': low,
'high': high}
# elif random_type == 'syn':
# pass
else:
raise KeyError(about._errors.cstring(
"ERROR: unsupported random key '" + str(random_type) + "'."))
return random_arguments
# ---Properties--- # ---Properties---
def set_val(self, new_val=None, copy=False): def set_val(self, new_val=None, copy=False):
new_val = self.cast(new_val) new_val = self.cast(new_val)
...@@ -462,6 +521,7 @@ class Field(object): ...@@ -462,6 +521,7 @@ class Field(object):
assert len(other.domain) == len(self.domain) assert len(other.domain) == len(self.domain)
for index in xrange(len(self.domain)): for index in xrange(len(self.domain)):
assert other.domain[index] == self.domain[index] assert other.domain[index] == self.domain[index]
assert len(other.field_type) == len(self.field_type)
for index in xrange(len(self.field_type)): for index in xrange(len(self.field_type)):
assert other.field_type[index] == self.field_type[index] assert other.field_type[index] == self.field_type[index]
except AssertionError: except AssertionError:
......
...@@ -21,11 +21,9 @@ ...@@ -21,11 +21,9 @@
from __future__ import division from __future__ import division
from linear_operator import LinearOperator,\ from linear_operator import LinearOperator
LinearOperatorParadict
from square_operator import SquareOperator,\ from endomorphic_operator import EndomorphicOperator
SquareOperatorParadict
from nifty_operators import operator,\ from nifty_operators import operator,\
diagonal_operator,\ diagonal_operator,\
......
# -*- coding: utf-8 -*-
from diagonal_operator import DiagonalOperator
# -*- coding: utf-8 -*-
import numpy as np
from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.config import about,\
nifty_configuration as gc
from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator
class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), implemented=False,
diagonal=None, bare=False, datamodel=None, copy=True):
super(DiagonalOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
if datamodel is None:
if isinstance(diagonal, distributed_data_object):
datamodel = diagonal.distribution_strategy
elif isinstance(diagonal, Field):
datamodel = diagonal.datamodel
self.datamodel = self._parse_datamodel(datamodel=datamodel,
val=diagonal)
self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)
def _times(self, x, spaces, types):
pass
def _adjoint_times(self, x, spaces, types):
pass
def _inverse_times(self, x, spaces, types):
pass
def _adjoint_inverse_times(self, x, spaces, types):
pass
def _inverse_adjoint_times(self, x, spaces, types):
pass
def diagonal(self, bare=False, copy=True):
if bare:
diagonal = self._diagonal.weight(power=-1)
elif copy:
diagonal = self._diagonal.copy()
else:
diagonal = self._diagonal
return diagonal
def inverse_diagonal(self, bare=False):
return 1/self.diagonal(bare=bare, copy=False)
def trace(self, bare=False):
return self.diagonal(bare=bare, copy=False).sum()
def inverse_trace(self, bare=False):
return self.inverse_diagonal(bare=bare, copy=False).sum()
def trace_log(self):
log_diagonal = self.diagonal(copy=False).apply_scalar_function(np.log)
return log_diagonal.sum()
def determinant(self):
return self.diagonal(copy=False).val.prod()
def inverse_determinant(self):
return 1/self.determinant()
def log_determinant(self):
return np.log(self.determinant())
# ---Mandatory properties and methods---
@property
def symmetric(self):
return self._symmetric
@property
def unitary(self):
return self._unitary
# ---Added properties and methods---
@property
def datamodel(self):
return self._datamodel
def _parse_datamodel(self, datamodel, val):
if datamodel is None:
if isinstance(val, distributed_data_object):
datamodel = val.distribution_strategy
elif isinstance(val, Field):
datamodel = val.datamodel
else:
about.warnings.cprint("WARNING: Datamodel set to default!")
datamodel = gc['default_datamodel']
elif datamodel not in DISTRIBUTION_STRATEGIES['all']:
raise ValueError(about._errors.cstring(
"ERROR: Invalid datamodel!"))
return datamodel
def set_diagonal(self, diagonal, bare=False, copy=True):
# use the casting functionality from Field to process `diagonal`
f = Field(domain=self.domain,
val=diagonal,
field_type=self.field_type,
datamodel=self.datamodel,
copy=copy)
# weight if the given values were `bare`
f.weight(inplace=True)
# check if the operator is symmetric:
self._symmetric = (f.val.imag == 0).all()
# check if the operator is unitary:
self._unitary = (f.val * f.val.conjugate() == 1).all()
# store the diagonal-field
self._diagonal = f
# -*- coding: utf-8 -*-
from endmorphic_operator import EndomorphicOperator
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from nifty.config import about import abc
from nifty.operators.linear_operator import LinearOperator
from square_operator_paradict import SquareOperatorParadict
class SquareOperator(LinearOperator):
def __init__(self, domain=None, target=None, from nifty.operators.linear_operator import LinearOperator
field_type=None, field_type_target=None,
implemented=False, symmetric=False, unitary=False):
if target is not None:
about.warnings.cprint(
"WARNING: Discarding given target for SquareOperator.")
target = domain
if field_type_target is not None:
about.warnings.cprint(
"WARNING: Discarding given field_type_target for "
"SquareOperator.")
field_type_target = field_type
LinearOperator.__init__(self, class EndomorphicOperator(LinearOperator):
domain=domain, __metaclass__ = abc.ABCMeta
target=target,
field_type=field_type,
field_type_target=field_type_target,
implemented=implemented)
self.paradict = SquareOperatorParadict(symmetric=symmetric, # ---Overwritten properties and methods---
unitary=unitary)
def inverse_times(self, x, spaces=None, types=None): def inverse_times(self, x, spaces=None, types=None):
if self.paradict['symmetric'] and self.paradict['unitary']: if self.paradict['symmetric'] and self.paradict['unitary']:
return self.times(x, spaces, types) return self.times(x, spaces, types)
else: else:
return LinearOperator.inverse_times(self, return super(EndomorphicOperator, self).inverse_times(
x=x, x=x,
spaces=spaces, spaces=spaces,
types=types) types=types)
...@@ -47,7 +25,7 @@ class SquareOperator(LinearOperator): ...@@ -47,7 +25,7 @@ class SquareOperator(LinearOperator):
elif self.paradict['unitary']: elif self.paradict['unitary']:
return self.inverse_times(x, spaces, types) return self.inverse_times(x, spaces, types)
else: else:
return LinearOperator.adjoint_times(self, return super(EndomorphicOperator, self).adjoint_times(
x=x, x=x,
spaces=spaces, spaces=spaces,
types=types) types=types)
...@@ -58,7 +36,7 @@ class SquareOperator(LinearOperator): ...@@ -58,7 +36,7 @@ class SquareOperator(LinearOperator):
elif self.paradict['unitary']: elif self.paradict['unitary']:
return self.times(x, spaces, types) return self.times(x, spaces, types)
else: else:
return LinearOperator.adjoint_inverse_times(self, return super(EndomorphicOperator, self).adjoint_inverse_times(
x=x, x=x,
spaces=spaces, spaces=spaces,
types=types) types=types)
...@@ -69,11 +47,31 @@ class SquareOperator(LinearOperator): ...@@ -69,11 +47,31 @@ class SquareOperator(LinearOperator):
elif self.paradict['unitary']: elif self.paradict['unitary']:
return self.times(x, spaces, types) return self.times(x, spaces, types)
else: else:
return LinearOperator.inverse_adjoint_times(self, return super(EndomorphicOperator, self).inverse_adjoint_times(
x=x, x=x,
spaces=spaces, spaces=spaces,
types=types) types=types)
# ---Mandatory properties and methods---
@property
def target(self):
return self.domain
@property
def field_type_target(self):
return self.field_type
# ---Added properties and methods---
@abc.abstractproperty
def symmetric(self):
raise NotImplementedError
@abc.abstractproperty
def unitary(self):
raise NotImplementedError
def trace(self): def trace(self):
pass pass
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import abc
from nifty.config import about from nifty.config import about
from nifty.field import Field from nifty.field import Field
from nifty.spaces import Space from nifty.spaces import Space
from nifty.field_types import FieldType from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
from linear_operator_paradict import LinearOperatorParadict
class LinearOperator(object): class LinearOperator(object):
__metaclass__ = abc.ABCMeta
def __init__(self, domain=None, target=None, def __init__(self, domain=(), field_type=(), implemented=False):
field_type=None, field_type_target=None, self._domain = self._parse_domain(domain)
implemented=False, symmetric=False, unitary=False): self._field_type = self._parse_field_type(field_type)
self.paradict = LinearOperatorParadict()
self._implemented = bool(implemented) self._implemented = bool(implemented)
self.domain = self._parse_domain(domain) @property
self.target = self._parse_domain(target) def domain(self):
return self._domain
@abc.abstractproperty
def target(self):
raise NotImplementedError
@property
def field_type(self):
return self._field_type
self.field_type = self._parse_field_type(field_type) @abc.abstractproperty
self.field_type_target = self._parse_field_type(field_type_target) def field_type_target(self):
raise NotImplementedError
def _parse_domain(self, domain): def _parse_domain(self, domain):
if domain is None: if domain is None:
......
# -*- coding: utf-8 -*-
from nifty.paradict import Paradict
class LinearOperatorParadict(Paradict):
pass
# -*- coding: utf-8 -*-
from square_operator import SquareOperator
from square_operator_paradict import SquareOperatorParadict
# -*- coding: utf-8 -*-
from nifty.config import about
from nifty.operators.linear_operator import LinearOperatorParadict