Skip to content
Snippets Groups Projects

Fix polynomial fit

Merged Philipp Arras requested to merge fix_polynomial_fit into NIFTy_8
All threads resolved!
+ 19
16
@@ -63,22 +63,19 @@ class PolynomialResponse(ift.LinearOperator):
tgt = ift.UnstructuredDomain(sampling_points.shape)
self._target = ift.DomainTuple.make(tgt)
self._capability = self.TIMES | self.ADJOINT_TIMES
sh = (self.target.size, domain.size)
self._mat = np.empty(sh)
self._mat = np.empty((self.target.size, domain.size))
for d in range(domain.size):
self._mat.T[d] = sampling_points**d
def apply(self, x, mode):
self._check_input(x, mode)
val = x.val_rw()
if mode == self.TIMES:
# FIXME Use polynomial() here
out = self._mat.dot(val)
else:
# FIXME Can this be optimized?
out = self._mat.conj().T.dot(val)
return ift.makeField(self._tgt(mode), out)
m = self._mat
f = m if mode == self.TIMES else m.conj().T
return ift.makeField(self._tgt(mode), f.dot(x.val))
def mock_fct(x):
return np.sin(x**2 / 10) * x**3
def main():
@@ -86,12 +83,14 @@ def main():
N_params = 10
N_samples = 100
size = (12,)
x = ift.random.current_rng().random(size) * 10
y = np.sin(x**2) * x**3
y = mock_fct(x)
var = np.full_like(y, y.var() / 10)
var[-2] *= 4
var[5] /= 2
y[5] -= 0
var[5] -= 1
# Set up minimization problem
p_space = ift.UnstructuredDomain(N_params)
@@ -104,11 +103,11 @@ def main():
N = ift.makeOp(ift.makeField(d_space, var), sampling_dtype=float)
IC = ift.DeltaEnergyController(tol_rel_deltaE=1e-12, iteration_limit=200)
likelihood_energy = ift.GaussianEnergy(d, N) @ R
likelihood_energy = ift.GaussianEnergy(d, N.inverse) @ R
Ham = ift.StandardHamiltonian(likelihood_energy, IC)
H = ift.EnergyAdapter(params, Ham, want_metric=True)
# Minimize
# Minimize KL
minimizer = ift.NewtonCG(IC)
H, _ = minimizer(H)
@@ -122,6 +121,9 @@ def main():
xmin, xmax = x.min(), x.max()
xs = np.linspace(xmin, xmax, 100)
# Ground truth
y_true = mock_fct(xs)
sc = ift.StatCalculator()
for ii in range(len(samples)):
sc.add(samples[ii])
@@ -132,8 +134,9 @@ def main():
plt.plot(xs, ys, 'k', alpha=.05)
ys = polynomial(H.position, xs)
plt.plot(xs, ys, 'r', linewidth=2., label='Interpolation')
plt.plot(xs, y_true, 'b', linewidth=0.5, label='$\mathrm{f}(x) = \sin({x^2}/{10}) * x^2$')
plt.legend()
plt.savefig('fit.png')
plt.savefig('fit.png', dpi=300)
plt.close()
# Print parameters
Loading