Commit fe1ce4a4 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_4' into diag_hack

parents 38f2fc5c 83c74e7b
Pipeline #26417 passed with stage
in 5 minutes and 47 seconds
...@@ -31,12 +31,14 @@ class QuadraticEnergy(Energy): ...@@ -31,12 +31,14 @@ class QuadraticEnergy(Energy):
self._b = b self._b = b
if _grad is not None: if _grad is not None:
self._grad = _grad self._grad = _grad
Ax = _grad + self._b Ax = _grad if b is None else _grad + b
else: else:
Ax = self._A(self.position) Ax = self._A(self.position)
self._grad = Ax - self._b self._grad = Ax if b is None else Ax - b
self._grad.lock() self._grad.lock()
self._value = 0.5*self.position.vdot(Ax) - b.vdot(self.position) self._value = 0.5*self.position.vdot(Ax)
if b is not None:
self._value -= b.vdot(self.position)
def at(self, position): def at(self, position):
return QuadraticEnergy(position=position, A=self._A, b=self._b) return QuadraticEnergy(position=position, A=self._A, b=self._b)
......
...@@ -84,6 +84,8 @@ def parse_spaces(spaces, nspc): ...@@ -84,6 +84,8 @@ def parse_spaces(spaces, nspc):
spaces = (safe_cast(int, spaces),) spaces = (safe_cast(int, spaces),)
else: else:
spaces = tuple(safe_cast(int, item) for item in spaces) spaces = tuple(safe_cast(int, item) for item in spaces)
if len(spaces) == 0:
return spaces
tmp = tuple(set(spaces)) tmp = tuple(set(spaces))
if tmp[0] < 0 or tmp[-1] >= nspc: if tmp[0] < 0 or tmp[-1] >= nspc:
raise ValueError("space index out of range") raise ValueError("space index out of range")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment