Commit 9a4d6183 authored by Silvan Streit's avatar Silvan Streit
Browse files

Test models: WIP consistency check

* fails for constant models
parent bfd57f63
......@@ -25,15 +25,70 @@ import numpy as np
class Model_Tests(unittest.TestCase):
@expand(product([ift.GLSpace(15),
def make_model(self, type, **kwargs):
if type == 'Constant':
np.random.seed(kwargs['seed'])
S = ift.ScalingOperator(1., kwargs['space'])
s = S.draw_sample()
return ift.Constant(
ift.MultiField.from_dict({kwargs['space_key']: s}),
ift.MultiField.from_dict({kwargs['space_key']: s}))
elif type == 'Variable':
np.random.seed(kwargs['seed'])
S = ift.ScalingOperator(1., kwargs['space'])
s = S.draw_sample()
return ift.Variable(
ift.MultiField.from_dict({kwargs['space_key']: s}))
elif type == 'LinearModel':
return ift.LinearModel(
inp=kwargs['model'], lin_op=kwargs['lin_op'])
else:
raise ValueError('unknown type passed')
def make_linear_operator(self, type, **kwargs):
if type == 'ScalingOperator':
lin_op = ift.ScalingOperator(1., kwargs['space'])
else:
raise ValueError('unknown type passed')
return lin_op
@expand(product(
['Variable', 'Constant'],
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]
))
def testBasics(self, type1, space, seed):
model1 = self.make_model(
type1, space_key='s1', space=space, seed=seed)['s1']
ift.extra.check_value_gradient_consistency(model1)
@expand(product(
['Variable', 'Constant'],
['Variable'],
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]
))
def testMul(self, type1, type2, space, seed):
model1 = self.make_model(
type1, space_key='s1', space=space, seed=seed)['s1']
model2 = self.make_model(
type2, space_key='s2', space=space, seed=seed+1)['s2']
ift.extra.check_value_gradient_consistency(model1*model2)
@expand(product(
['Variable', 'Constant'],
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]))
def testMul(self, space, seed):
np.random.seed(seed)
S = ift.ScalingOperator(1., space)
s1 = S.draw_sample()
s2 = S.draw_sample()
s1_var = ift.Variable(ift.MultiField.from_dict({'s1': s1}))['s1']
s2_var = ift.Variable(ift.MultiField.from_dict({'s2': s2}))['s2']
ift.extra.check_value_gradient_consistency(s1_var*s2_var)
[4, 78, 23]
))
def testLinModel(self, type1, space, seed):
model1 = self.make_model(
type1, space_key='s1', space=space, seed=seed)['s1']
lin_op = self.make_linear_operator('ScalingOperator', space=space)
model2 = self.make_model('LinearModel', model=model1, lin_op=lin_op)
ift.extra.check_value_gradient_consistency(model1*model2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment