ift issueshttps://gitlab.mpcdf.mpg.de/groups/ift/-/issues2024-03-22T09:21:40Zhttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/405Memory leak in re.optimize_kl2024-03-22T09:21:40ZSimon StraehnzMemory leak in re.optimize_klI recently tried to do some batch reconstructions again, and apparently there is still a memory leak in `re.optimize_kl` when running multiple reconstructions in a loop. I have attached a modified version of the demo `0_intro.py` that sh...I recently tried to do some batch reconstructions again, and apparently there is still a memory leak in `re.optimize_kl` when running multiple reconstructions in a loop. I have attached a modified version of the demo `0_intro.py` that shows the problem. The leak in this example amounts to ca. 233 MiB per iteration, even though the dof * samples * 8 (float64) = 613 KiB.
Edit: it actually is "only" 133 MiB, see below
[0_intro_memtest.py](/uploads/3f6802f83e3ed20e3605c6a8604c2130/0_intro_memtest.py)https://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/404Gauss-Makrov on GPU2024-03-05T16:54:12ZPhilipp HaimGauss-Makrov on GPUThe the loop in the general gm-process introduced in [this](%) commit leads to significant performance degradation of the correlated field on the GPU. Changing the loop to a `cumsum` (at least for the correlated field) should fix this is...The the loop in the general gm-process introduced in [this](%) commit leads to significant performance degradation of the correlated field on the GPU. Changing the loop to a `cumsum` (at least for the correlated field) should fix this issue.https://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/403NIFTy operator in NIFTy.re2024-02-27T15:08:14ZJakob RothNIFTy operator in NIFTy.reIn the classic NIFTy we have the `JaxOperator` interfacing classic nifty with jax. I don't think it was extensively used, but it was certainly handy for some projects.
The new version of jax_linop (https://github.com/NIFTy-PPL/jax_linop...In the classic NIFTy we have the `JaxOperator` interfacing classic nifty with jax. I don't think it was extensively used, but it was certainly handy for some projects.
The new version of jax_linop (https://github.com/NIFTy-PPL/jax_linop) now also supports binding nonlinear functions to jax. Thus we could also build the reverse binding a `NiftyOperator` which takes a classic NIFTy operator and returns a Jax primitive. Of course, this operator won't be as performant as a native Jax implementation, but I think this would still be very handy, especially for projects that transition their code base to Jax, since with such an operator, it would be possible to gradually move to Jax and keep some legacy nifty code in the beginning.
Once jax_linop is stable, I could implement such an operator. @pfrank @gedenhof, what do you thinkhttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/402Non-jitted model evalutation in optmize_kl2024-02-18T17:33:14ZPhilipp HaimNon-jitted model evalutation in optmize_klI have a use-case where I want to evaluate a part of my model on a GPU, while having the inital and final parts of the model on the CPU. To my understanding, such a function cannot be jitted (at the moment). However, the conjugate gradie...I have a use-case where I want to evaluate a part of my model on a GPU, while having the inital and final parts of the model on the CPU. To my understanding, such a function cannot be jitted (at the moment). However, the conjugate gradient implementation both explicitly and implicitly (through the call of jax.lax.while_loop) jits the model evaluation. The model is then executed fully on the CPU. Trying to avoid this compilation by removing [this jit-call in conjugate_gradient](https://gitlab.mpcdf.mpg.de/ift/nifty/-/blob/NIFTy_8/src/re/conjugate_gradient.py?ref_type=heads#L242) leads to an error in the following `while_loop` [here](https://gitlab.mpcdf.mpg.de/ift/nifty/-/blob/NIFTy_8/src/re/conjugate_gradient.py?ref_type=heads#L382) do to the some of the arrays being on different devices.
As a sanity check I have tested `jax.linearize` and `jax.linear_transpose` on the model, which works without any (apparent) issues. So I have hopes that `optimize_kl` could work with such a model, if all jit-compilations of the full model call would be avoided.
If I understand [this issue](https://github.com/google/jax/discussions/17040) correctly, JAX is working on an API that would allow this kind of behavior in a jitted function, however I couldn't find any current information on that feature. So while this might resolve the issue, it is unclear when it would be available.
As there are already flags like `kl_jit` and `residual_jit`, would it be possible to add a flag that prevents any jit-compilation within the optimization routine? Are there any other places where an implicit jit-compilation takes place that I am missing?
If it would help I could provide a minimal working example for this issue.https://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/401Implement Wiener Filter method2024-02-01T11:21:52ZGordian EdenhoferImplement Wiener Filter methodImplement a method solving the Wiener Filter for a given likelihood. As to generalize to non-Gaussian likelihoods, make it optionally depend on the position. This could be implement by simply wrapping `jft.draw_linear_residual`.Implement a method solving the Wiener Filter for a given likelihood. As to generalize to non-Gaussian likelihoods, make it optionally depend on the position. This could be implement by simply wrapping `jft.draw_linear_residual`.https://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/400ift.Plot breaks silently with arguments vmin, vmax and LogNorm2024-03-15T11:43:58ZVincent Eberleift.Plot breaks silently with arguments vmin, vmax and LogNormwhen using vmin and vmax outside of LogNorm using matplotlib you get the following error message:
e.g.
```Python
mock_signal = signal(ift.from_random(signal.domain)).val
plt.imshow(mock_signal.T, origin="lower",vmin=1e-9,vmax=1e-5, nor...when using vmin and vmax outside of LogNorm using matplotlib you get the following error message:
e.g.
```Python
mock_signal = signal(ift.from_random(signal.domain)).val
plt.imshow(mock_signal.T, origin="lower",vmin=1e-9,vmax=1e-5, norm=LogNorm())
plt.colorbar()
plt.show()
```
Passing parameters norm and vmin/vmax simultaneously is not supported. Please pass vmin/vmax directly to the norm when creating it.
But using ift.Plot():
e.g.
```
plot = ift.Plot()
plot.add(signal(ift.from_random(signal.domain)).val, vmin=1e-9, vmax=1e-5, norm=LogNorm())
plot.output()
```
will display a weird plot![Screenshot_from_2024-01-29_15-31-09](/uploads/12bc10d66f37486a5078a652c6fb9464/Screenshot_from_2024-01-29_15-31-09.png)Vincent EberleVincent Eberlehttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/399Sum of likelihoods with named domains operators2024-01-09T20:23:53ZMatteo GuardianiSum of likelihoods with named domains operatorsWhen summing two likelihoods applied to some operators with different named domains, the summation does not work during minimization, because the metric update does not know how to correctly join multiple named inputs.
Minimal breaking e...When summing two likelihoods applied to some operators with different named domains, the summation does not work during minimization, because the metric update does not know how to correctly join multiple named inputs.
Minimal breaking example:
```
from jax import random
import nifty8.re as jft
import numpy as np
import jax
jax.config.update("jax_platform_name", "cpu")
seed = 42
key = random.PRNGKey(seed)
shape = (10, 10)
cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)}
cf_fl = {
"fluctuations": (1e-1, 5e-3),
"loglogavgslope": (-3., 1e-2),
"flexibility": (1e+0, 5e-1),
"asperity": (5e-1, 5e-2),
}
cfm = jft.CorrelatedFieldMaker("jcf_")
cfm.set_amplitude_total_offset(**cf_zm)
cfm.add_fluctuations(
shape,
distances=1. / shape[0],
**cf_fl,
prefix="",
non_parametric_kind="power",
)
jcf = cfm.finalize()
key, subkey = random.split(key)
pos = jft.random_like(subkey, jcf.domain)
noise_level = 0.2
datar = jcf(pos) + np.random.normal(0, 1, shape)*noise_level
dom_key = 'a'
m = jft.Model(
lambda x: {dom_key: jcf(x)},
domain=jcf.domain)
R = jft.Model(lambda x: x[dom_key], domain=m.target)
like1 = jft.Gaussian(datar, lambda x: 1/noise_level**2 * x) @ R
like2 = jft.Gaussian(datar, lambda x: 2/noise_level**2 * x) @ R
ll = (like1 + like2) @ m
pos = jft.random_like(key, jcf.domain)
n_iterations = 2
n_samples = 2
delta = 1e-3
absdelta = 1e-4
samples, _ = jft.optimize_kl(
ll,
jft.Vector(pos),
n_total_iterations=n_iterations,
n_samples=n_samples,
# Source for the stochasticity for sampling
key=key,
draw_linear_kwargs=dict(cg_name="SL",
cg_kwargs=dict(absdelta=absdelta / 10., maxiter=100)),
nonlinearly_update_kwargs=dict(
minimize_kwargs=dict(
name="SN",
xtol=delta,
cg_kwargs=dict(name=None),
maxiter=5,
)
),
kl_kwargs=dict(
minimize_kwargs=dict(
name="M", absdelta=absdelta, cg_kwargs=dict(name="MCG"), maxiter=35
)
),
sample_mode="nonlinear_resample",
resume=False)
```
gives the following error:
```
~/pro/python/nifty/nifty8/re/likelihood.py in joined_left_sqrt_metric(p, t, **pkw)
641 def joined_left_sqrt_metric(p, t, **pkw):
642 return (
--> 643 self.left_sqrt_metric(p, t[lkey], **pkw) +
644 other.left_sqrt_metric(p, t[rkey], **pkw)
645 )
TypeError: unsupported operand type(s) for +: 'dict' and 'dict'
```https://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/398Use ducc0 in NIFTy.re when available2023-12-19T11:09:31ZGordian EdenhoferUse ducc0 in NIFTy.re when availableWrap ducc0 functionality using https://gitlab.mpcdf.mpg.de/jroth/extending-jax-and-nifty and use it by default for the Hartley transformation in the CFM.
Related to #371 .Wrap ducc0 functionality using https://gitlab.mpcdf.mpg.de/jroth/extending-jax-and-nifty and use it by default for the Hartley transformation in the CFM.
Related to #371 .Gordian EdenhoferGordian Edenhoferhttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/397NIFTy re demo2023-12-08T22:42:12ZJakob RothNIFTy re demoThe nifty.re demo now sets point estimates to showcase this feature: https://gitlab.mpcdf.mpg.de/ift/nifty/-/blob/NIFTy_8/demos/nifty_re.py?ref_type=heads#L202
As the first new nifty.re-users now copy this over to their own script witho...The nifty.re demo now sets point estimates to showcase this feature: https://gitlab.mpcdf.mpg.de/ift/nifty/-/blob/NIFTy_8/demos/nifty_re.py?ref_type=heads#L202
As the first new nifty.re-users now copy this over to their own script without thinking about what this implies; I am wondering if we should remove it in the demo. @pfrank, @gedenhof, what do you think? Alternatively, we could also add a comment that in many applications, one probably doesn't want to set point estimates or at least would need to adapt the keys to the domain of the model.https://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/396smap axis ordering2023-12-08T18:56:27ZLaurin Södingsmap axis orderingThere seems to be an error with the axis ordering in smap for the corner case of multiple size 1 axes. Here is a minimal example that compares to vmap:
```
import nifty8.re as jft
import jax
import jax.numpy as jnp
def simple_function(...There seems to be an error with the axis ordering in smap for the corner case of multiple size 1 axes. Here is a minimal example that compares to vmap:
```
import nifty8.re as jft
import jax
import jax.numpy as jnp
def simple_function(x):
return x
a = jnp.arange(100)[..., jnp.newaxis, jnp.newaxis]
print(jft.smap(simple_function, in_axes=(2), out_axes=(1))(a).shape)
print(jax.vmap(simple_function, in_axes=(2), out_axes=(1))(a).shape)
# (1, 1, 100)
# (100, 1, 1)
```https://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/395NIFTy.re: try to get JAXified ducc0 operators?2023-11-29T15:15:02ZMartin ReineckeNIFTy.re: try to get JAXified ducc0 operators?I think it might be really beneficial to get the linear operators provided by `ducc0` (SHT, NUFFT, wgridder) usable from JAX. I think I have made some small progress in understanding what is necessary, but I really need help from someone...I think it might be really beneficial to get the linear operators provided by `ducc0` (SHT, NUFFT, wgridder) usable from JAX. I think I have made some small progress in understanding what is necessary, but I really need help from someone more familiar with JAX and its custom calls.
Any volunteers? @gedenhof, @jroth, @pfrank?https://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/394remove .envrc2023-11-27T15:01:57ZVincent Eberleremove .envrchttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/393Register `Likelihood` as PyTree2024-01-21T17:49:36ZGordian EdenhoferRegister `Likelihood` as PyTreeIf `Likelihood` would be a registered PyTree, we could compile `draw_sample(likelihood, primals, *a, **k)` and the data would be traced instead of inlined (I think)!If `Likelihood` would be a registered PyTree, we could compile `draw_sample(likelihood, primals, *a, **k)` and the data would be traced instead of inlined (I think)!Gordian EdenhoferGordian Edenhoferhttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/392Testing of re.optimize_kl2023-11-23T22:00:28ZPhilipp FrankTesting of re.optimize_klCurrently the `re.optimize_kl` and `re.OptimizeVI` features remain largely untested. Only the demos ensure that `optimize_kl` runs through as intended for one specific configuration. We should at least add tests verifying that the differ...Currently the `re.optimize_kl` and `re.OptimizeVI` features remain largely untested. Only the demos ensure that `optimize_kl` runs through as intended for one specific configuration. We should at least add tests verifying that the different configurations produce the intended outcomes and that the update rules behave as expected.Philipp FrankPhilipp Frankhttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/391any check for lognormal_moments2023-11-24T23:43:23ZJulian Rüstigany check for lognormal_momentsInside the stats_distributions, a check ensuring positivity for the mean and the std for the lognormal_distribution, throws an error if one supplies an array for the mean or the standard deviation.Inside the stats_distributions, a check ensuring positivity for the mean and the std for the lognormal_distribution, throws an error if one supplies an array for the mean or the standard deviation.Julian RüstigJulian Rüstighttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/390update likelihood sum domains2023-11-25T16:11:30ZJulian Rüstigupdate likelihood sum domainsWhen adding two likelihoods in nifty.re the domains should be updated.When adding two likelihoods in nifty.re the domains should be updated.Julian RüstigJulian Rüstighttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/389Incorrect number of samples in optimize_kl2023-11-27T10:15:22ZAymeric GalanIncorrect number of samples in optimize_klAfter I updated my script to the last changes of `re.optimize_kl`, it seems that when I provide a varying number of samples (i.e. `n_samples` is a callable), `optimize_kl()` only returns a number of samples that corresponds to the first ...After I updated my script to the last changes of `re.optimize_kl`, it seems that when I provide a varying number of samples (i.e. `n_samples` is a callable), `optimize_kl()` only returns a number of samples that corresponds to the first iteration defined in `n_samples`. Not sure if the during optimization it actually uses the right number of samples though (it seems to do so).https://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/387remove flake.lock2023-11-27T15:01:48ZVincent Eberleremove flake.lockhttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/386remove flake.nix2023-11-27T15:01:24ZVincent Eberleremove flake.nixhttps://gitlab.mpcdf.mpg.de/ift/nifty/-/issues/385update setup.py2023-11-17T18:19:11ZVincent Eberleupdate setup.pyVincent EberleVincent Eberle