Commit 3f3cb38e authored by Philipp Frank's avatar Philipp Frank
Browse files

cleanup

parent 9e71ab4c
...@@ -28,8 +28,13 @@ from ..operators.linear_operator import LinearOperator ...@@ -28,8 +28,13 @@ from ..operators.linear_operator import LinearOperator
from ..operators.multifield2vector import Multifield2Vector from ..operators.multifield2vector import Multifield2Vector
from ..operators.sandwich_operator import SandwichOperator from ..operators.sandwich_operator import SandwichOperator
from ..operators.simple_linear_operators import FieldAdapter from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import full, makeField, from_random, is_fieldlike from ..sugar import full, makeField, makeDomain, from_random, is_fieldlike
from ..minimization.energy_adapter import StochasticEnergyAdapter from ..minimization.energy_adapter import StochasticEnergyAdapter
from ..utilities import myassert
def _eval(op, position):
return op(position.extract(op.domain))
class MeanFieldVI: class MeanFieldVI:
...@@ -59,15 +64,15 @@ class MeanFieldVI: ...@@ -59,15 +64,15 @@ class MeanFieldVI:
@property @property
def mean(self): def mean(self):
return self._mean.force(self._KL.position) return _eval(self._mean,self._KL.position)
@property @property
def std(self): def std(self):
return self._std.force(self._KL.position) return _eval(self._std,self._KL.position)
@property @property
def entropy(self): def entropy(self):
return self._entropy.force(self._KL.position) return _eval(self._entropy,self._KL.position)
def draw_sample(self): def draw_sample(self):
_, op = self._generator.simplify_for_constant_input( _, op = self._generator.simplify_for_constant_input(
...@@ -77,7 +82,6 @@ class MeanFieldVI: ...@@ -77,7 +82,6 @@ class MeanFieldVI:
def minimize(self, minimizer): def minimize(self, minimizer):
self._KL, _ = minimizer(self._KL) self._KL, _ = minimizer(self._KL)
class FullCovarianceVI: class FullCovarianceVI:
def __init__(self, position, hamiltonian, n_samples, mirror_samples, def __init__(self, position, hamiltonian, n_samples, mirror_samples,
initial_sig=1, comm=None, nanisinf=False): initial_sig=1, comm=None, nanisinf=False):
...@@ -86,12 +90,10 @@ class FullCovarianceVI: ...@@ -86,12 +90,10 @@ class FullCovarianceVI:
""" """
Flat = Multifield2Vector(position.domain) Flat = Multifield2Vector(position.domain)
flat_domain = Flat.target[0] flat_domain = Flat.target[0]
N_tri = flat_domain.shape[0]*(flat_domain.shape[0]+1)//2
triangular_space = DomainTuple.make(UnstructuredDomain(N_tri))
tri = FieldAdapter(triangular_space, 'cov')
mat_space = DomainTuple.make((flat_domain,flat_domain)) mat_space = DomainTuple.make((flat_domain,flat_domain))
lat = FieldAdapter(Flat.target,'latent') lat = FieldAdapter(Flat.target,'latent')
LT = LowerTriangularProjector(triangular_space,mat_space) LT = LowerTriangularInserter(mat_space)
tri = FieldAdapter(LT.domain, 'cov')
mean = FieldAdapter(flat_domain,'mean') mean = FieldAdapter(flat_domain,'mean')
cov = LT @ tri cov = LT @ tri
matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co') matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co')
...@@ -100,14 +102,12 @@ class FullCovarianceVI: ...@@ -100,14 +102,12 @@ class FullCovarianceVI:
self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup) self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup)
Diag = DiagonalSelector(cov.target, Flat.target) diag_cov = (DiagonalSelector(cov.target) @ cov).absolute()
diag_cov = Diag(cov).absolute()
self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov
diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig)) diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig))
diag_tri = diag_tri[np.tril_indices(flat_domain.shape[0])]
pos = MultiField.from_dict( pos = MultiField.from_dict(
{'mean':Flat(position), {'mean': Flat(position),
'cov':makeField(triangular_space, diag_tri)}) 'cov': LT.adjoint(makeField(mat_space, diag_tri))})
op = hamiltonian(self._generator) + self._entropy op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples, self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm) mirror_samples, nanisinf=nanisinf, comm=comm)
...@@ -116,11 +116,11 @@ class FullCovarianceVI: ...@@ -116,11 +116,11 @@ class FullCovarianceVI:
@property @property
def mean(self): def mean(self):
return self._mean.force(self._KL.position) return _eval(self._mean,self._KL.position)
@property @property
def entropy(self): def entropy(self):
return self._entropy.force(self._KL.position) return _eval(self._entropy,self._KL.position)
def draw_sample(self): def draw_sample(self):
_, op = self._generator.simplify_for_constant_input( _, op = self._generator.simplify_for_constant_input(
...@@ -155,21 +155,21 @@ class GaussianEntropy(EnergyOperator): ...@@ -155,21 +155,21 @@ class GaussianEntropy(EnergyOperator):
return res.add_metric(SandwichOperator.make(res.jac)) return res.add_metric(SandwichOperator.make(res.jac))
class LowerTriangularProjector(LinearOperator): class LowerTriangularInserter(LinearOperator):
"""Project the DOFs of a triangular matrix into the matrix form. """Inserts the DOFs of a lower triangular matrix into a matrix.
Parameters Parameters
---------- ----------
domain: Domain
A one-dimensional domain containing N(N+1)/2 DOFs of a triangular
matrix.
target: Domain target: Domain
A two-dimensional domain with NxN entries. A two-dimensional domain with NxN entries.
""" """
def __init__(self, domain, target): def __init__(self, target):
self._domain = DomainTuple.make(domain) myassert(len(target.shape) == 2)
self._target = DomainTuple.make(target) myassert(target.shape[0] == target.shape[1])
self._target = makeDomain(target)
ndof = (target.shape[0]*(target.shape[0]+1))//2
self._domain = makeDomain(UnstructuredDomain(ndof))
self._indices = np.tril_indices(target.shape[0]) self._indices = np.tril_indices(target.shape[0])
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
...@@ -190,42 +190,16 @@ class DiagonalSelector(LinearOperator): ...@@ -190,42 +190,16 @@ class DiagonalSelector(LinearOperator):
Parameters Parameters
---------- ----------
domain: Domain domain: Domain
The two-dimensional domain of the input field The two-dimensional domain of the input field. Must be of shape NxN.
target: Domain
The one-dimensional domain on which the diagonal of the input field is
defined.
""" """
def __init__(self, domain, target): def __init__(self, domain):
self._domain = DomainTuple.make(domain) myassert(len(domain.shape) == 2)
self._target = DomainTuple.make(target) myassert(domain.shape[0] == domain.shape[1])
self._capability = self.TIMES | self.ADJOINT_TIMES self._domain = makeDomain(domain)
self._target = makeDomain(UnstructuredDomain(domain.shape[0]))
def apply(self, x, mode):
self._check_input(x, mode)
x = np.diag(x.val)
if mode == self.ADJOINT_TIMES:
x = x.reshape(self._domain.shape)
return makeField(self._tgt(mode), x)
class Respacer(LinearOperator):
"""Re-map a field from one domain to another one with the same amounts of
DOFs. Wrapps the numpy.reshape method.
Parameters
----------
domain: Domain
The domain of the input field.
target: Domain
The domain of the output field.
"""
def __init__(self, domain, target):
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
return makeField(self._tgt(mode), x.val.reshape(self._tgt(mode).shape)) return makeField(self._tgt(mode), np.diag(x.val))
...@@ -149,15 +149,15 @@ class StochasticEnergyAdapter(Energy): ...@@ -149,15 +149,15 @@ class StochasticEnergyAdapter(Energy):
@staticmethod @staticmethod
def make(position, op, sampling_keys, n_samples, mirror_samples, def make(position, op, sampling_keys, n_samples, mirror_samples,
comm=None, nanisinf = False): comm=None, nanisinf = False):
"""A variant of `EnergyAdapter` that provides the energy interface for an """A variant of `EnergyAdapter` that provides the energy interface for
operator with a scalar target where parts of the imput are averaged an operator with a scalar target where parts of the imput are averaged
instead of optmized. instead of optmized.
Specifically, a set of standart normal distributed Specifically, a set of standart normal distributed
samples are drawn for the input corresponding to `keys` and each sample samples are drawn for the input corresponding to `keys` and each sample
gets partially inserted into `op`. The resulting operators are averaged and gets partially inserted into `op`. The resulting operators are averaged
represent a stochastic average of an energy with the remaining subdomain and represent a stochastic average of an energy with the remaining
being the DOFs that are considered to be optimization parameters. subdomain being the DOFs that are considered to be optimization parameters.
Parameters Parameters
---------- ----------
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment