Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
N
NIFTy
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Monitor
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
ift
NIFTy
Commits
c58224d8
Commit
c58224d8
authored
1 year ago
by
Gordian Edenhofer
Browse files
Options
Downloads
Patches
Plain Diff
stats_distributions: Return trafos suitable for within jit
parent
5c40e9ee
Branches
Branches containing commit
Tags
Tags containing commit
1 merge request
!904
Re likelihood
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/re/num/stats_distributions.py
+44
-43
44 additions, 43 deletions
src/re/num/stats_distributions.py
with
44 additions
and
43 deletions
src/re/num/stats_distributions.py
+
44
−
43
View file @
c58224d8
...
...
@@ -2,7 +2,7 @@ from functools import partial
from
typing
import
Callable
,
Optional
from
jax
import
numpy
as
jnp
from
jax.tree_util
import
tree_map
from
jax.tree_util
import
Partial
,
tree_map
from
..tree_math.vector_math
import
any
as
tree_any
...
...
@@ -12,7 +12,17 @@ log = partial(tree_map, jnp.log)
log1p
=
partial
(
tree_map
,
jnp
.
log1p
)
def
laplace_prior
(
alpha
)
->
Callable
:
def
_standard_to_laplace
(
xi
,
*
,
alpha
):
from
jax.scipy.stats
import
norm
norm_logcdf
=
partial
(
tree_map
,
norm
.
logcdf
)
res
=
(
xi
<
0
)
*
(
norm_logcdf
(
xi
)
+
jnp
.
log
(
2
))
res
-=
(
xi
>
0
)
*
(
norm_logcdf
(
-
xi
)
+
jnp
.
log
(
2
))
return
res
*
alpha
def
laplace_prior
(
alpha
)
->
Partial
:
"""
Takes random normal samples and outputs samples distributed according to
...
...
@@ -20,34 +30,28 @@ def laplace_prior(alpha) -> Callable:
P(x|a) = exp(-|x|/a)/a/2
"""
from
jax.scipy.stats
import
norm
norm_logcdf
=
partial
(
tree_map
,
norm
.
logcdf
)
return
Partial
(
_standard_to_laplace
,
alpha
=
alpha
)
def
standard_to_laplace
(
xi
):
res
=
(
xi
<
0
)
*
(
norm_logcdf
(
xi
)
+
jnp
.
log
(
2
))
res
-=
(
xi
>
0
)
*
(
norm_logcdf
(
-
xi
)
+
jnp
.
log
(
2
))
return
res
*
alpha
return
standard_to_laplace
def
_standard_to_normal
(
xi
,
*
,
mean
,
std
):
return
mean
+
std
*
xi
def
normal_prior
(
mean
,
std
)
->
Callable
:
def
normal_prior
(
mean
,
std
)
->
Partial
:
"""
Match standard normally distributed random variables to non-standard
variables.
"""
def
standard_to_normal
(
xi
):
return
mean
+
std
*
xi
return
Partial
(
_standard_to_normal
,
mean
=
mean
,
std
=
std
)
return
standard_to_normal
def
_normal_to_standard
(
y
,
*
,
mean
,
std
):
return
(
y
-
mean
)
/
std
def
normal_invprior
(
mean
,
std
)
->
Callable
:
"""
Get the inverse transform to `normal_prior`.
"""
def
normal_to_standard
(
y
):
return
(
y
-
mean
)
/
std
return
normal_to_standard
def
normal_invprior
(
mean
,
std
)
->
Partial
:
"""
Get the inverse transform to `normal_prior`.
"""
return
Partial
(
_normal_to_standard
,
mean
=
mean
,
std
=
std
)
def
lognormal_moments
(
mean
,
std
):
...
...
@@ -64,7 +68,11 @@ def lognormal_moments(mean, std):
return
logmean
,
logstd
def
lognormal_prior
(
mean
,
std
,
*
,
_log_mean
=
None
,
_log_std
=
None
)
->
Callable
:
def
_standard_to_lognormal
(
xi
,
*
,
log_mean
,
log_std
):
return
exp
(
_standard_to_normal
(
xi
,
mean
=
log_mean
,
std
=
log_std
))
def
lognormal_prior
(
mean
,
std
,
*
,
_log_mean
=
None
,
_log_std
=
None
)
->
Partial
:
"""
Moment-match standard normally distributed random variables to log-space
Takes random normal samples and outputs samples distributed according to
...
...
@@ -75,31 +83,29 @@ def lognormal_prior(mean, std, *, _log_mean=None, _log_std=None) -> Callable:
such that the mean and standard deviation of the distribution matches the
specified values.
"""
if
_log_mean
is
not
None
and
_log_std
is
not
None
:
standard_to_normal
=
normal_prior
(
_log_mean
,
_log_std
)
else
:
standard_to_normal
=
normal_prior
(
*
lognormal_moments
(
mean
,
std
))
if
_log_mean
is
None
and
_log_std
is
None
:
_log_mean
,
_log_std
=
lognormal_moments
(
mean
,
std
)
return
Partial
(
_standard_to_lognormal
,
log_mean
=
_log_mean
,
log_std
=
_log_std
)
def
standard_to_lognormal
(
xi
):
return
exp
(
standard_to_normal
(
xi
))
return
standard_to_lognormal
def
_lognormal_to_standard
(
y
,
*
,
log_mean
,
log_std
):
return
_normal_to_standard
(
log
(
y
),
mean
=
log_mean
,
std
=
log_std
)
def
lognormal_invprior
(
mean
,
std
,
*
,
_log_mean
=
None
,
_log_std
=
None
)
->
Callable
:
def
lognormal_invprior
(
mean
,
std
,
*
,
_log_mean
=
None
,
_log_std
=
None
)
->
Partial
:
"""
Get the inverse transform to `lognormal_prior`.
"""
if
_log_mean
is
not
None
and
_log_std
is
not
None
:
ln_m
,
ln
_std
=
_
log
_
mean
,
_log_
std
else
:
ln_m
,
ln_std
=
lognormal_moments
(
mean
,
std
)
if
_log_mean
is
None
and
_log_std
is
None
:
_log_mean
,
_log
_std
=
log
normal_moments
(
mean
,
std
)
return
Partial
(
_lognormal_to_standard
,
log_mean
=
_log_mean
,
log_std
=
_log_std
)
def
lognormal_to_standard
(
y
):
return
(
log
(
y
)
-
ln_m
)
/
ln_std
def
_standard_to_uniform
(
xi
,
*
,
a_min
,
scale
):
from
jax.scipy.stats
import
norm
return
lognormal_to_standard
return
a_min
+
scale
*
tree_map
(
norm
.
cdf
,
xi
)
def
uniform_prior
(
a_min
=
0.
,
a_max
=
1.
)
->
Callable
:
def
uniform_prior
(
a_min
=
0.
,
a_max
=
1.
)
->
Partial
:
"""
Transform a standard normal into a uniform distribution.
Parameters
...
...
@@ -111,18 +117,13 @@ def uniform_prior(a_min=0., a_max=1.) -> Callable:
"""
from
jax.scipy.stats
import
norm
norm_cdf
=
partial
(
tree_map
,
norm
.
cdf
)
scale
=
a_max
-
a_min
if
isinstance
(
a_min
,
float
)
and
isinstance
(
a_max
,
float
)
and
a_min
==
0.
and
a_max
==
1.
:
return
norm
_
cdf
return
Partial
(
partial
(
tree_map
,
norm
.
cdf
))
def
standard_to_uniform
(
xi
):
return
a_min
+
scale
*
norm_cdf
(
xi
)
return
standard_to_uniform
scale
=
a_max
-
a_min
return
Partial
(
_standard_to_uniform
,
a_min
=
a_min
,
scale
=
scale
)
def
interpolator
(
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment