Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
5d607ffc
Commit
5d607ffc
authored
May 27, 2020
by
Philipp Arras
Browse files
Restructure
parent
32af4710
Pipeline
#75652
failed with stages
in 4 minutes and 3 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/energy_operators.py
View file @
5d607ffc
...
...
@@ -485,24 +485,3 @@ class AveragedEnergy(EnergyOperator):
self
.
_check_input
(
x
)
mymap
=
map
(
lambda
v
:
self
.
_h
(
x
+
v
),
self
.
_res_samples
)
return
utilities
.
my_sum
(
mymap
)
/
len
(
self
.
_res_samples
)
class
_ConstantEnergyOperator
(
EnergyOperator
):
def
__init__
(
self
,
dom
,
output
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
dom
)
if
self
.
target
is
not
output
.
domain
:
raise
TypeError
self
.
_output
=
output
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
if
x
.
jac
is
not
None
:
val
=
self
.
_output
jac
=
NullOperator
(
self
.
_domain
,
self
.
_target
)
met
=
NullOperator
(
self
.
_domain
,
self
.
_domain
)
if
x
.
want_metric
else
None
return
x
.
new
(
val
,
jac
,
met
)
return
self
.
_output
def
__repr__
(
self
):
return
'ConstantEnergyOperator <- {}'
.
format
(
self
.
domain
.
keys
())
nifty6/operators/operator.py
View file @
5d607ffc
...
...
@@ -17,8 +17,9 @@
import
numpy
as
np
from
..utilities
import
NiftyMeta
,
indent
from
..
import
pointwise
from
..multi_domain
import
MultiDomain
from
..utilities
import
NiftyMeta
,
indent
class
Operator
(
metaclass
=
NiftyMeta
):
...
...
@@ -269,19 +270,34 @@ class Operator(metaclass=NiftyMeta):
return
self
.
__class__
.
__name__
def
simplify_for_constant_input
(
self
,
c_inp
):
from
.energy_operators
import
EnergyOperator
,
_ConstantEnergyOperator
from
.energy_operators
import
EnergyOperator
from
.simplify_for_const
import
ConstantEnergyOperator
,
ConstantOperator
if
c_inp
is
None
:
return
None
,
self
if
c_inp
.
domain
==
self
.
domain
:
if
isinstance
(
self
.
domain
,
MultiDomain
):
assert
isinstance
(
c_inp
.
domain
,
MultiDomain
)
if
c_inp
.
domain
is
self
.
domain
:
if
isinstance
(
self
,
EnergyOperator
):
op
=
_
ConstantEnergyOperator
(
self
.
domain
,
self
(
c_inp
))
op
=
ConstantEnergyOperator
(
self
.
domain
,
self
(
c_inp
))
else
:
op
=
_ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
op
=
ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
op
=
ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
return
op
(
c_inp
),
op
if
isinstance
(
self
.
domain
,
MultiDomain
)
and
\
set
(
c_inp
.
keys
())
>
set
(
self
.
domain
.
keys
()):
raise
NotImplementedError
(
'This branch is not tested yet'
)
op
=
ConstantOperator
(
self
.
domain
,
self
.
force
(
c_inp
))
from
..sugar
import
makeField
unaffected
=
makeField
({
kk
:
vv
for
kk
,
vv
in
c_inp
.
items
()
if
kk
not
in
self
.
domain
})
for
kk
in
unaffected
:
assert
kk
not
in
self
.
domain
assert
kk
not
in
self
.
target
return
op
.
force
(
c_inp
),
op
return
self
.
_simplify_for_constant_input_nontrivial
(
c_inp
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
return
None
,
self
from
.simplify_for_const
import
SlowPartialConstantOperator
return
None
,
SlowPartialConstantOperator
(
self
,
c_inp
)
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
return
_OpChain
.
make
((
_FunctionApplier
(
self
.
target
,
op
,
*
args
,
**
kwargs
),
self
))
...
...
@@ -295,67 +311,6 @@ for f in pointwise.ptw_dict.keys():
setattr
(
Operator
,
f
,
func
(
f
))
class
_ConstCollector
(
object
):
def
__init__
(
self
):
self
.
_const
=
None
self
.
_nc
=
set
()
def
mult
(
self
,
const
,
fulldom
):
if
const
is
None
:
self
.
_nc
|=
set
(
fulldom
)
else
:
self
.
_nc
|=
set
(
fulldom
)
-
set
(
const
)
if
self
.
_const
is
None
:
from
..multi_field
import
MultiField
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
const
[
key
]
for
key
in
const
if
key
not
in
self
.
_nc
})
else
:
from
..multi_field
import
MultiField
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
self
.
_const
[
key
]
*
const
[
key
]
for
key
in
const
if
key
not
in
self
.
_nc
})
def
add
(
self
,
const
,
fulldom
):
if
const
is
None
:
self
.
_nc
|=
set
(
fulldom
.
keys
())
else
:
from
..multi_field
import
MultiField
self
.
_nc
|=
set
(
fulldom
.
keys
())
-
set
(
const
.
keys
())
if
self
.
_const
is
None
:
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
const
[
key
]
for
key
in
const
.
keys
()
if
key
not
in
self
.
_nc
})
else
:
self
.
_const
=
self
.
_const
.
unite
(
const
)
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
self
.
_const
[
key
]
for
key
in
self
.
_const
if
key
not
in
self
.
_nc
})
@
property
def
constfield
(
self
):
return
self
.
_const
class
_ConstantOperator
(
Operator
):
def
__init__
(
self
,
dom
,
output
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
dom
)
self
.
_target
=
output
.
domain
self
.
_output
=
output
def
apply
(
self
,
x
):
from
.simple_linear_operators
import
NullOperator
self
.
_check_input
(
x
)
if
x
.
jac
is
not
None
:
return
x
.
new
(
self
.
_output
,
NullOperator
(
self
.
_domain
,
self
.
_target
))
return
self
.
_output
def
__repr__
(
self
):
dom
=
self
.
domain
.
keys
()
if
isinstance
(
self
.
domain
,
MultiDomain
)
else
'()'
tgt
=
self
.
target
.
keys
()
if
isinstance
(
self
.
target
,
MultiDomain
)
else
'()'
return
f
'
{
tgt
}
<- ConstantOperator <-
{
dom
}
'
class
_FunctionApplier
(
Operator
):
def
__init__
(
self
,
domain
,
funcname
,
*
args
,
**
kwargs
):
from
..sugar
import
makeDomain
...
...
@@ -450,16 +405,16 @@ class _OpProd(Operator):
return
lin1
.
new
(
lin1
.
_val
*
lin2
.
_val
,
jac
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
from
.simplify_for_const
import
ConstCollector
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op1
.
domain
))
f2
,
o2
=
self
.
_op2
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op2
.
domain
))
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
return
None
,
_OpProd
(
o1
,
o2
)
cc
=
_ConstCollector
()
cc
=
ConstCollector
()
cc
.
mult
(
f1
,
o1
.
target
)
cc
.
mult
(
f2
,
o2
.
target
)
return
cc
.
constfield
,
_OpProd
(
o1
,
o2
)
...
...
@@ -496,16 +451,16 @@ class _OpSum(Operator):
return
res
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
from
.simplify_for_const
import
ConstCollector
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op1
.
domain
))
f2
,
o2
=
self
.
_op2
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op2
.
domain
))
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
return
None
,
_OpSum
(
o1
,
o2
)
cc
=
_ConstCollector
()
cc
=
ConstCollector
()
cc
.
add
(
f1
,
o1
.
target
)
cc
.
add
(
f2
,
o2
.
target
)
return
cc
.
constfield
,
_OpSum
(
o1
,
o2
)
...
...
nifty6/operators/simplify_for_const.py
0 → 100644
View file @
5d607ffc
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
..multi_domain
import
MultiDomain
from
.energy_operators
import
EnergyOperator
from
.operator
import
Operator
from
.simple_linear_operators
import
NullOperator
class
ConstCollector
(
object
):
def
__init__
(
self
):
self
.
_const
=
None
self
.
_nc
=
set
()
def
mult
(
self
,
const
,
fulldom
):
if
const
is
None
:
self
.
_nc
|=
set
(
fulldom
)
else
:
self
.
_nc
|=
set
(
fulldom
)
-
set
(
const
)
if
self
.
_const
is
None
:
from
..multi_field
import
MultiField
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
const
[
key
]
for
key
in
const
if
key
not
in
self
.
_nc
})
else
:
from
..multi_field
import
MultiField
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
self
.
_const
[
key
]
*
const
[
key
]
for
key
in
const
if
key
not
in
self
.
_nc
})
def
add
(
self
,
const
,
fulldom
):
if
const
is
None
:
self
.
_nc
|=
set
(
fulldom
.
keys
())
else
:
from
..multi_field
import
MultiField
self
.
_nc
|=
set
(
fulldom
.
keys
())
-
set
(
const
.
keys
())
if
self
.
_const
is
None
:
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
const
[
key
]
for
key
in
const
.
keys
()
if
key
not
in
self
.
_nc
})
else
:
self
.
_const
=
self
.
_const
.
unite
(
const
)
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
self
.
_const
[
key
]
for
key
in
self
.
_const
if
key
not
in
self
.
_nc
})
@
property
def
constfield
(
self
):
return
self
.
_const
class
ConstantOperator
(
Operator
):
def
__init__
(
self
,
dom
,
output
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
dom
)
self
.
_target
=
output
.
domain
self
.
_output
=
output
def
apply
(
self
,
x
):
from
.simple_linear_operators
import
NullOperator
self
.
_check_input
(
x
)
if
x
.
jac
is
not
None
:
return
x
.
new
(
self
.
_output
,
NullOperator
(
self
.
_domain
,
self
.
_target
))
return
self
.
_output
def
__repr__
(
self
):
dom
=
self
.
domain
.
keys
()
if
isinstance
(
self
.
domain
,
MultiDomain
)
else
'()'
tgt
=
self
.
target
.
keys
()
if
isinstance
(
self
.
target
,
MultiDomain
)
else
'()'
return
f
'
{
tgt
}
<- ConstantOperator <-
{
dom
}
'
class
SlowPartialConstOperator
(
Operator
):
pass
class
ConstantEnergyOperator
(
EnergyOperator
):
def
__init__
(
self
,
dom
,
output
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
dom
)
if
self
.
target
is
not
output
.
domain
:
raise
TypeError
self
.
_output
=
output
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
if
x
.
jac
is
not
None
:
val
=
self
.
_output
jac
=
NullOperator
(
self
.
_domain
,
self
.
_target
)
met
=
NullOperator
(
self
.
_domain
,
self
.
_domain
)
if
x
.
want_metric
else
None
return
x
.
new
(
val
,
jac
,
met
)
return
self
.
_output
def
__repr__
(
self
):
return
'ConstantEnergyOperator <- {}'
.
format
(
self
.
domain
.
keys
())
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment