Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
9c429547
Commit
9c429547
authored
Jan 25, 2017
by
Theo Steininger
Browse files
Finished refactoring of probing classes. Now uses mixin-classes.
parent
cd593c77
Changes
13
Show whitespace changes
Inline
Side-by-side
nifty/nifty_utilities.py
View file @
9c429547
...
...
@@ -278,3 +278,36 @@ def get_default_codomain(domain):
return
LMGLTransformation
.
get_codomain
(
domain
)
else
:
raise
TypeError
(
'ERROR: unknown domain'
)
def
parse_domain
(
domain
):
from
nifty.spaces.space
import
Space
if
domain
is
None
:
domain
=
()
elif
isinstance
(
domain
,
Space
):
domain
=
(
domain
,)
elif
not
isinstance
(
domain
,
tuple
):
domain
=
tuple
(
domain
)
for
d
in
domain
:
if
not
isinstance
(
d
,
Space
):
raise
TypeError
(
"Given object contains something that is not a "
"nifty.space."
)
return
domain
def
parse_field_type
(
field_type
):
from
nifty.field_types
import
FieldType
if
field_type
is
None
:
field_type
=
()
elif
isinstance
(
field_type
,
FieldType
):
field_type
=
(
field_type
,)
elif
not
isinstance
(
field_type
,
tuple
):
field_type
=
tuple
(
field_type
)
for
ft
in
field_type
:
if
not
isinstance
(
ft
,
FieldType
):
raise
TypeError
(
"Given object is not a nifty.FieldType."
)
return
field_type
nifty/operators/fft_operator/__init__.py
View file @
9c429547
from
transformations
import
*
from
fft_operator
import
FFTOperator
nifty/operators/linear_operator/linear_operator.py
View file @
9c429547
...
...
@@ -4,8 +4,6 @@ import abc
from
keepers
import
Loggable
from
nifty.field
import
Field
from
nifty.spaces
import
Space
from
nifty.field_types
import
FieldType
import
nifty.nifty_utilities
as
utilities
...
...
@@ -16,33 +14,10 @@ class LinearOperator(Loggable, object):
pass
def
_parse_domain
(
self
,
domain
):
if
domain
is
None
:
domain
=
()
elif
isinstance
(
domain
,
Space
):
domain
=
(
domain
,)
elif
not
isinstance
(
domain
,
tuple
):
domain
=
tuple
(
domain
)
for
d
in
domain
:
if
not
isinstance
(
d
,
Space
):
raise
TypeError
(
"Given object contains something that is not a "
"nifty.space."
)
return
domain
return
utilities
.
parse_domain
(
domain
)
def
_parse_field_type
(
self
,
field_type
):
if
field_type
is
None
:
field_type
=
()
elif
isinstance
(
field_type
,
FieldType
):
field_type
=
(
field_type
,)
elif
not
isinstance
(
field_type
,
tuple
):
field_type
=
tuple
(
field_type
)
for
ft
in
field_type
:
if
not
isinstance
(
ft
,
FieldType
):
raise
TypeError
(
"Given object is not a nifty.FieldType."
)
return
field_type
return
utilities
.
parse_field_type
(
field_type
)
@
abc
.
abstractproperty
def
domain
(
self
):
...
...
nifty/operators/probing_operator/diagonal_prober.py
deleted
100644 → 0
View file @
cd593c77
# -*- coding: utf-8 -*-
from
prober
import
Prober
class
DiagonalProber
(
Prober
):
# ---Mandatory properties and methods---
def
finish_probe
(
self
,
probe
,
pre_result
):
return
probe
[
1
].
conjugate
()
*
pre_result
nifty/operators/probing_operator/trace_prober.py
deleted
100644 → 0
View file @
cd593c77
# -*- coding: utf-8 -*-
from
prober
import
Prober
class
TraceProber
(
Prober
):
# ---Mandatory properties and methods---
def
finish_probe
(
self
,
probe
,
pre_result
):
return
probe
[
1
].
conjugate
().
weight
(
power
=-
1
).
dot
(
pre_result
)
nifty/
operators/probing_operator
/__init__.py
→
nifty/
probing
/__init__.py
View file @
9c429547
# -*- coding: utf-8 -*-
from
prober
import
Prober
from
diagonal_prober
import
*
from
trace_prober
import
*
from
mixin_classes
import
*
nifty/probing/mixin_classes/__init__.py
0 → 100644
View file @
9c429547
# -*- coding: utf-8 -*-
from
mixin_base
import
MixinBase
from
diagonal_prober_mixin
import
DiagonalProberMixin
from
trace_prober_mixin
import
TraceProberMixin
nifty/probing/mixin_classes/diagonal_prober_mixin.py
0 → 100644
View file @
9c429547
# -*- coding: utf-8 -*-
from
mixin_base
import
MixinBase
class
DiagonalProberMixin
(
MixinBase
):
def
__init__
(
self
):
self
.
reset
()
super
(
DiagonalProberMixin
,
self
).
__init__
()
def
reset
(
self
):
self
.
__sum_of_probings
=
0
self
.
__sum_of_squares
=
0
self
.
__diagonal
=
None
self
.
__diagonal_variance
=
None
super
(
DiagonalProberMixin
,
self
).
reset
()
def
finish_probe
(
self
,
probe
,
pre_result
):
result
=
probe
[
1
].
conjugate
()
*
pre_result
self
.
__sum_of_probings
+=
result
if
self
.
compute_variance
:
self
.
__sum_of_squares
+=
result
.
conjugate
()
*
result
super
(
DiagonalProberMixin
,
self
).
finish_probe
(
probe
,
pre_result
)
@
property
def
diagonal
(
self
):
if
self
.
__diagonal
is
None
:
self
.
__diagonal
=
self
.
__sum_of_probings
/
self
.
probe_count
return
self
.
__diagonal
@
property
def
diagonal_variance
(
self
):
if
not
self
.
compute_variance
:
raise
AttributeError
(
"self.compute_variance is set to False"
)
if
self
.
__diagonal_variance
is
None
:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
n
=
self
.
probe_count
sum_pr
=
self
.
__sum_of_probings
mean
=
self
.
diagonal
sum_sq
=
self
.
__sum_of_squares
self
.
__diagonal_variance
=
((
sum_sq
-
sum_pr
*
mean
)
/
(
n
-
1
))
return
self
.
__diagonal_variance
nifty/probing/mixin_classes/mixin_base.py
0 → 100644
View file @
9c429547
# -*- coding: utf-8 -*-
class
MixinBase
(
object
):
def
reset
(
self
,
*
args
,
**
kwargs
):
pass
def
finish_probe
(
self
,
*
args
,
**
kwargs
):
pass
nifty/probing/mixin_classes/trace_prober_mixin.py
0 → 100644
View file @
9c429547
# -*- coding: utf-8 -*-
from
mixin_base
import
MixinBase
class
TraceProberMixin
(
MixinBase
):
def
__init__
(
self
):
self
.
reset
()
super
(
TraceProberMixin
,
self
).
__init__
()
def
reset
(
self
):
self
.
__sum_of_probings
=
0
self
.
__sum_of_squares
=
0
self
.
__trace
=
None
self
.
__trace_variance
=
None
super
(
TraceProberMixin
,
self
).
reset
()
def
finish_probe
(
self
,
probe
,
pre_result
):
result
=
probe
[
1
].
dot
(
pre_result
,
bare
=
True
)
self
.
__sum_of_probings
+=
result
if
self
.
compute_variance
:
self
.
__sum_of_squares
+=
result
.
conjugate
()
*
result
super
(
TraceProberMixin
,
self
).
finish_probe
(
probe
,
pre_result
)
@
property
def
trace
(
self
):
if
self
.
__trace
is
None
:
self
.
__trace
=
self
.
__sum_of_probings
/
self
.
probe_count
return
self
.
__trace
@
property
def
trace_variance
(
self
):
if
not
self
.
compute_variance
:
raise
AttributeError
(
"self.compute_variance is set to False"
)
if
self
.
__trace_variance
is
None
:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
n
=
self
.
probe_count
sum_pr
=
self
.
__sum_of_probings
mean
=
self
.
trace
sum_sq
=
self
.
__sum_of_squares
self
.
__trace_variance
=
((
sum_sq
-
sum_pr
*
mean
)
/
(
n
-
1
))
return
self
.
__trace_variance
nifty/probing/prober/__init__.py
0 → 100644
View file @
9c429547
# -*- coding: utf-8 -*-
from
prober
import
Prober
nifty/
operators/probing_operator/probing_operato
r.py
→
nifty/
probing/prober/probe
r.py
View file @
9c429547
...
...
@@ -4,33 +4,43 @@ import abc
import
numpy
as
np
from
nifty.field_types
import
FieldType
from
nifty.spaces
import
Space
from
nifty.field
import
Field
from
nifty.operators.endomorphic_operator
import
EndomorphicOperator
import
nifty.nifty_utilities
as
utilities
from
nifty
import
nifty_configuration
as
nc
from
d2o
import
STRATEGIES
as
DISTRIBUTION_STRATEGIES
class
Prob
ingOperator
(
EndomorphicOperator
):
class
Prob
er
(
object
):
"""
aka DiagonalProbingOperator
See the following webpages for the principles behind the usage of
mixin-classes
https://www.python.org/download/releases/2.2.3/descrintro/#cooperation
https://rhettinger.wordpress.com/2011/05/26/super-considered-super/
"""
# ---Overwritten properties and methods---
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
domain
=
None
,
field_type
=
None
,
distribution_strategy
=
None
,
probe_count
=
8
,
random_type
=
'pm1'
,
compute_variance
=
False
):
self
.
_domain
=
self
.
_
parse_domain
(
domain
)
self
.
_field_type
=
self
.
_
parse_field_type
(
field_type
)
self
.
_domain
=
utilities
.
parse_domain
(
domain
)
self
.
_field_type
=
utilities
.
parse_field_type
(
field_type
)
self
.
_distribution_strategy
=
\
self
.
_parse_distribution_strategy
(
distribution_strategy
)
self
.
distribution_strategy
=
distribution_strategy
self
.
probe_count
=
probe_count
self
.
random_type
=
random_type
self
.
_probe_count
=
self
.
_parse_probe_count
(
probe_count
)
self
.
_random_type
=
self
.
_parse_random_type
(
random_type
)
self
.
compute_variance
=
bool
(
compute_variance
)
# ---Mandatory properties and methods---
super
(
Prober
,
self
).
__init__
()
# ---Properties---
@
property
def
domain
(
self
):
...
...
@@ -40,57 +50,49 @@ class ProbingOperator(EndomorphicOperator):
def
field_type
(
self
):
return
self
.
_field_type
# ---Added properties and methods---
@
property
def
distribution_strategy
(
self
):
return
self
.
_distribution_strategy
def
_parse_distribution_strategy
(
self
,
distribution_strategy
):
if
distribution_strategy
is
None
:
distribution_strategy
=
nc
[
'default_distribution_strategy'
]
else
:
distribution_strategy
=
str
(
distribution_strategy
)
if
distribution_strategy
not
in
DISTRIBUTION_STRATEGIES
[
'global'
]:
raise
ValueError
(
"distribution_strategy must be a global-type "
"strategy."
)
return
distribution_strategy
self
.
_distribution_strategy
=
distribution_strategy
@
property
def
probe_count
(
self
):
return
self
.
_probe_count
@
probe_count
.
setter
def
probe_count
(
self
,
probe_count
):
self
.
_probe_count
=
int
(
probe_count
)
def
_parse_probe_count
(
self
,
probe_count
):
return
int
(
probe_count
)
@
property
def
random_type
(
self
):
return
self
.
_random_type
@
random_type
.
setter
def
random_type
(
self
,
random_type
):
def
_parse_random_type
(
self
,
random_type
):
if
random_type
not
in
[
"pm1"
,
"normal"
]:
raise
ValueError
(
"unsupported random type: '"
+
str
(
random_type
)
+
"'."
)
else
:
self
.
_random_type
=
random_type
return
random_type
# ---Probing methods---
def
probing_run
(
self
,
callee
):
""" controls the generation, evaluation and finalization of probes """
sum_of_probes
=
0
sum_of_squares
=
0
self
.
reset
()
for
index
in
xrange
(
self
.
probe_count
):
current_probe
=
self
.
get_probe
(
index
)
pre_result
=
self
.
process_probe
(
callee
,
current_probe
,
index
)
result
=
self
.
finish_probe
(
current_probe
,
pre_result
)
sum_of_probes
+=
result
if
self
.
compute_variance
:
sum_of_squares
+=
result
.
conjugate
()
*
result
self
.
finish_probe
(
current_probe
,
pre_result
)
mean_and_variance
=
self
.
finalize
(
sum_of_probes
,
sum_of_squares
)
return
mean_and_variance
def
reset
(
self
):
super
(
Prober
,
self
).
reset
()
def
get_probe
(
self
,
index
):
""" layer of abstraction for potential probe-caching """
...
...
@@ -113,21 +115,8 @@ class ProbingOperator(EndomorphicOperator):
""" processes a probe """
return
callee
(
probe
,
**
kwargs
)
@
abc
.
abstractmethod
def
finish_probe
(
self
,
probe
,
pre_result
):
return
pre_result
def
finalize
(
self
,
sum_of_probes
,
sum_of_squares
):
probe_count
=
self
.
probe_count
mean_of_probes
=
sum_of_probes
/
probe_count
if
self
.
compute_variance
:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
variance
=
((
sum_of_squares
-
sum_of_probes
*
mean_of_probes
)
/
(
probe_count
-
1
))
else
:
variance
=
None
return
(
mean_of_probes
,
variance
)
super
(
Prober
,
self
).
finish_probe
(
probe
,
pre_result
)
def
__call__
(
self
,
callee
):
return
self
.
probing_run
(
callee
)
nifty/spaces/space/space.py
View file @
9c429547
...
...
@@ -150,8 +150,7 @@ from keepers import Loggable,\
Versionable
class
Space
(
Versionable
,
Loggable
,
Plottable
,
object
):
class
Space
(
Versionable
,
Loggable
,
object
):
"""
.. __ __
.. /__/ / /_
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment