Commit e71ecace authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'fixup_Student' into 'NIFTy_6'

Fix a bug where studentt would only work if theta is a scalar

See merge request ift/nifty!455
parents c5c2881e f6fea72f
Pipeline #74162 passed with stages
in 20 minutes and 36 seconds
......@@ -286,7 +286,7 @@ class StudentTEnergy(EnergyOperator):
domain : `Domain` or `DomainTuple`
Domain of the operator
theta : Scalar
theta : Scalar or Field
Degree of freedom parameter for the student t distribution
......@@ -296,10 +296,10 @@ class StudentTEnergy(EnergyOperator):
def apply(self, x):
res = ((self._theta+1)/2)*(x**2/self._theta).ptw("log1p").sum()
res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
if not x.want_metric:
return res
met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3))
met = makeOp((self._theta+1) / (self._theta+3), self.domain)
return res.add_metric(met)
......@@ -345,7 +345,7 @@ def makeOp(input, dom=None):
if np.isscalar(input):
if not isinstance(dom, (DomainTuple, MultiDomain)):
raise TypeError("need proper `dom` argument")
return SalingOperator(dom, input)
return ScalingOperator(dom, input)
if dom is not None:
if not dom == input.domain:
raise ValueError("domain mismatch")
......@@ -74,6 +74,9 @@ def test_studentt(field):
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
theta = ift.from_random('normal',field.domain).exp()
energy = ift.StudentTEnergy(domain=field.domain, theta=theta)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
def test_hamiltonian_and_KL(field):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment