Skip to content
Snippets Groups Projects
Commit 3f00aed2 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

minor cleanups

parent 59d7def9
No related branches found
No related tags found
No related merge requests found
...@@ -31,10 +31,10 @@ class Hamiltonian(Energy): ...@@ -31,10 +31,10 @@ class Hamiltonian(Energy):
lh: Likelihood (energy object) lh: Likelihood (energy object)
prior: prior:
""" """
super(Hamiltonian, self).__init__(lh.position) super(Hamiltonian, self).__init__(lh._position)
self._lh = lh self._lh = lh
self._ic_samp = iteration_controller_sampling self._ic_samp = iteration_controller_sampling
self._prior = GaussianEnergy(Variable(self.position)) self._prior = GaussianEnergy(Variable(self._position))
def at(self, position): def at(self, position):
return self.__class__(self._lh.at(position), self._ic_samp) return self.__class__(self._lh.at(position), self._ic_samp)
......
...@@ -32,7 +32,7 @@ class GaussianEnergy(Energy): ...@@ -32,7 +32,7 @@ class GaussianEnergy(Energy):
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance covariance
""" """
super(GaussianEnergy, self).__init__(inp.position) super(GaussianEnergy, self).__init__(inp._position)
self._inp = inp self._inp = inp
self._mean = mean self._mean = mean
self._cov = covariance self._cov = covariance
......
...@@ -35,14 +35,14 @@ class QuadraticEnergy(Energy): ...@@ -35,14 +35,14 @@ class QuadraticEnergy(Energy):
self._grad = _grad self._grad = _grad
Ax = _grad if b is None else _grad + b Ax = _grad if b is None else _grad + b
else: else:
Ax = self._A(self.position) Ax = self._A(self._position)
self._grad = Ax if b is None else Ax - b self._grad = Ax if b is None else Ax - b
self._value = 0.5*self.position.vdot(Ax) self._value = 0.5*self._position.vdot(Ax)
if b is not None: if b is not None:
self._value -= b.vdot(self.position) self._value -= b.vdot(self._position)
def at(self, position): def at(self, position):
return QuadraticEnergy(position=position, A=self._A, b=self._b) return QuadraticEnergy(position, self._A, self._b)
def at_with_grad(self, position, grad): def at_with_grad(self, position, grad):
""" Specialized version of `at`, taking also a gradient. """ Specialized version of `at`, taking also a gradient.
...@@ -63,8 +63,7 @@ class QuadraticEnergy(Energy): ...@@ -63,8 +63,7 @@ class QuadraticEnergy(Energy):
Energy Energy
Energy object at new position. Energy object at new position.
""" """
return QuadraticEnergy(position=position, A=self._A, b=self._b, return QuadraticEnergy(position, self._A, self._b, grad)
_grad=grad)
@property @property
def value(self): def value(self):
......
...@@ -35,7 +35,7 @@ class MultiModel(Model): ...@@ -35,7 +35,7 @@ class MultiModel(Model):
val = self._model.value val = self._model.value
if not isinstance(val.domain, DomainTuple): if not isinstance(val.domain, DomainTuple):
raise TypeError raise TypeError
self._value = MultiField({key: val}) self._value = MultiField.from_dict({key: val})
self._jacobian = (MultiAdaptor(self.value.domain) * self._jacobian = (MultiAdaptor(self.value.domain) *
self._model.jacobian) self._model.jacobian)
......
...@@ -47,7 +47,7 @@ class SymmetrizingOperator(EndomorphicOperator): ...@@ -47,7 +47,7 @@ class SymmetrizingOperator(EndomorphicOperator):
tmp2[lead+(slice(1, None),)] -= tmp2[lead+(slice(None, 0, -1),)] tmp2[lead+(slice(1, None),)] -= tmp2[lead+(slice(None, 0, -1),)]
if i == ax: if i == ax:
tmp = dobj.redistribute(tmp, dist=ax) tmp = dobj.redistribute(tmp, dist=ax)
return Field(self.target, val=tmp) return Field(self.target, val=tmp)
@property @property
def capability(self): def capability(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment