Commit 3f3cb38e authored by Philipp Frank's avatar Philipp Frank
Browse files

cleanup

parent 9e71ab4c
......@@ -28,8 +28,13 @@ from ..operators.linear_operator import LinearOperator
from ..operators.multifield2vector import Multifield2Vector
from ..operators.sandwich_operator import SandwichOperator
from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import full, makeField, from_random, is_fieldlike
from ..sugar import full, makeField, makeDomain, from_random, is_fieldlike
from ..minimization.energy_adapter import StochasticEnergyAdapter
from ..utilities import myassert
def _eval(op, position):
return op(position.extract(op.domain))
class MeanFieldVI:
......@@ -59,15 +64,15 @@ class MeanFieldVI:
@property
def mean(self):
return self._mean.force(self._KL.position)
return _eval(self._mean,self._KL.position)
@property
def std(self):
return self._std.force(self._KL.position)
return _eval(self._std,self._KL.position)
@property
def entropy(self):
return self._entropy.force(self._KL.position)
return _eval(self._entropy,self._KL.position)
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
......@@ -77,7 +82,6 @@ class MeanFieldVI:
def minimize(self, minimizer):
self._KL, _ = minimizer(self._KL)
class FullCovarianceVI:
def __init__(self, position, hamiltonian, n_samples, mirror_samples,
initial_sig=1, comm=None, nanisinf=False):
......@@ -86,12 +90,10 @@ class FullCovarianceVI:
"""
Flat = Multifield2Vector(position.domain)
flat_domain = Flat.target[0]
N_tri = flat_domain.shape[0]*(flat_domain.shape[0]+1)//2
triangular_space = DomainTuple.make(UnstructuredDomain(N_tri))
tri = FieldAdapter(triangular_space, 'cov')
mat_space = DomainTuple.make((flat_domain,flat_domain))
lat = FieldAdapter(Flat.target,'latent')
LT = LowerTriangularProjector(triangular_space,mat_space)
LT = LowerTriangularInserter(mat_space)
tri = FieldAdapter(LT.domain, 'cov')
mean = FieldAdapter(flat_domain,'mean')
cov = LT @ tri
matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co')
......@@ -100,14 +102,12 @@ class FullCovarianceVI:
self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup)
Diag = DiagonalSelector(cov.target, Flat.target)
diag_cov = Diag(cov).absolute()
diag_cov = (DiagonalSelector(cov.target) @ cov).absolute()
self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov
diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig))
diag_tri = diag_tri[np.tril_indices(flat_domain.shape[0])]
pos = MultiField.from_dict(
{'mean':Flat(position),
'cov':makeField(triangular_space, diag_tri)})
{'mean': Flat(position),
'cov': LT.adjoint(makeField(mat_space, diag_tri))})
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
......@@ -116,11 +116,11 @@ class FullCovarianceVI:
@property
def mean(self):
return self._mean.force(self._KL.position)
return _eval(self._mean,self._KL.position)
@property
def entropy(self):
return self._entropy.force(self._KL.position)
return _eval(self._entropy,self._KL.position)
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
......@@ -155,21 +155,21 @@ class GaussianEntropy(EnergyOperator):
return res.add_metric(SandwichOperator.make(res.jac))
class LowerTriangularProjector(LinearOperator):
"""Project the DOFs of a triangular matrix into the matrix form.
class LowerTriangularInserter(LinearOperator):
"""Inserts the DOFs of a lower triangular matrix into a matrix.
Parameters
----------
domain: Domain
A one-dimensional domain containing N(N+1)/2 DOFs of a triangular
matrix.
target: Domain
A two-dimensional domain with NxN entries.
"""
def __init__(self, domain, target):
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(target)
def __init__(self, target):
myassert(len(target.shape) == 2)
myassert(target.shape[0] == target.shape[1])
self._target = makeDomain(target)
ndof = (target.shape[0]*(target.shape[0]+1))//2
self._domain = makeDomain(UnstructuredDomain(ndof))
self._indices = np.tril_indices(target.shape[0])
self._capability = self.TIMES | self.ADJOINT_TIMES
......@@ -190,42 +190,16 @@ class DiagonalSelector(LinearOperator):
Parameters
----------
domain: Domain
The two-dimensional domain of the input field
target: Domain
The one-dimensional domain on which the diagonal of the input field is
defined.
The two-dimensional domain of the input field. Must be of shape NxN.
"""
def __init__(self, domain, target):
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
x = np.diag(x.val)
if mode == self.ADJOINT_TIMES:
x = x.reshape(self._domain.shape)
return makeField(self._tgt(mode), x)
class Respacer(LinearOperator):
"""Re-map a field from one domain to another one with the same amounts of
DOFs. Wrapps the numpy.reshape method.
Parameters
----------
domain: Domain
The domain of the input field.
target: Domain
The domain of the output field.
"""
def __init__(self, domain, target):
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(target)
def __init__(self, domain):
myassert(len(domain.shape) == 2)
myassert(domain.shape[0] == domain.shape[1])
self._domain = makeDomain(domain)
self._target = makeDomain(UnstructuredDomain(domain.shape[0]))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
return makeField(self._tgt(mode), x.val.reshape(self._tgt(mode).shape))
return makeField(self._tgt(mode), np.diag(x.val))
......@@ -149,15 +149,15 @@ class StochasticEnergyAdapter(Energy):
@staticmethod
def make(position, op, sampling_keys, n_samples, mirror_samples,
comm=None, nanisinf = False):
"""A variant of `EnergyAdapter` that provides the energy interface for an
operator with a scalar target where parts of the imput are averaged
"""A variant of `EnergyAdapter` that provides the energy interface for
an operator with a scalar target where parts of the imput are averaged
instead of optmized.
Specifically, a set of standart normal distributed
samples are drawn for the input corresponding to `keys` and each sample
gets partially inserted into `op`. The resulting operators are averaged and
represent a stochastic average of an energy with the remaining subdomain
being the DOFs that are considered to be optimization parameters.
gets partially inserted into `op`. The resulting operators are averaged
and represent a stochastic average of an energy with the remaining
subdomain being the DOFs that are considered to be optimization parameters.
Parameters
----------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment