Commit 5d607ffc authored by Philipp Arras's avatar Philipp Arras
Browse files

Restructure

parent 32af4710
Pipeline #75652 failed with stages
in 4 minutes and 3 seconds
......@@ -485,24 +485,3 @@ class AveragedEnergy(EnergyOperator):
self._check_input(x)
mymap = map(lambda v: self._h(x+v), self._res_samples)
return utilities.my_sum(mymap)/len(self._res_samples)
class _ConstantEnergyOperator(EnergyOperator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
if self.target is not output.domain:
raise TypeError
self._output = output
def apply(self, x):
self._check_input(x)
if x.jac is not None:
val = self._output
jac = NullOperator(self._domain, self._target)
met = NullOperator(self._domain, self._domain) if x.want_metric else None
return x.new(val, jac, met)
return self._output
def __repr__(self):
return 'ConstantEnergyOperator <- {}'.format(self.domain.keys())
......@@ -17,8 +17,9 @@
import numpy as np
from ..utilities import NiftyMeta, indent
from .. import pointwise
from ..multi_domain import MultiDomain
from ..utilities import NiftyMeta, indent
class Operator(metaclass=NiftyMeta):
......@@ -269,19 +270,34 @@ class Operator(metaclass=NiftyMeta):
return self.__class__.__name__
def simplify_for_constant_input(self, c_inp):
from .energy_operators import EnergyOperator, _ConstantEnergyOperator
from .energy_operators import EnergyOperator
from .simplify_for_const import ConstantEnergyOperator, ConstantOperator
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
if isinstance(self.domain, MultiDomain):
assert isinstance(c_inp.domain, MultiDomain)
if c_inp.domain is self.domain:
if isinstance(self, EnergyOperator):
op = _ConstantEnergyOperator(self.domain, self(c_inp))
op = ConstantEnergyOperator(self.domain, self(c_inp))
else:
op = _ConstantOperator(self.domain, self(c_inp))
op = ConstantOperator(self.domain, self(c_inp))
op = ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
if isinstance(self.domain, MultiDomain) and \
set(c_inp.keys()) > set(self.domain.keys()):
raise NotImplementedError('This branch is not tested yet')
op = ConstantOperator(self.domain, self.force(c_inp))
from ..sugar import makeField
unaffected = makeField({kk: vv for kk, vv in c_inp.items() if kk not in self.domain})
for kk in unaffected:
assert kk not in self.domain
assert kk not in self.target
return op.force(c_inp), op
return self._simplify_for_constant_input_nontrivial(c_inp)
def _simplify_for_constant_input_nontrivial(self, c_inp):
return None, self
from .simplify_for_const import SlowPartialConstantOperator
return None, SlowPartialConstantOperator(self, c_inp)
def ptw(self, op, *args, **kwargs):
return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self))
......@@ -295,67 +311,6 @@ for f in pointwise.ptw_dict.keys():
setattr(Operator, f, func(f))
class _ConstCollector(object):
def __init__(self):
self._const = None
self._nc = set()
def mult(self, const, fulldom):
if const is None:
self._nc |= set(fulldom)
else:
self._nc |= set(fulldom) - set(const)
if self._const is None:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: const[key] for key in const if key not in self._nc})
else:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: self._const[key]*const[key]
for key in const if key not in self._nc})
def add(self, const, fulldom):
if const is None:
self._nc |= set(fulldom.keys())
else:
from ..multi_field import MultiField
self._nc |= set(fulldom.keys()) - set(const.keys())
if self._const is None:
self._const = MultiField.from_dict(
{key: const[key]
for key in const.keys() if key not in self._nc})
else:
self._const = self._const.unite(const)
self._const = MultiField.from_dict(
{key: self._const[key]
for key in self._const if key not in self._nc})
@property
def constfield(self):
return self._const
class _ConstantOperator(Operator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
self._target = output.domain
self._output = output
def apply(self, x):
from .simple_linear_operators import NullOperator
self._check_input(x)
if x.jac is not None:
return x.new(self._output, NullOperator(self._domain, self._target))
return self._output
def __repr__(self):
dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- ConstantOperator <- {dom}'
class _FunctionApplier(Operator):
def __init__(self, domain, funcname, *args, **kwargs):
from ..sugar import makeDomain
......@@ -450,16 +405,16 @@ class _OpProd(Operator):
return lin1.new(lin1._val*lin2._val, jac)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpProd(o1, o2)
cc = _ConstCollector()
cc = ConstCollector()
cc.mult(f1, o1.target)
cc.mult(f2, o2.target)
return cc.constfield, _OpProd(o1, o2)
......@@ -496,16 +451,16 @@ class _OpSum(Operator):
return res
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpSum(o1, o2)
cc = _ConstCollector()
cc = ConstCollector()
cc.add(f1, o1.target)
cc.add(f2, o2.target)
return cc.constfield, _OpSum(o1, o2)
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..multi_domain import MultiDomain
from .energy_operators import EnergyOperator
from .operator import Operator
from .simple_linear_operators import NullOperator
class ConstCollector(object):
def __init__(self):
self._const = None
self._nc = set()
def mult(self, const, fulldom):
if const is None:
self._nc |= set(fulldom)
else:
self._nc |= set(fulldom) - set(const)
if self._const is None:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: const[key] for key in const if key not in self._nc})
else:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: self._const[key]*const[key]
for key in const if key not in self._nc})
def add(self, const, fulldom):
if const is None:
self._nc |= set(fulldom.keys())
else:
from ..multi_field import MultiField
self._nc |= set(fulldom.keys()) - set(const.keys())
if self._const is None:
self._const = MultiField.from_dict(
{key: const[key]
for key in const.keys() if key not in self._nc})
else:
self._const = self._const.unite(const)
self._const = MultiField.from_dict(
{key: self._const[key]
for key in self._const if key not in self._nc})
@property
def constfield(self):
return self._const
class ConstantOperator(Operator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
self._target = output.domain
self._output = output
def apply(self, x):
from .simple_linear_operators import NullOperator
self._check_input(x)
if x.jac is not None:
return x.new(self._output, NullOperator(self._domain, self._target))
return self._output
def __repr__(self):
dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- ConstantOperator <- {dom}'
class SlowPartialConstOperator(Operator):
pass
class ConstantEnergyOperator(EnergyOperator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
if self.target is not output.domain:
raise TypeError
self._output = output
def apply(self, x):
self._check_input(x)
if x.jac is not None:
val = self._output
jac = NullOperator(self._domain, self._target)
met = NullOperator(self._domain, self._domain) if x.want_metric else None
return x.new(val, jac, met)
return self._output
def __repr__(self):
return 'ConstantEnergyOperator <- {}'.format(self.domain.keys())
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment