Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
0bef1520
Commit
0bef1520
authored
Mar 30, 2020
by
Martin Reinecke
Browse files
introduce new class hierarchy
parent
193a276f
Changes
10
Hide whitespace changes
Inline
Side-by-side
nifty6/extra.py
View file @
0bef1520
...
...
@@ -284,7 +284,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True):
linmid
=
op
(
Linearization
.
make_var
(
locmid
))
dirder
=
linmid
.
jac
(
dir
)
numgrad
=
(
lin2
.
val
-
lin
.
val
)
xtol
=
tol
*
dirder
.
norm
()
/
np
.
sqrt
(
dirder
.
size
)
xtol
=
tol
*
dirder
.
norm
()
/
np
.
sqrt
(
dirder
.
target
.
size
)
hist
.
append
((
numgrad
-
dirder
).
norm
())
# print(len(hist),hist[-1])
if
(
abs
(
numgrad
-
dirder
)
<=
xtol
).
s_all
():
...
...
nifty6/field.py
View file @
0bef1520
...
...
@@ -19,10 +19,11 @@ from functools import reduce
import
numpy
as
np
from
.
import
utilities
from
.operators.operator
import
Operator
from
.domain_tuple
import
DomainTuple
class
Field
(
object
):
class
Field
(
Operator
):
"""The discrete representation of a continuous field over multiple spaces.
Stores data arrays and carries all the needed meta-information (i.e. the
...
...
@@ -161,6 +162,26 @@ class Field(object):
"""
return
self
.
_val
.
copy
()
@
property
def
jac
(
self
):
return
None
@
property
def
want_metric
(
self
):
return
False
@
property
def
metric
(
self
):
raise
NotImplementedError
()
def
__call__
(
self
,
other
):
if
(
other
.
target
==
self
.
domain
):
return
self
raise
ValueError
(
"domain mismatch"
)
def
__matmul__
(
self
,
other
):
return
self
(
other
)
@
property
def
dtype
(
self
):
"""type : the data type of the field's entries"""
...
...
@@ -172,14 +193,9 @@ class Field(object):
return
self
.
_domain
@
property
def
shape
(
self
):
"""tuple of int : the concatenated shapes of all sub-domains"""
return
self
.
_domain
.
shape
@
property
def
size
(
self
):
"""int : total number of pixels in the field"""
return
self
.
_domain
.
size
def
target
(
self
):
"""DomainTuple : the field's domain"""
return
self
.
_domain
@
property
def
real
(
self
):
...
...
@@ -255,7 +271,7 @@ class Field(object):
if
np
.
isscalar
(
wgt
):
fct
*=
wgt
else
:
new_shape
=
np
.
ones
(
len
(
self
.
shape
),
dtype
=
np
.
int
)
new_shape
=
np
.
ones
(
len
(
self
.
_domain
.
shape
),
dtype
=
np
.
int
)
new_shape
[
self
.
_domain
.
axes
[
ind
][
0
]:
self
.
_domain
.
axes
[
ind
][
-
1
]
+
1
]
=
wgt
.
shape
wgt
=
wgt
.
reshape
(
new_shape
)
...
...
nifty6/linearization.py
View file @
0bef1520
...
...
@@ -17,13 +17,14 @@
import
numpy
as
np
from
.operators.operator
import
Operator
from
.field
import
Field
from
.multi_field
import
MultiField
from
.sugar
import
makeOp
from
.
import
utilities
class
Linearization
(
object
):
class
Linearization
(
Operator
):
"""Let `A` be an operator and `x` a field. `Linearization` stores the value
of the operator application (i.e. `A(x)`), the local Jacobian
(i.e. `dA(x)/dx`) and, optionally, the local metric.
...
...
@@ -118,6 +119,14 @@ class Linearization(object):
"""
return
self
.
_metric
def
__call__
(
self
,
other
):
if
(
other
.
target
==
self
.
domain
):
return
self
raise
ValueError
(
"domain mismatch"
)
def
__matmul__
(
self
,
other
):
return
self
(
other
)
def
__getitem__
(
self
,
name
):
return
self
.
new
(
self
.
_val
[
name
],
self
.
_jac
.
ducktape_left
(
name
))
...
...
nifty6/minimization/scipy_minimizer.py
View file @
0bef1520
...
...
@@ -54,7 +54,7 @@ def _toArray_rw(fld):
def
_toField
(
arr
,
template
):
if
isinstance
(
template
,
Field
):
return
Field
(
template
.
domain
,
arr
.
reshape
(
template
.
shape
).
copy
())
return
Field
(
template
.
domain
,
arr
.
reshape
(
template
.
domain
.
shape
).
copy
())
ofs
=
0
res
=
[]
for
v
in
template
.
values
():
...
...
nifty6/multi_field.py
View file @
0bef1520
...
...
@@ -18,12 +18,13 @@
import
numpy
as
np
from
.
import
utilities
from
.operators.operator
import
Operator
from
.field
import
Field
from
.multi_domain
import
MultiDomain
from
.domain_tuple
import
DomainTuple
class
MultiField
(
object
):
class
MultiField
(
Operator
):
def
__init__
(
self
,
domain
,
val
):
"""The discrete representation of a continuous field over a sum space.
...
...
@@ -82,6 +83,10 @@ class MultiField(object):
def
domain
(
self
):
return
self
.
_domain
@
property
def
target
(
self
):
return
self
.
_domain
# @property
# def dtype(self):
# return {key: val.dtype for key, val in self._val.items()}
...
...
@@ -144,6 +149,26 @@ class MultiField(object):
return
{
key
:
val
.
val_rw
()
for
key
,
val
in
zip
(
self
.
_domain
.
keys
(),
self
.
_val
)}
@
property
def
jac
(
self
):
return
None
@
property
def
want_metric
(
self
):
return
False
@
property
def
metric
(
self
):
raise
NotImplementedError
()
def
__call__
(
self
,
other
):
if
(
other
.
target
==
self
.
domain
):
return
self
raise
ValueError
(
"domain mismatch"
)
def
__matmul__
(
self
,
other
):
return
self
(
other
)
@
staticmethod
def
from_raw
(
domain
,
arr
):
return
MultiField
(
...
...
@@ -179,17 +204,6 @@ class MultiField(object):
"""
return
utilities
.
my_sum
(
map
(
lambda
v
:
v
.
s_sum
(),
self
.
_val
))
@
property
def
size
(
self
):
"""Computes the overall degrees of freedom.
Returns
-------
size : int
The sum of the size of the individual fields
"""
return
utilities
.
my_sum
(
map
(
lambda
d
:
d
.
size
,
self
.
_domain
.
domains
()))
def
__neg__
(
self
):
return
self
.
_transform
(
lambda
x
:
-
x
)
...
...
nifty6/operators/energy_operators.py
View file @
0bef1520
...
...
@@ -262,7 +262,7 @@ class InverseGammaLikelihood(EnergyOperator):
self
.
_domain
=
DomainTuple
.
make
(
beta
.
domain
)
self
.
_beta
=
beta
if
np
.
isscalar
(
alpha
):
alpha
=
Field
(
beta
.
domain
,
np
.
full
(
beta
.
shape
,
alpha
))
alpha
=
Field
(
beta
.
domain
,
np
.
full
(
beta
.
target
.
shape
,
alpha
))
elif
not
isinstance
(
alpha
,
Field
):
raise
TypeError
self
.
_alphap1
=
alpha
+
1
...
...
nifty6/operators/operator.py
View file @
0bef1520
...
...
@@ -17,8 +17,6 @@
import
numpy
as
np
from
..field
import
Field
from
..multi_field
import
MultiField
from
..utilities
import
NiftyMeta
,
indent
...
...
@@ -179,6 +177,8 @@ class Operator(metaclass=NiftyMeta):
return
self
.
apply
(
x
.
extract
(
self
.
domain
))
def
_check_input
(
self
,
x
):
from
..field
import
Field
from
..multi_field
import
MultiField
from
..linearization
import
Linearization
from
.scaling_operator
import
ScalingOperator
if
not
isinstance
(
x
,
(
Field
,
MultiField
,
Linearization
)):
...
...
nifty6/operators/outer_product_operator.py
View file @
0bef1520
...
...
@@ -44,6 +44,6 @@ class OuterProduct(LinearOperator):
return
Field
(
self
.
_target
,
np
.
multiply
.
outer
(
self
.
_field
.
val
,
x
.
val
))
axes
=
len
(
self
.
_field
.
shape
)
axes
=
len
(
self
.
_field
.
target
.
shape
)
return
Field
(
self
.
_domain
,
np
.
tensordot
(
self
.
_field
.
val
,
x
.
val
,
axes
))
test/test_field.py
View file @
0bef1520
...
...
@@ -29,8 +29,7 @@ SPACE_COMBINATIONS = [(), SPACES[0], SPACES[1], SPACES]
@
pmp
(
'domain'
,
SPACE_COMBINATIONS
)
@
pmp
(
'attribute_desired_type'
,
[[
'domain'
,
ift
.
DomainTuple
],
[
'val'
,
np
.
ndarray
],
[
'shape'
,
tuple
],
[
'size'
,
(
np
.
int
,
np
.
int64
)]])
[[
'domain'
,
ift
.
DomainTuple
],
[
'val'
,
np
.
ndarray
]])
def
test_return_types
(
domain
,
attribute_desired_type
):
attribute
=
attribute_desired_type
[
0
]
desired_type
=
attribute_desired_type
[
1
]
...
...
@@ -288,18 +287,18 @@ def test_stdfunc():
s
=
ift
.
RGSpace
((
200
,))
f
=
ift
.
Field
.
full
(
s
,
27
)
assert_equal
(
f
.
val
,
27
)
assert_equal
(
f
.
shape
,
(
200
,))
assert_equal
(
f
.
target
.
shape
,
(
200
,))
assert_equal
(
f
.
dtype
,
np
.
int
)
fx
=
ift
.
full
(
f
.
domain
,
0
)
assert_equal
(
f
.
dtype
,
fx
.
dtype
)
assert_equal
(
f
.
shape
,
fx
.
shape
)
assert_equal
(
f
.
target
.
shape
,
fx
.
target
.
shape
)
assert_equal
(
fx
.
val
,
0
)
fx
=
ift
.
full
(
f
.
domain
,
1
)
assert_equal
(
f
.
dtype
,
fx
.
dtype
)
assert_equal
(
f
.
shape
,
fx
.
shape
)
assert_equal
(
f
.
target
.
shape
,
fx
.
target
.
shape
)
assert_equal
(
fx
.
val
,
1
)
fx
=
ift
.
full
(
f
.
domain
,
67.
)
assert_equal
(
f
.
shape
,
fx
.
shape
)
assert_equal
(
f
.
target
.
shape
,
fx
.
target
.
shape
)
assert_equal
(
fx
.
val
,
67.
)
f
=
ift
.
Field
.
from_random
(
"normal"
,
s
)
f2
=
ift
.
Field
.
from_random
(
"normal"
,
s
)
...
...
test/test_multi_field.py
View file @
0bef1520
...
...
@@ -40,7 +40,7 @@ def test_multifield_field_consistency():
f1
=
ift
.
full
(
dom
,
27
)
f2
=
ift
.
makeField
(
dom
[
'd1'
],
f1
[
'd1'
].
val
)
assert_equal
(
f1
.
s_sum
(),
f2
.
s_sum
())
assert_equal
(
f1
.
size
,
f2
.
size
)
assert_equal
(
f1
.
target
.
size
,
f2
.
target
.
size
)
def
test_dataconv
():
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment