Commit 437a0d80 authored by Martin Reinecke's avatar Martin Reinecke

more tests/enable CorrelatedField

parent 5102e573
......@@ -43,7 +43,7 @@ if __name__ == '__main__':
correlated_field = lambda inp: ht(power_distributor(A(inp))*inp["xi"])
# alternatively to the block above one can do:
# correlated_field,_ = ift.make_correlated_field(position_space, A)
#correlated_field = ift.CorrelatedField(position_space, A)
# apply some nonlinearity
signal = lambda inp: correlated_field(inp).positive_tanh()
......
......@@ -74,7 +74,7 @@ from .library.los_response import LOSResponse
# from .library.point_sources import PointSources
from .library.poissonian_energy import PoissonianEnergy
from .library.wiener_filter_curvature import WienerFilterCurvature
# from .library.correlated_fields import (make_correlated_field,
from .library.correlated_fields import CorrelatedField
# make_mf_correlated_field)
from .library.bernoulli_energy import BernoulliEnergy
......
......@@ -508,7 +508,7 @@ def transpose(arr):
sz = rsz[i]//arr._data.itemsize
arrnew[:, lo:hi] = rbuf[ofs:ofs+sz].reshape(hi-lo, sz2).T
ofs += sz
return from_local_data((arr.shape[1],arr.shape[0]), arrnew, 0)
return from_local_data((arr.shape[1], arr.shape[0]), arrnew, 0)
def default_distaxis():
......
......@@ -20,19 +20,17 @@ from __future__ import absolute_import, division, print_function
from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field
from ..models.local_nonlinearity import PointwiseExponential
from ..models.variable import Variable
from ..multi.multi_field import MultiField
from ..multi.multi_domain import MultiDomain
from ..operators.domain_distributor import DomainDistributor
from ..operators.hartley_operator import HartleyOperator
from ..operators.harmonic_transform_operator import HarmonicTransformOperator
from ..operators.power_distributor import PowerDistributor
from ..operator import Operator
def make_correlated_field(s_space, amplitude_model):
class CorrelatedField(Operator):
'''
Method for construction of correlated fields
Class for construction of correlated fields
Parameters
----------
......@@ -40,53 +38,58 @@ def make_correlated_field(s_space, amplitude_model):
amplitude_model : model for correlation structure
'''
h_space = s_space.get_default_codomain()
ht = HarmonicTransformOperator(h_space, s_space)
p_space = amplitude_model.value.domain[0]
power_distributor = PowerDistributor(h_space, p_space)
def __init__(self, s_space, amplitude_model):
self._s_space = s_space
self._amplitude_model = amplitude_model
self._h_space = s_space.get_default_codomain()
self._ht = HarmonicTransformOperator(self._h_space, s_space)
self._p_space = amplitude_model.target[0]
self._power_distributor = PowerDistributor(self._h_space,
self._p_space)
position = MultiField.from_dict({'xi': Field.full(h_space, 0.)})
xi = Variable(position)['xi']
@property
def domain(self):
return MultiDomain.union(
(self._amplitude_model.domain,
MultiDomain.make({"xi": self._h_space})))
A = power_distributor(amplitude_model)
correlated_field_h = A * xi
correlated_field = ht(correlated_field_h)
internals = {'correlated_field_h': correlated_field_h,
'power_distributor': power_distributor,
'ht': ht}
return correlated_field, internals
def __call__(self, x):
A = self._power_distributor(self._amplitude_model(x))
correlated_field_h = A * x["xi"]
correlated_field = self._ht(correlated_field_h)
return correlated_field
def make_mf_correlated_field(s_space_spatial, s_space_energy,
amplitude_model_spatial, amplitude_model_energy):
'''
Method for construction of correlated multi-frequency fields
'''
h_space_spatial = s_space_spatial.get_default_codomain()
h_space_energy = s_space_energy.get_default_codomain()
h_space = DomainTuple.make((h_space_spatial, h_space_energy))
ht1 = HarmonicTransformOperator(h_space, space=0)
ht2 = HarmonicTransformOperator(ht1.target, space=1)
ht = ht2*ht1
p_space_spatial = amplitude_model_spatial.value.domain[0]
p_space_energy = amplitude_model_energy.value.domain[0]
pd_spatial = PowerDistributor(h_space, p_space_spatial, 0)
pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1)
pd = pd_spatial*pd_energy
dom_distr_spatial = DomainDistributor(pd.domain, 0)
dom_distr_energy = DomainDistributor(pd.domain, 1)
a_spatial = dom_distr_spatial(amplitude_model_spatial)
a_energy = dom_distr_energy(amplitude_model_energy)
a = a_spatial*a_energy
A = pd(a)
position = MultiField.from_dict(
{'xi': Field.from_random('normal', h_space)})
xi = Variable(position)['xi']
correlated_field_h = A*xi
correlated_field = ht(correlated_field_h)
return PointwiseExponential(correlated_field)
# def make_mf_correlated_field(s_space_spatial, s_space_energy,
# amplitude_model_spatial, amplitude_model_energy):
# '''
# Method for construction of correlated multi-frequency fields
# '''
# h_space_spatial = s_space_spatial.get_default_codomain()
# h_space_energy = s_space_energy.get_default_codomain()
# h_space = DomainTuple.make((h_space_spatial, h_space_energy))
# ht1 = HarmonicTransformOperator(h_space, space=0)
# ht2 = HarmonicTransformOperator(ht1.target, space=1)
# ht = ht2*ht1
#
# p_space_spatial = amplitude_model_spatial.value.domain[0]
# p_space_energy = amplitude_model_energy.value.domain[0]
#
# pd_spatial = PowerDistributor(h_space, p_space_spatial, 0)
# pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1)
# pd = pd_spatial*pd_energy
#
# dom_distr_spatial = DomainDistributor(pd.domain, 0)
# dom_distr_energy = DomainDistributor(pd.domain, 1)
#
# a_spatial = dom_distr_spatial(amplitude_model_spatial)
# a_energy = dom_distr_energy(amplitude_model_energy)
# a = a_spatial*a_energy
# A = pd(a)
#
# position = MultiField.from_dict(
# {'xi': Field.from_random('normal', h_space)})
# xi = Variable(position)['xi']
# correlated_field_h = A*xi
# correlated_field = ht(correlated_field_h)
# return PointwiseExponential(correlated_field)
......@@ -25,6 +25,17 @@ import numpy as np
class Model_Tests(unittest.TestCase):
@staticmethod
def make_linearization(type, space, seed):
np.random.seed(seed)
S = ift.ScalingOperator(1., space)
s = S.draw_sample()
if type == "Constant":
return ift.Linearization.make_const(s)
elif type == "Variable":
return ift.Linearization.make_var(s)
raise ValueError('unknown type passed')
def make_model(self, type, **kwargs):
if type == 'Constant':
np.random.seed(kwargs['seed'])
......@@ -53,16 +64,15 @@ class Model_Tests(unittest.TestCase):
return lin_op
@expand(product(
['Variable', 'Constant'],
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]
))
def testBasics(self, type1, space, seed):
model1 = self.make_model(
type1, space_key='s1', space=space, seed=seed)['s1']
ift.extra.check_value_gradient_consistency(model1)
def testBasics(self, space, seed):
var = self.make_linearization("Variable", space, seed)
model = lambda inp: inp
ift.extra.check_value_gradient_consistency(model, var.val)
@expand(product(
['Variable', 'Constant'],
......@@ -73,43 +83,33 @@ class Model_Tests(unittest.TestCase):
[4, 78, 23]
))
def testBinary(self, type1, type2, space, seed):
model1 = self.make_model(
type1, space_key='s1', space=space, seed=seed)['s1']
model2 = self.make_model(
type2, space_key='s2', space=space, seed=seed+1)['s2']
ift.extra.check_value_gradient_consistency(model1*model2)
ift.extra.check_value_gradient_consistency(model1+model2)
ift.extra.check_value_gradient_consistency(model1*3.)
dom1 = ift.MultiDomain.make({'s1': space})
lin1 = self.make_linearization(type1, dom1, seed)
dom2 = ift.MultiDomain.make({'s2': space})
lin2 = self.make_linearization(type2, dom2, seed)
@expand(product(
['Variable', 'Constant'],
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]
))
def testLinModel(self, type1, space, seed):
model1 = self.make_model(
type1, space_key='s1', space=space, seed=seed)['s1']
lin_op = self.make_linear_operator('ScalingOperator', space=space)
model2 = self.make_model('LinearModel', model=model1, lin_op=lin_op)
ift.extra.check_value_gradient_consistency(model1*model2)
@expand(product(
['Variable', 'Constant'],
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]
))
def testLocalModel(self, type, space, seed):
model = self.make_model(
type, space_key='s', space=space, seed=seed)['s']
ift.extra.check_value_gradient_consistency(
ift.PointwiseExponential(model))
ift.extra.check_value_gradient_consistency(ift.PointwiseTanh(model))
ift.extra.check_value_gradient_consistency(
ift.PointwisePositiveTanh(model))
dom = ift.MultiDomain.union((dom1, dom2))
model = lambda inp: inp["s1"]*inp["s2"]
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
model = lambda inp: inp["s1"]+inp["s2"]
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
model = lambda inp: inp["s1"]*3.
pos = ift.from_random("normal", dom1)
ift.extra.check_value_gradient_consistency(model, pos)
model = lambda inp: ift.ScalingOperator(2.456, space)(
inp["s1"]*inp["s2"])
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
model = lambda inp: ift.ScalingOperator(2.456, space)(
inp["s1"]*inp["s2"]).positive_tanh()
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
if isinstance(space, ift.RGSpace):
model = lambda inp: ift.FFTOperator(space)(inp["s1"]*inp["s2"])
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
@expand(product(
[ift.GLSpace(15),
......@@ -128,42 +128,42 @@ class Model_Tests(unittest.TestCase):
ceps_k, sm, sv, im, iv, seed):
# tests amplitude model and coorelated field model
np.random.seed(seed)
model = ift.make_amplitude_model(space, Npixdof, ceps_a, ceps_k, sm,
sv, im, iv)[0]
S = ift.ScalingOperator(1., model.position.domain)
model = model.at(S.draw_sample())
ift.extra.check_value_gradient_consistency(model)
model = ift.AmplitudeModel(space, Npixdof, ceps_a, ceps_k, sm,
sv, im, iv)
S = ift.ScalingOperator(1., model.domain)
pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model, pos)
model2 = ift.make_correlated_field(space, model)[0]
S = ift.ScalingOperator(1., model2.position.domain)
model2 = model2.at(S.draw_sample())
ift.extra.check_value_gradient_consistency(model2)
model2 = ift.CorrelatedField(space, model)
S = ift.ScalingOperator(1., model2.domain)
pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model2, pos)
@expand(product(
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]))
def testPointModel(seld, space, seed):
S = ift.ScalingOperator(1., space)
pos = ift.MultiField.from_dict(
{'points': S.draw_sample()})
alpha = 1.5
q = 0.73
model = ift.PointSources(pos, alpha, q)
# FIXME All those cdfs and ppfs are not that accurate
ift.extra.check_value_gradient_consistency(model, tol=1e-5)
@expand(product(
['Variable', 'Constant'],
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]
))
def testMultiModel(self, type, space, seed):
model = self.make_model(
type, space_key='s', space=space, seed=seed)['s']
mmodel = ift.MultiModel(model, 'g')
ift.extra.check_value_gradient_consistency(mmodel)
# @expand(product(
# [ift.GLSpace(15),
# ift.RGSpace(64, distances=.789),
# ift.RGSpace([32, 32], distances=.789)],
# [4, 78, 23]))
# def testPointModel(seld, space, seed):
#
# S = ift.ScalingOperator(1., space)
# pos = ift.MultiField.from_dict(
# {'points': S.draw_sample()})
# alpha = 1.5
# q = 0.73
# model = ift.PointSources(pos, alpha, q)
# # FIXME All those cdfs and ppfs are not that accurate
# ift.extra.check_value_gradient_consistency(model, tol=1e-5)
#
# @expand(product(
# ['Variable', 'Constant'],
# [ift.GLSpace(15),
# ift.RGSpace(64, distances=.789),
# ift.RGSpace([32, 32], distances=.789)],
# [4, 78, 23]
# ))
# def testMultiModel(self, type, space, seed):
# model = self.make_model(
# type, space_key='s', space=space, seed=seed)['s']
# mmodel = ift.MultiModel(model, 'g')
# ift.extra.check_value_gradient_consistency(mmodel)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment