linearization.py 11.6 KB
 Martin Reinecke committed Jan 07, 2019 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 ``````# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2019 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. `````` Martin Reinecke committed Jul 26, 2018 17 18 19 20 `````` import numpy as np from .field import Field `````` Martin Reinecke committed Aug 05, 2018 21 ``````from .multi_field import MultiField `````` Martin Reinecke committed Jul 26, 2018 22 ``````from .sugar import makeOp `````` Jakob Knollmueller committed Dec 21, 2018 23 ``````from .operators.scaling_operator import ScalingOperator `````` Martin Reinecke committed Jul 26, 2018 24 25 26 `````` class Linearization(object): `````` Martin Reinecke committed Jan 10, 2019 27 `````` """Let `A` be an operator and `x` a field. `Linearization` stores the value `````` Martin Reinecke committed Jan 08, 2019 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 `````` of the operator application (i.e. `A(x)`), the local Jacobian (i.e. `dA(x)/dx`) and, optionally, the local metric. Parameters ---------- val : Field/MultiField the value of the operator application jac : LinearOperator the Jacobian metric : LinearOperator or None (default: None) the metric want_metric : bool (default: False) if True, the metric will be computed for other Linearizations derived from this one. """ `````` Martin Reinecke committed Aug 21, 2018 43 `````` def __init__(self, val, jac, metric=None, want_metric=False): `````` Martin Reinecke committed Jul 26, 2018 44 45 `````` self._val = val self._jac = jac `````` Martin Reinecke committed Aug 10, 2018 46 47 `````` if self._val.domain != self._jac.target: raise ValueError("domain mismatch") `````` Martin Reinecke committed Aug 21, 2018 48 `````` self._want_metric = want_metric `````` Martin Reinecke committed Jul 26, 2018 49 50 `````` self._metric = metric `````` Martin Reinecke committed Aug 21, 2018 51 `````` def new(self, val, jac, metric=None): `````` Martin Reinecke committed Jan 08, 2019 52 53 54 55 56 57 58 59 60 61 62 63 `````` """Create a new Linearization, taking the `want_metric` property from this one. Parameters ---------- val : Field/MultiField the value of the operator application jac : LinearOperator the Jacobian metric : LinearOperator or None (default: None) the metric """ `````` Martin Reinecke committed Aug 21, 2018 64 65 `````` return Linearization(val, jac, metric, self._want_metric) `````` Martin Reinecke committed Jul 26, 2018 66 67 `````` @property def domain(self): `````` Martin Reinecke committed Jan 08, 2019 68 `````` """DomainTuple/MultiDomain : the Jacobian's domain""" `````` Martin Reinecke committed Jul 26, 2018 69 70 71 72 `````` return self._jac.domain @property def target(self): `````` Martin Reinecke committed Jan 08, 2019 73 `````` """DomainTuple/MultiDomain : the Jacobian's target (i.e. the value's domain)""" `````` Martin Reinecke committed Jul 26, 2018 74 75 76 77 `````` return self._jac.target @property def val(self): `````` Martin Reinecke committed Jan 08, 2019 78 `````` """Field/MultiField : the value""" `````` Martin Reinecke committed Jul 26, 2018 79 80 81 82 `````` return self._val @property def jac(self): `````` Martin Reinecke committed Jan 08, 2019 83 `````` """LinearOperator : the Jacobian""" `````` Martin Reinecke committed Jul 26, 2018 84 85 `````` return self._jac `````` Martin Reinecke committed Jul 26, 2018 86 87 `````` @property def gradient(self): `````` Martin Reinecke committed Jan 08, 2019 88 89 90 91 92 93 `````` """Field/MultiField : the gradient Notes ----- Only available if target is a scalar """ `````` Martin Reinecke committed Aug 09, 2018 94 `````` return self._jac.adjoint_times(Field.scalar(1.)) `````` Martin Reinecke committed Jul 26, 2018 95 `````` `````` Martin Reinecke committed Aug 21, 2018 96 97 `````` @property def want_metric(self): `````` Martin Reinecke committed Jan 08, 2019 98 `````` """bool : the value of `want_metric`""" `````` Martin Reinecke committed Aug 21, 2018 99 100 `````` return self._want_metric `````` Martin Reinecke committed Jul 26, 2018 101 102 `````` @property def metric(self): `````` Martin Reinecke committed Jan 08, 2019 103 104 105 106 107 108 `````` """LinearOperator : the metric Notes ----- Only available if target is a scalar """ `````` Martin Reinecke committed Jul 26, 2018 109 110 `````` return self._metric `````` Martin Reinecke committed Jul 27, 2018 111 `````` def __getitem__(self, name): `````` Martin Reinecke committed Nov 26, 2018 112 113 `````` from .operators.simple_linear_operators import ducktape return self.new(self._val[name], ducktape(None, self.domain, name)) `````` Martin Reinecke committed Jul 27, 2018 114 `````` `````` Martin Reinecke committed Jul 26, 2018 115 `````` def __neg__(self): `````` Martin Reinecke committed Aug 21, 2018 116 117 `````` return self.new(-self._val, -self._jac, None if self._metric is None else -self._metric) `````` Martin Reinecke committed Jul 26, 2018 118 `````` `````` Martin Reinecke committed Aug 06, 2018 119 `````` def conjugate(self): `````` Martin Reinecke committed Aug 21, 2018 120 `````` return self.new( `````` Martin Reinecke committed Aug 06, 2018 121 122 123 124 125 `````` self._val.conjugate(), self._jac.conjugate(), None if self._metric is None else self._metric.conjugate()) @property def real(self): `````` Martin Reinecke committed Aug 21, 2018 126 `````` return self.new(self._val.real, self._jac.real) `````` Martin Reinecke committed Aug 06, 2018 127 `````` `````` Martin Reinecke committed Aug 12, 2018 128 `````` def _myadd(self, other, neg): `````` Martin Reinecke committed Jul 26, 2018 129 130 131 `````` if isinstance(other, Linearization): met = None if self._metric is not None and other._metric is not None: `````` Martin Reinecke committed Aug 12, 2018 132 `````` met = self._metric._myadd(other._metric, neg) `````` Martin Reinecke committed Aug 21, 2018 133 `````` return self.new( `````` Martin Reinecke committed Aug 12, 2018 134 135 `````` self._val.flexible_addsub(other._val, neg), self._jac._myadd(other._jac, neg), met) `````` Martin Reinecke committed Jul 26, 2018 136 `````` if isinstance(other, (int, float, complex, Field, MultiField)): `````` Martin Reinecke committed Aug 12, 2018 137 `````` if neg: `````` Martin Reinecke committed Aug 21, 2018 138 `````` return self.new(self._val-other, self._jac, self._metric) `````` Martin Reinecke committed Aug 12, 2018 139 `````` else: `````` Martin Reinecke committed Aug 21, 2018 140 `````` return self.new(self._val+other, self._jac, self._metric) `````` Martin Reinecke committed Aug 12, 2018 141 142 143 `````` def __add__(self, other): return self._myadd(other, False) `````` Martin Reinecke committed Jul 26, 2018 144 145 `````` def __radd__(self, other): `````` Martin Reinecke committed Aug 12, 2018 146 `````` return self._myadd(other, False) `````` Martin Reinecke committed Jul 26, 2018 147 148 `````` def __sub__(self, other): `````` Martin Reinecke committed Aug 12, 2018 149 `````` return self._myadd(other, True) `````` Martin Reinecke committed Jul 26, 2018 150 151 152 153 `````` def __rsub__(self, other): return (-self).__add__(other) `````` Martin Reinecke committed Sep 12, 2018 154 155 `````` def __truediv__(self, other): if isinstance(other, Linearization): `````` Philipp Frank committed Jan 09, 2019 156 `````` return self.__mul__(other.one_over()) `````` Martin Reinecke committed Sep 12, 2018 157 158 159 `````` return self.__mul__(1./other) def __rtruediv__(self, other): `````` Philipp Frank committed Jan 09, 2019 160 `````` return self.one_over().__mul__(other) `````` Martin Reinecke committed Sep 12, 2018 161 `````` `````` Martin Reinecke committed Sep 27, 2018 162 163 164 `````` def __pow__(self, power): if not np.isscalar(power): return NotImplemented `````` Martin Reinecke committed Oct 09, 2018 165 166 `````` return self.new(self._val**power, makeOp(self._val**(power-1)).scale(power)(self._jac)) `````` Martin Reinecke committed Sep 27, 2018 167 `````` `````` Martin Reinecke committed Jul 26, 2018 168 169 170 `````` def __mul__(self, other): from .sugar import makeOp if isinstance(other, Linearization): `````` Martin Reinecke committed Aug 10, 2018 171 172 `````` if self.target != other.target: raise ValueError("domain mismatch") `````` Martin Reinecke committed Aug 21, 2018 173 `````` return self.new( `````` Martin Reinecke committed Jul 28, 2018 174 `````` self._val*other._val, `````` Martin Reinecke committed Aug 10, 2018 175 176 `````` (makeOp(other._val)(self._jac))._myadd( makeOp(self._val)(other._jac), False)) `````` Martin Reinecke committed Aug 06, 2018 177 178 179 180 `````` if np.isscalar(other): if other == 1: return self met = None if self._metric is None else self._metric.scale(other) `````` Martin Reinecke committed Aug 21, 2018 181 `````` return self.new(self._val*other, self._jac.scale(other), met) `````` Martin Reinecke committed Jul 26, 2018 182 `````` if isinstance(other, (Field, MultiField)): `````` Martin Reinecke committed Aug 10, 2018 183 184 `````` if self.target != other.domain: raise ValueError("domain mismatch") `````` Martin Reinecke committed Aug 21, 2018 185 `````` return self.new(self._val*other, makeOp(other)(self._jac)) `````` Martin Reinecke committed Jul 26, 2018 186 187 `````` def __rmul__(self, other): `````` Martin Reinecke committed Aug 06, 2018 188 `````` return self.__mul__(other) `````` Martin Reinecke committed Jul 26, 2018 189 `````` `````` 190 191 192 193 `````` def outer(self, other): from .operators.outer_product_operator import OuterProduct if isinstance(other, Linearization): return self.new( `````` Sebastian Hutschenreuter committed Sep 19, 2018 194 195 196 `````` OuterProduct(self._val, other.target)(other._val), OuterProduct(self._jac(self._val), other.target)._myadd( OuterProduct(self._val, other.target)(other._jac), False)) `````` 197 `````` if np.isscalar(other): `````` Martin Reinecke committed Sep 12, 2018 198 `````` return self.__mul__(other) `````` 199 `````` if isinstance(other, (Field, MultiField)): `````` Sebastian Hutschenreuter committed Sep 19, 2018 200 201 `````` return self.new(OuterProduct(self._val, other.domain)(other), OuterProduct(self._jac(self._val), other.domain)) `````` 202 `````` `````` Martin Reinecke committed Aug 03, 2018 203 `````` def vdot(self, other): `````` Martin Reinecke committed Aug 05, 2018 204 `````` from .operators.simple_linear_operators import VdotOperator `````` Martin Reinecke committed Aug 03, 2018 205 `````` if isinstance(other, (Field, MultiField)): `````` Martin Reinecke committed Aug 21, 2018 206 `````` return self.new( `````` Martin Reinecke committed Aug 09, 2018 207 `````` Field.scalar(self._val.vdot(other)), `````` Martin Reinecke committed Aug 05, 2018 208 `````` VdotOperator(other)(self._jac)) `````` Martin Reinecke committed Aug 21, 2018 209 `````` return self.new( `````` Martin Reinecke committed Aug 09, 2018 210 `````` Field.scalar(self._val.vdot(other._val)), `````` Martin Reinecke committed Aug 05, 2018 211 212 `````` VdotOperator(self._val)(other._jac) + VdotOperator(other._val)(self._jac)) `````` Martin Reinecke committed Aug 03, 2018 213 `````` `````` 214 `````` def sum(self, spaces=None): `````` Martin Reinecke committed Sep 18, 2018 215 `````` from .operators.contraction_operator import ContractionOperator `````` 216 217 218 `````` if spaces is None: return self.new( Field.scalar(self._val.sum()), `````` Martin Reinecke committed Sep 18, 2018 219 `````` ContractionOperator(self._jac.target, None)(self._jac)) `````` 220 221 222 `````` else: return self.new( self._val.sum(spaces), `````` Martin Reinecke committed Sep 18, 2018 223 `````` ContractionOperator(self._jac.target, spaces)(self._jac)) `````` 224 225 `````` def integrate(self, spaces=None): `````` Martin Reinecke committed Sep 18, 2018 226 `````` from .operators.contraction_operator import ContractionOperator `````` 227 228 229 `````` if spaces is None: return self.new( Field.scalar(self._val.integrate()), `````` Martin Reinecke committed Sep 18, 2018 230 `````` ContractionOperator(self._jac.target, None, 1)(self._jac)) `````` 231 232 233 `````` else: return self.new( self._val.integrate(spaces), `````` Martin Reinecke committed Sep 18, 2018 234 `````` ContractionOperator(self._jac.target, spaces, 1)(self._jac)) `````` Martin Reinecke committed Jul 26, 2018 235 236 237 `````` def exp(self): tmp = self._val.exp() `````` Martin Reinecke committed Aug 21, 2018 238 `````` return self.new(tmp, makeOp(tmp)(self._jac)) `````` Philipp Arras committed Oct 15, 2018 239 `````` `````` Martin Reinecke committed Dec 21, 2018 240 241 `````` def clip(self, min=None, max=None): tmp = self._val.clip(min, max) `````` Jakob Knollmueller committed Dec 21, 2018 242 `````` if (min is None) and (max is None): `````` Martin Reinecke committed Dec 21, 2018 243 `````` return self `````` Jakob Knollmueller committed Dec 21, 2018 244 245 246 247 248 249 `````` elif max is None: tmp2 = makeOp(1. - (tmp == min)) elif min is None: tmp2 = makeOp(1. - (tmp == max)) else: tmp2 = makeOp(1. - (tmp == min) - (tmp == max)) `````` Jakob Knollmueller committed Dec 15, 2018 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 `````` return self.new(tmp, tmp2(self._jac)) def sin(self): tmp = self._val.sin() tmp2 = self._val.cos() return self.new(tmp, makeOp(tmp2)(self._jac)) def cos(self): tmp = self._val.cos() tmp2 = - self._val.sin() return self.new(tmp, makeOp(tmp2)(self._jac)) def tan(self): tmp = self._val.tan() tmp2 = 1./(self._val.cos()**2) return self.new(tmp, makeOp(tmp2)(self._jac)) def sinc(self): tmp = self._val.sinc() tmp2 = (self._val.cos()-tmp)/self._val return self.new(tmp, makeOp(tmp2)(self._jac)) `````` Martin Reinecke committed Jul 26, 2018 272 273 `````` def log(self): tmp = self._val.log() `````` Martin Reinecke committed Aug 21, 2018 274 `````` return self.new(tmp, makeOp(1./self._val)(self._jac)) `````` Martin Reinecke committed Jul 26, 2018 275 `````` `````` Jakob Knollmueller committed Dec 15, 2018 276 277 278 279 280 281 282 283 284 285 `````` def sinh(self): tmp = self._val.sinh() tmp2 = self._val.cosh() return self.new(tmp, makeOp(tmp2)(self._jac)) def cosh(self): tmp = self._val.cosh() tmp2 = self._val.sinh() return self.new(tmp, makeOp(tmp2)(self._jac)) `````` Martin Reinecke committed Jul 27, 2018 286 287 `````` def tanh(self): tmp = self._val.tanh() `````` Martin Reinecke committed Aug 21, 2018 288 `````` return self.new(tmp, makeOp(1.-tmp**2)(self._jac)) `````` Martin Reinecke committed Jul 27, 2018 289 `````` `````` Jakob Knollmueller committed Dec 15, 2018 290 `````` def sigmoid(self): `````` Martin Reinecke committed Jul 27, 2018 291 292 `````` tmp = self._val.tanh() tmp2 = 0.5*(1.+tmp) `````` Martin Reinecke committed Aug 21, 2018 293 `````` return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac)) `````` Martin Reinecke committed Jul 27, 2018 294 `````` `````` Jakob Knollmueller committed Dec 15, 2018 295 296 297 298 299 300 301 302 303 304 `````` def absolute(self): tmp = self._val.absolute() tmp2 = self._val.sign() return self.new(tmp, makeOp(tmp2)(self._jac)) def one_over(self): tmp = 1./self._val tmp2 = - tmp/self._val return self.new(tmp, makeOp(tmp2)(self._jac)) `````` Martin Reinecke committed Jul 26, 2018 305 `````` def add_metric(self, metric): `````` Martin Reinecke committed Aug 21, 2018 306 `````` return self.new(self._val, self._jac, metric) `````` Martin Reinecke committed Jul 26, 2018 307 `````` `````` Martin Reinecke committed Aug 29, 2018 308 309 310 `````` def with_want_metric(self): return Linearization(self._val, self._jac, self._metric, True) `````` Martin Reinecke committed Jul 26, 2018 311 `````` @staticmethod `````` Martin Reinecke committed Aug 21, 2018 312 `````` def make_var(field, want_metric=False): `````` Martin Reinecke committed Jul 26, 2018 313 `````` from .operators.scaling_operator import ScalingOperator `````` Martin Reinecke committed Aug 21, 2018 314 315 `````` return Linearization(field, ScalingOperator(1., field.domain), want_metric=want_metric) `````` Martin Reinecke committed Jul 26, 2018 316 317 `````` @staticmethod `````` Martin Reinecke committed Aug 21, 2018 318 `````` def make_const(field, want_metric=False): `````` Martin Reinecke committed Aug 05, 2018 319 `````` from .operators.simple_linear_operators import NullOperator `````` Martin Reinecke committed Aug 21, 2018 320 321 `````` return Linearization(field, NullOperator(field.domain, field.domain), want_metric=want_metric) `````` Martin Reinecke committed Aug 29, 2018 322 `````` `````` Martin Reinecke committed Sep 14, 2018 323 324 325 326 `````` @staticmethod def make_const_empty_input(field, want_metric=False): from .operators.simple_linear_operators import NullOperator from .multi_domain import MultiDomain `````` Martin Reinecke committed Sep 18, 2018 327 328 329 `````` return Linearization( field, NullOperator(MultiDomain.make({}), field.domain), want_metric=want_metric) `````` Martin Reinecke committed Sep 14, 2018 330 `````` `````` Martin Reinecke committed Aug 29, 2018 331 332 333 `````` @staticmethod def make_partial_var(field, constants, want_metric=False): from .operators.scaling_operator import ScalingOperator `````` Philipp Arras committed Aug 31, 2018 334 `````` from .operators.block_diagonal_operator import BlockDiagonalOperator `````` Martin Reinecke committed Aug 29, 2018 335 336 337 338 339 `````` if len(constants) == 0: return Linearization.make_var(field, want_metric) else: ops = [ScalingOperator(0. if key in constants else 1., dom) for key, dom in field.domain.items()] `````` Philipp Arras committed Aug 31, 2018 340 `````` bdop = BlockDiagonalOperator(field.domain, tuple(ops)) `````` Martin Reinecke committed Aug 29, 2018 341 `` return Linearization(field, bdop, want_metric=want_metric)``