Commit 455b0dd9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cosmetics; start working on scipy minimizers and multifields

parent ef4479aa
...@@ -76,6 +76,8 @@ if __name__ == '__main__': ...@@ -76,6 +76,8 @@ if __name__ == '__main__':
ic_sampling = ift.GradientNormController(iteration_limit=100) ic_sampling = ift.GradientNormController(iteration_limit=100)
ic_newton = ift.GradientNormController(name='Newton', iteration_limit=100) ic_newton = ift.GradientNormController(name='Newton', iteration_limit=100)
minimizer = ift.RelaxedNewton(ic_newton) minimizer = ift.RelaxedNewton(ic_newton)
minimizer = ift.NewtonCG(1e-5, 10, True)
minimizer = ift.L_BFGS_B(1e-10, 1e-5, 100, 10, True)
# build model Hamiltonian # build model Hamiltonian
H = ift.Hamiltonian(likelihood, ic_sampling) H = ift.Hamiltonian(likelihood, ic_sampling)
......
...@@ -18,24 +18,53 @@ ...@@ -18,24 +18,53 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import numpy as np
from .. import dobj from .. import dobj
from ..compat import * from ..compat import *
from ..field import Field from ..field import Field
from ..multi_field import MultiField
from ..domain_tuple import DomainTuple
from ..logger import logger from ..logger import logger
from .iteration_controller import IterationController from .iteration_controller import IterationController
from .minimizer import Minimizer from .minimizer import Minimizer
def _multiToArray(fld):
szall = 0
for val in fld.values():
if val.dtype != np.float64:
raise TypeError("need float64 fields")
szall += val.size
res = np.empty(szall, dtype=np.float64)
ofs = 0
for val in fld.values():
res[ofs:ofs+val.size] = val.local_data.reshape(-1)
ofs += val.size
return res
def _toArray(fld): def _toArray(fld):
return fld.to_global_data().reshape(-1) if isinstance(fld, Field):
return fld.local_data.reshape(-1)
return _multiToArray(fld)
def _toArray_rw(fld): def _toArray_rw(fld):
return fld.to_global_data_rw().reshape(-1) if isinstance(fld, Field):
return fld.local_data.copy().reshape(-1)
return _multiToArray(fld)
def _toField(arr, dom): def _toField(arr, dom):
return Field.from_global_data(dom, arr.reshape(dom.shape).copy()) if isinstance(dom, DomainTuple):
return Field.from_local_data(dom, arr.reshape(dom.shape).copy())
ofs = 0
res = []
for d in dom.domains():
res.append(Field.from_local_data(
d, arr[ofs:ofs+d.size].copy().reshape(d.shape)))
ofs += d.size
return MultiField(dom, tuple(res))
class _MinHelper(object): class _MinHelper(object):
......
...@@ -132,18 +132,16 @@ class _OpProd(Operator): ...@@ -132,18 +132,16 @@ class _OpProd(Operator):
from ..linearization import Linearization from ..linearization import Linearization
from ..sugar import makeOp from ..sugar import makeOp
lin = isinstance(x, Linearization) lin = isinstance(x, Linearization)
v = x._val if lin else x
v1 = v.extract(self._op1.domain)
v2 = v.extract(self._op2.domain)
if not lin: if not lin:
r1 = self._op1(x.extract(self._op1.domain)) return self._op1(v1) * self._op2(v2)
r2 = self._op2(x.extract(self._op2.domain)) lin1 = self._op1(Linearization.make_var(v1))
return r1*r2 lin2 = self._op2(Linearization.make_var(v2))
lin1 = self._op1(
Linearization.make_var(x._val.extract(self._op1.domain)))
lin2 = self._op2(
Linearization.make_var(x._val.extract(self._op2.domain)))
op = (makeOp(lin1._val)(lin2._jac))._myadd( op = (makeOp(lin1._val)(lin2._jac))._myadd(
makeOp(lin2._val)(lin1._jac), False) makeOp(lin2._val)(lin1._jac), False)
jac = op(x.jac) return Linearization(lin1._val*lin2._val, op(x.jac))
return Linearization(lin1._val*lin2._val, jac)
class _OpSum(_CombinedOperator): class _OpSum(_CombinedOperator):
......
...@@ -100,10 +100,7 @@ class Test_Minimizers(unittest.TestCase): ...@@ -100,10 +100,7 @@ class Test_Minimizers(unittest.TestCase):
def __init__(self, loc): def __init__(self, loc):
self._loc = loc.to_global_data_rw() self._loc = loc.to_global_data_rw()
self._capability = self.TIMES self._capability = self.TIMES
self._domain = space
@property
def domain(self):
return space
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment