Commit 6e2cbb65 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add MultiModels

parent e0771e65
Pipeline #31760 failed with stages
in 12 minutes and 56 seconds
......@@ -3,8 +3,9 @@ from .linear import LinearModel
from .local_nonlinearity import (LocalModel, PointwiseExponential,
PointwisePositiveTanh, PointwiseTanh)
from .model import Model
from .multi_model import MultiModel
from .variable import Variable
__all__ = ['Model', 'Constant', 'LocalModel', 'Variable',
'LinearModel', 'PointwiseTanh', 'PointwisePositiveTanh',
'PointwiseExponential']
'PointwiseExponential', 'MultiModel']
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..domain_tuple import DomainTuple
from ..multi import MultiField
from ..operators import MultiAdaptor
from .model import Model
class MultiModel(Model):
def __init__(self, model, key):
# TODO Rewrite it such that it takes a dictionary as input. Just like MultiFields.
super(MultiModel, self).__init__(model.position)
self._model = model
self._key = key
val = self._model.value
if not isinstance(val.domain, DomainTuple):
raise TypeError
self._value = MultiField({key: val})
self._gradient = MultiAdaptor(self.value.domain) * self._model.gradient
def at(self, position):
return self.__class__(self._model.at(position), self._key)
......@@ -8,6 +8,7 @@ from .harmonic_transform_operator import HarmonicTransformOperator
from .inversion_enabler import InversionEnabler
from .laplace_operator import LaplaceOperator
from .linear_operator import LinearOperator
from .multi_adaptor import MultiAdaptor
from .power_distributor import PowerDistributor
from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator
......@@ -20,4 +21,4 @@ __all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
"FFTSmoothingOperator", "GeometryRemover",
"LaplaceOperator", "SmoothnessOperator", "PowerDistributor",
"InversionEnabler", "SandwichOperator", "SamplingEnabler",
"DOFDistributor", "SelectionOperator"]
"DOFDistributor", "SelectionOperator", "MultiAdaptor"]
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..multi import MultiDomain, MultiField
from .linear_operator import LinearOperator
class MultiAdaptor(LinearOperator):
def __init__(self, target):
super(MultiAdaptor, self).__init__()
if not isinstance(target, MultiDomain) or len(target) > 1:
raise TypeError
self._target = target
self._domain = list(target.values())[0]
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
key = list(self.target.keys())[0]
if mode == self.TIMES:
return MultiField({key: x})
else:
return x[key]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment