Commit 9953fdbd authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge remote-tracking branch 'origin/NIFTy_6' into imag_operator

parents 97790be1 2990b099
Pipeline #75291 passed with stages
in 8 minutes and 22 seconds
...@@ -119,3 +119,8 @@ run_curve_fitting: ...@@ -119,3 +119,8 @@ run_curve_fitting:
artifacts: artifacts:
paths: paths:
- '*.png' - '*.png'
run_visual_mgvi:
stage: demo_runs
script:
- python3 demos/mgvi_visualized.py
Changes since NIFTy 5: Changes since NIFTy 5:
Minimum Python version increased to 3.6
=======================================
New operators
=============
In addition to the below changes, the following operators were introduced:
* UniformOperator: Transforms a Gaussian into a uniform distribution
* VariableCovarianceGaussianEnergy: Energy operator for inferring covariances
* MultiLinearEinsum: Multi-linear version of numpy's einsum with derivates
* LinearEinsum: Linear version of numpy's einsum with one free field
* PartialConjugate: Conjugates parts of a multi-field
* SliceOperator: Geometry preserving mask operator
* SplitOperator: Splits a single field into a multi-field
FFT convention adjusted
=======================
When going to harmonic space, NIFTy's FFT operator now uses a minus sign in the
exponent (and, consequently, a plus sign on the adjoint transform). This
convention is consistent with almost all other numerical FFT libraries.
Interface change in EndomorphicOperator.draw_sample()
=====================================================
Both complex-valued and real-valued Gaussian probability distributions have
Hermitian and positive endomorphisms as covariance. Just by looking at an
endomorphic operator itself it is not clear whether it is viewed as covariance
for real or complex Gaussians when a sample of the respective distribution shall
be drawn. Therefore, we introduce the method `draw_sample_with_dtype()` which
needs to be given the data type of the probability distribution. This function
is implemented for all operators which actually draw random numbers
(`DiagonalOperator` and `ScalingOperator`). The class `SamplingDtypeSetter` acts
as a wrapper for this kind of operators in order to fix the data type of the
distribution. Samples from these operators can be drawn with `.draw_sample()`.
In order to dive into those subtleties I suggest running the following code and
playing around with the dtypes.
```
import nifty6 as ift
import numpy as np
dom = ift.UnstructuredDomain(5)
dtype = [np.float64, np.complex128][1]
invcov = ift.ScalingOperator(dom, 3)
e = ift.GaussianEnergy(mean=ift.from_random(dom, 'normal', dtype=dtype),
inverse_covariance=invcov)
pos = ift.from_random(dom, 'normal', dtype=np.complex128)
lin = e(ift.Linearization.make_var(pos, want_metric=True))
met = lin.metric
print(met)
print(met.draw_sample())
```
MPI parallelisation over samples in MetricGaussianKL MPI parallelisation over samples in MetricGaussianKL
==================================================== ====================================================
...@@ -15,6 +71,13 @@ the generation of reproducible random numbers in the presence of MPI parallelism ...@@ -15,6 +71,13 @@ the generation of reproducible random numbers in the presence of MPI parallelism
and leads to cleaner code overall. Please see the documentation of and leads to cleaner code overall. Please see the documentation of
`nifty6.random` for details. `nifty6.random` for details.
Interface Change for from_random and OuterProduct
=================================================
The sugar.from_random, Field.from_random, MultiField.from_random now take domain
as the first argument and default to 'normal' for the second argument.
Likewise OuterProduct takes domain as the first argument and a field as the second.
Interface Change for non-linear Operators Interface Change for non-linear Operators
========================================= =========================================
......
...@@ -45,7 +45,7 @@ Installation ...@@ -45,7 +45,7 @@ Installation
### Requirements ### Requirements
- [Python 3](https://www.python.org/) (3.5.x or later) - [Python 3](https://www.python.org/) (3.6.x or later)
- [SciPy](https://www.scipy.org/) - [SciPy](https://www.scipy.org/)
Optional dependencies: Optional dependencies:
......
...@@ -43,7 +43,7 @@ if __name__ == '__main__': ...@@ -43,7 +43,7 @@ if __name__ == '__main__':
harmonic_space = position_space.get_default_codomain() harmonic_space = position_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(harmonic_space, position_space) HT = ift.HarmonicTransformOperator(harmonic_space, position_space)
position = ift.from_random('normal', harmonic_space) position = ift.from_random(harmonic_space, 'normal')
# Define power spectrum and amplitudes # Define power spectrum and amplitudes
def sqrtpspec(k): def sqrtpspec(k):
...@@ -58,13 +58,13 @@ if __name__ == '__main__': ...@@ -58,13 +58,13 @@ if __name__ == '__main__':
# Generate mock data # Generate mock data
p = R(sky) p = R(sky)
mock_position = ift.from_random('normal', harmonic_space) mock_position = ift.from_random(harmonic_space, 'normal')
tmp = p(mock_position).val.astype(np.float64) tmp = p(mock_position).val.astype(np.float64)
data = ift.random.current_rng().binomial(1, tmp) data = ift.random.current_rng().binomial(1, tmp)
data = ift.Field.from_raw(R.target, data) data = ift.Field.from_raw(R.target, data)
# Compute likelihood and Hamiltonian # Compute likelihood and Hamiltonian
position = ift.from_random('normal', harmonic_space) position = ift.from_random(harmonic_space, 'normal')
likelihood = ift.BernoulliEnergy(data) @ p likelihood = ift.BernoulliEnergy(data) @ p
ic_newton = ift.DeltaEnergyController( ic_newton = ift.DeltaEnergyController(
name='Newton', iteration_limit=100, tol_rel_deltaE=1e-8) name='Newton', iteration_limit=100, tol_rel_deltaE=1e-8)
......
...@@ -236,7 +236,7 @@ ...@@ -236,7 +236,7 @@
"R = HT #*ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)\n", "R = HT #*ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)\n",
"\n", "\n",
"# Fields and data\n", "# Fields and data\n",
"sh = Sh.draw_sample()\n", "sh = Sh.draw_sample_with_dtype(dtype=np.float64)\n",
"noiseless_data=R(sh)\n", "noiseless_data=R(sh)\n",
"noise_amplitude = np.sqrt(0.2)\n", "noise_amplitude = np.sqrt(0.2)\n",
"N = ift.ScalingOperator(s_space, noise_amplitude**2)\n", "N = ift.ScalingOperator(s_space, noise_amplitude**2)\n",
...@@ -394,7 +394,7 @@ ...@@ -394,7 +394,7 @@
"# R is defined below\n", "# R is defined below\n",
"\n", "\n",
"# Fields\n", "# Fields\n",
"sh = Sh.draw_sample()\n", "sh = Sh.draw_sample_with_dtype(dtype=np.float64)\n",
"s = HT(sh)\n", "s = HT(sh)\n",
"n = ift.Field.from_random(domain=s_space, random_type='normal',\n", "n = ift.Field.from_random(domain=s_space, random_type='normal',\n",
" std=noise_amplitude, mean=0)" " std=noise_amplitude, mean=0)"
...@@ -471,7 +471,7 @@ ...@@ -471,7 +471,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200)" "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200, np.float64)"
] ]
}, },
{ {
...@@ -571,7 +571,7 @@ ...@@ -571,7 +571,7 @@
"N = ift.ScalingOperator(s_space, sigma2)\n", "N = ift.ScalingOperator(s_space, sigma2)\n",
"\n", "\n",
"# Fields and data\n", "# Fields and data\n",
"sh = Sh.draw_sample()\n", "sh = Sh.draw_sample_with_dtype(dtype=np.float64)\n",
"n = ift.Field.from_random(domain=s_space, random_type='normal',\n", "n = ift.Field.from_random(domain=s_space, random_type='normal',\n",
" std=np.sqrt(sigma2), mean=0)\n", " std=np.sqrt(sigma2), mean=0)\n",
"\n", "\n",
...@@ -598,7 +598,7 @@ ...@@ -598,7 +598,7 @@
"m = D(j)\n", "m = D(j)\n",
"\n", "\n",
"# Uncertainty\n", "# Uncertainty\n",
"m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20)\n", "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20, np.float64)\n",
"\n", "\n",
"# Get data\n", "# Get data\n",
"s_data = HT(sh).val\n", "s_data = HT(sh).val\n",
...@@ -709,8 +709,15 @@ ...@@ -709,8 +709,15 @@
"\n", "\n",
"https://gitlab.mpcdf.mpg.de/ift/NIFTy\n", "https://gitlab.mpcdf.mpg.de/ift/NIFTy\n",
"\n", "\n",
"NIFTy v5 **more or less stable!**" "NIFTy v6 **more or less stable!**"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
...@@ -730,7 +737,7 @@ ...@@ -730,7 +737,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.5" "version": "3.8.2"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -40,7 +40,7 @@ def make_checkerboard_mask(position_space): ...@@ -40,7 +40,7 @@ def make_checkerboard_mask(position_space):
def make_random_mask(): def make_random_mask():
# Random mask for spherical mode # Random mask for spherical mode
mask = ift.from_random('pm1', position_space) mask = ift.from_random(position_space, 'pm1')
mask = (mask + 1)/2 mask = (mask + 1)/2
return mask.val return mask.val
...@@ -114,8 +114,8 @@ if __name__ == '__main__': ...@@ -114,8 +114,8 @@ if __name__ == '__main__':
N = ift.ScalingOperator(data_space, noise) N = ift.ScalingOperator(data_space, noise)
# Create mock data # Create mock data
MOCK_SIGNAL = S.draw_sample() MOCK_SIGNAL = S.draw_sample_with_dtype(dtype=np.float64)
MOCK_NOISE = N.draw_sample() MOCK_NOISE = N.draw_sample_with_dtype(dtype=np.float64)
data = R(MOCK_SIGNAL) + MOCK_NOISE data = R(MOCK_SIGNAL) + MOCK_NOISE
# Build inverse propagator D and information source j # Build inverse propagator D and information source j
......
...@@ -90,7 +90,7 @@ if __name__ == '__main__': ...@@ -90,7 +90,7 @@ if __name__ == '__main__':
# Generate mock data and define likelihood operator # Generate mock data and define likelihood operator
d_space = R.target[0] d_space = R.target[0]
lamb = R(sky) lamb = R(sky)
mock_position = ift.from_random('normal', domain) mock_position = ift.from_random(domain, 'normal')
data = lamb(mock_position) data = lamb(mock_position)
data = ift.random.current_rng().poisson(data.val.astype(np.float64)) data = ift.random.current_rng().poisson(data.val.astype(np.float64))
data = ift.Field.from_raw(d_space, data) data = ift.Field.from_raw(d_space, data)
...@@ -103,7 +103,7 @@ if __name__ == '__main__': ...@@ -103,7 +103,7 @@ if __name__ == '__main__':
# Compute MAP solution by minimizing the information Hamiltonian # Compute MAP solution by minimizing the information Hamiltonian
H = ift.StandardHamiltonian(likelihood) H = ift.StandardHamiltonian(likelihood)
initial_position = ift.from_random('normal', domain) initial_position = ift.from_random(domain, 'normal')
H = ift.EnergyAdapter(initial_position, H, want_metric=True) H = ift.EnergyAdapter(initial_position, H, want_metric=True)
H, convergence = minimizer(H) H, convergence = minimizer(H)
......
...@@ -98,8 +98,8 @@ if __name__ == '__main__': ...@@ -98,8 +98,8 @@ if __name__ == '__main__':
N = ift.ScalingOperator(data_space, noise) N = ift.ScalingOperator(data_space, noise)
# Generate mock signal and data # Generate mock signal and data
mock_position = ift.from_random('normal', signal_response.domain) mock_position = ift.from_random(signal_response.domain, 'normal')
data = signal_response(mock_position) + N.draw_sample() data = signal_response(mock_position) + N.draw_sample_with_dtype(dtype=np.float64)
# Minimization parameters # Minimization parameters
ic_sampling = ift.AbsDeltaEnergyController( ic_sampling = ift.AbsDeltaEnergyController(
......
...@@ -97,8 +97,8 @@ if __name__ == '__main__': ...@@ -97,8 +97,8 @@ if __name__ == '__main__':
N = ift.ScalingOperator(data_space, noise) N = ift.ScalingOperator(data_space, noise)
# Generate mock signal and data # Generate mock signal and data
mock_position = ift.from_random('normal', signal_response.domain) mock_position = ift.from_random(signal_response.domain, 'normal')
data = signal_response(mock_position) + N.draw_sample() data = signal_response(mock_position) + N.draw_sample_with_dtype(dtype=np.float64)
plot = ift.Plot() plot = ift.Plot()
plot.add(signal(mock_position), title='Ground Truth') plot.add(signal(mock_position), title='Ground Truth')
...@@ -114,7 +114,9 @@ if __name__ == '__main__': ...@@ -114,7 +114,9 @@ if __name__ == '__main__':
ic_newton = ift.AbsDeltaEnergyController(name='Newton', ic_newton = ift.AbsDeltaEnergyController(name='Newton',
deltaE=0.01, deltaE=0.01,
iteration_limit=35) iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton) ic_sampling.enable_logging()
ic_newton.enable_logging()
minimizer = ift.NewtonCG(ic_newton, activate_logging=True)
## number of samples used to estimate the KL ## number of samples used to estimate the KL
N_samples = 20 N_samples = 20
...@@ -143,10 +145,15 @@ if __name__ == '__main__': ...@@ -143,10 +145,15 @@ if __name__ == '__main__':
plot.add([A2.force(KL.position), plot.add([A2.force(KL.position),
A2.force(mock_position)], A2.force(mock_position)],
title="power2") title="power2")
plot.output(nx=2, plot.add((ic_newton.history, ic_sampling.history,
minimizer.inversion_history),
label=['KL', 'Sampling', 'Newton inversion'],
title='Cumulative energies', s=[None, None, 1],
alpha=[None, 0.2, None])
plot.output(nx=3,
ny=2, ny=2,
ysize=10, ysize=10,
xsize=10, xsize=15,
name=filename.format("loop_{:02d}".format(i))) name=filename.format("loop_{:02d}".format(i)))
# Done, draw posterior samples # Done, draw posterior samples
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Authors: Reimar Leike, Philipp Arras
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
###############################################################################
# Metric Gaussian Variational Inference (MGVI)
#
# This script demonstrates how MGVI works for an inference problem with only
# two real quantities of interest. This enables us to plot the posterior
# probability density as two-dimensional plot. The posterior samples generated
# by MGVI are contrasted with the maximum-a-posterior (MAP) solution together
# with samples drawn with the Laplace method. This method uses the local
# curvature at the MAP solution as inverse covariance of a Gaussian probability
# density.
###############################################################################
import numpy as np
import pylab as plt
from matplotlib.colors import LogNorm
import nifty6 as ift
if __name__ == '__main__':
dom = ift.UnstructuredDomain(1)
scale = 10
a = ift.FieldAdapter(dom, 'a')
b = ift.FieldAdapter(dom, 'b')
lh = (a.adjoint @ a).scale(scale) + (b.adjoint @ b).scale(-1.35*2).exp()
lh = ift.VariableCovarianceGaussianEnergy(dom, 'a', 'b', np.float64) @ lh
icsamp = ift.AbsDeltaEnergyController(deltaE=0.1, iteration_limit=2)
ham = ift.StandardHamiltonian(lh, icsamp)
x_limits = [-8/scale, 8/scale]
x_limits_scaled = [-8, 8]
y_limits = [-4, 4]
x = np.linspace(*x_limits, num=401)
y = np.linspace(*y_limits, num=401)
xx, yy = np.meshgrid(x, y, indexing='ij')
def np_ham(x, y):
prior = x**2 + y**2
mean = x*scale
lcov = 1.35*2*y
lh = .5*(mean**2*np.exp(-lcov) + lcov)
return lh + prior
z = np.exp(-1*np_ham(xx, yy))
plt.plot(y, np.sum(z, axis=0))
plt.xlabel('y')
plt.ylabel('unnormalized pdf')
plt.title('Marginal density')
plt.pause(2.0)
plt.close()
plt.plot(x*scale, np.sum(z, axis=1))
plt.xlabel('x')
plt.ylabel('unnormalized pdf')
plt.title('Marginal density')
plt.pause(2.0)
plt.close()
pos = ift.from_random(ham.domain, 'normal')
MAP = ift.EnergyAdapter(pos, ham, want_metric=True)
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=20, name='Mini'))
MAP, _ = minimizer(MAP)
map_xs, map_ys = [], []
for ii in range(10):
samp = (MAP.metric.draw_sample(from_inverse=True) + MAP.position).val
map_xs.append(samp['a'])
map_ys.append(samp['b'])
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=2, name='Mini'))
pos = ift.from_random(ham.domain, 'normal')
plt.figure(figsize=[12, 8])
for ii in range(15):
if ii % 3 == 0:
mgkl = ift.MetricGaussianKL(pos, ham, 40)
plt.cla()
plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3,
vmax=np.max(z), cmap='gist_earth_r',
extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar()
cbar.ax.set_ylabel('pdf')
xs, ys = [], []
for samp in mgkl.samples:
samp = (samp + pos).val
xs.append(samp['a'])
ys.append(samp['b'])
plt.scatter(np.array(xs)*scale, np.array(ys), label='MGVI samples')
plt.scatter(pos.val['a']*scale, pos.val['b'], label='MGVI latent mean')
plt.scatter(np.array(map_xs)*scale, np.array(map_ys),
label='Laplace samples')
plt.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution')
plt.legend()
plt.draw()
plt.pause(1.0)
mgkl, _ = minimizer(mgkl)
pos = mgkl.position
ift.logger.info('Finished')
# Uncomment the following line in order to leave the plots open
# plt.show()
rm -rf docs/build docs/source/mod rm -rf docs/build docs/source/mod
sphinx-apidoc -e -o docs/source/mod nifty6 EXCLUDE="nifty6/logger.py nifty6/git_version.py"
sphinx-apidoc -e -o docs/source/mod nifty6 ${EXCLUDE}
sphinx-build -b html docs/source/ docs/build/ sphinx-build -b html docs/source/ docs/build/
...@@ -7,7 +7,7 @@ NIFTy-related publications ...@@ -7,7 +7,7 @@ NIFTy-related publications
author={Arras, Philipp and Baltac, Mihai and Ensslin, Torsten A and Frank, Philipp and Hutschenreuter, Sebastian and Knollmueller, Jakob and Leike, Reimar and Newrzella, Max-Niklas and Platz, Lukas and Reinecke, Martin and others}, author={Arras, Philipp and Baltac, Mihai and Ensslin, Torsten A and Frank, Philipp and Hutschenreuter, Sebastian and Knollmueller, Jakob and Leike, Reimar and Newrzella, Max-Niklas and Platz, Lukas and Reinecke, Martin and others},
journal={Astrophysics Source Code Library}, journal={Astrophysics Source Code Library},
year={2019} year={2019}
} }
@software{nifty, @software{nifty,
author = {{Martin Reinecke, Theo Steininger, Marco Selig}}, author = {{Martin Reinecke, Theo Steininger, Marco Selig}},
...@@ -15,7 +15,7 @@ NIFTy-related publications ...@@ -15,7 +15,7 @@ NIFTy-related publications
url = {https://gitlab.mpcdf.mpg.de/ift/NIFTy}, url = {https://gitlab.mpcdf.mpg.de/ift/NIFTy},
version = {nifty6}, version = {nifty6},
date = {2018-04-05}, date = {2018-04-05},
} }
@article{2013A&A...554A..26S, @article{2013A&A...554A..26S,
author = {{Selig}, M. and {Bell}, M.~R. and {Junklewitz}, H. and {Oppermann}, N. and {Reinecke}, M. and {Greiner}, M. and {Pachajoa}, C. and {En{\ss}lin}, T.~A.}, author = {{Selig}, M. and {Bell}, M.~R. and {Junklewitz}, H. and {Oppermann}, N. and {Reinecke}, M. and {Greiner}, M. and {Pachajoa}, C. and {En{\ss}lin}, T.~A.},
...@@ -33,7 +33,7 @@ NIFTy-related publications ...@@ -33,7 +33,7 @@ NIFTy-related publications
doi = {10.1051/0004-6361/201321236}, doi = {10.1051/0004-6361/201321236},
adsurl = {http://cdsads.u-strasbg.fr/abs/2013A%26A...554A..26S}, adsurl = {http://cdsads.u-strasbg.fr/abs/2013A%26A...554A..26S},
adsnote = {Provided by the SAO/NASA Astrophysics Data System} adsnote = {Provided by the SAO/NASA Astrophysics Data System}
} }
@article{2017arXiv170801073S, @article{2017arXiv170801073S,
author = {{Steininger}, T. and {Dixit}, J. and {Frank}, P. and {Greiner}, M. and {Hutschenreuter}, S. and {Knollm{\"u}ller}, J. and {Leike}, R. and {Porqueres}, N. and {Pumpe}, D. and {Reinecke}, M. and {{\v S}raml}, M. and {Varady}, C. and {En{\ss}lin}, T.}, author = {{Steininger}, T. and {Dixit}, J. and {Frank}, P. and {Greiner}, M. and {Hutschenreuter}, S. and {Knollm{\"u}ller}, J. and {Leike}, R. and {Porqueres}, N. and {Pumpe}, D. and {Reinecke}, M. and {{\v S}raml}, M. and {Varady}, C. and {En{\ss}lin}, T.},
...@@ -47,4 +47,4 @@ NIFTy-related publications ...@@ -47,4 +47,4 @@ NIFTy-related publications
month = aug, month = aug,
adsurl = {http://cdsads.u-strasbg.fr/abs/2017arXiv170801073S}, adsurl = {http://cdsads.u-strasbg.fr/abs/2017arXiv170801073S},
adsnote = {Provided by the SAO/NASA Astrophysics Data System} adsnote = {Provided by the SAO/NASA Astrophysics Data System}
} }
...@@ -25,7 +25,8 @@ from .operators.adder import Adder ...@@ -25,7 +25,8 @@ from .operators.adder import Adder
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
from .operators.contraction_operator import ContractionOperator from .operators.einsum import LinearEinsum, MultiLinearEinsum
from .operators.contraction_operator import ContractionOperator, IntegrationOperator
from .operators.linear_interpolation import LinearInterpolator from .operators.linear_interpolation import LinearInterpolator
from .operators.endomorphic_operator import EndomorphicOperator from .operators.endomorphic_operator import EndomorphicOperator
from .operators.harmonic_operators import ( from .operators.harmonic_operators import (
...@@ -35,16 +36,16 @@ from .operators.field_zero_padder import FieldZeroPadder ...@@ -35,16 +36,16 @@ from .operators.field_zero_padder import FieldZeroPadder
from .operators.inversion_enabler import InversionEnabler from .operators.inversion_enabler import InversionEnabler
from .operators.mask_operator import MaskOperator from .operators.mask_operator import MaskOperator
from .operators.regridding_operator import RegriddingOperator from .operators.regridding_operator import RegriddingOperator
from .operators.sampling_enabler import SamplingEnabler from .operators.sampling_enabler import SamplingEnabler, SamplingDtypeSetter
from .operators.sandwich_operator import SandwichOperator from .operators.sandwich_operator import SandwichOperator
from .operators.scaling_operator import ScalingOperator from .operators.scaling_operator import ScalingOperator
from .operators.selection_operators import SliceOperator, SplitOperator
from .operators.block_diagonal_operator import BlockDiagonalOperator from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import ( from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer, VdotOperator, ConjugationOperator, Realizer, FieldAdapter, ducktape,
FieldAdapter, ducktape, GeometryRemover, NullOperator, GeometryRemover, NullOperator, PartialExtractor, Imaginizer)
MatrixProductOperator, PartialExtractor, SwitchSpacesOperator, from .operators.matrix_product_operator import MatrixProductOperator
Imaginizer)
from .operators.value_inserter import ValueInserter from .operators.value_inserter import ValueInserter
from .operators.energy_operators import ( from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood, EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
...@@ -97,5 +98,7 @@ from .linearization import Linearization ...@@ -97,5 +98,7 @@ from .linearization import Linearization
from .operator_spectrum import operator_spectrum from .operator_spectrum import operator_spectrum
from .operator_tree_optimiser import optimise_operator
# We deliberately don't set __all__ here, because we don't want people to do a # We deliberately don't set __all__ here, because we don't want people to do a
# "from nifty6 import *"; that would swamp the global namespace. # "from nifty6 import *"; that would swamp the global namespace.
...@@ -42,8 +42,8 @@ def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol, ...@@ -42,8 +42,8 @@ def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
needed_cap = op.TIMES | op.ADJOINT_TIMES needed_cap = op.TIMES | op.ADJOINT_TIMES
if (