Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
0c406f77
Commit
0c406f77
authored
May 19, 2021
by
Martin Reinecke
Browse files
Merge remote-tracking branch 'origin/NIFTy_7' into fix_deprecation_warnings
parents
66266a70
f8d19b19
Pipeline
#101746
passed with stages
in 13 minutes and 11 seconds
Changes
12
Pipelines
1
Show whitespace changes
Inline
Side-by-side
.gitlab-ci.yml
View file @
0c406f77
...
@@ -5,11 +5,20 @@ variables:
...
@@ -5,11 +5,20 @@ variables:
OMP_NUM_THREADS
:
1
OMP_NUM_THREADS
:
1
stages
:
stages
:
-
static_checks
-
build_docker
-
build_docker
-
test
-
test
-
release
-
release
-
demo_runs
-
demo_runs
check_no_asserts
:
image
:
debian:testing-slim
stage
:
static_checks
before_script
:
-
ls
script
:
-
if [ `grep -r "^[[:space:]]*assert[ (]" src demos | wc -l` -ne 0 ]; then echo "Have found assert statements. Don't use them! Use \`utilities.myassert\` instead." && exit 1; fi
build_docker_from_scratch
:
build_docker_from_scratch
:
only
:
only
:
-
schedules
-
schedules
...
...
demos/getting_started_density.py
View file @
0c406f77
...
@@ -30,62 +30,6 @@ import numpy as np
...
@@ -30,62 +30,6 @@ import numpy as np
import
nifty7
as
ift
import
nifty7
as
ift
def
density_estimator
(
domain
,
pad
=
1.0
,
cf_fluctuations
=
None
,
cf_azm_uniform
=
None
):
cf_azm_uniform_sane_default
=
(
1e-4
,
1.0
)
cf_fluctuations_sane_default
=
{
"scale"
:
(
0.5
,
0.3
),
"cutoff"
:
(
4.0
,
3.0
),
"loglogslope"
:
(
-
6.0
,
3.0
)
}
domain
=
ift
.
DomainTuple
.
make
(
domain
)
dom_scaling
=
1.
+
np
.
broadcast_to
(
pad
,
(
len
(
domain
.
axes
),
))
if
cf_fluctuations
is
None
:
cf_fluctuations
=
cf_fluctuations_sane_default
if
cf_azm_uniform
is
None
:
cf_azm_uni
=
cf_azm_uniform_sane_default
domain_padded
=
[]
for
d_scl
,
d
in
zip
(
dom_scaling
,
domain
):
if
not
isinstance
(
d
,
ift
.
RGSpace
)
or
d
.
harmonic
:
te
=
[
f
"unexpected domain encountered in `domain`:
{
domain
}
"
]
te
+=
"expected a non-harmonic `ift.RGSpace`"
raise
TypeError
(
"
\n
"
.
join
(
te
))
shape_padded
=
tuple
((
d_scl
*
np
.
array
(
d
.
shape
)).
astype
(
int
))
domain_padded
.
append
(
ift
.
RGSpace
(
shape_padded
,
distances
=
d
.
distances
))
domain_padded
=
ift
.
DomainTuple
.
make
(
domain_padded
)
# Set up the signal model
azm_offset_mean
=
0.0
# The zero-mode should be inferred only from the data
cfmaker
=
ift
.
CorrelatedFieldMaker
(
""
)
for
i
,
d
in
enumerate
(
domain_padded
):
if
isinstance
(
cf_fluctuations
,
(
list
,
tuple
)):
cf_fl
=
cf_fluctuations
[
i
]
else
:
cf_fl
=
cf_fluctuations
cfmaker
.
add_fluctuations_matern
(
d
,
**
cf_fl
,
prefix
=
f
"ax
{
i
}
"
)
scalar_domain
=
ift
.
DomainTuple
.
scalar_domain
()
uniform
=
ift
.
UniformOperator
(
scalar_domain
,
*
cf_azm_uni
)
azm
=
uniform
.
ducktape
(
"zeromode"
)
cfmaker
.
set_amplitude_total_offset
(
azm_offset_mean
,
azm
)
correlated_field
=
cfmaker
.
finalize
(
0
).
clip
(
-
10.
,
10.
)
normalized_amplitudes
=
cfmaker
.
get_normalized_amplitudes
()
domain_shape
=
tuple
(
d
.
shape
for
d
in
domain
)
slc
=
ift
.
SliceOperator
(
correlated_field
.
target
,
domain_shape
)
signal
=
ift
.
exp
(
slc
@
correlated_field
)
model_operators
=
{
"correlated_field"
:
correlated_field
,
"select_subset"
:
slc
,
"amplitude_total_offset"
:
azm
,
"normalized_amplitudes"
:
normalized_amplitudes
}
return
signal
,
model_operators
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
filename
=
"getting_started_density_{}.png"
filename
=
"getting_started_density_{}.png"
ift
.
random
.
push_sseq_from_seed
(
42
)
ift
.
random
.
push_sseq_from_seed
(
42
)
...
@@ -97,7 +41,7 @@ if __name__ == "__main__":
...
@@ -97,7 +41,7 @@ if __name__ == "__main__":
sp2
=
ift
.
RGSpace
(
npix2
)
sp2
=
ift
.
RGSpace
(
npix2
)
position_space
=
ift
.
DomainTuple
.
make
((
sp1
,
sp2
))
position_space
=
ift
.
DomainTuple
.
make
((
sp1
,
sp2
))
signal
,
ops
=
density_estimator
(
position_space
)
signal
,
ops
=
ift
.
density_estimator
(
position_space
)
correlated_field
=
ops
[
"correlated_field"
]
correlated_field
=
ops
[
"correlated_field"
]
data_space
=
signal
.
target
data_space
=
signal
.
target
...
...
src/extra.py
View file @
0c406f77
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#
# Copyright(C) 2013-202
0
Max-Planck-Society
# Copyright(C) 2013-202
1
Max-Planck-Society
#
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
@@ -341,7 +341,7 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
...
@@ -341,7 +341,7 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
val0
=
op
(
loc
)
val0
=
op
(
loc
)
_
,
op0
=
op
.
simplify_for_constant_input
(
cstloc
)
_
,
op0
=
op
.
simplify_for_constant_input
(
cstloc
)
assert
op0
.
domain
is
varloc
.
domain
my
assert
(
op0
.
domain
is
varloc
.
domain
)
val1
=
op0
(
varloc
)
val1
=
op0
(
varloc
)
assert_equal
(
val0
,
val1
)
assert_equal
(
val0
,
val1
)
...
@@ -350,7 +350,7 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
...
@@ -350,7 +350,7 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
oplin0
=
op0
(
lin0
)
oplin0
=
op0
(
lin0
)
oplin
=
op
(
lin
)
oplin
=
op
(
lin
)
assert
oplin
.
jac
.
target
is
oplin0
.
jac
.
target
my
assert
(
oplin
.
jac
.
target
is
oplin0
.
jac
.
target
)
rndinp
=
from_random
(
oplin
.
jac
.
target
)
rndinp
=
from_random
(
oplin
.
jac
.
target
)
assert_allclose
(
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
varloc
.
domain
),
assert_allclose
(
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
varloc
.
domain
),
oplin0
.
jac
.
adjoint
(
rndinp
),
1e-13
,
1e-13
)
oplin0
.
jac
.
adjoint
(
rndinp
),
1e-13
,
1e-13
)
...
...
src/library/correlated_fields.py
View file @
0c406f77
...
@@ -43,6 +43,7 @@ from ..operators.operator import Operator
...
@@ -43,6 +43,7 @@ from ..operators.operator import Operator
from
..operators.simple_linear_operators
import
VdotOperator
,
ducktape
from
..operators.simple_linear_operators
import
VdotOperator
,
ducktape
from
..probing
import
StatCalculator
from
..probing
import
StatCalculator
from
..sugar
import
full
,
makeDomain
,
makeField
,
makeOp
from
..sugar
import
full
,
makeDomain
,
makeField
,
makeOp
from
..utilities
import
myassert
def
_log_k_lengths
(
pspace
):
def
_log_k_lengths
(
pspace
):
...
@@ -54,17 +55,17 @@ def _relative_log_k_lengths(power_space):
...
@@ -54,17 +55,17 @@ def _relative_log_k_lengths(power_space):
"""Log-distance to first bin
"""Log-distance to first bin
logkl.shape==power_space.shape, logkl[0]=logkl[1]=0"""
logkl.shape==power_space.shape, logkl[0]=logkl[1]=0"""
power_space
=
DomainTuple
.
make
(
power_space
)
power_space
=
DomainTuple
.
make
(
power_space
)
assert
isinstance
(
power_space
[
0
],
PowerSpace
)
my
assert
(
isinstance
(
power_space
[
0
],
PowerSpace
)
)
assert
len
(
power_space
)
==
1
my
assert
(
len
(
power_space
)
==
1
)
logkl
=
_log_k_lengths
(
power_space
[
0
])
logkl
=
_log_k_lengths
(
power_space
[
0
])
assert
logkl
.
shape
[
0
]
==
power_space
[
0
].
shape
[
0
]
-
1
my
assert
(
logkl
.
shape
[
0
]
==
power_space
[
0
].
shape
[
0
]
-
1
)
logkl
-=
logkl
[
0
]
logkl
-=
logkl
[
0
]
return
np
.
insert
(
logkl
,
0
,
0
)
return
np
.
insert
(
logkl
,
0
,
0
)
def
_log_vol
(
power_space
):
def
_log_vol
(
power_space
):
power_space
=
makeDomain
(
power_space
)
power_space
=
makeDomain
(
power_space
)
assert
isinstance
(
power_space
[
0
],
PowerSpace
)
my
assert
(
isinstance
(
power_space
[
0
],
PowerSpace
)
)
logk_lengths
=
_log_k_lengths
(
power_space
[
0
])
logk_lengths
=
_log_k_lengths
(
power_space
[
0
])
return
logk_lengths
[
1
:]
-
logk_lengths
[:
-
1
]
return
logk_lengths
[
1
:]
-
logk_lengths
[:
-
1
]
...
@@ -89,7 +90,7 @@ def _total_fluctuation_realized(samples):
...
@@ -89,7 +90,7 @@ def _total_fluctuation_realized(samples):
class
_SlopeRemover
(
EndomorphicOperator
):
class
_SlopeRemover
(
EndomorphicOperator
):
def
__init__
(
self
,
domain
,
space
=
0
):
def
__init__
(
self
,
domain
,
space
=
0
):
self
.
_domain
=
makeDomain
(
domain
)
self
.
_domain
=
makeDomain
(
domain
)
assert
isinstance
(
self
.
_domain
[
space
],
PowerSpace
)
my
assert
(
isinstance
(
self
.
_domain
[
space
],
PowerSpace
)
)
logkl
=
_relative_log_k_lengths
(
self
.
_domain
[
space
])
logkl
=
_relative_log_k_lengths
(
self
.
_domain
[
space
])
sc
=
logkl
/
float
(
logkl
[
-
1
])
sc
=
logkl
/
float
(
logkl
[
-
1
])
...
@@ -114,7 +115,7 @@ class _SlopeRemover(EndomorphicOperator):
...
@@ -114,7 +115,7 @@ class _SlopeRemover(EndomorphicOperator):
class
_TwoLogIntegrations
(
LinearOperator
):
class
_TwoLogIntegrations
(
LinearOperator
):
def
__init__
(
self
,
target
,
space
=
0
):
def
__init__
(
self
,
target
,
space
=
0
):
self
.
_target
=
makeDomain
(
target
)
self
.
_target
=
makeDomain
(
target
)
assert
isinstance
(
self
.
target
[
space
],
PowerSpace
)
my
assert
(
isinstance
(
self
.
target
[
space
],
PowerSpace
)
)
dom
=
list
(
self
.
_target
)
dom
=
list
(
self
.
_target
)
dom
[
space
]
=
UnstructuredDomain
((
2
,
self
.
target
[
space
].
shape
[
0
]
-
2
))
dom
[
space
]
=
UnstructuredDomain
((
2
,
self
.
target
[
space
].
shape
[
0
]
-
2
))
self
.
_domain
=
makeDomain
(
dom
)
self
.
_domain
=
makeDomain
(
dom
)
...
@@ -173,7 +174,7 @@ class _Normalization(Operator):
...
@@ -173,7 +174,7 @@ class _Normalization(Operator):
"""
"""
def
__init__
(
self
,
domain
,
space
=
0
):
def
__init__
(
self
,
domain
,
space
=
0
):
self
.
_domain
=
self
.
_target
=
DomainTuple
.
make
(
domain
)
self
.
_domain
=
self
.
_target
=
DomainTuple
.
make
(
domain
)
assert
isinstance
(
self
.
_domain
[
space
],
PowerSpace
)
my
assert
(
isinstance
(
self
.
_domain
[
space
],
PowerSpace
)
)
hspace
=
list
(
self
.
_domain
)
hspace
=
list
(
self
.
_domain
)
hspace
[
space
]
=
hspace
[
space
].
harmonic_partner
hspace
[
space
]
=
hspace
[
space
].
harmonic_partner
hspace
=
makeDomain
(
hspace
)
hspace
=
makeDomain
(
hspace
)
...
@@ -280,10 +281,10 @@ class _Amplitude(Operator):
...
@@ -280,10 +281,10 @@ class _Amplitude(Operator):
asperity > 0 or None
asperity > 0 or None
loglogavgslope probably negative
loglogavgslope probably negative
"""
"""
assert
isinstance
(
fluctuations
,
Operator
)
my
assert
(
isinstance
(
fluctuations
,
Operator
)
)
assert
isinstance
(
flexibility
,
Operator
)
or
flexibility
is
None
my
assert
(
isinstance
(
flexibility
,
Operator
)
or
flexibility
is
None
)
assert
isinstance
(
asperity
,
Operator
)
or
asperity
is
None
my
assert
(
isinstance
(
asperity
,
Operator
)
or
asperity
is
None
)
assert
isinstance
(
loglogavgslope
,
Operator
)
my
assert
(
isinstance
(
loglogavgslope
,
Operator
)
)
if
len
(
dofdex
)
>
0
:
if
len
(
dofdex
)
>
0
:
N_copies
=
max
(
dofdex
)
+
1
N_copies
=
max
(
dofdex
)
+
1
...
@@ -296,7 +297,7 @@ class _Amplitude(Operator):
...
@@ -296,7 +297,7 @@ class _Amplitude(Operator):
N_copies
=
0
N_copies
=
0
space
=
0
space
=
0
distributed_tgt
=
target
=
makeDomain
(
target
)
distributed_tgt
=
target
=
makeDomain
(
target
)
assert
isinstance
(
target
[
space
],
PowerSpace
)
my
assert
(
isinstance
(
target
[
space
],
PowerSpace
)
)
twolog
=
_TwoLogIntegrations
(
target
,
space
)
twolog
=
_TwoLogIntegrations
(
target
,
space
)
dom
=
twolog
.
domain
dom
=
twolog
.
domain
...
@@ -514,7 +515,7 @@ class CorrelatedFieldMaker:
...
@@ -514,7 +515,7 @@ class CorrelatedFieldMaker:
else
:
else
:
N
=
0
N
=
0
target_subdomain
=
makeDomain
(
target_subdomain
)
target_subdomain
=
makeDomain
(
target_subdomain
)
# assert
isinstance(target_subdomain[space], (RGSpace, HPSpace, GLSpace))
#
my
assert
(
isinstance(target_subdomain[space], (RGSpace, HPSpace, GLSpace))
)
for
arg
in
[
fluctuations
,
loglogavgslope
]:
for
arg
in
[
fluctuations
,
loglogavgslope
]:
if
len
(
arg
)
!=
2
:
if
len
(
arg
)
!=
2
:
...
@@ -803,7 +804,7 @@ class CorrelatedFieldMaker:
...
@@ -803,7 +804,7 @@ class CorrelatedFieldMaker:
a_target
=
amp
.
target
a_target
=
amp
.
target
a_space
=
0
if
not
hasattr
(
amp
,
"_space"
)
else
amp
.
_space
a_space
=
0
if
not
hasattr
(
amp
,
"_space"
)
else
amp
.
_space
a_pp
=
amp
.
target
[
a_space
]
a_pp
=
amp
.
target
[
a_space
]
assert
isinstance
(
a_pp
,
PowerSpace
)
my
assert
(
isinstance
(
a_pp
,
PowerSpace
)
)
azm_expander
=
ContractionOperator
(
azm_expander
=
ContractionOperator
(
a_target
,
spaces
=
a_space
a_target
,
spaces
=
a_space
...
...
src/minimization/line_search.py
View file @
0c406f77
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2019
, 2021
Max-Planck-Society
#
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
@@ -294,8 +294,8 @@ class LineSearch(metaclass=NiftyMeta):
...
@@ -294,8 +294,8 @@ class LineSearch(metaclass=NiftyMeta):
if
phiprime_lo
*
(
alpha_hi
-
alpha_lo
)
>=
0.
:
if
phiprime_lo
*
(
alpha_hi
-
alpha_lo
)
>=
0.
:
raise
ValueError
(
"inconsistent data"
)
raise
ValueError
(
"inconsistent data"
)
for
i
in
range
(
self
.
max_zoom_iterations
):
for
i
in
range
(
self
.
max_zoom_iterations
):
# assert
phi_lo <= phi_0 + self.c1*alpha_lo*phiprime_0
#
my
assert
(
phi_lo <= phi_0 + self.c1*alpha_lo*phiprime_0
)
# assert
phiprime_lo*(alpha_hi-alpha_lo)<0.
#
my
assert
(
phiprime_lo*(alpha_hi-alpha_lo)<0.
)
delta_alpha
=
alpha_hi
-
alpha_lo
delta_alpha
=
alpha_hi
-
alpha_lo
a
,
b
=
min
(
alpha_lo
,
alpha_hi
),
max
(
alpha_lo
,
alpha_hi
)
a
,
b
=
min
(
alpha_lo
,
alpha_hi
),
max
(
alpha_lo
,
alpha_hi
)
...
...
src/minimization/metric_gaussian_kl.py
View file @
0c406f77
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#
# Copyright(C) 2013-202
0
Max-Planck-Society
# Copyright(C) 2013-202
1
Max-Planck-Society
#
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
@@ -25,6 +25,7 @@ from ..operators.endomorphic_operator import EndomorphicOperator
...
@@ -25,6 +25,7 @@ from ..operators.endomorphic_operator import EndomorphicOperator
from
..operators.energy_operators
import
StandardHamiltonian
from
..operators.energy_operators
import
StandardHamiltonian
from
..probing
import
approximation2endo
from
..probing
import
approximation2endo
from
..sugar
import
makeOp
from
..sugar
import
makeOp
from
..utilities
import
myassert
from
.energy
import
Energy
from
.energy
import
Energy
...
@@ -47,9 +48,9 @@ def _get_lo_hi(comm, n_samples):
...
@@ -47,9 +48,9 @@ def _get_lo_hi(comm, n_samples):
def
_modify_sample_domain
(
sample
,
domain
):
def
_modify_sample_domain
(
sample
,
domain
):
"""Takes only keys from sample which are also in domain and inserts zeros
"""Takes only keys from sample which are also in domain and inserts zeros
for keys which are not in sample.domain."""
for keys which are not in sample.domain."""
from
..multi_domain
import
MultiDomain
from
..field
import
Field
from
..domain_tuple
import
DomainTuple
from
..domain_tuple
import
DomainTuple
from
..field
import
Field
from
..multi_domain
import
MultiDomain
from
..sugar
import
makeDomain
from
..sugar
import
makeDomain
domain
=
makeDomain
(
domain
)
domain
=
makeDomain
(
domain
)
if
isinstance
(
domain
,
DomainTuple
)
and
isinstance
(
sample
,
Field
):
if
isinstance
(
domain
,
DomainTuple
)
and
isinstance
(
sample
,
Field
):
...
@@ -96,7 +97,7 @@ class MetricGaussianKL(Energy):
...
@@ -96,7 +97,7 @@ class MetricGaussianKL(Energy):
if
not
_callingfrommake
:
if
not
_callingfrommake
:
raise
NotImplementedError
raise
NotImplementedError
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
assert
mean
.
domain
is
hamiltonian
.
domain
my
assert
(
mean
.
domain
is
hamiltonian
.
domain
)
self
.
_hamiltonian
=
hamiltonian
self
.
_hamiltonian
=
hamiltonian
self
.
_n_samples
=
int
(
n_samples
)
self
.
_n_samples
=
int
(
n_samples
)
self
.
_mirror_samples
=
bool
(
mirror_samples
)
self
.
_mirror_samples
=
bool
(
mirror_samples
)
...
...
src/operator_tree_optimiser.py
View file @
0c406f77
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#
# Copyright(C) 2013-202
0
Max-Planck-Society
# Copyright(C) 2013-202
1
Max-Planck-Society
#
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
@@ -23,6 +23,7 @@ from .multi_field import MultiField
...
@@ -23,6 +23,7 @@ from .multi_field import MultiField
from
.operators.operator
import
_OpChain
,
_OpProd
,
_OpSum
from
.operators.operator
import
_OpChain
,
_OpProd
,
_OpSum
from
.operators.simple_linear_operators
import
FieldAdapter
from
.operators.simple_linear_operators
import
FieldAdapter
from
.sugar
import
domain_union
,
from_random
from
.sugar
import
domain_union
,
from_random
from
.utilities
import
myassert
def
_optimise_operator
(
op
):
def
_optimise_operator
(
op
):
...
@@ -312,7 +313,7 @@ def optimise_operator(op):
...
@@ -312,7 +313,7 @@ def optimise_operator(op):
test_field
=
from_random
(
op
.
domain
)
test_field
=
from_random
(
op
.
domain
)
if
isinstance
(
op
(
test_field
),
MultiField
):
if
isinstance
(
op
(
test_field
),
MultiField
):
for
key
in
op
(
test_field
).
keys
():
for
key
in
op
(
test_field
).
keys
():
assert
allclose
(
op
(
test_field
).
val
[
key
],
op_optimised
(
test_field
).
val
[
key
],
1e-10
)
my
assert
(
allclose
(
op
(
test_field
).
val
[
key
],
op_optimised
(
test_field
).
val
[
key
],
1e-10
)
)
else
:
else
:
assert
allclose
(
op
(
test_field
).
val
,
op_optimised
(
test_field
).
val
,
1e-10
)
my
assert
(
allclose
(
op
(
test_field
).
val
,
op_optimised
(
test_field
).
val
,
1e-10
)
)
return
op_optimised
return
op_optimised
src/operators/einsum.py
View file @
0c406f77
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#
# Copyright(C) 2013-202
0
Max-Planck-Society
# Copyright(C) 2013-202
1
Max-Planck-Society
# Authors: Gordian Edenhofer, Philipp Frank
# Authors: Gordian Edenhofer, Philipp Frank
#
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
@@ -25,6 +25,7 @@ from ..field import Field
...
@@ -25,6 +25,7 @@ from ..field import Field
from
..linearization
import
Linearization
from
..linearization
import
Linearization
from
..multi_domain
import
MultiDomain
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
from
..multi_field
import
MultiField
from
..utilities
import
myassert
from
.linear_operator
import
LinearOperator
from
.linear_operator
import
LinearOperator
from
.operator
import
Operator
from
.operator
import
Operator
...
@@ -248,7 +249,7 @@ class LinearEinsum(LinearOperator):
...
@@ -248,7 +249,7 @@ class LinearEinsum(LinearOperator):
if
k_hit
in
_key_order
:
if
k_hit
in
_key_order
:
tgt
+=
[
self
.
_mf
.
domain
[
k_hit
][
dom_k_idx
]]
tgt
+=
[
self
.
_mf
.
domain
[
k_hit
][
dom_k_idx
]]
else
:
else
:
assert
k_hit
==
id
(
self
)
my
assert
(
k_hit
==
id
(
self
)
)
tgt
+=
[
self
.
_domain
[
dom_k_idx
]]
tgt
+=
[
self
.
_domain
[
dom_k_idx
]]
numpy_subscripts
+=
""
.
join
(
subscriptmap
[
o
])
numpy_subscripts
+=
""
.
join
(
subscriptmap
[
o
])
_target
=
DomainTuple
.
make
(
tgt
)
_target
=
DomainTuple
.
make
(
tgt
)
...
...
src/operators/energy_operators.py
View file @
0c406f77
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#
# Copyright(C) 2013-202
0
Max-Planck-Society
# Copyright(C) 2013-202
1
Max-Planck-Society
#
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
@@ -23,6 +23,7 @@ from ..field import Field
...
@@ -23,6 +23,7 @@ from ..field import Field
from
..multi_domain
import
MultiDomain
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
from
..multi_field
import
MultiField
from
..sugar
import
makeDomain
,
makeOp
from
..sugar
import
makeDomain
,
makeOp
from
..utilities
import
myassert
from
.linear_operator
import
LinearOperator
from
.linear_operator
import
LinearOperator
from
.operator
import
Operator
from
.operator
import
Operator
from
.sampling_enabler
import
SamplingDtypeSetter
,
SamplingEnabler
from
.sampling_enabler
import
SamplingDtypeSetter
,
SamplingEnabler
...
@@ -178,9 +179,9 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
...
@@ -178,9 +179,9 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
ConstantEnergyOperator
from
.simplify_for_const
import
ConstantEnergyOperator
assert
len
(
c_inp
.
keys
())
==
1
my
assert
(
len
(
c_inp
.
keys
())
==
1
)
key
=
c_inp
.
keys
()[
0
]
key
=
c_inp
.
keys
()[
0
]
assert
key
in
self
.
_domain
.
keys
()
my
assert
(
key
in
self
.
_domain
.
keys
()
)
cst
=
c_inp
[
key
]
cst
=
c_inp
[
key
]
if
key
==
self
.
_kr
:
if
key
==
self
.
_kr
:
res
=
_SpecialGammaEnergy
(
cst
).
ducktape
(
self
.
_ki
)
res
=
_SpecialGammaEnergy
(
cst
).
ducktape
(
self
.
_ki
)
...
@@ -193,7 +194,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
...
@@ -193,7 +194,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
trlog
/=
2
trlog
/=
2
res
=
res
+
ConstantEnergyOperator
(
-
trlog
)
res
=
res
+
ConstantEnergyOperator
(
-
trlog
)
res
=
res
+
ConstantEnergyOperator
(
0.
)
res
=
res
+
ConstantEnergyOperator
(
0.
)
assert
res
.
target
is
self
.
target
my
assert
(
res
.
target
is
self
.
target
)
return
None
,
res
return
None
,
res
...
...
src/operators/operator.py
View file @
0c406f77
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#
# Copyright(C) 2013-202
0
Max-Planck-Society
# Copyright(C) 2013-202
1
Max-Planck-Society
#
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
@@ -20,7 +20,7 @@ import numpy as np
...
@@ -20,7 +20,7 @@ import numpy as np
from
..
import
pointwise
from
..
import
pointwise
from
..logger
import
logger
from
..logger
import
logger
from
..multi_domain
import
MultiDomain
from
..multi_domain
import
MultiDomain
from
..utilities
import
NiftyMeta
,
indent
from
..utilities
import
NiftyMeta
,
indent
,
myassert
class
Operator
(
metaclass
=
NiftyMeta
):
class
Operator
(
metaclass
=
NiftyMeta
):
...
@@ -270,8 +270,8 @@ class Operator(metaclass=NiftyMeta):
...
@@ -270,8 +270,8 @@ class Operator(metaclass=NiftyMeta):
return
self
@
ducktape
(
self
,
None
,
name
)
return
self
@
ducktape
(
self
,
None
,
name
)
def
ducktape_left
(
self
,
name
):
def
ducktape_left
(
self
,
name
):
from
..sugar
import
is_fieldlike
,
is_linearization
,
is_operator
from
.simple_linear_operators
import
ducktape
from
.simple_linear_operators
import
ducktape
from
..sugar
import
is_operator
,
is_fieldlike
,
is_linearization
if
is_operator
(
self
):
if
is_operator
(
self
):
return
ducktape
(
None
,
self
,
name
)
@
self
return
ducktape
(
None
,
self
,
name
)
@
self
if
is_fieldlike
(
self
)
or
is_linearization
(
self
):
if
is_fieldlike
(
self
)
or
is_linearization
(
self
):
...
@@ -281,11 +281,12 @@ class Operator(metaclass=NiftyMeta):
...
@@ -281,11 +281,12 @@ class Operator(metaclass=NiftyMeta):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
def
simplify_for_constant_input
(
self
,
c_inp
):
def
simplify_for_constant_input
(
self
,
c_inp
):
from
.energy_operators
import
EnergyOperator
from
.simplify_for_const
import
ConstantEnergyOperator
,
ConstantOperator
from
..multi_field
import
MultiField
from
..domain_tuple
import
DomainTuple
from
..domain_tuple
import
DomainTuple
from
..multi_field
import
MultiField
from
..sugar
import
makeDomain
from
..sugar
import
makeDomain
from
.energy_operators
import
EnergyOperator
from
.simplify_for_const
import
(
ConstantEnergyOperator
,
ConstantOperator
)
if
c_inp
is
None
or
(
isinstance
(
c_inp
,
MultiField
)
and
len
(
c_inp
.
keys
())
==
0
):
if
c_inp
is
None
or
(
isinstance
(
c_inp
,
MultiField
)
and
len
(
c_inp
.
keys
())
==
0
):
return
None
,
self
return
None
,
self
dom
=
c_inp
.
domain
dom
=
c_inp
.
domain
...
@@ -295,7 +296,7 @@ class Operator(metaclass=NiftyMeta):
...
@@ -295,7 +296,7 @@ class Operator(metaclass=NiftyMeta):
# Convention: If c_inp is MultiField, it needs to be defined on a
# Convention: If c_inp is MultiField, it needs to be defined on a
# subdomain of self._domain
# subdomain of self._domain
if
isinstance
(
self
.
domain
,
MultiDomain
):
if
isinstance
(
self
.
domain
,
MultiDomain
):
assert
isinstance
(
dom
,
MultiDomain
)
my
assert
(
isinstance
(
dom
,
MultiDomain
)
)
if
not
set
(
c_inp
.
keys
())
<=
set
(
self
.
domain
.
keys
()):
if
not
set
(
c_inp
.
keys
())
<=
set
(
self
.
domain
.
keys
()):
raise
ValueError
raise
ValueError
...
@@ -312,13 +313,13 @@ class Operator(metaclass=NiftyMeta):
...
@@ -312,13 +313,13 @@ class Operator(metaclass=NiftyMeta):
c_out
,
op
=
self
.
_simplify_for_constant_input_nontrivial
(
c_inp
)
c_out
,
op
=
self
.
_simplify_for_constant_input_nontrivial
(
c_inp
)
vardom
=
makeDomain
({
kk
:
vv
for
kk
,
vv
in
self
.
domain
.
items
()
vardom
=
makeDomain
({
kk
:
vv
for
kk
,
vv
in
self
.
domain
.
items
()
if
kk
not
in
c_inp
.
keys
()})
if
kk
not
in
c_inp
.
keys
()})
assert
op
<