Skip to content
Snippets Groups Projects

Improve exec_time

Merged Philipp Arras requested to merge improve_profiling into NIFTy_8
1 unresolved thread
2 files
+ 48
35
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 46
34
@@ -12,9 +12,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
# Copyright(C) 2022 Philipp Arras
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import cProfile
import io
import pstats
import sys
from time import time
@@ -550,53 +554,61 @@ def plot_priorsamples(op, n_samples=5, common_colorbar=True, **kwargs):
p.output(**kwargs)
def exec_time(obj, want_metric=True):
"""Times the execution time of an operator or an energy."""
def exec_time(obj, want_metric=True, verbose=False):
"""Times the execution time of an operator or an energy.
Parameters
----------
obj : Operator or Energy
Operator or Energy that shall be profiled.
want_metric : bool, optional
Determine if Operator shall be called with `want_metric=True`. Only
applicable for EnergyOperators. Default: True.
verbose : bool, optional
If True, more profiling information is printed. Default: False.
"""
from .linearization import Linearization
from .minimization.energy import Energy
if isinstance(obj, Energy):
t0 = time()
obj.at(0.99*obj.position)
logger.info('Energy.at(): {}'.format(time() - t0))
def _profile_func(func, inp, what):
t0 = time()
obj.value
logger.info('Energy.value: {}'.format(time() - t0))
t0 = time()
obj.gradient
logger.info('Energy.gradient: {}'.format(time() - t0))
t0 = time()
obj.metric
logger.info('Energy.metric: {}'.format(time() - t0))
with cProfile.Profile() as pr:
res = func(inp)
logger.info(f'{what}: {(time() - t0)*1000:>8.3f} ms')
if verbose:
s = io.StringIO()
pstats.Stats(pr, stream=s).sort_stats(pstats.SortKey.TIME).print_stats(5)
logger.info(s.getvalue())
return res
def _profile_get_attr(obj, attr, what):
return _profile_func(lambda x: getattr(obj, x), attr, what)
t0 = time()
obj.apply_metric(obj.position)
logger.info('Energy.apply_metric: {}'.format(time() - t0))
if isinstance(obj, Energy):
newpos = 0.99*obj.position
_profile_func(lambda x: x.at(newpos), obj, "Energy.at()\t\t\t\t")
_profile_get_attr(obj, "value", "Energy.value\t\t\t\t")
_profile_get_attr(obj, "gradient", "Energy.gradient\t\t\t\t")
_profile_get_attr(obj, "metric", "Energy.metric\t\t\t\t")
if obj.metric is not None:
_profile_func(lambda x: x.apply_metric(x.position), obj, "Energy.apply_metric\t\t\t")
_profile_func(lambda x: x.metric(x.position), obj, "Energy.metric(position)\t\t\t")
t0 = time()
obj.metric(obj.position)
logger.info('Energy.metric(position): {}'.format(time() - t0))
elif isinstance(obj, Operator):
want_metric = bool(want_metric)
pos = from_random(obj.domain, 'normal')
t0 = time()
obj(pos)
logger.info('Operator call with field: {}'.format(time() - t0))
pos = from_random(obj.domain, 'normal')
lin = Linearization.make_var(pos, want_metric=want_metric)
t0 = time()
res = obj(lin)
logger.info('Operator call with linearization: {}'.format(time() - t0))
_profile_func(lambda x: x(pos), obj, "Operator call with field\t\t")
res = _profile_func(lambda x: x(lin), obj, "Operator call with linearization\t")
_profile_func(lambda x: res.jac(x), pos, "Apply linearization\t\t\t")
_profile_func(lambda x: res.jac.adjoint(x), res.val, "Apply linearization (adjoint)\t\t")
if obj.target is DomainTuple.scalar_domain():
t0 = time()
res.gradient
logger.info('Gradient evaluation: {}'.format(time() - t0))
_profile_get_attr(res, "gradient", "Gradient evaluation\t\t\t")
if want_metric:
t0 = time()
res.metric(pos)
logger.info('Metric apply: {}'.format(time() - t0))
_profile_func(lambda x: res.metric(x), pos, "Metric apply\t\t\t\t")
else:
raise TypeError
Loading