Commit 11387b19 authored by Reimar H Leike's avatar Reimar H Leike
Browse files

extended model tests to test all models

parent 3f730f24
......@@ -72,12 +72,14 @@ class Model_Tests(unittest.TestCase):
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]
))
def testMul(self, type1, type2, space, seed):
def testBinary(self, type1, type2, space, seed):
model1 = self.make_model(
type1, space_key='s1', space=space, seed=seed)['s1']
model2 = self.make_model(
type2, space_key='s2', space=space, seed=seed+1)['s2']
ift.extra.check_value_gradient_consistency(model1*model2)
ift.extra.check_value_gradient_consistency(model1+model2)
ift.extra.check_value_gradient_consistency(model1*3.)
@expand(product(
['Variable', 'Constant'],
......@@ -92,3 +94,31 @@ class Model_Tests(unittest.TestCase):
lin_op = self.make_linear_operator('ScalingOperator', space=space)
model2 = self.make_model('LinearModel', model=model1, lin_op=lin_op)
ift.extra.check_value_gradient_consistency(model1*model2)
@expand(product(
['Variable', 'Constant'],
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]
))
def testLocalModel(self, type, space, seed):
model = self.make_model(
type, space_key='s', space=space, seed=seed)['s']
ift.extra.check_value_gradient_consistency(ift.PointwiseExponential(model))
ift.extra.check_value_gradient_consistency(ift.PointwiseTanh(model))
ift.extra.check_value_gradient_consistency(ift.PointwisePositiveTanh(model))
@expand(product(
['Variable', 'Constant'],
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]
))
def testMultiModel(self, type, space, seed):
model = self.make_model(
type, space_key='s', space=space, seed=seed)['s']
mmodel = ift.MultiModel(model, 'g')
ift.extra.check_value_gradient_consistency(mmodel)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment