Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
393327d5
Commit
393327d5
authored
Apr 08, 2020
by
Martin Reinecke
Browse files
Merge branch 'pointwise' into pointwise_ng
parents
01779e03
55ec681a
Changes
12
Hide whitespace changes
Inline
Side-by-side
demos/bernoulli_demo.py
View file @
393327d5
...
...
@@ -52,7 +52,7 @@ if __name__ == '__main__':
A
=
ift
.
create_power_operator
(
harmonic_space
,
sqrtpspec
)
# Set up a sky operator and instrumental response
sky
=
HT
(
A
).
ptw
(
"
sigmoid
"
)
sky
=
ift
.
sigmoid
(
HT
(
A
)
)
GR
=
ift
.
GeometryRemover
(
position_space
)
R
=
GR
...
...
demos/getting_started_2.py
View file @
393327d5
...
...
@@ -80,7 +80,7 @@ if __name__ == '__main__':
A
=
pd
(
a
)
# Define sky operator
sky
=
HT
(
ift
.
makeOp
(
A
))
.
ptw
(
"exp"
)
sky
=
ift
.
exp
(
HT
(
ift
.
makeOp
(
A
)))
M
=
ift
.
DiagonalOperator
(
exposure
)
GR
=
ift
.
GeometryRemover
(
position_space
)
...
...
demos/getting_started_3.py
View file @
393327d5
...
...
@@ -85,7 +85,7 @@ if __name__ == '__main__':
A
=
cfmaker
.
amplitude
# Apply a nonlinearity
signal
=
correlated_field
.
ptw
(
"sigmoid"
)
signal
=
ift
.
sigmoid
(
correlated_field
)
# Build the line-of-sight response and define signal response
LOS_starts
,
LOS_ends
=
random_los
(
100
)
if
mode
==
0
else
radial_los
(
100
)
...
...
@@ -149,7 +149,7 @@ if __name__ == '__main__':
filename_res
=
filename
.
format
(
"results"
)
plot
=
ift
.
Plot
()
plot
.
add
(
sc
.
mean
,
title
=
"Posterior Mean"
)
plot
.
add
(
sc
.
var
.
ptw
(
"sqrt"
),
title
=
"Posterior Standard Deviation"
)
plot
.
add
(
ift
.
sqrt
(
sc
.
var
),
title
=
"Posterior Standard Deviation"
)
powers
=
[
A
.
force
(
s
+
KL
.
position
)
for
s
in
KL
.
samples
]
plot
.
add
(
...
...
demos/getting_started_mf.py
View file @
393327d5
...
...
@@ -84,7 +84,7 @@ if __name__ == '__main__':
DC
=
SingleDomain
(
correlated_field
.
target
,
position_space
)
## Apply a nonlinearity
signal
=
DC
@
correlated_field
.
ptw
(
"sigmoid"
)
signal
=
DC
@
ift
.
sigmoid
(
correlated_field
)
# Build the line-of-sight response and define signal response
LOS_starts
,
LOS_ends
=
random_los
(
100
)
if
mode
==
0
else
radial_los
(
100
)
...
...
@@ -170,7 +170,7 @@ if __name__ == '__main__':
filename_res
=
filename
.
format
(
"results"
)
plot
=
ift
.
Plot
()
plot
.
add
(
sc
.
mean
,
title
=
"Posterior Mean"
)
plot
.
add
(
sc
.
var
.
ptw
(
"sqrt"
),
title
=
"Posterior Standard Deviation"
)
plot
.
add
(
ift
.
sqrt
(
sc
.
var
),
title
=
"Posterior Standard Deviation"
)
powers1
=
[
A1
.
force
(
s
+
KL
.
position
)
for
s
in
KL
.
samples
]
powers2
=
[
A2
.
force
(
s
+
KL
.
position
)
for
s
in
KL
.
samples
]
...
...
nifty6/field.py
View file @
393327d5
...
...
@@ -686,9 +686,6 @@ class Field(Operator):
def
flexible_addsub
(
self
,
other
,
neg
):
return
self
-
other
if
neg
else
self
+
other
def
clip
(
self
,
a_min
=
None
,
a_max
=
None
):
return
self
.
ptw
(
"clip"
,
a_min
,
a_max
)
def
_binary_op
(
self
,
other
,
op
):
# if other is a field, make sure that the domains match
f
=
getattr
(
self
.
_val
,
op
)
...
...
@@ -700,20 +697,24 @@ class Field(Operator):
return
Field
(
self
.
_domain
,
f
(
other
))
return
NotImplemented
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
from
.pointwise
import
ptw_dict
def
_prep_args
(
self
,
args
,
kwargs
):
for
arg
in
args
+
tuple
(
kwargs
.
values
()):
if
not
(
arg
is
None
or
np
.
isscalar
(
arg
)
or
arg
.
jac
is
None
):
raise
TypeError
(
"bad argument"
)
argstmp
=
tuple
(
arg
if
arg
is
None
or
np
.
isscalar
(
arg
)
else
arg
.
_val
for
arg
in
args
)
kwargstmp
=
{
key
:
val
if
val
is
None
or
np
.
isscalar
(
val
)
else
val
.
_val
for
key
,
val
in
kwargs
.
items
()}
return
argstmp
,
kwargstmp
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
from
.pointwise
import
ptw_dict
argstmp
,
kwargstmp
=
self
.
_prep_args
(
args
,
kwargs
)
return
Field
(
self
.
_domain
,
ptw_dict
[
op
][
0
](
self
.
_val
,
*
argstmp
,
**
kwargstmp
))
def
ptw_with_deriv
(
self
,
op
,
*
args
,
**
kwargs
):
from
.pointwise
import
ptw_dict
argstmp
=
tuple
(
arg
if
arg
is
None
or
np
.
isscalar
(
arg
)
else
arg
.
_val
for
arg
in
args
)
kwargstmp
=
{
key
:
val
if
val
is
None
or
np
.
isscalar
(
val
)
else
val
.
_val
for
key
,
val
in
kwargs
.
items
()}
argstmp
,
kwargstmp
=
self
.
_prep_args
(
args
,
kwargs
)
tmp
=
ptw_dict
[
op
][
1
](
self
.
_val
,
*
argstmp
,
**
kwargstmp
)
return
(
Field
(
self
.
_domain
,
tmp
[
0
]),
Field
(
self
.
_domain
,
tmp
[
1
]))
...
...
nifty6/library/correlated_fields.py
View file @
393327d5
...
...
@@ -225,7 +225,7 @@ class _Normalization(Operator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
amp
=
x
.
ptw
(
"exp"
)
spec
=
amp
*
amp
spec
=
amp
*
*
2
# FIXME This normalizes also the zeromode which is supposed to be left
# untouched by this operator
return
self
.
_specsum
(
self
.
_mode_multiplicity
(
spec
))
**
(
-
0.5
)
*
amp
...
...
nifty6/linearization.py
View file @
393327d5
...
...
@@ -294,15 +294,6 @@ class Linearization(Operator):
t1
,
t2
=
self
.
_fld
.
ptw_with_deriv
(
op
,
*
args
,
**
kwargs
)
return
self
.
new
(
t1
,
makeOp
(
t2
)(
self
.
_jac
))
def
clip
(
self
,
a_min
=
None
,
a_max
=
None
):
if
a_min
is
None
and
a_max
is
None
:
return
self
if
not
(
a_min
is
None
or
np
.
isscalar
(
a_min
)
or
a_min
.
jac
is
None
):
return
NotImplemented
if
not
(
a_max
is
None
or
np
.
isscalar
(
a_max
)
or
a_max
.
jac
is
None
):
return
NotImplemented
return
self
.
ptw
(
"clip"
,
a_min
,
a_max
)
def
add_metric
(
self
,
metric
):
return
self
.
new
(
self
.
_fld
,
self
.
_jac
,
metric
)
...
...
nifty6/multi_field.py
View file @
393327d5
...
...
@@ -314,25 +314,27 @@ class MultiField(Operator):
res
[
key
]
=
-
val
if
neg
else
val
return
MultiField
.
from_dict
(
res
)
def
_prep_args
(
self
,
args
,
kwargs
,
i
):
for
arg
in
args
+
tuple
(
kwargs
.
values
()):
if
not
(
arg
is
None
or
np
.
isscalar
(
arg
)
or
arg
.
jac
is
None
):
raise
TypeError
(
"bad argument"
)
argstmp
=
tuple
(
arg
if
arg
is
None
or
np
.
isscalar
(
arg
)
else
arg
.
_val
[
i
]
for
arg
in
args
)
kwargstmp
=
{
key
:
val
if
val
is
None
or
np
.
isscalar
(
val
)
else
val
.
_val
[
i
]
for
key
,
val
in
kwargs
.
items
()}
return
argstmp
,
kwargstmp
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
# _check_args(args, kwargs)
tmp
=
[]
for
i
in
range
(
len
(
self
.
_val
)):
argstmp
=
tuple
(
arg
if
arg
is
None
or
np
.
isscalar
(
arg
)
else
arg
.
_val
[
i
]
for
arg
in
args
)
kwargstmp
=
{
key
:
val
if
val
is
None
or
np
.
isscalar
(
val
)
else
val
.
_val
[
i
]
for
key
,
val
in
kwargs
.
items
()}
argstmp
,
kwargstmp
=
self
.
_prep_args
(
args
,
kwargs
,
i
)
tmp
.
append
(
self
.
_val
[
i
].
ptw
(
op
,
*
argstmp
,
**
kwargstmp
))
return
MultiField
(
self
.
domain
,
tuple
(
tmp
))
def
ptw_with_deriv
(
self
,
op
,
*
args
,
**
kwargs
):
# _check_args(args, kwargs)
tmp
=
[]
for
i
in
range
(
len
(
self
.
_val
)):
argstmp
=
tuple
(
arg
if
arg
is
None
or
np
.
isscalar
(
arg
)
else
arg
.
_val
[
i
]
for
arg
in
args
)
kwargstmp
=
{
key
:
val
if
val
is
None
or
np
.
isscalar
(
val
)
else
val
.
_val
[
i
]
for
key
,
val
in
kwargs
.
items
()}
argstmp
,
kwargstmp
=
self
.
_prep_args
(
args
,
kwargs
,
i
)
tmp
.
append
(
self
.
_val
[
i
].
ptw_with_deriv
(
op
,
*
argstmp
,
**
kwargstmp
))
return
(
MultiField
(
self
.
domain
,
tuple
(
v
[
0
]
for
v
in
tmp
)),
MultiField
(
self
.
domain
,
tuple
(
v
[
1
]
for
v
in
tmp
)))
...
...
nifty6/operators/operator.py
View file @
393327d5
...
...
@@ -18,6 +18,7 @@
import
numpy
as
np
from
..utilities
import
NiftyMeta
,
indent
from
..
import
pointwise
class
Operator
(
metaclass
=
NiftyMeta
):
...
...
@@ -221,15 +222,6 @@ class Operator(metaclass=NiftyMeta):
return
NotImplemented
return
self
.
ptw
(
"power"
,
power
)
def
clip
(
self
,
a_min
=
None
,
a_max
=
None
):
if
a_min
is
None
and
a_max
is
None
:
return
self
if
not
(
a_min
is
None
or
np
.
isscalar
(
a_min
)
or
a_min
.
jac
is
None
):
return
NotImplemented
if
not
(
a_max
is
None
or
np
.
isscalar
(
a_max
)
or
a_max
.
jac
is
None
):
return
NotImplemented
return
self
.
ptw
(
"clip"
,
a_min
,
a_max
)
def
apply
(
self
,
x
):
"""Applies the operator to a Field or MultiField.
...
...
@@ -292,6 +284,14 @@ class Operator(metaclass=NiftyMeta):
return
_OpChain
.
make
((
_FunctionApplier
(
self
.
target
,
op
,
*
args
,
**
kwargs
),
self
))
for
f
in
pointwise
.
ptw_dict
.
keys
():
def
func
(
f
):
def
func2
(
self
,
*
args
,
**
kwargs
):
return
self
.
ptw
(
f
,
*
args
,
**
kwargs
)
return
func2
setattr
(
Operator
,
f
,
func
(
f
))
class
_ConstCollector
(
object
):
def
__init__
(
self
):
self
.
_const
=
None
...
...
nifty6/pointwise.py
View file @
393327d5
...
...
@@ -67,7 +67,7 @@ def _power_helper(v, expo):
return
(
np
.
power
(
v
,
expo
),
expo
*
np
.
power
(
v
,
expo
-
1
))
def
_clip_helper
(
v
,
a_min
=
None
,
a_max
=
None
):
def
_clip_helper
(
v
,
a_min
,
a_max
):
if
np
.
issubdtype
(
v
.
dtype
,
np
.
complexfloating
):
raise
TypeError
(
"Argument must not be complex"
)
tmp
=
np
.
clip
(
v
,
a_min
,
a_max
)
...
...
nifty6/sugar.py
View file @
393327d5
...
...
@@ -33,13 +33,15 @@ from .operators.distributors import PowerDistributor
from
.operators.operator
import
Operator
from
.operators.scaling_operator
import
ScalingOperator
from
.plot
import
Plot
from
.
import
pointwise
__all__
=
[
'PS_field'
,
'power_analyze'
,
'create_power_operator'
,
'create_harmonic_smoothing_operator'
,
'from_random'
,
'full'
,
'makeField'
,
'makeDomain'
,
'get_signal_variance'
,
'makeOp'
,
'domain_union'
,
'get_default_codomain'
,
'single_plot'
,
'exec_time'
,
'calculate_position'
]
'calculate_position'
]
+
list
(
pointwise
.
ptw_dict
.
keys
())
def
PS_field
(
pspace
,
func
):
...
...
@@ -341,7 +343,7 @@ def makeOp(input, dom=None):
if
input
is
None
:
return
None
if
np
.
isscalar
(
input
):
if
not
isinstance
(
dom
,
(
DomaiTuple
,
MultiDomain
)):
if
not
isinstance
(
dom
,
(
Domai
n
Tuple
,
MultiDomain
)):
raise
TypeError
(
"need proper `dom` argument"
)
return
SalingOperator
(
dom
,
input
)
if
dom
is
not
None
:
...
...
@@ -373,8 +375,16 @@ def domain_union(domains):
return
MultiDomain
.
union
(
domains
)
def
clip
(
a
,
a_min
=
None
,
a_max
=
None
):
return
a
.
clip
(
a_min
,
a_max
)
# Pointwise functions
_current_module
=
sys
.
modules
[
__name__
]
for
f
in
pointwise
.
ptw_dict
.
keys
():
def
func
(
f
):
def
func2
(
x
,
*
args
,
**
kwargs
):
return
x
.
ptw
(
f
,
*
args
,
**
kwargs
)
return
func2
setattr
(
_current_module
,
f
,
func
(
f
))
def
get_default_codomain
(
domainoid
,
space
=
None
):
...
...
test/test_field.py
View file @
393327d5
...
...
@@ -193,8 +193,8 @@ def test_empty_domain():
def
test_trivialities
():
s1
=
ift
.
RGSpace
((
10
,))
f1
=
ift
.
Field
.
full
(
s1
,
27
)
assert_equal
(
f1
.
clip
(
a_min
=
29
).
val
,
29.
)
assert_equal
(
f1
.
clip
(
a_max
=
25
).
val
,
25.
)
assert_equal
(
f1
.
clip
(
a_min
=
29
,
a_max
=
50
).
val
,
29.
)
assert_equal
(
f1
.
clip
(
a_min
=
0
,
a_max
=
25
).
val
,
25.
)
assert_equal
(
f1
.
val
,
f1
.
real
.
val
)
assert_equal
(
f1
.
val
,
(
+
f1
).
val
)
f1
=
ift
.
Field
.
full
(
s1
,
27.
+
3j
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment