Commit e7ff7772 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'curvefitting' into 'NIFTy_5'

Curvefitting

See merge request ift/nifty-dev!62
parents bfd57f63 10d47d08
...@@ -111,3 +111,12 @@ run_bernoulli: ...@@ -111,3 +111,12 @@ run_bernoulli:
artifacts: artifacts:
paths: paths:
- '*.png' - '*.png'
run_curve_fitting:
stage: demo_runs
script:
- python demos/polynomial_fit.py
- python3 demos/polynomial_fit.py
artifacts:
paths:
- '*.png'
import matplotlib.pyplot as plt
import numpy as np
import nifty5 as ift
np.random.seed(12)
def polynomial(coefficients, sampling_points):
"""Computes values of polynomial whose coefficients are stored in
coefficients at sampling points. This is a quick version of the
PolynomialResponse.
Parameters
----------
coefficients: Model
sampling_points: Numpy array
"""
if not (isinstance(coefficients, ift.Model)
and isinstance(sampling_points, np.ndarray)):
raise TypeError
params = coefficients.value.to_global_data()
out = np.zeros_like(sampling_points)
for ii in range(len(params)):
out += params[ii] * sampling_points**ii
return out
class PolynomialResponse(ift.LinearOperator):
"""Calculates values of a polynomial parameterized by input at sampling points.
Parameters
----------
domain: UnstructuredDomain
The domain on which the coefficients of the polynomial are defined.
sampling_points: Numpy array
x-values of the sampling points.
"""
def __init__(self, domain, sampling_points):
super(PolynomialResponse, self).__init__()
if not (isinstance(domain, ift.UnstructuredDomain)
and isinstance(x, np.ndarray)):
raise TypeError
self._domain = ift.DomainTuple.make(domain)
tgt = ift.UnstructuredDomain(sampling_points.shape)
self._target = ift.DomainTuple.make(tgt)
sh = (self.target.size, domain.size)
self._mat = np.empty(sh)
for d in range(domain.size):
self._mat.T[d] = sampling_points**d
def apply(self, x, mode):
self._check_input(x, mode)
val = x.to_global_data()
if mode == self.TIMES:
# FIXME Use polynomial() here
out = self._mat.dot(val)
else:
# FIXME Can this be optimized?
out = self._mat.conj().T.dot(val)
return ift.from_global_data(self._tgt(mode), out)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
# Generate some mock data
N_params = 10
N_samples = 100
size = (12,)
x = np.random.random(size) * 10
y = np.sin(x**2) * x**3
var = np.full_like(y, y.var() / 10)
var[-2] *= 4
var[5] /= 2
y[5] -= 0
# Set up minimization problem
p_space = ift.UnstructuredDomain(N_params)
params = ift.Variable(ift.MultiField.from_dict(
{'params': ift.full(p_space, 0.)}))['params']
R = PolynomialResponse(p_space, x)
ift.extra.consistency_check(R)
d_space = R.target
d = ift.from_global_data(d_space, y)
N = ift.DiagonalOperator(ift.from_global_data(d_space, var))
IC = ift.GradientNormController(tol_abs_gradnorm=1e-8)
H = ift.Hamiltonian(ift.GaussianEnergy(R(params), d, N), IC)
H = H.make_invertible(IC)
# Minimize
minimizer = ift.RelaxedNewton(IC)
H, _ = minimizer(H)
# Draw posterior samples
samples = [H.metric.draw_sample(from_inverse=True) + H.position
for _ in range(N_samples)]
# Plotting
plt.errorbar(x, y, np.sqrt(var), fmt='ko', label='Data with error bars')
xmin, xmax = x.min(), x.max()
xs = np.linspace(xmin, xmax, 100)
sc = ift.StatCalculator()
for ii in range(len(samples)):
sc.add(params.at(samples[ii]).value)
ys = polynomial(params.at(samples[ii]), xs)
if ii == 0:
plt.plot(xs, ys, 'k', alpha=.05, label='Posterior samples')
continue
plt.plot(xs, ys, 'k', alpha=.05)
ys = polynomial(params.at(H.position), xs)
plt.plot(xs, ys, 'r', linewidth=2., label='Interpolation')
plt.legend()
plt.savefig('fit.png')
plt.close()
# Print parameters
mean = sc.mean.to_global_data()
sigma = np.sqrt(sc.var.to_global_data())
for ii in range(len(mean)):
print('Coefficient x**{}: {:.2E} +/- {:.2E}'.format(ii, mean[ii],
sigma[ii]))
...@@ -27,6 +27,7 @@ from ..utilities import NiftyMetaBase ...@@ -27,6 +27,7 @@ from ..utilities import NiftyMetaBase
class Domain(NiftyMetaBase()): class Domain(NiftyMetaBase()):
"""The abstract class repesenting a (structured or unstructured) domain. """The abstract class repesenting a (structured or unstructured) domain.
""" """
def __init__(self): def __init__(self):
self._hash = None self._hash = None
......
...@@ -132,6 +132,7 @@ class LOSResponse(LinearOperator): ...@@ -132,6 +132,7 @@ class LOSResponse(LinearOperator):
every calling MPI task (i.e. the full LOS information has to be provided on every calling MPI task (i.e. the full LOS information has to be provided on
every task). every task).
""" """
def __init__(self, domain, starts, ends, sigmas_low=None, sigmas_up=None): def __init__(self, domain, starts, ends, sigmas_low=None, sigmas_up=None):
super(LOSResponse, self).__init__() super(LOSResponse, self).__init__()
......
...@@ -29,6 +29,7 @@ class RelaxedNewton(DescentMinimizer): ...@@ -29,6 +29,7 @@ class RelaxedNewton(DescentMinimizer):
The descent direction is determined by weighting the gradient at the The descent direction is determined by weighting the gradient at the
current parameter position with the inverse local metric. current parameter position with the inverse local metric.
""" """
def __init__(self, controller, line_searcher=None): def __init__(self, controller, line_searcher=None):
if line_searcher is None: if line_searcher is None:
line_searcher = LineSearchStrongWolfe( line_searcher = LineSearchStrongWolfe(
......
...@@ -28,5 +28,6 @@ class SteepestDescent(DescentMinimizer): ...@@ -28,5 +28,6 @@ class SteepestDescent(DescentMinimizer):
Also known as 'gradient descent'. This algorithm simply follows the Also known as 'gradient descent'. This algorithm simply follows the
functional's gradient for minimization. functional's gradient for minimization.
""" """
def get_descent_direction(self, energy): def get_descent_direction(self, energy):
return -energy.gradient return -energy.gradient
...@@ -106,6 +106,7 @@ class _InformationStore(object): ...@@ -106,6 +106,7 @@ class _InformationStore(object):
yy : numpy.ndarray yy : numpy.ndarray
2D circular buffer of scalar products between different elements of y. 2D circular buffer of scalar products between different elements of y.
""" """
def __init__(self, max_history_length, x0, gradient): def __init__(self, max_history_length, x0, gradient):
self.max_history_length = max_history_length self.max_history_length = max_history_length
self.s = [None]*max_history_length self.s = [None]*max_history_length
......
...@@ -35,6 +35,7 @@ def _joint_position(model1, model2): ...@@ -35,6 +35,7 @@ def _joint_position(model1, model2):
class ScalarMul(Model): class ScalarMul(Model):
"""Class representing a model multiplied by a scalar factor.""" """Class representing a model multiplied by a scalar factor."""
def __init__(self, factor, model): def __init__(self, factor, model):
super(ScalarMul, self).__init__(model.position) super(ScalarMul, self).__init__(model.position)
# TODO -> floating # TODO -> floating
...@@ -53,6 +54,7 @@ class ScalarMul(Model): ...@@ -53,6 +54,7 @@ class ScalarMul(Model):
class Add(Model): class Add(Model):
"""Class representing the sum of two models.""" """Class representing the sum of two models."""
def __init__(self, position, model1, model2): def __init__(self, position, model1, model2):
super(Add, self).__init__(position) super(Add, self).__init__(position)
...@@ -83,6 +85,7 @@ class Add(Model): ...@@ -83,6 +85,7 @@ class Add(Model):
class Mul(Model): class Mul(Model):
"""Class representing the pointwise product of two models.""" """Class representing the pointwise product of two models."""
def __init__(self, position, model1, model2): def __init__(self, position, model1, model2):
super(Mul, self).__init__(position) super(Mul, self).__init__(position)
......
...@@ -39,6 +39,7 @@ class Constant(Model): ...@@ -39,6 +39,7 @@ class Constant(Model):
- Position has no influence on value. - Position has no influence on value.
- The Jacobian is a null matrix. - The Jacobian is a null matrix.
""" """
def __init__(self, position, constant): def __init__(self, position, constant):
super(Constant, self).__init__(position) super(Constant, self).__init__(position)
self._constant = constant self._constant = constant
......
...@@ -47,6 +47,7 @@ class Model(NiftyMetaBase()): ...@@ -47,6 +47,7 @@ class Model(NiftyMetaBase()):
one automatically gets the value and Jacobian of the model. The 'at' method one automatically gets the value and Jacobian of the model. The 'at' method
creates a new instance of the class. creates a new instance of the class.
""" """
def __init__(self, position): def __init__(self, position):
self._position = position self._position = position
......
...@@ -27,6 +27,7 @@ from .model import Model ...@@ -27,6 +27,7 @@ from .model import Model
class MultiModel(Model): class MultiModel(Model):
""" """ """ """
def __init__(self, model, key): def __init__(self, model, key):
# TODO Rewrite it such that it takes a dictionary as input. # TODO Rewrite it such that it takes a dictionary as input.
# (just like MultiFields). # (just like MultiFields).
......
...@@ -31,6 +31,7 @@ class Variable(Model): ...@@ -31,6 +31,7 @@ class Variable(Model):
position : Field or MultiField position : Field or MultiField
The current position in parameter space. The current position in parameter space.
""" """
def __init__(self, position): def __init__(self, position):
super(Variable, self).__init__(position) super(Variable, self).__init__(position)
......
...@@ -54,7 +54,7 @@ class HartleyOperator(LinearOperator): ...@@ -54,7 +54,7 @@ class HartleyOperator(LinearOperator):
of the result field, respectivey. of the result field, respectivey.
In many contexts the Hartley transform is a perfect substitute for the In many contexts the Hartley transform is a perfect substitute for the
Fourier transform, but in some situations (e.g. convolution with a general, Fourier transform, but in some situations (e.g. convolution with a general,
non-symmetrc kernel, the full FFT must be used instead. non-symmetric kernel, the full FFT must be used instead.
""" """
def __init__(self, domain, target=None, space=None): def __init__(self, domain, target=None, space=None):
......
...@@ -34,6 +34,7 @@ class NullOperator(LinearOperator): ...@@ -34,6 +34,7 @@ class NullOperator(LinearOperator):
target : DomainTuple or MultiDomain target : DomainTuple or MultiDomain
output domain output domain
""" """
def __init__(self, domain, target): def __init__(self, domain, target):
from ..sugar import makeDomain from ..sugar import makeDomain
self._domain = makeDomain(domain) self._domain = makeDomain(domain)
......
...@@ -42,6 +42,7 @@ class QHTOperator(LinearOperator): ...@@ -42,6 +42,7 @@ class QHTOperator(LinearOperator):
The index of the domain on which the operator acts. The index of the domain on which the operator acts.
target[space] must be a nonharmonic LogRGSpace. target[space] must be a nonharmonic LogRGSpace.
""" """
def __init__(self, target, space=0): def __init__(self, target, space=0):
self._target = DomainTuple.make(target) self._target = DomainTuple.make(target)
self._space = infer_space(self._target, space) self._space = infer_space(self._target, space)
......
...@@ -111,4 +111,4 @@ class ScalingOperator(EndomorphicOperator): ...@@ -111,4 +111,4 @@ class ScalingOperator(EndomorphicOperator):
fct = 1./np.sqrt(fct) if from_inverse else np.sqrt(fct) fct = 1./np.sqrt(fct) if from_inverse else np.sqrt(fct)
cls = Field if isinstance(self._domain, DomainTuple) else MultiField cls = Field if isinstance(self._domain, DomainTuple) else MultiField
return cls.from_random( return cls.from_random(
random_type="normal", domain=self._domain, std=fct, dtype=dtype) random_type="normal", domain=self._domain, std=fct, dtype=dtype)
...@@ -35,6 +35,7 @@ class SelectionOperator(LinearOperator): ...@@ -35,6 +35,7 @@ class SelectionOperator(LinearOperator):
key : :class:`str` key : :class:`str`
String identifier of the wanted subdomain String identifier of the wanted subdomain
""" """
def __init__(self, domain, key): def __init__(self, domain, key):
self._domain = MultiDomain.make(domain) self._domain = MultiDomain.make(domain)
self._key = key self._key = key
......
...@@ -89,7 +89,7 @@ def get_slice_list(shape, axes): ...@@ -89,7 +89,7 @@ def get_slice_list(shape, axes):
slice_list = [ slice_list = [
next(it_iter) next(it_iter)
if axis else slice(None, None) for axis in axes_select if axis else slice(None, None) for axis in axes_select
] ]
yield slice_list yield slice_list
else: else:
yield [slice(None, None)] yield [slice(None, None)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment