Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
380b5273
Commit
380b5273
authored
Mar 11, 2020
by
Philipp Arras
Browse files
Restructuring
parent
e627438f
Pipeline
#70652
passed with stages
in 17 minutes and 55 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/library/special_distributions.py
View file @
380b5273
...
...
@@ -20,12 +20,18 @@ from scipy.stats import invgamma, norm
from
..
import
Adder
from
..domain_tuple
import
DomainTuple
from
..domains.unstructured_domain
import
UnstructuredDomain
from
..field
import
Field
from
..linearization
import
Linearization
from
..operators.operator
import
Operator
from
..sugar
import
makeOp
def
_f_on_np
(
f
,
arr
):
fld
=
Field
.
from_raw
(
UnstructuredDomain
(
arr
.
shape
),
arr
)
return
f
(
fld
).
val
class
_InterpolationOperator
(
Operator
):
"""
Calculates a function pointwise on a field by interpolation.
...
...
@@ -38,44 +44,34 @@ class _InterpolationOperator(Operator):
The domain on which the field shall be defined. This is at the same
time the domain and the target of the operator.
func : function
The function which is applied on the field.
The function which is applied on the field. Assumed to act on numpy
arrays.
xmin : float
The smallest value for which func will be evaluated.
xmax : float
The largest value for which func will be evaluated.
delta : float
Distance between sampling points for linear interpolation.
table_func : {'None', 'exp', 'log', 'power'}
exponent : float
This is only used by the 'power' table_func.
table_func : function
Non-linear function applied to table in order to transform the table
to a more linear space. Assumed to act on `Linearization`s, optional.
inv_table_func : function
Inverse of table_func, optional.
"""
def
__init__
(
self
,
domain
,
func
,
xmin
,
xmax
,
delta
,
table_func
=
None
,
exponent
=
None
):
def
__init__
(
self
,
domain
,
func
,
xmin
,
xmax
,
delta
,
table_func
=
None
,
inv_table_func
=
None
):
self
.
_domain
=
self
.
_target
=
DomainTuple
.
make
(
domain
)
self
.
_xmin
,
self
.
_xmax
=
float
(
xmin
),
float
(
xmax
)
self
.
_d
=
float
(
delta
)
self
.
_xs
=
np
.
arange
(
xmin
,
xmax
+
2
*
self
.
_d
,
self
.
_d
)
self
.
_table
=
func
(
self
.
_xs
)
self
.
_transform
=
table_func
is
not
None
self
.
_args
=
[]
if
exponent
is
not
None
and
table_func
!=
'power'
:
raise
Exception
(
"exponent is only used when table_func is 'power'."
)
if
table_func
is
None
:
pass
elif
table_func
==
'exp'
:
self
.
_table
=
np
.
exp
(
self
.
_table
)
self
.
_inv_table_func
=
'log'
elif
table_func
==
'log'
:
self
.
_table
=
np
.
log
(
self
.
_table
)
self
.
_inv_table_func
=
'exp'
elif
table_func
==
'power'
:
if
not
np
.
isscalar
(
exponent
):
return
NotImplemented
self
.
_table
=
np
.
power
(
self
.
_table
,
exponent
)
self
.
_inv_table_func
=
'__pow__'
self
.
_args
=
[
np
.
power
(
float
(
exponent
),
-
1
)]
else
:
return
NotImplemented
if
table_func
is
not
None
:
if
inv_table_func
is
None
:
raise
ValueError
a
=
func
(
np
.
random
.
randn
(
10
))
a1
=
_f_on_np
(
lambda
x
:
inv_table_func
(
table_func
(
x
)),
a
)
np
.
testing
.
assert_allclose
(
a
,
a1
)
self
.
_table
=
_f_on_np
(
table_func
,
self
.
_table
)
self
.
_inv_table_func
=
inv_table_func
self
.
_deriv
=
(
self
.
_table
[
1
:]
-
self
.
_table
[:
-
1
])
/
self
.
_d
def
apply
(
self
,
x
):
...
...
@@ -86,16 +82,12 @@ class _InterpolationOperator(Operator):
fi
=
np
.
floor
(
val
).
astype
(
int
)
w
=
val
-
fi
res
=
(
1
-
w
)
*
self
.
_table
[
fi
]
+
w
*
self
.
_table
[
fi
+
1
]
resfld
=
Field
(
self
.
_domain
,
res
)
if
not
lin
:
if
self
.
_transform
:
resfld
=
getattr
(
resfld
,
self
.
_inv_table_func
)(
*
self
.
_args
)
return
resfld
lin
=
Linearization
.
make_var
(
resfld
)
if
self
.
_transform
:
lin
=
getattr
(
lin
,
self
.
_inv_table_func
)(
*
self
.
_args
)
jac
=
makeOp
(
Field
(
self
.
_domain
,
self
.
_deriv
[
fi
]))
return
x
.
new
(
lin
.
val
,
lin
.
jac
@
jac
)
res
=
Field
(
self
.
_domain
,
res
)
if
lin
:
res
=
x
.
new
(
res
,
makeOp
(
Field
(
self
.
_domain
,
self
.
_deriv
[
fi
])))
if
self
.
_inv_table_func
is
not
None
:
res
=
self
.
_inv_table_func
(
res
)
return
res
def
InverseGammaOperator
(
domain
,
alpha
,
q
,
delta
=
0.001
):
...
...
@@ -127,8 +119,8 @@ def InverseGammaOperator(domain, alpha, q, delta=0.001):
delta : float
Distance between sampling points for linear interpolation.
"""
func
=
lambda
x
:
invgamma
.
ppf
(
norm
.
cdf
(
x
),
float
(
alpha
))
op
=
_InterpolationOperator
(
domain
,
func
,
-
8.2
,
8.2
,
delta
,
'log'
)
op
=
_InterpolationOperator
(
domain
,
lambda
x
:
invgamma
.
ppf
(
norm
.
cdf
(
x
),
float
(
alpha
))
,
-
8.2
,
8.2
,
delta
,
lambda
x
:
x
.
log
(),
lambda
x
:
x
.
exp
()
)
if
np
.
isscalar
(
q
):
return
op
.
scale
(
q
)
return
makeOp
(
q
)
@
op
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment