Commit 5572575f authored by Martin Reinecke's avatar Martin Reinecke

adjust polynomial fitting demo

parent d32f675d
...@@ -12,14 +12,14 @@ def polynomial(coefficients, sampling_points): ...@@ -12,14 +12,14 @@ def polynomial(coefficients, sampling_points):
Parameters Parameters
---------- ----------
coefficients: Model coefficients: Field
sampling_points: Numpy array sampling_points: Numpy array
""" """
if not (isinstance(coefficients, ift.Model) if not (isinstance(coefficients, ift.Field)
and isinstance(sampling_points, np.ndarray)): and isinstance(sampling_points, np.ndarray)):
raise TypeError raise TypeError
params = coefficients.value.to_global_data() params = coefficients.to_global_data()
out = np.zeros_like(sampling_points) out = np.zeros_like(sampling_points)
for ii in range(len(params)): for ii in range(len(params)):
out += params[ii] * sampling_points**ii out += params[ii] * sampling_points**ii
...@@ -88,8 +88,7 @@ y[5] -= 0 ...@@ -88,8 +88,7 @@ y[5] -= 0
# Set up minimization problem # Set up minimization problem
p_space = ift.UnstructuredDomain(N_params) p_space = ift.UnstructuredDomain(N_params)
params = ift.Variable(ift.MultiField.from_dict( params = ift.full(p_space, 0.)
{'params': ift.full(p_space, 0.)}))['params']
R = PolynomialResponse(p_space, x) R = PolynomialResponse(p_space, x)
ift.extra.consistency_check(R) ift.extra.consistency_check(R)
...@@ -98,7 +97,9 @@ d = ift.from_global_data(d_space, y) ...@@ -98,7 +97,9 @@ d = ift.from_global_data(d_space, y)
N = ift.DiagonalOperator(ift.from_global_data(d_space, var)) N = ift.DiagonalOperator(ift.from_global_data(d_space, var))
IC = ift.GradientNormController(tol_abs_gradnorm=1e-8) IC = ift.GradientNormController(tol_abs_gradnorm=1e-8)
H = ift.Hamiltonian(ift.GaussianEnergy(R(params), d, N), IC) likelihood = lambda inp: ift.GaussianEnergy(d, N)(R(inp))
H = ift.Hamiltonian(likelihood, IC)
H = ift.EnergyAdapter(params, H)
H = H.make_invertible(IC) H = H.make_invertible(IC)
# Minimize # Minimize
...@@ -116,13 +117,13 @@ xs = np.linspace(xmin, xmax, 100) ...@@ -116,13 +117,13 @@ xs = np.linspace(xmin, xmax, 100)
sc = ift.StatCalculator() sc = ift.StatCalculator()
for ii in range(len(samples)): for ii in range(len(samples)):
sc.add(params.at(samples[ii]).value) sc.add(samples[ii])
ys = polynomial(params.at(samples[ii]), xs) ys = polynomial(samples[ii], xs)
if ii == 0: if ii == 0:
plt.plot(xs, ys, 'k', alpha=.05, label='Posterior samples') plt.plot(xs, ys, 'k', alpha=.05, label='Posterior samples')
continue continue
plt.plot(xs, ys, 'k', alpha=.05) plt.plot(xs, ys, 'k', alpha=.05)
ys = polynomial(params.at(H.position), xs) ys = polynomial(H.position, xs)
plt.plot(xs, ys, 'r', linewidth=2., label='Interpolation') plt.plot(xs, ys, 'r', linewidth=2., label='Interpolation')
plt.legend() plt.legend()
plt.savefig('fit.png') plt.savefig('fit.png')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment