Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
c73b4467
Commit
c73b4467
authored
May 27, 2020
by
Philipp Arras
Browse files
Add SlowPartialConstOperator
parent
5d607ffc
Changes
3
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/energy_operators.py
View file @
c73b4467
...
...
@@ -245,6 +245,10 @@ class GaussianEnergy(EnergyOperator):
return
res
.
add_metric
(
self
.
_met
)
return
res
def
__repr__
(
self
):
dom
=
'()'
if
isinstance
(
self
.
domain
,
DomainTuple
)
else
self
.
domain
.
keys
()
return
f
'GaussianEnergy
{
dom
}
'
class
PoissonianEnergy
(
EnergyOperator
):
"""Computes likelihood Hamiltonians of expected count field constrained by
...
...
nifty6/operators/operator.py
View file @
c73b4467
...
...
@@ -297,7 +297,20 @@ class Operator(metaclass=NiftyMeta):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
SlowPartialConstantOperator
return
None
,
SlowPartialConstantOperator
(
self
,
c_inp
)
from
..multi_field
import
MultiField
try
:
c_out
=
self
.
force
(
c_inp
)
except
KeyError
:
c_out
=
None
if
isinstance
(
c_out
,
MultiField
):
dct
=
{}
for
kk
in
set
(
c_inp
.
keys
())
-
set
(
self
.
domain
.
keys
()):
if
isinstance
(
self
.
target
,
MultiDomain
)
and
kk
in
self
.
target
.
keys
():
raise
NotImplementedError
dct
[
kk
]
=
c_inp
[
kk
]
c_out
=
c_out
.
unite
(
MultiField
.
from_dict
(
dct
))
return
c_out
,
self
@
SlowPartialConstantOperator
(
self
.
domain
,
c_inp
.
keys
())
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
return
_OpChain
.
make
((
_FunctionApplier
(
self
.
target
,
op
,
*
args
,
**
kwargs
),
self
))
...
...
nifty6/operators/simplify_for_const.py
View file @
c73b4467
...
...
@@ -16,8 +16,11 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
from
.block_diagonal_operator
import
BlockDiagonalOperator
from
.energy_operators
import
EnergyOperator
from
.operator
import
Operator
from
.scaling_operator
import
ScalingOperator
from
.simple_linear_operators
import
NullOperator
...
...
@@ -82,10 +85,30 @@ class ConstantOperator(Operator):
return
f
'
{
tgt
}
<- ConstantOperator <-
{
dom
}
'
class
SlowPartialConstOperator
(
Operator
):
pass
class
SlowPartialConstantOperator
(
Operator
):
def
__init__
(
self
,
domain
,
constant_keys
):
from
..sugar
import
makeDomain
if
not
isinstance
(
domain
,
MultiDomain
):
raise
TypeError
self
.
_keys
=
set
(
constant_keys
)
&
set
(
domain
.
keys
())
if
len
(
self
.
_keys
)
==
0
:
raise
ValueError
self
.
_domain
=
self
.
_target
=
makeDomain
(
domain
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
if
x
.
jac
is
None
:
return
x
jac
=
{}
for
kk
,
dd
in
self
.
_domain
.
items
():
fac
=
1
if
kk
in
self
.
_keys
:
fac
=
0
jac
[
kk
]
=
ScalingOperator
(
dd
,
fac
)
return
x
.
prepend_jac
(
BlockDiagonalOperator
(
x
.
jac
.
domain
,
jac
))
def
__repr__
(
self
):
return
f
'SlowPartialConstantOperator (
{
self
.
_keys
}
)'
class
ConstantEnergyOperator
(
EnergyOperator
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment