Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
ift
NIFTy
Commits
9b6d9324
Commit
9b6d9324
authored
Dec 08, 2017
by
Martin Reinecke
Browse files
tweaks
parent
5edadf46
Pipeline
#22717
passed with stage
in 4 minutes and 47 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/field.py
View file @
9b6d9324
...
...
@@ -59,7 +59,7 @@ class Field(object):
"""
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
copy
=
False
):
self
.
domain
=
self
.
_
parse
_domain
(
domain
=
domain
,
val
=
val
)
self
.
domain
=
self
.
_
infer
_domain
(
domain
=
domain
,
val
=
val
)
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
)
if
isinstance
(
val
,
Field
):
...
...
@@ -128,7 +128,7 @@ class Field(object):
return
Field
.
empty
(
field
.
domain
,
dtype
)
@
staticmethod
def
_
parse
_domain
(
domain
,
val
=
None
):
def
_
infer
_domain
(
domain
,
val
=
None
):
if
domain
is
None
:
if
isinstance
(
val
,
Field
):
return
val
.
domain
...
...
nifty/library/wiener_filter_curvature.py
View file @
9b6d9324
...
...
@@ -71,8 +71,7 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator):
noise
=
self
.
N
.
diagonal
().
weight
(
-
1
)
mock_noise
=
Field
.
from_random
(
random_type
=
"normal"
,
domain
=
self
.
N
.
domain
,
dtype
=
noise
.
dtype
.
type
)
domain
=
self
.
N
.
domain
,
dtype
=
noise
.
dtype
)
mock_noise
*=
sqrt
(
noise
)
mock_data
=
self
.
R
(
mock_signal
)
+
mock_noise
...
...
nifty/probing/prober.py
View file @
9b6d9324
...
...
@@ -21,7 +21,6 @@ from builtins import range
from
builtins
import
object
import
numpy
as
np
from
..field
import
Field
,
DomainTuple
from
..
import
utilities
class
Prober
(
object
):
...
...
@@ -38,15 +37,16 @@ class Prober(object):
compute_variance
=
False
,
ncpu
=
1
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_probe_count
=
self
.
_parse_probe_count
(
probe_count
)
self
.
_ncpu
=
self
.
_parse_probe_count
(
ncpu
)
self
.
_random_type
=
self
.
_parse_random_type
(
random_type
)
self
.
_probe_count
=
int
(
probe_count
)
self
.
_ncpu
=
int
(
ncpu
)
if
random_type
not
in
[
"pm1"
,
"normal"
]:
raise
ValueError
(
"unsupported random type: '"
+
str
(
random_type
)
+
"'."
)
self
.
_random_type
=
random_type
self
.
compute_variance
=
bool
(
compute_variance
)
self
.
probe_dtype
=
np
.
dtype
(
probe_dtype
)
self
.
_uid_counter
=
0
# ---Properties---
@
property
def
domain
(
self
):
return
self
.
_domain
...
...
@@ -55,22 +55,11 @@ class Prober(object):
def
probe_count
(
self
):
return
self
.
_probe_count
def
_parse_probe_count
(
self
,
probe_count
):
return
int
(
probe_count
)
@
property
def
random_type
(
self
):
return
self
.
_random_type
def
_parse_random_type
(
self
,
random_type
):
if
random_type
not
in
[
"pm1"
,
"normal"
]:
raise
ValueError
(
"unsupported random type: '"
+
str
(
random_type
)
+
"'."
)
return
random_type
# ---Probing methods---
def
gen_parallel_probe
(
self
,
callee
):
def
gen_parallel_probe
(
self
,
callee
):
for
i
in
range
(
self
.
probe_count
):
yield
(
callee
,
self
.
get_probe
(
i
))
...
...
@@ -87,7 +76,7 @@ class Prober(object):
pool
=
Pool
(
self
.
_ncpu
)
for
i
in
pool
.
imap_unordered
(
self
.
evaluate_probe_parallel
,
self
.
gen_parallel_probe
(
callee
)):
self
.
finish_probe
(
i
[
0
],
i
[
1
])
self
.
finish_probe
(
i
[
0
],
i
[
1
])
def
evaluate_probe_parallel
(
self
,
argtuple
):
callee
=
argtuple
[
0
]
...
...
@@ -104,8 +93,7 @@ class Prober(object):
def
generate_probe
(
self
):
""" a random-probe generator """
f
=
Field
.
from_random
(
random_type
=
self
.
random_type
,
domain
=
self
.
domain
,
dtype
=
self
.
probe_dtype
.
type
)
domain
=
self
.
domain
,
dtype
=
self
.
probe_dtype
)
uid
=
self
.
_uid_counter
self
.
_uid_counter
+=
1
return
(
uid
,
f
)
...
...
nifty/spaces/rg_space.py
View file @
9b6d9324
...
...
@@ -53,7 +53,19 @@ class RGSpace(Space):
if
np
.
isscalar
(
shape
):
shape
=
(
shape
,)
self
.
_shape
=
tuple
(
int
(
i
)
for
i
in
shape
)
self
.
_distances
=
self
.
_parse_distances
(
distances
)
if
distances
is
None
:
if
self
.
harmonic
:
self
.
_distances
=
(
1.
,)
*
len
(
self
.
_shape
)
else
:
self
.
_distances
=
tuple
(
1.
/
s
for
s
in
self
.
_shape
)
elif
np
.
isscalar
(
distances
):
self
.
_distances
=
(
float
(
distances
),)
*
len
(
self
.
_shape
)
else
:
temp
=
np
.
empty
(
len
(
self
.
shape
),
dtype
=
np
.
float64
)
temp
[:]
=
distances
self
.
_distances
=
tuple
(
temp
)
self
.
_dvol
=
float
(
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
_distances
))
self
.
_dim
=
int
(
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
_shape
))
...
...
@@ -169,14 +181,3 @@ class RGSpace(Space):
distance between neighboring grid points along the n-th dimension.
"""
return
self
.
_distances
def
_parse_distances
(
self
,
distances
):
if
distances
is
None
:
if
self
.
harmonic
:
temp
=
np
.
ones_like
(
self
.
shape
,
dtype
=
np
.
float64
)
else
:
temp
=
1.
/
np
.
array
(
self
.
shape
,
dtype
=
np
.
float64
)
else
:
temp
=
np
.
empty
(
len
(
self
.
shape
),
dtype
=
np
.
float64
)
temp
[:]
=
distances
return
tuple
(
temp
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment