Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
82fde86c
Commit
82fde86c
authored
May 27, 2020
by
Philipp Arras
Browse files
Cleanup not working code
parent
90076c0f
Pipeline
#75661
passed with stages
in 13 minutes and 3 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/energy_operators.py
View file @
82fde86c
...
...
@@ -247,39 +247,6 @@ class GaussianEnergy(EnergyOperator):
return
res
.
add_metric
(
self
.
_met
)
return
res
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
ConstantOperator
from
..multi_domain
import
MultiDomain
if
not
self
.
_trivial_invcov
:
raise
NotImplementedError
# FIXME
# No need to implement support for DomainTuple since this done by
# Operator.simplify_for_constant_input()
assert
isinstance
(
self
.
domain
,
MultiDomain
)
c_dom
=
{}
var_dom
=
{}
not_touched_dom
=
{}
for
kk
in
self
.
_domain
.
keys
():
if
kk
in
c_inp
.
domain
.
keys
():
c_dom
[
kk
]
=
self
.
_domain
[
kk
]
else
:
var_dom
[
kk
]
=
self
.
_domain
[
kk
]
for
kk
in
set
(
c_inp
.
keys
())
-
set
(
self
.
_domain
.
keys
()):
not_touched_dom
[
kk
]
=
c_inp
.
domain
[
kk
]
var_dom
=
MultiDomain
.
make
(
var_dom
)
c_dom
=
MultiDomain
.
make
(
c_dom
)
not_touched_dom
=
MultiDomain
.
make
(
not_touched_dom
)
c_mean
=
None
if
self
.
_mean
is
None
else
self
.
_mean
.
extract
(
c_dom
)
var_mean
=
None
if
self
.
_mean
is
None
else
self
.
_mean
.
extract
(
var_dom
)
c_op
=
ConstantOperator
(
c_dom
,
GaussianEnergy
(
c_mean
,
None
,
c_inp
.
domain
)(
c_inp
))
var_op
=
GaussianEnergy
(
var_mean
,
None
,
var_dom
)
#@ rest
newop
=
var_op
+
c_op
return
c_inp
.
extract_part
(
not_touched_dom
),
newop
def
__repr__
(
self
):
dom
=
'()'
if
isinstance
(
self
.
domain
,
DomainTuple
)
else
self
.
domain
.
keys
()
return
f
'GaussianEnergy
{
dom
}
'
...
...
nifty6/operators/operator.py
View file @
82fde86c
...
...
@@ -18,6 +18,7 @@
import
numpy
as
np
from
..
import
pointwise
from
..logger
import
logger
from
..multi_domain
import
MultiDomain
from
..utilities
import
NiftyMeta
,
indent
...
...
@@ -274,8 +275,12 @@ class Operator(metaclass=NiftyMeta):
from
.simplify_for_const
import
ConstantEnergyOperator
,
ConstantOperator
if
c_inp
is
None
:
return
None
,
self
if
isinstance
(
self
.
domain
,
MultiDomain
):
assert
isinstance
(
c_inp
.
domain
,
MultiDomain
)
if
set
(
c_inp
.
keys
())
>
set
(
self
.
domain
.
keys
()):
raise
ValueError
if
c_inp
.
domain
is
self
.
domain
:
if
isinstance
(
self
,
EnergyOperator
):
op
=
ConstantEnergyOperator
(
self
.
domain
,
self
(
c_inp
))
...
...
@@ -283,34 +288,17 @@ class Operator(metaclass=NiftyMeta):
op
=
ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
op
=
ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
return
op
(
c_inp
),
op
if
isinstance
(
self
.
domain
,
MultiDomain
)
and
\
set
(
c_inp
.
keys
())
>
set
(
self
.
domain
.
keys
()):
raise
NotImplementedError
(
'This branch is not tested yet'
)
op
=
ConstantOperator
(
self
.
domain
,
self
.
force
(
c_inp
))
from
..sugar
import
makeField
unaffected
=
makeField
({
kk
:
vv
for
kk
,
vv
in
c_inp
.
items
()
if
kk
not
in
self
.
domain
})
for
kk
in
unaffected
:
assert
kk
not
in
self
.
domain
assert
kk
not
in
self
.
target
return
op
.
force
(
c_inp
),
op
if
not
isinstance
(
c_inp
.
domain
,
MultiDomain
):
raise
RuntimeError
return
self
.
_simplify_for_constant_input_nontrivial
(
c_inp
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
SlowPartialConstantOperator
from
..multi_field
import
MultiField
try
:
c_out
=
self
.
force
(
c_inp
)
except
KeyError
:
c_out
=
None
if
isinstance
(
c_out
,
MultiField
):
dct
=
{}
for
kk
in
set
(
c_inp
.
keys
())
-
set
(
self
.
domain
.
keys
()):
if
isinstance
(
self
.
target
,
MultiDomain
)
and
kk
in
self
.
target
.
keys
():
raise
NotImplementedError
dct
[
kk
]
=
c_inp
[
kk
]
c_out
=
c_out
.
unite
(
MultiField
.
from_dict
(
dct
))
return
c_out
,
self
@
SlowPartialConstantOperator
(
self
.
domain
,
c_inp
.
keys
())
s
=
(
'SlowPartialConstantOperator used. You might want to consider'
,
' implementing `_simplify_for_constant_input_nontrivial()` for'
,
' this operator.'
)
logger
.
warning
(
s
)
return
None
,
self
@
SlowPartialConstantOperator
(
self
.
domain
,
c_inp
.
keys
())
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
return
_OpChain
.
make
((
_FunctionApplier
(
self
.
target
,
op
,
*
args
,
**
kwargs
),
self
))
...
...
nifty6/operators/simplify_for_const.py
View file @
82fde86c
...
...
@@ -16,7 +16,6 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
from
.block_diagonal_operator
import
BlockDiagonalOperator
from
.energy_operators
import
EnergyOperator
from
.operator
import
Operator
...
...
@@ -90,21 +89,17 @@ class SlowPartialConstantOperator(Operator):
from
..sugar
import
makeDomain
if
not
isinstance
(
domain
,
MultiDomain
):
raise
TypeError
self
.
_keys
=
set
(
constant_keys
)
&
set
(
domain
.
keys
())
if
len
(
self
.
_keys
)
==
0
:
if
set
(
constant_keys
)
>
set
(
domain
.
keys
())
or
len
(
constant_keys
)
==
0
:
raise
ValueError
self
.
_keys
=
set
(
constant_keys
)
&
set
(
domain
.
keys
())
self
.
_domain
=
self
.
_target
=
makeDomain
(
domain
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
if
x
.
jac
is
None
:
return
x
jac
=
{}
for
kk
,
dd
in
self
.
_domain
.
items
():
fac
=
1
if
kk
in
self
.
_keys
:
fac
=
0
jac
[
kk
]
=
ScalingOperator
(
dd
,
fac
)
jac
=
{
kk
:
ScalingOperator
(
dd
,
0
if
kk
in
self
.
_keys
else
1
)
for
kk
,
dd
in
self
.
_domain
.
items
()}
return
x
.
prepend_jac
(
BlockDiagonalOperator
(
x
.
jac
.
domain
,
jac
))
def
__repr__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment